Skip to content

Commit

Permalink
keys_to_use is now inferred from architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
RichJackson committed Oct 24, 2024
1 parent d29a6c2 commit 64f4600
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
1 change: 0 additions & 1 deletion kazu/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class TrainingConfig:
device: str
#: number of workers for dataloader
workers: int
keys_to_use: list[str]
#: architecture to use. Currently supports bert, deberta, distilbert
architecture: str = "bert"
#: fraction of epoch to complete before evaluations begin
Expand Down
10 changes: 6 additions & 4 deletions kazu/training/train_multilabel_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,6 @@ def __init__(
working_dir: Path,
summary_writer: Optional[SummaryWriter] = None,
ls_wrapper: Optional[LSManagerViewWrapper] = None,
keys_to_use: Optional[list[str]] = None,
):

self.ls_wrapper = ls_wrapper
Expand All @@ -320,9 +319,12 @@ def __init__(
self.test_dataset = test_dataset
self.label_list = label_list
self.pretrained_model_name_or_path = pretrained_model_name_or_path
self.keys_to_use = (
keys_to_use if keys_to_use else ["input_ids", "attention_mask", "token_type_ids"]
)
self.keys_to_use = self._select_keys_to_use()

def _select_keys_to_use(self) -> list[str]:
if self.training_config.architecture == "distilbert":
return ["input_ids", "attention_mask"]
return ["input_ids", "attention_mask", "token_type_ids"]

def _write_to_tensorboard(
self, global_step: int, main_tag: str, tag_scalar_dict: dict[str, NumericMetric]
Expand Down
1 change: 0 additions & 1 deletion kazu/training/train_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,6 @@ def run(cfg: DictConfig) -> None:
working_dir=output_dir,
summary_writer=SummaryWriter(log_dir=str(tb_path.absolute())),
ls_wrapper=wrapper,
keys_to_use=training_config.keys_to_use,
)
trainer.train_model()

Expand Down

0 comments on commit 64f4600

Please sign in to comment.