Skip to content

Commit

Permalink
fix issue #237
Browse files Browse the repository at this point in the history
  • Loading branch information
liyin2015 committed Oct 27, 2024
1 parent e1f943f commit d15022b
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 11 deletions.
35 changes: 26 additions & 9 deletions adalflow/adalflow/core/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
It is a pipeline that consists of three subcomponents."""

import os
import json
import re
from pathlib import Path

from typing import Any, Dict, Optional, Union, Callable, Tuple, List
import logging
Expand Down Expand Up @@ -114,16 +115,14 @@ def __init__(

template = template or DEFAULT_LIGHTRAG_SYSTEM_PROMPT

# Cache
model_str = (
f"{model_client.__class__.__name__}_{model_kwargs.get('model', 'default')}"
)
_cache_path = (
get_adalflow_default_root_path() if cache_path is None else cache_path
# create the cache path and initialize the cache engine

self.set_cache_path(
cache_path, model_client, model_kwargs.get("model", "default")
)
self.cache_path = os.path.join(_cache_path, f"cache_{model_str}.db")

CachedEngine.__init__(self, cache_path=self.cache_path)

Component.__init__(self)
GradComponent.__init__(self)
CallbackManager.__init__(self)
Expand All @@ -148,7 +147,6 @@ def __init__(
self.mock_output_data: str = "mock data"
self.data_map_func: Callable = None
self.set_data_map_func()
self.model_str = model_str
self._use_cache = use_cache

self._kwargs = {
Expand All @@ -163,6 +161,25 @@ def __init__(
}
self._teacher: Optional["Generator"] = None

def set_cache_path(self, cache_path: str, model_client: object, model: str):
"""Set the cache path for the generator."""

# Construct a valid model string using the client class name and model
self.model_str = f"{model_client.__class__.__name__}_{model}"

# Remove any characters that are not allowed in file names (cross-platform)
# On Windows, characters like `:<>?/\|*` are prohibited.
self.model_str = re.sub(r"[^a-zA-Z0-9_\-]", "_", self.model_str)

_cache_path = (
get_adalflow_default_root_path() if cache_path is None else cache_path
)

# Use pathlib to handle paths more safely across OS
self.cache_path = Path(_cache_path) / f"cache_{self.model_str}.db"

log.debug(f"Cache path set to: {self.cache_path}")

def get_cache_path(self) -> str:
r"""Get the cache path for the generator."""
return self.cache_path
Expand Down
8 changes: 6 additions & 2 deletions adalflow/adalflow/utils/cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import hashlib
import diskcache as dc
from pathlib import Path
from typing import Union


def hash_text(text: str):
Expand All @@ -15,9 +17,11 @@ def direct(text: str):


class CachedEngine:
def __init__(self, cache_path: str):
def __init__(self, cache_path: Union[str, Path]):
super().__init__()
self.cache_path = cache_path
self.cache_path = Path(cache_path)
self.cache_path.parent.mkdir(parents=True, exist_ok=True)

self.cache = dc.Cache(cache_path)

def _check_cache(self, prompt: str):
Expand Down
27 changes: 27 additions & 0 deletions adalflow/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import unittest
import os
import shutil
from pathlib import Path

from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion
Expand Down Expand Up @@ -55,6 +56,32 @@ def test_generator_call(self):
print(f"output: {output}")
# self.assertEqual(output.data, "Generated text response")

def test_cache_path(self):
prompt_kwargs = {"input_str": "Hello, world!"}
model_kwargs = {"model": "phi3.5:latest"}

self.test_generator = Generator(
model_client=self.mock_api_client,
prompt_kwargs=prompt_kwargs,
model_kwargs=model_kwargs,
use_cache=True,
)

# Convert the path to a string to avoid the TypeError
cache_path = self.test_generator.get_cache_path()
cache_path_str = str(cache_path)

print(f"cache path: {cache_path}")

# Check if the sanitized model string is in the cache path
self.assertIn("phi3_5_latest", cache_path_str)

# Check if the cache path exists as a file (or directory, depending on your use case)

self.assertTrue(
Path(cache_path).exists(), f"Cache path {cache_path_str} does not exist"
)

def test_generator_prompt_logger_first_record(self):
# prompt_kwargs = {"input_str": "Hello, world!"}
# model_kwargs = {"model": "gpt-3.5-turbo"}
Expand Down

0 comments on commit d15022b

Please sign in to comment.