Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a distillation experiment #44

Open
wants to merge 39 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
06dc83f
soba + pytrees
pierreablin Jun 4, 2024
2af84b0
add dataset
pierreablin Jun 4, 2024
3a8f231
FIX flake8
MatDag Jun 4, 2024
c5f6af2
ENH put update_sgd_fn in utils file
MatDag Jul 18, 2024
4830cef
ENH create tree_utils.py
MatDag Jul 18, 2024
f6c7bc1
WIP pytree
MatDag Jul 22, 2024
c7b7a1d
FIX bug amigo
MatDag Jul 22, 2024
b041359
WIP bome
MatDag Jul 22, 2024
0c41828
WIP pytrees
MatDag Jul 22, 2024
854559a
WIP fsla pytree
MatDag Jul 22, 2024
f5d7be1
WIP jaxopt pytree
MatDag Jul 22, 2024
28857c0
WIP memory trees
MatDag Jul 22, 2024
1672b23
WIP mrbo
MatDag Jul 22, 2024
7ad80cd
WIP pytrees
MatDag Jul 23, 2024
aeaa115
ENH tree_diff
MatDag Jul 23, 2024
ac218d4
WIP saba pytree
MatDag Jul 23, 2024
9ca7bab
FIX saba vr
MatDag Jul 23, 2024
e1d84ff
FIX sustain select_memory
MatDag Jul 23, 2024
947c5dd
FIX sustain select_memory
MatDag Jul 23, 2024
2cc800a
FIX hia
MatDag Jul 23, 2024
fba548e
FIX sustain
MatDag Jul 23, 2024
3e9ab13
ENH distillation
MatDag Jul 25, 2024
28926f7
ENH enables to save distilled images
MatDag Jul 25, 2024
fa0aba8
FIX requirement flax
MatDag Jul 25, 2024
3fddd1c
FIX requirement optax
MatDag Jul 25, 2024
c0edadb
WIP init
MatDag Aug 6, 2024
d06dfe7
WIP making it work
MatDag Oct 16, 2024
1f47feb
WIP comment cnn
MatDag Oct 16, 2024
faa4bc8
ENH init inner_var
MatDag Oct 16, 2024
514356b
WIP
MatDag Oct 16, 2024
d30a685
FIX revert soba
MatDag Oct 16, 2024
9e04793
WIP
MatDag Oct 16, 2024
cbebe9f
FIX flatten
MatDag Oct 16, 2024
5e6280b
CLN jax.tree_map -> jax.tree_util.tree_map
MatDag Oct 17, 2024
d020e4d
FIX test
MatDag Oct 17, 2024
e4e2c96
WIP
MatDag Oct 17, 2024
eb8ae64
ENH allow several achitectures
MatDag Oct 18, 2024
96590a9
FIX model
MatDag Oct 18, 2024
8e319a8
FIX accuracy
MatDag Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
FIX test
  • Loading branch information
MatDag committed Oct 17, 2024
commit d020e4d90234f710a5757dd1a62bd70f667f4d9c
37 changes: 18 additions & 19 deletions datasets/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,6 @@ def download_mnist():
print("Save complete.")


class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
# x = nn.Conv(features=32, kernel_size=(3, 3))(x)
# x = nn.relu(x)
# x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
# x = nn.Conv(features=64, kernel_size=(3, 3))(x)
# x = nn.relu(x)
# x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
# x = nn.Dense(features=256)(x)
# x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x


class Dataset(BaseDataset):
"""Datacleaning with MNIST"""

Expand Down Expand Up @@ -128,6 +111,22 @@ def get_data(self):
self.n_samples_inner = X_train.shape[0]
self.n_samples_outer = X_val.shape[0]

class CNN(nn.Module):
"""A simple CNN model."""
@nn.compact
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = nn.Conv(features=64, kernel_size=(3, 3))(x)
x = nn.relu(x)
x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(features=256)(x)
x = nn.relu(x)
x = nn.Dense(features=10)(x)
return x

cnn = CNN()
key = jax.random.PRNGKey(self.random_state)
inner_params = cnn.init(key, jnp.ones([1, 28, 28, 1]))['params']
Expand All @@ -147,8 +146,8 @@ def f_inner(inner_var, outer_var, start=0, batch_size=1):
# distilled dataset
res = loss(inner_var, outer_var, jnp.eye(10))
# outer_var is the distilled dataset, we have one sample
# per class. Thus the one_hot encoding of labels if the distilled
# dataset is jnp.eye(10)
# per class. Thus the one_hot encoding of the distilled dataset's
# labels is jnp.eye(10)

res += self.reg * tree_inner_product(inner_var, inner_var)
return res
Expand Down
Loading