Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
sungyubkim authored Mar 1, 2023
1 parent f7aa613 commit 87ca4b4
Show file tree
Hide file tree
Showing 6 changed files with 340 additions and 0 deletions.
27 changes: 27 additions & 0 deletions methods/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Role of files

* Code Snippets

* ```cs.py``` : a python implementation of Connectivity Sharpness

* ```lanczos.py```: a python implementation of Lanczos iteration

* ```rto.py```: a python implementation of Randomize-Then-Optimize (RTO) method for LL and CL

* Helper scripts
* ```tool.py```: an auxiliary codes for managing parameters in Haiku
* ```mp.py```:an auxiliary codes for data parallelization in JAX.



# Requirements

```bash
tensorflow 2.9.1
tensorflow-datasets 4.6.0
jax 0.3.17
jaxlib 0.3.15+cuda11.cudnn82
jaxline 0.0.5
flax 0.5.3
dm-haiku 0.0.9.dev0
```
24 changes: 24 additions & 0 deletions methods/cs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from functools import partial
import jax
import jax.numpy as jnp
import tool

# evaluate connectivity sharpness

def eval_ctk_trace(forward, params, state, batch, rng, max_iter, num_classes):
vec_params = tool.params_to_vec(params)
f = partial(forward, state=state, input_=batch['x'])
_, f_vjp = jax.vjp(f, params)

def body_fn(_, a):
res, rng = a
_, rng = jax.random.split( rng )
v = jax.random.rademacher(rng, (batch['x'].shape[0], num_classes), jnp.float32)
j_p = tool.params_to_vec(f_vjp(v)) * vec_params
tr_ctk = jnp.sum(jnp.square(j_p)) / batch['x'].shape[0]
res += tr_ctk / max_iter
return (res, rng)

a = jax.lax.fori_loop(0, max_iter, body_fn, (0., rng))
res, rng = a
return res
78 changes: 78 additions & 0 deletions methods/lanczos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import tensorflow as tf
tf.config.experimental.set_visible_devices([], 'GPU')
from absl import flags
import jax
import jax.numpy as jnp
import numpy as np
from tqdm import tqdm

import tool, mp

flags.DEFINE_float('scale_factor', 1.0, help='scale factor')
flags.DEFINE_bool('use_connect', False, help='use connectivity or parameter')
flags.DEFINE_integer('sharpness_rand_proj_dim', 10, help='random projection dimension for sharpness estimation')
FLAGS = flags.FLAGS

@jax.pmap
def mvp_batch(v, trainer, batch):
# apply JVP -> VJP to compute J^T J v (Matrix-vector multiplication)
vec_params, unravel_fn = tool.params_to_vec(trainer.params, True)

if FLAGS.use_connect:
multiplier = vec_params
else:
multiplier = jnp.ones_like(vec_params)

f = lambda x: tool.forward(x, trainer, batch['x'], train=True)
_, res = jax.jvp(f, [unravel_fn(vec_params)], [unravel_fn(v * multiplier)])
_, f_vjp = jax.vjp(f, unravel_fn(vec_params))
res = f_vjp(res)
res = tool.params_to_vec(res) * multiplier
return res

def mvp(v, trainer, ds_train):
res = 0.
for batch in ds_train:
res += mvp_batch(mp.replicate(v), trainer, batch).sum(axis=0)
return res

def lanczos(trainer, ds_train, rng):
# Modified lanczos alogrithm of https://github.com/google/spectral-density/blob/master/jax/lanczos.py for recent jax ver.
rand_proj_dim = FLAGS.sharpness_rand_proj_dim
vec_params, unravel_fn = tool.params_to_vec(mp.unreplicate(trainer).params, True)

tridiag = jnp.zeros((rand_proj_dim, rand_proj_dim))
vecs = jnp.zeros((rand_proj_dim, len(vec_params)))

init_vec = jax.random.normal(rng, shape=vec_params.shape)
init_vec = init_vec / jnp.linalg.norm(init_vec)
vecs = vecs.at[0].set(init_vec)

beta = 0
for i in tqdm(range(rand_proj_dim)):
v = vecs[i, :]
if i == 0:
v_old = 0
else:
v_old = vecs[i -1, :]

w = mvp(v, trainer, ds_train)
w = w - beta * v_old

alpha = jnp.dot(w, v)
tridiag = tridiag.at[i, i].set(alpha)
w = w - alpha * v

for j in range(i):
tau = vecs[j, :]
coef = np.dot(w, tau)
w += - coef * tau

beta = jnp.linalg.norm(w)

if (i + 1) < rand_proj_dim:
tridiag = tridiag.at[i, i+1].set(beta)
tridiag = tridiag.at[i+1, i].set(beta)
vecs = vecs.at[i+1].set(w/beta)

return tridiag, vecs
14 changes: 14 additions & 0 deletions methods/mp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import jax
from flax.jax_utils import replicate


def cross_replica_mean(replicated):
return jax.pmap(lambda x: jax.lax.pmean(x,'x'),'x')(replicated)

def unreplicate(tree, i=0):
"""Returns a single instance of a replicated array."""
return jax.tree_util.tree_map(lambda x: x[i], tree)

def sync_bn(state):
batch_stats = cross_replica_mean(state.batch_stats)
return state.replace(batch_stats=batch_stats)
135 changes: 135 additions & 0 deletions methods/rto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import numpy as np
import jax
import jax.numpy as jnp
from functools import partial
from typing import Any, OrderedDict
from tqdm import tqdm
from absl import flags
from time import time

import tool, mp

FLAGS = flags.FLAGS

def loss_fn(params, num_train, trainer, batch, rng, param_noise, batch_noise):
vec_params = tool.params_to_vec(params)
if FLAGS.use_connect:
new_params = jax.tree_util.tree_map(
lambda x, y: x * y, params, trainer.offset)
else:
new_params = params

f = lambda x: tool.forward(x, trainer, batch['x'], train=True)
logit, jvp = jax.jvp(f, [trainer.offset], [new_params])

acc = (jnp.argmax(logit + jvp, axis=-1) == jnp.argmax(batch['y'], axis=-1)).astype(int)

loss = 0.5 * ((jvp - batch_noise * batch['y'])**2).sum(axis=-1)
wd = 0.5 * ((vec_params - param_noise)**2).sum()
loss_ = loss.mean() + wd / num_train * (FLAGS.sigma/FLAGS.alpha)**2
param_norm = (vec_params**2).sum()
return loss_, (loss, acc, wd, param_norm)

@partial(jax.pmap, axis_name='batch', static_broadcasted_argnums=(3,4))
def new_opt_step(trainer, batch, rng, sync_grad, num_train, param_noise, data_noise):
vec_params, unravel_fn = tool.params_to_vec(trainer.params, True)
batch_noise = data_noise[batch['idx']]
grad_fn = jax.grad(loss_fn, has_aux=True)
# compute grad
grad, (loss_b, acc_b, wd_b, param_norm) = grad_fn(
trainer.params,
num_train,
trainer,
batch,
rng,
param_noise,
batch_noise,
)
grad = tool.params_to_vec(grad)
grad_norm = jnp.sqrt((grad**2).sum())

log = [
('loss_sgd', loss_b),
('acc_sgd', acc_b),
('wd_sgd', wd_b),
('grad_sgd', grad_norm),
('param_sgd', param_norm),
]
log = OrderedDict(log)

# update NN
if sync_grad:
grad = jax.lax.pmean(grad, axis_name='batch')
trainer = trainer.apply_gradients(grads=unravel_fn(grad))
return log, trainer

def get_posterior_rto(trainer, opt_step, dataset_opt, rng, *args, **kwargs):
num_devices = jax.device_count()
num_train = kwargs['num_train']
num_class = kwargs['num_class']
sync_grad = not(FLAGS.ft_local)
vec_params= tool.params_to_vec(mp.unreplicate(trainer).params)
init_p = jax.pmap(tool.init_trainer_ft_lin)

if FLAGS.ft_local:
num_stage = int(np.ceil(float(FLAGS.ens_num)/num_devices))
else:
num_stage = FLAGS.ens_num

print(f'Start {FLAGS.ens_num} ensemble training')
posterior = []
for i in range(num_stage):
# randomize
rng, rng_ = jax.random.split(rng)
if sync_grad:
param_noise = mp.replicate(jax.random.normal(rng_, vec_params.shape)) * FLAGS.alpha
rng, rng_ = jax.random.split(rng)
data_noise = mp.replicate(jax.random.normal(rng_, (num_train, num_class))) * FLAGS.sigma
else:
param_noise = jax.random.normal(rng_, (num_devices, *vec_params.shape)) * FLAGS.alpha
rng, rng_ = jax.random.split(rng)
data_noise = jax.random.normal(rng_, (num_devices, num_train, num_class)) * FLAGS.sigma

trainer_ft = init_p(trainer, param_noise)

# optimize
pbar = tqdm(range(FLAGS.ft_step))
for step in pbar:
if i==0 and step==1:
# remove first iteration to exclude compile time
start_time = time()
batch_tr = next(dataset_opt)
rng, rng_ = jax.random.split(rng)
log, trainer_ft = new_opt_step(
trainer_ft,
batch_tr,
jax.random.split(rng_, num_devices),
sync_grad,
num_train,
param_noise,
data_noise,
)

log = OrderedDict([(k,f'{np.mean(v):.2f}') for k,v in log.items()])
log.update({'stage': i})
log.move_to_end('stage', last=False)
pbar.set_postfix(log)

end_time = time()
print(f'Cache time except compile : {end_time - start_time:.4} s')

print(f'Post-processing members')
if FLAGS.ft_local:
# add multiple local models
for ens_mem in tqdm(range(num_devices)):
member = mp.unreplicate(trainer_ft, ens_mem)
if FLAGS.use_connect:
member = member.replace(params=jax.tree_util.tree_map(lambda x, y : x * y, member.params, member.offset))
posterior.append(member)
else:
# add single global model
member = mp.unreplicate(trainer_ft)
if FLAGS.use_connect:
member = member.replace(params=jax.tree_util.tree_map(lambda x, y : x * y, member.params, member.offset))
posterior.append(member)
return posterior
62 changes: 62 additions & 0 deletions methods/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Any, Callable
from functools import partial
import jax
import jax.numpy as jnp
from flax import struct
import optax

from absl import flags
from jax.flatten_util import ravel_pytree

from absl import flags
FLAGS = flags.FLAGS

class Trainer(struct.PyTreeNode):
step: int
apply_fn: Callable = struct.field(pytree_node=False)
tx: Callable = struct.field(pytree_node=False)
params: Any = None
state: Any = None
opt_state: Any = None

@classmethod
def create(cls, *, apply_fn, params, tx, **kwargs):
opt_state = tx.init(params)
return cls(step=0, apply_fn=apply_fn, params=params, tx=tx, opt_state=opt_state, **kwargs)

def apply_gradients(self, *, grads, **kwargs):
updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params)
new_params = optax.apply_updates(self.params, updates)
return self.replace(step=self.step+1, params=new_params, opt_state=new_opt_state, **kwargs)

def select_tree(pred: jnp.ndarray, a, b):
assert pred.ndim == 0 and pred.dtype == jnp.bool_, "expected boolean scalar"
return jax.tree_map(partial(jax.lax.select, pred), a, b)

class TrainerPert(Trainer):
offset : Any = None

def params_to_vec(param, unravel=False):
vec_param, unravel_fn = ravel_pytree(param)
if unravel:
return vec_param, unravel_fn
else:
return vec_param

def forward(params, trainer, input_, rng=None, train=True):
res, _ = trainer.apply_fn(params, trainer.state, rng, input_, train)
return res

def init_trainer_ft_lin(trainer, init_params=None):
vec_params, unravel_fn = params_to_vec(trainer.params, True)
if init_params is None:
init_params = jnp.zeros_like(vec_params)
tx = optax.chain(optax.adam(learning_rate=FLAGS.ft_lr))
trainer_ft = TrainerPert.create(
apply_fn=trainer.apply_fn,
state=trainer.state,
offset=trainer.params,
params=unravel_fn(init_params),
tx=tx,
)
return trainer_ft

0 comments on commit 87ca4b4

Please sign in to comment.