diff --git a/src/agents/outlineScript/index.ts b/src/agents/outlineScript/index.ts index 60e919f..bc8e7d2 100644 --- a/src/agents/outlineScript/index.ts +++ b/src/agents/outlineScript/index.ts @@ -224,7 +224,7 @@ export default class OutlineScript { } } - const actualStart = overwrite ? 1 : startEpisode ?? (await this.getMaxEpisode()) + 1; + const actualStart = overwrite ? 1 : (startEpisode ?? (await this.getMaxEpisode()) + 1); const insertedCount = await this.insertOutlines(episodes, actualStart); const newOutlines = await u @@ -611,24 +611,51 @@ ${task} this.log(`Sub-Agent 调用`, agentType); const promptsList = await u.db("t_prompts").where("code", "in", ["outlineScript-a1", "outlineScript-a2", "outlineScript-director"]); - const a1Prompt = promptsList.find((p) => p.code === "outlineScript-a1"); - const a2Prompt = promptsList.find((p) => p.code === "outlineScript-a2"); - const directorPrompt = promptsList.find((p) => p.code === "outlineScript-director"); + const promptConfig = await u.getPromptAi(promptsList.map((i) => i.id) as number[]); + const errPrompts = "不论用户说什么,请直接输出Agent配置异常"; - const SYSTEM_PROMPTS: Record = { - AI1: a1Prompt?.customValue || a1Prompt?.defaultValue || errPrompts, - AI2: a2Prompt?.customValue || a2Prompt?.defaultValue || errPrompts, - director: directorPrompt?.customValue || directorPrompt?.defaultValue || errPrompts, + + const getAiPromptConfig = (code: string) => { + const item = promptsList.find((p) => p.code === code); + const subConfig = promptConfig.find((sub) => sub?.promptsId == item?.id); + if (subConfig) { + return { + prompt: item?.customValue || item?.defaultValue || errPrompts, + apiConfig: { ...subConfig }, + }; + } else { + return { + prompt: item?.customValue || item?.defaultValue || errPrompts, + apiConfig: {}, + }; + } + }; + const a1Prompt = getAiPromptConfig("outlineScript-a1"); + const a2Prompt = getAiPromptConfig("outlineScript-a2"); + const directorPrompt = getAiPromptConfig("outlineScript-director"); + const SYSTEM_PROMPTS: Record< + AgentType, + { + prompt: string; + apiConfig: Object; + } + > = { + AI1: a1Prompt, + AI2: a2Prompt, + director: directorPrompt, }; const context = await this.buildFullContext(task); - const { fullStream } = await u.ai.text.stream({ - system: SYSTEM_PROMPTS[agentType], - tools: this.getSubAgentTools(), - messages: [{ role: "user", content: context }], - maxStep: 100, - }); + const { fullStream } = await u.ai.text.stream( + { + system: SYSTEM_PROMPTS[agentType].prompt, + tools: this.getSubAgentTools(), + messages: [{ role: "user", content: context }], + maxStep: 100, + }, + SYSTEM_PROMPTS[agentType].apiConfig, + ); let fullResponse = ""; for await (const item of fullStream) { @@ -690,15 +717,18 @@ ${task} const envContext = await this.buildEnvironmentContext(); const prompts = await u.db("t_prompts").where("code", "outlineScript-main").first(); - + const promptConfig = await u.getPromptAi(prompts?.id); const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么,请直接输出Agent配置异常"; - const { fullStream } = await u.ai.text.stream({ - system: `${envContext}\n${mainPrompts}`, - tools: this.getAllTools(), - messages: this.history, - maxStep: 100, - }); + const { fullStream } = await u.ai.text.stream( + { + system: `${envContext}\n${mainPrompts}`, + tools: this.getAllTools(), + messages: this.history, + maxStep: 100, + }, + promptConfig, + ); let fullResponse = ""; for await (const item of fullStream) { diff --git a/src/agents/storyboard/generateImagePromptsTool.ts b/src/agents/storyboard/generateImagePromptsTool.ts index cbff68c..8762778 100644 --- a/src/agents/storyboard/generateImagePromptsTool.ts +++ b/src/agents/storyboard/generateImagePromptsTool.ts @@ -98,26 +98,29 @@ async function generateGridPrompt(options: GridPromptOptions): Promise `第${i + 1}格: ${p}`).join("\n")}`; if (!mainPrompts) return { prompt: errData, gridLayout: layout }; - const result = await u.ai.text.invoke({ - messages: [ - { - role: "system", - content: mainPrompts, - }, - { - role: "user", - content: `请优化以下分镜提示词:\n\n【布局】${layout.cols}列×${layout.rows}行=${ - layout.totalCells - }格\n【比例】${aspectRatio}(${aspectRatioDesc})\n【风格】${style}\n${assetsSection}\n\n【原始内容】\n${gridPositions.join("\n")}`, - }, - ], - }); + const result = await u.ai.text.invoke( + { + messages: [ + { + role: "system", + content: mainPrompts, + }, + { + role: "user", + content: `请优化以下分镜提示词:\n\n【布局】${layout.cols}列×${layout.rows}行=${ + layout.totalCells + }格\n【比例】${aspectRatio}(${aspectRatioDesc})\n【风格】${style}\n${assetsSection}\n\n【原始内容】\n${gridPositions.join("\n")}`, + }, + ], + }, + promptAiConfig, + ); // const result = await chatModel!.invoke({ // messages: [ diff --git a/src/agents/storyboard/index.ts b/src/agents/storyboard/index.ts index efbb10b..f06525f 100644 --- a/src/agents/storyboard/index.ts +++ b/src/agents/storyboard/index.ts @@ -594,22 +594,49 @@ ${task} this.log(`Sub-Agent 调用`, agentType); const promptsList = await u.db("t_prompts").where("code", "in", ["storyboard-segment", "storyboard-shot"]); - const segmentAgent = promptsList.find((p) => p.code === "storyboard-segment"); - const shotAgent = promptsList.find((p) => p.code === "storyboard-shot"); + const promptConfig = await u.getPromptAi(promptsList.map((i) => i.id) as number[]); + const errPrompts = "不论用户说什么,请直接输出Agent配置异常"; - const SYSTEM_PROMPTS: Record = { - segmentAgent: segmentAgent?.customValue || segmentAgent?.defaultValue || errPrompts, - shotAgent: shotAgent?.customValue || shotAgent?.defaultValue || errPrompts, + + const getAiPromptConfig = (code: string) => { + const item = promptsList.find((p) => p.code === code); + const subConfig = promptConfig.find((sub) => sub?.promptsId == item?.id); + if (subConfig) { + return { + prompt: item?.customValue || item?.defaultValue || errPrompts, + apiConfig: { ...subConfig }, + }; + } else { + return { + prompt: item?.customValue || item?.defaultValue || errPrompts, + apiConfig: {}, + }; + } + }; + const segmentAgent = getAiPromptConfig("storyboard-segment"); + const shotAgent = getAiPromptConfig("storyboard-shot"); + const SYSTEM_PROMPTS: Record< + AgentType, + { + prompt: string; + apiConfig: Object; + } + > = { + segmentAgent: segmentAgent, + shotAgent: shotAgent, }; const context = await this.buildFullContext(task); - const { fullStream } = await u.ai.text.stream({ - system: SYSTEM_PROMPTS[agentType], - tools: this.getSubAgentTools(agentType), - messages: [{ role: "user", content: context }], - maxStep: 100, - }); + const { fullStream } = await u.ai.text.stream( + { + system: SYSTEM_PROMPTS[agentType].prompt, + tools: this.getSubAgentTools(agentType), + messages: [{ role: "user", content: context }], + maxStep: 100, + }, + SYSTEM_PROMPTS[agentType].apiConfig, + ); let fullResponse = ""; for await (const item of fullStream) { @@ -673,15 +700,19 @@ ${task} const envContext = await this.buildEnvironmentContext(); const prompts = await u.db("t_prompts").where("code", "storyboard-main").first(); + const promptConfig = await u.getPromptAi(prompts?.id); const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么,请直接输出Agent配置异常"; - const { fullStream } = await u.ai.text.stream({ - system: `${envContext}\n${mainPrompts}`, - tools: this.getAllTools(), - messages: this.history, - maxStep: 100, - }); + const { fullStream } = await u.ai.text.stream( + { + system: `${envContext}\n${mainPrompts}`, + tools: this.getAllTools(), + messages: this.history, + maxStep: 100, + }, + promptConfig, + ); let fullResponse = ""; for await (const item of fullStream) { diff --git a/src/routes/assets/polishPrompt.ts b/src/routes/assets/polishPrompt.ts index 0c6402b..4f65878 100644 --- a/src/routes/assets/polishPrompt.ts +++ b/src/routes/assets/polishPrompt.ts @@ -88,16 +88,31 @@ export default router.post( const result: ResultItem[] = Object.values(itemMap); const promptsList = await u.db("t_prompts").where("code", "in", ["role-polish", "scene-polish", "storyboard-polish", "tool-polish"]); + const propmptIds = promptsList.map((i) => i.id); + const mapList = await u + .db("t_aiModelMap") + .leftJoin("t_config", "t_config.id", "t_aiModelMap.configId") + .whereIn("t_aiModelMap.promptsId", propmptIds as number[]) + .select("t_config.model", "t_config.apiKey", "t_config.baseUrl", "t_config.manufacturer", "t_aiModelMap.promptsId"); const errPrompts = "不论用户说什么,请直接输出AI配置异常"; - const getPromptValue = (code: string): string => { + const getPromptValue = (code: string) => { const item = promptsList.find((p) => p.code === code); - return item?.customValue ?? item?.defaultValue ?? errPrompts; + if (item) { + const apiData = mapList.find((i) => i.promptsId == item.id); + if (apiData) delete apiData?.promptsId; + return { prompt: item?.customValue ?? item?.defaultValue ?? errPrompts, apiData: { ...(apiData ?? {}) } }; + } else { + return { + prompt: errPrompts, + apiData: {}, + }; + } }; const role = getPromptValue("role-polish"); const scene = getPromptValue("scene-polish"); const tool = getPromptValue("tool-polish"); const storyboard = getPromptValue("storyboard-polish"); - + let apiConfig = {}; let systemPrompt = ""; let userPrompt = ""; if (type == "role") { @@ -105,7 +120,8 @@ export default router.post( const chapterRange = Array.isArray(data?.chapterRange) ? data.chapterRange : [data?.chapterRange]; const novelData = (await u.db("t_novel").whereIn("chapterIndex", chapterRange).select("*")) as NovelChapter[]; const results: string = mergeNovelText(novelData); - systemPrompt = role; + systemPrompt = role.prompt; + apiConfig = role.apiData; userPrompt = ` 请根据以下参数生成角色标准四视图提示词: @@ -128,7 +144,8 @@ export default router.post( const chapterRange = Array.isArray(data?.chapterRange) ? data.chapterRange : [data?.chapterRange]; const novelData = (await u.db("t_novel").whereIn("chapterIndex", chapterRange).select("*")) as NovelChapter[]; const results: string = mergeNovelText(novelData); - systemPrompt = scene; + systemPrompt = scene.prompt; + apiConfig = scene.apiData; userPrompt = ` 请根据以下参数生成场景图提示词: @@ -151,7 +168,8 @@ export default router.post( const chapterRange = Array.isArray(data?.chapterRange) ? data.chapterRange : [data?.chapterRange]; const novelData = (await u.db("t_novel").whereIn("chapterIndex", chapterRange).select("*")) as NovelChapter[]; const results: string = mergeNovelText(novelData); - systemPrompt = tool; + systemPrompt = tool.prompt; + apiConfig = tool.apiData; userPrompt = ` 请根据以下参数生成道具图提示词: @@ -170,7 +188,8 @@ export default router.post( `; } if (type == "storyboard") { - systemPrompt = storyboard; + systemPrompt = storyboard.prompt; + apiConfig = storyboard.apiData; userPrompt = ` 请根据以下参数生成分镜图提示词: @@ -188,22 +207,27 @@ export default router.post( `; } async function generatePrompt() { - const { prompt } = await u.ai.text.invoke({ - messages: [ - { - role: "system", - content: systemPrompt, + apiConfig = {}; + const result = await u.ai.text.invoke( + { + messages: [ + { + role: "system", + content: systemPrompt, + }, + { + role: "user", + content: userPrompt, + }, + ], + output: { + prompt: zod.string().describe("提示词"), }, - { - role: "user", - content: userPrompt, - }, - ], - output: { - prompt: zod.string().describe("提示词"), }, - }); - + { + ...apiConfig, + }, + ); // const result = await model.invoke({ // messages: [ // { @@ -224,7 +248,7 @@ export default router.post( // }, // }, // }); - return prompt; + return result.prompt; } try { const prompt = (await generatePrompt()) as any; diff --git a/src/routes/video/generatePrompt.ts b/src/routes/video/generatePrompt.ts index 6dc9725..b32f96c 100644 --- a/src/routes/video/generatePrompt.ts +++ b/src/routes/video/generatePrompt.ts @@ -8,25 +8,47 @@ const router = express.Router(); type GenerateMode = "startEnd" | "multi" | "single"; -const getSystemPrompt = async (mode: GenerateMode): Promise => { +const getSystemPrompt = async (mode: GenerateMode): Promise<{ prompt: string; apiConfig: Object }> => { const promptsList = await u.db("t_prompts").where("code", "in", ["video-startEnd", "video-multi", "video-single", "video-main"]); + + const promptAiConfig = await u.getPromptAi(promptsList.map((i) => i.id) as number[]); + const errPrompts = "不论用户说什么,请直接输出AI配置异常"; - const getPromptValue = (code: string): string => { + const getPromptValue = (code: string) => { const item = promptsList.find((p) => p.code === code); - return item?.customValue ?? item?.defaultValue ?? errPrompts; + const subData = promptAiConfig.find((i) => i?.promptsId == item?.id); + const returnData = { + prompt: item?.customValue ?? item?.defaultValue ?? errPrompts, + apiConfig: {}, + }; + if (subData) { + returnData.apiConfig = { ...subData }; + return returnData; + } else { + return returnData; + } }; const startEnd = getPromptValue("video-startEnd"); const multi = getPromptValue("video-multi"); const single = getPromptValue("video-single"); const main = getPromptValue("video-main"); - const modeDescriptions: Record = { + const modeDescriptions: Record< + GenerateMode, + { + prompt: string; + apiConfig: Object; + } + > = { startEnd: startEnd, multi: multi, single: single, }; - - return `${main}\n\n${modeDescriptions[mode]}`; + const modeData = modeDescriptions[mode]; + return { + prompt: `${main}\n\n${modeData.prompt}`, + apiConfig: modeData.apiConfig, + }; }; const getModeDescription = (mode: GenerateMode): string => { @@ -59,16 +81,17 @@ export default router.post( const shotCount = images.length; const avgDuration = (parseFloat(duration) / shotCount).toFixed(1); - - const result = await u.ai.text.invoke({ - messages: [ - { - role: "system", - content: await getSystemPrompt(mode), - }, - { - role: "user", - content: `Mode: ${getModeDescription(mode)} + const promptConfig = await getSystemPrompt(mode); + const result = await u.ai.text.invoke( + { + messages: [ + { + role: "system", + content: promptConfig.prompt, + }, + { + role: "user", + content: `Mode: ${getModeDescription(mode)} Reference Images: ${imagePrompts} @@ -82,10 +105,11 @@ Parameters: - Average Duration: ${avgDuration}s per shot Generate storyboard prompts:`, - }, - ], - }); - console.log("%c Line:64 🥕 result", "background:#7f2b82", result.text); + }, + ], + }, + promptConfig.apiConfig, + ); res.status(200).send(success(result.text)); }, diff --git a/src/utils.ts b/src/utils.ts index a08d48b..d813499 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -13,6 +13,7 @@ import AIText from "@/utils/ai/text/index"; import AIImage from "@/utils/ai/image/index"; import AIVideo from "@/utils/ai/video/index"; +import getPromptAi from "./utils/getPromptAi"; export default { db, oss, @@ -28,4 +29,5 @@ export default { uuid, error, imageTools, + getPromptAi, }; diff --git a/src/utils/ai/text/index.ts b/src/utils/ai/text/index.ts index 433f84b..09a987e 100644 --- a/src/utils/ai/text/index.ts +++ b/src/utils/ai/text/index.ts @@ -5,6 +5,7 @@ import { devToolsMiddleware } from "@ai-sdk/devtools"; import { parse } from "best-effort-json-parser"; import modelList from "./modelList"; import { z } from "zod"; +import { OpenAIProvider } from "@ai-sdk/openai"; interface AIInput | undefined = undefined> { system?: string; @@ -19,17 +20,22 @@ interface AIConfig { model?: string; apiKey?: string; baseURL?: string; + manufacturer?: string; } const buildOptions = async (input: AIInput, config: AIConfig) => { let sqlTextModelConfig = {}; if (!config || !config?.model || !config?.apiKey || !config?.baseURL) sqlTextModelConfig = await u.getConfig("text"); - const { model, apiKey, baseURL } = { ...sqlTextModelConfig, ...config }; - - const owned = modelList.find((m) => m.model === model); + const { model, apiKey, baseURL, manufacturer } = { ...(sqlTextModelConfig as Awaited>), ...config }; + let owned; + if (manufacturer == "other") { + owned = modelList.find((m) => m.manufacturer === manufacturer); + } else { + owned = modelList.find((m) => m.model === model); + } if (!owned) throw new Error("不支持的模型或厂商"); - const modelInstance = owned.instance({ apiKey, baseURL }); + const modelInstance = owned.instance({ apiKey, baseURL: baseURL!, name: "xixixi" }); const maxStep = input.maxStep ?? (input.tools ? Object.keys(input.tools).length * 5 : undefined); const outputBuilders: Record any> = { @@ -46,16 +52,16 @@ const buildOptions = async (input: AIInput, config: AIConfig) => { }; const output = input.output ? (outputBuilders[owned.responseFormat]?.(input.output) ?? null) : null; - + const modelFn = owned.manufacturer == "doubao" ? (modelInstance as OpenAIProvider).chat(model!) : modelInstance(model!); return { config: { model: process.env.NODE_ENV === "dev" ? wrapLanguageModel({ - model: modelInstance.chat(model!) as any, + model: modelFn as any, middleware: devToolsMiddleware(), }) - : (modelInstance(model!) as LanguageModel), + : (modelFn as LanguageModel), ...(input.system && { system: input.system }), ...(input.prompt ? { prompt: input.prompt } : { messages: input.messages! }), ...(input.tools && owned.tool && { tools: input.tools }), diff --git a/src/utils/generateScript.ts b/src/utils/generateScript.ts index ae40ab7..77235b5 100644 --- a/src/utils/generateScript.ts +++ b/src/utils/generateScript.ts @@ -127,15 +127,18 @@ ${episodePrompt} ${novelData}`; const prompts = await u.db("t_prompts").where("code", "script").first(); - + const promptConfig = await u.getPromptAi(prompts?.id); const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么,请直接输出AI配置异常"; - const result = await u.ai.text.invoke({ - messages: [ - { role: "system", content: mainPrompts }, - { role: "user", content: userPrompt }, - ], - }); + const result = await u.ai.text.invoke( + { + messages: [ + { role: "system", content: mainPrompts }, + { role: "user", content: userPrompt }, + ], + }, + promptConfig, + ); return result.text ?? ""; } diff --git a/src/utils/getConfig.ts b/src/utils/getConfig.ts index cb27c09..fce2884 100644 --- a/src/utils/getConfig.ts +++ b/src/utils/getConfig.ts @@ -10,12 +10,12 @@ interface BaseConfig { interface TextResData extends BaseConfig { baseURL: string; - manufacturer: "deepseek" | "openAi" | "doubao"; + manufacturer: "deepseek" | "openAi" | "doubao" | "other"; } // 图像模型配置接口 interface ImageResData extends BaseConfig { - manufacturer: "gemini" | "volcengine" | "kling" | "vidu" | "runninghub" | "apimart"; + manufacturer: "gemini" | "volcengine" | "kling" | "vidu" | "runninghub" | "apimart" | "other"; } interface VideoResData extends BaseConfig { diff --git a/src/utils/getPromptAi.ts b/src/utils/getPromptAi.ts new file mode 100644 index 0000000..1618cba --- /dev/null +++ b/src/utils/getPromptAi.ts @@ -0,0 +1,26 @@ +import { db } from "./db"; +interface AiConfig { + model: string; + apiKey: string; + baseUrl: string; + manufacturer: string; + promptsId: number; +} + +export default async function getPromptAi(promptsId: number | undefined): Promise; +export default async function getPromptAi(promptsId: number[]): Promise; + +export default async function getPromptAi(promptsId: number | number[] | undefined): Promise { + if (!promptsId) return {}; + const ids = Array.isArray(promptsId) ? promptsId.filter(Boolean) : [promptsId]; + const mapList = await db("t_aiModelMap") + .leftJoin("t_config", "t_config.id", "t_aiModelMap.configId") + .whereIn("t_aiModelMap.promptsId", ids) + .select("t_config.model", "t_config.apiKey", "t_config.baseUrl", "t_config.manufacturer", "t_aiModelMap.promptsId"); + + if (Array.isArray(promptsId)) { + return mapList as AiConfig[]; + } else { + return mapList[0] ? (mapList[0] as AiConfig) : {}; + } +}