Skip to content

Commit

Permalink
ok
Browse files Browse the repository at this point in the history
  • Loading branch information
johndpope committed Oct 4, 2024
1 parent 7940ef1 commit cfc1ee9
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 7 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ __pycache__/vit.cpython-311.pyc
__pycache__/helper.cpython-311.pyc
actions-runner/*
imf_encoder.onnx
imf_encoder_web.onnx
45 changes: 39 additions & 6 deletions onnxconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch.nn as nn
from PIL import Image
from torchvision import transforms

from onnxconverter_common import float16
import onnx

class IMFDecoder(nn.Module):
def __init__(self, model):
Expand Down Expand Up @@ -63,7 +64,22 @@ def print_model_structure(model):
if hasattr(module, 'bias') and module.bias is not None:
print(f" Bias shape: {module.bias.shape}")

# Adjusted export_to_onnx function

def convert_int64_to_int32(model):
for tensor in model.graph.initializer:
if tensor.data_type == onnx.TensorProto.INT64:
tensor.data_type = onnx.TensorProto.INT32
tensor.int64_data = tensor.int64_data.astype(np.int32)

for node in model.graph.node:
for attr in node.attribute:
if attr.type == onnx.AttributeProto.INT:
attr.i = int(attr.i)
elif attr.type == onnx.AttributeProto.INTS:
attr.ints[:] = [int(i) for i in attr.ints]

return model

def export_to_onnx(model, x_current, x_reference, file_name):
try:
print("Model structure before tracing:")
Expand Down Expand Up @@ -97,7 +113,7 @@ def export_to_onnx(model, x_current, x_reference, file_name):
opset_version=11,
do_constant_folding=True,
input_names=['x_current', 'x_reference'],
output_names=['f_r', 't_r', 't_c'], # Adjusted output names
output_names=['f_r', 't_r', 't_c'],
dynamic_axes={
'x_current': {0: 'batch_size'},
'x_reference': {0: 'batch_size'},
Expand All @@ -108,6 +124,23 @@ def export_to_onnx(model, x_current, x_reference, file_name):
verbose=True
)
print(f"Model exported successfully to {file_name}")

# Load the ONNX model
onnx_model = onnx.load(file_name)

# Convert int64 to int32
print("Converting int64 to int32...")
onnx_model = convert_int64_to_int32(onnx_model)

# Optionally, convert float32 to float16 to reduce model size
print("Converting float32 to float16...")
onnx_model = float16.convert_float_to_float16(onnx_model)

# Save the converted model
web_compatible_file = file_name.replace('.onnx', '_web.onnx')
onnx.save(onnx_model, web_compatible_file)
print(f"Web-compatible model saved as {web_compatible_file}")

except Exception as e:
print(f"Error during ONNX export: {str(e)}")
import traceback
Expand All @@ -128,9 +161,9 @@ def export_to_onnx(model, x_current, x_reference, file_name):
# Load images and preprocess
def load_image(image_path):
transform = transforms.Compose([
transforms.Resize((256, 256)), # Adjust as per your model's requirements
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], # Adjust as per your model's requirements
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert('RGB')
Expand All @@ -141,4 +174,4 @@ def load_image(image_path):
x_reference = load_image("x_reference.png")

# Export the model
export_to_onnx(encoder_model, x_current, x_reference, "imf_encoder.onnx")
export_to_onnx(encoder_model, x_current, x_reference, "imf_encoder.onnx")
5 changes: 4 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,7 @@ onnxruntime
opencv-python
pymatting
decord
mediapipe
mediapipe

onnx
onnxconverter_common

0 comments on commit cfc1ee9

Please sign in to comment.