Skip to content

Commit

Permalink
fix failing test and add notes on change in outputs
Browse files Browse the repository at this point in the history
Signed-off-by: Anh-Uong <[email protected]>
  • Loading branch information
anhuong committed Sep 13, 2023
1 parent 71d3cb6 commit 80a5157
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 108 deletions.
4 changes: 0 additions & 4 deletions examples/text-sentiment/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,10 @@

# Run inference for two sample prompts
for text in ["I am not feeling well today!", "Today is a nice sunny day"]:
# TODO: is this the recommended approach for setting up client and request?
predict_class = DataBase.get_class_for_name(
"HuggingFaceSentimentTaskRequest"
)
request = predict_class(text_input=text).to_proto()
# request = inference_service.messages.HuggingFaceSentimentTaskRequest(
# text_input=text
# )
response = client_stub.HuggingFaceSentimentTaskPredict(
request, metadata=[("mm-model-id", model_id)], timeout=1
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import grpc

# Local
from caikit.core.data_model.base import DataBase
from caikit.runtime.interceptors.caikit_runtime_server_wrapper import (
CaikitRuntimeServerWrapper,
)
Expand Down Expand Up @@ -47,9 +48,8 @@ def predict(request, context, caikit_rpc):
client = sample_inference_service.stub_class(
grpc.insecure_channel(f"localhost:{open_port}")
)
_ = client.SampleTaskPredict(
sample_inference_service.messages.SampleTaskRequest(), timeout=3
)
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
_ = client.SampleTaskPredict(predict_class().to_proto(), timeout=3)
assert len(calls) == 1
assert isinstance(calls[0], TaskPredictRPC)
assert calls[0].name == "SampleTaskPredict"
Expand Down
26 changes: 8 additions & 18 deletions tests/runtime/servicers/test_global_predict_servicer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,11 @@ def test_global_predict_works_on_bidirectional_streaming_rpcs(
):
"""Simple test that our SampleModule's bidirectional stream inference fn is supported"""

def req_iterator() -> Iterator[
sample_inference_service.messages.BidiStreamingSampleTaskRequest
]:
# TODO: can this be replaced?
# Guessing not since got error: caikit.runtime.types.caikit_runtime_exception.CaikitRuntimeException: Exception raised during inference. This may be a problem with your input: BidiStreamingSampleTaskRequest.__init__() got an unexpected keyword argument 'sample_input'
# predict_class = DataBase.get_class_for_name("BidiStreamingSampleTaskRequest")
# predict_class(
# sample_input=HAPPY_PATH_INPUT_DM
# ).to_proto()
predict_class = DataBase.get_class_for_name("BidiStreamingSampleTaskRequest")

def req_iterator() -> Iterator[predict_class]:
for i in range(100):
yield sample_inference_service.messages.BidiStreamingSampleTaskRequest(
sample_inputs=HAPPY_PATH_INPUT
)
yield predict_class(sample_inputs=HAPPY_PATH_INPUT_DM).to_proto()

response_stream = sample_predict_servicer.Predict(
req_iterator(),
Expand All @@ -147,13 +139,11 @@ def test_global_predict_works_on_bidirectional_streaming_rpcs_with_multiple_stre
mock_manager = MagicMock()
mock_manager.retrieve_model.return_value = GeoStreamingModule()

def req_iterator() -> Iterator[
sample_inference_service.messages.BidiStreamingGeoSpatialTaskRequest
]:
predict_class = DataBase.get_class_for_name("BidiStreamingGeoSpatialTaskRequest")

def req_iterator() -> Iterator[predict_class]:
for i in range(100):
yield sample_inference_service.messages.BidiStreamingGeoSpatialTaskRequest(
lats=i, lons=100 - i, name="Gabe"
)
yield predict_class(lats=i, lons=100 - i, name="Gabe").to_proto()

with patch.object(sample_predict_servicer, "_model_manager", mock_manager):
response_stream = sample_predict_servicer.Predict(
Expand Down
45 changes: 32 additions & 13 deletions tests/runtime/test_grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
process_pb2_grpc,
)
from caikit.runtime.service_factory import ServicePackage, ServicePackageFactory
from caikit.runtime.utils.servicer_util import build_caikit_library_request_dict
from sample_lib import InnerModule, SamplePrimitiveModule
from sample_lib.data_model import (
OtherOutputType,
Expand Down Expand Up @@ -206,34 +207,52 @@ def test_components_preinitialized(reset_globals, open_port):
assert MODEL_MANAGER._initializers


def test_predict_sample_module_proto_ok_response(
def test_predict_sample_module_ok_response(
sample_task_model_id, runtime_grpc_server, sample_inference_service
):
"""Test RPC CaikitRuntime.SampleTaskPredict successful response"""
stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel())
predict_request = sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT
)
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()

actual_response = stub.SampleTaskPredict(
predict_request, metadata=[("mm-model-id", sample_task_model_id)]
)

assert actual_response == HAPPY_PATH_RESPONSE


def test_predict_sample_module_ok_response(
sample_task_model_id, runtime_grpc_server, sample_inference_service
def test_global_predict_build_caikit_library_request_dict_creates_caikit_core_run_kwargs(
sample_inference_service,
):
"""Test RPC CaikitRuntime.SampleTaskPredict successful response"""
stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel())
"""Test using proto versus pythonic data model for inference requests to compare diffs"""
# Protobuf input
proto_request = sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT_DM.to_proto(),
)
proto_request_dict = build_caikit_library_request_dict(
proto_request,
sample_lib.modules.sample_task.SampleModule.RUN_SIGNATURE,
)

# unset fields not included
proto_expected_arguments = {"sample_input"}
assert proto_request.HasField("throw") is False
assert proto_expected_arguments == set(proto_request_dict.keys())

# Pythonic data model
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
# TODO: is this getting predict_class = SampleModule() ??
predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()
python_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()

actual_response = stub.SampleTaskPredict(
predict_request, metadata=[("mm-model-id", sample_task_model_id)]
python_sample_module_request_dict = build_caikit_library_request_dict(
python_request,
sample_lib.modules.sample_task.SampleModule.RUN_SIGNATURE,
)

assert actual_response == HAPPY_PATH_RESPONSE
# unset fields are included if they have defaults set
python_expected_arguments = {"sample_input", "throw"}
assert python_request.HasField("throw") is True
assert python_expected_arguments == set(python_sample_module_request_dict.keys())


def test_predict_streaming_module(
Expand Down
4 changes: 3 additions & 1 deletion tests/runtime/test_service_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

# Local
from caikit.core.data_model import render_dataobject_protos
from caikit.core.data_model.base import DataBase
from caikit.runtime.service_factory import ServicePackage, ServicePackageFactory
from sample_lib import SampleModule
from sample_lib.data_model import SampleInputType, SampleOutputType
Expand Down Expand Up @@ -359,7 +360,8 @@ def run(
inference_service = ServicePackageFactory.get_service_package(
ServicePackageFactory.ServiceType.INFERENCE
)
sample_task_request = inference_service.messages.SampleTaskRequest
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
sample_task_request = predict_class().to_proto()

# Check that the new parameter defined in this backend module exists in the service
assert "backend_param" in sample_task_request.DESCRIPTOR.fields_by_name.keys()
Loading

0 comments on commit 80a5157

Please sign in to comment.