diff --git a/examples/text-sentiment/client.py b/examples/text-sentiment/client.py index 74b95f6b6..7e13d7e09 100644 --- a/examples/text-sentiment/client.py +++ b/examples/text-sentiment/client.py @@ -51,14 +51,10 @@ # Run inference for two sample prompts for text in ["I am not feeling well today!", "Today is a nice sunny day"]: - # TODO: is this the recommended approach for setting up client and request? predict_class = DataBase.get_class_for_name( "HuggingFaceSentimentTaskRequest" ) request = predict_class(text_input=text).to_proto() - # request = inference_service.messages.HuggingFaceSentimentTaskRequest( - # text_input=text - # ) response = client_stub.HuggingFaceSentimentTaskPredict( request, metadata=[("mm-model-id", model_id)], timeout=1 ) diff --git a/tests/runtime/interceptors/test_caikit_runtime_server_wrapper.py b/tests/runtime/interceptors/test_caikit_runtime_server_wrapper.py index ebfe89446..03eb2b457 100644 --- a/tests/runtime/interceptors/test_caikit_runtime_server_wrapper.py +++ b/tests/runtime/interceptors/test_caikit_runtime_server_wrapper.py @@ -18,6 +18,7 @@ import grpc # Local +from caikit.core.data_model.base import DataBase from caikit.runtime.interceptors.caikit_runtime_server_wrapper import ( CaikitRuntimeServerWrapper, ) @@ -47,9 +48,8 @@ def predict(request, context, caikit_rpc): client = sample_inference_service.stub_class( grpc.insecure_channel(f"localhost:{open_port}") ) - _ = client.SampleTaskPredict( - sample_inference_service.messages.SampleTaskRequest(), timeout=3 - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + _ = client.SampleTaskPredict(predict_class().to_proto(), timeout=3) assert len(calls) == 1 assert isinstance(calls[0], TaskPredictRPC) assert calls[0].name == "SampleTaskPredict" diff --git a/tests/runtime/servicers/test_global_predict_servicer_impl.py b/tests/runtime/servicers/test_global_predict_servicer_impl.py index 406aa40a1..644f3f51d 100644 --- a/tests/runtime/servicers/test_global_predict_servicer_impl.py +++ b/tests/runtime/servicers/test_global_predict_servicer_impl.py @@ -111,19 +111,11 @@ def test_global_predict_works_on_bidirectional_streaming_rpcs( ): """Simple test that our SampleModule's bidirectional stream inference fn is supported""" - def req_iterator() -> Iterator[ - sample_inference_service.messages.BidiStreamingSampleTaskRequest - ]: - # TODO: can this be replaced? - # Guessing not since got error: caikit.runtime.types.caikit_runtime_exception.CaikitRuntimeException: Exception raised during inference. This may be a problem with your input: BidiStreamingSampleTaskRequest.__init__() got an unexpected keyword argument 'sample_input' - # predict_class = DataBase.get_class_for_name("BidiStreamingSampleTaskRequest") - # predict_class( - # sample_input=HAPPY_PATH_INPUT_DM - # ).to_proto() + predict_class = DataBase.get_class_for_name("BidiStreamingSampleTaskRequest") + + def req_iterator() -> Iterator[predict_class]: for i in range(100): - yield sample_inference_service.messages.BidiStreamingSampleTaskRequest( - sample_inputs=HAPPY_PATH_INPUT - ) + yield predict_class(sample_inputs=HAPPY_PATH_INPUT_DM).to_proto() response_stream = sample_predict_servicer.Predict( req_iterator(), @@ -147,13 +139,11 @@ def test_global_predict_works_on_bidirectional_streaming_rpcs_with_multiple_stre mock_manager = MagicMock() mock_manager.retrieve_model.return_value = GeoStreamingModule() - def req_iterator() -> Iterator[ - sample_inference_service.messages.BidiStreamingGeoSpatialTaskRequest - ]: + predict_class = DataBase.get_class_for_name("BidiStreamingGeoSpatialTaskRequest") + + def req_iterator() -> Iterator[predict_class]: for i in range(100): - yield sample_inference_service.messages.BidiStreamingGeoSpatialTaskRequest( - lats=i, lons=100 - i, name="Gabe" - ) + yield predict_class(lats=i, lons=100 - i, name="Gabe").to_proto() with patch.object(sample_predict_servicer, "_model_manager", mock_manager): response_stream = sample_predict_servicer.Predict( diff --git a/tests/runtime/test_grpc_server.py b/tests/runtime/test_grpc_server.py index 9913e09c6..1c90252ad 100644 --- a/tests/runtime/test_grpc_server.py +++ b/tests/runtime/test_grpc_server.py @@ -59,6 +59,7 @@ process_pb2_grpc, ) from caikit.runtime.service_factory import ServicePackage, ServicePackageFactory +from caikit.runtime.utils.servicer_util import build_caikit_library_request_dict from sample_lib import InnerModule, SamplePrimitiveModule from sample_lib.data_model import ( OtherOutputType, @@ -206,34 +207,52 @@ def test_components_preinitialized(reset_globals, open_port): assert MODEL_MANAGER._initializers -def test_predict_sample_module_proto_ok_response( +def test_predict_sample_module_ok_response( sample_task_model_id, runtime_grpc_server, sample_inference_service ): """Test RPC CaikitRuntime.SampleTaskPredict successful response""" stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel()) - predict_request = sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ) + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() + actual_response = stub.SampleTaskPredict( predict_request, metadata=[("mm-model-id", sample_task_model_id)] ) + assert actual_response == HAPPY_PATH_RESPONSE -def test_predict_sample_module_ok_response( - sample_task_model_id, runtime_grpc_server, sample_inference_service +def test_global_predict_build_caikit_library_request_dict_creates_caikit_core_run_kwargs( + sample_inference_service, ): - """Test RPC CaikitRuntime.SampleTaskPredict successful response""" - stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel()) + """Test using proto versus pythonic data model for inference requests to compare diffs""" + # Protobuf input + proto_request = sample_inference_service.messages.SampleTaskRequest( + sample_input=HAPPY_PATH_INPUT_DM.to_proto(), + ) + proto_request_dict = build_caikit_library_request_dict( + proto_request, + sample_lib.modules.sample_task.SampleModule.RUN_SIGNATURE, + ) + + # unset fields not included + proto_expected_arguments = {"sample_input"} + assert proto_request.HasField("throw") is False + assert proto_expected_arguments == set(proto_request_dict.keys()) + + # Pythonic data model predict_class = DataBase.get_class_for_name("SampleTaskRequest") - # TODO: is this getting predict_class = SampleModule() ?? - predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() + python_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto() - actual_response = stub.SampleTaskPredict( - predict_request, metadata=[("mm-model-id", sample_task_model_id)] + python_sample_module_request_dict = build_caikit_library_request_dict( + python_request, + sample_lib.modules.sample_task.SampleModule.RUN_SIGNATURE, ) - assert actual_response == HAPPY_PATH_RESPONSE + # unset fields are included if they have defaults set + python_expected_arguments = {"sample_input", "throw"} + assert python_request.HasField("throw") is True + assert python_expected_arguments == set(python_sample_module_request_dict.keys()) def test_predict_streaming_module( diff --git a/tests/runtime/test_service_factory.py b/tests/runtime/test_service_factory.py index 1ecefaf83..36da14285 100644 --- a/tests/runtime/test_service_factory.py +++ b/tests/runtime/test_service_factory.py @@ -23,6 +23,7 @@ # Local from caikit.core.data_model import render_dataobject_protos +from caikit.core.data_model.base import DataBase from caikit.runtime.service_factory import ServicePackage, ServicePackageFactory from sample_lib import SampleModule from sample_lib.data_model import SampleInputType, SampleOutputType @@ -359,7 +360,8 @@ def run( inference_service = ServicePackageFactory.get_service_package( ServicePackageFactory.ServiceType.INFERENCE ) - sample_task_request = inference_service.messages.SampleTaskRequest + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + sample_task_request = predict_class().to_proto() # Check that the new parameter defined in this backend module exists in the service assert "backend_param" in sample_task_request.DESCRIPTOR.fields_by_name.keys() diff --git a/tests/runtime/utils/test_servicer_util.py b/tests/runtime/utils/test_servicer_util.py index 921cb13e2..a4c3d8a07 100644 --- a/tests/runtime/utils/test_servicer_util.py +++ b/tests/runtime/utils/test_servicer_util.py @@ -111,7 +111,9 @@ def test_service_util_validate_caikit_library_class_method_exists_does_raise(): # ---------------- Tests for build_proto_response -------------------------- -def test_servicer_util_build_proto_response_raises_on_garbage_response_type(): +def test_servicer_util_build_proto_response_raises_on_garbage_response_type( + sample_inference_service, +): class FooResponse: def __init__(self, foo) -> None: self.foo = foo @@ -125,23 +127,19 @@ def test_servicer_util_is_protobuf_primitive_returns_true_for_primitive_types( sample_inference_service, ): """Test that is_protobuf_primitive_field is True when considering primitive types""" + predict_class = DataBase.get_class_for_name("SampleTaskRequest") assert is_protobuf_primitive_field( - sample_inference_service.messages.SampleTaskRequest().DESCRIPTOR.fields_by_name[ - "int_type" - ] + predict_class().to_proto().DESCRIPTOR.fields_by_name["int_type"] ) -def test_servicer_util_is_protobuf_primitive_returns_false_for_custom_types( - sample_inference_service, -): +def test_servicer_util_is_protobuf_primitive_returns_false_for_custom_types(): """Test that is_protobuf_primitive_field is False when considering message and enum types. This is essential for handling Caikit library CDM objects, which are generally defined in terms of messages""" + predict_class = DataBase.get_class_for_name("SampleTaskRequest") assert not is_protobuf_primitive_field( - sample_inference_service.messages.SampleTaskRequest().DESCRIPTOR.fields_by_name[ - "sample_input" - ] + predict_class().to_proto().DESCRIPTOR.fields_by_name["sample_input"] ) @@ -199,7 +197,7 @@ def test_servicer_util_will_not_validate_arbitrary_service_descriptor(): # ---------------- Tests for build_caikit_library_request_dict -------------------- -HAPPY_PATH_INPUT = SampleInputType(name="Gabe").to_proto() +HAPPY_PATH_INPUT_DM = SampleInputType(name="Gabe") def test_global_predict_build_caikit_library_request_dict_creates_caikit_core_run_kwargs( @@ -207,19 +205,13 @@ def test_global_predict_build_caikit_library_request_dict_creates_caikit_core_ru ): """Test that build_caikit_library_request_dict creates module run kwargs from RPC msg""" predict_class = DataBase.get_class_for_name("SampleTaskRequest") - # TODO: this change caused error: where output != expected_arguments -- {'sample_input', 'throw'} request_dict = build_caikit_library_request_dict( - # predict_class( - # sample_input=HAPPY_PATH_INPUT_DM - # ).to_proto(), - sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT - ), + predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(), sample_lib.modules.sample_task.SampleModule.RUN_SIGNATURE, ) - # No self or "throw", throw was not set and the throw parameter contains a default value - expected_arguments = {"sample_input"} + # Since using pythonic data model and throw has a default parameter, it is an expected argument + expected_arguments = {"sample_input", "throw"} assert expected_arguments == set(request_dict.keys()) assert isinstance(request_dict["sample_input"], SampleInputType) @@ -232,17 +224,15 @@ def test_global_predict_build_caikit_library_request_dict_strips_invalid_run_kwa # Sample module doesn't take the `int_type` or `bool_type` params predict_class = DataBase.get_class_for_name("SampleTaskRequest") request_dict = build_caikit_library_request_dict( - # predict_class( - # sample_input=HAPPY_PATH_INPUT_DM, - sample_inference_service.messages.SampleTaskRequest( - sample_input=HAPPY_PATH_INPUT, + predict_class( + sample_input=HAPPY_PATH_INPUT_DM, int_type=5, bool_type=True, - ), + ).to_proto(), sample_lib.modules.sample_task.SampleModule.RUN_SIGNATURE, ) - expected_arguments = {"sample_input"} + expected_arguments = {"sample_input", "throw"} assert expected_arguments == set(request_dict.keys()) assert "int_type" not in request_dict.keys() @@ -251,7 +241,6 @@ def test_global_predict_build_caikit_library_request_dict_strips_empty_list_from sample_inference_service, ): """Global predict build_caikit_library_request_dict strips empty list from request""" - # TODO: are these ones worth replacing...? predict_class = DataBase.get_class_for_name("SampleTaskRequest") request_dict = build_caikit_library_request_dict( predict_class(int_type=5, list_type=[]).to_proto(), @@ -262,10 +251,12 @@ def test_global_predict_build_caikit_library_request_dict_strips_empty_list_from assert "int_type" in request_dict.keys() -def test_global_predict_build_caikit_library_request_dict_works_for_unset_primitives( +def test_global_predict_build_caikit_library_request_dict_with_proto_does_not_include_unset_primitives( sample_inference_service, ): """Global predict build_caikit_library_request_dict works for primitives""" + # When using protobuf message for request, if params unset, does not include primitive args + # that are unset even if default values present request = sample_inference_service.messages.SampleTaskRequest() request_dict = build_caikit_library_request_dict( @@ -275,10 +266,31 @@ def test_global_predict_build_caikit_library_request_dict_works_for_unset_primit assert len(request_dict.keys()) == 0 -def test_global_predict_build_caikit_library_request_dict_works_for_set_primitives( +def test_global_predict_build_caikit_library_request_dict_works_for_unset_primitives( sample_inference_service, ): """Global predict build_caikit_library_request_dict works for primitives""" + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + request = predict_class().to_proto() + + request_dict = build_caikit_library_request_dict( + request, SamplePrimitiveModule.RUN_SIGNATURE + ) + # When using pythonic data model for request, primitive args found + # because default values set + assert len(request_dict.keys()) == 5 + assert request_dict["bool_type"] is True + assert request_dict["int_type"] == 42 + assert request_dict["float_type"] == 34.0 + assert request_dict["str_type"] == "moose" + assert request_dict["bytes_type"] == b"" + + +def test_global_predict_build_caikit_library_request_dict_with_proto_for_set_primitives( + sample_inference_service, +): + """Global predict build_caikit_library_request_dict works for primitives""" + # When using protobuf message for request, only set params included request = sample_inference_service.messages.SampleTaskRequest( int_type=5, float_type=4.2, @@ -297,23 +309,39 @@ def test_global_predict_build_caikit_library_request_dict_works_for_set_primitiv assert request_dict["list_type"] == ["1", "2", "3"] +def test_global_predict_build_caikit_library_request_dict_works_for_set_primitives( + sample_inference_service, +): + """Global predict build_caikit_library_request_dict works for primitives""" + # When using pythonic data model for request, primitive args found that are set and + # ones with default values set + predict_class = DataBase.get_class_for_name("SampleTaskRequest") + request = predict_class( + int_type=5, + float_type=4.2, + str_type="moose", + bytes_type=b"foo", + list_type=["1", "2", "3"], + ).to_proto() + + request_dict = build_caikit_library_request_dict( + request, SamplePrimitiveModule.RUN_SIGNATURE + ) + # bool_type also included because default value set + assert request_dict["bool_type"] is True + assert request_dict["int_type"] == 5 + assert request_dict["float_type"] == 4.2 + assert request_dict["str_type"] == "moose" + assert request_dict["bytes_type"] == b"foo" + assert request_dict["list_type"] == ["1", "2", "3"] + + def test_global_train_build_caikit_library_request_dict_strips_empty_list_from_request( sample_train_service, ): """Global train build_caikit_library_request_dict strips empty list from request""" - # NOTE: not sure this test is relevant anymore, since nothing effectively gets removed? - # the datastream is empty but it's not removed from request, which is expected stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType training_data = stream_type(jsondata=stream_type.JsonData(data=[])) - # TODO: fails with diffs.... - # AssertionError: assert {'training_da... 'union_list'} == {'batch_size'... 'union_list'} - # E Extra items in the left set: - # E 'training_data' - # E Extra items in the right set: - # E 'sleep_time' - # E 'oom_exit' - # E 'sleep_increment' - # E 'batch_size'... train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") train_request_params_class = DataBase.get_class_for_name( "SampleTaskSampleModuleTrainParameters" @@ -329,6 +357,41 @@ def test_global_train_build_caikit_library_request_dict_strips_empty_list_from_r ) # model_name is not expected to be passed through + # since using pythonic data model, keeps params that have default value and removes ones that are empty + # TODO: create issue for union_list that should be removed + expected_arguments = { + "union_list", + "sleep_time", + "oom_exit", + "sleep_increment", + "batch_size", + } + + assert expected_arguments == set(caikit.core_request.keys()) + + +def test_global_train_build_caikit_library_request_dict_with_proto_keeps_empty_params_from_request( + sample_train_service, +): + """Global train build_caikit_library_request_dict strips empty list from request""" + # NOTE: Using protobuf to create request, by explicitly passing in training_data, even though + # the datastream is empty, it's not removed from request, which is expected + stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType + training_data = stream_type(jsondata=stream_type.JsonData(data=[])).to_proto() + train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + model_name=random_test_id(), + parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( + training_data=training_data + ), + ) + + caikit.core_request = build_caikit_library_request_dict( + train_request.parameters, + sample_lib.modules.sample_task.SampleModule.TRAIN_SIGNATURE, + ) + + # model_name is not expected to be passed through + # TODO: create issue for union_list that should be removed expected_arguments = {"training_data", "union_list"} assert expected_arguments == set(caikit.core_request.keys()) @@ -341,14 +404,18 @@ def test_global_train_build_caikit_library_request_dict_works_for_repeated_field """Global train build_caikit_library_request_dict works for repeated fields""" stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType - training_data = stream_type(jsondata=stream_type.JsonData(data=[])).to_proto() - train_request = sample_train_service.messages.SampleTaskListModuleTrainRequest( + training_data = stream_type(jsondata=stream_type.JsonData(data=[])) + train_class = DataBase.get_class_for_name("SampleTaskListModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskListModuleTrainParameters" + ) + train_request = train_class( model_name=random_test_id(), - parameters=sample_train_service.messages.SampleTaskListModuleTrainParameters( + parameters=train_request_params_class( training_data=training_data, poison_pills=["Bob Marley", "Bunny Livingston"], ), - ) + ).to_proto() caikit.core_request = build_caikit_library_request_dict( train_request.parameters, @@ -356,7 +423,7 @@ def test_global_train_build_caikit_library_request_dict_works_for_repeated_field ) # model_name is not expected to be passed through - expected_arguments = {"training_data", "poison_pills"} + expected_arguments = {"poison_pills", "batch_size"} assert expected_arguments == set(caikit.core_request.keys()) assert len(caikit.core_request.keys()) == 2 @@ -368,18 +435,21 @@ def test_global_train_build_caikit_library_request_dict_ok_with_DataStreamSource sample_train_service, ): stream_type = caikit.interfaces.common.data_model.DataStreamSourceInt - training_data = stream_type( - jsondata=stream_type.JsonData(data=[100, 120]) - ).to_proto() + training_data = stream_type(jsondata=stream_type.JsonData(data=[100, 120])) - train_request = sample_train_service.messages.OtherTaskOtherModuleTrainRequest( + train_class = DataBase.get_class_for_name("OtherTaskOtherModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "OtherTaskOtherModuleTrainParameters" + ) + train_request = train_class( model_name="Bar Training", - parameters=sample_train_service.messages.OtherTaskOtherModuleTrainParameters( - sample_input_sampleinputtype=SampleInputType(name="Gabe").to_proto(), + parameters=train_request_params_class( + sample_input_sampleinputtype=HAPPY_PATH_INPUT_DM, batch_size=100, training_data=training_data, ), - ) + ).to_proto() + caikit.core_request = build_caikit_library_request_dict( train_request.parameters, sample_lib.modules.other_task.OtherModule.TRAIN_SIGNATURE, @@ -396,15 +466,17 @@ def test_global_train_build_caikit_library_request_dict_ok_with_data_stream_file """Global train build_caikit_library_request_dict works for csv training data file""" stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType - training_data = stream_type( - file=stream_type.File(filename=sample_csv_file) - ).to_proto() - train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + training_data = stream_type(file=stream_type.File(filename=sample_csv_file)) + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( model_name=random_test_id(), - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( + parameters=train_request_params_class( training_data=training_data, ), - ) + ).to_proto() caikit.core_request = build_caikit_library_request_dict( train_request.parameters, @@ -412,7 +484,15 @@ def test_global_train_build_caikit_library_request_dict_ok_with_data_stream_file ) # model_name is not expected to be passed through - expected_arguments = {"training_data", "union_list"} + # TODO: union_list should not be there + expected_arguments = { + "training_data", + "union_list", + "batch_size", + "oom_exit", + "sleep_time", + "sleep_increment", + } assert expected_arguments == set(caikit.core_request.keys()) @@ -424,14 +504,18 @@ def test_global_train_build_caikit_library_request_dict_ok_with_training_data_as stream_type = caikit.interfaces.common.data_model.DataStreamSourceSampleTrainingType training_data = stream_type( listoffiles=stream_type.ListOfFiles(files=[sample_csv_file, sample_json_file]) - ).to_proto() - train_request = sample_train_service.messages.SampleTaskListModuleTrainRequest( + ) + train_class = DataBase.get_class_for_name("SampleTaskListModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskListModuleTrainParameters" + ) + train_request = train_class( model_name=random_test_id(), - parameters=sample_train_service.messages.SampleTaskListModuleTrainParameters( + parameters=train_request_params_class( training_data=training_data, poison_pills=["Bob Marley", "Bunny Livingston"], ), - ) + ).to_proto() caikit.core_request = build_caikit_library_request_dict( train_request.parameters, @@ -439,10 +523,10 @@ def test_global_train_build_caikit_library_request_dict_ok_with_training_data_as ) # model_name is not expected to be passed through - expected_arguments = {"training_data", "poison_pills"} + expected_arguments = {"training_data", "poison_pills", "batch_size"} assert expected_arguments == set(caikit.core_request.keys()) - assert len(caikit.core_request.keys()) == 2 + assert len(caikit.core_request.keys()) == 3 assert "training_data" in caikit.core_request @@ -464,16 +548,32 @@ def test_build_caikit_library_request_dict_works_when_data_stream_directory_incl ) training_data = stream_type( directory=stream_type.Directory(dirname=tempdir, extension="json") - ).to_proto() - train_request = sample_train_service.messages.SampleTaskSampleModuleTrainRequest( + ) + train_class = DataBase.get_class_for_name("SampleTaskSampleModuleTrainRequest") + train_request_params_class = DataBase.get_class_for_name( + "SampleTaskSampleModuleTrainParameters" + ) + train_request = train_class( model_name=random_test_id(), - parameters=sample_train_service.messages.SampleTaskSampleModuleTrainParameters( + parameters=train_request_params_class( training_data=training_data, ), - ) + ).to_proto() # no error because at least 1 json file exists within the provided dir caikit.core_request = build_caikit_library_request_dict( train_request.parameters, sample_lib.modules.sample_task.SampleModule.TRAIN_SIGNATURE, ) + + expected_arguments = { + "training_data", + "union_list", + "batch_size", + "oom_exit", + "sleep_time", + "sleep_increment", + } + + assert expected_arguments == set(caikit.core_request.keys()) + assert "training_data" in caikit.core_request