This commit is contained in:
ACT丶流星雨 2026-03-19 20:13:09 +08:00
commit 5cbf95ac62
4 changed files with 86 additions and 27 deletions

View File

@ -140,16 +140,36 @@ export default async (knex: Knex, forceInit: boolean = false): Promise<void> =>
value: uuid().slice(0, 8), value: uuid().slice(0, 8),
}, },
{ {
key: "shortTermMemoryLength", key: "messagesPerSummary",
value: 10,
},
{
key: "searchTopK",
value: 3, value: 3,
}, },
{ {
key: "similarityThreshold", key: "shortTermLimit",
value: 0.3, value: 5,
},
{
key: "summaryMaxLength",
value: 500,
},
{
key: "summaryLimit",
value: 10,
},
{
key: "ragLimit",
value: 3,
},
{
key: "deepRetrieveSummaryLimit",
value: 5,
},
{
key: "modelOnnxFile",
value: '["all-MiniLM-L6-v2", "onnx", "model_fp16.onnx"]',
},
{
key: "modelDtype",
value: "fp16",
}, },
]); ]);
}, },

View File

@ -4,14 +4,31 @@ import u from "@/utils";
const router = express.Router(); const router = express.Router();
export default router.get("/", async (req, res) => { export default router.get("/", async (req, res) => {
const settingData = await u.db("o_setting").whereIn("key", ["shortTermMemoryLength", "searchTopK", "similarityThreshold"]); const settingData = await u
.db("o_setting")
.whereIn("key", [
"messagesPerSummary",
"shortTermLimit",
"summaryMaxLength",
"summaryLimit",
"ragLimit",
"deepRetrieveSummaryLimit",
"modelOnnxFile",
"modelDtype",
]);
if (!settingData) return res.status(400).send(error(`获取记忆配置失败`)); if (!settingData) return res.status(400).send(error(`获取记忆配置失败`));
const memoryObj: Record<string, number> = {}; const memoryObj: Record<string, number | string | string[]> = {};
settingData.forEach((i) => { settingData.forEach((i) => {
if (i.key && i.value) { if (i.key && i.value) {
memoryObj[i.key] = Number(i.value); let value: number | string | string[] = i.value;
if (i.key == "modelOnnxFile") {
value = JSON.parse(i.value);
} else if (i.key != "modelDtype") {
value = Number(value);
}
memoryObj[i.key] = value;
} }
}); });

View File

@ -9,21 +9,37 @@ const router = express.Router();
export default router.post( export default router.post(
"/", "/",
validateFields({ validateFields({
shortTermMemoryLength: z.number(), //短期记忆长度 messagesPerSummary: z.number(),
searchTopK: z.number(), //搜索记忆条数 shortTermLimit: z.number(),
similarityThreshold: z.number(), //记忆相似度阈值 summaryMaxLength: z.number(),
summaryLimit: z.number(),
ragLimit: z.number(),
deepRetrieveSummaryLimit: z.number(),
modelOnnxFile: z.array(z.string()),
modelDtype: z.string(),
}), }),
async (req, res) => { async (req, res) => {
const { shortTermMemoryLength, searchTopK, similarityThreshold } = req.body; const { messagesPerSummary, shortTermLimit, summaryMaxLength, summaryLimit, ragLimit, deepRetrieveSummaryLimit, modelOnnxFile, modelDtype } =
await u.db("o_setting").where("key", "shortTermMemoryLength").update({ req.body;
value: shortTermMemoryLength,
}); const upsert = async (key: string, value: string) => {
await u.db("o_setting").where("key", "searchTopK").update({ const exists = await u.db("o_setting").where("key", key).first();
value: searchTopK, if (exists) {
}); await u.db("o_setting").where("key", key).update({ value });
await u.db("o_setting").where("key", "similarityThreshold").update({ } else {
value: similarityThreshold, await u.db("o_setting").insert({ key, value });
}); }
};
await upsert("messagesPerSummary", messagesPerSummary);
await upsert("shortTermLimit", shortTermLimit);
await upsert("summaryMaxLength", summaryMaxLength);
await upsert("summaryLimit", summaryLimit);
await upsert("ragLimit", ragLimit);
await upsert("deepRetrieveSummaryLimit", deepRetrieveSummaryLimit);
await upsert("modelOnnxFile", JSON.stringify(modelOnnxFile));
await upsert("modelDtype", modelDtype);
res.status(200).send(success("保存设置成功")); res.status(200).send(success("保存设置成功"));
}, },
); );

View File

@ -2,18 +2,24 @@ import { pipeline, env as transformersEnv, FeatureExtractionPipeline } from "@hu
import path from "path"; import path from "path";
import fs from "fs"; import fs from "fs";
import getPath from "@/utils/getPath"; import getPath from "@/utils/getPath";
import db from "@/utils/db";
// ── 模型配置 ── // ── 模型配置 ──
const modelOnnxFile = ["all-MiniLM-L6-v2", "onnx", "model_fp16.onnx"]; // 模型文件路径 // const modelOnnxFile = ["all-MiniLM-L6-v2", "onnx", "model_fp16.onnx"]; // 模型文件路径
const modelDtype = "fp16" as const; // 量化类型fp32 // const modelDtype = "fp16" as const; // 量化类型fp32
let extractor: FeatureExtractionPipeline | null = null; let extractor: FeatureExtractionPipeline | null = null;
export async function initEmbedding(): Promise<void> { export async function initEmbedding(): Promise<void> {
if (extractor) return; if (extractor) return;
//todo 模型配置放到这里 //todo 模型配置放到这里
const modelConfigData = await db("o_setting").whereIn("key", ["modelOnnxFile", "modelDtype"]);
const modelObj: Record<string, string> = {};
Object.entries(modelConfigData).forEach(([key, value]) => {
modelObj[key] = value as string;
});
let modelOnnxFile = modelObj?.modelOnnxFile ? JSON.parse(modelObj.modelOnnxFile) : ["all-MiniLM-L6-v2", "onnx", "model_fp16.onnx"]; // 模型文件路径
let modelDtype = modelObj?.modelDtype ?? ("fp16" as const); // 量化类型fp32
const onnxPath = path.join(getPath("models"), ...modelOnnxFile); const onnxPath = path.join(getPath("models"), ...modelOnnxFile);
if (!fs.existsSync(onnxPath)) { if (!fs.existsSync(onnxPath)) {
throw new Error(`Embedding 模型文件不存在: ${onnxPath}`); throw new Error(`Embedding 模型文件不存在: ${onnxPath}`);