thread-master/backend/internal/library/knowledge/brave_collect.go

294 lines
6.3 KiB
Go
Raw Normal View History

2026-06-26 08:37:04 +00:00
package knowledge
import (
"context"
"strings"
"sync"
"sync/atomic"
"haixun-backend/internal/library/websearch"
)
type BraveSearchLocale struct {
Country string
SearchLang string
UserLocation string
}
type BraveCollectConfig struct {
ResultsPerQuery int
MinSourcesBeforeStop int
MaxSourcesCap int
Concurrency int
}
func BraveCollectConfigFromQueryCfg(cfg queryConfig) BraveCollectConfig {
out := BraveCollectConfig{
ResultsPerQuery: cfg.ResultsPerQuery,
MinSourcesBeforeStop: cfg.MinSourcesBeforeStop,
MaxSourcesCap: cfg.MaxSourcesCap,
Concurrency: cfg.BraveCollectConcurrency,
}
if out.ResultsPerQuery <= 0 {
out.ResultsPerQuery = 8
}
if out.MinSourcesBeforeStop <= 0 {
out.MinSourcesBeforeStop = 18
}
if out.MaxSourcesCap <= 0 {
out.MaxSourcesCap = 32
}
if out.Concurrency <= 0 {
out.Concurrency = 4
}
return out
}
func DefaultBraveCollectConfig() BraveCollectConfig {
cfg, err := loadQueryConfig()
if err != nil {
return BraveCollectConfig{ResultsPerQuery: 8, MinSourcesBeforeStop: 18, MaxSourcesCap: 32, Concurrency: 4}
}
return BraveCollectConfigFromQueryCfg(cfg)
}
func CollectBraveSources(
ctx context.Context,
client websearch.Client,
locale BraveSearchLocale,
queries []string,
cfg BraveCollectConfig,
onProgress func(i, total int),
heartbeat func() error,
) []BraveSource {
return CollectWebSources(ctx, client, locale, queries, cfg, onProgress, heartbeat)
}
func CollectWebSources(
ctx context.Context,
client websearch.Client,
locale BraveSearchLocale,
queries []string,
cfg BraveCollectConfig,
onProgress func(i, total int),
heartbeat func() error,
) []BraveSource {
if client == nil || !client.Enabled() || len(queries) == 0 {
return nil
}
if cfg.Concurrency <= 1 {
return collectWebSourcesSequential(ctx, client, locale, queries, cfg, onProgress, heartbeat)
}
return collectWebSourcesParallel(ctx, client, locale, queries, cfg, onProgress, heartbeat)
}
func collectWebSourcesSequential(
ctx context.Context,
client websearch.Client,
locale BraveSearchLocale,
queries []string,
cfg BraveCollectConfig,
onProgress func(i, total int),
heartbeat func() error,
) []BraveSource {
capHint := cfg.MaxSourcesCap
if est := len(queries) * cfg.ResultsPerQuery; est < capHint {
capHint = est
}
out := make([]BraveSource, 0, capHint)
seenURL := map[string]struct{}{}
for i, query := range queries {
if shouldStopCollect(out, cfg) {
break
}
if heartbeat != nil {
if err := heartbeat(); err != nil {
return out
}
}
appendBraveResults(&out, seenURL, query, searchWebQuery(ctx, client, locale, query, cfg.ResultsPerQuery))
if onProgress != nil {
onProgress(i, len(queries))
}
}
return out
}
type braveCollectState struct {
cfg BraveCollectConfig
mu sync.Mutex
out []BraveSource
seenURL map[string]struct{}
stop bool
completed int32
}
func (s *braveCollectState) shouldStop(cfg BraveCollectConfig) bool {
s.mu.Lock()
defer s.mu.Unlock()
if s.stop {
return true
}
if shouldStopCollect(s.out, cfg) {
s.stop = true
return true
}
return false
}
func (s *braveCollectState) appendResults(query string, items []BraveSource) {
s.mu.Lock()
defer s.mu.Unlock()
if s.stop {
return
}
appendBraveResults(&s.out, s.seenURL, query, items)
if shouldStopCollect(s.out, s.cfg) {
s.stop = true
}
}
func collectWebSourcesParallel(
ctx context.Context,
client websearch.Client,
locale BraveSearchLocale,
queries []string,
cfg BraveCollectConfig,
onProgress func(i, total int),
heartbeat func() error,
) []BraveSource {
state := &braveCollectState{
cfg: cfg,
out: make([]BraveSource, 0, cfg.MaxSourcesCap),
seenURL: map[string]struct{}{},
}
workers := cfg.Concurrency
if workers > len(queries) {
workers = len(queries)
}
if workers <= 0 {
workers = 1
}
jobs := make(chan int, len(queries))
for i := range queries {
jobs <- i
}
close(jobs)
var wg sync.WaitGroup
for w := 0; w < workers; w++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := range jobs {
if state.shouldStop(cfg) {
return
}
if heartbeat != nil {
if err := heartbeat(); err != nil {
return
}
}
query := queries[i]
items := searchWebQuery(ctx, client, locale, query, cfg.ResultsPerQuery)
state.appendResults(query, items)
done := int(atomic.AddInt32(&state.completed, 1))
if onProgress != nil {
onProgress(done-1, len(queries))
}
}
}()
}
wg.Wait()
return state.out
}
func shouldStopCollect(out []BraveSource, cfg BraveCollectConfig) bool {
if len(out) >= cfg.MaxSourcesCap {
return true
}
return len(out) >= cfg.MinSourcesBeforeStop && uniqueSourceCount(out) >= cfg.MinSourcesBeforeStop
}
func searchWebQuery(
ctx context.Context,
client websearch.Client,
locale BraveSearchLocale,
query string,
limit int,
) []BraveSource {
res, _ := client.Search(ctx, websearch.SearchOptions{
Query: query,
Limit: limit,
Mode: websearch.ModeKnowledgeExpand,
Country: locale.Country,
SearchLang: locale.SearchLang,
UserLocation: locale.UserLocation,
})
items := make([]BraveSource, 0, len(res.Results))
for _, item := range res.Results {
url := strings.TrimSpace(item.URL)
if url == "" {
continue
}
items = append(items, BraveSource{
Query: query,
Snippet: item.Snippet,
URL: url,
Title: item.Title,
})
}
return items
}
func appendBraveResults(out *[]BraveSource, seenURL map[string]struct{}, query string, items []BraveSource) {
for _, item := range items {
url := strings.TrimSpace(item.URL)
if url == "" {
continue
}
if _, ok := seenURL[url]; ok {
continue
}
seenURL[url] = struct{}{}
if item.Query == "" {
item.Query = query
}
*out = append(*out, item)
}
}
func MergeBraveSources(chunks ...[]BraveSource) []BraveSource {
seen := map[string]struct{}{}
out := make([]BraveSource, 0)
for _, chunk := range chunks {
for _, item := range chunk {
url := strings.TrimSpace(item.URL)
if url == "" {
continue
}
if _, ok := seen[url]; ok {
continue
}
seen[url] = struct{}{}
out = append(out, item)
}
}
return out
}
func uniqueSourceCount(items []BraveSource) int {
seen := map[string]struct{}{}
for _, item := range items {
url := strings.TrimSpace(item.URL)
if url == "" {
continue
}
seen[url] = struct{}{}
}
return len(seen)
}