From 9b29b64eb1bb75733b499dc431660183b1a27278 Mon Sep 17 00:00:00 2001 From: Wojciech-Rebisz Date: Wed, 11 Sep 2024 10:41:50 +0200 Subject: [PATCH] Add stream option to _agenerate --- libs/ibm/langchain_ibm/llms.py | 18 +++++++++++++----- libs/ibm/tests/integration_tests/test_llms.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/libs/ibm/langchain_ibm/llms.py b/libs/ibm/langchain_ibm/llms.py index 78a967e..5486157 100644 --- a/libs/ibm/langchain_ibm/llms.py +++ b/libs/ibm/langchain_ibm/llms.py @@ -471,17 +471,25 @@ async def _agenerate( prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, **kwargs: Any, ) -> LLMResult: """Async run the LLM on the given prompt and input.""" params, kwargs = self._get_chat_params(stop=stop, **kwargs) params = self._validate_chat_params(params) - responses = [ - await self.watsonx_model.agenerate(prompt=prompt, params=params, **kwargs) - for prompt in prompts - ] + if stream: + return await super()._agenerate( + prompts=prompts, stop=stop, run_manager=run_manager, **kwargs + ) + else: + responses = [ + await self.watsonx_model.agenerate( + prompt=prompt, params=params, **kwargs + ) + for prompt in prompts + ] - return self._create_llm_result(responses) + return self._create_llm_result(responses) def _stream( self, diff --git a/libs/ibm/tests/integration_tests/test_llms.py b/libs/ibm/tests/integration_tests/test_llms.py index fd0ae32..5a8cfb7 100644 --- a/libs/ibm/tests/integration_tests/test_llms.py +++ b/libs/ibm/tests/integration_tests/test_llms.py @@ -445,6 +445,16 @@ async def test_watsonx_agenerate() -> None: assert response.llm_output["token_usage"]["generated_token_count"] != 0 # type: ignore +async def test_watsonx_agenerate_with_stream() -> None: + watsonxllm = WatsonxLLM( + model_id=MODEL_ID, + url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) + response = await watsonxllm.agenerate(["What color sunflower is?"], stream=True) + assert "yellow" in response.generations[0][0].text.lower() + + def test_get_num_tokens() -> None: watsonxllm = WatsonxLLM( model_id=MODEL_ID,