Skip to content

Commit

Permalink
♻️ change name to get_inference_request
Browse files Browse the repository at this point in the history
Signed-off-by: Prashant Gupta <[email protected]>
  • Loading branch information
prashantgupta24 committed Sep 21, 2023
1 parent 67aba55 commit da67d77
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 19 deletions.
2 changes: 1 addition & 1 deletion caikit/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
# limitations under the License.

# Local
from .service_factory import get_request, get_train_params, get_train_request
from .service_factory import get_inference_request, get_train_params, get_train_request
2 changes: 1 addition & 1 deletion caikit/runtime/service_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def _get_and_filter_modules(
return clean_modules


def get_request(
def get_inference_request(
module_class: Type[ModuleBase],
input_streaming: bool = False,
output_streaming: bool = False,
Expand Down
4 changes: 2 additions & 2 deletions examples/text-sentiment/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

# Local
from caikit.config.config import get_config
from caikit.runtime import get_request
from caikit.runtime import get_inference_request
from caikit.runtime.service_factory import ServicePackageFactory
from text_sentiment.runtime_model import HuggingFaceSentimentModule
import caikit
Expand Down Expand Up @@ -52,7 +52,7 @@

# Run inference for two sample prompts
for text in ["I am not feeling well today!", "Today is a nice sunny day"]:
request = get_request(HuggingFaceSentimentModule)(
request = get_inference_request(HuggingFaceSentimentModule)(
text_input=text
).to_proto()
response = client_stub.HuggingFaceSentimentTaskPredict(
Expand Down
30 changes: 15 additions & 15 deletions tests/runtime/test_grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
TrainingStatus,
TrainingStatusResponse,
)
from caikit.runtime import get_request, get_train_params, get_train_request
from caikit.runtime import get_inference_request, get_train_params, get_train_request
from caikit.runtime.grpc_server import RuntimeGRPCServer
from caikit.runtime.model_management.model_manager import ModelManager
from caikit.runtime.protobufs import (
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_predict_sample_module_ok_response(
):
"""Test RPC CaikitRuntime.SampleTaskPredict successful response"""
stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel())
predict_request = get_request(SampleModule)(
predict_request = get_inference_request(SampleModule)(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto()

Expand Down Expand Up @@ -244,7 +244,7 @@ def test_global_predict_build_caikit_library_request_dict_creates_caikit_core_ru
assert proto_expected_arguments == set(proto_request_dict.keys())

# pythonic data model request
predict_class = get_request(SampleModule)
predict_class = get_inference_request(SampleModule)
python_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()

python_sample_module_request_dict = build_caikit_library_request_dict(
Expand All @@ -263,7 +263,7 @@ def test_predict_streaming_module(
):
"""Test RPC CaikitRuntime.StreamingTaskPredict successful response"""
stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel())
predict_class = get_request(
predict_class = get_inference_request(
StreamingModule, input_streaming=False, output_streaming=True
)
predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()
Expand All @@ -287,7 +287,7 @@ def test_predict_sample_module_error_response(
stub = sample_inference_service.stub_class(
runtime_grpc_server.make_local_channel()
)
predict_class = get_request(SampleModule)
predict_class = get_inference_request(SampleModule)
predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()

stub.SampleTaskPredict(
Expand All @@ -302,7 +302,7 @@ def test_rpc_validation_on_predict(
):
"""Check that the server catches models sent to the wrong task RPCs"""
stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel())
predict_class = get_request(OtherModule)
predict_class = get_inference_request(OtherModule)
predict_request = predict_class(
sample_input_sampleinputtype=HAPPY_PATH_INPUT_DM
).to_proto()
Expand Down Expand Up @@ -333,7 +333,7 @@ def test_rpc_validation_on_predict_for_unsupported_model(
stub = sample_inference_service.stub_class(
runtime_grpc_server.make_local_channel()
)
predict_class = get_request(SampleModule)
predict_class = get_inference_request(SampleModule)
predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()
with pytest.raises(grpc.RpcError) as context:
stub.SampleTaskPredict(
Expand Down Expand Up @@ -366,7 +366,7 @@ def test_rpc_validation_on_predict_for_wrong_streaming_flavor(
runtime_grpc_server.make_local_channel()
)

predict_class = get_request(
predict_class = get_inference_request(
SampleModule,
)
predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()
Expand Down Expand Up @@ -423,7 +423,7 @@ def test_train_fake_module_ok_response_and_can_predict_with_trained_model(
)

# make sure the trained model can run inference
predict_class = get_request(SampleModule)
predict_class = get_inference_request(SampleModule)
predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()
inference_response = inference_stub.SampleTaskPredict(
predict_request, metadata=[("mm-model-id", actual_response.model_name)]
Expand Down Expand Up @@ -459,7 +459,7 @@ def test_train_fake_module_ok_response_with_loaded_model_can_predict_with_traine
)

# make sure the trained model can run inference
predict_class = get_request(CompositeModule)
predict_class = get_inference_request(CompositeModule)
predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()
inference_response = inference_stub.SampleTaskPredict(
predict_request, metadata=[("mm-model-id", actual_response.model_name)]
Expand Down Expand Up @@ -506,7 +506,7 @@ def test_train_fake_module_does_not_change_another_instance_model_of_block(
)

# make sure the trained model can run inference, and the batch size 100 was used
predict_class = get_request(OtherModule)
predict_class = get_inference_request(OtherModule)
predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()
trained_inference_response = inference_stub.OtherTaskPredict(
predict_request, metadata=[("mm-model-id", actual_response.model_name)]
Expand Down Expand Up @@ -566,7 +566,7 @@ def test_train_primitive_model(
)

# make sure the trained model can run inference
predict_class = get_request(SampleModule)
predict_class = get_inference_request(SampleModule)
predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()

inference_response = inference_stub.SampleTaskPredict(
Expand Down Expand Up @@ -613,7 +613,7 @@ def test_train_fake_module_ok_response_with_datastream_jsondata(
)

# make sure the trained model can run inference
predict_class = get_request(SampleModule)
predict_class = get_inference_request(SampleModule)
predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()
inference_response = inference_stub.SampleTaskPredict(
predict_request, metadata=[("mm-model-id", actual_response.model_name)]
Expand Down Expand Up @@ -651,7 +651,7 @@ def test_train_fake_module_ok_response_with_datastream_csv_file(
)

# make sure the trained model can run inference
predict_class = get_request(SampleModule)
predict_class = get_inference_request(SampleModule)
predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()
inference_response = inference_stub.SampleTaskPredict(
predict_request, metadata=[("mm-model-id", actual_response.model_name)]
Expand Down Expand Up @@ -1116,7 +1116,7 @@ def test_metrics_stored_after_server_interrupt(
with temp_config({"runtime": {"metering": {"enabled": True}}}, "merge"):
with runtime_grpc_test_server(open_port) as server:
stub = sample_inference_service.stub_class(server.make_local_channel())
predict_class = get_request(SampleModule)
predict_class = get_inference_request(SampleModule)
predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()
_ = stub.SampleTaskPredict(
predict_request, metadata=[("mm-model-id", sample_task_model_id)]
Expand Down

0 comments on commit da67d77

Please sign in to comment.