Skip to content

Commit

Permalink
continued switch to use datamodel classes
Browse files Browse the repository at this point in the history
Signed-off-by: Anh-Uong <[email protected]>
  • Loading branch information
anhuong committed Sep 13, 2023
1 parent 73e718d commit 9c2b046
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 26 deletions.
10 changes: 8 additions & 2 deletions examples/text-sentiment/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

# Local
from caikit.config.config import get_config
from caikit.core.data_model.base import DataBase
from caikit.runtime.service_factory import ServicePackageFactory
import caikit

Expand Down Expand Up @@ -50,9 +51,14 @@

# Run inference for two sample prompts
for text in ["I am not feeling well today!", "Today is a nice sunny day"]:
request = inference_service.messages.HuggingFaceSentimentTaskRequest(
# TODO: is this the recommended approach for setting up client and request?
predict_class = DataBase.get_class_for_name("HuggingFaceSentimentTaskRequest")
request = predict_class(
text_input=text
)
).to_proto()
# request = inference_service.messages.HuggingFaceSentimentTaskRequest(
# text_input=text
# )
response = client_stub.HuggingFaceSentimentTaskPredict(
request, metadata=[("mm-model-id", model_id)], timeout=1
)
Expand Down
62 changes: 38 additions & 24 deletions tests/runtime/servicers/test_global_predict_servicer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import pytest

# Local
from caikit.core.data_model.base import DataBase
from caikit.runtime.servicers.global_predict_servicer import GlobalPredictServicer
from caikit.runtime.types.aborted_exception import AbortedException
from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException
Expand All @@ -56,9 +57,10 @@ def test_calling_predict_should_raise_if_module_raises(
):
with pytest.raises(CaikitRuntimeException) as context:
# SampleModules will raise a RuntimeError if the throw flag is set
request = sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT, throw=True
)
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
request = predict_class(
sample_input=HAPPY_PATH_INPUT_DM, throw=True
).to_proto()
sample_predict_servicer.Predict(
request,
Fixtures.build_context(sample_task_model_id),
Expand All @@ -77,9 +79,10 @@ def test_invalid_input_to_a_valid_caikit_core_class_method_raises(
"""Test that a caikit.core module that gets an unexpected input value errors in an expected way"""
with pytest.raises(CaikitRuntimeException) as context:
# SampleModules will raise a ValueError if the poison pill name is given
request = sample_inference_service.messages.SampleTaskRequest(
sample_input=SampleInputType(name=SampleModule.POISON_PILL_NAME).to_proto(),
)
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
request = predict_class(
sample_input=SampleInputType(name=SampleModule.POISON_PILL_NAME)
).to_proto()
sample_predict_servicer.Predict(
request,
Fixtures.build_context(sample_task_model_id),
Expand All @@ -96,10 +99,11 @@ def test_global_predict_works_for_unary_rpcs(
sample_task_unary_rpc,
):
"""Global predict of SampleTaskRequest returns a prediction"""
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
response = sample_predict_servicer.Predict(
sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT
),
predict_class(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto(),
Fixtures.build_context(sample_task_model_id),
caikit_rpc=sample_task_unary_rpc,
)
Expand All @@ -114,6 +118,12 @@ def test_global_predict_works_on_bidirectional_streaming_rpcs(
def req_iterator() -> Iterator[
sample_inference_service.messages.BidiStreamingSampleTaskRequest
]:
# TODO: can this be replaced?
# Guessing not since got error: caikit.runtime.types.caikit_runtime_exception.CaikitRuntimeException: Exception raised during inference. This may be a problem with your input: BidiStreamingSampleTaskRequest.__init__() got an unexpected keyword argument 'sample_input'
# predict_class = DataBase.get_class_for_name("BidiStreamingSampleTaskRequest")
# predict_class(
# sample_input=HAPPY_PATH_INPUT_DM
# ).to_proto()
for i in range(100):
yield sample_inference_service.messages.BidiStreamingSampleTaskRequest(
sample_inputs=HAPPY_PATH_INPUT
Expand Down Expand Up @@ -198,12 +208,13 @@ def run(self, *args, **kwargs):
mock_manager.retrieve_model.return_value = dummy_model

context = Fixtures.build_context("test-any-unresponsive-model")
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
predict_thread = threading.Thread(
target=sample_predict_servicer.Predict,
args=(
sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT
),
predict_class(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto(),
context,
),
kwargs={"caikit_rpc": sample_task_unary_rpc},
Expand Down Expand Up @@ -234,9 +245,10 @@ def test_metering_ignore_unsuccessful_calls(
gps = GlobalPredictServicer(sample_inference_service)
try:
with patch.object(gps.rpc_meter, "update_metrics") as mock_update_func:
request = sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT, throw=True
)
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
request = predict_class(
sample_input=HAPPY_PATH_INPUT_DM, throw=True
).to_proto()
with pytest.raises(CaikitRuntimeException):
gps.Predict(
request,
Expand All @@ -257,11 +269,12 @@ def test_metering_predict_rpc_counter(
sample_predict_servicer = GlobalPredictServicer(sample_inference_service)
try:
# Making 20 requests
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
for i in range(20):
sample_predict_servicer.Predict(
sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT
),
predict_class(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto(),
Fixtures.build_context(sample_task_model_id),
caikit_rpc=sample_task_unary_rpc,
)
Expand Down Expand Up @@ -299,10 +312,11 @@ def test_metering_write_to_metrics_file_twice(
# need a new servicer to get a fresh new RPC meter
sample_predict_servicer = GlobalPredictServicer(sample_inference_service)
try:
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
sample_predict_servicer.Predict(
sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT
),
predict_class(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto(),
Fixtures.build_context(sample_task_model_id),
caikit_rpc=sample_task_unary_rpc,
)
Expand All @@ -311,9 +325,9 @@ def test_metering_write_to_metrics_file_twice(
sample_predict_servicer.rpc_meter.flush_metrics()

sample_predict_servicer.Predict(
sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT
),
predict_class(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto(),
Fixtures.build_context(sample_task_model_id),
caikit_rpc=sample_task_unary_rpc,
)
Expand Down

0 comments on commit 9c2b046

Please sign in to comment.