192 lines
4.0 KiB
Go
192 lines
4.0 KiB
Go
|
|
package exa
|
||
|
|
|
||
|
|
import (
|
||
|
|
"bytes"
|
||
|
|
"context"
|
||
|
|
"encoding/json"
|
||
|
|
"fmt"
|
||
|
|
"io"
|
||
|
|
"net/http"
|
||
|
|
"strings"
|
||
|
|
"time"
|
||
|
|
)
|
||
|
|
|
||
|
|
const defaultBaseURL = "https://api.exa.ai/search"
|
||
|
|
|
||
|
|
type Mode string
|
||
|
|
|
||
|
|
const (
|
||
|
|
ModeKnowledgeExpand Mode = "knowledge_expand"
|
||
|
|
ModeThreadsDiscover Mode = "threads_discover"
|
||
|
|
)
|
||
|
|
|
||
|
|
type SearchResult struct {
|
||
|
|
Title string
|
||
|
|
Snippet string
|
||
|
|
URL string
|
||
|
|
PublishedDate string
|
||
|
|
Author string
|
||
|
|
HighlightScore float64
|
||
|
|
}
|
||
|
|
|
||
|
|
type SearchResponse struct {
|
||
|
|
Results []SearchResult
|
||
|
|
Query string
|
||
|
|
Status string // success | unavailable
|
||
|
|
}
|
||
|
|
|
||
|
|
type Client struct {
|
||
|
|
apiKey string
|
||
|
|
baseURL string
|
||
|
|
http *http.Client
|
||
|
|
}
|
||
|
|
|
||
|
|
func NewClient(apiKey string) *Client {
|
||
|
|
return &Client{
|
||
|
|
apiKey: strings.TrimSpace(apiKey),
|
||
|
|
baseURL: defaultBaseURL,
|
||
|
|
http: &http.Client{
|
||
|
|
Timeout: 25 * time.Second,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *Client) Enabled() bool {
|
||
|
|
return c != nil && c.apiKey != ""
|
||
|
|
}
|
||
|
|
|
||
|
|
type SearchOptions struct {
|
||
|
|
Query string
|
||
|
|
Limit int
|
||
|
|
Mode Mode
|
||
|
|
UserLocation string
|
||
|
|
StartPublishedDate string
|
||
|
|
}
|
||
|
|
|
||
|
|
func (c *Client) Search(ctx context.Context, opts SearchOptions) (SearchResponse, error) {
|
||
|
|
out := SearchResponse{Query: strings.TrimSpace(opts.Query), Status: "unavailable"}
|
||
|
|
if !c.Enabled() {
|
||
|
|
return out, nil
|
||
|
|
}
|
||
|
|
if out.Query == "" {
|
||
|
|
return out, fmt.Errorf("exa search query is required")
|
||
|
|
}
|
||
|
|
|
||
|
|
limit := opts.Limit
|
||
|
|
if limit <= 0 {
|
||
|
|
limit = 5
|
||
|
|
}
|
||
|
|
if limit > 20 {
|
||
|
|
limit = 20
|
||
|
|
}
|
||
|
|
|
||
|
|
userLocation := strings.TrimSpace(opts.UserLocation)
|
||
|
|
if userLocation == "" {
|
||
|
|
userLocation = "TW"
|
||
|
|
}
|
||
|
|
|
||
|
|
body := map[string]any{
|
||
|
|
"query": out.Query,
|
||
|
|
"type": "auto",
|
||
|
|
"numResults": limit,
|
||
|
|
"userLocation": userLocation,
|
||
|
|
"contents": map[string]any{
|
||
|
|
"highlights": true,
|
||
|
|
},
|
||
|
|
}
|
||
|
|
if opts.Mode == ModeThreadsDiscover {
|
||
|
|
body["includeDomains"] = []string{"threads.net", "threads.com"}
|
||
|
|
}
|
||
|
|
if start := strings.TrimSpace(opts.StartPublishedDate); start != "" {
|
||
|
|
body["startPublishedDate"] = start
|
||
|
|
}
|
||
|
|
|
||
|
|
payload, err := json.Marshal(body)
|
||
|
|
if err != nil {
|
||
|
|
return out, err
|
||
|
|
}
|
||
|
|
|
||
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.baseURL, bytes.NewReader(payload))
|
||
|
|
if err != nil {
|
||
|
|
return out, err
|
||
|
|
}
|
||
|
|
req.Header.Set("Accept", "application/json")
|
||
|
|
req.Header.Set("Content-Type", "application/json")
|
||
|
|
req.Header.Set("x-api-key", c.apiKey)
|
||
|
|
|
||
|
|
res, err := c.http.Do(req)
|
||
|
|
if err != nil {
|
||
|
|
return out, nil
|
||
|
|
}
|
||
|
|
defer res.Body.Close()
|
||
|
|
|
||
|
|
if res.StatusCode != http.StatusOK {
|
||
|
|
return out, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
raw, err := io.ReadAll(io.LimitReader(res.Body, 1<<20))
|
||
|
|
if err != nil {
|
||
|
|
return out, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
var parsed struct {
|
||
|
|
Results []struct {
|
||
|
|
Title string `json:"title"`
|
||
|
|
URL string `json:"url"`
|
||
|
|
PublishedDate string `json:"publishedDate"`
|
||
|
|
Author string `json:"author"`
|
||
|
|
Highlights []string `json:"highlights"`
|
||
|
|
HighlightScores []float64 `json:"highlightScores"`
|
||
|
|
} `json:"results"`
|
||
|
|
}
|
||
|
|
if err := json.Unmarshal(raw, &parsed); err != nil {
|
||
|
|
return out, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
threadsOnly := opts.Mode == ModeThreadsDiscover
|
||
|
|
for _, item := range parsed.Results {
|
||
|
|
rawURL := strings.TrimSpace(item.URL)
|
||
|
|
if rawURL == "" {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
if threadsOnly && !isThreadsURL(rawURL) {
|
||
|
|
continue
|
||
|
|
}
|
||
|
|
snippet := firstHighlight(item.Highlights)
|
||
|
|
if snippet == "" {
|
||
|
|
snippet = strings.TrimSpace(item.Title)
|
||
|
|
}
|
||
|
|
score := 0.0
|
||
|
|
if len(item.HighlightScores) > 0 {
|
||
|
|
score = item.HighlightScores[0]
|
||
|
|
}
|
||
|
|
out.Results = append(out.Results, SearchResult{
|
||
|
|
Title: strings.TrimSpace(item.Title),
|
||
|
|
Snippet: snippet,
|
||
|
|
URL: rawURL,
|
||
|
|
PublishedDate: strings.TrimSpace(item.PublishedDate),
|
||
|
|
Author: strings.TrimSpace(item.Author),
|
||
|
|
HighlightScore: score,
|
||
|
|
})
|
||
|
|
if len(out.Results) >= limit {
|
||
|
|
break
|
||
|
|
}
|
||
|
|
}
|
||
|
|
out.Status = "success"
|
||
|
|
return out, nil
|
||
|
|
}
|
||
|
|
|
||
|
|
func firstHighlight(items []string) string {
|
||
|
|
for _, item := range items {
|
||
|
|
if trimmed := strings.TrimSpace(item); trimmed != "" {
|
||
|
|
return trimmed
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return ""
|
||
|
|
}
|
||
|
|
|
||
|
|
func isThreadsURL(raw string) bool {
|
||
|
|
lower := strings.ToLower(raw)
|
||
|
|
return strings.Contains(lower, "threads.com") || strings.Contains(lower, "threads.net")
|
||
|
|
}
|