Skip to content

Commit

Permalink
rename verbose_client
Browse files Browse the repository at this point in the history
make no network attempts in unit test
  • Loading branch information
mkhludnev committed Mar 3, 2024
1 parent 7d2c7f7 commit c6a82a4
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
6 changes: 3 additions & 3 deletions libs/partners/nvidia-trt/langchain_nvidia_trt/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class TritonTensorRTLLM(BaseLLM):
length_penalty: (float) The penalty to apply repeated tokens
tokens: (int) The maximum number of tokens to generate.
client: The client object used to communicate with the inference server
verbose: flag to pass to the client on creation
verbose_client: flag to pass to the client on creation
Example:
.. code-block:: python
Expand Down Expand Up @@ -74,7 +74,7 @@ class TritonTensorRTLLM(BaseLLM):
description="Request the inference server to load the specified model.\
Certain Triton configurations do not allow for this operation.",
)
verbose: bool = False
verbose_client: bool = False

def __del__(self):
"""Ensure the client streaming connection is properly shutdown"""
Expand All @@ -85,7 +85,7 @@ def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that python package exists in environment."""
if not values.get("client"):
values["client"] = grpcclient.InferenceServerClient(
values["server_url"], verbose=values.get("verbose", False)
values["server_url"], verbose=values.get("verbose_client", False)
)
return values

Expand Down
19 changes: 8 additions & 11 deletions libs/partners/nvidia-trt/tests/unit_tests/test_llms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
"""Test TritonTensorRT Chat API wrapper."""
import sys
from io import StringIO

import pytest
from tritonclient.utils import InferenceServerException
from unittest.mock import patch

from langchain_nvidia_trt import TritonTensorRTLLM

Expand All @@ -13,23 +11,22 @@ def test_initialization() -> None:
TritonTensorRTLLM(model_name="ensemble", server_url="http://localhost:8001")


def test_default_verbose() -> None:
@patch('tritonclient.grpc.service_pb2_grpc.GRPCInferenceServiceStub')
def test_default_verbose(ignore) -> None:
llm = TritonTensorRTLLM(server_url="http://localhost:8001", model_name="ensemble")
captured = StringIO()
sys.stdout = captured
with pytest.raises(InferenceServerException):
llm.client.is_server_live()
llm.client.is_server_live()
sys.stdout = sys.__stdout__
assert "is_server_live" not in captured.getvalue()


def test_verbose() -> None:
@patch('tritonclient.grpc.service_pb2_grpc.GRPCInferenceServiceStub')
def test_verbose(ignore) -> None:
llm = TritonTensorRTLLM(
server_url="http://localhost:8001", model_name="ensemble", verbose=True
server_url="http://localhost:8001", model_name="ensemble", verbose_client=True
)
captured = StringIO()
sys.stdout = captured
with pytest.raises(InferenceServerException):
llm.client.is_server_live()
llm.client.is_server_live()
sys.stdout = sys.__stdout__
assert "is_server_live" in captured.getvalue()

0 comments on commit c6a82a4

Please sign in to comment.