diff --git a/apax/train/metrics.py b/apax/train/metrics.py index 151cfe68..e5067b0f 100644 --- a/apax/train/metrics.py +++ b/apax/train/metrics.py @@ -9,7 +9,13 @@ log = logging.getLogger(__name__) -class RootAverage(metrics.Average): +class Averagefp64(metrics.Average): + @classmethod + def empty(cls) -> metrics.Metric: + return cls(total=jnp.array(0, jnp.float64), count=jnp.array(0, jnp.int64)) + + +class RootAverage(Averagefp64): """ Modifies the `compute` method of `metrics.Average` to obtain the root of the average. Meant to be used with `mse_fn`. @@ -59,7 +65,7 @@ def make_single_metric(key: str, reduction: str) -> metrics.Average: if reduction == "rmse": metric = RootAverage else: - metric = metrics.Average + metric = Averagefp64 reduction_fn = reduction_fns[reduction] reduction_fn = partial(reduction_fn, key=key)