Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a distillation experiment #44

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
06dc83f
soba + pytrees
pierreablin Jun 4, 2024
2af84b0
add dataset
pierreablin Jun 4, 2024
3a8f231
FIX flake8
MatDag Jun 4, 2024
c5f6af2
ENH put update_sgd_fn in utils file
MatDag Jul 18, 2024
4830cef
ENH create tree_utils.py
MatDag Jul 18, 2024
f6c7bc1
WIP pytree
MatDag Jul 22, 2024
c7b7a1d
FIX bug amigo
MatDag Jul 22, 2024
b041359
WIP bome
MatDag Jul 22, 2024
0c41828
WIP pytrees
MatDag Jul 22, 2024
854559a
WIP fsla pytree
MatDag Jul 22, 2024
f5d7be1
WIP jaxopt pytree
MatDag Jul 22, 2024
28857c0
WIP memory trees
MatDag Jul 22, 2024
1672b23
WIP mrbo
MatDag Jul 22, 2024
7ad80cd
WIP pytrees
MatDag Jul 23, 2024
aeaa115
ENH tree_diff
MatDag Jul 23, 2024
ac218d4
WIP saba pytree
MatDag Jul 23, 2024
9ca7bab
FIX saba vr
MatDag Jul 23, 2024
e1d84ff
FIX sustain select_memory
MatDag Jul 23, 2024
947c5dd
FIX sustain select_memory
MatDag Jul 23, 2024
2cc800a
FIX hia
MatDag Jul 23, 2024
fba548e
FIX sustain
MatDag Jul 23, 2024
3e9ab13
ENH distillation
MatDag Jul 25, 2024
28926f7
ENH enables to save distilled images
MatDag Jul 25, 2024
fa0aba8
FIX requirement flax
MatDag Jul 25, 2024
3fddd1c
FIX requirement optax
MatDag Jul 25, 2024
c0edadb
WIP init
MatDag Aug 6, 2024
d06dfe7
WIP making it work
MatDag Oct 16, 2024
1f47feb
WIP comment cnn
MatDag Oct 16, 2024
faa4bc8
ENH init inner_var
MatDag Oct 16, 2024
514356b
WIP
MatDag Oct 16, 2024
d30a685
FIX revert soba
MatDag Oct 16, 2024
9e04793
WIP
MatDag Oct 16, 2024
cbebe9f
FIX flatten
MatDag Oct 16, 2024
5e6280b
CLN jax.tree_map -> jax.tree_util.tree_map
MatDag Oct 17, 2024
d020e4d
FIX test
MatDag Oct 17, 2024
e4e2c96
WIP
MatDag Oct 17, 2024
eb8ae64
ENH allow several achitectures
MatDag Oct 18, 2024
96590a9
FIX model
MatDag Oct 18, 2024
8e319a8
FIX accuracy
MatDag Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions benchmark_utils/gd_inner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
from functools import partial
from benchmark_utils.tree_utils import update_sgd_fn


@partial(jax.jit, static_argnames=('grad_inner', 'n_steps'))
Expand All @@ -12,9 +13,9 @@ def gd_inner_jax(inner_var, outer_var, step_size, grad_inner=None,
----------
grad_inner : callable
Gradient of the inner oracle with respect to the inner variable.
inner_var : array
inner_var : pytree
Initial value of the inner variable.
outer_var : array
outer_var : pytree
Value of the outer variable.
step_size : float
Step size of the gradient descent.
Expand All @@ -26,8 +27,9 @@ def gd_inner_jax(inner_var, outer_var, step_size, grad_inner=None,
inner_var : array
Value of the inner variable after n_steps of gradient descent.
"""
def iter(i, inner_var):
inner_var -= step_size * grad_inner(inner_var, outer_var)
def iter(_, inner_var):
inner_var = update_sgd_fn(inner_var, grad_inner(inner_var, outer_var),
step_size)
return inner_var
inner_var = jax.lax.fori_loop(0, n_steps, iter, inner_var)
return inner_var
55 changes: 34 additions & 21 deletions benchmark_utils/hessian_approximation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import jax
from benchmark_utils.tree_utils import update_sgd_fn
from benchmark_utils.tree_utils import tree_scalar_mult, tree_add


def hia_jax(
Expand Down Expand Up @@ -54,10 +56,11 @@ def hvp(v, start_idx):
def iter(_, args):
state_sampler, v = args
start_idx, *_, state_sampler = sampler(state_sampler)
v -= step_size * hvp(v, start_idx)
v = update_sgd_fn(v, hvp(v, start_idx), step_size)
return state_sampler, v
state_sampler, v = jax.lax.fori_loop(0, p[0], iter, (state_sampler, v))
return n_steps * step_size * v, jax.random.split(key, 1)[0], state_sampler
v = tree_scalar_mult(n_steps * step_size, v)
return v, jax.random.split(key, 1)[0], state_sampler


def shia_jax(
Expand All @@ -75,13 +78,13 @@ def shia_jax(

Parameters
----------
inner_var : array
inner_var : pytree
Inner variable.

outer_var : array
outer_var : pytree
Outer variable.

v : array
v : pytree
Right hand side of the linear system.

state_sampler : dict
Expand Down Expand Up @@ -113,12 +116,12 @@ def hvp(v, start_idx):
def iter(_, args):
state_sampler, v, s = args
start_idx, *_, state_sampler = sampler(state_sampler)
v -= step_size * hvp(v, start_idx)
s += v
v = update_sgd_fn(v, hvp(v, start_idx), step_size)
s = update_sgd_fn(s, v, -1) # s += v
return state_sampler, v, s
state_sampler, _, s = jax.lax.fori_loop(0, n_steps, iter,
(state_sampler, v, s))
return step_size * s, state_sampler
return tree_scalar_mult(step_size, s), state_sampler


def shia_fb_jax(inner_var, outer_var, v, step_size, n_steps=1,
Expand Down Expand Up @@ -162,11 +165,11 @@ def hvp(v):

def iter(_, args):
v, s = args
v -= step_size * hvp(v)
s += v
v = update_sgd_fn(v, hvp(v), step_size)
s = update_sgd_fn(s, v, -1) # s += v
return v, s
_, s = jax.lax.fori_loop(0, n_steps, iter, (v, s))
return step_size * s
return tree_scalar_mult(step_size, s)


def sgd_v_jax(inner_var, outer_var, v, grad_out, state_sampler,
Expand Down Expand Up @@ -220,7 +223,10 @@ def hvp(v, start_idx):
def iter(_, args):
state_sampler, v = args
start_idx, *_, state_sampler = sampler(state_sampler)
v -= step_size * (hvp(v, start_idx) - grad_out)
v = update_sgd_fn(v,
tree_add(hvp(v, start_idx),
tree_scalar_mult(-1, grad_out)),
step_size)
return state_sampler, v
state_sampler, v = jax.lax.fori_loop(0, n_steps, iter, (state_sampler, v))
return v, state_sampler
Expand Down Expand Up @@ -290,16 +296,19 @@ def hvp_old(v, start_idx):
def iter(_, args):
state_sampler, v, s, v_old, s_old = args
start_idx, *_, state_sampler = sampler(state_sampler)
v -= step_size * hvp(v, start_idx)
s += v
v_old -= step_size * hvp_old(v_old, start_idx)
s_old += v_old
v = update_sgd_fn(v, hvp(v, start_idx), step_size)
s = update_sgd_fn(s, v, -1) # s += v
v_old = update_sgd_fn(v_old, hvp_old(v_old, start_idx), step_size)
s_old = update_sgd_fn(s_old, v_old, -1) # s_old += v_old
return state_sampler, v, s, v_old, s_old

state_sampler, _, s, _, s_old = jax.lax.fori_loop(
0, n_steps, iter, (state_sampler, v, s, v_old, s_old)
)

return step_size * s, step_size * s_old, state_sampler
return (
tree_scalar_mult(step_size, s), tree_scalar_mult(step_size, s_old),
state_sampler
)


def joint_hia_jax(
Expand Down Expand Up @@ -375,12 +384,16 @@ def hvp_old(v, start_idx):
def iter(_, args):
state_sampler, v, v_old = args
start_idx, *_, state_sampler = sampler(state_sampler)
v -= step_size * hvp(v, start_idx)
v_old -= step_size * hvp_old(v_old, start_idx)
v = update_sgd_fn(v, hvp(v, start_idx), step_size)
v_old = update_sgd_fn(v_old, hvp_old(v_old, start_idx), step_size)
return state_sampler, v, v_old

state_sampler, v, v_old = jax.lax.fori_loop(
0, p[0], iter, (state_sampler, v, v_old)
)

return n_steps * step_size * v, n_steps * step_size * v_old, \
return (
tree_scalar_mult(n_steps * step_size, v),
tree_scalar_mult(n_steps * step_size, v_old),
jax.random.split(key, 1)[0], state_sampler
)
11 changes: 7 additions & 4 deletions benchmark_utils/sgd_inner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import jax
from benchmark_utils.tree_utils import update_sgd_fn


def sgd_inner_jax(inner_var, outer_var, state_sampler, step_size,
Expand All @@ -8,9 +9,9 @@ def sgd_inner_jax(inner_var, outer_var, state_sampler, step_size,

Parameters
----------
inner_var : array
inner_var : pytree
Initial value of the inner variable.
outer_var : array
outer_var : pytree
Value of the outer variable.
state_sampler : dict
State of the sampler.
Expand All @@ -23,10 +24,12 @@ def sgd_inner_jax(inner_var, outer_var, state_sampler, step_size,
grad_inner : callable
Gradient of the inner oracle with respect to the inner variable.
"""
def iter(i, args):
def iter(_, args):
state_sampler, inner_var = args
start_idx, *_, state_sampler = sampler(state_sampler)
inner_var -= step_size * grad_inner(inner_var, outer_var, start_idx)
inner_var = update_sgd_fn(inner_var,
grad_inner(inner_var, outer_var, start_idx),
step_size)
return state_sampler, inner_var
state_sampler, inner_var = jax.lax.fori_loop(0, n_steps, iter,
(state_sampler, inner_var))
Expand Down
131 changes: 131 additions & 0 deletions benchmark_utils/tree_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import jax
import jax.numpy as jnp


def update_sgd_fn(var, grad, step_size):
"""
Helper function that update the variable with a gradient step.

Parameters
----------
var : pytree
Variable to update.

grad : pytree
Gradient of the variable.

step_size : float
Step size of the gradient step.
"""
return jax.tree_util.tree_map(lambda x, y: x - step_size * y,
var, grad)


def tree_add(a, b):
"""
Helper function that adds two pytrees.

Parameters
----------
a : pytree
First pytree to add.

b : pytree
Second pytree to add.
"""
return jax.tree_util.tree_map(jnp.add, a, b)


def tree_diff(a, b):
"""
Helper function that subtracts two pytrees.

Parameters
----------
a : pytree
First pytree to subtract.

b : pytree
Second pytree to subtract.
"""
return jax.tree_util.tree_map(jnp.subtract, a, b)


def tree_scalar_mult(scalar, tree):
"""
Helper function that multiplies two pytrees.

Parameters
----------
a : pytree
First pytree to multiply.

b : pytree
Second pytree to multiply.
"""
return jax.tree_util.tree_map(lambda x: scalar*x, tree)


def tree_inner_product(a, b):
"""
Helper function that computes the inner product of two pytrees.

Parameters
----------
a : pytree
First pytree.

b : pytree
Second pytree.
"""
return jax.tree_util.tree_reduce(jnp.add, jax.tree_util.tree_map(
lambda x, y: jnp.sum(x * y), a, b))


def init_memory_of_trees(n_memories, tree):
"""
Helper function that initializes the memory of a pytree.

Parameters
----------
n_memories : int
Number of memories to initialize.

tree : pytree
Pytree to initialize.
"""
return jax.tree_util.tree_map(lambda x: jnp.zeros((n_memories, *x.shape)),
tree)


def select_memory(memory, idx):
"""
Helper function that selects a memory from a memory pytree.

Parameters
----------
memory : pytree
Memory pytree.

idx : int
Index of the memory to select.
"""
return jax.tree_util.tree_map(lambda x: x[idx], memory)


def update_memory(memory, idx, value):
"""
Helper function that updates a memory from a memory pytree.

Parameters
----------
memory : pytree
Memory pytree.

idx : int
Index of the memory to update.

value : pytree
Value to update the memory with.
"""
return jax.tree_util.tree_map(lambda x: x.at[idx].set(value), memory)
Loading
Loading