diff --git a/apebench/_base_scenario.py b/apebench/_base_scenario.py index 1c9060c..fc9b035 100644 --- a/apebench/_base_scenario.py +++ b/apebench/_base_scenario.py @@ -15,7 +15,7 @@ from ._corrected_stepper import CorrectedStepper from ._extensions import arch_extensions -from .components import metrics_dict +from .components import metric_dict class BaseScenario(eqx.Module, ABC): @@ -822,7 +822,7 @@ def perform_tests( for metric_config in metrics: metric_args = metric_config.split(";") metric_name = metric_args[0] - metric_constructor = metrics_dict[metric_name] + metric_constructor = metric_dict[metric_name] metric_fn = metric_constructor(metric_config) results[metric_config] = metric_fn diff --git a/apebench/components/__init__.py b/apebench/components/__init__.py index be00168..f5afbce 100644 --- a/apebench/components/__init__.py +++ b/apebench/components/__init__.py @@ -1,3 +1,3 @@ -from ._metrics import metrics_dict +from ._metrics import metric_dict -__all__ = ["metrics_dict"] +__all__ = ["metric_dict"] diff --git a/apebench/components/_metrics.py b/apebench/components/_metrics.py index 70f3cce..07d694f 100644 --- a/apebench/components/_metrics.py +++ b/apebench/components/_metrics.py @@ -3,7 +3,7 @@ import exponax as ex from jaxtyping import Array, Float -metrics_dict: Dict[ +metric_dict: Dict[ str, Callable[ [