-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
f7aa613
commit 87ca4b4
Showing
6 changed files
with
340 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Role of files | ||
|
||
* Code Snippets | ||
|
||
* ```cs.py``` : a python implementation of Connectivity Sharpness | ||
|
||
* ```lanczos.py```: a python implementation of Lanczos iteration | ||
|
||
* ```rto.py```: a python implementation of Randomize-Then-Optimize (RTO) method for LL and CL | ||
|
||
* Helper scripts | ||
* ```tool.py```: an auxiliary codes for managing parameters in Haiku | ||
* ```mp.py```:an auxiliary codes for data parallelization in JAX. | ||
|
||
|
||
|
||
# Requirements | ||
|
||
```bash | ||
tensorflow 2.9.1 | ||
tensorflow-datasets 4.6.0 | ||
jax 0.3.17 | ||
jaxlib 0.3.15+cuda11.cudnn82 | ||
jaxline 0.0.5 | ||
flax 0.5.3 | ||
dm-haiku 0.0.9.dev0 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from functools import partial | ||
import jax | ||
import jax.numpy as jnp | ||
import tool | ||
|
||
# evaluate connectivity sharpness | ||
|
||
def eval_ctk_trace(forward, params, state, batch, rng, max_iter, num_classes): | ||
vec_params = tool.params_to_vec(params) | ||
f = partial(forward, state=state, input_=batch['x']) | ||
_, f_vjp = jax.vjp(f, params) | ||
|
||
def body_fn(_, a): | ||
res, rng = a | ||
_, rng = jax.random.split( rng ) | ||
v = jax.random.rademacher(rng, (batch['x'].shape[0], num_classes), jnp.float32) | ||
j_p = tool.params_to_vec(f_vjp(v)) * vec_params | ||
tr_ctk = jnp.sum(jnp.square(j_p)) / batch['x'].shape[0] | ||
res += tr_ctk / max_iter | ||
return (res, rng) | ||
|
||
a = jax.lax.fori_loop(0, max_iter, body_fn, (0., rng)) | ||
res, rng = a | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import tensorflow as tf | ||
tf.config.experimental.set_visible_devices([], 'GPU') | ||
from absl import flags | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
import tool, mp | ||
|
||
flags.DEFINE_float('scale_factor', 1.0, help='scale factor') | ||
flags.DEFINE_bool('use_connect', False, help='use connectivity or parameter') | ||
flags.DEFINE_integer('sharpness_rand_proj_dim', 10, help='random projection dimension for sharpness estimation') | ||
FLAGS = flags.FLAGS | ||
|
||
@jax.pmap | ||
def mvp_batch(v, trainer, batch): | ||
# apply JVP -> VJP to compute J^T J v (Matrix-vector multiplication) | ||
vec_params, unravel_fn = tool.params_to_vec(trainer.params, True) | ||
|
||
if FLAGS.use_connect: | ||
multiplier = vec_params | ||
else: | ||
multiplier = jnp.ones_like(vec_params) | ||
|
||
f = lambda x: tool.forward(x, trainer, batch['x'], train=True) | ||
_, res = jax.jvp(f, [unravel_fn(vec_params)], [unravel_fn(v * multiplier)]) | ||
_, f_vjp = jax.vjp(f, unravel_fn(vec_params)) | ||
res = f_vjp(res) | ||
res = tool.params_to_vec(res) * multiplier | ||
return res | ||
|
||
def mvp(v, trainer, ds_train): | ||
res = 0. | ||
for batch in ds_train: | ||
res += mvp_batch(mp.replicate(v), trainer, batch).sum(axis=0) | ||
return res | ||
|
||
def lanczos(trainer, ds_train, rng): | ||
# Modified lanczos alogrithm of https://github.com/google/spectral-density/blob/master/jax/lanczos.py for recent jax ver. | ||
rand_proj_dim = FLAGS.sharpness_rand_proj_dim | ||
vec_params, unravel_fn = tool.params_to_vec(mp.unreplicate(trainer).params, True) | ||
|
||
tridiag = jnp.zeros((rand_proj_dim, rand_proj_dim)) | ||
vecs = jnp.zeros((rand_proj_dim, len(vec_params))) | ||
|
||
init_vec = jax.random.normal(rng, shape=vec_params.shape) | ||
init_vec = init_vec / jnp.linalg.norm(init_vec) | ||
vecs = vecs.at[0].set(init_vec) | ||
|
||
beta = 0 | ||
for i in tqdm(range(rand_proj_dim)): | ||
v = vecs[i, :] | ||
if i == 0: | ||
v_old = 0 | ||
else: | ||
v_old = vecs[i -1, :] | ||
|
||
w = mvp(v, trainer, ds_train) | ||
w = w - beta * v_old | ||
|
||
alpha = jnp.dot(w, v) | ||
tridiag = tridiag.at[i, i].set(alpha) | ||
w = w - alpha * v | ||
|
||
for j in range(i): | ||
tau = vecs[j, :] | ||
coef = np.dot(w, tau) | ||
w += - coef * tau | ||
|
||
beta = jnp.linalg.norm(w) | ||
|
||
if (i + 1) < rand_proj_dim: | ||
tridiag = tridiag.at[i, i+1].set(beta) | ||
tridiag = tridiag.at[i+1, i].set(beta) | ||
vecs = vecs.at[i+1].set(w/beta) | ||
|
||
return tridiag, vecs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import jax | ||
from flax.jax_utils import replicate | ||
|
||
|
||
def cross_replica_mean(replicated): | ||
return jax.pmap(lambda x: jax.lax.pmean(x,'x'),'x')(replicated) | ||
|
||
def unreplicate(tree, i=0): | ||
"""Returns a single instance of a replicated array.""" | ||
return jax.tree_util.tree_map(lambda x: x[i], tree) | ||
|
||
def sync_bn(state): | ||
batch_stats = cross_replica_mean(state.batch_stats) | ||
return state.replace(batch_stats=batch_stats) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import numpy as np | ||
import jax | ||
import jax.numpy as jnp | ||
from functools import partial | ||
from typing import Any, OrderedDict | ||
from tqdm import tqdm | ||
from absl import flags | ||
from time import time | ||
|
||
import tool, mp | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
def loss_fn(params, num_train, trainer, batch, rng, param_noise, batch_noise): | ||
vec_params = tool.params_to_vec(params) | ||
if FLAGS.use_connect: | ||
new_params = jax.tree_util.tree_map( | ||
lambda x, y: x * y, params, trainer.offset) | ||
else: | ||
new_params = params | ||
|
||
f = lambda x: tool.forward(x, trainer, batch['x'], train=True) | ||
logit, jvp = jax.jvp(f, [trainer.offset], [new_params]) | ||
|
||
acc = (jnp.argmax(logit + jvp, axis=-1) == jnp.argmax(batch['y'], axis=-1)).astype(int) | ||
|
||
loss = 0.5 * ((jvp - batch_noise * batch['y'])**2).sum(axis=-1) | ||
wd = 0.5 * ((vec_params - param_noise)**2).sum() | ||
loss_ = loss.mean() + wd / num_train * (FLAGS.sigma/FLAGS.alpha)**2 | ||
param_norm = (vec_params**2).sum() | ||
return loss_, (loss, acc, wd, param_norm) | ||
|
||
@partial(jax.pmap, axis_name='batch', static_broadcasted_argnums=(3,4)) | ||
def new_opt_step(trainer, batch, rng, sync_grad, num_train, param_noise, data_noise): | ||
vec_params, unravel_fn = tool.params_to_vec(trainer.params, True) | ||
batch_noise = data_noise[batch['idx']] | ||
grad_fn = jax.grad(loss_fn, has_aux=True) | ||
# compute grad | ||
grad, (loss_b, acc_b, wd_b, param_norm) = grad_fn( | ||
trainer.params, | ||
num_train, | ||
trainer, | ||
batch, | ||
rng, | ||
param_noise, | ||
batch_noise, | ||
) | ||
grad = tool.params_to_vec(grad) | ||
grad_norm = jnp.sqrt((grad**2).sum()) | ||
|
||
log = [ | ||
('loss_sgd', loss_b), | ||
('acc_sgd', acc_b), | ||
('wd_sgd', wd_b), | ||
('grad_sgd', grad_norm), | ||
('param_sgd', param_norm), | ||
] | ||
log = OrderedDict(log) | ||
|
||
# update NN | ||
if sync_grad: | ||
grad = jax.lax.pmean(grad, axis_name='batch') | ||
trainer = trainer.apply_gradients(grads=unravel_fn(grad)) | ||
return log, trainer | ||
|
||
def get_posterior_rto(trainer, opt_step, dataset_opt, rng, *args, **kwargs): | ||
num_devices = jax.device_count() | ||
num_train = kwargs['num_train'] | ||
num_class = kwargs['num_class'] | ||
sync_grad = not(FLAGS.ft_local) | ||
vec_params= tool.params_to_vec(mp.unreplicate(trainer).params) | ||
init_p = jax.pmap(tool.init_trainer_ft_lin) | ||
|
||
if FLAGS.ft_local: | ||
num_stage = int(np.ceil(float(FLAGS.ens_num)/num_devices)) | ||
else: | ||
num_stage = FLAGS.ens_num | ||
|
||
print(f'Start {FLAGS.ens_num} ensemble training') | ||
posterior = [] | ||
for i in range(num_stage): | ||
# randomize | ||
rng, rng_ = jax.random.split(rng) | ||
if sync_grad: | ||
param_noise = mp.replicate(jax.random.normal(rng_, vec_params.shape)) * FLAGS.alpha | ||
rng, rng_ = jax.random.split(rng) | ||
data_noise = mp.replicate(jax.random.normal(rng_, (num_train, num_class))) * FLAGS.sigma | ||
else: | ||
param_noise = jax.random.normal(rng_, (num_devices, *vec_params.shape)) * FLAGS.alpha | ||
rng, rng_ = jax.random.split(rng) | ||
data_noise = jax.random.normal(rng_, (num_devices, num_train, num_class)) * FLAGS.sigma | ||
|
||
trainer_ft = init_p(trainer, param_noise) | ||
|
||
# optimize | ||
pbar = tqdm(range(FLAGS.ft_step)) | ||
for step in pbar: | ||
if i==0 and step==1: | ||
# remove first iteration to exclude compile time | ||
start_time = time() | ||
batch_tr = next(dataset_opt) | ||
rng, rng_ = jax.random.split(rng) | ||
log, trainer_ft = new_opt_step( | ||
trainer_ft, | ||
batch_tr, | ||
jax.random.split(rng_, num_devices), | ||
sync_grad, | ||
num_train, | ||
param_noise, | ||
data_noise, | ||
) | ||
|
||
log = OrderedDict([(k,f'{np.mean(v):.2f}') for k,v in log.items()]) | ||
log.update({'stage': i}) | ||
log.move_to_end('stage', last=False) | ||
pbar.set_postfix(log) | ||
|
||
end_time = time() | ||
print(f'Cache time except compile : {end_time - start_time:.4} s') | ||
|
||
print(f'Post-processing members') | ||
if FLAGS.ft_local: | ||
# add multiple local models | ||
for ens_mem in tqdm(range(num_devices)): | ||
member = mp.unreplicate(trainer_ft, ens_mem) | ||
if FLAGS.use_connect: | ||
member = member.replace(params=jax.tree_util.tree_map(lambda x, y : x * y, member.params, member.offset)) | ||
posterior.append(member) | ||
else: | ||
# add single global model | ||
member = mp.unreplicate(trainer_ft) | ||
if FLAGS.use_connect: | ||
member = member.replace(params=jax.tree_util.tree_map(lambda x, y : x * y, member.params, member.offset)) | ||
posterior.append(member) | ||
return posterior |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
from typing import Any, Callable | ||
from functools import partial | ||
import jax | ||
import jax.numpy as jnp | ||
from flax import struct | ||
import optax | ||
|
||
from absl import flags | ||
from jax.flatten_util import ravel_pytree | ||
|
||
from absl import flags | ||
FLAGS = flags.FLAGS | ||
|
||
class Trainer(struct.PyTreeNode): | ||
step: int | ||
apply_fn: Callable = struct.field(pytree_node=False) | ||
tx: Callable = struct.field(pytree_node=False) | ||
params: Any = None | ||
state: Any = None | ||
opt_state: Any = None | ||
|
||
@classmethod | ||
def create(cls, *, apply_fn, params, tx, **kwargs): | ||
opt_state = tx.init(params) | ||
return cls(step=0, apply_fn=apply_fn, params=params, tx=tx, opt_state=opt_state, **kwargs) | ||
|
||
def apply_gradients(self, *, grads, **kwargs): | ||
updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params) | ||
new_params = optax.apply_updates(self.params, updates) | ||
return self.replace(step=self.step+1, params=new_params, opt_state=new_opt_state, **kwargs) | ||
|
||
def select_tree(pred: jnp.ndarray, a, b): | ||
assert pred.ndim == 0 and pred.dtype == jnp.bool_, "expected boolean scalar" | ||
return jax.tree_map(partial(jax.lax.select, pred), a, b) | ||
|
||
class TrainerPert(Trainer): | ||
offset : Any = None | ||
|
||
def params_to_vec(param, unravel=False): | ||
vec_param, unravel_fn = ravel_pytree(param) | ||
if unravel: | ||
return vec_param, unravel_fn | ||
else: | ||
return vec_param | ||
|
||
def forward(params, trainer, input_, rng=None, train=True): | ||
res, _ = trainer.apply_fn(params, trainer.state, rng, input_, train) | ||
return res | ||
|
||
def init_trainer_ft_lin(trainer, init_params=None): | ||
vec_params, unravel_fn = params_to_vec(trainer.params, True) | ||
if init_params is None: | ||
init_params = jnp.zeros_like(vec_params) | ||
tx = optax.chain(optax.adam(learning_rate=FLAGS.ft_lr)) | ||
trainer_ft = TrainerPert.create( | ||
apply_fn=trainer.apply_fn, | ||
state=trainer.state, | ||
offset=trainer.params, | ||
params=unravel_fn(init_params), | ||
tx=tx, | ||
) | ||
return trainer_ft |