From bfab669beb04841f9621b8ff9ee7b2125c6135cb Mon Sep 17 00:00:00 2001 From: Kevis-Kokitsi Maninis Date: Fri, 11 Oct 2024 02:59:39 -0700 Subject: [PATCH] Add option to densely mask `weighted_mean_absolute_error` loss. The same processing was applied las to `weighted_mean_squared_error`. PiperOrigin-RevId: 684767761 --- scenic/model_lib/base_models/model_utils.py | 43 +++++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/scenic/model_lib/base_models/model_utils.py b/scenic/model_lib/base_models/model_utils.py index 6c290ac4..3c7faa71 100644 --- a/scenic/model_lib/base_models/model_utils.py +++ b/scenic/model_lib/base_models/model_utils.py @@ -30,8 +30,8 @@ def psum_metric_normalizer( metrics: Tuple[jnp.ndarray, jnp.ndarray], - axis_name: Union[str, - Tuple[str]] = 'batch') -> Tuple[jnp.ndarray, jnp.ndarray]: + axis_name: Union[str, Tuple[str, ...]] = 'batch' +) -> Tuple[jnp.ndarray, jnp.ndarray]: """Applies psum over the given tuple of (metric, normalizer).""" psumed_metric = jax.lax.psum(jnp.sum(metrics[0]), axis_name=axis_name) psumed_normalizer = jax.lax.psum(jnp.sum(metrics[1]), axis_name=axis_name) @@ -603,15 +603,15 @@ def weighted_squared_error( Args: predictions: Output of model in shape shape [batch, ..., n_features]. targets: Array of shape [batch, ..., n_features]. - weights: None or array of shape [batch,] This is the weight to apply to the - loss computed for each example in the batch. Can be used to ignore padded - examples in the batch. + weights: None or array of shape [batch, ...]. This is the weight to apply + to the loss computed for each example in the batch. Can be used to ignore + padded examples in the batch. axis: The axis (or axes) to compute the loss over. If not specified, all dimensions besides the leading batch dimension are used. Returns: The mean squared error for each example in the given batch. The output shape - is [batch,]. + depends on axis. """ if predictions.ndim != targets.ndim: raise ValueError( @@ -671,7 +671,8 @@ def weighted_mean_squared_error( def weighted_absolute_error( predictions: jnp.ndarray, targets: jnp.ndarray, - weights: Optional[jnp.ndarray] = None) -> jnp.ndarray: + weights: Optional[jnp.ndarray] = None, + axis: Optional[Union[int, Tuple[int, ...]]] = None) -> jnp.ndarray: """Computes weighted absolute error given predictions and targets. This computes the absolute_error of examples in a single, potentially @@ -682,23 +683,28 @@ def weighted_absolute_error( Args: predictions: Output of model in shape shape [batch, ..., n_features]. targets: Array of shape [batch, ..., n_features]. - weights: None or array of shape [batch,] This is the weight to apply to the - loss computed for each example in the batch. Can be used to ignore padded - examples in the batch. + weights: None or array of shape [batch, ...] This is the weight to apply to + the loss computed for each example in the batch. Can be used to ignore + padded examples in the batch. + axis: The axis (or axes) to compute the loss over. If not specified, all + dimensions besides the leading batch dimension are used. Returns: The mean absolute error for each example in the given batch. The output - shape is [batch,]. + shape depends on axis. """ if predictions.ndim != targets.ndim: raise ValueError( 'Incorrect shapes. Got shape %s predictions and %s targets' % (str(predictions.shape), str(targets.shape))) + if axis is None: + # Sum over all features in each example in the batch: + axis = tuple(range(1, predictions.ndim)) error = targets - predictions loss = jnp.absolute(error) # Sum over all features in each example in the batch: - loss = jnp.sum(loss, axis=tuple(range(1, predictions.ndim))) + loss = jnp.sum(loss, axis=axis) if weights is not None: loss = apply_weights(loss, weights) return loss @@ -707,22 +713,25 @@ def weighted_absolute_error( def weighted_mean_absolute_error( predictions: jnp.ndarray, targets: jnp.ndarray, - weights: Optional[jnp.ndarray] = None) -> jnp.ndarray: + weights: Optional[jnp.ndarray] = None, + axis: Optional[Union[int, Tuple[int, ...]]] = None) -> jnp.ndarray: """Weighted mean of weighted_unnormalized_mean_absolute_error. Args: predictions: Output of model in shape [batch, ..., num_features]. targets: Targets of shape [batch, ..., num_features]. - weights: None or array of shape [batch] This is the weight to apply to the - loss computed for each example in the batch. Can be used to ignore padded - examples in the batch. + weights: None or array of shape [batch, ...]. This is the weight to apply + to the loss computed for each example in the batch. Can be used to ignore + padded examples in the batch. + axis: The axis (or axes) to compute the loss over. If not specified, all + dimensions besides the leading batch dimension are used. Returns: The averaged mean absolute error of all the examples in the given batch as a scalar. """ unnormalized_mae = weighted_absolute_error( - predictions=predictions, targets=targets, weights=weights) + predictions=predictions, targets=targets, weights=weights, axis=axis) if weights is not None: # Divide by sum of weights: