Skip to content

Commit

Permalink
Merge pull request #2293 from bbopar/1186-add-get-memories-by-ids
Browse files Browse the repository at this point in the history
feat: add getMemoryByIds to database adapters
  • Loading branch information
wtfsayo authored Jan 15, 2025
2 parents 89b6a19 + e5403ec commit 4d42de6
Show file tree
Hide file tree
Showing 8 changed files with 190 additions and 8 deletions.
27 changes: 27 additions & 0 deletions packages/adapter-pglite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,33 @@ export class PGLiteDatabaseAdapter
}, "getMemoryById");
}

async getMemoriesByIds(
memoryIds: UUID[],
tableName?: string
): Promise<Memory[]> {
return this.withDatabase(async () => {
if (memoryIds.length === 0) return [];
const placeholders = memoryIds.map((_, i) => `$${i + 1}`).join(",");
let sql = `SELECT * FROM memories WHERE id IN (${placeholders})`;
const queryParams: any[] = [...memoryIds];

if (tableName) {
sql += ` AND type = $${memoryIds.length + 1}`;
queryParams.push(tableName);
}

const { rows } = await this.query<Memory>(sql, queryParams);

return rows.map((row) => ({
...row,
content:
typeof row.content === "string"
? JSON.parse(row.content)
: row.content,
}));
}, "getMemoriesByIds");
}

async createMemory(memory: Memory, tableName: string): Promise<void> {
return this.withDatabase(async () => {
elizaLogger.debug("PostgresAdapter createMemory:", {
Expand Down
27 changes: 27 additions & 0 deletions packages/adapter-postgres/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,33 @@ export class PostgresDatabaseAdapter
}, "getMemoryById");
}

async getMemoriesByIds(
memoryIds: UUID[],
tableName?: string
): Promise<Memory[]> {
return this.withDatabase(async () => {
if (memoryIds.length === 0) return [];
const placeholders = memoryIds.map((_, i) => `$${i + 1}`).join(",");
let sql = `SELECT * FROM memories WHERE id IN (${placeholders})`;
const queryParams: any[] = [...memoryIds];

if (tableName) {
sql += ` AND type = $${memoryIds.length + 1}`;
queryParams.push(tableName);
}

const { rows } = await this.pool.query(sql, queryParams);

return rows.map((row) => ({
...row,
content:
typeof row.content === "string"
? JSON.parse(row.content)
: row.content,
}));
}, "getMemoriesByIds");
}

async createMemory(memory: Memory, tableName: string): Promise<void> {
return this.withDatabase(async () => {
elizaLogger.debug("PostgresAdapter createMemory:", {
Expand Down
27 changes: 27 additions & 0 deletions packages/adapter-sqlite/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,33 @@ export class SqliteDatabaseAdapter
return null;
}

async getMemoriesByIds(
memoryIds: UUID[],
tableName?: string
): Promise<Memory[]> {
if (memoryIds.length === 0) return [];
const queryParams: any[] = [];
const placeholders = memoryIds.map(() => "?").join(",");
let sql = `SELECT * FROM memories WHERE id IN (${placeholders})`;
queryParams.push(...memoryIds);

if (tableName) {
sql += ` AND type = ?`;
queryParams.push(tableName);
}

const memories = this.db.prepare(sql).all(...queryParams) as Memory[];

return memories.map((memory) => ({
...memory,
createdAt:
typeof memory.createdAt === "string"
? Date.parse(memory.createdAt as string)
: memory.createdAt,
content: JSON.parse(memory.content as unknown as string),
}));
}

async createMemory(memory: Memory, tableName: string): Promise<void> {
// Delete any existing memory with the same ID first
// const deleteSql = `DELETE FROM memories WHERE id = ? AND type = ?`;
Expand Down
29 changes: 29 additions & 0 deletions packages/adapter-sqljs/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,35 @@ export class SqlJsDatabaseAdapter
return memory || null;
}

async getMemoriesByIds(
memoryIds: UUID[],
tableName?: string
): Promise<Memory[]> {
if (memoryIds.length === 0) return [];
const placeholders = memoryIds.map(() => "?").join(",");
let sql = `SELECT * FROM memories WHERE id IN (${placeholders})`;
const queryParams: any[] = [...memoryIds];

if (tableName) {
sql += ` AND type = ?`;
queryParams.push(tableName);
}

const stmt = this.db.prepare(sql);
stmt.bind(queryParams);

const memories: Memory[] = [];
while (stmt.step()) {
const memory = stmt.getAsObject() as unknown as Memory;
memories.push({
...memory,
content: JSON.parse(memory.content as unknown as string),
});
}
stmt.free();
return memories;
}

async createMemory(memory: Memory, tableName: string): Promise<void> {
let isUnique = true;
if (memory.embedding) {
Expand Down
25 changes: 25 additions & 0 deletions packages/adapter-supabase/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,31 @@ export class SupabaseDatabaseAdapter extends DatabaseAdapter {
return data as Memory;
}

async getMemoriesByIds(
memoryIds: UUID[],
tableName?: string
): Promise<Memory[]> {
if (memoryIds.length === 0) return [];

let query = this.supabase
.from("memories")
.select("*")
.in("id", memoryIds);

if (tableName) {
query = query.eq("type", tableName);
}

const { data, error } = await query;

if (error) {
console.error("Error retrieving memories by IDs:", error);
return [];
}

return data as Memory[];
}

async createMemory(
memory: Memory,
tableName: string,
Expand Down
12 changes: 12 additions & 0 deletions packages/core/__tests__/database.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@ class MockDatabaseAdapter extends DatabaseAdapter {
getMemoryById(_id: UUID): Promise<Memory | null> {
throw new Error("Method not implemented.");
}
async getMemoriesByIds(
memoryIds: UUID[],
_tableName?: string
): Promise<Memory[]> {
return memoryIds.map((id) => ({
id: id,
content: { text: "Test Memory" },
roomId: "room-id" as UUID,
userId: "user-id" as UUID,
agentId: "agent-id" as UUID,
})) as Memory[];
}
log(_params: {
body: { [key: string]: unknown };
userId: UUID;
Expand Down
15 changes: 13 additions & 2 deletions packages/core/src/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,17 @@ export abstract class DatabaseAdapter<DB = any> implements IDatabaseAdapter {

abstract getMemoryById(id: UUID): Promise<Memory | null>;

/**
* Retrieves multiple memories by their IDs
* @param memoryIds Array of UUIDs of the memories to retrieve
* @param tableName Optional table name to filter memories by type
* @returns Promise resolving to array of Memory objects
*/
abstract getMemoriesByIds(
memoryIds: UUID[],
tableName?: string
): Promise<Memory[]>;

/**
* Retrieves cached embeddings based on the specified query parameters.
* @param params An object containing parameters for the embedding retrieval.
Expand Down Expand Up @@ -382,12 +393,12 @@ export abstract class DatabaseAdapter<DB = any> implements IDatabaseAdapter {
userId: UUID;
}): Promise<Relationship[]>;

/**
/**
* Retrieves knowledge items based on specified parameters.
* @param params Object containing search parameters
* @returns Promise resolving to array of knowledge items
*/
abstract getKnowledge(params: {
abstract getKnowledge(params: {
id?: UUID;
agentId: UUID;
limit?: number;
Expand Down
36 changes: 30 additions & 6 deletions packages/core/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ export enum ModelProviderName {
AKASH_CHAT_API = "akash_chat_api",
LIVEPEER = "livepeer",
LETZAI = "letzai",
DEEPSEEK="deepseek",
INFERA="infera"
DEEPSEEK = "deepseek",
INFERA = "infera",
}

/**
Expand Down Expand Up @@ -909,6 +909,8 @@ export interface IDatabaseAdapter {

getMemoryById(id: UUID): Promise<Memory | null>;

getMemoriesByIds(ids: UUID[], tableName?: string): Promise<Memory[]>;

getMemoriesByRoomIds(params: {
tableName: string;
agentId: UUID;
Expand Down Expand Up @@ -1087,7 +1089,10 @@ export interface IMemoryManager {
): Promise<{ embedding: number[]; levenshtein_score: number }[]>;

getMemoryById(id: UUID): Promise<Memory | null>;
getMemoriesByRoomIds(params: { roomIds: UUID[], limit?: number }): Promise<Memory[]>;
getMemoriesByRoomIds(params: {
roomIds: UUID[];
limit?: number;
}): Promise<Memory[]>;
searchMemoriesByEmbedding(
embedding: number[],
opts: {
Expand Down Expand Up @@ -1378,9 +1383,28 @@ export interface IrysTimestamp {
}

export interface IIrysService extends Service {
getDataFromAnAgent(agentsWalletPublicKeys: string[], tags: GraphQLTag[], timestamp: IrysTimestamp): Promise<DataIrysFetchedFromGQL>;
workerUploadDataOnIrys(data: any, dataType: IrysDataType, messageType: IrysMessageType, serviceCategory: string[], protocol: string[], validationThreshold: number[], minimumProviders: number[], testProvider: boolean[], reputation: number[]): Promise<UploadIrysResult>;
providerUploadDataOnIrys(data: any, dataType: IrysDataType, serviceCategory: string[], protocol: string[]): Promise<UploadIrysResult>;
getDataFromAnAgent(
agentsWalletPublicKeys: string[],
tags: GraphQLTag[],
timestamp: IrysTimestamp
): Promise<DataIrysFetchedFromGQL>;
workerUploadDataOnIrys(
data: any,
dataType: IrysDataType,
messageType: IrysMessageType,
serviceCategory: string[],
protocol: string[],
validationThreshold: number[],
minimumProviders: number[],
testProvider: boolean[],
reputation: number[]
): Promise<UploadIrysResult>;
providerUploadDataOnIrys(
data: any,
dataType: IrysDataType,
serviceCategory: string[],
protocol: string[]
): Promise<UploadIrysResult>;
}

export interface ITeeLogService extends Service {
Expand Down

0 comments on commit 4d42de6

Please sign in to comment.