diff --git a/pyproject.toml b/pyproject.toml index d3ed84b..6094b50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "boto3", "colorama", "textual", - "litellm", + "litellm==1.52.16", "numpydoc" ] diff --git a/r2ai/auto.py b/r2ai/auto.py index 2a68377..245ebb6 100644 --- a/r2ai/auto.py +++ b/r2ai/auto.py @@ -317,4 +317,4 @@ def chat(interpreter, **kwargs): finally: signal.signal(signal.SIGINT, original_handler) spinner.stop() - litellm.in_memory_llm_clients_cache.clear() \ No newline at end of file + litellm.in_memory_llm_clients_cache.flush_cache() diff --git a/r2ai/interpreter.py b/r2ai/interpreter.py index 214b9f8..6b71cf1 100644 --- a/r2ai/interpreter.py +++ b/r2ai/interpreter.py @@ -88,7 +88,8 @@ def ddg(m): return f"Considering:\n```{res}\n```\n" def is_litellm_model(model): - from litellm import models_by_provider + import litellm + litellm.drop_params = True provider = None model_name = None if model.startswith ("/"): @@ -97,7 +98,7 @@ def is_litellm_model(model): provider, model_name = model.split(":") elif "/" in model: provider, model_name = model.split("/") - if provider in models_by_provider and model_name in models_by_provider[provider]: + if provider in litellm.models_by_provider and (model_name in litellm.models_by_provider[provider] or model in litellm.models_by_provider[provider]): return True return False @@ -423,8 +424,8 @@ def respond(self): max_completion_tokens=maxtokens, temperature=float(self.env["llm.temperature"]), top_p=float(self.env["llm.top_p"]), - stop=self.terminator, ) + response = completion.choices[0].message.content if "content" in self.messages[-1]: last_message = self.messages[-1]["content"]