Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis committed Feb 23, 2024
1 parent d900521 commit a325132
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 23 deletions.
10 changes: 8 additions & 2 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ def _stream(
params = self._format_params(messages=messages, stop=stop, **kwargs)
with self._client.messages.stream(**params) as stream:
for text in stream.text_stream:
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk

async def _astream(
self,
Expand All @@ -172,7 +175,10 @@ async def _astream(
params = self._format_params(messages=messages, stop=stop, **kwargs)
async with self._async_client.messages.stream(**params) as stream:
async for text in stream.text_stream:
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
await run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk

def _generate(
self,
Expand Down
34 changes: 13 additions & 21 deletions libs/partners/anthropic/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@

from typing import List

import pytest
from langchain_core.callbacks import CallbackManager
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_core.prompts import ChatPromptTemplate

from langchain_anthropic.chat_models import ChatAnthropic, ChatAnthropicMessages
from langchain_anthropic import ChatAnthropic, ChatAnthropicMessages
from tests.unit_tests._utils import FakeCallbackHandler


Expand Down Expand Up @@ -94,7 +93,6 @@ def test_system_invoke() -> None:
assert isinstance(result.content, str)


@pytest.mark.scheduled
def test_anthropic_call() -> None:
"""Test valid call to anthropic."""
chat = ChatAnthropic(model="test")
Expand All @@ -104,7 +102,6 @@ def test_anthropic_call() -> None:
assert isinstance(response.content, str)


@pytest.mark.scheduled
def test_anthropic_generate() -> None:
"""Test generate method of anthropic."""
chat = ChatAnthropic(model="test")
Expand All @@ -121,50 +118,45 @@ def test_anthropic_generate() -> None:
assert chat_messages == messages_copy


@pytest.mark.scheduled
def test_anthropic_streaming() -> None:
"""Test streaming tokens from anthropic."""
chat = ChatAnthropic(model="test", streaming=True)
chat = ChatAnthropic(model="test")
message = HumanMessage(content="Hello")
response = chat([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
response = chat.stream([message])
for token in response:
assert isinstance(token, AIMessageChunk)
assert isinstance(token.content, str)


@pytest.mark.scheduled
def test_anthropic_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatAnthropic(
model="test",
streaming=True,
callback_manager=callback_manager,
verbose=True,
)
message = HumanMessage(content="Write me a sentence with 10 words.")
chat([message])
for token in chat.stream([message]):
assert isinstance(token, AIMessageChunk)
assert isinstance(token.content, str)
assert callback_handler.llm_streams > 1


@pytest.mark.scheduled
async def test_anthropic_async_streaming_callback() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
chat = ChatAnthropic(
model="test",
streaming=True,
callback_manager=callback_manager,
verbose=True,
)
chat_messages: List[BaseMessage] = [
HumanMessage(content="How many toes do dogs have?")
]
result: LLMResult = await chat.agenerate([chat_messages])
async for token in chat.astream(chat_messages):
assert isinstance(token, AIMessageChunk)
assert isinstance(token.content, str)
assert callback_handler.llm_streams > 1
assert isinstance(result, LLMResult)
for response in result.generations[0]:
assert isinstance(response, ChatGeneration)
assert isinstance(response.text, str)
assert response.text == response.message.content

0 comments on commit a325132

Please sign in to comment.