Skip to content

Commit

Permalink
Merge pull request elizaOS#1379 from ryanleecode/fix/postgres-adapter…
Browse files Browse the repository at this point in the history
…-settings

fix: postgres adapter settings not being applied
  • Loading branch information
monilpat authored Dec 23, 2024
2 parents 0e61d05 + 69929ce commit 9dd14db
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 8 deletions.
14 changes: 14 additions & 0 deletions packages/adapter-postgres/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import {
elizaLogger,
getEmbeddingConfig,
DatabaseAdapter,
EmbeddingProvider,
} from "@elizaos/core";
import fs from "fs";
import { fileURLToPath } from "url";
Expand Down Expand Up @@ -189,6 +190,19 @@ export class PostgresDatabaseAdapter
try {
await client.query("BEGIN");

// Set application settings for embedding dimension
const embeddingConfig = getEmbeddingConfig();
if (embeddingConfig.provider === EmbeddingProvider.OpenAI) {
await client.query("SET app.use_openai_embedding = 'true'");
await client.query("SET app.use_ollama_embedding = 'false'");
} else if (embeddingConfig.provider === EmbeddingProvider.Ollama) {
await client.query("SET app.use_openai_embedding = 'false'");
await client.query("SET app.use_ollama_embedding = 'true'");
} else {
await client.query("SET app.use_openai_embedding = 'false'");
await client.query("SET app.use_ollama_embedding = 'false'");
}

// Check if schema already exists (check for a core table)
const { rows } = await client.query(`
SELECT EXISTS (
Expand Down
40 changes: 32 additions & 8 deletions packages/core/src/embedding.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,36 @@ interface EmbeddingOptions {
provider?: string;
}

// Add the embedding configuration
export const getEmbeddingConfig = () => ({
export const EmbeddingProvider = {
OpenAI: "OpenAI",
Ollama: "Ollama",
GaiaNet: "GaiaNet",
BGE: "BGE",
} as const;

export type EmbeddingProvider =
(typeof EmbeddingProvider)[keyof typeof EmbeddingProvider];

export namespace EmbeddingProvider {
export type OpenAI = typeof EmbeddingProvider.OpenAI;
export type Ollama = typeof EmbeddingProvider.Ollama;
export type GaiaNet = typeof EmbeddingProvider.GaiaNet;
export type BGE = typeof EmbeddingProvider.BGE;
}

export type EmbeddingConfig = {
readonly dimensions: number;
readonly model: string;
readonly provider: EmbeddingProvider;
};

export const getEmbeddingConfig = (): EmbeddingConfig => ({
dimensions:
settings.USE_OPENAI_EMBEDDING?.toLowerCase() === "true"
? 1536 // OpenAI
: settings.USE_OLLAMA_EMBEDDING?.toLowerCase() === "true"
? 1024 // Ollama mxbai-embed-large
:settings.USE_GAIANET_EMBEDDING?.toLowerCase() === "true"
: settings.USE_GAIANET_EMBEDDING?.toLowerCase() === "true"
? 768 // GaiaNet
: 384, // BGE
model:
Expand Down Expand Up @@ -171,7 +193,7 @@ export async function embed(runtime: IAgentRuntime, input: string) {
const isNode = typeof process !== "undefined" && process.versions?.node;

// Determine which embedding path to use
if (config.provider === "OpenAI") {
if (config.provider === EmbeddingProvider.OpenAI) {
return await getRemoteEmbedding(input, {
model: config.model,
endpoint: "https://api.openai.com/v1",
Expand All @@ -180,7 +202,7 @@ export async function embed(runtime: IAgentRuntime, input: string) {
});
}

if (config.provider === "Ollama") {
if (config.provider === EmbeddingProvider.Ollama) {
return await getRemoteEmbedding(input, {
model: config.model,
endpoint:
Expand All @@ -191,7 +213,7 @@ export async function embed(runtime: IAgentRuntime, input: string) {
});
}

if (config.provider=="GaiaNet") {
if (config.provider == EmbeddingProvider.GaiaNet) {
return await getRemoteEmbedding(input, {
model: config.model,
endpoint:
Expand Down Expand Up @@ -252,9 +274,11 @@ export async function embed(runtime: IAgentRuntime, input: string) {
return await import("fastembed");
} catch {
elizaLogger.error("Failed to load fastembed.");
throw new Error("fastembed import failed, falling back to remote embedding");
throw new Error(
"fastembed import failed, falling back to remote embedding"
);
}
})()
})(),
]);

const [fs, { fileURLToPath }, fastEmbed] = moduleImports;
Expand Down

0 comments on commit 9dd14db

Please sign in to comment.