2026-04-01 03:26:54 +08:00

190 lines
8.0 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import { tool, Tool } from "ai";
import { z } from "zod";
import _ from "lodash";
import ResTool from "@/socket/resTool";
import u from "@/utils";
const deriveAssetSchema = z.object({
id: z.number().describe("衍生资产ID,如果新增则为空"),
assetsId: z.number().describe("关联的资产ID"),
prompt: z.string().describe("生成提示词"),
name: z.string().describe("衍生资产名称"),
desc: z.string().describe("衍生资产描述"),
src: z.string().nullable().describe("衍生资产资源路径"),
state: z.enum(["未生成", "生成中", "已完成", "生成失败"]).describe("衍生资产生成状态"),
type: z.enum(["role", "tool", "scene", "clip"]).describe("衍生资产类型"),
});
export const assetItemSchema = z.object({
id: z.number().describe("资产唯一标识"),
name: z.string().describe("资产名称"),
type: z.enum(["role", "tool", "scene", "clip"]).describe("资产类型"),
prompt: z.string().describe("生成提示词"),
desc: z.string().describe("资产描述"),
derive: z.array(deriveAssetSchema).describe("衍生资产列表"),
});
const storyboardSchema = z.object({
id: z.number().describe("分镜ID必须为真实id"),
duration: z.number().describe("持续时长(秒)"),
prompt: z.string().describe("生成提示词"),
associateAssetsIds: z.array(z.number()).describe("关联资产ID列表"),
src: z.string().nullable().describe("分镜资源路径"),
index: z.number().nullable().optional().describe("分镜排序字段"),
});
const workbenchDataSchema = z.object({
name: z.string().describe("项目名称"),
duration: z.string().describe("视频时长"),
resolution: z.string().describe("分辨率"),
fps: z.string().describe("帧率"),
cover: z.string().optional().describe("封面图片路径"),
gradient: z.string().optional().describe("渐变色配置"),
});
const posterItemSchema = z.object({
id: z.number().describe("海报ID"),
image: z.string().describe("海报图片路径"),
});
export const flowDataSchema = z.object({
script: z.string().describe("剧本内容"),
scriptPlan: z.string().describe("拍摄计划"),
assets: z.array(assetItemSchema).describe("衍生资产"),
storyboardTable: z.string().describe("分镜表"),
storyboard: z.array(storyboardSchema).describe("分镜面板"),
});
export type FlowData = z.infer<typeof flowDataSchema>;
const keySchema = z.enum(Object.keys(flowDataSchema.shape) as [keyof FlowData, ...Array<keyof FlowData>]);
const flowDataKeyLabels = Object.fromEntries(
Object.entries(flowDataSchema.shape).map(([key, schema]) => [key, (schema as z.ZodTypeAny).description ?? key]),
) as Record<keyof FlowData, string>;
interface ToolConfig {
resTool: ResTool;
toolsNames?: string[];
msg: ReturnType<ResTool["newMessage"]>;
}
export default (toolCpnfig: ToolConfig) => {
const { resTool, toolsNames, msg } = toolCpnfig;
const { socket } = resTool;
const tools: Record<string, Tool> = {
get_flowData: tool({
description: "获取工作区数据",
inputSchema: z.object({
key: keySchema.describe("数据key"),
}),
execute: async ({ key }) => {
const thinking = msg.thinking(`正在获取${flowDataKeyLabels[key]}工作区数据...`);
console.log("[tools] get_flowData", key);
const flowData: FlowData = await new Promise((resolve) => socket.emit("getFlowData", { key }, (res: any) => resolve(res)));
thinking.appendText(`获取到${flowDataKeyLabels[key]}:\n` + JSON.stringify(flowData[key], null, 2));
thinking.updateTitle(`获取${flowDataKeyLabels[key]}完成`);
thinking.complete();
return flowData[key];
},
}),
add_deriveAsset: tool({
description: "新增或更新衍生资产",
inputSchema: z.object({
assetsId: z.number().describe("关联的资产ID"),
id: z.number().nullable().describe("衍生资产ID,如果新增则为空"),
name: z.string().describe("衍生资产名称"),
desc: z.string().describe("衍生资产描述"),
}),
execute: async (deriveAsset) => {
const thinking = msg.thinking("正在操作资产...");
const { projectId, scriptId } = resTool.data;
const startTime = Date.now();
const parentAssets = await u.db("o_assets").where("id", deriveAsset.assetsId).select("id", "type").first();
if (!parentAssets) return "关联的资产不存在";
const data = {
id: deriveAsset.id ?? undefined,
assetsId: deriveAsset.assetsId,
projectId,
name: deriveAsset.name,
type: parentAssets.type,
describe: deriveAsset.desc,
startTime,
};
if (deriveAsset.id) {
await u.db("o_assets").where("id", deriveAsset.id).update(data);
thinking.appendText(`已更新衍生资产ID: ${deriveAsset.id}\n`);
} else {
const [insertedId] = await u.db("o_assets").insert(data);
data.id = insertedId;
await u.db("o_scriptAssets").insert({ scriptId, assetId: insertedId });
thinking.appendText(`已新增衍生资产ID: ${insertedId}\n`);
}
const res = await new Promise((resolve) => socket.emit("addDeriveAsset", data, (res: any) => resolve(res)));
thinking.updateTitle("资产操作完成");
thinking.complete();
return res ?? "操作成功";
},
}),
del_deriveAsset: tool({
description: "删除衍生资产",
inputSchema: z.object({
assetsId: z.number().describe("关联的资产ID"),
id: z.number().describe("衍生资产ID"),
}),
execute: async ({ assetsId, id }) => {
const thinking = msg.thinking("正在操作资产...");
const { scriptId } = resTool.data;
await u.db("o_assets").where("id", id).del();
await u.db("o_scriptAssets").where({ scriptId, assetId: id }).del();
thinking.appendText(`已删除衍生资产ID: ${id}\n`);
const res = await new Promise((resolve) => socket.emit("delDeriveAsset", { assetsId, id }, (res: any) => resolve(res)));
thinking.updateTitle("资产操作完成");
thinking.complete();
return res ?? "删除成功";
},
}),
generate_deriveAsset: tool({
description: "生成衍生资产",
inputSchema: z.object({
ids: z.array(z.number()).describe("需要生成的 衍生资产ID"),
}),
execute: async ({ ids }) => {
const thinking = msg.thinking("正在生成衍生资产...");
new Promise((resolve) => socket.emit("generateDeriveAsset", { ids }, (res: any) => resolve(res)))
.then((res) => {
thinking.appendText(`已生成衍生资产ID: ${JSON.stringify(res, null, 2)}\n`);
thinking.updateTitle("衍生资产开始完成");
thinking.complete();
})
.catch((e) => {
thinking.appendText("衍生资产生成失败:\n" + u.error(e).message);
thinking.updateTitle("衍生资产生成失败");
thinking.complete();
});
return "开始生成衍生资产";
},
}),
generate_storyboard: tool({
description: "生成分镜图片",
inputSchema: z.object({
ids: z.array(z.number()).describe("必须获取真实的分镜ID支持批量生成"),
}),
execute: async ({ ids }) => {
const thinking = msg.thinking("正在生成分镜...");
new Promise((resolve) => socket.emit("generateStoryboard", { ids }, (res: any) => resolve(res)))
.then((res) => {
thinking.appendText("生成的分镜数据:\n" + JSON.stringify(res, null, 2));
thinking.updateTitle("分镜生成完成");
thinking.complete();
})
.catch((e) => {
thinking.appendText("分镜生成失败:\n" + u.error(e).message);
thinking.updateTitle("分镜生成失败");
thinking.complete();
});
return "开始生成分镜";
},
}),
};
return toolsNames ? Object.fromEntries(Object.entries(tools).filter(([n]) => toolsNames.includes(n))) : tools;
};