242 lines
5.7 KiB
Go
242 lines
5.7 KiB
Go
package knowledge
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"regexp"
|
||
"strings"
|
||
|
||
libprompt "haixun-backend/internal/library/prompt"
|
||
|
||
"github.com/google/uuid"
|
||
)
|
||
|
||
type SynthInput struct {
|
||
Seed string
|
||
ProductBrief string
|
||
TargetAudience string
|
||
Persona string
|
||
Sources []BraveSource
|
||
}
|
||
|
||
type rawSynthOutput struct {
|
||
Nodes []struct {
|
||
Label string `json:"label"`
|
||
NodeKind string `json:"nodeKind"`
|
||
Type string `json:"type"`
|
||
Layer int `json:"layer"`
|
||
Relation string `json:"relation"`
|
||
PlacementValue string `json:"placementValue"`
|
||
ProductFitScore int `json:"productFitScore"`
|
||
EvidenceURLs []string `json:"evidenceUrls"`
|
||
} `json:"nodes"`
|
||
Edges []struct {
|
||
From string `json:"from"`
|
||
To string `json:"to"`
|
||
Relation string `json:"relation"`
|
||
} `json:"edges"`
|
||
}
|
||
|
||
var codeFenceRE = regexp.MustCompile("(?s)^```(?:json)?\\s*(.*?)\\s*```$")
|
||
|
||
func BuildUserPrompt(in SynthInput) (string, error) {
|
||
var sources strings.Builder
|
||
limit := len(in.Sources)
|
||
if limit > 30 {
|
||
limit = 30
|
||
}
|
||
for i := 0; i < limit; i++ {
|
||
src := in.Sources[i]
|
||
fmt.Fprintf(&sources, "[%d] query=%s\nurl=%s\ntitle=%s\nsnippet=%s\n\n",
|
||
i+1, src.Query, src.URL, src.Title, src.Snippet)
|
||
}
|
||
vars := map[string]string{
|
||
"seed": strings.TrimSpace(in.Seed),
|
||
"product_brief_line": optionalLine("產品簡述", in.ProductBrief),
|
||
"target_audience_line": optionalLine("目標受眾", in.TargetAudience),
|
||
"persona_line": optionalLine("人設", in.Persona),
|
||
"sources": strings.TrimSpace(sources.String()),
|
||
}
|
||
return libprompt.KnowledgeGraphUser(vars)
|
||
}
|
||
|
||
func optionalLine(label, value string) string {
|
||
value = strings.TrimSpace(value)
|
||
if value == "" {
|
||
return ""
|
||
}
|
||
return label + ":" + value + "\n"
|
||
}
|
||
|
||
func ParseSynthOutput(raw string, in SynthInput, sources []BraveSource) (Graph, error) {
|
||
payload, err := extractJSONObject(raw)
|
||
if err != nil {
|
||
return Graph{}, err
|
||
}
|
||
var out rawSynthOutput
|
||
if err := json.Unmarshal(payload, &out); err != nil {
|
||
return Graph{}, fmt.Errorf("parse knowledge graph json: %w", err)
|
||
}
|
||
|
||
seed := strings.TrimSpace(in.Seed)
|
||
graph := Graph{
|
||
Seed: seed,
|
||
BraveSources: sources,
|
||
Nodes: []Node{},
|
||
Edges: []Edge{},
|
||
}
|
||
|
||
sourceByURL := map[string]BraveSource{}
|
||
for _, src := range sources {
|
||
if src.URL != "" {
|
||
sourceByURL[src.URL] = src
|
||
}
|
||
}
|
||
|
||
hasCore := false
|
||
for _, item := range out.Nodes {
|
||
label := strings.TrimSpace(item.Label)
|
||
if label == "" {
|
||
continue
|
||
}
|
||
layer := item.Layer
|
||
nodeType := strings.TrimSpace(item.Type)
|
||
nodeKind := strings.TrimSpace(item.NodeKind)
|
||
if layer == 0 || nodeType == "core" {
|
||
layer = 0
|
||
nodeType = "core"
|
||
if nodeKind == "" {
|
||
nodeKind = "pain"
|
||
}
|
||
hasCore = true
|
||
}
|
||
if nodeKind == "" {
|
||
if layer >= 2 {
|
||
nodeKind = "cause"
|
||
} else if layer == 1 {
|
||
nodeKind = "symptom"
|
||
} else {
|
||
nodeKind = "knowledge"
|
||
}
|
||
}
|
||
evidence := make([]Evidence, 0, len(item.EvidenceURLs))
|
||
for _, u := range item.EvidenceURLs {
|
||
u = strings.TrimSpace(u)
|
||
if u == "" {
|
||
continue
|
||
}
|
||
ev := Evidence{URL: u}
|
||
if src, ok := sourceByURL[u]; ok {
|
||
ev.Snippet = src.Snippet
|
||
ev.Query = src.Query
|
||
}
|
||
evidence = append(evidence, ev)
|
||
}
|
||
fit := item.ProductFitScore
|
||
if fit <= 0 {
|
||
fit = defaultProductFit(nodeKind, layer)
|
||
}
|
||
graph.Nodes = append(graph.Nodes, Node{
|
||
ID: uuid.NewString(),
|
||
Label: label,
|
||
NodeKind: nodeKind,
|
||
Type: nodeType,
|
||
Layer: layer,
|
||
Relation: strings.TrimSpace(item.Relation),
|
||
PlacementValue: normalizePlacement(item.PlacementValue, nodeKind),
|
||
ProductFitScore: fit,
|
||
Evidence: evidence,
|
||
})
|
||
}
|
||
|
||
if !hasCore && seed != "" {
|
||
graph.Nodes = append([]Node{{
|
||
ID: uuid.NewString(),
|
||
Label: seed,
|
||
NodeKind: "pain",
|
||
Type: "core",
|
||
Layer: 0,
|
||
PlacementValue: "high",
|
||
ProductFitScore: 90,
|
||
}}, graph.Nodes...)
|
||
}
|
||
|
||
labelToID := map[string]string{}
|
||
for _, node := range graph.Nodes {
|
||
labelToID[strings.ToLower(strings.TrimSpace(node.Label))] = node.ID
|
||
}
|
||
for _, edge := range out.Edges {
|
||
from := resolveNodeRef(edge.From, labelToID, graph.Nodes)
|
||
to := resolveNodeRef(edge.To, labelToID, graph.Nodes)
|
||
if from == "" || to == "" || from == to {
|
||
continue
|
||
}
|
||
graph.Edges = append(graph.Edges, Edge{
|
||
From: from,
|
||
To: to,
|
||
Relation: strings.TrimSpace(edge.Relation),
|
||
})
|
||
}
|
||
|
||
DeriveSearchTagsFromGraph(&graph)
|
||
return graph, nil
|
||
}
|
||
|
||
func defaultProductFit(nodeKind string, layer int) int {
|
||
switch nodeKind {
|
||
case "pain":
|
||
if layer == 0 {
|
||
return 90
|
||
}
|
||
return 80
|
||
case "symptom", "cause":
|
||
return 70
|
||
default:
|
||
return 50
|
||
}
|
||
}
|
||
|
||
func normalizePlacement(value, nodeKind string) string {
|
||
value = strings.TrimSpace(strings.ToLower(value))
|
||
switch value {
|
||
case "high", "medium", "low":
|
||
return value
|
||
}
|
||
if IsPainNode(Node{NodeKind: nodeKind}) {
|
||
return "high"
|
||
}
|
||
return "low"
|
||
}
|
||
|
||
func resolveNodeRef(ref string, labelToID map[string]string, nodes []Node) string {
|
||
ref = strings.TrimSpace(ref)
|
||
if ref == "" {
|
||
return ""
|
||
}
|
||
for _, node := range nodes {
|
||
if node.ID == ref {
|
||
return node.ID
|
||
}
|
||
}
|
||
if id, ok := labelToID[strings.ToLower(ref)]; ok {
|
||
return id
|
||
}
|
||
return ""
|
||
}
|
||
|
||
func extractJSONObject(raw string) ([]byte, error) {
|
||
text := strings.TrimSpace(raw)
|
||
if text == "" {
|
||
return nil, fmt.Errorf("empty LLM response")
|
||
}
|
||
if m := codeFenceRE.FindStringSubmatch(text); len(m) == 2 {
|
||
text = strings.TrimSpace(m[1])
|
||
}
|
||
start := strings.Index(text, "{")
|
||
end := strings.LastIndex(text, "}")
|
||
if start < 0 || end <= start {
|
||
return nil, fmt.Errorf("LLM response does not contain JSON object")
|
||
}
|
||
return []byte(text[start : end+1]), nil
|
||
}
|