memory接入setting表
This commit is contained in:
parent
06b6569c8e
commit
a80ef39f66
@ -5,13 +5,22 @@ import type { memories as MemoryRow } from "@/types/database";
|
|||||||
import { tool } from "ai";
|
import { tool } from "ai";
|
||||||
import { z } from "zod";
|
import { z } from "zod";
|
||||||
|
|
||||||
// ── 可调配置 ──
|
// ── 可调配置默认值 ──
|
||||||
const messagesPerSummary = 3; // 每累积多少条message触发一次summary生成
|
const DEFAULTS: {
|
||||||
const summaryMaxLength = 500; // summary最大字符长度
|
messagesPerSummary: number;
|
||||||
const shortTermLimit = 5; // get()返回的近期未总结message条数
|
summaryMaxLength: number;
|
||||||
const summaryLimit = 10; // get()返回的summary条数
|
shortTermLimit: number;
|
||||||
const ragLimit = 3; // get()向量相似搜索返回的message条数
|
summaryLimit: number;
|
||||||
const deepRetrieveSummaryLimit = 5; // deepRetrieve()向量召回summary的条数
|
ragLimit: number;
|
||||||
|
deepRetrieveSummaryLimit: number;
|
||||||
|
} = {
|
||||||
|
messagesPerSummary: 3, // 每累积多少条message触发一次summary生成
|
||||||
|
summaryMaxLength: 500, // summary最大字符长度
|
||||||
|
shortTermLimit: 5, // get()返回的近期未总结message条数
|
||||||
|
summaryLimit: 10, // get()返回的summary条数
|
||||||
|
ragLimit: 3, // get()向量相似搜索返回的message条数
|
||||||
|
deepRetrieveSummaryLimit: 5, // deepRetrieve()向量召回summary的条数
|
||||||
|
};
|
||||||
|
|
||||||
// ── 向量搜索辅助 ──
|
// ── 向量搜索辅助 ──
|
||||||
function vectorSearch(rows: MemoryRow[], queryEmbedding: number[], limit: number) {
|
function vectorSearch(rows: MemoryRow[], queryEmbedding: number[], limit: number) {
|
||||||
@ -34,11 +43,12 @@ class Memory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private async generateSummary(contents: string[]): Promise<string> {
|
private async generateSummary(contents: string[]): Promise<string> {
|
||||||
|
const { summaryMaxLength } = await this.getConfigData({ summaryMaxLength: DEFAULTS.summaryMaxLength });
|
||||||
const { text } = await u.Ai.Text(this.agentType as any).invoke({
|
const { text } = await u.Ai.Text(this.agentType as any).invoke({
|
||||||
system: `你是一个记忆压缩助手。请将以下多条记忆内容压缩为一段简洁的摘要,不超过${summaryMaxLength}个字符。只输出摘要内容,不要加任何前缀或解释。`,
|
system: `你是一个记忆压缩助手。请将以下多条记忆内容压缩为一段简洁的摘要,不超过${summaryMaxLength}个字符。只输出摘要内容,不要加任何前缀或解释。`,
|
||||||
messages: [{ role: "user", content: contents.map((c, i) => `${i + 1}. ${c}`).join("\n") }],
|
messages: [{ role: "user", content: contents.map((c, i) => `${i + 1}. ${c}`).join("\n") }],
|
||||||
});
|
});
|
||||||
return text.slice(0, summaryMaxLength);
|
return text.slice(0, Number(summaryMaxLength));
|
||||||
}
|
}
|
||||||
|
|
||||||
private async judgeSummaryRelevance(keyword: string, summaries: { id: string; content: string }[]): Promise<string[]> {
|
private async judgeSummaryRelevance(keyword: string, summaries: { id: string; content: string }[]): Promise<string[]> {
|
||||||
@ -54,8 +64,26 @@ class Memory {
|
|||||||
} catch {}
|
} catch {}
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
private async getConfigData<T extends Record<string, string | number>>(defaults: T): Promise<T> {
|
||||||
|
const keys = Object.keys(defaults) as (keyof T & string)[];
|
||||||
|
const rows = await u.db("o_setting").whereIn("key", keys);
|
||||||
|
|
||||||
async add( role: string = "user",content: string) {
|
const dbMap: Record<string, string | null> = {};
|
||||||
|
for (const row of rows) {
|
||||||
|
if (row.key != null) dbMap[row.key] = row.value ?? null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = { ...defaults };
|
||||||
|
for (const key of keys) {
|
||||||
|
const raw = dbMap[key];
|
||||||
|
if (raw == null) continue; // null / undefined 使用默认值
|
||||||
|
const num = Number(raw);
|
||||||
|
(result as Record<string, string | number>)[key] = Number.isNaN(num) ? raw : num;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
async add(role: string = "user", content: string) {
|
||||||
|
const { messagesPerSummary } = await this.getConfigData({ messagesPerSummary: DEFAULTS.messagesPerSummary });
|
||||||
const id = uuidv4();
|
const id = uuidv4();
|
||||||
const embedding = await getEmbedding(content);
|
const embedding = await getEmbedding(content);
|
||||||
const isolationKey = this.isolationKey;
|
const isolationKey = this.isolationKey;
|
||||||
@ -69,14 +97,14 @@ class Memory {
|
|||||||
embedding: JSON.stringify(embedding),
|
embedding: JSON.stringify(embedding),
|
||||||
relatedMessageIds: null,
|
relatedMessageIds: null,
|
||||||
summarized: 0,
|
summarized: 0,
|
||||||
createTime: Date.now(),
|
createdAt: Date.now(),
|
||||||
} as any);
|
} as any);
|
||||||
|
|
||||||
// 检查未总结消息数量
|
// 检查未总结消息数量
|
||||||
const unsummarized = await u.db("memories").where({ isolationKey, type: "message", summarized: 0 }).orderBy("createTime", "asc");
|
const unsummarized = await u.db("memories").where({ isolationKey, type: "message", summarized: 0 }).orderBy("createdAt", "asc");
|
||||||
|
|
||||||
if (unsummarized.length >= messagesPerSummary) {
|
if (unsummarized.length >= Number(messagesPerSummary)) {
|
||||||
const batch = unsummarized.slice(0, messagesPerSummary);
|
const batch = unsummarized.slice(0, Number(messagesPerSummary));
|
||||||
const batchIds = batch.map((m) => m.id);
|
const batchIds = batch.map((m) => m.id);
|
||||||
const batchContents = batch.map((m) => m.content);
|
const batchContents = batch.map((m) => m.content);
|
||||||
|
|
||||||
@ -92,8 +120,8 @@ class Memory {
|
|||||||
embedding: JSON.stringify(summaryEmbedding),
|
embedding: JSON.stringify(summaryEmbedding),
|
||||||
relatedMessageIds: JSON.stringify(batchIds),
|
relatedMessageIds: JSON.stringify(batchIds),
|
||||||
summarized: 0,
|
summarized: 0,
|
||||||
createTime: Date.now(),
|
createdAt: Date.now(),
|
||||||
});
|
} as any);
|
||||||
|
|
||||||
// 标记已总结
|
// 标记已总结
|
||||||
await u.db("memories").whereIn("id", batchIds).update({ summarized: 1 });
|
await u.db("memories").whereIn("id", batchIds).update({ summarized: 1 });
|
||||||
@ -101,42 +129,50 @@ class Memory {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async get(text: string) {
|
async get(text: string) {
|
||||||
|
const { shortTermLimit, summaryLimit, ragLimit } = await this.getConfigData({
|
||||||
|
shortTermLimit: DEFAULTS.shortTermLimit,
|
||||||
|
summaryLimit: DEFAULTS.summaryLimit,
|
||||||
|
ragLimit: DEFAULTS.ragLimit,
|
||||||
|
});
|
||||||
|
|
||||||
const isolationKey = this.isolationKey;
|
const isolationKey = this.isolationKey;
|
||||||
// shortTerm: 最近未被总结的 messages
|
// shortTerm: 最近未被总结的 messages
|
||||||
const shortTerm = await u
|
const shortTerm = await u
|
||||||
.db("memories")
|
.db("memories")
|
||||||
.where({ isolationKey, type: "message", summarized: 0 })
|
.where({ isolationKey, type: "message", summarized: 0 })
|
||||||
.orderBy("createTime", "desc")
|
.orderBy("createdAt", "desc")
|
||||||
.limit(shortTermLimit);
|
.limit(Number(shortTermLimit));
|
||||||
shortTerm.reverse(); // 最旧在前
|
shortTerm.reverse(); // 最旧在前
|
||||||
|
|
||||||
// summaries: 最近的 summary
|
// summaries: 最近的 summary
|
||||||
const summaries = await u.db("memories").where({ isolationKey, type: "summary" }).orderBy("createTime", "desc").limit(summaryLimit);
|
const summaries = await u.db("memories").where({ isolationKey, type: "summary" }).orderBy("createdAt", "desc").limit(Number(summaryLimit));
|
||||||
summaries.reverse();
|
summaries.reverse();
|
||||||
|
|
||||||
// rag: 向量搜索所有 messages
|
// rag: 向量搜索所有 messages
|
||||||
const queryEmbedding = await getEmbedding(text);
|
const queryEmbedding = await getEmbedding(text);
|
||||||
const allMessages = await u.db("memories").where({ isolationKey, type: "message" });
|
const allMessages = await u.db("memories").where({ isolationKey, type: "message" });
|
||||||
const ragResults = vectorSearch(allMessages, queryEmbedding, ragLimit);
|
const ragResults = vectorSearch(allMessages, queryEmbedding, Number(ragLimit));
|
||||||
|
|
||||||
return {
|
return {
|
||||||
shortTerm: shortTerm.map((m: any) => ({ id: m.id, role: m.role, content: m.content, createTime: m.createTime })),
|
shortTerm: shortTerm.map((m: any) => ({ id: m.id, role: m.role, content: m.content, createdAt: m.createdAt })),
|
||||||
summaries: summaries.map((s) => ({
|
summaries: summaries.map((s) => ({
|
||||||
id: s.id,
|
id: s.id,
|
||||||
content: s.content,
|
content: s.content,
|
||||||
relatedMessageIds: JSON.parse(s.relatedMessageIds || "[]"),
|
relatedMessageIds: JSON.parse(s.relatedMessageIds || "[]"),
|
||||||
createTime: s.createTime,
|
createdAt: (s as any).createdAt,
|
||||||
})),
|
})),
|
||||||
rag: ragResults.map((r) => ({ id: r.id, content: r.content, similarity: r.similarity })),
|
rag: ragResults.map((r) => ({ id: r.id, content: r.content, similarity: r.similarity })),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
async deepRetrieve(keyword: string) {
|
async deepRetrieve(keyword: string) {
|
||||||
|
const { deepRetrieveSummaryLimit } = await this.getConfigData({ deepRetrieveSummaryLimit: DEFAULTS.deepRetrieveSummaryLimit });
|
||||||
|
|
||||||
const isolationKey = this.isolationKey;
|
const isolationKey = this.isolationKey;
|
||||||
// 步骤1: 向量搜索 summary
|
// 步骤1: 向量搜索 summary
|
||||||
const queryEmbedding = await getEmbedding(keyword);
|
const queryEmbedding = await getEmbedding(keyword);
|
||||||
const allSummaries = await u.db("memories").where({ isolationKey, type: "summary" });
|
const allSummaries = await u.db("memories").where({ isolationKey, type: "summary" });
|
||||||
const topSummaries = vectorSearch(allSummaries, queryEmbedding, deepRetrieveSummaryLimit);
|
const topSummaries = vectorSearch(allSummaries, queryEmbedding, Number(deepRetrieveSummaryLimit));
|
||||||
|
|
||||||
if (topSummaries.length === 0) return [];
|
if (topSummaries.length === 0) return [];
|
||||||
|
|
||||||
@ -154,9 +190,9 @@ class Memory {
|
|||||||
|
|
||||||
if (messageIds.length === 0) return [];
|
if (messageIds.length === 0) return [];
|
||||||
|
|
||||||
const messages = await u.db("memories").whereIn("id", messageIds).orderBy("createTime", "asc");
|
const messages = await u.db("memories").whereIn("id", messageIds).orderBy("createdAt", "asc");
|
||||||
|
|
||||||
return messages.map((m) => ({ id: m.id, content: m.content, createTime: m.createTime }));
|
return messages.map((m) => ({ id: m.id, content: m.content, createdAt: m.createdAt }));
|
||||||
}
|
}
|
||||||
|
|
||||||
getTools() {
|
getTools() {
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user