498 lines
14 KiB
Go
498 lines
14 KiB
Go
package job
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"strings"
|
||
|
||
libbrave "haixun-backend/internal/library/brave"
|
||
"haixun-backend/internal/library/clock"
|
||
app "haixun-backend/internal/library/errors"
|
||
"haixun-backend/internal/library/errors/code"
|
||
libkg "haixun-backend/internal/library/knowledge"
|
||
"haixun-backend/internal/library/placement"
|
||
libprompt "haixun-backend/internal/library/prompt"
|
||
"haixun-backend/internal/model/ai/domain/enum"
|
||
domai "haixun-backend/internal/model/ai/domain/usecase"
|
||
aiusecase "haixun-backend/internal/model/ai/usecase"
|
||
brandentity "haixun-backend/internal/model/brand/domain/entity"
|
||
branddomain "haixun-backend/internal/model/brand/domain/usecase"
|
||
jobdom "haixun-backend/internal/model/job/domain/usecase"
|
||
kgusecase "haixun-backend/internal/model/knowledge_graph/domain/usecase"
|
||
placementusecase "haixun-backend/internal/model/placement/usecase"
|
||
threadsaccountdomain "haixun-backend/internal/model/threads_account/domain/usecase"
|
||
)
|
||
|
||
type ExpandGraphDeps struct {
|
||
Jobs jobdom.UseCase
|
||
Brand branddomain.UseCase
|
||
KnowledgeGraph kgusecase.UseCase
|
||
ThreadsAccount threadsaccountdomain.UseCase
|
||
Placement placementusecase.UseCase
|
||
AI aiusecase.UseCase
|
||
}
|
||
|
||
func RegisterExpandGraphHandler(runner *Runner, deps ExpandGraphDeps) {
|
||
if runner == nil {
|
||
return
|
||
}
|
||
runner.RegisterStepHandler("expand", func(ctx context.Context, step StepContext) error {
|
||
return runExpandGraph(ctx, step, deps)
|
||
})
|
||
}
|
||
|
||
func brandIDFromPayload(payload map[string]any) string {
|
||
brandID := stringField(payload, "brand_id")
|
||
if brandID == "" {
|
||
brandID = stringField(payload, "persona_id")
|
||
}
|
||
return brandID
|
||
}
|
||
|
||
func runExpandGraph(ctx context.Context, step StepContext, deps ExpandGraphDeps) error {
|
||
payload := step.Run.Payload
|
||
tenantID := stringField(payload, "tenant_id")
|
||
ownerUID := stringField(payload, "owner_uid")
|
||
brandID := brandIDFromPayload(payload)
|
||
seed := stringField(payload, "seed_query")
|
||
supplemental := boolField(payload, "supplemental")
|
||
|
||
if tenantID == "" || ownerUID == "" || brandID == "" {
|
||
return fmt.Errorf("expand-graph payload missing tenant_id, owner_uid, or brand_id")
|
||
}
|
||
if seed == "" {
|
||
return fmt.Errorf("expand-graph payload missing seed_query")
|
||
}
|
||
|
||
brand, err := deps.Brand.Get(ctx, tenantID, ownerUID, brandID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
productBrief := strings.TrimSpace(brand.ProductBrief)
|
||
if formatted := placement.ProductBriefFromContext(brand.ProductContext); formatted != "" {
|
||
productBrief = formatted
|
||
}
|
||
|
||
research, err := deps.Placement.ResearchSettings(ctx, tenantID, ownerUID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
memberCtx, err := deps.ThreadsAccount.ResolveMemberPlacementContext(ctx, tenantID, ownerUID, research)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
braveClient := libbrave.NewClient(memberCtx.BraveAPIKey)
|
||
|
||
credential, err := deps.ThreadsAccount.ResolveMemberAiCredential(ctx, tenantID, ownerUID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
providerID, err := aiusecase.MapWorkerProvider(credential.Provider)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
updateProgress := func(summary string, percentage int) {
|
||
_ = step.Heartbeat(ctx)
|
||
_, _ = deps.Jobs.UpdateProgress(ctx, jobdom.UpdateProgressRequest{
|
||
JobID: step.JobID,
|
||
WorkerID: step.WorkerID,
|
||
Phase: "expand",
|
||
Summary: summary,
|
||
Percentage: percentage,
|
||
})
|
||
}
|
||
|
||
bootstrap := boolField(payload, "bootstrap")
|
||
if bootstrap || brand.ResearchMap.IsEmpty() {
|
||
updateProgress("產生研究地圖…", 5)
|
||
if err := ensureResearchMap(ctx, step, deps, brand, productBrief, providerID, credential, updateProgress); err != nil {
|
||
return err
|
||
}
|
||
brand, err = deps.Brand.Get(ctx, tenantID, ownerUID, brandID)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
}
|
||
|
||
updateProgress("規劃 Brave 查詢…", 10)
|
||
|
||
var existing *kgusecase.GraphSummary
|
||
if supplemental {
|
||
existing, _ = deps.KnowledgeGraph.Get(ctx, tenantID, ownerUID, brandID)
|
||
}
|
||
|
||
l1Labels := []string{}
|
||
if existing != nil {
|
||
l1Labels = libkg.L1LabelsFromNodes(existing.Nodes)
|
||
}
|
||
queries := libkg.PlanQueries(libkg.PlanInput{
|
||
Seed: seed,
|
||
TargetAudience: brand.TargetAudience,
|
||
ProductBrief: productBrief,
|
||
L1Labels: l1Labels,
|
||
Supplemental: supplemental,
|
||
})
|
||
|
||
updateProgress(fmt.Sprintf("Brave 知識擴展(%d 查詢)…", len(queries)), 25)
|
||
|
||
braveSources, err := runBraveKnowledgeExpand(ctx, braveClient, memberCtx, queries, func(i, total int) {
|
||
pct := 25 + ((i + 1) * 30 / max(total, 1))
|
||
updateProgress(fmt.Sprintf("Brave 查詢進行中 %d/%d…", i+1, total), pct)
|
||
}, func() error {
|
||
cancelled, _ := deps.Jobs.IsCancelRequested(ctx, step.JobID)
|
||
if cancelled {
|
||
return errJobCancelled
|
||
}
|
||
return ctx.Err()
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
updateProgress("AI 合成知識圖譜…", 60)
|
||
|
||
systemPrompt, err := libprompt.KnowledgeGraphSystem()
|
||
if err != nil {
|
||
return app.For(code.AI).SysInternal("knowledge graph prompt load failed")
|
||
}
|
||
userPrompt, err := libkg.BuildUserPrompt(libkg.SynthInput{
|
||
Seed: seed,
|
||
ProductBrief: productBrief,
|
||
TargetAudience: brand.TargetAudience,
|
||
Persona: brand.Brief,
|
||
Sources: braveSources,
|
||
})
|
||
if err != nil {
|
||
return app.For(code.AI).SysInternal("knowledge graph user prompt load failed")
|
||
}
|
||
result, err := deps.AI.GenerateText(ctx, domai.GenerateRequest{
|
||
Provider: providerID,
|
||
Model: credential.Model,
|
||
Credential: domai.Credential{
|
||
APIKey: credential.APIKey,
|
||
},
|
||
System: systemPrompt,
|
||
Messages: []domai.Message{
|
||
{Role: "user", Content: userPrompt},
|
||
},
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
graph, err := libkg.ParseSynthOutput(result.Text, libkg.SynthInput{
|
||
Seed: seed,
|
||
ProductBrief: productBrief,
|
||
TargetAudience: brand.TargetAudience,
|
||
}, braveSources)
|
||
if err != nil {
|
||
return app.For(code.AI).SvcThirdParty("知識圖譜 LLM 回傳無法解析:" + err.Error())
|
||
}
|
||
|
||
if supplemental && existing != nil {
|
||
graph = mergeGraphs(existing, graph, braveSources)
|
||
}
|
||
|
||
needsSupplemental := graph.PainTagCount < libkg.MinPainTagCandidates() && !supplemental
|
||
if needsSupplemental {
|
||
updateProgress(fmt.Sprintf("痛點 tag 僅 %d,執行補充查詢…", graph.PainTagCount), 75)
|
||
suppQueries := libkg.PlanQueries(libkg.PlanInput{
|
||
Seed: seed,
|
||
L1Labels: libkg.L1LabelsFromNodes(graph.Nodes),
|
||
Supplemental: true,
|
||
})
|
||
extraSources, err := runBraveKnowledgeExpand(ctx, braveClient, memberCtx, suppQueries, nil, func() error {
|
||
cancelled, _ := deps.Jobs.IsCancelRequested(ctx, step.JobID)
|
||
if cancelled {
|
||
return errJobCancelled
|
||
}
|
||
return ctx.Err()
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
braveSources = append(braveSources, extraSources...)
|
||
suppInstruction, err := libprompt.KnowledgeGraphSupplemental()
|
||
if err != nil {
|
||
return app.For(code.AI).SysInternal("knowledge graph supplemental prompt load failed")
|
||
}
|
||
suppUserPrompt, err := libkg.BuildUserPrompt(libkg.SynthInput{
|
||
Seed: seed,
|
||
ProductBrief: productBrief,
|
||
TargetAudience: brand.TargetAudience,
|
||
Persona: brand.Brief,
|
||
Sources: braveSources,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
suppResult, err := deps.AI.GenerateText(ctx, domai.GenerateRequest{
|
||
Provider: providerID,
|
||
Model: credential.Model,
|
||
Credential: domai.Credential{
|
||
APIKey: credential.APIKey,
|
||
},
|
||
System: systemPrompt,
|
||
Messages: []domai.Message{
|
||
{Role: "user", Content: suppUserPrompt + "\n\n" + suppInstruction},
|
||
},
|
||
})
|
||
if err == nil {
|
||
if patched, parseErr := libkg.ParseSynthOutput(suppResult.Text, libkg.SynthInput{Seed: seed}, braveSources); parseErr == nil {
|
||
graph = mergeGraphs(&kgusecase.GraphSummary{
|
||
Seed: graph.Seed,
|
||
Nodes: graph.Nodes,
|
||
Edges: graph.Edges,
|
||
}, patched, braveSources)
|
||
}
|
||
}
|
||
needsSupplemental = graph.PainTagCount < libkg.MinPainTagCandidates()
|
||
}
|
||
|
||
updateProgress("寫入知識圖譜…", 90)
|
||
|
||
graph.BraveSources = braveSources
|
||
now := clock.NowUnixNano()
|
||
saved, err := deps.KnowledgeGraph.Upsert(ctx, kgusecase.UpsertRequest{
|
||
TenantID: tenantID,
|
||
OwnerUID: ownerUID,
|
||
BrandID: brandID,
|
||
Seed: graph.Seed,
|
||
Nodes: graph.Nodes,
|
||
Edges: graph.Edges,
|
||
BraveSources: graph.BraveSources,
|
||
PainTagCount: graph.PainTagCount,
|
||
GeneratedAt: now,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
handoff := map[string]any{
|
||
"flow": "placement",
|
||
"brand_id": brandID,
|
||
"pain_tag_count": saved.PainTagCount,
|
||
"summary": fmt.Sprintf("圖譜 %d 節點,痛點候選 %d", len(saved.Nodes), saved.PainTagCount),
|
||
"next_route": "/research?brand=" + brandID,
|
||
"needs_supplemental_expand": needsSupplemental,
|
||
"search_source_mode": string(memberCtx.SearchSourceMode),
|
||
"dev_mode": memberCtx.DevMode,
|
||
}
|
||
_, err = deps.Jobs.CompleteRun(ctx, jobdom.CompleteRunRequest{
|
||
JobID: step.JobID,
|
||
WorkerID: step.WorkerID,
|
||
Result: map[string]any{
|
||
"graph_id": saved.ID,
|
||
"seed": saved.Seed,
|
||
"pain_tag_count": saved.PainTagCount,
|
||
"node_count": len(saved.Nodes),
|
||
"search_source_mode": string(memberCtx.SearchSourceMode),
|
||
"handoff": handoff,
|
||
},
|
||
})
|
||
return err
|
||
}
|
||
|
||
func runBraveKnowledgeExpand(
|
||
ctx context.Context,
|
||
client *libbrave.Client,
|
||
member placement.MemberContext,
|
||
queries []string,
|
||
onProgress func(i, total int),
|
||
heartbeat func() error,
|
||
) ([]libkg.BraveSource, error) {
|
||
if client == nil || !client.Enabled() {
|
||
return nil, app.For(code.Setting).InputMissingRequired("請在設定頁設定 Brave Search API key(跟隨此登入帳號)")
|
||
}
|
||
out := make([]libkg.BraveSource, 0, len(queries)*2)
|
||
for i, query := range queries {
|
||
if heartbeat != nil {
|
||
if err := heartbeat(); err != nil {
|
||
return nil, err
|
||
}
|
||
}
|
||
res, _ := client.Search(ctx, libbrave.SearchOptions{
|
||
Query: query,
|
||
Limit: 3,
|
||
Mode: libbrave.ModeKnowledgeExpand,
|
||
Country: member.BraveCountry,
|
||
SearchLang: member.BraveSearchLang,
|
||
})
|
||
for _, item := range res.Results {
|
||
out = append(out, libkg.BraveSource{
|
||
Query: query,
|
||
Snippet: item.Snippet,
|
||
URL: item.URL,
|
||
Title: item.Title,
|
||
})
|
||
}
|
||
if onProgress != nil {
|
||
onProgress(i, len(queries))
|
||
}
|
||
}
|
||
if len(out) == 0 {
|
||
return nil, app.For(code.Setting).SvcThirdParty("Brave 查詢無結果,請確認 API key 或稍後重試")
|
||
}
|
||
return out, nil
|
||
}
|
||
|
||
func mergeGraphs(existing *kgusecase.GraphSummary, incoming libkg.Graph, extraSources []libkg.BraveSource) libkg.Graph {
|
||
if existing == nil {
|
||
return incoming
|
||
}
|
||
merged := libkg.Graph{
|
||
Seed: existing.Seed,
|
||
Nodes: append([]libkg.Node{}, existing.Nodes...),
|
||
Edges: append([]libkg.Edge{}, existing.Edges...),
|
||
BraveSources: append([]libkg.BraveSource{}, existing.BraveSources...),
|
||
}
|
||
seenLabel := map[string]struct{}{}
|
||
for _, node := range merged.Nodes {
|
||
seenLabel[strings.ToLower(strings.TrimSpace(node.Label))] = struct{}{}
|
||
}
|
||
for _, node := range incoming.Nodes {
|
||
key := strings.ToLower(strings.TrimSpace(node.Label))
|
||
if _, ok := seenLabel[key]; ok {
|
||
continue
|
||
}
|
||
seenLabel[key] = struct{}{}
|
||
merged.Nodes = append(merged.Nodes, node)
|
||
}
|
||
edgeSeen := map[string]struct{}{}
|
||
for _, edge := range merged.Edges {
|
||
edgeSeen[edge.From+"->"+edge.To] = struct{}{}
|
||
}
|
||
for _, edge := range incoming.Edges {
|
||
key := edge.From + "->" + edge.To
|
||
if _, ok := edgeSeen[key]; ok {
|
||
continue
|
||
}
|
||
edgeSeen[key] = struct{}{}
|
||
merged.Edges = append(merged.Edges, edge)
|
||
}
|
||
merged.BraveSources = append(merged.BraveSources, extraSources...)
|
||
libkg.DeriveSearchTagsFromGraph(&merged)
|
||
return merged
|
||
}
|
||
|
||
func stringField(payload map[string]any, key string) string {
|
||
if payload == nil {
|
||
return ""
|
||
}
|
||
raw, ok := payload[key]
|
||
if !ok || raw == nil {
|
||
return ""
|
||
}
|
||
switch v := raw.(type) {
|
||
case string:
|
||
return strings.TrimSpace(v)
|
||
default:
|
||
return strings.TrimSpace(fmt.Sprint(v))
|
||
}
|
||
}
|
||
|
||
func boolField(payload map[string]any, key string) bool {
|
||
if payload == nil {
|
||
return false
|
||
}
|
||
raw, ok := payload[key]
|
||
if !ok || raw == nil {
|
||
return false
|
||
}
|
||
switch v := raw.(type) {
|
||
case bool:
|
||
return v
|
||
case string:
|
||
return strings.EqualFold(strings.TrimSpace(v), "true")
|
||
default:
|
||
return false
|
||
}
|
||
}
|
||
|
||
func max(a, b int) int {
|
||
if a > b {
|
||
return a
|
||
}
|
||
return b
|
||
}
|
||
|
||
func ensureResearchMap(
|
||
ctx context.Context,
|
||
step StepContext,
|
||
deps ExpandGraphDeps,
|
||
brand *branddomain.BrandSummary,
|
||
productBrief string,
|
||
providerID enum.ProviderID,
|
||
credential *threadsaccountdomain.WorkerAiCredential,
|
||
updateProgress func(string, int),
|
||
) error {
|
||
tenantID := stringField(step.Run.Payload, "tenant_id")
|
||
ownerUID := stringField(step.Run.Payload, "owner_uid")
|
||
if tenantID == "" || ownerUID == "" || brand == nil {
|
||
return fmt.Errorf("research map: missing actor or brand")
|
||
}
|
||
|
||
result, err := deps.AI.GenerateText(ctx, domai.GenerateRequest{
|
||
Provider: providerID,
|
||
Model: credential.Model,
|
||
Credential: domai.Credential{
|
||
APIKey: credential.APIKey,
|
||
},
|
||
System: placement.BuildResearchMapSystemPrompt(),
|
||
Messages: []domai.Message{
|
||
{
|
||
Role: "user",
|
||
Content: placement.BuildResearchMapUserPrompt(placement.ResearchMapInput{
|
||
Label: brand.DisplayName,
|
||
SeedQuery: brand.SeedQuery,
|
||
Brief: brand.Brief,
|
||
ProductContext: brand.ProductContext,
|
||
}),
|
||
},
|
||
},
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
parsed, err := placement.ParseResearchMapOutput(result.Text)
|
||
if err != nil {
|
||
return app.For(code.AI).SvcThirdParty("研究地圖 LLM 回傳無法解析:" + err.Error())
|
||
}
|
||
|
||
entityMap := brandentity.ResearchMap{
|
||
AudienceSummary: parsed.AudienceSummary,
|
||
ContentGoal: parsed.ContentGoal,
|
||
Questions: parsed.Questions,
|
||
Pillars: parsed.Pillars,
|
||
Exclusions: parsed.Exclusions,
|
||
}
|
||
targetAudience := strings.TrimSpace(brand.TargetAudience)
|
||
if targetAudience == "" {
|
||
targetAudience = parsed.AudienceSummary
|
||
}
|
||
patch := branddomain.BrandPatch{
|
||
ResearchMap: &entityMap,
|
||
}
|
||
if targetAudience != "" {
|
||
patch.TargetAudience = &targetAudience
|
||
}
|
||
if strings.TrimSpace(brand.ProductBrief) == "" && productBrief != "" {
|
||
patch.ProductBrief = &productBrief
|
||
}
|
||
|
||
_, err = deps.Brand.Update(ctx, branddomain.UpdateRequest{
|
||
TenantID: tenantID,
|
||
OwnerUID: ownerUID,
|
||
BrandID: brand.ID,
|
||
Patch: patch,
|
||
})
|
||
if err != nil {
|
||
return err
|
||
}
|
||
updateProgress("研究地圖已就緒", 8)
|
||
return nil
|
||
}
|