From 64f46005d62fc3a33a46485d92deccc83d774f28 Mon Sep 17 00:00:00 2001 From: Richard Jackson Date: Thu, 24 Oct 2024 15:12:13 +0100 Subject: [PATCH] keys_to_use is now inferred from architecture --- kazu/training/config.py | 1 - kazu/training/train_multilabel_ner.py | 10 ++++++---- kazu/training/train_script.py | 1 - 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/kazu/training/config.py b/kazu/training/config.py index 90b74503..c5b663b9 100644 --- a/kazu/training/config.py +++ b/kazu/training/config.py @@ -41,7 +41,6 @@ class TrainingConfig: device: str #: number of workers for dataloader workers: int - keys_to_use: list[str] #: architecture to use. Currently supports bert, deberta, distilbert architecture: str = "bert" #: fraction of epoch to complete before evaluations begin diff --git a/kazu/training/train_multilabel_ner.py b/kazu/training/train_multilabel_ner.py index fabe962e..a5bdcec9 100644 --- a/kazu/training/train_multilabel_ner.py +++ b/kazu/training/train_multilabel_ner.py @@ -306,7 +306,6 @@ def __init__( working_dir: Path, summary_writer: Optional[SummaryWriter] = None, ls_wrapper: Optional[LSManagerViewWrapper] = None, - keys_to_use: Optional[list[str]] = None, ): self.ls_wrapper = ls_wrapper @@ -320,9 +319,12 @@ def __init__( self.test_dataset = test_dataset self.label_list = label_list self.pretrained_model_name_or_path = pretrained_model_name_or_path - self.keys_to_use = ( - keys_to_use if keys_to_use else ["input_ids", "attention_mask", "token_type_ids"] - ) + self.keys_to_use = self._select_keys_to_use() + + def _select_keys_to_use(self) -> list[str]: + if self.training_config.architecture == "distilbert": + return ["input_ids", "attention_mask"] + return ["input_ids", "attention_mask", "token_type_ids"] def _write_to_tensorboard( self, global_step: int, main_tag: str, tag_scalar_dict: dict[str, NumericMetric] diff --git a/kazu/training/train_script.py b/kazu/training/train_script.py index a87855e0..d0df5e64 100644 --- a/kazu/training/train_script.py +++ b/kazu/training/train_script.py @@ -125,7 +125,6 @@ def run(cfg: DictConfig) -> None: working_dir=output_dir, summary_writer=SummaryWriter(log_dir=str(tb_path.absolute())), ls_wrapper=wrapper, - keys_to_use=training_config.keys_to_use, ) trainer.train_model()