From 3501127633d141dce0195a9a46930fd1168f4439 Mon Sep 17 00:00:00 2001 From: Mateusz Szewczyk <139469471+MateuszOssGit@users.noreply.github.com> Date: Wed, 16 Oct 2024 12:37:31 +0200 Subject: [PATCH] Added Unit Standard tests (#32) * Added Unit standard tests * Updated Unit standard tests * update for unit standard test v2 --- libs/ibm/langchain_ibm/chat_models.py | 6 ++++ .../unit_tests/test_chat_models_standard.py | 30 +++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 libs/ibm/tests/unit_tests/test_chat_models_standard.py diff --git a/libs/ibm/langchain_ibm/chat_models.py b/libs/ibm/langchain_ibm/chat_models.py index 48514a4..71edba0 100644 --- a/libs/ibm/langchain_ibm/chat_models.py +++ b/libs/ibm/langchain_ibm/chat_models.py @@ -470,6 +470,9 @@ class ChatWatsonx(BaseChatModel): * True - default path to truststore will be taken * False - no verification will be made""" + validate_model: bool = True + """Model ID validation.""" + streaming: bool = False """ Whether to stream the results or not. """ @@ -532,6 +535,7 @@ def validate_environment(self) -> Self: project_id=self.project_id, space_id=self.space_id, verify=self.verify, + validate=self.validate_model, ) self.watsonx_model = watsonx_model @@ -584,6 +588,8 @@ def validate_environment(self) -> Self: params=self.params, project_id=self.project_id, space_id=self.space_id, + verify=self.verify, + validate=self.validate_model, ) self.watsonx_model = watsonx_chat diff --git a/libs/ibm/tests/unit_tests/test_chat_models_standard.py b/libs/ibm/tests/unit_tests/test_chat_models_standard.py new file mode 100644 index 0000000..7dfa8a8 --- /dev/null +++ b/libs/ibm/tests/unit_tests/test_chat_models_standard.py @@ -0,0 +1,30 @@ +from typing import Type + +from ibm_watsonx_ai import APIClient, Credentials # type: ignore +from ibm_watsonx_ai.service_instance import ServiceInstance # type: ignore +from langchain_core.language_models import BaseChatModel +from langchain_standard_tests.unit_tests import ChatModelUnitTests + +from langchain_ibm import ChatWatsonx + +client = APIClient.__new__(APIClient) +client.CLOUD_PLATFORM_SPACES = True +client.ICP_PLATFORM_SPACES = True +credentials = Credentials(api_key="api_key") +client.credentials = credentials +client.service_instance = ServiceInstance.__new__(ServiceInstance) +client.service_instance._credentials = credentials + + +class TestWatsonxStandard(ChatModelUnitTests): + @property + def chat_model_class(self) -> Type[BaseChatModel]: + return ChatWatsonx + + @property + def chat_model_params(self) -> dict: + return { + "model_id": "ibm/granite-13b-instruct-v2", + "validate_model": False, + "watsonx_client": client, + }