Skip to content

Commit

Permalink
[13]sharing/Tie weight recipe
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Nov 25, 2023
1 parent 42dee36 commit e5cf591
Showing 1 changed file with 113 additions and 34 deletions.
147 changes: 113 additions & 34 deletions docs/notebooks/common_recipes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -36,7 +36,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -105,20 +105,6 @@
"print(sk.tree_diagram(optim_state))"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"class Tree(sk.TreeClass):\n",
" def __init__(self, buffer: jax.Array):\n",
" self.buffer = buffer\n",
"\n",
" def __call__(self, x: jax.Array) -> jax.Array:\n",
" return x + jax.lax.stop_gradient(self.buffer)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -164,7 +150,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -220,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -275,7 +261,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -301,7 +287,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -345,7 +331,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -394,7 +380,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -464,7 +450,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -600,7 +586,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -626,7 +612,7 @@
" [1.]], dtype=float32)"
]
},
"execution_count": 11,
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -683,7 +669,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -772,7 +758,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand All @@ -796,7 +782,7 @@
" )]"
]
},
"execution_count": 13,
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -852,7 +838,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -969,7 +955,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 14,
"metadata": {},
"outputs": [
{
Expand All @@ -978,7 +964,7 @@
"25"
]
},
"execution_count": 15,
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -1033,7 +1019,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 15,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1084,7 +1070,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 16,
"metadata": {},
"outputs": [
{
Expand All @@ -1110,7 +1096,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 17,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -1175,6 +1161,99 @@
" ...\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [13] Sharing/Tie Weights\n",
"\n",
"In this example a simple `AutoEncoder` with shared `weight` between the encode/decoder is demonstrated."
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"TiedAutoEncoder(\n",
" encoder=Linear(\n",
" in_features=(#1), \n",
" out_features=#10, \n",
" weight_init=#glorot_uniform, \n",
" bias_init=#zeros, \n",
" weight=f32[1,10](μ=-0.78, σ=1.11, ∈[-2.58,0.00]), \n",
" bias=f32[10](μ=-0.39, σ=0.55, ∈[-1.29,0.00])\n",
" ), \n",
" decoder=Linear(\n",
" in_features=(#10), \n",
" out_features=#1, \n",
" weight_init=#glorot_uniform, \n",
" bias_init=#zeros, \n",
" weight=None, \n",
" bias=f32[1](μ=-2.40, σ=0.00, ∈[-2.40,-2.40])\n",
" )\n",
")"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import serket as sk\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"\n",
"\n",
"class TiedAutoEncoder(sk.TreeClass):\n",
" def __init__(self, *, key: jax.Array):\n",
" k1, k2 = jr.split(key)\n",
" self.encoder = sk.nn.Linear(1, 10, key=k1)\n",
" # set the unused weight of decoder to `None` to avoid memory usage\n",
" self.decoder = sk.nn.Linear(10, 1, key=k2).at[\"weight\"].set(None)\n",
"\n",
" def _call(self, x):\n",
" # share/tie weights of encoder and decoder\n",
" # however this operation mutates the state\n",
" # so this method will only work with .at\n",
" # otherwise will throw `AttributeError`\n",
" self.decoder.weight = self.encoder.weight.T\n",
" out = self.decoder(jax.nn.relu(self.encoder(x)))\n",
" return out\n",
"\n",
" def __call__(self, x):\n",
" # make the mutating method `_call` work with .at\n",
" # since .at returns a tuple of the method value and a new instance\n",
" # of the class that has the mutated state (i.e. does not mutate in place)\n",
" # then we can define __call__ to return only the result of the method\n",
" # and ignore the new instance of the class\n",
" out, _ = self.at[\"_call\"](x)\n",
" return out\n",
"\n",
"\n",
"tree = sk.tree_mask(TiedAutoEncoder(key=jr.PRNGKey(0)))\n",
"\n",
"\n",
"@jax.jit\n",
"@jax.grad\n",
"def loss_func(net, x, y):\n",
" net = sk.tree_unmask(net)\n",
" return jnp.mean((jax.vmap(net)(x) - y) ** 2)\n",
"\n",
"\n",
"tree = sk.tree_mask(tree)\n",
"x = jnp.ones([10, 1]) + 0.0\n",
"y = jnp.ones([10, 1]) * 2.0\n",
"grads: TiedAutoEncoder = loss_func(tree, x, y)\n",
"\n",
"grads"
]
}
],
"metadata": {
Expand Down

0 comments on commit e5cf591

Please sign in to comment.