From 9027916d289e0d18dd6118afbfc10539a7b75a9e Mon Sep 17 00:00:00 2001 From: ChainYo Date: Sat, 2 Apr 2022 20:28:58 +0200 Subject: [PATCH 1/6] add text2text pipeline --- .../pipelines/text2text_pipeline.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 optimum_transformers/pipelines/text2text_pipeline.py diff --git a/optimum_transformers/pipelines/text2text_pipeline.py b/optimum_transformers/pipelines/text2text_pipeline.py new file mode 100644 index 0000000..d1ff369 --- /dev/null +++ b/optimum_transformers/pipelines/text2text_pipeline.py @@ -0,0 +1,37 @@ +from transformers import Text2TextGenerationPipeline +from transformers.file_utils import is_tf_available + +from ..generation_utils import GenerationMixin +from .base import _warmup_onnx_graph + +if is_tf_available(): + import tensorflow as tf + + +class OptimumText2TextGenerationPipeline(Text2TextGenerationPipeline): + def __init__(self, *args, onnx_model, example, **kwargs): + super().__init__(*args, **kwargs) + self.onnx_model = onnx_model + self.example = example + _warmup_onnx_graph(self) + + + def _forward(self, model_inputs, **generate_kwargs): + if self.framework == "pt": + in_b, input_length = model_inputs["input_ids"].shape + elif self.framework == "tf": + in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy() + + generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length) + generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length) + self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"]) + generation_matrix = GenerationMixin(self.model, self.onnx_model) + output_ids = generation_matrix(**model_inputs, **generate_kwargs) + out_b = output_ids.shape[0] + + if self.framework == "pt": + output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:]) + elif self.framework == "tf": + output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:])) + return {"output_ids": output_ids} + \ No newline at end of file From 13cc02169736901762b83b6a0b9cd81a9fd371cf Mon Sep 17 00:00:00 2001 From: ChainYo Date: Sat, 2 Apr 2022 20:35:23 +0200 Subject: [PATCH 2/6] rename file --- .../pipelines/{text2text_pipeline.py => text2text_generation.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename optimum_transformers/pipelines/{text2text_pipeline.py => text2text_generation.py} (100%) diff --git a/optimum_transformers/pipelines/text2text_pipeline.py b/optimum_transformers/pipelines/text2text_generation.py similarity index 100% rename from optimum_transformers/pipelines/text2text_pipeline.py rename to optimum_transformers/pipelines/text2text_generation.py From 3237c3bdcaa045a27c3d5357699fa805d104d1b6 Mon Sep 17 00:00:00 2001 From: ChainYo Date: Sat, 2 Apr 2022 20:35:31 +0200 Subject: [PATCH 3/6] add text2text to init --- optimum_transformers/pipelines/__init__.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/optimum_transformers/pipelines/__init__.py b/optimum_transformers/pipelines/__init__.py index 5038c79..dbacb80 100644 --- a/optimum_transformers/pipelines/__init__.py +++ b/optimum_transformers/pipelines/__init__.py @@ -34,6 +34,7 @@ from .question_answering import OptimumQuestionAnsweringPipeline from .text_classification import OptimumTextClassificationPipeline from .text_generation import OptimumTextGenerationPipeline +from .text2text_generation import OptimumText2TextGenerationPipeline from .token_classification import ( OptimumTokenClassificationPipeline, ) @@ -178,6 +179,17 @@ "text_inputs": "HuggingFace is creating a tool that the community uses to solve NLP tasks." }, }, + "text2text-generation": { + "impl": OptimumText2TextGenerationPipeline, + "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), + "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), + "default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, + "type": "text", + "feature": "seq2seq-lm", + "example": { + "text_inputs": "HuggingFace is creating a tool that the community uses to solve NLP tasks." + }, + }, } From bb193f14c17a0e34e92b8619a48ec62654e9da26 Mon Sep 17 00:00:00 2001 From: ChainYo Date: Sat, 2 Apr 2022 20:37:21 +0200 Subject: [PATCH 4/6] fix auto model imports --- optimum_transformers/pipelines/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/optimum_transformers/pipelines/__init__.py b/optimum_transformers/pipelines/__init__.py index dbacb80..9b83be4 100644 --- a/optimum_transformers/pipelines/__init__.py +++ b/optimum_transformers/pipelines/__init__.py @@ -56,6 +56,7 @@ TFAutoModelForCausalLM, TFAutoModelForMaskedLM, TFAutoModelForQuestionAnswering, + TFAutoModelForSeq2SeqLM, TFAutoModelForSequenceClassification, TFAutoModelForTokenClassification, ) @@ -68,6 +69,7 @@ AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, ) From 6d0a7d844a907f880274bf5feacb1cf70c1231ee Mon Sep 17 00:00:00 2001 From: ChainYo Date: Sat, 2 Apr 2022 20:45:00 +0200 Subject: [PATCH 5/6] fix generationmatrix call --- optimum_transformers/pipelines/text2text_generation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/optimum_transformers/pipelines/text2text_generation.py b/optimum_transformers/pipelines/text2text_generation.py index d1ff369..4b44abc 100644 --- a/optimum_transformers/pipelines/text2text_generation.py +++ b/optimum_transformers/pipelines/text2text_generation.py @@ -26,7 +26,7 @@ def _forward(self, model_inputs, **generate_kwargs): generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length) self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"]) generation_matrix = GenerationMixin(self.model, self.onnx_model) - output_ids = generation_matrix(**model_inputs, **generate_kwargs) + output_ids = generation_matrix.generate(**model_inputs, **generate_kwargs) out_b = output_ids.shape[0] if self.framework == "pt": @@ -34,4 +34,3 @@ def _forward(self, model_inputs, **generate_kwargs): elif self.framework == "tf": output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:])) return {"output_ids": output_ids} - \ No newline at end of file From 741b2e091e6d08195fdb838c05a5d556eaa9d1f6 Mon Sep 17 00:00:00 2001 From: ChainYo Date: Wed, 6 Apr 2022 11:42:25 +0200 Subject: [PATCH 6/6] incorrect fix --- optimum_transformers/generation_utils.py | 34 ++++++++++++++++++------ 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/optimum_transformers/generation_utils.py b/optimum_transformers/generation_utils.py index 4b98c01..7ae0bfd 100644 --- a/optimum_transformers/generation_utils.py +++ b/optimum_transformers/generation_utils.py @@ -1531,13 +1531,22 @@ def greedy_search( # prepare model inputs model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) + print(model_inputs.keys()) # forward pass to get next token if self.use_onnx: - inputs_onnx = { - "input_ids": model_inputs["input_ids"].cpu().detach().numpy(), - "attention_mask": model_inputs["attention_mask"].cpu().detach().numpy() - } + if self.model.config.is_encoder_decoder: + inputs_onnx = { + "input_ids": input_ids.cpu().detach().numpy(), + "attention_mask": model_inputs["attention_mask"].cpu().detach().numpy(), + "decoder_input_ids": model_inputs["decoder_input_ids"].cpu().detach().numpy(), + "decoder_attention_mask": model_inputs["attention_mask"].cpu().detach().numpy(), + } + else: + inputs_onnx = { + "input_ids": model_inputs["input_ids"].cpu().detach().numpy(), + "attention_mask": model_inputs["attention_mask"].cpu().detach().numpy() + } outputs = CausalLMOutputWithCrossAttentions( logits=torch.tensor(self.onnx_model.run(None, inputs_onnx)[0])) else: @@ -1783,10 +1792,19 @@ def sample( # forward pass to get next token if self.use_onnx: - inputs_onnx = { - "input_ids": model_inputs["input_ids"].cpu().detach().numpy(), - "attention_mask": model_inputs["attention_mask"].cpu().detach().numpy() - } + if self.model.config.is_encoder_decoder: + inputs_onnx = { + "input_ids": input_ids.cpu().detach().numpy(), + "attention_mask": model_inputs["attention_mask"].cpu().detach().numpy(), + } + if self.model.use_past: + inputs_onnx["decoder_inputs_ids"] = model_inputs["decoder_inputs_ids"].cpu().detach().numpy() + inputs_onnx["decoder_attention_mask"] = model_inputs["decoder_attention_mask"].cpu().detach().numpy() + else: + inputs_onnx = { + "input_ids": model_inputs["input_ids"].cpu().detach().numpy(), + "attention_mask": model_inputs["attention_mask"].cpu().detach().numpy() + } outputs = CausalLMOutputWithCrossAttentions( logits=torch.tensor(self.onnx_model.run(None, inputs_onnx)[0])) else: