Skip to content

Commit

Permalink
added fp64 version of clu Average
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Jan 2, 2024
1 parent b19adfd commit 1c3e2d5
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions apax/train/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
log = logging.getLogger(__name__)


class RootAverage(metrics.Average):
class Averagefp64(metrics.Average):
@classmethod
def empty(cls) -> metrics.Metric:
return cls(total=jnp.array(0, jnp.float64), count=jnp.array(0, jnp.int64))


class RootAverage(Averagefp64):
"""
Modifies the `compute` method of `metrics.Average` to obtain the root of the average.
Meant to be used with `mse_fn`.
Expand Down Expand Up @@ -59,7 +65,7 @@ def make_single_metric(key: str, reduction: str) -> metrics.Average:
if reduction == "rmse":
metric = RootAverage
else:
metric = metrics.Average
metric = Averagefp64

reduction_fn = reduction_fns[reduction]
reduction_fn = partial(reduction_fn, key=key)
Expand Down

0 comments on commit 1c3e2d5

Please sign in to comment.