-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This layer zeroes the logits of accidental negatives.
- Loading branch information
Showing
3 changed files
with
245 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import keras | ||
import numpy as np | ||
from keras import ops | ||
|
||
from keras_rs.src import types | ||
from keras_rs.src.api_export import keras_rs_export | ||
from keras_rs.src.utils import keras_utils | ||
|
||
SMALLEST_FLOAT = np.finfo(np.float32).smallest_normal / 100.0 | ||
|
||
|
||
@keras_rs_export("keras_rs.layers.RemoveAccidentalHits") | ||
class RemoveAccidentalHits(keras.layers.Layer): | ||
"""Zeroes the logits of accidental negatives. | ||
Zeroes the logits of negative candidates that have the same ID as the | ||
positive candidate in that row. | ||
""" | ||
|
||
def call( | ||
self, | ||
labels: types.Tensor, | ||
logits: types.Tensor, | ||
candidate_ids: types.Tensor, | ||
) -> types.Tensor: | ||
"""Zeroes selected logits. | ||
For each row in the batch, zeroes the logits of negative candidates that | ||
have the same ID as the positive candidate in that row. | ||
Args: | ||
labels: one-hot labels tensor, typically | ||
`[batch_size, num_candidates]` but can have more dimensions or be | ||
1d as `[num_candidates]`. | ||
logits: logits tensor. Must have the same shape as `labels`. | ||
candidate_ids: candidate identifiers tensor, can be `[num_candidates]` | ||
or `[batch_size, num_candidates]` or have more dimensions as long | ||
as they match the last dimensions of `labels`. | ||
Returns: | ||
logits: Modified logits. | ||
""" | ||
# A more principled way is to implement | ||
# `softmax_cross_entropy_with_logits` with a input mask. Here we | ||
# approximate so by letting accidental hits have extremely small logits | ||
# (SMALLEST_FLOAT) for ease-of-implementation. | ||
|
||
labels_shape = ops.shape(labels) | ||
labels_rank = len(labels_shape) | ||
logits_shape = ops.shape(logits) | ||
candidate_ids_shape = ops.shape(candidate_ids) | ||
candidate_ids_rank = len(candidate_ids_shape) | ||
|
||
if not keras_utils.check_shapes_compatible(labels_shape, logits_shape): | ||
raise ValueError( | ||
"The shape of `labels` and `logits` must match. Got " | ||
f"labels.shape = {labels_shape} and " | ||
f"logits.shape = {logits_shape}." | ||
) | ||
|
||
if not keras_utils.check_shapes_compatible( | ||
labels_shape[-candidate_ids_rank:], candidate_ids_shape | ||
): | ||
raise ValueError( | ||
"The shape of `candidate_ids` must match the last dimensions " | ||
f"of `labels`. Got labels.shape = {labels_shape} and " | ||
f"candidate_ids.shape = {candidate_ids_shape}." | ||
) | ||
|
||
# Add dimensions to `candidate_ids` to have the same rank as `labels`. | ||
if candidate_ids_rank < labels_rank: | ||
candidate_ids = ops.expand_dims( | ||
candidate_ids, list(range(labels_rank - candidate_ids_rank)) | ||
) | ||
positive_indices = ops.expand_dims(ops.argmax(labels, axis=-1), -1) | ||
positive_candidate_ids = ops.take(candidate_ids, positive_indices) | ||
|
||
duplicate = ops.cast( | ||
ops.equal(positive_candidate_ids, candidate_ids), labels.dtype | ||
) | ||
duplicate = ops.subtract(duplicate, labels) | ||
|
||
return ops.add(logits, ops.multiply(duplicate, SMALLEST_FLOAT)) |
159 changes: 159 additions & 0 deletions
159
keras_rs/src/layers/retrieval/remove_accidental_hits_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import keras | ||
from absl.testing import parameterized | ||
from keras import ops | ||
from keras.layers import deserialize | ||
from keras.layers import serialize | ||
|
||
from keras_rs.src import testing | ||
from keras_rs.src.layers.retrieval import remove_accidental_hits | ||
|
||
|
||
class RemoveAccidentalHitsTest(testing.TestCase, parameterized.TestCase): | ||
def create_inputs(self, logits_rank=2, candidate_ids_rank=1): | ||
num_candidates = 10 | ||
shape_3d = (15, 20, num_candidates) | ||
shape = shape_3d[-logits_rank:] | ||
candidate_ids_shape = shape_3d[-candidate_ids_rank:] | ||
rng = keras.random.SeedGenerator(42) | ||
|
||
labels = keras.ops.one_hot( | ||
keras.random.randint( | ||
shape[:-1], minval=0, maxval=num_candidates, seed=rng | ||
), | ||
num_candidates, | ||
) | ||
logits = keras.random.uniform(shape, seed=rng) | ||
candidate_ids = keras.random.randint( | ||
candidate_ids_shape, minval=0, maxval=num_candidates, seed=rng | ||
) | ||
return labels, logits, candidate_ids | ||
|
||
@parameterized.named_parameters( | ||
[ | ||
{"testcase_name": "1_1", "logits_rank": 1, "candidate_ids_rank": 1}, | ||
{"testcase_name": "2_1", "logits_rank": 2, "candidate_ids_rank": 1}, | ||
{"testcase_name": "2_2", "logits_rank": 2, "candidate_ids_rank": 2}, | ||
{"testcase_name": "3_1", "logits_rank": 3, "candidate_ids_rank": 1}, | ||
{"testcase_name": "3_2", "logits_rank": 3, "candidate_ids_rank": 2}, | ||
{"testcase_name": "3_3", "logits_rank": 3, "candidate_ids_rank": 3}, | ||
] | ||
) | ||
def test_call(self, logits_rank, candidate_ids_rank): | ||
labels, logits, candidate_ids = self.create_inputs( | ||
logits_rank=logits_rank, candidate_ids_rank=candidate_ids_rank | ||
) | ||
|
||
out_logits = remove_accidental_hits.RemoveAccidentalHits()( | ||
labels, logits, candidate_ids | ||
) | ||
|
||
# Logits of labels are unchanged. | ||
self.assertAllClose( | ||
ops.sum(ops.multiply(out_logits, labels), axis=-1), | ||
ops.sum(ops.multiply(logits, labels), axis=-1), | ||
) | ||
|
||
# Instead of having nested loops, which we can't do becasue they depend | ||
# on the rank, we unroll the index combinations. | ||
shape = ops.shape(logits) | ||
if logits_rank == 1: | ||
indices = [ | ||
(), | ||
] | ||
elif logits_rank == 2: | ||
indices = [(i,) for i in range(shape[0])] | ||
elif logits_rank == 3: | ||
indices = [(i, j) for i in range(shape[0]) for j in range(shape[1])] | ||
|
||
for index_tuple in indices: | ||
sub_labels = labels | ||
sub_logits = logits | ||
sub_out_logits = out_logits | ||
sub_candidate_ids = candidate_ids | ||
# This loop applies multiple indices to go deep several dimensions. | ||
for i in index_tuple: | ||
sub_labels = sub_labels[i] | ||
sub_logits = sub_logits[i] | ||
sub_out_logits = sub_out_logits[i] | ||
if len(ops.shape(sub_candidate_ids)) > 1: | ||
sub_candidate_ids = sub_candidate_ids[i] | ||
|
||
row_positive_idx = ops.argmax(sub_labels) | ||
positive_candidate_id = sub_candidate_ids[row_positive_idx] | ||
|
||
for col_idx in range(sub_out_logits.shape[0]): | ||
same_candidate_as_positive = ops.equal( | ||
positive_candidate_id, sub_candidate_ids[col_idx] | ||
) | ||
is_positive = ops.equal(col_idx, row_positive_idx) | ||
|
||
if ops.convert_to_numpy( | ||
same_candidate_as_positive | ||
) and not ops.convert_to_numpy(is_positive): | ||
# We zeroed the logits. | ||
self.assertAllClose( | ||
sub_out_logits[col_idx], | ||
ops.add( | ||
sub_logits[col_idx], | ||
remove_accidental_hits.SMALLEST_FLOAT, | ||
), | ||
) | ||
else: | ||
# We left the logits unchanged. | ||
self.assertAllClose( | ||
sub_out_logits[col_idx], | ||
sub_logits[col_idx], | ||
) | ||
|
||
def test_mismatched_labels_logits_shapes(self): | ||
layer = remove_accidental_hits.RemoveAccidentalHits() | ||
|
||
with self.assertRaisesRegex( | ||
ValueError, "shape of `labels` and `logits` must match" | ||
): | ||
layer(ops.zeros((10, 20)), ops.zeros((10, 30)), ops.zeros((20,))) | ||
|
||
def test_mismatched_labels_candidates_shapes(self): | ||
layer = remove_accidental_hits.RemoveAccidentalHits() | ||
|
||
with self.assertRaisesRegex( | ||
ValueError, | ||
"shape of `candidate_ids` must match .* `labels`", | ||
): | ||
layer(ops.zeros((10, 20)), ops.zeros((10, 20)), ops.zeros((30,))) | ||
|
||
def test_predict(self): | ||
# Note: for predict, we test with probabilities that have a batch dim. | ||
labels, logits, candidate_ids = self.create_inputs(candidate_ids_rank=2) | ||
|
||
layer = remove_accidental_hits.RemoveAccidentalHits() | ||
in_labels = keras.layers.Input(labels.shape[1:]) | ||
in_logits = keras.layers.Input(logits.shape[1:]) | ||
in_candidate_ids = keras.layers.Input(labels.shape[1:]) | ||
out_logits = layer(in_labels, in_logits, in_candidate_ids) | ||
model = keras.Model( | ||
[in_labels, in_logits, in_candidate_ids], out_logits | ||
) | ||
|
||
model.predict([labels, logits, candidate_ids], batch_size=8) | ||
|
||
def test_serialization(self): | ||
layer = remove_accidental_hits.RemoveAccidentalHits() | ||
restored = deserialize(serialize(layer)) | ||
self.assertDictEqual(layer.get_config(), restored.get_config()) | ||
|
||
def test_model_saving(self): | ||
labels, logits, candidate_ids = self.create_inputs() | ||
|
||
layer = remove_accidental_hits.RemoveAccidentalHits() | ||
in_labels = keras.layers.Input(labels.shape[1:]) | ||
in_logits = keras.layers.Input(logits.shape[1:]) | ||
in_candidate_ids = keras.layers.Input(batch_shape=candidate_ids.shape) | ||
out_logits = layer(in_labels, in_logits, in_candidate_ids) | ||
model = keras.Model( | ||
[in_labels, in_logits, in_candidate_ids], out_logits | ||
) | ||
|
||
self.run_model_saving_test( | ||
model=model, input_data=[labels, logits, candidate_ids] | ||
) |