diff --git a/package.json b/package.json index a005f94..f8e872d 100644 --- a/package.json +++ b/package.json @@ -49,6 +49,7 @@ "langchain": "^1.2.10", "morgan": "^1.10.1", "qwen-ai-provider": "^0.1.1", + "serialize-error": "^13.0.1", "sharp": "^0.34.5", "sqlite3": "^5.1.7", "zhipu-ai-provider": "^0.2.2", diff --git a/src/agents/storyboard/index.ts b/src/agents/storyboard/index.ts index f215888..efbb10b 100644 --- a/src/agents/storyboard/index.ts +++ b/src/agents/storyboard/index.ts @@ -1,10 +1,8 @@ // @/agents/Storyboard.ts import u from "@/utils"; -import { createAgent } from "langchain"; +import { tool, ModelMessage, Tool } from "ai"; import { EventEmitter } from "events"; -import { openAI } from "@/agents/models"; import { z } from "zod"; -import { tool } from "@langchain/core/tools"; import type { DB } from "@/types/database"; import generateImageTool from "./generateImageTool"; import imageSplitting from "./imageSplitting"; @@ -46,7 +44,7 @@ export default class Storyboard { private readonly projectId: number; private readonly scriptId: number; readonly emitter = new EventEmitter(); - history: Array<[string, string]> = []; + history: ModelMessage[] = []; novelChapters: DB["t_novel"][] = []; // 存储 segmentAgent 生成的片段结果 @@ -58,10 +56,6 @@ export default class Storyboard { // 存储正在生成分镜图的分镜ID private generatingShots: Set = new Set(); - modelName = "gpt-4.1"; - apiKey = ""; - baseURL = ""; - constructor(projectId: number, scriptId: number) { this.projectId = projectId; this.scriptId = scriptId; @@ -105,28 +99,28 @@ export default class Storyboard { // ==================== 剧本相关操作 ==================== - getScript = tool( - async () => { + getScript = tool({ + title: "getScript", + description: "获取剧本内容", + inputSchema: z.object({}), + execute: async () => { this.log("获取剧本", `scriptId: ${this.scriptId}`); const script = await u.db("t_script").where({ id: this.scriptId, projectId: this.projectId }).first(); if (!script) throw new Error("剧本不存在"); return `剧本集:${script.name}\n\n内容:\n\`\`\`${script.content}\`\`\``; }, - { - name: "getScript", - description: "获取剧本内容", - schema: z.object({}), - verboseParsingErrors: true, - }, - ); + }); // ==================== 资产相关操作 ==================== /** * 获取资产列表(供 segmentAgent 和 shotAgent 调用) */ - getAssets = tool( - async () => { + getAssets = tool({ + title: "getAssets", + description: "获取资产列表(角色、道具、场景),包含名称和详细介绍。生成片段和分镜时必须先调用此工具获取资产信息,确保名称一致性", + inputSchema: z.object({}), + execute: async () => { this.log("获取资产列表", `scriptId: ${this.scriptId}`); const scriptData = await u.db("t_script").where({ id: this.scriptId, projectId: this.projectId }).first(); const row = await u.db("t_outline").where({ id: scriptData?.outlineId!, projectId: this.projectId }).first(); @@ -171,69 +165,69 @@ ${sections.join("\n\n")} 2. 禁止在资产名称前后添加修饰词 3. 禁止捏造资产列表中不存在的角色、场景、道具`; }, - { - name: "getAssets", - description: "获取资产列表(角色、道具、场景),包含名称和详细介绍。生成片段和分镜时必须先调用此工具获取资产信息,确保名称一致性", - schema: z.object({}), - verboseParsingErrors: true, - }, - ); + }); // ==================== 片段和分镜工具 ==================== /** * 获取当前存储的片段数据(供 shotAgent 调用) */ - getSegments = tool( - async () => { + getSegments = tool({ + title: "getSegments", + description: "获取当前已生成的片段数据,用于生成分镜", + inputSchema: z.object({}), + execute: async () => { this.log("获取片段数据", `共 ${this.segments.length} 个片段`); if (this.segments.length === 0) { return "暂无片段数据,请先调用 segmentAgent 生成片段"; } return JSON.stringify(this.segments, null, 2); }, - { - name: "getSegments", - description: "获取当前已生成的片段数据,用于生成分镜", - schema: z.object({}), - verboseParsingErrors: true, - }, - ); + }); /** * 更新/存储片段数据(供 segmentAgent 调用) */ - updateSegments = tool( - async ({ segments }: { segments: Segment[] }) => { + updateSegments = tool({ + title: "updateSegments", + description: "存储生成的片段数据,segmentAgent 在生成片段后必须调用此工具保存结果", + inputSchema: z.object({ + segments: z + .array( + z.object({ + index: z.number().describe("片段序号"), + description: z.string().describe("片段描述"), + emotion: z.string().optional().describe("情绪氛围"), + action: z.string().optional().describe("主要动作"), + }), + ) + .describe("片段数组"), + }), + execute: async ({ segments }: { segments: Segment[] }) => { this.log("更新片段数据", `共 ${segments.length} 个片段`); this.segments = segments; this.emit("segmentsUpdated", this.segments); return `成功存储 ${segments.length} 个片段`; }, - { - name: "updateSegments", - description: "存储生成的片段数据,segmentAgent 在生成片段后必须调用此工具保存结果", - schema: z.object({ - segments: z - .array( - z.object({ - index: z.number().describe("片段序号"), - description: z.string().describe("片段描述"), - emotion: z.string().optional().describe("情绪氛围"), - action: z.string().optional().describe("主要动作"), - }), - ) - .describe("片段数组"), - }), - verboseParsingErrors: true, - }, - ); + }); /** * 添加分镜(供 shotAgent 调用) */ - addShots = tool( - async ({ shots }: { shots: Array<{ segmentIndex: number; prompts: string[] }> }) => { + addShots = tool({ + title: "addShots", + description: "添加新的分镜。每个分镜有独立ID,包含多个镜头(每个镜头对应一个提示词)。如果片段已存在分镜会跳过", + inputSchema: z.object({ + shots: z + .array( + z.object({ + segmentIndex: z.number().describe("对应的片段序号"), + prompts: z.array(z.string()).describe("镜头提示词数组,每个提示词对应一个镜头(中文)"), + }), + ) + .describe("要添加的分镜数组"), + }), + execute: async ({ shots }: { shots: Array<{ segmentIndex: number; prompts: string[] }> }) => { const added: { id: number; segmentIndex: number }[] = []; const skipped: number[] = []; @@ -266,29 +260,20 @@ ${sections.join("\n\n")} } return `已添加${addedInfo}。当前共 ${this.shots.length} 个分镜`; }, - { - name: "addShots", - description: "添加新的分镜。每个分镜有独立ID,包含多个镜头(每个镜头对应一个提示词)。如果片段已存在分镜会跳过", - schema: z.object({ - shots: z - .array( - z.object({ - segmentIndex: z.number().describe("对应的片段序号"), - prompts: z.array(z.string()).describe("镜头提示词数组,每个提示词对应一个镜头(中文)"), - }), - ) - .describe("要添加的分镜数组"), - }), - verboseParsingErrors: true, - }, - ); + }); /** * 更新指定分镜(供 shotAgent 调用) * 保留原有 cells 的 id 和 src 字段,只更新 prompt */ - updateShots = tool( - async ({ shotId, prompts }: { shotId: number; prompts: string[] }) => { + updateShots = tool({ + title: "updateShots", + description: "更新指定分镜的镜头提示词。通过分镜ID指定要修改的分镜", + inputSchema: z.object({ + shotId: z.number().describe("要更新的分镜ID"), + prompts: z.array(z.string()).describe("新的镜头提示词数组,每个提示词对应一个镜头"), + }), + execute: async ({ shotId, prompts }: { shotId: number; prompts: string[] }) => { const existingIndex = this.shots.findIndex((item) => item.id === shotId); if (existingIndex === -1) { @@ -314,22 +299,18 @@ ${sections.join("\n\n")} return `已更新分镜 ${shotId}`; }, - { - name: "updateShots", - description: "更新指定分镜的镜头提示词。通过分镜ID指定要修改的分镜", - schema: z.object({ - shotId: z.number().describe("要更新的分镜ID"), - prompts: z.array(z.string()).describe("新的镜头提示词数组,每个提示词对应一个镜头"), - }), - verboseParsingErrors: true, - }, - ); + }); /** * 删除指定分镜(供 shotAgent 调用) */ - deleteShots = tool( - async ({ shotIds }: { shotIds: number[] }) => { + deleteShots = tool({ + title: "deleteShots", + description: "删除指定的分镜。通过分镜ID指定要删除的分镜", + inputSchema: z.object({ + shotIds: z.array(z.number()).describe("要删除的分镜ID数组"), + }), + execute: async ({ shotIds }: { shotIds: number[] }) => { const deleted: number[] = []; const notFound: number[] = []; @@ -351,21 +332,19 @@ ${sections.join("\n\n")} } return `已删除分镜 ${deleted.join(", ")}。当前共 ${this.shots.length} 个分镜`; }, - { - name: "deleteShots", - description: "删除指定的分镜。通过分镜ID指定要删除的分镜", - schema: z.object({ - shotIds: z.array(z.number()).describe("要删除的分镜ID数组"), - }), - verboseParsingErrors: true, - }, - ); + }); /** * 生成分镜图(异步执行,使用 nanoBanana) */ - generateShotImage = tool( - async ({ shotIds }: { shotIds: number[] }) => { + generateShotImage = tool({ + title: "generateShotImage", + description: + "为指定分镜生成分镜图。每个分镜会根据其所有提示词生成一张完整宫格图,然后自动分割为单格图片。通过分镜ID指定,不需要指定具体格子,整个分镜是一个完整的生成单元", + inputSchema: z.object({ + shotIds: z.array(z.number()).describe("要生成分镜图的分镜ID数组"), + }), + execute: async ({ shotIds }: { shotIds: number[] }) => { const toGenerate: number[] = []; const alreadyGenerating: number[] = []; const notFound: number[] = []; @@ -417,16 +396,7 @@ ${sections.join("\n\n")} } return result; }, - { - name: "generateShotImage", - description: - "为指定分镜生成分镜图。每个分镜会根据其所有提示词生成一张完整宫格图,然后自动分割为单格图片。通过分镜ID指定,不需要指定具体格子,整个分镜是一个完整的生成单元", - schema: z.object({ - shotIds: z.array(z.number()).describe("要生成分镜图的分镜ID数组"), - }), - verboseParsingErrors: true, - }, - ); + }); /** * 执行分镜图生成的具体逻辑(异步并发) @@ -566,7 +536,7 @@ ${assetList} private buildConversationHistory(): string { if (!this.history.length) return "无对话历史"; - return this.history.map(([role, content]) => `${role}: ${content}`).join("\n\n"); + return this.history.map(({ role, content }) => `${role}: ${content}`).join("\n\n"); } private async buildFullContext(task: string): Promise { @@ -586,26 +556,33 @@ ${task} // ==================== Sub-Agent ==================== - private createModel() { - return openAI({ - modelName: this.modelName, - configuration: { apiKey: this.apiKey, baseURL: this.baseURL }, - }); - } - /** * 获取不同 Sub-Agent 可用的工具 */ - private getSubAgentTools(agentType: AgentType) { + private getSubAgentTools(agentType: AgentType): Record { switch (agentType) { case "segmentAgent": // segmentAgent 可以获取剧本和资产,并需要调用 updateSegments 保存结果 - return [this.getScript, this.getAssets, this.updateSegments]; + return { + getScript: this.getScript, + getAssets: this.getAssets, + updateSegments: this.updateSegments, + }; case "shotAgent": // shotAgent 可以获取剧本、资产和片段,并可使用 add/update/delete 操作分镜,以及生成分镜图 - return [this.getScript, this.getAssets, this.getSegments, this.addShots, this.updateShots, this.deleteShots, this.generateShotImage]; + return { + getScript: this.getScript, + getAssets: this.getAssets, + getSegments: this.getSegments, + addShots: this.addShots, + updateShots: this.updateShots, + deleteShots: this.deleteShots, + generateShotImage: this.generateShotImage, + }; default: - return [this.getScript]; + return { + getScript: this.getScript, + }; } } @@ -627,74 +604,71 @@ ${task} const context = await this.buildFullContext(task); - const agent = createAgent({ - model: this.createModel(), - systemPrompt: SYSTEM_PROMPTS[agentType], + const { fullStream } = await u.ai.text.stream({ + system: SYSTEM_PROMPTS[agentType], tools: this.getSubAgentTools(agentType), + messages: [{ role: "user", content: context }], + maxStep: 100, }); - const stream = await agent.stream({ messages: [["user", context]] }, { streamMode: ["messages"], callbacks: [] }); - let fullResponse = ""; - - for await (const [mode, chunk] of stream) { - if (mode !== "messages") continue; - const [token] = chunk as any; - const block = token.contentBlocks?.[0]; - - // 处理 AI 文本流 - if (token.type === "ai" && block?.text) { - fullResponse += block.text; - this.emit("subAgentStream", { agent: agentType, text: block.text }); + for await (const item of fullStream) { + if (item.type == "tool-call") { + this.emit("toolCall", { agent: "main", name: item.title, args: null }); } - // 处理 tool 调用 - if (token.type === "ai" && token.tool_calls?.length) { - for (const toolCall of token.tool_calls) { - this.emit("toolCall", { agent: agentType, name: toolCall.name, args: toolCall.args }); - } + if (item.type == "text-delta") { + fullResponse += item.text; + this.emit("subAgentStream", { agent: agentType, text: item.text }); } } this.emit("subAgentEnd", { agent: agentType }); - this.history.push(["ai", fullResponse]); + this.history.push({ + role: "assistant", + content: fullResponse, + }); this.log(`Sub-Agent 完成`, agentType); - return fullResponse; + + return fullResponse ?? `${agentType}已完成任务`; } private createSubAgentTool(agentType: AgentType, description: string) { - return tool(async ({ taskDescription }) => this.invokeSubAgent(agentType, taskDescription), { - name: agentType, + return tool({ + title: agentType, description, - schema: z.object({ + inputSchema: z.object({ taskDescription: z.string().describe("具体的任务描述,包含章节范围、修改要求等详细信息"), }), + execute: async ({ taskDescription }) => this.invokeSubAgent(agentType, taskDescription), }); } // ==================== 主入口 ==================== private getAllTools() { - return [ - this.createSubAgentTool( + return { + segmentAgent: this.createSubAgentTool( "segmentAgent", "调用片段师。负责根据剧本生成片段,会自行调用 getScript 获取剧本内容,并调用 updateSegments 保存片段结果。", ), - this.createSubAgentTool( + shotAgent: this.createSubAgentTool( "shotAgent", "调用分镜师。负责根据片段生成分镜提示词,会自行调用 getSegments 获取片段数据,并调用 addShots/updateShots 保存分镜结果。", ), // this.createSubAgentTool("director", "调用导演。负责审核故事线和大纲,会自行调用 updateOutline 或 saveStoryline 进行修改。"), - this.getScript, - this.getSegments, - this.generateShotImage, + getScript: this.getScript, + getSegments: this.getSegments, + generateShotImage: this.generateShotImage, ...this.getSubAgentTools("segmentAgent"), ...this.getSubAgentTools("shotAgent"), - ]; + }; } async call(msg: string): Promise { - console.log("模型名称:", this.modelName); - this.history.push(["user", msg]); + this.history.push({ + role: "user", + content: msg, + }); const envContext = await this.buildEnvironmentContext(); @@ -702,34 +676,28 @@ ${task} const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么,请直接输出Agent配置异常"; - const mainAgent = createAgent({ - model: this.createModel(), + const { fullStream } = await u.ai.text.stream({ + system: `${envContext}\n${mainPrompts}`, tools: this.getAllTools(), - systemPrompt: `${envContext}\n${mainPrompts}`, + messages: this.history, + maxStep: 100, }); - const stream = await mainAgent.stream({ messages: this.history }, { streamMode: ["messages"], callbacks: [] }); let fullResponse = ""; - - for await (const [mode, chunk] of stream) { - if (mode !== "messages") continue; - const [token] = chunk as any; - const block = token.contentBlocks?.[0]; - // 处理 AI 文本流 - if (token.type === "ai" && block?.text) { - fullResponse += block.text; - this.emit("data", block.text); + for await (const item of fullStream) { + if (item.type == "tool-call") { + this.emit("toolCall", { agent: "main", name: item.title, args: null }); } - - // 处理 tool 调用 - if (token.type === "ai" && token.tool_calls?.length) { - for (const toolCall of token.tool_calls) { - this.emit("toolCall", { agent: "main", name: toolCall.name, args: toolCall.args }); - } + if (item.type == "text-delta") { + fullResponse += item.text; + this.emit("data", item.text); } } + this.history.push({ + role: "assistant", + content: fullResponse, + }); - this.history.push(["assistant", fullResponse]); this.emit("response", fullResponse); return fullResponse; diff --git a/src/routes/other/testImage.ts b/src/routes/other/testImage.ts index 11d52f3..1fc3ef1 100644 --- a/src/routes/other/testImage.ts +++ b/src/routes/other/testImage.ts @@ -16,25 +16,34 @@ export default router.post( }), async (req, res) => { const { modelName, apiKey, baseURL, manufacturer } = req.body; - try { - const contentStr = await u.ai.generateImage( - { - prompt: "2D cat", - imageBase64: [], - aspectRatio: "16:9", - size: "1K", - }, - { - model: modelName, - apiKey, - baseURL, - manufacturer, - }, - ); - res.status(200).send(success(contentStr)); - } catch (err: any) { - const message = err?.response?.data?.error?.message || err?.error?.message || "模型调用失败"; - res.status(500).send(error(message)); - } + + const image =await u.ai.image({ + prompt: "2D cat", + imageBase64: [], + aspectRatio: "16:9", + size: "1K", + }); + res.status(200).send(success(image)); + + // try { + // const contentStr = await u.ai.generateImage( + // { + // prompt: "2D cat", + // imageBase64: [], + // aspectRatio: "16:9", + // size: "1K", + // }, + // { + // model: modelName, + // apiKey, + // baseURL, + // manufacturer, + // }, + // ); + // res.status(200).send(success(contentStr)); + // } catch (err: any) { + // const message = err?.response?.data?.error?.message || err?.error?.message || "模型调用失败"; + // res.status(500).send(error(message)); + // } }, ); diff --git a/src/utils.ts b/src/utils.ts index e9f14c6..4022fd3 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -6,18 +6,24 @@ import number2Chinese from "@/utils/number2Chinese"; import deleteOutline from "@/utils/deleteOutline"; import getConfig from "./utils/getConfig"; import { v4 as uuid } from "uuid"; +import error from "@/utils/error"; +import * as imageTools from "@/utils/imageTools"; -import AIText from "@/utils/ai/text"; +import AIText from "@/utils/ai/text/index"; +import AIImage from "@/utils/ai/image/index"; export default { db, oss, ai: { text: AIText, + image: AIImage, }, editImage, number2Chinese, deleteOutline, getConfig, uuid, + error, + imageTools, }; diff --git a/src/utils/ai/image/index.ts b/src/utils/ai/image/index.ts new file mode 100644 index 0000000..8aefebe --- /dev/null +++ b/src/utils/ai/image/index.ts @@ -0,0 +1,44 @@ +import "./type"; +import u from "@/utils"; +import modelList from "./modelList"; +import axios from "axios"; + +import volcengine from "./owned/volcengine"; +import kling from "./owned/kling"; + + +interface AIConfig { + model?: string; + apiKey?: string; + baseURL?: string; +} + +const urlToBase64 = async (url: string): Promise => { + const res = await axios.get(url, { responseType: "arraybuffer" }); + const base64 = Buffer.from(res.data).toString("base64"); + const mimeType = res.headers["content-type"] || "image/png"; + return `data:${mimeType};base64,${base64}`; +}; + +const modelInstance = { + gemini: null, + volcengine: volcengine, + kling: kling, + vidu: null, + runninghub: null, + apimart: null, +} as const; + +export default async (input: ImageConfig, config?: AIConfig) => { + const sqlTextModelConfig = await u.getConfig("image"); + const { model, apiKey, baseURL, manufacturer } = { ...sqlTextModelConfig, ...config }; + const manufacturerFn = modelInstance[manufacturer as keyof typeof modelInstance]; + if (!manufacturerFn) if (!manufacturerFn) throw new Error("不支持的图片厂商"); + const owned = modelList.find((m) => m.model === model); + if (!owned) throw new Error("不支持的模型"); + + let imageUrl = await manufacturerFn(input, { model, apiKey, baseURL }); + if (!input.resType) input.resType = "b64"; + if (input.resType === "b64" && imageUrl.startsWith("http")) imageUrl = await urlToBase64(imageUrl); + return input; +}; diff --git a/src/utils/ai/image/modelList.ts b/src/utils/ai/image/modelList.ts new file mode 100644 index 0000000..2b9bae4 --- /dev/null +++ b/src/utils/ai/image/modelList.ts @@ -0,0 +1,77 @@ +interface Owned { + manufacturer: string; + model: string; + grid: boolean; + type: "t2i" | "ti2i" | "i2i"; +} + +const modelList: Owned[] = [ + // 火山引擎 + { + manufacturer: "volcengine", + model: "doubao-seedream-4-5-251128", + grid: false, + type: "ti2i", + }, + { + manufacturer: "volcengine", + model: "doubao-seedream-4-0-250828", + grid: false, + type: "ti2i", + }, + //可灵 + { + manufacturer: "kling", + model: "kling-image-o1", + grid: false, + type: "ti2i", + }, + //gemini + { + manufacturer: "gemini", + model: "gemini-2.5-flash-image", + grid: true, + type: "ti2i", + }, + { + manufacturer: "gemini", + model: "gemini-2.5-flash-image-preview", + grid: true, + type: "ti2i", + }, + { + manufacturer: "gemini", + model: "gemini-2.5-flash-image-preview-all", + grid: true, + type: "ti2i", + }, + { + manufacturer: "gemini", + model: "gemini-3-pro-image-preview", + grid: true, + type: "ti2i", + }, + //Vidu + { + manufacturer: "vidu", + model: "viduq2", + grid: false, + type: "ti2i", + }, + //RunningHub + { + manufacturer: "runninghub", + model: "nanobanana", + grid: true, + type: "ti2i", + }, + //ApiMart + { + manufacturer: "apimart", + model: "nanobanana", + grid: true, + type: "ti2i", + }, +]; + +export default modelList; diff --git a/src/utils/ai/image/owned/gemini.ts b/src/utils/ai/image/owned/gemini.ts new file mode 100644 index 0000000..e710879 --- /dev/null +++ b/src/utils/ai/image/owned/gemini.ts @@ -0,0 +1,34 @@ +import "../type"; +import { createGoogleGenerativeAI } from "@ai-sdk/google"; +import { generateImage } from "ai"; + +export default async (input: ImageConfig, config: AIConfig): Promise => { + if (!config.model) throw new Error("缺少Model名称"); + if (!config.apiKey) throw new Error("缺少API Key"); + if (!input.prompt) throw new Error("缺少提示词"); + + const google = createGoogleGenerativeAI({ + apiKey: config.apiKey, + baseURL: config.baseURL, + }); + + // 构建完整的提示词 + const fullPrompt = input.systemPrompt ? `${input.systemPrompt}\n\n${input.prompt}` : input.prompt; + + // 根据 size 配置映射到具体尺寸 + const sizeMap: Record = { + "1K": "1024x1024", + "2K": "2048x2048", + "4K": "4096x4096", + }; + + const { image } = await generateImage({ + model: google.image(config.model), + prompt: fullPrompt, + aspectRatio: input.aspectRatio as "1:1" | "3:4" | "4:3" | "9:16" | "16:9", + size: sizeMap[input.size] ?? "1024x1024", + }); + + // 返回生成的图片 base64 + return image.base64; +}; diff --git a/src/utils/ai/image/owned/kling.ts b/src/utils/ai/image/owned/kling.ts new file mode 100644 index 0000000..fa32aa4 --- /dev/null +++ b/src/utils/ai/image/owned/kling.ts @@ -0,0 +1,107 @@ +import "../type"; +import axios from "axios"; +import jwt from "jsonwebtoken"; +import u from "@/utils"; +import { pollTask } from "@/utils/ai/utils"; + +function generateJwtToken(ak: string, sk: string): string { + const now = Math.floor(Date.now() / 1000); + const payload = { + iss: ak, + exp: now + 1800, + nbf: now - 5, + }; + return jwt.sign(payload, sk, { + algorithm: "HS256", + header: { alg: "HS256", typ: "JWT" }, + }); +} + +function getApiToken(apiKey: string): string { + const trimmedKey = apiKey.replace(/^Bearer\s+/i, "").trim(); + + if (trimmedKey.includes("|")) { + const parts = trimmedKey.split("|"); + if (parts.length !== 2 || !parts[0].trim() || !parts[1].trim()) { + throw new Error("API Key格式错误,请使用 ak|sk 格式"); + } + return generateJwtToken(parts[0].trim(), parts[1].trim()); + } + + return trimmedKey; +} + +async function processImages(imageBase64: string[]): Promise> { + let images = imageBase64.filter((img) => img?.trim()); + if (images.length === 0) return []; + + // 压缩所有图片到10MB以内 + images = await Promise.all(images.map((img) => u.imageTools.compressImage(img, "10mb"))); + + // 参考主体数量和参考图片数量之和不得超过10 + if (images.length > 10) { + const mergeImageList = images.splice(9); + const mergedImage = await u.imageTools.mergeImages(mergeImageList, "10mb"); + images.push(mergedImage); + } + + return images.map((img) => ({ + image: img.replace(/^data:image\/[a-z]+;base64,/i, ""), + })); +} + +export default async (input: ImageConfig, config: AIConfig): Promise => { + if (!config.apiKey) throw new Error("缺少API Key"); + if (!input.prompt) throw new Error("缺少提示词,prompt为必填项"); + + const authorization = `Bearer ${getApiToken(config.apiKey)}`; + const baseURL = (config.baseURL ?? "https://api-beijing.klingai.com/v1/images/omni-image").replace(/\/+$/, ""); + const imageList = await processImages(input.imageBase64); + + const body: Record = { + model_name: config.model || "kling-image-o1", + prompt: input.prompt, + n: 1, + ...(input.size !== "4K" && { resolution: input.size.toLowerCase() }), + ...(imageList.length > 0 && { image_list: imageList }), + }; + + const headers = { + "Content-Type": "application/json", + Authorization: authorization, + }; + + try { + const { data: createData } = await axios.post(baseURL, body, { headers }); + + if (createData.code !== 0) { + throw new Error(createData.message || "创建任务失败"); + } + + const taskId = createData.data?.task_id; + if (!taskId) throw new Error("未获取到任务ID"); + + const queryUrl = `${baseURL}/${taskId}`; + return await pollTask(async () => { + const { data: queryData } = await axios.get(queryUrl, { headers }); + + if (queryData.code !== 0) { + return { completed: false, error: queryData.message || "查询任务失败" }; + } + + const { task_status, task_status_msg, task_result } = queryData.data || {}; + + if (task_status === "failed") { + return { completed: false, error: task_status_msg || "图片生成失败" }; + } + + if (task_status === "succeed") { + return { completed: true, imageUrl: task_result?.images?.[0]?.url }; + } + + return { completed: false }; + }); + } catch (error) { + throw new Error(u.error(error).message || "可灵图片生成失败"); + } +} diff --git a/src/utils/ai/image/owned/volcengine.ts b/src/utils/ai/image/owned/volcengine.ts new file mode 100644 index 0000000..2d93fdf --- /dev/null +++ b/src/utils/ai/image/owned/volcengine.ts @@ -0,0 +1,31 @@ +import "../type"; +import axios from "axios"; +import u from "@/utils"; + +export default async (input: ImageConfig, config: AIConfig): Promise => { + if (!config.model) throw new Error("缺少Model名称"); + if (!config.apiKey) throw new Error("缺少API Key"); + + const apiKey = "Bearer " + config.apiKey.replace(/Bearer\s+/g, "").trim(); + const size = input.size === "1K" ? "2K" : input.size; + + const body: Record = { + model: config.model, + prompt: input.prompt, + size, + response_format: "url", + sequential_image_generation: "disabled", + stream: false, + watermark: false, + ...(input.imageBase64 && { image: input.imageBase64 }), + }; + + const url = config.baseURL ?? "https://ark.cn-beijing.volces.com/api/v3/images/generations"; + try { + const { data } = await axios.post(url, body, { headers: { Authorization: apiKey } }); + return data.data[0]?.url; + } catch (error) { + const msg = u.error(error).message || "Volcengine 图片生成失败"; + throw new Error(msg); + } +} diff --git a/src/utils/ai/image/type.ts b/src/utils/ai/image/type.ts new file mode 100644 index 0000000..bbd7eec --- /dev/null +++ b/src/utils/ai/image/type.ts @@ -0,0 +1,14 @@ +interface ImageConfig { + systemPrompt?: string; + prompt: string; + imageBase64: string[]; + size: "1K" | "2K" | "4K"; + aspectRatio: string; + resType?: "url" | "b64"; +} + +interface AIConfig { + model?: string; + apiKey?: string; + baseURL?: string; +} \ No newline at end of file diff --git a/src/utils/ai/text.ts b/src/utils/ai/text/index.ts similarity index 100% rename from src/utils/ai/text.ts rename to src/utils/ai/text/index.ts diff --git a/src/utils/ai/modelList.ts b/src/utils/ai/text/modelList.ts similarity index 100% rename from src/utils/ai/modelList.ts rename to src/utils/ai/text/modelList.ts diff --git a/src/utils/ai/utils.ts b/src/utils/ai/utils.ts new file mode 100644 index 0000000..9bcc2a9 --- /dev/null +++ b/src/utils/ai/utils.ts @@ -0,0 +1,13 @@ +export const pollTask = async ( + queryFn: () => Promise<{ completed: boolean; imageUrl?: string; error?: string }>, + maxAttempts = 500, + interval = 2000, +): Promise => { + for (let i = 0; i < maxAttempts; i++) { + await new Promise((resolve) => setTimeout(resolve, interval)); + const { completed, imageUrl, error } = await queryFn(); + if (error) throw new Error(error); + if (completed && imageUrl) return imageUrl; + } + throw new Error(`任务轮询超时,已尝试 ${maxAttempts} 次`); +}; \ No newline at end of file diff --git a/src/utils/ai/video/owned/volcengine.ts b/src/utils/ai/video/owned/volcengine.ts new file mode 100644 index 0000000..8729fad --- /dev/null +++ b/src/utils/ai/video/owned/volcengine.ts @@ -0,0 +1,70 @@ +import "../type"; +import axios from "axios"; +import u from "@/utils"; + +interface DoubaoVideoConfig { + prompt: string; + savePath: string; + imageBase64?: string[]; // 单张参考图片 base64 + duration: 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12; // 支持 2~12 秒 + aspectRatio: "16:9" | "9:16" | "1:1" | "4:3" | "3:4" | "21:9" | "adaptive"; + audio?: boolean; +} + +const pollTask = async ( + queryFn: () => Promise<{ completed: boolean; imageUrl?: string; error?: string }>, + maxAttempts = 500, + interval = 2000, +): Promise => { + for (let i = 0; i < maxAttempts; i++) { + await new Promise((resolve) => setTimeout(resolve, interval)); + const { completed, imageUrl, error } = await queryFn(); + if (error) throw new Error(error); + if (completed && imageUrl) return imageUrl; + } + throw new Error(`任务轮询超时,已尝试 ${maxAttempts} 次`); +}; + +export default async (input: ImageConfig, config: AIConfig) => { + console.log("%c Line:5 🍓 input", "background:#7f2b82", input); + console.log("%c Line:5 🍎 config", "background:#93c0a4", config); + if (!config.model) throw new Error("缺少Model名称"); + if (!config.apiKey) throw new Error("缺少API Key"); + + const key = "Bearer " + config.apiKey.replaceAll("Bearer ", "").trim(); + + const doubaoConfig = config as DoubaoVideoConfig; + const createRes = await axios.post( + config.baseURL ?? "https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks", + { + model: "doubao-seedance-1-5-pro-251215", + content: [ + { type: "text", text: input.prompt }, + ...(doubaoConfig.imageBase64 + ? doubaoConfig.imageBase64.map((base64, i) => ({ + type: "image_url", + image_url: { url: base64 }, + role: i === 0 ? "first_frame" : "last_frame", + })) + : []), + ], + generate_audio: doubaoConfig.audio ?? false, + duration: doubaoConfig.duration, + resolution: doubaoConfig.aspectRatio, + watermark: false, + }, + { headers: { "Content-Type": "application/json", Authorization: key } }, + ); + const taskId = createRes.data.id; + if (!taskId) throw new Error("视频任务创建失败"); + return await pollTask(async () => { + const res = await axios.get(`${config.baseURL ?? "https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks"}/${taskId}`, { + headers: { Authorization: key }, + }); + const { status, content } = res.data; + if (status === "succeeded") return { completed: true, imageUrl: content?.video_url }; + if (["failed", "cancelled", "expired"].includes(status)) return { completed: false, error: `任务${status}` }; + if (["queued", "running"].includes(status)) return { completed: false }; + return { completed: false, error: `未知状态: ${status}` }; + }); +}; diff --git a/src/utils/error.ts b/src/utils/error.ts new file mode 100644 index 0000000..a08789b --- /dev/null +++ b/src/utils/error.ts @@ -0,0 +1,68 @@ +// utils/error.ts +import { serializeError } from "serialize-error"; +import { isAxiosError } from "axios"; + +export interface NormalizedError { + name: string; + message: string; + code?: string; + status?: number; + stack?: string; + cause?: NormalizedError; + responseData?: unknown; + meta?: Record; +} + +export function normalizeError(error: unknown): NormalizedError { + // Axios 特殊处理 + if (isAxiosError(error)) { + return { + name: "AxiosError", + message: error.response?.data?.error?.message || error.response?.data?.message || error.message, + code: error.code, + status: error.response?.status, + stack: error.stack, + responseData: error.response?.data, + meta: { + url: error.config?.url, + method: error.config?.method, + }, + }; + } + + // 普通 Error,用 serialize-error 处理 + if (error instanceof Error) { + const serialized = serializeError(error); + return { + name: serialized.name || "Error", + message: serialized.message || "未知错误", + code: (serialized as any).code, + stack: serialized.stack, + cause: error.cause ? normalizeError(error.cause) : undefined, + meta: extractMeta(serialized), + }; + } + + // 非 Error + return { + name: "UnknownError", + message: String(error), + meta: { raw: serializeError(error) }, + }; +} + +// 提取自定义属性 +function extractMeta(obj: Record): Record | undefined { + const standardKeys = ["name", "message", "stack", "cause"]; + const meta: Record = {}; + + for (const [key, value] of Object.entries(obj)) { + if (!standardKeys.includes(key) && value !== undefined) { + meta[key] = value; + } + } + + return Object.keys(meta).length > 0 ? meta : undefined; +} + +export default normalizeError; diff --git a/src/utils/getConfig.ts b/src/utils/getConfig.ts index 3bf3859..86a7791 100644 --- a/src/utils/getConfig.ts +++ b/src/utils/getConfig.ts @@ -13,8 +13,9 @@ interface TextResData extends BaseConfig { manufacturer: "deepseek" | "openAi" | "doubao"; } +// 图像模型配置接口 interface ImageResData extends BaseConfig { - manufacturer: "openAi" | "gemini" | "volcengine" | "runninghub" | "apimart"; + manufacturer: "gemini" | "volcengine" | "kling" | "vidu" | "runninghub" | "apimart"; } interface VideoResData extends BaseConfig { diff --git a/src/utils/imageTools.ts b/src/utils/imageTools.ts new file mode 100644 index 0000000..82b519b --- /dev/null +++ b/src/utils/imageTools.ts @@ -0,0 +1,122 @@ +import sharp from "sharp"; + +/** + * 解析大小字符串为字节数 + */ +function parseSize(size: string): number { + const match = size.toLowerCase().match(/^(\d+(?:\.\d+)?)\s*(kb|mb|gb|b)?$/); + if (!match) { + throw new Error(`无效的大小格式: ${size}`); + } + const value = parseFloat(match[1]); + const unit = match[2] || "b"; + const multipliers: Record = { + b: 1, + kb: 1024, + mb: 1024 * 1024, + gb: 1024 * 1024 * 1024, + }; + return Math.floor(value * multipliers[unit]); +} + +/** + * 将base64字符串转换为Buffer + */ +function base64ToBuffer(base64: string): Buffer { + const base64Data = base64.replace(/^data:image\/\w+;base64,/, ""); + return Buffer.from(base64Data, "base64"); +} + +/** + * 压缩Buffer到指定大小以内 + */ +async function compressToSize(imageBuffer: Buffer, maxBytes: number, originalWidth: number, originalHeight: number): Promise { + let quality = 90; + let scale = 1; + + while (true) { + const targetWidth = Math.round(originalWidth * scale); + const targetHeight = Math.round(originalHeight * scale); + + const resultBuffer = await sharp(imageBuffer).resize(targetWidth, targetHeight, { fit: "fill" }).jpeg({ quality }).toBuffer(); + + if (resultBuffer.length <= maxBytes) { + return resultBuffer; + } + + if (quality > 10) { + quality -= 10; + } else { + quality = 90; + scale *= 0.8; + } + } +} + +/** + * 压缩单张图片到指定大小以内 + * @param imageBase64 - base64编码的图片 + * @param maxSize - 最大输出大小,支持格式如 "10mb", "5MB", "1024kb" 等 + * @returns 压缩后的图片base64字符串 + */ +export async function compressImage(imageBase64: string, maxSize = "10mb"): Promise { + const maxBytes = parseSize(maxSize); + const imageBuffer = base64ToBuffer(imageBase64); + const metadata = await sharp(imageBuffer).metadata(); + const resultBuffer = await compressToSize(imageBuffer, maxBytes, metadata.width || 1, metadata.height || 1); + return resultBuffer.toString("base64"); +} + +/** + * 将多张图片横向拼接为一张,并确保输出大小不超过指定限制 + * @param imageBase64List - base64编码的图片数组 + * @param maxSize - 最大输出大小,支持格式如 "10mb", "5MB", "1024kb" 等 + * @returns 拼接后的图片base64字符串 + */ +export async function mergeImages(imageBase64List: string[], maxSize = "10mb"): Promise { + if (imageBase64List.length === 0) { + throw new Error("图片列表不能为空"); + } + + const maxBytes = parseSize(maxSize); + const imageBuffers = imageBase64List.map(base64ToBuffer); + const imageMetadatas = await Promise.all(imageBuffers.map((buffer) => sharp(buffer).metadata())); + const maxHeight = Math.max(...imageMetadatas.map((m) => m.height || 0)); + + // 计算各图片调整后的宽度 + const imageWidths = imageMetadatas.map((metadata) => { + const aspectRatio = (metadata.width || 1) / (metadata.height || 1); + return Math.round(maxHeight * aspectRatio); + }); + const totalWidth = imageWidths.reduce((sum, w) => sum + w, 0); + + // 拼接图片 + const resizedImages = await Promise.all( + imageBuffers.map(async (buffer, index) => { + return sharp(buffer).resize(imageWidths[index], maxHeight, { fit: "cover" }).toBuffer(); + }), + ); + + let currentX = 0; + const compositeInputs = resizedImages.map((buffer, index) => { + const input = { input: buffer, left: currentX, top: 0 }; + currentX += imageWidths[index]; + return input; + }); + + const mergedBuffer = await sharp({ + create: { + width: totalWidth, + height: maxHeight, + channels: 4, + background: { r: 255, g: 255, b: 255, alpha: 1 }, + }, + }) + .composite(compositeInputs) + .jpeg({ quality: 90 }) + .toBuffer(); + + // 复用压缩逻辑 + const resultBuffer = await compressToSize(mergedBuffer, maxBytes, totalWidth, maxHeight); + return resultBuffer.toString("base64"); +} \ No newline at end of file diff --git a/yarn.lock b/yarn.lock index f52c0c3..a00a10f 100644 --- a/yarn.lock +++ b/yarn.lock @@ -4564,6 +4564,11 @@ nodemon@^3.1.11: touch "^3.1.0" undefsafe "^2.0.5" +non-error@^0.1.0: + version "0.1.0" + resolved "https://registry.npmmirror.com/non-error/-/non-error-0.1.0.tgz#b78b7d9a67ccb03ac979f9758813336ca7521cf2" + integrity sha512-TMB1uHiGsHRGv1uYclfhivcnf0/PdFp2pNqRxXjncaAsjYMoisaQJI+SSZCqRq+VliwRTC8tsMQfmrWjDMhkPQ== + nopt@^4.0.1: version "4.0.3" resolved "https://registry.npmmirror.com/nopt/-/nopt-4.0.3.tgz#a375cad9d02fd921278d954c2254d5aa57e15e48" @@ -5330,6 +5335,14 @@ send@^1.1.0, send@^1.2.0: range-parser "^1.2.1" statuses "^2.0.2" +serialize-error@^13.0.1: + version "13.0.1" + resolved "https://registry.npmmirror.com/serialize-error/-/serialize-error-13.0.1.tgz#dd1e1bf6d3e3d01037d126bd95e919f48b0c8ec0" + integrity sha512-bBZaRwLH9PN5HbLCjPId4dP5bNGEtumcErgOX952IsvOhVPrm3/AeK1y0UHA/QaPG701eg0yEnOKsCOC6X/kaA== + dependencies: + non-error "^0.1.0" + type-fest "^5.4.1" + serialize-error@^7.0.1: version "7.0.1" resolved "https://registry.npmmirror.com/serialize-error/-/serialize-error-7.0.1.tgz#f1360b0447f61ffb483ec4157c737fab7d778e18" @@ -5760,6 +5773,11 @@ supports-preserve-symlinks-flag@^1.0.0: resolved "https://registry.npmmirror.com/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz#6eda4bd344a3c94aea376d4cc31bc77311039e09" integrity sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w== +tagged-tag@^1.0.0: + version "1.0.0" + resolved "https://registry.npmmirror.com/tagged-tag/-/tagged-tag-1.0.0.tgz#a0b5917c2864cba54841495abfa3f6b13edcf4d6" + integrity sha512-yEFYrVhod+hdNyx7g5Bnkkb0G6si8HJurOoOEgC8B/O0uXLHlaey/65KRv6cuWBNhBgHKAROVpc7QyYqE5gFng== + tar-fs@^2.0.0: version "2.1.4" resolved "https://registry.npmmirror.com/tar-fs/-/tar-fs-2.1.4.tgz#800824dbf4ef06ded9afea4acafe71c67c76b930" @@ -5929,6 +5947,13 @@ type-fest@^0.13.1: resolved "https://registry.npmmirror.com/type-fest/-/type-fest-0.13.1.tgz#0172cb5bce80b0bd542ea348db50c7e21834d934" integrity sha512-34R7HTnG0XIJcBSn5XhDd7nNFPRcXYRZrBB2O2jdKqYODldSzBAqzsWoZYYvduky73toYS/ESqxPvkDf/F0XMg== +type-fest@^5.4.1: + version "5.4.3" + resolved "https://registry.npmmirror.com/type-fest/-/type-fest-5.4.3.tgz#b4c7e028da129098911ee2162a0c30df8a1be904" + integrity sha512-AXSAQJu79WGc79/3e9/CR77I/KQgeY1AhNvcShIH4PTcGYyC4xv6H4R4AUOwkPS5799KlVDAu8zExeCrkGquiA== + dependencies: + tagged-tag "^1.0.0" + type-is@^2.0.1: version "2.0.1" resolved "https://registry.npmmirror.com/type-is/-/type-is-2.0.1.tgz#64f6cf03f92fce4015c2b224793f6bdd4b068c97"