76 lines
2.5 KiB
TypeScript
76 lines
2.5 KiB
TypeScript
import { generateObject, generateText, type LanguageModel } from "ai";
|
||
import type { z } from "zod";
|
||
import { extractJsonFromText } from "./coerce-research-map";
|
||
import {
|
||
getOpenCodeGenerationSettings,
|
||
prefersOpenCodeTextFirst,
|
||
} from "./opencode-go-settings";
|
||
|
||
export interface GenerateStructuredObjectOptions<S extends z.ZodTypeAny> {
|
||
model: LanguageModel;
|
||
schema: S;
|
||
system?: string;
|
||
prompt: string;
|
||
provider: string;
|
||
modelId: string;
|
||
jsonPromptSuffix?: string;
|
||
normalize?: (value: unknown) => unknown;
|
||
}
|
||
|
||
export async function generateStructuredObject<S extends z.ZodTypeAny>(
|
||
options: GenerateStructuredObjectOptions<S>
|
||
): Promise<z.infer<S>> {
|
||
const settings = getOpenCodeGenerationSettings(options.provider, options.modelId);
|
||
const preferText = prefersOpenCodeTextFirst(options.provider, options.modelId);
|
||
const jsonSuffix =
|
||
options.jsonPromptSuffix ?? "\n\n只回傳 JSON,不要 markdown 或額外說明。\n\n強制:所有文字欄位必須使用繁體中文(台灣用語),絕對禁止簡體字。";
|
||
|
||
async function viaObject() {
|
||
const { object } = await generateObject({
|
||
model: options.model,
|
||
schema: options.schema,
|
||
system: `${options.system ?? ""}\n\n強制:所有文字欄位必須使用繁體中文(台灣用語),絕對禁止簡體字。`,
|
||
prompt: options.prompt,
|
||
...settings,
|
||
});
|
||
return options.schema.parse(options.normalize ? options.normalize(object) : object);
|
||
}
|
||
|
||
async function viaText(extra = "") {
|
||
const { text } = await generateText({
|
||
model: options.model,
|
||
system: options.system,
|
||
prompt: `${options.prompt}${jsonSuffix}${extra}`,
|
||
...settings,
|
||
});
|
||
let parsed = extractJsonFromText(text);
|
||
if (Array.isArray(parsed) && parsed.length > 0) {
|
||
parsed = parsed[0];
|
||
}
|
||
return options.schema.parse(options.normalize ? options.normalize(parsed) : parsed);
|
||
}
|
||
|
||
const attempts: Array<() => Promise<z.infer<S>>> = preferText
|
||
? [
|
||
() => viaText(),
|
||
() => viaObject(),
|
||
() => viaText("\n\n上次格式不完整,請務必回傳合法 JSON。"),
|
||
]
|
||
: [
|
||
() => viaObject(),
|
||
() => viaText(),
|
||
() => viaText("\n\n上次格式不完整,請務必回傳合法 JSON。"),
|
||
];
|
||
|
||
let lastError: unknown;
|
||
for (const attempt of attempts) {
|
||
try {
|
||
return await attempt();
|
||
} catch (error) {
|
||
lastError = error;
|
||
}
|
||
}
|
||
|
||
throw lastError instanceof Error ? lastError : new Error("結構化 AI 回傳失敗");
|
||
}
|