From 2ed7c7fec449cb6915ca040164d21248ab7427e2 Mon Sep 17 00:00:00 2001 From: davidmezzetti <561939+davidmezzetti@users.noreply.github.com> Date: Tue, 31 Aug 2021 19:18:21 -0400 Subject: [PATCH] Add ONNX support for Embeddings and Pipelines, closes #109 --- docs/pipelines/onnx.md | 26 ++---- setup.py | 5 +- src/python/txtai/models/__init__.py | 1 + src/python/txtai/models/models.py | 39 +++++++++ src/python/txtai/models/onnx.py | 104 +++++++++++++++++++++++ src/python/txtai/models/pooling.py | 11 +-- src/python/txtai/pipeline/hfonnx.py | 15 ++-- src/python/txtai/pipeline/hfpipeline.py | 2 +- src/python/txtai/vectors/factory.py | 4 +- src/python/txtai/vectors/transformers.py | 8 +- src/python/txtai/vectors/words.py | 22 +++++ test/python/testonnx.py | 74 ++++++++-------- test/python/testoptional.py | 11 +++ 13 files changed, 250 insertions(+), 72 deletions(-) create mode 100644 src/python/txtai/models/onnx.py diff --git a/docs/pipelines/onnx.md b/docs/pipelines/onnx.md index 08f28b276..c0008c0f4 100644 --- a/docs/pipelines/onnx.md +++ b/docs/pipelines/onnx.md @@ -1,33 +1,23 @@ # HFOnnx -Exports a Hugging Face Transformer model to ONNX. +Exports a Hugging Face Transformer model to ONNX. Currently, this works best with classification/pooling/qa models. Work is ongoing for sequence to +sequence models (summarization, transcription, translation). Example on how to use the pipeline below. ```python -from onnxruntime import InferenceSession, SessionOptions -from transformers import AutoTokenizer +from txtai.pipeline import HFOnnx, Labels -from txtai.pipeline import HFOnnx - -# Normalize logits using sigmoid function -sigmoid = lambda x: 1.0 / (1.0 + np.exp(-x)) +# Model path +path = "distilbert-base-uncased-finetuned-sst-2-english" # Export model to ONNX onnx = HFOnnx() -model = onnx("distilbert-base-uncased-finetuned-sst-2-english", "sequence-classification", "model.onnx", True) - -# Build ONNX session -options = SessionOptions() -session = InferenceSession(model, options) - -# Tokenize -tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-finetuned-sst-2-english") -tokens = tokenizer(["I am happy"], return_tensors="np") +model = onnx(path, "text-classification", "model.onnx", True) # Run inference and validate -outputs = session.run(None, dict(tokens)) -outputs = sigmoid(outputs[0]) +labels = Labels((model, path), dynamic=False) +labels("I am happy") ``` ::: txtai.pipeline.HFOnnx.__init__ diff --git a/setup.py b/setup.py index cfba4b2fa..80b99fcc4 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,8 @@ "uvicorn>=0.12.1", ] +extras["model"] = ["onnxruntime>=1.8.1"] + extras["pipeline"] = [ "fasttext>=0.9.2", "nltk>=3.5", @@ -33,7 +35,6 @@ "annoy>=1.16.3", "fasttext>=0.9.2", "hnswlib>=0.5.0", - "onnxruntime>=1.8.1", "pymagnitude-lite>=0.1.43", "scikit-learn>=0.23.1", "sentence-transformers>=2.0.0", @@ -41,7 +42,7 @@ extras["workflow"] = ["apache-libcloud>=3.3.1", "pillow>=7.2.0", "requests>=2.24.0"] -extras["all"] = extras["api"] + extras["pipeline"] + extras["similarity"] + extras["workflow"] +extras["all"] = extras["api"] + extras["model"] + extras["pipeline"] + extras["similarity"] + extras["workflow"] setup( name="txtai", diff --git a/src/python/txtai/models/__init__.py b/src/python/txtai/models/__init__.py index 9b881a9fd..9d73f17cf 100644 --- a/src/python/txtai/models/__init__.py +++ b/src/python/txtai/models/__init__.py @@ -3,4 +3,5 @@ """ from .models import Models +from .onnx import OnnxModel from .pooling import MeanPooling, Pooling diff --git a/src/python/txtai/models/models.py b/src/python/txtai/models/models.py index 92e197f6b..459a78776 100644 --- a/src/python/txtai/models/models.py +++ b/src/python/txtai/models/models.py @@ -2,8 +2,14 @@ Models module """ +import os + import torch +from transformers import AutoModel, AutoModelForQuestionAnswering, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification + +from .onnx import OnnxModel + class Models: """ @@ -81,3 +87,36 @@ def reference(deviceid): """ return "cpu" if deviceid < 0 else "cuda:{}".format(deviceid) + + @staticmethod + def load(path, task="default"): + """ + Loads a machine learning model. Handles multiple model frameworks (ONNX, Transformers). + + Args: + path: path to model + task: task name used to lookup model configuration + + Returns: + machine learning model + """ + + # Detect ONNX models + if isinstance(path, bytes) or (isinstance(path, str) and os.path.isfile(path)): + return OnnxModel(path) + + # Return path, if path isn't a string + if not isinstance(path, str): + return path + + # Transformer models + config = { + "default": AutoModel.from_pretrained, + "question-answering": AutoModelForQuestionAnswering.from_pretrained, + "summarization": AutoModelForSeq2SeqLM.from_pretrained, + "text-classification": AutoModelForSequenceClassification.from_pretrained, + "zero-shot-classification": AutoModelForSequenceClassification.from_pretrained, + } + + # Load model for supported tasks. Return path for unsupported tasks. + return config[task](path) if task in config else path diff --git a/src/python/txtai/models/onnx.py b/src/python/txtai/models/onnx.py new file mode 100644 index 000000000..4ecc230ba --- /dev/null +++ b/src/python/txtai/models/onnx.py @@ -0,0 +1,104 @@ +""" +ONNX module +""" + +# Conditional import +try: + from onnxruntime import InferenceSession, SessionOptions + + ONNX_RUNTIME = True +except ImportError: + ONNX_RUNTIME = False + +import numpy as np +import torch + +from transformers.configuration_utils import PretrainedConfig +from transformers.models.auto.modeling_auto import ( + MODEL_MAPPING, + MODEL_FOR_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, +) +from transformers.models.auto.tokenization_auto import TOKENIZER_MAPPING +from transformers.modeling_utils import PreTrainedModel + +# pylint: disable=W0223 +class OnnxModel(PreTrainedModel): + """ + Provides a Transformers/PyTorch compatible interface for ONNX models. Handles casting inputs + and outputs with minimal to no copying of data. + """ + + def __init__(self, model): + """ + Creates a new OnnxModel. + + Args: + model: path to model or InferenceSession + """ + + if not ONNX_RUNTIME: + raise ImportError('onnxruntime is not available - install "model" extra to enable') + + super().__init__(PretrainedConfig()) + + # Create ONNX session + self.model = InferenceSession(model, SessionOptions()) + + # Add references for this class to supported AutoModel classes + name = self.__class__.__name__ + if name not in MODEL_MAPPING: + MODEL_MAPPING[name] = self.__class__ + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING[name] = self.__class__ + MODEL_FOR_QUESTION_ANSWERING_MAPPING[name] = self.__class__ + + # Add references for this class to support pipeline AutoTokenizers + if type(self.config) not in TOKENIZER_MAPPING: + TOKENIZER_MAPPING[type(self.config)] = None + + def forward(self, **inputs): + """ + Runs inputs through an ONNX model and returns outputs. This method handles casting inputs + and outputs between torch tensors and numpy arrays as shared memory (no copy). + + Args: + inputs: model inputs + + Returns: + model outputs + """ + + inputs = self.parse(inputs) + + # Run inputs through ONNX model + results = self.model.run(None, inputs) + + # pylint: disable=E1101 + return torch.from_numpy(np.array(results)) + + def parse(self, inputs): + """ + Parse model inputs and handle converting to ONNX compatible inputs. + + Args: + inputs: model inputs + + Returns: + ONNX compatible model inputs + """ + + features = {} + + # Select features from inputs + for key in ["input_ids", "attention_mask", "token_type_ids"]: + if key in inputs: + value = inputs[key] + + # Cast torch tensors to numpy + if hasattr(value, "cpu"): + value = value.cpu().numpy() + + # Cast to numpy array if not already one + features[key] = np.asarray(value) + + return features diff --git a/src/python/txtai/models/pooling.py b/src/python/txtai/models/pooling.py index 072fce5ac..0af85f384 100644 --- a/src/python/txtai/models/pooling.py +++ b/src/python/txtai/models/pooling.py @@ -7,7 +7,7 @@ from torch import nn -from transformers import AutoModel, AutoTokenizer +from transformers import AutoTokenizer from .models import Models @@ -17,21 +17,22 @@ class Pooling(nn.Module): Builds pooled vectors usings outputs from a transformers model. """ - def __init__(self, path, device, batch=32, maxlength=None): + def __init__(self, path, device, tokenizer=None, batch=32, maxlength=None): """ Creates a new Pooling model. Args: - path: path to transformers model + path: path to model, accepts Hugging Face model hub id or local path device: tensor device id + tokenizer: optional path to tokenizer batch: batch size maxlength: max sequence length """ super().__init__() - self.model = AutoModel.from_pretrained(path) - self.tokenizer = AutoTokenizer.from_pretrained(path) + self.model = Models.load(path) + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer if tokenizer else path) self.device = Models.device(device) # Detect unbounded tokenizer typically found in older models diff --git a/src/python/txtai/pipeline/hfonnx.py b/src/python/txtai/pipeline/hfonnx.py index 8a3dfb39c..d97a307a8 100644 --- a/src/python/txtai/pipeline/hfonnx.py +++ b/src/python/txtai/pipeline/hfonnx.py @@ -18,8 +18,7 @@ from torch.onnx import export -from transformers import AutoModel, AutoModelForCausalLM, AutoModelForMultipleChoice, AutoModelForQuestionAnswering -from transformers import AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, AutoTokenizer +from transformers import AutoModel, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoTokenizer from ..models import MeanPooling from .tensors import Tensors @@ -40,6 +39,9 @@ def __call__(self, path, task="default", output=None, quantize=False, opset=12): output: optional output model path, defaults to return byte array if None quantize: if model should be quantized (requires onnx to be installed), defaults to False opset: onnx opset, defaults to 12 + + Returns: + path to model output or model as bytes depending on output parameter """ inputs, outputs, model = self.parameters(task) @@ -141,12 +143,7 @@ def parameters(self, task): config = { "default": (OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), AutoModel.from_pretrained), - "causal-lm": (OrderedDict({"logits": {0: "batch", 1: "sequence"}}), AutoModelForCausalLM.from_pretrained), "pooling": (OrderedDict({"embeddings": {0: "batch", 1: "sequence"}}), lambda x: MeanPoolingOnnx(x, -1)), - "seq2seq-lm": (OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}), AutoModelForSeq2SeqLM.from_pretrained), - "sequence-classification": (OrderedDict({"logits": {0: "batch"}}), AutoModelForSequenceClassification.from_pretrained), - "token-classification": (OrderedDict({"logits": {0: "batch", 1: "sequence"}}), AutoModelForTokenClassification.from_pretrained), - "multiple-choice": (OrderedDict({"logits": {0: "batch"}}), AutoModelForMultipleChoice.from_pretrained), "question-answering": ( OrderedDict( { @@ -156,8 +153,12 @@ def parameters(self, task): ), AutoModelForQuestionAnswering.from_pretrained, ), + "text-classification": (OrderedDict({"logits": {0: "batch"}}), AutoModelForSequenceClassification.from_pretrained), } + # Aliases + config["zero-shot-classification"] = config["text-classification"] + return (inputs,) + config[task] diff --git a/src/python/txtai/pipeline/hfpipeline.py b/src/python/txtai/pipeline/hfpipeline.py index 0a9689003..61be8e993 100644 --- a/src/python/txtai/pipeline/hfpipeline.py +++ b/src/python/txtai/pipeline/hfpipeline.py @@ -36,7 +36,7 @@ def __init__(self, task, path=None, quantize=False, gpu=False, model=None): # Transformer pipeline task if isinstance(path, tuple): - self.pipeline = pipeline(task, model=path[0], tokenizer=path[1], device=deviceid) + self.pipeline = pipeline(task, model=Models.load(path[0], task), tokenizer=path[1], device=deviceid) else: self.pipeline = pipeline(task, model=path, tokenizer=path, device=deviceid) diff --git a/src/python/txtai/vectors/factory.py b/src/python/txtai/vectors/factory.py index 9569d357f..994bce1d9 100644 --- a/src/python/txtai/vectors/factory.py +++ b/src/python/txtai/vectors/factory.py @@ -2,8 +2,6 @@ Factory module """ -import os - from .transformers import TransformersVectors from .words import WordVectors, WORDS @@ -58,6 +56,6 @@ def method(config): # Infer method from path, if blank if not method and path: - method = "words" if os.path.isfile(path) else "transformers" + method = "words" if WordVectors.isDatabase(path) else "transformers" return method diff --git a/src/python/txtai/vectors/transformers.py b/src/python/txtai/vectors/transformers.py index f72071bd8..4fdd47a27 100644 --- a/src/python/txtai/vectors/transformers.py +++ b/src/python/txtai/vectors/transformers.py @@ -2,6 +2,7 @@ Transformers module """ +import os import pickle import tempfile @@ -14,7 +15,7 @@ SENTENCE_TRANSFORMERS = False from .base import Vectors -from ..models import MeanPooling, Models +from ..models import MeanPooling, Models, Pooling from ..pipeline.tokenizer import Tokenizer @@ -32,7 +33,10 @@ def load(self, path): # Build embeddings with transformers (default) if transformers: - return MeanPooling(path, device=deviceid) + if isinstance(path, bytes) or (isinstance(path, str) and os.path.isfile(path)): + return Pooling(path, device=deviceid, tokenizer=self.config.get("tokenizer")) + + return MeanPooling(path, device=deviceid, tokenizer=self.config.get("tokenizer")) if not SENTENCE_TRANSFORMERS: raise ImportError('sentence-transformers is not available - install "similarity" extra to enable') diff --git a/src/python/txtai/vectors/words.py b/src/python/txtai/vectors/words.py index 6593c2162..816b75926 100644 --- a/src/python/txtai/vectors/words.py +++ b/src/python/txtai/vectors/words.py @@ -153,6 +153,28 @@ def lookup(self, tokens): return self.model.query(tokens) + @staticmethod + def isDatabase(path): + """ + Checks if this is a SQLite database file which is the file format used for word vectors databases. + + Args: + path: path to check + + Returns: + True if this is a SQLite database + """ + + if isinstance(path, str) and os.path.isfile(path) and os.path.getsize(path) >= 100: + # Read 100 byte SQLite header + with open(path, "rb") as f: + header = f.read(100) + + # Check for SQLite header + return header.startswith(b"SQLite format 3\000") + + return False + @staticmethod def build(data, size, mincount, path): """ diff --git a/test/python/testonnx.py b/test/python/testonnx.py index 3a74f0f82..e53102aa2 100644 --- a/test/python/testonnx.py +++ b/test/python/testonnx.py @@ -6,13 +6,8 @@ import tempfile import unittest -import numpy as np - -from onnxruntime import InferenceSession, SessionOptions - -from transformers import AutoTokenizer - -from txtai.pipeline import HFOnnx, HFTrainer +from txtai.embeddings import Embeddings +from txtai.pipeline import HFOnnx, HFTrainer, Labels, Questions class TestOnnx(unittest.TestCase): @@ -45,50 +40,61 @@ def testClassification(self): Test exporting a classification model to ONNX and running inference """ - # Normalize logits using sigmoid function - sigmoid = lambda x: 1.0 / (1.0 + np.exp(-x)) + path = "google/bert_uncased_L-2_H-128_A-2" trainer = HFTrainer() - model, tokenizer = trainer("google/bert_uncased_L-2_H-128_A-2", self.data) + model, tokenizer = trainer(path, self.data) # Output file path output = os.path.join(tempfile.gettempdir(), "onnx") # Export model to ONNX onnx = HFOnnx() - model = onnx((model, tokenizer), "sequence-classification", output, True) - - # Build ONNX session - options = SessionOptions() - session = InferenceSession(model, options) - - # Tokenize and cast to int64 to support all platforms - tokens = tokenizer(["cat"], return_tensors="np") - tokens = {x: tokens[x].astype(np.int64) for x in tokens} + model = onnx((model, tokenizer), "text-classification", output, True) - # Run inference and validate - outputs = session.run(None, dict(tokens)) - outputs = sigmoid(outputs[0]) - self.assertEqual(np.argmax(outputs[0]), 1) + # Test classification + labels = Labels((model, path), dynamic=False) + self.assertEqual(labels("cat")[0][0], 1) def testPooling(self): """ Test exporting a pooling model to ONNX and running inference """ + path = "sentence-transformers/paraphrase-MiniLM-L3-v2" + + # Export model to ONNX + onnx = HFOnnx() + model = onnx(path, "pooling", quantize=True) + + embeddings = Embeddings({"path": model, "tokenizer": path}) + self.assertEqual(embeddings.similarity("animal", ["dog", "book", "rug"])[0][0], 0) + + def testQA(self): + """ + Test exporting a QA model to ONNX and running inference + """ + + path = "distilbert-base-cased-distilled-squad" + # Export model to ONNX onnx = HFOnnx() - model = onnx("sentence-transformers/paraphrase-MiniLM-L3-v2", "pooling", quantize=True) - tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2") + model = onnx(path, "question-answering") + + questions = Questions((model, path)) + self.assertEqual(questions(["What is the price?"], ["The price is $30"])[0], "$30") - # Build ONNX session - options = SessionOptions() - session = InferenceSession(model, options) + def testZeroShot(self): + """ + Test exporing a zero shot classification model to ONNX and running inference + """ + + path = "prajjwal1/bert-medium-mnli" - # Tokenize and cast to int64 to support all platforms - tokens = tokenizer(["cat"], return_tensors="np") - tokens = {x: tokens[x].astype(np.int64) for x in tokens} + # Export model to ONNX + onnx = HFOnnx() + model = onnx(path, "zero-shot-classification", quantize=True) - # Run inference and validate - outputs = session.run(None, dict(tokens)) - self.assertEqual(outputs[0].shape, (1, 384)) + # Test zero shot classification + labels = Labels((model, path)) + self.assertEqual(labels("That is great news", ["negative", "positive"])[0][0], 1) diff --git a/test/python/testoptional.py b/test/python/testoptional.py index 2c46f52a2..28c18cd69 100644 --- a/test/python/testoptional.py +++ b/test/python/testoptional.py @@ -7,6 +7,7 @@ import txtai.ann.factory from txtai.ann import ANNFactory +from txtai.models import OnnxModel from txtai.pipeline import HFOnnx, Segmentation, Textractor, Transcription, Translation from txtai.vectors import VectorsFactory from txtai.workflow.task.image import ImageTask @@ -27,6 +28,8 @@ def toggle(): txtai.ann.factory.ANNOY = not txtai.ann.factory.ANNOY txtai.ann.factory.HNSWLIB = not txtai.ann.factory.HNSWLIB + txtai.models.onnx.ONNX_RUNTIME = not txtai.models.onnx.ONNX_RUNTIME + txtai.pipeline.hfonnx.ONNX_RUNTIME = not txtai.pipeline.hfonnx.ONNX_RUNTIME txtai.pipeline.segmentation.NLTK = not txtai.pipeline.segmentation.NLTK txtai.pipeline.textractor.TIKA = not txtai.pipeline.textractor.TIKA @@ -68,6 +71,14 @@ def testAnn(self): with self.assertRaises(ImportError): ANNFactory.create({"backend": "hnsw"}) + def testModel(self): + """ + Test missing model dependencies + """ + + with self.assertRaises(ImportError): + OnnxModel(None) + def testPipeline(self): """ Test missing pipeline dependencies