From da67d7700ad4f0f7b4838656bfcb53948750d812 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Thu, 21 Sep 2023 09:50:29 -0700 Subject: [PATCH] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20change=20name=20to=20get?= =?UTF-8?q?=5Finference=5Frequest?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- caikit/runtime/__init__.py | 2 +- caikit/runtime/service_factory.py | 2 +- examples/text-sentiment/client.py | 4 ++-- tests/runtime/test_grpc_server.py | 30 +++++++++++++++--------------- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/caikit/runtime/__init__.py b/caikit/runtime/__init__.py index d15dd1f19..4b57df6b8 100644 --- a/caikit/runtime/__init__.py +++ b/caikit/runtime/__init__.py @@ -13,4 +13,4 @@ # limitations under the License. # Local -from .service_factory import get_request, get_train_params, get_train_request +from .service_factory import get_inference_request, get_train_params, get_train_request diff --git a/caikit/runtime/service_factory.py b/caikit/runtime/service_factory.py index 191244c99..b9ee4d43d 100644 --- a/caikit/runtime/service_factory.py +++ b/caikit/runtime/service_factory.py @@ -269,7 +269,7 @@ def _get_and_filter_modules( return clean_modules -def get_request( +def get_inference_request( module_class: Type[ModuleBase], input_streaming: bool = False, output_streaming: bool = False, diff --git a/examples/text-sentiment/client.py b/examples/text-sentiment/client.py index 54d1c6b7a..48c198bb0 100644 --- a/examples/text-sentiment/client.py +++ b/examples/text-sentiment/client.py @@ -21,7 +21,7 @@ # Local from caikit.config.config import get_config -from caikit.runtime import get_request +from caikit.runtime import get_inference_request from caikit.runtime.service_factory import ServicePackageFactory from text_sentiment.runtime_model import HuggingFaceSentimentModule import caikit @@ -52,7 +52,7 @@ # Run inference for two sample prompts for text in ["I am not feeling well today!", "Today is a nice sunny day"]: - request = get_request(HuggingFaceSentimentModule)( + request = get_inference_request(HuggingFaceSentimentModule)( text_input=text ).to_proto() response = client_stub.HuggingFaceSentimentTaskPredict( diff --git a/tests/runtime/test_grpc_server.py b/tests/runtime/test_grpc_server.py index d23dd07fd..004670c56 100644 --- a/tests/runtime/test_grpc_server.py +++ b/tests/runtime/test_grpc_server.py @@ -49,7 +49,7 @@ TrainingStatus, TrainingStatusResponse, ) -from caikit.runtime import get_request, get_train_params, get_train_request +from caikit.runtime import get_inference_request, get_train_params, get_train_request from caikit.runtime.grpc_server import RuntimeGRPCServer from caikit.runtime.model_management.model_manager import ModelManager from caikit.runtime.protobufs import ( @@ -214,7 +214,7 @@ def test_predict_sample_module_ok_response( ): """Test RPC CaikitRuntime.SampleTaskPredict successful response""" stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel()) - predict_request = get_request(SampleModule)( + predict_request = get_inference_request(SampleModule)( sample_input=HAPPY_PATH_INPUT_DM ).to_proto() @@ -244,7 +244,7 @@ def test_global_predict_build_caikit_library_request_dict_creates_caikit_core_ru assert proto_expected_arguments == set(proto_request_dict.keys()) # pythonic data model request - predict_class = get_request(SampleModule) + predict_class = get_inference_request(SampleModule) python_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() python_sample_module_request_dict = build_caikit_library_request_dict( @@ -263,7 +263,7 @@ def test_predict_streaming_module( ): """Test RPC CaikitRuntime.StreamingTaskPredict successful response""" stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel()) - predict_class = get_request( + predict_class = get_inference_request( StreamingModule, input_streaming=False, output_streaming=True ) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() @@ -287,7 +287,7 @@ def test_predict_sample_module_error_response( stub = sample_inference_service.stub_class( runtime_grpc_server.make_local_channel() ) - predict_class = get_request(SampleModule) + predict_class = get_inference_request(SampleModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() stub.SampleTaskPredict( @@ -302,7 +302,7 @@ def test_rpc_validation_on_predict( ): """Check that the server catches models sent to the wrong task RPCs""" stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel()) - predict_class = get_request(OtherModule) + predict_class = get_inference_request(OtherModule) predict_request = predict_class( sample_input_sampleinputtype=HAPPY_PATH_INPUT_DM ).to_proto() @@ -333,7 +333,7 @@ def test_rpc_validation_on_predict_for_unsupported_model( stub = sample_inference_service.stub_class( runtime_grpc_server.make_local_channel() ) - predict_class = get_request(SampleModule) + predict_class = get_inference_request(SampleModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() with pytest.raises(grpc.RpcError) as context: stub.SampleTaskPredict( @@ -366,7 +366,7 @@ def test_rpc_validation_on_predict_for_wrong_streaming_flavor( runtime_grpc_server.make_local_channel() ) - predict_class = get_request( + predict_class = get_inference_request( SampleModule, ) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() @@ -423,7 +423,7 @@ def test_train_fake_module_ok_response_and_can_predict_with_trained_model( ) # make sure the trained model can run inference - predict_class = get_request(SampleModule) + predict_class = get_inference_request(SampleModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] @@ -459,7 +459,7 @@ def test_train_fake_module_ok_response_with_loaded_model_can_predict_with_traine ) # make sure the trained model can run inference - predict_class = get_request(CompositeModule) + predict_class = get_inference_request(CompositeModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] @@ -506,7 +506,7 @@ def test_train_fake_module_does_not_change_another_instance_model_of_block( ) # make sure the trained model can run inference, and the batch size 100 was used - predict_class = get_request(OtherModule) + predict_class = get_inference_request(OtherModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() trained_inference_response = inference_stub.OtherTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] @@ -566,7 +566,7 @@ def test_train_primitive_model( ) # make sure the trained model can run inference - predict_class = get_request(SampleModule) + predict_class = get_inference_request(SampleModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( @@ -613,7 +613,7 @@ def test_train_fake_module_ok_response_with_datastream_jsondata( ) # make sure the trained model can run inference - predict_class = get_request(SampleModule) + predict_class = get_inference_request(SampleModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] @@ -651,7 +651,7 @@ def test_train_fake_module_ok_response_with_datastream_csv_file( ) # make sure the trained model can run inference - predict_class = get_request(SampleModule) + predict_class = get_inference_request(SampleModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] @@ -1116,7 +1116,7 @@ def test_metrics_stored_after_server_interrupt( with temp_config({"runtime": {"metering": {"enabled": True}}}, "merge"): with runtime_grpc_test_server(open_port) as server: stub = sample_inference_service.stub_class(server.make_local_channel()) - predict_class = get_request(SampleModule) + predict_class = get_inference_request(SampleModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() _ = stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", sample_task_model_id)]