Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add #12978

Closed
Closed

Add #12978

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion libs/langchain/langchain/chat_models/fireworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.chat_models.base import (
BaseChatModel,
_agenerate_from_stream,
_generate_from_stream,
)
from langchain.llms.base import create_base_retry_decorator
from langchain.pydantic_v1 import Field, SecretStr, root_validator
from langchain.schema.messages import (
Expand Down Expand Up @@ -91,6 +95,8 @@ class ChatFireworks(BaseChatModel):
fireworks_api_key: Optional[SecretStr] = None
max_retries: int = 20
use_retry: bool = True
streaming: bool = False
"""Whether to stream the results or not."""

@property
def lc_secrets(self) -> Dict[str, str]:
Expand Down Expand Up @@ -126,8 +132,16 @@ def _generate(
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return _generate_from_stream(stream_iter)

message_dicts = self._create_message_dicts(messages)

params = {
Expand All @@ -149,9 +163,17 @@ async def _agenerate(
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await _agenerate_from_stream(stream_iter)
message_dicts = self._create_message_dicts(messages)

params = {
"model": self.model,
"messages": message_dicts,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Test ChatFireworks wrapper."""
import sys
from typing import cast
from typing import Any, cast

import pytest

from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.fireworks import ChatFireworks
from langchain.schema import ChatGeneration, ChatResult, LLMResult
from langchain.schema.messages import BaseMessage, HumanMessage, SystemMessage
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler

if sys.version_info < (3, 9):
pytest.skip("fireworks-ai requires Python > 3.8", allow_module_level=True)
Expand Down Expand Up @@ -72,6 +74,52 @@ def test_chat_fireworks_multiple_completions() -> None:
assert isinstance(generation.message.content, str)


@pytest.mark.scheduled
def test_chat_fireworks_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatFireworks(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
message = HumanMessage(content="Hello")
response = chat([message])
assert callback_handler.llm_streams > 0
assert isinstance(response, BaseMessage)


@pytest.mark.scheduled
def test_chat_fireworks_streaming_generation_info() -> None:
"""Test that generation info is preserved when streaming."""

class _FakeCallback(FakeCallbackHandler):
saved_things: dict = {}

def on_llm_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
# Save the generation
self.saved_things["generation"] = args[0]

callback = _FakeCallback()
callback_manager = CallbackManager([callback])
chat = ChatFireworks(
max_tokens=2,
temperature=0,
callback_manager=callback_manager,
)
list(chat.stream("say 'Hello!' only"))
generation = callback.saved_things["generation"]
# `Hello!` is two tokens, assert that that is what is returned
assert generation.generations[0][0].text == "Hello!"


@pytest.mark.scheduled
def test_chat_fireworks_llm_output_contains_model_id(chat: ChatFireworks) -> None:
"""Test llm_output contains model_id."""
Expand All @@ -98,6 +146,32 @@ async def test_fireworks_ainvoke(chat: ChatFireworks) -> None:
assert result.content[-1] == ","


@pytest.mark.scheduled
@pytest.mark.asyncio
async def test_async_chat_fireworks_streaming() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatFireworks(
max_tokens=10,
streaming=True,
temperature=0,
callback_manager=callback_manager,
verbose=True,
)
message = HumanMessage(content="Hello")
response = await chat.agenerate([[message], [message]])
assert callback_handler.llm_streams > 0
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content


@pytest.mark.scheduled
def test_fireworks_batch(chat: ChatFireworks) -> None:
"""Test batch tokens from ChatFireworks."""
Expand Down
Loading