haixunMaster/haixun-backend/internal/handler/job/worker_handlers.go

550 lines
17 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package job
import (
"net/http"
"strings"
"haixun-backend/internal/library/clock"
app "haixun-backend/internal/library/errors"
"haixun-backend/internal/library/errors/code"
libprompt "haixun-backend/internal/library/prompt"
"haixun-backend/internal/library/style8d"
joblogic "haixun-backend/internal/logic/job"
"haixun-backend/internal/model/ai/domain/enum"
domai "haixun-backend/internal/model/ai/domain/usecase"
jobentity "haixun-backend/internal/model/job/domain/entity"
jobenum "haixun-backend/internal/model/job/domain/enum"
jobusecase "haixun-backend/internal/model/job/domain/usecase"
personausecase "haixun-backend/internal/model/persona/domain/usecase"
"haixun-backend/internal/response"
"haixun-backend/internal/svc"
"haixun-backend/internal/types"
"github.com/zeromicro/go-zero/rest/httpx"
)
const workerSecretHeader = "X-Worker-Secret"
type workerJobPath struct {
ID string `path:"id"`
}
type claimWorkerJobReq struct {
WorkerType string `json:"worker_type"`
WorkerID string `json:"worker_id"`
}
type workerJobReq struct {
workerJobPath
WorkerID string `json:"worker_id"`
}
type workerHeartbeatReq struct {
workerJobPath
WorkerID string `json:"worker_id"`
TTLSeconds int `json:"ttl_seconds,optional"`
}
type workerProgressReq struct {
workerJobPath
WorkerID string `json:"worker_id"`
Phase string `json:"phase,optional"`
Summary string `json:"summary,optional"`
Percentage *int `json:"percentage,optional"`
Steps []types.JobStepProgressData `json:"steps,optional"`
}
type workerCompleteReq struct {
workerJobPath
WorkerID string `json:"worker_id"`
Result map[string]interface{} `json:"result,optional"`
}
type workerFailReq struct {
workerJobPath
WorkerID string `json:"worker_id"`
Error string `json:"error"`
Phase string `json:"phase,optional"`
}
type storePersonaStyleProfileReq struct {
ID string `path:"id"`
TenantID string `json:"tenant_id"`
OwnerUID string `json:"owner_uid"`
StyleProfile string `json:"style_profile"`
StyleBenchmark string `json:"style_benchmark,optional"`
}
type workerThreadsAccountSessionReq struct {
ID string `path:"id"`
TenantID string `json:"tenant_id"`
OwnerUID string `json:"owner_uid"`
}
type analyzeStyle8DPostReq struct {
Text string `json:"text"`
Permalink string `json:"permalink,optional"`
LikeCount int `json:"like_count,optional"`
ReplyCount int `json:"reply_count,optional"`
}
type analyzeStyle8DReq struct {
workerJobPath
WorkerID string `json:"worker_id"`
TenantID string `json:"tenant_id"`
OwnerUID string `json:"owner_uid"`
PersonaID string `json:"persona_id"`
ThreadsAccountID string `json:"threads_account_id"`
Username string `json:"username"`
Posts []analyzeStyle8DPostReq `json:"posts"`
Steps []types.JobStepProgressData `json:"steps,optional"`
}
func ClaimWorkerJobHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := requireWorkerSecret(r, svcCtx); err != nil {
response.Write(r.Context(), w, nil, err)
return
}
var req claimWorkerJobReq
if err := httpx.Parse(r, &req); err != nil {
response.Write(r.Context(), w, nil, response.WrapRequestError(err))
return
}
run, err := svcCtx.Job.ClaimNext(r.Context(), jobusecase.ClaimNextRequest{
WorkerType: req.WorkerType,
WorkerID: req.WorkerID,
})
if err != nil || run == nil {
response.Write(r.Context(), w, nil, err)
return
}
data := joblogic.ToJobData(run)
response.Write(r.Context(), w, &data, nil)
}
}
func RefreshWorkerJobLockHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := requireWorkerSecret(r, svcCtx); err != nil {
response.Write(r.Context(), w, nil, err)
return
}
var req workerHeartbeatReq
if err := httpx.Parse(r, &req); err != nil {
response.Write(r.Context(), w, nil, response.WrapRequestError(err))
return
}
ttl := req.TTLSeconds
if ttl <= 0 {
ttl = 300
}
err := svcCtx.Job.RefreshRunLock(r.Context(), req.ID, req.WorkerID, ttl)
response.Write(r.Context(), w, map[string]bool{"ok": err == nil}, err)
}
}
func CheckWorkerJobCancelHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := requireWorkerSecret(r, svcCtx); err != nil {
response.Write(r.Context(), w, nil, err)
return
}
var req workerJobReq
if err := httpx.Parse(r, &req); err != nil {
response.Write(r.Context(), w, nil, response.WrapRequestError(err))
return
}
cancelled, err := svcCtx.Job.IsCancelRequested(r.Context(), req.ID)
response.Write(r.Context(), w, map[string]bool{"cancelled": cancelled}, err)
}
}
func AckWorkerJobCancelHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := requireWorkerSecret(r, svcCtx); err != nil {
response.Write(r.Context(), w, nil, err)
return
}
var req workerJobReq
if err := httpx.Parse(r, &req); err != nil {
response.Write(r.Context(), w, nil, response.WrapRequestError(err))
return
}
run, err := svcCtx.Job.AcknowledgeCancel(r.Context(), jobusecase.AcknowledgeCancelRequest{
JobID: req.ID,
WorkerID: req.WorkerID,
})
if err != nil {
response.Write(r.Context(), w, nil, err)
return
}
data := joblogic.ToJobData(run)
response.Write(r.Context(), w, &data, nil)
}
}
func UpdateWorkerJobProgressHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := requireWorkerSecret(r, svcCtx); err != nil {
response.Write(r.Context(), w, nil, err)
return
}
var req workerProgressReq
if err := httpx.Parse(r, &req); err != nil {
response.Write(r.Context(), w, nil, response.WrapRequestError(err))
return
}
percentage := -1
if req.Percentage != nil {
percentage = *req.Percentage
}
run, err := svcCtx.Job.UpdateProgress(r.Context(), jobusecase.UpdateProgressRequest{
JobID: req.ID,
WorkerID: req.WorkerID,
Phase: req.Phase,
Summary: req.Summary,
Percentage: percentage,
Steps: toEntitySteps(req.Steps),
})
if err != nil {
response.Write(r.Context(), w, nil, err)
return
}
data := joblogic.ToJobData(run)
response.Write(r.Context(), w, &data, nil)
}
}
func CompleteWorkerJobHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := requireWorkerSecret(r, svcCtx); err != nil {
response.Write(r.Context(), w, nil, err)
return
}
var req workerCompleteReq
if err := httpx.Parse(r, &req); err != nil {
response.Write(r.Context(), w, nil, response.WrapRequestError(err))
return
}
run, err := svcCtx.Job.CompleteRun(r.Context(), jobusecase.CompleteRunRequest{
JobID: req.ID,
WorkerID: req.WorkerID,
Result: req.Result,
})
if err != nil {
response.Write(r.Context(), w, nil, err)
return
}
data := joblogic.ToJobData(run)
response.Write(r.Context(), w, &data, nil)
}
}
func FailWorkerJobHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := requireWorkerSecret(r, svcCtx); err != nil {
response.Write(r.Context(), w, nil, err)
return
}
var req workerFailReq
if err := httpx.Parse(r, &req); err != nil {
response.Write(r.Context(), w, nil, response.WrapRequestError(err))
return
}
run, err := svcCtx.Job.FailRun(r.Context(), jobusecase.FailRunRequest{
JobID: req.ID,
WorkerID: req.WorkerID,
Error: req.Error,
Phase: req.Phase,
})
if err != nil {
response.Write(r.Context(), w, nil, err)
return
}
data := joblogic.ToJobData(run)
response.Write(r.Context(), w, &data, nil)
}
}
func StorePersonaStyleProfileFromWorkerHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := requireWorkerSecret(r, svcCtx); err != nil {
response.Write(r.Context(), w, nil, err)
return
}
var req storePersonaStyleProfileReq
if err := httpx.Parse(r, &req); err != nil {
response.Write(r.Context(), w, nil, response.WrapRequestError(err))
return
}
if strings.TrimSpace(req.StyleProfile) == "" {
response.Write(r.Context(), w, nil, app.For(code.Persona).InputMissingRequired("style_profile is required"))
return
}
profile := strings.TrimSpace(req.StyleProfile)
benchmark := strings.TrimPrefix(strings.TrimSpace(req.StyleBenchmark), "@")
item, err := svcCtx.Persona.Update(r.Context(), personausecase.UpdateRequest{
TenantID: req.TenantID,
OwnerUID: req.OwnerUID,
PersonaID: req.ID,
Patch: personausecase.PersonaPatch{
StyleProfile: &profile,
StyleBenchmark: &benchmark,
},
})
if err != nil {
response.Write(r.Context(), w, nil, err)
return
}
response.Write(r.Context(), w, map[string]any{"id": item.ID, "update_at": item.UpdateAt}, nil)
}
}
func AnalyzeStyle8DFromWorkerHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := requireWorkerSecret(r, svcCtx); err != nil {
response.Write(r.Context(), w, nil, err)
return
}
var req analyzeStyle8DReq
if err := httpx.Parse(r, &req); err != nil {
response.Write(r.Context(), w, nil, response.WrapRequestError(err))
return
}
if strings.TrimSpace(req.WorkerID) == "" {
response.Write(r.Context(), w, nil, app.For(code.Job).InputMissingRequired("worker_id is required"))
return
}
if strings.TrimSpace(req.PersonaID) == "" {
response.Write(r.Context(), w, nil, app.For(code.Persona).InputMissingRequired("persona_id is required"))
return
}
if strings.TrimSpace(req.ThreadsAccountID) == "" {
response.Write(r.Context(), w, nil, app.For(code.ThreadsAccount).InputMissingRequired("threads_account_id is required"))
return
}
if len(req.Posts) == 0 {
response.Write(r.Context(), w, nil, app.For(code.Persona).InputMissingRequired("posts is required"))
return
}
credential, err := svcCtx.ThreadsAccount.ResolveWorkerAiCredential(
r.Context(),
req.TenantID,
req.OwnerUID,
req.ThreadsAccountID,
)
if err != nil {
response.Write(r.Context(), w, nil, err)
return
}
providerID, err := mapWorkerAIProvider(credential.Provider)
if err != nil {
response.Write(r.Context(), w, nil, err)
return
}
posts := make([]style8d.Post, 0, len(req.Posts))
for _, item := range req.Posts {
text := strings.TrimSpace(item.Text)
if text == "" {
continue
}
posts = append(posts, style8d.Post{
Text: text,
Permalink: strings.TrimSpace(item.Permalink),
LikeCount: item.LikeCount,
ReplyCount: item.ReplyCount,
})
}
if len(posts) == 0 {
response.Write(r.Context(), w, nil, app.For(code.Persona).InputInvalidFormat("posts contain no readable text"))
return
}
steps := toEntitySteps(req.Steps)
steps = markWorkerStep(steps, "style", jobenum.StepStatusRunning, "AI 正在分析 D1D8…")
_, _ = svcCtx.Job.UpdateProgress(r.Context(), jobusecase.UpdateProgressRequest{
JobID: req.ID,
WorkerID: req.WorkerID,
Phase: "style",
Summary: "AI 正在分析八個風格維度…",
Percentage: 55,
Steps: steps,
})
username := strings.TrimPrefix(strings.TrimSpace(req.Username), "@")
systemPrompt, err := libprompt.Style8DSystem()
if err != nil {
response.Write(r.Context(), w, nil, app.For(code.AI).SysInternal("prompt config load failed"))
return
}
result, err := svcCtx.AI.GenerateText(r.Context(), domai.GenerateRequest{
Provider: providerID,
Model: credential.Model,
Credential: domai.Credential{
APIKey: credential.APIKey,
},
System: systemPrompt,
Messages: []domai.Message{
{Role: "user", Content: style8d.BuildUserPrompt(username, posts)},
},
})
if err != nil {
if strings.Contains(err.Error(), "HTTP 401") {
err = app.For(code.AI).SvcThirdParty(
"8D AI 分析授權失敗:目前帳號的研究用 Provider API key 無效或未授權。請到「設定 > 帳號 AI 設定」確認 research provider=" +
credential.Provider + "、model=" + credential.Model + ",並重新貼上對應 provider 的 API key",
)
}
response.Write(r.Context(), w, nil, err)
return
}
parsed, err := style8d.ParseLLMOutput(result.Text)
if err != nil {
response.Write(r.Context(), w, nil, app.For(code.AI).SvcThirdParty("8D LLM 回傳無法解析:"+err.Error()))
return
}
profile := style8d.BuildStoredProfile(username, posts, parsed)
profileJSON, err := profile.JSON()
if err != nil {
response.Write(r.Context(), w, nil, err)
return
}
steps = markWorkerStep(steps, "style", jobenum.StepStatusSucceeded, "8D 風格策略已產生")
steps = markWorkerStep(steps, "store", jobenum.StepStatusRunning, "寫入人設風格策略…")
_, _ = svcCtx.Job.UpdateProgress(r.Context(), jobusecase.UpdateProgressRequest{
JobID: req.ID,
WorkerID: req.WorkerID,
Phase: "store",
Summary: "8D 分析完成,寫入人設…",
Percentage: 88,
Steps: steps,
})
_, err = svcCtx.Persona.Update(r.Context(), personausecase.UpdateRequest{
TenantID: req.TenantID,
OwnerUID: req.OwnerUID,
PersonaID: req.PersonaID,
Patch: personausecase.PersonaPatch{
StyleProfile: &profileJSON,
StyleBenchmark: &username,
},
})
if err != nil {
response.Write(r.Context(), w, nil, err)
return
}
steps = markWorkerStep(steps, "store", jobenum.StepStatusSucceeded, "8D 策略已寫入人設")
_, _ = svcCtx.Job.UpdateProgress(r.Context(), jobusecase.UpdateProgressRequest{
JobID: req.ID,
WorkerID: req.WorkerID,
Phase: "store",
Summary: "8D 策略已寫入人設",
Percentage: 92,
Steps: steps,
})
response.Write(r.Context(), w, map[string]any{
"persona_id": req.PersonaID,
"post_count": len(posts),
"style_profile": profileJSON,
"style_benchmark": username,
}, nil)
}
}
func GetWorkerThreadsAccountSessionHandler(svcCtx *svc.ServiceContext) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if err := requireWorkerSecret(r, svcCtx); err != nil {
response.Write(r.Context(), w, nil, err)
return
}
var req workerThreadsAccountSessionReq
if err := httpx.Parse(r, &req); err != nil {
response.Write(r.Context(), w, nil, response.WrapRequestError(err))
return
}
session, err := svcCtx.ThreadsAccount.GetBrowserSession(r.Context(), req.TenantID, req.OwnerUID, req.ID)
if err != nil {
response.Write(r.Context(), w, nil, err)
return
}
response.Write(r.Context(), w, map[string]any{
"account_id": session.AccountID,
"storage_state": session.StorageState,
"update_at": session.UpdateAt,
}, nil)
}
}
func requireWorkerSecret(r *http.Request, svcCtx *svc.ServiceContext) error {
secret := strings.TrimSpace(svcCtx.Config.InternalWorker.Secret)
if secret == "" {
return nil
}
if r.Header.Get(workerSecretHeader) != secret {
return app.For(code.Auth).AuthUnauthorized("invalid worker secret")
}
return nil
}
func mapWorkerAIProvider(provider string) (enum.ProviderID, error) {
switch strings.TrimSpace(provider) {
case string(enum.ProviderOpenCode):
return enum.ProviderOpenCode, nil
case string(enum.ProviderXAI):
return enum.ProviderXAI, nil
default:
return "", app.For(code.AI).InputInvalidFormat("worker 8D 分析目前僅支援 opencode-go 與 xai請在 AI 設定調整 research provider")
}
}
func markWorkerStep(steps []jobentity.StepProgress, stepID string, status jobenum.StepStatus, message string) []jobentity.StepProgress {
now := clock.NowUnixNano()
found := false
for i := range steps {
if steps[i].ID != stepID {
continue
}
found = true
steps[i].Status = status
steps[i].Message = message
if status == jobenum.StepStatusRunning && steps[i].StartedAt == nil {
steps[i].StartedAt = &now
}
if status == jobenum.StepStatusSucceeded || status == jobenum.StepStatusFailed {
steps[i].EndedAt = &now
}
}
if !found {
item := jobentity.StepProgress{ID: stepID, Status: status, Message: message}
if status == jobenum.StepStatusRunning {
item.StartedAt = &now
}
if status == jobenum.StepStatusSucceeded || status == jobenum.StepStatusFailed {
item.EndedAt = &now
}
steps = append(steps, item)
}
return steps
}
func toEntitySteps(steps []types.JobStepProgressData) []jobentity.StepProgress {
out := make([]jobentity.StepProgress, 0, len(steps))
for _, step := range steps {
out = append(out, jobentity.StepProgress{
ID: step.ID,
Status: jobenum.StepStatus(step.Status),
StartedAt: step.StartedAt,
EndedAt: step.EndedAt,
Message: step.Message,
})
}
return out
}