Skip to content

Commit

Permalink
add collect_steps option to SwagUpdateCallback
Browse files Browse the repository at this point in the history
  • Loading branch information
svirpioj committed May 29, 2024
1 parent 8763f31 commit e4d9813
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 19 deletions.
14 changes: 9 additions & 5 deletions examples/bert_snli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ def main():
parser.add_argument("--limit-training", type=int, help="limit training data to N first samples")
parser.add_argument("--data-cache-dir", type=str, help="folder to cache HF datasets")
parser.add_argument("--model-cache-dir", type=str, help="folder to cache HF models and tokenizers")
parser.add_argument("--batch-size", type=int, default=4, help="batch size")
parser.add_argument("--epochs", type=int, default=3, help="number of training epochs")
parser.add_argument("--collect-steps", type=int, default=100, help="number of steps between collecting parameters")
parser.add_argument("--learning-rate", type=float, default=2e-5, help="learning rate")
args = parser.parse_args()

if args.device:
Expand All @@ -49,10 +53,10 @@ def tokenize_dataset(dataset):

training_args = transformers.TrainingArguments(
output_dir=args.save_folder,
learning_rate=2e-5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=3,
learning_rate=args.learning_rate,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.epochs,
use_cpu=True if device == "cpu" else False
)

Expand All @@ -66,7 +70,7 @@ def tokenize_dataset(dataset):
train_dataset=processed_dataset,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[SwagUpdateCallback(swag_model)]
callbacks=[SwagUpdateCallback(swag_model, collect_steps=args.collect_steps)]
)
trainer.train()
trainer.save_model(os.path.join(args.save_folder, "final_base"))
Expand Down
15 changes: 9 additions & 6 deletions examples/marian_mt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ def main():
parser.add_argument("--device", type=str, help="set device (default: cuda if available, otherwise cpu)")
parser.add_argument("--save-folder", type=str, default="save_folder")
parser.add_argument("--limit-training", type=int, help="limit training data to N first samples")
parser.add_argument("--batch-size", type=int, default=4, help="batch size")
parser.add_argument("--epochs", type=int, default=3, help="number of training epochs")
parser.add_argument("--collect-steps", type=int, default=100, help="number of steps between collecting parameters")
parser.add_argument("--learning-rate", type=float, default=2e-5, help="learning rate")
args = parser.parse_args()

if args.device:
Expand Down Expand Up @@ -60,22 +63,22 @@ def tokenize_function(example):
tokenizer=tokenizer, model=model
)

training_args = transformers.TrainingArguments(
training_args = transformers.Seq2SeqTrainingArguments(
output_dir=args.save_folder,
learning_rate=2e-5,
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
learning_rate=args.learning_rate,
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
num_train_epochs=args.epochs,
use_cpu=True if device == "cpu" else False
)

trainer = transformers.Trainer(
trainer = transformers.Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets,
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=[SwagUpdateCallback(swag_model)]
callbacks=[SwagUpdateCallback(swag_model, collect_steps=args.collect_steps)]
)
trainer.train()
trainer.save_model(os.path.join(args.save_folder, "final_base"))
Expand Down
39 changes: 37 additions & 2 deletions src/swag_transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,50 @@ class SwagUpdateCallback(TrainerCallback):
)
trainer.train()
Two possible schedules for the updates are currently supported: If
collect_steps > 0 is provided, the parameters are collected after
each collect_steps training steps. Otherwise, and as default, the
parameters are collected on the end of each training epoch.
"""

def __init__(self, swag_model):
def __init__(self, swag_model, collect_steps=None):
self.main_model = swag_model
self.collect_steps = collect_steps
self.last_collect_step = None

def on_train_end(self, args, state, control, model=None, **kwargs):
if self.last_collect_step == state.global_step:
return
if model is None:
logger.error("No model provided for SWAG update")
return
logger.debug("Updating SWAG parameters from %s after train end (steps %s)", type(model).__name__, state.global_step)
self.main_model.swag.collect_model(model)
self.main_model.config.update_internal_config(model.config)

def on_epoch_end(self, args, state, control, model=None, **kwargs):
if self.collect_steps:
return
if model is None:
logger.error("No model provided for SWAG update")
return
logger.debug("Updating SWAG parameters from %s after epoch end (steps %s)", type(model).__name__, state.global_step)
self.main_model.swag.collect_model(model)
self.main_model.config.update_internal_config(model.config)
self.last_collect_step = state.global_step

def on_step_end(self, args, state, control, model=None, **kwargs):
if not self.collect_steps:
return
if not state.global_step:
return
if state.global_step % self.collect_steps != 0:
return
if model is None:
logger.error("No model provided for SWAG update")
return
logger.debug("Updating SWAG parameters from %s", type(model).__name__)
logger.debug("Updating SWAG parameters from %s after step %s", type(model).__name__, state.global_step)
self.main_model.swag.collect_model(model)
self.main_model.config.update_internal_config(model.config)
self.last_collect_step = state.global_step
12 changes: 6 additions & 6 deletions tests/test_swag_marian.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def tokenize_function(example):
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer, model=model
)
train_epochs = 5
train_epochs = 60
logging.debug(model.lm_head.weight)
logging.debug(model.model.encoder.embed_tokens.weight)
logging.debug(model.model.decoder.embed_tokens.weight)
Expand All @@ -166,14 +166,14 @@ def tokenize_function(example):
train_dataset=tokenized_datasets["train"],
data_collator=data_collator,
tokenizer=tokenizer,
callbacks=[SwagUpdateCallback(swag_model)]
callbacks=[SwagUpdateCallback(swag_model, collect_steps=2)]
)
trainer.train()
self.assertEqual(swag_model.swag.n_models, train_epochs)
logging.info("N models: %s", swag_model.swag.n_models.item())
# self.assertEqual(swag_model.swag.n_models, train_epochs)
swag_model.swag.sample()
sample_text = "what is so great ?"

batch = tokenizer([sample_text], return_tensors="pt").to(device)
sample_text = "India and Japan prime ministers meet in Tokyo"
batch = tokenizer([sample_text], return_tensors="pt")
generated_ids = model.generate(**batch, max_new_tokens=10)
base_output = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)[0]
logging.debug(base_output)
Expand Down

0 comments on commit e4d9813

Please sign in to comment.