diff --git a/libs/ibm/langchain_ibm/llms.py b/libs/ibm/langchain_ibm/llms.py index 5ed9874..f7c3a7b 100644 --- a/libs/ibm/langchain_ibm/llms.py +++ b/libs/ibm/langchain_ibm/llms.py @@ -1,9 +1,10 @@ import logging import os -from typing import Any, Dict, Iterator, List, Mapping, Optional, Union +from typing import Any, Dict, Iterator, List, Mapping, Optional, Tuple, Union from ibm_watsonx_ai import Credentials # type: ignore from ibm_watsonx_ai.foundation_models import Model, ModelInference # type: ignore +from ibm_watsonx_ai.metanames import GenTextParamsMetaNames # type: ignore from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import BaseLLM from langchain_core.outputs import Generation, GenerationChunk, LLMResult @@ -11,6 +12,9 @@ from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env logger = logging.getLogger(__name__) +textgen_valid_params = [ + value for key, value in GenTextParamsMetaNames.__dict__.items() if key.isupper() +] class WatsonxLLM(BaseLLM): @@ -265,18 +269,50 @@ def get_count_value(key: str, result: Dict[str, Any]) -> int: "input_token_count": input_token_count, } + @staticmethod + def _validate_chat_params( + params: Dict[str, Any], + ) -> Dict[str, Any]: + """Validate and fix the chat parameters""" + for param in params.keys(): + if param.lower() not in textgen_valid_params: + raise Exception( + f"Parameter {param} is not valid. " + f"Valid parameters are: {textgen_valid_params}" + ) + return params + + @staticmethod + def _override_chat_params( + params: Dict[str, Any], **kwargs: Any + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + Override class parameters with those provided in the invoke method. + Merges the 'params' dictionary with any 'params' found in kwargs, + then updates 'params' with matching keys from kwargs and removes + those keys from kwargs. + """ + for key in list(kwargs.keys()): + if key.lower() in textgen_valid_params: + params[key] = kwargs.pop(key) + return params, kwargs + def _get_chat_params( self, stop: Optional[List[str]] = None, **kwargs: Any - ) -> Optional[Dict[str, Any]]: - params = {**self.params} if self.params else {} - params = params | {**kwargs.get("params", {})} + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + params = ( + {**self.params, **kwargs.pop("params", {})} + if self.params + else kwargs.pop("params", {}) + ) + params, kwargs = self._override_chat_params(params, **kwargs) if stop is not None: if params and "stop_sequences" in params: raise ValueError( "`stop_sequences` found in both the input and default params." ) params = (params or {}) | {"stop_sequences": stop} - return params + return params, kwargs def _create_llm_result(self, response: List[dict]) -> LLMResult: """Create the LLMResult from the choices and prompts.""" @@ -360,7 +396,8 @@ def _generate( response = watsonx_llm.generate(["What is a molecule"]) """ - params = self._get_chat_params(stop=stop, **kwargs) + params, kwargs = self._get_chat_params(stop=stop, **kwargs) + params = self._validate_chat_params(params) should_stream = stream if stream is not None else self.streaming if should_stream: if len(prompts) > 1: @@ -383,7 +420,7 @@ def _generate( return LLMResult(generations=[[generation]]) else: response = self.watsonx_model.generate( - prompt=prompts, **(kwargs | {"params": params}) + prompt=prompts, params=params, **kwargs ) return self._create_llm_result(response) @@ -406,11 +443,12 @@ def _stream( response = watsonx_llm.stream("What is a molecule") for chunk in response: - print(chunk, end='') + print(chunk, end='', flush=True) """ - params = self._get_chat_params(stop=stop, **kwargs) + params, kwargs = self._get_chat_params(stop=stop, **kwargs) + params = self._validate_chat_params(params) for stream_resp in self.watsonx_model.generate_text_stream( - prompt=prompt, raw_response=True, **(kwargs | {"params": params}) + prompt=prompt, params=params, **(kwargs | {"raw_response": True}) ): if not isinstance(stream_resp, dict): stream_resp = stream_resp.dict() diff --git a/libs/ibm/tests/integration_tests/test_llms.py b/libs/ibm/tests/integration_tests/test_llms.py index cf0a39e..329c799 100644 --- a/libs/ibm/tests/integration_tests/test_llms.py +++ b/libs/ibm/tests/integration_tests/test_llms.py @@ -91,6 +91,27 @@ def test_watsonxllm_invoke_with_params_3() -> None: assert len(response) > 0 +def test_watsonxllm_invoke_with_params_4() -> None: + parameters_1 = { + GenTextParamsMetaNames.DECODING_METHOD: "sample", + GenTextParamsMetaNames.MAX_NEW_TOKENS: 10, + } + parameters_2 = { + "temperature": 0.6, + } + + watsonxllm = WatsonxLLM( + model_id=MODEL_ID, + url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + params=parameters_1, + ) + response = watsonxllm.invoke("What color sunflower is?", **parameters_2) # type: ignore[arg-type] + print(f"\nResponse: {response}") + assert isinstance(response, str) + assert len(response) > 0 + + def test_watsonxllm_generate() -> None: watsonxllm = WatsonxLLM( model_id=MODEL_ID, @@ -140,6 +161,18 @@ def test_watsonxllm_generate_with_multiple_prompts() -> None: assert len(response_text) > 0 +def test_watsonxllm_invoke_with_guardrails() -> None: + watsonxllm = WatsonxLLM( + model_id=MODEL_ID, + url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) + response = watsonxllm.invoke("What color sunflower is?", guardrails=True) + print(f"\nResponse: {response}") + assert isinstance(response, str) + assert len(response) > 0 + + def test_watsonxllm_generate_stream() -> None: watsonxllm = WatsonxLLM( model_id=MODEL_ID, @@ -177,6 +210,115 @@ def test_watsonxllm_stream() -> None: ), "Linked text stream are not the same as generated text" +def test_watsonxllm_stream_with_kwargs() -> None: + watsonxllm = WatsonxLLM( + model_id=MODEL_ID, + url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) + stream_response = watsonxllm.stream("What color sunflower is?", raw_response=True) + + for chunk in stream_response: + assert isinstance( + chunk, str + ), f"chunk expect type '{str}', actual '{type(chunk)}'" + + +def test_watsonxllm_stream_with_params() -> None: + parameters = { + GenTextParamsMetaNames.DECODING_METHOD: "greedy", + GenTextParamsMetaNames.MAX_NEW_TOKENS: 10, + GenTextParamsMetaNames.MIN_NEW_TOKENS: 5, + } + watsonxllm = WatsonxLLM( + model_id=MODEL_ID, + url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + params=parameters, + ) + response = watsonxllm.invoke("What color sunflower is?") + print(f"\nResponse: {response}") + + stream_response = watsonxllm.stream("What color sunflower is?") + + linked_text_stream = "" + for chunk in stream_response: + assert isinstance( + chunk, str + ), f"chunk expect type '{str}', actual '{type(chunk)}'" + linked_text_stream += chunk + print(f"Linked text stream: {linked_text_stream}") + assert ( + response == linked_text_stream + ), "Linked text stream are not the same as generated text" + + +def test_watsonxllm_stream_with_params_2() -> None: + parameters = { + GenTextParamsMetaNames.DECODING_METHOD: "greedy", + GenTextParamsMetaNames.MAX_NEW_TOKENS: 10, + GenTextParamsMetaNames.MIN_NEW_TOKENS: 5, + } + watsonxllm = WatsonxLLM( + model_id=MODEL_ID, + url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + ) + stream_response = watsonxllm.stream("What color sunflower is?", params=parameters) + + for chunk in stream_response: + assert isinstance( + chunk, str + ), f"chunk expect type '{str}', actual '{type(chunk)}'" + print(chunk) + + +def test_watsonxllm_stream_with_params_3() -> None: + parameters_1 = { + GenTextParamsMetaNames.DECODING_METHOD: "sample", + GenTextParamsMetaNames.MAX_NEW_TOKENS: 10, + } + parameters_2 = { + GenTextParamsMetaNames.MIN_NEW_TOKENS: 5, + } + watsonxllm = WatsonxLLM( + model_id=MODEL_ID, + url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + params=parameters_1, + ) + stream_response = watsonxllm.stream("What color sunflower is?", params=parameters_2) + + for chunk in stream_response: + assert isinstance( + chunk, str + ), f"chunk expect type '{str}', actual '{type(chunk)}'" + print(chunk) + + +def test_watsonxllm_stream_with_params_4() -> None: + parameters_1 = { + GenTextParamsMetaNames.DECODING_METHOD: "sample", + GenTextParamsMetaNames.MAX_NEW_TOKENS: 10, + } + parameters_2 = { + "temperature": 0.6, + } + watsonxllm = WatsonxLLM( + model_id=MODEL_ID, + url="https://us-south.ml.cloud.ibm.com", # type: ignore[arg-type] + project_id=WX_PROJECT_ID, + params=parameters_1, + ) + stream_response = watsonxllm.stream("What color sunflower is?", **parameters_2) # type: ignore[arg-type] + + for chunk in stream_response: + assert isinstance( + chunk, str + ), f"chunk expect type '{str}', actual '{type(chunk)}'" + print(chunk) + + def test_watsonxllm_invoke_from_wx_model() -> None: model = Model( model_id=MODEL_ID,