import u from "@/utils"; import { generateText, streamText, Output, stepCountIs, ModelMessage, LanguageModel, Tool, GenerateTextResult } from "ai"; import { wrapLanguageModel } from "ai"; import { devToolsMiddleware } from "@ai-sdk/devtools"; import { parse } from "best-effort-json-parser"; import modelList from "./modelList"; import { z } from "zod"; interface AIInput | undefined = undefined> { system?: string; tools?: Record; maxStep?: number; output?: T; prompt?: string; messages?: Array; } interface AIConfig { model?: string; apiKey?: string; baseURL?: string; } const buildOptions = async (input: AIInput, config: AIConfig) => { let sqlTextModelConfig = {}; if (!config || !config?.model || !config?.apiKey || !config?.baseURL) sqlTextModelConfig = await u.getConfig("text"); const { model, apiKey, baseURL } = { ...sqlTextModelConfig, ...config }; const owned = modelList.find((m) => m.model === model); if (!owned) throw new Error("不支持的模型或厂商"); const modelInstance = owned.instance({ apiKey, baseURL }); const maxStep = input.maxStep ?? (input.tools ? Object.keys(input.tools).length * 5 : undefined); const outputBuilders: Record any> = { schema: (s) => Output.object({ schema: z.object(s) }), object: () => { const jsonSchemaPrompt = `\n请按照以下 JSON Schema 格式返回结果:\n${JSON.stringify( z.toJSONSchema(z.object(input.output)), null, 2, )}\n只返回结果,不要将Schema返回。`; input.system = (input.system ?? "") + jsonSchemaPrompt; // return Output.json(); }, }; const output = input.output ? outputBuilders[owned.responseFormat]?.(input.output) ?? null : null; return { config: { model: process.env.NODE_ENV === "dev" ? wrapLanguageModel({ model: modelInstance(model!) as any, middleware: devToolsMiddleware(), }) : (modelInstance(model!) as LanguageModel), ...(input.system && { system: input.system }), ...(input.prompt ? { prompt: input.prompt } : { messages: input.messages! }), ...(input.tools && owned.tool && { tools: input.tools }), ...(maxStep && { stopWhen: stepCountIs(maxStep) }), ...(output && { output }), }, responseFormat: owned.responseFormat, }; }; type InferOutput = T extends Record ? z.infer> : GenerateTextResult, never>; const ai = Object.create({}) as { invoke | undefined = undefined>(input: AIInput, config?: AIConfig): Promise>; stream(input: AIInput, config?: AIConfig): Promise>; }; ai.invoke = async (input: AIInput, config: AIConfig = {}) => { const options = await buildOptions(input, config); const result = await generateText(options.config); if (options.responseFormat === "object" && input.output) { const pattern = /{[^{}]*}|{(?:[^{}]*|{[^{}]*})*}/g; const jsonLikeTexts = Array.from(result.text.matchAll(pattern), (m) => m[0]); const res = jsonLikeTexts.map((jsonText) => parse(jsonText)); return res[0]; } if (options.responseFormat === "schema" && input.output) { return JSON.parse(result.text); } return result; }; ai.stream = async (input: AIInput, config: AIConfig = {}) => { const options = await buildOptions(input, config); return streamText(options.config); }; export default ai;