Skip to content

Commit

Permalink
Merge pull request caikit#468 from anhuong/fix-datamodel-classes
Browse files Browse the repository at this point in the history
test: update tests to use datamodel classes
  • Loading branch information
gabe-l-hart authored Sep 15, 2023
2 parents e048be4 + 933901c commit faf973a
Show file tree
Hide file tree
Showing 7 changed files with 433 additions and 234 deletions.
6 changes: 4 additions & 2 deletions examples/text-sentiment/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

# Local
from caikit.config.config import get_config
from caikit.core.data_model.base import DataBase
from caikit.runtime.service_factory import ServicePackageFactory
import caikit

Expand Down Expand Up @@ -50,9 +51,10 @@

# Run inference for two sample prompts
for text in ["I am not feeling well today!", "Today is a nice sunny day"]:
request = inference_service.messages.HuggingFaceSentimentTaskRequest(
text_input=text
predict_class = DataBase.get_class_for_name(
"HuggingFaceSentimentTaskRequest"
)
request = predict_class(text_input=text).to_proto()
response = client_stub.HuggingFaceSentimentTaskPredict(
request, metadata=[("mm-model-id", model_id)], timeout=1
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import grpc

# Local
from caikit.core.data_model.base import DataBase
from caikit.runtime.interceptors.caikit_runtime_server_wrapper import (
CaikitRuntimeServerWrapper,
)
Expand Down Expand Up @@ -47,9 +48,8 @@ def predict(request, context, caikit_rpc):
client = sample_inference_service.stub_class(
grpc.insecure_channel(f"localhost:{open_port}")
)
_ = client.SampleTaskPredict(
sample_inference_service.messages.SampleTaskRequest(), timeout=3
)
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
_ = client.SampleTaskPredict(predict_class().to_proto(), timeout=3)
assert len(calls) == 1
assert isinstance(calls[0], TaskPredictRPC)
assert calls[0].name == "SampleTaskPredict"
Expand Down
64 changes: 28 additions & 36 deletions tests/runtime/servicers/test_global_predict_servicer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import pytest

# Local
from caikit.core.data_model.base import DataBase
from caikit.runtime.servicers.global_predict_servicer import GlobalPredictServicer
from caikit.runtime.types.aborted_exception import AbortedException
from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException
Expand All @@ -56,9 +57,8 @@ def test_calling_predict_should_raise_if_module_raises(
):
with pytest.raises(CaikitRuntimeException) as context:
# SampleModules will raise a RuntimeError if the throw flag is set
request = sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT, throw=True
)
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
request = predict_class(sample_input=HAPPY_PATH_INPUT_DM, throw=True).to_proto()
sample_predict_servicer.Predict(
request,
Fixtures.build_context(sample_task_model_id),
Expand All @@ -77,9 +77,10 @@ def test_invalid_input_to_a_valid_caikit_core_class_method_raises(
"""Test that a caikit.core module that gets an unexpected input value errors in an expected way"""
with pytest.raises(CaikitRuntimeException) as context:
# SampleModules will raise a ValueError if the poison pill name is given
request = sample_inference_service.messages.SampleTaskRequest(
sample_input=SampleInputType(name=SampleModule.POISON_PILL_NAME).to_proto(),
)
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
request = predict_class(
sample_input=SampleInputType(name=SampleModule.POISON_PILL_NAME)
).to_proto()
sample_predict_servicer.Predict(
request,
Fixtures.build_context(sample_task_model_id),
Expand All @@ -96,10 +97,9 @@ def test_global_predict_works_for_unary_rpcs(
sample_task_unary_rpc,
):
"""Global predict of SampleTaskRequest returns a prediction"""
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
response = sample_predict_servicer.Predict(
sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT
),
predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(),
Fixtures.build_context(sample_task_model_id),
caikit_rpc=sample_task_unary_rpc,
)
Expand All @@ -111,13 +111,11 @@ def test_global_predict_works_on_bidirectional_streaming_rpcs(
):
"""Simple test that our SampleModule's bidirectional stream inference fn is supported"""

def req_iterator() -> Iterator[
sample_inference_service.messages.BidiStreamingSampleTaskRequest
]:
predict_class = DataBase.get_class_for_name("BidiStreamingSampleTaskRequest")

def req_iterator() -> Iterator[predict_class]:
for i in range(100):
yield sample_inference_service.messages.BidiStreamingSampleTaskRequest(
sample_inputs=HAPPY_PATH_INPUT
)
yield predict_class(sample_inputs=HAPPY_PATH_INPUT_DM).to_proto()

response_stream = sample_predict_servicer.Predict(
req_iterator(),
Expand All @@ -141,13 +139,11 @@ def test_global_predict_works_on_bidirectional_streaming_rpcs_with_multiple_stre
mock_manager = MagicMock()
mock_manager.retrieve_model.return_value = GeoStreamingModule()

def req_iterator() -> Iterator[
sample_inference_service.messages.BidiStreamingGeoSpatialTaskRequest
]:
predict_class = DataBase.get_class_for_name("BidiStreamingGeoSpatialTaskRequest")

def req_iterator() -> Iterator[predict_class]:
for i in range(100):
yield sample_inference_service.messages.BidiStreamingGeoSpatialTaskRequest(
lats=i, lons=100 - i, name="Gabe"
)
yield predict_class(lats=i, lons=100 - i, name="Gabe").to_proto()

with patch.object(sample_predict_servicer, "_model_manager", mock_manager):
response_stream = sample_predict_servicer.Predict(
Expand Down Expand Up @@ -198,12 +194,11 @@ def run(self, *args, **kwargs):
mock_manager.retrieve_model.return_value = dummy_model

context = Fixtures.build_context("test-any-unresponsive-model")
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
predict_thread = threading.Thread(
target=sample_predict_servicer.Predict,
args=(
sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT
),
predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(),
context,
),
kwargs={"caikit_rpc": sample_task_unary_rpc},
Expand Down Expand Up @@ -234,9 +229,10 @@ def test_metering_ignore_unsuccessful_calls(
gps = GlobalPredictServicer(sample_inference_service)
try:
with patch.object(gps.rpc_meter, "update_metrics") as mock_update_func:
request = sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT, throw=True
)
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
request = predict_class(
sample_input=HAPPY_PATH_INPUT_DM, throw=True
).to_proto()
with pytest.raises(CaikitRuntimeException):
gps.Predict(
request,
Expand All @@ -257,11 +253,10 @@ def test_metering_predict_rpc_counter(
sample_predict_servicer = GlobalPredictServicer(sample_inference_service)
try:
# Making 20 requests
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
for i in range(20):
sample_predict_servicer.Predict(
sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT
),
predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(),
Fixtures.build_context(sample_task_model_id),
caikit_rpc=sample_task_unary_rpc,
)
Expand Down Expand Up @@ -299,10 +294,9 @@ def test_metering_write_to_metrics_file_twice(
# need a new servicer to get a fresh new RPC meter
sample_predict_servicer = GlobalPredictServicer(sample_inference_service)
try:
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
sample_predict_servicer.Predict(
sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT
),
predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(),
Fixtures.build_context(sample_task_model_id),
caikit_rpc=sample_task_unary_rpc,
)
Expand All @@ -311,9 +305,7 @@ def test_metering_write_to_metrics_file_twice(
sample_predict_servicer.rpc_meter.flush_metrics()

sample_predict_servicer.Predict(
sample_inference_service.messages.SampleTaskRequest(
sample_input=HAPPY_PATH_INPUT
),
predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(),
Fixtures.build_context(sample_task_model_id),
caikit_rpc=sample_task_unary_rpc,
)
Expand Down
Loading

0 comments on commit faf973a

Please sign in to comment.