Skip to content

Commit

Permalink
Add Remove Accidental Hits layer.
Browse files Browse the repository at this point in the history
This layer zeroes the logits of accidental negatives.
  • Loading branch information
hertschuh committed Mar 7, 2025
1 parent 23cffe9 commit 8bc5d6c
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 0 deletions.
3 changes: 3 additions & 0 deletions keras_rs/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from keras_rs.src.layers.retrieval.hard_negative_mining import (
HardNegativeMining,
)
from keras_rs.src.layers.retrieval.remove_accidental_hits import (
RemoveAccidentalHits,
)
from keras_rs.src.layers.retrieval.sampling_probability_correction import (
SamplingProbabilityCorrection,
)
83 changes: 83 additions & 0 deletions keras_rs/src/layers/retrieval/remove_accidental_hits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import keras
import numpy as np
from keras import ops

from keras_rs.src import types
from keras_rs.src.api_export import keras_rs_export
from keras_rs.src.utils import keras_utils

SMALLEST_FLOAT = np.finfo(np.float32).smallest_normal / 100.0


@keras_rs_export("keras_rs.layers.RemoveAccidentalHits")
class RemoveAccidentalHits(keras.layers.Layer):
"""Zeroes the logits of accidental negatives.
Zeroes the logits of negative candidates that have the same ID as the
positive candidate in that row.
"""

def call(
self,
labels: types.Tensor,
logits: types.Tensor,
candidate_ids: types.Tensor,
) -> types.Tensor:
"""Zeroes selected logits.
For each row in the batch, zeroes the logits of negative candidates that
have the same ID as the positive candidate in that row.
Args:
labels: one-hot labels tensor, typically
`[batch_size, num_candidates]` but can have more dimensions or be
1d as `[num_candidates]`.
logits: logits tensor. Must have the same shape as `labels`.
candidate_ids: candidate identifiers tensor, can be `[num_candidates]`
or `[batch_size, num_candidates]` or have more dimensions as long
as they match the last dimensions of `labels`.
Returns:
logits: Modified logits.
"""
# A more principled way is to implement
# `softmax_cross_entropy_with_logits` with a input mask. Here we
# approximate so by letting accidental hits have extremely small logits
# (SMALLEST_FLOAT) for ease-of-implementation.

labels_shape = ops.shape(labels)
labels_rank = len(labels_shape)
logits_shape = ops.shape(logits)
candidate_ids_shape = ops.shape(candidate_ids)
candidate_ids_rank = len(candidate_ids_shape)

if not keras_utils.check_shapes_compatible(labels_shape, logits_shape):
raise ValueError(
"The shape of `labels` and `logits` must match. Got "
f"labels.shape = {labels_shape} and "
f"logits.shape = {logits_shape}."
)

if not keras_utils.check_shapes_compatible(
labels_shape[-candidate_ids_rank:], candidate_ids_shape
):
raise ValueError(
"The shape of `candidate_ids` must match the last dimensions "
f"of `labels`. Got labels.shape = {labels_shape} and "
f"candidate_ids.shape = {candidate_ids_shape}."
)

# Add dimensions to `candidate_ids` to have the same rank as `labels`.
if candidate_ids_rank < labels_rank:
candidate_ids = ops.expand_dims(
candidate_ids, list(range(labels_rank - candidate_ids_rank))
)
positive_indices = ops.expand_dims(ops.argmax(labels, axis=-1), -1)
positive_candidate_ids = ops.take(candidate_ids, positive_indices)

duplicate = ops.cast(
ops.equal(positive_candidate_ids, candidate_ids), labels.dtype
)
duplicate = ops.subtract(duplicate, labels)

return ops.add(logits, ops.multiply(duplicate, SMALLEST_FLOAT))
159 changes: 159 additions & 0 deletions keras_rs/src/layers/retrieval/remove_accidental_hits_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import keras
from absl.testing import parameterized
from keras import ops
from keras.layers import deserialize
from keras.layers import serialize

from keras_rs.src import testing
from keras_rs.src.layers.retrieval import remove_accidental_hits


class RemoveAccidentalHitsTest(testing.TestCase, parameterized.TestCase):
def create_inputs(self, logits_rank=2, candidate_ids_rank=1):
num_candidates = 10
shape_3d = (15, 20, num_candidates)
shape = shape_3d[-logits_rank:]
candidate_ids_shape = shape_3d[-candidate_ids_rank:]
rng = keras.random.SeedGenerator(42)

labels = keras.ops.one_hot(
keras.random.randint(
shape[:-1], minval=0, maxval=num_candidates, seed=rng
),
num_candidates,
)
logits = keras.random.uniform(shape, seed=rng)
candidate_ids = keras.random.randint(
candidate_ids_shape, minval=0, maxval=num_candidates, seed=rng
)
return labels, logits, candidate_ids

@parameterized.named_parameters(
[
{"testcase_name": "1_1", "logits_rank": 1, "candidate_ids_rank": 1},
{"testcase_name": "2_1", "logits_rank": 2, "candidate_ids_rank": 1},
{"testcase_name": "2_2", "logits_rank": 2, "candidate_ids_rank": 2},
{"testcase_name": "3_1", "logits_rank": 3, "candidate_ids_rank": 1},
{"testcase_name": "3_2", "logits_rank": 3, "candidate_ids_rank": 2},
{"testcase_name": "3_3", "logits_rank": 3, "candidate_ids_rank": 3},
]
)
def test_call(self, logits_rank, candidate_ids_rank):
labels, logits, candidate_ids = self.create_inputs(
logits_rank=logits_rank, candidate_ids_rank=candidate_ids_rank
)

out_logits = remove_accidental_hits.RemoveAccidentalHits()(
labels, logits, candidate_ids
)

# Logits of labels are unchanged.
self.assertAllClose(
ops.sum(ops.multiply(out_logits, labels), axis=-1),
ops.sum(ops.multiply(logits, labels), axis=-1),
)

# Instead of having nested loops, which we can't do becasue they depend
# on the rank, we unroll the index combinations.
shape = ops.shape(logits)
if logits_rank == 1:
indices = [
(),
]
elif logits_rank == 2:
indices = [(i,) for i in range(shape[0])]
elif logits_rank == 3:
indices = [(i, j) for i in range(shape[0]) for j in range(shape[1])]

for index_tuple in indices:
sub_labels = labels
sub_logits = logits
sub_out_logits = out_logits
sub_candidate_ids = candidate_ids
# This loop applies multiple indices to go deep several dimensions.
for i in index_tuple:
sub_labels = sub_labels[i]
sub_logits = sub_logits[i]
sub_out_logits = sub_out_logits[i]
if len(ops.shape(sub_candidate_ids)) > 1:
sub_candidate_ids = sub_candidate_ids[i]

row_positive_idx = ops.argmax(sub_labels)
positive_candidate_id = sub_candidate_ids[row_positive_idx]

for col_idx in range(sub_out_logits.shape[0]):
same_candidate_as_positive = ops.equal(
positive_candidate_id, sub_candidate_ids[col_idx]
)
is_positive = ops.equal(col_idx, row_positive_idx)

if ops.convert_to_numpy(
same_candidate_as_positive
) and not ops.convert_to_numpy(is_positive):
# We zeroed the logits.
self.assertAllClose(
sub_out_logits[col_idx],
ops.add(
sub_logits[col_idx],
remove_accidental_hits.SMALLEST_FLOAT,
),
)
else:
# We left the logits unchanged.
self.assertAllClose(
sub_out_logits[col_idx],
sub_logits[col_idx],
)

def test_mismatched_labels_logits_shapes(self):
layer = remove_accidental_hits.RemoveAccidentalHits()

with self.assertRaisesRegex(
ValueError, "shape of `labels` and `logits` must match"
):
layer(ops.zeros((10, 20)), ops.zeros((10, 30)), ops.zeros((20,)))

def test_mismatched_labels_candidates_shapes(self):
layer = remove_accidental_hits.RemoveAccidentalHits()

with self.assertRaisesRegex(
ValueError,
"shape of `candidate_ids` must match .* `labels`",
):
layer(ops.zeros((10, 20)), ops.zeros((10, 20)), ops.zeros((30,)))

def test_predict(self):
# Note: for predict, we test with probabilities that have a batch dim.
labels, logits, candidate_ids = self.create_inputs(candidate_ids_rank=2)

layer = remove_accidental_hits.RemoveAccidentalHits()
in_labels = keras.layers.Input(labels.shape[1:])
in_logits = keras.layers.Input(logits.shape[1:])
in_candidate_ids = keras.layers.Input(labels.shape[1:])
out_logits = layer(in_labels, in_logits, in_candidate_ids)
model = keras.Model(
[in_labels, in_logits, in_candidate_ids], out_logits
)

model.predict([labels, logits, candidate_ids], batch_size=8)

def test_serialization(self):
layer = remove_accidental_hits.RemoveAccidentalHits()
restored = deserialize(serialize(layer))
self.assertDictEqual(layer.get_config(), restored.get_config())

def test_model_saving(self):
labels, logits, candidate_ids = self.create_inputs()

layer = remove_accidental_hits.RemoveAccidentalHits()
in_labels = keras.layers.Input(labels.shape[1:])
in_logits = keras.layers.Input(logits.shape[1:])
in_candidate_ids = keras.layers.Input(batch_shape=candidate_ids.shape)
out_logits = layer(in_labels, in_logits, in_candidate_ids)
model = keras.Model(
[in_labels, in_logits, in_candidate_ids], out_logits
)

self.run_model_saving_test(
model=model, input_data=[labels, logits, candidate_ids]
)

0 comments on commit 8bc5d6c

Please sign in to comment.