Skip to content

Commit

Permalink
caching support for embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
ayulockin committed Dec 14, 2024
1 parent 1c12986 commit 28fa1fb
Showing 1 changed file with 23 additions and 2 deletions.
25 changes: 23 additions & 2 deletions src/ragas/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pydantic_core import CoreSchema, core_schema

from ragas.run_config import RunConfig, add_async_retry, add_retry
from ragas.cache import CacheInterface, cacher

if t.TYPE_CHECKING:
from llama_index.core.base.embeddings.base import BaseEmbedding
Expand Down Expand Up @@ -85,13 +86,21 @@ class LangchainEmbeddingsWrapper(BaseRagasEmbeddings):
"""

def __init__(
self, embeddings: Embeddings, run_config: t.Optional[RunConfig] = None
self, embeddings: Embeddings,
run_config: t.Optional[RunConfig] = None,
cache: t.Optional[CacheInterface] = None
):
self.embeddings = embeddings
if run_config is None:
run_config = RunConfig()
self.set_run_config(run_config)

if cache is not None:
self.embed_query = cacher(cache_backend=cache)(self.embed_query)
self.embed_documents = cacher(cache_backend=cache)(self.embed_documents)
self.aembed_query = cacher(cache_backend=cache)(self.aembed_query)
self.aembed_documents = cacher(cache_backend=cache)(self.aembed_documents)

def embed_query(self, text: str) -> List[float]:
"""
Embed a single query text.
Expand Down Expand Up @@ -189,6 +198,7 @@ class HuggingfaceEmbeddings(BaseRagasEmbeddings):
cache_folder: t.Optional[str] = None
model_kwargs: t.Dict[str, t.Any] = field(default_factory=dict)
encode_kwargs: t.Dict[str, t.Any] = field(default_factory=dict)
cache: t.Optional[CacheInterface] = None

def __post_init__(self):
"""
Expand Down Expand Up @@ -226,6 +236,11 @@ def __post_init__(self):
if "convert_to_tensor" not in self.encode_kwargs:
self.encode_kwargs["convert_to_tensor"] = True

if self.cache is not None:
self.embed_query = cacher(cache_backend=self.cache)(self.embed_query)
self.embed_documents = cacher(cache_backend=self.cache)(self.embed_documents)
self.predict = cacher(cache_backend=self.cache)(self.predict)

def embed_query(self, text: str) -> List[float]:
"""
Embed a single query text.
Expand Down Expand Up @@ -297,13 +312,19 @@ class LlamaIndexEmbeddingsWrapper(BaseRagasEmbeddings):
"""

def __init__(
self, embeddings: BaseEmbedding, run_config: t.Optional[RunConfig] = None
self, embeddings: BaseEmbedding, run_config: t.Optional[RunConfig] = None, cache: t.Optional[CacheInterface] = None
):
self.embeddings = embeddings
if run_config is None:
run_config = RunConfig()
self.set_run_config(run_config)

if cache is not None:
self.embed_query = cacher(cache_backend=cache)(self.embed_query)
self.embed_documents = cacher(cache_backend=cache)(self.embed_documents)
self.aembed_query = cacher(cache_backend=cache)(self.aembed_query)
self.aembed_documents = cacher(cache_backend=cache)(self.aembed_documents)

def embed_query(self, text: str) -> t.List[float]:
return self.embeddings.get_query_embedding(text)

Expand Down

0 comments on commit 28fa1fb

Please sign in to comment.