Skip to content

Commit

Permalink
🐛 fix cache paths when they become too long to handle (#643)
Browse files Browse the repository at this point in the history
Signed-off-by: Pranav Gaikwad <[email protected]>
  • Loading branch information
pranavgaikwad authored Feb 12, 2025
1 parent e4ea2de commit 03f9298
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 25 deletions.
76 changes: 59 additions & 17 deletions kai/cache.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import platform
import re
from abc import ABC, abstractmethod
from pathlib import Path
Expand Down Expand Up @@ -225,41 +226,82 @@ def __init__(self, task: Task, request_type: str = "llm_request"):
self.task = task
self.request_type = request_type
self._req_count = 0
self._limit = {
"Windows": 190,
"Linux": 3986,
"Darwin": 914,
}.get(platform.system(), 190)

def _path_with_limit(self, root: Optional[Task], path: Path) -> Path:
if len(str(path)) > self._limit and root is not None:
root_path = Path(root.__class__.__name__)
stem = Path(f"depth_{self.task.depth}") / self.task.__class__.__name__
if isinstance(root, ValidationError):
root_path /= self._clean_filename(root.file)
if isinstance(self.task, ValidationError):
if not isinstance(root, ValidationError) or (
isinstance(root, ValidationError) and root.file != self.task.file
):
stem /= self._clean_filename(self.task.file)
if isinstance(self.task, AnalyzerRuleViolation):
stem /= self.task.violation.id
return root_path / stem / path.name
return path

def _parent_file_path(self, task: Task) -> Optional[str]:
if isinstance(task.parent, ValidationError):
return task.parent.file
return None

def _clean_filename(self, name: str) -> str:
filename = re.sub(r"[\\/:\.]", "_", name)
filename = re.sub(r"\_+", "_", filename)
segments = filename.split("_")
filename = "_".join(segments[-min(3, len(segments)) :])
return filename[-min(50, len(filename)) :]

def _dfs(self, task: Optional[Task]) -> Path:
def _dfs(self, task: Optional[Task]) -> tuple[Optional[Task], Path]:
"""Recursively traverses the task all the way upto parent to generate unique cache file path
Returns two components - root task and a unique path mimicking the entire tree
"""
if not task:
return Path(".")
return None, Path(".")
root_node, root_path = self._dfs(task.parent)
if root_node is None:
root_node = task
if isinstance(task, ValidationError):
filename = re.sub(r"[\\/:\.]", "_", task.file)
filename = re.sub(r"\_+", "_", filename)
segments = filename.split("_")
filename = "_".join(segments[-min(3, len(segments)) :])
filename = filename[-min(50, len(filename)) :]
base_path = self._dfs(task.parent) / task.__class__.__name__ / filename
stem = Path(task.__class__.__name__)
# to minimize path, only add filepath to the root OR
# when its different than the immediate parent task
parent_file = self._parent_file_path(task)
if root_node == task or (
parent_file is not None and parent_file != task.file
):
stem /= self._clean_filename(task.file)
if isinstance(task, AnalyzerRuleViolation):
base_path = base_path / task.violation.id
return base_path
stem /= task.violation.id
else:
return (
self._dfs(task.parent)
/ task.__class__.__name__
/ f"prio_{task.priority}_depth_{task.depth}"
)
stem = Path(task.__class__.__name__) / f"depth_{task.depth}"
return root_node, root_path / stem

def cache_meta(self) -> dict[str, str]:
meta = {
"taskType": self.task.__class__.__name__,
"taskString": str(self.task),
"parent": self.task.parent.__class__.__name__,
}
if isinstance(self.task, ValidationError):
meta["file"] = self.task.file
meta["message"] = self.task.message
if isinstance(self.task.parent, ValidationError):
meta["parentFile"] = self.task.parent.file
return meta

def cache_path(self) -> Path:
path = self._dfs(self.task) / f"{self._req_count}_{self.request_type}.json"
root_node, path = self._dfs(self.task)
path /= f"{self._req_count}_{self.request_type}.json"
self._req_count += 1
return path
return self._path_with_limit(root_node, path)


class SimplePathResolver(CachePathResolver):
Expand Down
17 changes: 11 additions & 6 deletions kai/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,11 +247,16 @@ def invoke(

response = invoke_llm.invoke(input, config, stop=stop, **kwargs)

self.cache.put(
path=cache_path,
input=input,
output=response,
cache_meta=cache_meta,
)
try:
self.cache.put(
path=cache_path,
input=input,
output=response,
cache_meta=cache_meta,
)
except Exception as e:
# only raise an exception when we are in demo mode
if self.demo_mode:
raise e

return response
26 changes: 24 additions & 2 deletions tests/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tempfile
import unittest
from pathlib import Path
from unittest.mock import patch

from langchain_core.messages import AIMessage, HumanMessage

Expand All @@ -11,6 +12,7 @@
from kai.reactive_codeplanner.task_runner.compiler.maven_validator import (
DependencyResolutionError,
PackageDoesNotExistError,
SymbolNotFoundError,
)


Expand Down Expand Up @@ -44,17 +46,27 @@ def setUp(self) -> None:
self.t2 = PackageDoesNotExistError(
file="file://test/pom.xml",
line=10,
depth=1,
column=2,
message="package not found",
parent=self.t1,
)
self.t3 = DependencyResolutionError(
file="file://test/pom.xml",
line=10,
depth=2,
column=2,
message="package not found",
parent=self.t2,
)
self.t4 = SymbolNotFoundError(
file="test/src/main/java/io/konveyor/main.java",
line=10,
depth=3,
column=2,
message="cannot find symbol",
parent=self.t3,
)

self.t1_cache_expected_path = Path(
"AnalyzerRuleViolation",
Expand All @@ -77,7 +89,13 @@ def setUp(self) -> None:
"PackageDoesNotExistError",
"test_pom_xml",
"DependencyResolutionError",
"test_pom_xml",
"0_analyzerfix.json",
)
self.t4_cache_expected_path = Path(
"AnalyzerRuleViolation",
"konveyor_main_java",
"depth_3",
"SymbolNotFoundError",
"0_analyzerfix.json",
)

Expand Down Expand Up @@ -119,11 +137,15 @@ def test_task_based_path_resolver(self) -> None:
t2_cache_path.parent
/ Path(
"DependencyResolutionError",
"test_pom_xml",
"0_analyzerfix.json",
),
)

path_resolver = TaskBasedPathResolver(task=self.t4, request_type="analyzerfix")
with patch.object(path_resolver, "_limit", 50):
t4_cache_path = path_resolver.cache_path()
self.assertEqual(t4_cache_path, self.t4_cache_expected_path)

def test_json_cache(self) -> None:
cache = JSONCacheWithTrace(
cache_dir=Path(self.cache_dir),
Expand Down

0 comments on commit 03f9298

Please sign in to comment.