Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test/add multimodal audio test #382

Open
wants to merge 2 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions src/rai/rai/messages/multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,14 @@ class MultimodalArtifact(TypedDict):

class MultimodalMessage(BaseMessage):
images: Optional[List[str]] = None
audios: Optional[Any] = None
audios: Optional[List[str]] = None

def __init__(
self,
**kwargs: Any,
):
super().__init__(**kwargs) # type: ignore

if self.audios not in [None, []]:
raise ValueError("Audio is not yet supported")

_content: List[Union[str, Dict[str, Union[Dict[str, str], str]]]] = []

Expand All @@ -56,6 +54,19 @@ def __init__(
for image in self.images
]
_content.extend(_image_content)

# aduio content handling (used audio/wav as MIME type)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# aduio content handling (used audio/wav as MIME type)
# audio content handling (used audio/wav as MIME type)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mb, I'll change that

if isinstance(self.audios, list):
_audio_content = [
{
"type": "audio_url",
"audio_url": {
"url": f"data:audio/wav;base64,{audio}",
},
}
for audio in self.audios
]
_content.extend(_audio_content)
self.content = _content

@property
Expand Down
82 changes: 53 additions & 29 deletions tests/messages/test_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import base64
import os
from typing import List, Literal, Type, cast

import cv2
import numpy as np
import pytest
Expand All @@ -29,20 +28,15 @@
from langfuse.callback import CallbackHandler
from pydantic import BaseModel, Field
from pytest import FixtureRequest

from rai.messages import HumanMultimodalMessage
from rai.tools.utils import run_requested_tools


class GetImageToolInput(BaseModel):
name: str = Field(..., title="Name of the image")


class GetImageTool(BaseTool):

name: str = "GetImageTool"
description: str = "Get an image from the user"

args_schema: Type[GetImageToolInput] = GetImageToolInput # type: ignore

def _run(self, name: str):
Expand All @@ -53,6 +47,26 @@ def _run(self, name: str):
base64_image = base64.b64encode(image).decode("utf-8")
return {"content": f"Here is the image {name}", "images": [base64_image]}

class GetAudioToolInput(BaseModel):
name: str = Field(..., title="Name of the audio file")

class GetAudioTool(BaseTool):
name: str = "GetAudioTool"
description: str = "Get an audio file from the user"
args_schema: Type[GetAudioToolInput] = GetAudioToolInput # type: ignore

def _run(self, name: str):
# simple audio signal (1 second of 440Hz sine wave)
sample_rate = 44100
duration = 1.0
t = np.linspace(0, duration, int(sample_rate * duration))
audio_signal = np.sin(2 * np.pi * 440 * t)
Comment on lines +58 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A better solution would be to add an actual test.wav file, instead of mocking the input to the too like so.


# convert to WAV format
audio_bytes = audio_signal.astype(np.float32).tobytes()
base64_audio = base64.b64encode(audio_bytes).decode("utf-8")

return {"content": f"Here is the audio file {name}", "audios": [base64_audio]}

@pytest.mark.billable
@pytest.mark.parametrize(
Expand All @@ -70,9 +84,9 @@ def test_multimodal_messages(
request: FixtureRequest,
):
llm = request.getfixturevalue(llm) # type: ignore
tools = [GetImageTool()]
tools = [GetImageTool(), GetAudioTool()] # added AudioTool to tools
llm_with_tools = llm.bind_tools(tools) # type: ignore

langfuse_handler = CallbackHandler(
public_key=os.getenv("LANGFUSE_PK"),
secret_key=os.getenv("LANGFUSE_SK"),
Expand All @@ -81,25 +95,35 @@ def test_multimodal_messages(
tags=["test"],
)

scenario: List[BaseMessage] = [
HumanMultimodalMessage(
content="Can you please describe the contents of test.png image? Remember to use the available tools."
),
scenarios = [
[
HumanMultimodalMessage(
content="Can you please describe the contents of test.png image? Remember to use the available tools."
),
],
# audio scenario
[
HumanMultimodalMessage(
content="Can you please analyze the contents of test.wav audio? Remember to use the available tools."
),
]
]
with callback() as cb:
ai_msg = cast(
AIMessage,
llm_with_tools.invoke(scenario, config={"callbacks": [langfuse_handler]}),
)
scenario.append(ai_msg)
scenario = run_requested_tools(ai_msg, tools, scenario, llm_type=llm_type)
ai_msg = llm_with_tools.invoke(
scenario, config={"callbacks": [langfuse_handler]}
)
usage_tracker.add_usage(
llm_type,
cost=cb.total_cost,
total_tokens=cb.total_tokens,
input_tokens=cb.prompt_tokens,
output_tokens=cb.completion_tokens,
)

for scenario in scenarios:
with callback() as cb:
ai_msg = cast(
AIMessage,
llm_with_tools.invoke(scenario, config={"callbacks": [langfuse_handler]}),
)
scenario.append(ai_msg)
scenario = run_requested_tools(ai_msg, tools, scenario, llm_type=llm_type)
ai_msg = llm_with_tools.invoke(
scenario, config={"callbacks": [langfuse_handler]}
)
usage_tracker.add_usage(
llm_type,
cost=cb.total_cost,
total_tokens=cb.total_tokens,
input_tokens=cb.prompt_tokens,
output_tokens=cb.completion_tokens,
)
Loading