Skip to content

Commit

Permalink
Merge pull request #364 from gabe-l-hart/RouteInfoFromBackend
Browse files Browse the repository at this point in the history
Forward get_route_info and ROUTE_INFO_HEADER_KEY from backend
  • Loading branch information
gabe-l-hart authored Jun 19, 2024
2 parents 3b2f8fd + ff7f056 commit 4ce8435
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 57 deletions.
13 changes: 1 addition & 12 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from ...toolkit.text_generation.tgis_utils import (
GENERATE_FUNCTION_TGIS_ARGS,
TGISGenerationClient,
get_route_info,
)
from ...toolkit.verbalizer_utils import render_verbalizer
from . import PeftPromptTuning
Expand Down Expand Up @@ -354,15 +353,5 @@ def _register_model_connection_with_context(
a context override provided.
"""
if self._tgis_backend:
if route_info := get_route_info(context):
log.debug(
"<NLP10705560D> Registering remote model connection with context "
"override: 'hostname: %s'",
route_info,
)
self._tgis_backend.register_model_connection(
self.base_model_name,
{"hostname": route_info},
fill_with_defaults=True,
)
self._tgis_backend.handle_runtime_context(self.base_model_name, context)
self._model_loaded = True
11 changes: 1 addition & 10 deletions caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
from ...toolkit.text_generation.tgis_utils import (
GENERATE_FUNCTION_TGIS_ARGS,
TGISGenerationClient,
get_route_info,
)
from .text_generation_local import TextGeneration

Expand Down Expand Up @@ -362,13 +361,5 @@ def _register_model_connection_with_context(
a context override provided.
"""
if self._tgis_backend:
if route_info := get_route_info(context):
log.debug(
"<NLP15770311D> Registering remote model connection with context "
"override: 'hostname: %s'",
route_info,
)
self._tgis_backend.register_model_connection(
self.model_name, {"hostname": route_info}, fill_with_defaults=True
)
self._tgis_backend.handle_runtime_context(self.model_name, context)
self._model_loaded = True
37 changes: 5 additions & 32 deletions caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
"""This file is for helper functions related to TGIS."""

# Standard
from typing import Iterable, Optional
from typing import Iterable

# Third Party
import fastapi
import grpc

# First Party
Expand All @@ -34,7 +33,7 @@
TokenizationResults,
TokenStreamDetails,
)
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
from caikit_tgis_backend import TGISBackend
from caikit_tgis_backend.protobufs import generation_pb2
import alog

Expand Down Expand Up @@ -87,7 +86,9 @@
}

# HTTP Header / gRPC Metadata key used to identify a route override
ROUTE_INFO_HEADER_KEY = "x-route-info"
# (forwarded for API compatibility)
ROUTE_INFO_HEADER_KEY = TGISBackend.ROUTE_INFO_HEADER_KEY
get_route_info = TGISBackend.get_route_info


def raise_caikit_core_exception(rpc_error: grpc.RpcError):
Expand Down Expand Up @@ -688,31 +689,3 @@ def unary_tokenize(
return TokenizationResults(
token_count=response.token_count,
)


def get_route_info(
context: Optional[RuntimeServerContextType],
) -> Optional[str]:
"""
Returns a tuple `(True, x-route-info)` from context if "x-route-info" was found in
the headers/metadata.
Otherwise returns a tuple `(False, None)` if "x-route-info" was not found in the
context or if context is None.
"""
if context is None:
return None

if isinstance(context, grpc.ServicerContext):
route_info = dict(context.invocation_metadata()).get(ROUTE_INFO_HEADER_KEY)
if route_info:
return route_info
elif isinstance(context, fastapi.Request):
route_info = context.headers.get(ROUTE_INFO_HEADER_KEY)
if route_info:
return route_info
else:
error.log_raise(
"<NLP92615097E>",
ValueError(f"context is of an unsupported type: {type(context)}"),
)
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ classifiers=[
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"caikit[runtime-grpc,runtime-http]>=0.26.27,<0.27.0",
"caikit-tgis-backend>=0.1.33,<0.2.0",
"caikit[runtime-grpc,runtime-http]>=0.26.34,<0.27.0",
"caikit-tgis-backend>=0.1.34,<0.2.0",
# TODO: loosen dependencies
"grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking
"grpcio-reflection>=1.62.2",
Expand Down
5 changes: 4 additions & 1 deletion tests/toolkit/text_generation/test_tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def test_TGISGenerationClient_rpc_errors(status_code, method):
assert isinstance(rpc_err, grpc.RpcError)


# NOTE: This test is preserved in caikit-nlp despite being duplicated in
# caikit-tgis-backend so that we guarantee that the functionality is accessible
# in a version-compatible way here.
@pytest.mark.parametrize(
argnames=["context", "route_info"],
argvalues=[
Expand Down Expand Up @@ -168,7 +171,7 @@ def test_TGISGenerationClient_rpc_errors(status_code, method):
)
def test_get_route_info(context: RuntimeServerContextType, route_info: Optional[str]):
if not isinstance(context, (fastapi.Request, grpc.ServicerContext, type(None))):
with pytest.raises(ValueError):
with pytest.raises(TypeError):
tgis_utils.get_route_info(context)
else:
actual_route_info = tgis_utils.get_route_info(context)
Expand Down

0 comments on commit 4ce8435

Please sign in to comment.