haixunMaster/haixun-backend/internal/worker/job/expand_graph.go

498 lines
14 KiB
Go
Raw Normal View History

2026-06-24 10:02:42 +00:00
package job
import (
"context"
"fmt"
"strings"
libbrave "haixun-backend/internal/library/brave"
"haixun-backend/internal/library/clock"
app "haixun-backend/internal/library/errors"
"haixun-backend/internal/library/errors/code"
libkg "haixun-backend/internal/library/knowledge"
"haixun-backend/internal/library/placement"
libprompt "haixun-backend/internal/library/prompt"
"haixun-backend/internal/model/ai/domain/enum"
domai "haixun-backend/internal/model/ai/domain/usecase"
aiusecase "haixun-backend/internal/model/ai/usecase"
brandentity "haixun-backend/internal/model/brand/domain/entity"
branddomain "haixun-backend/internal/model/brand/domain/usecase"
jobdom "haixun-backend/internal/model/job/domain/usecase"
kgusecase "haixun-backend/internal/model/knowledge_graph/domain/usecase"
placementusecase "haixun-backend/internal/model/placement/usecase"
threadsaccountdomain "haixun-backend/internal/model/threads_account/domain/usecase"
)
type ExpandGraphDeps struct {
Jobs jobdom.UseCase
Brand branddomain.UseCase
KnowledgeGraph kgusecase.UseCase
ThreadsAccount threadsaccountdomain.UseCase
Placement placementusecase.UseCase
AI aiusecase.UseCase
}
func RegisterExpandGraphHandler(runner *Runner, deps ExpandGraphDeps) {
if runner == nil {
return
}
runner.RegisterStepHandler("expand", func(ctx context.Context, step StepContext) error {
return runExpandGraph(ctx, step, deps)
})
}
func brandIDFromPayload(payload map[string]any) string {
brandID := stringField(payload, "brand_id")
if brandID == "" {
brandID = stringField(payload, "persona_id")
}
return brandID
}
func runExpandGraph(ctx context.Context, step StepContext, deps ExpandGraphDeps) error {
payload := step.Run.Payload
tenantID := stringField(payload, "tenant_id")
ownerUID := stringField(payload, "owner_uid")
brandID := brandIDFromPayload(payload)
seed := stringField(payload, "seed_query")
supplemental := boolField(payload, "supplemental")
if tenantID == "" || ownerUID == "" || brandID == "" {
return fmt.Errorf("expand-graph payload missing tenant_id, owner_uid, or brand_id")
}
if seed == "" {
return fmt.Errorf("expand-graph payload missing seed_query")
}
brand, err := deps.Brand.Get(ctx, tenantID, ownerUID, brandID)
if err != nil {
return err
}
productBrief := strings.TrimSpace(brand.ProductBrief)
if formatted := placement.ProductBriefFromContext(brand.ProductContext); formatted != "" {
productBrief = formatted
}
research, err := deps.Placement.ResearchSettings(ctx, tenantID, ownerUID)
if err != nil {
return err
}
memberCtx, err := deps.ThreadsAccount.ResolveMemberPlacementContext(ctx, tenantID, ownerUID, research)
if err != nil {
return err
}
braveClient := libbrave.NewClient(memberCtx.BraveAPIKey)
credential, err := deps.ThreadsAccount.ResolveMemberAiCredential(ctx, tenantID, ownerUID)
if err != nil {
return err
}
providerID, err := aiusecase.MapWorkerProvider(credential.Provider)
if err != nil {
return err
}
updateProgress := func(summary string, percentage int) {
_ = step.Heartbeat(ctx)
_, _ = deps.Jobs.UpdateProgress(ctx, jobdom.UpdateProgressRequest{
JobID: step.JobID,
WorkerID: step.WorkerID,
Phase: "expand",
Summary: summary,
Percentage: percentage,
})
}
bootstrap := boolField(payload, "bootstrap")
if bootstrap || brand.ResearchMap.IsEmpty() {
updateProgress("產生研究地圖…", 5)
if err := ensureResearchMap(ctx, step, deps, brand, productBrief, providerID, credential, updateProgress); err != nil {
return err
}
brand, err = deps.Brand.Get(ctx, tenantID, ownerUID, brandID)
if err != nil {
return err
}
}
updateProgress("規劃 Brave 查詢…", 10)
var existing *kgusecase.GraphSummary
if supplemental {
existing, _ = deps.KnowledgeGraph.Get(ctx, tenantID, ownerUID, brandID)
}
l1Labels := []string{}
if existing != nil {
l1Labels = libkg.L1LabelsFromNodes(existing.Nodes)
}
queries := libkg.PlanQueries(libkg.PlanInput{
Seed: seed,
TargetAudience: brand.TargetAudience,
ProductBrief: productBrief,
L1Labels: l1Labels,
Supplemental: supplemental,
})
updateProgress(fmt.Sprintf("Brave 知識擴展(%d 查詢)…", len(queries)), 25)
braveSources, err := runBraveKnowledgeExpand(ctx, braveClient, memberCtx, queries, func(i, total int) {
pct := 25 + ((i + 1) * 30 / max(total, 1))
updateProgress(fmt.Sprintf("Brave 查詢進行中 %d/%d…", i+1, total), pct)
}, func() error {
cancelled, _ := deps.Jobs.IsCancelRequested(ctx, step.JobID)
if cancelled {
return errJobCancelled
}
return ctx.Err()
})
if err != nil {
return err
}
updateProgress("AI 合成知識圖譜…", 60)
systemPrompt, err := libprompt.KnowledgeGraphSystem()
if err != nil {
return app.For(code.AI).SysInternal("knowledge graph prompt load failed")
}
userPrompt, err := libkg.BuildUserPrompt(libkg.SynthInput{
Seed: seed,
ProductBrief: productBrief,
TargetAudience: brand.TargetAudience,
Persona: brand.Brief,
Sources: braveSources,
})
if err != nil {
return app.For(code.AI).SysInternal("knowledge graph user prompt load failed")
}
result, err := deps.AI.GenerateText(ctx, domai.GenerateRequest{
Provider: providerID,
Model: credential.Model,
Credential: domai.Credential{
APIKey: credential.APIKey,
},
System: systemPrompt,
Messages: []domai.Message{
{Role: "user", Content: userPrompt},
},
})
if err != nil {
return err
}
graph, err := libkg.ParseSynthOutput(result.Text, libkg.SynthInput{
Seed: seed,
ProductBrief: productBrief,
TargetAudience: brand.TargetAudience,
}, braveSources)
if err != nil {
return app.For(code.AI).SvcThirdParty("知識圖譜 LLM 回傳無法解析:" + err.Error())
}
if supplemental && existing != nil {
graph = mergeGraphs(existing, graph, braveSources)
}
needsSupplemental := graph.PainTagCount < libkg.MinPainTagCandidates() && !supplemental
if needsSupplemental {
updateProgress(fmt.Sprintf("痛點 tag 僅 %d執行補充查詢…", graph.PainTagCount), 75)
suppQueries := libkg.PlanQueries(libkg.PlanInput{
Seed: seed,
L1Labels: libkg.L1LabelsFromNodes(graph.Nodes),
Supplemental: true,
})
extraSources, err := runBraveKnowledgeExpand(ctx, braveClient, memberCtx, suppQueries, nil, func() error {
cancelled, _ := deps.Jobs.IsCancelRequested(ctx, step.JobID)
if cancelled {
return errJobCancelled
}
return ctx.Err()
})
if err != nil {
return err
}
braveSources = append(braveSources, extraSources...)
suppInstruction, err := libprompt.KnowledgeGraphSupplemental()
if err != nil {
return app.For(code.AI).SysInternal("knowledge graph supplemental prompt load failed")
}
suppUserPrompt, err := libkg.BuildUserPrompt(libkg.SynthInput{
Seed: seed,
ProductBrief: productBrief,
TargetAudience: brand.TargetAudience,
Persona: brand.Brief,
Sources: braveSources,
})
if err != nil {
return err
}
suppResult, err := deps.AI.GenerateText(ctx, domai.GenerateRequest{
Provider: providerID,
Model: credential.Model,
Credential: domai.Credential{
APIKey: credential.APIKey,
},
System: systemPrompt,
Messages: []domai.Message{
{Role: "user", Content: suppUserPrompt + "\n\n" + suppInstruction},
},
})
if err == nil {
if patched, parseErr := libkg.ParseSynthOutput(suppResult.Text, libkg.SynthInput{Seed: seed}, braveSources); parseErr == nil {
graph = mergeGraphs(&kgusecase.GraphSummary{
Seed: graph.Seed,
Nodes: graph.Nodes,
Edges: graph.Edges,
}, patched, braveSources)
}
}
needsSupplemental = graph.PainTagCount < libkg.MinPainTagCandidates()
}
updateProgress("寫入知識圖譜…", 90)
graph.BraveSources = braveSources
now := clock.NowUnixNano()
saved, err := deps.KnowledgeGraph.Upsert(ctx, kgusecase.UpsertRequest{
TenantID: tenantID,
OwnerUID: ownerUID,
BrandID: brandID,
Seed: graph.Seed,
Nodes: graph.Nodes,
Edges: graph.Edges,
BraveSources: graph.BraveSources,
PainTagCount: graph.PainTagCount,
GeneratedAt: now,
})
if err != nil {
return err
}
handoff := map[string]any{
"flow": "placement",
"brand_id": brandID,
"pain_tag_count": saved.PainTagCount,
"summary": fmt.Sprintf("圖譜 %d 節點,痛點候選 %d", len(saved.Nodes), saved.PainTagCount),
"next_route": "/research?brand=" + brandID,
"needs_supplemental_expand": needsSupplemental,
"search_source_mode": string(memberCtx.SearchSourceMode),
"dev_mode": memberCtx.DevMode,
}
_, err = deps.Jobs.CompleteRun(ctx, jobdom.CompleteRunRequest{
JobID: step.JobID,
WorkerID: step.WorkerID,
Result: map[string]any{
"graph_id": saved.ID,
"seed": saved.Seed,
"pain_tag_count": saved.PainTagCount,
"node_count": len(saved.Nodes),
"search_source_mode": string(memberCtx.SearchSourceMode),
"handoff": handoff,
},
})
return err
}
func runBraveKnowledgeExpand(
ctx context.Context,
client *libbrave.Client,
member placement.MemberContext,
queries []string,
onProgress func(i, total int),
heartbeat func() error,
) ([]libkg.BraveSource, error) {
if client == nil || !client.Enabled() {
return nil, app.For(code.Setting).InputMissingRequired("請在設定頁設定 Brave Search API key跟隨此登入帳號")
}
out := make([]libkg.BraveSource, 0, len(queries)*2)
for i, query := range queries {
if heartbeat != nil {
if err := heartbeat(); err != nil {
return nil, err
}
}
res, _ := client.Search(ctx, libbrave.SearchOptions{
Query: query,
Limit: 3,
Mode: libbrave.ModeKnowledgeExpand,
Country: member.BraveCountry,
SearchLang: member.BraveSearchLang,
})
for _, item := range res.Results {
out = append(out, libkg.BraveSource{
Query: query,
Snippet: item.Snippet,
URL: item.URL,
Title: item.Title,
})
}
if onProgress != nil {
onProgress(i, len(queries))
}
}
if len(out) == 0 {
return nil, app.For(code.Setting).SvcThirdParty("Brave 查詢無結果,請確認 API key 或稍後重試")
}
return out, nil
}
func mergeGraphs(existing *kgusecase.GraphSummary, incoming libkg.Graph, extraSources []libkg.BraveSource) libkg.Graph {
if existing == nil {
return incoming
}
merged := libkg.Graph{
Seed: existing.Seed,
Nodes: append([]libkg.Node{}, existing.Nodes...),
Edges: append([]libkg.Edge{}, existing.Edges...),
BraveSources: append([]libkg.BraveSource{}, existing.BraveSources...),
}
seenLabel := map[string]struct{}{}
for _, node := range merged.Nodes {
seenLabel[strings.ToLower(strings.TrimSpace(node.Label))] = struct{}{}
}
for _, node := range incoming.Nodes {
key := strings.ToLower(strings.TrimSpace(node.Label))
if _, ok := seenLabel[key]; ok {
continue
}
seenLabel[key] = struct{}{}
merged.Nodes = append(merged.Nodes, node)
}
edgeSeen := map[string]struct{}{}
for _, edge := range merged.Edges {
edgeSeen[edge.From+"->"+edge.To] = struct{}{}
}
for _, edge := range incoming.Edges {
key := edge.From + "->" + edge.To
if _, ok := edgeSeen[key]; ok {
continue
}
edgeSeen[key] = struct{}{}
merged.Edges = append(merged.Edges, edge)
}
merged.BraveSources = append(merged.BraveSources, extraSources...)
libkg.DeriveSearchTagsFromGraph(&merged)
return merged
}
func stringField(payload map[string]any, key string) string {
if payload == nil {
return ""
}
raw, ok := payload[key]
if !ok || raw == nil {
return ""
}
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
default:
return strings.TrimSpace(fmt.Sprint(v))
}
}
func boolField(payload map[string]any, key string) bool {
if payload == nil {
return false
}
raw, ok := payload[key]
if !ok || raw == nil {
return false
}
switch v := raw.(type) {
case bool:
return v
case string:
return strings.EqualFold(strings.TrimSpace(v), "true")
default:
return false
}
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func ensureResearchMap(
ctx context.Context,
step StepContext,
deps ExpandGraphDeps,
brand *branddomain.BrandSummary,
productBrief string,
providerID enum.ProviderID,
credential *threadsaccountdomain.WorkerAiCredential,
updateProgress func(string, int),
) error {
tenantID := stringField(step.Run.Payload, "tenant_id")
ownerUID := stringField(step.Run.Payload, "owner_uid")
if tenantID == "" || ownerUID == "" || brand == nil {
return fmt.Errorf("research map: missing actor or brand")
}
result, err := deps.AI.GenerateText(ctx, domai.GenerateRequest{
Provider: providerID,
Model: credential.Model,
Credential: domai.Credential{
APIKey: credential.APIKey,
},
System: placement.BuildResearchMapSystemPrompt(),
Messages: []domai.Message{
{
Role: "user",
Content: placement.BuildResearchMapUserPrompt(placement.ResearchMapInput{
Label: brand.DisplayName,
SeedQuery: brand.SeedQuery,
Brief: brand.Brief,
ProductContext: brand.ProductContext,
}),
},
},
})
if err != nil {
return err
}
parsed, err := placement.ParseResearchMapOutput(result.Text)
if err != nil {
return app.For(code.AI).SvcThirdParty("研究地圖 LLM 回傳無法解析:" + err.Error())
}
entityMap := brandentity.ResearchMap{
AudienceSummary: parsed.AudienceSummary,
ContentGoal: parsed.ContentGoal,
Questions: parsed.Questions,
Pillars: parsed.Pillars,
Exclusions: parsed.Exclusions,
}
targetAudience := strings.TrimSpace(brand.TargetAudience)
if targetAudience == "" {
targetAudience = parsed.AudienceSummary
}
patch := branddomain.BrandPatch{
ResearchMap: &entityMap,
}
if targetAudience != "" {
patch.TargetAudience = &targetAudience
}
if strings.TrimSpace(brand.ProductBrief) == "" && productBrief != "" {
patch.ProductBrief = &productBrief
}
_, err = deps.Brand.Update(ctx, branddomain.UpdateRequest{
TenantID: tenantID,
OwnerUID: ownerUID,
BrandID: brand.ID,
Patch: patch,
})
if err != nil {
return err
}
updateProgress("研究地圖已就緒", 8)
return nil
}