You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
MACs calculation is broken for recurrent layers, i.e., if there are T time steps, MACs are off by a factor of T.
Code below reproduces the bug for GRU and LSTM layers, it works correctly for looping over a linear layer (although MAC numbers are not perfectly identical). For LSTM and GRU layers, thop (correctly) calculates MACs that are 100x higher for the torch model compared to the onnx model using onnx-tool.
To Reproduce
The code below creates a torch model with a single recurrent layer and measures MACs for torch model using thop and for exported onnx models using onnx-tool.
torch == 1.13.1
onnx-tool == 0.8.5
import torch
import torch.nn as nn
from thop import profile
import onnx_tool
from onnx_tool import create_ndarray_f32
import os
import pandas as pd
# Input size
BATCH_SIZE = 1
INPUT_SIZE = 64
HIDDEN_SIZE = INPUT_SIZE
STEPS = 100
# Define a simple Seq2Seq model
class Seq2SeqModel(nn.Module):
def __init__(self, layer_type: str):
super().__init__()
self.layer_type = layer_type
if self.layer_type == "linear":
self.layer = nn.Linear(INPUT_SIZE, HIDDEN_SIZE)
elif self.layer_type == "lstm":
self.layer = nn.LSTM(INPUT_SIZE, HIDDEN_SIZE, batch_first=True)
elif self.layer_type == "gru":
self.layer = nn.GRU(INPUT_SIZE, HIDDEN_SIZE, batch_first=True)
else:
raise ValueError("Invalid layer type")
def forward(self, x):
if self.layer_type == "linear":
out = []
for s in range(STEPS):
y = self.layer(x[:, s, :])
out.append(y)
out = torch.stack(out, dim=1)
else:
out, h = self.layer(x)
return out
# Test each layer and print results
layer_types = ["linear", "lstm", "gru"]
for layer_type in layer_types:
model = Seq2SeqModel(layer_type=layer_type)
# Sample input
sample_input = torch.randn(BATCH_SIZE, STEPS, INPUT_SIZE)
# Measure MACs and FLOPs with torch
macs, params = profile(model, inputs=(sample_input,))
# Print results
print("\n")
print("=" * 30)
print(f"Layer Type: {layer_type}")
print("-" * 30)
print("torch profiler results:")
print("-" * 30)
# print(f"Input shape: {sample_input.shape}")
# print(f"Output shape: {y.shape}")
print(f"Params: {int(params)}")
print(f"MACs: {int(macs)}")
print("-" * 30)
# export to onnx
tmpfile = "tmp.onnx"
with torch.no_grad():
torch_out = torch.onnx.export(model, sample_input, tmpfile, opset_version=12)
# profile onnx model with onnx-tool and save results to csv
onnx_tool.model_profile(
tmpfile,
saveshapesmodel=f"{layer_type}.onnx",
dynamic_shapes={"input_1": create_ndarray_f32(sample_input.shape)},
savenode=f"{layer_type}_profile.csv",
)
os.remove(tmpfile)
# Print results from onnx-tool profiler
df = pd.read_csv(f"{layer_type}_profile.csv")
print("-" * 30)
print("onnx-tool profiler results:")
print("-" * 30)
print(f"Params: {df.iloc[-1, :]['Params']}")
print(f"MACs: {df.iloc[-1, :]['Forward_MACs']}")
print("=" * 30)
print("\n")
The text was updated successfully, but these errors were encountered:
Describe the bug
MACs calculation is broken for recurrent layers, i.e., if there are T time steps, MACs are off by a factor of T.
Code below reproduces the bug for GRU and LSTM layers, it works correctly for looping over a linear layer (although MAC numbers are not perfectly identical). For LSTM and GRU layers, thop (correctly) calculates MACs that are 100x higher for the torch model compared to the onnx model using onnx-tool.
This is the result of running code below:
Layer Type: linear
torch thop profiler results:
Params: 4160
MACs: 409600
onnx-tool profiler results:
Params: 4160
MACs: 416000
Layer Type: lstm
Params: 33280
MACs: 3379200
Params: 33280
MACs: 33280
Layer Type: gru
Params: 24960
MACs: 2540800
Params: 24960
MACs: 24960
To Reproduce
The code below creates a torch model with a single recurrent layer and measures MACs for torch model using thop and for exported onnx models using onnx-tool.
torch == 1.13.1
onnx-tool == 0.8.5
The text was updated successfully, but these errors were encountered: