Skip to content

Commit

Permalink
ok
Browse files Browse the repository at this point in the history
  • Loading branch information
johndpope committed Oct 4, 2024
1 parent cfc1ee9 commit 54a159b
Showing 1 changed file with 55 additions and 22 deletions.
77 changes: 55 additions & 22 deletions onnxconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torchvision import transforms
from onnxconverter_common import float16
import onnx
import numpy as np

class IMFDecoder(nn.Module):
def __init__(self, model):
Expand Down Expand Up @@ -65,20 +66,58 @@ def print_model_structure(model):
print(f" Bias shape: {module.bias.shape}")


def convert_int64_to_int32(model):
for tensor in model.graph.initializer:
if tensor.data_type == onnx.TensorProto.INT64:
tensor.data_type = onnx.TensorProto.INT32
tensor.int64_data = tensor.int64_data.astype(np.int32)

for node in model.graph.node:
for attr in node.attribute:
if attr.type == onnx.AttributeProto.INT:
attr.i = int(attr.i)
elif attr.type == onnx.AttributeProto.INTS:
attr.ints[:] = [int(i) for i in attr.ints]
import onnx
from onnx import numpy_helper
import numpy as np

def convert_float64_to_float32(tensor):
return onnx.helper.make_tensor(
name=tensor.name,
data_type=onnx.TensorProto.FLOAT,
dims=tensor.dims,
vals=numpy_helper.to_array(tensor).astype(np.float32).tobytes(),
raw=True,
)

def convert_int64_to_int32(tensor):
return onnx.helper.make_tensor(
name=tensor.name,
data_type=onnx.TensorProto.INT32,
dims=tensor.dims,
vals=numpy_helper.to_array(tensor).astype(np.int32).tobytes(),
raw=True,
)

def convert_model_to_32bit(model, output_path):
# Convert initializers
for initializer in model.graph.initializer:
if initializer.data_type == onnx.TensorProto.DOUBLE:
new_initializer = convert_float64_to_float32(initializer)
model.graph.initializer.remove(initializer)
model.graph.initializer.extend([new_initializer])
elif initializer.data_type == onnx.TensorProto.INT64:
new_initializer = convert_int64_to_int32(initializer)
model.graph.initializer.remove(initializer)
model.graph.initializer.extend([new_initializer])

# Convert inputs
for input in model.graph.input:
if input.type.tensor_type.elem_type == onnx.TensorProto.DOUBLE:
input.type.tensor_type.elem_type = onnx.TensorProto.FLOAT
elif input.type.tensor_type.elem_type == onnx.TensorProto.INT64:
input.type.tensor_type.elem_type = onnx.TensorProto.INT32

# Convert outputs
for output in model.graph.output:
if output.type.tensor_type.elem_type == onnx.TensorProto.DOUBLE:
output.type.tensor_type.elem_type = onnx.TensorProto.FLOAT
elif output.type.tensor_type.elem_type == onnx.TensorProto.INT64:
output.type.tensor_type.elem_type = onnx.TensorProto.INT32

# Save the converted model
onnx.save(model, output_path)
print(f"Converted model saved to {output_path}")

return model

def export_to_onnx(model, x_current, x_reference, file_name):
try:
Expand Down Expand Up @@ -128,18 +167,12 @@ def export_to_onnx(model, x_current, x_reference, file_name):
# Load the ONNX model
onnx_model = onnx.load(file_name)

# Convert int64 to int32
# Convert int64 to int32
print("Converting int64 to int32...")
onnx_model = convert_int64_to_int32(onnx_model)
web_compatible_file = file_name.replace('.onnx', '_web.onnx')
onnx_model = convert_model_to_32bit(onnx_model,web_compatible_file)

# Optionally, convert float32 to float16 to reduce model size
print("Converting float32 to float16...")
onnx_model = float16.convert_float_to_float16(onnx_model)

# Save the converted model
web_compatible_file = file_name.replace('.onnx', '_web.onnx')
onnx.save(onnx_model, web_compatible_file)
print(f"Web-compatible model saved as {web_compatible_file}")

except Exception as e:
print(f"Error during ONNX export: {str(e)}")
Expand Down

0 comments on commit 54a159b

Please sign in to comment.