From 20f7d5fc353f13678841c563e449a9bfe924700f Mon Sep 17 00:00:00 2001 From: Mikhail Khludnev Date: Thu, 1 Feb 2024 00:44:15 +0300 Subject: [PATCH 1/4] add TritonTensorRTLLM(verbose=False) --- .../nvidia-trt/langchain_nvidia_trt/llms.py | 4 ++- .../nvidia-trt/tests/unit_tests/test_llms.py | 26 +++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py b/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py index 0ea1fca1df8f1..782c727703003 100644 --- a/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py +++ b/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py @@ -40,6 +40,7 @@ class TritonTensorRTLLM(BaseLLM): length_penalty: (float) The penalty to apply repeated tokens tokens: (int) The maximum number of tokens to generate. client: The client object used to communicate with the inference server + verbose: flag to pass to the client on creation Example: .. code-block:: python @@ -73,6 +74,7 @@ class TritonTensorRTLLM(BaseLLM): description="Request the inference server to load the specified model.\ Certain Triton configurations do not allow for this operation.", ) + verbose: bool = False def __del__(self): """Ensure the client streaming connection is properly shutdown""" @@ -82,7 +84,7 @@ def __del__(self): def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate that python package exists in environment.""" if not values.get("client"): - values["client"] = grpcclient.InferenceServerClient(values["server_url"]) + values["client"] = grpcclient.InferenceServerClient(values["server_url"], verbose=values.get("verbose",False)) return values @property diff --git a/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py b/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py index 0113b83507315..945491238c861 100644 --- a/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py +++ b/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py @@ -1,7 +1,33 @@ """Test TritonTensorRT Chat API wrapper.""" +import sys +from io import StringIO + +import pytest +from tritonclient.utils import InferenceServerException + from langchain_nvidia_trt import TritonTensorRTLLM def test_initialization() -> None: """Test integration initialization.""" TritonTensorRTLLM(model_name="ensemble", server_url="http://localhost:8001") + +def test_default_verbose() -> None: + llm=TritonTensorRTLLM(server_url="http://localhost:8001", model_name="ensemble") + captured = StringIO() + sys.stdout = captured + with pytest.raises(InferenceServerException): + llm.client.is_server_live() + sys.stdout = sys.__stdout__ + assert not "is_server_live" in captured.getvalue() + +def test_verbose() -> None: + llm = TritonTensorRTLLM(server_url="http://localhost:8001", model_name="ensemble", verbose=True) + captured = StringIO() + sys.stdout = captured + with pytest.raises(InferenceServerException): + llm.client.is_server_live() + sys.stdout = sys.__stdout__ + assert "is_server_live" in captured.getvalue() + + From 7d2c7f79c70dc6d92630b2d6ad76649c1cf0c1ed Mon Sep 17 00:00:00 2001 From: Mikhail Khludnev Date: Thu, 1 Feb 2024 00:55:29 +0300 Subject: [PATCH 2/4] linter fix --- .../partners/nvidia-trt/langchain_nvidia_trt/llms.py | 4 +++- .../nvidia-trt/tests/unit_tests/test_llms.py | 12 +++++++----- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py b/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py index 782c727703003..5bb7d08017891 100644 --- a/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py +++ b/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py @@ -84,7 +84,9 @@ def __del__(self): def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate that python package exists in environment.""" if not values.get("client"): - values["client"] = grpcclient.InferenceServerClient(values["server_url"], verbose=values.get("verbose",False)) + values["client"] = grpcclient.InferenceServerClient( + values["server_url"], verbose=values.get("verbose", False) + ) return values @property diff --git a/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py b/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py index 945491238c861..a5798479e0fa4 100644 --- a/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py +++ b/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py @@ -12,22 +12,24 @@ def test_initialization() -> None: """Test integration initialization.""" TritonTensorRTLLM(model_name="ensemble", server_url="http://localhost:8001") + def test_default_verbose() -> None: - llm=TritonTensorRTLLM(server_url="http://localhost:8001", model_name="ensemble") + llm = TritonTensorRTLLM(server_url="http://localhost:8001", model_name="ensemble") captured = StringIO() sys.stdout = captured with pytest.raises(InferenceServerException): llm.client.is_server_live() sys.stdout = sys.__stdout__ - assert not "is_server_live" in captured.getvalue() + assert "is_server_live" not in captured.getvalue() + def test_verbose() -> None: - llm = TritonTensorRTLLM(server_url="http://localhost:8001", model_name="ensemble", verbose=True) + llm = TritonTensorRTLLM( + server_url="http://localhost:8001", model_name="ensemble", verbose=True + ) captured = StringIO() sys.stdout = captured with pytest.raises(InferenceServerException): llm.client.is_server_live() sys.stdout = sys.__stdout__ assert "is_server_live" in captured.getvalue() - - From c6a82a45985ebb9ebee0489ca968fd4402663c62 Mon Sep 17 00:00:00 2001 From: Mikhail Khludnev Date: Tue, 13 Feb 2024 23:33:59 +0300 Subject: [PATCH 3/4] rename verbose_client make no network attempts in unit test --- .../nvidia-trt/langchain_nvidia_trt/llms.py | 6 +++--- .../nvidia-trt/tests/unit_tests/test_llms.py | 19 ++++++++----------- 2 files changed, 11 insertions(+), 14 deletions(-) diff --git a/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py b/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py index 5bb7d08017891..b4fe74cd94f7f 100644 --- a/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py +++ b/libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py @@ -40,7 +40,7 @@ class TritonTensorRTLLM(BaseLLM): length_penalty: (float) The penalty to apply repeated tokens tokens: (int) The maximum number of tokens to generate. client: The client object used to communicate with the inference server - verbose: flag to pass to the client on creation + verbose_client: flag to pass to the client on creation Example: .. code-block:: python @@ -74,7 +74,7 @@ class TritonTensorRTLLM(BaseLLM): description="Request the inference server to load the specified model.\ Certain Triton configurations do not allow for this operation.", ) - verbose: bool = False + verbose_client: bool = False def __del__(self): """Ensure the client streaming connection is properly shutdown""" @@ -85,7 +85,7 @@ def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate that python package exists in environment.""" if not values.get("client"): values["client"] = grpcclient.InferenceServerClient( - values["server_url"], verbose=values.get("verbose", False) + values["server_url"], verbose=values.get("verbose_client", False) ) return values diff --git a/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py b/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py index a5798479e0fa4..c92c0c7d8363a 100644 --- a/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py +++ b/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py @@ -1,9 +1,7 @@ """Test TritonTensorRT Chat API wrapper.""" import sys from io import StringIO - -import pytest -from tritonclient.utils import InferenceServerException +from unittest.mock import patch from langchain_nvidia_trt import TritonTensorRTLLM @@ -13,23 +11,22 @@ def test_initialization() -> None: TritonTensorRTLLM(model_name="ensemble", server_url="http://localhost:8001") -def test_default_verbose() -> None: +@patch('tritonclient.grpc.service_pb2_grpc.GRPCInferenceServiceStub') +def test_default_verbose(ignore) -> None: llm = TritonTensorRTLLM(server_url="http://localhost:8001", model_name="ensemble") captured = StringIO() sys.stdout = captured - with pytest.raises(InferenceServerException): - llm.client.is_server_live() + llm.client.is_server_live() sys.stdout = sys.__stdout__ assert "is_server_live" not in captured.getvalue() - -def test_verbose() -> None: +@patch('tritonclient.grpc.service_pb2_grpc.GRPCInferenceServiceStub') +def test_verbose(ignore) -> None: llm = TritonTensorRTLLM( - server_url="http://localhost:8001", model_name="ensemble", verbose=True + server_url="http://localhost:8001", model_name="ensemble", verbose_client=True ) captured = StringIO() sys.stdout = captured - with pytest.raises(InferenceServerException): - llm.client.is_server_live() + llm.client.is_server_live() sys.stdout = sys.__stdout__ assert "is_server_live" in captured.getvalue() From 9f05293efbcb0d7686da75c74f9395ffb9e2d704 Mon Sep 17 00:00:00 2001 From: Mikhail Khludnev Date: Tue, 13 Feb 2024 23:43:56 +0300 Subject: [PATCH 4/4] ruff --- libs/partners/nvidia-trt/tests/unit_tests/test_llms.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py b/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py index c92c0c7d8363a..8ef80cdea4688 100644 --- a/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py +++ b/libs/partners/nvidia-trt/tests/unit_tests/test_llms.py @@ -11,7 +11,7 @@ def test_initialization() -> None: TritonTensorRTLLM(model_name="ensemble", server_url="http://localhost:8001") -@patch('tritonclient.grpc.service_pb2_grpc.GRPCInferenceServiceStub') +@patch("tritonclient.grpc.service_pb2_grpc.GRPCInferenceServiceStub") def test_default_verbose(ignore) -> None: llm = TritonTensorRTLLM(server_url="http://localhost:8001", model_name="ensemble") captured = StringIO() @@ -20,7 +20,8 @@ def test_default_verbose(ignore) -> None: sys.stdout = sys.__stdout__ assert "is_server_live" not in captured.getvalue() -@patch('tritonclient.grpc.service_pb2_grpc.GRPCInferenceServiceStub') + +@patch("tritonclient.grpc.service_pb2_grpc.GRPCInferenceServiceStub") def test_verbose(ignore) -> None: llm = TritonTensorRTLLM( server_url="http://localhost:8001", model_name="ensemble", verbose_client=True