Skip to content

Commit

Permalink
pass tokenizer path through args, add optim lr scheduler to config, a…
Browse files Browse the repository at this point in the history
…dd lr argpars
  • Loading branch information
MenuaB committed Mar 27, 2024
1 parent 4881ddc commit f939b39
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 13 deletions.
2 changes: 2 additions & 0 deletions chemlactica/config/default_train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class TrainConfig:
max_learning_rate: float = 6.0e-4
warmup_steps: int = 500
weight_decay: float = 0.1
optimizer: str = "adamw_torch"
lr_scheduler_type: str = "linear"


@dataclass
Expand Down
7 changes: 5 additions & 2 deletions chemlactica/custom_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@ class CustomArguments(TrainingArguments):
)
command: str = field(default=None)
experiment_name: str = field(default=None)
tokenizer_path: str = field(
default="/auto/home/menuab/code/ChemLactica/chemlactica/tokenizer/ChemLacticaTokenizer66"
)
# train_config: dict = field(default=None)


class CustomTrainer(Trainer):
def __init__(self, tokenizer_path, *args, **kwargs):
def __init__(self, *args, **kwargs):
# the number of samples to print when the training begins, for debugging purposes
self.num_samples_to_print = 5
self.tokenizer_path = tokenizer_path
self.tokenizer_path = kwargs["args"].tokenizer_path
super().__init__(*args, **kwargs)

def training_step(self, model: Module, inputs: Dict[str, Tensor | Any]) -> Tensor:
Expand Down
8 changes: 3 additions & 5 deletions chemlactica/get_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@
from config.default_train_config import SFTTrainConfig


def get_trainer(
train_type, model, model_config, dataset, training_args, evaluate_only, slurm_eval
):
def get_trainer(train_type, model, dataset, training_args, evaluate_only, slurm_eval):
if train_type == "pretrain":
trainer = CustomTrainer(
model=model,
tokenizer_path=model_config.tokenizer_path,
# tokenizer_path=model_config.tokenizer_path,
args=training_args,
# compute_metrics=compute_metrics,
train_dataset=dataset["train"] if not evaluate_only else None,
Expand All @@ -26,7 +24,7 @@ def get_trainer(

elif train_type == "sft":
sft_config = SFTTrainConfig()
tokenizer = get_tokenizer(model_config.tokenizer_path)
tokenizer = get_tokenizer(training_args.tokenizer_path)
response_template = "[PROPERTY]activity "
collator = DataCollatorForCompletionOnlyLM(
response_template, tokenizer=tokenizer
Expand Down
11 changes: 7 additions & 4 deletions chemlactica/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def train(
training_data_dirs,
dir_data_types,
valid_data_dir,
learning_rate,
scheduler_max_steps,
eval_steps,
save_steps,
Expand Down Expand Up @@ -218,6 +219,7 @@ def train(
slurm_eval=slurm_eval,
experiment_name=experiment_name,
# train_config=train_config,
tokenizer_path=model_config.tokenizer_path,
do_train=not evaluate_only,
output_dir=checkpoints_dir,
per_device_train_batch_size=train_batch_size,
Expand All @@ -228,7 +230,9 @@ def train(
bf16_full_eval=True,
fp16=False,
logging_dir=track_dir,
learning_rate=train_config.max_learning_rate,
learning_rate=learning_rate
if learning_rate
else train_config.max_learning_rate,
weight_decay=train_config.weight_decay,
adam_beta1=train_config.adam_beta1,
adam_beta2=train_config.adam_beta2,
Expand All @@ -251,8 +255,8 @@ def train(
# gradient_accumulation_steps=gradient_accumulation_steps,
# save_total_limit=4, in order for offline eval to work, we keep all of them for now
resume_from_checkpoint=resume_from_checkpoint,
lr_scheduler_type="linear",
optim="adamw_torch",
lr_scheduler_type=train_config.lr_scheduler_type,
optim=train_config.optimizer,
# load_best_model=True
)

Expand All @@ -271,7 +275,6 @@ def train(
trainer = get_trainer(
train_type,
model,
model_config,
dataset,
training_args,
evaluate_only,
Expand Down
9 changes: 9 additions & 0 deletions chemlactica/utils/parseargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ def init_parser():
required=True,
help="path to directory containing validation data",
)
parser.add_argument(
"--learning_rate",
type=int,
metavar="LR",
dest="learning_rate",
required=False,
default=None,
help="learning rate",
)
parser.add_argument(
"--max_steps",
type=int,
Expand Down
5 changes: 3 additions & 2 deletions submit_run_galactica_pre.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

use_accelerate = True
rsync_enabled = False
executor_name = "local" # options are ["slurm", "local"]
executor_name = "slurm" # options are ["slurm", "local"]
root_path = ""
num_gpus = 2
model_name = "galactica"
Expand Down Expand Up @@ -40,9 +40,10 @@
"valid_data_dir": "/nfs/ap/mnt/sxtn/rdkit_computed_rel+form/valid_rdkit_computed_rel+form",
"max_steps": 120000,
# "num_train_epochs": 15,
# "learning_rate": 5,
"eval_steps": 1000,
"save_steps": 1000,
"train_batch_size": 16,
"train_batch_size": 2,
# "valid_batch_size": 16,
"dataloader_num_workers": 30,
"experiment_name": "freesolv_30e",
Expand Down

0 comments on commit f939b39

Please sign in to comment.