From 8bc5d6c88e138c15c337dd7729f60cf4dcdf8e35 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Fri, 7 Mar 2025 15:48:47 -0800 Subject: [PATCH] Add Remove Accidental Hits layer. This layer zeroes the logits of accidental negatives. --- keras_rs/api/layers/__init__.py | 3 + .../retrieval/remove_accidental_hits.py | 83 +++++++++ .../retrieval/remove_accidental_hits_test.py | 159 ++++++++++++++++++ 3 files changed, 245 insertions(+) create mode 100644 keras_rs/src/layers/retrieval/remove_accidental_hits.py create mode 100644 keras_rs/src/layers/retrieval/remove_accidental_hits_test.py diff --git a/keras_rs/api/layers/__init__.py b/keras_rs/api/layers/__init__.py index e8d70d8..716f3ab 100644 --- a/keras_rs/api/layers/__init__.py +++ b/keras_rs/api/layers/__init__.py @@ -12,6 +12,9 @@ from keras_rs.src.layers.retrieval.hard_negative_mining import ( HardNegativeMining, ) +from keras_rs.src.layers.retrieval.remove_accidental_hits import ( + RemoveAccidentalHits, +) from keras_rs.src.layers.retrieval.sampling_probability_correction import ( SamplingProbabilityCorrection, ) diff --git a/keras_rs/src/layers/retrieval/remove_accidental_hits.py b/keras_rs/src/layers/retrieval/remove_accidental_hits.py new file mode 100644 index 0000000..8eccf13 --- /dev/null +++ b/keras_rs/src/layers/retrieval/remove_accidental_hits.py @@ -0,0 +1,83 @@ +import keras +import numpy as np +from keras import ops + +from keras_rs.src import types +from keras_rs.src.api_export import keras_rs_export +from keras_rs.src.utils import keras_utils + +SMALLEST_FLOAT = np.finfo(np.float32).smallest_normal / 100.0 + + +@keras_rs_export("keras_rs.layers.RemoveAccidentalHits") +class RemoveAccidentalHits(keras.layers.Layer): + """Zeroes the logits of accidental negatives. + + Zeroes the logits of negative candidates that have the same ID as the + positive candidate in that row. + """ + + def call( + self, + labels: types.Tensor, + logits: types.Tensor, + candidate_ids: types.Tensor, + ) -> types.Tensor: + """Zeroes selected logits. + + For each row in the batch, zeroes the logits of negative candidates that + have the same ID as the positive candidate in that row. + + Args: + labels: one-hot labels tensor, typically + `[batch_size, num_candidates]` but can have more dimensions or be + 1d as `[num_candidates]`. + logits: logits tensor. Must have the same shape as `labels`. + candidate_ids: candidate identifiers tensor, can be `[num_candidates]` + or `[batch_size, num_candidates]` or have more dimensions as long + as they match the last dimensions of `labels`. + + Returns: + logits: Modified logits. + """ + # A more principled way is to implement + # `softmax_cross_entropy_with_logits` with a input mask. Here we + # approximate so by letting accidental hits have extremely small logits + # (SMALLEST_FLOAT) for ease-of-implementation. + + labels_shape = ops.shape(labels) + labels_rank = len(labels_shape) + logits_shape = ops.shape(logits) + candidate_ids_shape = ops.shape(candidate_ids) + candidate_ids_rank = len(candidate_ids_shape) + + if not keras_utils.check_shapes_compatible(labels_shape, logits_shape): + raise ValueError( + "The shape of `labels` and `logits` must match. Got " + f"labels.shape = {labels_shape} and " + f"logits.shape = {logits_shape}." + ) + + if not keras_utils.check_shapes_compatible( + labels_shape[-candidate_ids_rank:], candidate_ids_shape + ): + raise ValueError( + "The shape of `candidate_ids` must match the last dimensions " + f"of `labels`. Got labels.shape = {labels_shape} and " + f"candidate_ids.shape = {candidate_ids_shape}." + ) + + # Add dimensions to `candidate_ids` to have the same rank as `labels`. + if candidate_ids_rank < labels_rank: + candidate_ids = ops.expand_dims( + candidate_ids, list(range(labels_rank - candidate_ids_rank)) + ) + positive_indices = ops.expand_dims(ops.argmax(labels, axis=-1), -1) + positive_candidate_ids = ops.take(candidate_ids, positive_indices) + + duplicate = ops.cast( + ops.equal(positive_candidate_ids, candidate_ids), labels.dtype + ) + duplicate = ops.subtract(duplicate, labels) + + return ops.add(logits, ops.multiply(duplicate, SMALLEST_FLOAT)) diff --git a/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py new file mode 100644 index 0000000..36b21fc --- /dev/null +++ b/keras_rs/src/layers/retrieval/remove_accidental_hits_test.py @@ -0,0 +1,159 @@ +import keras +from absl.testing import parameterized +from keras import ops +from keras.layers import deserialize +from keras.layers import serialize + +from keras_rs.src import testing +from keras_rs.src.layers.retrieval import remove_accidental_hits + + +class RemoveAccidentalHitsTest(testing.TestCase, parameterized.TestCase): + def create_inputs(self, logits_rank=2, candidate_ids_rank=1): + num_candidates = 10 + shape_3d = (15, 20, num_candidates) + shape = shape_3d[-logits_rank:] + candidate_ids_shape = shape_3d[-candidate_ids_rank:] + rng = keras.random.SeedGenerator(42) + + labels = keras.ops.one_hot( + keras.random.randint( + shape[:-1], minval=0, maxval=num_candidates, seed=rng + ), + num_candidates, + ) + logits = keras.random.uniform(shape, seed=rng) + candidate_ids = keras.random.randint( + candidate_ids_shape, minval=0, maxval=num_candidates, seed=rng + ) + return labels, logits, candidate_ids + + @parameterized.named_parameters( + [ + {"testcase_name": "1_1", "logits_rank": 1, "candidate_ids_rank": 1}, + {"testcase_name": "2_1", "logits_rank": 2, "candidate_ids_rank": 1}, + {"testcase_name": "2_2", "logits_rank": 2, "candidate_ids_rank": 2}, + {"testcase_name": "3_1", "logits_rank": 3, "candidate_ids_rank": 1}, + {"testcase_name": "3_2", "logits_rank": 3, "candidate_ids_rank": 2}, + {"testcase_name": "3_3", "logits_rank": 3, "candidate_ids_rank": 3}, + ] + ) + def test_call(self, logits_rank, candidate_ids_rank): + labels, logits, candidate_ids = self.create_inputs( + logits_rank=logits_rank, candidate_ids_rank=candidate_ids_rank + ) + + out_logits = remove_accidental_hits.RemoveAccidentalHits()( + labels, logits, candidate_ids + ) + + # Logits of labels are unchanged. + self.assertAllClose( + ops.sum(ops.multiply(out_logits, labels), axis=-1), + ops.sum(ops.multiply(logits, labels), axis=-1), + ) + + # Instead of having nested loops, which we can't do becasue they depend + # on the rank, we unroll the index combinations. + shape = ops.shape(logits) + if logits_rank == 1: + indices = [ + (), + ] + elif logits_rank == 2: + indices = [(i,) for i in range(shape[0])] + elif logits_rank == 3: + indices = [(i, j) for i in range(shape[0]) for j in range(shape[1])] + + for index_tuple in indices: + sub_labels = labels + sub_logits = logits + sub_out_logits = out_logits + sub_candidate_ids = candidate_ids + # This loop applies multiple indices to go deep several dimensions. + for i in index_tuple: + sub_labels = sub_labels[i] + sub_logits = sub_logits[i] + sub_out_logits = sub_out_logits[i] + if len(ops.shape(sub_candidate_ids)) > 1: + sub_candidate_ids = sub_candidate_ids[i] + + row_positive_idx = ops.argmax(sub_labels) + positive_candidate_id = sub_candidate_ids[row_positive_idx] + + for col_idx in range(sub_out_logits.shape[0]): + same_candidate_as_positive = ops.equal( + positive_candidate_id, sub_candidate_ids[col_idx] + ) + is_positive = ops.equal(col_idx, row_positive_idx) + + if ops.convert_to_numpy( + same_candidate_as_positive + ) and not ops.convert_to_numpy(is_positive): + # We zeroed the logits. + self.assertAllClose( + sub_out_logits[col_idx], + ops.add( + sub_logits[col_idx], + remove_accidental_hits.SMALLEST_FLOAT, + ), + ) + else: + # We left the logits unchanged. + self.assertAllClose( + sub_out_logits[col_idx], + sub_logits[col_idx], + ) + + def test_mismatched_labels_logits_shapes(self): + layer = remove_accidental_hits.RemoveAccidentalHits() + + with self.assertRaisesRegex( + ValueError, "shape of `labels` and `logits` must match" + ): + layer(ops.zeros((10, 20)), ops.zeros((10, 30)), ops.zeros((20,))) + + def test_mismatched_labels_candidates_shapes(self): + layer = remove_accidental_hits.RemoveAccidentalHits() + + with self.assertRaisesRegex( + ValueError, + "shape of `candidate_ids` must match .* `labels`", + ): + layer(ops.zeros((10, 20)), ops.zeros((10, 20)), ops.zeros((30,))) + + def test_predict(self): + # Note: for predict, we test with probabilities that have a batch dim. + labels, logits, candidate_ids = self.create_inputs(candidate_ids_rank=2) + + layer = remove_accidental_hits.RemoveAccidentalHits() + in_labels = keras.layers.Input(labels.shape[1:]) + in_logits = keras.layers.Input(logits.shape[1:]) + in_candidate_ids = keras.layers.Input(labels.shape[1:]) + out_logits = layer(in_labels, in_logits, in_candidate_ids) + model = keras.Model( + [in_labels, in_logits, in_candidate_ids], out_logits + ) + + model.predict([labels, logits, candidate_ids], batch_size=8) + + def test_serialization(self): + layer = remove_accidental_hits.RemoveAccidentalHits() + restored = deserialize(serialize(layer)) + self.assertDictEqual(layer.get_config(), restored.get_config()) + + def test_model_saving(self): + labels, logits, candidate_ids = self.create_inputs() + + layer = remove_accidental_hits.RemoveAccidentalHits() + in_labels = keras.layers.Input(labels.shape[1:]) + in_logits = keras.layers.Input(logits.shape[1:]) + in_candidate_ids = keras.layers.Input(batch_shape=candidate_ids.shape) + out_logits = layer(in_labels, in_logits, in_candidate_ids) + model = keras.Model( + [in_labels, in_logits, in_candidate_ids], out_logits + ) + + self.run_model_saving_test( + model=model, input_data=[labels, logits, candidate_ids] + )