forked from run-llama/LlamaIndexTS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGeminiEmbedding.ts
39 lines (34 loc) · 1.18 KB
/
GeminiEmbedding.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import { BaseEmbedding } from "@llamaindex/core/embeddings";
import { GeminiSession, GeminiSessionStore } from "../llm/gemini/base.js";
import { GEMINI_BACKENDS } from "../llm/gemini/types.js";
export enum GEMINI_EMBEDDING_MODEL {
EMBEDDING_001 = "embedding-001",
TEXT_EMBEDDING_004 = "text-embedding-004",
}
/**
* GeminiEmbedding is an alias for Gemini that implements the BaseEmbedding interface.
* Note: Vertex SDK currently does not support embeddings
*/
export class GeminiEmbedding extends BaseEmbedding {
model: GEMINI_EMBEDDING_MODEL;
session: GeminiSession;
constructor(init?: Partial<GeminiEmbedding>) {
super();
this.model = init?.model ?? GEMINI_EMBEDDING_MODEL.EMBEDDING_001;
this.session =
init?.session ??
(GeminiSessionStore.get({
backend: GEMINI_BACKENDS.GOOGLE,
}) as GeminiSession);
}
private async getEmbedding(prompt: string): Promise<number[]> {
const client = this.session.getGenerativeModel({
model: this.model,
});
const result = await client.embedContent(prompt);
return result.embedding.values;
}
getTextEmbedding(text: string): Promise<number[]> {
return this.getEmbedding(text);
}
}