Skip to content

Commit

Permalink
fix: make sure structured pruning and unstructured pruning work well …
Browse files Browse the repository at this point in the history
…together
  • Loading branch information
andrei-stoian-zama authored Sep 8, 2023
1 parent cafd8d1 commit ada18ab
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
5 changes: 4 additions & 1 deletion src/concrete/ml/sklearn/qnn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def make_pruning_permanent(self) -> None:
keep_idxs = numpy.setdiff1d(idx, neurons_removed_idx)

# Remove the pruning hooks on this layer
pruning.remove(layer, "weight")
if layer in self.pruned_layers:
pruning.remove(layer, "weight")
self.pruned_layers.remove(layer)
else:
keep_idxs = numpy.arange(weights.shape[0])

Expand Down Expand Up @@ -283,6 +285,7 @@ def enable_pruning(self) -> None:
# Use L2-norm structured pruning, using the torch ln_structured
# function, with norm=2 and axis=0 (output/neuron axis)
pruning.ln_structured(layer, "weight", self.n_prune_neurons_percentage, 2, 0)
self.pruned_layers.add(layer)

# Note this is counting only Linear layers
layer_idx += 1
Expand Down
7 changes: 5 additions & 2 deletions tests/sklearn/test_qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,10 @@ def test_failure_bad_data_types(model_classes, container, bad_types, expected_er

@pytest.mark.parametrize("activation_function", [pytest.param(nn.ReLU)])
@pytest.mark.parametrize("model_class", get_sklearn_neural_net_models())
def test_structured_pruning(activation_function, model_class, load_data, default_configuration):
@pytest.mark.parametrize("accum_bits", [5, 8])
def test_structured_pruning(
activation_function, model_class, accum_bits, load_data, default_configuration
):
"""Test whether the sklearn quantized NN wrappers compile to FHE and execute well on encrypted
inputs"""
n_features = 10
Expand Down Expand Up @@ -358,7 +361,7 @@ def test_structured_pruning(activation_function, model_class, load_data, default
"module__n_layers": 2,
"module__n_w_bits": 2,
"module__n_a_bits": 2,
"module__n_accum_bits": 8,
"module__n_accum_bits": accum_bits,
"module__activation_function": activation_function,
"max_epochs": 2,
"verbose": 0,
Expand Down

0 comments on commit ada18ab

Please sign in to comment.