haixunMaster/haixun-backend/internal/worker/job/expand_graph.go

498 lines
14 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package job
import (
"context"
"fmt"
"strings"
libbrave "haixun-backend/internal/library/brave"
"haixun-backend/internal/library/clock"
app "haixun-backend/internal/library/errors"
"haixun-backend/internal/library/errors/code"
libkg "haixun-backend/internal/library/knowledge"
"haixun-backend/internal/library/placement"
libprompt "haixun-backend/internal/library/prompt"
"haixun-backend/internal/model/ai/domain/enum"
domai "haixun-backend/internal/model/ai/domain/usecase"
aiusecase "haixun-backend/internal/model/ai/usecase"
brandentity "haixun-backend/internal/model/brand/domain/entity"
branddomain "haixun-backend/internal/model/brand/domain/usecase"
jobdom "haixun-backend/internal/model/job/domain/usecase"
kgusecase "haixun-backend/internal/model/knowledge_graph/domain/usecase"
placementusecase "haixun-backend/internal/model/placement/usecase"
threadsaccountdomain "haixun-backend/internal/model/threads_account/domain/usecase"
)
type ExpandGraphDeps struct {
Jobs jobdom.UseCase
Brand branddomain.UseCase
KnowledgeGraph kgusecase.UseCase
ThreadsAccount threadsaccountdomain.UseCase
Placement placementusecase.UseCase
AI aiusecase.UseCase
}
func RegisterExpandGraphHandler(runner *Runner, deps ExpandGraphDeps) {
if runner == nil {
return
}
runner.RegisterStepHandler("expand", func(ctx context.Context, step StepContext) error {
return runExpandGraph(ctx, step, deps)
})
}
func brandIDFromPayload(payload map[string]any) string {
brandID := stringField(payload, "brand_id")
if brandID == "" {
brandID = stringField(payload, "persona_id")
}
return brandID
}
func runExpandGraph(ctx context.Context, step StepContext, deps ExpandGraphDeps) error {
payload := step.Run.Payload
tenantID := stringField(payload, "tenant_id")
ownerUID := stringField(payload, "owner_uid")
brandID := brandIDFromPayload(payload)
seed := stringField(payload, "seed_query")
supplemental := boolField(payload, "supplemental")
if tenantID == "" || ownerUID == "" || brandID == "" {
return fmt.Errorf("expand-graph payload missing tenant_id, owner_uid, or brand_id")
}
if seed == "" {
return fmt.Errorf("expand-graph payload missing seed_query")
}
brand, err := deps.Brand.Get(ctx, tenantID, ownerUID, brandID)
if err != nil {
return err
}
productBrief := strings.TrimSpace(brand.ProductBrief)
if formatted := placement.ProductBriefFromContext(brand.ProductContext); formatted != "" {
productBrief = formatted
}
research, err := deps.Placement.ResearchSettings(ctx, tenantID, ownerUID)
if err != nil {
return err
}
memberCtx, err := deps.ThreadsAccount.ResolveMemberPlacementContext(ctx, tenantID, ownerUID, research)
if err != nil {
return err
}
braveClient := libbrave.NewClient(memberCtx.BraveAPIKey)
credential, err := deps.ThreadsAccount.ResolveMemberAiCredential(ctx, tenantID, ownerUID)
if err != nil {
return err
}
providerID, err := aiusecase.MapWorkerProvider(credential.Provider)
if err != nil {
return err
}
updateProgress := func(summary string, percentage int) {
_ = step.Heartbeat(ctx)
_, _ = deps.Jobs.UpdateProgress(ctx, jobdom.UpdateProgressRequest{
JobID: step.JobID,
WorkerID: step.WorkerID,
Phase: "expand",
Summary: summary,
Percentage: percentage,
})
}
bootstrap := boolField(payload, "bootstrap")
if bootstrap || brand.ResearchMap.IsEmpty() {
updateProgress("產生研究地圖…", 5)
if err := ensureResearchMap(ctx, step, deps, brand, productBrief, providerID, credential, updateProgress); err != nil {
return err
}
brand, err = deps.Brand.Get(ctx, tenantID, ownerUID, brandID)
if err != nil {
return err
}
}
updateProgress("規劃 Brave 查詢…", 10)
var existing *kgusecase.GraphSummary
if supplemental {
existing, _ = deps.KnowledgeGraph.Get(ctx, tenantID, ownerUID, brandID)
}
l1Labels := []string{}
if existing != nil {
l1Labels = libkg.L1LabelsFromNodes(existing.Nodes)
}
queries := libkg.PlanQueries(libkg.PlanInput{
Seed: seed,
TargetAudience: brand.TargetAudience,
ProductBrief: productBrief,
L1Labels: l1Labels,
Supplemental: supplemental,
})
updateProgress(fmt.Sprintf("Brave 知識擴展(%d 查詢)…", len(queries)), 25)
braveSources, err := runBraveKnowledgeExpand(ctx, braveClient, memberCtx, queries, func(i, total int) {
pct := 25 + ((i + 1) * 30 / max(total, 1))
updateProgress(fmt.Sprintf("Brave 查詢進行中 %d/%d…", i+1, total), pct)
}, func() error {
cancelled, _ := deps.Jobs.IsCancelRequested(ctx, step.JobID)
if cancelled {
return errJobCancelled
}
return ctx.Err()
})
if err != nil {
return err
}
updateProgress("AI 合成知識圖譜…", 60)
systemPrompt, err := libprompt.KnowledgeGraphSystem()
if err != nil {
return app.For(code.AI).SysInternal("knowledge graph prompt load failed")
}
userPrompt, err := libkg.BuildUserPrompt(libkg.SynthInput{
Seed: seed,
ProductBrief: productBrief,
TargetAudience: brand.TargetAudience,
Persona: brand.Brief,
Sources: braveSources,
})
if err != nil {
return app.For(code.AI).SysInternal("knowledge graph user prompt load failed")
}
result, err := deps.AI.GenerateText(ctx, domai.GenerateRequest{
Provider: providerID,
Model: credential.Model,
Credential: domai.Credential{
APIKey: credential.APIKey,
},
System: systemPrompt,
Messages: []domai.Message{
{Role: "user", Content: userPrompt},
},
})
if err != nil {
return err
}
graph, err := libkg.ParseSynthOutput(result.Text, libkg.SynthInput{
Seed: seed,
ProductBrief: productBrief,
TargetAudience: brand.TargetAudience,
}, braveSources)
if err != nil {
return app.For(code.AI).SvcThirdParty("知識圖譜 LLM 回傳無法解析:" + err.Error())
}
if supplemental && existing != nil {
graph = mergeGraphs(existing, graph, braveSources)
}
needsSupplemental := graph.PainTagCount < libkg.MinPainTagCandidates() && !supplemental
if needsSupplemental {
updateProgress(fmt.Sprintf("痛點 tag 僅 %d執行補充查詢…", graph.PainTagCount), 75)
suppQueries := libkg.PlanQueries(libkg.PlanInput{
Seed: seed,
L1Labels: libkg.L1LabelsFromNodes(graph.Nodes),
Supplemental: true,
})
extraSources, err := runBraveKnowledgeExpand(ctx, braveClient, memberCtx, suppQueries, nil, func() error {
cancelled, _ := deps.Jobs.IsCancelRequested(ctx, step.JobID)
if cancelled {
return errJobCancelled
}
return ctx.Err()
})
if err != nil {
return err
}
braveSources = append(braveSources, extraSources...)
suppInstruction, err := libprompt.KnowledgeGraphSupplemental()
if err != nil {
return app.For(code.AI).SysInternal("knowledge graph supplemental prompt load failed")
}
suppUserPrompt, err := libkg.BuildUserPrompt(libkg.SynthInput{
Seed: seed,
ProductBrief: productBrief,
TargetAudience: brand.TargetAudience,
Persona: brand.Brief,
Sources: braveSources,
})
if err != nil {
return err
}
suppResult, err := deps.AI.GenerateText(ctx, domai.GenerateRequest{
Provider: providerID,
Model: credential.Model,
Credential: domai.Credential{
APIKey: credential.APIKey,
},
System: systemPrompt,
Messages: []domai.Message{
{Role: "user", Content: suppUserPrompt + "\n\n" + suppInstruction},
},
})
if err == nil {
if patched, parseErr := libkg.ParseSynthOutput(suppResult.Text, libkg.SynthInput{Seed: seed}, braveSources); parseErr == nil {
graph = mergeGraphs(&kgusecase.GraphSummary{
Seed: graph.Seed,
Nodes: graph.Nodes,
Edges: graph.Edges,
}, patched, braveSources)
}
}
needsSupplemental = graph.PainTagCount < libkg.MinPainTagCandidates()
}
updateProgress("寫入知識圖譜…", 90)
graph.BraveSources = braveSources
now := clock.NowUnixNano()
saved, err := deps.KnowledgeGraph.Upsert(ctx, kgusecase.UpsertRequest{
TenantID: tenantID,
OwnerUID: ownerUID,
BrandID: brandID,
Seed: graph.Seed,
Nodes: graph.Nodes,
Edges: graph.Edges,
BraveSources: graph.BraveSources,
PainTagCount: graph.PainTagCount,
GeneratedAt: now,
})
if err != nil {
return err
}
handoff := map[string]any{
"flow": "placement",
"brand_id": brandID,
"pain_tag_count": saved.PainTagCount,
"summary": fmt.Sprintf("圖譜 %d 節點,痛點候選 %d", len(saved.Nodes), saved.PainTagCount),
"next_route": "/research?brand=" + brandID,
"needs_supplemental_expand": needsSupplemental,
"search_source_mode": string(memberCtx.SearchSourceMode),
"dev_mode": memberCtx.DevMode,
}
_, err = deps.Jobs.CompleteRun(ctx, jobdom.CompleteRunRequest{
JobID: step.JobID,
WorkerID: step.WorkerID,
Result: map[string]any{
"graph_id": saved.ID,
"seed": saved.Seed,
"pain_tag_count": saved.PainTagCount,
"node_count": len(saved.Nodes),
"search_source_mode": string(memberCtx.SearchSourceMode),
"handoff": handoff,
},
})
return err
}
func runBraveKnowledgeExpand(
ctx context.Context,
client *libbrave.Client,
member placement.MemberContext,
queries []string,
onProgress func(i, total int),
heartbeat func() error,
) ([]libkg.BraveSource, error) {
if client == nil || !client.Enabled() {
return nil, app.For(code.Setting).InputMissingRequired("請在設定頁設定 Brave Search API key跟隨此登入帳號")
}
out := make([]libkg.BraveSource, 0, len(queries)*2)
for i, query := range queries {
if heartbeat != nil {
if err := heartbeat(); err != nil {
return nil, err
}
}
res, _ := client.Search(ctx, libbrave.SearchOptions{
Query: query,
Limit: 3,
Mode: libbrave.ModeKnowledgeExpand,
Country: member.BraveCountry,
SearchLang: member.BraveSearchLang,
})
for _, item := range res.Results {
out = append(out, libkg.BraveSource{
Query: query,
Snippet: item.Snippet,
URL: item.URL,
Title: item.Title,
})
}
if onProgress != nil {
onProgress(i, len(queries))
}
}
if len(out) == 0 {
return nil, app.For(code.Setting).SvcThirdParty("Brave 查詢無結果,請確認 API key 或稍後重試")
}
return out, nil
}
func mergeGraphs(existing *kgusecase.GraphSummary, incoming libkg.Graph, extraSources []libkg.BraveSource) libkg.Graph {
if existing == nil {
return incoming
}
merged := libkg.Graph{
Seed: existing.Seed,
Nodes: append([]libkg.Node{}, existing.Nodes...),
Edges: append([]libkg.Edge{}, existing.Edges...),
BraveSources: append([]libkg.BraveSource{}, existing.BraveSources...),
}
seenLabel := map[string]struct{}{}
for _, node := range merged.Nodes {
seenLabel[strings.ToLower(strings.TrimSpace(node.Label))] = struct{}{}
}
for _, node := range incoming.Nodes {
key := strings.ToLower(strings.TrimSpace(node.Label))
if _, ok := seenLabel[key]; ok {
continue
}
seenLabel[key] = struct{}{}
merged.Nodes = append(merged.Nodes, node)
}
edgeSeen := map[string]struct{}{}
for _, edge := range merged.Edges {
edgeSeen[edge.From+"->"+edge.To] = struct{}{}
}
for _, edge := range incoming.Edges {
key := edge.From + "->" + edge.To
if _, ok := edgeSeen[key]; ok {
continue
}
edgeSeen[key] = struct{}{}
merged.Edges = append(merged.Edges, edge)
}
merged.BraveSources = append(merged.BraveSources, extraSources...)
libkg.DeriveSearchTagsFromGraph(&merged)
return merged
}
func stringField(payload map[string]any, key string) string {
if payload == nil {
return ""
}
raw, ok := payload[key]
if !ok || raw == nil {
return ""
}
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
default:
return strings.TrimSpace(fmt.Sprint(v))
}
}
func boolField(payload map[string]any, key string) bool {
if payload == nil {
return false
}
raw, ok := payload[key]
if !ok || raw == nil {
return false
}
switch v := raw.(type) {
case bool:
return v
case string:
return strings.EqualFold(strings.TrimSpace(v), "true")
default:
return false
}
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func ensureResearchMap(
ctx context.Context,
step StepContext,
deps ExpandGraphDeps,
brand *branddomain.BrandSummary,
productBrief string,
providerID enum.ProviderID,
credential *threadsaccountdomain.WorkerAiCredential,
updateProgress func(string, int),
) error {
tenantID := stringField(step.Run.Payload, "tenant_id")
ownerUID := stringField(step.Run.Payload, "owner_uid")
if tenantID == "" || ownerUID == "" || brand == nil {
return fmt.Errorf("research map: missing actor or brand")
}
result, err := deps.AI.GenerateText(ctx, domai.GenerateRequest{
Provider: providerID,
Model: credential.Model,
Credential: domai.Credential{
APIKey: credential.APIKey,
},
System: placement.BuildResearchMapSystemPrompt(),
Messages: []domai.Message{
{
Role: "user",
Content: placement.BuildResearchMapUserPrompt(placement.ResearchMapInput{
Label: brand.DisplayName,
SeedQuery: brand.SeedQuery,
Brief: brand.Brief,
ProductContext: brand.ProductContext,
}),
},
},
})
if err != nil {
return err
}
parsed, err := placement.ParseResearchMapOutput(result.Text)
if err != nil {
return app.For(code.AI).SvcThirdParty("研究地圖 LLM 回傳無法解析:" + err.Error())
}
entityMap := brandentity.ResearchMap{
AudienceSummary: parsed.AudienceSummary,
ContentGoal: parsed.ContentGoal,
Questions: parsed.Questions,
Pillars: parsed.Pillars,
Exclusions: parsed.Exclusions,
}
targetAudience := strings.TrimSpace(brand.TargetAudience)
if targetAudience == "" {
targetAudience = parsed.AudienceSummary
}
patch := branddomain.BrandPatch{
ResearchMap: &entityMap,
}
if targetAudience != "" {
patch.TargetAudience = &targetAudience
}
if strings.TrimSpace(brand.ProductBrief) == "" && productBrief != "" {
patch.ProductBrief = &productBrief
}
_, err = deps.Brand.Update(ctx, branddomain.UpdateRequest{
TenantID: tenantID,
OwnerUID: ownerUID,
BrandID: brand.ID,
Patch: patch,
})
if err != nil {
return err
}
updateProgress("研究地圖已就緒", 8)
return nil
}