Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add Multieurlex #120

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
1 change: 0 additions & 1 deletion lm_eval/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,6 @@ def __init__(self, lm: LM, cache_db: str):
Path to the `cache` database.
"""
from sqlitedict import SqliteDict

self.lm = lm
if os.path.dirname(cache_db):
os.makedirs(os.path.dirname(cache_db), exist_ok=True)
Expand Down
4 changes: 2 additions & 2 deletions lm_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ def evaluate(
metrics, example = output
example.update(fewshot_logging_info)
example.update(task.get_logging_info())
example_logger.info(json.dumps(example))
example_logger.info(json.dumps(example, ensure_ascii=False))
else:
metrics = output
example = fewshot_logging_info
example.update(task.get_logging_info())
example_logger.info(json.dumps(example))
example_logger.info(json.dumps(example, ensure_ascii=False))

for metric, value in metrics.items():
vals[(task_template_key, metric)].append(value)
Expand Down
11 changes: 7 additions & 4 deletions lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from . import jigsaw_unintended_bias
from . import lama
from . import lince
from . import multi_eurlex
from . import piaf
from . import race
from . import schema_guided_dstc8
Expand Down Expand Up @@ -134,6 +135,8 @@
# WMT
# Format: `wmt{year}_{lang1}_{lang2}`
**wmt.construct_tasks(),
# MultiEURLEX
"multi_eurlex_mt": multi_eurlex.MultiEURLEXMT,
# BLiMP
"blimp_adjunct_island": blimp.BlimpAdjunctIsland,
"blimp_anaphor_gender_agreement": blimp.BlimpAnaphorGenderAgreement,
Expand Down Expand Up @@ -209,10 +212,10 @@
# TODO: Not Yet Available in `promptsource/eval-hackathon`
########################################################
# GEM/mlsum
# "mlsum_es": gem_mlsum.GEMMLSUMEs,
# "mlsum_de": gem_mlsum.GEMMLSUMDe,
# "mlsum_es_covid_challenge_set": gem_mlsum.GEMMLSUMEsChallgeTestCovid,
# "mlsum_de_covid_challenge_set": gem_mlsum.GEMMLSUMDeChallgeTestCovid,
"mlsum_es": gem_mlsum.GEMMLSUMEs,
"mlsum_de": gem_mlsum.GEMMLSUMDe,
"mlsum_es_covid_challenge_set": gem_mlsum.GEMMLSUMEsChallgeTestCovid,
"mlsum_de_covid_challenge_set": gem_mlsum.GEMMLSUMDeChallgeTestCovid,
# LAMA
# "bigscience-lama": lama.BigScienceLAMA,
########################################################
Expand Down
65 changes: 65 additions & 0 deletions lm_eval/tasks/multi_eurlex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
MultiEURLEX
"""
from typing import Dict, List, Optional
from lm_eval.api.task import PromptSourceTask

_CITATION = """
"""

_LANGUAGES = [
"en",
"da",
"de",
"nl",
"sv",
"bg",
"cs",
"hr",
"pl",
"sk",
"sl",
"es",
"fr",
"it",
"pt",
"ro",
"et",
"fi",
"hu",
"lt",
"lv",
"el",
"mt",
]

class MultiEURLEXMT(PromptSourceTask):
DATASET_PATH = "multi_eurlex"
DATASET_NAME = "all_languages"
VERSION = 0

def has_training_docs(self):
return True

def has_validation_docs(self):
return True

def has_test_docs(self):
return True

def training_docs(self):
if self.has_training_docs():
return self.dataset["train"]

def validation_docs(self):
if self.has_validation_docs():
return self.dataset["validation"]

def test_docs(self):
if self.has_test_docs():
return self.dataset["test"]

def max_generation_length(self) -> Optional[int]:
return 1024


1 change: 0 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def main():
results = evaluator.cli_evaluate(**evaluate_args)
else:
from codecarbon import OfflineEmissionsTracker

with OfflineEmissionsTracker(country_iso_code="FRA", log_level="error"):
print() # Add newline between emissions tracker and evaluation logging.
results = evaluator.cli_evaluate(**evaluate_args)
Expand Down