From 1c3e2d5c3b2fd20e8d1ed83acfce58e10c6b7dca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 2 Jan 2024 17:35:41 +0100 Subject: [PATCH] added fp64 version of clu `Average` --- apax/train/metrics.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/apax/train/metrics.py b/apax/train/metrics.py index 151cfe68..e5067b0f 100644 --- a/apax/train/metrics.py +++ b/apax/train/metrics.py @@ -9,7 +9,13 @@ log = logging.getLogger(__name__) -class RootAverage(metrics.Average): +class Averagefp64(metrics.Average): + @classmethod + def empty(cls) -> metrics.Metric: + return cls(total=jnp.array(0, jnp.float64), count=jnp.array(0, jnp.int64)) + + +class RootAverage(Averagefp64): """ Modifies the `compute` method of `metrics.Average` to obtain the root of the average. Meant to be used with `mse_fn`. @@ -59,7 +65,7 @@ def make_single_metric(key: str, reduction: str) -> metrics.Average: if reduction == "rmse": metric = RootAverage else: - metric = metrics.Average + metric = Averagefp64 reduction_fn = reduction_fns[reduction] reduction_fn = partial(reduction_fn, key=key)