Skip to content

Commit

Permalink
cache mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
ayulockin committed Dec 14, 2024
1 parent 28fa1fb commit f8e9e61
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 29 deletions.
4 changes: 2 additions & 2 deletions src/ragas/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ragas.cache import cacher, CacheInterface, DiskCacheBackend
from ragas.cache import CacheInterface, DiskCacheBackend, CacherMixin
from ragas.dataset_schema import EvaluationDataset, MultiTurnSample, SingleTurnSample
from ragas.evaluation import evaluate
from ragas.run_config import RunConfig
Expand All @@ -16,7 +16,7 @@
"SingleTurnSample",
"MultiTurnSample",
"EvaluationDataset",
"cacher",
"CacheInterface",
"DiskCacheBackend",
"CacherMixin",
]
19 changes: 19 additions & 0 deletions src/ragas/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,22 @@ def sync_wrapper(*args, **kwargs):
return async_wrapper if is_async else sync_wrapper

return decorator


class CacherMixin:
"""
A mixin that provides a method to wrap functions with the cacher decorator.
Instances of classes inheriting this mixin can specify a cache backend.
"""

def __init__(self, cache: Optional[CacheInterface] = None):
self.cache_backend = cache

def wrap_method_with_cache(self, func):
"""
Wrap the given function with the cacher decorator if a cache_backend is available.
Otherwise, return the original function.
"""
if self.cache_backend is None:
return func
return cacher(cache_backend=self.cache_backend)(func)
38 changes: 20 additions & 18 deletions src/ragas/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from pydantic_core import CoreSchema, core_schema

from ragas.run_config import RunConfig, add_async_retry, add_retry
from ragas.cache import CacheInterface, cacher
from ragas.cache import CacheInterface, CacherMixin

if t.TYPE_CHECKING:
from llama_index.core.base.embeddings.base import BaseEmbedding
Expand Down Expand Up @@ -80,7 +80,7 @@ def __get_pydantic_core_schema__(
)


class LangchainEmbeddingsWrapper(BaseRagasEmbeddings):
class LangchainEmbeddingsWrapper(BaseRagasEmbeddings, CacherMixin):
"""
Wrapper for any embeddings from langchain.
"""
Expand All @@ -90,16 +90,17 @@ def __init__(
run_config: t.Optional[RunConfig] = None,
cache: t.Optional[CacheInterface] = None
):
CacherMixin.__init__(self, cache)

self.embeddings = embeddings
if run_config is None:
run_config = RunConfig()
self.set_run_config(run_config)

if cache is not None:
self.embed_query = cacher(cache_backend=cache)(self.embed_query)
self.embed_documents = cacher(cache_backend=cache)(self.embed_documents)
self.aembed_query = cacher(cache_backend=cache)(self.aembed_query)
self.aembed_documents = cacher(cache_backend=cache)(self.aembed_documents)
self.embed_query = self.wrap_method_with_cache(self.embed_query)
self.embed_documents = self.wrap_method_with_cache(self.embed_documents)
self.aembed_query = self.wrap_method_with_cache(self.aembed_query)
self.aembed_documents = self.wrap_method_with_cache(self.aembed_documents)

def embed_query(self, text: str) -> List[float]:
"""
Expand Down Expand Up @@ -147,7 +148,7 @@ def __repr__(self) -> str:


@dataclass
class HuggingfaceEmbeddings(BaseRagasEmbeddings):
class HuggingfaceEmbeddings(BaseRagasEmbeddings, CacherMixin):
"""
Hugging Face embeddings class for generating embeddings using pre-trained models.
Expand Down Expand Up @@ -204,6 +205,7 @@ def __post_init__(self):
"""
Initialize the model after the object is created.
"""
CacherMixin.__init__(self, self.cache)
try:
import sentence_transformers
from transformers import AutoConfig
Expand Down Expand Up @@ -236,10 +238,9 @@ def __post_init__(self):
if "convert_to_tensor" not in self.encode_kwargs:
self.encode_kwargs["convert_to_tensor"] = True

if self.cache is not None:
self.embed_query = cacher(cache_backend=self.cache)(self.embed_query)
self.embed_documents = cacher(cache_backend=self.cache)(self.embed_documents)
self.predict = cacher(cache_backend=self.cache)(self.predict)
self.embed_query = self.wrap_method_with_cache(self.embed_query)
self.embed_documents = self.wrap_method_with_cache(self.embed_documents)
self.predict = self.wrap_method_with_cache(self.predict)

def embed_query(self, text: str) -> List[float]:
"""
Expand Down Expand Up @@ -281,7 +282,7 @@ def predict(self, texts: List[List[str]]) -> List[List[float]]:
return predictions.tolist()


class LlamaIndexEmbeddingsWrapper(BaseRagasEmbeddings):
class LlamaIndexEmbeddingsWrapper(BaseRagasEmbeddings, CacherMixin):
"""
Wrapper for any embeddings from llama-index.
Expand Down Expand Up @@ -314,16 +315,17 @@ class LlamaIndexEmbeddingsWrapper(BaseRagasEmbeddings):
def __init__(
self, embeddings: BaseEmbedding, run_config: t.Optional[RunConfig] = None, cache: t.Optional[CacheInterface] = None
):
CacherMixin.__init__(self, cache)

self.embeddings = embeddings
if run_config is None:
run_config = RunConfig()
self.set_run_config(run_config)

if cache is not None:
self.embed_query = cacher(cache_backend=cache)(self.embed_query)
self.embed_documents = cacher(cache_backend=cache)(self.embed_documents)
self.aembed_query = cacher(cache_backend=cache)(self.aembed_query)
self.aembed_documents = cacher(cache_backend=cache)(self.aembed_documents)
self.embed_query = self.wrap_method_with_cache(self.embed_query)
self.embed_documents = self.wrap_method_with_cache(self.embed_documents)
self.aembed_query = self.wrap_method_with_cache(self.aembed_query)
self.aembed_documents = self.wrap_method_with_cache(self.aembed_documents)

def embed_query(self, text: str) -> t.List[float]:
return self.embeddings.get_query_embedding(text)
Expand Down
18 changes: 9 additions & 9 deletions src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from langchain_openai.llms import AzureOpenAI, OpenAI
from langchain_openai.llms.base import BaseOpenAI

from ragas.cache import cacher, CacheInterface
from ragas.cache import CacheInterface, CacherMixin
from ragas.exceptions import LLMDidNotFinishException
from ragas.integrations.helicone import helicone_config
from ragas.run_config import RunConfig, add_async_retry
Expand Down Expand Up @@ -112,7 +112,7 @@ async def generate(
return result


class LangchainLLMWrapper(BaseRagasLLM):
class LangchainLLMWrapper(BaseRagasLLM, CacherMixin):
"""
A simple base class for RagasLLMs that is based on Langchain's BaseLanguageModel
interface. it implements 2 functions:
Expand All @@ -127,15 +127,15 @@ def __init__(
is_finished_parser: t.Optional[t.Callable[[LLMResult], bool]] = None,
cache: t.Optional[CacheInterface] = None,
):
CacherMixin.__init__(self, cache)
self.langchain_llm = langchain_llm
if run_config is None:
run_config = RunConfig()
self.set_run_config(run_config)
self.is_finished_parser = is_finished_parser

if cache is not None:
self.generate_text = cacher(cache_backend=cache)(self.generate_text)
self.agenerate_text = cacher(cache_backend=cache)(self.agenerate_text)
self.generate_text = self.wrap_method_with_cache(self.generate_text)
self.agenerate_text = self.wrap_method_with_cache(self.agenerate_text)

def is_finished(self, response: LLMResult) -> bool:
"""
Expand Down Expand Up @@ -270,7 +270,7 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(langchain_llm={self.langchain_llm.__class__.__name__}(...))"


class LlamaIndexLLMWrapper(BaseRagasLLM):
class LlamaIndexLLMWrapper(BaseRagasLLM, CacherMixin):
"""
A Adaptor for LlamaIndex LLMs
"""
Expand All @@ -281,6 +281,7 @@ def __init__(
run_config: t.Optional[RunConfig] = None,
cache: t.Optional[CacheInterface] = None,
):
CacherMixin.__init__(self, cache)
self.llm = llm

try:
Expand All @@ -292,9 +293,8 @@ def __init__(
run_config = RunConfig()
self.set_run_config(run_config)

if cache is not None:
self.generate_text = cacher(cache_backend=cache)(self.generate_text)
self.agenerate_text = cacher(cache_backend=cache)(self.agenerate_text)
self.generate_text = self.wrap_method_with_cache(self.generate_text)
self.agenerate_text = self.wrap_method_with_cache(self.agenerate_text)

def check_args(
self,
Expand Down

0 comments on commit f8e9e61

Please sign in to comment.