-
Notifications
You must be signed in to change notification settings - Fork 154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add power of two scaling adapter for roundPBS #118
Conversation
@@ -88,10 +93,11 @@ def __init__( | |||
|
|||
quant_name = f"quant{idx}" | |||
quantizer = qnn.QuantIdentity( | |||
bit_width=n_a_bits, | |||
bit_width=8 if idx == 0 else n_a_bits, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Inputs quantized to 8 bits. Should only be used when by defautl for power of two scaling (roundPBS helps with this)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for the explanation, maybe add comment in the code to explain this !
use_case_examples/llm/utils.py
Outdated
@@ -32,7 +32,7 @@ def max_fhe_relu(q_x, axis=-1, keepdims=True): | |||
if keepdims: | |||
shape = list(result.shape) | |||
shape.insert(axis, 1) | |||
result = result.reshape(shape) | |||
result = result.reshape(tuple(shape)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CP changed the semantics, they only accept tuples now
379a5f8
to
ff87d8b
Compare
ff87d8b
to
226b37c
Compare
@@ -695,6 +697,7 @@ def __init__(self, n_classes, n_bits, n_active, signed, narrow) -> None: | |||
n_active (int): number of active (non-zero weight) neurons to keep | |||
signed (bool): whether quantized integer values are signed | |||
narrow (bool): whether the range of quantized integer values is narrow/symmetric | |||
power_of_two_scaling (bool): whether to use power-of-two scaling quantizers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe add an additional sentence explaining what "power-of-two scaling" is / can be used for ?
curr_inputs = { | ||
input_name: node_results.get(input_name, None) for input_name in node.input | ||
} | ||
curr_inputs = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this PR could be a good opportunity to maybe add additional comments for the following section, as it's an important part of the code that still remains somewhat obscure to whomever stumbles on it ! There are already some comments but I believe some steps are missing or others could be a bit more detailed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you could pin-point what is not clear it would help.. I don't really know which parts to explain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes sure !
@@ -604,6 +607,9 @@ def quantize_module(self, *calibration_data: numpy.ndarray) -> QuantizedModule: | |||
onnx_model=self.numpy_model.onnx_model, | |||
) | |||
|
|||
adapter = PowerOfTwoScalingRoundPBSAdapter(quantized_module) | |||
adapter.process() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is "process" ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok just saw the method below, but maybe add a comment here to briefly say what we do here ?
@@ -0,0 +1,25 @@ | |||
"""Custom Quantiation Aware Training Brevitas quantizers.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo "Quantiation" -> "Quantization"
the input value was an integer power of two | ||
""" | ||
log2_value = int(numpy.rint(numpy.log2(value))) | ||
if numpy.isclose(numpy.power(2.0, log2_value), value, atol=0.01): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it not enough to just check that numpy.rint(numpy.log2(value)) == numpy.log2(value)
? or am I missing something ?
if not, then maybe make "atol" a parameter for integer_log2
(or at least say why it was set to 0.01)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the best practice is to never compare floats using equality. But you're right about the 0.01, it's quite arbitrary and not a good value for low powers of two (it's too big when the power is say -7). I'll use rtol instead
tests/torch/test_brevitas_qat.py
Outdated
else: | ||
pass | ||
|
||
# y_pred_clear_round = model.predict(x_test, fhe="disable") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the following supposed to be removed ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added it back but correctness is not achieved, I'd rather push it as is and work on it later. I created an issue on it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
huge work, thanks a lot ! I have several comments (mostly things about detailing a bit the steps with comments)
Also, is it expected that apidocs were generated in this PR ? we'll update them before releasing anyway !
|
||
# Constant inputs | ||
curr_cst_inputs: Dict[int, ONNXOpInputOutputType] = {} | ||
for input_idx, (input_name, value) in enumerate(curr_inputs.items()): | ||
for input_idx, (input_name, value) in enumerate(curr_inputs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this section (the for loop) do overall ?
and then it's not very clear of what each if/else (in the loop) does / why they are like this
curr_inputs[input_name] for input_name in variable_input_names | ||
input_data | ||
for input_name, input_data in curr_inputs | ||
if input_name in variable_input_names | ||
) | ||
|
||
# For mypy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
below this ( I can't comment further here) :
- why casting is need :
curr_calibration_data = cast(Tuple[numpy.ndarray], curr_calibration_data)
- "# Find the unique integer producers of the current's op output tensor" what is a producer ? how are they used ?
@@ -455,10 +456,12 @@ def _quantize_layers(self, *input_calibration_data: numpy.ndarray): | |||
has_variable_inputs = (len(curr_inputs) - len(curr_cst_inputs)) > 0 | |||
|
|||
variable_input_names = [ | |||
input_name for input_name in curr_inputs if input_name not in constants | |||
input_name for input_name, _ in curr_inputs if input_name not in constants | |||
] | |||
curr_calibration_data = tuple( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
basically we are only interested in data from variables right ?
Conformance issue. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unblocking my request for changes, fine for me if you take care of the remaining comments in a following PR !
Should we also update a notebook to make sure there is a speed up? Would be great to add a definition about this quantizer. I.e. what it does and what speed up is expected. I suppose the gain would come from using rounding instead of a standard PBS to re-quantize? |
That's a very long CI, unsure what to do here 🤔 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CI is taking way to long.
Let's not merge it to main as is as it would slow down other PRs
Coverage passed ✅Coverage details
|
" \"module__activation_function\": nn.Sigmoid,\n", | ||
" \"module__activation_function\": nn.ReLU,\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this needed? Changing the architecture might have a positive impact on the accuracy which is not really what we want in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe this is because the power of two scaling feature only works with ReLU
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there an assert, if power of two is activated, that the activation is a relu?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it. But I don't see anything that prevents the use of sigmoid or other non linear function. What would happen then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great, thanks a lot for this feature !!
n_accum_bits: int = MAX_BITWIDTH_BACKWARD_COMPATIBLE, | ||
n_w_bits: int = 4, | ||
n_a_bits: int = 4, | ||
# No pruning by default as roundPBS keeps the PBS precision low |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What' is the chosen bit width to round down?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it is determined from the learned quantization scales
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So these hyper parameters are useless? I am trying to understand what is the number of bits that's used to round?
n_prune_neurons_percentage: float = 0.0, | ||
activation_function: Type = nn.ReLU, | ||
quant_narrow: bool = False, | ||
quant_signed: bool = True, | ||
power_of_two_scaling: bool = True, # Default to true: use roundPBS to speed up the NNs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a bit weird here how rounding is implicitly contained in the PoT feature. Shouldn't there be 2 distinct features?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved but some experiments/documentation should be added to explain to the user why and how this works.
Closes https://github.com/zama-ai/concrete-ml-internal/issues/3947
Closes https://github.com/zama-ai/concrete-ml-internal/issues/3946