-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
54 lines (48 loc) · 2.23 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Provides utility functions for saving/loading checkpoints, managing CUDA usage, and initializing optimizers for neural network training
# import pytorch library - a tensor library for deep learning using GPUs and CPUs
import torch
import json
# Checkpoints - save the model state dict to the specified directory "model_dir"
def save_checkpoint(model, model_dir):
torch.save(model.state_dict(), model_dir)
# loads state dict found in "model_dir" directory into a given model
# map_location remaps tensors to appropriate device (useful for loading a model checkpoint saved on a GPU to a CPU or vice versa)
#
def resume_checkpoint(model, model_dir, device_id):
state_dict = torch.load(
model_dir, map_location=lambda storage, loc: storage.cuda(device=device_id) # maps storage (data tensor being loaded) to cuda device with given device_id (aka always on GPU)
)
model.load_state_dict(state_dict) # load state dict into provided model
# Hyper params
def use_cuda(enabled, device_id=0):
if enabled: # if enabled, check CUDA support -> if that is available, set device to specified cuda device
assert torch.cuda.is_available(), "CUDA is not available"
torch.cuda.set_device(device_id)
# takes in network (neural network model) and params (dict containing optimizer parameters)
# based on value of "optimizer" key, it initializes and returns an optimizer object
def use_optimizer(network, params):
if params["optimizer"] == "sgd":
optimizer = torch.optim.SGD(
network.parameters(),
lr=params["lr"],
momentum=params["momentum"],
weight_decay=params["l2_regularization"],
)
elif params["optimizer"] == "adam":
optimizer = torch.optim.Adam(
network.parameters(),
lr=params["lr"],
weight_decay=params["l2_regularization"],
)
elif params["optimizer"] == "rmsprop":
optimizer = torch.optim.RMSprop(
network.parameters(),
lr=params["lr"],
alpha=params["rmsprop_alpha"],
momentum=params["momentum"],
)
else:
raise ValueError(
f'{params.get("optimizer", None)} is not allowed as optimizer value'
)
return optimizer