226 lines
5.2 KiB
Go
226 lines
5.2 KiB
Go
|
|
package knowledge
|
||
|
|
|
||
|
|
import (
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"strings"
|
||
|
|
"sync"
|
||
|
|
|
||
|
|
libprompt "haixun-backend/internal/library/prompt"
|
||
|
|
)
|
||
|
|
|
||
|
|
type queryConfig struct {
|
||
|
|
MaxPlanQueries int `json:"max_plan_queries"`
|
||
|
|
MaxSupplemental int `json:"max_supplemental_queries"`
|
||
|
|
MinPainTagCandidates int `json:"min_pain_tag_candidates"`
|
||
|
|
MinTotalTagCandidates int `json:"min_total_tag_candidates"`
|
||
|
|
PlanBase []string `json:"plan_base"`
|
||
|
|
PlanPeripheral []string `json:"plan_peripheral"`
|
||
|
|
PlanAudience string `json:"plan_audience"`
|
||
|
|
PlanL1Cause string `json:"plan_l1_cause"`
|
||
|
|
PlanL1Pain string `json:"plan_l1_pain"`
|
||
|
|
Supplemental []string `json:"supplemental"`
|
||
|
|
SupplementalL1 string `json:"supplemental_l1"`
|
||
|
|
RecencySuffix string `json:"recency_suffix"`
|
||
|
|
RecencyHelpMarkers string `json:"recency_help_markers"`
|
||
|
|
}
|
||
|
|
|
||
|
|
var (
|
||
|
|
queryCfgOnce sync.Once
|
||
|
|
queryCfg queryConfig
|
||
|
|
queryCfgErr error
|
||
|
|
)
|
||
|
|
|
||
|
|
func loadQueryConfig() (queryConfig, error) {
|
||
|
|
queryCfgOnce.Do(func() {
|
||
|
|
raw, err := libprompt.KnowledgeGraphQueryConfig()
|
||
|
|
if err != nil {
|
||
|
|
queryCfgErr = err
|
||
|
|
return
|
||
|
|
}
|
||
|
|
payload, err := json.Marshal(raw)
|
||
|
|
if err != nil {
|
||
|
|
queryCfgErr = err
|
||
|
|
return
|
||
|
|
}
|
||
|
|
queryCfgErr = json.Unmarshal(payload, &queryCfg)
|
||
|
|
})
|
||
|
|
return queryCfg, queryCfgErr
|
||
|
|
}
|
||
|
|
|
||
|
|
func MaxPlanQueriesPerRound() int {
|
||
|
|
cfg, err := loadQueryConfig()
|
||
|
|
if err != nil || cfg.MaxPlanQueries <= 0 {
|
||
|
|
return 15
|
||
|
|
}
|
||
|
|
return cfg.MaxPlanQueries
|
||
|
|
}
|
||
|
|
|
||
|
|
func MaxSupplementalQueries() int {
|
||
|
|
cfg, err := loadQueryConfig()
|
||
|
|
if err != nil || cfg.MaxSupplemental <= 0 {
|
||
|
|
return 5
|
||
|
|
}
|
||
|
|
return cfg.MaxSupplemental
|
||
|
|
}
|
||
|
|
|
||
|
|
func MinPainTagCandidates() int {
|
||
|
|
cfg, err := loadQueryConfig()
|
||
|
|
if err != nil || cfg.MinPainTagCandidates <= 0 {
|
||
|
|
return 8
|
||
|
|
}
|
||
|
|
return cfg.MinPainTagCandidates
|
||
|
|
}
|
||
|
|
|
||
|
|
type PlanInput struct {
|
||
|
|
Seed string
|
||
|
|
TargetAudience string
|
||
|
|
ProductBrief string
|
||
|
|
L1Labels []string
|
||
|
|
Supplemental bool
|
||
|
|
}
|
||
|
|
|
||
|
|
func PlanQueries(in PlanInput) []string {
|
||
|
|
cfg, err := loadQueryConfig()
|
||
|
|
if err != nil {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
seed := strings.TrimSpace(in.Seed)
|
||
|
|
if seed == "" {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
if in.Supplemental {
|
||
|
|
return supplementalQueries(cfg, seed, in.L1Labels)
|
||
|
|
}
|
||
|
|
|
||
|
|
seen := map[string]struct{}{}
|
||
|
|
out := make([]string, 0, cfg.MaxPlanQueries)
|
||
|
|
add := func(q string) {
|
||
|
|
q = strings.TrimSpace(q)
|
||
|
|
if q == "" {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
if _, ok := seen[q]; ok {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
seen[q] = struct{}{}
|
||
|
|
out = append(out, q)
|
||
|
|
}
|
||
|
|
|
||
|
|
vars := map[string]string{"seed": seed, "audience": strings.TrimSpace(in.TargetAudience)}
|
||
|
|
for _, tpl := range cfg.PlanBase {
|
||
|
|
add(renderQueryTemplate(tpl, vars))
|
||
|
|
if len(out) >= cfg.MaxPlanQueries {
|
||
|
|
return capQueries(out, cfg.MaxPlanQueries)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
for _, tpl := range cfg.PlanPeripheral {
|
||
|
|
add(renderQueryTemplate(tpl, vars))
|
||
|
|
if len(out) >= cfg.MaxPlanQueries {
|
||
|
|
return capQueries(out, cfg.MaxPlanQueries)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
if vars["audience"] != "" && strings.TrimSpace(cfg.PlanAudience) != "" {
|
||
|
|
add(renderQueryTemplate(cfg.PlanAudience, vars))
|
||
|
|
}
|
||
|
|
for _, label := range in.L1Labels {
|
||
|
|
label = strings.TrimSpace(label)
|
||
|
|
if label == "" || label == seed {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
l1vars := map[string]string{"seed": seed, "label": label}
|
||
|
|
add(renderQueryTemplate(cfg.PlanL1Cause, l1vars))
|
||
|
|
if len(out) >= cfg.MaxPlanQueries {
|
||
|
|
return capQueries(out, cfg.MaxPlanQueries)
|
||
|
|
}
|
||
|
|
add(renderQueryTemplate(cfg.PlanL1Pain, l1vars))
|
||
|
|
if len(out) >= cfg.MaxPlanQueries {
|
||
|
|
return capQueries(out, cfg.MaxPlanQueries)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return capQueries(out, cfg.MaxPlanQueries)
|
||
|
|
}
|
||
|
|
|
||
|
|
func supplementalQueries(cfg queryConfig, seed string, l1Labels []string) []string {
|
||
|
|
seed = strings.TrimSpace(seed)
|
||
|
|
if seed == "" {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
seen := map[string]struct{}{}
|
||
|
|
out := make([]string, 0, cfg.MaxSupplemental)
|
||
|
|
add := func(q string) {
|
||
|
|
q = strings.TrimSpace(q)
|
||
|
|
if q == "" {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
if _, ok := seen[q]; ok {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
seen[q] = struct{}{}
|
||
|
|
out = append(out, q)
|
||
|
|
}
|
||
|
|
vars := map[string]string{"seed": seed}
|
||
|
|
for _, tpl := range cfg.Supplemental {
|
||
|
|
add(renderQueryTemplate(tpl, vars))
|
||
|
|
}
|
||
|
|
for _, label := range l1Labels {
|
||
|
|
label = strings.TrimSpace(label)
|
||
|
|
if label == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
add(renderQueryTemplate(cfg.SupplementalL1, map[string]string{"seed": seed, "label": label}))
|
||
|
|
if len(out) >= cfg.MaxSupplemental {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return capQueries(out, cfg.MaxSupplemental)
|
||
|
|
}
|
||
|
|
|
||
|
|
func BuildRecencyQuery(label string) string {
|
||
|
|
cfg, err := loadQueryConfig()
|
||
|
|
if err != nil {
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
label = strings.TrimSpace(label)
|
||
|
|
if label == "" {
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
if strings.ContainsAny(label, cfg.RecencyHelpMarkers) {
|
||
|
|
return label
|
||
|
|
}
|
||
|
|
suffix := strings.TrimSpace(cfg.RecencySuffix)
|
||
|
|
if suffix == "" {
|
||
|
|
suffix = "請問"
|
||
|
|
}
|
||
|
|
return fmt.Sprintf("%s %s", label, suffix)
|
||
|
|
}
|
||
|
|
|
||
|
|
func renderQueryTemplate(tpl string, vars map[string]string) string {
|
||
|
|
out := tpl
|
||
|
|
for key, value := range vars {
|
||
|
|
out = strings.ReplaceAll(out, "{{"+key+"}}", value)
|
||
|
|
}
|
||
|
|
return strings.TrimSpace(out)
|
||
|
|
}
|
||
|
|
|
||
|
|
func capQueries(items []string, max int) []string {
|
||
|
|
if max <= 0 || len(items) <= max {
|
||
|
|
return items
|
||
|
|
}
|
||
|
|
return items[:max]
|
||
|
|
}
|
||
|
|
|
||
|
|
func L1LabelsFromNodes(nodes []Node) []string {
|
||
|
|
out := make([]string, 0, len(nodes))
|
||
|
|
for _, node := range nodes {
|
||
|
|
if node.Layer != 1 {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
label := strings.TrimSpace(node.Label)
|
||
|
|
if label != "" {
|
||
|
|
out = append(out, label)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return out
|
||
|
|
}
|