diff --git a/beir/retrieval/evaluation.py b/beir/retrieval/evaluation.py index fd40d72..18c98fb 100644 --- a/beir/retrieval/evaluation.py +++ b/beir/retrieval/evaluation.py @@ -11,7 +11,7 @@ class EvaluateRetrieval: - def __init__(self, retriever: Union[Type[DRES], Type[DRFS], Type[BM25], Type[SS]] = None, k_values: List[int] = [1,3,5,10,100,1000], score_function: str = "cos_sim"): + def __init__(self, retriever: Union[DRES, DRFS, BM25, SS] = None, k_values: List[int] = [1,3,5,10,100,1000], score_function: str = "cos_sim"): self.k_values = k_values self.top_k = max(k_values) self.retriever = retriever diff --git a/beir/retrieval/train.py b/beir/retrieval/train.py index 60af1c2..d1e7d44 100644 --- a/beir/retrieval/train.py +++ b/beir/retrieval/train.py @@ -15,12 +15,12 @@ class TrainRetriever: - def __init__(self, model: Type[SentenceTransformer], batch_size: int = 64): + def __init__(self, model: SentenceTransformer, batch_size: int = 64): self.model = model self.batch_size = batch_size def load_train(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str], - qrels: Dict[str, Dict[str, int]]) -> List[Type[InputExample]]: + qrels: Dict[str, Dict[str, int]]) -> List[InputExample]: query_ids = list(queries.keys()) train_samples = [] @@ -40,7 +40,7 @@ def load_train(self, corpus: Dict[str, Dict[str, str]], queries: Dict[str, str], logger.info("Loaded {} training pairs.".format(len(train_samples))) return train_samples - def load_train_triplets(self, triplets: List[Tuple[str, str, str]]) -> List[Type[InputExample]]: + def load_train_triplets(self, triplets: List[Tuple[str, str, str]]) -> List[InputExample]: train_samples = [] @@ -53,7 +53,7 @@ def load_train_triplets(self, triplets: List[Tuple[str, str, str]]) -> List[Type logger.info("Loaded {} training pairs.".format(len(train_samples))) return train_samples - def prepare_train(self, train_dataset: List[Type[InputExample]], shuffle: bool = True, dataset_present: bool = False) -> DataLoader: + def prepare_train(self, train_dataset: List[InputExample], shuffle: bool = True, dataset_present: bool = False) -> DataLoader: if not dataset_present: train_dataset = SentencesDataset(train_dataset, model=self.model) @@ -61,7 +61,7 @@ def prepare_train(self, train_dataset: List[Type[InputExample]], shuffle: bool = train_dataloader = DataLoader(train_dataset, shuffle=shuffle, batch_size=self.batch_size) return train_dataloader - def prepare_train_triplets(self, train_dataset: List[Type[InputExample]]) -> DataLoader: + def prepare_train_triplets(self, train_dataset: List[InputExample]) -> DataLoader: train_dataloader = datasets.NoDuplicatesDataLoader(train_dataset, batch_size=self.batch_size) return train_dataloader @@ -117,7 +117,7 @@ def fit(self, steps_per_epoch = None, scheduler: str = 'WarmupLinear', warmup_steps: int = 10000, - optimizer_class: Type[Optimizer] = AdamW, + optimizer_class: Optimizer = AdamW, optimizer_params : Dict[str, object]= {'lr': 2e-5, 'eps': 1e-6, 'correct_bias': False}, weight_decay: float = 0.01, evaluation_steps: int = 0,