Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add power of two scaling adapter for roundPBS #118

Merged
merged 19 commits into from
Sep 12, 2023

Conversation

andrei-stoian-zama
Copy link
Collaborator

@andrei-stoian-zama andrei-stoian-zama commented Aug 1, 2023

  • Adds a pattern matcher for detecting the Power of Two scaling usage:
  • Fixes a bug in the onnx parser when an initalizer node is used by several nodes
  • Fixes a bug in the LLM notebook for reshape semantics
  • add tests with custom networks with compile_brevitas_qat
  • add tests when rounding is set manually and should not be overwritten
  • add tests for correctness -> not working well yet: https://github.com/zama-ai/concrete-ml-internal/issues/3946

Closes https://github.com/zama-ai/concrete-ml-internal/issues/3947
Closes https://github.com/zama-ai/concrete-ml-internal/issues/3946

@cla-bot cla-bot bot added the cla-signed label Aug 1, 2023
@@ -88,10 +93,11 @@ def __init__(

quant_name = f"quant{idx}"
quantizer = qnn.QuantIdentity(
bit_width=n_a_bits,
bit_width=8 if idx == 0 else n_a_bits,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inputs quantized to 8 bits. Should only be used when by defautl for power of two scaling (roundPBS helps with this)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the explanation, maybe add comment in the code to explain this !

@@ -32,7 +32,7 @@ def max_fhe_relu(q_x, axis=-1, keepdims=True):
if keepdims:
shape = list(result.shape)
shape.insert(axis, 1)
result = result.reshape(shape)
result = result.reshape(tuple(shape))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CP changed the semantics, they only accept tuples now

@andrei-stoian-zama andrei-stoian-zama marked this pull request as ready for review August 8, 2023 13:11
@andrei-stoian-zama andrei-stoian-zama requested a review from a team as a code owner August 8, 2023 13:11
@fd0r fd0r marked this pull request as draft August 21, 2023 13:46
@fd0r fd0r force-pushed the feat/add_qnn_power_of_two branch 3 times, most recently from 379a5f8 to ff87d8b Compare August 25, 2023 08:11
@andrei-stoian-zama andrei-stoian-zama force-pushed the feat/add_qnn_power_of_two branch from ff87d8b to 226b37c Compare August 28, 2023 09:44
@andrei-stoian-zama andrei-stoian-zama marked this pull request as ready for review August 30, 2023 07:44
@@ -695,6 +697,7 @@ def __init__(self, n_classes, n_bits, n_active, signed, narrow) -> None:
n_active (int): number of active (non-zero weight) neurons to keep
signed (bool): whether quantized integer values are signed
narrow (bool): whether the range of quantized integer values is narrow/symmetric
power_of_two_scaling (bool): whether to use power-of-two scaling quantizers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add an additional sentence explaining what "power-of-two scaling" is / can be used for ?

curr_inputs = {
input_name: node_results.get(input_name, None) for input_name in node.input
}
curr_inputs = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR could be a good opportunity to maybe add additional comments for the following section, as it's an important part of the code that still remains somewhat obscure to whomever stumbles on it ! There are already some comments but I believe some steps are missing or others could be a bit more detailed

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you could pin-point what is not clear it would help.. I don't really know which parts to explain

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes sure !

@@ -604,6 +607,9 @@ def quantize_module(self, *calibration_data: numpy.ndarray) -> QuantizedModule:
onnx_model=self.numpy_model.onnx_model,
)

adapter = PowerOfTwoScalingRoundPBSAdapter(quantized_module)
adapter.process()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is "process" ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok just saw the method below, but maybe add a comment here to briefly say what we do here ?

@@ -0,0 +1,25 @@
"""Custom Quantiation Aware Training Brevitas quantizers."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo "Quantiation" -> "Quantization"

the input value was an integer power of two
"""
log2_value = int(numpy.rint(numpy.log2(value)))
if numpy.isclose(numpy.power(2.0, log2_value), value, atol=0.01):
Copy link
Collaborator

@RomanBredehoft RomanBredehoft Aug 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it not enough to just check that numpy.rint(numpy.log2(value)) == numpy.log2(value) ? or am I missing something ?

if not, then maybe make "atol" a parameter for integer_log2 (or at least say why it was set to 0.01)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the best practice is to never compare floats using equality. But you're right about the 0.01, it's quite arbitrary and not a good value for low powers of two (it's too big when the power is say -7). I'll use rtol instead

else:
pass

# y_pred_clear_round = model.predict(x_test, fhe="disable")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the following supposed to be removed ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it back but correctness is not achieved, I'd rather push it as is and work on it later. I created an issue on it

Copy link
Collaborator

@RomanBredehoft RomanBredehoft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huge work, thanks a lot ! I have several comments (mostly things about detailing a bit the steps with comments)

Also, is it expected that apidocs were generated in this PR ? we'll update them before releasing anyway !


# Constant inputs
curr_cst_inputs: Dict[int, ONNXOpInputOutputType] = {}
for input_idx, (input_name, value) in enumerate(curr_inputs.items()):
for input_idx, (input_name, value) in enumerate(curr_inputs):
Copy link
Collaborator

@RomanBredehoft RomanBredehoft Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this section (the for loop) do overall ?

and then it's not very clear of what each if/else (in the loop) does / why they are like this

curr_inputs[input_name] for input_name in variable_input_names
input_data
for input_name, input_data in curr_inputs
if input_name in variable_input_names
)

# For mypy
Copy link
Collaborator

@RomanBredehoft RomanBredehoft Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

below this ( I can't comment further here) :

  • why casting is need : curr_calibration_data = cast(Tuple[numpy.ndarray], curr_calibration_data)
  • "# Find the unique integer producers of the current's op output tensor" what is a producer ? how are they used ?

@@ -455,10 +456,12 @@ def _quantize_layers(self, *input_calibration_data: numpy.ndarray):
has_variable_inputs = (len(curr_inputs) - len(curr_cst_inputs)) > 0

variable_input_names = [
input_name for input_name in curr_inputs if input_name not in constants
input_name for input_name, _ in curr_inputs if input_name not in constants
]
curr_calibration_data = tuple(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically we are only interested in data from variables right ?

@fd0r
Copy link
Collaborator

fd0r commented Aug 31, 2023

Conformance issue.

RomanBredehoft
RomanBredehoft previously approved these changes Aug 31, 2023
Copy link
Collaborator

@RomanBredehoft RomanBredehoft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unblocking my request for changes, fine for me if you take care of the remaining comments in a following PR !

@jfrery
Copy link
Collaborator

jfrery commented Sep 4, 2023

Should we also update a notebook to make sure there is a speed up? Would be great to add a definition about this quantizer. I.e. what it does and what speed up is expected. I suppose the gain would come from using rounding instead of a standard PBS to re-quantize?

@fd0r
Copy link
Collaborator

fd0r commented Sep 4, 2023

That's a very long CI, unsure what to do here 🤔

@fd0r fd0r self-requested a review September 4, 2023 08:53
Copy link
Collaborator

@fd0r fd0r left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI is taking way to long.
Let's not merge it to main as is as it would slow down other PRs

@github-actions
Copy link

github-actions bot commented Sep 8, 2023

Coverage passed ✅

Coverage details

---------- coverage: platform linux, python 3.8.18-final-0 -----------
Name    Stmts   Miss  Cover   Missing
-------------------------------------
TOTAL    5901      0   100%

50 files skipped due to complete coverage.

Comment on lines -74 to +71
" \"module__activation_function\": nn.Sigmoid,\n",
" \"module__activation_function\": nn.ReLU,\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this needed? Changing the architecture might have a positive impact on the accuracy which is not really what we want in this PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is because the power of two scaling feature only works with ReLU

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there an assert, if power of two is activated, that the activation is a relu?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it. But I don't see anything that prevents the use of sigmoid or other non linear function. What would happen then?

Copy link
Collaborator

@RomanBredehoft RomanBredehoft left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great, thanks a lot for this feature !!

n_accum_bits: int = MAX_BITWIDTH_BACKWARD_COMPATIBLE,
n_w_bits: int = 4,
n_a_bits: int = 4,
# No pruning by default as roundPBS keeps the PBS precision low
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What' is the chosen bit width to round down?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is determined from the learned quantization scales

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So these hyper parameters are useless? I am trying to understand what is the number of bits that's used to round?

n_prune_neurons_percentage: float = 0.0,
activation_function: Type = nn.ReLU,
quant_narrow: bool = False,
quant_signed: bool = True,
power_of_two_scaling: bool = True, # Default to true: use roundPBS to speed up the NNs
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit weird here how rounding is implicitly contained in the PoT feature. Shouldn't there be 2 distinct features?

Copy link
Collaborator

@fd0r fd0r left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved but some experiments/documentation should be added to explain to the user why and how this works.

@andrei-stoian-zama andrei-stoian-zama merged commit 546fac9 into main Sep 12, 2023
@andrei-stoian-zama andrei-stoian-zama deleted the feat/add_qnn_power_of_two branch September 12, 2023 10:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants