diff --git a/src/agents/productionAgent/index.ts b/src/agents/productionAgent/index.ts index 709df70..84a1757 100644 --- a/src/agents/productionAgent/index.ts +++ b/src/agents/productionAgent/index.ts @@ -123,7 +123,7 @@ function runSubAgent(parentCtx: AgentContext) { prompt: z.string().max(100).describe("交给子Agent的任务简约描述"), }), execute: async ({ agent, prompt }) => { - //todo 传入md有问题 + //todo 传入md有问题 const fn = [executionAI, supervisionAI][subAgentList.indexOf(agent)]; //运行子Agent const subTextStream = await fn({ ...parentCtx, text: prompt }); diff --git a/src/agents/scriptAgent/index.ts b/src/agents/scriptAgent/index.ts index 0f727bd..174fb8f 100644 --- a/src/agents/scriptAgent/index.ts +++ b/src/agents/scriptAgent/index.ts @@ -11,8 +11,10 @@ export interface AgentContext { socket: Socket; isolationKey: string; text: string; + userMessageTime?: number; abortSignal?: AbortSignal; resTool: ResTool; + msg: ReturnType; } function buildSystemPrompt(skillPrompt: string, mem: Awaited>): string { @@ -35,26 +37,23 @@ function buildSystemPrompt(skillPrompt: string, mem: Awaited { - console.log("%c Line:73 🍧 completion", "background:#93c0a4", completion); - await memory.add("assistant:decision", completion.text); - }, }); return textStream; @@ -83,9 +78,6 @@ export async function decisionAI(ctx: AgentContext) { export async function executionAI(ctx: AgentContext) { const { isolationKey, text, abortSignal, resTool } = ctx; - - resTool.systemMessage("执行层AI 接管聊天"); - const memory = new Memory("scriptAgent", isolationKey); const [skill, mem] = await Promise.all([useSkill("script_agent_execution.md"), memory.get(text)]); @@ -100,10 +92,6 @@ export async function executionAI(ctx: AgentContext) { ...memory.getTools(), ...useTools(ctx.resTool), }, - onFinish: async (completion) => { - console.log("%c Line:102 🍻 completion", "background:#fca650", completion); - await memory.add("assistant:execution", completion.text); - }, }); return textStream; @@ -112,8 +100,6 @@ export async function executionAI(ctx: AgentContext) { export async function supervisionAI(ctx: AgentContext) { const { isolationKey, text, abortSignal, resTool } = ctx; - resTool.systemMessage("监督层AI 接管聊天"); - const memory = new Memory("scriptAgent", isolationKey); const [skill, mem] = await Promise.all([useSkill("script_agent_supervision.md"), memory.get(text)]); @@ -127,10 +113,6 @@ export async function supervisionAI(ctx: AgentContext) { ...skill.tools, ...useTools(ctx.resTool), }, - onFinish: async (completion) => { - console.log("%c Line:129 🍣 completion", "background:#3f7cff", completion); - await memory.add("assistant:supervision", completion.text); - }, }); return textStream; @@ -138,6 +120,7 @@ export async function supervisionAI(ctx: AgentContext) { //工具函数 function runSubAgent(parentCtx: AgentContext) { + const memory = new Memory("scriptAgent", parentCtx.isolationKey); return tool({ description: "启动子Agent执行独立任务。可用子Agent:executionAI, decisionAI, supervisionAI", inputSchema: z.object({ @@ -146,17 +129,30 @@ function runSubAgent(parentCtx: AgentContext) { }), execute: async ({ agent, prompt }) => { const fn = [executionAI, supervisionAI][subAgentList.indexOf(agent)]; - //运行子Agent + + const subMsg = parentCtx.resTool.newMessage("assistant", agent == "executionAI" ? "编剧" : "编辑"); + + // 先完成主Agent当前的消息 + parentCtx.msg.complete(); + // 子Agent用新消息回复 const subTextStream = await fn({ ...parentCtx, text: prompt }); - - let msg = parentCtx.resTool.textMessage(); + let text = subMsg.text(); let fullResponse = ""; - for await (const chunk of subTextStream) { - msg.send(chunk); + text.append(chunk); fullResponse += chunk; } - msg!.end(); + text.complete(); + subMsg.complete(); + if (fullResponse.trim()) { + await memory.add(`assistant:${agent === "executionAI" ? "execution" : "supervision"}`, fullResponse, { + name: agent === "executionAI" ? "编剧" : "编辑", + createTime: new Date(subMsg.datetime).getTime(), + }); + } + + // 为主Agent后续输出创建新消息 + parentCtx.msg = parentCtx.resTool.newMessage("assistant", "统筹"); return fullResponse; }, diff --git a/src/agents/scriptAgent/tools.ts b/src/agents/scriptAgent/tools.ts index f05d736..e3bc2f6 100644 --- a/src/agents/scriptAgent/tools.ts +++ b/src/agents/scriptAgent/tools.ts @@ -38,7 +38,7 @@ export default (resTool: ResTool, toolsNames?: string[]) => { ids: z.array(z.number()).describe("章节id,注意区分"), }), execute: async ({ ids }) => { - resTool.systemMessage(`正在阅读 章节事件 数据...`); + resTool.newMessage('system').text(`正在获取章节事件,章节ID:${ids.join(",")}`).complete(); console.log("[tools] get_novel_events", ids); const data = await u .db("o_novel") @@ -55,7 +55,7 @@ export default (resTool: ResTool, toolsNames?: string[]) => { key: keySchema.describe("数据key"), }), execute: async ({ key }) => { - resTool.systemMessage(`正在阅读 ${planDataKeyLabels[key]} 数据...`); + resTool.newMessage('system').text(`正在阅读 ${planDataKeyLabels[key]} 数据...`).complete(); console.log("[tools] get_planData", key); const planData: planData = await new Promise((resolve) => socket.emit("getPlanData", { key }, (res: any) => resolve(res))); return planData[key]; @@ -67,8 +67,9 @@ export default (resTool: ResTool, toolsNames?: string[]) => { id: z.string().describe("章节id"), }), execute: async ({ id }) => { - console.log(id); - return ""; + console.log("[tools] get_novel_text", id); + const data = await u.db("o_novel").where({ id }).select("chapterData").first(); + return data && data?.chapterData ? data.chapterData : ""; }, }), set_planData_storySkeleton: tool({ @@ -76,7 +77,7 @@ export default (resTool: ResTool, toolsNames?: string[]) => { inputSchema: z.object({ value: planData.shape.storySkeleton }), execute: async ({ value }) => { console.log("[tools] set_planData storySkeleton", value); - resTool.systemMessage("正在保存 故事骨架 数据"); + resTool.newMessage('system').text("正在保存 故事骨架 数据").complete(); socket.emit("setPlanData", { key: "storySkeleton", value }); return true; }, @@ -86,7 +87,7 @@ export default (resTool: ResTool, toolsNames?: string[]) => { inputSchema: z.object({ value: planData.shape.adaptationStrategy }), execute: async ({ value }) => { console.log("[tools] set_planData adaptationStrategy", value); - resTool.systemMessage("正在保存 改编策略 数据"); + resTool.newMessage('system').text("正在保存 改编策略 数据").complete(); socket.emit("setPlanData", { key: "adaptationStrategy", value }); return true; }, diff --git a/src/lib/fixDB.ts b/src/lib/fixDB.ts index b0ba75d..0da1cd3 100644 --- a/src/lib/fixDB.ts +++ b/src/lib/fixDB.ts @@ -27,4 +27,5 @@ export default async (knex: Knex): Promise => { // memories 表新增字段 await addColumn("memories", "episodesId", "text"); await addColumn("memories", "agentType", "text"); + await addColumn("memories", "name", "text"); }; diff --git a/src/lib/initDB.ts b/src/lib/initDB.ts index 839b7a5..c2fbdf6 100644 --- a/src/lib/initDB.ts +++ b/src/lib/initDB.ts @@ -268,7 +268,9 @@ export default async (knex: Knex, forceInit: boolean = false): Promise => table.text("name"); table.text("content"); table.integer("projectId"); + table.integer("extractState"); table.integer("createTime"); + table.text("errorReason"); table.primary(["id"]); table.unique(["id"]); }, @@ -802,6 +804,7 @@ export default async (knex: Knex, forceInit: boolean = false): Promise => table.text("isolationKey").notNullable(); // 记忆隔离键 table.text("type").notNullable(); // 'message' | 'summary' table.text("role"); // 'user' | 'assistant' + table.text("name"); table.text("content").notNullable(); table.text("embedding"); // 向量嵌入 JSON table.text("relatedMessageIds"); // summary关联的message id列表 JSON diff --git a/src/router.ts b/src/router.ts index 63253f9..d2367b0 100644 --- a/src/router.ts +++ b/src/router.ts @@ -1,4 +1,4 @@ -// @routes-hash 63d067de9d3f97b0602ef91a69334bc8 +// @routes-hash cf18bce0fd0b7cd4d351d3c319cc7205 import { Express } from "express"; import route1 from "./routes/agents/clearMemory"; @@ -70,41 +70,43 @@ import route66 from "./routes/script/delScript"; import route67 from "./routes/script/exportScript"; import route68 from "./routes/script/extractAssets"; import route69 from "./routes/script/getScrptApi"; -import route70 from "./routes/script/updateScript"; -import route71 from "./routes/scriptAgent/getPlanData"; -import route72 from "./routes/scriptAgent/setPlanData"; -import route73 from "./routes/setting/about/checkUpdate"; -import route74 from "./routes/setting/agentDeploy/agentSetKey"; -import route75 from "./routes/setting/agentDeploy/deployAgentModel"; -import route76 from "./routes/setting/agentDeploy/getAgentDeploy"; -import route77 from "./routes/setting/dbConfig/clearData"; -import route78 from "./routes/setting/dev/getSwitchAiDevTool"; -import route79 from "./routes/setting/dev/updateSwitchAiDevTool"; -import route80 from "./routes/setting/fileManagement/openFolder"; -import route81 from "./routes/setting/getTextModel"; -import route82 from "./routes/setting/loginConfig/getUser"; -import route83 from "./routes/setting/loginConfig/updateUserPwd"; -import route84 from "./routes/setting/memoryConfig/delAllMemory"; -import route85 from "./routes/setting/memoryConfig/getMemory"; -import route86 from "./routes/setting/memoryConfig/sureMemory"; -import route87 from "./routes/setting/skillManagement/addSkill"; -import route88 from "./routes/setting/skillManagement/deleteSkill"; -import route89 from "./routes/setting/skillManagement/embeddingSkill"; -import route90 from "./routes/setting/skillManagement/generateDescription"; -import route91 from "./routes/setting/skillManagement/getSkillList"; -import route92 from "./routes/setting/skillManagement/scanSkills"; -import route93 from "./routes/setting/skillManagement/updateSkill"; -import route94 from "./routes/setting/vendorConfig/addVendor"; -import route95 from "./routes/setting/vendorConfig/deleteVendor"; -import route96 from "./routes/setting/vendorConfig/getVendorList"; -import route97 from "./routes/setting/vendorConfig/modelTest"; -import route98 from "./routes/setting/vendorConfig/updateCode"; -import route99 from "./routes/setting/vendorConfig/updateVendor"; -import route100 from "./routes/task/getProject"; -import route101 from "./routes/task/getTaskApi"; -import route102 from "./routes/task/getTaskCategories"; -import route103 from "./routes/task/taskDetails"; -import route104 from "./routes/test/test"; +import route70 from "./routes/script/pollScriptAssets"; +import route71 from "./routes/script/updateScript"; +import route72 from "./routes/scriptAgent/getPlanData"; +import route73 from "./routes/scriptAgent/setPlanData"; +import route74 from "./routes/setting/about/checkUpdate"; +import route75 from "./routes/setting/about/downloadApp"; +import route76 from "./routes/setting/agentDeploy/agentSetKey"; +import route77 from "./routes/setting/agentDeploy/deployAgentModel"; +import route78 from "./routes/setting/agentDeploy/getAgentDeploy"; +import route79 from "./routes/setting/dbConfig/clearData"; +import route80 from "./routes/setting/dev/getSwitchAiDevTool"; +import route81 from "./routes/setting/dev/updateSwitchAiDevTool"; +import route82 from "./routes/setting/fileManagement/openFolder"; +import route83 from "./routes/setting/getTextModel"; +import route84 from "./routes/setting/loginConfig/getUser"; +import route85 from "./routes/setting/loginConfig/updateUserPwd"; +import route86 from "./routes/setting/memoryConfig/delAllMemory"; +import route87 from "./routes/setting/memoryConfig/getMemory"; +import route88 from "./routes/setting/memoryConfig/sureMemory"; +import route89 from "./routes/setting/skillManagement/addSkill"; +import route90 from "./routes/setting/skillManagement/deleteSkill"; +import route91 from "./routes/setting/skillManagement/embeddingSkill"; +import route92 from "./routes/setting/skillManagement/generateDescription"; +import route93 from "./routes/setting/skillManagement/getSkillList"; +import route94 from "./routes/setting/skillManagement/scanSkills"; +import route95 from "./routes/setting/skillManagement/updateSkill"; +import route96 from "./routes/setting/vendorConfig/addVendor"; +import route97 from "./routes/setting/vendorConfig/deleteVendor"; +import route98 from "./routes/setting/vendorConfig/getVendorList"; +import route99 from "./routes/setting/vendorConfig/modelTest"; +import route100 from "./routes/setting/vendorConfig/updateCode"; +import route101 from "./routes/setting/vendorConfig/updateVendor"; +import route102 from "./routes/task/getProject"; +import route103 from "./routes/task/getTaskApi"; +import route104 from "./routes/task/getTaskCategories"; +import route105 from "./routes/task/taskDetails"; +import route106 from "./routes/test/test"; export default async (app: Express) => { app.use("/api/agents/clearMemory", route1); @@ -176,39 +178,41 @@ export default async (app: Express) => { app.use("/api/script/exportScript", route67); app.use("/api/script/extractAssets", route68); app.use("/api/script/getScrptApi", route69); - app.use("/api/script/updateScript", route70); - app.use("/api/scriptAgent/getPlanData", route71); - app.use("/api/scriptAgent/setPlanData", route72); - app.use("/api/setting/about/checkUpdate", route73); - app.use("/api/setting/agentDeploy/agentSetKey", route74); - app.use("/api/setting/agentDeploy/deployAgentModel", route75); - app.use("/api/setting/agentDeploy/getAgentDeploy", route76); - app.use("/api/setting/dbConfig/clearData", route77); - app.use("/api/setting/dev/getSwitchAiDevTool", route78); - app.use("/api/setting/dev/updateSwitchAiDevTool", route79); - app.use("/api/setting/fileManagement/openFolder", route80); - app.use("/api/setting/getTextModel", route81); - app.use("/api/setting/loginConfig/getUser", route82); - app.use("/api/setting/loginConfig/updateUserPwd", route83); - app.use("/api/setting/memoryConfig/delAllMemory", route84); - app.use("/api/setting/memoryConfig/getMemory", route85); - app.use("/api/setting/memoryConfig/sureMemory", route86); - app.use("/api/setting/skillManagement/addSkill", route87); - app.use("/api/setting/skillManagement/deleteSkill", route88); - app.use("/api/setting/skillManagement/embeddingSkill", route89); - app.use("/api/setting/skillManagement/generateDescription", route90); - app.use("/api/setting/skillManagement/getSkillList", route91); - app.use("/api/setting/skillManagement/scanSkills", route92); - app.use("/api/setting/skillManagement/updateSkill", route93); - app.use("/api/setting/vendorConfig/addVendor", route94); - app.use("/api/setting/vendorConfig/deleteVendor", route95); - app.use("/api/setting/vendorConfig/getVendorList", route96); - app.use("/api/setting/vendorConfig/modelTest", route97); - app.use("/api/setting/vendorConfig/updateCode", route98); - app.use("/api/setting/vendorConfig/updateVendor", route99); - app.use("/api/task/getProject", route100); - app.use("/api/task/getTaskApi", route101); - app.use("/api/task/getTaskCategories", route102); - app.use("/api/task/taskDetails", route103); - app.use("/api/test/test", route104); + app.use("/api/script/pollScriptAssets", route70); + app.use("/api/script/updateScript", route71); + app.use("/api/scriptAgent/getPlanData", route72); + app.use("/api/scriptAgent/setPlanData", route73); + app.use("/api/setting/about/checkUpdate", route74); + app.use("/api/setting/about/downloadApp", route75); + app.use("/api/setting/agentDeploy/agentSetKey", route76); + app.use("/api/setting/agentDeploy/deployAgentModel", route77); + app.use("/api/setting/agentDeploy/getAgentDeploy", route78); + app.use("/api/setting/dbConfig/clearData", route79); + app.use("/api/setting/dev/getSwitchAiDevTool", route80); + app.use("/api/setting/dev/updateSwitchAiDevTool", route81); + app.use("/api/setting/fileManagement/openFolder", route82); + app.use("/api/setting/getTextModel", route83); + app.use("/api/setting/loginConfig/getUser", route84); + app.use("/api/setting/loginConfig/updateUserPwd", route85); + app.use("/api/setting/memoryConfig/delAllMemory", route86); + app.use("/api/setting/memoryConfig/getMemory", route87); + app.use("/api/setting/memoryConfig/sureMemory", route88); + app.use("/api/setting/skillManagement/addSkill", route89); + app.use("/api/setting/skillManagement/deleteSkill", route90); + app.use("/api/setting/skillManagement/embeddingSkill", route91); + app.use("/api/setting/skillManagement/generateDescription", route92); + app.use("/api/setting/skillManagement/getSkillList", route93); + app.use("/api/setting/skillManagement/scanSkills", route94); + app.use("/api/setting/skillManagement/updateSkill", route95); + app.use("/api/setting/vendorConfig/addVendor", route96); + app.use("/api/setting/vendorConfig/deleteVendor", route97); + app.use("/api/setting/vendorConfig/getVendorList", route98); + app.use("/api/setting/vendorConfig/modelTest", route99); + app.use("/api/setting/vendorConfig/updateCode", route100); + app.use("/api/setting/vendorConfig/updateVendor", route101); + app.use("/api/task/getProject", route102); + app.use("/api/task/getTaskApi", route103); + app.use("/api/task/getTaskCategories", route104); + app.use("/api/task/taskDetails", route105); + app.use("/api/test/test", route106); } diff --git a/src/routes/agents/getMemory.ts b/src/routes/agents/getMemory.ts index 4aece69..a02cee5 100644 --- a/src/routes/agents/getMemory.ts +++ b/src/routes/agents/getMemory.ts @@ -9,11 +9,6 @@ function normalizeRole(role?: string | null): "user" | "assistant" { return role?.startsWith("assistant") ? "assistant" : "user"; } -function getAssistantName(role?: string | null): string | undefined { - if (!role?.startsWith("assistant:")) return undefined; - return role.split(":")[1] || "assistant"; -} - export default router.post( "/", validateFields({ @@ -29,12 +24,14 @@ export default router.post( .db("memories") .where({ isolationKey, type: "message" }) .orderBy("createTime", "asc") - .select("id", "role", "content", "createTime"); + .select("id", "role", "name", "content", "createTime"); const history = rows.map((row) => ({ id: row.id, role: normalizeRole(row.role), - name: getAssistantName(row.role), + name: row.name ?? undefined, + status: "complete", + datetime: new Date(row.createTime).toISOString(), content: [{ type: "markdown", status: "complete", data: row.content }], createTime: row.createTime, })); diff --git a/src/routes/assets/uploadClip.ts b/src/routes/assets/uploadClip.ts index 3890e72..271b7ab 100644 --- a/src/routes/assets/uploadClip.ts +++ b/src/routes/assets/uploadClip.ts @@ -50,7 +50,7 @@ export default router.post( filePath: savePath, type, assetsId: id, - state: '已完成', + state: "已完成", }); await u.db("o_assets").where("id", id).update({ imageId: imageId, diff --git a/src/routes/novel/event/generateEvents.ts b/src/routes/novel/event/generateEvents.ts index 96d10fc..e256d87 100644 --- a/src/routes/novel/event/generateEvents.ts +++ b/src/routes/novel/event/generateEvents.ts @@ -20,8 +20,13 @@ export default router.post( u.db("o_novel").where("projectId", projectId).whereIn("id", novelIds), Promise.resolve(new u.cleanNovel()), ]); - - await u.db("o_novel").where("projectId", projectId).update({ eventState: 0, event: null }); + if (allChapters.length === 0) { + return res.status(400).send(success("没有对应章节")); + } + if (allChapters.filter((item) => item.eventState === 0).length) { + return res.status(400).send(success("存在未完成事件,请先等待事件完成")); + } + await u.db("o_novel").where("projectId", projectId).whereIn("id", novelIds).update({ eventState: 0, event: null }); novel.emitter.on("item", async (item) => { await u .db("o_novel") diff --git a/src/routes/novel/getNovelEventState.ts b/src/routes/novel/getNovelEventState.ts index 3593c8c..bad3f0d 100644 --- a/src/routes/novel/getNovelEventState.ts +++ b/src/routes/novel/getNovelEventState.ts @@ -5,7 +5,6 @@ import { success } from "@/lib/responseFormat"; import { validateFields } from "@/middleware/middleware"; const router = express.Router(); -// 获取原文数据 export default router.post( "/", validateFields({ @@ -13,7 +12,7 @@ export default router.post( }), async (req, res) => { const { ids } = req.body; - const data = await u.db("o_novel").whereIn("id", ids).whereNot("eventState", 0).select("id", "event", "eventState"); + const data = await u.db("o_novel").whereIn("id", ids).whereNot("eventState", 0).select("id", "event", "eventState", "errorReason"); res.status(200).send(success(data)); }, ); diff --git a/src/routes/script/extractAssets.ts b/src/routes/script/extractAssets.ts index a9be363..0d5f73d 100644 --- a/src/routes/script/extractAssets.ts +++ b/src/routes/script/extractAssets.ts @@ -17,18 +17,21 @@ export const AssetSchema = z.object({ type Asset = z.infer; -/** 控制并发的辅助函数 */ -async function pMap(items: T[], fn: (item: T) => Promise, concurrency: number): Promise { - const results: R[] = []; - let index = 0; - async function worker() { - while (index < items.length) { - const i = index++; - results[i] = await fn(items[i]); - } +/** 按批次并发执行,每批 batchSize 个同时跑,批次完成后调用 onBatchDone */ +async function pMapBatch( + items: T[], + fn: (item: T) => Promise, + batchSize: number, + onBatchDone?: (batchResults: R[]) => Promise, +): Promise { + const allResults: R[] = []; + for (let i = 0; i < items.length; i += batchSize) { + const batch = items.slice(i, i + batchSize); + const batchResults = await Promise.all(batch.map(fn)); + allResults.push(...batchResults); + if (onBatchDone) await onBatchDone(batchResults); } - await Promise.all(Array.from({ length: Math.min(concurrency, items.length) }, () => worker())); - return results; + return allResults; } export default router.post( @@ -45,23 +48,94 @@ export default router.post( const intansce = u.Ai.Text("universalAgent"); const novelData = await u.db("o_novel").where("projectId", projectId).select("chapterData"); if (!novelData || novelData.length === 0) return res.status(400).send(error("请先上传小说")); - - // 每个 scriptId 对应提取出的资产列表 - const scriptAssetsMap = new Map(); - + await u.db("o_script").whereIn("id", scriptIds).update({ + extractState: 0, + }); // 构建 scriptId -> script 内容的映射 const scriptMap = new Map(scripts.map((s: o_script) => [s.id, s])); const errors: { scriptId: number; error: string }[] = []; + let successCount = 0; - // 并发提取所有剧本的资产,每个剧本单独跑一次 AI - await pMap( + // 每批提取结果:scriptId -> 资产列表 + type BatchResult = { scriptId: number; assets: Asset[] } | null; + + /** 一批剧本提取完成后统一入库并建立关联 */ + async function persistBatch(batchResults: BatchResult[]) { + const validResults = batchResults.filter((r): r is { scriptId: number; assets: Asset[] } => r !== null && r.assets.length > 0); + if (!validResults.length) return; + + // 合并本批所有资产,同名去重 + const mergedAssetsMap = new Map(); + const assetScriptIds = new Map(); + for (const { scriptId, assets } of validResults) { + for (const asset of assets) { + if (!mergedAssetsMap.has(asset.name)) { + mergedAssetsMap.set(asset.name, asset); + } + const ids = assetScriptIds.get(asset.name) || []; + ids.push(scriptId); + assetScriptIds.set(asset.name, ids); + } + } + + // 查询已有资产,避免重复插入 + const existingAssets = await u.db("o_assets").where("projectId", projectId).select("id", "name"); + const existingMap = new Map(existingAssets.map((a) => [a.name!, a.id!])); + + // 插入不存在的资产 + const toInsert = [...mergedAssetsMap.values()].filter((asset) => !existingMap.has(asset.name)); + if (toInsert.length) { + await u.db("o_assets").insert( + toInsert.map((asset) => ({ + name: asset.name, + prompt: asset.prompt, + type: asset.type, + describe: asset.desc, + projectId: projectId, + startTime: Date.now(), + })), + ); + } + + // 重新查询获取完整的 name -> id 映射 + const allAssets = await u.db("o_assets").where("projectId", projectId).select("id", "name"); + const nameToId = new Map(allAssets.map((a) => [a.name, a.id])); + + // 建立本批各 scriptId 与资产的关联 + const batchScriptIds = validResults.map((r) => r.scriptId); + const scriptAssetRows: { scriptId: number; assetId: number }[] = []; + for (const [name, sIds] of assetScriptIds) { + const assetId = nameToId.get(name); + if (assetId) { + for (const sid of sIds) { + scriptAssetRows.push({ scriptId: sid, assetId }); + } + } + } + + // 先删除本批 scriptId 的旧关联,再插入新的 + await u.db("o_scriptAssets").whereIn("scriptId", batchScriptIds).delete(); + if (scriptAssetRows.length) { + await u.db("o_scriptAssets").insert(scriptAssetRows); + } + + // 本批成功的剧本状态更新为 1(成功) + await u.db("o_script").whereIn("id", batchScriptIds).update({ + extractState: 1, + errorReason: null, + }); + } + + // 按批次并发提取剧本资产,每批完成后统一入库 + await pMapBatch( scriptIds, async (scriptId: number) => { const script = scriptMap.get(scriptId); if (!script) { errors.push({ scriptId, error: "未找到对应剧本" }); - return; + await u.db("o_script").where("id", scriptId).update({ extractState: -1, errorReason: "未找到对应剧本" }); + return null; } // 用闭包收集当前 scriptId 的资产 @@ -102,78 +176,23 @@ export default router.post( const msg = e?.message || String(e); console.error(`[extractAssets] scriptId=${scriptId} name=${script.name} 提取失败:`, msg); errors.push({ scriptId, error: script.name + ":" + u.error(e).message }); - return; + await u.db("o_script").where("id", scriptId).update({ extractState: -1, errorReason: u.error(e).message }); + return null; } if (!collected.length) { errors.push({ scriptId, error: "AI 未返回任何资产" }); - return; + await u.db("o_script").where("id", scriptId).update({ extractState: -1, errorReason: "AI 未返回任何资产" }); + return null; } - scriptAssetsMap.set(scriptId, collected); + successCount++; + return { scriptId, assets: collected }; }, concurrency, + persistBatch, ); - // 如果全部失败,直接返回错误 - if (!scriptAssetsMap.size) { - return res.status(500).send(error("所有剧本资产提取均失败", errors)); - } - - // 按 name 合并所有资产,同名资产只保留第一个 - const mergedAssetsMap = new Map(); - // 同时记录每个资产名称关联的 scriptId 列表 - const assetScriptIds = new Map(); - - for (const [scriptId, assets] of scriptAssetsMap) { - for (const asset of assets) { - if (!mergedAssetsMap.has(asset.name)) { - mergedAssetsMap.set(asset.name, asset); - } - const ids = assetScriptIds.get(asset.name) || []; - ids.push(scriptId); - assetScriptIds.set(asset.name, ids); - } - } - - // 一次性查询数据库中已有的资产 - const existingAssets = await u.db("o_assets").where("projectId", projectId).select("id", "name"); - const existingMap = new Map(existingAssets.map((a) => [a.name!, a.id!])); - - // 批量插入不存在的资产 - const toInsert = [...mergedAssetsMap.values()].filter((asset) => !existingMap.has(asset.name)); - if (toInsert.length) { - await u.db("o_assets").insert( - toInsert.map((asset) => ({ - name: asset.name, - prompt: asset.prompt, - type: asset.type, - describe: asset.desc, - projectId: projectId, - startTime: Date.now(), - })), - ); - } - - // 重新查询所有资产,获取完整的 name -> id 映射 - const allAssets = await u.db("o_assets").where("projectId", projectId).select("id", "name"); - const nameToId = new Map(allAssets.map((a) => [a.name, a.id])); - - // 批量建立 scriptId <-> assetId 的关联 - const scriptAssetRows: { scriptId: number; assetId: number }[] = []; - for (const [name, sIds] of assetScriptIds) { - const assetId = nameToId.get(name); - if (assetId) { - for (const sid of sIds) { - scriptAssetRows.push({ scriptId: sid, assetId }); - } - } - } - await u.db("o_scriptAssets").whereIn("scriptId", scriptIds).delete(); - if (scriptAssetRows.length) { - await u.db("o_scriptAssets").insert(scriptAssetRows); - } - - return res.send(success(errors.length ? `部分剧本资产提取失败\n${errors.map((i) => i.error).join("\n")}` : "资产提取完成")); + return res.send(success("开始提取资产")); }, ); diff --git a/src/routes/script/getScrptApi.ts b/src/routes/script/getScrptApi.ts index 9b253b8..e31a7ce 100644 --- a/src/routes/script/getScrptApi.ts +++ b/src/routes/script/getScrptApi.ts @@ -22,8 +22,10 @@ export default router.post( const assetsData = await u .db("o_assets") .leftJoin("o_scriptAssets", "o_assets.id", "o_scriptAssets.assetId") - // @ts-ignore - .whereIn( "o_scriptAssets.scriptId", data.map((i) => i.id)) + .whereIn( + "o_scriptAssets.scriptId", + data.map((i) => i.id!), + ) .select("o_assets.id", "o_assets.name", "o_scriptAssets.scriptId"); const scriptAssetsMap: Record = {}; assetsData.forEach((i) => { @@ -37,6 +39,8 @@ export default router.post( id: i.id, name: i.name, content: i.content, + extractState: i.extractState, + errorReason: i.errorReason, createTime: i.createTime, relatedAssets: scriptAssetsMap[i.id!] || [], })); diff --git a/src/routes/script/pollScriptAssets.ts b/src/routes/script/pollScriptAssets.ts new file mode 100644 index 0000000..76e0f93 --- /dev/null +++ b/src/routes/script/pollScriptAssets.ts @@ -0,0 +1,18 @@ +import express from "express"; +import u from "@/utils"; +import { z } from "zod"; +import { success } from "@/lib/responseFormat"; +import { validateFields } from "@/middleware/middleware"; +const router = express.Router(); + +export default router.post( + "/", + validateFields({ + ids: z.array(z.number()), + }), + async (req, res) => { + const { ids } = req.body; + const data = await u.db("o_script").whereIn("id", ids).whereNot("extractState", "生成中").select("id", "extractState", "errorReason"); + res.status(200).send(success(data)); + }, +); diff --git a/src/routes/setting/about/downloadApp.ts b/src/routes/setting/about/downloadApp.ts new file mode 100644 index 0000000..5b146c1 --- /dev/null +++ b/src/routes/setting/about/downloadApp.ts @@ -0,0 +1,238 @@ +import express from "express"; +import { success, error } from "@/lib/responseFormat"; +import getPath from "@/utils/getPath"; +import z from "zod"; +import fs from "fs"; +import path from "path"; +import axios from "axios"; +import compressing from "compressing"; +import { validateFields } from "@/middleware/middleware"; +import { spawn } from "child_process"; + +const router = express.Router(); + +/** 仓库源配置 */ +const REPO_SOURCES = { + github: { + repo: "HBAI-Ltd/Toonflow-app", + api: "https://api.github.com/repos/HBAI-Ltd/Toonflow-app/releases/latest", + headers: { Accept: "application/vnd.github.v3+json" }, + }, + gitee: { + repo: "HBAI-Ltd/Toonflow-app", + api: "https://gitee.com/api/v5/repos/HBAI-Ltd/Toonflow-app/releases/latest", + headers: {}, + }, +} as const; + +type SourceType = keyof typeof REPO_SOURCES; + +function normalizeAssets(source: SourceType, release: any): { name: string; browser_download_url: string }[] { + if (source === "github") { + return (release.assets ?? []).map((a: any) => ({ + name: a.name, + browser_download_url: a.browser_download_url, + })); + } + return (release.assets ?? []).map((a: any) => ({ + name: a.name, + browser_download_url: a.browser_download_url, + })); +} + +/** 获取当前系统平台和架构标识,用于匹配安装包文件名 */ +function getPlatformArch(): { platform: string; arch: string } { + const platform = process.platform === "win32" ? "win" : process.platform === "darwin" ? "mac" : "linux"; + const arch = process.arch === "arm64" ? "arm64" : "x64"; + return { platform, arch }; +} + +/** 匹配安装包资产(.exe / .dmg / .AppImage / .portable.exe) */ +function findInstallerAsset(assets: any[]): any | null { + const { platform, arch } = getPlatformArch(); + const installerExtensions: Record = { + win: [".exe"], + mac: [".dmg"], + linux: [".AppImage"], + }; + const exts = installerExtensions[platform] || [".exe"]; + // 优先找 nsis 安装包(排除 portable),如果没有再找 portable + return ( + assets.find( + (a: any) => + exts.some((ext) => a.name.endsWith(ext)) && + a.name.includes(arch) && + !a.name.toLowerCase().includes("portable") && + !a.name.endsWith(".blockmap"), + ) ?? + assets.find((a: any) => exts.some((ext) => a.name.endsWith(ext)) && a.name.includes(arch) && !a.name.endsWith(".blockmap")) ?? + null + ); +} + +/** + * 下载文件到指定路径(支持流式写入与进度) + */ +async function downloadFile(url: string, destPath: string): Promise { + const dir = path.dirname(destPath); + if (!fs.existsSync(dir)) fs.mkdirSync(dir, { recursive: true }); + + const response = await axios.get(url, { + responseType: "stream", + headers: { Accept: "application/octet-stream" }, + timeout: 600_000, // 10 分钟超时 + }); + + const writer = fs.createWriteStream(destPath); + response.data.pipe(writer); + + return new Promise((resolve, reject) => { + writer.on("finish", resolve); + writer.on("error", reject); + }); +} +export default router.post( + "/", + validateFields({ + source: z.enum(["github", "gitee"]), + reinstall: z.boolean(), + latestVersion: z.string(), + }), + async (req, res) => { + try { + const { reinstall, latestVersion, source } = req.body as { + reinstall: boolean; + latestVersion: string; + source: string; + }; + + if (!latestVersion) { + return res.status(400).send(error("缺少目标版本号 latestVersion")); + } + + const sourceConfig = REPO_SOURCES[source as SourceType] ?? REPO_SOURCES.github; + + // ─── 获取 Release 信息(支持 GitHub / Gitee) ────────────────────── + let releaseRes; + try { + releaseRes = await axios.get(sourceConfig.api, { + headers: sourceConfig.headers, + timeout: 30_000, + }); + } catch (e) { + return res.status(500).send(error(`获取 ${source} Release 信息失败`)); + } + + const release = releaseRes.data; + + const assets = normalizeAssets(source as SourceType, release); + + if (reinstall) { + // ═══════════════ 模式 A:下载完整安装包 ═══════════════ + const installerAsset = findInstallerAsset(assets); + + if (!installerAsset) { + return res.status(404).send(error("未找到当前平台的安装包,请前往 GitHub Releases 手动下载")); + } + + const tempDir = getPath(["temp"]); + + if (!fs.existsSync(tempDir)) fs.mkdirSync(tempDir, { recursive: true }); + const installerPath = path.join(tempDir, installerAsset.name); + + // 如果已经下载过相同文件,跳过下载 + if (!fs.existsSync(installerPath)) { + await downloadFile(installerAsset.browser_download_url, installerPath); + } + + // 使用 shell 打开安装程序 + const sub = spawn("cmd", ["/c", `${installerPath}`], { + cwd: tempDir, + detached: true, + stdio: "ignore", + windowsHide: false, + }); + + sub.unref(); + + return res.status(200).send( + success({ + type: "reinstall", + version: latestVersion, + filePath: installerPath, + message: "安装包已下载并打开,请按照安装向导完成更新", + }), + ); + } else { + // ═══════════════ 模式 B:data 补丁热更新 ═══════════════ + const patchAsset = assets.find((a: any) => a.name.startsWith(latestVersion) && a.name.endsWith(".zip")) ?? null; + + if (!patchAsset) { + return res.status(404).send(error("未找到 data 补丁包,请前往 GitHub Releases 手动下载")); + } + // + + const tempDir = getPath(["temp"]); + if (!fs.existsSync(tempDir)) fs.mkdirSync(tempDir, { recursive: true }); + const patchZipPath = path.join(tempDir, `${latestVersion}.zip`); + + // 下载补丁 zip + await downloadFile(patchAsset.browser_download_url, patchZipPath); + + // 解压覆盖到 data 目录(同名文件夹先删除再解压,确保完全替换) + const dataDir = getPath(); + + // 先读取 zip 内的顶层文件夹/文件列表,删除 data 目录下的同名项 + const zipStream = new compressing.zip.UncompressStream({ source: patchZipPath, zipFileNameEncoding: "utf8" }); + const topLevelEntries = new Set(); + await new Promise((resolve, reject) => { + zipStream.on("entry", (_header: any, stream: any, next: () => void) => { + const entryName: string = _header.name || ""; + // 取顶层名称(第一个 / 之前的部分) + const topName = entryName.split("/")[0]; + if (topName) topLevelEntries.add(topName); + stream.resume(); + next(); + }); + zipStream.on("finish", resolve); + zipStream.on("error", reject); + }); + + // 删除 data 目录下与 zip 顶层同名的文件夹/文件 + for (const name of topLevelEntries) { + const targetPath = path.join(dataDir, name); + if (fs.existsSync(targetPath)) { + const stat = fs.statSync(targetPath); + if (stat.isDirectory()) { + fs.rmSync(targetPath, { recursive: true, force: true }); + } else { + fs.unlinkSync(targetPath); + } + } + } + + await compressing.zip.uncompress(patchZipPath, dataDir, { zipFileNameEncoding: "utf8" }); + + // 清理临时文件 + try { + fs.unlinkSync(patchZipPath); + } catch { + // 忽略清理失败 + } + + return res.status(200).send( + success({ + type: "patch", + version: latestVersion, + message: "补丁更新完成,请重启应用以使更新生效", + restartRequired: true, + }), + ); + } + } catch (err: any) { + console.error("[downloadApp] 更新失败:", err); + const message = err?.response?.status === 404 ? "未找到更新资源,请检查版本号或稍后重试" : (err?.message ?? "更新失败,请稍后重试"); + return res.status(500).send(error(message)); + } + }, +); diff --git a/src/routes/setting/vendorConfig/updateVendor.ts b/src/routes/setting/vendorConfig/updateVendor.ts index 499c3a4..df0ddaa 100644 --- a/src/routes/setting/vendorConfig/updateVendor.ts +++ b/src/routes/setting/vendorConfig/updateVendor.ts @@ -73,7 +73,6 @@ export default router.post( "/", validateFields({ id: z.string(), - tsCode: z.string(), inputValues: z.record(z.string(), z.string()), inputs: z.array( z.object({ @@ -121,57 +120,16 @@ export default router.post( ), }), async (req, res) => { - const { id, tsCode, name, models, inputs, inputValues, icon } = req.body; - - const jsCode = transform(tsCode, { transforms: ["typescript"] }).code; - const exports = u.vm(jsCode); - if (!exports) return res.status(400).send(success("脚本文件必须导出对象")); - if (!exports.textRequest) return res.status(400).send(success("脚本文件必须导出文本请求对象")); - if (!exports.imageRequest) return res.status(400).send(success("脚本文件必须导出图像请求对象")); - if (!exports.videoRequest) return res.status(400).send(success("脚本文件必须导出视频请求对象")); - if (!exports.vendor) return res.status(400).send(success("脚本文件必须导出vendor对象")); - const vendor = exports.vendor; - const result = vendorConfigSchema.safeParse(vendor); - if (!result.success) { - const errorMsg = result.error.issues.map((e) => `${e.path.join(".")}: ${e.message}`).join("; "); - return res.status(400).send(error(`vendor配置校验失败: ${errorMsg}`)); - } - const replaceBlockValue = (code: string, key: string, newValue: string): string => { - const open = newValue.trimStart()[0] as "[" | "{"; - const close = open === "[" ? "]" : "}"; - const keyMatch = code.match(new RegExp(`\\b${key}\\s*:\\s*[\\[{]`)); - if (!keyMatch || keyMatch.index === undefined) return code; - const valueStart = keyMatch.index + keyMatch[0].length - 1; - let depth = 0; - let valueEnd = -1; - for (let i = valueStart; i < code.length; i++) { - if (code[i] === open) depth++; - else if (code[i] === close) { - depth--; - if (depth === 0) { - valueEnd = i; - break; - } - } - } - if (valueEnd === -1) return code; - return code.slice(0, valueStart) + newValue + code.slice(valueEnd + 1); - }; - - let updatedTsCode = tsCode; - updatedTsCode = replaceBlockValue(updatedTsCode, "inputs", JSON.stringify(inputs ?? vendor.inputs, null, 2)); - updatedTsCode = replaceBlockValue(updatedTsCode, "inputValues", JSON.stringify(inputValues ?? vendor.inputValues, null, 2)); - updatedTsCode = replaceBlockValue(updatedTsCode, "models", JSON.stringify(models ?? vendor.models, null, 2)); + const { id, name, models, inputs, inputValues, icon } = req.body; await u .db("o_vendorConfig") .where("id", id) .update({ - inputs: inputs ? JSON.stringify(inputs) : JSON.stringify(vendor.inputs), - inputValues: inputValues ? JSON.stringify(inputValues) : JSON.stringify(vendor.inputValues), - models: models ? JSON.stringify(models) : JSON.stringify(vendor.models), - code: updatedTsCode, + inputs: JSON.stringify(inputs), + inputValues: JSON.stringify(inputValues), + models: JSON.stringify(models), }); - res.status(200).send(success(result.data)); + res.status(200).send(success("更新成功")); }, ); diff --git a/src/socket/chatMessagesData.d.ts b/src/socket/chatMessagesData.d.ts new file mode 100644 index 0000000..2059aae --- /dev/null +++ b/src/socket/chatMessagesData.d.ts @@ -0,0 +1,58 @@ +import type { ToolCallEventType } from './adapters/agui/types/events'; + +export type ChatMessageStatus = 'pending' | 'streaming' | 'complete' | 'stop' | 'error'; +export type AttachmentType = 'image' | 'video' | 'audio' | 'pdf' | 'doc' | 'ppt' | 'txt'; +export type ChatComment = 'good' | 'bad' | ''; + +// 基础内容接口 +export interface ChatBaseContent { + type: T; + data: D; + status?: ChatMessageStatus; + id?: string; + strategy?: 'merge' | 'append'; + ext?: Record; +} + +// 内容类型定义 +export type TextContent = ChatBaseContent<'text', string>; +export type MarkdownContent = ChatBaseContent<'markdown', string>; +export type ImageContent = ChatBaseContent<'image', { name?: string; url?: string; width?: number; height?: number }>; +export type ThinkingContent = ChatBaseContent<'thinking', { text?: string; title?: string }>; +export type SearchContent = ChatBaseContent<'search', { title?: string; references?: { title: string; icon?: string; type?: string; url?: string; content?: string; site?: string; date?: string }[] }>; +export type SuggestionContent = ChatBaseContent<'suggestion', { title: string; prompt?: string }[]>; +export type AttachmentContent = ChatBaseContent<'attachment', { fileType: AttachmentType; size?: number; name?: string; url?: string; isReference?: boolean; width?: number; height?: number; extension?: string; metadata?: Record }[]>; +export type ToolCallContent = ChatBaseContent<'toolcall', { toolCallId: string; toolCallName: string; eventType?: ToolCallEventType; parentMessageId?: string; args?: string; chunk?: string; result?: string }>; +export type ActivityContent> = ChatBaseContent<'activity', { activityType: string; messageId?: string; content: T; deltaInfo?: { fromIndex: number; toIndex: number } }>; + +// 聚合内容类型 +export type AIMessageContent = TextContent | MarkdownContent | ImageContent | ThinkingContent | SearchContent | SuggestionContent | ReasoningContent | ToolCallContent | ActivityContent; +export type ReasoningContent = ChatBaseContent<'reasoning', AIMessageContent[]>; +export type UserMessageContent = TextContent | AttachmentContent; + +// 消息类型定义 +export interface ChatBaseMessage { + id: string; + status?: ChatMessageStatus; + datetime?: string; + ext?: any; +} + +export interface UserMessage extends ChatBaseMessage { + role: 'user'; + content: UserMessageContent[]; +} + +export interface AIMessage extends ChatBaseMessage { + role: 'assistant'; + content?: AIMessageContent[]; + history?: AIMessageContent[][]; + comment?: ChatComment; +} + +export interface SystemMessage extends ChatBaseMessage { + role: 'system'; + content: TextContent[]; +} + +export type ChatMessagesData = UserMessage | AIMessage | SystemMessage; \ No newline at end of file diff --git a/src/socket/resTool copy.ts b/src/socket/resTool copy.ts new file mode 100644 index 0000000..fbac487 --- /dev/null +++ b/src/socket/resTool copy.ts @@ -0,0 +1,79 @@ +import u from "@/utils"; +import { Socket } from "socket.io"; + +class ResTool { + public socket: Socket; + public data: Record; + constructor(socket: Socket, data: Record = {}) { + this.socket = socket; + this.data = data; + } + + textMessage(name: string = "AI") { + const messageId = u.uuid(); + this.socket.emit("textMessage", { + type: "start", + messageId, + delta: null, + role: "assistant", + name, + }); + const handle = { + send: (delta: string) => { + this.socket.emit("textMessage", { + type: "content", + messageId, + delta, + role: "assistant", + name, + }); + return handle; + }, + end: () => { + this.socket.emit("textMessage", { + type: "end", + messageId, + delta: null, + role: "assistant", + name, + }); + }, + }; + return handle; + } + thinkMessage() { + const messageId = u.uuid(); + this.socket.emit("thinkMessage", { + type: "start", + messageId, + delta: null, + role: "assistant", + }); + const handle = { + send: (delta: string) => { + this.socket.emit("thinkMessage", { + type: "content", + messageId, + delta, + role: "assistant", + }); + return handle; + }, + end: () => { + this.socket.emit("thinkMessage", { + type: "end", + messageId, + delta: null, + role: "assistant", + }); + }, + }; + return handle; + } + systemMessage(content: string) { + const messageId = u.uuid(); + this.socket.emit("systemMessage", { messageId, content }); + } +} + +export default ResTool; diff --git a/src/socket/resTool.ts b/src/socket/resTool.ts index fbac487..d2fb422 100644 --- a/src/socket/resTool.ts +++ b/src/socket/resTool.ts @@ -1,79 +1,544 @@ import u from "@/utils"; import { Socket } from "socket.io"; +import type { + ChatMessageStatus, + AIMessageContent, + UserMessageContent, + TextContent, + MarkdownContent, + ImageContent, + ThinkingContent, + SearchContent, + SuggestionContent, + ToolCallContent, + ActivityContent, + ReasoningContent, + AttachmentContent, +} from "./ChatMessagesData"; + +type ContentType = AIMessageContent["type"]; class ResTool { public socket: Socket; public data: Record; + constructor(socket: Socket, data: Record = {}) { this.socket = socket; this.data = data; } - textMessage(name: string = "AI") { + // 创建新消息 + newMessage(role: "assistant" | "user" | "system" = "assistant", name?: string) { const messageId = u.uuid(); - this.socket.emit("textMessage", { - type: "start", - messageId, - delta: null, - role: "assistant", + const datetime = new Date().toISOString(); + + this.socket.emit("message", { + id: messageId, + role, name, + status: "pending" as ChatMessageStatus, + datetime, + content: [], }); - const handle = { - send: (delta: string) => { - this.socket.emit("textMessage", { - type: "content", - messageId, - delta, - role: "assistant", - name, - }); - return handle; - }, - end: () => { - this.socket.emit("textMessage", { - type: "end", - messageId, - delta: null, - role: "assistant", - name, - }); - }, - }; - return handle; + + return new MessageBuilder(this.socket, messageId, role, name, datetime); } - thinkMessage() { - const messageId = u.uuid(); - this.socket.emit("thinkMessage", { - type: "start", - messageId, - delta: null, - role: "assistant", + + // 发送错误消息 + sendError(messageId: string, error: string) { + this.socket.emit("message:update", { + id: messageId, + status: "error" as ChatMessageStatus, + ext: { error }, }); - const handle = { - send: (delta: string) => { - this.socket.emit("thinkMessage", { - type: "content", - messageId, - delta, - role: "assistant", - }); - return handle; - }, - end: () => { - this.socket.emit("thinkMessage", { - type: "end", - messageId, - delta: null, - role: "assistant", - }); - }, - }; - return handle; } - systemMessage(content: string) { - const messageId = u.uuid(); - this.socket.emit("systemMessage", { messageId, content }); + + // 发送完成状态 + sendComplete(messageId: string) { + this.socket.emit("message:update", { + id: messageId, + status: "complete" as ChatMessageStatus, + }); + } +} + +// 消息构建器 +class MessageBuilder { + private socket: Socket; + private messageId: string; + private messageRole: "assistant" | "user" | "system"; + private messageName?: string; + private messageDatetime: string; + + constructor(socket: Socket, messageId: string, role: "assistant" | "user" | "system", name?: string, datetime?: string) { + this.socket = socket; + this.messageId = messageId; + this.messageRole = role; + this.messageName = name; + this.messageDatetime = datetime ?? new Date().toISOString(); + } + + get id() { + return this.messageId; + } + + get role() { + return this.messageRole; + } + + get name() { + return this.messageName; + } + + get datetime() { + return this.messageDatetime; + } + + // 更新消息状态 + updateStatus(status: ChatMessageStatus) { + this.socket.emit("message:update", { + id: this.messageId, + status, + }); + return this; + } + + // 添加文本内容 + text(initialText = "") { + const contentId = u.uuid(); + const content: TextContent = { + type: "text", + id: contentId, + data: initialText, + status: "pending", + }; + + this.socket.emit("content:add", { + messageId: this.messageId, + content, + }); + + return new ContentStream(this.socket, this.messageId, contentId, "text"); + } + + // 添加 Markdown 内容 + markdown(initialText = "") { + const contentId = u.uuid(); + const content: MarkdownContent = { + type: "markdown", + id: contentId, + data: initialText, + status: "pending", + }; + + this.socket.emit("content:add", { + messageId: this.messageId, + content, + }); + + return new ContentStream(this.socket, this.messageId, contentId, "markdown"); + } + + // 添加思考内容 + thinking(title = "思考中...") { + const contentId = u.uuid(); + const content: ThinkingContent = { + type: "thinking", + id: contentId, + data: { title, text: "" }, + status: "pending", + }; + + this.socket.emit("content:add", { + messageId: this.messageId, + content, + }); + + return new ThinkingStream(this.socket, this.messageId, contentId); + } + + // 添加搜索内容 + search(title = "搜索中...") { + const contentId = u.uuid(); + const content: SearchContent = { + type: "search", + id: contentId, + data: { title, references: [] }, + status: "pending", + }; + + this.socket.emit("content:add", { + messageId: this.messageId, + content, + }); + + return new SearchStream(this.socket, this.messageId, contentId); + } + + // 添加图片内容 + image(data: ImageContent["data"]) { + const contentId = u.uuid(); + const content: ImageContent = { + type: "image", + id: contentId, + data, + status: "complete", + }; + + this.socket.emit("content:add", { + messageId: this.messageId, + content, + }); + + return this; + } + + // 添加建议内容 + suggestion(suggestions: SuggestionContent["data"]) { + const contentId = u.uuid(); + const content: SuggestionContent = { + type: "suggestion", + id: contentId, + data: suggestions, + status: "complete", + }; + + this.socket.emit("content:add", { + messageId: this.messageId, + content, + }); + + return this; + } + + // 添加工具调用内容 + toolCall(data: ToolCallContent["data"]) { + const contentId = u.uuid(); + const content: ToolCallContent = { + type: "toolcall", + id: contentId, + data: { ...data, parentMessageId: this.messageId }, + status: "pending", + }; + + this.socket.emit("content:add", { + messageId: this.messageId, + content, + }); + + return new ToolCallStream(this.socket, this.messageId, contentId, data.toolCallId); + } + + // 添加活动内容 + activity>(activityType: string, content: T) { + const contentId = u.uuid(); + const activityContent: ActivityContent = { + type: "activity", + id: contentId, + data: { + activityType, + messageId: this.messageId, + content, + }, + status: "complete", + }; + + this.socket.emit("content:add", { + messageId: this.messageId, + content: activityContent, + }); + + return this; + } + + // 添加推理内容 + reasoning() { + const contentId = u.uuid(); + const content: ReasoningContent = { + type: "reasoning", + id: contentId, + data: [], + status: "pending", + }; + + this.socket.emit("content:add", { + messageId: this.messageId, + content, + }); + + return new ReasoningBuilder(this.socket, this.messageId, contentId); + } + + // 完成消息 + complete() { + this.socket.emit("message:update", { + id: this.messageId, + status: "complete" as ChatMessageStatus, + }); + } + + // 停止消息 + stop() { + this.socket.emit("message:update", { + id: this.messageId, + status: "stop" as ChatMessageStatus, + }); + } + + // 错误 + error(errorMsg?: string) { + this.socket.emit("message:update", { + id: this.messageId, + status: "error" as ChatMessageStatus, + ext: errorMsg ? { error: errorMsg } : undefined, + }); + } +} + +// 内容流基类 +class ContentStream { + protected socket: Socket; + protected messageId: string; + protected contentId: string; + protected contentType: ContentType; + + constructor(socket: Socket, messageId: string, contentId: string, contentType: ContentType) { + this.socket = socket; + this.messageId = messageId; + this.contentId = contentId; + this.contentType = contentType; + } + + get id() { + return this.contentId; + } + + // 流式追加数据 + append(chunk: string) { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + type: this.contentType, + data: chunk, + strategy: "append", + status: "streaming", + }); + return this; + } + + // 合并/替换数据 + merge(data: T) { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + type: this.contentType, + data, + strategy: "merge", + status: "streaming", + }); + return this; + } + + // 完成内容 + complete(finalData?: T) { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + type: this.contentType, + data: finalData, + status: "complete", + }); + return this; + } + + // 错误 + error() { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + status: "error", + }); + return this; + } +} + +// 思考内容流 +class ThinkingStream extends ContentStream { + constructor(socket: Socket, messageId: string, contentId: string) { + super(socket, messageId, contentId, "thinking"); + } + + // 追加思考文本 + appendText(chunk: string) { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + type: "thinking", + data: { text: chunk }, + strategy: "append", + status: "streaming", + }); + return this; + } + + // 更新标题 + updateTitle(title: string) { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + type: "thinking", + data: { title }, + strategy: "merge", + status: "streaming", + }); + return this; + } +} + +// 搜索内容流 +class SearchStream extends ContentStream { + constructor(socket: Socket, messageId: string, contentId: string) { + super(socket, messageId, contentId, "search"); + } + + // 添加引用 + addReference(ref: Exclude[0]) { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + type: "search", + data: { references: [ref] }, + strategy: "append", + status: "streaming", + }); + return this; + } + + // 批量添加引用 + addReferences(refs: SearchContent["data"]["references"]) { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + type: "search", + data: { references: refs }, + strategy: "append", + status: "streaming", + }); + return this; + } + + // 更新标题 + updateTitle(title: string) { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + type: "search", + data: { title }, + strategy: "merge", + status: "streaming", + }); + return this; + } +} + +// 工具调用流 +class ToolCallStream extends ContentStream { + private toolCallId: string; + + constructor(socket: Socket, messageId: string, contentId: string, toolCallId: string) { + super(socket, messageId, contentId, "toolcall"); + this.toolCallId = toolCallId; + } + + // 追加参数块 + appendArgs(chunk: string) { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + type: "toolcall", + data: { toolCallId: this.toolCallId, args: chunk }, + strategy: "append", + status: "streaming", + }); + return this; + } + + // 追加结果块 + appendResult(chunk: string) { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + type: "toolcall", + data: { toolCallId: this.toolCallId, chunk }, + strategy: "append", + status: "streaming", + }); + return this; + } + + // 设置完整结果 + setResult(result: string) { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + type: "toolcall", + data: { toolCallId: this.toolCallId, result }, + strategy: "merge", + status: "complete", + }); + return this; + } + + // 更新事件类型 + updateEventType(eventType: ToolCallContent["data"]["eventType"]) { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + type: "toolcall", + data: { toolCallId: this.toolCallId, eventType }, + strategy: "merge", + status: "streaming", + }); + return this; + } +} + +// 推理构建器 +class ReasoningBuilder { + private socket: Socket; + private messageId: string; + private contentId: string; + + constructor(socket: Socket, messageId: string, contentId: string) { + this.socket = socket; + this.messageId = messageId; + this.contentId = contentId; + } + + // 添加子内容 + addContent(content: AIMessageContent) { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + type: "reasoning", + data: [content], + strategy: "append", + status: "streaming", + }); + return this; + } + + // 完成推理 + complete() { + this.socket.emit("content:update", { + messageId: this.messageId, + contentId: this.contentId, + type: "reasoning", + status: "complete", + }); + return this; } } export default ResTool; +export { MessageBuilder, ContentStream, ThinkingStream, SearchStream, ToolCallStream, ReasoningBuilder }; diff --git a/src/socket/routes/scriptAgent.ts b/src/socket/routes/scriptAgent.ts index 18071c9..5867c79 100644 --- a/src/socket/routes/scriptAgent.ts +++ b/src/socket/routes/scriptAgent.ts @@ -3,6 +3,7 @@ import u from "@/utils"; import { Namespace, Socket } from "socket.io"; import * as agent from "@/agents/scriptAgent/index"; import ResTool from "@/socket/resTool"; +import Memory from "@/utils/agent/memory"; async function verifyToken(rawToken: string): Promise { const setting = await u.db("o_setting").where("key", "tokenKey").select("value").first(); @@ -40,23 +41,61 @@ export default (nsp: Namespace) => { }); let abortController: AbortController | null = null; - socket.on("message", async (text: string) => { + socket.on("chat", async (data: { content: string }) => { + const { content } = data; abortController?.abort(); abortController = new AbortController(); const currentController = abortController; + const memory = new Memory("scriptAgent", isolationKey); - const textStream = await agent.decisionAI({ socket, isolationKey, text, abortSignal: currentController.signal, resTool }); + const msg = resTool.newMessage("assistant", "统筹"); + const ctx: agent.AgentContext = { + socket, + isolationKey, + text: content, + userMessageTime: new Date(msg.datetime).getTime() - 1, + abortSignal: currentController.signal, + resTool, + msg, + }; - let msg = resTool.textMessage(); + const textStream = await agent.decisionAI(ctx); + + let currentMsg = ctx.msg; + let text = currentMsg.text(); + let currentContent = ""; + + const persistCurrentMessage = async () => { + if (!currentContent.trim()) return; + await memory.add("assistant:decision", currentContent, { + name: "统筹", + createTime: new Date(currentMsg.datetime).getTime(), + }); + currentContent = ""; + }; + + const syncCurrentMessage = async () => { + if (ctx.msg === currentMsg) return; + text.complete(); + currentMsg.complete(); + await persistCurrentMessage(); + currentMsg = ctx.msg; + text = currentMsg.text(); + }; try { for await (const chunk of textStream) { - msg.send(chunk); + await syncCurrentMessage(); + text.append(chunk); + currentContent += chunk; } } catch (err: any) { if (err.name !== "AbortError") throw err; } finally { - msg.end(); + await syncCurrentMessage(); + text.complete(); + currentMsg.complete(); + await persistCurrentMessage(); if (abortController === currentController) { abortController = null; } diff --git a/src/utils/agent/memory.ts b/src/utils/agent/memory.ts index 16fde04..1ba2413 100644 --- a/src/utils/agent/memory.ts +++ b/src/utils/agent/memory.ts @@ -82,7 +82,8 @@ class Memory { } return result; } - async add(role: string = "user", content: string) { + + async add(role: string = "user", content: string, options?: { name?: string; createTime?: number }) { const { messagesPerSummary } = await this.getConfigData({ messagesPerSummary: DEFAULTS.messagesPerSummary }); const id = uuidv4(); const embedding = await getEmbedding(content); @@ -93,11 +94,12 @@ class Memory { isolationKey, type: "message", role, + name: options?.name, content, embedding: JSON.stringify(embedding), relatedMessageIds: null, summarized: 0, - createTime: Date.now(), + createTime: options?.createTime ?? Date.now(), } as any); // 检查未总结消息数量 @@ -154,7 +156,7 @@ class Memory { const ragResults = vectorSearch(allMessages, queryEmbedding, Number(ragLimit)); return { - shortTerm: shortTerm.map((m: any) => ({ id: m.id, role: m.role, content: m.content, createTime: m.createTime })), + shortTerm: shortTerm.map((m: any) => ({ id: m.id, role: m.role, name: m.name, content: m.content, createTime: m.createTime })), summaries: summaries.map((s) => ({ id: s.id, content: s.content,