forked from caikit/caikit
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Martin Hickey <[email protected]>
- Loading branch information
Showing
11 changed files
with
249 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Text Sentiment Analysis Example | ||
|
||
This example uses the [HuggingFace DistilBERT base uncased finetuned SST-2](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english) AI model to perform text sentiment analysis. The Caikit runtime loads the model and serves it so that it can be inferred or called. | ||
|
||
## Before Starting | ||
|
||
The following tools are required: | ||
|
||
- [python](https://www.python.org) (v3.8+) | ||
- [pip](https://pypi.org/project/pip/) (v23.0+) | ||
|
||
**Note: Before installing dependencie and to avoid conflicts in your environment, it is advisable to use a [virtual environment(venv)](https://docs.python.org/3/library/venv.html).** | ||
|
||
Install the dependencies: `pip install -r requirements.txt` | ||
|
||
## Running the Caikit runtime | ||
|
||
In one terminal, start the runtime server: | ||
|
||
```shell | ||
python3 start_runtime.py | ||
``` | ||
|
||
You should see output similar to the following: | ||
|
||
```command | ||
$ python3 start_runtime.py | ||
|
||
<function register_backend_type at 0x7fce0064b5e0> is still in the BETA phase and subject to change! | ||
{"channel": "COM-LIB-INIT", "exception": null, "level": "info", "log_code": "<RUN11997772I>", "message": "Loading service module: text_sentiment", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:52.808812"} | ||
{"channel": "COM-LIB-INIT", "exception": null, "level": "info", "log_code": "<RUN11997772I>", "message": "Loading service module: caikit.interfaces.common", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:52.809406"} | ||
{"channel": "COM-LIB-INIT", "exception": null, "level": "info", "log_code": "<RUN11997772I>", "message": "Loading service module: caikit.interfaces.runtime", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:52.809565"} | ||
[…] | ||
{"channel": "MODEL-LOADER", "exception": null, "level": "info", "log_code": "<RUN89711114I>", "message": "Loading model 'text_sentiment'", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:52.826657"} | ||
{"channel": "MDLMNG", "exception": null, "level": "warning", "log_code": "<COR56759744W>", "message": "No backend configured! Trying to configure using default config file.", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:52.827742"} | ||
No model was supplied, defaulted to distilbert-base-uncased-finetuned-sst-2-english and revision af0f99b (https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english). | ||
Using a pipeline without specifying a model name and revision in production is not recommended. | ||
[…] | ||
{"channel": "COM-LIB-INIT", "exception": null, "level": "info", "log_code": "<RUN11997772I>", "message": "Loading service module: text_sentiment", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:53.929756"} | ||
{"channel": "COM-LIB-INIT", "exception": null, "level": "info", "log_code": "<RUN11997772I>", "message": "Loading service module: caikit.interfaces.common", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:53.929814"} | ||
{"channel": "COM-LIB-INIT", "exception": null, "level": "info", "log_code": "<RUN11997772I>", "message": "Loading service module: caikit.interfaces.runtime", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:53.929858"} | ||
{"channel": "GP-SERVICR-I", "exception": null, "level": "info", "log_code": "<RUN76773778I>", "message": "Validated Caikit Library CDM successfully", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:53.929942"} | ||
{"channel": "GP-SERVICR-I", "exception": null, "level": "info", "log_code": "<RUN76884779I>", "message": "Constructed inference service for library: text_sentiment, version: unknown", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:53.930734"} | ||
{"channel": "SERVER-WRAPR", "exception": null, "level": "info", "log_code": "<RUN81194024I>", "message": "Intercepting RPC method /caikit.runtime.HfTextsentiment.HfTextsentimentService/HfBlockPredict", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:53.930786"} | ||
{"channel": "SERVER-WRAPR", "exception": null, "level": "info", "log_code": "<RUN33333123I>", "message": "Wrapping safe rpc for Predict", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:53.931424"} | ||
{"channel": "SERVER-WRAPR", "exception": null, "level": "info", "log_code": "<RUN30032825I>", "message": "Re-routing RPC /caikit.runtime.HfTextsentiment.HfTextsentimentService/HfBlockPredict from <function _ServiceBuilder._GenerateNonImplementedMethod.<locals>.<lambda> at 0x7fce01f660d0> to <function CaikitRuntimeServerWrapper.safe_rpc_wrapper.<locals>.safe_rpc_call at 0x7fce02144670>", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:53.931479"} | ||
{"channel": "SERVER-WRAPR", "exception": null, "level": "info", "log_code": "<RUN24924908I>", "message": "Interception of service caikit.runtime.HfTextsentiment.HfTextsentimentService complete", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:53.931530"} | ||
[…] | ||
|
||
{"channel": "GRPC-SERVR", "exception": null, "level": "info", "log_code": "<RUN10001807I>", "message": "Running in insecure mode", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:53.936511"} | ||
{"channel": "GRPC-SERVR", "exception": null, "level": "info", "log_code": "<RUN10001001I>", "message": "Caikit Runtime is serving on port: 8085 with thread pool size: 5", "num_indent": 0, "thread_id": 8605140480, "timestamp": "2023-05-02T11:42:53.938054"} | ||
``` | ||
|
||
## Inferring the Served Model | ||
|
||
In another terminal, run the client code: | ||
|
||
```shell | ||
python3 client.py | ||
``` | ||
|
||
The client code calls the model and queries it for sentiment analysis on a 2 different pieces of text. | ||
|
||
You should see output similar to the following: | ||
|
||
```command | ||
$ python3 client.py | ||
|
||
<function register_backend_type at 0x7fe930bdbdc0> is still in the BETA phase and subject to change! | ||
Text: I am not feeling well today! | ||
RESPONSE: classes { | ||
class_name: "NEGATIVE" | ||
confidence: 0.99977594614028931 | ||
} | ||
|
||
Text: Today is a nice sunny day | ||
RESPONSE: classes { | ||
class_name: "POSITIVE" | ||
confidence: 0.999869704246521 | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import grpc | ||
from caikit.runtime.service_factory import ServicePackageFactory | ||
|
||
from text_sentiment.data_model import TextInput | ||
|
||
inference_service = ServicePackageFactory().get_service_package( | ||
ServicePackageFactory.ServiceType.INFERENCE, | ||
ServicePackageFactory.ServiceSource.GENERATED, | ||
) | ||
|
||
port = 8085 | ||
channel = grpc.insecure_channel(f"localhost:{port}") | ||
|
||
client_stub = inference_service.stub_class(channel) | ||
|
||
# print(dir(client_stub)) | ||
|
||
for text in ["I am not feeling well today!", "Today is a nice sunny day"]: | ||
input_text_proto = TextInput(text=text).to_proto() | ||
request = inference_service.messages.HfBlockRequest(text_input=input_text_proto) | ||
response = client_stub.HfBlockPredict( | ||
request, metadata=[("mm-model-id", "text_sentiment")] | ||
) | ||
print("Text:", text) | ||
print("RESPONSE:", response) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
block_id: 8f72161-c0e4-49b0-8fd0-7587b3017a35 | ||
name: HuggingFaceSentimentBlock | ||
version: 0.0.1 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
caikit | ||
|
||
# Only needed for HuggingFace | ||
scipy | ||
torch | ||
transformers~=4.27.2 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from os import path | ||
import sys | ||
import alog | ||
|
||
sys.path.append(path.abspath(path.join(path.dirname(__file__), "../"))) # Here we assume this file is at the same level of requirements.txt | ||
import text_sentiment | ||
|
||
alog.configure(default_level="debug") | ||
|
||
from caikit.runtime import grpc_server | ||
grpc_server.main() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import os | ||
|
||
from . import data_model, runtime_model | ||
import caikit | ||
|
||
# Give the path to the `config.yml` | ||
CONFIG_PATH = os.path.realpath( | ||
os.path.join(os.path.dirname(__file__), "config.yml") | ||
) | ||
|
||
caikit.configure(CONFIG_PATH) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
runtime: | ||
library: text_sentiment | ||
service_generation: | ||
primitive_data_model_types: | ||
- "text_sentiment.data_model.classification.TextInput" | ||
|
2 changes: 2 additions & 0 deletions
2
examples/text-sentiment/text_sentiment/data_model/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .classification import ClassificationPrediction, ClassInfo, TextInput | ||
|
25 changes: 25 additions & 0 deletions
25
examples/text-sentiment/text_sentiment/data_model/classification.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from caikit.core.data_model import dataobject | ||
|
||
@dataobject( | ||
package="text_sentiment.data_model", | ||
schema={ | ||
"class_name": str, # (required) Predicted relevant class name | ||
"confidence": float, # (required) The confidence-like score of this prediction in [0, 1] | ||
}, | ||
) | ||
class ClassInfo: | ||
"""A single classification prediction.""" | ||
|
||
@dataobject( | ||
package="text_sentiment.data_model", | ||
schema={"classes": {"elements": ClassInfo}}, | ||
) | ||
class ClassificationPrediction: | ||
"""The result of a classification prediction.""" | ||
|
||
@dataobject(package="text_sentiment.data_model", schema={"text": str}) | ||
class TextInput: | ||
"""A sample `domain primitive` input type for this library. | ||
The analog to a `Raw Document` for the `Natural Language Processing` domain.""" | ||
|
||
|
2 changes: 2 additions & 0 deletions
2
examples/text-sentiment/text_sentiment/runtime_model/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .hf_block import HuggingFaceSentimentBlock | ||
|
71 changes: 71 additions & 0 deletions
71
examples/text-sentiment/text_sentiment/runtime_model/hf_block.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import os | ||
|
||
from caikit.core import BlockBase, ModuleLoader, ModuleSaver, block | ||
from transformers import pipeline | ||
|
||
from text_sentiment.data_model.classification import ClassificationPrediction, ClassInfo, TextInput | ||
|
||
@block("8f72161-c0e4-49b0-8fd0-7587b3017a35", "HuggingFaceSentimentBlock", "0.0.1") | ||
class HuggingFaceSentimentBlock(BlockBase): | ||
"""Class to wrap sentiment analysis pipeline from HuggingFace""" | ||
|
||
def __init__(self, model_path) -> None: | ||
super().__init__() | ||
loader = ModuleLoader(model_path) | ||
config = loader.config | ||
model = pipeline(model=config.hf_artifact_path, task="sentiment-analysis") | ||
self.sentiment_pipeline = model | ||
|
||
def run(self, text_input: TextInput) -> ClassificationPrediction: | ||
"""Run HF sentiment analysis | ||
Args: | ||
text_input: TextInput | ||
Returns: | ||
ClassificationPrediction: predicted classes with their confidence score. | ||
""" | ||
raw_results = self.sentiment_pipeline([text_input.text]) | ||
|
||
class_info = [] | ||
for result in raw_results: | ||
class_info.append( | ||
ClassInfo(class_name=result["label"], confidence=result["score"]) | ||
) | ||
return ClassificationPrediction(class_info) | ||
|
||
@classmethod | ||
def bootstrap(cls, model_path="distilbert-base-uncased-finetuned-sst-2-english"): | ||
"""Load a HuggingFace based caikit model | ||
Args: | ||
model_path: str | ||
Path to HugginFace model | ||
Returns: | ||
HuggingFaceModel | ||
""" | ||
return cls(model_path) | ||
|
||
def save(self, model_path, **kwargs): | ||
module_saver = ModuleSaver( | ||
self, | ||
model_path=model_path, | ||
) | ||
|
||
# Extract object to be saved | ||
with module_saver: | ||
# Make the directory to save model artifacts | ||
rel_path, _ = module_saver.add_dir("hf_model") | ||
save_path = os.path.join(model_path, rel_path) | ||
self.sentiment_pipeline.save_pretrained(save_path) | ||
module_saver.update_config({"hf_artifact_path": rel_path}) | ||
|
||
# this is how you load the model, if you have a caikit model | ||
@classmethod | ||
def load(cls, model_path): | ||
"""Load a HuggingFace based caikit model | ||
Args: | ||
model_path: str | ||
Path to HuggingFace model | ||
Returns: | ||
HuggingFaceModel | ||
""" | ||
return cls(model_path) | ||
|