Skip to content

Commit

Permalink
adds test case for generator with stream enabled in the model client,…
Browse files Browse the repository at this point in the history
… created a mock class::TestGeneratorWithStream for simulating streamed API response
  • Loading branch information
BalasubramanyamEvani committed Dec 18, 2024
1 parent 90442ed commit 13ab83b
Showing 1 changed file with 68 additions and 0 deletions.
68 changes: 68 additions & 0 deletions adalflow/tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from adalflow.core.model_client import ModelClient
from adalflow.components.model_client.groq_client import GroqAPIClient
from adalflow.tracing import GeneratorStateLogger
from typing import Generator as GeneratorType


class TestGenerator(IsolatedAsyncioTestCase):
Expand Down Expand Up @@ -192,5 +193,72 @@ def test_groq_client_call(self, mock_call):
# self.assertEqual(output.data, "Generated text response")


class TestGeneratorWithStream(unittest.TestCase):

def setUp(self):
"""Set up the mocked environment for the stream test cases."""
self.sent_chunks = [
"Pa",
"ris",
" is",
" the",
" capital",
" of",
" France",
]

with patch(
"adalflow.core.model_client.ModelClient", spec=ModelClient
) as MockAPI:
mock_api_client = Mock(spec=ModelClient)
MockAPI.return_value = mock_api_client

mock_api_client.convert_inputs_to_api_kwargs.return_value = {
"model": "phi3:latest",
"stream": True,
"prompt": (
"<START_OF_SYSTEM_PROMPT>\nYou are a helpful assistant.\n<END_OF_SYSTEM_PROMPT>\n"
"<START_OF_USER_PROMPT>\nWhat is the capital of France?\n<END_OF_USER_PROMPT>"
),
}

mock_api_client.parse_chat_completion.return_value = (
self._mock_stream_generator(self.sent_chunks)
)
self.mock_api_client = mock_api_client

self.generator = Generator(model_client=self.mock_api_client)

def test_generator_call_with_stream(self):
"""Test the generator call with streaming enabled."""
prompt_kwargs = {"input_str": "What is the capital of France?"}
model_kwargs = {"model": "phi3:latest", "stream": True}

output = self.generator.call(
prompt_kwargs=prompt_kwargs, model_kwargs=model_kwargs
)

# Assert that output is of type GeneratorOutput
self.assertIsInstance(output, GeneratorOutput)
# Assert that output.data is a generator type
self.assertIsInstance(output.data, GeneratorType)

received_chunks = []
for chunk in output.data:
# Assert that each chunk is of type GeneratorOutput
self.assertIsInstance(chunk, GeneratorOutput)
received_chunks.append(chunk.raw_response)

# Assert that the received chunks match the sent chunks
self.assertEqual(received_chunks, self.sent_chunks)

def _mock_stream_generator(
self, completion: list[str]
) -> GeneratorType[GeneratorOutput, None, None]:
"""Simulates streamed API responses."""
for chunk in completion:
yield GeneratorOutput(data=None, raw_response=chunk)


if __name__ == "__main__":
unittest.main()

0 comments on commit 13ab83b

Please sign in to comment.