Skip to content

Commit

Permalink
wip(fix types)
Browse files Browse the repository at this point in the history
  • Loading branch information
fd0r committed Aug 24, 2023
1 parent 82c30d5 commit 373e81b
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions src/concrete/ml/quantization/quantized_module_passes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Optimization passes for QuantizedModules."""
from typing import Dict, List, Optional, Tuple
from collections import defaultdict

import numpy

Expand Down Expand Up @@ -59,6 +60,8 @@ def process(self) -> PatternDict:
another Gemm/conv node. Optionally a Relu can be placed before this input
quantization node.
Nothing will be done if rounding is already specified.
Returns:
result (PatternDict): a dictionary containing for each Conv/Gemm node for which
round PBS can be applied based on power-of-two scaling factors
Expand Down Expand Up @@ -92,7 +95,7 @@ def compute_op_predecessors(self) -> PredecessorsType:
"""

# Initialize the list of predecessors with tensors that are graph inputs
predecessors: PredecessorsType = {}
predecessors: PredecessorsType = defaultdict(list)

for (node_inputs, node_op) in self._qmodule.quant_layers_dict.values():
# The first input node contains the encrypted data
Expand All @@ -105,18 +108,15 @@ def compute_op_predecessors(self) -> PredecessorsType:
pred = self._qmodule.quant_layers_dict.get(enc_input_node, (None, None))
# Get the quantized op that produces the current op's input
pred_with_output = (pred[1], enc_input_node)
if node_op not in predecessors:
predecessors[node_op] = [pred_with_output]
else:
predecessors[node_op].append(pred_with_output)
predecessors[node_op].append(pred_with_output)
return predecessors

def match_path_pattern(
self,
predecessors: PredecessorsType,
nodes_in_path: List[QuantizedOp],
input_producer_of_path: Optional[QuantizedOp],
):
) -> bool:
"""Determine if a pattern has the structure that makes it viable for roundPBS.
Args:
Expand Down Expand Up @@ -182,6 +182,8 @@ def detect_patterns(self, predecessors: PredecessorsType) -> PatternDict:
"Power of Two adapter: Error during graph traversal",
)
# If multiple ops produced this node, the pattern is not matched

assert back_node is not None
if len(predecessors[back_node]) > 1:
break

Expand All @@ -192,6 +194,7 @@ def detect_patterns(self, predecessors: PredecessorsType) -> PatternDict:
# The Gemm/Conv op that produces this integer node is the one
# onto which we apply the roundPBS optimization
nodes_in_path.append(back_node)
assert back_node is not None
list_pred_of_path = predecessors.get(back_node, None)
if list_pred_of_path is not None and len(list_pred_of_path) == 1:
integer_node_input_quant = list_pred_of_path[0][0]
Expand Down Expand Up @@ -227,7 +230,9 @@ def integer_log2(value: float) -> Tuple[int, bool]:

invalid_paths: List[QuantizedMixingOp] = []
for path_start_node, (path, path_input_quant) in valid_paths.items():
# Placeholders
scale_input, scale_output, scale_weights = None, None, None
# Populate placeholders
for node in path:
if isinstance(node, self.SUPPORTED_ROUND_PBS_OPS):
# Get the scale of the input of the Gemm/Conv node
Expand All @@ -241,6 +246,7 @@ def integer_log2(value: float) -> Tuple[int, bool]:
# node that will apply roundPBS
scale_output = node.constant_inputs[1]

# Check placeholders
assert scale_input is not None, (
"Power of two adapter: Can not determine input scale of pattern",
)
Expand All @@ -251,10 +257,12 @@ def integer_log2(value: float) -> Tuple[int, bool]:
"Power of two adapter: Can not determine output scale of pattern",
)

# Check if power of two
log2_input, ok_input = integer_log2(scale_input)
log2_weights, ok_weights = integer_log2(scale_weights)
log2_output, ok_output = integer_log2(scale_output)

# Modify rounding
if ok_input and ok_weights and ok_output:
assert_true(
path_start_node.rounding_threshold_bits is None,
Expand Down

0 comments on commit 373e81b

Please sign in to comment.