Skip to content

Commit

Permalink
Fixed review comments
Browse files Browse the repository at this point in the history
Signed-off-by: JonahSussman <[email protected]>
  • Loading branch information
JonahSussman committed Feb 5, 2025
1 parent 88745c9 commit ede4f9e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
16 changes: 14 additions & 2 deletions kai/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,22 @@ class CachePathResolver(ABC):
"""

@abstractmethod
def cache_path(self) -> Path: ...
def cache_path(self) -> Path:
"""
Generates a path to store cache
NOTE: This method should only be called once per desired path! You
should store the result in a variable if you need to use it multiple
times.
"""
...

@abstractmethod
def cache_meta(self) -> dict[str, str]: ...
def cache_meta(self) -> dict[str, str]:
"""
Generates metadata to store with cache
"""
...


class Cache(ABC):
Expand Down
25 changes: 14 additions & 11 deletions kai/llm_interfacing/model_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,22 +205,25 @@ def invoke(
else:
invoke_llm = self.llm

if self.demo_mode and self.cache and cache_path_resolver:
cache_entry = self.cache.get(
path=cache_path_resolver.cache_path(), input=input
)
if not (self.cache and cache_path_resolver):
return invoke_llm.invoke(input, config, stop=stop, **kwargs)

cache_path = cache_path_resolver.cache_path()
cache_meta = cache_path_resolver.cache_meta()

if self.demo_mode:
cache_entry = self.cache.get(path=cache_path, input=input)

if cache_entry:
return cache_entry

response = invoke_llm.invoke(input, config, stop=stop, **kwargs)

if self.cache and cache_path_resolver:
self.cache.put(
path=cache_path_resolver.cache_path(),
input=input,
output=response,
cache_meta=cache_path_resolver.cache_meta(),
)
self.cache.put(
path=cache_path,
input=input,
output=response,
cache_meta=cache_meta,
)

return response
2 changes: 0 additions & 2 deletions kai/rpc_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def initialize(

model_provider.validate_environment()
except Exception as e:
server.shutdown_flag = True
server.send_response(
id=id,
error=JsonRpcError(
Expand All @@ -226,7 +225,6 @@ def initialize(
excluded_paths=app.config.analyzer_lsp_excluded_paths,
)
except Exception as e:
server.shutdown_flag = True
server.send_response(
id=id,
error=JsonRpcError(
Expand Down

0 comments on commit ede4f9e

Please sign in to comment.