Skip to content

Commit

Permalink
fix: make parameter configurable for different architectures
Browse files Browse the repository at this point in the history
  • Loading branch information
paluchasz committed Oct 24, 2024
1 parent ed750c3 commit b560cf0
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 20 deletions.
1 change: 1 addition & 0 deletions kazu/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,6 @@ class TrainingConfig:
test_overfit: bool
device: str
workers: int
keys_to_use: list[str]
architecture: str = "bert"
epoch_completion_fraction_before_evals: float = 0.75
6 changes: 5 additions & 1 deletion kazu/training/train_multilabel_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def __init__(
working_dir: Path,
summary_writer: Optional[SummaryWriter] = None,
ls_wrapper: Optional[LSManagerViewWrapper] = None,
keys_to_use: Optional[list[str]] = None,
):

self.ls_wrapper = ls_wrapper
Expand All @@ -321,6 +322,9 @@ def __init__(
self.eval_dataset = eval_dataset
self.label_list = label_list
self.pretrained_model_name_or_path = pretrained_model_name_or_path
self.keys_to_use = (
keys_to_use if keys_to_use else ["input_ids", "attention_mask", "token_type_ids"]
)

def _write_to_tensorboard(
self, global_step: int, main_tag: str, tag_scalar_dict: dict[str, NumericMetric]
Expand Down Expand Up @@ -441,7 +445,7 @@ def _process_docs(self, model: PreTrainedModel) -> list[Document]:
tokenized_word_processor=TokenizedWordProcessor(
labels=self.label_list, use_multilabel=True
),
keys_to_use=["input_ids", "attention_mask", "token_type_ids"],
keys_to_use=self.keys_to_use,
device=self.training_config.device,
)

Expand Down
1 change: 1 addition & 0 deletions kazu/training/train_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def run(cfg: DictConfig) -> None:
working_dir=output_dir,
summary_writer=SummaryWriter(log_dir=str(tb_path.absolute())),
ls_wrapper=wrapper,
keys_to_use=training_config.keys_to_use,
)
trainer.train_model()

Expand Down
42 changes: 23 additions & 19 deletions scripts/examples/conf/multilabel_ner_training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,31 @@ training_config:
test_overfit: false
device: mps
architecture: bert
keys_to_use: #distilbert for token classification doesn't use token_type_ids
- input_ids
- attention_mask
- token_type_ids
workers: 2
css_colors:
- '#000000' # Black
- '#FF0000' # Red
- '#00FF00' # Lime
- '#0000FF' # Blue
- '#FFFF00' # Yellow
- '#00FFFF' # Cyan
- '#FF00FF' # Magenta
- '#800000' # Maroon
- '#808000' # Olive
- '#008000' # Green
- '#800080' # Purple
- '#008080' # Teal
- '#FFA500' # Orange
- '#A52A2A' # Brown
- '#8A2BE2' # BlueViolet
- '#5F9EA0' # CadetBlue
- '#D2691E' # Chocolate
- '#000080' # Navy
- '#FFDAB9' # PeachPuff
- "#000000" # Black
- "#FF0000" # Red
- "#00FF00" # Lime
- "#0000FF" # Blue
- "#FFFF00" # Yellow
- "#00FFFF" # Cyan
- "#FF00FF" # Magenta
- "#800000" # Maroon
- "#808000" # Olive
- "#008000" # Green
- "#800080" # Purple
- "#008080" # Teal
- "#FFA500" # Orange
- "#A52A2A" # Brown
- "#8A2BE2" # BlueViolet
- "#5F9EA0" # CadetBlue
- "#D2691E" # Chocolate
- "#000080" # Navy
- "#FFDAB9" # PeachPuff
label_studio_manager:
_target_: kazu.annotation.label_studio.LabelStudioManager
project_name: "clean_data"
Expand Down

0 comments on commit b560cf0

Please sign in to comment.