diff --git a/adalflow/adalflow/core/generator.py b/adalflow/adalflow/core/generator.py index d3662063..309f954f 100644 --- a/adalflow/adalflow/core/generator.py +++ b/adalflow/adalflow/core/generator.py @@ -2,8 +2,9 @@ It is a pipeline that consists of three subcomponents.""" -import os import json +import re +from pathlib import Path from typing import Any, Dict, Optional, Union, Callable, Tuple, List import logging @@ -114,16 +115,14 @@ def __init__( template = template or DEFAULT_LIGHTRAG_SYSTEM_PROMPT - # Cache - model_str = ( - f"{model_client.__class__.__name__}_{model_kwargs.get('model', 'default')}" - ) - _cache_path = ( - get_adalflow_default_root_path() if cache_path is None else cache_path + # create the cache path and initialize the cache engine + + self.set_cache_path( + cache_path, model_client, model_kwargs.get("model", "default") ) - self.cache_path = os.path.join(_cache_path, f"cache_{model_str}.db") CachedEngine.__init__(self, cache_path=self.cache_path) + Component.__init__(self) GradComponent.__init__(self) CallbackManager.__init__(self) @@ -148,7 +147,6 @@ def __init__( self.mock_output_data: str = "mock data" self.data_map_func: Callable = None self.set_data_map_func() - self.model_str = model_str self._use_cache = use_cache self._kwargs = { @@ -163,6 +161,25 @@ def __init__( } self._teacher: Optional["Generator"] = None + def set_cache_path(self, cache_path: str, model_client: object, model: str): + """Set the cache path for the generator.""" + + # Construct a valid model string using the client class name and model + self.model_str = f"{model_client.__class__.__name__}_{model}" + + # Remove any characters that are not allowed in file names (cross-platform) + # On Windows, characters like `:<>?/\|*` are prohibited. + self.model_str = re.sub(r"[^a-zA-Z0-9_\-]", "_", self.model_str) + + _cache_path = ( + get_adalflow_default_root_path() if cache_path is None else cache_path + ) + + # Use pathlib to handle paths more safely across OS + self.cache_path = Path(_cache_path) / f"cache_{self.model_str}.db" + + log.debug(f"Cache path set to: {self.cache_path}") + def get_cache_path(self) -> str: r"""Get the cache path for the generator.""" return self.cache_path diff --git a/adalflow/adalflow/utils/cache.py b/adalflow/adalflow/utils/cache.py index 31fccdaa..c330cdc8 100644 --- a/adalflow/adalflow/utils/cache.py +++ b/adalflow/adalflow/utils/cache.py @@ -1,5 +1,7 @@ import hashlib import diskcache as dc +from pathlib import Path +from typing import Union def hash_text(text: str): @@ -15,9 +17,11 @@ def direct(text: str): class CachedEngine: - def __init__(self, cache_path: str): + def __init__(self, cache_path: Union[str, Path]): super().__init__() - self.cache_path = cache_path + self.cache_path = Path(cache_path) + self.cache_path.parent.mkdir(parents=True, exist_ok=True) + self.cache = dc.Cache(cache_path) def _check_cache(self, prompt: str): diff --git a/adalflow/tests/test_generator.py b/adalflow/tests/test_generator.py index 5ea8b76d..a15c302a 100644 --- a/adalflow/tests/test_generator.py +++ b/adalflow/tests/test_generator.py @@ -3,6 +3,7 @@ import unittest import os import shutil +from pathlib import Path from openai.types import CompletionUsage from openai.types.chat import ChatCompletion @@ -55,6 +56,32 @@ def test_generator_call(self): print(f"output: {output}") # self.assertEqual(output.data, "Generated text response") + def test_cache_path(self): + prompt_kwargs = {"input_str": "Hello, world!"} + model_kwargs = {"model": "phi3.5:latest"} + + self.test_generator = Generator( + model_client=self.mock_api_client, + prompt_kwargs=prompt_kwargs, + model_kwargs=model_kwargs, + use_cache=True, + ) + + # Convert the path to a string to avoid the TypeError + cache_path = self.test_generator.get_cache_path() + cache_path_str = str(cache_path) + + print(f"cache path: {cache_path}") + + # Check if the sanitized model string is in the cache path + self.assertIn("phi3_5_latest", cache_path_str) + + # Check if the cache path exists as a file (or directory, depending on your use case) + + self.assertTrue( + Path(cache_path).exists(), f"Cache path {cache_path_str} does not exist" + ) + def test_generator_prompt_logger_first_record(self): # prompt_kwargs = {"input_str": "Hello, world!"} # model_kwargs = {"model": "gpt-3.5-turbo"}