Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrated get_embeddings from litellm service to kairon #1844

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 99 additions & 21 deletions kairon/shared/llm/processor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import os
import time
import urllib.parse
from secrets import randbelow, choice
from typing import Text, Dict, List, Tuple, Union
from urllib.parse import urljoin

import litellm
from fastembed import SparseTextEmbedding, LateInteractionTextEmbedding
from loguru import logger as logging
from mongoengine.base import BaseList
from tiktoken import get_encoding
Expand Down Expand Up @@ -44,17 +46,16 @@ def __init__(self, bot: Text, llm_type: str):
self.llm_type = llm_type
self.vectors_config = {}
self.sparse_vectors_config = {}


self.llm_secret = Sysadmin.get_llm_secret(llm_type, bot)
if llm_type != DEFAULT_LLM:
self.llm_secret_embedding = Sysadmin.get_llm_secret(DEFAULT_LLM, bot)
else:
self.llm_secret_embedding = self.llm_secret

self.tokenizer = get_encoding("cl100k_base")
self.EMBEDDING_CTX_LENGTH = 8191
self.__logs = []
self.load_sparse_embedding_model()
self.load_rerank_embedding_model()

async def train(self, user, *args, **kwargs) -> Dict:
invocation = kwargs.pop('invocation', None)
Expand Down Expand Up @@ -144,29 +145,48 @@ async def predict(self, query: Text, user, *args, **kwargs) -> Tuple:
elapsed_time = end_time - start_time
return response, elapsed_time

def truncate_text(self, texts: List[Text]) -> List[Text]:
"""
Truncate multiple texts to 8191 tokens for openai
"""
truncated_texts = []

for text in texts:
tokens = self.tokenizer.encode(text)[:self.EMBEDDING_CTX_LENGTH]
truncated_texts.append(self.tokenizer.decode(tokens))

async def get_embedding(self, texts: Union[Text, List[Text]], user: Text, **kwargs):
return truncated_texts

async def get_embedding(self, texts: Union[Text, List[Text]], user, **kwargs):
"""
Get embeddings for a batch of texts by making an API call.
Get embeddings for a batch of texts.
"""
body = {
'texts': texts,
'user': user,
'invocation': kwargs.get("invocation")
}
try:
is_single_text = isinstance(texts, str)
if is_single_text:
texts = [texts]

timeout = Utility.environment['llm'].get('request_timeout', 30)
http_response, status_code, _, _ = await ActionUtility.execute_request_async(
http_url=f"{Utility.environment['llm']['url']}/{urllib.parse.quote(self.bot)}/embedding/{self.llm_type}",
request_method="POST",
request_body=body,
timeout=timeout)
truncated_texts = self.truncate_text(texts)

embeddings = {}

result = await litellm.aembedding(
model="text-embedding-3-small",
input=truncated_texts,
metadata={'user': user, 'bot': self.bot, 'invocation': kwargs.get("invocation")},
api_key=self.llm_secret_embedding.get('api_key'),
num_retries=3
)
embeddings["dense"] = [embedding["embedding"] for embedding in result["data"]]
embeddings["sparse"] = self.get_sparse_embedding(truncated_texts)
embeddings["rerank"] = self.get_rerank_embedding(truncated_texts)

if is_single_text:
return {model: embedding[0] for model, embedding in embeddings.items()}

if status_code == 200:
embeddings = http_response.get('embedding', {})
return embeddings
else:
raise Exception(f"Failed to fetch embeddings: {http_response.get('message', 'Unknown error')}")
except Exception as e:
raise Exception(f"Failed to fetch embeddings: {str(e)}")

async def __parse_completion_response(self, response, **kwargs):
if kwargs.get("stream"):
Expand Down Expand Up @@ -467,4 +487,62 @@ async def initialize_vector_configs(self):
self.vectors_config = response_data.get('vectors_config', {})
self.sparse_vectors_config = response_data.get('sparse_vectors_config', {})
else:
raise Exception(f"Failed to fetch vector configs: {http_response.get('message', 'Unknown error')}")
raise Exception(f"Failed to fetch vector configs: {http_response.get('message', 'Unknown error')}")

@classmethod
def load_sparse_embedding_model(cls):
hf_cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
kairon_cache_dir = "./kairon/pre-trained-models/"

cache_dir = hf_cache_dir if os.path.exists(hf_cache_dir) else kairon_cache_dir

if cls._sparse_embedding is None:
cls._sparse_embedding = SparseTextEmbedding("Qdrant/bm25", cache_dir=cache_dir)
logging.info("SPARSE MODEL LOADED")

@classmethod
def load_rerank_embedding_model(cls):
hf_cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
kairon_cache_dir = "./kairon/pre-trained-models/"

cache_dir = hf_cache_dir if os.path.exists(hf_cache_dir) else kairon_cache_dir

if cls._rerank_embedding is None:
cls._rerank_embedding = LateInteractionTextEmbedding("colbert-ir/colbertv2.0", cache_dir=cache_dir)
logging.info("RERANK MODEL LOADED")

def get_sparse_embedding(self, sentences):
"""
Generate sparse embeddings for a list of sentences.

Args:
sentences (list): A list of sentences to be encoded

Returns:
list: A list of embeddings.
"""
try:
embeddings = list(self._sparse_embedding.passage_embed(sentences))

return [
{"values": emb.values.tolist(), "indices": emb.indices.tolist()}
for emb in embeddings
]
except Exception as e:
raise Exception(f"Error processing sparse embeddings: {str(e)}")

def get_rerank_embedding(self, sentences):
"""
Generate embeddings for a list of sentences.

Args:
sentences (list): A list of sentences to be encoded.

Returns:
list: A list of embedding vectors.
"""
try:
embeddings = list(self._rerank_embedding.passage_embed(sentences))
return [emb.tolist() for emb in embeddings]
except Exception as e:
raise Exception(f"Error processing rerank embeddings: {str(e)}")
Loading
Loading