Skip to content

Commit

Permalink
community: Fix ChatLiteLLMRouter runtime issues (#28163)
Browse files Browse the repository at this point in the history
**Description:** Fix ChatLiteLLMRouter ctor validation and model_name
parameter
**Issue:** #19356, #27455, #28077
**Twitter handle:** @bburgin_0
  • Loading branch information
bburgin authored Dec 16, 2024
1 parent 234d496 commit 27a9056
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 71 deletions.
9 changes: 5 additions & 4 deletions docs/docs/integrations/chat/litellm_router.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,17 @@
" },\n",
" },\n",
" {\n",
" \"model_name\": \"gpt-4\",\n",
" \"model_name\": \"gpt-35-turbo\",\n",
" \"litellm_params\": {\n",
" \"model\": \"azure/gpt-4-1106-preview\",\n",
" \"model\": \"azure/gpt-35-turbo\",\n",
" \"api_key\": \"<your-api-key>\",\n",
" \"api_version\": \"2023-05-15\",\n",
" \"api_base\": \"https://<your-endpoint>.openai.azure.com/\",\n",
" },\n",
" },\n",
"]\n",
"litellm_router = Router(model_list=model_list)\n",
"chat = ChatLiteLLMRouter(router=litellm_router)"
"chat = ChatLiteLLMRouter(router=litellm_router, model_name=\"gpt-35-turbo\")"
]
},
{
Expand Down Expand Up @@ -177,6 +177,7 @@
"source": [
"chat = ChatLiteLLMRouter(\n",
" router=litellm_router,\n",
" model_name=\"gpt-35-turbo\",\n",
" streaming=True,\n",
" verbose=True,\n",
" callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n",
Expand Down Expand Up @@ -209,7 +210,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
"version": "3.11.9"
}
},
"nbformat": 4,
Expand Down
59 changes: 25 additions & 34 deletions libs/community/langchain_community/chat_models/litellm_router.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
"""LiteLLM Router as LangChain Model."""

from typing import (
Any,
AsyncIterator,
Iterator,
List,
Mapping,
Optional,
)
from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional

from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
Expand All @@ -17,24 +10,17 @@
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessageChunk,
BaseMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatGenerationChunk,
ChatResult,
)
from langchain_core.messages import AIMessageChunk, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult

from langchain_community.chat_models.litellm import (
ChatLiteLLM,
_convert_delta_to_message_chunk,
_convert_dict_to_message,
)

token_usage_key_name = "token_usage"
model_extra_key_name = "model_extra"
token_usage_key_name = "token_usage" # nosec # incorrectly flagged as password
model_extra_key_name = "model_extra" # nosec # incorrectly flagged as password


def get_llm_output(usage: Any, **params: Any) -> dict:
Expand All @@ -56,21 +42,14 @@ class ChatLiteLLMRouter(ChatLiteLLM):

def __init__(self, *, router: Any, **kwargs: Any) -> None:
"""Construct Chat LiteLLM Router."""
super().__init__(**kwargs)
super().__init__(router=router, **kwargs) # type: ignore
self.router = router

@property
def _llm_type(self) -> str:
return "LiteLLMRouter"

def _set_model_for_completion(self) -> None:
# use first model name (aka: model group),
# since we can only pass one to the router completion functions
self.model = self.router.model_list[0]["model_name"]

def _prepare_params_for_router(self, params: Any) -> None:
params["model"] = self.model

# allow the router to set api_base based on its model choice
api_base_key_name = "api_base"
if api_base_key_name in params and params[api_base_key_name] is None:
Expand All @@ -79,6 +58,22 @@ def _prepare_params_for_router(self, params: Any) -> None:
# add metadata so router can fill it below
params.setdefault("metadata", {})

def set_default_model(self, model_name: str) -> None:
"""Set the default model to use for completion calls.
Sets `self.model` to `model_name` if it is in the litellm router's
(`self.router`) model list. This provides the default model to use
for completion calls if no `model` kwarg is provided.
"""
model_list = self.router.model_list
if not model_list:
raise ValueError("model_list is None or empty.")
for entry in model_list:
if entry["model_name"] == model_name:
self.model = model_name
return
raise ValueError(f"Model {model_name} not found in model_list.")

def _generate(
self,
messages: List[BaseMessage],
Expand All @@ -96,7 +91,6 @@ def _generate(

message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
self._set_model_for_completion()
self._prepare_params_for_router(params)

response = self.router.completion(
Expand All @@ -115,7 +109,6 @@ def _stream(
default_chunk_class = AIMessageChunk
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
self._set_model_for_completion()
self._prepare_params_for_router(params)

for chunk in self.router.completion(messages=message_dicts, **params):
Expand All @@ -139,7 +132,6 @@ async def _astream(
default_chunk_class = AIMessageChunk
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs, "stream": True}
self._set_model_for_completion()
self._prepare_params_for_router(params)

async for chunk in await self.router.acompletion(
Expand Down Expand Up @@ -174,7 +166,6 @@ async def _agenerate(

message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
self._set_model_for_completion()
self._prepare_params_for_router(params)

response = await self.router.acompletion(
Expand All @@ -196,14 +187,14 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
token_usage = output["token_usage"]
if token_usage is not None:
# get dict from LiteLLM Usage class
for k, v in token_usage.dict().items():
if k in overall_token_usage:
for k, v in token_usage.model_dump().items():
if k in overall_token_usage and overall_token_usage[k] is not None:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
if system_fingerprint is None:
system_fingerprint = output.get("system_fingerprint")
combined = {"token_usage": overall_token_usage, "model_name": self.model_name}
combined = {"token_usage": overall_token_usage, "model_name": self.model}
if system_fingerprint:
combined["system_fingerprint"] = system_fingerprint
return combined
Expand Down
Loading

0 comments on commit 27a9056

Please sign in to comment.