diff --git a/kazu/training/config.py b/kazu/training/config.py index 286d78fa..9761dd58 100644 --- a/kazu/training/config.py +++ b/kazu/training/config.py @@ -24,5 +24,6 @@ class TrainingConfig: test_overfit: bool device: str workers: int + keys_to_use: list[str] architecture: str = "bert" epoch_completion_fraction_before_evals: float = 0.75 diff --git a/kazu/training/train_multilabel_ner.py b/kazu/training/train_multilabel_ner.py index d6b444e1..c0a9cc2a 100644 --- a/kazu/training/train_multilabel_ner.py +++ b/kazu/training/train_multilabel_ner.py @@ -306,6 +306,7 @@ def __init__( working_dir: Path, summary_writer: Optional[SummaryWriter] = None, ls_wrapper: Optional[LSManagerViewWrapper] = None, + keys_to_use: Optional[list[str]] = None, ): self.ls_wrapper = ls_wrapper @@ -321,6 +322,9 @@ def __init__( self.eval_dataset = eval_dataset self.label_list = label_list self.pretrained_model_name_or_path = pretrained_model_name_or_path + self.keys_to_use = ( + keys_to_use if keys_to_use else ["input_ids", "attention_mask", "token_type_ids"] + ) def _write_to_tensorboard( self, global_step: int, main_tag: str, tag_scalar_dict: dict[str, NumericMetric] @@ -441,7 +445,7 @@ def _process_docs(self, model: PreTrainedModel) -> list[Document]: tokenized_word_processor=TokenizedWordProcessor( labels=self.label_list, use_multilabel=True ), - keys_to_use=["input_ids", "attention_mask", "token_type_ids"], + keys_to_use=self.keys_to_use, device=self.training_config.device, ) diff --git a/kazu/training/train_script.py b/kazu/training/train_script.py index f0b72098..2667db7c 100644 --- a/kazu/training/train_script.py +++ b/kazu/training/train_script.py @@ -126,6 +126,7 @@ def run(cfg: DictConfig) -> None: working_dir=output_dir, summary_writer=SummaryWriter(log_dir=str(tb_path.absolute())), ls_wrapper=wrapper, + keys_to_use=training_config.keys_to_use, ) trainer.train_model() diff --git a/scripts/examples/conf/multilabel_ner_training/default.yaml b/scripts/examples/conf/multilabel_ner_training/default.yaml index 133e7722..00133a94 100644 --- a/scripts/examples/conf/multilabel_ner_training/default.yaml +++ b/scripts/examples/conf/multilabel_ner_training/default.yaml @@ -20,27 +20,31 @@ training_config: test_overfit: false device: mps architecture: bert + keys_to_use: #distilbert for token classification doesn't use token_type_ids + - input_ids + - attention_mask + - token_type_ids workers: 2 css_colors: - - '#000000' # Black - - '#FF0000' # Red - - '#00FF00' # Lime - - '#0000FF' # Blue - - '#FFFF00' # Yellow - - '#00FFFF' # Cyan - - '#FF00FF' # Magenta - - '#800000' # Maroon - - '#808000' # Olive - - '#008000' # Green - - '#800080' # Purple - - '#008080' # Teal - - '#FFA500' # Orange - - '#A52A2A' # Brown - - '#8A2BE2' # BlueViolet - - '#5F9EA0' # CadetBlue - - '#D2691E' # Chocolate - - '#000080' # Navy - - '#FFDAB9' # PeachPuff + - "#000000" # Black + - "#FF0000" # Red + - "#00FF00" # Lime + - "#0000FF" # Blue + - "#FFFF00" # Yellow + - "#00FFFF" # Cyan + - "#FF00FF" # Magenta + - "#800000" # Maroon + - "#808000" # Olive + - "#008000" # Green + - "#800080" # Purple + - "#008080" # Teal + - "#FFA500" # Orange + - "#A52A2A" # Brown + - "#8A2BE2" # BlueViolet + - "#5F9EA0" # CadetBlue + - "#D2691E" # Chocolate + - "#000080" # Navy + - "#FFDAB9" # PeachPuff label_studio_manager: _target_: kazu.annotation.label_studio.LabelStudioManager project_name: "clean_data"