Skip to content

Commit

Permalink
Merge changes for embedding caching, change dim size to 1536
Browse files Browse the repository at this point in the history
  • Loading branch information
lalalune committed Mar 11, 2024
1 parent 3e1b6ea commit 9af6602
Show file tree
Hide file tree
Showing 14 changed files with 145 additions and 33 deletions.
16 changes: 16 additions & 0 deletions docs/docs/classes/BgentRuntime.md
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,19 @@ Register an evaluator to assess and guide the agent's responses.
#### Returns

`void`

___

### retriveCachedEmbedding

**retriveCachedEmbedding**(`input`): `Promise`\<``null`` \| `number`[]\>

#### Parameters

| Name | Type |
| :------ | :------ |
| `input` | `string` |

#### Returns

`Promise`\<``null`` \| `number`[]\>
22 changes: 22 additions & 0 deletions docs/docs/classes/DatabaseAdapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,28 @@ ___

___

### getMemoryByContent

**getMemoryByContent**(`«destructured»`): `Promise`\<`SimilaritySearch`[]\>

#### Parameters

| Name | Type |
| :------ | :------ |
| `«destructured»` | `Object` |
| › `query_field_name` | `string` |
| › `query_field_sub_name` | `string` |
| › `query_input` | `string` |
| › `query_match_count` | `number` |
| › `query_table_name` | `string` |
| › `query_threshold` | `number` |

#### Returns

`Promise`\<`SimilaritySearch`[]\>

___

### getRelationship

**getRelationship**(`params`): `Promise`\<``null`` \| [`Relationship`](../interfaces/Relationship.md)\>
Expand Down
16 changes: 16 additions & 0 deletions docs/docs/classes/MemoryManager.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,22 @@ A Promise resolving to an array of Memory objects.

___

### getMemoryByContent

**getMemoryByContent**(`content`): `Promise`\<`SimilaritySearch`[]\>

#### Parameters

| Name | Type |
| :------ | :------ |
| `content` | `string` |

#### Returns

`Promise`\<`SimilaritySearch`[]\>

___

### removeAllMemoriesByUserIds

**removeAllMemoriesByUserIds**(`userIds`): `Promise`\<`void`\>
Expand Down
26 changes: 26 additions & 0 deletions docs/docs/classes/SupabaseDatabaseAdapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,32 @@ ___

___

### getMemoryByContent

**getMemoryByContent**(`opts`): `Promise`\<`SimilaritySearch`[]\>

#### Parameters

| Name | Type |
| :------ | :------ |
| `opts` | `Object` |
| `opts.query_field_name` | `string` |
| `opts.query_field_sub_name` | `string` |
| `opts.query_input` | `string` |
| `opts.query_match_count` | `number` |
| `opts.query_table_name` | `string` |
| `opts.query_threshold` | `number` |

#### Returns

`Promise`\<`SimilaritySearch`[]\>

#### Overrides

[DatabaseAdapter](DatabaseAdapter.md).[getMemoryByContent](DatabaseAdapter.md#getmemorybycontent)

___

### getRelationship

**getRelationship**(`params`): `Promise`\<``null`` \| [`Relationship`](../interfaces/Relationship.md)\>
Expand Down
2 changes: 1 addition & 1 deletion docs/docs/variables/embeddingDimension.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ sidebar_position: 0
custom_edit_url: null
---

`Const` **embeddingDimension**: ``3072``
`Const` **embeddingDimension**: ``1536``
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "bgent",
"version": "0.0.45",
"version": "0.0.46",
"private": false,
"description": "bgent. because agent was taken.",
"type": "module",
Expand Down
2 changes: 1 addition & 1 deletion scripts/shell.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ async function startApplication() {
Authorization: 'Bearer ' + session.access_token
},
body: JSON.stringify({
content,
content: { content, action: "WAIT"},
agentId: agentUUID,
room_id
})
Expand Down
11 changes: 2 additions & 9 deletions src/agents/simple/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ const onMessage = async (
logger.warn("Sender content null, skipping");
return;
}

const data = (await handleMessage(runtime, message, state)) as Content;
return data;
};
Expand Down Expand Up @@ -183,14 +183,7 @@ const routes: Route[] = [
}

// parse the body from the request
const message = await req.json() as {content: any} & Message
// Validate the message content
if (message.hasOwnProperty("content") === false) {
return new Response("content is required", { status: 400 });
} else {
const content = message.content
message.content = {content: content, action: "null"} as Content
}
const message = (await req.json()) as Message;

const databaseAdapter = new SupabaseDatabaseAdapter(
env.SUPABASE_URL,
Expand Down
2 changes: 1 addition & 1 deletion src/lib/__tests__/memory.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ describe("Memory", () => {
content: { content: dissimilarMemoryContent },
user_ids: [user?.id as UUID, zeroUuid],
room_id: room_id as UUID,
embedding: getCachedEmbedding(dissimilarMemoryContent),
embedding,
});
if (!embedding) {
writeCachedEmbedding(
Expand Down
21 changes: 21 additions & 0 deletions src/lib/adapters/supabase.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
Actor,
GoalStatus,
Account,
SimilaritySearch,
} from "../types";
import { DatabaseAdapter } from "../database";

Expand Down Expand Up @@ -62,6 +63,11 @@ export class SupabaseDatabaseAdapter extends DatabaseAdapter {
match_count: number;
unique: boolean;
}): Promise<Memory[]> {
console.log(
"searching memories",
params.tableName,
params.embedding.length,
);
const result = await this.supabase.rpc("search_memories", {
query_table_name: params.tableName,
query_user_ids: params.userIds,
Expand All @@ -76,6 +82,21 @@ export class SupabaseDatabaseAdapter extends DatabaseAdapter {
return result.data;
}

async getMemoryByContent(opts: {
query_table_name: string;
query_threshold: number;
query_input: string;
query_field_name: string;
query_field_sub_name: string;
query_match_count: number;
}): Promise<SimilaritySearch[]> {
const result = await this.supabase.rpc("get_embedding_list", opts);
if (result.error) {
throw new Error(JSON.stringify(result.error));
}
return result.data;
}

async updateGoalStatus(params: {
goalId: UUID;
status: GoalStatus;
Expand Down
17 changes: 17 additions & 0 deletions src/lib/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
Actor,
GoalStatus,
Account,
SimilaritySearch,
} from "./types";

export abstract class DatabaseAdapter {
Expand All @@ -19,6 +20,22 @@ export abstract class DatabaseAdapter {
tableName: string;
}): Promise<Memory[]>;

abstract getMemoryByContent({
query_table_name,
query_threshold,
query_input,
query_field_name,
query_field_sub_name,
query_match_count,
}: {
query_table_name: string;
query_threshold: number;
query_input: string;
query_field_name: string;
query_field_sub_name: string;
query_match_count: number;
}): Promise<SimilaritySearch[]>;

abstract log(params: {
body: { [key: string]: unknown };
user_id: UUID;
Expand Down
33 changes: 16 additions & 17 deletions src/lib/memory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,23 +84,15 @@ export class MemoryManager {
}

async getMemoryByContent(content: string): Promise<SimilaritySearch[]> {
const opts = {
query_table_name: this.tableName,
query_threshold: 2,
query_input: content,
query_field_name: 'content',
query_field_sub_name: 'content',
query_match_count: 10,
};

if (!this.runtime || 'undefined' === typeof this.runtime.supabase) {
return [];
}
const result = await this.runtime.supabase.rpc("get_embedding_list", opts);
if (result.error) {
throw new Error(JSON.stringify(result.error));
}
return result.data;
const result = await this.runtime.databaseAdapter.getMemoryByContent({
query_table_name: this.tableName,
query_threshold: 2,
query_input: content,
query_field_name: "content",
query_field_sub_name: "content",
query_match_count: 10,
});
return result;
}

/**
Expand Down Expand Up @@ -129,6 +121,11 @@ export class MemoryManager {
unique,
} = opts;

console.log("embedding length to search is", embedding.length);

console.log("opts are", opts);
console.log(opts);

const result = await this.runtime.databaseAdapter.searchMemories({
tableName: this.tableName,
userIds: userIds,
Expand All @@ -138,6 +135,8 @@ export class MemoryManager {
unique: !!unique,
});

console.log("result.embedding.length", result[0]?.embedding?.length);

return result;
}

Expand Down
6 changes: 5 additions & 1 deletion src/lib/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ export class BgentRuntime {
body: JSON.stringify({
input,
model: embeddingModel,
length: 1536,
}),
};
try {
Expand All @@ -319,6 +320,8 @@ export class BgentRuntime {

const data: OpenAIEmbeddingResponse = await response.json();

console.log("*** EMBEDDING LENGTH IS", data?.data?.[0].embedding.length);

return data?.data?.[0].embedding;
} catch (e) {
console.error(e);
Expand All @@ -327,7 +330,8 @@ export class BgentRuntime {
}

async retriveCachedEmbedding(input: string) {
const similaritySearchResult = await this.messageManager.getMemoryByContent(input);
const similaritySearchResult =
await this.messageManager.getMemoryByContent(input);
if (similaritySearchResult.length > 0) {
return similaritySearchResult[0].embedding;
}
Expand Down
2 changes: 0 additions & 2 deletions src/lib/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ export interface Memory {
room_id: UUID; // The room or conversation ID associated with the memory.
}


/**
* Represents a similarity search result, including the embedding vector and the Levenshtein score for a given search query.
*/
Expand All @@ -58,7 +57,6 @@ export interface SimilaritySearch {
levenshtein_score: number;
}


/**
* Represents an objective within a goal, detailing what needs to be achieved and whether it has been completed.
*/
Expand Down

0 comments on commit 9af6602

Please sign in to comment.