diff --git a/adalflow/tests/test_generator.py b/adalflow/tests/test_generator.py index a15c302a..38d331ce 100644 --- a/adalflow/tests/test_generator.py +++ b/adalflow/tests/test_generator.py @@ -15,6 +15,7 @@ from adalflow.core.model_client import ModelClient from adalflow.components.model_client.groq_client import GroqAPIClient from adalflow.tracing import GeneratorStateLogger +from typing import Generator as GeneratorType class TestGenerator(IsolatedAsyncioTestCase): @@ -192,5 +193,72 @@ def test_groq_client_call(self, mock_call): # self.assertEqual(output.data, "Generated text response") +class TestGeneratorWithStream(unittest.TestCase): + + def setUp(self): + """Set up the mocked environment for the stream test cases.""" + self.sent_chunks = [ + "Pa", + "ris", + " is", + " the", + " capital", + " of", + " France", + ] + + with patch( + "adalflow.core.model_client.ModelClient", spec=ModelClient + ) as MockAPI: + mock_api_client = Mock(spec=ModelClient) + MockAPI.return_value = mock_api_client + + mock_api_client.convert_inputs_to_api_kwargs.return_value = { + "model": "phi3:latest", + "stream": True, + "prompt": ( + "\nYou are a helpful assistant.\n\n" + "\nWhat is the capital of France?\n" + ), + } + + mock_api_client.parse_chat_completion.return_value = ( + self._mock_stream_generator(self.sent_chunks) + ) + self.mock_api_client = mock_api_client + + self.generator = Generator(model_client=self.mock_api_client) + + def test_generator_call_with_stream(self): + """Test the generator call with streaming enabled.""" + prompt_kwargs = {"input_str": "What is the capital of France?"} + model_kwargs = {"model": "phi3:latest", "stream": True} + + output = self.generator.call( + prompt_kwargs=prompt_kwargs, model_kwargs=model_kwargs + ) + + # Assert that output is of type GeneratorOutput + self.assertIsInstance(output, GeneratorOutput) + # Assert that output.data is a generator type + self.assertIsInstance(output.data, GeneratorType) + + received_chunks = [] + for chunk in output.data: + # Assert that each chunk is of type GeneratorOutput + self.assertIsInstance(chunk, GeneratorOutput) + received_chunks.append(chunk.raw_response) + + # Assert that the received chunks match the sent chunks + self.assertEqual(received_chunks, self.sent_chunks) + + def _mock_stream_generator( + self, completion: list[str] + ) -> GeneratorType[GeneratorOutput, None, None]: + """Simulates streamed API responses.""" + for chunk in completion: + yield GeneratorOutput(data=None, raw_response=chunk) + + if __name__ == "__main__": unittest.main()