diff --git a/src/concrete/ml/common/preprocessors.py b/src/concrete/ml/common/preprocessors.py index bb75bdadc..0b3434a7a 100644 --- a/src/concrete/ml/common/preprocessors.py +++ b/src/concrete/ml/common/preprocessors.py @@ -1571,12 +1571,12 @@ def add_extra_shape_to_subgraph(tlu_node, extra_dim: Tuple[int, ...]): node.evaluator.properties["constant"], node.evaluator.properties["constant"].shape + tuple([1 for _ in extra_dim]), ) + node.output.shape = node.evaluator.properties["constant"].shape continue node.inputs[0].shape = deepcopy(node.inputs[0].shape) node.inputs[0].shape = node.inputs[0].shape + extra_dim - # print(f"{node.output.shape=}") node.output.shape = deepcopy(node.output.shape) if node.output.shape[-1:] == extra_dim: # BIG HACK