151 lines
4.4 KiB
TypeScript
151 lines
4.4 KiB
TypeScript
import express from "express";
|
||
import pLimit from "p-limit";
|
||
import u from "@/utils";
|
||
import { z } from "zod";
|
||
import { v4 as uuidv4 } from "uuid";
|
||
import { error, success } from "@/lib/responseFormat";
|
||
import { validateFields } from "@/middleware/middleware";
|
||
|
||
const router = express.Router();
|
||
|
||
type AssetType = "role" | "scene" | "tool";
|
||
|
||
interface AssetTypeConfig {
|
||
label: string;
|
||
taskClass: string;
|
||
dir: string;
|
||
promptTitle: string;
|
||
promptEnd: string;
|
||
}
|
||
|
||
const assetTypeConfig: Record<AssetType, AssetTypeConfig> = {
|
||
role: {
|
||
label: "角色",
|
||
taskClass: "角色图生成",
|
||
dir: "role",
|
||
promptTitle: "角色标准四视图",
|
||
promptEnd: "人物角色四视图",
|
||
},
|
||
scene: {
|
||
label: "场景",
|
||
taskClass: "场景图生成",
|
||
dir: "scene",
|
||
promptTitle: "标准场景图",
|
||
promptEnd: "标准场景图",
|
||
},
|
||
tool: {
|
||
label: "道具",
|
||
taskClass: "道具图生成",
|
||
dir: "props",
|
||
promptTitle: "标准道具图",
|
||
promptEnd: "标准道具图",
|
||
},
|
||
};
|
||
|
||
function buildPrompt(cfg: AssetTypeConfig, artStyle: string, name: string, prompt: string): string {
|
||
return `
|
||
请根据以下参数生成${cfg.promptTitle}:
|
||
|
||
**基础参数:**
|
||
- 画风风格: ${artStyle || "未指定"}
|
||
|
||
**${cfg.label}设定:**
|
||
- 名称:${name},
|
||
- 提示词:${prompt},
|
||
|
||
请严格按照系统规范生成${cfg.promptEnd}。
|
||
`;
|
||
}
|
||
|
||
const requestSchema = {
|
||
projectId: z.number(),
|
||
model: z.string(),
|
||
resolution: z.string(),
|
||
concurrentCount: z.number().int().min(1).optional(),
|
||
items: z.array(
|
||
z.object({
|
||
id: z.number(),
|
||
type: z.enum(["role", "scene", "tool", "storyboard"]),
|
||
name: z.string(),
|
||
prompt: z.string(),
|
||
base64: z.string().optional().nullable(),
|
||
}),
|
||
),
|
||
};
|
||
|
||
export default router.post("/", validateFields(requestSchema), async (req, res) => {
|
||
const { projectId, model, resolution, concurrentCount, items } = req.body;
|
||
|
||
// 1. 查询项目
|
||
const project = await u.db("o_project").where("id", projectId).select("artStyle", "type", "intro").first();
|
||
if (!project) return res.status(500).send(error("项目为空"));
|
||
|
||
// 2. 逐条插入 o_image 占位记录,收集 imageId 列表
|
||
const totalNovelId: number[] = [];
|
||
for (const item of items) {
|
||
const [imageId] = await u.db("o_image").insert({
|
||
type: item.type,
|
||
state: "生成中",
|
||
assetsId: item.id,
|
||
});
|
||
await u.db("o_assets").where("id", item.id).update({ imageId });
|
||
totalNovelId.push(imageId);
|
||
}
|
||
|
||
// 3. 后台异步并发生成,不阻塞响应
|
||
const limit = pLimit(concurrentCount ?? 1);
|
||
|
||
const tasks = items.map((item: { id: number; type: string; name: string; prompt: string; base64: string | null | undefined }, index: number) =>
|
||
limit(async () => {
|
||
const imageId = totalNovelId[index];
|
||
const cfg = assetTypeConfig[item.type as AssetType];
|
||
if (!cfg) return;
|
||
|
||
await u.db("o_assets").where("id", item.id).update({ imageId });
|
||
|
||
const imagePath = `/${projectId}/${cfg.dir}/${uuidv4()}.jpg`;
|
||
const userPrompt = buildPrompt(cfg, project.artStyle ?? "", item.name, item.prompt);
|
||
const describe = `生成${cfg.label}图,名称:${item.name},提示词:${item.prompt}`;
|
||
const relatedObjects = { id: item.id, projectId, type: cfg.label };
|
||
|
||
try {
|
||
const aiImage = u.Ai.Image(model);
|
||
await aiImage.run({
|
||
prompt: userPrompt,
|
||
imageBase64: item.base64 ? [item.base64] : [],
|
||
size: resolution,
|
||
aspectRatio: "16:9",
|
||
taskClass: cfg.taskClass,
|
||
describe,
|
||
projectId,
|
||
relatedObjects: JSON.stringify(relatedObjects),
|
||
});
|
||
aiImage.save(imagePath);
|
||
|
||
const imageData = await u.db("o_image").where("id", imageId).select("*").first();
|
||
if (!imageData) return;
|
||
|
||
await u
|
||
.db("o_image")
|
||
.where("id", imageId)
|
||
.update({
|
||
state: "已完成",
|
||
filePath: imagePath,
|
||
type: item.type,
|
||
model: model.split(":")[1],
|
||
resolution,
|
||
});
|
||
|
||
await u.db("o_assets").where("id", item.id).update({ imageId });
|
||
} catch (e: any) {
|
||
await u.db("o_image").where("id", imageId).update({ state: "生成失败" });
|
||
}
|
||
}),
|
||
);
|
||
|
||
// 后台执行,不等待结果
|
||
Promise.all(tasks).catch(() => {});
|
||
|
||
return res.status(200).send(success({ total: items.length }));
|
||
});
|