diff --git a/optimum_transformers/generation_utils.py b/optimum_transformers/generation_utils.py index 4b98c01..7ae0bfd 100644 --- a/optimum_transformers/generation_utils.py +++ b/optimum_transformers/generation_utils.py @@ -1531,13 +1531,22 @@ def greedy_search( # prepare model inputs model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) + print(model_inputs.keys()) # forward pass to get next token if self.use_onnx: - inputs_onnx = { - "input_ids": model_inputs["input_ids"].cpu().detach().numpy(), - "attention_mask": model_inputs["attention_mask"].cpu().detach().numpy() - } + if self.model.config.is_encoder_decoder: + inputs_onnx = { + "input_ids": input_ids.cpu().detach().numpy(), + "attention_mask": model_inputs["attention_mask"].cpu().detach().numpy(), + "decoder_input_ids": model_inputs["decoder_input_ids"].cpu().detach().numpy(), + "decoder_attention_mask": model_inputs["attention_mask"].cpu().detach().numpy(), + } + else: + inputs_onnx = { + "input_ids": model_inputs["input_ids"].cpu().detach().numpy(), + "attention_mask": model_inputs["attention_mask"].cpu().detach().numpy() + } outputs = CausalLMOutputWithCrossAttentions( logits=torch.tensor(self.onnx_model.run(None, inputs_onnx)[0])) else: @@ -1783,10 +1792,19 @@ def sample( # forward pass to get next token if self.use_onnx: - inputs_onnx = { - "input_ids": model_inputs["input_ids"].cpu().detach().numpy(), - "attention_mask": model_inputs["attention_mask"].cpu().detach().numpy() - } + if self.model.config.is_encoder_decoder: + inputs_onnx = { + "input_ids": input_ids.cpu().detach().numpy(), + "attention_mask": model_inputs["attention_mask"].cpu().detach().numpy(), + } + if self.model.use_past: + inputs_onnx["decoder_inputs_ids"] = model_inputs["decoder_inputs_ids"].cpu().detach().numpy() + inputs_onnx["decoder_attention_mask"] = model_inputs["decoder_attention_mask"].cpu().detach().numpy() + else: + inputs_onnx = { + "input_ids": model_inputs["input_ids"].cpu().detach().numpy(), + "attention_mask": model_inputs["attention_mask"].cpu().detach().numpy() + } outputs = CausalLMOutputWithCrossAttentions( logits=torch.tensor(self.onnx_model.run(None, inputs_onnx)[0])) else: diff --git a/optimum_transformers/pipelines/__init__.py b/optimum_transformers/pipelines/__init__.py index 5038c79..9b83be4 100644 --- a/optimum_transformers/pipelines/__init__.py +++ b/optimum_transformers/pipelines/__init__.py @@ -34,6 +34,7 @@ from .question_answering import OptimumQuestionAnsweringPipeline from .text_classification import OptimumTextClassificationPipeline from .text_generation import OptimumTextGenerationPipeline +from .text2text_generation import OptimumText2TextGenerationPipeline from .token_classification import ( OptimumTokenClassificationPipeline, ) @@ -55,6 +56,7 @@ TFAutoModelForCausalLM, TFAutoModelForMaskedLM, TFAutoModelForQuestionAnswering, + TFAutoModelForSeq2SeqLM, TFAutoModelForSequenceClassification, TFAutoModelForTokenClassification, ) @@ -67,6 +69,7 @@ AutoModelForCausalLM, AutoModelForMaskedLM, AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, AutoModelForSequenceClassification, AutoModelForTokenClassification, ) @@ -178,6 +181,17 @@ "text_inputs": "HuggingFace is creating a tool that the community uses to solve NLP tasks." }, }, + "text2text-generation": { + "impl": OptimumText2TextGenerationPipeline, + "tf": (TFAutoModelForSeq2SeqLM,) if is_tf_available() else (), + "pt": (AutoModelForSeq2SeqLM,) if is_torch_available() else (), + "default": {"model": {"pt": "t5-base", "tf": "t5-base"}}, + "type": "text", + "feature": "seq2seq-lm", + "example": { + "text_inputs": "HuggingFace is creating a tool that the community uses to solve NLP tasks." + }, + }, } diff --git a/optimum_transformers/pipelines/text2text_generation.py b/optimum_transformers/pipelines/text2text_generation.py new file mode 100644 index 0000000..4b44abc --- /dev/null +++ b/optimum_transformers/pipelines/text2text_generation.py @@ -0,0 +1,36 @@ +from transformers import Text2TextGenerationPipeline +from transformers.file_utils import is_tf_available + +from ..generation_utils import GenerationMixin +from .base import _warmup_onnx_graph + +if is_tf_available(): + import tensorflow as tf + + +class OptimumText2TextGenerationPipeline(Text2TextGenerationPipeline): + def __init__(self, *args, onnx_model, example, **kwargs): + super().__init__(*args, **kwargs) + self.onnx_model = onnx_model + self.example = example + _warmup_onnx_graph(self) + + + def _forward(self, model_inputs, **generate_kwargs): + if self.framework == "pt": + in_b, input_length = model_inputs["input_ids"].shape + elif self.framework == "tf": + in_b, input_length = tf.shape(model_inputs["input_ids"]).numpy() + + generate_kwargs["min_length"] = generate_kwargs.get("min_length", self.model.config.min_length) + generate_kwargs["max_length"] = generate_kwargs.get("max_length", self.model.config.max_length) + self.check_inputs(input_length, generate_kwargs["min_length"], generate_kwargs["max_length"]) + generation_matrix = GenerationMixin(self.model, self.onnx_model) + output_ids = generation_matrix.generate(**model_inputs, **generate_kwargs) + out_b = output_ids.shape[0] + + if self.framework == "pt": + output_ids = output_ids.reshape(in_b, out_b // in_b, *output_ids.shape[1:]) + elif self.framework == "tf": + output_ids = tf.reshape(output_ids, (in_b, out_b // in_b, *output_ids.shape[1:])) + return {"output_ids": output_ids}