haixunMaster/haixun-backend/internal/library/knowledge/queries.go

226 lines
5.2 KiB
Go

package knowledge
import (
"encoding/json"
"fmt"
"strings"
"sync"
libprompt "haixun-backend/internal/library/prompt"
)
type queryConfig struct {
MaxPlanQueries int `json:"max_plan_queries"`
MaxSupplemental int `json:"max_supplemental_queries"`
MinPainTagCandidates int `json:"min_pain_tag_candidates"`
MinTotalTagCandidates int `json:"min_total_tag_candidates"`
PlanBase []string `json:"plan_base"`
PlanPeripheral []string `json:"plan_peripheral"`
PlanAudience string `json:"plan_audience"`
PlanL1Cause string `json:"plan_l1_cause"`
PlanL1Pain string `json:"plan_l1_pain"`
Supplemental []string `json:"supplemental"`
SupplementalL1 string `json:"supplemental_l1"`
RecencySuffix string `json:"recency_suffix"`
RecencyHelpMarkers string `json:"recency_help_markers"`
}
var (
queryCfgOnce sync.Once
queryCfg queryConfig
queryCfgErr error
)
func loadQueryConfig() (queryConfig, error) {
queryCfgOnce.Do(func() {
raw, err := libprompt.KnowledgeGraphQueryConfig()
if err != nil {
queryCfgErr = err
return
}
payload, err := json.Marshal(raw)
if err != nil {
queryCfgErr = err
return
}
queryCfgErr = json.Unmarshal(payload, &queryCfg)
})
return queryCfg, queryCfgErr
}
func MaxPlanQueriesPerRound() int {
cfg, err := loadQueryConfig()
if err != nil || cfg.MaxPlanQueries <= 0 {
return 15
}
return cfg.MaxPlanQueries
}
func MaxSupplementalQueries() int {
cfg, err := loadQueryConfig()
if err != nil || cfg.MaxSupplemental <= 0 {
return 5
}
return cfg.MaxSupplemental
}
func MinPainTagCandidates() int {
cfg, err := loadQueryConfig()
if err != nil || cfg.MinPainTagCandidates <= 0 {
return 8
}
return cfg.MinPainTagCandidates
}
type PlanInput struct {
Seed string
TargetAudience string
ProductBrief string
L1Labels []string
Supplemental bool
}
func PlanQueries(in PlanInput) []string {
cfg, err := loadQueryConfig()
if err != nil {
return nil
}
seed := strings.TrimSpace(in.Seed)
if seed == "" {
return nil
}
if in.Supplemental {
return supplementalQueries(cfg, seed, in.L1Labels)
}
seen := map[string]struct{}{}
out := make([]string, 0, cfg.MaxPlanQueries)
add := func(q string) {
q = strings.TrimSpace(q)
if q == "" {
return
}
if _, ok := seen[q]; ok {
return
}
seen[q] = struct{}{}
out = append(out, q)
}
vars := map[string]string{"seed": seed, "audience": strings.TrimSpace(in.TargetAudience)}
for _, tpl := range cfg.PlanBase {
add(renderQueryTemplate(tpl, vars))
if len(out) >= cfg.MaxPlanQueries {
return capQueries(out, cfg.MaxPlanQueries)
}
}
for _, tpl := range cfg.PlanPeripheral {
add(renderQueryTemplate(tpl, vars))
if len(out) >= cfg.MaxPlanQueries {
return capQueries(out, cfg.MaxPlanQueries)
}
}
if vars["audience"] != "" && strings.TrimSpace(cfg.PlanAudience) != "" {
add(renderQueryTemplate(cfg.PlanAudience, vars))
}
for _, label := range in.L1Labels {
label = strings.TrimSpace(label)
if label == "" || label == seed {
continue
}
l1vars := map[string]string{"seed": seed, "label": label}
add(renderQueryTemplate(cfg.PlanL1Cause, l1vars))
if len(out) >= cfg.MaxPlanQueries {
return capQueries(out, cfg.MaxPlanQueries)
}
add(renderQueryTemplate(cfg.PlanL1Pain, l1vars))
if len(out) >= cfg.MaxPlanQueries {
return capQueries(out, cfg.MaxPlanQueries)
}
}
return capQueries(out, cfg.MaxPlanQueries)
}
func supplementalQueries(cfg queryConfig, seed string, l1Labels []string) []string {
seed = strings.TrimSpace(seed)
if seed == "" {
return nil
}
seen := map[string]struct{}{}
out := make([]string, 0, cfg.MaxSupplemental)
add := func(q string) {
q = strings.TrimSpace(q)
if q == "" {
return
}
if _, ok := seen[q]; ok {
return
}
seen[q] = struct{}{}
out = append(out, q)
}
vars := map[string]string{"seed": seed}
for _, tpl := range cfg.Supplemental {
add(renderQueryTemplate(tpl, vars))
}
for _, label := range l1Labels {
label = strings.TrimSpace(label)
if label == "" {
continue
}
add(renderQueryTemplate(cfg.SupplementalL1, map[string]string{"seed": seed, "label": label}))
if len(out) >= cfg.MaxSupplemental {
break
}
}
return capQueries(out, cfg.MaxSupplemental)
}
func BuildRecencyQuery(label string) string {
cfg, err := loadQueryConfig()
if err != nil {
return ""
}
label = strings.TrimSpace(label)
if label == "" {
return ""
}
if strings.ContainsAny(label, cfg.RecencyHelpMarkers) {
return label
}
suffix := strings.TrimSpace(cfg.RecencySuffix)
if suffix == "" {
suffix = "請問"
}
return fmt.Sprintf("%s %s", label, suffix)
}
func renderQueryTemplate(tpl string, vars map[string]string) string {
out := tpl
for key, value := range vars {
out = strings.ReplaceAll(out, "{{"+key+"}}", value)
}
return strings.TrimSpace(out)
}
func capQueries(items []string, max int) []string {
if max <= 0 || len(items) <= max {
return items
}
return items[:max]
}
func L1LabelsFromNodes(nodes []Node) []string {
out := make([]string, 0, len(nodes))
for _, node := range nodes {
if node.Layer != 1 {
continue
}
label := strings.TrimSpace(node.Label)
if label != "" {
out = append(out, label)
}
}
return out
}