Skip to content

Commit

Permalink
Update common_recipes.ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Nov 25, 2023
1 parent 88d5bc0 commit 3317ffb
Showing 1 changed file with 25 additions and 37 deletions.
62 changes: 25 additions & 37 deletions docs/notebooks/common_recipes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install git+https://github.com/ASEM000/serket --quiet"
"# !pip install git+https://github.com/ASEM000/serket --quiet"
]
},
{
Expand Down Expand Up @@ -1275,11 +1275,13 @@
"name": "stdout",
"output_type": "stream",
"text": [
"autojit\n",
"(Array([[5., 5., 5., 5., 5.],\n",
" [5., 5., 5., 5., 5.],\n",
" [5., 5., 5., 5., 5.],\n",
" [5., 5., 5., 5., 5.],\n",
" [5., 5., 5., 5., 5.]], dtype=float32), 'hello')\n",
"autovmap\n",
"(Array([5., 5., 5., 5., 5.], dtype=float32), 'hello')\n"
]
}
Expand All @@ -1291,44 +1293,28 @@
"import jax.numpy as jnp\n",
"\n",
"\n",
"def autojit(func, **jit_kwargs):\n",
" \"\"\"Allow non-jax types to be passed to jitted functions by masking/unmasking\"\"\"\n",
"def automask(jax_transform):\n",
" # takes a jax transformation and returns the same transformation\n",
" # but with the ability to apply it to arbitrary pytrees of non-jax types\n",
"\n",
" @ft.partial(jax.jit, **jit_kwargs)\n",
" def jit_boundary(*args, **kwargs):\n",
" # unmask the inputs before pasing to the actual function\n",
" args, kwargs = sk.tree_unmask((args, kwargs))\n",
" # mask the outputs after calling the actual function\n",
" # because all outputs from jitted function should return jax-types\n",
" return sk.tree_mask(func(*args, **kwargs))\n",
" def out_transform(func, **transformation_kwargs):\n",
" @ft.partial(jax_transform, **transformation_kwargs)\n",
" def jax_boundary(*args, **kwargs):\n",
" # unmask the inputs before pasing to the actual function\n",
" args, kwargs = sk.tree_unmask((args, kwargs))\n",
" # mask the outputs after calling the actual function\n",
" # because outputs from `jax` transformation should return jax-types\n",
" return sk.tree_mask(func(*args, **kwargs))\n",
"\n",
" @ft.wraps(func)\n",
" def outer_wrapper(*args, **kwargs):\n",
" # mask the inputs before jit boundary\n",
" args, kwargs = sk.tree_mask((args, kwargs))\n",
" return sk.tree_unmask(jit_boundary(*args, **kwargs))\n",
" @ft.wraps(func)\n",
" def outer_wrapper(*args, **kwargs):\n",
" # mask the inputs before the `jax` boundary\n",
" args, kwargs = sk.tree_mask((args, kwargs))\n",
" return sk.tree_unmask(jax_boundary(*args, **kwargs))\n",
"\n",
" return outer_wrapper\n",
" return outer_wrapper\n",
"\n",
"\n",
"def autovmap(func, **vmap_kwargs):\n",
" \"\"\"Allow non-jax types to be passed to vmaped functions by masking/unmasking\"\"\"\n",
"\n",
" @ft.partial(jax.vmap, **vmap_kwargs)\n",
" def vmap_boundary(*args, **kwargs):\n",
" # unmask the inputs before pasing to the actual function\n",
" args, kwargs = sk.tree_unmask((args, kwargs))\n",
" # mask the outputs after calling the actual function\n",
" # because all outputs from vmap function should return jax-types\n",
" return sk.tree_mask(func(*args, **kwargs))\n",
"\n",
" @ft.wraps(func)\n",
" def outer_wrapper(*args, **kwargs):\n",
" # mask the inputs before jit boundary\n",
" args, kwargs = sk.tree_mask((args, kwargs))\n",
" return sk.tree_unmask(vmap_boundary(*args, **kwargs))\n",
"\n",
" return outer_wrapper\n",
" return out_transform\n",
"\n",
"\n",
"x, y = jnp.ones([5, 5]), jnp.ones([5, 5])\n",
Expand All @@ -1337,21 +1323,23 @@
"# test masked transformations\n",
"\n",
"\n",
"@autojit\n",
"@automask(jax.jit)\n",
"def func(x: jax.Array, y: jax.Array, name: str):\n",
" # name is not a jax type, with normal jit this will throw an error\n",
" return x @ y, name\n",
"\n",
"\n",
"print(\"autojit\")\n",
"print(func(x, y, \"hello\"))\n",
"\n",
"\n",
"@autovmap\n",
"@automask(jax.vmap)\n",
"def func(x: jax.Array, y: jax.Array, name: str):\n",
" # name is not a jax type, with normal vmap this will throw an error\n",
" return x @ y, name\n",
"\n",
"\n",
"print(\"autovmap\")\n",
"print(func(x, y, \"hello\"))"
]
}
Expand Down

0 comments on commit 3317ffb

Please sign in to comment.