-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathutils.py
141 lines (112 loc) · 4.65 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: skip-file
"""Utility code for generating and saving image grids and checkpointing.
The `save_image` code is copied from
https://github.com/google/flax/blob/master/examples/vae/utils.py,
which is a JAX equivalent to the same function in TorchVision
(https://github.com/pytorch/vision/blob/master/torchvision/utils.py)
"""
import math
from typing import Any, Dict, Optional, TypeVar
import flax
import jax
import jax.numpy as jnp
from PIL import Image
import tensorflow as tf
from jax import numpy as jnp
T = TypeVar("T")
def batch_add(a, b):
return jax.vmap(lambda a, b: a + b)(a, b)
def batch_mul(a, b):
return jax.vmap(lambda a, b: a * b)(a, b)
def load_training_state(filepath, state):
with tf.io.gfile.GFile(filepath, "rb") as f:
state = flax.serialization.from_bytes(state, f.read())
return state
def save_image(ndarray, fp, nrow=8, padding=2, pad_value=0.0, format=None):
"""Make a grid of images and save it into an image file.
Pixel values are assumed to be within [0, 1].
Args:
ndarray (array_like): 4D mini-batch images of shape (B x H x W x C).
fp: A filename(string) or file object.
nrow (int, optional): Number of images displayed in each row of the grid.
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
padding (int, optional): amount of padding. Default: ``2``.
pad_value (float, optional): Value for the padded pixels. Default: ``0``.
format(Optional): If omitted, the format to use is determined from the
filename extension. If a file object was used instead of a filename, this
parameter should always be used.
"""
if not (isinstance(ndarray, jnp.ndarray) or
(isinstance(ndarray, list) and
all(isinstance(t, jnp.ndarray) for t in ndarray))):
raise TypeError("array_like of tensors expected, got {}".format(
type(ndarray)))
ndarray = jnp.asarray(ndarray)
if ndarray.ndim == 4 and ndarray.shape[-1] == 1: # single-channel images
ndarray = jnp.concatenate((ndarray, ndarray, ndarray), -1)
# make the mini-batch of images into a grid
nmaps = ndarray.shape[0]
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(ndarray.shape[1] + padding), int(ndarray.shape[2] +
padding)
num_channels = ndarray.shape[3]
grid = jnp.full(
(height * ymaps + padding, width * xmaps + padding, num_channels),
pad_value).astype(jnp.float32)
k = 0
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
grid = jax.ops.index_update(
grid, jax.ops.index[y * height + padding:(y + 1) * height,
x * width + padding:(x + 1) * width], ndarray[k])
k = k + 1
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
ndarr = jnp.clip(grid * 255.0 + 0.5, 0, 255).astype(jnp.uint8)
im = Image.fromarray(ndarr.copy())
im.save(fp, format=format)
def flatten_dict(config):
"""Flatten a hierarchical dict to a simple dict."""
new_dict = {}
for key, value in config.items():
if isinstance(value, dict):
sub_dict = flatten_dict(value)
for subkey, subvalue in sub_dict.items():
new_dict[key + "/" + subkey] = subvalue
elif isinstance(value, tuple):
new_dict[key] = str(value)
else:
new_dict[key] = value
return new_dict
def get_div_fn(fn):
"""Create the divergence function of `fn` using the Hutchinson-Skilling trace estimator."""
def div_fn(x, t, eps):
grad_fn = lambda data: jnp.sum(fn(data, t) * eps)
grad_fn_eps = jax.grad(grad_fn)(x)
return jnp.sum(grad_fn_eps * eps, axis=tuple(range(1, len(x.shape))))
return div_fn
def get_value_div_fn(fn):
"""Return both the function value and its estimated divergence via Hutchinson's trace estimator."""
def value_div_fn(x, t, eps):
def value_grad_fn(data):
f = fn(data, t)
return jnp.sum(f * eps), f
grad_fn_eps, value = jax.grad(value_grad_fn, has_aux=True)(x)
return value, jnp.sum(grad_fn_eps * eps, axis=tuple(range(1, len(x.shape))))
return value_div_fn