From 27a905672595f01a2321749194f77477071a74c2 Mon Sep 17 00:00:00 2001 From: Brian Burgin Date: Mon, 16 Dec 2024 17:17:39 -0600 Subject: [PATCH] community: Fix ChatLiteLLMRouter runtime issues (#28163) **Description:** Fix ChatLiteLLMRouter ctor validation and model_name parameter **Issue:** #19356, #27455, #28077 **Twitter handle:** @bburgin_0 --- .../integrations/chat/litellm_router.ipynb | 9 +- .../chat_models/litellm_router.py | 59 ++++----- .../chat_models/test_litellm_router.py | 118 +++++++++++++----- 3 files changed, 115 insertions(+), 71 deletions(-) diff --git a/docs/docs/integrations/chat/litellm_router.ipynb b/docs/docs/integrations/chat/litellm_router.ipynb index 7a7d0fd1218c8..af56c657547a0 100644 --- a/docs/docs/integrations/chat/litellm_router.ipynb +++ b/docs/docs/integrations/chat/litellm_router.ipynb @@ -63,9 +63,9 @@ " },\n", " },\n", " {\n", - " \"model_name\": \"gpt-4\",\n", + " \"model_name\": \"gpt-35-turbo\",\n", " \"litellm_params\": {\n", - " \"model\": \"azure/gpt-4-1106-preview\",\n", + " \"model\": \"azure/gpt-35-turbo\",\n", " \"api_key\": \"\",\n", " \"api_version\": \"2023-05-15\",\n", " \"api_base\": \"https://.openai.azure.com/\",\n", @@ -73,7 +73,7 @@ " },\n", "]\n", "litellm_router = Router(model_list=model_list)\n", - "chat = ChatLiteLLMRouter(router=litellm_router)" + "chat = ChatLiteLLMRouter(router=litellm_router, model_name=\"gpt-35-turbo\")" ] }, { @@ -177,6 +177,7 @@ "source": [ "chat = ChatLiteLLMRouter(\n", " router=litellm_router,\n", + " model_name=\"gpt-35-turbo\",\n", " streaming=True,\n", " verbose=True,\n", " callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),\n", @@ -209,7 +210,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.13" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/libs/community/langchain_community/chat_models/litellm_router.py b/libs/community/langchain_community/chat_models/litellm_router.py index d1932113cde9c..ee3dee32e1b97 100644 --- a/libs/community/langchain_community/chat_models/litellm_router.py +++ b/libs/community/langchain_community/chat_models/litellm_router.py @@ -1,13 +1,6 @@ """LiteLLM Router as LangChain Model.""" -from typing import ( - Any, - AsyncIterator, - Iterator, - List, - Mapping, - Optional, -) +from typing import Any, AsyncIterator, Iterator, List, Mapping, Optional from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -17,15 +10,8 @@ agenerate_from_stream, generate_from_stream, ) -from langchain_core.messages import ( - AIMessageChunk, - BaseMessage, -) -from langchain_core.outputs import ( - ChatGeneration, - ChatGenerationChunk, - ChatResult, -) +from langchain_core.messages import AIMessageChunk, BaseMessage +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_community.chat_models.litellm import ( ChatLiteLLM, @@ -33,8 +19,8 @@ _convert_dict_to_message, ) -token_usage_key_name = "token_usage" -model_extra_key_name = "model_extra" +token_usage_key_name = "token_usage" # nosec # incorrectly flagged as password +model_extra_key_name = "model_extra" # nosec # incorrectly flagged as password def get_llm_output(usage: Any, **params: Any) -> dict: @@ -56,21 +42,14 @@ class ChatLiteLLMRouter(ChatLiteLLM): def __init__(self, *, router: Any, **kwargs: Any) -> None: """Construct Chat LiteLLM Router.""" - super().__init__(**kwargs) + super().__init__(router=router, **kwargs) # type: ignore self.router = router @property def _llm_type(self) -> str: return "LiteLLMRouter" - def _set_model_for_completion(self) -> None: - # use first model name (aka: model group), - # since we can only pass one to the router completion functions - self.model = self.router.model_list[0]["model_name"] - def _prepare_params_for_router(self, params: Any) -> None: - params["model"] = self.model - # allow the router to set api_base based on its model choice api_base_key_name = "api_base" if api_base_key_name in params and params[api_base_key_name] is None: @@ -79,6 +58,22 @@ def _prepare_params_for_router(self, params: Any) -> None: # add metadata so router can fill it below params.setdefault("metadata", {}) + def set_default_model(self, model_name: str) -> None: + """Set the default model to use for completion calls. + + Sets `self.model` to `model_name` if it is in the litellm router's + (`self.router`) model list. This provides the default model to use + for completion calls if no `model` kwarg is provided. + """ + model_list = self.router.model_list + if not model_list: + raise ValueError("model_list is None or empty.") + for entry in model_list: + if entry["model_name"] == model_name: + self.model = model_name + return + raise ValueError(f"Model {model_name} not found in model_list.") + def _generate( self, messages: List[BaseMessage], @@ -96,7 +91,6 @@ def _generate( message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} - self._set_model_for_completion() self._prepare_params_for_router(params) response = self.router.completion( @@ -115,7 +109,6 @@ def _stream( default_chunk_class = AIMessageChunk message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} - self._set_model_for_completion() self._prepare_params_for_router(params) for chunk in self.router.completion(messages=message_dicts, **params): @@ -139,7 +132,6 @@ async def _astream( default_chunk_class = AIMessageChunk message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs, "stream": True} - self._set_model_for_completion() self._prepare_params_for_router(params) async for chunk in await self.router.acompletion( @@ -174,7 +166,6 @@ async def _agenerate( message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} - self._set_model_for_completion() self._prepare_params_for_router(params) response = await self.router.acompletion( @@ -196,14 +187,14 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict: token_usage = output["token_usage"] if token_usage is not None: # get dict from LiteLLM Usage class - for k, v in token_usage.dict().items(): - if k in overall_token_usage: + for k, v in token_usage.model_dump().items(): + if k in overall_token_usage and overall_token_usage[k] is not None: overall_token_usage[k] += v else: overall_token_usage[k] = v if system_fingerprint is None: system_fingerprint = output.get("system_fingerprint") - combined = {"token_usage": overall_token_usage, "model_name": self.model_name} + combined = {"token_usage": overall_token_usage, "model_name": self.model} if system_fingerprint: combined["system_fingerprint"] = system_fingerprint return combined diff --git a/libs/community/tests/integration_tests/chat_models/test_litellm_router.py b/libs/community/tests/integration_tests/chat_models/test_litellm_router.py index 23f60398d6f74..c2d8ce85e0156 100644 --- a/libs/community/tests/integration_tests/chat_models/test_litellm_router.py +++ b/libs/community/tests/integration_tests/chat_models/test_litellm_router.py @@ -1,8 +1,20 @@ """Test LiteLLM Router API wrapper.""" import asyncio +import queue +import threading from copy import deepcopy -from typing import Any, AsyncGenerator, Coroutine, Dict, List, Tuple, Union, cast +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Dict, + Generator, + List, + Tuple, + Union, + cast, +) import pytest from langchain_core.messages import AIMessage, BaseMessage, HumanMessage @@ -11,7 +23,8 @@ from langchain_community.chat_models.litellm_router import ChatLiteLLMRouter from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler -model_group = "gpt-4" +model_group_gpt4 = "gpt-4" +model_group_to_test = "gpt-35-turbo" fake_model_prefix = "azure/fake-deployment-name-" fake_models_names = [fake_model_prefix + suffix for suffix in ["1", "2"]] fake_api_key = "fakekeyvalue" @@ -23,7 +36,7 @@ model_list = [ { - "model_name": model_group, + "model_name": model_group_gpt4, "litellm_params": { "model": fake_models_names[0], "api_key": fake_api_key, @@ -32,7 +45,7 @@ }, }, { - "model_name": model_group, + "model_name": model_group_to_test, "litellm_params": { "model": fake_models_names[1], "api_key": fake_api_key, @@ -43,6 +56,39 @@ ] +# from https://stackoverflow.com/a/78573267 +def aiter_to_iter(it: AsyncIterator) -> Generator: + "Convert an async iterator into a regular (sync) iterator." + q_in: queue.SimpleQueue = queue.SimpleQueue() + q_out: queue.SimpleQueue = queue.SimpleQueue() + + async def threadmain() -> None: + try: + # Wait until the sync generator requests an item before continuing + while q_in.get(): + q_out.put((True, await it.__anext__())) + except StopAsyncIteration: + q_out.put((False, None)) + except BaseException as ex: + q_out.put((False, ex)) + + thread = threading.Thread(target=asyncio.run, args=(threadmain(),), daemon=True) + thread.start() + + try: + while True: + q_in.put(True) + cont, result = q_out.get() + if cont: + yield result + elif result is None: + break + else: + raise result + finally: + q_in.put(False) + + class FakeCompletion: def __init__(self) -> None: self.seen_inputs: List[Any] = [] @@ -55,13 +101,6 @@ def _get_new_result_and_choices( choices = cast(List[Dict[str, Any]], result["choices"]) return result, choices - @staticmethod - def _get_next_result( - agen: AsyncGenerator[Dict[str, Any], None], - ) -> Dict[str, Any]: - coroutine = cast(Coroutine, agen.__anext__()) - return asyncio.run(coroutine) - async def _get_fake_results_agenerator( self, **kwargs: Any ) -> AsyncGenerator[Dict[str, Any], None]: @@ -76,7 +115,7 @@ async def _get_fake_results_agenerator( ], "created": 0, "id": "", - "model": model_group, + "model": model_group_to_test, "object": "chat.completion", } if kwargs["stream"]: @@ -115,17 +154,18 @@ async def _get_fake_results_agenerator( def completion(self, **kwargs: Any) -> Union[List, Dict[str, Any]]: agen = self._get_fake_results_agenerator(**kwargs) + synchronous_iter = aiter_to_iter(agen) if kwargs["stream"]: results: List[Dict[str, Any]] = [] while True: try: - results.append(self._get_next_result(agen)) - except StopAsyncIteration: + results.append(synchronous_iter.__next__()) + except StopIteration: break return results else: # there is only one result for non-streaming - return self._get_next_result(agen) + return synchronous_iter.__next__() async def acompletion( self, **kwargs: Any @@ -142,7 +182,7 @@ def check_inputs(self, expected_num_calls: int) -> None: for kwargs in self.seen_inputs: metadata = kwargs["metadata"] - assert metadata["model_group"] == model_group + assert metadata["model_group"] == model_group_to_test # LiteLLM router chooses one model name from the model_list assert kwargs["model"] in fake_models_names @@ -172,17 +212,16 @@ def litellm_router() -> Any: """LiteLLM router for testing.""" from litellm import Router - return Router( - model_list=model_list, - ) + return Router(model_list=model_list) @pytest.mark.scheduled +@pytest.mark.enable_socket def test_litellm_router_call( fake_completion: FakeCompletion, litellm_router: Any ) -> None: """Test valid call to LiteLLM Router.""" - chat = ChatLiteLLMRouter(router=litellm_router) + chat = ChatLiteLLMRouter(router=litellm_router, model_name=model_group_to_test) message = HumanMessage(content="Hello") response = chat.invoke([message]) @@ -195,13 +234,12 @@ def test_litellm_router_call( @pytest.mark.scheduled +@pytest.mark.enable_socket def test_litellm_router_generate( fake_completion: FakeCompletion, litellm_router: Any ) -> None: """Test generate method of LiteLLM Router.""" - from litellm import Usage - - chat = ChatLiteLLMRouter(router=litellm_router) + chat = ChatLiteLLMRouter(router=litellm_router, model_name=model_group_to_test) chat_messages: List[List[BaseMessage]] = [ [HumanMessage(content="How many toes do dogs have?")] ] @@ -219,18 +257,25 @@ def test_litellm_router_generate( assert generation.message.content == fake_answer assert chat_messages == messages_copy assert result.llm_output is not None - assert result.llm_output[token_usage_key_name] == Usage( - completion_tokens=1, prompt_tokens=2, total_tokens=3 - ) + assert result.llm_output[token_usage_key_name] == { + "completion_tokens": 1, + "completion_tokens_details": None, + "prompt_tokens": 2, + "prompt_tokens_details": None, + "total_tokens": 3, + } fake_completion.check_inputs(expected_num_calls=1) @pytest.mark.scheduled +@pytest.mark.enable_socket def test_litellm_router_streaming( fake_completion: FakeCompletion, litellm_router: Any ) -> None: """Test streaming tokens from LiteLLM Router.""" - chat = ChatLiteLLMRouter(router=litellm_router, streaming=True) + chat = ChatLiteLLMRouter( + router=litellm_router, model_name=model_group_to_test, streaming=True + ) message = HumanMessage(content="Hello") response = chat.invoke([message]) @@ -243,6 +288,7 @@ def test_litellm_router_streaming( @pytest.mark.scheduled +@pytest.mark.enable_socket def test_litellm_router_streaming_callback( fake_completion: FakeCompletion, litellm_router: Any ) -> None: @@ -250,6 +296,7 @@ def test_litellm_router_streaming_callback( callback_handler = FakeCallbackHandler() chat = ChatLiteLLMRouter( router=litellm_router, + model_name=model_group_to_test, streaming=True, callbacks=[callback_handler], verbose=True, @@ -267,13 +314,12 @@ def test_litellm_router_streaming_callback( @pytest.mark.scheduled +@pytest.mark.enable_socket async def test_async_litellm_router( fake_completion: FakeCompletion, litellm_router: Any ) -> None: """Test async generation.""" - from litellm import Usage - - chat = ChatLiteLLMRouter(router=litellm_router) + chat = ChatLiteLLMRouter(router=litellm_router, model_name=model_group_to_test) message = HumanMessage(content="Hello") response = await chat.agenerate([[message], [message]]) @@ -288,13 +334,18 @@ async def test_async_litellm_router( assert generation.message.content == generation.text assert generation.message.content == fake_answer assert response.llm_output is not None - assert response.llm_output[token_usage_key_name] == Usage( - completion_tokens=2, prompt_tokens=4, total_tokens=6 - ) + assert response.llm_output[token_usage_key_name] == { + "completion_tokens": 2, + "completion_tokens_details": None, + "prompt_tokens": 4, + "prompt_tokens_details": None, + "total_tokens": 6, + } fake_completion.check_inputs(expected_num_calls=2) @pytest.mark.scheduled +@pytest.mark.enable_socket async def test_async_litellm_router_streaming( fake_completion: FakeCompletion, litellm_router: Any ) -> None: @@ -302,6 +353,7 @@ async def test_async_litellm_router_streaming( callback_handler = FakeCallbackHandler() chat = ChatLiteLLMRouter( router=litellm_router, + model_name=model_group_to_test, streaming=True, callbacks=[callback_handler], verbose=True,