Skip to content

Commit

Permalink
make llm caching implementation less ugly
Browse files Browse the repository at this point in the history
  • Loading branch information
ayulockin committed Dec 14, 2024
1 parent 7f5d399 commit 0a79a04
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/ragas/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ def is_multiple_completion_supported(llm: BaseLanguageModel) -> bool:
class BaseRagasLLM(ABC):
run_config: RunConfig = field(default_factory=RunConfig, repr=False)
multiple_completion_supported: bool = field(default=False, repr=False)
cache: t.Optional[CacheInterface] = field(default=None, repr=False)

def __post_init__(self):
# If a cache_backend is provided, wrap the implementation methods at construction time.
if self.cache is not None:
self.generate_text = cacher(cache_backend=self.cache)(self.generate_text)
self.agenerate_text = cacher(cache_backend=self.cache)(self.agenerate_text)

def set_run_config(self, run_config: RunConfig):
self.run_config = run_config
Expand Down Expand Up @@ -125,18 +132,13 @@ def __init__(
langchain_llm: BaseLanguageModel,
run_config: t.Optional[RunConfig] = None,
is_finished_parser: t.Optional[t.Callable[[LLMResult], bool]] = None,
cache: t.Optional[CacheInterface] = None,
):
self.langchain_llm = langchain_llm
if run_config is None:
run_config = RunConfig()
self.set_run_config(run_config)
self.is_finished_parser = is_finished_parser

if cache is not None:
self.generate_text = cacher(cache_backend=cache)(self.generate_text)
self.agenerate_text = cacher(cache_backend=cache)(self.agenerate_text)

def is_finished(self, response: LLMResult) -> bool:
"""
Parse the response to check if the LLM finished by checking the finish_reason
Expand Down Expand Up @@ -281,6 +283,7 @@ def __init__(
run_config: t.Optional[RunConfig] = None,
cache: t.Optional[CacheInterface] = None,
):
super().__init__(cache=cache)
self.llm = llm

try:
Expand All @@ -292,10 +295,6 @@ def __init__(
run_config = RunConfig()
self.set_run_config(run_config)

if cache is not None:
self.generate_text = cacher(cache_backend=cache)(self.generate_text)
self.agenerate_text = cacher(cache_backend=cache)(self.agenerate_text)

def check_args(
self,
n: int,
Expand Down

0 comments on commit 0a79a04

Please sign in to comment.