Skip to content

Commit

Permalink
Add option to densely mask weighted_mean_absolute_error loss.
Browse files Browse the repository at this point in the history
The same processing was applied las to `weighted_mean_squared_error`.

PiperOrigin-RevId: 684767761
  • Loading branch information
kmaninis authored and Scenic Authors committed Oct 11, 2024
1 parent 8f58121 commit bfab669
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions scenic/model_lib/base_models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

def psum_metric_normalizer(
metrics: Tuple[jnp.ndarray, jnp.ndarray],
axis_name: Union[str,
Tuple[str]] = 'batch') -> Tuple[jnp.ndarray, jnp.ndarray]:
axis_name: Union[str, Tuple[str, ...]] = 'batch'
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Applies psum over the given tuple of (metric, normalizer)."""
psumed_metric = jax.lax.psum(jnp.sum(metrics[0]), axis_name=axis_name)
psumed_normalizer = jax.lax.psum(jnp.sum(metrics[1]), axis_name=axis_name)
Expand Down Expand Up @@ -603,15 +603,15 @@ def weighted_squared_error(
Args:
predictions: Output of model in shape shape [batch, ..., n_features].
targets: Array of shape [batch, ..., n_features].
weights: None or array of shape [batch,] This is the weight to apply to the
loss computed for each example in the batch. Can be used to ignore padded
examples in the batch.
weights: None or array of shape [batch, ...]. This is the weight to apply
to the loss computed for each example in the batch. Can be used to ignore
padded examples in the batch.
axis: The axis (or axes) to compute the loss over. If not specified, all
dimensions besides the leading batch dimension are used.
Returns:
The mean squared error for each example in the given batch. The output shape
is [batch,].
depends on axis.
"""
if predictions.ndim != targets.ndim:
raise ValueError(
Expand Down Expand Up @@ -671,7 +671,8 @@ def weighted_mean_squared_error(
def weighted_absolute_error(
predictions: jnp.ndarray,
targets: jnp.ndarray,
weights: Optional[jnp.ndarray] = None) -> jnp.ndarray:
weights: Optional[jnp.ndarray] = None,
axis: Optional[Union[int, Tuple[int, ...]]] = None) -> jnp.ndarray:
"""Computes weighted absolute error given predictions and targets.
This computes the absolute_error of examples in a single, potentially
Expand All @@ -682,23 +683,28 @@ def weighted_absolute_error(
Args:
predictions: Output of model in shape shape [batch, ..., n_features].
targets: Array of shape [batch, ..., n_features].
weights: None or array of shape [batch,] This is the weight to apply to the
loss computed for each example in the batch. Can be used to ignore padded
examples in the batch.
weights: None or array of shape [batch, ...] This is the weight to apply to
the loss computed for each example in the batch. Can be used to ignore
padded examples in the batch.
axis: The axis (or axes) to compute the loss over. If not specified, all
dimensions besides the leading batch dimension are used.
Returns:
The mean absolute error for each example in the given batch. The output
shape is [batch,].
shape depends on axis.
"""
if predictions.ndim != targets.ndim:
raise ValueError(
'Incorrect shapes. Got shape %s predictions and %s targets' %
(str(predictions.shape), str(targets.shape)))
if axis is None:
# Sum over all features in each example in the batch:
axis = tuple(range(1, predictions.ndim))

error = targets - predictions
loss = jnp.absolute(error)
# Sum over all features in each example in the batch:
loss = jnp.sum(loss, axis=tuple(range(1, predictions.ndim)))
loss = jnp.sum(loss, axis=axis)
if weights is not None:
loss = apply_weights(loss, weights)
return loss
Expand All @@ -707,22 +713,25 @@ def weighted_absolute_error(
def weighted_mean_absolute_error(
predictions: jnp.ndarray,
targets: jnp.ndarray,
weights: Optional[jnp.ndarray] = None) -> jnp.ndarray:
weights: Optional[jnp.ndarray] = None,
axis: Optional[Union[int, Tuple[int, ...]]] = None) -> jnp.ndarray:
"""Weighted mean of weighted_unnormalized_mean_absolute_error.
Args:
predictions: Output of model in shape [batch, ..., num_features].
targets: Targets of shape [batch, ..., num_features].
weights: None or array of shape [batch] This is the weight to apply to the
loss computed for each example in the batch. Can be used to ignore padded
examples in the batch.
weights: None or array of shape [batch, ...]. This is the weight to apply
to the loss computed for each example in the batch. Can be used to ignore
padded examples in the batch.
axis: The axis (or axes) to compute the loss over. If not specified, all
dimensions besides the leading batch dimension are used.
Returns:
The averaged mean absolute error of all the examples in the given batch as
a scalar.
"""
unnormalized_mae = weighted_absolute_error(
predictions=predictions, targets=targets, weights=weights)
predictions=predictions, targets=targets, weights=weights, axis=axis)

if weights is not None:
# Divide by sum of weights:
Expand Down

0 comments on commit bfab669

Please sign in to comment.