Skip to content

Commit

Permalink
refactor rename_arguments + improve Crossentropy
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Mar 4, 2023
1 parent 8cc317e commit 99040b3
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 62 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ There are a few reasons for having losses in a metrics library:
1. Most code from this library was originally written for and will still be consumed by Elegy. Since Elegy needs support for calculating cumulative losses, as you will see later, a Metric abstraction called `Losses` was created for this.
2. A couple of API design decisions are shared between the `Loss` and `Metric` APIs. This includes:
* `__call__` and `update` both accept any number keyword only arguments. This is used to facilitate composition (see [Combinators](#combinators) section).
* Both classes have the `index_into` and `map_arg` methods that allow them to modify how arguments are consumed.
* Both classes have the `index_into` and `rename_arguments` methods that allow them to modify how arguments are consumed.
* Argument names are standardized to be consistent when ever possible, e.g. both `metrics.Accuracy` and `losses.Crossentropy` use the `target` and `preds` arguments.

</details>
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ There are a few reasons for having losses in a metrics library:
1. Most code from this library was originally written for and will still be consumed by Elegy. Since Elegy needs support for calculating cumulative losses, as you will see later, a Metric abstraction called `Losses` was created for this.
2. A couple of API design decisions are shared between the `Loss` and `Metric` APIs. This includes:
* `__call__` and `update` both accept any number keyword only arguments. This is used to facilitate composition (see [Combinators](#combinators) section).
* Both classes have the `index_into` and `map_arg` methods that allow them to modify how arguments are consumed.
* Both classes have the `index_into` and `rename_arguments` methods that allow them to modify how arguments are consumed.
* Argument names are standardized to be consistent when ever possible, e.g. both `metrics.Accuracy` and `losses.Crossentropy` use the `target` and `preds` arguments.

</details>
Expand Down
41 changes: 16 additions & 25 deletions jax_metrics/losses/crossentropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,10 @@
import jax.numpy as jnp
import optax

from jax_metrics import types, utils
from jax_metrics import types
from jax_metrics.losses.loss import Loss, Reduction


def smooth_labels(
target: jax.Array,
smoothing: jax.Array,
) -> jax.Array:
smooth_positives = 1.0 - smoothing
smooth_negatives = smoothing / target.shape[-1]
return smooth_positives * target + smooth_negatives


def crossentropy(
target: jax.Array,
preds: jax.Array,
Expand All @@ -27,43 +18,43 @@ def crossentropy(
check_bounds: bool = True,
) -> jax.Array:
n_classes = preds.shape[-1]
integer_labels = False

if target.ndim == preds.ndim - 1:
if target.shape != preds.shape[:-1]:
raise ValueError(
f"Target shape '{target.shape}' does not match preds shape '{preds.shape}'"
)
target = jax.nn.one_hot(target, n_classes)
else:
if target.ndim != preds.ndim:
raise ValueError(
f"Target shape '{target.shape}' does not match preds shape '{preds.shape}'"
)
if label_smoothing is not None or not from_logits:
target = jax.nn.one_hot(target, n_classes)
else:
integer_labels = True
elif target.ndim != preds.ndim:
raise ValueError(
f"Target shape '{target.shape}' does not match preds shape '{preds.shape}'"
)

if label_smoothing is not None:
target = optax.smooth_labels(target, label_smoothing)

loss: jax.Array
if from_logits:
if binary:
loss = optax.sigmoid_binary_cross_entropy(preds, target).mean(axis=-1)
elif integer_labels:
loss = optax.softmax_cross_entropy_with_integer_labels(preds, target)
else:
loss = optax.softmax_cross_entropy(preds, target)
else:
preds = jnp.clip(preds, types.EPSILON, 1.0 - types.EPSILON)

if binary:
loss = target * jnp.log(preds) # + types.EPSILON)
loss += (1 - target) * jnp.log(1 - preds) # + types.EPSILON)
loss = -loss.mean(axis=-1)
loss = -jnp.mean(
target * jnp.log(preds) + (1 - target) * jnp.log(1 - preds), axis=-1
)
else:
loss = -(target * jnp.log(preds)).sum(axis=-1)

# TODO: implement check_bounds
# if check_bounds:
# # set NaN where target is negative or larger/equal to the number of preds channels
# loss = jnp.where(target < 0, jnp.nan, loss)
# loss = jnp.where(target >= n_classes, jnp.nan, loss)

return loss


Expand Down
4 changes: 2 additions & 2 deletions jax_metrics/losses/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,14 @@ def index_into(self, **kwargs: types.IndexLike) -> "IndexedLoss":
"""
return IndexedLoss(self, kwargs)

def map_arg(self, **kwargs: str) -> "MapArgsLoss":
def rename_arguments(self, **kwargs: str) -> "MapArgsLoss":
"""
Returns a loss that renames the keyword arguments expected by `__call__`.
Example:
```python
crossentropy_loss = jm.losses.Crossentropy().map_arg(target="y_true", preds="y_pred")
crossentropy_loss = jm.losses.Crossentropy().rename_arguments(target="y_true", preds="y_pred")
...
loss = crossentropy_loss(y_true=y, y_pred=logits)
```
Expand Down
51 changes: 25 additions & 26 deletions jax_metrics/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import typing as tp
from abc import abstractmethod

Expand All @@ -8,7 +9,7 @@
from jax_metrics import types

M = tp.TypeVar("M", bound="Metric")
MA = tp.TypeVar("MA", bound="MapArgsMetric")
MA = tp.TypeVar("MA", bound="RenameArguments")
Slice = tp.Tuple[tp.Union[int, str], ...]


Expand Down Expand Up @@ -145,14 +146,14 @@ def index_into(self, **kwargs: types.IndexLike) -> "IndexedMetric":
"""
return IndexedMetric(self, kwargs)

def map_arg(self, **kwargs: str) -> "MapArgsMetric":
def rename_arguments(self, **kwargs: str) -> "RenameArguments":
"""
Returns a metric that renames the keyword arguments expected by `.update()`.
Example:
```python
mean = jm.metrics.Mean().map_arg(values="loss").reset()
mean = jm.metrics.Mean().rename_arguments(values="loss")
...
loss = loss_fn(x, y)
mean = mean.update(loss=loss)
Expand All @@ -162,9 +163,9 @@ def map_arg(self, **kwargs: str) -> "MapArgsMetric":
**kwargs: keyword arguments to be renamed
Returns:
A MapArgsMetric instance
A RenameArguments instance
"""
return MapArgsMetric(self, kwargs)
return RenameArguments(self, kwargs)


class SumMetric(Metric):
Expand Down Expand Up @@ -215,34 +216,32 @@ def compute(self) -> tp.Any:
return self.metric.compute()


class MapArgsMetric(Metric):
metric: Metric = field()
args_map: tp.Dict[str, str] = static_field()
Real = str
Expected = str

def __init__(self, metric: Metric, args_map: tp.Dict[str, str]):
self.metric = metric
self.args_map = args_map

@dataclasses.dataclass
class RenameArguments(Metric):
metric: Metric = field()
real_to_expected: tp.Dict[Real, Expected] = static_field()

def reset(self: MA) -> MA:
return self.replace(metric=self.metric.reset())

def update(self: MA, **kwargs: tp.Any) -> MA:
for arg in self.args_map:
if arg not in kwargs:
raise KeyError(f"'{arg}' expected but not given")

kwarg_updates = {
next_arg: kwargs[prev_arg] for prev_arg, next_arg in self.args_map.items()
}

# delete previous kwargs
for arg in self.args_map:
del kwargs[arg]
def update(self, **updates: tp.Any) -> "RenameArguments":
for expected in self.real_to_expected.values():
if expected not in updates:
raise KeyError(f"'{expected}' expected but not given")

# add new kwargs
kwargs.update(kwarg_updates)

return self.replace(metric=self.metric.update(**kwargs))
updates.update(
{
real: updates[expected]
for real, expected in self.real_to_expected.items()
}
)

return self.replace(metric=self.metric.update(**updates))

def compute(self) -> tp.Any:
return self.metric.compute()
11 changes: 4 additions & 7 deletions jax_metrics/metrics/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from simple_pytree import field, static_field

from jax_metrics import types
from jax_metrics.metrics.metric import Metric, SumMetric
from jax_metrics.metrics.metric import Metric, RenameArguments, SumMetric

M = tp.TypeVar("M", bound="Reduce")

Expand Down Expand Up @@ -86,9 +86,6 @@ def update(
Array with the cumulative reduce.
"""

if self.total is None:
raise self._not_initialized_error()

# perform update
if sample_weight is not None:
if sample_weight.ndim > values.ndim:
Expand Down Expand Up @@ -138,10 +135,10 @@ def update(
return self.replace(total=total, count=count)

def compute(self) -> jax.Array:
if self.total is None:
raise self._not_initialized_error()

if self.reduction == Reduction.sum:
return self.total
else:
return self.total / self.count

def from_argument(self, argument: str) -> RenameArguments:
return self.rename_arguments(values=argument)

0 comments on commit 99040b3

Please sign in to comment.