diff --git a/libs/community/tests/unit_tests/chat_models/test_reka.py b/libs/community/tests/unit_tests/chat_models/test_reka.py index 6c61d9a069fe4a..eca1a948b8b221 100644 --- a/libs/community/tests/unit_tests/chat_models/test_reka.py +++ b/libs/community/tests/unit_tests/chat_models/test_reka.py @@ -1,4 +1,4 @@ -"""Test Reka Chat API wrapper.""" +"""Test Reka Chat wrapper.""" import os from typing import List @@ -7,7 +7,10 @@ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_community.chat_models import ChatReka -from langchain_community.chat_models.reka import process_messages_for_reka +from langchain_community.chat_models.reka import ( + convert_to_reka_messages, + process_content, +) os.environ["REKA_API_KEY"] = "dummy_key" @@ -52,34 +55,87 @@ def test_reka_initialization() -> None: @pytest.mark.parametrize( - ("messages", "expected"), + ("content", "expected"), [ - ([HumanMessage(content="Hello")], [{"role": "user", "content": "Hello"}]), + ("Hello", [{"type": "text", "text": "Hello"}]), ( - [HumanMessage(content="Hello"), AIMessage(content="Hi there!")], [ - {"role": "user", "content": "Hello"}, - {"role": "assistant", "content": "Hi there!"}, + {"type": "text", "text": "Hello"}, + {"type": "image_url", "image_url": "https://example.com/image.jpg"}, + ], + [ + {"type": "text", "text": "Hello"}, + {"type": "image_url", "image_url": "https://example.com/image.jpg"}, ], ), ( [ - SystemMessage(content="You're an assistant"), - HumanMessage(content="Hello"), - AIMessage(content="Hi there!"), + {"type": "text", "text": "Hello"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.jpg"}, + }, ], [ - {"role": "user", "content": "You're an assistant\nHello"}, - {"role": "assistant", "content": "Hi there!"}, + {"type": "text", "text": "Hello"}, + {"type": "image_url", "image_url": "https://example.com/image.jpg"}, ], ), ], ) -def test_message_processing(messages: List[BaseMessage], expected: List[dict]) -> None: - result = process_messages_for_reka(messages) +def test_process_content(content, expected) -> None: + result = process_content(content) assert result == expected +@pytest.mark.parametrize( + ("messages", "expected"), + [ + ( + [HumanMessage(content="Hello")], + [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}], + ), + ( + [ + HumanMessage( + content=[ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": "https://example.com/image.jpg", + }, + ] + ), + AIMessage(content="It's a beautiful landscape."), + ], + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": "https://example.com/image.jpg", + }, + ], + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "It's a beautiful landscape."} + ], + }, + ], + ), + ], +) +def test_convert_to_reka_messages( + messages: List[BaseMessage], expected: List[dict] +) -> None: + result = convert_to_reka_messages(messages) + assert [message.dict() for message in result] == expected + + @pytest.mark.requires("reka") def test_reka_streaming() -> None: llm = ChatReka(streaming=True)