Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Knn classifier in cml #217

Merged
merged 51 commits into from
Sep 21, 2023
Merged

Knn classifier in cml #217

merged 51 commits into from
Sep 21, 2023

Conversation

kcelia
Copy link
Collaborator

@kcelia kcelia commented Sep 1, 2023

As of now:

  • only predict function is supported, predict_proba is not.
  • what's done on the client side? the majority_vote.
  • what's done on the server side? pairwise_euclidean_distance and the k nearest labels.
    • sqrt is not performed in FHE, since sqrt is a monotonic function, it doesn't affect the argmax calculation. So, removing it will increase the computation,

closes #3818

@kcelia kcelia requested a review from a team as a code owner September 1, 2023 14:17
@cla-bot cla-bot bot added the cla-signed label Sep 1, 2023
@kcelia kcelia changed the title Knn classifier in cml v2 3818 Knn classifier in cml Sep 1, 2023
@kcelia kcelia force-pushed the knn_classifier_in_cml_v2_3818 branch 2 times, most recently from 8045f47 to a50ad1d Compare September 1, 2023 14:52
@kcelia kcelia force-pushed the knn_classifier_in_cml_v2_3818 branch from 6afa513 to 2f087c1 Compare September 4, 2023 08:05
Copy link
Collaborator

@jfrery jfrery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great first draft well done! Could you explain the choices about what is done in FHE and what is done in clear on the client side?

It looks like we only partly compute the distances in FHE. Why isn't the sqrt done in FHE? Is it just prohibitively expensive to compute the topk and majority vote in FHE? Currently it seems that we return all distance to the client which is probably going to leak quite some information about the training.

distance_matrix = (
numpy.sum(q_X**2, axis=1, keepdims=True)
- 2 * q_X @ self._q_X_fit.T
+ numpy.expand_dims(numpy.sum(self._q_X_fit**2, axis=1), 0)
Copy link
Collaborator

@jfrery jfrery Sep 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numpy.expand_dims(numpy.sum(self._q_X_fit**2, axis=1), 0) can be done at training time I suppose

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's a constant no? it will be precomputed by CP

@fd0r
Copy link
Collaborator

fd0r commented Sep 4, 2023

Seems like this PR adds a lot of time to the CI.

@kcelia kcelia marked this pull request as draft September 5, 2023 14:39
@kcelia kcelia force-pushed the knn_classifier_in_cml_v2_3818 branch 3 times, most recently from 45bc34f to 0dc23cf Compare September 11, 2023 13:08
src/concrete/ml/pytest/utils.py Show resolved Hide resolved
src/concrete/ml/search_parameters/p_error_search.py Outdated Show resolved Hide resolved
tests/deployment/test_client_server.py Outdated Show resolved Hide resolved
@kcelia kcelia force-pushed the knn_classifier_in_cml_v2_3818 branch 6 times, most recently from 69f11f8 to dac4e8c Compare September 20, 2023 14:48
# a training phase
self._q_fit_X: numpy.ndarray
# _y: Labels of `_q_fit_X`
self._y: numpy.ndarray
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haven't followed everything, so we keep this _y attribute then ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we can keep it but we need to make sure it doesn't exist in the model exported in the client (in client / server)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes ok so basically it's for the predict but not for the post_processing right got it

Copy link
Collaborator

@andrei-stoian-zama andrei-stoian-zama left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good !

tests/deployment/test_client_server.py Show resolved Hide resolved
Copy link
Collaborator

@jfrery jfrery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good just a few comments.

src/concrete/ml/sklearn/base.py Show resolved Hide resolved
x = scatter1d(x, max_x, range_i + d)

# Max index selection
sign = diff <= 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the sign but a boolean I guess. Why isn't this done right after computing diff?

Have you tried the CP comparison optimization by replacing

diff = a - b
sign = diff <= 0

to

sign = a <= b

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the beginning of this conversation, I thought we were skeptical about letting the compiler choose the best strategy.

But I missed Ruby's last messages, which are:

yes, it make the bitwidth compatible with the stategies you asked, and once bitwidth inference is done, it picks the best strategy based on an heuristic (I think it tries to minimize the number of TLU without increasing the maximum precision). So if a 8bit TLU already exists in the circuit it accepts to use that precision, otherwise it will try to stick to lower precisions.
It's not optimal in the sense of the cost model as it would requires solving the crypto-parameters

So, I think CP's comparison is worth using.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's indeed a boolean that tells us if a is greater than b.
I'll change the naming.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can just check on a simple example if you have any time improvement. Otherwise you can leave it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you don't have the time right now worth an issue then I think because this might be a good thing to try anyway

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do it in a separate PR.
Some testings are needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great can you create an issue for this?

@kcelia kcelia force-pushed the knn_classifier_in_cml_v2_3818 branch 3 times, most recently from 4b03fb9 to 3cf5e7d Compare September 21, 2023 09:52
@kcelia kcelia force-pushed the knn_classifier_in_cml_v2_3818 branch from 3cf5e7d to fd2c1c7 Compare September 21, 2023 11:15
@github-actions
Copy link

Coverage passed ✅

Coverage details

---------- coverage: platform linux, python 3.8.18-final-0 -----------
Name    Stmts   Miss  Cover   Missing
-------------------------------------
TOTAL    6085      0   100%

51 files skipped due to complete coverage.

Copy link
Collaborator

@jfrery jfrery left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Thanks. Please just create an issue to check the CP comparison (see comment)

@kcelia kcelia merged commit 1c33ec8 into main Sep 21, 2023
8 of 9 checks passed
@kcelia kcelia deleted the knn_classifier_in_cml_v2_3818 branch September 21, 2023 13:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants