diff --git a/caikit/runtime/service_generation/rpcs.py b/caikit/runtime/service_generation/rpcs.py index 438ba3ead..521bfe87c 100644 --- a/caikit/runtime/service_generation/rpcs.py +++ b/caikit/runtime/service_generation/rpcs.py @@ -63,7 +63,21 @@ def name(self) -> str: def create_request_data_model(self, package_name: str) -> Type[DataBase]: """Dynamically create data model for this RPC's request message""" - return self.request.create_data_model(package_name) + request_data_model = self.request.create_data_model(package_name) + if isinstance(self, TaskPredictRPC): + # add the DM to the task class for inference rpcs + if self._input_streaming and self._output_streaming: + setattr(self.task, "BIDI_REQUEST_DATA_MODEL", request_data_model) + elif self._output_streaming: + setattr(self.task, "SERVER_REQUEST_DATA_MODEL", request_data_model) + elif self._input_streaming: + setattr(self.task, "CLIENT_REQUEST_DATA_MODEL", request_data_model) + else: + setattr(self.task, "UNARY_REQUEST_DATA_MODEL", request_data_model) + if isinstance(self, ModuleClassTrainRPC): + # add the DM to the module class directly for train rpcs + setattr(self.clz, "TRAIN_REQUEST_DATA_MODEL", request_data_model) + return request_data_model def create_rpc_json(self, package_name: str) -> Dict: """Return json snippet for the service definition of this RPC""" @@ -124,6 +138,8 @@ def create_request_data_model(self, package_name: str): request data model""" # Build the inner request data model inner_request_data_model = self._inner_request.create_data_model(package_name) + + setattr(self.clz, "TRAINING_PARAMETERS_DATA_MODEL", inner_request_data_model) # Insert the new type into the outer request for triple_index, _ in enumerate(self._req.triples): if self._req.triples[triple_index][1] == "parameters": diff --git a/examples/text-sentiment/client.py b/examples/text-sentiment/client.py index 7e13d7e09..2523d7122 100644 --- a/examples/text-sentiment/client.py +++ b/examples/text-sentiment/client.py @@ -21,8 +21,8 @@ # Local from caikit.config.config import get_config -from caikit.core.data_model.base import DataBase from caikit.runtime.service_factory import ServicePackageFactory +from text_sentiment.runtime_model import HuggingFaceSentimentModule import caikit if __name__ == "__main__": @@ -51,10 +51,9 @@ # Run inference for two sample prompts for text in ["I am not feeling well today!", "Today is a nice sunny day"]: - predict_class = DataBase.get_class_for_name( - "HuggingFaceSentimentTaskRequest" - ) - request = predict_class(text_input=text).to_proto() + request = HuggingFaceSentimentModule.TASK_CLASS.UNARY_REQUEST_DATA_MODEL( + text_input=text + ).to_proto() response = client_stub.HuggingFaceSentimentTaskPredict( request, metadata=[("mm-model-id", model_id)], timeout=1 ) diff --git a/tests/runtime/test_grpc_server.py b/tests/runtime/test_grpc_server.py index 1c0a3de59..0fc6b2fbf 100644 --- a/tests/runtime/test_grpc_server.py +++ b/tests/runtime/test_grpc_server.py @@ -212,8 +212,11 @@ def test_predict_sample_module_ok_response( ): """Test RPC CaikitRuntime.SampleTaskPredict successful response""" stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel()) - predict_class = DataBase.get_class_for_name("SampleTaskRequest") - predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() + predict_request = ( + sample_lib.modules.SampleModule.TASK_CLASS.UNARY_REQUEST_DATA_MODEL( + sample_input=HAPPY_PATH_INPUT_DM + ).to_proto() + ) actual_response = stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", sample_task_model_id)] @@ -397,15 +400,9 @@ def test_train_fake_module_ok_response_and_can_predict_with_trained_model( ) ) model_name = random_test_id() - train_request_class = DataBase.get_class_for_name( - "SampleTaskSampleModuleTrainRequest" - ) - train_request_params_class = DataBase.get_class_for_name( - "SampleTaskSampleModuleTrainParameters" - ) - train_request = train_request_class( + train_request = sample_lib.modules.SampleModule.TRAIN_REQUEST_DATA_MODEL( model_name=model_name, - parameters=train_request_params_class( + parameters=sample_lib.modules.SampleModule.TRAINING_PARAMETERS_DATA_MODEL( training_data=training_data, union_list=["str", "sequence"], ),