From a50ad1de64f7bf1c23add5dc2c23947aaed1e4a8 Mon Sep 17 00:00:00 2001 From: kcelia Date: Fri, 1 Sep 2023 16:28:06 +0200 Subject: [PATCH] chore: correct pairwise euclidean_distances mistake in the dims --- src/concrete/ml/sklearn/base.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/concrete/ml/sklearn/base.py b/src/concrete/ml/sklearn/base.py index bfe17b2cd..33f896670 100644 --- a/src/concrete/ml/sklearn/base.py +++ b/src/concrete/ml/sklearn/base.py @@ -1891,13 +1891,11 @@ def _inference(self, q_X: numpy.ndarray) -> numpy.ndarray: # @ is used for matrices quand c'est une matrice @ -> matmul distance_matrix = ( - np.sum(q_X**2).reshape(1) + numpy.sum(q_X**2, axis=1, keepdims=True) - 2 * q_X @ self._q_X_fit.T - + np.sum(self._q_X_fit**2, axis=1).reshape(1, -1) + + numpy.expand_dims(numpy.sum(self._q_X_fit**2, axis=1), 0) ) - # distance_matrix = np.sum(self._q_X_fit **2 + q_X**2 - 2 * self._q_X_fit * q_X, axis=1) - return distance_matrix def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.ndarray: