Skip to content

Commit

Permalink
docs: changed docs mistakes
Browse files Browse the repository at this point in the history
  • Loading branch information
yanomaly committed Nov 5, 2024
1 parent ff35152 commit 212c6d1
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 59 deletions.
34 changes: 29 additions & 5 deletions docs/docs/integrations/chat/writer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,7 @@
"cell_type": "code",
"id": "2113471c-75d7-45df-b784-d78da4ef7aba",
"metadata": {},
"source": [
"%pip install -qU langchain-community writer-sdk"
],
"source": "%pip install -qU langchain-community writer-sdk",
"outputs": [],
"execution_count": null
},
Expand All @@ -102,13 +100,14 @@
},
"source": [
"from langchain_community.chat_models.writer import ChatWriter\n",
"from writerai import AsyncWriter, Writer\n",
"\n",
"llm = ChatWriter(\n",
" client=Writer(),\n",
" async_client=AsyncWriter(),\n",
" model=\"palmyra-x-004\",\n",
" temperature=0.7,\n",
" max_tokens=1000,\n",
" # api_key=\"...\", # if you prefer to pass api key in directly instaed of using env vars\n",
" # base_url=\"...\",\n",
" # other params...\n",
")"
],
Expand Down Expand Up @@ -152,6 +151,31 @@
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Streaming",
"id": "35b3a5b3dabef65"
},
{
"metadata": {},
"cell_type": "code",
"source": "ai_stream = llm.stream(messages)",
"id": "2725770182bf96dc",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"for chunk in ai_stream:\n",
" print(chunk.content, end=\"\")"
],
"id": "a48410d9488162e3",
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"id": "778f912a-66ea-4a5d-b3de-6c7db4baba26",
Expand Down
14 changes: 11 additions & 3 deletions libs/community/langchain_community/chat_models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,20 @@ class ChatWriter(BaseChatModel):
.. code-block:: python
from langchain_community.chat_models import ChatWriter
from writerai import Writer, AsyncWriter
chat = ChatWriter(model="palmyra-x-004")
client = Writer()
async_client = AsyncWriter()
chat = ChatWriter(
client=client,
async_client=async_client,
model="palmyra-x-004"
)
"""

client: Any = Field(default=None, exclude=True) #: :meta private:
async_client: Any = Field(default=None, exclude=True) #: :meta private:
client: Any = Field(exclude=True) #: :meta private:
async_client: Any = Field(exclude=True) #: :meta private:
model_name: str = Field(default="palmyra-x-004", alias="model")
"""Model name to use."""
temperature: float = 0.7
Expand Down
114 changes: 63 additions & 51 deletions libs/community/tests/unit_tests/chat_models/test_writer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""Unit tests for Writer chat model integration."""

import json
from typing import Any, Dict, List, Optional
from unittest.mock import AsyncMock, MagicMock, patch
from unittest.mock import AsyncMock, MagicMock

import pytest
from langchain_core.callbacks.manager import CallbackManager
Expand All @@ -11,6 +9,8 @@
from langchain_community.chat_models.writer import ChatWriter
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler

"""Classes for mocking Writer responses."""


class ChoiceDelta:
def __init__(self, content: str):
Expand Down Expand Up @@ -104,16 +104,33 @@ def __init__(
self.choices = choices


"""Unit tests for Writer chat model integration."""


class TestChatWriter:
def test_writer_model_param(self) -> None:
"""Test different ways to initialize the chat model."""
test_cases: List[dict] = [
{"model_name": "palmyra-x-004"},
{"model": "palmyra-x-004"},
{"model_name": "palmyra-x-004"},
{
"model_name": "palmyra-x-004",
"client": MagicMock(),
"async_client": AsyncMock(),
},
{
"model": "palmyra-x-004",
"client": MagicMock(),
"async_client": AsyncMock(),
},
{
"model_name": "palmyra-x-004",
"client": MagicMock(),
"async_client": AsyncMock(),
},
{
"model": "palmyra-x-004",
"temperature": 0.5,
"client": MagicMock(),
"async_client": AsyncMock(),
},
]

Expand Down Expand Up @@ -183,7 +200,6 @@ def test_convert_writer_to_langchain_with_tool_calls(self) -> None:
@pytest.fixture(autouse=True)
def mock_unstreaming_completion(self) -> Chat:
"""Fixture providing a mock API response."""

return Chat(
id="chat-12345",
object="chat.completion",
Expand Down Expand Up @@ -270,29 +286,29 @@ def test_sync_completion(
self, mock_unstreaming_completion: List[ChatCompletionChunk]
) -> None:
"""Test basic chat completion with mocked response."""
chat = ChatWriter()
mock_client = MagicMock()
mock_client.chat.chat.return_value = mock_unstreaming_completion

with patch.object(chat, "client", mock_client):
message = HumanMessage(content="Hi there!")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert response.content == "Hello! How can I help you?"
chat = ChatWriter(client=mock_client, async_client=AsyncMock())

message = HumanMessage(content="Hi there!")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert response.content == "Hello! How can I help you?"

async def test_async_completion(
self, mock_unstreaming_completion: List[ChatCompletionChunk]
) -> None:
"""Test async chat completion with mocked response."""
chat = ChatWriter()
mock_client = AsyncMock()
mock_client.chat.chat.return_value = mock_unstreaming_completion

with patch.object(chat, "async_client", mock_client):
message = HumanMessage(content="Hi there!")
response = await chat.ainvoke([message])
assert isinstance(response, AIMessage)
assert response.content == "Hello! How can I help you?"
chat = ChatWriter(client=MagicMock(), async_client=mock_client)

message = HumanMessage(content="Hi there!")
response = await chat.ainvoke([message])
assert isinstance(response, AIMessage)
assert response.content == "Hello! How can I help you?"

def test_sync_streaming(
self, mock_streaming_chunks: List[ChatCompletionChunk]
Expand All @@ -301,27 +317,25 @@ def test_sync_streaming(
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])

chat = ChatWriter(
callback_manager=callback_manager,
max_tokens=10,
)

mock_client = MagicMock()
mock_response = MagicMock()
mock_response.__iter__.return_value = mock_streaming_chunks
mock_client.chat.chat.return_value = mock_response

with patch.object(chat, "client", mock_client):
message = HumanMessage(content="Hi")
response = chat.stream([message])

response_message = ""

for chunk in response:
response_message += str(chunk.content)
chat = ChatWriter(
client=mock_client,
async_client=AsyncMock(),
callback_manager=callback_manager,
max_tokens=10,
)

assert callback_handler.llm_streams > 0
assert response_message == "Hello! How can I help you?"
message = HumanMessage(content="Hi")
response = chat.stream([message])
response_message = ""
for chunk in response:
response_message += str(chunk.content)
assert callback_handler.llm_streams > 0
assert response_message == "Hello! How can I help you?"

async def test_async_streaming(
self, mock_streaming_chunks: List[ChatCompletionChunk]
Expand All @@ -330,27 +344,25 @@ async def test_async_streaming(
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])

chat = ChatWriter(
callback_manager=callback_manager,
max_tokens=10,
)

mock_client = AsyncMock()
mock_response = AsyncMock()
mock_response.__aiter__.return_value = mock_streaming_chunks
mock_client.chat.chat.return_value = mock_response

with patch.object(chat, "async_client", mock_client):
message = HumanMessage(content="Hi")
response = chat.astream([message])

response_message = ""

async for chunk in response:
response_message += str(chunk.content)
chat = ChatWriter(
client=MagicMock(),
async_client=mock_client,
callback_manager=callback_manager,
max_tokens=10,
)

assert callback_handler.llm_streams > 0
assert response_message == "Hello! How can I help you?"
message = HumanMessage(content="Hi")
response = chat.astream([message])
response_message = ""
async for chunk in response:
response_message += str(chunk.content)
assert callback_handler.llm_streams > 0
assert response_message == "Hello! How can I help you?"

def test_sync_tool_calling(
self, mock_tool_call_choice_response: Dict[str, Any]
Expand All @@ -366,7 +378,7 @@ class GetWeather(BaseModel):
mock_client = MagicMock()
mock_client.chat.chat.return_value = mock_tool_call_choice_response

chat = ChatWriter(client=mock_client)
chat = ChatWriter(client=mock_client, async_client=AsyncMock())

chat_with_tools = chat.bind_tools(
tools=[GetWeather],
Expand All @@ -393,7 +405,7 @@ class GetWeather(BaseModel):
mock_client = AsyncMock()
mock_client.chat.chat.return_value = mock_tool_call_choice_response

chat = ChatWriter(async_client=mock_client)
chat = ChatWriter(client=MagicMock(), async_client=mock_client)

chat_with_tools = chat.bind_tools(
tools=[GetWeather],
Expand Down

0 comments on commit 212c6d1

Please sign in to comment.