Skip to content

Commit

Permalink
Switched from three embedding system to a single embedding system (#1847
Browse files Browse the repository at this point in the history
)

* Migrated get_embeddings from litellm service to kairon and added corresponding test cases

* Added fastembed library

* Updated Test Cases and loading of models

* Switched from three embedding system to a single embedding system
  • Loading branch information
himanshugt16 authored Mar 10, 2025
1 parent 4760807 commit be3ab3c
Show file tree
Hide file tree
Showing 10 changed files with 969 additions and 2,213 deletions.
2 changes: 0 additions & 2 deletions kairon/actions/definitions/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from kairon.shared.data.constant import DEFAULT_NLU_FALLBACK_RESPONSE
from kairon.shared.models import LlmPromptType, LlmPromptSource
from kairon.shared.llm.processor import LLMProcessor
LLMProcessor.load_sparse_embedding_model()
LLMProcessor.load_rerank_embedding_model()



Expand Down
5 changes: 1 addition & 4 deletions kairon/shared/cognition/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,8 +585,6 @@ async def upsert_data(self, primary_key_col: str, collection_name: str, event_ty
"""

from kairon.shared.llm.processor import LLMProcessor
LLMProcessor.load_sparse_embedding_model()
LLMProcessor.load_rerank_embedding_model()
llm_processor = LLMProcessor(bot, DEFAULT_LLM)
suffix = "_faq_embd"
qdrant_collection = f"{bot}_{collection_name}{suffix}" if collection_name else f"{bot}{suffix}"
Expand Down Expand Up @@ -656,8 +654,7 @@ async def sync_with_qdrant(self, llm_processor, collection_name, bot, document,
search_payload, embedding_payload = Utility.retrieve_search_payload_and_embedding_payload(
document['data'], metadata)
embeddings = await llm_processor.get_embedding(embedding_payload, user, invocation='knowledge_vault_sync')
embeddings_formatted = {key: value[0] for key, value in embeddings.items()}
points = [{'id': document['vector_id'], 'vector': embeddings_formatted, 'payload': search_payload}]
points = [{'id': document['vector_id'], 'vector': embeddings, 'payload': search_payload}]
await llm_processor.__collection_upsert__(collection_name, {'points': points},
err_msg="Unable to train FAQ! Contact support")
logger.info(f"Row with {primary_key_col}: {document['data'].get(primary_key_col)} upserted in Qdrant.")
Expand Down
79 changes: 35 additions & 44 deletions kairon/shared/llm/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ def __init__(self, bot: Text, llm_type: str):
self.headers = {"api-key": Utility.environment['vector']['key']}
self.suffix = "_faq_embd"
self.llm_type = llm_type
self.vectors_config = {}
self.sparse_vectors_config = {}
self.vector_config = {'size': self.__embedding__, 'distance': 'Cosine'}
# self.vectors_config = {}
# self.sparse_vectors_config = {}
self.llm_secret = Sysadmin.get_llm_secret(llm_type, bot)
if llm_type != DEFAULT_LLM:
self.llm_secret_embedding = Sysadmin.get_llm_secret(DEFAULT_LLM, bot)
Expand Down Expand Up @@ -95,19 +96,8 @@ async def train(self, user, *args, **kwargs) -> Dict:
vector_ids.append(content['vector_id'])

embeddings = await self.get_embedding(embedding_payloads, user, invocation=invocation)
points = []

for idx, vector_id in enumerate(vector_ids):
vector_data = {}
for model_name, model_embeddings in embeddings.items():
vector_data[model_name] = model_embeddings[idx]
point = {
"id": vector_id,
"payload": search_payloads[idx],
"vector": vector_data
}
points.append(point)

points = [{'id': vector_ids[idx], 'vector': embeddings[idx], 'payload': search_payloads[idx]}
for idx in range(len(vector_ids))]
await self.__collection_upsert__(collection, {'points': points},
err_msg="Unable to train FAQ! Contact support")
count += len(batch_contents)
Expand Down Expand Up @@ -158,32 +148,26 @@ async def get_embedding(self, texts: Union[Text, List[Text]], user, **kwargs):
"""
Get embeddings for a batch of texts.
"""
try:
is_single_text = isinstance(texts, str)
if is_single_text:
texts = [texts]

truncated_texts = self.truncate_text(texts)

embeddings = {}
is_single_text = isinstance(texts, str)
if is_single_text:
texts = [texts]

truncated_texts = self.truncate_text(texts)

result = await litellm.aembedding(
model="text-embedding-3-small",
input=truncated_texts,
metadata={'user': user, 'bot': self.bot, 'invocation': kwargs.get("invocation")},
api_key=self.llm_secret_embedding.get('api_key'),
num_retries=3
)

result = await litellm.aembedding(
model="text-embedding-3-small",
input=truncated_texts,
metadata={'user': user, 'bot': self.bot, 'invocation': kwargs.get("invocation")},
api_key=self.llm_secret_embedding.get('api_key'),
num_retries=3
)
embeddings["dense"] = [embedding["embedding"] for embedding in result["data"]]
embeddings["sparse"] = self.get_sparse_embedding(truncated_texts)
embeddings["rerank"] = self.get_rerank_embedding(truncated_texts)
embeddings = [embedding["embedding"] for embedding in result["data"]]

if is_single_text:
return {model: embedding[0] for model, embedding in embeddings.items()}
if is_single_text:
return embeddings[0]

return embeddings
except Exception as e:
raise Exception(f"Failed to fetch embeddings: {str(e)}")
return embeddings

async def __parse_completion_response(self, response, **kwargs):
if kwargs.get("stream"):
Expand Down Expand Up @@ -289,13 +273,10 @@ async def __delete_collections(self):
await client.cleanup()

async def __create_collection__(self, collection_name: Text):
await self.initialize_vector_configs()
await AioRestClient().request(http_url=urljoin(self.db_url, f"/collections/{collection_name}"),
request_method="PUT",
headers=self.headers,
request_body={'name': collection_name, 'vectors': self.vectors_config,
'sparse_vectors': self.sparse_vectors_config
},
request_body={'name': collection_name, 'vectors': self.vector_config},
return_json=False,
timeout=5)

Expand Down Expand Up @@ -329,6 +310,16 @@ async def __collection_exists__(self, collection_name: Text) -> bool:
logging.info(e)
return False

async def __collection_search__(self, collection_name: Text, vector: List, limit: int, score_threshold: float):
client = AioRestClient()
response = await client.request(
http_url=urljoin(self.db_url, f"/collections/{collection_name}/points/search"),
request_method="POST",
headers=self.headers,
request_body={'vector': vector, 'limit': limit, 'with_payload': True, 'score_threshold': score_threshold},
return_json=True,
timeout=5)
return response

async def __collection_hybrid_query__(self, collection_name: Text, embeddings: Dict, limit: int, score_threshold: float):
client = AioRestClient()
Expand Down Expand Up @@ -385,10 +376,10 @@ async def __attach_similarity_prompt_if_enabled(self, query_embedding, context_p
collection_name = f"{self.bot}{self.suffix}"
else:
collection_name = f"{self.bot}_{similarity_context_prompt.get('collection')}{self.suffix}"
search_result = await self.__collection_hybrid_query__(collection_name, embeddings=query_embedding, limit=limit,
search_result = await self.__collection_search__(collection_name, vector=query_embedding, limit=limit,
score_threshold=score_threshold)

for entry in search_result['result']['points']:
for entry in search_result['result']:
if 'content' not in entry['payload']:
extracted_payload = {}
for key, value in entry['payload'].items():
Expand Down
21 changes: 1 addition & 20 deletions kairon/shared/vector_embeddings/db/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from kairon.shared.actions.models import DbActionOperationType
from kairon.shared.actions.exception import ActionFailure
from kairon.shared.llm.processor import LLMProcessor
LLMProcessor.load_sparse_embedding_model()
LLMProcessor.load_rerank_embedding_model()


class Qdrant(DatabaseBase, ABC):
Expand Down Expand Up @@ -39,24 +37,7 @@ async def perform_operation(self, data: Dict, user: str, **kwargs):
user_msg = data.get(DbActionOperationType.embedding_search)
if user_msg and isinstance(user_msg, str):
vector = await self.__get_embedding(user_msg, user, **kwargs)
request['prefetch'] = [
{
"query": vector.get("dense", []),
"using": "dense",
"limit": 20
},
{
"query": vector.get("rerank", []),
"using": "rerank",
"limit": 20
},
{
"query": vector.get("sparse", {}),
"using": "sparse",
"limit": 20
}
]
request.update({"query": {"fusion": "rrf"}})
request['query'] = vector

if DbActionOperationType.payload_search in data:
payload = data.get(DbActionOperationType.payload_search)
Expand Down
2 changes: 0 additions & 2 deletions kairon/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
from kairon.shared.metering.metering_processor import MeteringProcessor
from kairon.shared.utils import Utility
from kairon.shared.llm.processor import LLMProcessor
LLMProcessor.load_sparse_embedding_model()
LLMProcessor.load_rerank_embedding_model()


def train_model_for_bot(bot: str):
Expand Down
Loading

0 comments on commit be3ab3c

Please sign in to comment.