diff --git a/README.md b/README.md index fa2c7a7..ea11ede 100644 --- a/README.md +++ b/README.md @@ -136,7 +136,7 @@ There are a few reasons for having losses in a metrics library: 1. Most code from this library was originally written for and will still be consumed by Elegy. Since Elegy needs support for calculating cumulative losses, as you will see later, a Metric abstraction called `Losses` was created for this. 2. A couple of API design decisions are shared between the `Loss` and `Metric` APIs. This includes: * `__call__` and `update` both accept any number keyword only arguments. This is used to facilitate composition (see [Combinators](#combinators) section). - * Both classes have the `index_into` and `map_arg` methods that allow them to modify how arguments are consumed. + * Both classes have the `index_into` and `rename_arguments` methods that allow them to modify how arguments are consumed. * Argument names are standardized to be consistent when ever possible, e.g. both `metrics.Accuracy` and `losses.Crossentropy` use the `target` and `preds` arguments. diff --git a/docs/index.md b/docs/index.md index d802091..2050c32 100644 --- a/docs/index.md +++ b/docs/index.md @@ -131,7 +131,7 @@ There are a few reasons for having losses in a metrics library: 1. Most code from this library was originally written for and will still be consumed by Elegy. Since Elegy needs support for calculating cumulative losses, as you will see later, a Metric abstraction called `Losses` was created for this. 2. A couple of API design decisions are shared between the `Loss` and `Metric` APIs. This includes: * `__call__` and `update` both accept any number keyword only arguments. This is used to facilitate composition (see [Combinators](#combinators) section). - * Both classes have the `index_into` and `map_arg` methods that allow them to modify how arguments are consumed. + * Both classes have the `index_into` and `rename_arguments` methods that allow them to modify how arguments are consumed. * Argument names are standardized to be consistent when ever possible, e.g. both `metrics.Accuracy` and `losses.Crossentropy` use the `target` and `preds` arguments. diff --git a/jax_metrics/losses/crossentropy.py b/jax_metrics/losses/crossentropy.py index 2f634fd..6ad378a 100644 --- a/jax_metrics/losses/crossentropy.py +++ b/jax_metrics/losses/crossentropy.py @@ -4,19 +4,10 @@ import jax.numpy as jnp import optax -from jax_metrics import types, utils +from jax_metrics import types from jax_metrics.losses.loss import Loss, Reduction -def smooth_labels( - target: jax.Array, - smoothing: jax.Array, -) -> jax.Array: - smooth_positives = 1.0 - smoothing - smooth_negatives = smoothing / target.shape[-1] - return smooth_positives * target + smooth_negatives - - def crossentropy( target: jax.Array, preds: jax.Array, @@ -27,43 +18,43 @@ def crossentropy( check_bounds: bool = True, ) -> jax.Array: n_classes = preds.shape[-1] + integer_labels = False if target.ndim == preds.ndim - 1: if target.shape != preds.shape[:-1]: raise ValueError( f"Target shape '{target.shape}' does not match preds shape '{preds.shape}'" ) - target = jax.nn.one_hot(target, n_classes) - else: - if target.ndim != preds.ndim: - raise ValueError( - f"Target shape '{target.shape}' does not match preds shape '{preds.shape}'" - ) + if label_smoothing is not None or not from_logits: + target = jax.nn.one_hot(target, n_classes) + else: + integer_labels = True + elif target.ndim != preds.ndim: + raise ValueError( + f"Target shape '{target.shape}' does not match preds shape '{preds.shape}'" + ) if label_smoothing is not None: target = optax.smooth_labels(target, label_smoothing) + loss: jax.Array if from_logits: if binary: loss = optax.sigmoid_binary_cross_entropy(preds, target).mean(axis=-1) + elif integer_labels: + loss = optax.softmax_cross_entropy_with_integer_labels(preds, target) else: loss = optax.softmax_cross_entropy(preds, target) else: preds = jnp.clip(preds, types.EPSILON, 1.0 - types.EPSILON) if binary: - loss = target * jnp.log(preds) # + types.EPSILON) - loss += (1 - target) * jnp.log(1 - preds) # + types.EPSILON) - loss = -loss.mean(axis=-1) + loss = -jnp.mean( + target * jnp.log(preds) + (1 - target) * jnp.log(1 - preds), axis=-1 + ) else: loss = -(target * jnp.log(preds)).sum(axis=-1) - # TODO: implement check_bounds - # if check_bounds: - # # set NaN where target is negative or larger/equal to the number of preds channels - # loss = jnp.where(target < 0, jnp.nan, loss) - # loss = jnp.where(target >= n_classes, jnp.nan, loss) - return loss diff --git a/jax_metrics/losses/loss.py b/jax_metrics/losses/loss.py index 54ad06e..de5a3a3 100644 --- a/jax_metrics/losses/loss.py +++ b/jax_metrics/losses/loss.py @@ -166,14 +166,14 @@ def index_into(self, **kwargs: types.IndexLike) -> "IndexedLoss": """ return IndexedLoss(self, kwargs) - def map_arg(self, **kwargs: str) -> "MapArgsLoss": + def rename_arguments(self, **kwargs: str) -> "MapArgsLoss": """ Returns a loss that renames the keyword arguments expected by `__call__`. Example: ```python - crossentropy_loss = jm.losses.Crossentropy().map_arg(target="y_true", preds="y_pred") + crossentropy_loss = jm.losses.Crossentropy().rename_arguments(target="y_true", preds="y_pred") ... loss = crossentropy_loss(y_true=y, y_pred=logits) ``` diff --git a/jax_metrics/metrics/metric.py b/jax_metrics/metrics/metric.py index 56c0a63..71baaac 100644 --- a/jax_metrics/metrics/metric.py +++ b/jax_metrics/metrics/metric.py @@ -1,3 +1,4 @@ +import dataclasses import typing as tp from abc import abstractmethod @@ -8,7 +9,7 @@ from jax_metrics import types M = tp.TypeVar("M", bound="Metric") -MA = tp.TypeVar("MA", bound="MapArgsMetric") +MA = tp.TypeVar("MA", bound="RenameArguments") Slice = tp.Tuple[tp.Union[int, str], ...] @@ -145,14 +146,14 @@ def index_into(self, **kwargs: types.IndexLike) -> "IndexedMetric": """ return IndexedMetric(self, kwargs) - def map_arg(self, **kwargs: str) -> "MapArgsMetric": + def rename_arguments(self, **kwargs: str) -> "RenameArguments": """ Returns a metric that renames the keyword arguments expected by `.update()`. Example: ```python - mean = jm.metrics.Mean().map_arg(values="loss").reset() + mean = jm.metrics.Mean().rename_arguments(values="loss") ... loss = loss_fn(x, y) mean = mean.update(loss=loss) @@ -162,9 +163,9 @@ def map_arg(self, **kwargs: str) -> "MapArgsMetric": **kwargs: keyword arguments to be renamed Returns: - A MapArgsMetric instance + A RenameArguments instance """ - return MapArgsMetric(self, kwargs) + return RenameArguments(self, kwargs) class SumMetric(Metric): @@ -215,34 +216,32 @@ def compute(self) -> tp.Any: return self.metric.compute() -class MapArgsMetric(Metric): - metric: Metric = field() - args_map: tp.Dict[str, str] = static_field() +Real = str +Expected = str - def __init__(self, metric: Metric, args_map: tp.Dict[str, str]): - self.metric = metric - self.args_map = args_map + +@dataclasses.dataclass +class RenameArguments(Metric): + metric: Metric = field() + real_to_expected: tp.Dict[Real, Expected] = static_field() def reset(self: MA) -> MA: return self.replace(metric=self.metric.reset()) - def update(self: MA, **kwargs: tp.Any) -> MA: - for arg in self.args_map: - if arg not in kwargs: - raise KeyError(f"'{arg}' expected but not given") - - kwarg_updates = { - next_arg: kwargs[prev_arg] for prev_arg, next_arg in self.args_map.items() - } - - # delete previous kwargs - for arg in self.args_map: - del kwargs[arg] + def update(self, **updates: tp.Any) -> "RenameArguments": + for expected in self.real_to_expected.values(): + if expected not in updates: + raise KeyError(f"'{expected}' expected but not given") # add new kwargs - kwargs.update(kwarg_updates) - - return self.replace(metric=self.metric.update(**kwargs)) + updates.update( + { + real: updates[expected] + for real, expected in self.real_to_expected.items() + } + ) + + return self.replace(metric=self.metric.update(**updates)) def compute(self) -> tp.Any: return self.metric.compute() diff --git a/jax_metrics/metrics/reduce.py b/jax_metrics/metrics/reduce.py index 804d265..09762fa 100644 --- a/jax_metrics/metrics/reduce.py +++ b/jax_metrics/metrics/reduce.py @@ -7,7 +7,7 @@ from simple_pytree import field, static_field from jax_metrics import types -from jax_metrics.metrics.metric import Metric, SumMetric +from jax_metrics.metrics.metric import Metric, RenameArguments, SumMetric M = tp.TypeVar("M", bound="Reduce") @@ -86,9 +86,6 @@ def update( Array with the cumulative reduce. """ - if self.total is None: - raise self._not_initialized_error() - # perform update if sample_weight is not None: if sample_weight.ndim > values.ndim: @@ -138,10 +135,10 @@ def update( return self.replace(total=total, count=count) def compute(self) -> jax.Array: - if self.total is None: - raise self._not_initialized_error() - if self.reduction == Reduction.sum: return self.total else: return self.total / self.count + + def from_argument(self, argument: str) -> RenameArguments: + return self.rename_arguments(values=argument)