Skip to content

Commit

Permalink
Merge pull request #358 from caikit/add_tgis_timeout
Browse files Browse the repository at this point in the history
Add tgis timeout
  • Loading branch information
gkumbhat authored May 17, 2024
2 parents 769812f + fc5ebda commit 98578c9
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 15 deletions.
3 changes: 3 additions & 0 deletions caikit_nlp/config/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,6 @@ embedding:

runtime:
library: caikit_nlp

# Configure request timeout for TGIS backend (in seconds)
tgis_request_timeout: 60
15 changes: 12 additions & 3 deletions caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import grpc

# First Party
from caikit import get_config
from caikit.core.exceptions import error_handler
from caikit.core.exceptions.caikit_core_exception import (
CaikitCoreException,
Expand Down Expand Up @@ -326,6 +327,8 @@ def __init__(
self.producer_id = producer_id
self.prefix_id = prefix_id

self.tgis_req_timeout = get_config().tgis_request_timeout

def unary_generate(
self,
text,
Expand Down Expand Up @@ -432,7 +435,9 @@ def unary_generate(
# Currently, we send a batch request of len(x)==1, so we expect one response back
with alog.ContextTimer(log.trace, "TGIS request duration: "):
try:
batch_response = self.tgis_client.Generate(request)
batch_response = self.tgis_client.Generate(
request, timeout=self.tgis_req_timeout
)
except grpc.RpcError as err:
raise_caikit_core_exception(err)

Expand Down Expand Up @@ -576,7 +581,9 @@ def stream_generate(

# stream GenerationResponse
try:
stream_response = self.tgis_client.GenerateStream(request)
stream_response = self.tgis_client.GenerateStream(
request, timeout=self.tgis_req_timeout
)

for stream_part in stream_response:
details = TokenStreamDetails(
Expand Down Expand Up @@ -645,7 +652,9 @@ def unary_tokenize(
# Currently, we send a batch request of len(x)==1, so we expect one response back
with alog.ContextTimer(log.trace, "TGIS request duration: "):
try:
batch_response = self.tgis_client.Tokenize(request)
batch_response = self.tgis_client.Tokenize(
request, timeout=self.tgis_req_timeout
)
except grpc.RpcError as err:
raise_caikit_core_exception(err)

Expand Down
12 changes: 6 additions & 6 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,17 @@ class StubTGISClient:
def __init__(self, base_model_name):
pass

def Generate(self, request):
def Generate(self, request, **kwargs):
return StubTGISClient.unary_generate(request)

def GenerateStream(self, request):
def GenerateStream(self, request, **kwargs):
return StubTGISClient.stream_generate(request)

def Tokenize(self, request):
def Tokenize(self, request, **kwargs):
return StubTGISClient.tokenize(request)

@staticmethod
def unary_generate(request):
def unary_generate(request, **kwargs):
fake_response = mock.Mock()
fake_result = mock.Mock()
fake_result.stop_reason = 5
Expand All @@ -229,7 +229,7 @@ def unary_generate(request):
return fake_response

@staticmethod
def stream_generate(request):
def stream_generate(request, **kwargs):
fake_stream = mock.Mock()
fake_stream.stop_reason = 5
fake_stream.generated_token_count = 1
Expand All @@ -250,7 +250,7 @@ def stream_generate(request):
yield fake_stream

@staticmethod
def tokenize(request):
def tokenize(request, **kwargs):
fake_response = mock.Mock()
fake_result = mock.Mock()
fake_result.token_count = 1
Expand Down
9 changes: 3 additions & 6 deletions tests/toolkit/text_generation/test_tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,19 @@ def _maybe_raise(self, error_type: Type[grpc.RpcError], *args):
)

def Generate(
self,
request: generation_pb2.BatchedGenerationRequest,
self, request: generation_pb2.BatchedGenerationRequest, **kwargs
) -> generation_pb2.BatchedGenerationResponse:
self._maybe_raise(grpc._channel._InactiveRpcError)
return generation_pb2.BatchedGenerationResponse()

def GenerateStream(
self,
request: generation_pb2.SingleGenerationRequest,
self, request: generation_pb2.SingleGenerationRequest, **kwargs
) -> Iterable[generation_pb2.GenerationResponse]:
self._maybe_raise(grpc._channel._MultiThreadedRendezvous, None, None, None)
yield generation_pb2.GenerationResponse()

def Tokenize(
self,
request: generation_pb2.BatchedTokenizeRequest,
self, request: generation_pb2.BatchedTokenizeRequest, **kwargs
) -> generation_pb2.BatchedTokenizeResponse:
self._maybe_raise(grpc._channel._InactiveRpcError)
return generation_pb2.BatchedTokenizeResponse()
Expand Down

0 comments on commit 98578c9

Please sign in to comment.