完善分镜langchain移除,补充图片生成厂商
This commit is contained in:
parent
f29704a4ad
commit
88a58b886b
@ -49,6 +49,7 @@
|
|||||||
"langchain": "^1.2.10",
|
"langchain": "^1.2.10",
|
||||||
"morgan": "^1.10.1",
|
"morgan": "^1.10.1",
|
||||||
"qwen-ai-provider": "^0.1.1",
|
"qwen-ai-provider": "^0.1.1",
|
||||||
|
"serialize-error": "^13.0.1",
|
||||||
"sharp": "^0.34.5",
|
"sharp": "^0.34.5",
|
||||||
"sqlite3": "^5.1.7",
|
"sqlite3": "^5.1.7",
|
||||||
"zhipu-ai-provider": "^0.2.2",
|
"zhipu-ai-provider": "^0.2.2",
|
||||||
|
|||||||
@ -1,10 +1,8 @@
|
|||||||
// @/agents/Storyboard.ts
|
// @/agents/Storyboard.ts
|
||||||
import u from "@/utils";
|
import u from "@/utils";
|
||||||
import { createAgent } from "langchain";
|
import { tool, ModelMessage, Tool } from "ai";
|
||||||
import { EventEmitter } from "events";
|
import { EventEmitter } from "events";
|
||||||
import { openAI } from "@/agents/models";
|
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
import { tool } from "@langchain/core/tools";
|
|
||||||
import type { DB } from "@/types/database";
|
import type { DB } from "@/types/database";
|
||||||
import generateImageTool from "./generateImageTool";
|
import generateImageTool from "./generateImageTool";
|
||||||
import imageSplitting from "./imageSplitting";
|
import imageSplitting from "./imageSplitting";
|
||||||
@ -46,7 +44,7 @@ export default class Storyboard {
|
|||||||
private readonly projectId: number;
|
private readonly projectId: number;
|
||||||
private readonly scriptId: number;
|
private readonly scriptId: number;
|
||||||
readonly emitter = new EventEmitter();
|
readonly emitter = new EventEmitter();
|
||||||
history: Array<[string, string]> = [];
|
history: ModelMessage[] = [];
|
||||||
novelChapters: DB["t_novel"][] = [];
|
novelChapters: DB["t_novel"][] = [];
|
||||||
|
|
||||||
// 存储 segmentAgent 生成的片段结果
|
// 存储 segmentAgent 生成的片段结果
|
||||||
@ -58,10 +56,6 @@ export default class Storyboard {
|
|||||||
// 存储正在生成分镜图的分镜ID
|
// 存储正在生成分镜图的分镜ID
|
||||||
private generatingShots: Set<number> = new Set();
|
private generatingShots: Set<number> = new Set();
|
||||||
|
|
||||||
modelName = "gpt-4.1";
|
|
||||||
apiKey = "";
|
|
||||||
baseURL = "";
|
|
||||||
|
|
||||||
constructor(projectId: number, scriptId: number) {
|
constructor(projectId: number, scriptId: number) {
|
||||||
this.projectId = projectId;
|
this.projectId = projectId;
|
||||||
this.scriptId = scriptId;
|
this.scriptId = scriptId;
|
||||||
@ -105,28 +99,28 @@ export default class Storyboard {
|
|||||||
|
|
||||||
// ==================== 剧本相关操作 ====================
|
// ==================== 剧本相关操作 ====================
|
||||||
|
|
||||||
getScript = tool(
|
getScript = tool({
|
||||||
async () => {
|
title: "getScript",
|
||||||
|
description: "获取剧本内容",
|
||||||
|
inputSchema: z.object({}),
|
||||||
|
execute: async () => {
|
||||||
this.log("获取剧本", `scriptId: ${this.scriptId}`);
|
this.log("获取剧本", `scriptId: ${this.scriptId}`);
|
||||||
const script = await u.db("t_script").where({ id: this.scriptId, projectId: this.projectId }).first();
|
const script = await u.db("t_script").where({ id: this.scriptId, projectId: this.projectId }).first();
|
||||||
if (!script) throw new Error("剧本不存在");
|
if (!script) throw new Error("剧本不存在");
|
||||||
return `剧本集:${script.name}\n\n内容:\n\`\`\`${script.content}\`\`\``;
|
return `剧本集:${script.name}\n\n内容:\n\`\`\`${script.content}\`\`\``;
|
||||||
},
|
},
|
||||||
{
|
});
|
||||||
name: "getScript",
|
|
||||||
description: "获取剧本内容",
|
|
||||||
schema: z.object({}),
|
|
||||||
verboseParsingErrors: true,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
// ==================== 资产相关操作 ====================
|
// ==================== 资产相关操作 ====================
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取资产列表(供 segmentAgent 和 shotAgent 调用)
|
* 获取资产列表(供 segmentAgent 和 shotAgent 调用)
|
||||||
*/
|
*/
|
||||||
getAssets = tool(
|
getAssets = tool({
|
||||||
async () => {
|
title: "getAssets",
|
||||||
|
description: "获取资产列表(角色、道具、场景),包含名称和详细介绍。生成片段和分镜时必须先调用此工具获取资产信息,确保名称一致性",
|
||||||
|
inputSchema: z.object({}),
|
||||||
|
execute: async () => {
|
||||||
this.log("获取资产列表", `scriptId: ${this.scriptId}`);
|
this.log("获取资产列表", `scriptId: ${this.scriptId}`);
|
||||||
const scriptData = await u.db("t_script").where({ id: this.scriptId, projectId: this.projectId }).first();
|
const scriptData = await u.db("t_script").where({ id: this.scriptId, projectId: this.projectId }).first();
|
||||||
const row = await u.db("t_outline").where({ id: scriptData?.outlineId!, projectId: this.projectId }).first();
|
const row = await u.db("t_outline").where({ id: scriptData?.outlineId!, projectId: this.projectId }).first();
|
||||||
@ -171,69 +165,69 @@ ${sections.join("\n\n")}
|
|||||||
2. 禁止在资产名称前后添加修饰词
|
2. 禁止在资产名称前后添加修饰词
|
||||||
3. 禁止捏造资产列表中不存在的角色、场景、道具`;
|
3. 禁止捏造资产列表中不存在的角色、场景、道具`;
|
||||||
},
|
},
|
||||||
{
|
});
|
||||||
name: "getAssets",
|
|
||||||
description: "获取资产列表(角色、道具、场景),包含名称和详细介绍。生成片段和分镜时必须先调用此工具获取资产信息,确保名称一致性",
|
|
||||||
schema: z.object({}),
|
|
||||||
verboseParsingErrors: true,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
// ==================== 片段和分镜工具 ====================
|
// ==================== 片段和分镜工具 ====================
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取当前存储的片段数据(供 shotAgent 调用)
|
* 获取当前存储的片段数据(供 shotAgent 调用)
|
||||||
*/
|
*/
|
||||||
getSegments = tool(
|
getSegments = tool({
|
||||||
async () => {
|
title: "getSegments",
|
||||||
|
description: "获取当前已生成的片段数据,用于生成分镜",
|
||||||
|
inputSchema: z.object({}),
|
||||||
|
execute: async () => {
|
||||||
this.log("获取片段数据", `共 ${this.segments.length} 个片段`);
|
this.log("获取片段数据", `共 ${this.segments.length} 个片段`);
|
||||||
if (this.segments.length === 0) {
|
if (this.segments.length === 0) {
|
||||||
return "暂无片段数据,请先调用 segmentAgent 生成片段";
|
return "暂无片段数据,请先调用 segmentAgent 生成片段";
|
||||||
}
|
}
|
||||||
return JSON.stringify(this.segments, null, 2);
|
return JSON.stringify(this.segments, null, 2);
|
||||||
},
|
},
|
||||||
{
|
});
|
||||||
name: "getSegments",
|
|
||||||
description: "获取当前已生成的片段数据,用于生成分镜",
|
|
||||||
schema: z.object({}),
|
|
||||||
verboseParsingErrors: true,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 更新/存储片段数据(供 segmentAgent 调用)
|
* 更新/存储片段数据(供 segmentAgent 调用)
|
||||||
*/
|
*/
|
||||||
updateSegments = tool(
|
updateSegments = tool({
|
||||||
async ({ segments }: { segments: Segment[] }) => {
|
title: "updateSegments",
|
||||||
|
description: "存储生成的片段数据,segmentAgent 在生成片段后必须调用此工具保存结果",
|
||||||
|
inputSchema: z.object({
|
||||||
|
segments: z
|
||||||
|
.array(
|
||||||
|
z.object({
|
||||||
|
index: z.number().describe("片段序号"),
|
||||||
|
description: z.string().describe("片段描述"),
|
||||||
|
emotion: z.string().optional().describe("情绪氛围"),
|
||||||
|
action: z.string().optional().describe("主要动作"),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.describe("片段数组"),
|
||||||
|
}),
|
||||||
|
execute: async ({ segments }: { segments: Segment[] }) => {
|
||||||
this.log("更新片段数据", `共 ${segments.length} 个片段`);
|
this.log("更新片段数据", `共 ${segments.length} 个片段`);
|
||||||
this.segments = segments;
|
this.segments = segments;
|
||||||
this.emit("segmentsUpdated", this.segments);
|
this.emit("segmentsUpdated", this.segments);
|
||||||
return `成功存储 ${segments.length} 个片段`;
|
return `成功存储 ${segments.length} 个片段`;
|
||||||
},
|
},
|
||||||
{
|
});
|
||||||
name: "updateSegments",
|
|
||||||
description: "存储生成的片段数据,segmentAgent 在生成片段后必须调用此工具保存结果",
|
|
||||||
schema: z.object({
|
|
||||||
segments: z
|
|
||||||
.array(
|
|
||||||
z.object({
|
|
||||||
index: z.number().describe("片段序号"),
|
|
||||||
description: z.string().describe("片段描述"),
|
|
||||||
emotion: z.string().optional().describe("情绪氛围"),
|
|
||||||
action: z.string().optional().describe("主要动作"),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.describe("片段数组"),
|
|
||||||
}),
|
|
||||||
verboseParsingErrors: true,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 添加分镜(供 shotAgent 调用)
|
* 添加分镜(供 shotAgent 调用)
|
||||||
*/
|
*/
|
||||||
addShots = tool(
|
addShots = tool({
|
||||||
async ({ shots }: { shots: Array<{ segmentIndex: number; prompts: string[] }> }) => {
|
title: "addShots",
|
||||||
|
description: "添加新的分镜。每个分镜有独立ID,包含多个镜头(每个镜头对应一个提示词)。如果片段已存在分镜会跳过",
|
||||||
|
inputSchema: z.object({
|
||||||
|
shots: z
|
||||||
|
.array(
|
||||||
|
z.object({
|
||||||
|
segmentIndex: z.number().describe("对应的片段序号"),
|
||||||
|
prompts: z.array(z.string()).describe("镜头提示词数组,每个提示词对应一个镜头(中文)"),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.describe("要添加的分镜数组"),
|
||||||
|
}),
|
||||||
|
execute: async ({ shots }: { shots: Array<{ segmentIndex: number; prompts: string[] }> }) => {
|
||||||
const added: { id: number; segmentIndex: number }[] = [];
|
const added: { id: number; segmentIndex: number }[] = [];
|
||||||
const skipped: number[] = [];
|
const skipped: number[] = [];
|
||||||
|
|
||||||
@ -266,29 +260,20 @@ ${sections.join("\n\n")}
|
|||||||
}
|
}
|
||||||
return `已添加${addedInfo}。当前共 ${this.shots.length} 个分镜`;
|
return `已添加${addedInfo}。当前共 ${this.shots.length} 个分镜`;
|
||||||
},
|
},
|
||||||
{
|
});
|
||||||
name: "addShots",
|
|
||||||
description: "添加新的分镜。每个分镜有独立ID,包含多个镜头(每个镜头对应一个提示词)。如果片段已存在分镜会跳过",
|
|
||||||
schema: z.object({
|
|
||||||
shots: z
|
|
||||||
.array(
|
|
||||||
z.object({
|
|
||||||
segmentIndex: z.number().describe("对应的片段序号"),
|
|
||||||
prompts: z.array(z.string()).describe("镜头提示词数组,每个提示词对应一个镜头(中文)"),
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
.describe("要添加的分镜数组"),
|
|
||||||
}),
|
|
||||||
verboseParsingErrors: true,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 更新指定分镜(供 shotAgent 调用)
|
* 更新指定分镜(供 shotAgent 调用)
|
||||||
* 保留原有 cells 的 id 和 src 字段,只更新 prompt
|
* 保留原有 cells 的 id 和 src 字段,只更新 prompt
|
||||||
*/
|
*/
|
||||||
updateShots = tool(
|
updateShots = tool({
|
||||||
async ({ shotId, prompts }: { shotId: number; prompts: string[] }) => {
|
title: "updateShots",
|
||||||
|
description: "更新指定分镜的镜头提示词。通过分镜ID指定要修改的分镜",
|
||||||
|
inputSchema: z.object({
|
||||||
|
shotId: z.number().describe("要更新的分镜ID"),
|
||||||
|
prompts: z.array(z.string()).describe("新的镜头提示词数组,每个提示词对应一个镜头"),
|
||||||
|
}),
|
||||||
|
execute: async ({ shotId, prompts }: { shotId: number; prompts: string[] }) => {
|
||||||
const existingIndex = this.shots.findIndex((item) => item.id === shotId);
|
const existingIndex = this.shots.findIndex((item) => item.id === shotId);
|
||||||
|
|
||||||
if (existingIndex === -1) {
|
if (existingIndex === -1) {
|
||||||
@ -314,22 +299,18 @@ ${sections.join("\n\n")}
|
|||||||
|
|
||||||
return `已更新分镜 ${shotId}`;
|
return `已更新分镜 ${shotId}`;
|
||||||
},
|
},
|
||||||
{
|
});
|
||||||
name: "updateShots",
|
|
||||||
description: "更新指定分镜的镜头提示词。通过分镜ID指定要修改的分镜",
|
|
||||||
schema: z.object({
|
|
||||||
shotId: z.number().describe("要更新的分镜ID"),
|
|
||||||
prompts: z.array(z.string()).describe("新的镜头提示词数组,每个提示词对应一个镜头"),
|
|
||||||
}),
|
|
||||||
verboseParsingErrors: true,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 删除指定分镜(供 shotAgent 调用)
|
* 删除指定分镜(供 shotAgent 调用)
|
||||||
*/
|
*/
|
||||||
deleteShots = tool(
|
deleteShots = tool({
|
||||||
async ({ shotIds }: { shotIds: number[] }) => {
|
title: "deleteShots",
|
||||||
|
description: "删除指定的分镜。通过分镜ID指定要删除的分镜",
|
||||||
|
inputSchema: z.object({
|
||||||
|
shotIds: z.array(z.number()).describe("要删除的分镜ID数组"),
|
||||||
|
}),
|
||||||
|
execute: async ({ shotIds }: { shotIds: number[] }) => {
|
||||||
const deleted: number[] = [];
|
const deleted: number[] = [];
|
||||||
const notFound: number[] = [];
|
const notFound: number[] = [];
|
||||||
|
|
||||||
@ -351,21 +332,19 @@ ${sections.join("\n\n")}
|
|||||||
}
|
}
|
||||||
return `已删除分镜 ${deleted.join(", ")}。当前共 ${this.shots.length} 个分镜`;
|
return `已删除分镜 ${deleted.join(", ")}。当前共 ${this.shots.length} 个分镜`;
|
||||||
},
|
},
|
||||||
{
|
});
|
||||||
name: "deleteShots",
|
|
||||||
description: "删除指定的分镜。通过分镜ID指定要删除的分镜",
|
|
||||||
schema: z.object({
|
|
||||||
shotIds: z.array(z.number()).describe("要删除的分镜ID数组"),
|
|
||||||
}),
|
|
||||||
verboseParsingErrors: true,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 生成分镜图(异步执行,使用 nanoBanana)
|
* 生成分镜图(异步执行,使用 nanoBanana)
|
||||||
*/
|
*/
|
||||||
generateShotImage = tool(
|
generateShotImage = tool({
|
||||||
async ({ shotIds }: { shotIds: number[] }) => {
|
title: "generateShotImage",
|
||||||
|
description:
|
||||||
|
"为指定分镜生成分镜图。每个分镜会根据其所有提示词生成一张完整宫格图,然后自动分割为单格图片。通过分镜ID指定,不需要指定具体格子,整个分镜是一个完整的生成单元",
|
||||||
|
inputSchema: z.object({
|
||||||
|
shotIds: z.array(z.number()).describe("要生成分镜图的分镜ID数组"),
|
||||||
|
}),
|
||||||
|
execute: async ({ shotIds }: { shotIds: number[] }) => {
|
||||||
const toGenerate: number[] = [];
|
const toGenerate: number[] = [];
|
||||||
const alreadyGenerating: number[] = [];
|
const alreadyGenerating: number[] = [];
|
||||||
const notFound: number[] = [];
|
const notFound: number[] = [];
|
||||||
@ -417,16 +396,7 @@ ${sections.join("\n\n")}
|
|||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
},
|
},
|
||||||
{
|
});
|
||||||
name: "generateShotImage",
|
|
||||||
description:
|
|
||||||
"为指定分镜生成分镜图。每个分镜会根据其所有提示词生成一张完整宫格图,然后自动分割为单格图片。通过分镜ID指定,不需要指定具体格子,整个分镜是一个完整的生成单元",
|
|
||||||
schema: z.object({
|
|
||||||
shotIds: z.array(z.number()).describe("要生成分镜图的分镜ID数组"),
|
|
||||||
}),
|
|
||||||
verboseParsingErrors: true,
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 执行分镜图生成的具体逻辑(异步并发)
|
* 执行分镜图生成的具体逻辑(异步并发)
|
||||||
@ -566,7 +536,7 @@ ${assetList}
|
|||||||
|
|
||||||
private buildConversationHistory(): string {
|
private buildConversationHistory(): string {
|
||||||
if (!this.history.length) return "无对话历史";
|
if (!this.history.length) return "无对话历史";
|
||||||
return this.history.map(([role, content]) => `${role}: ${content}`).join("\n\n");
|
return this.history.map(({ role, content }) => `${role}: ${content}`).join("\n\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
private async buildFullContext(task: string): Promise<string> {
|
private async buildFullContext(task: string): Promise<string> {
|
||||||
@ -586,26 +556,33 @@ ${task}
|
|||||||
|
|
||||||
// ==================== Sub-Agent ====================
|
// ==================== Sub-Agent ====================
|
||||||
|
|
||||||
private createModel() {
|
|
||||||
return openAI({
|
|
||||||
modelName: this.modelName,
|
|
||||||
configuration: { apiKey: this.apiKey, baseURL: this.baseURL },
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取不同 Sub-Agent 可用的工具
|
* 获取不同 Sub-Agent 可用的工具
|
||||||
*/
|
*/
|
||||||
private getSubAgentTools(agentType: AgentType) {
|
private getSubAgentTools(agentType: AgentType): Record<string, Tool> {
|
||||||
switch (agentType) {
|
switch (agentType) {
|
||||||
case "segmentAgent":
|
case "segmentAgent":
|
||||||
// segmentAgent 可以获取剧本和资产,并需要调用 updateSegments 保存结果
|
// segmentAgent 可以获取剧本和资产,并需要调用 updateSegments 保存结果
|
||||||
return [this.getScript, this.getAssets, this.updateSegments];
|
return {
|
||||||
|
getScript: this.getScript,
|
||||||
|
getAssets: this.getAssets,
|
||||||
|
updateSegments: this.updateSegments,
|
||||||
|
};
|
||||||
case "shotAgent":
|
case "shotAgent":
|
||||||
// shotAgent 可以获取剧本、资产和片段,并可使用 add/update/delete 操作分镜,以及生成分镜图
|
// shotAgent 可以获取剧本、资产和片段,并可使用 add/update/delete 操作分镜,以及生成分镜图
|
||||||
return [this.getScript, this.getAssets, this.getSegments, this.addShots, this.updateShots, this.deleteShots, this.generateShotImage];
|
return {
|
||||||
|
getScript: this.getScript,
|
||||||
|
getAssets: this.getAssets,
|
||||||
|
getSegments: this.getSegments,
|
||||||
|
addShots: this.addShots,
|
||||||
|
updateShots: this.updateShots,
|
||||||
|
deleteShots: this.deleteShots,
|
||||||
|
generateShotImage: this.generateShotImage,
|
||||||
|
};
|
||||||
default:
|
default:
|
||||||
return [this.getScript];
|
return {
|
||||||
|
getScript: this.getScript,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -627,74 +604,71 @@ ${task}
|
|||||||
|
|
||||||
const context = await this.buildFullContext(task);
|
const context = await this.buildFullContext(task);
|
||||||
|
|
||||||
const agent = createAgent({
|
const { fullStream } = await u.ai.text.stream({
|
||||||
model: this.createModel(),
|
system: SYSTEM_PROMPTS[agentType],
|
||||||
systemPrompt: SYSTEM_PROMPTS[agentType],
|
|
||||||
tools: this.getSubAgentTools(agentType),
|
tools: this.getSubAgentTools(agentType),
|
||||||
|
messages: [{ role: "user", content: context }],
|
||||||
|
maxStep: 100,
|
||||||
});
|
});
|
||||||
|
|
||||||
const stream = await agent.stream({ messages: [["user", context]] }, { streamMode: ["messages"], callbacks: [] });
|
|
||||||
|
|
||||||
let fullResponse = "";
|
let fullResponse = "";
|
||||||
|
for await (const item of fullStream) {
|
||||||
for await (const [mode, chunk] of stream) {
|
if (item.type == "tool-call") {
|
||||||
if (mode !== "messages") continue;
|
this.emit("toolCall", { agent: "main", name: item.title, args: null });
|
||||||
const [token] = chunk as any;
|
|
||||||
const block = token.contentBlocks?.[0];
|
|
||||||
|
|
||||||
// 处理 AI 文本流
|
|
||||||
if (token.type === "ai" && block?.text) {
|
|
||||||
fullResponse += block.text;
|
|
||||||
this.emit("subAgentStream", { agent: agentType, text: block.text });
|
|
||||||
}
|
}
|
||||||
// 处理 tool 调用
|
if (item.type == "text-delta") {
|
||||||
if (token.type === "ai" && token.tool_calls?.length) {
|
fullResponse += item.text;
|
||||||
for (const toolCall of token.tool_calls) {
|
this.emit("subAgentStream", { agent: agentType, text: item.text });
|
||||||
this.emit("toolCall", { agent: agentType, name: toolCall.name, args: toolCall.args });
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
this.emit("subAgentEnd", { agent: agentType });
|
this.emit("subAgentEnd", { agent: agentType });
|
||||||
this.history.push(["ai", fullResponse]);
|
this.history.push({
|
||||||
|
role: "assistant",
|
||||||
|
content: fullResponse,
|
||||||
|
});
|
||||||
this.log(`Sub-Agent 完成`, agentType);
|
this.log(`Sub-Agent 完成`, agentType);
|
||||||
return fullResponse;
|
|
||||||
|
return fullResponse ?? `${agentType}已完成任务`;
|
||||||
}
|
}
|
||||||
|
|
||||||
private createSubAgentTool(agentType: AgentType, description: string) {
|
private createSubAgentTool(agentType: AgentType, description: string) {
|
||||||
return tool(async ({ taskDescription }) => this.invokeSubAgent(agentType, taskDescription), {
|
return tool({
|
||||||
name: agentType,
|
title: agentType,
|
||||||
description,
|
description,
|
||||||
schema: z.object({
|
inputSchema: z.object({
|
||||||
taskDescription: z.string().describe("具体的任务描述,包含章节范围、修改要求等详细信息"),
|
taskDescription: z.string().describe("具体的任务描述,包含章节范围、修改要求等详细信息"),
|
||||||
}),
|
}),
|
||||||
|
execute: async ({ taskDescription }) => this.invokeSubAgent(agentType, taskDescription),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// ==================== 主入口 ====================
|
// ==================== 主入口 ====================
|
||||||
|
|
||||||
private getAllTools() {
|
private getAllTools() {
|
||||||
return [
|
return {
|
||||||
this.createSubAgentTool(
|
segmentAgent: this.createSubAgentTool(
|
||||||
"segmentAgent",
|
"segmentAgent",
|
||||||
"调用片段师。负责根据剧本生成片段,会自行调用 getScript 获取剧本内容,并调用 updateSegments 保存片段结果。",
|
"调用片段师。负责根据剧本生成片段,会自行调用 getScript 获取剧本内容,并调用 updateSegments 保存片段结果。",
|
||||||
),
|
),
|
||||||
this.createSubAgentTool(
|
shotAgent: this.createSubAgentTool(
|
||||||
"shotAgent",
|
"shotAgent",
|
||||||
"调用分镜师。负责根据片段生成分镜提示词,会自行调用 getSegments 获取片段数据,并调用 addShots/updateShots 保存分镜结果。",
|
"调用分镜师。负责根据片段生成分镜提示词,会自行调用 getSegments 获取片段数据,并调用 addShots/updateShots 保存分镜结果。",
|
||||||
),
|
),
|
||||||
// this.createSubAgentTool("director", "调用导演。负责审核故事线和大纲,会自行调用 updateOutline 或 saveStoryline 进行修改。"),
|
// this.createSubAgentTool("director", "调用导演。负责审核故事线和大纲,会自行调用 updateOutline 或 saveStoryline 进行修改。"),
|
||||||
this.getScript,
|
getScript: this.getScript,
|
||||||
this.getSegments,
|
getSegments: this.getSegments,
|
||||||
this.generateShotImage,
|
generateShotImage: this.generateShotImage,
|
||||||
...this.getSubAgentTools("segmentAgent"),
|
...this.getSubAgentTools("segmentAgent"),
|
||||||
...this.getSubAgentTools("shotAgent"),
|
...this.getSubAgentTools("shotAgent"),
|
||||||
];
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
async call(msg: string): Promise<string> {
|
async call(msg: string): Promise<string> {
|
||||||
console.log("模型名称:", this.modelName);
|
this.history.push({
|
||||||
this.history.push(["user", msg]);
|
role: "user",
|
||||||
|
content: msg,
|
||||||
|
});
|
||||||
|
|
||||||
const envContext = await this.buildEnvironmentContext();
|
const envContext = await this.buildEnvironmentContext();
|
||||||
|
|
||||||
@ -702,34 +676,28 @@ ${task}
|
|||||||
|
|
||||||
const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么,请直接输出Agent配置异常";
|
const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么,请直接输出Agent配置异常";
|
||||||
|
|
||||||
const mainAgent = createAgent({
|
const { fullStream } = await u.ai.text.stream({
|
||||||
model: this.createModel(),
|
system: `${envContext}\n${mainPrompts}`,
|
||||||
tools: this.getAllTools(),
|
tools: this.getAllTools(),
|
||||||
systemPrompt: `${envContext}\n${mainPrompts}`,
|
messages: this.history,
|
||||||
|
maxStep: 100,
|
||||||
});
|
});
|
||||||
const stream = await mainAgent.stream({ messages: this.history }, { streamMode: ["messages"], callbacks: [] });
|
|
||||||
|
|
||||||
let fullResponse = "";
|
let fullResponse = "";
|
||||||
|
for await (const item of fullStream) {
|
||||||
for await (const [mode, chunk] of stream) {
|
if (item.type == "tool-call") {
|
||||||
if (mode !== "messages") continue;
|
this.emit("toolCall", { agent: "main", name: item.title, args: null });
|
||||||
const [token] = chunk as any;
|
|
||||||
const block = token.contentBlocks?.[0];
|
|
||||||
// 处理 AI 文本流
|
|
||||||
if (token.type === "ai" && block?.text) {
|
|
||||||
fullResponse += block.text;
|
|
||||||
this.emit("data", block.text);
|
|
||||||
}
|
}
|
||||||
|
if (item.type == "text-delta") {
|
||||||
// 处理 tool 调用
|
fullResponse += item.text;
|
||||||
if (token.type === "ai" && token.tool_calls?.length) {
|
this.emit("data", item.text);
|
||||||
for (const toolCall of token.tool_calls) {
|
|
||||||
this.emit("toolCall", { agent: "main", name: toolCall.name, args: toolCall.args });
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
this.history.push({
|
||||||
|
role: "assistant",
|
||||||
|
content: fullResponse,
|
||||||
|
});
|
||||||
|
|
||||||
this.history.push(["assistant", fullResponse]);
|
|
||||||
this.emit("response", fullResponse);
|
this.emit("response", fullResponse);
|
||||||
|
|
||||||
return fullResponse;
|
return fullResponse;
|
||||||
|
|||||||
@ -16,25 +16,34 @@ export default router.post(
|
|||||||
}),
|
}),
|
||||||
async (req, res) => {
|
async (req, res) => {
|
||||||
const { modelName, apiKey, baseURL, manufacturer } = req.body;
|
const { modelName, apiKey, baseURL, manufacturer } = req.body;
|
||||||
try {
|
|
||||||
const contentStr = await u.ai.generateImage(
|
const image =await u.ai.image({
|
||||||
{
|
prompt: "2D cat",
|
||||||
prompt: "2D cat",
|
imageBase64: [],
|
||||||
imageBase64: [],
|
aspectRatio: "16:9",
|
||||||
aspectRatio: "16:9",
|
size: "1K",
|
||||||
size: "1K",
|
});
|
||||||
},
|
res.status(200).send(success(image));
|
||||||
{
|
|
||||||
model: modelName,
|
// try {
|
||||||
apiKey,
|
// const contentStr = await u.ai.generateImage(
|
||||||
baseURL,
|
// {
|
||||||
manufacturer,
|
// prompt: "2D cat",
|
||||||
},
|
// imageBase64: [],
|
||||||
);
|
// aspectRatio: "16:9",
|
||||||
res.status(200).send(success(contentStr));
|
// size: "1K",
|
||||||
} catch (err: any) {
|
// },
|
||||||
const message = err?.response?.data?.error?.message || err?.error?.message || "模型调用失败";
|
// {
|
||||||
res.status(500).send(error(message));
|
// model: modelName,
|
||||||
}
|
// apiKey,
|
||||||
|
// baseURL,
|
||||||
|
// manufacturer,
|
||||||
|
// },
|
||||||
|
// );
|
||||||
|
// res.status(200).send(success(contentStr));
|
||||||
|
// } catch (err: any) {
|
||||||
|
// const message = err?.response?.data?.error?.message || err?.error?.message || "模型调用失败";
|
||||||
|
// res.status(500).send(error(message));
|
||||||
|
// }
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|||||||
@ -6,18 +6,24 @@ import number2Chinese from "@/utils/number2Chinese";
|
|||||||
import deleteOutline from "@/utils/deleteOutline";
|
import deleteOutline from "@/utils/deleteOutline";
|
||||||
import getConfig from "./utils/getConfig";
|
import getConfig from "./utils/getConfig";
|
||||||
import { v4 as uuid } from "uuid";
|
import { v4 as uuid } from "uuid";
|
||||||
|
import error from "@/utils/error";
|
||||||
|
import * as imageTools from "@/utils/imageTools";
|
||||||
|
|
||||||
import AIText from "@/utils/ai/text";
|
import AIText from "@/utils/ai/text/index";
|
||||||
|
import AIImage from "@/utils/ai/image/index";
|
||||||
|
|
||||||
export default {
|
export default {
|
||||||
db,
|
db,
|
||||||
oss,
|
oss,
|
||||||
ai: {
|
ai: {
|
||||||
text: AIText,
|
text: AIText,
|
||||||
|
image: AIImage,
|
||||||
},
|
},
|
||||||
editImage,
|
editImage,
|
||||||
number2Chinese,
|
number2Chinese,
|
||||||
deleteOutline,
|
deleteOutline,
|
||||||
getConfig,
|
getConfig,
|
||||||
uuid,
|
uuid,
|
||||||
|
error,
|
||||||
|
imageTools,
|
||||||
};
|
};
|
||||||
|
|||||||
44
src/utils/ai/image/index.ts
Normal file
44
src/utils/ai/image/index.ts
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
import "./type";
|
||||||
|
import u from "@/utils";
|
||||||
|
import modelList from "./modelList";
|
||||||
|
import axios from "axios";
|
||||||
|
|
||||||
|
import volcengine from "./owned/volcengine";
|
||||||
|
import kling from "./owned/kling";
|
||||||
|
|
||||||
|
|
||||||
|
interface AIConfig {
|
||||||
|
model?: string;
|
||||||
|
apiKey?: string;
|
||||||
|
baseURL?: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const urlToBase64 = async (url: string): Promise<string> => {
|
||||||
|
const res = await axios.get(url, { responseType: "arraybuffer" });
|
||||||
|
const base64 = Buffer.from(res.data).toString("base64");
|
||||||
|
const mimeType = res.headers["content-type"] || "image/png";
|
||||||
|
return `data:${mimeType};base64,${base64}`;
|
||||||
|
};
|
||||||
|
|
||||||
|
const modelInstance = {
|
||||||
|
gemini: null,
|
||||||
|
volcengine: volcengine,
|
||||||
|
kling: kling,
|
||||||
|
vidu: null,
|
||||||
|
runninghub: null,
|
||||||
|
apimart: null,
|
||||||
|
} as const;
|
||||||
|
|
||||||
|
export default async (input: ImageConfig, config?: AIConfig) => {
|
||||||
|
const sqlTextModelConfig = await u.getConfig("image");
|
||||||
|
const { model, apiKey, baseURL, manufacturer } = { ...sqlTextModelConfig, ...config };
|
||||||
|
const manufacturerFn = modelInstance[manufacturer as keyof typeof modelInstance];
|
||||||
|
if (!manufacturerFn) if (!manufacturerFn) throw new Error("不支持的图片厂商");
|
||||||
|
const owned = modelList.find((m) => m.model === model);
|
||||||
|
if (!owned) throw new Error("不支持的模型");
|
||||||
|
|
||||||
|
let imageUrl = await manufacturerFn(input, { model, apiKey, baseURL });
|
||||||
|
if (!input.resType) input.resType = "b64";
|
||||||
|
if (input.resType === "b64" && imageUrl.startsWith("http")) imageUrl = await urlToBase64(imageUrl);
|
||||||
|
return input;
|
||||||
|
};
|
||||||
77
src/utils/ai/image/modelList.ts
Normal file
77
src/utils/ai/image/modelList.ts
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
interface Owned {
|
||||||
|
manufacturer: string;
|
||||||
|
model: string;
|
||||||
|
grid: boolean;
|
||||||
|
type: "t2i" | "ti2i" | "i2i";
|
||||||
|
}
|
||||||
|
|
||||||
|
const modelList: Owned[] = [
|
||||||
|
// 火山引擎
|
||||||
|
{
|
||||||
|
manufacturer: "volcengine",
|
||||||
|
model: "doubao-seedream-4-5-251128",
|
||||||
|
grid: false,
|
||||||
|
type: "ti2i",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
manufacturer: "volcengine",
|
||||||
|
model: "doubao-seedream-4-0-250828",
|
||||||
|
grid: false,
|
||||||
|
type: "ti2i",
|
||||||
|
},
|
||||||
|
//可灵
|
||||||
|
{
|
||||||
|
manufacturer: "kling",
|
||||||
|
model: "kling-image-o1",
|
||||||
|
grid: false,
|
||||||
|
type: "ti2i",
|
||||||
|
},
|
||||||
|
//gemini
|
||||||
|
{
|
||||||
|
manufacturer: "gemini",
|
||||||
|
model: "gemini-2.5-flash-image",
|
||||||
|
grid: true,
|
||||||
|
type: "ti2i",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
manufacturer: "gemini",
|
||||||
|
model: "gemini-2.5-flash-image-preview",
|
||||||
|
grid: true,
|
||||||
|
type: "ti2i",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
manufacturer: "gemini",
|
||||||
|
model: "gemini-2.5-flash-image-preview-all",
|
||||||
|
grid: true,
|
||||||
|
type: "ti2i",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
manufacturer: "gemini",
|
||||||
|
model: "gemini-3-pro-image-preview",
|
||||||
|
grid: true,
|
||||||
|
type: "ti2i",
|
||||||
|
},
|
||||||
|
//Vidu
|
||||||
|
{
|
||||||
|
manufacturer: "vidu",
|
||||||
|
model: "viduq2",
|
||||||
|
grid: false,
|
||||||
|
type: "ti2i",
|
||||||
|
},
|
||||||
|
//RunningHub
|
||||||
|
{
|
||||||
|
manufacturer: "runninghub",
|
||||||
|
model: "nanobanana",
|
||||||
|
grid: true,
|
||||||
|
type: "ti2i",
|
||||||
|
},
|
||||||
|
//ApiMart
|
||||||
|
{
|
||||||
|
manufacturer: "apimart",
|
||||||
|
model: "nanobanana",
|
||||||
|
grid: true,
|
||||||
|
type: "ti2i",
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
export default modelList;
|
||||||
34
src/utils/ai/image/owned/gemini.ts
Normal file
34
src/utils/ai/image/owned/gemini.ts
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import "../type";
|
||||||
|
import { createGoogleGenerativeAI } from "@ai-sdk/google";
|
||||||
|
import { generateImage } from "ai";
|
||||||
|
|
||||||
|
export default async (input: ImageConfig, config: AIConfig): Promise<string> => {
|
||||||
|
if (!config.model) throw new Error("缺少Model名称");
|
||||||
|
if (!config.apiKey) throw new Error("缺少API Key");
|
||||||
|
if (!input.prompt) throw new Error("缺少提示词");
|
||||||
|
|
||||||
|
const google = createGoogleGenerativeAI({
|
||||||
|
apiKey: config.apiKey,
|
||||||
|
baseURL: config.baseURL,
|
||||||
|
});
|
||||||
|
|
||||||
|
// 构建完整的提示词
|
||||||
|
const fullPrompt = input.systemPrompt ? `${input.systemPrompt}\n\n${input.prompt}` : input.prompt;
|
||||||
|
|
||||||
|
// 根据 size 配置映射到具体尺寸
|
||||||
|
const sizeMap: Record<string, `${number}x${number}`> = {
|
||||||
|
"1K": "1024x1024",
|
||||||
|
"2K": "2048x2048",
|
||||||
|
"4K": "4096x4096",
|
||||||
|
};
|
||||||
|
|
||||||
|
const { image } = await generateImage({
|
||||||
|
model: google.image(config.model),
|
||||||
|
prompt: fullPrompt,
|
||||||
|
aspectRatio: input.aspectRatio as "1:1" | "3:4" | "4:3" | "9:16" | "16:9",
|
||||||
|
size: sizeMap[input.size] ?? "1024x1024",
|
||||||
|
});
|
||||||
|
|
||||||
|
// 返回生成的图片 base64
|
||||||
|
return image.base64;
|
||||||
|
};
|
||||||
107
src/utils/ai/image/owned/kling.ts
Normal file
107
src/utils/ai/image/owned/kling.ts
Normal file
@ -0,0 +1,107 @@
|
|||||||
|
import "../type";
|
||||||
|
import axios from "axios";
|
||||||
|
import jwt from "jsonwebtoken";
|
||||||
|
import u from "@/utils";
|
||||||
|
import { pollTask } from "@/utils/ai/utils";
|
||||||
|
|
||||||
|
function generateJwtToken(ak: string, sk: string): string {
|
||||||
|
const now = Math.floor(Date.now() / 1000);
|
||||||
|
const payload = {
|
||||||
|
iss: ak,
|
||||||
|
exp: now + 1800,
|
||||||
|
nbf: now - 5,
|
||||||
|
};
|
||||||
|
return jwt.sign(payload, sk, {
|
||||||
|
algorithm: "HS256",
|
||||||
|
header: { alg: "HS256", typ: "JWT" },
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function getApiToken(apiKey: string): string {
|
||||||
|
const trimmedKey = apiKey.replace(/^Bearer\s+/i, "").trim();
|
||||||
|
|
||||||
|
if (trimmedKey.includes("|")) {
|
||||||
|
const parts = trimmedKey.split("|");
|
||||||
|
if (parts.length !== 2 || !parts[0].trim() || !parts[1].trim()) {
|
||||||
|
throw new Error("API Key格式错误,请使用 ak|sk 格式");
|
||||||
|
}
|
||||||
|
return generateJwtToken(parts[0].trim(), parts[1].trim());
|
||||||
|
}
|
||||||
|
|
||||||
|
return trimmedKey;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function processImages(imageBase64: string[]): Promise<Array<{ image: string }>> {
|
||||||
|
let images = imageBase64.filter((img) => img?.trim());
|
||||||
|
if (images.length === 0) return [];
|
||||||
|
|
||||||
|
// 压缩所有图片到10MB以内
|
||||||
|
images = await Promise.all(images.map((img) => u.imageTools.compressImage(img, "10mb")));
|
||||||
|
|
||||||
|
// 参考主体数量和参考图片数量之和不得超过10
|
||||||
|
if (images.length > 10) {
|
||||||
|
const mergeImageList = images.splice(9);
|
||||||
|
const mergedImage = await u.imageTools.mergeImages(mergeImageList, "10mb");
|
||||||
|
images.push(mergedImage);
|
||||||
|
}
|
||||||
|
|
||||||
|
return images.map((img) => ({
|
||||||
|
image: img.replace(/^data:image\/[a-z]+;base64,/i, ""),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
export default async (input: ImageConfig, config: AIConfig): Promise<string> => {
|
||||||
|
if (!config.apiKey) throw new Error("缺少API Key");
|
||||||
|
if (!input.prompt) throw new Error("缺少提示词,prompt为必填项");
|
||||||
|
|
||||||
|
const authorization = `Bearer ${getApiToken(config.apiKey)}`;
|
||||||
|
const baseURL = (config.baseURL ?? "https://api-beijing.klingai.com/v1/images/omni-image").replace(/\/+$/, "");
|
||||||
|
const imageList = await processImages(input.imageBase64);
|
||||||
|
|
||||||
|
const body: Record<string, any> = {
|
||||||
|
model_name: config.model || "kling-image-o1",
|
||||||
|
prompt: input.prompt,
|
||||||
|
n: 1,
|
||||||
|
...(input.size !== "4K" && { resolution: input.size.toLowerCase() }),
|
||||||
|
...(imageList.length > 0 && { image_list: imageList }),
|
||||||
|
};
|
||||||
|
|
||||||
|
const headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
Authorization: authorization,
|
||||||
|
};
|
||||||
|
|
||||||
|
try {
|
||||||
|
const { data: createData } = await axios.post(baseURL, body, { headers });
|
||||||
|
|
||||||
|
if (createData.code !== 0) {
|
||||||
|
throw new Error(createData.message || "创建任务失败");
|
||||||
|
}
|
||||||
|
|
||||||
|
const taskId = createData.data?.task_id;
|
||||||
|
if (!taskId) throw new Error("未获取到任务ID");
|
||||||
|
|
||||||
|
const queryUrl = `${baseURL}/${taskId}`;
|
||||||
|
return await pollTask(async () => {
|
||||||
|
const { data: queryData } = await axios.get(queryUrl, { headers });
|
||||||
|
|
||||||
|
if (queryData.code !== 0) {
|
||||||
|
return { completed: false, error: queryData.message || "查询任务失败" };
|
||||||
|
}
|
||||||
|
|
||||||
|
const { task_status, task_status_msg, task_result } = queryData.data || {};
|
||||||
|
|
||||||
|
if (task_status === "failed") {
|
||||||
|
return { completed: false, error: task_status_msg || "图片生成失败" };
|
||||||
|
}
|
||||||
|
|
||||||
|
if (task_status === "succeed") {
|
||||||
|
return { completed: true, imageUrl: task_result?.images?.[0]?.url };
|
||||||
|
}
|
||||||
|
|
||||||
|
return { completed: false };
|
||||||
|
});
|
||||||
|
} catch (error) {
|
||||||
|
throw new Error(u.error(error).message || "可灵图片生成失败");
|
||||||
|
}
|
||||||
|
}
|
||||||
31
src/utils/ai/image/owned/volcengine.ts
Normal file
31
src/utils/ai/image/owned/volcengine.ts
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import "../type";
|
||||||
|
import axios from "axios";
|
||||||
|
import u from "@/utils";
|
||||||
|
|
||||||
|
export default async (input: ImageConfig, config: AIConfig): Promise<string> => {
|
||||||
|
if (!config.model) throw new Error("缺少Model名称");
|
||||||
|
if (!config.apiKey) throw new Error("缺少API Key");
|
||||||
|
|
||||||
|
const apiKey = "Bearer " + config.apiKey.replace(/Bearer\s+/g, "").trim();
|
||||||
|
const size = input.size === "1K" ? "2K" : input.size;
|
||||||
|
|
||||||
|
const body: Record<string, any> = {
|
||||||
|
model: config.model,
|
||||||
|
prompt: input.prompt,
|
||||||
|
size,
|
||||||
|
response_format: "url",
|
||||||
|
sequential_image_generation: "disabled",
|
||||||
|
stream: false,
|
||||||
|
watermark: false,
|
||||||
|
...(input.imageBase64 && { image: input.imageBase64 }),
|
||||||
|
};
|
||||||
|
|
||||||
|
const url = config.baseURL ?? "https://ark.cn-beijing.volces.com/api/v3/images/generations";
|
||||||
|
try {
|
||||||
|
const { data } = await axios.post(url, body, { headers: { Authorization: apiKey } });
|
||||||
|
return data.data[0]?.url;
|
||||||
|
} catch (error) {
|
||||||
|
const msg = u.error(error).message || "Volcengine 图片生成失败";
|
||||||
|
throw new Error(msg);
|
||||||
|
}
|
||||||
|
}
|
||||||
14
src/utils/ai/image/type.ts
Normal file
14
src/utils/ai/image/type.ts
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
interface ImageConfig {
|
||||||
|
systemPrompt?: string;
|
||||||
|
prompt: string;
|
||||||
|
imageBase64: string[];
|
||||||
|
size: "1K" | "2K" | "4K";
|
||||||
|
aspectRatio: string;
|
||||||
|
resType?: "url" | "b64";
|
||||||
|
}
|
||||||
|
|
||||||
|
interface AIConfig {
|
||||||
|
model?: string;
|
||||||
|
apiKey?: string;
|
||||||
|
baseURL?: string;
|
||||||
|
}
|
||||||
13
src/utils/ai/utils.ts
Normal file
13
src/utils/ai/utils.ts
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
export const pollTask = async (
|
||||||
|
queryFn: () => Promise<{ completed: boolean; imageUrl?: string; error?: string }>,
|
||||||
|
maxAttempts = 500,
|
||||||
|
interval = 2000,
|
||||||
|
): Promise<string> => {
|
||||||
|
for (let i = 0; i < maxAttempts; i++) {
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, interval));
|
||||||
|
const { completed, imageUrl, error } = await queryFn();
|
||||||
|
if (error) throw new Error(error);
|
||||||
|
if (completed && imageUrl) return imageUrl;
|
||||||
|
}
|
||||||
|
throw new Error(`任务轮询超时,已尝试 ${maxAttempts} 次`);
|
||||||
|
};
|
||||||
70
src/utils/ai/video/owned/volcengine.ts
Normal file
70
src/utils/ai/video/owned/volcengine.ts
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import "../type";
|
||||||
|
import axios from "axios";
|
||||||
|
import u from "@/utils";
|
||||||
|
|
||||||
|
interface DoubaoVideoConfig {
|
||||||
|
prompt: string;
|
||||||
|
savePath: string;
|
||||||
|
imageBase64?: string[]; // 单张参考图片 base64
|
||||||
|
duration: 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12; // 支持 2~12 秒
|
||||||
|
aspectRatio: "16:9" | "9:16" | "1:1" | "4:3" | "3:4" | "21:9" | "adaptive";
|
||||||
|
audio?: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
const pollTask = async (
|
||||||
|
queryFn: () => Promise<{ completed: boolean; imageUrl?: string; error?: string }>,
|
||||||
|
maxAttempts = 500,
|
||||||
|
interval = 2000,
|
||||||
|
): Promise<string> => {
|
||||||
|
for (let i = 0; i < maxAttempts; i++) {
|
||||||
|
await new Promise((resolve) => setTimeout(resolve, interval));
|
||||||
|
const { completed, imageUrl, error } = await queryFn();
|
||||||
|
if (error) throw new Error(error);
|
||||||
|
if (completed && imageUrl) return imageUrl;
|
||||||
|
}
|
||||||
|
throw new Error(`任务轮询超时,已尝试 ${maxAttempts} 次`);
|
||||||
|
};
|
||||||
|
|
||||||
|
export default async (input: ImageConfig, config: AIConfig) => {
|
||||||
|
console.log("%c Line:5 🍓 input", "background:#7f2b82", input);
|
||||||
|
console.log("%c Line:5 🍎 config", "background:#93c0a4", config);
|
||||||
|
if (!config.model) throw new Error("缺少Model名称");
|
||||||
|
if (!config.apiKey) throw new Error("缺少API Key");
|
||||||
|
|
||||||
|
const key = "Bearer " + config.apiKey.replaceAll("Bearer ", "").trim();
|
||||||
|
|
||||||
|
const doubaoConfig = config as DoubaoVideoConfig;
|
||||||
|
const createRes = await axios.post(
|
||||||
|
config.baseURL ?? "https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks",
|
||||||
|
{
|
||||||
|
model: "doubao-seedance-1-5-pro-251215",
|
||||||
|
content: [
|
||||||
|
{ type: "text", text: input.prompt },
|
||||||
|
...(doubaoConfig.imageBase64
|
||||||
|
? doubaoConfig.imageBase64.map((base64, i) => ({
|
||||||
|
type: "image_url",
|
||||||
|
image_url: { url: base64 },
|
||||||
|
role: i === 0 ? "first_frame" : "last_frame",
|
||||||
|
}))
|
||||||
|
: []),
|
||||||
|
],
|
||||||
|
generate_audio: doubaoConfig.audio ?? false,
|
||||||
|
duration: doubaoConfig.duration,
|
||||||
|
resolution: doubaoConfig.aspectRatio,
|
||||||
|
watermark: false,
|
||||||
|
},
|
||||||
|
{ headers: { "Content-Type": "application/json", Authorization: key } },
|
||||||
|
);
|
||||||
|
const taskId = createRes.data.id;
|
||||||
|
if (!taskId) throw new Error("视频任务创建失败");
|
||||||
|
return await pollTask(async () => {
|
||||||
|
const res = await axios.get(`${config.baseURL ?? "https://ark.cn-beijing.volces.com/api/v3/contents/generations/tasks"}/${taskId}`, {
|
||||||
|
headers: { Authorization: key },
|
||||||
|
});
|
||||||
|
const { status, content } = res.data;
|
||||||
|
if (status === "succeeded") return { completed: true, imageUrl: content?.video_url };
|
||||||
|
if (["failed", "cancelled", "expired"].includes(status)) return { completed: false, error: `任务${status}` };
|
||||||
|
if (["queued", "running"].includes(status)) return { completed: false };
|
||||||
|
return { completed: false, error: `未知状态: ${status}` };
|
||||||
|
});
|
||||||
|
};
|
||||||
68
src/utils/error.ts
Normal file
68
src/utils/error.ts
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
// utils/error.ts
|
||||||
|
import { serializeError } from "serialize-error";
|
||||||
|
import { isAxiosError } from "axios";
|
||||||
|
|
||||||
|
export interface NormalizedError {
|
||||||
|
name: string;
|
||||||
|
message: string;
|
||||||
|
code?: string;
|
||||||
|
status?: number;
|
||||||
|
stack?: string;
|
||||||
|
cause?: NormalizedError;
|
||||||
|
responseData?: unknown;
|
||||||
|
meta?: Record<string, unknown>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function normalizeError(error: unknown): NormalizedError {
|
||||||
|
// Axios 特殊处理
|
||||||
|
if (isAxiosError(error)) {
|
||||||
|
return {
|
||||||
|
name: "AxiosError",
|
||||||
|
message: error.response?.data?.error?.message || error.response?.data?.message || error.message,
|
||||||
|
code: error.code,
|
||||||
|
status: error.response?.status,
|
||||||
|
stack: error.stack,
|
||||||
|
responseData: error.response?.data,
|
||||||
|
meta: {
|
||||||
|
url: error.config?.url,
|
||||||
|
method: error.config?.method,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// 普通 Error,用 serialize-error 处理
|
||||||
|
if (error instanceof Error) {
|
||||||
|
const serialized = serializeError(error);
|
||||||
|
return {
|
||||||
|
name: serialized.name || "Error",
|
||||||
|
message: serialized.message || "未知错误",
|
||||||
|
code: (serialized as any).code,
|
||||||
|
stack: serialized.stack,
|
||||||
|
cause: error.cause ? normalizeError(error.cause) : undefined,
|
||||||
|
meta: extractMeta(serialized),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// 非 Error
|
||||||
|
return {
|
||||||
|
name: "UnknownError",
|
||||||
|
message: String(error),
|
||||||
|
meta: { raw: serializeError(error) },
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提取自定义属性
|
||||||
|
function extractMeta(obj: Record<string, unknown>): Record<string, unknown> | undefined {
|
||||||
|
const standardKeys = ["name", "message", "stack", "cause"];
|
||||||
|
const meta: Record<string, unknown> = {};
|
||||||
|
|
||||||
|
for (const [key, value] of Object.entries(obj)) {
|
||||||
|
if (!standardKeys.includes(key) && value !== undefined) {
|
||||||
|
meta[key] = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Object.keys(meta).length > 0 ? meta : undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default normalizeError;
|
||||||
@ -13,8 +13,9 @@ interface TextResData extends BaseConfig {
|
|||||||
manufacturer: "deepseek" | "openAi" | "doubao";
|
manufacturer: "deepseek" | "openAi" | "doubao";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 图像模型配置接口
|
||||||
interface ImageResData extends BaseConfig {
|
interface ImageResData extends BaseConfig {
|
||||||
manufacturer: "openAi" | "gemini" | "volcengine" | "runninghub" | "apimart";
|
manufacturer: "gemini" | "volcengine" | "kling" | "vidu" | "runninghub" | "apimart";
|
||||||
}
|
}
|
||||||
|
|
||||||
interface VideoResData extends BaseConfig {
|
interface VideoResData extends BaseConfig {
|
||||||
|
|||||||
122
src/utils/imageTools.ts
Normal file
122
src/utils/imageTools.ts
Normal file
@ -0,0 +1,122 @@
|
|||||||
|
import sharp from "sharp";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析大小字符串为字节数
|
||||||
|
*/
|
||||||
|
function parseSize(size: string): number {
|
||||||
|
const match = size.toLowerCase().match(/^(\d+(?:\.\d+)?)\s*(kb|mb|gb|b)?$/);
|
||||||
|
if (!match) {
|
||||||
|
throw new Error(`无效的大小格式: ${size}`);
|
||||||
|
}
|
||||||
|
const value = parseFloat(match[1]);
|
||||||
|
const unit = match[2] || "b";
|
||||||
|
const multipliers: Record<string, number> = {
|
||||||
|
b: 1,
|
||||||
|
kb: 1024,
|
||||||
|
mb: 1024 * 1024,
|
||||||
|
gb: 1024 * 1024 * 1024,
|
||||||
|
};
|
||||||
|
return Math.floor(value * multipliers[unit]);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 将base64字符串转换为Buffer
|
||||||
|
*/
|
||||||
|
function base64ToBuffer(base64: string): Buffer {
|
||||||
|
const base64Data = base64.replace(/^data:image\/\w+;base64,/, "");
|
||||||
|
return Buffer.from(base64Data, "base64");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 压缩Buffer到指定大小以内
|
||||||
|
*/
|
||||||
|
async function compressToSize(imageBuffer: Buffer, maxBytes: number, originalWidth: number, originalHeight: number): Promise<Buffer> {
|
||||||
|
let quality = 90;
|
||||||
|
let scale = 1;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const targetWidth = Math.round(originalWidth * scale);
|
||||||
|
const targetHeight = Math.round(originalHeight * scale);
|
||||||
|
|
||||||
|
const resultBuffer = await sharp(imageBuffer).resize(targetWidth, targetHeight, { fit: "fill" }).jpeg({ quality }).toBuffer();
|
||||||
|
|
||||||
|
if (resultBuffer.length <= maxBytes) {
|
||||||
|
return resultBuffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (quality > 10) {
|
||||||
|
quality -= 10;
|
||||||
|
} else {
|
||||||
|
quality = 90;
|
||||||
|
scale *= 0.8;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 压缩单张图片到指定大小以内
|
||||||
|
* @param imageBase64 - base64编码的图片
|
||||||
|
* @param maxSize - 最大输出大小,支持格式如 "10mb", "5MB", "1024kb" 等
|
||||||
|
* @returns 压缩后的图片base64字符串
|
||||||
|
*/
|
||||||
|
export async function compressImage(imageBase64: string, maxSize = "10mb"): Promise<string> {
|
||||||
|
const maxBytes = parseSize(maxSize);
|
||||||
|
const imageBuffer = base64ToBuffer(imageBase64);
|
||||||
|
const metadata = await sharp(imageBuffer).metadata();
|
||||||
|
const resultBuffer = await compressToSize(imageBuffer, maxBytes, metadata.width || 1, metadata.height || 1);
|
||||||
|
return resultBuffer.toString("base64");
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 将多张图片横向拼接为一张,并确保输出大小不超过指定限制
|
||||||
|
* @param imageBase64List - base64编码的图片数组
|
||||||
|
* @param maxSize - 最大输出大小,支持格式如 "10mb", "5MB", "1024kb" 等
|
||||||
|
* @returns 拼接后的图片base64字符串
|
||||||
|
*/
|
||||||
|
export async function mergeImages(imageBase64List: string[], maxSize = "10mb"): Promise<string> {
|
||||||
|
if (imageBase64List.length === 0) {
|
||||||
|
throw new Error("图片列表不能为空");
|
||||||
|
}
|
||||||
|
|
||||||
|
const maxBytes = parseSize(maxSize);
|
||||||
|
const imageBuffers = imageBase64List.map(base64ToBuffer);
|
||||||
|
const imageMetadatas = await Promise.all(imageBuffers.map((buffer) => sharp(buffer).metadata()));
|
||||||
|
const maxHeight = Math.max(...imageMetadatas.map((m) => m.height || 0));
|
||||||
|
|
||||||
|
// 计算各图片调整后的宽度
|
||||||
|
const imageWidths = imageMetadatas.map((metadata) => {
|
||||||
|
const aspectRatio = (metadata.width || 1) / (metadata.height || 1);
|
||||||
|
return Math.round(maxHeight * aspectRatio);
|
||||||
|
});
|
||||||
|
const totalWidth = imageWidths.reduce((sum, w) => sum + w, 0);
|
||||||
|
|
||||||
|
// 拼接图片
|
||||||
|
const resizedImages = await Promise.all(
|
||||||
|
imageBuffers.map(async (buffer, index) => {
|
||||||
|
return sharp(buffer).resize(imageWidths[index], maxHeight, { fit: "cover" }).toBuffer();
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
let currentX = 0;
|
||||||
|
const compositeInputs = resizedImages.map((buffer, index) => {
|
||||||
|
const input = { input: buffer, left: currentX, top: 0 };
|
||||||
|
currentX += imageWidths[index];
|
||||||
|
return input;
|
||||||
|
});
|
||||||
|
|
||||||
|
const mergedBuffer = await sharp({
|
||||||
|
create: {
|
||||||
|
width: totalWidth,
|
||||||
|
height: maxHeight,
|
||||||
|
channels: 4,
|
||||||
|
background: { r: 255, g: 255, b: 255, alpha: 1 },
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.composite(compositeInputs)
|
||||||
|
.jpeg({ quality: 90 })
|
||||||
|
.toBuffer();
|
||||||
|
|
||||||
|
// 复用压缩逻辑
|
||||||
|
const resultBuffer = await compressToSize(mergedBuffer, maxBytes, totalWidth, maxHeight);
|
||||||
|
return resultBuffer.toString("base64");
|
||||||
|
}
|
||||||
25
yarn.lock
25
yarn.lock
@ -4564,6 +4564,11 @@ nodemon@^3.1.11:
|
|||||||
touch "^3.1.0"
|
touch "^3.1.0"
|
||||||
undefsafe "^2.0.5"
|
undefsafe "^2.0.5"
|
||||||
|
|
||||||
|
non-error@^0.1.0:
|
||||||
|
version "0.1.0"
|
||||||
|
resolved "https://registry.npmmirror.com/non-error/-/non-error-0.1.0.tgz#b78b7d9a67ccb03ac979f9758813336ca7521cf2"
|
||||||
|
integrity sha512-TMB1uHiGsHRGv1uYclfhivcnf0/PdFp2pNqRxXjncaAsjYMoisaQJI+SSZCqRq+VliwRTC8tsMQfmrWjDMhkPQ==
|
||||||
|
|
||||||
nopt@^4.0.1:
|
nopt@^4.0.1:
|
||||||
version "4.0.3"
|
version "4.0.3"
|
||||||
resolved "https://registry.npmmirror.com/nopt/-/nopt-4.0.3.tgz#a375cad9d02fd921278d954c2254d5aa57e15e48"
|
resolved "https://registry.npmmirror.com/nopt/-/nopt-4.0.3.tgz#a375cad9d02fd921278d954c2254d5aa57e15e48"
|
||||||
@ -5330,6 +5335,14 @@ send@^1.1.0, send@^1.2.0:
|
|||||||
range-parser "^1.2.1"
|
range-parser "^1.2.1"
|
||||||
statuses "^2.0.2"
|
statuses "^2.0.2"
|
||||||
|
|
||||||
|
serialize-error@^13.0.1:
|
||||||
|
version "13.0.1"
|
||||||
|
resolved "https://registry.npmmirror.com/serialize-error/-/serialize-error-13.0.1.tgz#dd1e1bf6d3e3d01037d126bd95e919f48b0c8ec0"
|
||||||
|
integrity sha512-bBZaRwLH9PN5HbLCjPId4dP5bNGEtumcErgOX952IsvOhVPrm3/AeK1y0UHA/QaPG701eg0yEnOKsCOC6X/kaA==
|
||||||
|
dependencies:
|
||||||
|
non-error "^0.1.0"
|
||||||
|
type-fest "^5.4.1"
|
||||||
|
|
||||||
serialize-error@^7.0.1:
|
serialize-error@^7.0.1:
|
||||||
version "7.0.1"
|
version "7.0.1"
|
||||||
resolved "https://registry.npmmirror.com/serialize-error/-/serialize-error-7.0.1.tgz#f1360b0447f61ffb483ec4157c737fab7d778e18"
|
resolved "https://registry.npmmirror.com/serialize-error/-/serialize-error-7.0.1.tgz#f1360b0447f61ffb483ec4157c737fab7d778e18"
|
||||||
@ -5760,6 +5773,11 @@ supports-preserve-symlinks-flag@^1.0.0:
|
|||||||
resolved "https://registry.npmmirror.com/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz#6eda4bd344a3c94aea376d4cc31bc77311039e09"
|
resolved "https://registry.npmmirror.com/supports-preserve-symlinks-flag/-/supports-preserve-symlinks-flag-1.0.0.tgz#6eda4bd344a3c94aea376d4cc31bc77311039e09"
|
||||||
integrity sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==
|
integrity sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==
|
||||||
|
|
||||||
|
tagged-tag@^1.0.0:
|
||||||
|
version "1.0.0"
|
||||||
|
resolved "https://registry.npmmirror.com/tagged-tag/-/tagged-tag-1.0.0.tgz#a0b5917c2864cba54841495abfa3f6b13edcf4d6"
|
||||||
|
integrity sha512-yEFYrVhod+hdNyx7g5Bnkkb0G6si8HJurOoOEgC8B/O0uXLHlaey/65KRv6cuWBNhBgHKAROVpc7QyYqE5gFng==
|
||||||
|
|
||||||
tar-fs@^2.0.0:
|
tar-fs@^2.0.0:
|
||||||
version "2.1.4"
|
version "2.1.4"
|
||||||
resolved "https://registry.npmmirror.com/tar-fs/-/tar-fs-2.1.4.tgz#800824dbf4ef06ded9afea4acafe71c67c76b930"
|
resolved "https://registry.npmmirror.com/tar-fs/-/tar-fs-2.1.4.tgz#800824dbf4ef06ded9afea4acafe71c67c76b930"
|
||||||
@ -5929,6 +5947,13 @@ type-fest@^0.13.1:
|
|||||||
resolved "https://registry.npmmirror.com/type-fest/-/type-fest-0.13.1.tgz#0172cb5bce80b0bd542ea348db50c7e21834d934"
|
resolved "https://registry.npmmirror.com/type-fest/-/type-fest-0.13.1.tgz#0172cb5bce80b0bd542ea348db50c7e21834d934"
|
||||||
integrity sha512-34R7HTnG0XIJcBSn5XhDd7nNFPRcXYRZrBB2O2jdKqYODldSzBAqzsWoZYYvduky73toYS/ESqxPvkDf/F0XMg==
|
integrity sha512-34R7HTnG0XIJcBSn5XhDd7nNFPRcXYRZrBB2O2jdKqYODldSzBAqzsWoZYYvduky73toYS/ESqxPvkDf/F0XMg==
|
||||||
|
|
||||||
|
type-fest@^5.4.1:
|
||||||
|
version "5.4.3"
|
||||||
|
resolved "https://registry.npmmirror.com/type-fest/-/type-fest-5.4.3.tgz#b4c7e028da129098911ee2162a0c30df8a1be904"
|
||||||
|
integrity sha512-AXSAQJu79WGc79/3e9/CR77I/KQgeY1AhNvcShIH4PTcGYyC4xv6H4R4AUOwkPS5799KlVDAu8zExeCrkGquiA==
|
||||||
|
dependencies:
|
||||||
|
tagged-tag "^1.0.0"
|
||||||
|
|
||||||
type-is@^2.0.1:
|
type-is@^2.0.1:
|
||||||
version "2.0.1"
|
version "2.0.1"
|
||||||
resolved "https://registry.npmmirror.com/type-is/-/type-is-2.0.1.tgz#64f6cf03f92fce4015c2b224793f6bdd4b068c97"
|
resolved "https://registry.npmmirror.com/type-is/-/type-is-2.0.1.tgz#64f6cf03f92fce4015c2b224793f6bdd4b068c97"
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user