From bd0ab2669defa76858ea13083d5547c46f3ead5c Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Mon, 18 Sep 2023 12:11:25 -0700 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20introduce=20request=20dm=20fetching?= =?UTF-8?q?=20in=20runtime?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Prashant Gupta --- caikit/runtime/__init__.py | 3 + caikit/runtime/service_factory.py | 36 +++++ examples/text-sentiment/client.py | 3 +- tests/fixtures/sample_lib/__init__.py | 9 +- tests/fixtures/sample_lib/modules/__init__.py | 1 + tests/runtime/test_grpc_server.py | 140 +++++++----------- 6 files changed, 106 insertions(+), 86 deletions(-) diff --git a/caikit/runtime/__init__.py b/caikit/runtime/__init__.py index 2068258bf..d15dd1f19 100644 --- a/caikit/runtime/__init__.py +++ b/caikit/runtime/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# Local +from .service_factory import get_request, get_train_params, get_train_request diff --git a/caikit/runtime/service_factory.py b/caikit/runtime/service_factory.py index 1970a82a7..16dd5d3e0 100644 --- a/caikit/runtime/service_factory.py +++ b/caikit/runtime/service_factory.py @@ -31,6 +31,7 @@ # Local from caikit import get_config from caikit.core import LocalBackend, ModuleBase, registries +from caikit.core.data_model.base import DataBase from caikit.core.data_model.dataobject import _AUTO_GEN_PROTO_CLASSES from caikit.interfaces.runtime.data_model import ( TrainingInfoRequest, @@ -266,3 +267,38 @@ def _get_and_filter_modules( excluded_modules, ) return clean_modules + + +def get_request( + module_class: Type[ModuleBase], + input_streaming: bool = False, + output_streaming: bool = False, +) -> Type[DataBase]: + """Helper function to return the request DataModel for the Module Class""" + if input_streaming and output_streaming: + request_class_name = f"BidiStreaming{module_class.TASK_CLASS.__name__}Request" + elif input_streaming: + request_class_name = f"ClientStreaming{module_class.TASK_CLASS.__name__}Request" + elif output_streaming: + request_class_name = f"ServerStreaming{module_class.TASK_CLASS.__name__}Request" + else: + request_class_name = f"{module_class.TASK_CLASS.__name__}Request" + return DataBase.get_class_for_name(request_class_name) + + +def get_train_request(module_class: Type[ModuleBase]) -> Type[DataBase]: + """Helper function to return the train request DataModel for the Module Class""" + request_class_name = ( + f"{module_class.TASK_CLASS.__name__}{module_class.__name__}TrainRequest" + ) + print(request_class_name) + return DataBase.get_class_for_name(request_class_name) + + +def get_train_params(module_class: Type[ModuleBase]) -> Type[DataBase]: + """Helper function to return the train parameters DataModel for the Module Class""" + request_class_name = ( + f"{module_class.TASK_CLASS.__name__}{module_class.__name__}TrainParameters" + ) + print(request_class_name) + return DataBase.get_class_for_name(request_class_name) diff --git a/examples/text-sentiment/client.py b/examples/text-sentiment/client.py index 2523d7122..54d1c6b7a 100644 --- a/examples/text-sentiment/client.py +++ b/examples/text-sentiment/client.py @@ -21,6 +21,7 @@ # Local from caikit.config.config import get_config +from caikit.runtime import get_request from caikit.runtime.service_factory import ServicePackageFactory from text_sentiment.runtime_model import HuggingFaceSentimentModule import caikit @@ -51,7 +52,7 @@ # Run inference for two sample prompts for text in ["I am not feeling well today!", "Today is a nice sunny day"]: - request = HuggingFaceSentimentModule.TASK_CLASS.UNARY_REQUEST_DATA_MODEL( + request = get_request(HuggingFaceSentimentModule)( text_input=text ).to_proto() response = client_stub.HuggingFaceSentimentTaskPredict( diff --git a/tests/fixtures/sample_lib/__init__.py b/tests/fixtures/sample_lib/__init__.py index f218fb0c4..a01314e4d 100644 --- a/tests/fixtures/sample_lib/__init__.py +++ b/tests/fixtures/sample_lib/__init__.py @@ -3,7 +3,14 @@ # Local from . import data_model, modules -from .modules import InnerModule, OtherModule, SampleModule, SamplePrimitiveModule +from .modules import ( + CompositeModule, + InnerModule, + OtherModule, + SampleModule, + SamplePrimitiveModule, + StreamingModule, +) from caikit.config import configure # Run configure for sample_lib configuration diff --git a/tests/fixtures/sample_lib/modules/__init__.py b/tests/fixtures/sample_lib/modules/__init__.py index d3eaea6aa..65f60a58f 100644 --- a/tests/fixtures/sample_lib/modules/__init__.py +++ b/tests/fixtures/sample_lib/modules/__init__.py @@ -7,4 +7,5 @@ InnerModule, SampleModule, SamplePrimitiveModule, + StreamingModule, ) diff --git a/tests/runtime/test_grpc_server.py b/tests/runtime/test_grpc_server.py index 0fc6b2fbf..d23dd07fd 100644 --- a/tests/runtime/test_grpc_server.py +++ b/tests/runtime/test_grpc_server.py @@ -42,7 +42,6 @@ # Local from caikit import get_config from caikit.core import MODEL_MANAGER -from caikit.core.data_model.base import DataBase from caikit.core.data_model.producer import ProducerId from caikit.interfaces.runtime.data_model import ( TrainingInfoRequest, @@ -50,6 +49,7 @@ TrainingStatus, TrainingStatusResponse, ) +from caikit.runtime import get_request, get_train_params, get_train_request from caikit.runtime.grpc_server import RuntimeGRPCServer from caikit.runtime.model_management.model_manager import ModelManager from caikit.runtime.protobufs import ( @@ -60,7 +60,14 @@ ) from caikit.runtime.service_factory import ServicePackage, ServicePackageFactory from caikit.runtime.utils.servicer_util import build_caikit_library_request_dict -from sample_lib import InnerModule, SamplePrimitiveModule +from sample_lib import ( + CompositeModule, + InnerModule, + OtherModule, + SampleModule, + SamplePrimitiveModule, + StreamingModule, +) from sample_lib.data_model import ( OtherOutputType, SampleInputType, @@ -76,7 +83,6 @@ runtime_grpc_test_server, ) import caikit.interfaces.common -import sample_lib ## Helpers ##################################################################### @@ -131,11 +137,7 @@ def test_model_train(runtime_grpc_server): "parameters": { "training_data": { "jsondata": { - "data": [ - sample_lib.data_model.SampleTrainingType( - number=1 - ).to_dict() - ] + "data": [SampleTrainingType(number=1).to_dict()] }, }, }, @@ -212,11 +214,9 @@ def test_predict_sample_module_ok_response( ): """Test RPC CaikitRuntime.SampleTaskPredict successful response""" stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel()) - predict_request = ( - sample_lib.modules.SampleModule.TASK_CLASS.UNARY_REQUEST_DATA_MODEL( - sample_input=HAPPY_PATH_INPUT_DM - ).to_proto() - ) + predict_request = get_request(SampleModule)( + sample_input=HAPPY_PATH_INPUT_DM + ).to_proto() actual_response = stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", sample_task_model_id)] @@ -235,7 +235,7 @@ def test_global_predict_build_caikit_library_request_dict_creates_caikit_core_ru ) proto_request_dict = build_caikit_library_request_dict( proto_request, - sample_lib.modules.sample_task.SampleModule.RUN_SIGNATURE, + SampleModule.RUN_SIGNATURE, ) # unset fields not included @@ -244,12 +244,12 @@ def test_global_predict_build_caikit_library_request_dict_creates_caikit_core_ru assert proto_expected_arguments == set(proto_request_dict.keys()) # pythonic data model request - predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_class = get_request(SampleModule) python_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() python_sample_module_request_dict = build_caikit_library_request_dict( python_request, - sample_lib.modules.sample_task.SampleModule.RUN_SIGNATURE, + SampleModule.RUN_SIGNATURE, ) # unset fields are included if they have defaults set @@ -263,7 +263,9 @@ def test_predict_streaming_module( ): """Test RPC CaikitRuntime.StreamingTaskPredict successful response""" stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel()) - predict_class = DataBase.get_class_for_name("ServerStreamingStreamingTaskRequest") + predict_class = get_request( + StreamingModule, input_streaming=False, output_streaming=True + ) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() stream = stub.ServerStreamingStreamingTaskPredict( @@ -285,7 +287,7 @@ def test_predict_sample_module_error_response( stub = sample_inference_service.stub_class( runtime_grpc_server.make_local_channel() ) - predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_class = get_request(SampleModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() stub.SampleTaskPredict( @@ -300,7 +302,7 @@ def test_rpc_validation_on_predict( ): """Check that the server catches models sent to the wrong task RPCs""" stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel()) - predict_class = DataBase.get_class_for_name("OtherTaskRequest") + predict_class = get_request(OtherModule) predict_request = predict_class( sample_input_sampleinputtype=HAPPY_PATH_INPUT_DM ).to_proto() @@ -331,7 +333,7 @@ def test_rpc_validation_on_predict_for_unsupported_model( stub = sample_inference_service.stub_class( runtime_grpc_server.make_local_channel() ) - predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_class = get_request(SampleModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() with pytest.raises(grpc.RpcError) as context: stub.SampleTaskPredict( @@ -363,7 +365,10 @@ def test_rpc_validation_on_predict_for_wrong_streaming_flavor( stub = sample_inference_service.stub_class( runtime_grpc_server.make_local_channel() ) - predict_class = DataBase.get_class_for_name("SampleTaskRequest") + + predict_class = get_request( + SampleModule, + ) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() with pytest.raises(grpc.RpcError) as context: response = stub.ServerStreamingSampleTaskPredict( @@ -400,9 +405,9 @@ def test_train_fake_module_ok_response_and_can_predict_with_trained_model( ) ) model_name = random_test_id() - train_request = sample_lib.modules.SampleModule.TRAIN_REQUEST_DATA_MODEL( + train_request = get_train_request(SampleModule)( model_name=model_name, - parameters=sample_lib.modules.SampleModule.TRAINING_PARAMETERS_DATA_MODEL( + parameters=get_train_params(SampleModule)( training_data=training_data, union_list=["str", "sequence"], ), @@ -418,7 +423,7 @@ def test_train_fake_module_ok_response_and_can_predict_with_trained_model( ) # make sure the trained model can run inference - predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_class = get_request(SampleModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] @@ -440,13 +445,9 @@ def test_train_fake_module_ok_response_with_loaded_model_can_predict_with_traine model_id=sample_task_model_id ) model_name = random_test_id() - train_class = DataBase.get_class_for_name("SampleTaskCompositeModuleTrainRequest") - train_request_params_class = DataBase.get_class_for_name( - "SampleTaskCompositeModuleTrainParameters" - ) - train_request = train_class( + train_request = get_train_request(CompositeModule)( model_name=model_name, - parameters=train_request_params_class(sample_block=sample_model), + parameters=get_train_params(CompositeModule)(sample_block=sample_model), ).to_proto() actual_response = train_stub.SampleTaskCompositeModuleTrain(train_request) @@ -458,7 +459,7 @@ def test_train_fake_module_ok_response_with_loaded_model_can_predict_with_traine ) # make sure the trained model can run inference - predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_class = get_request(CompositeModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] @@ -485,13 +486,9 @@ def test_train_fake_module_does_not_change_another_instance_model_of_block( stream_type = caikit.interfaces.common.data_model.DataStreamSourceInt training_data = stream_type(file=stream_type.File(filename=sample_int_file)) - train_class = DataBase.get_class_for_name("OtherTaskOtherModuleTrainRequest") - train_request_params_class = DataBase.get_class_for_name( - "OtherTaskOtherModuleTrainParameters" - ) - train_request = train_class( + train_request = get_train_request(OtherModule)( model_name="Bar Training", - parameters=train_request_params_class( + parameters=get_train_params(OtherModule)( sample_input_sampleinputtype=HAPPY_PATH_INPUT_DM, batch_size=100, training_data=training_data, @@ -509,7 +506,7 @@ def test_train_fake_module_does_not_change_another_instance_model_of_block( ) # make sure the trained model can run inference, and the batch size 100 was used - predict_class = DataBase.get_class_for_name("OtherTaskRequest") + predict_class = get_request(OtherModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() trained_inference_response = inference_stub.OtherTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] @@ -540,16 +537,10 @@ def test_train_primitive_model( """Test that we can make a successful training and inference call to the primitive module using primitive inputs""" model_name = "primitive_trained_model" - train_request_class = DataBase.get_class_for_name( - "SampleTaskSamplePrimitiveModuleTrainRequest" - ) - train_request_params_class = DataBase.get_class_for_name( - "SampleTaskSamplePrimitiveModuleTrainParameters" - ) - train_request = train_request_class( + train_request = get_train_request(SamplePrimitiveModule)( model_name=model_name, - parameters=train_request_params_class( + parameters=get_train_params(SamplePrimitiveModule)( sample_input=HAPPY_PATH_INPUT_DM, simple_list=["hello", "world"], union_list=["str", "sequence"], @@ -575,7 +566,7 @@ def test_train_primitive_model( ) # make sure the trained model can run inference - predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_class = get_request(SampleModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( @@ -604,13 +595,10 @@ def test_train_fake_module_ok_response_with_datastream_jsondata( ) ) model_name = random_test_id() - train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") - train_request_params_class = DataBase.get_class_for_name( - "SampleTaskSampleModuleTrainParameters" - ) - train_request = train_class( + + train_request = get_train_request(SampleModule)( model_name=model_name, - parameters=train_request_params_class( + parameters=get_train_params(SampleModule)( batch_size=42, training_data=training_data, ), @@ -625,7 +613,7 @@ def test_train_fake_module_ok_response_with_datastream_jsondata( ) # make sure the trained model can run inference - predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_class = get_request(SampleModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] @@ -646,13 +634,10 @@ def test_train_fake_module_ok_response_with_datastream_csv_file( stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType training_data = stream_type(file=stream_type.File(filename=sample_csv_file)) model_name = random_test_id() - train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") - train_request_params_class = DataBase.get_class_for_name( - "SampleTaskSampleModuleTrainParameters" - ) - train_request = train_class( + + train_request = get_train_request(SampleModule)( model_name=model_name, - parameters=train_request_params_class( + parameters=get_train_params(SampleModule)( training_data=training_data, ), ).to_proto() @@ -666,7 +651,7 @@ def test_train_fake_module_ok_response_with_datastream_csv_file( ) # make sure the trained model can run inference - predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_class = get_request(SampleModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] @@ -686,13 +671,10 @@ def test_train_and_successfully_cancel_training( ) model_name = random_test_id() # start a training that sleeps for a long time, so I can cancel - train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") - train_request_params_class = DataBase.get_class_for_name( - "SampleTaskSampleModuleTrainParameters" - ) - train_request = train_class( + + train_request = get_train_request(SampleModule)( model_name=model_name, - parameters=train_request_params_class( + parameters=get_train_params(SampleModule)( training_data=training_data, sleep_time=10 ), ).to_proto() @@ -727,13 +709,9 @@ def test_cancel_does_not_affect_other_models( ) model_name = random_test_id() # start a training that sleeps for a long time, so I can cancel - train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") - train_request_params_class = DataBase.get_class_for_name( - "SampleTaskSampleModuleTrainParameters" - ) - train_request = train_class( + train_request = get_train_request(SampleModule)( model_name=model_name, - parameters=train_request_params_class( + parameters=get_train_params(SampleModule)( training_data=training_data, sleep_time=10 ), ).to_proto() @@ -754,9 +732,9 @@ def test_cancel_does_not_affect_other_models( # train another model model_name2 = random_test_id() - train_request2 = train_class( + train_request2 = get_train_request(SampleModule)( model_name=model_name2, - parameters=train_request_params_class(training_data=training_data), + parameters=get_train_params(SampleModule)(training_data=training_data), ).to_proto() train_response2 = train_stub.SampleTaskSampleModuleTrain(train_request2) @@ -787,15 +765,9 @@ def test_train_fake_module_error_response_with_unloaded_model( sample_model = caikit.interfaces.runtime.data_model.ModelPointer( model_id=random_test_id() ) - train_class = DataBase.get_class_for_name( - "SampleTaskCompositeModuleTrainRequest" - ) - train_request_params_class = DataBase.get_class_for_name( - "SampleTaskCompositeModuleTrainParameters" - ) - train_request = train_class( + train_request = get_train_request(CompositeModule)( model_name=random_test_id(), - parameters=train_request_params_class(sample_block=sample_model), + parameters=get_train_params(CompositeModule)(sample_block=sample_model), ).to_proto() train_stub.SampleTaskCompositeModuleTrain(train_request) @@ -1144,7 +1116,7 @@ def test_metrics_stored_after_server_interrupt( with temp_config({"runtime": {"metering": {"enabled": True}}}, "merge"): with runtime_grpc_test_server(open_port) as server: stub = sample_inference_service.stub_class(server.make_local_channel()) - predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_class = get_request(SampleModule) predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() _ = stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", sample_task_model_id)]