294 lines
6.3 KiB
Go
294 lines
6.3 KiB
Go
|
|
package knowledge
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
"strings"
|
||
|
|
"sync"
|
||
|
|
"sync/atomic"
|
||
|
|
|
||
|
|
"haixun-backend/internal/library/websearch"
|
||
|
|
)
|
||
|
|
|
||
|
|
type BraveSearchLocale struct {
|
||
|
|
Country string
|
||
|
|
SearchLang string
|
||
|
|
UserLocation string
|
||
|
|
}
|
||
|
|
|
||
|
|
type BraveCollectConfig struct {
|
||
|
|
ResultsPerQuery int
|
||
|
|
MinSourcesBeforeStop int
|
||
|
|
MaxSourcesCap int
|
||
|
|
Concurrency int
|
||
|
|
}
|
||
|
|
|
||
|
|
func BraveCollectConfigFromQueryCfg(cfg queryConfig) BraveCollectConfig {
|
||
|
|
out := BraveCollectConfig{
|
||
|
|
ResultsPerQuery: cfg.ResultsPerQuery,
|
||
|
|
MinSourcesBeforeStop: cfg.MinSourcesBeforeStop,
|
||
|
|
MaxSourcesCap: cfg.MaxSourcesCap,
|
||
|
|
Concurrency: cfg.BraveCollectConcurrency,
|
||
|
|
}
|
||
|
|
if out.ResultsPerQuery <= 0 {
|
||
|
|
out.ResultsPerQuery = 8
|
||
|
|
}
|
||
|
|
if out.MinSourcesBeforeStop <= 0 {
|
||
|
|
out.MinSourcesBeforeStop = 18
|
||
|
|
}
|
||
|
|
if out.MaxSourcesCap <= 0 {
|
||
|
|
out.MaxSourcesCap = 32
|
||
|
|
}
|
||
|
|
if out.Concurrency <= 0 {
|
||
|
|
out.Concurrency = 4
|
||
|
|
}
|
||
|
|
return out
|
||
|
|
}
|
||
|
|
|
||
|
|
func DefaultBraveCollectConfig() BraveCollectConfig {
|
||
|
|
cfg, err := loadQueryConfig()
|
||
|
|
if err != nil {
|
||
|
|
return BraveCollectConfig{ResultsPerQuery: 8, MinSourcesBeforeStop: 18, MaxSourcesCap: 32, Concurrency: 4}
|
||
|
|
}
|
||
|
|
return BraveCollectConfigFromQueryCfg(cfg)
|
||
|
|
}
|
||
|
|
|
||
|
|
func CollectBraveSources(
|
||
|
|
ctx context.Context,
|
||
|
|
client websearch.Client,
|
||
|
|
locale BraveSearchLocale,
|
||
|
|
queries []string,
|
||
|
|
cfg BraveCollectConfig,
|
||
|
|
onProgress func(i, total int),
|
||
|
|
heartbeat func() error,
|
||
|
|
) []BraveSource {
|
||
|
|
return CollectWebSources(ctx, client, locale, queries, cfg, onProgress, heartbeat)
|
||
|
|
}
|
||
|
|
|
||
|
|
func CollectWebSources(
|
||
|
|
ctx context.Context,
|
||
|
|
client websearch.Client,
|
||
|
|
locale BraveSearchLocale,
|
||
|
|
queries []string,
|
||
|
|
cfg BraveCollectConfig,
|
||
|
|
onProgress func(i, total int),
|
||
|
|
heartbeat func() error,
|
||
|
|
) []BraveSource {
|
||
|
|
if client == nil || !client.Enabled() || len(queries) == 0 {
|
||
|
|
return nil
|
||
|
|
}
|
||
|
|
if cfg.Concurrency <= 1 {
|
||
|
|
return collectWebSourcesSequential(ctx, client, locale, queries, cfg, onProgress, heartbeat)
|
||
|
|
}
|
||
|
|
return collectWebSourcesParallel(ctx, client, locale, queries, cfg, onProgress, heartbeat)
|
||
|
|
}
|
||
|
|
|
||
|
|
func collectWebSourcesSequential(
|
||
|
|
ctx context.Context,
|
||
|
|
client websearch.Client,
|
||
|
|
locale BraveSearchLocale,
|
||
|
|
queries []string,
|
||
|
|
cfg BraveCollectConfig,
|
||
|
|
onProgress func(i, total int),
|
||
|
|
heartbeat func() error,
|
||
|
|
) []BraveSource {
|
||
|
|
capHint := cfg.MaxSourcesCap
|
||
|
|
if est := len(queries) * cfg.ResultsPerQuery; est < capHint {
|
||
|
|
capHint = est
|
||
|
|
}
|
||
|
|
out := make([]BraveSource, 0, capHint)
|
||
|
|
seenURL := map[string]struct{}{}
|
||
|
|
|
||
|
|
for i, query := range queries {
|
||
|
|
if shouldStopCollect(out, cfg) {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
if heartbeat != nil {
|
||
|
|
if err := heartbeat(); err != nil {
|
||
|
|
return out
|
||
|
|
}
|
||
|
|
}
|
||
|
|
appendBraveResults(&out, seenURL, query, searchWebQuery(ctx, client, locale, query, cfg.ResultsPerQuery))
|
||
|
|
if onProgress != nil {
|
||
|
|
onProgress(i, len(queries))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return out
|
||
|
|
}
|
||
|
|
|
||
|
|
type braveCollectState struct {
|
||
|
|
cfg BraveCollectConfig
|
||
|
|
mu sync.Mutex
|
||
|
|
out []BraveSource
|
||
|
|
seenURL map[string]struct{}
|
||
|
|
stop bool
|
||
|
|
completed int32
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *braveCollectState) shouldStop(cfg BraveCollectConfig) bool {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
if s.stop {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
if shouldStopCollect(s.out, cfg) {
|
||
|
|
s.stop = true
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
return false
|
||
|
|
}
|
||
|
|
|
||
|
|
func (s *braveCollectState) appendResults(query string, items []BraveSource) {
|
||
|
|
s.mu.Lock()
|
||
|
|
defer s.mu.Unlock()
|
||
|
|
if s.stop {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
appendBraveResults(&s.out, s.seenURL, query, items)
|
||
|
|
if shouldStopCollect(s.out, s.cfg) {
|
||
|
|
s.stop = true
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func collectWebSourcesParallel(
|
||
|
|
ctx context.Context,
|
||
|
|
client websearch.Client,
|
||
|
|
locale BraveSearchLocale,
|
||
|
|
queries []string,
|
||
|
|
cfg BraveCollectConfig,
|
||
|
|
onProgress func(i, total int),
|
||
|
|
heartbeat func() error,
|
||
|
|
) []BraveSource {
|
||
|
|
state := &braveCollectState{
|
||
|
|
cfg: cfg,
|
||
|
|
out: make([]BraveSource, 0, cfg.MaxSourcesCap),
|
||
|
|
seenURL: map[string]struct{}{},
|
||
|
|
}
|
||
|
|
|
||
|
|
workers := cfg.Concurrency
|
||
|
|
if workers > len(queries) {
|
||
|
|
workers = len(queries)
|
||
|
|
}
|
||
|
|
if workers <= 0 {
|
||
|
|
workers = 1
|
||
|
|
}
|
||
|
|
|
||
|
|
jobs := make(chan int, len(queries))
|
||
|
|
for i := range queries {
|
||
|
|
jobs <- i
|
||
|
|
}
|
||
|
|
close(jobs)
|
||
|
|
|
||
|
|
var wg sync.WaitGroup
|
||
|
|
for w := 0; w < workers; w++ {
|
||
|
|
wg.Add(1)
|
||
|
|
go func() {
|
||
|
|
defer wg.Done()
|
||
|
|
for i := range jobs {
|
||
|
|
if state.shouldStop(cfg) {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
if heartbeat != nil {
|
||
|
|
if err := heartbeat(); err != nil {
|
||
|
|
return
|
||
|
|
}
|
||
|
|
}
|
||
|
|
query := queries[i]
|
||
|
|
items := searchWebQuery(ctx, client, locale, query, cfg.ResultsPerQuery)
|
||
|
|
state.appendResults(query, items)
|
||
|
|
done := int(atomic.AddInt32(&state.completed, 1))
|
||
|
|
if onProgress != nil {
|
||
|
|
onProgress(done-1, len(queries))
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
}
|
||
|
|
wg.Wait()
|
||
|
|
return state.out
|
||
|
|
}
|
||
|
|
|
||
|
|
func shouldStopCollect(out []BraveSource, cfg BraveCollectConfig) bool {
|
||
|
|
if len(out) >= cfg.MaxSourcesCap {
|
||
|
|
return true
|
||
|
|
}
|
||
|
|
return len(out) >= cfg.MinSourcesBeforeStop && uniqueSourceCount(out) >= cfg.MinSourcesBeforeStop
|
||
|
|
}
|
||
|
|
|
||
|
|
func searchWebQuery(
|
||
|
|
ctx context.Context,
|
||
|
|
client websearch.Client,
|
||
|
|
locale BraveSearchLocale,
|
||
|
|
query string,
|
||
|
|
limit int,
|
||
|
|
) []BraveSource {
|
||
|
|
res, _ := client.Search(ctx, websearch.SearchOptions{
|
||
|
|
Query: query,
|
||
|
|
Limit: limit,
|
||
|
|
Mode: websearch.ModeKnowledgeExpand,
|
||
|
|
Country: locale.Country,
|
||
|
|
SearchLang: locale.SearchLang,
|
||
|
|
UserLocation: locale.UserLocation,
|
||
|
|
})
|
||
|
|
items := make([]BraveSource, 0, len(res.Results))
|
||
|
|
for _, item := range res.Results {
|
||
|
|
url := strings.TrimSpace(item.URL)
|
||
|
|
if url == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
items = append(items, BraveSource{
|
||
|
|
Query: query,
|
||
|
|
Snippet: item.Snippet,
|
||
|
|
URL: url,
|
||
|
|
Title: item.Title,
|
||
|
|
})
|
||
|
|
}
|
||
|
|
return items
|
||
|
|
}
|
||
|
|
|
||
|
|
func appendBraveResults(out *[]BraveSource, seenURL map[string]struct{}, query string, items []BraveSource) {
|
||
|
|
for _, item := range items {
|
||
|
|
url := strings.TrimSpace(item.URL)
|
||
|
|
if url == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
if _, ok := seenURL[url]; ok {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
seenURL[url] = struct{}{}
|
||
|
|
if item.Query == "" {
|
||
|
|
item.Query = query
|
||
|
|
}
|
||
|
|
*out = append(*out, item)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func MergeBraveSources(chunks ...[]BraveSource) []BraveSource {
|
||
|
|
seen := map[string]struct{}{}
|
||
|
|
out := make([]BraveSource, 0)
|
||
|
|
for _, chunk := range chunks {
|
||
|
|
for _, item := range chunk {
|
||
|
|
url := strings.TrimSpace(item.URL)
|
||
|
|
if url == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
if _, ok := seen[url]; ok {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
seen[url] = struct{}{}
|
||
|
|
out = append(out, item)
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return out
|
||
|
|
}
|
||
|
|
|
||
|
|
func uniqueSourceCount(items []BraveSource) int {
|
||
|
|
seen := map[string]struct{}{}
|
||
|
|
for _, item := range items {
|
||
|
|
url := strings.TrimSpace(item.URL)
|
||
|
|
if url == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
seen[url] = struct{}{}
|
||
|
|
}
|
||
|
|
return len(seen)
|
||
|
|
}
|