From cfc1ee90e3181ac6bc61d25e7367662870260a86 Mon Sep 17 00:00:00 2001 From: John Pope Date: Fri, 4 Oct 2024 19:19:37 +1000 Subject: [PATCH] ok --- .gitignore | 1 + onnxconv.py | 45 +++++++++++++++++++++++++++++++++++++++------ requirements.txt | 5 ++++- 3 files changed, 44 insertions(+), 7 deletions(-) diff --git a/.gitignore b/.gitignore index 1d703f7..5622522 100644 --- a/.gitignore +++ b/.gitignore @@ -11,3 +11,4 @@ __pycache__/vit.cpython-311.pyc __pycache__/helper.cpython-311.pyc actions-runner/* imf_encoder.onnx +imf_encoder_web.onnx diff --git a/onnxconv.py b/onnxconv.py index cf79a92..e72ddfe 100644 --- a/onnxconv.py +++ b/onnxconv.py @@ -4,7 +4,8 @@ import torch.nn as nn from PIL import Image from torchvision import transforms - +from onnxconverter_common import float16 +import onnx class IMFDecoder(nn.Module): def __init__(self, model): @@ -63,7 +64,22 @@ def print_model_structure(model): if hasattr(module, 'bias') and module.bias is not None: print(f" Bias shape: {module.bias.shape}") -# Adjusted export_to_onnx function + +def convert_int64_to_int32(model): + for tensor in model.graph.initializer: + if tensor.data_type == onnx.TensorProto.INT64: + tensor.data_type = onnx.TensorProto.INT32 + tensor.int64_data = tensor.int64_data.astype(np.int32) + + for node in model.graph.node: + for attr in node.attribute: + if attr.type == onnx.AttributeProto.INT: + attr.i = int(attr.i) + elif attr.type == onnx.AttributeProto.INTS: + attr.ints[:] = [int(i) for i in attr.ints] + + return model + def export_to_onnx(model, x_current, x_reference, file_name): try: print("Model structure before tracing:") @@ -97,7 +113,7 @@ def export_to_onnx(model, x_current, x_reference, file_name): opset_version=11, do_constant_folding=True, input_names=['x_current', 'x_reference'], - output_names=['f_r', 't_r', 't_c'], # Adjusted output names + output_names=['f_r', 't_r', 't_c'], dynamic_axes={ 'x_current': {0: 'batch_size'}, 'x_reference': {0: 'batch_size'}, @@ -108,6 +124,23 @@ def export_to_onnx(model, x_current, x_reference, file_name): verbose=True ) print(f"Model exported successfully to {file_name}") + + # Load the ONNX model + onnx_model = onnx.load(file_name) + + # Convert int64 to int32 + print("Converting int64 to int32...") + onnx_model = convert_int64_to_int32(onnx_model) + + # Optionally, convert float32 to float16 to reduce model size + print("Converting float32 to float16...") + onnx_model = float16.convert_float_to_float16(onnx_model) + + # Save the converted model + web_compatible_file = file_name.replace('.onnx', '_web.onnx') + onnx.save(onnx_model, web_compatible_file) + print(f"Web-compatible model saved as {web_compatible_file}") + except Exception as e: print(f"Error during ONNX export: {str(e)}") import traceback @@ -128,9 +161,9 @@ def export_to_onnx(model, x_current, x_reference, file_name): # Load images and preprocess def load_image(image_path): transform = transforms.Compose([ - transforms.Resize((256, 256)), # Adjust as per your model's requirements + transforms.Resize((256, 256)), transforms.ToTensor(), - transforms.Normalize(mean=[0.485, 0.456, 0.406], # Adjust as per your model's requirements + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = Image.open(image_path).convert('RGB') @@ -141,4 +174,4 @@ def load_image(image_path): x_reference = load_image("x_reference.png") # Export the model -export_to_onnx(encoder_model, x_current, x_reference, "imf_encoder.onnx") +export_to_onnx(encoder_model, x_current, x_reference, "imf_encoder.onnx") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 9026aca..ac68006 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,4 +15,7 @@ onnxruntime opencv-python pymatting decord -mediapipe \ No newline at end of file +mediapipe + +onnx +onnxconverter_common \ No newline at end of file