Skip to content

Commit

Permalink
cleaner embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
ayulockin committed Dec 14, 2024
1 parent f84f300 commit 7299bd3
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions src/ragas/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio
import typing as t
from abc import ABC
from abc import ABC, abstractmethod
from dataclasses import field
from typing import List

Expand Down Expand Up @@ -36,6 +36,20 @@ class BaseRagasEmbeddings(Embeddings, ABC):
"""

run_config: RunConfig
cache: t.Optional[CacheInterface] = None

def __init__(self, cache: t.Optional[CacheInterface] = None):
super().__init__()
self.cache = cache
if self.cache is not None:
self.embed_query = cacher(cache_backend=self.cache)(self.embed_query)
self.embed_documents = cacher(cache_backend=self.cache)(
self.embed_documents
)
self.aembed_query = cacher(cache_backend=self.cache)(self.aembed_query)
self.aembed_documents = cacher(cache_backend=self.cache)(
self.aembed_documents
)

async def embed_text(self, text: str, is_async=True) -> List[float]:
"""
Expand All @@ -62,6 +76,12 @@ async def embed_texts(
)
return await loop.run_in_executor(None, embed_documents_with_retry, texts)

@abstractmethod
async def aembed_query(self, text: str) -> List[float]: ...

@abstractmethod
async def aembed_documents(self, texts: List[str]) -> t.List[t.List[float]]: ...

def set_run_config(self, run_config: RunConfig):
"""
Set the run configuration for the embedding operations.
Expand Down Expand Up @@ -91,17 +111,12 @@ def __init__(
run_config: t.Optional[RunConfig] = None,
cache: t.Optional[CacheInterface] = None,
):
super().__init__(cache=cache)
self.embeddings = embeddings
if run_config is None:
run_config = RunConfig()
self.set_run_config(run_config)

if cache is not None:
self.embed_query = cacher(cache_backend=cache)(self.embed_query)
self.embed_documents = cacher(cache_backend=cache)(self.embed_documents)
self.aembed_query = cacher(cache_backend=cache)(self.aembed_query)
self.aembed_documents = cacher(cache_backend=cache)(self.aembed_documents)

def embed_query(self, text: str) -> List[float]:
"""
Embed a single query text.
Expand Down Expand Up @@ -205,6 +220,7 @@ def __post_init__(self):
"""
Initialize the model after the object is created.
"""
super().__init__(cache=self.cache)
try:
import sentence_transformers
from transformers import AutoConfig
Expand Down Expand Up @@ -238,10 +254,6 @@ def __post_init__(self):
self.encode_kwargs["convert_to_tensor"] = True

if self.cache is not None:
self.embed_query = cacher(cache_backend=self.cache)(self.embed_query)
self.embed_documents = cacher(cache_backend=self.cache)(
self.embed_documents
)
self.predict = cacher(cache_backend=self.cache)(self.predict)

def embed_query(self, text: str) -> List[float]:
Expand Down Expand Up @@ -320,17 +332,12 @@ def __init__(
run_config: t.Optional[RunConfig] = None,
cache: t.Optional[CacheInterface] = None,
):
super().__init__(cache=cache)
self.embeddings = embeddings
if run_config is None:
run_config = RunConfig()
self.set_run_config(run_config)

if cache is not None:
self.embed_query = cacher(cache_backend=cache)(self.embed_query)
self.embed_documents = cacher(cache_backend=cache)(self.embed_documents)
self.aembed_query = cacher(cache_backend=cache)(self.aembed_query)
self.aembed_documents = cacher(cache_backend=cache)(self.aembed_documents)

def embed_query(self, text: str) -> t.List[float]:
return self.embeddings.get_query_embedding(text)

Expand Down

0 comments on commit 7299bd3

Please sign in to comment.