diff --git a/examples/text-sentiment/client.py b/examples/text-sentiment/client.py index b339d5942..7e13d7e09 100644 --- a/examples/text-sentiment/client.py +++ b/examples/text-sentiment/client.py @@ -21,6 +21,7 @@ # Local from caikit.config.config import get_config +from caikit.core.data_model.base import DataBase from caikit.runtime.service_factory import ServicePackageFactory import caikit @@ -50,9 +51,10 @@ # Run inference for two sample prompts for text in ["I am not feeling well today!", "Today is a nice sunny day"]: - request = inference_service.messages.HuggingFaceSentimentTaskRequest( - text_input=text + predict_class = DataBase.get_class_for_name( + "HuggingFaceSentimentTaskRequest" ) + request = predict_class(text_input=text).to_proto() response = client_stub.HuggingFaceSentimentTaskPredict( request, metadata=[("mm-model-id", model_id)], timeout=1 ) diff --git a/tests/runtime/interceptors/test_caikit_runtime_server_wrapper.py b/tests/runtime/interceptors/test_caikit_runtime_server_wrapper.py index ebfe89446..03eb2b457 100644 --- a/tests/runtime/interceptors/test_caikit_runtime_server_wrapper.py +++ b/tests/runtime/interceptors/test_caikit_runtime_server_wrapper.py @@ -18,6 +18,7 @@ import grpc # Local +from caikit.core.data_model.base import DataBase from caikit.runtime.interceptors.caikit_runtime_server_wrapper import ( CaikitRuntimeServerWrapper, ) @@ -47,9 +48,8 @@ def predict(request, context, caikit_rpc): client = sample_inference_service.stub_class( grpc.insecure_channel(f"localhost:{open_port}") ) - _ = client.SampleTaskPredict( - sample_inference_service.messages.SampleTaskRequest(), timeout=3 - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + _ = client.SampleTaskPredict(predict_class().to_proto(), timeout=3) assert len(calls) == 1 assert isinstance(calls[0], TaskPredictRPC) assert calls[0].name == "SampleTaskPredict" diff --git a/tests/runtime/servicers/test_global_predict_servicer_impl.py b/tests/runtime/servicers/test_global_predict_servicer_impl.py index cfef456d4..644f3f51d 100644 --- a/tests/runtime/servicers/test_global_predict_servicer_impl.py +++ b/tests/runtime/servicers/test_global_predict_servicer_impl.py @@ -34,6 +34,7 @@ import pytest # Local +from caikit.core.data_model.base import DataBase from caikit.runtime.servicers.global_predict_servicer import GlobalPredictServicer from caikit.runtime.types.aborted_exception import AbortedException from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException @@ -56,9 +57,8 @@ def test_calling_predict_should_raise_if_module_raises( ): with pytest.raises(CaikitRuntimeException) as context: # SampleModules will raise a RuntimeError if the throw flag is set - request = sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT, throw=True - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + request = predict_class(sample_input=HAPPY_PATH_INPUT_DM, throw=True).to_proto() sample_predict_servicer.Predict( request, Fixtures.build_context(sample_task_model_id), @@ -77,9 +77,10 @@ def test_invalid_input_to_a_valid_caikit_core_class_method_raises( """Test that a caikit.core module that gets an unexpected input value errors in an expected way""" with pytest.raises(CaikitRuntimeException) as context: # SampleModules will raise a ValueError if the poison pill name is given - request = sample_inference_service.messages.SampleTaskRequest( - sample_input=SampleInputType(name=SampleModule.POISON_PILL_NAME).to_proto(), - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + request = predict_class( + sample_input=SampleInputType(name=SampleModule.POISON_PILL_NAME) + ).to_proto() sample_predict_servicer.Predict( request, Fixtures.build_context(sample_task_model_id), @@ -96,10 +97,9 @@ def test_global_predict_works_for_unary_rpcs( sample_task_unary_rpc, ): """Global predict of SampleTaskRequest returns a prediction""" + predict_class = DataBase.get_class_for_name("SampleTaskRequest") response = sample_predict_servicer.Predict( - sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ), + predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(), Fixtures.build_context(sample_task_model_id), caikit_rpc=sample_task_unary_rpc, ) @@ -111,13 +111,11 @@ def test_global_predict_works_on_bidirectional_streaming_rpcs( ): """Simple test that our SampleModule's bidirectional stream inference fn is supported""" - def req_iterator() -> Iterator[ - sample_inference_service.messages.BidiStreamingSampleTaskRequest - ]: + predict_class = DataBase.get_class_for_name("BidiStreamingSampleTaskRequest") + + def req_iterator() -> Iterator[predict_class]: for i in range(100): - yield sample_inference_service.messages.BidiStreamingSampleTaskRequest( - sample_inputs=HAPPY_PATH_INPUT - ) + yield predict_class(sample_inputs=HAPPY_PATH_INPUT_DM).to_proto() response_stream = sample_predict_servicer.Predict( req_iterator(), @@ -141,13 +139,11 @@ def test_global_predict_works_on_bidirectional_streaming_rpcs_with_multiple_stre mock_manager = MagicMock() mock_manager.retrieve_model.return_value = GeoStreamingModule() - def req_iterator() -> Iterator[ - sample_inference_service.messages.BidiStreamingGeoSpatialTaskRequest - ]: + predict_class = DataBase.get_class_for_name("BidiStreamingGeoSpatialTaskRequest") + + def req_iterator() -> Iterator[predict_class]: for i in range(100): - yield sample_inference_service.messages.BidiStreamingGeoSpatialTaskRequest( - lats=i, lons=100 - i, name="Gabe" - ) + yield predict_class(lats=i, lons=100 - i, name="Gabe").to_proto() with patch.object(sample_predict_servicer, "_model_manager", mock_manager): response_stream = sample_predict_servicer.Predict( @@ -198,12 +194,11 @@ def run(self, *args, **kwargs): mock_manager.retrieve_model.return_value = dummy_model context = Fixtures.build_context("test-any-unresponsive-model") + predict_class = DataBase.get_class_for_name("SampleTaskRequest") predict_thread = threading.Thread( target=sample_predict_servicer.Predict, args=( - sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ), + predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(), context, ), kwargs={"caikit_rpc": sample_task_unary_rpc}, @@ -234,9 +229,10 @@ def test_metering_ignore_unsuccessful_calls( gps = GlobalPredictServicer(sample_inference_service) try: with patch.object(gps.rpc_meter, "update_metrics") as mock_update_func: - request = sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT, throw=True - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + request = predict_class( + sample_input=HAPPY_PATH_INPUT_DM, throw=True + ).to_proto() with pytest.raises(CaikitRuntimeException): gps.Predict( request, @@ -257,11 +253,10 @@ def test_metering_predict_rpc_counter( sample_predict_servicer = GlobalPredictServicer(sample_inference_service) try: # Making 20 requests + predict_class = DataBase.get_class_for_name("SampleTaskRequest") for i in range(20): sample_predict_servicer.Predict( - sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ), + predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(), Fixtures.build_context(sample_task_model_id), caikit_rpc=sample_task_unary_rpc, ) @@ -299,10 +294,9 @@ def test_metering_write_to_metrics_file_twice( # need a new servicer to get a fresh new RPC meter sample_predict_servicer = GlobalPredictServicer(sample_inference_service) try: + predict_class = DataBase.get_class_for_name("SampleTaskRequest") sample_predict_servicer.Predict( - sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ), + predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(), Fixtures.build_context(sample_task_model_id), caikit_rpc=sample_task_unary_rpc, ) @@ -311,9 +305,7 @@ def test_metering_write_to_metrics_file_twice( sample_predict_servicer.rpc_meter.flush_metrics() sample_predict_servicer.Predict( - sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ), + predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(), Fixtures.build_context(sample_task_model_id), caikit_rpc=sample_task_unary_rpc, ) diff --git a/tests/runtime/servicers/test_global_train_servicer_impl.py b/tests/runtime/servicers/test_global_train_servicer_impl.py index c42f14c27..fbf5fa7cd 100644 --- a/tests/runtime/servicers/test_global_train_servicer_impl.py +++ b/tests/runtime/servicers/test_global_train_servicer_impl.py @@ -27,6 +27,7 @@ # Local from caikit.config import get_config from caikit.core import MODEL_MANAGER +from caikit.core.data_model.base import DataBase from caikit.core.data_model.producer import ProducerId from caikit.interfaces.common.data_model.stream_sources import S3Path from caikit.runtime.servicers.global_train_servicer import GlobalTrainServicer @@ -45,6 +46,8 @@ ## Helpers ##################################################################### +HAPPY_PATH_INPUT_DM = SampleInputType(name="Gabe") + @contextmanager def set_use_subprocess(use_subprocess: bool): @@ -94,15 +97,19 @@ def test_global_train_sample_task( stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType training_data = stream_type( jsondata=stream_type.JsonData(data=[SampleTrainingType(1)]) - ).to_proto() + ) model_name = random_test_id() - train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( model_name=model_name, - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( + parameters=train_request_params_class( batch_size=42, training_data=training_data, ), - ) + ).to_proto() training_response = sample_train_servicer.Train( train_request, Fixtures.build_context("foo") @@ -131,10 +138,9 @@ def test_global_train_sample_task( == "sample_lib.modules.sample_task.sample_implementation.SampleModule" ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") inference_response = sample_predict_servicer.Predict( - sample_inference_service.messages.SampleTaskRequest( - sample_input=SampleInputType(name="Gabe").to_proto() - ), + predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(), Fixtures.build_context(training_response.model_name), caikit_rpc=sample_task_unary_rpc, ) @@ -158,15 +164,19 @@ def test_global_train_other_task( """ batch_size = 42 stream_type = caikit.interfaces.common.data_model.DataStreamSourceInt - training_data = stream_type(jsondata=stream_type.JsonData(data=[1])).to_proto() - train_request = sample_train_service.messages.OtherTaskOtherModuleTrainRequest( + training_data = stream_type(jsondata=stream_type.JsonData(data=[1])) + train_class = DataBase.get_class_for_name("OtherTaskOtherModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "OtherTaskOtherModuleTrainParameters" + ) + train_request = train_class( model_name="Other module Training", - parameters=sample_train_service.messages.OtherTaskOtherModuleTrainParameters( + parameters=train_request_params_class( training_data=training_data, - sample_input_sampleinputtype=SampleInputType(name="Gabe").to_proto(), + sample_input_sampleinputtype=HAPPY_PATH_INPUT_DM, batch_size=batch_size, ), - ) + ).to_proto() training_response = sample_train_servicer.Train( train_request, Fixtures.build_context("foo") @@ -187,10 +197,9 @@ def test_global_train_other_task( == "sample_lib.modules.other_task.other_implementation.OtherModule" ) + predict_class = DataBase.get_class_for_name("OtherTaskRequest") inference_response = sample_predict_servicer.Predict( - sample_inference_service.messages.OtherTaskRequest( - sample_input=SampleInputType(name="Gabe").to_proto() - ), + predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(), Fixtures.build_context(training_response.model_name), caikit_rpc=sample_inference_service.caikit_rpcs["OtherTaskPredict"], ) @@ -214,14 +223,16 @@ def test_global_train_Another_Widget_that_requires_SampleWidget_loaded_should_no """Global train of TrainRequest returns a training job with the correct model name, and some training id for a train function that requires another loaded model""" sample_model = caikit.interfaces.runtime.data_model.ModelPointer( model_id=sample_task_model_id - ).to_proto() + ) - training_request = sample_train_service.messages.SampleTaskCompositeModuleTrainRequest( - model_name="AnotherWidget_Training", - parameters=sample_train_service.messages.SampleTaskCompositeModuleTrainParameters( - sample_block=sample_model, - ), + train_class = DataBase.get_class_for_name("SampleTaskCompositeModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskCompositeModuleTrainParameters" ) + training_request = train_class( + model_name="AnotherWidget_Training", + parameters=train_request_params_class(sample_block=sample_model), + ).to_proto() training_response = sample_train_servicer.Train( training_request, Fixtures.build_context("foo") @@ -244,10 +255,9 @@ def test_global_train_Another_Widget_that_requires_SampleWidget_loaded_should_no ) # make sure the trained model can run inference + predict_class = DataBase.get_class_for_name("SampleTaskRequest") inference_response = sample_predict_servicer.Predict( - sample_inference_service.messages.SampleTaskRequest( - sample_input=SampleInputType(name="Gabe").to_proto() - ), + predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(), Fixtures.build_context(training_response.model_name), caikit_rpc=sample_task_unary_rpc, ) @@ -269,14 +279,18 @@ def test_run_train_job_works_with_wait( stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType training_data = stream_type( jsondata=stream_type.JsonData(data=[SampleTrainingType(1)]) - ).to_proto() - train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + ) + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( model_name=random_test_id(), - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( + parameters=train_request_params_class( batch_size=42, training_data=training_data, ), - ) + ).to_proto() servicer = GlobalTrainServicer(training_service=sample_train_service) with TemporaryDirectory() as tmp_dir: training_response = servicer.run_training_job( @@ -292,10 +306,9 @@ def test_run_train_job_works_with_wait( training_response.training_id, ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") inference_response = sample_predict_servicer.Predict( - sample_inference_service.messages.SampleTaskRequest( - sample_input=SampleInputType(name="Test").to_proto() - ), + predict_class(sample_input=SampleInputType(name="Test")).to_proto(), Fixtures.build_context(training_response.model_name), caikit_rpc=sample_task_unary_rpc, ) @@ -318,15 +331,15 @@ def test_global_train_Another_Widget_that_requires_SampleWidget_but_not_loaded_s """Global train of TrainRequest raises when calling a train function that requires another loaded model, but model is not loaded""" model_id = random_test_id() - sample_model = caikit.interfaces.runtime.data_model.ModelPointer( - model_id=model_id - ).to_proto() - request = sample_train_service.messages.SampleTaskCompositeModuleTrainRequest( - model_name="AnotherWidget_Training", - parameters=sample_train_service.messages.SampleTaskCompositeModuleTrainParameters( - sample_block=sample_model, - ), + sample_model = caikit.interfaces.runtime.data_model.ModelPointer(model_id=model_id) + train_class = DataBase.get_class_for_name("SampleTaskCompositeModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskCompositeModuleTrainParameters" ) + request = train_class( + model_name="AnotherWidget_Training", + parameters=train_request_params_class(sample_block=sample_model), + ).to_proto() with pytest.raises(CaikitRuntimeException) as context: sample_train_servicer.Train(request, Fixtures.build_context("foo")) @@ -341,15 +354,19 @@ def test_global_train_Edge_Case_Widget_should_raise_when_error_surfaces_from_mod stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType training_data = stream_type( jsondata=stream_type.JsonData(data=[SampleTrainingType(1)]) - ).to_proto() + ) - train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( model_name=random_test_id(), - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( + parameters=train_request_params_class( batch_size=999, training_data=training_data, ), - ) + ).to_proto() training_response = sample_train_servicer.Train( train_request, Fixtures.build_context("foo") @@ -369,15 +386,19 @@ def test_global_train_returns_exit_code_with_oom( stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType training_data = stream_type( jsondata=stream_type.JsonData(data=[SampleTrainingType(1)]) - ).to_proto() - train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + ) + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( model_name=random_test_id(), - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( + parameters=train_request_params_class( batch_size=42, training_data=training_data, oom_exit=True, ), - ) + ).to_proto() # Enable sub-processing for test with set_use_subprocess(True): @@ -399,16 +420,20 @@ def test_local_trainer_rejects_s3_output_paths( stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType training_data = stream_type( jsondata=stream_type.JsonData(data=[SampleTrainingType(1)]) - ).to_proto() - train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + ) + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( model_name=random_test_id(), - output_path=S3Path(path="foo").to_proto(), - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( + output_path=S3Path(path="foo"), + parameters=train_request_params_class( batch_size=42, training_data=training_data, oom_exit=True, ), - ) + ).to_proto() with pytest.raises( CaikitRuntimeException, match=".*S3 output path not supported by this runtime" @@ -429,17 +454,21 @@ def test_global_train_aborts_long_running_trains( stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType training_data = stream_type( jsondata=stream_type.JsonData(data=[SampleTrainingType(1)]) - ).to_proto() + ) training_id = random_test_id() - train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( model_name=training_id, - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( + parameters=train_request_params_class( batch_size=42, training_data=training_data, oom_exit=False, ), - ) + ).to_proto() if sample_train_servicer.use_subprocess: start_method = ( diff --git a/tests/runtime/test_grpc_server.py b/tests/runtime/test_grpc_server.py index c11af4e4d..1c0a3de59 100644 --- a/tests/runtime/test_grpc_server.py +++ b/tests/runtime/test_grpc_server.py @@ -59,6 +59,7 @@ process_pb2_grpc, ) from caikit.runtime.service_factory import ServicePackage, ServicePackageFactory +from caikit.runtime.utils.servicer_util import build_caikit_library_request_dict from sample_lib import InnerModule, SamplePrimitiveModule from sample_lib.data_model import ( OtherOutputType, @@ -81,7 +82,8 @@ log = alog.use_channel("TEST-SERVE-I") -HAPPY_PATH_INPUT = SampleInputType(name="Gabe").to_proto() +HAPPY_PATH_INPUT_DM = SampleInputType(name="Gabe") +HAPPY_PATH_INPUT = HAPPY_PATH_INPUT_DM.to_proto() HAPPY_PATH_RESPONSE = SampleOutputType(greeting="Hello Gabe").to_proto() HAPPY_PATH_TRAIN_RESPONSE = TrainingJob( model_name="dummy name", training_id="dummy id" @@ -210,25 +212,57 @@ def test_predict_sample_module_ok_response( ): """Test RPC CaikitRuntime.SampleTaskPredict successful response""" stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel()) - predict_request = sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() + actual_response = stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", sample_task_model_id)] ) + assert actual_response == HAPPY_PATH_RESPONSE +def test_global_predict_build_caikit_library_request_dict_creates_caikit_core_run_kwargs( + sample_inference_service, +): + """Test using proto versus pythonic data model for inference requests to compare diffs""" + # protobuf request + proto_request = sample_inference_service.messages.SampleTaskRequest( + sample_input=HAPPY_PATH_INPUT_DM.to_proto(), + ) + proto_request_dict = build_caikit_library_request_dict( + proto_request, + sample_lib.modules.sample_task.SampleModule.RUN_SIGNATURE, + ) + + # unset fields not included + proto_expected_arguments = {"sample_input"} + assert proto_request.HasField("throw") is False + assert proto_expected_arguments == set(proto_request_dict.keys()) + + # pythonic data model request + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + python_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() + + python_sample_module_request_dict = build_caikit_library_request_dict( + python_request, + sample_lib.modules.sample_task.SampleModule.RUN_SIGNATURE, + ) + + # unset fields are included if they have defaults set + python_expected_arguments = {"sample_input", "throw"} + assert python_request.HasField("throw") is True + assert python_expected_arguments == set(python_sample_module_request_dict.keys()) + + def test_predict_streaming_module( streaming_task_model_id, runtime_grpc_server, sample_inference_service ): """Test RPC CaikitRuntime.StreamingTaskPredict successful response""" stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel()) - predict_request = ( - sample_inference_service.messages.ServerStreamingStreamingTaskRequest( - sample_input=HAPPY_PATH_INPUT - ) - ) + predict_class = DataBase.get_class_for_name("ServerStreamingStreamingTaskRequest") + predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() + stream = stub.ServerStreamingStreamingTaskPredict( predict_request, metadata=[("mm-model-id", streaming_task_model_id)] ) @@ -248,9 +282,9 @@ def test_predict_sample_module_error_response( stub = sample_inference_service.stub_class( runtime_grpc_server.make_local_channel() ) - predict_request = sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() + stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", "random_model_id")] ) @@ -263,9 +297,11 @@ def test_rpc_validation_on_predict( ): """Check that the server catches models sent to the wrong task RPCs""" stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel()) - predict_request = sample_inference_service.messages.OtherTaskRequest( - sample_input_sampleinputtype=HAPPY_PATH_INPUT - ) + predict_class = DataBase.get_class_for_name("OtherTaskRequest") + predict_request = predict_class( + sample_input_sampleinputtype=HAPPY_PATH_INPUT_DM + ).to_proto() + with pytest.raises( grpc.RpcError, match="Wrong inference RPC invoked for model class .* Use SampleTaskPredict instead of OtherTaskPredict", @@ -292,9 +328,8 @@ def test_rpc_validation_on_predict_for_unsupported_model( stub = sample_inference_service.stub_class( runtime_grpc_server.make_local_channel() ) - predict_request = sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() with pytest.raises(grpc.RpcError) as context: stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", model_id)] @@ -325,9 +360,8 @@ def test_rpc_validation_on_predict_for_wrong_streaming_flavor( stub = sample_inference_service.stub_class( runtime_grpc_server.make_local_channel() ) - predict_request = sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() with pytest.raises(grpc.RpcError) as context: response = stub.ServerStreamingSampleTaskPredict( predict_request, metadata=[("mm-model-id", model_id)] @@ -387,9 +421,8 @@ def test_train_fake_module_ok_response_and_can_predict_with_trained_model( ) # make sure the trained model can run inference - predict_request = sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] ) @@ -408,14 +441,17 @@ def test_train_fake_module_ok_response_with_loaded_model_can_predict_with_traine """Test RPC CaikitRuntime.WorkflowsSampleTaskSampleWorkflowTrain successful response with a loaded model""" sample_model = caikit.interfaces.runtime.data_model.ModelPointer( model_id=sample_task_model_id - ).to_proto() + ) model_name = random_test_id() - train_request = sample_train_service.messages.SampleTaskCompositeModuleTrainRequest( - model_name=model_name, - parameters=sample_train_service.messages.SampleTaskCompositeModuleTrainParameters( - sample_block=sample_model - ), + train_class = DataBase.get_class_for_name("SampleTaskCompositeModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskCompositeModuleTrainParameters" ) + train_request = train_class( + model_name=model_name, + parameters=train_request_params_class(sample_block=sample_model), + ).to_proto() + actual_response = train_stub.SampleTaskCompositeModuleTrain(train_request) assert_training_successful( actual_response, HAPPY_PATH_TRAIN_RESPONSE, model_name, training_management_stub @@ -425,9 +461,8 @@ def test_train_fake_module_ok_response_with_loaded_model_can_predict_with_traine ) # make sure the trained model can run inference - predict_request = sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] ) @@ -451,18 +486,20 @@ def test_train_fake_module_does_not_change_another_instance_model_of_block( # Train an OtherModule with batch size 100 stream_type = caikit.interfaces.common.data_model.DataStreamSourceInt - training_data = stream_type( - file=stream_type.File(filename=sample_int_file) - ).to_proto() + training_data = stream_type(file=stream_type.File(filename=sample_int_file)) - train_request = sample_train_service.messages.OtherTaskOtherModuleTrainRequest( + train_class = DataBase.get_class_for_name("OtherTaskOtherModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "OtherTaskOtherModuleTrainParameters" + ) + train_request = train_class( model_name="Bar Training", - parameters=sample_train_service.messages.OtherTaskOtherModuleTrainParameters( - sample_input_sampleinputtype=SampleInputType(name="Gabe").to_proto(), + parameters=train_request_params_class( + sample_input_sampleinputtype=HAPPY_PATH_INPUT_DM, batch_size=100, training_data=training_data, ), - ) + ).to_proto() actual_response = train_stub.OtherTaskOtherModuleTrain(train_request) assert_training_successful( actual_response, @@ -475,9 +512,8 @@ def test_train_fake_module_does_not_change_another_instance_model_of_block( ) # make sure the trained model can run inference, and the batch size 100 was used - predict_request = sample_inference_service.messages.OtherTaskRequest( - sample_input=HAPPY_PATH_INPUT - ) + predict_class = DataBase.get_class_for_name("OtherTaskRequest") + predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() trained_inference_response = inference_stub.OtherTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] ) @@ -517,7 +553,7 @@ def test_train_primitive_model( train_request = train_request_class( model_name=model_name, parameters=train_request_params_class( - sample_input=SampleInputType(name="Gabe"), + sample_input=HAPPY_PATH_INPUT_DM, simple_list=["hello", "world"], union_list=["str", "sequence"], union_list2=[1, 2], @@ -542,9 +578,8 @@ def test_train_primitive_model( ) # make sure the trained model can run inference - predict_request = sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", training_response.model_name)] @@ -570,15 +605,19 @@ def test_train_fake_module_ok_response_with_datastream_jsondata( jsondata=stream_type.JsonData( data=[SampleTrainingType(1), SampleTrainingType(2)] ) - ).to_proto() + ) model_name = random_test_id() - train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( model_name=model_name, - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( + parameters=train_request_params_class( batch_size=42, training_data=training_data, ), - ) + ).to_proto() actual_response = train_stub.SampleTaskSampleModuleTrain(train_request) assert_training_successful( @@ -589,9 +628,8 @@ def test_train_fake_module_ok_response_with_datastream_jsondata( ) # make sure the trained model can run inference - predict_request = sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] ) @@ -609,16 +647,18 @@ def test_train_fake_module_ok_response_with_datastream_csv_file( ): """Test RPC CaikitRuntime.SampleTaskSampleModuleTrainRequest successful response with training data file type""" stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType - training_data = stream_type( - file=stream_type.File(filename=sample_csv_file) - ).to_proto() + training_data = stream_type(file=stream_type.File(filename=sample_csv_file)) model_name = random_test_id() - train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( model_name=model_name, - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( + parameters=train_request_params_class( training_data=training_data, ), - ) + ).to_proto() actual_response = train_stub.SampleTaskSampleModuleTrain(train_request) assert_training_successful( @@ -629,9 +669,8 @@ def test_train_fake_module_ok_response_with_datastream_csv_file( ) # make sure the trained model can run inference - predict_request = sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() inference_response = inference_stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", actual_response.model_name)] ) @@ -650,12 +689,16 @@ def test_train_and_successfully_cancel_training( ) model_name = random_test_id() # start a training that sleeps for a long time, so I can cancel - train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( model_name=model_name, - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( - training_data=training_data.to_proto(), sleep_time=10 + parameters=train_request_params_class( + training_data=training_data, sleep_time=10 ), - ) + ).to_proto() train_response = train_stub.SampleTaskSampleModuleTrain(train_request) training_id = train_response.training_id @@ -687,12 +730,16 @@ def test_cancel_does_not_affect_other_models( ) model_name = random_test_id() # start a training that sleeps for a long time, so I can cancel - train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( model_name=model_name, - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( - training_data=training_data.to_proto(), sleep_time=10 + parameters=train_request_params_class( + training_data=training_data, sleep_time=10 ), - ) + ).to_proto() train_response = train_stub.SampleTaskSampleModuleTrain(train_request) assert dir(train_response) == dir(HAPPY_PATH_TRAIN_RESPONSE) @@ -710,12 +757,10 @@ def test_cancel_does_not_affect_other_models( # train another model model_name2 = random_test_id() - train_request2 = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + train_request2 = train_class( model_name=model_name2, - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( - training_data=training_data.to_proto() - ), - ) + parameters=train_request_params_class(training_data=training_data), + ).to_proto() train_response2 = train_stub.SampleTaskSampleModuleTrain(train_request2) # cancel the first training @@ -744,14 +789,18 @@ def test_train_fake_module_error_response_with_unloaded_model( with pytest.raises(grpc.RpcError) as context: sample_model = caikit.interfaces.runtime.data_model.ModelPointer( model_id=random_test_id() + ) + train_class = DataBase.get_class_for_name( + "SampleTaskCompositeModuleTrainRequest" + ) + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskCompositeModuleTrainParameters" + ) + train_request = train_class( + model_name=random_test_id(), + parameters=train_request_params_class(sample_block=sample_model), ).to_proto() - train_request = sample_train_service.messages.SampleTaskCompositeModuleTrainRequest( - model_name=random_test_id(), - parameters=sample_train_service.messages.SampleTaskCompositeModuleTrainParameters( - sample_block=sample_model - ), - ) train_stub.SampleTaskCompositeModuleTrain(train_request) assert context.value.code() == grpc.StatusCode.NOT_FOUND @@ -1098,9 +1147,8 @@ def test_metrics_stored_after_server_interrupt( with temp_config({"runtime": {"metering": {"enabled": True}}}, "merge"): with runtime_grpc_test_server(open_port) as server: stub = sample_inference_service.stub_class(server.make_local_channel()) - predict_request = sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() _ = stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", sample_task_model_id)] ) diff --git a/tests/runtime/test_service_factory.py b/tests/runtime/test_service_factory.py index 1ecefaf83..36da14285 100644 --- a/tests/runtime/test_service_factory.py +++ b/tests/runtime/test_service_factory.py @@ -23,6 +23,7 @@ # Local from caikit.core.data_model import render_dataobject_protos +from caikit.core.data_model.base import DataBase from caikit.runtime.service_factory import ServicePackage, ServicePackageFactory from sample_lib import SampleModule from sample_lib.data_model import SampleInputType, SampleOutputType @@ -359,7 +360,8 @@ def run( inference_service = ServicePackageFactory.get_service_package( ServicePackageFactory.ServiceType.INFERENCE ) - sample_task_request = inference_service.messages.SampleTaskRequest + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + sample_task_request = predict_class().to_proto() # Check that the new parameter defined in this backend module exists in the service assert "backend_param" in sample_task_request.DESCRIPTOR.fields_by_name.keys() diff --git a/tests/runtime/utils/test_servicer_util.py b/tests/runtime/utils/test_servicer_util.py index c7a6adacd..56932996e 100644 --- a/tests/runtime/utils/test_servicer_util.py +++ b/tests/runtime/utils/test_servicer_util.py @@ -20,6 +20,7 @@ import pytest # Local +from caikit.core.data_model.base import DataBase from caikit.runtime.protobufs import model_runtime_pb2 from caikit.runtime.service_generation.data_stream_source import DataStreamSourceBase from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException @@ -124,10 +125,9 @@ def test_servicer_util_is_protobuf_primitive_returns_true_for_primitive_types( sample_inference_service, ): """Test that is_protobuf_primitive_field is True when considering primitive types""" + predict_class = DataBase.get_class_for_name("SampleTaskRequest") assert is_protobuf_primitive_field( - sample_inference_service.messages.SampleTaskRequest().DESCRIPTOR.fields_by_name[ - "int_type" - ] + predict_class().to_proto().DESCRIPTOR.fields_by_name["int_type"] ) @@ -137,10 +137,9 @@ def test_servicer_util_is_protobuf_primitive_returns_false_for_custom_types( """Test that is_protobuf_primitive_field is False when considering message and enum types. This is essential for handling Caikit library CDM objects, which are generally defined in terms of messages""" + predict_class = DataBase.get_class_for_name("SampleTaskRequest") assert not is_protobuf_primitive_field( - sample_inference_service.messages.SampleTaskRequest().DESCRIPTOR.fields_by_name[ - "sample_input" - ] + predict_class().to_proto().DESCRIPTOR.fields_by_name["sample_input"] ) @@ -198,22 +197,21 @@ def test_servicer_util_will_not_validate_arbitrary_service_descriptor(): # ---------------- Tests for build_caikit_library_request_dict -------------------- -HAPPY_PATH_INPUT = SampleInputType(name="Gabe").to_proto() +HAPPY_PATH_INPUT_DM = SampleInputType(name="Gabe") def test_global_predict_build_caikit_library_request_dict_creates_caikit_core_run_kwargs( sample_inference_service, ): """Test that build_caikit_library_request_dict creates module run kwargs from RPC msg""" + predict_class = DataBase.get_class_for_name("SampleTaskRequest") request_dict = build_caikit_library_request_dict( - sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ), + predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(), sample_lib.modules.sample_task.SampleModule.RUN_SIGNATURE, ) - # No self or "throw", throw was not set and the throw parameter contains a default value - expected_arguments = {"sample_input"} + # Since using pythonic data model and throw has a default parameter, it is an expected argument + expected_arguments = {"sample_input", "throw"} assert expected_arguments == set(request_dict.keys()) assert isinstance(request_dict["sample_input"], SampleInputType) @@ -224,16 +222,17 @@ def test_global_predict_build_caikit_library_request_dict_strips_invalid_run_kwa ): """Global predict build_caikit_library_request_dict strips invalid run kwargs from request""" # Sample module doesn't take the `int_type` or `bool_type` params + predict_class = DataBase.get_class_for_name("SampleTaskRequest") request_dict = build_caikit_library_request_dict( - sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT, + predict_class( + sample_input=HAPPY_PATH_INPUT_DM, int_type=5, bool_type=True, - ), + ).to_proto(), sample_lib.modules.sample_task.SampleModule.RUN_SIGNATURE, ) - expected_arguments = {"sample_input"} + expected_arguments = {"sample_input", "throw"} assert expected_arguments == set(request_dict.keys()) assert "int_type" not in request_dict.keys() @@ -242,8 +241,9 @@ def test_global_predict_build_caikit_library_request_dict_strips_empty_list_from sample_inference_service, ): """Global predict build_caikit_library_request_dict strips empty list from request""" + predict_class = DataBase.get_class_for_name("SampleTaskRequest") request_dict = build_caikit_library_request_dict( - sample_inference_service.messages.SampleTaskRequest(int_type=5, list_type=[]), + predict_class(int_type=5, list_type=[]).to_proto(), sample_lib.modules.sample_task.SamplePrimitiveModule.RUN_SIGNATURE, ) @@ -251,10 +251,12 @@ def test_global_predict_build_caikit_library_request_dict_strips_empty_list_from assert "int_type" in request_dict.keys() -def test_global_predict_build_caikit_library_request_dict_works_for_unset_primitives( +def test_global_predict_build_caikit_library_request_dict_with_proto_does_not_include_unset_primitives( sample_inference_service, ): """Global predict build_caikit_library_request_dict works for primitives""" + # When using protobuf message for request, if params unset, does not include primitive args + # that are unset even if default values present request = sample_inference_service.messages.SampleTaskRequest() request_dict = build_caikit_library_request_dict( @@ -264,10 +266,31 @@ def test_global_predict_build_caikit_library_request_dict_works_for_unset_primit assert len(request_dict.keys()) == 0 -def test_global_predict_build_caikit_library_request_dict_works_for_set_primitives( +def test_global_predict_build_caikit_library_request_dict_works_for_unset_primitives( + sample_inference_service, +): + """Global predict build_caikit_library_request_dict works for primitives""" + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + request = predict_class().to_proto() + + request_dict = build_caikit_library_request_dict( + request, SamplePrimitiveModule.RUN_SIGNATURE + ) + # When using pythonic data model for request, primitive args found + # because default values set + assert len(request_dict.keys()) == 5 + assert request_dict["bool_type"] is True + assert request_dict["int_type"] == 42 + assert request_dict["float_type"] == 34.0 + assert request_dict["str_type"] == "moose" + assert request_dict["bytes_type"] == b"" + + +def test_global_predict_build_caikit_library_request_dict_with_proto_for_set_primitives( sample_inference_service, ): """Global predict build_caikit_library_request_dict works for primitives""" + # When using protobuf message for request, only set params included request = sample_inference_service.messages.SampleTaskRequest( int_type=5, float_type=4.2, @@ -286,12 +309,77 @@ def test_global_predict_build_caikit_library_request_dict_works_for_set_primitiv assert request_dict["list_type"] == ["1", "2", "3"] +def test_global_predict_build_caikit_library_request_dict_works_for_set_primitives( + sample_inference_service, +): + """Global predict build_caikit_library_request_dict works for primitives""" + # When using pythonic data model for request, primitive args found that are set and + # ones with default values set + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + request = predict_class( + int_type=5, + float_type=4.2, + str_type="moose", + bytes_type=b"foo", + list_type=["1", "2", "3"], + ).to_proto() + + request_dict = build_caikit_library_request_dict( + request, SamplePrimitiveModule.RUN_SIGNATURE + ) + # bool_type also included because default value set + assert request_dict["bool_type"] is True + assert request_dict["int_type"] == 5 + assert request_dict["float_type"] == 4.2 + assert request_dict["str_type"] == "moose" + assert request_dict["bytes_type"] == b"foo" + assert request_dict["list_type"] == ["1", "2", "3"] + + +@pytest.mark.skip( + "Skipping until bug fixes for unset Union fields - https://github.com/caikit/caikit/issues/471" +) def test_global_train_build_caikit_library_request_dict_strips_empty_list_from_request( sample_train_service, ): """Global train build_caikit_library_request_dict strips empty list from request""" - # NOTE: not sure this test is relevant anymore, since nothing effectively gets removed? - # the datastream is empty but it's not removed from request, which is expected + stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType + training_data = stream_type(jsondata=stream_type.JsonData(data=[])) + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( + model_name=random_test_id(), + parameters=train_request_params_class(training_data=training_data), + ).to_proto() + + caikit.core_request = build_caikit_library_request_dict( + train_request.parameters, + sample_lib.modules.sample_task.SampleModule.TRAIN_SIGNATURE, + ) + + # model_name is not expected to be passed through + # since using pythonic data model, keeps params that have default value and removes ones that are empty + expected_arguments = { + "sleep_time", + "oom_exit", + "sleep_increment", + "batch_size", + } + + assert expected_arguments == set(caikit.core_request.keys()) + + +@pytest.mark.skip( + "Skipping until bug fixes for unset Union fields - https://github.com/caikit/caikit/issues/471" +) +def test_global_train_build_caikit_library_request_dict_with_proto_keeps_empty_params_from_request( + sample_train_service, +): + """Global train build_caikit_library_request_dict strips empty list from request""" + # NOTE: Using protobuf to create request, by explicitly passing in training_data, even though + # the datastream is empty, it's not removed from request, which is expected stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType training_data = stream_type(jsondata=stream_type.JsonData(data=[])).to_proto() train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( @@ -307,7 +395,7 @@ def test_global_train_build_caikit_library_request_dict_strips_empty_list_from_r ) # model_name is not expected to be passed through - expected_arguments = {"training_data", "union_list"} + expected_arguments = {"training_data"} assert expected_arguments == set(caikit.core_request.keys()) assert isinstance(caikit.core_request["training_data"], DataStreamSourceBase) @@ -319,14 +407,18 @@ def test_global_train_build_caikit_library_request_dict_works_for_repeated_field """Global train build_caikit_library_request_dict works for repeated fields""" stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType - training_data = stream_type(jsondata=stream_type.JsonData(data=[])).to_proto() - train_request = sample_train_service.messages.SampleTaskListModuleTrainRequest( + training_data = stream_type(jsondata=stream_type.JsonData(data=[])) + train_class = DataBase.get_class_for_name("SampleTaskListModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskListModuleTrainParameters" + ) + train_request = train_class( model_name=random_test_id(), - parameters=sample_train_service.messages.SampleTaskListModuleTrainParameters( + parameters=train_request_params_class( training_data=training_data, poison_pills=["Bob Marley", "Bunny Livingston"], ), - ) + ).to_proto() caikit.core_request = build_caikit_library_request_dict( train_request.parameters, @@ -334,7 +426,7 @@ def test_global_train_build_caikit_library_request_dict_works_for_repeated_field ) # model_name is not expected to be passed through - expected_arguments = {"training_data", "poison_pills"} + expected_arguments = {"poison_pills", "batch_size"} assert expected_arguments == set(caikit.core_request.keys()) assert len(caikit.core_request.keys()) == 2 @@ -346,18 +438,21 @@ def test_global_train_build_caikit_library_request_dict_ok_with_DataStreamSource sample_train_service, ): stream_type = caikit.interfaces.common.data_model.DataStreamSourceInt - training_data = stream_type( - jsondata=stream_type.JsonData(data=[100, 120]) - ).to_proto() + training_data = stream_type(jsondata=stream_type.JsonData(data=[100, 120])) - train_request = sample_train_service.messages.OtherTaskOtherModuleTrainRequest( + train_class = DataBase.get_class_for_name("OtherTaskOtherModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "OtherTaskOtherModuleTrainParameters" + ) + train_request = train_class( model_name="Bar Training", - parameters=sample_train_service.messages.OtherTaskOtherModuleTrainParameters( - sample_input_sampleinputtype=SampleInputType(name="Gabe").to_proto(), + parameters=train_request_params_class( + sample_input_sampleinputtype=HAPPY_PATH_INPUT_DM, batch_size=100, training_data=training_data, ), - ) + ).to_proto() + caikit.core_request = build_caikit_library_request_dict( train_request.parameters, sample_lib.modules.other_task.OtherModule.TRAIN_SIGNATURE, @@ -368,21 +463,26 @@ def test_global_train_build_caikit_library_request_dict_ok_with_DataStreamSource assert expected_arguments == set(caikit.core_request.keys()) +@pytest.mark.skip( + "Skipping until bug fixes for unset Union fields - https://github.com/caikit/caikit/issues/471" +) def test_global_train_build_caikit_library_request_dict_ok_with_data_stream_file_type_csv( sample_train_service, sample_csv_file ): """Global train build_caikit_library_request_dict works for csv training data file""" stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType - training_data = stream_type( - file=stream_type.File(filename=sample_csv_file) - ).to_proto() - train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + training_data = stream_type(file=stream_type.File(filename=sample_csv_file)) + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( model_name=random_test_id(), - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( + parameters=train_request_params_class( training_data=training_data, ), - ) + ).to_proto() caikit.core_request = build_caikit_library_request_dict( train_request.parameters, @@ -390,7 +490,13 @@ def test_global_train_build_caikit_library_request_dict_ok_with_data_stream_file ) # model_name is not expected to be passed through - expected_arguments = {"training_data", "union_list"} + expected_arguments = { + "training_data", + "batch_size", + "oom_exit", + "sleep_time", + "sleep_increment", + } assert expected_arguments == set(caikit.core_request.keys()) @@ -402,14 +508,18 @@ def test_global_train_build_caikit_library_request_dict_ok_with_training_data_as stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType training_data = stream_type( listoffiles=stream_type.ListOfFiles(files=[sample_csv_file, sample_json_file]) - ).to_proto() - train_request = sample_train_service.messages.SampleTaskListModuleTrainRequest( + ) + train_class = DataBase.get_class_for_name("SampleTaskListModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskListModuleTrainParameters" + ) + train_request = train_class( model_name=random_test_id(), - parameters=sample_train_service.messages.SampleTaskListModuleTrainParameters( + parameters=train_request_params_class( training_data=training_data, poison_pills=["Bob Marley", "Bunny Livingston"], ), - ) + ).to_proto() caikit.core_request = build_caikit_library_request_dict( train_request.parameters, @@ -417,10 +527,10 @@ def test_global_train_build_caikit_library_request_dict_ok_with_training_data_as ) # model_name is not expected to be passed through - expected_arguments = {"training_data", "poison_pills"} + expected_arguments = {"training_data", "poison_pills", "batch_size"} assert expected_arguments == set(caikit.core_request.keys()) - assert len(caikit.core_request.keys()) == 2 + assert len(caikit.core_request.keys()) == 3 assert "training_data" in caikit.core_request @@ -442,16 +552,32 @@ def test_build_caikit_library_request_dict_works_when_data_stream_directory_incl ) training_data = stream_type( directory=stream_type.Directory(dirname=tempdir, extension="json") - ).to_proto() - train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + ) + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( model_name=random_test_id(), - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( + parameters=train_request_params_class( training_data=training_data, ), - ) + ).to_proto() # no error because at least 1 json file exists within the provided dir caikit.core_request = build_caikit_library_request_dict( train_request.parameters, sample_lib.modules.sample_task.SampleModule.TRAIN_SIGNATURE, ) + + expected_arguments = { + "training_data", + "union_list", + "batch_size", + "oom_exit", + "sleep_time", + "sleep_increment", + } + + assert expected_arguments == set(caikit.core_request.keys()) + assert "training_data" in caikit.core_request