-
Notifications
You must be signed in to change notification settings - Fork 11
/
checkpointing.py
114 lines (91 loc) · 3.88 KB
/
checkpointing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import copy
import os
from typing import Any, Dict, Optional, Union, Type
import torch
from torch import nn, optim
class CheckpointManager(object):
r"""
A :class:`CheckpointManager` periodically serializes models and optimizer as .pth files during
training, and keeps track of best performing checkpoint based on an observed metric.
Extended Summary
----------------
It saves state dicts of models and optimizer as ``.pth`` files in a specified directory. This
class closely follows the API of PyTorch optimizers and learning rate schedulers.
Notes
-----
For :class:`~torch.nn.DataParallel` objects, ``.module.state_dict()`` is called instead of
``.state_dict()``.
Parameters
----------
models: Dict[str, torch.nn.Module]
Models which need to be serialized as a checkpoint.
optimizer: torch.optim.Optimizer
Optimizer which needs to be serialized as a checkpoint.
serialization_dir: str
Path to an empty or non-existent directory to save checkpoints.
mode: str, optional (default="max")
One of ``min``, ``max``. In ``min`` mode, best checkpoint will be recorded when metric
hits a lower value; in `max` mode it will be recorded when metric hits a higher value.
filename_prefix: str, optional (default="checkpoint")
Prefix of the to-be-saved checkpoint files.
Examples
--------
>>> model = torch.nn.Linear(10, 2)
>>> optimizer = torch.optim.Adam(model.parameters())
>>> ckpt_manager = CheckpointManager({"model": model}, optimizer, "/tmp/ckpt", mode="min")
>>> num_epochs = 20
>>> for epoch in range(num_epochs):
... train(model)
... val_loss = validate(model)
... ckpt_manager.step(val_loss, epoch)
"""
def __init__(
self,
models: Union[nn.Module, Dict[str, nn.Module]],
serialization_dir: str,
mode: str = "max",
filename_prefix: str = "model",
):
# Convert single model to a dict.
if isinstance(models, nn.Module):
models = {"model": models}
for key in models:
if not isinstance(models[key], nn.Module):
raise TypeError("{} is not a Module".format(type(models).__name__))
self._models = models
self._serialization_dir = serialization_dir
self._mode = mode
self._filename_prefix = filename_prefix
# Initialize members to hold state dict of best checkpoint and its performance.
self._best_metric: Optional[Union[float, torch.Tensor]] = None
def step(self, metric: Union[float, torch.Tensor], epoch=None):
r"""Serialize checkpoint and update best checkpoint based on metric and mode."""
# Update best checkpoint based on metric and metric mode.
if not self._best_metric:
self._best_metric = metric
models_state_dict: Dict[str, Any] = {}
for key in self._models:
if isinstance(self._models[key], nn.DataParallel):
models_state_dict[key] = self._models[key].module.state_dict()
else:
models_state_dict[key] = self._models[key].state_dict()
if epoch is not None:
torch.save(
models_state_dict,
os.path.join(
self._serialization_dir, "%s-epoch-%d.pth" % (self._filename_prefix, epoch)
),
)
if (self._mode == "min" and metric <= self._best_metric) or (
self._mode == "max" and metric >= self._best_metric
):
self._best_metric = metric
# Serialize checkpoint corresponding to current epoch (or iteration).
torch.save(
models_state_dict,
os.path.join(
self._serialization_dir, f"{self._filename_prefix}-best.pth"
),
)
return True
return False