Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add CUREv1 retrieval dataset #1459

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions mteb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
MTEB_ENG_CLASSIC,
MTEB_MAIN_RU,
MTEB_RETRIEVAL_LAW,
MTEB_RETRIEVAL_MEDICAL,
MTEB_RETRIEVAL_WITH_INSTRUCTIONS,
CoIR,
)
Expand All @@ -24,6 +25,7 @@
"MTEB_ENG_CLASSIC",
"MTEB_MAIN_RU",
"MTEB_RETRIEVAL_LAW",
"MTEB_RETRIEVAL_MEDICAL",
"MTEB_RETRIEVAL_WITH_INSTRUCTIONS",
"CoIR",
"TASKS_REGISTRY",
Expand Down
18 changes: 18 additions & 0 deletions mteb/benchmarks/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,24 @@ def load_results(
citation=None,
)

MTEB_RETRIEVAL_MEDICAL = Benchmark(
name="MTEB(Medical)",
tasks=get_tasks(
tasks=[
"CUREv1",
"NFCorpus",
"SCIDOCS",
"TRECCOVID",
"SciFact",
"MedicalQARetrieval",
"PublicHealthQA",
]
),
description="Medical benchmarks from MTEB",
dbuades marked this conversation as resolved.
Show resolved Hide resolved
reference="",
citation=None,
)

MTEB_MINERS_BITEXT_MINING = Benchmark(
name="MINERSBitextMining",
tasks=get_tasks(
Expand Down
1 change: 1 addition & 0 deletions mteb/tasks/Retrieval/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from .multilingual.BelebeleRetrieval import *
from .multilingual.CrossLingualSemanticDiscriminationWMT19 import *
from .multilingual.CrossLingualSemanticDiscriminationWMT21 import *
from .multilingual.CUREv1Retrieval import *
from .multilingual.IndicQARetrieval import *
from .multilingual.MintakaRetrieval import *
from .multilingual.MIRACLRetrieval import *
Expand Down
151 changes: 151 additions & 0 deletions mteb/tasks/Retrieval/multilingual/CUREv1Retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
from __future__ import annotations

from enum import Enum

from datasets import DatasetDict, load_dataset

from mteb.abstasks.TaskMetadata import TaskMetadata

from ....abstasks.AbsTaskRetrieval import AbsTaskRetrieval
from ....abstasks.MultilingualTask import MultilingualTask

_LANGUAGES = {
"en": ["eng-Latn", "eng-Latn"],
"es": ["spa-Latn", "eng-Latn"],
"fr": ["fra-Latn", "eng-Latn"],
}


class CUREv1Splits(str, Enum):
all = "All"
dentistry_and_oral_health = "Dentistry and Oral Health"
dermatology = "Dermatology"
gastroenterology = "Gastroenterology"
genetics = "Genetics"
neuroscience_and_neurology = "Neuroscience and Neurology"
orthopedic_surgery = "Orthopedic Surgery"
otorhinolaryngology = "Otorhinolaryngology"
plastic_surgery = "Plastic Surgery"
psychiatry_and_psychology = "Psychiatry and Psychology"
pulmonology = "Pulmonology"

@classmethod
def names(cls) -> list[str]:
return sorted(cls._member_names_)


class CUREv1Retrieval(MultilingualTask, AbsTaskRetrieval):
metadata = TaskMetadata(
dataset={
"path": "clinia/CUREv1",
"revision": "3bcf51c91e04d04a8a3329dfbe988b964c5cbe83",
},
name="CUREv1",
description="Collection of query-passage pairs curated by medical professionals, across 10 disciplines and 3 cross-lingual settings.",
type="Retrieval",
modalities=["text"],
category="s2p",
reference="https://huggingface.co/datasets/clinia/CUREv1",
eval_splits=CUREv1Splits.names(),
eval_langs=_LANGUAGES,
main_score="ndcg_at_10",
date=("2024-01-01", "2024-10-31"),
domains=["Medical", "Academic"],
dbuades marked this conversation as resolved.
Show resolved Hide resolved
task_subtypes=[],
license="cc-by-nc-4.0",
annotations_creators="expert-annotated",
dialect=[],
sample_creation="created",
bibtex_citation="",
prompt={
"query": "Given a question by a medical professional, retrieve relevant passages that best answer the question",
},
)

def _load_corpus(self, split: str, cache_dir: str | None = None):
ds = load_dataset(
path=self.metadata_dict["dataset"]["path"],
revision=self.metadata_dict["dataset"]["revision"],
name="corpus",
split=split,
cache_dir=cache_dir,
)

corpus = {
doc["_id"]: {"title": doc["title"], "text": doc["text"]} for doc in ds
}

return corpus

def _load_qrels(self, split: str, cache_dir: str | None = None):
ds = load_dataset(
path=self.metadata_dict["dataset"]["path"],
revision=self.metadata_dict["dataset"]["revision"],
name="qrels",
split=split,
cache_dir=cache_dir,
)

qrels = {}

for qrel in ds:
query_id = qrel["query-id"]
doc_id = qrel["corpus-id"]
score = int(qrel["score"])
if query_id not in qrels:
qrels[query_id] = {}
qrels[query_id][doc_id] = score

return qrels

def _load_queries(self, split: str, language: str, cache_dir: str | None = None):
ds = load_dataset(
path=self.metadata_dict["dataset"]["path"],
revision=self.metadata_dict["dataset"]["revision"],
name=f"queries-{language}",
split=split,
cache_dir=cache_dir,
)

queries = {query["_id"]: query["text"] for query in ds}

return queries

def load_data(self, **kwargs):
if self.data_loaded:
return

eval_splits = kwargs.get("eval_splits", self.metadata.eval_splits)
languages = kwargs.get("eval_langs", self.metadata.eval_langs)
cache_dir = kwargs.get("cache_dir", None)

# Iterate over splits and languages
corpus = {
language: {split: None for split in eval_splits} for language in languages
}
queries = {
language: {split: None for split in eval_splits} for language in languages
}
relevant_docs = {
language: {split: None for split in eval_splits} for language in languages
}
for split in eval_splits:
# Since this is a cross-lingual dataset, the corpus and the relevant documents do not depend on the language
split_corpus = self._load_corpus(split=split, cache_dir=cache_dir)
split_qrels = self._load_qrels(split=split, cache_dir=cache_dir)

# Queries depend on the language
for language in languages:
corpus[language][split] = split_corpus
relevant_docs[language][split] = split_qrels

queries[language][split] = self._load_queries(
split=split, language=language, cache_dir=cache_dir
)

# Convert into DatasetDict
self.corpus = DatasetDict(corpus)
self.queries = DatasetDict(queries)
self.relevant_docs = DatasetDict(relevant_docs)

self.data_loaded = True