From 459ad025e7e777aaee41bcb5533f272fb687c3d4 Mon Sep 17 00:00:00 2001 From: Ryan Date: Mon, 23 Dec 2024 02:39:58 +0700 Subject: [PATCH] fix: postgres adapter settings not being applied --- packages/adapter-postgres/src/index.ts | 14 +++++++++ packages/core/src/embedding.ts | 40 ++++++++++++++++++++------ 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/packages/adapter-postgres/src/index.ts b/packages/adapter-postgres/src/index.ts index b65addda9b4..f1942b9fef9 100644 --- a/packages/adapter-postgres/src/index.ts +++ b/packages/adapter-postgres/src/index.ts @@ -23,6 +23,7 @@ import { elizaLogger, getEmbeddingConfig, DatabaseAdapter, + EmbeddingProvider, } from "@elizaos/core"; import fs from "fs"; import { fileURLToPath } from "url"; @@ -189,6 +190,19 @@ export class PostgresDatabaseAdapter try { await client.query("BEGIN"); + // Set application settings for embedding dimension + const embeddingConfig = getEmbeddingConfig(); + if (embeddingConfig.provider === EmbeddingProvider.OpenAI) { + await client.query("SET app.use_openai_embedding = 'true'"); + await client.query("SET app.use_ollama_embedding = 'false'"); + } else if (embeddingConfig.provider === EmbeddingProvider.Ollama) { + await client.query("SET app.use_openai_embedding = 'false'"); + await client.query("SET app.use_ollama_embedding = 'true'"); + } else { + await client.query("SET app.use_openai_embedding = 'false'"); + await client.query("SET app.use_ollama_embedding = 'false'"); + } + // Check if schema already exists (check for a core table) const { rows } = await client.query(` SELECT EXISTS ( diff --git a/packages/core/src/embedding.ts b/packages/core/src/embedding.ts index 49c1a4163c2..767b6b5673b 100644 --- a/packages/core/src/embedding.ts +++ b/packages/core/src/embedding.ts @@ -14,14 +14,36 @@ interface EmbeddingOptions { provider?: string; } -// Add the embedding configuration -export const getEmbeddingConfig = () => ({ +export const EmbeddingProvider = { + OpenAI: "OpenAI", + Ollama: "Ollama", + GaiaNet: "GaiaNet", + BGE: "BGE", +} as const; + +export type EmbeddingProvider = + (typeof EmbeddingProvider)[keyof typeof EmbeddingProvider]; + +export namespace EmbeddingProvider { + export type OpenAI = typeof EmbeddingProvider.OpenAI; + export type Ollama = typeof EmbeddingProvider.Ollama; + export type GaiaNet = typeof EmbeddingProvider.GaiaNet; + export type BGE = typeof EmbeddingProvider.BGE; +} + +export type EmbeddingConfig = { + readonly dimensions: number; + readonly model: string; + readonly provider: EmbeddingProvider; +}; + +export const getEmbeddingConfig = (): EmbeddingConfig => ({ dimensions: settings.USE_OPENAI_EMBEDDING?.toLowerCase() === "true" ? 1536 // OpenAI : settings.USE_OLLAMA_EMBEDDING?.toLowerCase() === "true" ? 1024 // Ollama mxbai-embed-large - :settings.USE_GAIANET_EMBEDDING?.toLowerCase() === "true" + : settings.USE_GAIANET_EMBEDDING?.toLowerCase() === "true" ? 768 // GaiaNet : 384, // BGE model: @@ -171,7 +193,7 @@ export async function embed(runtime: IAgentRuntime, input: string) { const isNode = typeof process !== "undefined" && process.versions?.node; // Determine which embedding path to use - if (config.provider === "OpenAI") { + if (config.provider === EmbeddingProvider.OpenAI) { return await getRemoteEmbedding(input, { model: config.model, endpoint: "https://api.openai.com/v1", @@ -180,7 +202,7 @@ export async function embed(runtime: IAgentRuntime, input: string) { }); } - if (config.provider === "Ollama") { + if (config.provider === EmbeddingProvider.Ollama) { return await getRemoteEmbedding(input, { model: config.model, endpoint: @@ -191,7 +213,7 @@ export async function embed(runtime: IAgentRuntime, input: string) { }); } - if (config.provider=="GaiaNet") { + if (config.provider == EmbeddingProvider.GaiaNet) { return await getRemoteEmbedding(input, { model: config.model, endpoint: @@ -252,9 +274,11 @@ export async function embed(runtime: IAgentRuntime, input: string) { return await import("fastembed"); } catch { elizaLogger.error("Failed to load fastembed."); - throw new Error("fastembed import failed, falling back to remote embedding"); + throw new Error( + "fastembed import failed, falling back to remote embedding" + ); } - })() + })(), ]); const [fs, { fileURLToPath }, fastEmbed] = moduleImports;