diff --git a/src/rai/rai/messages/multimodal.py b/src/rai/rai/messages/multimodal.py index 8db862be..44fe7131 100644 --- a/src/rai/rai/messages/multimodal.py +++ b/src/rai/rai/messages/multimodal.py @@ -27,7 +27,7 @@ class MultimodalArtifact(TypedDict): class MultimodalMessage(BaseMessage): images: Optional[List[str]] = None - audios: Optional[Any] = None + audios: Optional[List[str]] = None def __init__( self, @@ -35,8 +35,6 @@ def __init__( ): super().__init__(**kwargs) # type: ignore - if self.audios not in [None, []]: - raise ValueError("Audio is not yet supported") _content: List[Union[str, Dict[str, Union[Dict[str, str], str]]]] = [] @@ -56,6 +54,19 @@ def __init__( for image in self.images ] _content.extend(_image_content) + + # aduio content handling (used audio/wav as MIME type) + if isinstance(self.audios, list): + _audio_content = [ + { + "type": "audio_url", + "audio_url": { + "url": f"data:audio/wav;base64,{audio}", + }, + } + for audio in self.audios + ] + _content.extend(_audio_content) self.content = _content @property diff --git a/tests/messages/test_multimodal.py b/tests/messages/test_multimodal.py index 3ada5ec4..34f7e0ec 100644 --- a/tests/messages/test_multimodal.py +++ b/tests/messages/test_multimodal.py @@ -15,7 +15,6 @@ import base64 import os from typing import List, Literal, Type, cast - import cv2 import numpy as np import pytest @@ -29,20 +28,15 @@ from langfuse.callback import CallbackHandler from pydantic import BaseModel, Field from pytest import FixtureRequest - from rai.messages import HumanMultimodalMessage from rai.tools.utils import run_requested_tools - class GetImageToolInput(BaseModel): name: str = Field(..., title="Name of the image") - class GetImageTool(BaseTool): - name: str = "GetImageTool" description: str = "Get an image from the user" - args_schema: Type[GetImageToolInput] = GetImageToolInput # type: ignore def _run(self, name: str): @@ -53,6 +47,26 @@ def _run(self, name: str): base64_image = base64.b64encode(image).decode("utf-8") return {"content": f"Here is the image {name}", "images": [base64_image]} +class GetAudioToolInput(BaseModel): + name: str = Field(..., title="Name of the audio file") + +class GetAudioTool(BaseTool): + name: str = "GetAudioTool" + description: str = "Get an audio file from the user" + args_schema: Type[GetAudioToolInput] = GetAudioToolInput # type: ignore + + def _run(self, name: str): + # simple audio signal (1 second of 440Hz sine wave) + sample_rate = 44100 + duration = 1.0 + t = np.linspace(0, duration, int(sample_rate * duration)) + audio_signal = np.sin(2 * np.pi * 440 * t) + + # convert to WAV format + audio_bytes = audio_signal.astype(np.float32).tobytes() + base64_audio = base64.b64encode(audio_bytes).decode("utf-8") + + return {"content": f"Here is the audio file {name}", "audios": [base64_audio]} @pytest.mark.billable @pytest.mark.parametrize( @@ -70,9 +84,9 @@ def test_multimodal_messages( request: FixtureRequest, ): llm = request.getfixturevalue(llm) # type: ignore - tools = [GetImageTool()] + tools = [GetImageTool(), GetAudioTool()] # added AudioTool to tools llm_with_tools = llm.bind_tools(tools) # type: ignore - + langfuse_handler = CallbackHandler( public_key=os.getenv("LANGFUSE_PK"), secret_key=os.getenv("LANGFUSE_SK"), @@ -81,25 +95,35 @@ def test_multimodal_messages( tags=["test"], ) - scenario: List[BaseMessage] = [ - HumanMultimodalMessage( - content="Can you please describe the contents of test.png image? Remember to use the available tools." - ), + scenarios = [ + [ + HumanMultimodalMessage( + content="Can you please describe the contents of test.png image? Remember to use the available tools." + ), + ], + # audio scenario + [ + HumanMultimodalMessage( + content="Can you please analyze the contents of test.wav audio? Remember to use the available tools." + ), + ] ] - with callback() as cb: - ai_msg = cast( - AIMessage, - llm_with_tools.invoke(scenario, config={"callbacks": [langfuse_handler]}), - ) - scenario.append(ai_msg) - scenario = run_requested_tools(ai_msg, tools, scenario, llm_type=llm_type) - ai_msg = llm_with_tools.invoke( - scenario, config={"callbacks": [langfuse_handler]} - ) - usage_tracker.add_usage( - llm_type, - cost=cb.total_cost, - total_tokens=cb.total_tokens, - input_tokens=cb.prompt_tokens, - output_tokens=cb.completion_tokens, - ) + + for scenario in scenarios: + with callback() as cb: + ai_msg = cast( + AIMessage, + llm_with_tools.invoke(scenario, config={"callbacks": [langfuse_handler]}), + ) + scenario.append(ai_msg) + scenario = run_requested_tools(ai_msg, tools, scenario, llm_type=llm_type) + ai_msg = llm_with_tools.invoke( + scenario, config={"callbacks": [langfuse_handler]} + ) + usage_tracker.add_usage( + llm_type, + cost=cb.total_cost, + total_tokens=cb.total_tokens, + input_tokens=cb.prompt_tokens, + output_tokens=cb.completion_tokens, + ) \ No newline at end of file