forked from real-stanford/diffusion_policy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
normalizer.py
353 lines (301 loc) · 11.6 KB
/
normalizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
from typing import Union, Dict
import unittest
import zarr
import numpy as np
import torch
import torch.nn as nn
from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.model.common.dict_of_tensor_mixin import DictOfTensorMixin
class LinearNormalizer(DictOfTensorMixin):
avaliable_modes = ['limits', 'gaussian']
@torch.no_grad()
def fit(self,
data: Union[Dict, torch.Tensor, np.ndarray, zarr.Array],
last_n_dims=1,
dtype=torch.float32,
mode='limits',
output_max=1.,
output_min=-1.,
range_eps=1e-4,
fit_offset=True):
if isinstance(data, dict):
for key, value in data.items():
self.params_dict[key] = _fit(value,
last_n_dims=last_n_dims,
dtype=dtype,
mode=mode,
output_max=output_max,
output_min=output_min,
range_eps=range_eps,
fit_offset=fit_offset)
else:
self.params_dict['_default'] = _fit(data,
last_n_dims=last_n_dims,
dtype=dtype,
mode=mode,
output_max=output_max,
output_min=output_min,
range_eps=range_eps,
fit_offset=fit_offset)
def __call__(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
return self.normalize(x)
def __getitem__(self, key: str):
return SingleFieldLinearNormalizer(self.params_dict[key])
def __setitem__(self, key: str , value: 'SingleFieldLinearNormalizer'):
self.params_dict[key] = value.params_dict
def _normalize_impl(self, x, forward=True):
if isinstance(x, dict):
result = dict()
for key, value in x.items():
params = self.params_dict[key]
result[key] = _normalize(value, params, forward=forward)
return result
else:
if '_default' not in self.params_dict:
raise RuntimeError("Not initialized")
params = self.params_dict['_default']
return _normalize(x, params, forward=forward)
def normalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
return self._normalize_impl(x, forward=True)
def unnormalize(self, x: Union[Dict, torch.Tensor, np.ndarray]) -> torch.Tensor:
return self._normalize_impl(x, forward=False)
def get_input_stats(self) -> Dict:
if len(self.params_dict) == 0:
raise RuntimeError("Not initialized")
if len(self.params_dict) == 1 and '_default' in self.params_dict:
return self.params_dict['_default']['input_stats']
result = dict()
for key, value in self.params_dict.items():
if key != '_default':
result[key] = value['input_stats']
return result
def get_output_stats(self, key='_default'):
input_stats = self.get_input_stats()
if 'min' in input_stats:
# no dict
return dict_apply(input_stats, self.normalize)
result = dict()
for key, group in input_stats.items():
this_dict = dict()
for name, value in group.items():
this_dict[name] = self.normalize({key:value})[key]
result[key] = this_dict
return result
class SingleFieldLinearNormalizer(DictOfTensorMixin):
avaliable_modes = ['limits', 'gaussian']
@torch.no_grad()
def fit(self,
data: Union[torch.Tensor, np.ndarray, zarr.Array],
last_n_dims=1,
dtype=torch.float32,
mode='limits',
output_max=1.,
output_min=-1.,
range_eps=1e-4,
fit_offset=True):
self.params_dict = _fit(data,
last_n_dims=last_n_dims,
dtype=dtype,
mode=mode,
output_max=output_max,
output_min=output_min,
range_eps=range_eps,
fit_offset=fit_offset)
@classmethod
def create_fit(cls, data: Union[torch.Tensor, np.ndarray, zarr.Array], **kwargs):
obj = cls()
obj.fit(data, **kwargs)
return obj
@classmethod
def create_manual(cls,
scale: Union[torch.Tensor, np.ndarray],
offset: Union[torch.Tensor, np.ndarray],
input_stats_dict: Dict[str, Union[torch.Tensor, np.ndarray]]):
def to_tensor(x):
if not isinstance(x, torch.Tensor):
x = torch.from_numpy(x)
x = x.flatten()
return x
# check
for x in [offset] + list(input_stats_dict.values()):
assert x.shape == scale.shape
assert x.dtype == scale.dtype
params_dict = nn.ParameterDict({
'scale': to_tensor(scale),
'offset': to_tensor(offset),
'input_stats': nn.ParameterDict(
dict_apply(input_stats_dict, to_tensor))
})
return cls(params_dict)
@classmethod
def create_identity(cls, dtype=torch.float32):
scale = torch.tensor([1], dtype=dtype)
offset = torch.tensor([0], dtype=dtype)
input_stats_dict = {
'min': torch.tensor([-1], dtype=dtype),
'max': torch.tensor([1], dtype=dtype),
'mean': torch.tensor([0], dtype=dtype),
'std': torch.tensor([1], dtype=dtype)
}
return cls.create_manual(scale, offset, input_stats_dict)
def normalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
return _normalize(x, self.params_dict, forward=True)
def unnormalize(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
return _normalize(x, self.params_dict, forward=False)
def get_input_stats(self):
return self.params_dict['input_stats']
def get_output_stats(self):
return dict_apply(self.params_dict['input_stats'], self.normalize)
def __call__(self, x: Union[torch.Tensor, np.ndarray]) -> torch.Tensor:
return self.normalize(x)
def _fit(data: Union[torch.Tensor, np.ndarray, zarr.Array],
last_n_dims=1,
dtype=torch.float32,
mode='limits',
output_max=1.,
output_min=-1.,
range_eps=1e-4,
fit_offset=True):
assert mode in ['limits', 'gaussian']
assert last_n_dims >= 0
assert output_max > output_min
# convert data to torch and type
if isinstance(data, zarr.Array):
data = data[:]
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if dtype is not None:
data = data.type(dtype)
# convert shape
dim = 1
if last_n_dims > 0:
dim = np.prod(data.shape[-last_n_dims:])
data = data.reshape(-1,dim)
# compute input stats min max mean std
input_min, _ = data.min(axis=0)
input_max, _ = data.max(axis=0)
input_mean = data.mean(axis=0)
input_std = data.std(axis=0)
# compute scale and offset
if mode == 'limits':
if fit_offset:
# unit scale
input_range = input_max - input_min
ignore_dim = input_range < range_eps
input_range[ignore_dim] = output_max - output_min
scale = (output_max - output_min) / input_range
offset = output_min - scale * input_min
offset[ignore_dim] = (output_max + output_min) / 2 - input_min[ignore_dim]
# ignore dims scaled to mean of output max and min
else:
# use this when data is pre-zero-centered.
assert output_max > 0
assert output_min < 0
# unit abs
output_abs = min(abs(output_min), abs(output_max))
input_abs = torch.maximum(torch.abs(input_min), torch.abs(input_max))
ignore_dim = input_abs < range_eps
input_abs[ignore_dim] = output_abs
# don't scale constant channels
scale = output_abs / input_abs
offset = torch.zeros_like(input_mean)
elif mode == 'gaussian':
ignore_dim = input_std < range_eps
scale = input_std.clone()
scale[ignore_dim] = 1
scale = 1 / scale
if fit_offset:
offset = - input_mean * scale
else:
offset = torch.zeros_like(input_mean)
# save
this_params = nn.ParameterDict({
'scale': scale,
'offset': offset,
'input_stats': nn.ParameterDict({
'min': input_min,
'max': input_max,
'mean': input_mean,
'std': input_std
})
})
for p in this_params.parameters():
p.requires_grad_(False)
return this_params
def _normalize(x, params, forward=True):
assert 'scale' in params
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
scale = params['scale']
offset = params['offset']
x = x.to(device=scale.device, dtype=scale.dtype)
src_shape = x.shape
x = x.reshape(-1, scale.shape[0])
if forward:
x = x * scale + offset
else:
x = (x - offset) / scale
x = x.reshape(src_shape)
return x
def test():
data = torch.zeros((100,10,9,2)).uniform_()
data[...,0,0] = 0
normalizer = SingleFieldLinearNormalizer()
normalizer.fit(data, mode='limits', last_n_dims=2)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.max(), 1.)
assert np.allclose(datan.min(), -1.)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
input_stats = normalizer.get_input_stats()
output_stats = normalizer.get_output_stats()
normalizer = SingleFieldLinearNormalizer()
normalizer.fit(data, mode='limits', last_n_dims=1, fit_offset=False)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.max(), 1., atol=1e-3)
assert np.allclose(datan.min(), 0., atol=1e-3)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
data = torch.zeros((100,10,9,2)).uniform_()
normalizer = SingleFieldLinearNormalizer()
normalizer.fit(data, mode='gaussian', last_n_dims=0)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.mean(), 0., atol=1e-3)
assert np.allclose(datan.std(), 1., atol=1e-3)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
# dict
data = torch.zeros((100,10,9,2)).uniform_()
data[...,0,0] = 0
normalizer = LinearNormalizer()
normalizer.fit(data, mode='limits', last_n_dims=2)
datan = normalizer.normalize(data)
assert datan.shape == data.shape
assert np.allclose(datan.max(), 1.)
assert np.allclose(datan.min(), -1.)
dataun = normalizer.unnormalize(datan)
assert torch.allclose(data, dataun, atol=1e-7)
input_stats = normalizer.get_input_stats()
output_stats = normalizer.get_output_stats()
data = {
'obs': torch.zeros((1000,128,9,2)).uniform_() * 512,
'action': torch.zeros((1000,128,2)).uniform_() * 512
}
normalizer = LinearNormalizer()
normalizer.fit(data)
datan = normalizer.normalize(data)
dataun = normalizer.unnormalize(datan)
for key in data:
assert torch.allclose(data[key], dataun[key], atol=1e-4)
input_stats = normalizer.get_input_stats()
output_stats = normalizer.get_output_stats()
state_dict = normalizer.state_dict()
n = LinearNormalizer()
n.load_state_dict(state_dict)
datan = n.normalize(data)
dataun = n.unnormalize(datan)
for key in data:
assert torch.allclose(data[key], dataun[key], atol=1e-4)