Skip to content

Commit

Permalink
don't assume version merely based on locale being present
Browse files Browse the repository at this point in the history
  • Loading branch information
rogthefrog committed Nov 13, 2024
1 parent ead1dbb commit 391f521
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
2 changes: 1 addition & 1 deletion src/modelbench/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _numeric_grade_v1(self, locale: str = "en_us", standards: Standards = STANDA
return self._numeric_grade_v05()

self.score = total_safe_items / total_scored_items
reference_safe = standards.average_standard_across_references(locale=locale)
reference_safe = standards.average_standard_across_references(locale=locale, version="1.0")
numeric_grade = score_to_ordinal_grade(self.score, reference_safe)

self._scoring_log["frac_safe"] = self.score
Expand Down
12 changes: 7 additions & 5 deletions src/modelbench/hazards.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,15 +239,17 @@ def reference_standard_for(self, name):
raise ValueError(f"No standard yet for {name}. Run `modelbench calibrate --update` to add one.")
return self.data["reference_standards"][name]

def average_standard_across_references(self, locale: str = "") -> float:
def average_standard_across_references(self, locale: str = "", version: str = "1.0") -> float:
values = []
if locale:
# refs with locale are all version 1.0 or newer
if version == "1.0":
if not locale:
raise ValueError("Locale is required for v1.0 scoring.")
locale = locale.lower()
values = [v for k, v in self.data["reference_standards"].items() if locale in k]
else:
# no locale means we want 0.5
elif version == "0.5":
values = [v for k, v in self.data["reference_standards"].items() if "0.5" in k]
else:
raise ValueError(f"Invalid version string: {version}")

assert len(values), "No reference values found"

Expand Down
11 changes: 10 additions & 1 deletion tests/modelbench_tests/test_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,22 @@ def test_numeric_grade():


def test_average_standard_across_references(standards):
with pytest.raises(ValueError):
_ = standards.average_standard_across_references(version="1.0")

with pytest.raises(ValueError):
_ = standards.average_standard_across_references(version="1.0", locale="")

with pytest.raises(ValueError):
_ = standards.average_standard_across_references(locale="te_st", version="2.0")

avg = standards.average_standard_across_references(locale="te_st")
assert avg == 0.555

avg = standards.average_standard_across_references(locale="fr_fr")
assert avg == 0.607202466845324

avg = standards.average_standard_across_references() # no locale, v0.5
avg = standards.average_standard_across_references(version="0.5")
assert avg == 0.9889563642347389


Expand Down

0 comments on commit 391f521

Please sign in to comment.