Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PyTorch model reader #1213

Merged
merged 3 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 19 additions & 14 deletions model_compression_toolkit/core/pytorch/reader/graph_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,19 @@ def _build_input_alloc_and_call_args(n: Node, input_tensors_in_node_kwargs: Dict
tensor_input_alloc = []
op_call_args = list(n.args)
if inputs_as_list:
op_call_args.pop(0)
# input tensors are a list in the first argument -> remove from op_call_args and go over
# the tensors in that list.
_args = op_call_args.pop(0)
elad-c marked this conversation as resolved.
Show resolved Hide resolved
else:
for in_node in n.all_input_nodes:
# The extra for loop is used to tackle the case of the same input tensor for this node (e.g. torch.add(x, x)).
for i, arg in enumerate(n.args):
if arg == in_node:
tensor_input_alloc.append(i)
for k, arg in input_tensors_in_node_kwargs.items():
if arg == in_node:
tensor_input_alloc.append(k)
_args = n.args
for in_node in n.all_input_nodes:
# The extra for loop is used to tackle the case of the same input tensor for this node (e.g. torch.add(x, x)).
for i, arg in enumerate(_args):
if arg == in_node:
tensor_input_alloc.append(i)
for k, arg in input_tensors_in_node_kwargs.items():
if arg == in_node:
tensor_input_alloc.append(k)

return op_call_args, tensor_input_alloc

Expand Down Expand Up @@ -253,11 +256,8 @@ def nodes_builder(model: GraphModule,
node_kwargs[k] = v

# Check if node's first input argument is a list of input fx nodes, such as torch.cat:
is_first_input_list_of_nodes = is_instance_first_arg(node, (list, tuple)) and all(
inputs_as_list = is_instance_first_arg(node, (list, tuple)) and all(
[isinstance(n, Node) for n in node.args[0]])
is_placeholder_a_list = is_instance_first_arg(node, Node) and \
node.args[0].op == PLACEHOLDER and node.args[0].meta[TYPE] in (list, tuple)
inputs_as_list = is_first_input_list_of_nodes or is_placeholder_a_list

# Build tensor_input_alloc required for the model builder. All input nodes are received as a list in the builder,
# so tensor_input_alloc is used to allocate each input tensor in the correct place in the node's args & kwargs.
Expand Down Expand Up @@ -333,7 +333,12 @@ def edges_builder(model: GraphModule,
if input_node in fx_node_2_graph_node:
# n_edges_for_input_node is for the case that the input node appears more than
# once as the input of the node, for example add(x, x)
n_edges_for_input_node = sum([1 for a in node.args if input_node == a])
if node in fx_node_2_graph_node and isinstance(fx_node_2_graph_node[node], FunctionalNode) and \
fx_node_2_graph_node[node].inputs_as_list:
_args = node.args[0]
else:
_args = node.args
n_edges_for_input_node = sum([1 for a in _args if input_node == a])
n_edges_for_input_node = max(n_edges_for_input_node, 1)

dst_index = node.all_input_nodes.index(input_node)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import model_compression_toolkit as mct
from model_compression_toolkit.constants import PYTORCH
from model_compression_toolkit.target_platform_capabilities.constants import IMX500_TP_MODEL
from model_compression_toolkit.core.pytorch.constants import CPU
from tests.pytorch_tests.model_tests.base_pytorch_feature_test import BasePytorchFeatureNetworkTest


Expand Down Expand Up @@ -60,6 +61,12 @@ def create_networks(self):
return Activation16BitNet()

def compare(self, quantized_model, float_model, input_x=None, quantization_info=None):
x = torch.from_numpy(input_x[0].astype('float32'))
out_f = float_model(x)
quantized_model = quantized_model.to(CPU)
out_q = quantized_model(x.to(CPU))
self.unit_test.assertTrue(out_f.shape == out_q.shape, "Output shape mismatch.")

mul1_act_quant = quantized_model.mul_activation_holder_quantizer
mul2_act_quant = quantized_model.mul_1_activation_holder_quantizer
self.unit_test.assertTrue(mul1_act_quant.activation_holder_quantizer.num_bits == 16,
Expand Down
Loading