Skip to content

Commit

Permalink
git commit -m "fix(db): add limit param to memory retrieval across ad…
Browse files Browse the repository at this point in the history
…apters

- Add limit parameter to getMemoriesByRoomIds in SQLite adapter
- Add limit parameter to getMemoriesByRoomIds in SQLjs adapter
- Add limit parameter to getMemoriesByRoomIds in PGLite adapter
- Add limit parameter to getMemoriesByRoomIds in Supabase adapter
- Fix query parameter ordering in SQLjs adapter
- Add consistent DESC ordering across all adapters

Closes #2253"
  • Loading branch information
augchan42 committed Jan 13, 2025
1 parent 60bb094 commit 5bcc18b
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 81 deletions.
8 changes: 8 additions & 0 deletions packages/adapter-pglite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ export class PGLiteDatabaseAdapter
roomIds: UUID[];
agentId?: UUID;
tableName: string;
limit?: number;
}): Promise<Memory[]> {
return this.withDatabase(async () => {
if (params.roomIds.length === 0) return [];
Expand All @@ -167,6 +168,13 @@ export class PGLiteDatabaseAdapter
queryParams = [...queryParams, params.agentId];
}

// Add ordering and limit
query += ` ORDER BY "createdAt" DESC`;
if (params.limit) {
query += ` LIMIT $${queryParams.length + 1}`;
queryParams.push(params.limit.toString());
}

const { rows } = await this.query<Memory>(query, queryParams);
return rows.map((row) => ({
...row,
Expand Down
12 changes: 11 additions & 1 deletion packages/adapter-sqlite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -154,19 +154,29 @@ export class SqliteDatabaseAdapter
agentId: UUID;
roomIds: UUID[];
tableName: string;
limit?: number;
}): Promise<Memory[]> {
if (!params.tableName) {
// default to messages
params.tableName = "messages";
}

const placeholders = params.roomIds.map(() => "?").join(", ");
const sql = `SELECT * FROM memories WHERE type = ? AND agentId = ? AND roomId IN (${placeholders})`;
let sql = `SELECT * FROM memories WHERE type = ? AND agentId = ? AND roomId IN (${placeholders})`;

const queryParams = [
params.tableName,
params.agentId,
...params.roomIds,
];

// Add ordering and limit
sql += ` ORDER BY createdAt DESC`;
if (params.limit) {
sql += ` LIMIT ?`;
queryParams.push(params.limit.toString());
}

const stmt = this.db.prepare(sql);
const rows = stmt.all(...queryParams) as (Memory & {
content: string;
Expand Down
67 changes: 46 additions & 21 deletions packages/adapter-sqljs/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import {
type Relationship,
type UUID,
RAGKnowledgeItem,
elizaLogger
elizaLogger,
} from "@elizaos/core";
import { v4 } from "uuid";
import { sqliteTables } from "./sqliteTables.ts";
Expand Down Expand Up @@ -81,15 +81,26 @@ export class SqlJsDatabaseAdapter
agentId: UUID;
roomIds: UUID[];
tableName: string;
limit?: number;
}): Promise<Memory[]> {
const placeholders = params.roomIds.map(() => "?").join(", ");
const sql = `SELECT * FROM memories WHERE 'type' = ? AND agentId = ? AND roomId IN (${placeholders})`;
const stmt = this.db.prepare(sql);
let sql = `SELECT * FROM memories WHERE 'type' = ? AND agentId = ? AND roomId IN (${placeholders})`;

const queryParams = [
params.tableName,
params.agentId,
...params.roomIds,
];

// Add ordering and limit
sql += ` ORDER BY createdAt DESC`;
if (params.limit) {
sql += ` LIMIT ?`;
queryParams.push(params.limit.toString());
}

const stmt = this.db.prepare(sql);

elizaLogger.log({ queryParams });
stmt.bind(queryParams);
elizaLogger.log({ queryParams });
Expand Down Expand Up @@ -834,8 +845,10 @@ export class SqlJsDatabaseAdapter
id: row.id,
agentId: row.agentId,
content: JSON.parse(row.content),
embedding: row.embedding ? new Float32Array(row.embedding) : undefined, // Convert Uint8Array back to Float32Array
createdAt: row.createdAt
embedding: row.embedding
? new Float32Array(row.embedding)
: undefined, // Convert Uint8Array back to Float32Array
createdAt: row.createdAt,
});
}
stmt.free();
Expand All @@ -852,7 +865,7 @@ export class SqlJsDatabaseAdapter
const cacheKey = `embedding_${params.agentId}_${params.searchText}`;
const cachedResult = await this.getCache({
key: cacheKey,
agentId: params.agentId
agentId: params.agentId,
});

if (cachedResult) {
Expand Down Expand Up @@ -901,11 +914,11 @@ export class SqlJsDatabaseAdapter
stmt.bind([
new Uint8Array(params.embedding.buffer),
params.agentId,
`%${params.searchText || ''}%`,
`%${params.searchText || ""}%`,
params.agentId,
params.agentId,
params.match_threshold,
params.match_count
params.match_count,
]);

const results: RAGKnowledgeItem[] = [];
Expand All @@ -915,17 +928,19 @@ export class SqlJsDatabaseAdapter
id: row.id,
agentId: row.agentId,
content: JSON.parse(row.content),
embedding: row.embedding ? new Float32Array(row.embedding) : undefined,
embedding: row.embedding
? new Float32Array(row.embedding)
: undefined,
createdAt: row.createdAt,
similarity: row.keyword_score
similarity: row.keyword_score,
});
}
stmt.free();

await this.setCache({
key: cacheKey,
agentId: params.agentId,
value: JSON.stringify(results)
value: JSON.stringify(results),
});

return results;
Expand All @@ -947,31 +962,41 @@ export class SqlJsDatabaseAdapter
knowledge.id,
metadata.isShared ? null : knowledge.agentId,
JSON.stringify(knowledge.content),
knowledge.embedding ? new Uint8Array(knowledge.embedding.buffer) : null,
knowledge.embedding
? new Uint8Array(knowledge.embedding.buffer)
: null,
knowledge.createdAt || Date.now(),
metadata.isMain ? 1 : 0,
metadata.originalId || null,
metadata.chunkIndex || null,
metadata.isShared ? 1 : 0
metadata.isShared ? 1 : 0,
]);
stmt.free();
} catch (error: any) {
const isShared = knowledge.content.metadata?.isShared;
const isPrimaryKeyError = error?.code === 'SQLITE_CONSTRAINT_PRIMARYKEY';
const isPrimaryKeyError =
error?.code === "SQLITE_CONSTRAINT_PRIMARYKEY";

if (isShared && isPrimaryKeyError) {
elizaLogger.info(`Shared knowledge ${knowledge.id} already exists, skipping`);
elizaLogger.info(
`Shared knowledge ${knowledge.id} already exists, skipping`
);
return;
} else if (!isShared && !error.message?.includes('SQLITE_CONSTRAINT_PRIMARYKEY')) {
} else if (
!isShared &&
!error.message?.includes("SQLITE_CONSTRAINT_PRIMARYKEY")
) {
elizaLogger.error(`Error creating knowledge ${knowledge.id}:`, {
error,
embeddingLength: knowledge.embedding?.length,
content: knowledge.content
content: knowledge.content,
});
throw error;
}

elizaLogger.debug(`Knowledge ${knowledge.id} already exists, skipping`);
elizaLogger.debug(
`Knowledge ${knowledge.id} already exists, skipping`
);
}
}

Expand All @@ -983,9 +1008,9 @@ export class SqlJsDatabaseAdapter
}

async clearKnowledge(agentId: UUID, shared?: boolean): Promise<void> {
const sql = shared ?
`DELETE FROM knowledge WHERE ("agentId" = ? OR "isShared" = 1)` :
`DELETE FROM knowledge WHERE "agentId" = ?`;
const sql = shared
? `DELETE FROM knowledge WHERE ("agentId" = ? OR "isShared" = 1)`
: `DELETE FROM knowledge WHERE "agentId" = ?`;

const stmt = this.db.prepare(sql);
stmt.run([agentId]);
Expand Down
Loading

0 comments on commit 5bcc18b

Please sign in to comment.