Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hard Negative Mining Layer. #24

Merged
merged 1 commit into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras_rs/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@
from keras_rs.src.layers.retrieval.brute_force_retrieval import (
BruteForceRetrieval,
)
from keras_rs.src.layers.retrieval.hard_negative_mining import (
HardNegativeMining,
)
1 change: 1 addition & 0 deletions keras_rs/src/layers/modeling/dot_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class DotInteraction(keras.layers.Layer):
entries will be zeros. Otherwise, the output will be only the lower
triangular part of the interaction matrix. The latter saves space
but is much slower.
**kwargs: Args to pass to the base class.

References:
- [M. Naumov et al.](https://arxiv.org/abs/1906.00091)
Expand Down
12 changes: 6 additions & 6 deletions keras_rs/src/layers/modeling/dot_interaction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,11 +95,6 @@ def test_invalid_input_different_shapes(self):
with self.assertRaises(ValueError):
layer(unequal_shape_input)

def test_serialization(self):
layer = DotInteraction()
restored = deserialize(serialize(layer))
self.assertDictEqual(layer.get_config(), restored.get_config())

@parameterized.named_parameters(
(
"self_interaction_false_skip_gather_false",
Expand Down Expand Up @@ -132,7 +127,12 @@ def test_predict(self, self_interaction, skip_gather):
x = keras.layers.Dense(units=1)(x)
model = keras.Model([feature1, feature2, feature3], x)

model.predict(self.input)
model.predict(self.input, batch_size=2)

def test_serialization(self):
layer = DotInteraction()
restored = deserialize(serialize(layer))
self.assertDictEqual(layer.get_config(), restored.get_config())

def test_model_saving(self):
feature1 = keras.layers.Input(shape=(5,))
Expand Down
1 change: 1 addition & 0 deletions keras_rs/src/layers/modeling/feature_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class FeatureCross(keras.layers.Layer):
Regularizer to use for the kernel matrix.
bias_regularizer: string or `keras.regularizer` regularizer.
Regularizer to use for the bias vector.
**kwargs: Args to pass to the base class.

Example:

Expand Down
33 changes: 20 additions & 13 deletions keras_rs/src/layers/modeling/feature_cross_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,6 @@ def test_invalid_diag_scale(self):
with self.assertRaises(ValueError):
FeatureCross(diag_scale=-1.0)

def test_serialization(self):
sampler = FeatureCross(projection_dim=None, pre_activation="swish")
restored = deserialize(serialize(sampler))
self.assertDictEqual(sampler.get_config(), restored.get_config())

def test_diag_scale(self):
layer = FeatureCross(
projection_dim=None, diag_scale=1.0, kernel_initializer="ones"
Expand All @@ -81,16 +76,28 @@ def test_pre_activation(self):

self.assertAllClose(self.x, output)

def test_predict(self):
x0 = keras.layers.Input(shape=(3,))
x1 = FeatureCross(projection_dim=None)(x0, x0)
x2 = FeatureCross(projection_dim=None)(x0, x1)
logits = keras.layers.Dense(units=1)(x2)
model = keras.Model(x0, logits)

model.predict(self.x0, batch_size=2)

def test_serialization(self):
sampler = FeatureCross(projection_dim=None, pre_activation="swish")
restored = deserialize(serialize(sampler))
self.assertDictEqual(sampler.get_config(), restored.get_config())

def test_model_saving(self):
def get_model():
x0 = keras.layers.Input(shape=(3,))
x1 = FeatureCross(projection_dim=None)(x0, x0)
x2 = FeatureCross(projection_dim=None)(x0, x1)
logits = keras.layers.Dense(units=1)(x2)
model = keras.Model(x0, logits)
return model
x0 = keras.layers.Input(shape=(3,))
x1 = FeatureCross(projection_dim=None)(x0, x0)
x2 = FeatureCross(projection_dim=None)(x0, x1)
logits = keras.layers.Dense(units=1)(x2)
model = keras.Model(x0, logits)

self.run_model_saving_test(
model=get_model(),
model=model,
input_data=self.x0,
)
111 changes: 111 additions & 0 deletions keras_rs/src/layers/retrieval/hard_negative_mining.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Any

import keras
import numpy as np
from keras import ops

from keras_rs.src import types
from keras_rs.src.api_export import keras_rs_export

MAX_FLOAT = np.finfo(np.float32).max / 100.0


def _gather_elements_along_row(
data: types.Tensor, column_indices: types.Tensor
) -> types.Tensor:
"""Gathers elements from a 2D tensor given the column indices of each row.

First, gets the flat 1D indices to gather from. Then flattens the data to 1D
and uses `ops.take()` to generate 1D output and finally reshapes the output
back to 2D.

Args:
data: A [N, M] 2D `Tensor`.
column_indices: A [N, K] 2D `Tensor` denoting for each row, the K column
indices to gather elements from the data `Tensor`.

Returns:
A [N, K] `Tensor` including output elements gathered from data `Tensor`.

Raises:
ValueError: if the first dimensions of data and column_indices don't
match.
"""
num_row, num_column, *_ = ops.shape(data)
num_gathered = ops.shape(column_indices)[1]
row_indices = ops.tile(
ops.expand_dims(ops.arange(num_row), -1), [1, num_gathered]
)
flat_data = ops.reshape(data, [-1])
flat_indices = ops.reshape(
ops.add(ops.multiply(row_indices, num_column), column_indices), [-1]
)
return ops.reshape(
ops.take(flat_data, flat_indices), [num_row, num_gathered]
)


@keras_rs_export("keras_rs.layers.HardNegativeMining")
class HardNegativeMining(keras.layers.Layer):
"""Transforms logits and labels to return hard negatives.

Args:
num_hard_negatives: How many hard negatives to return.
**kwargs: Args to pass to the base class.
"""

def __init__(self, num_hard_negatives: int, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._num_hard_negatives = num_hard_negatives
self.built = True

def call(
self, logits: types.Tensor, labels: types.Tensor
) -> tuple[types.Tensor, types.Tensor]:
"""Filters logits and labels with per-query hard negative mining.

The result will include logits and labels for `num_hard_negatives`
negatives as well as the positive candidate.

Args:
logits: `[batch_size, number_of_candidates]` tensor of logits.
labels: `[batch_size, number_of_candidates]` one-hot tensor of
labels.

Returns:
tuple containing:
- logits: `[batch_size, num_hard_negatives + 1]` tensor of logits.
- labels: `[batch_size, num_hard_negatives + 1]` one-hot tensor of
labels.
"""

# Number of sampled logits, i.e, the number of hard negatives to be
# sampled (k) + number of true logit (1) per query, capped by batch
# size.
num_logits = ops.shape(logits)[1]
if isinstance(num_logits, int):
num_sampled = min(self._num_hard_negatives + 1, num_logits)
else:
num_sampled = ops.minimum(self._num_hard_negatives + 1, num_logits)
# To gather indices of top k negative logits per row (query) in logits,
# true logits need to be excluded. First replace the true logits
# (corresponding to positive labels) with a large score value and then
# select the top k + 1 logits from each row so that selected indices
# include the indices of true logit + top k negative logits. This
# approach is to avoid using inefficient masking when excluding true
# logits.

# For each query, get the indices of the logits which have the highest
# k + 1 logit values, including the highest k negative logits and one
# true logit.
_, col_indices = ops.top_k(
ops.add(logits, ops.multiply(labels, MAX_FLOAT)),
k=num_sampled,
sorted=False,
)

# Gather sampled logits and corresponding labels.
logits = _gather_elements_along_row(logits, col_indices)
labels = _gather_elements_along_row(labels, col_indices)

return logits, labels
97 changes: 97 additions & 0 deletions keras_rs/src/layers/retrieval/hard_negative_mining_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import keras
from absl.testing import parameterized
from keras import ops
from keras.layers import deserialize
from keras.layers import serialize

from keras_rs.src import testing
from keras_rs.src.layers.retrieval import hard_negative_mining


class HardNegativeMiningTest(testing.TestCase, parameterized.TestCase):
@parameterized.parameters(42, 123, 8391, 12390, 1230)
def test_call(self, random_seed):
num_hard_negatives = 3
# (num_queries, num_candidates)
shape = (2, 20)
rng = keras.random.SeedGenerator(random_seed)

logits = keras.random.uniform(shape, dtype="float32", seed=rng)
labels = ops.transpose(
keras.random.shuffle(
ops.transpose(ops.eye(*shape, dtype="float32")), seed=rng
)
)

out_logits, out_labels = hard_negative_mining.HardNegativeMining(
num_hard_negatives
)(logits, labels)

self.assertEqual(out_logits.shape[-1], num_hard_negatives + 1)

# Logits for positives are always returned.
self.assertAllClose(
ops.sum(out_logits * out_labels, axis=1),
ops.sum(logits * labels, axis=1),
)

# Set the logits for labels to be highest to ignore effect of labels.
logits = logits + labels * 1000.0

out_logits, _ = hard_negative_mining.HardNegativeMining(
num_hard_negatives
)(logits, labels)

# Highest K logits are always returned.
self.assertAllClose(
ops.sort(logits, axis=1)[:, -num_hard_negatives - 1 :],
ops.sort(out_logits),
)

def test_predict(self):
num_candidates = 20
in_logits = keras.layers.Input(shape=(num_candidates,))
in_labels = keras.layers.Input(shape=(num_candidates,))
out_logits, out_labels = hard_negative_mining.HardNegativeMining(
num_hard_negatives=3
)(in_logits, in_labels)
model = keras.Model([in_logits, in_labels], [out_logits, out_labels])

shape = (25, num_candidates)
rng = keras.random.SeedGenerator(42)
logits = keras.random.uniform(shape, dtype="float32", seed=rng)
labels = ops.transpose(
keras.random.shuffle(
ops.transpose(ops.eye(*shape, dtype="float32")), seed=rng
)
)

model.predict([logits, labels], batch_size=10)

def test_serialization(self):
layer = hard_negative_mining.HardNegativeMining(num_hard_negatives=3)
restored = deserialize(serialize(layer))
self.assertDictEqual(layer.get_config(), restored.get_config())

def test_model_saving(self):
num_candidates = 20
in_logits = keras.layers.Input(shape=(num_candidates,))
in_labels = keras.layers.Input(shape=(num_candidates,))
out_logits, out_labels = hard_negative_mining.HardNegativeMining(
num_hard_negatives=3
)(in_logits, in_labels)
model = keras.Model([in_logits, in_labels], [out_logits, out_labels])

shape = (2, num_candidates)
rng = keras.random.SeedGenerator(42)
logits = keras.random.uniform(shape, dtype="float32", seed=rng)
labels = ops.transpose(
keras.random.shuffle(
ops.transpose(ops.eye(*shape, dtype="float32")), seed=rng
)
)

self.run_model_saving_test(
model=model,
input_data=[logits, labels],
)