Skip to content
This repository was archived by the owner on Dec 15, 2024. It is now read-only.

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
morioka committed Sep 12, 2022
1 parent e02721d commit c3b531a
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions run_qg.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,14 +178,16 @@ def main(args_file=None):
using_tpu=training_args.tpu_num_cores is not None
)

# Prediction Loss
training_args.prediction_loss_only=True

# Initialize our Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=valid_dataset,
data_collator=data_collator,
prediction_loss_only=True,
label_smoothing=model_args.label_smoothing
)

Expand All @@ -200,8 +202,8 @@ def main(args_file=None):
trainer.save_model()
# For convenience, we also re-save the tokenizer to the same directory,
# so that you can share your model easily on huggingface.co/models =)
if trainer.is_world_master():
tokenizer.save_pretrained(training_args.output_dir)
if trainer.is_world_process_zero():
tokenizer.save_pretrained(training_args.output_dir)

# Evaluation
results = {}
Expand Down Expand Up @@ -233,4 +235,4 @@ def run_qg(args_dict):
main(args_file="args.json")

if __name__ == "__main__":
main()
main()

0 comments on commit c3b531a

Please sign in to comment.