Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/replace rounding by truncate [ON HOLD - NOT URGENT] #472

Closed
wants to merge 75 commits into from
Closed
Show file tree
Hide file tree
Changes from 74 commits
Commits
Show all changes
75 commits
Select commit Hold shift + click to select a range
a3649b6
chore: Version 1
kcelia Oct 25, 2023
e12f13a
chore: rebase on main
kcelia Dec 11, 2023
d88d603
chore: add new rounded operators
kcelia Nov 29, 2023
e630ab5
chore: add rounded operators to ONNX_OPS_TO_NUMPY_IMPL_BOOL and remov…
kcelia Nov 29, 2023
3ea13c8
chore: update convert.py
kcelia Nov 29, 2023
de51201
chore: add conversion to our rounded ops in onnx graph for trees
jfrery Nov 29, 2023
d79127f
chore: update conversion to rounded ops to use list of lsbs
jfrery Nov 29, 2023
dddead6
chore: fix lsb to remove according to op comparison
kcelia Nov 29, 2023
6a33e62
chore: reconcial with Jordan's version
kcelia Nov 29, 2023
9bf30ce
chore: update
kcelia Nov 30, 2023
c0107a6
chore: fix dump tests
kcelia Nov 30, 2023
6c9a475
chore: rename variable
kcelia Nov 30, 2023
3d858fe
chore: restore previous version
kcelia Dec 1, 2023
aac906f
chore: remove replace_operator_with_rounded_version function
kcelia Dec 1, 2023
992f230
chore: update
kcelia Dec 4, 2023
0243dc7
chore: add comments
kcelia Dec 4, 2023
dc08fb0
chore: fix divergence in graphs after serialization
kcelia Dec 4, 2023
78f3185
chore: update
kcelia Dec 5, 2023
4f9859b
chore: coverage
kcelia Dec 5, 2023
251e91d
chore: update lsbs 2
kcelia Dec 5, 2023
4d68c95
chore: update
kcelia Dec 6, 2023
d186ec8
chore: update
kcelia Dec 6, 2023
c1fcf5f
chore: update with new version
kcelia Dec 6, 2023
012ddb6
chore: update lsb computation
kcelia Dec 7, 2023
399ccc7
chore: fix assert c_r1 < c_r0 + 2
kcelia Dec 8, 2023
6956996
chore: update serialized test
kcelia Dec 8, 2023
a9b2b06
chore: version for Jordan
kcelia Dec 8, 2023
ffa4577
chore: version with MSB=1
kcelia Dec 8, 2023
9031531
chore: update
kcelia Dec 11, 2023
71b56b2
chore: remove rounding in serialization
kcelia Dec 11, 2023
ae51470
chore: update v3
kcelia Dec 11, 2023
0ad84a6
chore: update v4
kcelia Dec 11, 2023
586d7db
chore: remove the assert that checks that :
kcelia Dec 13, 2023
094df3e
chore: first test with truncate
kcelia Dec 13, 2023
51dd5f6
chore: first test with truncate
kcelia Dec 15, 2023
e284bd2
chore: run ensemble model aggregation in FHE
jfrery Dec 18, 2023
8c1e99a
chore: refresh notebooks
jfrery Dec 18, 2023
70f1775
chore: update celia
kcelia Dec 25, 2023
6262f12
chore: add op_input and op_leaves
kcelia Jan 11, 2024
67e3572
chore: restore non fhe computation
kcelia Jan 11, 2024
8e2e056
chore: update dump test
kcelia Jan 11, 2024
3e2289a
chore: update test dump
kcelia Jan 11, 2024
a9c8385
chore: fix pipeline test
kcelia Jan 12, 2024
a43f04c
chore: fix rounding test by decreasing the n_bits value because no cr…
kcelia Jan 12, 2024
ecf5c66
chore: reduce n_bits in simulation test to 4 bits otherwise OOM
kcelia Jan 12, 2024
9007362
chore: add a test for fhe sum
kcelia Jan 12, 2024
8fab929
chore: update
kcelia Jan 15, 2024
ba26a5c
chore: update
kcelia Jan 15, 2024
05256b2
chore: update
kcelia Jan 16, 2024
03d498b
chore: remove useless prints
kcelia Jan 16, 2024
cf879d3
chore: update get_n_bits_dict_trees
kcelia Jan 17, 2024
ff1c6b1
chore: update
kcelia Jan 17, 2024
d4ca140
chore: update comment
kcelia Jan 17, 2024
f96333b
chore: update simulated p_error test
kcelia Jan 17, 2024
cc4781f
chore: update coverage
kcelia Jan 17, 2024
a652001
chore: update tests
kcelia Jan 18, 2024
9839ce9
chore: update assert
kcelia Jan 18, 2024
7cb13e0
chore: update comment
kcelia Jan 22, 2024
7d93575
chore: update comment
kcelia Jan 22, 2024
70adfd5
chore: test dump in both cases (sum_fhe enabled and disabled)
kcelia Jan 22, 2024
783e7af
chore: remove env var
kcelia Jan 22, 2024
39b2972
chore: restore knn notebook
kcelia Jan 23, 2024
07b2f2a
chore: restore exp notebotebook
kcelia Jan 23, 2024
7fddece
chore: update v1
kcelia Jan 23, 2024
14d9dc0
chore: update v2
kcelia Jan 23, 2024
3248216
chore: update v3
kcelia Jan 23, 2024
ab45587
chore: update
kcelia Jan 23, 2024
0ad0f6d
chore: update comments
kcelia Jan 24, 2024
9b58948
chore: update
kcelia Jan 24, 2024
25bf839
chore: fix test dump
kcelia Jan 24, 2024
5be7255
chore: update comments
kcelia Jan 24, 2024
ee696ec
chore: remove comment
kcelia Jan 25, 2024
92e3fb9
chore: rebase
kcelia Jan 29, 2024
4e2fcda
chore: update
kcelia Jan 29, 2024
b17cfd8
chore: merge with public main
kcelia Feb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/concrete/ml/onnx/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import tempfile
import warnings
from pathlib import Path
from typing import Callable, Optional, Tuple, Union
from typing import Callable, Tuple, Union

import numpy
import onnx
Expand Down Expand Up @@ -255,7 +255,7 @@ def get_equivalent_numpy_forward_from_onnx(
def get_equivalent_numpy_forward_from_onnx_tree(
onnx_model: onnx.ModelProto,
check_model: bool = True,
lsbs_to_remove_for_trees: Optional[Tuple[int, int]] = None,
auto_truncate=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure about the type of this, should it be a bool?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't take care of the typing in this PR (it's still in draft) and as I explained it's on hold.

auto_truncate is the concrete.fhe.auto_truncate object and not a bool, it can be None, if the feature is disabled.

In our test, we check that the feature is stable, i.e. we get the same result with and without truncate.
(in case the concrete team makes an update that changes the behavior of truncate).

) -> Tuple[Callable[..., Tuple[numpy.ndarray, ...]], onnx.ModelProto]:
"""Get the numpy equivalent forward of the provided ONNX model for tree-based models only.

Expand All @@ -264,7 +264,7 @@ def get_equivalent_numpy_forward_from_onnx_tree(
forward.
check_model (bool): set to True to run the onnx checker on the model.
Defaults to True.
lsbs_to_remove_for_trees (Optional[Tuple[int, int]]): This parameter is exclusively used for
auto_truncate (TODO): This parameter is exclusively used for
optimizing tree-based models. It contains the values of the least significant bits to
remove during the tree traversal, where the first value refers to the first comparison
(either "less" or "less_or_equal"), while the second value refers to the "Equal"
Expand All @@ -280,6 +280,6 @@ def get_equivalent_numpy_forward_from_onnx_tree(
# Return lambda of numpy equivalent of onnx execution
return (
lambda *args: execute_onnx_with_numpy_trees(
equivalent_onnx_model.graph, lsbs_to_remove_for_trees, *args
equivalent_onnx_model.graph, auto_truncate, *args
)
), equivalent_onnx_model
23 changes: 8 additions & 15 deletions src/concrete/ml/onnx/onnx_impl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy
from concrete.fhe import conv as fhe_conv
from concrete.fhe import ones as fhe_ones
from concrete.fhe import round_bit_pattern
from concrete.fhe import truncate_bit_pattern
from concrete.fhe.tracing import Tracer

from ..common.debugging import assert_true
Expand Down Expand Up @@ -238,36 +238,29 @@ def onnx_avgpool_compute_norm_const(
# - Adjust the typing
# FIXME: https://github.com/zama-ai/concrete-ml-internal/issues/4143
def rounded_comparison(
x: numpy.ndarray, y: numpy.ndarray, lsbs_to_remove: int, operation: ComparisonOperationType
x: numpy.ndarray, y: numpy.ndarray, auto_truncate, operation: ComparisonOperationType
) -> Tuple[bool]:
"""Comparison operation using `round_bit_pattern` function.
"""Comparison operation using `truncate_bit_pattern` function.

`round_bit_pattern` rounds the bit pattern of an integer to the closer
`truncate_bit_pattern` rounds the bit pattern of an integer to the closer
It also checks for any potential overflow. If so, it readjusts the LSBs accordingly.

The parameter `lsbs_to_remove` in `round_bit_pattern` can either be an integer specifying the
The parameter `lsbs_to_remove` in `truncate_bit_pattern` can either be an integer specifying the
number of LSBS to remove, or an `AutoRounder` object that determines the required number of LSBs
based on the specified number of MSBs to retain. But in our case, we choose to compute the LSBs
manually.

Args:
x (numpy.ndarray): Input tensor
y (numpy.ndarray): Input tensor
lsbs_to_remove (int): Number of the least significant bits to remove
operation (ComparisonOperationType): Comparison operation, which can `<`, `<=` and `==`
auto_truncate: TODO
operation: TODO

Returns:
Tuple[bool]: If x and y satisfy the comparison operator.
"""

assert isinstance(lsbs_to_remove, int)

# Workaround: in this context, `round_bit_pattern` is used as a truncate operation.
# Consequently, we subtract a term, called `half` that will subsequently be re-added during the
# `round_bit_pattern` process.
half = 1 << (lsbs_to_remove - 1)

# To determine if 'x' 'operation' 'y' (operation being <, >, >=, <=), we evaluate 'x - y'
rounded_subtraction = round_bit_pattern((x - y) - half, lsbs_to_remove=lsbs_to_remove)
rounded_subtraction = truncate_bit_pattern(x - y, lsbs_to_remove=auto_truncate)

return (operation(rounded_subtraction),)
19 changes: 10 additions & 9 deletions src/concrete/ml/onnx/onnx_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@

# Original file:
# https://github.com/google/jax/blob/f6d329b2d9b5f83c6a59e5739aa1ca8d4d1ffa1c/examples/onnx2xla.py
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Any, Callable, Dict, Tuple

import numpy
import onnx
Expand Down Expand Up @@ -413,7 +413,7 @@
}
# All numpy operators used for tree-based models that support auto rounding
ONNX_COMPARISON_OPS_TO_ROUNDED_TREES_NUMPY_IMPL_BOOL = {
"Less": rounded_numpy_less_for_trees,
"Less": rounded_numpy_less_for_trees, # type: ignore[dict-item]
"Equal": rounded_numpy_equal_for_trees,
"LessOrEqual": rounded_numpy_less_or_equal_for_trees,
}
Expand Down Expand Up @@ -483,14 +483,14 @@ def execute_onnx_with_numpy(

def execute_onnx_with_numpy_trees(
graph: onnx.GraphProto,
lsbs_to_remove_for_trees: Optional[Tuple[int, int]],
auto_truncate,
*inputs: numpy.ndarray,
) -> Tuple[numpy.ndarray, ...]:
"""Execute the provided ONNX graph on the given inputs for tree-based models only.

Args:
graph (onnx.GraphProto): The ONNX graph to execute.
lsbs_to_remove_for_trees (Optional[Tuple[int, int]]): This parameter is exclusively used for
auto_truncate: This parameter is exclusively used for
optimizing tree-based models. It contains the values of the least significant bits to
remove during the tree traversal, where the first value refers to the first comparison
(either "less" or "less_or_equal"), while the second value refers to the "Equal"
Expand All @@ -505,7 +505,7 @@ def execute_onnx_with_numpy_trees(
op_type: Callable[..., Tuple[numpy.ndarray[Any, Any], ...]]

# If no tree-based optimization is specified, return standard execution
if lsbs_to_remove_for_trees is None:
if auto_truncate is None:
return execute_onnx_with_numpy(graph, *inputs)

node_results: Dict[str, numpy.ndarray] = dict(
Expand All @@ -521,11 +521,12 @@ def execute_onnx_with_numpy_trees(
attributes = {attribute.name: get_attribute(attribute) for attribute in node.attribute}

if node.op_type in ONNX_COMPARISON_OPS_TO_ROUNDED_TREES_NUMPY_IMPL_BOOL:
attributes["auto_truncate"] = auto_truncate

# The first LSB refers to `Less` or `LessOrEqual` comparisons
# The second LSB refers to `Equal` comparison
stage = 0 if node.op_type != "Equal" else 1
attributes["lsbs_to_remove_for_trees"] = lsbs_to_remove_for_trees[stage]
# # The first LSB refers to `Less` or `LessOrEqual` comparisons
# # The second LSB refers to `Equal` comparison
# stage = 0 if node.op_type != "Equal" else 1
# attributes["lsbs_to_remove_for_trees"] = lsbs_to_remove_for_trees[stage]

# Use rounded numpy operation to relevant comparison nodes
op_type = ONNX_COMPARISON_OPS_TO_ROUNDED_TREES_NUMPY_IMPL_BOOL[node.op_type]
Expand Down
31 changes: 16 additions & 15 deletions src/concrete/ml/onnx/ops_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,7 +892,7 @@ def rounded_numpy_equal_for_trees(
x: numpy.ndarray,
y: numpy.ndarray,
*,
lsbs_to_remove_for_trees: Optional[int] = None,
auto_truncate=None,
) -> Tuple[numpy.ndarray]:
"""Compute rounded equal in numpy according to ONNX spec for tree-based models only.

Expand All @@ -901,7 +901,7 @@ def rounded_numpy_equal_for_trees(
Args:
x (numpy.ndarray): Input tensor
y (numpy.ndarray): Input tensor
lsbs_to_remove_for_trees (Optional[int]): Number of the least significant bits to remove
auto_truncate: Number of the least significant bits to remove
for tree-based models only.

Returns:
Expand All @@ -916,9 +916,9 @@ def rounded_numpy_equal_for_trees(

# Option 2 is selected because it adheres to the established pattern in `rounded_comparison`
# which does: (a - b) - half.
if lsbs_to_remove_for_trees is not None and lsbs_to_remove_for_trees > 0:
if auto_truncate is not None:
return rounded_comparison(
y, x, lsbs_to_remove_for_trees, operation=lambda x: x >= 0
y, x, auto_truncate, operation=lambda x: x >= 0
) # pragma: no cover

# Else, default numpy_equal operator
Expand Down Expand Up @@ -1076,7 +1076,7 @@ def rounded_numpy_less_for_trees(
x: numpy.ndarray,
y: numpy.ndarray,
*,
lsbs_to_remove_for_trees: Optional[int] = None,
auto_truncate,
) -> Tuple[numpy.ndarray]:
"""Compute rounded less in numpy according to ONNX spec for tree-based models only.

Expand All @@ -1085,7 +1085,7 @@ def rounded_numpy_less_for_trees(
Args:
x (numpy.ndarray): Input tensor
y (numpy.ndarray): Input tensor
lsbs_to_remove_for_trees (Optional[int]): Number of the least significant bits to remove
auto_truncate: Number of the least significant bits to remove
for tree-based models only.

Returns:
Expand All @@ -1094,8 +1094,8 @@ def rounded_numpy_less_for_trees(

# numpy.less(x, y) is equivalent to :
# x - y <= 0 => round_bit_pattern(x - y - half) < 0
if lsbs_to_remove_for_trees is not None and lsbs_to_remove_for_trees > 0:
return rounded_comparison(x, y, lsbs_to_remove_for_trees, operation=lambda x: x < 0)
if auto_truncate is not None:
return rounded_comparison(x, y, auto_truncate, operation=lambda x: x < 0)

# Else, default numpy_less operator
return numpy_less(x, y)
Expand Down Expand Up @@ -1143,7 +1143,7 @@ def rounded_numpy_less_or_equal_for_trees(
x: numpy.ndarray,
y: numpy.ndarray,
*,
lsbs_to_remove_for_trees: Optional[int] = None,
auto_truncate=None,
) -> Tuple[numpy.ndarray]:
"""Compute rounded less or equal in numpy according to ONNX spec for tree-based models only.

Expand All @@ -1152,21 +1152,22 @@ def rounded_numpy_less_or_equal_for_trees(
Args:
x (numpy.ndarray): Input tensor
y (numpy.ndarray): Input tensor
lsbs_to_remove_for_trees (Optional[int]): Number of the least significant bits to remove
auto_truncate: Number of the least significant bits to remove
for tree-based models only.

Returns:
Tuple[numpy.ndarray]: Output tensor
"""

# numpy.less_equal(x, y) <= y is equivalent to :
# option 1: x - y <= 0 => round_bit_pattern(x - y + half) <= 0 or
# option 2: y - x >= 0 => round_bit_pattern(y - x - half) >= 0
# option 1: x - y <= 0 => truncate_bit_pattern(x - y + half) <= 0 or
# option 2: y - x >= 0 => truncate_bit_pattern(y - x - half) >= 0

# Option 2 is selected because it adheres to the established pattern in `rounded_comparison`
# which does: (a - b) - half.
if lsbs_to_remove_for_trees is not None and lsbs_to_remove_for_trees > 0:
return rounded_comparison(y, x, lsbs_to_remove_for_trees, operation=lambda x: x >= 0)
# which does: (a - b).

if auto_truncate is not None:
return rounded_comparison(y, x, auto_truncate, operation=lambda x: x >= 0)

# Else, default numpy_less_or_equal operator
return numpy_less_or_equal(x, y)
Expand Down
8 changes: 7 additions & 1 deletion src/concrete/ml/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Modules for quantization."""
from .base_quantized_op import QuantizedOp
from .post_training import PostTrainingAffineQuantization, PostTrainingQATImporter, get_n_bits_dict
from .post_training import (
PostTrainingAffineQuantization,
PostTrainingQATImporter,
_get_n_bits_dict_trees,
_inspect_tree_n_bits,
get_n_bits_dict,
)
from .quantized_module import QuantizedModule
from .quantized_ops import (
QuantizedAbs,
Expand Down
93 changes: 93 additions & 0 deletions src/concrete/ml/quantization/post_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,99 @@
from .quantized_ops import QuantizedBrevitasQuant
from .quantizers import QuantizationOptions, QuantizedArray, UniformQuantizer

# pylint: disable=too-many-lines


def _inspect_tree_n_bits(n_bits):
"""Validate the 'n_bits' parameter for tree-based models.

This function checks whether 'n_bits' is a valid integer or dictionary.
- If 'n_bits' is an integer, it must be a non-null positive, its value is assigned to
'op_inputs' and 'op_leaves' bits
- If it is a dictionary, it should contain integer values for keys 'op_leaves' and 'op_inputs',
where 'op_leaves' should not exceed 'op_inputs'.

The function raises a ValueError with a descriptive message if 'n_bits' does not meet
these criteria.

Args:
n_bits (int, Dict[str, int]): number of bits for quantization, can be a single value or
a dictionary with the following keys :
- "op_inputs" (mandatory): number of bits to quantize the input values
- "op_leaves" (optional): number of bits to quantize the leaves, must be less than or
equal to 'op_inputs. defaults to the value of 'op_inputs if not specified.

Raises:
ValueError: If 'n_bits' does not conform to the required format or value constraints.
"""

detailed_message = (
"Invalid 'n_bits', either pass a strictly positive integer or a dictionary containing "
"integer values for the following keys:\n"
"- 'op_inputs' (mandatory): number of bits to quantize the input values\n"
"- 'op_leaves' (optional): number of bits to quantize the leaves, must be less than or "
"equal to 'op_inputs'. Defaults to the value of 'op_inputs' if not specified."
"When using a single integer for n_bits, its value is assigned to 'op_inputs' and "
"'op_leaves' bits.\n"
)

error_message = ""

if isinstance(n_bits, int):
if n_bits <= 0:
error_message = "n_bits must be a strictly positive integer"
elif isinstance(n_bits, dict):
if "op_inputs" not in n_bits.keys():
error_message = "Invalid keys in `n_bits` dictionary. The key 'op_inputs' is mandatory"
elif set(n_bits.keys()) - {"op_leaves", "op_inputs"}:
error_message = (
"Invalid keys in 'n_bits' dictionary. Only 'op_inputs' (mandatory) and 'op_leaves' "
"(optional) are allowed"
)
elif not all(isinstance(value, int) and value > 0 for value in n_bits.values()):
error_message = "All values in 'n_bits' dictionary must be strictly positive integers"

elif n_bits.get("op_leaves", 0) > n_bits.get("op_inputs", 0):
error_message = "'op_leaves' must be less than or equal to 'op_inputs'"
else:
error_message = "n_bits must be either an integer or a dictionary"

if len(error_message) > 0:
raise ValueError(
f"{error_message}. Got '{type(n_bits)}' and '{n_bits}' value.\n{detailed_message}"
)


def _get_n_bits_dict_trees(n_bits: Union[int, Dict[str, int]]) -> Dict[str, int]:
"""Convert the n_bits parameter into a proper dictionary for tree based-models.

Args:
n_bits (int, Dict[str, int]): number of bits for quantization, can be a single value or
a dictionary with the following keys :
- "op_inputs" (mandatory): number of bits to quantize the input values
- "op_leaves" (optional): number of bits to quantize the leaves, must be less than or
equal to 'op_inputs'. defaults to the value of "op_inputs" if not specified.

When using a single integer for n_bits, its value is assigned to "op_inputs" and
"op_leaves" bits.

Returns:
n_bits_dict (Dict[str, int]): A dictionary properly representing the number of bits to use
for quantization.
"""

_inspect_tree_n_bits(n_bits)

# If a single integer is passed, we use a default value for the model's input and leaves
if isinstance(n_bits, int):
return {"op_inputs": n_bits, "op_leaves": n_bits}

# Default 'op_leaves' to 'op_inputs' if not specified
if "op_leaves" not in n_bits:
n_bits["op_leaves"] = n_bits["op_inputs"]

return n_bits


def get_n_bits_dict(n_bits: Union[int, Dict[str, int]]) -> Dict[str, int]:
"""Convert the n_bits parameter into a proper dictionary.
Expand Down
Loading
Loading