From e4d9813dc1832753cedd57d856f699cfa489271c Mon Sep 17 00:00:00 2001 From: Sami Virpioja Date: Wed, 29 May 2024 17:33:46 +0300 Subject: [PATCH] add collect_steps option to SwagUpdateCallback --- examples/bert_snli.py | 14 +++++---- examples/marian_mt.py | 15 ++++++---- src/swag_transformers/trainer_utils.py | 39 ++++++++++++++++++++++++-- tests/test_swag_marian.py | 12 ++++---- 4 files changed, 61 insertions(+), 19 deletions(-) diff --git a/examples/bert_snli.py b/examples/bert_snli.py index 97ba8ae..3719919 100644 --- a/examples/bert_snli.py +++ b/examples/bert_snli.py @@ -23,6 +23,10 @@ def main(): parser.add_argument("--limit-training", type=int, help="limit training data to N first samples") parser.add_argument("--data-cache-dir", type=str, help="folder to cache HF datasets") parser.add_argument("--model-cache-dir", type=str, help="folder to cache HF models and tokenizers") + parser.add_argument("--batch-size", type=int, default=4, help="batch size") + parser.add_argument("--epochs", type=int, default=3, help="number of training epochs") + parser.add_argument("--collect-steps", type=int, default=100, help="number of steps between collecting parameters") + parser.add_argument("--learning-rate", type=float, default=2e-5, help="learning rate") args = parser.parse_args() if args.device: @@ -49,10 +53,10 @@ def tokenize_dataset(dataset): training_args = transformers.TrainingArguments( output_dir=args.save_folder, - learning_rate=2e-5, - per_device_train_batch_size=4, - per_device_eval_batch_size=4, - num_train_epochs=3, + learning_rate=args.learning_rate, + per_device_train_batch_size=args.batch_size, + per_device_eval_batch_size=args.batch_size, + num_train_epochs=args.epochs, use_cpu=True if device == "cpu" else False ) @@ -66,7 +70,7 @@ def tokenize_dataset(dataset): train_dataset=processed_dataset, tokenizer=tokenizer, data_collator=data_collator, - callbacks=[SwagUpdateCallback(swag_model)] + callbacks=[SwagUpdateCallback(swag_model, collect_steps=args.collect_steps)] ) trainer.train() trainer.save_model(os.path.join(args.save_folder, "final_base")) diff --git a/examples/marian_mt.py b/examples/marian_mt.py index 10075ef..59830fd 100644 --- a/examples/marian_mt.py +++ b/examples/marian_mt.py @@ -21,7 +21,10 @@ def main(): parser.add_argument("--device", type=str, help="set device (default: cuda if available, otherwise cpu)") parser.add_argument("--save-folder", type=str, default="save_folder") parser.add_argument("--limit-training", type=int, help="limit training data to N first samples") + parser.add_argument("--batch-size", type=int, default=4, help="batch size") parser.add_argument("--epochs", type=int, default=3, help="number of training epochs") + parser.add_argument("--collect-steps", type=int, default=100, help="number of steps between collecting parameters") + parser.add_argument("--learning-rate", type=float, default=2e-5, help="learning rate") args = parser.parse_args() if args.device: @@ -60,22 +63,22 @@ def tokenize_function(example): tokenizer=tokenizer, model=model ) - training_args = transformers.TrainingArguments( + training_args = transformers.Seq2SeqTrainingArguments( output_dir=args.save_folder, - learning_rate=2e-5, - per_device_train_batch_size=4, - per_device_eval_batch_size=4, + learning_rate=args.learning_rate, + per_device_train_batch_size=args.batch_size, + per_device_eval_batch_size=args.batch_size, num_train_epochs=args.epochs, use_cpu=True if device == "cpu" else False ) - trainer = transformers.Trainer( + trainer = transformers.Seq2SeqTrainer( model=model, args=training_args, train_dataset=tokenized_datasets, tokenizer=tokenizer, data_collator=data_collator, - callbacks=[SwagUpdateCallback(swag_model)] + callbacks=[SwagUpdateCallback(swag_model, collect_steps=args.collect_steps)] ) trainer.train() trainer.save_model(os.path.join(args.save_folder, "final_base")) diff --git a/src/swag_transformers/trainer_utils.py b/src/swag_transformers/trainer_utils.py index 78a2132..bae41c8 100644 --- a/src/swag_transformers/trainer_utils.py +++ b/src/swag_transformers/trainer_utils.py @@ -23,15 +23,50 @@ class SwagUpdateCallback(TrainerCallback): ) trainer.train() + Two possible schedules for the updates are currently supported: If + collect_steps > 0 is provided, the parameters are collected after + each collect_steps training steps. Otherwise, and as default, the + parameters are collected on the end of each training epoch. + """ - def __init__(self, swag_model): + def __init__(self, swag_model, collect_steps=None): self.main_model = swag_model + self.collect_steps = collect_steps + self.last_collect_step = None + + def on_train_end(self, args, state, control, model=None, **kwargs): + if self.last_collect_step == state.global_step: + return + if model is None: + logger.error("No model provided for SWAG update") + return + logger.debug("Updating SWAG parameters from %s after train end (steps %s)", type(model).__name__, state.global_step) + self.main_model.swag.collect_model(model) + self.main_model.config.update_internal_config(model.config) def on_epoch_end(self, args, state, control, model=None, **kwargs): + if self.collect_steps: + return + if model is None: + logger.error("No model provided for SWAG update") + return + logger.debug("Updating SWAG parameters from %s after epoch end (steps %s)", type(model).__name__, state.global_step) + self.main_model.swag.collect_model(model) + self.main_model.config.update_internal_config(model.config) + self.last_collect_step = state.global_step + + def on_step_end(self, args, state, control, model=None, **kwargs): + if not self.collect_steps: + return + if not state.global_step: + return + if state.global_step % self.collect_steps != 0: + return if model is None: logger.error("No model provided for SWAG update") return - logger.debug("Updating SWAG parameters from %s", type(model).__name__) + logger.debug("Updating SWAG parameters from %s after step %s", type(model).__name__, state.global_step) self.main_model.swag.collect_model(model) self.main_model.config.update_internal_config(model.config) + self.last_collect_step = state.global_step diff --git a/tests/test_swag_marian.py b/tests/test_swag_marian.py index ee0e3cf..0997fd7 100644 --- a/tests/test_swag_marian.py +++ b/tests/test_swag_marian.py @@ -150,7 +150,7 @@ def tokenize_function(example): data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, model=model ) - train_epochs = 5 + train_epochs = 60 logging.debug(model.lm_head.weight) logging.debug(model.model.encoder.embed_tokens.weight) logging.debug(model.model.decoder.embed_tokens.weight) @@ -166,14 +166,14 @@ def tokenize_function(example): train_dataset=tokenized_datasets["train"], data_collator=data_collator, tokenizer=tokenizer, - callbacks=[SwagUpdateCallback(swag_model)] + callbacks=[SwagUpdateCallback(swag_model, collect_steps=2)] ) trainer.train() - self.assertEqual(swag_model.swag.n_models, train_epochs) + logging.info("N models: %s", swag_model.swag.n_models.item()) + # self.assertEqual(swag_model.swag.n_models, train_epochs) swag_model.swag.sample() - sample_text = "what is so great ?" - - batch = tokenizer([sample_text], return_tensors="pt").to(device) + sample_text = "India and Japan prime ministers meet in Tokyo" + batch = tokenizer([sample_text], return_tensors="pt") generated_ids = model.generate(**batch, max_new_tokens=10) base_output = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0] logging.debug(base_output)