-
Notifications
You must be signed in to change notification settings - Fork 163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support encrypted mul div #690
Conversation
c458a8a
to
e370395
Compare
03f7f1f
to
4ee14f1
Compare
fdea8e2
to
2f60349
Compare
f105961
to
38354db
Compare
assert min_non_zero_value is not None and min_non_zero_value > 0 | ||
self.min_non_zero_value = min_non_zero_value | ||
|
||
q_array_divider = QuantizedArray(self.n_bits, 1 / inputs[1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also, how can we be sure that inputs[1]
is not 0 ? I feel like 1 / inputs[1]
can fail here no ?
btw are values always positive ? or are we talking about any floats (including negatives) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do this right before:
min_non_zero_value = numpy.min(numpy.abs(inputs[1]))
# mypy
assert min_non_zero_value is not None and min_non_zero_value > 0
should be good no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes right, I missed the assert. But does that mean that we always expect values to be strictly positive ? how can we ensure that ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I could check is that no value are quant dequant to a 0 maybe
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes but my question is more about "how can you be sure of that ?" !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you are right we do the abs, so my questions is then : how can we be sure that we never get 0 in inputs[1]
? I understand that we assert or could add more assert, but I don't get how we are sure that this/these assert(s) will never ever fail in any circumstances (and if we can't, that's an issue imo) !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't. If the user provides a 0 it will fail. As any division fails when there are 0 in the denominator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah I see, in that case if it's a user issue, then we should have a raise ValueError
with an explicit message rather than an empty assert. Unless an error is triggered before reaching this point (and in that case, is it explicit enough ?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The assert here is for mypy. Not sure we should rewrite basic error message such as division by 0.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well I won't push too much on this one, but I just fear that the error will not be a clear "DivisionByZeroError" and might confuse the user, would be worth a check or even a test
input_1 = q_input_1.dequant() | ||
|
||
# Replace input_1 with min_non_zero_qvalue if input_1 is 0 | ||
input_1 = numpy.where(input_1 == 0, self.min_non_zero_value, input_1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is this working ? just replacing 0 with the (float) min ? I imagine that means that values are expected to be positive ! And then still, should we instead add self.min_non_zero_value
to ìnput_1`, instead of replacing 0, right ?
in any case we should explain in comments how all this work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm yeah I don't understand how that work either. I did that to fix the calibration but it should fail when compiling the circuit. Unless the numpy.where never does anything? I need to change that indeed thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so what should we do here then ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, here we are within a PBS and we assign values that are 0 to self.min_non_zero_value. That should work fine within a PBS.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes but my observation was that, since we replace all zeros with self.min_non_zero_value
:
- since this value is computed through
numpy.min(numpy.abs(inputs[1]))
, does that mean we expectinput_1
to always be positive ?- if not, for example,
[-2, 4, 0]
would be replaced by[-2, 4, 2]
and I don't think it makes much sense - if so, how can we be sure that
input_1
is>= 0
? more that having asserts I mean (like my comment in https://github.com/zama-ai/concrete-ml/pull/690/files#r1677747936)
- if not, for example,
- also, aren't we breaking the values' distribution by doing something like this ? initially I suggested to instead do something like
input_1 += min_non_zero_value
(orinput_1 +=
if positive) but don't think that's desirable as well. What about requantizing on [1, max_val] ? Anyway, I feel just replacing 0 by this new val is not exactly right
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I totally understand your worry on this however! Let's find the best solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh ok sorry I don't know why I kept seeing argmax
and not argmin
, so yeah ok we just get the closest value to 0, makes much more sense indeed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it make sense to instead add an epsilon to this 0 ? or simply to all values ? not sure it'll change much since we re-quantize after the 1/x
apart from that, I'm not sure I have better ideas here . If you want we could discuss this tomorrows yes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would it make sense to instead add an epsilon to this 0 ?
Yes that's what I proposed. An epsilon would make sense if it's represented by the quantized values. The scale is basically the smallest represented floating point. So we could use this instead of the min(abs(x)). But I am not entirely sure we can get the actual quantizer of the input. Also I am afraid that for large bit-width this epsilon is going to be super small and will lead to numerical errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah numerical errors could be an issue indeed. But actually this epsilon should be added to all values not just the 0 maybe ? in order to keep the same distribution ? not sure how the quant params will handle this though
@@ -363,8 +363,9 @@ def test_all_arith_ops( | |||
# Compute the quantized operator result | |||
quantized_output_vv = q_op(q_inputs_0, q_inputs_1).dequant() | |||
|
|||
# Check the R2 of raw output and quantized output | |||
check_r2_score(raw_output_vv, quantized_output_vv) | |||
if n_bits > 16: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we sure about this change ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We were running that check twice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure I see where we were doing the second chck 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check_r2_score(raw_output_vv, quantized_output_vv)
is right below the if
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but you added that one 😅 , my question was about the if statement that you added, like mentioned below
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah yeah what am saying -_-'. Sorry. So yes the problem is as I said, correctness were fine before with univariate LUT but in QuantizedDiv|Mul we have quant requant which impact the final accuracy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we know by how much it can impact ? how "strong" do we expect this impact to be ? I'm just being a bit suspicious here because of things like https://github.com/zama-ai/concrete-ml/pull/690/files#r1634728461 or https://github.com/zama-ai/concrete-ml/pull/690/files#r1634723604 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for this feature ! I have several questions and observations, mostly about adding more comments to make the code more clear
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the tests pass, this looks good
38354db
to
9eefa55
Compare
Can we merge this PR? I am working on compiling a model that involves encrypted mul, currently it prompts "AssertionError: Do not support this type of operation between encrypted tensors". I think this PR could fix this error. |
Yes sure. This PR has been open for too long already. We will merge it soon. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry, I still have some remaining questions !
265bb0f
to
325ee5d
Compare
@RomanBredehoft could you re-review this please. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
still think something is wrong sorry !
Coverage passed ✅Coverage details
|
closes https://github.com/zama-ai/concrete-ml-internal/issues/4418
closes https://github.com/zama-ai/concrete-ml-internal/issues/4163
ref https://github.com/zama-ai/concrete-internal/issues/716