Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
Signed-off-by: Anh-Uong <[email protected]>
  • Loading branch information
anhuong committed Sep 13, 2023
1 parent 9c2b046 commit 71d3cb6
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 116 deletions.
8 changes: 4 additions & 4 deletions examples/text-sentiment/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@
# Run inference for two sample prompts
for text in ["I am not feeling well today!", "Today is a nice sunny day"]:
# TODO: is this the recommended approach for setting up client and request?
predict_class = DataBase.get_class_for_name("HuggingFaceSentimentTaskRequest")
request = predict_class(
text_input=text
).to_proto()
predict_class = DataBase.get_class_for_name(
"HuggingFaceSentimentTaskRequest"
)
request = predict_class(text_input=text).to_proto()
# request = inference_service.messages.HuggingFaceSentimentTaskRequest(
# text_input=text
# )
Expand Down
24 changes: 6 additions & 18 deletions tests/runtime/servicers/test_global_predict_servicer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ def test_calling_predict_should_raise_if_module_raises(
with pytest.raises(CaikitRuntimeException) as context:
# SampleModules will raise a RuntimeError if the throw flag is set
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
request = predict_class(
sample_input=HAPPY_PATH_INPUT_DM, throw=True
).to_proto()
request = predict_class(sample_input=HAPPY_PATH_INPUT_DM, throw=True).to_proto()
sample_predict_servicer.Predict(
request,
Fixtures.build_context(sample_task_model_id),
Expand Down Expand Up @@ -101,9 +99,7 @@ def test_global_predict_works_for_unary_rpcs(
"""Global predict of SampleTaskRequest returns a prediction"""
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
response = sample_predict_servicer.Predict(
predict_class(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto(),
predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(),
Fixtures.build_context(sample_task_model_id),
caikit_rpc=sample_task_unary_rpc,
)
Expand Down Expand Up @@ -212,9 +208,7 @@ def run(self, *args, **kwargs):
predict_thread = threading.Thread(
target=sample_predict_servicer.Predict,
args=(
predict_class(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto(),
predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(),
context,
),
kwargs={"caikit_rpc": sample_task_unary_rpc},
Expand Down Expand Up @@ -272,9 +266,7 @@ def test_metering_predict_rpc_counter(
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
for i in range(20):
sample_predict_servicer.Predict(
predict_class(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto(),
predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(),
Fixtures.build_context(sample_task_model_id),
caikit_rpc=sample_task_unary_rpc,
)
Expand Down Expand Up @@ -314,9 +306,7 @@ def test_metering_write_to_metrics_file_twice(
try:
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
sample_predict_servicer.Predict(
predict_class(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto(),
predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(),
Fixtures.build_context(sample_task_model_id),
caikit_rpc=sample_task_unary_rpc,
)
Expand All @@ -325,9 +315,7 @@ def test_metering_write_to_metrics_file_twice(
sample_predict_servicer.rpc_meter.flush_metrics()

sample_predict_servicer.Predict(
predict_class(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto(),
predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(),
Fixtures.build_context(sample_task_model_id),
caikit_rpc=sample_task_unary_rpc,
)
Expand Down
45 changes: 16 additions & 29 deletions tests/runtime/servicers/test_global_train_servicer_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
# Local
from caikit.config import get_config
from caikit.core import MODEL_MANAGER
from caikit.core.data_model.producer import ProducerId
from caikit.core.data_model.base import DataBase
from caikit.core.data_model.producer import ProducerId
from caikit.interfaces.common.data_model.stream_sources import S3Path
from caikit.runtime.servicers.global_train_servicer import GlobalTrainServicer
from caikit.runtime.types.caikit_runtime_exception import CaikitRuntimeException
Expand All @@ -48,6 +48,7 @@

HAPPY_PATH_INPUT_DM = SampleInputType(name="Gabe")


@contextmanager
def set_use_subprocess(use_subprocess: bool):
with temp_config(
Expand Down Expand Up @@ -107,7 +108,7 @@ def test_global_train_sample_task(
parameters=train_request_params_class(
batch_size=42,
training_data=training_data,
)
),
).to_proto()

training_response = sample_train_servicer.Train(
Expand Down Expand Up @@ -139,9 +140,7 @@ def test_global_train_sample_task(

predict_class = DataBase.get_class_for_name("SampleTaskRequest")
inference_response = sample_predict_servicer.Predict(
predict_class(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto(),
predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(),
Fixtures.build_context(training_response.model_name),
caikit_rpc=sample_task_unary_rpc,
)
Expand Down Expand Up @@ -176,7 +175,7 @@ def test_global_train_other_task(
training_data=training_data,
sample_input_sampleinputtype=HAPPY_PATH_INPUT_DM,
batch_size=batch_size,
)
),
).to_proto()

training_response = sample_train_servicer.Train(
Expand All @@ -200,9 +199,7 @@ def test_global_train_other_task(

predict_class = DataBase.get_class_for_name("OtherTaskRequest")
inference_response = sample_predict_servicer.Predict(
predict_class(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto(),
predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(),
Fixtures.build_context(training_response.model_name),
caikit_rpc=sample_inference_service.caikit_rpcs["OtherTaskPredict"],
)
Expand Down Expand Up @@ -234,9 +231,7 @@ def test_global_train_Another_Widget_that_requires_SampleWidget_loaded_should_no
)
training_request = train_class(
model_name="AnotherWidget_Training",
parameters=train_request_params_class(
sample_block=sample_model
)
parameters=train_request_params_class(sample_block=sample_model),
).to_proto()

training_response = sample_train_servicer.Train(
Expand All @@ -262,9 +257,7 @@ def test_global_train_Another_Widget_that_requires_SampleWidget_loaded_should_no
# make sure the trained model can run inference
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
inference_response = sample_predict_servicer.Predict(
predict_class(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto(),
predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto(),
Fixtures.build_context(training_response.model_name),
caikit_rpc=sample_task_unary_rpc,
)
Expand Down Expand Up @@ -296,7 +289,7 @@ def test_run_train_job_works_with_wait(
parameters=train_request_params_class(
batch_size=42,
training_data=training_data,
)
),
).to_proto()
servicer = GlobalTrainServicer(training_service=sample_train_service)
with TemporaryDirectory() as tmp_dir:
Expand All @@ -315,9 +308,7 @@ def test_run_train_job_works_with_wait(

predict_class = DataBase.get_class_for_name("SampleTaskRequest")
inference_response = sample_predict_servicer.Predict(
predict_class(
sample_input=SampleInputType(name="Test")
).to_proto(),
predict_class(sample_input=SampleInputType(name="Test")).to_proto(),
Fixtures.build_context(training_response.model_name),
caikit_rpc=sample_task_unary_rpc,
)
Expand All @@ -340,18 +331,14 @@ def test_global_train_Another_Widget_that_requires_SampleWidget_but_not_loaded_s
"""Global train of TrainRequest raises when calling a train function that requires another loaded model, but model is not loaded"""
model_id = random_test_id()

sample_model = caikit.interfaces.runtime.data_model.ModelPointer(
model_id=model_id
)
sample_model = caikit.interfaces.runtime.data_model.ModelPointer(model_id=model_id)
train_class = DataBase.get_class_for_name("SampleTaskCompositeModuleTrainRequest")
train_request_params_class = DataBase.get_class_for_name(
"SampleTaskCompositeModuleTrainParameters"
)
request = train_class(
model_name="AnotherWidget_Training",
parameters=train_request_params_class(
sample_block=sample_model
)
parameters=train_request_params_class(sample_block=sample_model),
).to_proto()

with pytest.raises(CaikitRuntimeException) as context:
Expand All @@ -378,7 +365,7 @@ def test_global_train_Edge_Case_Widget_should_raise_when_error_surfaces_from_mod
parameters=train_request_params_class(
batch_size=999,
training_data=training_data,
)
),
).to_proto()

training_response = sample_train_servicer.Train(
Expand Down Expand Up @@ -410,7 +397,7 @@ def test_global_train_returns_exit_code_with_oom(
batch_size=42,
training_data=training_data,
oom_exit=True,
)
),
).to_proto()

# Enable sub-processing for test
Expand Down Expand Up @@ -445,7 +432,7 @@ def test_local_trainer_rejects_s3_output_paths(
batch_size=42,
training_data=training_data,
oom_exit=True,
)
),
).to_proto()

with pytest.raises(
Expand Down Expand Up @@ -480,7 +467,7 @@ def test_global_train_aborts_long_running_trains(
batch_size=42,
training_data=training_data,
oom_exit=False,
)
),
).to_proto()

if sample_train_servicer.use_subprocess:
Expand Down
Loading

0 comments on commit 71d3cb6

Please sign in to comment.