提示词ai分配 完成

This commit is contained in:
zhishi 2026-02-06 16:58:19 +08:00
parent 7abda02b25
commit 96bb7d66ac
10 changed files with 260 additions and 111 deletions

View File

@ -224,7 +224,7 @@ export default class OutlineScript {
}
}
const actualStart = overwrite ? 1 : startEpisode ?? (await this.getMaxEpisode()) + 1;
const actualStart = overwrite ? 1 : (startEpisode ?? (await this.getMaxEpisode()) + 1);
const insertedCount = await this.insertOutlines(episodes, actualStart);
const newOutlines = await u
@ -611,24 +611,51 @@ ${task}
this.log(`Sub-Agent 调用`, agentType);
const promptsList = await u.db("t_prompts").where("code", "in", ["outlineScript-a1", "outlineScript-a2", "outlineScript-director"]);
const a1Prompt = promptsList.find((p) => p.code === "outlineScript-a1");
const a2Prompt = promptsList.find((p) => p.code === "outlineScript-a2");
const directorPrompt = promptsList.find((p) => p.code === "outlineScript-director");
const promptConfig = await u.getPromptAi(promptsList.map((i) => i.id) as number[]);
const errPrompts = "不论用户说什么请直接输出Agent配置异常";
const SYSTEM_PROMPTS: Record<AgentType, string> = {
AI1: a1Prompt?.customValue || a1Prompt?.defaultValue || errPrompts,
AI2: a2Prompt?.customValue || a2Prompt?.defaultValue || errPrompts,
director: directorPrompt?.customValue || directorPrompt?.defaultValue || errPrompts,
const getAiPromptConfig = (code: string) => {
const item = promptsList.find((p) => p.code === code);
const subConfig = promptConfig.find((sub) => sub?.promptsId == item?.id);
if (subConfig) {
return {
prompt: item?.customValue || item?.defaultValue || errPrompts,
apiConfig: { ...subConfig },
};
} else {
return {
prompt: item?.customValue || item?.defaultValue || errPrompts,
apiConfig: {},
};
}
};
const a1Prompt = getAiPromptConfig("outlineScript-a1");
const a2Prompt = getAiPromptConfig("outlineScript-a2");
const directorPrompt = getAiPromptConfig("outlineScript-director");
const SYSTEM_PROMPTS: Record<
AgentType,
{
prompt: string;
apiConfig: Object;
}
> = {
AI1: a1Prompt,
AI2: a2Prompt,
director: directorPrompt,
};
const context = await this.buildFullContext(task);
const { fullStream } = await u.ai.text.stream({
system: SYSTEM_PROMPTS[agentType],
tools: this.getSubAgentTools(),
messages: [{ role: "user", content: context }],
maxStep: 100,
});
const { fullStream } = await u.ai.text.stream(
{
system: SYSTEM_PROMPTS[agentType].prompt,
tools: this.getSubAgentTools(),
messages: [{ role: "user", content: context }],
maxStep: 100,
},
SYSTEM_PROMPTS[agentType].apiConfig,
);
let fullResponse = "";
for await (const item of fullStream) {
@ -690,15 +717,18 @@ ${task}
const envContext = await this.buildEnvironmentContext();
const prompts = await u.db("t_prompts").where("code", "outlineScript-main").first();
const promptConfig = await u.getPromptAi(prompts?.id);
const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么请直接输出Agent配置异常";
const { fullStream } = await u.ai.text.stream({
system: `${envContext}\n${mainPrompts}`,
tools: this.getAllTools(),
messages: this.history,
maxStep: 100,
});
const { fullStream } = await u.ai.text.stream(
{
system: `${envContext}\n${mainPrompts}`,
tools: this.getAllTools(),
messages: this.history,
maxStep: 100,
},
promptConfig,
);
let fullResponse = "";
for await (const item of fullStream) {

View File

@ -98,26 +98,29 @@ async function generateGridPrompt(options: GridPromptOptions): Promise<GridPromp
: "";
const promptsData = await u.db("t_prompts").where("code", "generateImagePrompts").first();
const promptAiConfig = await u.getPromptAi(promptsData?.id);
const mainPrompts = promptsData?.customValue || promptsData?.defaultValue;
const errData = `请输出${options.prompts.length}张图片\n提示词如下:\n${options.prompts.map((p, i) => `${i + 1}格: ${p}`).join("\n")}`;
if (!mainPrompts) return { prompt: errData, gridLayout: layout };
const result = await u.ai.text.invoke({
messages: [
{
role: "system",
content: mainPrompts,
},
{
role: "user",
content: `请优化以下分镜提示词:\n\n【布局】${layout.cols}列×${layout.rows}行=${
layout.totalCells
}\n${aspectRatio}${aspectRatioDesc}\n${style}\n${assetsSection}\n\n\n${gridPositions.join("\n")}`,
},
],
});
const result = await u.ai.text.invoke(
{
messages: [
{
role: "system",
content: mainPrompts,
},
{
role: "user",
content: `请优化以下分镜提示词:\n\n【布局】${layout.cols}列×${layout.rows}行=${
layout.totalCells
}\n${aspectRatio}${aspectRatioDesc}\n${style}\n${assetsSection}\n\n\n${gridPositions.join("\n")}`,
},
],
},
promptAiConfig,
);
// const result = await chatModel!.invoke({
// messages: [

View File

@ -594,22 +594,49 @@ ${task}
this.log(`Sub-Agent 调用`, agentType);
const promptsList = await u.db("t_prompts").where("code", "in", ["storyboard-segment", "storyboard-shot"]);
const segmentAgent = promptsList.find((p) => p.code === "storyboard-segment");
const shotAgent = promptsList.find((p) => p.code === "storyboard-shot");
const promptConfig = await u.getPromptAi(promptsList.map((i) => i.id) as number[]);
const errPrompts = "不论用户说什么请直接输出Agent配置异常";
const SYSTEM_PROMPTS: Record<AgentType, string> = {
segmentAgent: segmentAgent?.customValue || segmentAgent?.defaultValue || errPrompts,
shotAgent: shotAgent?.customValue || shotAgent?.defaultValue || errPrompts,
const getAiPromptConfig = (code: string) => {
const item = promptsList.find((p) => p.code === code);
const subConfig = promptConfig.find((sub) => sub?.promptsId == item?.id);
if (subConfig) {
return {
prompt: item?.customValue || item?.defaultValue || errPrompts,
apiConfig: { ...subConfig },
};
} else {
return {
prompt: item?.customValue || item?.defaultValue || errPrompts,
apiConfig: {},
};
}
};
const segmentAgent = getAiPromptConfig("storyboard-segment");
const shotAgent = getAiPromptConfig("storyboard-shot");
const SYSTEM_PROMPTS: Record<
AgentType,
{
prompt: string;
apiConfig: Object;
}
> = {
segmentAgent: segmentAgent,
shotAgent: shotAgent,
};
const context = await this.buildFullContext(task);
const { fullStream } = await u.ai.text.stream({
system: SYSTEM_PROMPTS[agentType],
tools: this.getSubAgentTools(agentType),
messages: [{ role: "user", content: context }],
maxStep: 100,
});
const { fullStream } = await u.ai.text.stream(
{
system: SYSTEM_PROMPTS[agentType].prompt,
tools: this.getSubAgentTools(agentType),
messages: [{ role: "user", content: context }],
maxStep: 100,
},
SYSTEM_PROMPTS[agentType].apiConfig,
);
let fullResponse = "";
for await (const item of fullStream) {
@ -673,15 +700,19 @@ ${task}
const envContext = await this.buildEnvironmentContext();
const prompts = await u.db("t_prompts").where("code", "storyboard-main").first();
const promptConfig = await u.getPromptAi(prompts?.id);
const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么请直接输出Agent配置异常";
const { fullStream } = await u.ai.text.stream({
system: `${envContext}\n${mainPrompts}`,
tools: this.getAllTools(),
messages: this.history,
maxStep: 100,
});
const { fullStream } = await u.ai.text.stream(
{
system: `${envContext}\n${mainPrompts}`,
tools: this.getAllTools(),
messages: this.history,
maxStep: 100,
},
promptConfig,
);
let fullResponse = "";
for await (const item of fullStream) {

View File

@ -88,16 +88,31 @@ export default router.post(
const result: ResultItem[] = Object.values(itemMap);
const promptsList = await u.db("t_prompts").where("code", "in", ["role-polish", "scene-polish", "storyboard-polish", "tool-polish"]);
const propmptIds = promptsList.map((i) => i.id);
const mapList = await u
.db("t_aiModelMap")
.leftJoin("t_config", "t_config.id", "t_aiModelMap.configId")
.whereIn("t_aiModelMap.promptsId", propmptIds as number[])
.select("t_config.model", "t_config.apiKey", "t_config.baseUrl", "t_config.manufacturer", "t_aiModelMap.promptsId");
const errPrompts = "不论用户说什么请直接输出AI配置异常";
const getPromptValue = (code: string): string => {
const getPromptValue = (code: string) => {
const item = promptsList.find((p) => p.code === code);
return item?.customValue ?? item?.defaultValue ?? errPrompts;
if (item) {
const apiData = mapList.find((i) => i.promptsId == item.id);
if (apiData) delete apiData?.promptsId;
return { prompt: item?.customValue ?? item?.defaultValue ?? errPrompts, apiData: { ...(apiData ?? {}) } };
} else {
return {
prompt: errPrompts,
apiData: {},
};
}
};
const role = getPromptValue("role-polish");
const scene = getPromptValue("scene-polish");
const tool = getPromptValue("tool-polish");
const storyboard = getPromptValue("storyboard-polish");
let apiConfig = {};
let systemPrompt = "";
let userPrompt = "";
if (type == "role") {
@ -105,7 +120,8 @@ export default router.post(
const chapterRange = Array.isArray(data?.chapterRange) ? data.chapterRange : [data?.chapterRange];
const novelData = (await u.db("t_novel").whereIn("chapterIndex", chapterRange).select("*")) as NovelChapter[];
const results: string = mergeNovelText(novelData);
systemPrompt = role;
systemPrompt = role.prompt;
apiConfig = role.apiData;
userPrompt = `
@ -128,7 +144,8 @@ export default router.post(
const chapterRange = Array.isArray(data?.chapterRange) ? data.chapterRange : [data?.chapterRange];
const novelData = (await u.db("t_novel").whereIn("chapterIndex", chapterRange).select("*")) as NovelChapter[];
const results: string = mergeNovelText(novelData);
systemPrompt = scene;
systemPrompt = scene.prompt;
apiConfig = scene.apiData;
userPrompt = `
@ -151,7 +168,8 @@ export default router.post(
const chapterRange = Array.isArray(data?.chapterRange) ? data.chapterRange : [data?.chapterRange];
const novelData = (await u.db("t_novel").whereIn("chapterIndex", chapterRange).select("*")) as NovelChapter[];
const results: string = mergeNovelText(novelData);
systemPrompt = tool;
systemPrompt = tool.prompt;
apiConfig = tool.apiData;
userPrompt = `
@ -170,7 +188,8 @@ export default router.post(
`;
}
if (type == "storyboard") {
systemPrompt = storyboard;
systemPrompt = storyboard.prompt;
apiConfig = storyboard.apiData;
userPrompt = `
@ -188,22 +207,27 @@ export default router.post(
`;
}
async function generatePrompt() {
const { prompt } = await u.ai.text.invoke({
messages: [
{
role: "system",
content: systemPrompt,
apiConfig = {};
const result = await u.ai.text.invoke(
{
messages: [
{
role: "system",
content: systemPrompt,
},
{
role: "user",
content: userPrompt,
},
],
output: {
prompt: zod.string().describe("提示词"),
},
{
role: "user",
content: userPrompt,
},
],
output: {
prompt: zod.string().describe("提示词"),
},
});
{
...apiConfig,
},
);
// const result = await model.invoke({
// messages: [
// {
@ -224,7 +248,7 @@ export default router.post(
// },
// },
// });
return prompt;
return result.prompt;
}
try {
const prompt = (await generatePrompt()) as any;

View File

@ -8,25 +8,47 @@ const router = express.Router();
type GenerateMode = "startEnd" | "multi" | "single";
const getSystemPrompt = async (mode: GenerateMode): Promise<string> => {
const getSystemPrompt = async (mode: GenerateMode): Promise<{ prompt: string; apiConfig: Object }> => {
const promptsList = await u.db("t_prompts").where("code", "in", ["video-startEnd", "video-multi", "video-single", "video-main"]);
const promptAiConfig = await u.getPromptAi(promptsList.map((i) => i.id) as number[]);
const errPrompts = "不论用户说什么请直接输出AI配置异常";
const getPromptValue = (code: string): string => {
const getPromptValue = (code: string) => {
const item = promptsList.find((p) => p.code === code);
return item?.customValue ?? item?.defaultValue ?? errPrompts;
const subData = promptAiConfig.find((i) => i?.promptsId == item?.id);
const returnData = {
prompt: item?.customValue ?? item?.defaultValue ?? errPrompts,
apiConfig: {},
};
if (subData) {
returnData.apiConfig = { ...subData };
return returnData;
} else {
return returnData;
}
};
const startEnd = getPromptValue("video-startEnd");
const multi = getPromptValue("video-multi");
const single = getPromptValue("video-single");
const main = getPromptValue("video-main");
const modeDescriptions: Record<GenerateMode, string> = {
const modeDescriptions: Record<
GenerateMode,
{
prompt: string;
apiConfig: Object;
}
> = {
startEnd: startEnd,
multi: multi,
single: single,
};
return `${main}\n\n${modeDescriptions[mode]}`;
const modeData = modeDescriptions[mode];
return {
prompt: `${main}\n\n${modeData.prompt}`,
apiConfig: modeData.apiConfig,
};
};
const getModeDescription = (mode: GenerateMode): string => {
@ -59,16 +81,17 @@ export default router.post(
const shotCount = images.length;
const avgDuration = (parseFloat(duration) / shotCount).toFixed(1);
const result = await u.ai.text.invoke({
messages: [
{
role: "system",
content: await getSystemPrompt(mode),
},
{
role: "user",
content: `Mode: ${getModeDescription(mode)}
const promptConfig = await getSystemPrompt(mode);
const result = await u.ai.text.invoke(
{
messages: [
{
role: "system",
content: promptConfig.prompt,
},
{
role: "user",
content: `Mode: ${getModeDescription(mode)}
Reference Images:
${imagePrompts}
@ -82,10 +105,11 @@ Parameters:
- Average Duration: ${avgDuration}s per shot
Generate storyboard prompts:`,
},
],
});
console.log("%c Line:64 🥕 result", "background:#7f2b82", result.text);
},
],
},
promptConfig.apiConfig,
);
res.status(200).send(success(result.text));
},

View File

@ -13,6 +13,7 @@ import AIText from "@/utils/ai/text/index";
import AIImage from "@/utils/ai/image/index";
import AIVideo from "@/utils/ai/video/index";
import getPromptAi from "./utils/getPromptAi";
export default {
db,
oss,
@ -28,4 +29,5 @@ export default {
uuid,
error,
imageTools,
getPromptAi,
};

View File

@ -5,6 +5,7 @@ import { devToolsMiddleware } from "@ai-sdk/devtools";
import { parse } from "best-effort-json-parser";
import modelList from "./modelList";
import { z } from "zod";
import { OpenAIProvider } from "@ai-sdk/openai";
interface AIInput<T extends Record<string, z.ZodTypeAny> | undefined = undefined> {
system?: string;
@ -19,17 +20,22 @@ interface AIConfig {
model?: string;
apiKey?: string;
baseURL?: string;
manufacturer?: string;
}
const buildOptions = async (input: AIInput<any>, config: AIConfig) => {
let sqlTextModelConfig = {};
if (!config || !config?.model || !config?.apiKey || !config?.baseURL) sqlTextModelConfig = await u.getConfig("text");
const { model, apiKey, baseURL } = { ...sqlTextModelConfig, ...config };
const owned = modelList.find((m) => m.model === model);
const { model, apiKey, baseURL, manufacturer } = { ...(sqlTextModelConfig as Awaited<ReturnType<typeof u.getConfig>>), ...config };
let owned;
if (manufacturer == "other") {
owned = modelList.find((m) => m.manufacturer === manufacturer);
} else {
owned = modelList.find((m) => m.model === model);
}
if (!owned) throw new Error("不支持的模型或厂商");
const modelInstance = owned.instance({ apiKey, baseURL });
const modelInstance = owned.instance({ apiKey, baseURL: baseURL!, name: "xixixi" });
const maxStep = input.maxStep ?? (input.tools ? Object.keys(input.tools).length * 5 : undefined);
const outputBuilders: Record<string, (schema: any) => any> = {
@ -46,16 +52,16 @@ const buildOptions = async (input: AIInput<any>, config: AIConfig) => {
};
const output = input.output ? (outputBuilders[owned.responseFormat]?.(input.output) ?? null) : null;
const modelFn = owned.manufacturer == "doubao" ? (modelInstance as OpenAIProvider).chat(model!) : modelInstance(model!);
return {
config: {
model:
process.env.NODE_ENV === "dev"
? wrapLanguageModel({
model: modelInstance.chat(model!) as any,
model: modelFn as any,
middleware: devToolsMiddleware(),
})
: (modelInstance(model!) as LanguageModel),
: (modelFn as LanguageModel),
...(input.system && { system: input.system }),
...(input.prompt ? { prompt: input.prompt } : { messages: input.messages! }),
...(input.tools && owned.tool && { tools: input.tools }),

View File

@ -127,15 +127,18 @@ ${episodePrompt}
${novelData}`;
const prompts = await u.db("t_prompts").where("code", "script").first();
const promptConfig = await u.getPromptAi(prompts?.id);
const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么请直接输出AI配置异常";
const result = await u.ai.text.invoke({
messages: [
{ role: "system", content: mainPrompts },
{ role: "user", content: userPrompt },
],
});
const result = await u.ai.text.invoke(
{
messages: [
{ role: "system", content: mainPrompts },
{ role: "user", content: userPrompt },
],
},
promptConfig,
);
return result.text ?? "";
}

View File

@ -10,12 +10,12 @@ interface BaseConfig {
interface TextResData extends BaseConfig {
baseURL: string;
manufacturer: "deepseek" | "openAi" | "doubao";
manufacturer: "deepseek" | "openAi" | "doubao" | "other";
}
// 图像模型配置接口
interface ImageResData extends BaseConfig {
manufacturer: "gemini" | "volcengine" | "kling" | "vidu" | "runninghub" | "apimart";
manufacturer: "gemini" | "volcengine" | "kling" | "vidu" | "runninghub" | "apimart" | "other";
}
interface VideoResData extends BaseConfig {

26
src/utils/getPromptAi.ts Normal file
View File

@ -0,0 +1,26 @@
import { db } from "./db";
interface AiConfig {
model: string;
apiKey: string;
baseUrl: string;
manufacturer: string;
promptsId: number;
}
export default async function getPromptAi(promptsId: number | undefined): Promise<AiConfig | {}>;
export default async function getPromptAi(promptsId: number[]): Promise<AiConfig[]>;
export default async function getPromptAi(promptsId: number | number[] | undefined): Promise<AiConfig | AiConfig[] | {}> {
if (!promptsId) return {};
const ids = Array.isArray(promptsId) ? promptsId.filter(Boolean) : [promptsId];
const mapList = await db("t_aiModelMap")
.leftJoin("t_config", "t_config.id", "t_aiModelMap.configId")
.whereIn("t_aiModelMap.promptsId", ids)
.select("t_config.model", "t_config.apiKey", "t_config.baseUrl", "t_config.manufacturer", "t_aiModelMap.promptsId");
if (Array.isArray(promptsId)) {
return mapList as AiConfig[];
} else {
return mapList[0] ? (mapList[0] as AiConfig) : {};
}
}