-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathutils.py
121 lines (100 loc) · 3.75 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import numpy as np
import logging
import os
def count_params(model):
param_num = sum(p.numel() for p in model.parameters())
return param_num / 1e6
def color_map(dataset='pascal'):
cmap = np.zeros((256, 3), dtype='uint8')
if dataset == 'pascal' or dataset == 'coco':
def bitget(byteval, idx):
return (byteval & (1 << idx)) != 0
for i in range(256):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << 7-j)
g = g | (bitget(c, 1) << 7-j)
b = b | (bitget(c, 2) << 7-j)
c = c >> 3
cmap[i] = np.array([r, g, b])
elif dataset == 'cityscapes':
cmap[0] = np.array([128, 64, 128])
cmap[1] = np.array([244, 35, 232])
cmap[2] = np.array([70, 70, 70])
cmap[3] = np.array([102, 102, 156])
cmap[4] = np.array([190, 153, 153])
cmap[5] = np.array([153, 153, 153])
cmap[6] = np.array([250, 170, 30])
cmap[7] = np.array([220, 220, 0])
cmap[8] = np.array([107, 142, 35])
cmap[9] = np.array([152, 251, 152])
cmap[10] = np.array([70, 130, 180])
cmap[11] = np.array([220, 20, 60])
cmap[12] = np.array([255, 0, 0])
cmap[13] = np.array([0, 0, 142])
cmap[14] = np.array([0, 0, 70])
cmap[15] = np.array([0, 60, 100])
cmap[16] = np.array([0, 80, 100])
cmap[17] = np.array([0, 0, 230])
cmap[18] = np.array([119, 11, 32])
return cmap
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, length=0):
self.length = length
self.reset()
def reset(self):
if self.length > 0:
self.history = []
else:
self.count = 0
self.sum = 0.0
self.val = 0.0
self.avg = 0.0
def update(self, val, num=1):
if self.length > 0:
# currently assert num==1 to avoid bad usage, refine when there are some explicit requirements
assert num == 1
self.history.append(val)
if len(self.history) > self.length:
del self.history[0]
self.val = self.history[-1]
self.avg = np.mean(self.history)
else:
self.val = val
self.sum += val * num
self.count += num
self.avg = self.sum / self.count
def intersectionAndUnion(output, target, K, ignore_index=255):
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
assert output.ndim in [1, 2, 3]
assert output.shape == target.shape
output = output.reshape(output.size).copy()
target = target.reshape(target.size)
output[np.where(target == ignore_index)[0]] = ignore_index
intersection = output[np.where(output == target)[0]]
area_intersection, _ = np.histogram(intersection, bins=np.arange(K + 1))
area_output, _ = np.histogram(output, bins=np.arange(K + 1))
area_target, _ = np.histogram(target, bins=np.arange(K + 1))
area_union = area_output + area_target - area_intersection
return area_intersection, area_union, area_target
logs = set()
def init_log(name, level=logging.INFO):
if (name, level) in logs:
return
logs.add((name, level))
logger = logging.getLogger(name)
logger.setLevel(level)
ch = logging.StreamHandler()
ch.setLevel(level)
if "SLURM_PROCID" in os.environ:
rank = int(os.environ["SLURM_PROCID"])
logger.addFilter(lambda record: rank == 0)
else:
rank = 0
format_str = "[%(asctime)s][%(levelname)8s] %(message)s"
formatter = logging.Formatter(format_str)
ch.setFormatter(formatter)
logger.addHandler(ch)
return logger