diff --git a/src/utils/agent/memory.ts b/src/utils/agent/memory.ts index e0b6458..94a8081 100644 --- a/src/utils/agent/memory.ts +++ b/src/utils/agent/memory.ts @@ -5,13 +5,22 @@ import type { memories as MemoryRow } from "@/types/database"; import { tool } from "ai"; import { z } from "zod"; -// ── 可调配置 ── -const messagesPerSummary = 3; // 每累积多少条message触发一次summary生成 -const summaryMaxLength = 500; // summary最大字符长度 -const shortTermLimit = 5; // get()返回的近期未总结message条数 -const summaryLimit = 10; // get()返回的summary条数 -const ragLimit = 3; // get()向量相似搜索返回的message条数 -const deepRetrieveSummaryLimit = 5; // deepRetrieve()向量召回summary的条数 +// ── 可调配置默认值 ── +const DEFAULTS: { + messagesPerSummary: number; + summaryMaxLength: number; + shortTermLimit: number; + summaryLimit: number; + ragLimit: number; + deepRetrieveSummaryLimit: number; +} = { + messagesPerSummary: 3, // 每累积多少条message触发一次summary生成 + summaryMaxLength: 500, // summary最大字符长度 + shortTermLimit: 5, // get()返回的近期未总结message条数 + summaryLimit: 10, // get()返回的summary条数 + ragLimit: 3, // get()向量相似搜索返回的message条数 + deepRetrieveSummaryLimit: 5, // deepRetrieve()向量召回summary的条数 +}; // ── 向量搜索辅助 ── function vectorSearch(rows: MemoryRow[], queryEmbedding: number[], limit: number) { @@ -34,11 +43,12 @@ class Memory { } private async generateSummary(contents: string[]): Promise { + const { summaryMaxLength } = await this.getConfigData({ summaryMaxLength: DEFAULTS.summaryMaxLength }); const { text } = await u.Ai.Text(this.agentType as any).invoke({ system: `你是一个记忆压缩助手。请将以下多条记忆内容压缩为一段简洁的摘要,不超过${summaryMaxLength}个字符。只输出摘要内容,不要加任何前缀或解释。`, messages: [{ role: "user", content: contents.map((c, i) => `${i + 1}. ${c}`).join("\n") }], }); - return text.slice(0, summaryMaxLength); + return text.slice(0, Number(summaryMaxLength)); } private async judgeSummaryRelevance(keyword: string, summaries: { id: string; content: string }[]): Promise { @@ -54,8 +64,26 @@ class Memory { } catch {} return []; } + private async getConfigData>(defaults: T): Promise { + const keys = Object.keys(defaults) as (keyof T & string)[]; + const rows = await u.db("o_setting").whereIn("key", keys); - async add( role: string = "user",content: string) { + const dbMap: Record = {}; + for (const row of rows) { + if (row.key != null) dbMap[row.key] = row.value ?? null; + } + + const result = { ...defaults }; + for (const key of keys) { + const raw = dbMap[key]; + if (raw == null) continue; // null / undefined 使用默认值 + const num = Number(raw); + (result as Record)[key] = Number.isNaN(num) ? raw : num; + } + return result; + } + async add(role: string = "user", content: string) { + const { messagesPerSummary } = await this.getConfigData({ messagesPerSummary: DEFAULTS.messagesPerSummary }); const id = uuidv4(); const embedding = await getEmbedding(content); const isolationKey = this.isolationKey; @@ -69,14 +97,14 @@ class Memory { embedding: JSON.stringify(embedding), relatedMessageIds: null, summarized: 0, - createTime: Date.now(), + createdAt: Date.now(), } as any); // 检查未总结消息数量 - const unsummarized = await u.db("memories").where({ isolationKey, type: "message", summarized: 0 }).orderBy("createTime", "asc"); + const unsummarized = await u.db("memories").where({ isolationKey, type: "message", summarized: 0 }).orderBy("createdAt", "asc"); - if (unsummarized.length >= messagesPerSummary) { - const batch = unsummarized.slice(0, messagesPerSummary); + if (unsummarized.length >= Number(messagesPerSummary)) { + const batch = unsummarized.slice(0, Number(messagesPerSummary)); const batchIds = batch.map((m) => m.id); const batchContents = batch.map((m) => m.content); @@ -92,8 +120,8 @@ class Memory { embedding: JSON.stringify(summaryEmbedding), relatedMessageIds: JSON.stringify(batchIds), summarized: 0, - createTime: Date.now(), - }); + createdAt: Date.now(), + } as any); // 标记已总结 await u.db("memories").whereIn("id", batchIds).update({ summarized: 1 }); @@ -101,42 +129,50 @@ class Memory { } async get(text: string) { + const { shortTermLimit, summaryLimit, ragLimit } = await this.getConfigData({ + shortTermLimit: DEFAULTS.shortTermLimit, + summaryLimit: DEFAULTS.summaryLimit, + ragLimit: DEFAULTS.ragLimit, + }); + const isolationKey = this.isolationKey; // shortTerm: 最近未被总结的 messages const shortTerm = await u .db("memories") .where({ isolationKey, type: "message", summarized: 0 }) - .orderBy("createTime", "desc") - .limit(shortTermLimit); + .orderBy("createdAt", "desc") + .limit(Number(shortTermLimit)); shortTerm.reverse(); // 最旧在前 // summaries: 最近的 summary - const summaries = await u.db("memories").where({ isolationKey, type: "summary" }).orderBy("createTime", "desc").limit(summaryLimit); + const summaries = await u.db("memories").where({ isolationKey, type: "summary" }).orderBy("createdAt", "desc").limit(Number(summaryLimit)); summaries.reverse(); // rag: 向量搜索所有 messages const queryEmbedding = await getEmbedding(text); const allMessages = await u.db("memories").where({ isolationKey, type: "message" }); - const ragResults = vectorSearch(allMessages, queryEmbedding, ragLimit); + const ragResults = vectorSearch(allMessages, queryEmbedding, Number(ragLimit)); return { - shortTerm: shortTerm.map((m: any) => ({ id: m.id, role: m.role, content: m.content, createTime: m.createTime })), + shortTerm: shortTerm.map((m: any) => ({ id: m.id, role: m.role, content: m.content, createdAt: m.createdAt })), summaries: summaries.map((s) => ({ id: s.id, content: s.content, relatedMessageIds: JSON.parse(s.relatedMessageIds || "[]"), - createTime: s.createTime, + createdAt: (s as any).createdAt, })), rag: ragResults.map((r) => ({ id: r.id, content: r.content, similarity: r.similarity })), }; } async deepRetrieve(keyword: string) { + const { deepRetrieveSummaryLimit } = await this.getConfigData({ deepRetrieveSummaryLimit: DEFAULTS.deepRetrieveSummaryLimit }); + const isolationKey = this.isolationKey; // 步骤1: 向量搜索 summary const queryEmbedding = await getEmbedding(keyword); const allSummaries = await u.db("memories").where({ isolationKey, type: "summary" }); - const topSummaries = vectorSearch(allSummaries, queryEmbedding, deepRetrieveSummaryLimit); + const topSummaries = vectorSearch(allSummaries, queryEmbedding, Number(deepRetrieveSummaryLimit)); if (topSummaries.length === 0) return []; @@ -154,9 +190,9 @@ class Memory { if (messageIds.length === 0) return []; - const messages = await u.db("memories").whereIn("id", messageIds).orderBy("createTime", "asc"); + const messages = await u.db("memories").whereIn("id", messageIds).orderBy("createdAt", "asc"); - return messages.map((m) => ({ id: m.id, content: m.content, createTime: m.createTime })); + return messages.map((m) => ({ id: m.id, content: m.content, createdAt: m.createdAt })); } getTools() {