Skip to content

Commit

Permalink
[14] masked transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Nov 25, 2023
1 parent e5cf591 commit 88d5bc0
Showing 1 changed file with 114 additions and 14 deletions.
128 changes: 114 additions & 14 deletions docs/notebooks/common_recipes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -36,7 +36,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -150,7 +150,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -206,7 +206,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -261,7 +261,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -287,7 +287,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -331,7 +331,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -380,7 +380,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -450,7 +450,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -586,7 +586,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand All @@ -612,7 +612,7 @@
" [1.]], dtype=float32)"
]
},
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -669,7 +669,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -758,7 +758,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand All @@ -782,7 +782,7 @@
" )]"
]
},
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -1254,6 +1254,106 @@
"\n",
"grads"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## [14] Masked transformation\n",
"\n",
"As an alternative to using `sk.tree_unmask` on pytrees before calling the function -as seen throughout training examples and recipes- , another approach is to wrap a certain transformation - not pytrees - (e.g. `jit`) to be make the masking/unmasking automatic; however this apporach will incur more overhead than applying `sk.tree_unmask` before the function call.\n",
"\n",
"The following example demonstrate how to wrap `jit`, and `vmap`."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(Array([[5., 5., 5., 5., 5.],\n",
" [5., 5., 5., 5., 5.],\n",
" [5., 5., 5., 5., 5.],\n",
" [5., 5., 5., 5., 5.],\n",
" [5., 5., 5., 5., 5.]], dtype=float32), 'hello')\n",
"(Array([5., 5., 5., 5., 5.], dtype=float32), 'hello')\n"
]
}
],
"source": [
"import serket as sk\n",
"import functools as ft\n",
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"\n",
"def autojit(func, **jit_kwargs):\n",
" \"\"\"Allow non-jax types to be passed to jitted functions by masking/unmasking\"\"\"\n",
"\n",
" @ft.partial(jax.jit, **jit_kwargs)\n",
" def jit_boundary(*args, **kwargs):\n",
" # unmask the inputs before pasing to the actual function\n",
" args, kwargs = sk.tree_unmask((args, kwargs))\n",
" # mask the outputs after calling the actual function\n",
" # because all outputs from jitted function should return jax-types\n",
" return sk.tree_mask(func(*args, **kwargs))\n",
"\n",
" @ft.wraps(func)\n",
" def outer_wrapper(*args, **kwargs):\n",
" # mask the inputs before jit boundary\n",
" args, kwargs = sk.tree_mask((args, kwargs))\n",
" return sk.tree_unmask(jit_boundary(*args, **kwargs))\n",
"\n",
" return outer_wrapper\n",
"\n",
"\n",
"def autovmap(func, **vmap_kwargs):\n",
" \"\"\"Allow non-jax types to be passed to vmaped functions by masking/unmasking\"\"\"\n",
"\n",
" @ft.partial(jax.vmap, **vmap_kwargs)\n",
" def vmap_boundary(*args, **kwargs):\n",
" # unmask the inputs before pasing to the actual function\n",
" args, kwargs = sk.tree_unmask((args, kwargs))\n",
" # mask the outputs after calling the actual function\n",
" # because all outputs from vmap function should return jax-types\n",
" return sk.tree_mask(func(*args, **kwargs))\n",
"\n",
" @ft.wraps(func)\n",
" def outer_wrapper(*args, **kwargs):\n",
" # mask the inputs before jit boundary\n",
" args, kwargs = sk.tree_mask((args, kwargs))\n",
" return sk.tree_unmask(vmap_boundary(*args, **kwargs))\n",
"\n",
" return outer_wrapper\n",
"\n",
"\n",
"x, y = jnp.ones([5, 5]), jnp.ones([5, 5])\n",
"\n",
"\n",
"# test masked transformations\n",
"\n",
"\n",
"@autojit\n",
"def func(x: jax.Array, y: jax.Array, name: str):\n",
" # name is not a jax type, with normal jit this will throw an error\n",
" return x @ y, name\n",
"\n",
"\n",
"print(func(x, y, \"hello\"))\n",
"\n",
"\n",
"@autovmap\n",
"def func(x: jax.Array, y: jax.Array, name: str):\n",
" # name is not a jax type, with normal vmap this will throw an error\n",
" return x @ y, name\n",
"\n",
"\n",
"print(func(x, y, \"hello\"))"
]
}
],
"metadata": {
Expand Down

0 comments on commit 88d5bc0

Please sign in to comment.