250 lines
9.4 KiB
TypeScript
250 lines
9.4 KiB
TypeScript
import { generateText, streamText, wrapLanguageModel, stepCountIs, extractReasoningMiddleware } from "ai";
|
|
import { devToolsMiddleware } from "@ai-sdk/devtools";
|
|
import axios from "axios";
|
|
import { transform } from "sucrase";
|
|
import u from "@/utils";
|
|
|
|
type AiType = "scriptAgent" | "productionAgent" | "universalAi";
|
|
type FnName = "textRequest" | "imageRequest" | "videoRequest" | "ttsRequest";
|
|
|
|
const AiTypeValues: AiType[] = ["scriptAgent", "productionAgent", "universalAi"];
|
|
async function resolveModelName(value: AiType | `${string}:${string}`): Promise<`${string}:${string}`> {
|
|
if (AiTypeValues.includes(value as AiType)) {
|
|
const agentDeployData = await u.db("o_agentDeploy").where("key", value).first();
|
|
if (!agentDeployData?.modelName) throw new Error(`${value}模型未配置`);
|
|
return agentDeployData.modelName as `${number}:${string}`;
|
|
}
|
|
return value as `${number}:${string}`;
|
|
}
|
|
|
|
async function getVendorTemplateFn(
|
|
fnName: "textRequest",
|
|
modelName: `${string}:${string}`,
|
|
): Promise<(think?: boolean, thinkLevel?: 0 | 1 | 2 | 3) => any>;
|
|
async function getVendorTemplateFn(fnName: Exclude<FnName, "textRequest">, modelName: `${string}:${string}`): Promise<(input: any) => any>;
|
|
async function getVendorTemplateFn(fnName: FnName, modelName: `${string}:${string}`): Promise<any> {
|
|
const [id, name] = modelName.split(":");
|
|
const vendorConfigData = await u.db("o_vendorConfig").where("id", id).first();
|
|
if (!vendorConfigData) throw new Error(`未找到供应商配置 id=${id}`);
|
|
const modelList = await u.vendor.getModelList(id);
|
|
const selectedModel = modelList.find((i: any) => i.modelName == name);
|
|
if (!selectedModel) throw new Error(`未找到模型 ${name} id=${id}`);
|
|
const code = u.vendor.getCode(id);
|
|
const jsCode = transform(code, { transforms: ["typescript"] }).code;
|
|
const running = u.vm(jsCode);
|
|
if (running.vendor) {
|
|
Object.assign(running.vendor.inputValues, JSON.parse(vendorConfigData.inputValues ?? "{}"));
|
|
running.vendor.models = modelList;
|
|
}
|
|
const fn = running[fnName];
|
|
if (!fn) throw new Error(`未找到供应商配置中的函数 ${fnName} id=${id}`);
|
|
if (fnName == "textRequest")
|
|
return (think?: boolean, thinkLevel: 0 | 1 | 2 | 3 = 0) => {
|
|
const effectiveThink = think ?? !!selectedModel.think;
|
|
return fn(selectedModel, effectiveThink, thinkLevel);
|
|
};
|
|
else return <T>(input: T) => fn(input, selectedModel);
|
|
}
|
|
|
|
async function withTaskRecord<T>(
|
|
modelKey: AiType | `${string}:${string}`,
|
|
taskClass: string,
|
|
describe: string,
|
|
relatedObjects: string,
|
|
projectId: number,
|
|
fn: (modelName: `${string}:${string}`, think: Boolean, thinkLevel: 0 | 1 | 2 | 3) => Promise<T>,
|
|
): Promise<T> {
|
|
const modelName = await resolveModelName(modelKey);
|
|
const [id, model] = modelName.split(":");
|
|
const taskRecord = await u.task(projectId, taskClass, model, { describe: describe, content: relatedObjects });
|
|
try {
|
|
const result = await fn(modelName, false, 0);
|
|
taskRecord(1);
|
|
return result;
|
|
} catch (e) {
|
|
taskRecord(-1, u.error(e).message);
|
|
throw e;
|
|
}
|
|
}
|
|
|
|
async function urlToBase64(url: string, retries = 3, delay = 1000): Promise<string> {
|
|
for (let attempt = 1; attempt <= retries; attempt++) {
|
|
try {
|
|
const res = await axios.get(url, { responseType: "arraybuffer" });
|
|
const base64 = Buffer.from(res.data).toString("base64");
|
|
return `${base64}`;
|
|
} catch (e) {
|
|
if (attempt === retries) throw e;
|
|
await new Promise((resolve) => setTimeout(resolve, delay * attempt));
|
|
}
|
|
}
|
|
throw new Error("urlToBase64 failed");
|
|
}
|
|
class AiText {
|
|
private AiType: AiType | `${string}:${string}`;
|
|
private think?: boolean;
|
|
private thinkLevel: 0 | 1 | 2 | 3;
|
|
constructor(AiType: AiType | `${string}:${string}`, think?: boolean, thinkLevel: 0 | 1 | 2 | 3 = 0) {
|
|
this.AiType = AiType;
|
|
this.think = think;
|
|
this.thinkLevel = thinkLevel;
|
|
}
|
|
async invoke(input: Omit<Parameters<typeof generateText>[0], "model">) {
|
|
const switchAiDevTool = await u.db("o_setting").where("key", "switchAiDevTool").first();
|
|
const modelName = await resolveModelName(this.AiType);
|
|
const sdkFn = await getVendorTemplateFn("textRequest", modelName);
|
|
return generateText({
|
|
...(input.tools && { stopWhen: stepCountIs(Object.keys(input.tools).length * 50) }),
|
|
...input,
|
|
model:
|
|
switchAiDevTool?.value === "1"
|
|
? wrapLanguageModel({
|
|
model: await sdkFn(this.think, this.thinkLevel),
|
|
middleware: devToolsMiddleware(),
|
|
})
|
|
: await sdkFn(this.think, this.thinkLevel),
|
|
} as Parameters<typeof generateText>[0]);
|
|
}
|
|
async stream(input: Omit<Parameters<typeof streamText>[0], "model">) {
|
|
const switchAiDevTool = await u.db("o_setting").where("key", "switchAiDevTool").first();
|
|
const modelName = await resolveModelName(this.AiType);
|
|
const sdkFn = await getVendorTemplateFn("textRequest", modelName);
|
|
return streamText({
|
|
...(input.tools && { stopWhen: stepCountIs(Object.keys(input.tools).length * 50) }),
|
|
...input,
|
|
model:
|
|
switchAiDevTool?.value == "1"
|
|
? wrapLanguageModel({
|
|
model: sdkFn(this.think, this.thinkLevel),
|
|
middleware: [
|
|
devToolsMiddleware(),
|
|
extractReasoningMiddleware({
|
|
tagName: "reasoning_content",
|
|
}),
|
|
],
|
|
})
|
|
: wrapLanguageModel({
|
|
model: sdkFn(this.think, this.thinkLevel),
|
|
middleware: extractReasoningMiddleware({
|
|
tagName: "reasoning_content",
|
|
}),
|
|
}),
|
|
} as Parameters<typeof streamText>[0]);
|
|
}
|
|
}
|
|
|
|
type ReferenceList = { type: "image"; base64: string } | { type: "audio"; base64: string } | { type: "video"; base64: string };
|
|
|
|
interface ImageConfig {
|
|
prompt: string;
|
|
referenceList?: Extract<ReferenceList, { type: "image" }>[];
|
|
size: "1K" | "2K" | "4K";
|
|
aspectRatio: `${number}:${number}`;
|
|
}
|
|
|
|
interface TaskRecord {
|
|
taskClass: string; // 任务分类
|
|
describe: string; // 任务描述
|
|
relatedObjects: string; // 相关对象信息,便于后续分析和追踪
|
|
projectId: number; // 项目ID
|
|
}
|
|
|
|
class AiImage {
|
|
private key: `${string}:${string}`;
|
|
private result: string = "";
|
|
constructor(key: `${string}:${string}`) {
|
|
this.key = key;
|
|
}
|
|
async run(input: ImageConfig, taskRecord?: TaskRecord) {
|
|
const modelName = await resolveModelName(this.key);
|
|
const exec = async (mn: `${string}:${string}`) => {
|
|
const fn = await getVendorTemplateFn("imageRequest", mn);
|
|
this.result = await fn(input);
|
|
if (this.result.startsWith("http")) this.result = await urlToBase64(this.result);
|
|
return this;
|
|
};
|
|
if (taskRecord) {
|
|
return withTaskRecord(this.key, taskRecord.taskClass, taskRecord.describe, taskRecord.relatedObjects, taskRecord.projectId, exec);
|
|
}
|
|
return exec(modelName);
|
|
}
|
|
async save(path: string) {
|
|
await u.oss.writeFile(path, this.result);
|
|
return this;
|
|
}
|
|
}
|
|
|
|
type VideoMode =
|
|
| "singleImage" //单图参考
|
|
| "startEndRequired" //首尾帧(两张都得有)
|
|
| "endFrameOptional" //首尾帧(尾帧可选)
|
|
| "startFrameOptional" //首尾帧(首帧可选)
|
|
| "text" //文本
|
|
| (`videoReference:${number}` | `imageReference:${number}` | `audioReference:${number}`)[]; //多参考(数字代表限制数量)
|
|
|
|
interface VideoConfig {
|
|
duration: number;
|
|
resolution: string;
|
|
aspectRatio: "16:9" | "9:16";
|
|
prompt: string;
|
|
referenceList?: ReferenceList[];
|
|
audio?: boolean;
|
|
mode: VideoMode[];
|
|
}
|
|
|
|
class AiVideo {
|
|
private key: `${string}:${string}`;
|
|
private result: string = "";
|
|
constructor(key: `${string}:${string}`) {
|
|
this.key = key;
|
|
}
|
|
async run(input: VideoConfig, taskRecord?: TaskRecord) {
|
|
const modelName = await resolveModelName(this.key);
|
|
const exec = async (mn: `${string}:${string}`) => {
|
|
const fn = await getVendorTemplateFn("videoRequest", mn);
|
|
this.result = await fn(input);
|
|
if (this.result.startsWith("http")) this.result = await urlToBase64(this.result);
|
|
return this;
|
|
};
|
|
if (taskRecord) {
|
|
return withTaskRecord(this.key, taskRecord.taskClass, taskRecord.describe, taskRecord.relatedObjects, taskRecord.projectId, exec);
|
|
}
|
|
return exec(modelName);
|
|
}
|
|
async save(path: string) {
|
|
await u.oss.writeFile(path, this.result);
|
|
return this;
|
|
}
|
|
}
|
|
class AiAudio {
|
|
private key: `${string}:${string}`;
|
|
private result: string = "";
|
|
constructor(key: `${string}:${string}`) {
|
|
this.key = key;
|
|
}
|
|
async run(input: VideoConfig, taskRecord?: TaskRecord) {
|
|
const modelName = await resolveModelName(this.key);
|
|
const exec = async (mn: `${string}:${string}`) => {
|
|
const fn = await getVendorTemplateFn("ttsRequest", mn);
|
|
this.result = await fn(input);
|
|
if (this.result.startsWith("http")) this.result = await urlToBase64(this.result);
|
|
return this;
|
|
};
|
|
if (taskRecord) {
|
|
return withTaskRecord(this.key, taskRecord.taskClass, taskRecord.describe, taskRecord.relatedObjects, taskRecord.projectId, exec);
|
|
}
|
|
return exec(modelName);
|
|
}
|
|
async save(path: string) {
|
|
await u.oss.writeFile(path, this.result);
|
|
return this;
|
|
}
|
|
}
|
|
|
|
export default {
|
|
Text: (AiType: AiType | `${string}:${string}`, think?: boolean, thinkLevel?: 0 | 1 | 2 | 3) => new AiText(AiType, think, thinkLevel),
|
|
Image: (key: `${string}:${string}`) => new AiImage(key),
|
|
Video: (key: `${string}:${string}`) => new AiVideo(key),
|
|
Audio: (key: `${string}:${string}`) => new AiAudio(key),
|
|
};
|