-
Notifications
You must be signed in to change notification settings - Fork 202
/
Copy pathusing_local_embedding_model.py
38 lines (28 loc) · 1.13 KB
/
using_local_embedding_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import sys
sys.path.append("..")
import logging
import numpy as np
from nano_graphrag import GraphRAG, QueryParam
from nano_graphrag._utils import wrap_embedding_func_with_attrs
from sentence_transformers import SentenceTransformer
logging.basicConfig(level=logging.WARNING)
logging.getLogger("nano-graphrag").setLevel(logging.INFO)
WORKING_DIR = "./nano_graphrag_cache_local_embedding_TEST"
EMBED_MODEL = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2", cache_folder=WORKING_DIR, device="cpu"
)
# We're using Sentence Transformers to generate embeddings for the BGE model
@wrap_embedding_func_with_attrs(
embedding_dim=EMBED_MODEL.get_sentence_embedding_dimension(),
max_token_size=EMBED_MODEL.max_seq_length,
)
async def local_embedding(texts: list[str]) -> np.ndarray:
return EMBED_MODEL.encode(texts, normalize_embeddings=True)
rag = GraphRAG(
working_dir=WORKING_DIR,
embedding_func=local_embedding,
)
with open("../tests/mock_data.txt", encoding="utf-8-sig") as f:
FAKE_TEXT = f.read()
# rag.insert(FAKE_TEXT)
print(rag.query("What the main theme of this story?", param=QueryParam(mode="local")))