diff --git a/src/concrete/ml/quantization/quantized_module_passes.py b/src/concrete/ml/quantization/quantized_module_passes.py index 14ca0a5f8..c96c0edf4 100644 --- a/src/concrete/ml/quantization/quantized_module_passes.py +++ b/src/concrete/ml/quantization/quantized_module_passes.py @@ -1,5 +1,6 @@ """Optimization passes for QuantizedModules.""" from typing import Dict, List, Optional, Tuple +from collections import defaultdict import numpy @@ -59,6 +60,8 @@ def process(self) -> PatternDict: another Gemm/conv node. Optionally a Relu can be placed before this input quantization node. + Nothing will be done if rounding is already specified. + Returns: result (PatternDict): a dictionary containing for each Conv/Gemm node for which round PBS can be applied based on power-of-two scaling factors @@ -92,7 +95,7 @@ def compute_op_predecessors(self) -> PredecessorsType: """ # Initialize the list of predecessors with tensors that are graph inputs - predecessors: PredecessorsType = {} + predecessors: PredecessorsType = defaultdict(list) for (node_inputs, node_op) in self._qmodule.quant_layers_dict.values(): # The first input node contains the encrypted data @@ -105,10 +108,7 @@ def compute_op_predecessors(self) -> PredecessorsType: pred = self._qmodule.quant_layers_dict.get(enc_input_node, (None, None)) # Get the quantized op that produces the current op's input pred_with_output = (pred[1], enc_input_node) - if node_op not in predecessors: - predecessors[node_op] = [pred_with_output] - else: - predecessors[node_op].append(pred_with_output) + predecessors[node_op].append(pred_with_output) return predecessors def match_path_pattern( @@ -116,7 +116,7 @@ def match_path_pattern( predecessors: PredecessorsType, nodes_in_path: List[QuantizedOp], input_producer_of_path: Optional[QuantizedOp], - ): + ) -> bool: """Determine if a pattern has the structure that makes it viable for roundPBS. Args: @@ -182,6 +182,8 @@ def detect_patterns(self, predecessors: PredecessorsType) -> PatternDict: "Power of Two adapter: Error during graph traversal", ) # If multiple ops produced this node, the pattern is not matched + + assert back_node is not None if len(predecessors[back_node]) > 1: break @@ -192,6 +194,7 @@ def detect_patterns(self, predecessors: PredecessorsType) -> PatternDict: # The Gemm/Conv op that produces this integer node is the one # onto which we apply the roundPBS optimization nodes_in_path.append(back_node) + assert back_node is not None list_pred_of_path = predecessors.get(back_node, None) if list_pred_of_path is not None and len(list_pred_of_path) == 1: integer_node_input_quant = list_pred_of_path[0][0] @@ -227,7 +230,9 @@ def integer_log2(value: float) -> Tuple[int, bool]: invalid_paths: List[QuantizedMixingOp] = [] for path_start_node, (path, path_input_quant) in valid_paths.items(): + # Placeholders scale_input, scale_output, scale_weights = None, None, None + # Populate placeholders for node in path: if isinstance(node, self.SUPPORTED_ROUND_PBS_OPS): # Get the scale of the input of the Gemm/Conv node @@ -241,6 +246,7 @@ def integer_log2(value: float) -> Tuple[int, bool]: # node that will apply roundPBS scale_output = node.constant_inputs[1] + # Check placeholders assert scale_input is not None, ( "Power of two adapter: Can not determine input scale of pattern", ) @@ -251,10 +257,12 @@ def integer_log2(value: float) -> Tuple[int, bool]: "Power of two adapter: Can not determine output scale of pattern", ) + # Check if power of two log2_input, ok_input = integer_log2(scale_input) log2_weights, ok_weights = integer_log2(scale_weights) log2_output, ok_output = integer_log2(scale_output) + # Modify rounding if ok_input and ok_weights and ok_output: assert_true( path_start_node.rounding_threshold_bits is None,