diff --git a/libs/community/tests/unit_tests/chat_models/test_writer.py b/libs/community/tests/unit_tests/chat_models/test_writer.py index c30f4d3ed8b958..87596a0445d49b 100644 --- a/libs/community/tests/unit_tests/chat_models/test_writer.py +++ b/libs/community/tests/unit_tests/chat_models/test_writer.py @@ -13,68 +13,124 @@ from langchain_community.chat_models.writer import ChatWriter from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +"""Classes for mocking Writer responses.""" -@pytest.mark.requires("writerai") + +class ChoiceDelta: + def __init__(self, content: str): + self.content = content + + +class ChunkChoice: + def __init__(self, index: int, finish_reason: str, delta: ChoiceDelta): + self.index = index + self.finish_reason = finish_reason + self.delta = delta + + +class ChatCompletionChunk: + def __init__( + self, + id: str, + object: str, + created: int, + model: str, + choices: List[ChunkChoice], + ): + self.id = id + self.object = object + self.created = created + self.model = model + self.choices = choices + + +class ToolCallFunction: + def __init__(self, name: str, arguments: str): + self.name = name + self.arguments = arguments + + +class ChoiceMessageToolCall: + def __init__(self, id: str, type: str, function: ToolCallFunction): + self.id = id + self.type = type + self.function = function + + +class Usage: + def __init__( + self, + prompt_tokens: int, + completion_tokens: int, + total_tokens: int, + ): + self.prompt_tokens = prompt_tokens + self.completion_tokens = completion_tokens + self.total_tokens = total_tokens + + +class ChoiceMessage: + def __init__( + self, + role: str, + content: str, + tool_calls: Optional[List[ChoiceMessageToolCall]] = None, + ): + self.role = role + self.content = content + self.tool_calls = tool_calls + + +class Choice: + def __init__(self, index: int, finish_reason: str, message: ChoiceMessage): + self.index = index + self.finish_reason = finish_reason + self.message = message + + +class Chat: + def __init__( + self, + id: str, + object: str, + created: int, + system_fingerprint: str, + model: str, + usage: Usage, + choices: List[Choice], + ): + self.id = id + self.object = object + self.created = created + self.system_fingerprint = system_fingerprint + self.model = model + self.usage = usage + self.choices = choices + + +@pytest.mark.requires("writer-sdk") class TestChatWriterCustom: """Test case for ChatWriter""" - - from writerai.types import Chat - from writerai.types.chat import ( - Choice, - ChoiceLogprobs, - ChoiceLogprobsContent, - ChoiceLogprobsContentTopLogprob, - ChoiceLogprobsRefusal, - ChoiceLogprobsRefusalTopLogprob, - ChoiceMessage, - ChoiceMessageToolCall, - ChoiceMessageToolCallFunction, - Usage, - ) - from writerai.types.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta - from writerai.types.chat_completion_chunk import Choice as ChunkChoice + def test(self): + assert 1 == 2, "FAILED MANUALLY" @pytest.fixture(autouse=True) def mock_unstreaming_completion(self) -> Chat: """Fixture providing a mock API response.""" - return self.Chat( + return Chat( id="chat-12345", object="chat.completion", created=1699000000, model="palmyra-x-004", - usage=self.Usage(prompt_tokens=10, completion_tokens=8, total_tokens=18), + system_fingerprint="v1", + usage=Usage(prompt_tokens=10, completion_tokens=8, total_tokens=18), choices=[ - self.Choice( + Choice( index=0, finish_reason="stop", - logprobs=self.ChoiceLogprobs( - content=[ - self.ChoiceLogprobsContent( - token="", - logprob=0, - top_logprobs=[ - self.ChoiceLogprobsContentTopLogprob( - token="", logprob=0 - ) - ], - ) - ], - refusal=[ - self.ChoiceLogprobsRefusal( - token="", - logprob=0, - top_logprobs=[ - self.ChoiceLogprobsRefusalTopLogprob( - token="", logprob=0 - ) - ], - ) - ], - ), - message=self.ChoiceMessage( + message=ChoiceMessage( role="assistant", content="Hello! How can I help you?", - refusal="", ), ) ], @@ -82,48 +138,25 @@ def mock_unstreaming_completion(self) -> Chat: @pytest.fixture(autouse=True) def mock_tool_call_choice_response(self) -> Chat: - return self.Chat( + return Chat( id="chat-12345", object="chat.completion", created=1699000000, model="palmyra-x-004", + system_fingerprint="v1", + usage=Usage(prompt_tokens=29, completion_tokens=32, total_tokens=61), choices=[ - self.Choice( + Choice( index=0, finish_reason="tool_calls", - logprobs=self.ChoiceLogprobs( - content=[ - self.ChoiceLogprobsContent( - token="", - logprob=0, - top_logprobs=[ - self.ChoiceLogprobsContentTopLogprob( - token="", logprob=0 - ) - ], - ) - ], - refusal=[ - self.ChoiceLogprobsRefusal( - token="", - logprob=0, - top_logprobs=[ - self.ChoiceLogprobsRefusalTopLogprob( - token="", logprob=0 - ) - ], - ) - ], - ), - message=self.ChoiceMessage( + message=ChoiceMessage( role="assistant", content="", - refusal="", tool_calls=[ - self.ChoiceMessageToolCall( + ChoiceMessageToolCall( id="call_abc123", type="function", - function=self.ChoiceMessageToolCallFunction( + function=ToolCallFunction( name="GetWeather", arguments='{"location": "London"}', ), @@ -138,29 +171,29 @@ def mock_tool_call_choice_response(self) -> Chat: def mock_streaming_chunks(self) -> List[ChatCompletionChunk]: """Fixture providing mock streaming response chunks.""" return [ - self.ChatCompletionChunk( + ChatCompletionChunk( id="chat-12345", object="chat.completion", created=1699000000, model="palmyra-x-004", choices=[ - self.ChunkChoice( + ChunkChoice( index=0, finish_reason="stop", - delta=self.ChoiceDelta(content="Hello! "), + delta=ChoiceDelta(content="Hello! "), ) ], ), - self.ChatCompletionChunk( + ChatCompletionChunk( id="chat-12345", object="chat.completion", created=1699000000, model="palmyra-x-004", choices=[ - self.ChunkChoice( + ChunkChoice( index=0, finish_reason="stop", - delta=self.ChoiceDelta(content="How can I help you?"), + delta=ChoiceDelta(content="How can I help you?"), ) ], ), @@ -395,7 +428,7 @@ class GetWeather(BaseModel): assert response.tool_calls[0]["args"]["location"] == "London" -@pytest.mark.requires("writerai") +@pytest.mark.requires("writer-sdk") class TestChatWriterStandart(ChatModelUnitTests): """Test case for ChatWriter that inherits from standard LangChain tests.""" diff --git a/libs/community/tests/unit_tests/llms/test_writer.py b/libs/community/tests/unit_tests/llms/test_writer.py index 705718af0a706e..8755e5b4a32af2 100644 --- a/libs/community/tests/unit_tests/llms/test_writer.py +++ b/libs/community/tests/unit_tests/llms/test_writer.py @@ -9,26 +9,40 @@ from langchain_community.llms.writer import Writer from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler +"""Classes for mocking Writer responses.""" -@pytest.mark.requires("writerai") + +class Choice: + def __init__(self, text: str): + self.text = text + + +class Completion: + def __init__(self, choices: List[Choice]): + self.choices = choices + + +class StreamingData: + def __init__(self, value: str): + self.value = value + + +@pytest.mark.requires("writer-sdk") class TestWriterLLM: """Unit tests for Writer LLM integration.""" - from writerai.types import Completion, StreamingData - from writerai.types.completion import Choice - @pytest.fixture(autouse=True) def mock_unstreaming_completion(self) -> Completion: """Fixture providing a mock API response.""" - return self.Completion(choices=[self.Choice(text="Hello! How can I help you?")]) + return Completion(choices=[Choice(text="Hello! How can I help you?")]) @pytest.fixture(autouse=True) def mock_streaming_completion(self) -> List[StreamingData]: """Fixture providing mock streaming response chunks.""" return [ - self.StreamingData(value="Hello! "), - self.StreamingData(value="How can I"), - self.StreamingData(value=" help you?"), + StreamingData(value="Hello! "), + StreamingData(value="How can I"), + StreamingData(value=" help you?"), ] def test_sync_unstream_completion(