From cc8feb06e889916d4067bcb4a21ead1185cf90c2 Mon Sep 17 00:00:00 2001 From: William Webber Date: Tue, 7 Nov 2023 12:46:39 +1100 Subject: [PATCH 1/5] Add "streaming" parameter to ChatFireworks --- .../langchain/chat_models/fireworks.py | 30 +++++++- .../chat_models/test_fireworks.py | 77 ++++++++++++++++++- 2 files changed, 105 insertions(+), 2 deletions(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index 36a7d582369b8..d3e5acbf0cea1 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -15,7 +15,11 @@ AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.chat_models.base import BaseChatModel +from langchain.chat_models.base import ( + BaseChatModel, + _agenerate_from_stream, + _generate_from_stream +) from langchain.llms.base import create_base_retry_decorator from langchain.pydantic_v1 import Field, SecretStr, root_validator from langchain.schema.messages import ( @@ -91,6 +95,10 @@ class ChatFireworks(BaseChatModel): fireworks_api_key: Optional[SecretStr] = None max_retries: int = 20 use_retry: bool = True + streaming: bool = False + """Whether to stream the results or not.""" + n: int = 1 + """Number of chat completions to generate for each prompt.""" @property def lc_secrets(self) -> Dict[str, str]: @@ -114,6 +122,10 @@ def validate_environment(cls, values: Dict) -> Dict: get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY") ) fireworks.client.api_key = fireworks_api_key.get_secret_value() + if values["n"] < 1: + raise ValueError("n must be at least 1.") + if values["n"] > 1 and values["streaming"]: + raise ValueError("n must be 1 when streaming.") return values @property @@ -126,8 +138,16 @@ def _generate( messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: + should_stream = stream if stream is not None else self.streaming + if should_stream: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return _generate_from_stream(stream_iter) + message_dicts = self._create_message_dicts(messages) params = { @@ -149,9 +169,17 @@ async def _agenerate( messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, **kwargs: Any, ) -> ChatResult: + should_stream = stream if stream is not None else self.streaming + if should_stream: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await _agenerate_from_stream(stream_iter) message_dicts = self._create_message_dicts(messages) + params = { "model": self.model, "messages": message_dicts, diff --git a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py index 43657cdae3c35..d8759cf163a19 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py @@ -1,13 +1,16 @@ """Test ChatFireworks wrapper.""" import sys -from typing import cast +from typing import cast, Any import pytest +from langchain.callbacks.manager import CallbackManager from langchain.chat_models.fireworks import ChatFireworks from langchain.schema import ChatGeneration, ChatResult, LLMResult from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage +from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler + if sys.version_info < (3, 9): pytest.skip("fireworks-ai requires Python > 3.8", allow_module_level=True) @@ -72,6 +75,52 @@ def test_chat_fireworks_multiple_completions() -> None: assert isinstance(generation.message.content, str) +@pytest.mark.scheduled +def test_chat_fireworks_streaming() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chat = ChatFireworks( + max_tokens=10, + streaming=True, + temperature=0, + callback_manager=callback_manager, + verbose=True, + ) + message = HumanMessage(content="Hello") + response = chat([message]) + assert callback_handler.llm_streams > 0 + assert isinstance(response, BaseMessage) + + +@pytest.mark.scheduled +def test_chat_fireworks_streaming_generation_info() -> None: + """Test that generation info is preserved when streaming.""" + + class _FakeCallback(FakeCallbackHandler): + saved_things: dict = {} + + def on_llm_end( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + # Save the generation + self.saved_things["generation"] = args[0] + + callback = _FakeCallback() + callback_manager = CallbackManager([callback]) + chat = ChatFireworks( + max_tokens=2, + temperature=0, + callback_manager=callback_manager, + ) + list(chat.stream("say 'Hello!' only")) + generation = callback.saved_things["generation"] + # `Hello!` is two tokens, assert that that is what is returned + assert generation.generations[0][0].text == "Hello!" + + @pytest.mark.scheduled def test_chat_fireworks_llm_output_contains_model_id(chat: ChatFireworks) -> None: """Test llm_output contains model_id.""" @@ -98,6 +147,32 @@ async def test_fireworks_ainvoke(chat: ChatFireworks) -> None: assert result.content[-1] == "," +@pytest.mark.scheduled +@pytest.mark.asyncio +async def test_async_chat_fireworks_streaming() -> None: + """Test that streaming correctly invokes on_llm_new_token callback.""" + callback_handler = FakeCallbackHandler() + callback_manager = CallbackManager([callback_handler]) + chat = ChatFireworks( + max_tokens=10, + streaming=True, + temperature=0, + callback_manager=callback_manager, + verbose=True, + ) + message = HumanMessage(content="Hello") + response = await chat.agenerate([[message], [message]]) + assert callback_handler.llm_streams > 0 + assert isinstance(response, LLMResult) + assert len(response.generations) == 2 + for generations in response.generations: + assert len(generations) == 1 + for generation in generations: + assert isinstance(generation, ChatGeneration) + assert isinstance(generation.text, str) + assert generation.text == generation.message.content + + @pytest.mark.scheduled def test_fireworks_batch(chat: ChatFireworks) -> None: """Test batch tokens from ChatFireworks.""" From c22bb98eab0ae08b2791a21f423820fcd001246d Mon Sep 17 00:00:00 2001 From: William Webber Date: Tue, 7 Nov 2023 15:21:08 +1100 Subject: [PATCH 2/5] Revert "n" as field; requires full kwargs handling --- libs/langchain/langchain/chat_models/fireworks.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index d3e5acbf0cea1..3557c742bf60e 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -97,8 +97,6 @@ class ChatFireworks(BaseChatModel): use_retry: bool = True streaming: bool = False """Whether to stream the results or not.""" - n: int = 1 - """Number of chat completions to generate for each prompt.""" @property def lc_secrets(self) -> Dict[str, str]: @@ -122,10 +120,6 @@ def validate_environment(cls, values: Dict) -> Dict: get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY") ) fireworks.client.api_key = fireworks_api_key.get_secret_value() - if values["n"] < 1: - raise ValueError("n must be at least 1.") - if values["n"] > 1 and values["streaming"]: - raise ValueError("n must be 1 when streaming.") return values @property From 4c767f8939c99f442e9e7f4cad058820410bc93e Mon Sep 17 00:00:00 2001 From: wewebber-merlin <138414820+wewebber-merlin@users.noreply.github.com> Date: Tue, 7 Nov 2023 15:35:24 +1100 Subject: [PATCH 3/5] Update test_fireworks.py Fix ordering of import. --- .../tests/integration_tests/chat_models/test_fireworks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py index d8759cf163a19..71a195b32585f 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py @@ -1,6 +1,6 @@ """Test ChatFireworks wrapper.""" import sys -from typing import cast, Any +from typing import Any, cast import pytest From 9e535cb2f1f92ff4df89ecea011e62ee10b63560 Mon Sep 17 00:00:00 2001 From: wewebber-merlin <138414820+wewebber-merlin@users.noreply.github.com> Date: Wed, 8 Nov 2023 09:17:28 +1100 Subject: [PATCH 4/5] Formatting in import --- libs/langchain/langchain/chat_models/fireworks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index 3557c742bf60e..7a6687a51a3ee 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -18,7 +18,7 @@ from langchain.chat_models.base import ( BaseChatModel, _agenerate_from_stream, - _generate_from_stream + _generate_from_stream, ) from langchain.llms.base import create_base_retry_decorator from langchain.pydantic_v1 import Field, SecretStr, root_validator From a393a6df07b9438b2ae1a8c71ed778cb000de24c Mon Sep 17 00:00:00 2001 From: William Webber Date: Wed, 8 Nov 2023 09:25:00 +1100 Subject: [PATCH 5/5] Formatting issues. --- libs/langchain/langchain/chat_models/fireworks.py | 2 +- .../tests/integration_tests/chat_models/test_fireworks.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/chat_models/fireworks.py b/libs/langchain/langchain/chat_models/fireworks.py index 3557c742bf60e..7a6687a51a3ee 100644 --- a/libs/langchain/langchain/chat_models/fireworks.py +++ b/libs/langchain/langchain/chat_models/fireworks.py @@ -18,7 +18,7 @@ from langchain.chat_models.base import ( BaseChatModel, _agenerate_from_stream, - _generate_from_stream + _generate_from_stream, ) from langchain.llms.base import create_base_retry_decorator from langchain.pydantic_v1 import Field, SecretStr, root_validator diff --git a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py index d8759cf163a19..b12726b513add 100644 --- a/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py +++ b/libs/langchain/tests/integration_tests/chat_models/test_fireworks.py @@ -1,6 +1,6 @@ """Test ChatFireworks wrapper.""" import sys -from typing import cast, Any +from typing import Any, cast import pytest @@ -8,7 +8,6 @@ from langchain.chat_models.fireworks import ChatFireworks from langchain.schema import ChatGeneration, ChatResult, LLMResult from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage - from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler if sys.version_info < (3, 9):