Skip to content

Commit

Permalink
chore: correct pairwise euclidean_distances
Browse files Browse the repository at this point in the history
mistake in the dims
  • Loading branch information
kcelia committed Sep 1, 2023
1 parent 2791ee9 commit 8045f47
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions src/concrete/ml/sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,13 +1891,11 @@ def _inference(self, q_X: numpy.ndarray) -> numpy.ndarray:
# @ is used for matrices quand c'est une matrice @ -> matmul

distance_matrix = (
np.sum(q_X**2).reshape(1)
numpy.sum(q_X**2, axis=1, keepdims=True)
- 2 * q_X @ self._q_X_fit.T
+ np.sum(self._q_X_fit**2, axis=1).reshape(1, -1)
+ numpy.sum(self._q_X_fit**2, axis=1, keepdims=True)
)

# distance_matrix = np.sum(self._q_X_fit **2 + q_X**2 - 2 * self._q_X_fit * q_X, axis=1)

return distance_matrix

def predict(self, X: Data, fhe: Union[FheMode, str] = FheMode.DISABLE) -> numpy.ndarray:
Expand Down

0 comments on commit 8045f47

Please sign in to comment.