Skip to content

Commit

Permalink
[ONNX] Part1: update current onnx export (#976)
Browse files Browse the repository at this point in the history
*update onnx export function and introduce large model export on TF onnx side
  • Loading branch information
felixdittrich92 authored Jul 7, 2022
1 parent 0921b32 commit 0422234
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 32 deletions.
21 changes: 15 additions & 6 deletions doctr/models/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from doctr.utils.data import download_from_url

__all__ = ['load_pretrained_params', 'conv_sequence_pt', 'export_classification_model_to_onnx']
__all__ = ['load_pretrained_params', 'conv_sequence_pt', 'export_model_to_onnx']


def load_pretrained_params(
Expand Down Expand Up @@ -92,19 +92,25 @@ def conv_sequence_pt(
return conv_seq


def export_classification_model_to_onnx(model: nn.Module, model_name: str, dummy_input: torch.Tensor) -> str:
"""Export classification model to ONNX format.
def export_model_to_onnx(
model: nn.Module,
model_name: str,
dummy_input: torch.Tensor,
**kwargs: Any
) -> str:
"""Export model to ONNX format.
>>> import torch
>>> from doctr.models.classification import resnet18
>>> from doctr.models.utils import export_classification_model_to_onnx
>>> from doctr.models.utils import export_model_to_onnx
>>> model = resnet18(pretrained=True)
>>> export_classification_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))
>>> export_model_to_onnx(model, "my_model", dummy_input=torch.randn(1, 3, 32, 32))
Args:
model: the PyTorch model to be exported
model_name: the name for the exported model
dummy_input: the dummy input to the model
kwargs: additional arguments to be passed to torch.onnx.export
Returns:
the path to the exported model
Expand All @@ -116,7 +122,10 @@ def export_classification_model_to_onnx(model: nn.Module, model_name: str, dummy
input_names=['input'],
output_names=['logits'],
dynamic_axes={'input': {0: 'batch_size'}, 'logits': {0: 'batch_size'}},
export_params=True, opset_version=13, verbose=False
export_params=True,
opset_version=14, # minimum opset which support all operators we use (v0.5.2)
verbose=False,
**kwargs,
)
logging.info(f"Model exported to {model_name}.onnx")
return f"{model_name}.onnx"
32 changes: 22 additions & 10 deletions doctr/models/utils/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
logging.getLogger("tensorflow").setLevel(logging.DEBUG)


__all__ = ['load_pretrained_params', 'conv_sequence', 'IntermediateLayerGetter',
'export_classification_model_to_onnx']
__all__ = ['load_pretrained_params', 'conv_sequence', 'IntermediateLayerGetter', 'export_model_to_onnx']


def load_pretrained_params(
Expand Down Expand Up @@ -123,33 +122,46 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}()"


def export_classification_model_to_onnx(model: Model,
model_name: str,
dummy_input: List[tf.TensorSpec]) -> Tuple[str, List[str]]:
"""Export classification model to ONNX format.
def export_model_to_onnx(
model: Model,
model_name: str,
dummy_input: List[tf.TensorSpec],
**kwargs: Any
) -> Tuple[str, List[str]]:
"""Export model to ONNX format.
>>> import tensorflow as tf
>>> from doctr.models.classification import resnet18
>>> from doctr.models.utils import export_classification_model_to_onnx
>>> model = resnet18(pretrained=True, include_top=True)
>>> export_classification_model_to_onnx(model, "my_model",
>>> export_model_to_onnx(model, "my_model",
>>> dummy_input=[tf.TensorSpec([None, 32, 32, 3], tf.float32, name="input")])
Args:
model: the keras model to be exported
model_name: the name for the exported model
dummy_input: the dummy input to the model
kwargs: additional arguments to be passed to tf2onnx
Returns:
the path to the exported model and a list with the output layer names
"""
large_model = kwargs.get('large_model', False)
model_proto, _ = tf2onnx.convert.from_keras(
model,
opset=13,
opset=14, # minimum opset which support all operators we use (v0.5.2)
input_signature=dummy_input,
output_path=f"{model_name}.onnx",
output_path=f"{model_name}.zip" if large_model else f"{model_name}.onnx",
**kwargs,
)
# Get the output layer names
output = [n.name for n in model_proto.graph.output]
logging.info(f"Model exported to {model_name}.onnx")

# models which are too large (weights > 2GB while converting to ONNX) needs to be handled
# about an external tensor storage where the graph and weights are seperatly stored in a archive
if large_model:
logging.info(f"Model exported to {model_name}.zip")
return f"{model_name}.zip", output

logging.info(f"Model exported to {model_name}.zip")
return f"{model_name}.onnx", output
4 changes: 2 additions & 2 deletions references/classification/train_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from doctr import transforms as T
from doctr.datasets import VOCABS, CharacterGenerator
from doctr.models import classification, login_to_hub, push_to_hf_hub
from doctr.models.utils import export_classification_model_to_onnx
from doctr.models.utils import export_model_to_onnx
from utils import plot_recorder, plot_samples


Expand Down Expand Up @@ -352,7 +352,7 @@ def main(args):
print("Exporting model to ONNX...")
dummy_batch = next(iter(val_loader))
dummy_input = dummy_batch[0].cuda() if torch.cuda.is_available() else dummy_batch[0]
model_path = export_classification_model_to_onnx(model, exp_name, dummy_input)
model_path = export_model_to_onnx(model, exp_name, dummy_input)
print(f"Exported model saved in {model_path}")


Expand Down
4 changes: 2 additions & 2 deletions references/classification/train_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from doctr import transforms as T
from doctr.datasets import VOCABS, CharacterGenerator, DataLoader
from doctr.models import classification
from doctr.models.utils import export_classification_model_to_onnx
from doctr.models.utils import export_model_to_onnx
from utils import plot_recorder, plot_samples


Expand Down Expand Up @@ -310,7 +310,7 @@ def main(args):
if args.export_onnx:
print("Exporting model to ONNX...")
dummy_input = [tf.TensorSpec([None, args.input_size, args.input_size, 3], tf.float32, name="input")]
model_path, _ = export_classification_model_to_onnx(model, exp_name, dummy_input)
model_path, _ = export_model_to_onnx(model, exp_name, dummy_input)
print(f"Exported model saved in {model_path}")


Expand Down
7 changes: 3 additions & 4 deletions tests/pytorch/test_models_classification_pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from doctr.models import classification
from doctr.models.classification.predictor import CropOrientationPredictor
from doctr.models.utils import export_classification_model_to_onnx
from doctr.models.utils import export_model_to_onnx


def _test_classification(model, input_shape, output_size, batch_size=2):
Expand Down Expand Up @@ -129,9 +129,8 @@ def test_models_onnx_export(arch_name, input_shape, output_size):
dummy_input = torch.rand((batch_size, *input_shape), dtype=torch.float32)
with tempfile.TemporaryDirectory() as tmpdir:
# Export
model_path = export_classification_model_to_onnx(model,
model_name=os.path.join(tmpdir, "model"),
dummy_input=dummy_input)
model_path = export_model_to_onnx(model, model_name=os.path.join(tmpdir, "model"), dummy_input=dummy_input)

assert os.path.exists(model_path)
# Inference
ort_session = onnxruntime.InferenceSession(os.path.join(tmpdir, "model.onnx"),
Expand Down
17 changes: 9 additions & 8 deletions tests/tensorflow/test_models_classification_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from doctr.models import classification
from doctr.models.classification.predictor import CropOrientationPredictor
from doctr.models.utils import export_classification_model_to_onnx
from doctr.models.utils import export_model_to_onnx


@pytest.mark.parametrize(
Expand Down Expand Up @@ -94,18 +94,17 @@ def test_crop_orientation_model(mock_text_box):
["resnet34", (32, 32, 3), (126,)],
["resnet34_wide", (32, 32, 3), (126,)],
["resnet50", (32, 32, 3), (126,)],
# Name:'res_net_4/magc/transform/conv2d_289/Conv2D:0_nchwc'
# Status Message: Input channels C is not equal to kernel channels * group. C: 32 kernel channels: 256 group: 1
#["magc_resnet31", (32, 32, 3), (126,)],
["magc_resnet31", (32, 32, 3), (126,)],
# Disabled for now
# ["mobilenet_v3_small", (512, 512, 3), (126,)],
# ["mobilenet_v3_large", (512, 512, 3), (126,)],
# ["mobilenet_v3_small_orientation", (128, 128, 3), (4,)],
],
)
def test_models_saved_model_export(arch_name, input_shape, output_size):
def test_models_onnx_export(arch_name, input_shape, output_size):
# Model
batch_size = 2
tf.keras.backend.clear_session()
if arch_name == "mobilenet_v3_small_orientation":
model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape)
else:
Expand All @@ -115,9 +114,11 @@ def test_models_saved_model_export(arch_name, input_shape, output_size):
np_dummy_input = np.random.rand(batch_size, *input_shape).astype(np.float32)
with tempfile.TemporaryDirectory() as tmpdir:
# Export
model_path, output = export_classification_model_to_onnx(model,
model_name=os.path.join(tmpdir, "model"),
dummy_input=dummy_input)
model_path, output = export_model_to_onnx(
model,
model_name=os.path.join(tmpdir, "model"),
dummy_input=dummy_input
)

assert os.path.exists(model_path)
# Inference
Expand Down

0 comments on commit 0422234

Please sign in to comment.