diff --git a/src/utils/ai.ts b/src/utils/ai.ts index 87db692..148dcff 100644 --- a/src/utils/ai.ts +++ b/src/utils/ai.ts @@ -4,25 +4,26 @@ import axios from "axios"; import { transform } from "sucrase"; import u from "@/utils"; -type AiType = "scriptAgent" | "productionAgent" | "assetsAi" | "polishingAi" | "ttsDubbing" | "eventExtractAi"; +type AiType = "scriptAgent" | "productionAgent" | "assetsAi" | "polishingAi" | "eventExtractAi" | "ttsDubbing" | "test"; type FnName = "textRequest" | "imageRequest" | "videoRequest" | "ttsRequest"; -const AiTypeValues: AiType[] = ["scriptAgent", "productionAgent", "assetsAi", "polishingAi", "ttsDubbing", "eventExtractAi"]; -async function getVendorTemplateFn(fnName: FnName, value: AiType | `${number}:${string}`) { - let id, modelName; - const isAgent = AiTypeValues.includes(value as AiType); - if (isAgent) { +const AiTypeValues: AiType[] = ["scriptAgent", "productionAgent", "assetsAi", "polishingAi", "eventExtractAi", "ttsDubbing"]; +async function resolveModelName(value: AiType | `${number}:${string}`): Promise<`${number}:${string}`> { + if (AiTypeValues.includes(value as AiType)) { const agentDeployData = await u.db("o_agentDeploy").where("key", value).first(); if (!agentDeployData?.modelName) throw new Error(`${value}模型未配置`); - [id, modelName] = agentDeployData.modelName.split(":"); - } else { - [id, modelName] = value.split(":"); + return agentDeployData.modelName as `${number}:${string}`; } + return value as `${number}:${string}`; +} + +async function getVendorTemplateFn(fnName: FnName, modelName: `${number}:${string}`) { + const [id, name] = modelName.split(":"); const vendorConfigData = await u.db("o_vendorConfig").where("id", id).first(); if (!vendorConfigData) throw new Error(`未找到供应商配置 id=${id}`); const modelList = JSON.parse(vendorConfigData.models ?? "[]"); - const selectedModel = modelList.find((i: any) => i.modelName == modelName); - if (!selectedModel) throw new Error(`未找到模型 ${modelName} id=${id}`); + const selectedModel = modelList.find((i: any) => i.modelName == name); + if (!selectedModel) throw new Error(`未找到模型 ${name} id=${id}`); const jsCode = transform(vendorConfigData.code!, { transforms: ["typescript"] }).code; const fn = u.vm(jsCode)[fnName]; if (!fn) throw new Error(`未找到供应商配置中的函数 ${fnName} id=${id}`); @@ -30,6 +31,23 @@ async function getVendorTemplateFn(fnName: FnName, value: AiType | `${number}:${ else return (input: T) => fn(input, selectedModel); } +async function withTaskRecord( + modelKey: AiType | `${number}:${string}`, + taskClass: string, + fn: (modelName: `${number}:${string}`) => Promise, +): Promise { + const modelName = await resolveModelName(modelKey); + const taskRecord = await u.task(1, taskClass, modelName, { describe: "", content: "" }); + try { + const result = await fn(modelName); + taskRecord(1); + return result; + } catch (e) { + taskRecord(-1, u.error(e).message); + throw e; + } +} + async function urlToBase64(url: string): Promise { const res = await axios.get(url, { responseType: "arraybuffer" }); const base64 = Buffer.from(res.data).toString("base64"); @@ -42,18 +60,22 @@ class AiText { this.AiType = AiType; } async invoke(input: Omit[0], "model">) { - return generateText({ - ...(input.tools && { stopWhen: stepCountIs(Object.keys(input.tools).length * 5) }), - ...input, - model: await getVendorTemplateFn("textRequest", this.AiType), - } as Parameters[0]); + return withTaskRecord(this.AiType, "TaskClass", async (modelName) => + generateText({ + ...(input.tools && { stopWhen: stepCountIs(Object.keys(input.tools).length * 5) }), + ...input, + model: await getVendorTemplateFn("textRequest", modelName), + } as Parameters[0]), + ); } async stream(input: Omit[0], "model">) { - return streamText({ - ...(input.tools && { stopWhen: stepCountIs(Object.keys(input.tools).length * 5) }), - ...input, - model: await getVendorTemplateFn("textRequest", this.AiType), - } as Parameters[0]); + return withTaskRecord(this.AiType, "TaskClass", async (modelName) => + streamText({ + ...(input.tools && { stopWhen: stepCountIs(Object.keys(input.tools).length * 5) }), + ...input, + model: await getVendorTemplateFn("textRequest", modelName), + } as Parameters[0]), + ); } } @@ -72,11 +94,12 @@ class AiImage { this.key = key; } async run(input: ImageConfig) { - const fn = await getVendorTemplateFn("imageRequest", this.key); - this.result = await fn(input); - if (this.result.startsWith("http")) this.result = await urlToBase64(this.result); - - return this; + return withTaskRecord(this.key, "TaskClass", async (modelName) => { + const fn = await getVendorTemplateFn("imageRequest", modelName); + this.result = await fn(input); + if (this.result.startsWith("http")) this.result = await urlToBase64(this.result); + return this; + }); } async save(path: string) { await u.oss.writeFile(path, this.result); @@ -90,10 +113,12 @@ class AiVideo { this.key = key; } async run(input: ImageConfig) { - const fn = await getVendorTemplateFn("videoRequest", this.key); - this.result = await fn(input); - if (this.result.startsWith("http")) this.result = await urlToBase64(this.result); - return this; + return withTaskRecord(this.key, "TaskClass", async (modelName) => { + const fn = await getVendorTemplateFn("videoRequest", modelName); + this.result = await fn(input); + if (this.result.startsWith("http")) this.result = await urlToBase64(this.result); + return this; + }); } async save(path: string) { await u.oss.writeFile(path, this.result); @@ -107,10 +132,12 @@ class AiAudio { this.key = key; } async run(input: ImageConfig) { - const fn = await getVendorTemplateFn("ttsRequest", this.key); - this.result = await fn(input); - if (this.result.startsWith("http")) this.result = await urlToBase64(this.result); - return this; + return withTaskRecord(this.key, "TaskClass", async (modelName) => { + const fn = await getVendorTemplateFn("ttsRequest", modelName); + this.result = await fn(input); + if (this.result.startsWith("http")) this.result = await urlToBase64(this.result); + return this; + }); } async save(path: string) { await u.oss.writeFile(path, this.result);