From ede4f9e8257da1bd74bdb01037972b8308e38afe Mon Sep 17 00:00:00 2001 From: JonahSussman Date: Wed, 5 Feb 2025 12:30:59 -0500 Subject: [PATCH] Fixed review comments Signed-off-by: JonahSussman --- kai/cache.py | 16 ++++++++++++++-- kai/llm_interfacing/model_provider.py | 25 ++++++++++++++----------- kai/rpc_server/server.py | 2 -- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/kai/cache.py b/kai/cache.py index 89a0f898..bf23fd62 100644 --- a/kai/cache.py +++ b/kai/cache.py @@ -22,10 +22,22 @@ class CachePathResolver(ABC): """ @abstractmethod - def cache_path(self) -> Path: ... + def cache_path(self) -> Path: + """ + Generates a path to store cache + + NOTE: This method should only be called once per desired path! You + should store the result in a variable if you need to use it multiple + times. + """ + ... @abstractmethod - def cache_meta(self) -> dict[str, str]: ... + def cache_meta(self) -> dict[str, str]: + """ + Generates metadata to store with cache + """ + ... class Cache(ABC): diff --git a/kai/llm_interfacing/model_provider.py b/kai/llm_interfacing/model_provider.py index 81cc55e8..11348292 100644 --- a/kai/llm_interfacing/model_provider.py +++ b/kai/llm_interfacing/model_provider.py @@ -205,22 +205,25 @@ def invoke( else: invoke_llm = self.llm - if self.demo_mode and self.cache and cache_path_resolver: - cache_entry = self.cache.get( - path=cache_path_resolver.cache_path(), input=input - ) + if not (self.cache and cache_path_resolver): + return invoke_llm.invoke(input, config, stop=stop, **kwargs) + + cache_path = cache_path_resolver.cache_path() + cache_meta = cache_path_resolver.cache_meta() + + if self.demo_mode: + cache_entry = self.cache.get(path=cache_path, input=input) if cache_entry: return cache_entry response = invoke_llm.invoke(input, config, stop=stop, **kwargs) - if self.cache and cache_path_resolver: - self.cache.put( - path=cache_path_resolver.cache_path(), - input=input, - output=response, - cache_meta=cache_path_resolver.cache_meta(), - ) + self.cache.put( + path=cache_path, + input=input, + output=response, + cache_meta=cache_meta, + ) return response diff --git a/kai/rpc_server/server.py b/kai/rpc_server/server.py index 34212e09..a5308b01 100644 --- a/kai/rpc_server/server.py +++ b/kai/rpc_server/server.py @@ -202,7 +202,6 @@ def initialize( model_provider.validate_environment() except Exception as e: - server.shutdown_flag = True server.send_response( id=id, error=JsonRpcError( @@ -226,7 +225,6 @@ def initialize( excluded_paths=app.config.analyzer_lsp_excluded_paths, ) except Exception as e: - server.shutdown_flag = True server.send_response( id=id, error=JsonRpcError(