-
Notifications
You must be signed in to change notification settings - Fork 82
/
utils.py
executable file
·85 lines (65 loc) · 2.17 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# Copyright (c) 2019 NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed as dist
from pathlib import Path
def get_rank():
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_world_size():
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def is_main_process():
return get_rank() == 0
def barrier():
if dist.is_available() and dist.is_initialized():
dist.barrier()
def format_step(step):
if isinstance(step, str):
return step
s = ""
if len(step) > 0:
s += "Training Epoch: {} ".format(step[0])
if len(step) > 1:
s += "Training Iteration: {} ".format(step[1])
if len(step) > 2:
s += "Validation Iteration: {} ".format(step[2])
return s
def mkdir(path):
Path(path).mkdir(parents=True, exist_ok=True)
def mkdir_by_main_process(path):
if is_main_process():
mkdir(path)
barrier()
def remove_prefix(text, prefix):
if text.startswith(prefix):
return text[len(prefix) :]
return text
def repair_checkpoint(model_ckpt):
in_state_dict = model_ckpt
pairings = [
(src_key, remove_prefix(src_key, "_orig_mod."))
for src_key in in_state_dict.keys()
]
if all(src_key == dest_key for src_key, dest_key in pairings):
return in_state_dict # Do not write checkpoint if no need to repair!
out_state_dict = {}
for src_key, dest_key in pairings:
out_state_dict[dest_key] = in_state_dict[src_key]
return out_state_dict