Skip to content

Commit

Permalink
Reorder steps in pre-train and finetune scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
paololucchino committed Aug 12, 2024
1 parent 0a2e319 commit b11d595
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 54 deletions.
38 changes: 11 additions & 27 deletions SeasonTST/SeasonTST_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
datefmt="%m/%d/%Y %I:%M:%S %p",
filename=f'logs/{datetime.datetime.now().strftime("%Y_%m_%d_%I:%M")}_finetune.log',
encoding="utf-8",
level=logging.DEBUG,
level=logging.INFO,
)


Expand All @@ -51,10 +51,6 @@
def finetune_func(learner, save_path, args, lr=0.001):
print("end-to-end finetuning")

if not os.path.exists(save_path):
os.makedirs(save_path)

print(save_path)
# fit the data to the model and save
learner.fine_tune(
n_epochs=args.n_epochs_finetune, base_lr=lr, freeze_epochs=args.freeze_epochs
Expand Down Expand Up @@ -107,20 +103,6 @@ def save_recorders(learner, args):
)


def test_func(weight_path, learner, args, dls):

out = learner.test(
dls.test, weight_path=weight_path, scores=[mse, mae]
) # out: a list of [pred, targ, score]
print("score:", out[2])
# save results
pd.DataFrame(np.array(out[2]).reshape(1, -1), columns=["mse", "mae"]).to_csv(
args.save_path + args.save_finetuned_model + "_acc.csv",
float_format="%.6f",
index=False,
)
return out


def load_config():

Expand All @@ -135,13 +117,14 @@ def load_config():
"revin": 0, # reversible instance normalization
"mask_ratio": 0.4, # masking ratio for the input
"lr": 1e-3,
"batch_size": 128,
"batch_size": 64,
"drop_last": False,
"num_workers": 6,
"prefetch_factor": 3,
"n_epochs_pretrain": 1, # number of pre-training epochs,
"n_epochs_pretrain": 20, # number of pre-training epochs,
"freeze_epochs": 0,
"n_epochs_finetune": 250,
"pretrained_model_id": 2500, # id of the saved pretrained model
"n_epochs_finetune": 10,
"pretrained_model_id": 2, # id of the saved pretrained model
"save_finetuned_model": "./finetuned_d128",
"save_path": "saved_models" + "/masked_patchtst/",
}
Expand Down Expand Up @@ -186,17 +169,18 @@ def main():
# Create dataloader
dls = get_dls(config_obj, SeasonTST_Dataset, data, mask)

# suggested_lr = find_lr(config_obj, dls)
# This is what I got on a small dataset. In case one wants to skip this for testing.
suggested_lr = 0.00017073526474706903
suggested_lr = 0.0002 # 0.000298364724028334
learner = get_learner(config_obj, dls, suggested_lr, model)
suggested_lr = learner.lr_finder()
print(suggested_lr)

learner = get_learner(config_obj, dls, suggested_lr, model)

# This function will save the model weights to config_obj.save_finetuned_model. ie will not overwrite the pretrained model.
# However, there is currently no set-up to do finetuning from the result of a previous finetuning.
# To continue training from a previous fine-tuning checkpoint, the path needs to be explicity fed to the get_model function
finetune_func(learner, pretrained_model_path, config_obj, suggested_lr)


if __name__ == "__main__":
# PYTHONPATH=$(pwd) python SeasonTST/SeasonTST_finetune.py
main()
52 changes: 29 additions & 23 deletions SeasonTST/SeasonTST_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
datefmt="%m/%d/%Y %I:%M:%S %p",
filename=f'logs/{datetime.datetime.now().strftime("%Y_%m_%d_%I_%M")}_train.log',
encoding="utf-8",
level=logging.DEBUG,
level=logging.INFO,
)


Expand Down Expand Up @@ -95,10 +95,11 @@ def load_config():
"mask_value": -99, # Value to assign to masked elements of data input
"lr": 1e-3,
"batch_size": 128,
"drop_last":True,
"prefetch_factor": 3,
"num_workers": 6,
"n_epochs_pretrain": 1, # number of pre-training epochs
"pretrained_model_id": 2500, # id of the saved pretrained model
"n_epochs_pretrain": 20, # number of pre-training epochs
"pretrained_model_id": 2, # id of the saved pretrained model
}

config_obj = SimpleNamespace(**config)
Expand All @@ -109,37 +110,42 @@ def main():
data, mask = load_data()
config_obj = load_config()

save_path = "saved_models" + "/masked_patchtst/"
pretrained_model = (
"patchtst_pretrained_cw"
+ str(config_obj.sequence_length)
+ "_patch"
+ str(config_obj.patch_len)
+ "_stride"
+ str(config_obj.stride)
+ "_epochs-pretrain"
+ str(config_obj.n_epochs_pretrain)
+ "_mask"
+ str(config_obj.mask_ratio)
+ "_model"
+ str(config_obj.pretrained_model_id)
)
pretrained_model_path = save_path + pretrained_model + ".pth"

# Creates train valid and test datasets for one epoch. Notice that they are in different locations!
dls = get_dls(config_obj, SeasonTST_Dataset, data, mask)

model = get_model(config_obj)

model = get_model(
config_obj, headtype="pretrain", weights_path=pretrained_model_path, exclude_head=False
)

# suggested_lr = find_lr(config_obj, dls)
# This is what I got on a small dataset. In case one wants to skip this for testing.
suggested_lr = 0.00020565123083486514

save_pretrained_model = (
"patchtst_pretrained_cw"
+ str(config_obj.sequence_length)
+ "_patch"
+ str(config_obj.patch_len)
+ "_stride"
+ str(config_obj.stride)
+ "_epochs-pretrain"
+ str(config_obj.n_epochs_pretrain)
+ "_mask"
+ str(config_obj.mask_ratio)
+ "_model"
+ str(config_obj.pretrained_model_id)
)
save_path = "saved_models" + "/masked_patchtst/"


pretrain_func(
save_pretrained_model, save_path, config_obj, model, dls, suggested_lr
pretrained_model, save_path, config_obj, model, dls, suggested_lr
)

pretrained_model_name = save_path + save_pretrained_model + ".pth"

model = transfer_weights(pretrained_model_name, model)
model = transfer_weights(pretrained_model_path, model)


if __name__ == "__main__":
Expand Down
5 changes: 1 addition & 4 deletions SeasonTST/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,17 +75,14 @@ def get_model(config, headtype="pretrain", weights_path=None, exclude_head=True)
return model


def find_lr(config_obj, dls):
def find_lr(model, config_obj, dls):
"""
# This method typically involves training the model for a few epochs with a range of learning rates and recording
the loss at each step. The learning rate that gives the fastest decrease in loss is considered optimal or
near-optimal for the training process.
:param config_obj:
:return:
"""

model = get_model(config_obj)
# get loss
loss_func = torch.nn.MSELoss(reduction="mean")
# get callbacks
Expand Down

0 comments on commit b11d595

Please sign in to comment.