import express from "express"; import pLimit from "p-limit"; import u from "@/utils"; import { z } from "zod"; import { v4 as uuidv4 } from "uuid"; import { error, success } from "@/lib/responseFormat"; import { validateFields } from "@/middleware/middleware"; const router = express.Router(); type AssetType = "role" | "scene" | "tool"; interface AssetTypeConfig { label: string; taskClass: string; dir: string; promptTitle: string; promptEnd: string; } const assetTypeConfig: Record = { role: { label: "角色", taskClass: "角色图生成", dir: "role", promptTitle: "角色标准四视图", promptEnd: "人物角色四视图", }, scene: { label: "场景", taskClass: "场景图生成", dir: "scene", promptTitle: "标准场景图", promptEnd: "标准场景图", }, tool: { label: "道具", taskClass: "道具图生成", dir: "props", promptTitle: "标准道具图", promptEnd: "标准道具图", }, }; function buildPrompt(cfg: AssetTypeConfig, artStyle: string, name: string, prompt: string): string { return ` 请根据以下参数生成${cfg.promptTitle}: **基础参数:** - 画风风格: ${artStyle || "未指定"} **${cfg.label}设定:** - 名称:${name}, - 提示词:${prompt}, 请严格按照系统规范生成${cfg.promptEnd}。 `; } const requestSchema = { projectId: z.number(), model: z.string(), resolution: z.string(), concurrentCount: z.number().int().min(1).optional(), items: z.array( z.object({ id: z.number(), type: z.enum(["role", "scene", "tool", "storyboard"]), name: z.string(), prompt: z.string(), base64: z.string().optional().nullable(), }), ), }; export default router.post("/", validateFields(requestSchema), async (req, res) => { const { projectId, model, resolution, concurrentCount, items } = req.body; // 1. 查询项目 const project = await u.db("o_project").where("id", projectId).select("artStyle", "type", "intro").first(); if (!project) return res.status(500).send(error("项目为空")); // 2. 逐条插入 o_image 占位记录,收集 imageId 列表 const totalNovelId: number[] = []; for (const item of items) { const [imageId] = await u.db("o_image").insert({ type: item.type, state: "生成中", assetsId: item.id, }); await u.db("o_assets").where("id", item.id).update({ imageId }); totalNovelId.push(imageId); } // 3. 后台异步并发生成,不阻塞响应 const limit = pLimit(concurrentCount ?? 1); const tasks = items.map((item: { id: number; type: string; name: string; prompt: string; base64: string | null | undefined }, index: number) => limit(async () => { const imageId = totalNovelId[index]; const cfg = assetTypeConfig[item.type as AssetType]; if (!cfg) return; await u.db("o_assets").where("id", item.id).update({ imageId }); const imagePath = `/${projectId}/${cfg.dir}/${uuidv4()}.jpg`; const userPrompt = buildPrompt(cfg, project.artStyle ?? "", item.name, item.prompt); const describe = `生成${cfg.label}图,名称:${item.name},提示词:${item.prompt}`; const relatedObjects = { id: item.id, projectId, type: cfg.label }; try { const aiImage = u.Ai.Image(model); await aiImage.run({ prompt: userPrompt, imageBase64: item.base64 ? [item.base64] : [], size: resolution, aspectRatio: "16:9", taskClass: cfg.taskClass, describe, projectId, relatedObjects: JSON.stringify(relatedObjects), }); aiImage.save(imagePath); const imageData = await u.db("o_image").where("id", imageId).select("*").first(); if (!imageData) return; await u .db("o_image") .where("id", imageId) .update({ state: "已完成", filePath: imagePath, type: item.type, model: model.split(":")[1], resolution, }); await u.db("o_assets").where("id", item.id).update({ imageId }); } catch (e: any) { await u.db("o_image").where("id", imageId).update({ state: "生成失败" }); } }), ); // 后台执行,不等待结果 Promise.all(tasks).catch(() => {}); return res.status(200).send(success({ total: items.length })); });