haixunMaster/haixun-backend/internal/library/knowledge/synth.go

242 lines
5.7 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package knowledge
import (
"encoding/json"
"fmt"
"regexp"
"strings"
libprompt "haixun-backend/internal/library/prompt"
"github.com/google/uuid"
)
type SynthInput struct {
Seed string
ProductBrief string
TargetAudience string
Persona string
Sources []BraveSource
}
type rawSynthOutput struct {
Nodes []struct {
Label string `json:"label"`
NodeKind string `json:"nodeKind"`
Type string `json:"type"`
Layer int `json:"layer"`
Relation string `json:"relation"`
PlacementValue string `json:"placementValue"`
ProductFitScore int `json:"productFitScore"`
EvidenceURLs []string `json:"evidenceUrls"`
} `json:"nodes"`
Edges []struct {
From string `json:"from"`
To string `json:"to"`
Relation string `json:"relation"`
} `json:"edges"`
}
var codeFenceRE = regexp.MustCompile("(?s)^```(?:json)?\\s*(.*?)\\s*```$")
func BuildUserPrompt(in SynthInput) (string, error) {
var sources strings.Builder
limit := len(in.Sources)
if limit > 30 {
limit = 30
}
for i := 0; i < limit; i++ {
src := in.Sources[i]
fmt.Fprintf(&sources, "[%d] query=%s\nurl=%s\ntitle=%s\nsnippet=%s\n\n",
i+1, src.Query, src.URL, src.Title, src.Snippet)
}
vars := map[string]string{
"seed": strings.TrimSpace(in.Seed),
"product_brief_line": optionalLine("產品簡述", in.ProductBrief),
"target_audience_line": optionalLine("目標受眾", in.TargetAudience),
"persona_line": optionalLine("人設", in.Persona),
"sources": strings.TrimSpace(sources.String()),
}
return libprompt.KnowledgeGraphUser(vars)
}
func optionalLine(label, value string) string {
value = strings.TrimSpace(value)
if value == "" {
return ""
}
return label + "" + value + "\n"
}
func ParseSynthOutput(raw string, in SynthInput, sources []BraveSource) (Graph, error) {
payload, err := extractJSONObject(raw)
if err != nil {
return Graph{}, err
}
var out rawSynthOutput
if err := json.Unmarshal(payload, &out); err != nil {
return Graph{}, fmt.Errorf("parse knowledge graph json: %w", err)
}
seed := strings.TrimSpace(in.Seed)
graph := Graph{
Seed: seed,
BraveSources: sources,
Nodes: []Node{},
Edges: []Edge{},
}
sourceByURL := map[string]BraveSource{}
for _, src := range sources {
if src.URL != "" {
sourceByURL[src.URL] = src
}
}
hasCore := false
for _, item := range out.Nodes {
label := strings.TrimSpace(item.Label)
if label == "" {
continue
}
layer := item.Layer
nodeType := strings.TrimSpace(item.Type)
nodeKind := strings.TrimSpace(item.NodeKind)
if layer == 0 || nodeType == "core" {
layer = 0
nodeType = "core"
if nodeKind == "" {
nodeKind = "pain"
}
hasCore = true
}
if nodeKind == "" {
if layer >= 2 {
nodeKind = "cause"
} else if layer == 1 {
nodeKind = "symptom"
} else {
nodeKind = "knowledge"
}
}
evidence := make([]Evidence, 0, len(item.EvidenceURLs))
for _, u := range item.EvidenceURLs {
u = strings.TrimSpace(u)
if u == "" {
continue
}
ev := Evidence{URL: u}
if src, ok := sourceByURL[u]; ok {
ev.Snippet = src.Snippet
ev.Query = src.Query
}
evidence = append(evidence, ev)
}
fit := item.ProductFitScore
if fit <= 0 {
fit = defaultProductFit(nodeKind, layer)
}
graph.Nodes = append(graph.Nodes, Node{
ID: uuid.NewString(),
Label: label,
NodeKind: nodeKind,
Type: nodeType,
Layer: layer,
Relation: strings.TrimSpace(item.Relation),
PlacementValue: normalizePlacement(item.PlacementValue, nodeKind),
ProductFitScore: fit,
Evidence: evidence,
})
}
if !hasCore && seed != "" {
graph.Nodes = append([]Node{{
ID: uuid.NewString(),
Label: seed,
NodeKind: "pain",
Type: "core",
Layer: 0,
PlacementValue: "high",
ProductFitScore: 90,
}}, graph.Nodes...)
}
labelToID := map[string]string{}
for _, node := range graph.Nodes {
labelToID[strings.ToLower(strings.TrimSpace(node.Label))] = node.ID
}
for _, edge := range out.Edges {
from := resolveNodeRef(edge.From, labelToID, graph.Nodes)
to := resolveNodeRef(edge.To, labelToID, graph.Nodes)
if from == "" || to == "" || from == to {
continue
}
graph.Edges = append(graph.Edges, Edge{
From: from,
To: to,
Relation: strings.TrimSpace(edge.Relation),
})
}
DeriveSearchTagsFromGraph(&graph)
return graph, nil
}
func defaultProductFit(nodeKind string, layer int) int {
switch nodeKind {
case "pain":
if layer == 0 {
return 90
}
return 80
case "symptom", "cause":
return 70
default:
return 50
}
}
func normalizePlacement(value, nodeKind string) string {
value = strings.TrimSpace(strings.ToLower(value))
switch value {
case "high", "medium", "low":
return value
}
if IsPainNode(Node{NodeKind: nodeKind}) {
return "high"
}
return "low"
}
func resolveNodeRef(ref string, labelToID map[string]string, nodes []Node) string {
ref = strings.TrimSpace(ref)
if ref == "" {
return ""
}
for _, node := range nodes {
if node.ID == ref {
return node.ID
}
}
if id, ok := labelToID[strings.ToLower(ref)]; ok {
return id
}
return ""
}
func extractJSONObject(raw string) ([]byte, error) {
text := strings.TrimSpace(raw)
if text == "" {
return nil, fmt.Errorf("empty LLM response")
}
if m := codeFenceRE.FindStringSubmatch(text); len(m) == 2 {
text = strings.TrimSpace(m[1])
}
start := strings.Index(text, "{")
end := strings.LastIndex(text, "}")
if start < 0 || end <= start {
return nil, fmt.Errorf("LLM response does not contain JSON object")
}
return []byte(text[start : end+1]), nil
}