Skip to content

Commit

Permalink
Backport fix from #633 for KNNShap
Browse files Browse the repository at this point in the history
  • Loading branch information
mdbenito committed Jan 12, 2025
1 parent ce35efd commit 262197f
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions src/pydvl/value/shapley/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,26 @@ def knn_shapley(u: Utility, *, progress: bool = True) -> ValuationResult:
# closest to farthest
_, indices = nns.kneighbors(u.data.x_test)

values: NDArray[np.float64] = np.zeros_like(u.data.indices, dtype=np.float64)
res = np.zeros_like(u.data.indices, dtype=np.float64)
n = len(u.data)
yt = u.data.y_train
iterator = enumerate(zip(u.data.y_test, indices), start=1)
for j, (y, ii) in tqdm(iterator, disable=not progress):
value_at_x = int(yt[ii[-1]] == y) / n
values[ii[-1]] += (value_at_x - values[ii[-1]]) / j
for i in range(n - 2, n_neighbors, -1): # farthest to closest
value_at_x = (
values[ii[i + 1]] + (int(yt[ii[i]] == y) - int(yt[ii[i + 1]] == y)) / i
)
values[ii[i]] += (value_at_x - values[ii[i]]) / j
for i in range(n_neighbors, -1, -1): # farthest to closest
value_at_x = (
values[ii[i + 1]]
+ (int(yt[ii[i]] == y) - int(yt[ii[i + 1]] == y)) / n_neighbors
)
values[ii[i]] += (value_at_x - values[ii[i]]) / j
values = np.zeros_like(u.data.indices, dtype=np.float64)
idx = ii[-1]
values[idx] = int(yt[idx] == y) / n

for i in range(n - 1, 0, -1):
prev_idx = idx
idx = ii[i - 1]
values[idx] = values[prev_idx] + (
int(yt[idx] == y) - int(yt[prev_idx] == y)
) / max(n_neighbors, i)
res += values

return ValuationResult(
algorithm="knn_shapley",
status=Status.Converged,
values=values,
values=res,
data_names=u.data.data_names,
)

0 comments on commit 262197f

Please sign in to comment.