Skip to content

Commit

Permalink
✨ add request DMs to Modules/Tasks
Browse files Browse the repository at this point in the history
Signed-off-by: Prashant Gupta <[email protected]>
  • Loading branch information
prashantgupta24 committed Sep 18, 2023
1 parent faf973a commit f9cc03a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 16 deletions.
18 changes: 17 additions & 1 deletion caikit/runtime/service_generation/rpcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,21 @@ def name(self) -> str:

def create_request_data_model(self, package_name: str) -> Type[DataBase]:
"""Dynamically create data model for this RPC's request message"""
return self.request.create_data_model(package_name)
request_data_model = self.request.create_data_model(package_name)
if isinstance(self, TaskPredictRPC):
# add the DM to the task class for inference rpcs
if self._input_streaming and self._output_streaming:
setattr(self.task, "BIDI_REQUEST_DATA_MODEL", request_data_model)
elif self._output_streaming:
setattr(self.task, "SERVER_REQUEST_DATA_MODEL", request_data_model)
elif self._input_streaming:
setattr(self.task, "CLIENT_REQUEST_DATA_MODEL", request_data_model)
else:
setattr(self.task, "UNARY_REQUEST_DATA_MODEL", request_data_model)
if isinstance(self, ModuleClassTrainRPC):
# add the DM to the module class directly for train rpcs
setattr(self.clz, "TRAIN_REQUEST_DATA_MODEL", request_data_model)
return request_data_model

def create_rpc_json(self, package_name: str) -> Dict:
"""Return json snippet for the service definition of this RPC"""
Expand Down Expand Up @@ -124,6 +138,8 @@ def create_request_data_model(self, package_name: str):
request data model"""
# Build the inner request data model
inner_request_data_model = self._inner_request.create_data_model(package_name)

setattr(self.clz, "TRAINING_PARAMETERS_DATA_MODEL", inner_request_data_model)
# Insert the new type into the outer request
for triple_index, _ in enumerate(self._req.triples):
if self._req.triples[triple_index][1] == "parameters":
Expand Down
9 changes: 4 additions & 5 deletions examples/text-sentiment/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

# Local
from caikit.config.config import get_config
from caikit.core.data_model.base import DataBase
from caikit.runtime.service_factory import ServicePackageFactory
from text_sentiment.runtime_model import HuggingFaceSentimentModule
import caikit

if __name__ == "__main__":
Expand Down Expand Up @@ -51,10 +51,9 @@

# Run inference for two sample prompts
for text in ["I am not feeling well today!", "Today is a nice sunny day"]:
predict_class = DataBase.get_class_for_name(
"HuggingFaceSentimentTaskRequest"
)
request = predict_class(text_input=text).to_proto()
request = HuggingFaceSentimentModule.TASK_CLASS.UNARY_REQUEST_DATA_MODEL(
text_input=text
).to_proto()
response = client_stub.HuggingFaceSentimentTaskPredict(
request, metadata=[("mm-model-id", model_id)], timeout=1
)
Expand Down
17 changes: 7 additions & 10 deletions tests/runtime/test_grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,11 @@ def test_predict_sample_module_ok_response(
):
"""Test RPC CaikitRuntime.SampleTaskPredict successful response"""
stub = sample_inference_service.stub_class(runtime_grpc_server.make_local_channel())
predict_class = DataBase.get_class_for_name("SampleTaskRequest")
predict_request = predict_class(sample_input=HAPPY_PATH_INPUT_DM).to_proto()
predict_request = (
sample_lib.modules.SampleModule.TASK_CLASS.UNARY_REQUEST_DATA_MODEL(
sample_input=HAPPY_PATH_INPUT_DM
).to_proto()
)

actual_response = stub.SampleTaskPredict(
predict_request, metadata=[("mm-model-id", sample_task_model_id)]
Expand Down Expand Up @@ -397,15 +400,9 @@ def test_train_fake_module_ok_response_and_can_predict_with_trained_model(
)
)
model_name = random_test_id()
train_request_class = DataBase.get_class_for_name(
"SampleTaskSampleModuleTrainRequest"
)
train_request_params_class = DataBase.get_class_for_name(
"SampleTaskSampleModuleTrainParameters"
)
train_request = train_request_class(
train_request = sample_lib.modules.SampleModule.TRAIN_REQUEST_DATA_MODEL(
model_name=model_name,
parameters=train_request_params_class(
parameters=sample_lib.modules.SampleModule.TRAINING_PARAMETERS_DATA_MODEL(
training_data=training_data,
union_list=["str", "sequence"],
),
Expand Down

0 comments on commit f9cc03a

Please sign in to comment.