Skip to content

Commit

Permalink
✨ introduce request dm fetching in runtime
Browse files Browse the repository at this point in the history
Signed-off-by: Prashant Gupta <[email protected]>
  • Loading branch information
prashantgupta24 committed Sep 18, 2023
1 parent 377a6be commit bd0ab26
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 86 deletions.
3 changes: 3 additions & 0 deletions caikit/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Local
from .service_factory import get_request, get_train_params, get_train_request
36 changes: 36 additions & 0 deletions caikit/runtime/service_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
# Local
from caikit import get_config
from caikit.core import LocalBackend, ModuleBase, registries
from caikit.core.data_model.base import DataBase
from caikit.core.data_model.dataobject import _AUTO_GEN_PROTO_CLASSES
from caikit.interfaces.runtime.data_model import (
TrainingInfoRequest,
Expand Down Expand Up @@ -266,3 +267,38 @@ def _get_and_filter_modules(
excluded_modules,
)
return clean_modules


def get_request(
module_class: Type[ModuleBase],
input_streaming: bool = False,
output_streaming: bool = False,
) -> Type[DataBase]:
"""Helper function to return the request DataModel for the Module Class"""
if input_streaming and output_streaming:
request_class_name = f"BidiStreaming{module_class.TASK_CLASS.__name__}Request"
elif input_streaming:
request_class_name = f"ClientStreaming{module_class.TASK_CLASS.__name__}Request"
elif output_streaming:
request_class_name = f"ServerStreaming{module_class.TASK_CLASS.__name__}Request"
else:
request_class_name = f"{module_class.TASK_CLASS.__name__}Request"
return DataBase.get_class_for_name(request_class_name)


def get_train_request(module_class: Type[ModuleBase]) -> Type[DataBase]:
"""Helper function to return the train request DataModel for the Module Class"""
request_class_name = (
f"{module_class.TASK_CLASS.__name__}{module_class.__name__}TrainRequest"
)
print(request_class_name)
return DataBase.get_class_for_name(request_class_name)


def get_train_params(module_class: Type[ModuleBase]) -> Type[DataBase]:
"""Helper function to return the train parameters DataModel for the Module Class"""
request_class_name = (
f"{module_class.TASK_CLASS.__name__}{module_class.__name__}TrainParameters"
)
print(request_class_name)
return DataBase.get_class_for_name(request_class_name)
3 changes: 2 additions & 1 deletion examples/text-sentiment/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

# Local
from caikit.config.config import get_config
from caikit.runtime import get_request
from caikit.runtime.service_factory import ServicePackageFactory
from text_sentiment.runtime_model import HuggingFaceSentimentModule
import caikit
Expand Down Expand Up @@ -51,7 +52,7 @@

# Run inference for two sample prompts
for text in ["I am not feeling well today!", "Today is a nice sunny day"]:
request = HuggingFaceSentimentModule.TASK_CLASS.UNARY_REQUEST_DATA_MODEL(
request = get_request(HuggingFaceSentimentModule)(
text_input=text
).to_proto()
response = client_stub.HuggingFaceSentimentTaskPredict(
Expand Down
9 changes: 8 additions & 1 deletion tests/fixtures/sample_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@

# Local
from . import data_model, modules
from .modules import InnerModule, OtherModule, SampleModule, SamplePrimitiveModule
from .modules import (
CompositeModule,
InnerModule,
OtherModule,
SampleModule,
SamplePrimitiveModule,
StreamingModule,
)
from caikit.config import configure

# Run configure for sample_lib configuration
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/sample_lib/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@
InnerModule,
SampleModule,
SamplePrimitiveModule,
StreamingModule,
)
Loading

0 comments on commit bd0ab26

Please sign in to comment.