所有ai 分模块接入,待测试

This commit is contained in:
zhishi 2026-02-07 17:51:18 +08:00
parent 981b6fe863
commit c064fdb679
35 changed files with 391 additions and 373 deletions

View File

@ -611,35 +611,18 @@ ${task}
this.log(`Sub-Agent 调用`, agentType); this.log(`Sub-Agent 调用`, agentType);
const promptsList = await u.db("t_prompts").where("code", "in", ["outlineScript-a1", "outlineScript-a2", "outlineScript-director"]); const promptsList = await u.db("t_prompts").where("code", "in", ["outlineScript-a1", "outlineScript-a2", "outlineScript-director"]);
const promptConfig = await u.getPromptAi(promptsList.map((i) => i.id) as number[]); const promptConfig = await u.getPromptAi("outlineScriptAgent");
const errPrompts = "不论用户说什么请直接输出Agent配置异常"; const errPrompts = "不论用户说什么请直接输出Agent配置异常";
const getAiPromptConfig = (code: string) => { const getAiPromptConfig = (code: string) => {
const item = promptsList.find((p) => p.code === code); const item = promptsList.find((p) => p.code === code);
const subConfig = promptConfig.find((sub) => sub?.promptsId == item?.id); return item?.customValue || item?.defaultValue || errPrompts;
if (subConfig) {
return {
prompt: item?.customValue || item?.defaultValue || errPrompts,
apiConfig: { ...subConfig },
};
} else {
return {
prompt: item?.customValue || item?.defaultValue || errPrompts,
apiConfig: {},
};
}
}; };
const a1Prompt = getAiPromptConfig("outlineScript-a1"); const a1Prompt = getAiPromptConfig("outlineScript-a1");
const a2Prompt = getAiPromptConfig("outlineScript-a2"); const a2Prompt = getAiPromptConfig("outlineScript-a2");
const directorPrompt = getAiPromptConfig("outlineScript-director"); const directorPrompt = getAiPromptConfig("outlineScript-director");
const SYSTEM_PROMPTS: Record< const SYSTEM_PROMPTS = {
AgentType,
{
prompt: string;
apiConfig: Object;
}
> = {
AI1: a1Prompt, AI1: a1Prompt,
AI2: a2Prompt, AI2: a2Prompt,
director: directorPrompt, director: directorPrompt,
@ -649,12 +632,12 @@ ${task}
const { fullStream } = await u.ai.text.stream( const { fullStream } = await u.ai.text.stream(
{ {
system: SYSTEM_PROMPTS[agentType].prompt, system: SYSTEM_PROMPTS[agentType],
tools: this.getSubAgentTools(), tools: this.getSubAgentTools(),
messages: [{ role: "user", content: context }], messages: [{ role: "user", content: context }],
maxStep: 100, maxStep: 100,
}, },
SYSTEM_PROMPTS[agentType].apiConfig, promptConfig,
); );
let fullResponse = ""; let fullResponse = "";
@ -717,7 +700,9 @@ ${task}
const envContext = await this.buildEnvironmentContext(); const envContext = await this.buildEnvironmentContext();
const prompts = await u.db("t_prompts").where("code", "outlineScript-main").first(); const prompts = await u.db("t_prompts").where("code", "outlineScript-main").first();
const promptConfig = await u.getPromptAi(prompts?.id); console.log("%c Line:703 🍭 prompts", "background:#f5ce50", prompts);
const promptConfig = await u.getPromptAi("outlineScriptAgent");
const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么请直接输出Agent配置异常"; const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么请直接输出Agent配置异常";
const { fullStream } = await u.ai.text.stream( const { fullStream } = await u.ai.text.stream(

View File

@ -98,7 +98,7 @@ async function generateGridPrompt(options: GridPromptOptions): Promise<GridPromp
: ""; : "";
const promptsData = await u.db("t_prompts").where("code", "generateImagePrompts").first(); const promptsData = await u.db("t_prompts").where("code", "generateImagePrompts").first();
const promptAiConfig = await u.getPromptAi(promptsData?.id); const promptAiConfig = await u.getPromptAi("storyboardAgent");
const mainPrompts = promptsData?.customValue || promptsData?.defaultValue; const mainPrompts = promptsData?.customValue || promptsData?.defaultValue;
const errData = `请输出${options.prompts.length}张图片\n提示词如下:\n${options.prompts.map((p, i) => `${i + 1}格: ${p}`).join("\n")}`; const errData = `请输出${options.prompts.length}张图片\n提示词如下:\n${options.prompts.map((p, i) => `${i + 1}格: ${p}`).join("\n")}`;

View File

@ -215,11 +215,13 @@ async function filterRelevantAssets(prompts: string[], allResources: ResourceIte
return availableImages; return availableImages;
} }
const { relevantAssets } = await u.ai.text.invoke({ const apiConfig = await u.getPromptAi("storyboardAgent");
messages: [ const { relevantAssets } = await u.ai.text.invoke(
{ {
role: "user", messages: [
content: `请分析以下分镜描述,从可用资产中筛选出与分镜内容直接相关的资产。 {
role: "user",
content: `请分析以下分镜描述,从可用资产中筛选出与分镜内容直接相关的资产。
${prompts.map((p, i) => `${i + 1}. ${p}`).join("\n")} ${prompts.map((p, i) => `${i + 1}. ${p}`).join("\n")}
@ -228,45 +230,21 @@ ${prompts.map((p, i) => `${i + 1}. ${p}`).join("\n")}
${availableResources.map((r) => `- ${r.name}${r.intro}`).join("\n")} ${availableResources.map((r) => `- ${r.name}${r.intro}`).join("\n")}
`, `,
},
],
output: {
relevantAssets: z
.array(
z.object({
name: z.string().describe("资产名称"),
reason: z.string().describe("选择该资产的原因"),
}),
)
.describe("与分镜内容相关的资产列表"),
}, },
],
output: {
relevantAssets: z
.array(
z.object({
name: z.string().describe("资产名称"),
reason: z.string().describe("选择该资产的原因"),
}),
)
.describe("与分镜内容相关的资产列表"),
}, },
}); apiConfig,
// const result = await chatModel!.invoke({ );
// messages: [
// {
// role: "user",
// content: `请分析以下分镜描述,从可用资产中筛选出与分镜内容直接相关的资产。
// 分镜描述:
// ${prompts.map((p, i) => `${i + 1}. ${p}`).join("\n")}
// 可用资产列表:
// ${availableResources.map((r) => `- ${r.name}${r.intro}`).join("\n")}
// 请仅选择在分镜中明确出现或被提及的角色、场景、道具。不要选择与分镜内容无关的资产。`,
// },
// ],
// responseFormat: {
// type: "json_schema",
// jsonSchema: {
// name: "filteredAssets",
// strict: true,
// schema: z.toJSONSchema(filteredAssetsSchema),
// },
// },
// });
// const data = result?.json as z.infer<typeof filteredAssetsSchema>;
if (!relevantAssets || relevantAssets.length === 0) { if (!relevantAssets || relevantAssets.length === 0) {
return availableImages; return availableImages;
@ -342,14 +320,18 @@ export default async (cells: { prompt: string }[], scriptId: number, projectId:
console.log("====润色后:", prompts); console.log("====润色后:", prompts);
const processedImages = await processImages(filteredImages); const processedImages = await processImages(filteredImages);
const apiConfig = await u.getPromptAi("storyboardImage");
const contentStr = await u.ai.image({ const contentStr = await u.ai.image(
systemPrompt: resourcesMapPrompts, {
prompt: prompts, systemPrompt: resourcesMapPrompts,
size: "4K", prompt: prompts,
aspectRatio: projectInfo?.videoRatio ? (projectInfo.videoRatio as any) : "16:9", size: "4K",
imageBase64: processedImages.map((buf) => buf.toString("base64")), aspectRatio: projectInfo?.videoRatio ? (projectInfo.videoRatio as any) : "16:9",
}); imageBase64: processedImages.map((buf) => buf.toString("base64")),
},
apiConfig,
);
const match = contentStr.match(/base64,([A-Za-z0-9+/=]+)/); const match = contentStr.match(/base64,([A-Za-z0-9+/=]+)/);
const base64Str = match?.[1] ?? contentStr; const base64Str = match?.[1] ?? contentStr;

View File

@ -63,9 +63,6 @@ export default class Storyboard {
// 更新shopts // 更新shopts
public updatePreShots(segmentId: number, cellId: number, cell: { src?: string; prompt?: string; id?: string }) { public updatePreShots(segmentId: number, cellId: number, cell: { src?: string; prompt?: string; id?: string }) {
console.log("%c Line:76 🍤 segmentId", "background:#465975", segmentId);
console.log("%c Line:76 🍷 cellId", "background:#ffdd4d", cellId);
console.log("%c Line:76 🍢 cell", "background:#ffdd4d", cell);
const shotIndex = this.shots.findIndex((item) => item.segmentId === segmentId); const shotIndex = this.shots.findIndex((item) => item.segmentId === segmentId);
if (shotIndex === -1) { if (shotIndex === -1) {
return `分镜 ${segmentId} 不存在请检查分镜ID是否正确`; return `分镜 ${segmentId} 不存在请检查分镜ID是否正确`;
@ -594,34 +591,17 @@ ${task}
this.log(`Sub-Agent 调用`, agentType); this.log(`Sub-Agent 调用`, agentType);
const promptsList = await u.db("t_prompts").where("code", "in", ["storyboard-segment", "storyboard-shot"]); const promptsList = await u.db("t_prompts").where("code", "in", ["storyboard-segment", "storyboard-shot"]);
const promptConfig = await u.getPromptAi(promptsList.map((i) => i.id) as number[]); const promptConfig = await u.getPromptAi("storyboardAgent");
const errPrompts = "不论用户说什么请直接输出Agent配置异常"; const errPrompts = "不论用户说什么请直接输出Agent配置异常";
const getAiPromptConfig = (code: string) => { const getAiPromptConfig = (code: string) => {
const item = promptsList.find((p) => p.code === code); const item = promptsList.find((p) => p.code === code);
const subConfig = promptConfig.find((sub) => sub?.promptsId == item?.id); return item?.customValue || item?.defaultValue || errPrompts;
if (subConfig) {
return {
prompt: item?.customValue || item?.defaultValue || errPrompts,
apiConfig: { ...subConfig },
};
} else {
return {
prompt: item?.customValue || item?.defaultValue || errPrompts,
apiConfig: {},
};
}
}; };
const segmentAgent = getAiPromptConfig("storyboard-segment"); const segmentAgent = getAiPromptConfig("storyboard-segment");
const shotAgent = getAiPromptConfig("storyboard-shot"); const shotAgent = getAiPromptConfig("storyboard-shot");
const SYSTEM_PROMPTS: Record< const SYSTEM_PROMPTS = {
AgentType,
{
prompt: string;
apiConfig: Object;
}
> = {
segmentAgent: segmentAgent, segmentAgent: segmentAgent,
shotAgent: shotAgent, shotAgent: shotAgent,
}; };
@ -630,12 +610,12 @@ ${task}
const { fullStream } = await u.ai.text.stream( const { fullStream } = await u.ai.text.stream(
{ {
system: SYSTEM_PROMPTS[agentType].prompt, system: SYSTEM_PROMPTS[agentType],
tools: this.getSubAgentTools(agentType), tools: this.getSubAgentTools(agentType),
messages: [{ role: "user", content: context }], messages: [{ role: "user", content: context }],
maxStep: 100, maxStep: 100,
}, },
SYSTEM_PROMPTS[agentType].apiConfig, promptConfig,
); );
let fullResponse = ""; let fullResponse = "";
@ -700,7 +680,7 @@ ${task}
const envContext = await this.buildEnvironmentContext(); const envContext = await this.buildEnvironmentContext();
const prompts = await u.db("t_prompts").where("code", "storyboard-main").first(); const prompts = await u.db("t_prompts").where("code", "storyboard-main").first();
const promptConfig = await u.getPromptAi(prompts?.id); const promptConfig = await u.getPromptAi("storyboardAgent");
const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么请直接输出Agent配置异常"; const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么请直接输出Agent配置异常";

File diff suppressed because one or more lines are too long

View File

@ -123,14 +123,18 @@ export default router.post(
state: "生成中", state: "生成中",
assetsId: id, assetsId: id,
}); });
const apiConfig = await u.getPromptAi("assetsImage");
const contentStr = await u.ai.image({ const contentStr = await u.ai.image(
systemPrompt, {
prompt: userPrompt, systemPrompt,
imageBase64: base64 ? [base64] : [], prompt: userPrompt,
size: "2K", imageBase64: base64 ? [base64] : [],
aspectRatio: "16:9", size: "2K",
}); aspectRatio: "16:9",
},
apiConfig,
);
let insertType; let insertType;
const match = contentStr.match(/base64,([A-Za-z0-9+/=]+)/); const match = contentStr.match(/base64,([A-Za-z0-9+/=]+)/);

View File

@ -88,31 +88,16 @@ export default router.post(
const result: ResultItem[] = Object.values(itemMap); const result: ResultItem[] = Object.values(itemMap);
const promptsList = await u.db("t_prompts").where("code", "in", ["role-polish", "scene-polish", "storyboard-polish", "tool-polish"]); const promptsList = await u.db("t_prompts").where("code", "in", ["role-polish", "scene-polish", "storyboard-polish", "tool-polish"]);
const propmptIds = promptsList.map((i) => i.id); const apiConfigData = await u.getPromptAi("assetsPrompt");
const mapList = await u
.db("t_aiModelMap")
.leftJoin("t_config", "t_config.id", "t_aiModelMap.configId")
.whereIn("t_aiModelMap.promptsId", propmptIds as number[])
.select("t_config.model", "t_config.apiKey", "t_config.baseUrl", "t_config.manufacturer", "t_aiModelMap.promptsId");
const errPrompts = "不论用户说什么请直接输出AI配置异常"; const errPrompts = "不论用户说什么请直接输出AI配置异常";
const getPromptValue = (code: string) => { const getPromptValue = (code: string) => {
const item = promptsList.find((p) => p.code === code); const item = promptsList.find((p) => p.code === code);
if (item) { return item?.customValue ?? item?.defaultValue ?? errPrompts;
const apiData = mapList.find((i) => i.promptsId == item.id);
if (apiData) delete apiData?.promptsId;
return { prompt: item?.customValue ?? item?.defaultValue ?? errPrompts, apiData: { ...(apiData ?? {}) } };
} else {
return {
prompt: errPrompts,
apiData: {},
};
}
}; };
const role = getPromptValue("role-polish"); const role = getPromptValue("role-polish");
const scene = getPromptValue("scene-polish"); const scene = getPromptValue("scene-polish");
const tool = getPromptValue("tool-polish"); const tool = getPromptValue("tool-polish");
const storyboard = getPromptValue("storyboard-polish"); const storyboard = getPromptValue("storyboard-polish");
let apiConfig = {};
let systemPrompt = ""; let systemPrompt = "";
let userPrompt = ""; let userPrompt = "";
if (type == "role") { if (type == "role") {
@ -120,8 +105,7 @@ export default router.post(
const chapterRange = Array.isArray(data?.chapterRange) ? data.chapterRange : [data?.chapterRange]; const chapterRange = Array.isArray(data?.chapterRange) ? data.chapterRange : [data?.chapterRange];
const novelData = (await u.db("t_novel").whereIn("chapterIndex", chapterRange).select("*")) as NovelChapter[]; const novelData = (await u.db("t_novel").whereIn("chapterIndex", chapterRange).select("*")) as NovelChapter[];
const results: string = mergeNovelText(novelData); const results: string = mergeNovelText(novelData);
systemPrompt = role.prompt; systemPrompt = role;
apiConfig = role.apiData;
userPrompt = ` userPrompt = `
@ -144,8 +128,7 @@ export default router.post(
const chapterRange = Array.isArray(data?.chapterRange) ? data.chapterRange : [data?.chapterRange]; const chapterRange = Array.isArray(data?.chapterRange) ? data.chapterRange : [data?.chapterRange];
const novelData = (await u.db("t_novel").whereIn("chapterIndex", chapterRange).select("*")) as NovelChapter[]; const novelData = (await u.db("t_novel").whereIn("chapterIndex", chapterRange).select("*")) as NovelChapter[];
const results: string = mergeNovelText(novelData); const results: string = mergeNovelText(novelData);
systemPrompt = scene.prompt; systemPrompt = scene;
apiConfig = scene.apiData;
userPrompt = ` userPrompt = `
@ -168,8 +151,7 @@ export default router.post(
const chapterRange = Array.isArray(data?.chapterRange) ? data.chapterRange : [data?.chapterRange]; const chapterRange = Array.isArray(data?.chapterRange) ? data.chapterRange : [data?.chapterRange];
const novelData = (await u.db("t_novel").whereIn("chapterIndex", chapterRange).select("*")) as NovelChapter[]; const novelData = (await u.db("t_novel").whereIn("chapterIndex", chapterRange).select("*")) as NovelChapter[];
const results: string = mergeNovelText(novelData); const results: string = mergeNovelText(novelData);
systemPrompt = tool.prompt; systemPrompt = tool;
apiConfig = tool.apiData;
userPrompt = ` userPrompt = `
@ -188,8 +170,7 @@ export default router.post(
`; `;
} }
if (type == "storyboard") { if (type == "storyboard") {
systemPrompt = storyboard.prompt; systemPrompt = storyboard;
apiConfig = storyboard.apiData;
userPrompt = ` userPrompt = `
@ -207,7 +188,6 @@ export default router.post(
`; `;
} }
async function generatePrompt() { async function generatePrompt() {
apiConfig = {};
const result = await u.ai.text.invoke( const result = await u.ai.text.invoke(
{ {
messages: [ messages: [
@ -224,9 +204,7 @@ export default router.post(
prompt: zod.string().describe("提示词"), prompt: zod.string().describe("提示词"),
}, },
}, },
{ apiConfigData,
...apiConfig,
},
); );
// const result = await model.invoke({ // const result = await model.invoke({
// messages: [ // messages: [
@ -256,7 +234,6 @@ export default router.post(
res.status(200).send(success({ prompt: prompt, assetsId })); res.status(200).send(success({ prompt: prompt, assetsId }));
} catch (e: any) { } catch (e: any) {
console.log("%c Line:235 🥚 e", "background:#33a5ff", e);
return res.status(500).send(error(e?.data?.error?.message ?? e?.message ?? "生成失败")); return res.status(500).send(error(e?.data?.error?.message ?? e?.message ?? "生成失败"));
} }
}, },

View File

@ -13,12 +13,12 @@ export default router.post(
modelName: z.string(), modelName: z.string(),
apiKey: z.string(), apiKey: z.string(),
baseURL: z.string().optional(), baseURL: z.string().optional(),
manufacturer: z.string(),
}), }),
async (req, res) => { async (req, res) => {
const { modelName, apiKey, baseURL } = req.body; const { modelName, apiKey, baseURL, manufacturer } = req.body;
const getWeatherTool = tool({ const getWeatherTool = tool({
// strict: true,
description: "Get the weather in a location", description: "Get the weather in a location",
inputSchema: z.object({ inputSchema: z.object({
location: z.string().describe("The location to get the weather for"), location: z.string().describe("The location to get the weather for"),
@ -43,6 +43,7 @@ export default router.post(
model: modelName, model: modelName,
apiKey, apiKey,
baseURL, baseURL,
manufacturer,
}, },
); );
res.status(200).send(success(reply)); res.status(200).send(success(reply));

View File

@ -17,13 +17,21 @@ export default router.post(
async (req, res) => { async (req, res) => {
const { modelName, apiKey, baseURL, manufacturer } = req.body; const { modelName, apiKey, baseURL, manufacturer } = req.body;
try { try {
const image = await u.ai.image({ const image = await u.ai.image(
prompt: {
"一张16:9比例的图片完美等分为2x2四宫格布局各区域无缝衔接\n左上宫格一只可爱的猫毛发蓬松眼睛明亮姿态俏皮\n右上宫格一只友善的狗金毛犬表情愉悦摇着尾巴\n左下宫格一头健壮的牛田园背景目光温和皮毛光泽\n右下宫格一匹骏马姿态优雅鬃毛飘逸肌肉健美\n风格要求四个宫格风格统一色彩鲜艳饱和高清画质细节清晰锐利专业插画风格线条干净统一的左上方光源柔和阴影和谐配色卡通/半写实风格,宫格间用白色或浅灰细线分隔", prompt:
imageBase64: [], "一张16:9比例的图片完美等分为2x2四宫格布局各区域无缝衔接\n左上宫格一只可爱的猫毛发蓬松眼睛明亮姿态俏皮\n右上宫格一只友善的狗金毛犬表情愉悦摇着尾巴\n左下宫格一头健壮的牛田园背景目光温和皮毛光泽\n右下宫格一匹骏马姿态优雅鬃毛飘逸肌肉健美\n风格要求四个宫格风格统一色彩鲜艳饱和高清画质细节清晰锐利专业插画风格线条干净统一的左上方光源柔和阴影和谐配色卡通/半写实风格,宫格间用白色或浅灰细线分隔",
aspectRatio: "16:9", imageBase64: [],
size: "1K", aspectRatio: "16:9",
}); size: "1K",
},
{
model: modelName,
apiKey,
baseURL,
manufacturer,
},
);
res.status(200).send(success(image)); res.status(200).send(success(image));
} catch (err) { } catch (err) {
const msg = u.error(err).message; const msg = u.error(err).message;

View File

@ -12,20 +12,28 @@ export default router.post(
modelName: z.string().optional(), modelName: z.string().optional(),
apiKey: z.string(), apiKey: z.string(),
baseURL: z.string().optional(), baseURL: z.string().optional(),
manufacturer: z.enum(["runninghub", "volcengine", "apimart", "gemini", "openAi"]), manufacturer: z.string(),
}), }),
async (req, res) => { async (req, res) => {
const { modelName, apiKey, baseURL, manufacturer } = req.body; const { modelName, apiKey, baseURL, manufacturer } = req.body;
try { try {
const videoPath = await u.ai.video({ const videoPath = await u.ai.video(
imageBase64: [], {
savePath: "test.mp4", imageBase64: [],
prompt: "stickman Dances", savePath: "test.mp4",
duration: 4, prompt: "stickman Dances",
resolution: "720p", duration: 4,
aspectRatio: "16:9", resolution: "720p",
audio: false, aspectRatio: "16:9",
}); audio: false,
},
{
model: modelName,
apiKey,
baseURL,
manufacturer,
},
);
const url = await u.oss.getFileUrl(videoPath); const url = await u.oss.getFileUrl(videoPath);
res.status(200).send(success(url)); res.status(200).send(success(url));
} catch (err: any) { } catch (err: any) {

View File

@ -8,23 +8,23 @@ const router = express.Router();
export default router.post( export default router.post(
"/", "/",
validateFields({ validateFields({
type: z.string(), type: z.enum(["text", "video", "image"]),
name: z.string(),
model: z.string(), model: z.string(),
baseUrl: z.string(), baseUrl: z.string(),
apiKey: z.string(), apiKey: z.string(),
modelType: z.string(),
manufacturer: z.string(), manufacturer: z.string(),
}), }),
async (req, res) => { async (req, res) => {
const { type, name, model, baseUrl, apiKey, manufacturer } = req.body; const { type, model, baseUrl, apiKey, manufacturer, modelType } = req.body;
await u.db("t_config").insert({ await u.db("t_config").insert({
type, type,
name,
model, model,
baseUrl, baseUrl,
apiKey, apiKey,
manufacturer, manufacturer,
modelType,
createTime: Date.now(), createTime: Date.now(),
userId: 1, userId: 1,
}); });

View File

@ -8,20 +8,13 @@ const router = express.Router();
export default router.post( export default router.post(
"/", "/",
validateFields({ validateFields({
id: z.number().optional(), id: z.number(),
promptsId: z.number(),
configId: z.number(), configId: z.number(),
}), }),
async (req, res) => { async (req, res) => {
const { id, promptsId, configId } = req.body; const { id, configId } = req.body;
if (id) { if (id) {
await u.db("t_aiModelMap").where("id", id).update({ await u.db("t_aiModelMap").where("id", id).update({
promptsId,
configId,
});
} else {
await u.db("t_aiModelMap").insert({
promptsId,
configId, configId,
}); });
} }

View File

@ -1,13 +1,13 @@
import express from "express"; import express from "express";
import u from "@/utils"; import u from "@/utils";
import { success } from "@/lib/responseFormat"; import { success } from "@/lib/responseFormat";
const router = express.Router(); const router = express.Router();
export default router.post("/", async (req, res) => { export default router.post("/", async (req, res) => {
const configData = await u const configData = await u
.db("t_prompts") .db("t_aiModelMap")
.leftJoin("t_aiModelMap", "t_prompts.id", "t_aiModelMap.promptsId") .leftJoin("t_config", "t_aiModelMap.configId", "t_config.id")
.leftJoin("t_config", "t_config.id", "t_aiModelMap.configId") .select("t_aiModelMap.name", "t_config.model", "t_aiModelMap.id");
.select("t_prompts.id as promptsId", "t_prompts.code", "t_prompts.name", "t_config.model", "t_aiModelMap.id");
res.status(200).send(success(configData)); res.status(200).send(success(configData));
}); });

View File

@ -9,23 +9,23 @@ export default router.post(
"/", "/",
validateFields({ validateFields({
id: z.number(), id: z.number(),
type: z.string(), type: z.enum(["text", "video", "image"]),
name: z.string(),
model: z.string(), model: z.string(),
baseUrl: z.string(), baseUrl: z.string(),
modelType: z.string(),
apiKey: z.string(), apiKey: z.string(),
manufacturer: z.string(), manufacturer: z.string(),
}), }),
async (req, res) => { async (req, res) => {
const { id, type, name, model, baseUrl, apiKey, manufacturer } = req.body; const { id, type, model, baseUrl, apiKey, manufacturer, modelType } = req.body;
await u.db("t_config").where("id", id).update({ await u.db("t_config").where("id", id).update({
type, type,
name,
model, model,
baseUrl, baseUrl,
apiKey, apiKey,
manufacturer, manufacturer,
modelType,
}); });
res.status(200).send(success("编辑成功")); res.status(200).send(success("编辑成功"));
}, },

View File

@ -17,19 +17,19 @@ async function urlToBase64(imageUrl: string): Promise<string> {
} }
// 超分并保存到 oss // 超分并保存到 oss
async function superResolutionAndSave( async function superResolutionAndSave(src: string, projectId: number, videoRatio: string): Promise<{ ossPath: string; base64: string }> {
src: string, const apiConfig = await u.getPromptAi("storyboardImage");
projectId: number, const contentStr = await u.ai.image(
videoRatio: string, {
): Promise<{ ossPath: string; base64: string }> { aspectRatio: videoRatio,
const contentStr = await u.ai.image({ size: "1K",
aspectRatio: videoRatio, resType: "b64",
size: "1K", systemPrompt: "你的核心任务是将所给的图片超分到 1K ,不改变图片任何内容,仅改变分辨率",
resType: "b64", prompt: "你的核心任务是将所给的图片超分到 1K ,不改变图片任何内容,仅改变分辨率",
systemPrompt: "你的核心任务是将所给的图片超分到 1K ,不改变图片任何内容,仅改变分辨率", imageBase64: [await urlToBase64(src)],
prompt: "你的核心任务是将所给的图片超分到 1K ,不改变图片任何内容,仅改变分辨率", },
imageBase64: [await urlToBase64(src)], apiConfig,
}); );
const match = contentStr.match(/base64,([A-Za-z0-9+/=]+)/); const match = contentStr.match(/base64,([A-Za-z0-9+/=]+)/);
const base64Str = match ? match[1] : contentStr; const base64Str = match ? match[1] : contentStr;
const buffer = Buffer.from(base64Str, "base64"); const buffer = Buffer.from(base64Str, "base64");
@ -50,9 +50,9 @@ export default router.post(
id: z.string(), id: z.string(),
prompt: z.string().optional(), prompt: z.string().optional(),
src: z.string(), src: z.string(),
}) }),
), ),
}) }),
), ),
}), }),
async (req, res) => { async (req, res) => {
@ -63,9 +63,7 @@ export default router.post(
if (!projectData) return res.status(500).send(error("项目不存在")); if (!projectData) return res.status(500).send(error("项目不存在"));
// 遍历处理每个分镜段 // 遍历处理每个分镜段
const processSegment = async ( const processSegment = async (segment: { cells: { id: string; src: string }[] }) => {
segment: { cells: { id: string; src: string }[] }
) => {
// 超分所有 cell // 超分所有 cell
const cellsWithSuperscore = await Promise.all( const cellsWithSuperscore = await Promise.all(
segment.cells.map(async (cell) => { segment.cells.map(async (cell) => {
@ -76,9 +74,9 @@ export default router.post(
scriptId, scriptId,
filePath: ossPath, // oss 路径(未签名) filePath: ossPath, // oss 路径(未签名)
src: cell.src, src: cell.src,
type: "分镜" type: "分镜",
}; };
}) }),
); );
return cellsWithSuperscore; return cellsWithSuperscore;
}; };
@ -92,9 +90,9 @@ export default router.post(
(item.value as any[]).map(async (cell) => ({ (item.value as any[]).map(async (cell) => ({
...cell, ...cell,
filePath: await u.oss.getFileUrl(cell.filePath ?? ""), filePath: await u.oss.getFileUrl(cell.filePath ?? ""),
})) })),
) ),
); );
res.status(200).send(success(flatList)); res.status(200).send(success(flatList));
} },
); );

View File

@ -4,6 +4,7 @@ import { error, success } from "@/lib/responseFormat";
import { validateFields } from "@/middleware/middleware"; import { validateFields } from "@/middleware/middleware";
import { z } from "zod"; import { z } from "zod";
import path from "path"; import path from "path";
import axios from "axios";
const router = express.Router(); const router = express.Router();
@ -103,7 +104,12 @@ const prompt = `
Motion Prompt JSON Motion Prompt JSON
`; `;
async function urlToBase64(imageUrl: string): Promise<string> {
const response = await axios.get(imageUrl, { responseType: "arraybuffer" });
const contentType = response.headers["content-type"] || "image/png";
const base64 = Buffer.from(response.data, "binary").toString("base64");
return `data:${contentType};base64,${base64}`;
}
// 生成单个分镜提示 // 生成单个分镜提示
async function generateSingleVideoPrompt({ async function generateSingleVideoPrompt({
scriptText, scriptText,
@ -114,19 +120,6 @@ async function generateSingleVideoPrompt({
storyboardPrompt: string; storyboardPrompt: string;
ossPath: string; ossPath: string;
}): Promise<{ content: string; time: number; name: string }> { }): Promise<{ content: string; time: number; name: string }> {
let rootDir: string;
if (typeof process.versions?.electron !== "undefined") {
const { app } = require("electron");
const userDataDir: string = app.getPath("userData");
rootDir = path.join(userDataDir, "uploads");
} else {
rootDir = path.join(process.cwd(), "uploads");
}
let imagePath = ossPath;
if (ossPath.includes("http")) {
imagePath = new URL(ossPath).pathname;
}
const messages: any[] = [ const messages: any[] = [
{ {
role: "system", role: "system",
@ -140,24 +133,27 @@ async function generateSingleVideoPrompt({
text: `剧本内容:${scriptText}\n分镜提示词:${storyboardPrompt}`, text: `剧本内容:${scriptText}\n分镜提示词:${storyboardPrompt}`,
}, },
{ {
type: "local", type: "image",
path: path.join(rootDir, imagePath), image: await urlToBase64(ossPath),
}, },
], ],
}, },
]; ];
try { try {
const result = await u.ai.text.invoke({ const apiConfig = await u.getPromptAi("videoPrompt");
messages,
output: {
time: z.number().describe("时长,镜头时长 1-15"),
content: z.string().describe("提示词内容"),
name: z.string().describe("分镜名称"),
},
});
console.log("%c Line:156 🍩 result", "background:#33a5ff", result);
const result = await u.ai.text.invoke(
{
messages,
output: {
time: z.number().describe("时长,镜头时长 1-15"),
content: z.string().describe("提示词内容"),
name: z.string().describe("分镜名称"),
},
},
apiConfig,
);
if (!result) { if (!result) {
console.error("AI 返回结果为空:", result); console.error("AI 返回结果为空:", result);
throw new Error("AI 返回结果为空"); throw new Error("AI 返回结果为空");

View File

@ -42,5 +42,5 @@ export default router.post(
}); });
res.status(200).send(success({ message: "新增视频成功" })); res.status(200).send(success({ message: "新增视频成功" }));
} },
); );

View File

@ -1,6 +1,6 @@
import express from "express"; import express from "express";
import u from "@/utils"; import u from "@/utils";
import { success } from "@/lib/responseFormat"; import { error, success } from "@/lib/responseFormat";
import { validateFields } from "@/middleware/middleware"; import { validateFields } from "@/middleware/middleware";
import { z } from "zod"; import { z } from "zod";
const router = express.Router(); const router = express.Router();
@ -20,8 +20,8 @@ export default router.post(
validateFields({ validateFields({
scriptId: z.number(), scriptId: z.number(),
projectId: z.number(), projectId: z.number(),
manufacturer: z.string(), configId: z.number(),
mode: z.enum(["startEnd", "multi", "single"]), mode: z.enum(["startEnd", "multi", "single",'text','']),
startFrame: imageItemSchema.optional(), startFrame: imageItemSchema.optional(),
endFrame: imageItemSchema.optional(), endFrame: imageItemSchema.optional(),
images: z images: z
@ -38,19 +38,21 @@ export default router.post(
prompt: z.string().optional(), prompt: z.string().optional(),
}), }),
async (req, res) => { async (req, res) => {
const { scriptId, projectId, manufacturer, mode, startFrame, endFrame, images, resolution, duration, prompt } = req.body; const { scriptId, projectId, configId, mode, startFrame, endFrame, images, resolution, duration, prompt } = req.body;
// 生成新ID // 生成新ID
const maxIdResult: any = await u.db("t_videoConfig").max("id as maxId").first(); const maxIdResult: any = await u.db("t_videoConfig").max("id as maxId").first();
const newId = (maxIdResult?.maxId || 0) + 1; const newId = (maxIdResult?.maxId || 0) + 1;
const now = Date.now(); const now = Date.now();
const configData = await u.db("t_config").where("id", configId).first();
if (!configData) return res.status(500).send(error("不存在的模型"));
// 插入数据 // 插入数据
await u.db("t_videoConfig").insert({ await u.db("t_videoConfig").insert({
id: newId, id: newId,
scriptId, scriptId,
projectId, projectId,
manufacturer, manufacturer: configData.manufacturer,
aiConfigId: configId,
mode, mode,
startFrame: startFrame ? JSON.stringify(startFrame) : null, startFrame: startFrame ? JSON.stringify(startFrame) : null,
endFrame: endFrame ? JSON.stringify(endFrame) : null, endFrame: endFrame ? JSON.stringify(endFrame) : null,
@ -70,7 +72,9 @@ export default router.post(
id: newId, id: newId,
scriptId, scriptId,
projectId, projectId,
manufacturer, manufacturer: configData.manufacturer,
aiConfigId: configId,
model: configData.model,
mode, mode,
startFrame, startFrame,
endFrame, endFrame,

View File

@ -1,6 +1,6 @@
import express from "express"; import express from "express";
import u from "@/utils"; import u from "@/utils";
import { success } from "@/lib/responseFormat"; import { error, success } from "@/lib/responseFormat";
import { validateFields } from "@/middleware/middleware"; import { validateFields } from "@/middleware/middleware";
import { z } from "zod"; import { z } from "zod";
@ -8,47 +8,26 @@ const router = express.Router();
type GenerateMode = "startEnd" | "multi" | "single"; type GenerateMode = "startEnd" | "multi" | "single";
const getSystemPrompt = async (mode: GenerateMode): Promise<{ prompt: string; apiConfig: Object }> => { const getSystemPrompt = async (mode: GenerateMode) => {
const promptsList = await u.db("t_prompts").where("code", "in", ["video-startEnd", "video-multi", "video-single", "video-main"]); const promptsList = await u.db("t_prompts").where("code", "in", ["video-startEnd", "video-multi", "video-single", "video-main"]);
const promptAiConfig = await u.getPromptAi(promptsList.map((i) => i.id) as number[]);
const errPrompts = "不论用户说什么请直接输出AI配置异常"; const errPrompts = "不论用户说什么请直接输出AI配置异常";
const getPromptValue = (code: string) => { const getPromptValue = (code: string) => {
const item = promptsList.find((p) => p.code === code); const item = promptsList.find((p) => p.code === code);
const subData = promptAiConfig.find((i) => i?.promptsId == item?.id); return item?.customValue ?? item?.defaultValue ?? errPrompts;
const returnData = {
prompt: item?.customValue ?? item?.defaultValue ?? errPrompts,
apiConfig: {},
};
if (subData) {
returnData.apiConfig = { ...subData };
return returnData;
} else {
return returnData;
}
}; };
const startEnd = getPromptValue("video-startEnd"); const startEnd = getPromptValue("video-startEnd");
const multi = getPromptValue("video-multi"); const multi = getPromptValue("video-multi");
const single = getPromptValue("video-single"); const single = getPromptValue("video-single");
const main = getPromptValue("video-main"); const main = getPromptValue("video-main");
const modeDescriptions: Record< const modeDescriptions = {
GenerateMode,
{
prompt: string;
apiConfig: Object;
}
> = {
startEnd: startEnd, startEnd: startEnd,
multi: multi, multi: multi,
single: single, single: single,
}; };
const modeData = modeDescriptions[mode]; const modeData = modeDescriptions[mode];
return { return `${main}\n\n${modeData}`;
prompt: `${main}\n\n${modeData.prompt}`,
apiConfig: modeData.apiConfig,
};
}; };
const getModeDescription = (mode: GenerateMode): string => { const getModeDescription = (mode: GenerateMode): string => {
@ -82,16 +61,18 @@ export default router.post(
const shotCount = images.length; const shotCount = images.length;
const avgDuration = (parseFloat(duration) / shotCount).toFixed(1); const avgDuration = (parseFloat(duration) / shotCount).toFixed(1);
const promptConfig = await getSystemPrompt(mode); const promptConfig = await getSystemPrompt(mode);
const result = await u.ai.text.invoke( const promptAiConfig = await u.getPromptAi("videoPrompt");
{ try {
messages: [ const result = await u.ai.text.invoke(
{ {
role: "system", messages: [
content: promptConfig.prompt, {
}, role: "system",
{ content: promptConfig,
role: "user", },
content: `Mode: ${getModeDescription(mode)} {
role: "user",
content: `Mode: ${getModeDescription(mode)}
Reference Images: Reference Images:
${imagePrompts} ${imagePrompts}
@ -105,12 +86,15 @@ Parameters:
- Average Duration: ${avgDuration}s per shot - Average Duration: ${avgDuration}s per shot
Generate storyboard prompts:`, Generate storyboard prompts:`,
}, },
], ],
}, },
promptConfig.apiConfig, promptAiConfig,
); );
res.status(200).send(success(result.text)); res.status(200).send(success(result.text));
} catch (e) {
return res.status(500).send(error(u.error(e).message));
}
}, },
); );

View File

@ -4,6 +4,7 @@ import { z } from "zod";
import { v4 as uuidv4 } from "uuid"; import { v4 as uuidv4 } from "uuid";
import { error, success } from "@/lib/responseFormat"; import { error, success } from "@/lib/responseFormat";
import { validateFields } from "@/middleware/middleware"; import { validateFields } from "@/middleware/middleware";
import { t_config } from "@/types/database";
const router = express.Router(); const router = express.Router();
@ -13,35 +14,52 @@ export default router.post(
validateFields({ validateFields({
projectId: z.number(), projectId: z.number(),
scriptId: z.number(), scriptId: z.number(),
configId: z.number().optional(), // 关联的视频配置ID configId: z.number().optional(), // 关联的视频配 置ID
type: z.string().optional(), type: z.string().optional(),
resolution: z.string(), resolution: z.string(),
aiConfigId: z.number(),
filePath: z.array(z.string()), filePath: z.array(z.string()),
duration: z.number(), duration: z.number(),
prompt: z.string(), prompt: z.string(),
}), }),
async (req, res) => { async (req, res) => {
const { type, scriptId, projectId, configId, resolution, filePath, duration, prompt } = req.body; const { type, scriptId, projectId, configId, aiConfigId, resolution, filePath, duration, prompt } = req.body;
// 参数校验 // // 参数校验
if (type === "volcengine") { // if (type === "volcengine") {
if (duration < 4 || duration > 12) { // if (duration < 4 || duration > 12) {
return res.status(400).send(error("视频时长需在4-12秒之间")); // return res.status(400).send(error("视频时长需在4-12秒之间"));
} // }
if (!["480p", "720p", "1080p"].includes(resolution)) { // if (!["480p", "720p", "1080p"].includes(resolution)) {
return res.status(400).send(error("视频分辨率不正确")); // return res.status(400).send(error("视频分辨率不正确"));
} // }
// }
// if (type === "runninghub") {
// if (duration !== 10 && duration !== 15) {
// return res.status(400).send(error("视频时长只能是10秒或15秒"));
// }
// if (resolution !== "9:16" && resolution !== "16:9") {
// return res.status(400).send(error("视频分辨率不正确"));
// }
// }
const configData = await u.db("t_videoConfig").where("id", configId).first();
if (!configData) {
return res.status(500).send(error("视频配置不存在"));
} }
if (type === "runninghub") { // 优先使用视频配置中的AI配置ID查询,查不到再使用传入的aiConfigId
if (duration !== 10 && duration !== 15) { let aiConfigData = null;
return res.status(400).send(error("视频时长只能是10秒或15秒")); if (configData.aiConfigId) {
} aiConfigData = await u.db("t_config").where("id", configData.aiConfigId).first();
if (resolution !== "9:16" && resolution !== "16:9") { }
return res.status(400).send(error("视频分辨率不正确")); if (!aiConfigData) {
} aiConfigData = await u.db("t_config").where("id", aiConfigId).first();
} }
if (!aiConfigData) {
return res.status(500).send(error("模型配置不存在"));
}
// 过滤掉空值 // 过滤掉空值
let fileUrl = filePath.filter((p: string) => p && p.trim() !== ""); let fileUrl = filePath.filter((p: string) => p && p.trim() !== "");
@ -103,7 +121,7 @@ export default router.post(
res.status(200).send(success({ id: videoId, configId: configId || null })); res.status(200).send(success({ id: videoId, configId: configId || null }));
// 异步生成视频 // 异步生成视频
generateVideoAsync(videoId, projectId, fileUrl, savePath, prompt, duration, resolution, type); generateVideoAsync(videoId, projectId, fileUrl, savePath, prompt, duration, resolution, aiConfigData);
}, },
); );
@ -116,7 +134,7 @@ async function generateVideoAsync(
prompt: string, prompt: string,
duration: number, duration: number,
resolution: string, resolution: string,
type?: string, aiConfigData: t_config,
) { ) {
try { try {
const projectData = await u.db("t_project").where("id", projectId).select("artStyle").first(); const projectData = await u.db("t_project").where("id", projectId).select("artStyle").first();
@ -149,14 +167,22 @@ ${prompt}
3. 3.
4. logo 4. logo
`; `;
const videoPath = await u.ai.video({ const videoPath = await u.ai.video(
imageBase64, {
savePath, imageBase64,
prompt: inputPrompt, savePath,
duration: duration as any, prompt: inputPrompt,
aspectRatio: resolution as any, duration: duration as any,
resolution: resolution as any, aspectRatio: resolution as any,
}); resolution: resolution as any,
},
{
baseURL: aiConfigData?.baseUrl!,
model: aiConfigData?.model!,
apiKey: aiConfigData?.apiKey!,
manufacturer: aiConfigData?.manufacturer!,
},
);
if (videoPath) { if (videoPath) {
// 生成成功,更新状态为 1 // 生成成功,更新状态为 1

View File

@ -14,8 +14,8 @@ export default router.post(
async (req, res) => { async (req, res) => {
const { userId } = req.body; const { userId } = req.body;
const data = await u.db("t_config").where("userId", userId).select("manufacturer", "model"); const data = await u.db("t_config").where("type", "video").where("userId", userId).select("manufacturer", "model", "id");
res.status(200).send(success(data)); res.status(200).send(success(data));
} },
); );

View File

@ -15,16 +15,20 @@ export default router.post(
const { scriptId } = req.body; const { scriptId } = req.body;
// 查询该脚本下的所有视频配置 // 查询该脚本下的所有视频配置
const configs = await u.db("t_videoConfig") const configs = await u
.db("t_videoConfig")
.leftJoin("t_config", "t_config.id", "t_videoConfig.aiConfigId")
.where({ scriptId }) .where({ scriptId })
.orderBy("createTime", "desc"); .orderBy("createTime", "desc")
.select("t_videoConfig.*", "t_config.manufacturer as manufacturer", "t_config.model");
// 解析 JSON 字段 // 解析 JSON 字段
const result = configs.map((config: any) => ({ const result = configs.map((config: any) => ({
id: config.id, id: config.id,
scriptId: config.scriptId, scriptId: config.scriptId,
projectId: config.projectId, projectId: config.projectId,
aiConfigId: config.aiConfigId,
manufacturer: config.manufacturer, manufacturer: config.manufacturer,
model: config.model,
mode: config.mode, mode: config.mode,
startFrame: config.startFrame ? JSON.parse(config.startFrame) : null, startFrame: config.startFrame ? JSON.parse(config.startFrame) : null,
endFrame: config.endFrame ? JSON.parse(config.endFrame) : null, endFrame: config.endFrame ? JSON.parse(config.endFrame) : null,

View File

@ -1,10 +1,11 @@
// @db-hash e1460b0ace03f6aaed458653a32b6ffb // @db-hash 4cd44aef6bb6ffb02c4619525966496d
//该文件由脚本自动生成,请勿手动修改 //该文件由脚本自动生成,请勿手动修改
export interface t_aiModelMap { export interface t_aiModelMap {
'configId'?: number | null; 'configId'?: number | null;
'id'?: number; 'id'?: number;
'promptsId'?: number | null; 'key'?: string | null;
'name'?: string | null;
} }
export interface t_assets { export interface t_assets {
'duration'?: string | null; 'duration'?: string | null;
@ -37,7 +38,7 @@ export interface t_config {
'id'?: number; 'id'?: number;
'manufacturer'?: string | null; 'manufacturer'?: string | null;
'model'?: string | null; 'model'?: string | null;
'name'?: string | null; 'modelType'?: string | null;
'type'?: string | null; 'type'?: string | null;
'userId'?: number | null; 'userId'?: number | null;
} }
@ -135,6 +136,7 @@ export interface t_video {
'time'?: number | null; 'time'?: number | null;
} }
export interface t_videoConfig { export interface t_videoConfig {
'aiConfigId'?: number | null;
'createTime'?: number | null; 'createTime'?: number | null;
'duration'?: number | null; 'duration'?: number | null;
'endFrame'?: string | null; 'endFrame'?: string | null;

View File

@ -28,13 +28,16 @@ const modelInstance = {
other, other,
} as const; } as const;
export default async (input: ImageConfig, config?: AIConfig) => { export default async (input: ImageConfig, config: AIConfig) => {
const sqlTextModelConfig = await u.getConfig("image"); const { model, apiKey, baseURL, manufacturer } = { ...config };
const { model, apiKey, baseURL, manufacturer } = { ...sqlTextModelConfig, ...config }; if (!config || !config?.model || !config?.apiKey || !config?.manufacturer) throw new Error("请检查模型配置是否正确");
const manufacturerFn = modelInstance[manufacturer as keyof typeof modelInstance]; const manufacturerFn = modelInstance[manufacturer as keyof typeof modelInstance];
if (!manufacturerFn) if (!manufacturerFn) throw new Error("不支持的图片厂商"); if (!manufacturerFn) if (!manufacturerFn) throw new Error("不支持的图片厂商");
const owned = modelList.find((m) => m.model === model); if (manufacturer !== "other") {
if (!owned) throw new Error("不支持的模型"); const owned = modelList.find((m) => m.model === model);
if (!owned) throw new Error("不支持的模型");
}
// 补充图片的 base64 内容类型字符串 // 补充图片的 base64 内容类型字符串
if (input.imageBase64 && input.imageBase64.length > 0) { if (input.imageBase64 && input.imageBase64.length > 0) {

View File

@ -3,14 +3,13 @@ import { createGoogleGenerativeAI } from "@ai-sdk/google";
import { generateText } from "ai"; import { generateText } from "ai";
export default async (input: ImageConfig, config: AIConfig): Promise<string> => { export default async (input: ImageConfig, config: AIConfig): Promise<string> => {
console.log("%c Line:6 🌰 config", "background:#ffdd4d", config);
if (!config.model) throw new Error("缺少Model名称"); if (!config.model) throw new Error("缺少Model名称");
if (!config.apiKey) throw new Error("缺少API Key"); if (!config.apiKey) throw new Error("缺少API Key");
if (!input.prompt) throw new Error("缺少提示词"); if (!input.prompt) throw new Error("缺少提示词");
const google = createGoogleGenerativeAI({ const google = createGoogleGenerativeAI({
apiKey: config.apiKey, apiKey: config.apiKey,
baseURL: config.baseURL, baseURL: config?.baseURL ?? "https://generativelanguage.googleapis.com/v1beta",
}); });
// 构建完整的提示词 // 构建完整的提示词

View File

@ -1,5 +1,5 @@
import "../type"; import "../type";
import { generateImage, generateText } from "ai"; import { generateImage, generateText, ModelMessage } from "ai";
import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
export default async (input: ImageConfig, config: AIConfig): Promise<string> => { export default async (input: ImageConfig, config: AIConfig): Promise<string> => {
@ -27,9 +27,24 @@ export default async (input: ImageConfig, config: AIConfig): Promise<string> =>
const fullPrompt = input.systemPrompt ? `${input.systemPrompt}\n\n${input.prompt}` : input.prompt; const fullPrompt = input.systemPrompt ? `${input.systemPrompt}\n\n${input.prompt}` : input.prompt;
const model = config.model; const model = config.model;
if (model.includes("gemini") || model.includes("nano")) { if (model.includes("gemini") || model.includes("nano")) {
let promptData;
if (input.imageBase64 && input.imageBase64.length) {
promptData = [{ role: "system", content: fullPrompt + `请直接输出图片` }];
(promptData as ModelMessage[]).push({
role: "user",
content: input.imageBase64.map((i) => ({
type: "image",
image: i,
})),
});
} else {
promptData = fullPrompt + `请直接输出图片`;
}
console.log("%c Line:31 🍅 promptData", "background:#2eafb0", promptData);
const result = await generateText({ const result = await generateText({
model: otherProvider.languageModel(model), model: otherProvider.languageModel(model),
prompt: fullPrompt + `请直接输出图片`, prompt: promptData as string | ModelMessage[],
providerOptions: { providerOptions: {
google: { google: {
imageConfig: { imageConfig: {

View File

@ -20,7 +20,6 @@ function template(replaceObj: Record<string, any>, url: string) {
export default async (input: ImageConfig, config: AIConfig): Promise<string> => { export default async (input: ImageConfig, config: AIConfig): Promise<string> => {
if (!config.model) throw new Error("缺少Model名称"); if (!config.model) throw new Error("缺少Model名称");
if (!config.apiKey) throw new Error("缺少API Key"); if (!config.apiKey) throw new Error("缺少API Key");
const apiKey = "Token " + config.apiKey.replace(/Token\s+/g, "").trim(); const apiKey = "Token " + config.apiKey.replace(/Token\s+/g, "").trim();
const viduq2Ratio = ["16:9", "9:16", "1:1", "3:4", "4:3", "21:9", "2:3", "3:2"]; const viduq2Ratio = ["16:9", "9:16", "1:1", "3:4", "4:3", "21:9", "2:3", "3:2"];
const viduq1Ratio = ["16:9", "9:16", "1:1", "3:4", "4:3"]; const viduq1Ratio = ["16:9", "9:16", "1:1", "3:4", "4:3"];
@ -60,7 +59,8 @@ export default async (input: ImageConfig, config: AIConfig): Promise<string> =>
...(images.length && { images: images }), ...(images.length && { images: images }),
}; };
const urlObj = getApiUrl(config.baseURL!); const urlObj = getApiUrl(config.baseURL! ?? "https://api.vidu.cn/ent/v2/reference2image|https://api.vidu.cn/ent/v2/tasks/{id}/creations");
try { try {
const { data } = await axios.post(urlObj.requestUrl, body, { headers: { Authorization: apiKey } }); const { data } = await axios.post(urlObj.requestUrl, body, { headers: { Authorization: apiKey } });
@ -69,17 +69,13 @@ export default async (input: ImageConfig, config: AIConfig): Promise<string> =>
return await pollTask(async () => { return await pollTask(async () => {
const { data: queryData } = await axios.get(queryUrl, { headers: { Authorization: apiKey } }); const { data: queryData } = await axios.get(queryUrl, { headers: { Authorization: apiKey } });
if (queryData.state !== 0) { const { state, err_code, creations } = queryData || {};
return { completed: false, error: queryData.message || "查询任务失败" };
}
const { state, err_code, creations } = queryData.data || {};
if (state === "failed") { if (state === "failed") {
return { completed: false, error: err_code || "图片生成失败" }; return { completed: false, error: err_code || "图片生成失败" };
} }
if (state === "succeed") { if (state === "success") {
return { completed: true, url: creations?.[0]?.url }; return { completed: true, url: creations?.[0]?.url };
} }

View File

@ -6,7 +6,6 @@ import { parse } from "best-effort-json-parser";
import modelList from "./modelList"; import modelList from "./modelList";
import { z } from "zod"; import { z } from "zod";
import { OpenAIProvider } from "@ai-sdk/openai"; import { OpenAIProvider } from "@ai-sdk/openai";
interface AIInput<T extends Record<string, z.ZodTypeAny> | undefined = undefined> { interface AIInput<T extends Record<string, z.ZodTypeAny> | undefined = undefined> {
system?: string; system?: string;
tools?: Record<string, Tool>; tools?: Record<string, Tool>;
@ -23,10 +22,9 @@ interface AIConfig {
manufacturer?: string; manufacturer?: string;
} }
const buildOptions = async (input: AIInput<any>, config: AIConfig) => { const buildOptions = async (input: AIInput<any>, config: AIConfig = {}) => {
let sqlTextModelConfig = {}; if (!config || !config?.model || !config?.apiKey || !config?.baseURL || !config?.manufacturer) throw new Error("请检查模型配置是否正确");
if (!config || !config?.model || !config?.apiKey || !config?.baseURL) sqlTextModelConfig = await u.getConfig("text"); const { model, apiKey, baseURL, manufacturer } = { ...config };
const { model, apiKey, baseURL, manufacturer } = { ...(sqlTextModelConfig as Awaited<ReturnType<typeof u.getConfig>>), ...config };
let owned; let owned;
if (manufacturer == "other") { if (manufacturer == "other") {
owned = modelList.find((m) => m.manufacturer === manufacturer); owned = modelList.find((m) => m.manufacturer === manufacturer);
@ -39,7 +37,9 @@ const buildOptions = async (input: AIInput<any>, config: AIConfig) => {
const maxStep = input.maxStep ?? (input.tools ? Object.keys(input.tools).length * 5 : undefined); const maxStep = input.maxStep ?? (input.tools ? Object.keys(input.tools).length * 5 : undefined);
const outputBuilders: Record<string, (schema: any) => any> = { const outputBuilders: Record<string, (schema: any) => any> = {
schema: (s) => Output.object({ schema: z.object(s) }), schema: (s) => {
return Output.object({ schema: z.object(s) });
},
object: () => { object: () => {
const jsonSchemaPrompt = `\n请按照以下 JSON Schema 格式返回结果:\n${JSON.stringify( const jsonSchemaPrompt = `\n请按照以下 JSON Schema 格式返回结果:\n${JSON.stringify(
z.toJSONSchema(z.object(input.output)), z.toJSONSchema(z.object(input.output)),
@ -52,16 +52,11 @@ const buildOptions = async (input: AIInput<any>, config: AIConfig) => {
}; };
const output = input.output ? (outputBuilders[owned.responseFormat]?.(input.output) ?? null) : null; const output = input.output ? (outputBuilders[owned.responseFormat]?.(input.output) ?? null) : null;
const modelFn = owned.manufacturer == "doubao" ? (modelInstance as OpenAIProvider).chat(model!) : modelInstance(model!); const chatModelManufacturer = ["doubao", "other", "openai"];
const modelFn = chatModelManufacturer.includes(owned.manufacturer) ? (modelInstance as OpenAIProvider).chat(model!) : modelInstance(model!);
return { return {
config: { config: {
model: model: modelFn as LanguageModel,
process.env.NODE_ENV === "dev"
? wrapLanguageModel({
model: modelFn as any,
middleware: devToolsMiddleware(),
})
: (modelFn as LanguageModel),
...(input.system && { system: input.system }), ...(input.system && { system: input.system }),
...(input.prompt ? { prompt: input.prompt } : { messages: input.messages! }), ...(input.prompt ? { prompt: input.prompt } : { messages: input.messages! }),
...(input.tools && owned.tool && { tools: input.tools }), ...(input.tools && owned.tool && { tools: input.tools }),
@ -79,7 +74,7 @@ const ai = Object.create({}) as {
stream(input: AIInput, config?: AIConfig): Promise<ReturnType<typeof streamText>>; stream(input: AIInput, config?: AIConfig): Promise<ReturnType<typeof streamText>>;
}; };
ai.invoke = async (input: AIInput<any>, config: AIConfig = {}) => { ai.invoke = async (input: AIInput<any>, config: AIConfig) => {
const options = await buildOptions(input, config); const options = await buildOptions(input, config);
const result = await generateText(options.config); const result = await generateText(options.config);
if (options.responseFormat === "object" && input.output) { if (options.responseFormat === "object" && input.output) {
@ -95,7 +90,7 @@ ai.invoke = async (input: AIInput<any>, config: AIConfig = {}) => {
return result; return result;
}; };
ai.stream = async (input: AIInput, config: AIConfig = {}) => { ai.stream = async (input: AIInput, config: AIConfig) => {
const options = await buildOptions(input, config); const options = await buildOptions(input, config);
return streamText(options.config); return streamText(options.config);
}; };

View File

@ -417,7 +417,7 @@ const modelList: Owned[] = [
responseFormat: "schema", responseFormat: "schema",
image: true, image: true,
think: false, think: false,
instance: createOpenAICompatible, instance: createOpenAI,
tool: true, tool: true,
}, },
]; ];

View File

@ -22,8 +22,10 @@ const modelInstance = {
} as const; } as const;
export default async (input: VideoConfig, config?: AIConfig) => { export default async (input: VideoConfig, config?: AIConfig) => {
const sqlTextModelConfig = await u.getConfig("video"); console.log("%c Line:25 🥛 config", "background:#2eafb0", config);
const { model, apiKey, baseURL, manufacturer } = { ...sqlTextModelConfig, ...config }; const { model, apiKey, baseURL, manufacturer } = { ...config };
if (!config || !config?.model || !config?.apiKey) throw new Error("请检查模型配置是否正确");
const manufacturerFn = modelInstance[manufacturer as keyof typeof modelInstance]; const manufacturerFn = modelInstance[manufacturer as keyof typeof modelInstance];
if (!manufacturerFn) if (!manufacturerFn) throw new Error("不支持的视频厂商"); if (!manufacturerFn) if (!manufacturerFn) throw new Error("不支持的视频厂商");
const owned = modelList.find((m) => m.model === model); const owned = modelList.find((m) => m.model === model);

View File

@ -47,6 +47,7 @@ export default async (input: VideoConfig, config: AIConfig) => {
}); });
const taskId = createResponse.data.id; const taskId = createResponse.data.id;
if (!taskId) throw new Error("视频任务创建失败"); if (!taskId) throw new Error("视频任务创建失败");
// 轮询任务状态 // 轮询任务状态

View File

@ -12,4 +12,5 @@ interface AIConfig {
model?: string; model?: string;
apiKey?: string; apiKey?: string;
baseURL?: string; baseURL?: string;
manufacturer?: string;
} }

View File

@ -79,13 +79,18 @@ async function convertDirectiveAndImages(images: Record<string, string>, directi
*/ */
export default async (images: Record<string, string>, directive: string, projectId: number) => { export default async (images: Record<string, string>, directive: string, projectId: number) => {
const { prompt, images: base64Images } = await convertDirectiveAndImages(images, directive); const { prompt, images: base64Images } = await convertDirectiveAndImages(images, directive);
const contentStr = await u.ai.image({ const apiConfig = await u.getPromptAi("editImage");
systemPrompt: "根据用户提供的具体修改指令,对上传的图片进行智能编辑。",
prompt: prompt, const contentStr = await u.ai.image(
imageBase64: base64Images, {
aspectRatio: "16:9", systemPrompt: "根据用户提供的具体修改指令,对上传的图片进行智能编辑。",
size: "1K", prompt: prompt,
}); imageBase64: base64Images,
aspectRatio: "16:9",
size: "1K",
},
apiConfig,
);
const match = contentStr.match(/base64,([A-Za-z0-9+/=]+)/); const match = contentStr.match(/base64,([A-Za-z0-9+/=]+)/);
const buffer = Buffer.from(match && match.length >= 1 ? match[1]! : contentStr, "base64"); const buffer = Buffer.from(match && match.length >= 1 ? match[1]! : contentStr, "base64");
const filePath = `/${projectId}/storyboard/${uuid()}.jpg`; const filePath = `/${projectId}/storyboard/${uuid()}.jpg`;

View File

@ -127,7 +127,7 @@ ${episodePrompt}
${novelData}`; ${novelData}`;
const prompts = await u.db("t_prompts").where("code", "script").first(); const prompts = await u.db("t_prompts").where("code", "script").first();
const promptConfig = await u.getPromptAi(prompts?.id); const promptConfig = await u.getPromptAi("generateScript");
const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么请直接输出AI配置异常"; const mainPrompts = prompts?.customValue || prompts?.defaultValue || "不论用户说什么请直接输出AI配置异常";
const result = await u.ai.text.invoke( const result = await u.ai.text.invoke(

View File

@ -1,26 +1,19 @@
import { db } from "./db"; import { db } from "./db";
interface AiConfig { interface AiConfig {
model: string; model?: string;
apiKey: string; apiKey: string;
baseUrl: string; baseURL?: string;
manufacturer: string; manufacturer: string;
promptsId: number;
} }
export default async function getPromptAi(promptsId: number | undefined): Promise<AiConfig | {}>; export default async function getPromptAi(key: string): Promise<AiConfig | {}> {
export default async function getPromptAi(promptsId: number[]): Promise<AiConfig[]>; const aiConfigData = await db("t_aiModelMap")
export default async function getPromptAi(promptsId: number | number[] | undefined): Promise<AiConfig | AiConfig[] | {}> {
if (!promptsId) return {};
const ids = Array.isArray(promptsId) ? promptsId.filter(Boolean) : [promptsId];
const mapList = await db("t_aiModelMap")
.leftJoin("t_config", "t_config.id", "t_aiModelMap.configId") .leftJoin("t_config", "t_config.id", "t_aiModelMap.configId")
.whereIn("t_aiModelMap.promptsId", ids) .where("t_aiModelMap.key", key)
.select("t_config.model", "t_config.apiKey", "t_config.baseUrl", "t_config.manufacturer", "t_aiModelMap.promptsId"); .select("t_config.model", "t_config.apiKey", "t_config.baseUrl as baseURL", "t_config.manufacturer")
.first();
if (Array.isArray(promptsId)) { if (aiConfigData) {
return mapList as AiConfig[]; return aiConfigData as AiConfig;
} else { } else return {};
return mapList[0] ? (mapList[0] as AiConfig) : {};
}
} }