Skip to content

Commit

Permalink
Unit test update
Browse files Browse the repository at this point in the history
  • Loading branch information
findalexli committed Sep 12, 2024
1 parent 07dec05 commit 1065dbc
Showing 1 changed file with 70 additions and 14 deletions.
84 changes: 70 additions & 14 deletions libs/community/tests/unit_tests/chat_models/test_reka.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Test Reka Chat API wrapper."""
"""Test Reka Chat wrapper."""

import os
from typing import List
Expand All @@ -7,7 +7,10 @@
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage

from langchain_community.chat_models import ChatReka
from langchain_community.chat_models.reka import process_messages_for_reka
from langchain_community.chat_models.reka import (
convert_to_reka_messages,
process_content,
)

os.environ["REKA_API_KEY"] = "dummy_key"

Expand Down Expand Up @@ -52,34 +55,87 @@ def test_reka_initialization() -> None:


@pytest.mark.parametrize(
("messages", "expected"),
("content", "expected"),
[
([HumanMessage(content="Hello")], [{"role": "user", "content": "Hello"}]),
("Hello", [{"type": "text", "text": "Hello"}]),
(
[HumanMessage(content="Hello"), AIMessage(content="Hi there!")],
[
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"type": "text", "text": "Hello"},
{"type": "image_url", "image_url": "https://example.com/image.jpg"},
],
[
{"type": "text", "text": "Hello"},
{"type": "image_url", "image_url": "https://example.com/image.jpg"},
],
),
(
[
SystemMessage(content="You're an assistant"),
HumanMessage(content="Hello"),
AIMessage(content="Hi there!"),
{"type": "text", "text": "Hello"},
{
"type": "image_url",
"image_url": {"url": "https://example.com/image.jpg"},
},
],
[
{"role": "user", "content": "You're an assistant\nHello"},
{"role": "assistant", "content": "Hi there!"},
{"type": "text", "text": "Hello"},
{"type": "image_url", "image_url": "https://example.com/image.jpg"},
],
),
],
)
def test_message_processing(messages: List[BaseMessage], expected: List[dict]) -> None:
result = process_messages_for_reka(messages)
def test_process_content(content, expected) -> None:
result = process_content(content)
assert result == expected


@pytest.mark.parametrize(
("messages", "expected"),
[
(
[HumanMessage(content="Hello")],
[{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
),
(
[
HumanMessage(
content=[
{"type": "text", "text": "Describe this image"},
{
"type": "image_url",
"image_url": "https://example.com/image.jpg",
},
]
),
AIMessage(content="It's a beautiful landscape."),
],
[
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image"},
{
"type": "image_url",
"image_url": "https://example.com/image.jpg",
},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "It's a beautiful landscape."}
],
},
],
),
],
)
def test_convert_to_reka_messages(
messages: List[BaseMessage], expected: List[dict]
) -> None:
result = convert_to_reka_messages(messages)
assert [message.dict() for message in result] == expected


@pytest.mark.requires("reka")
def test_reka_streaming() -> None:
llm = ChatReka(streaming=True)
Expand Down

0 comments on commit 1065dbc

Please sign in to comment.