Skip to content

Commit

Permalink
recipes docs
Browse files Browse the repository at this point in the history
  • Loading branch information
ASEM000 committed Apr 11, 2024
1 parent e5f0066 commit bea703f
Show file tree
Hide file tree
Showing 10 changed files with 2,058 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/API/constructor.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
🏗️ Constructor utils API
🏗️ Constructor API
=============================


Expand Down
1 change: 1 addition & 0 deletions docs/core_guides.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
:caption: Core guides
:maxdepth: 1

notebooks/[guides][core]surgery
notebooks/[guides][core]checkpointing
notebooks/[guides][core]distributed_training
notebooks/[guides][core]mixed_precision
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ Install from github::
core_guides
other_guides
interoperability
recipes

.. currentmodule:: serket

Expand Down
315 changes: 315 additions & 0 deletions docs/notebooks/[guides][core]surgery.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ✂️ Surgery\n",
"\n",
"This notebook provides tree editing (surgery) recipes using `at`. `at` encapsules a pytree and provides a way to edit it in out-of-place manner. Under the hood, `at` uses `jax.tree_util` or `optree` to traverse the tree and apply the provided function to the selected nodes.\n",
"\n",
"```python\n",
"import sepes as sp\n",
"import re\n",
"tree = dict(key_1=[1, -2, 3], key_2=[4, 5, 6], other=[7, 8, 9])\n",
"where = re.compile(\"key_.*\") # select all keys starting with \"key_\"\n",
"func = lambda node: sum(map(abs, node)) # sum of absolute values\n",
"sp.at(tree)[where].apply(func)\n",
"# {'key_1': 6, 'key_2': 15, 'other': [7, 8, 9]}\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"!pip install sepes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Out-of-place editing\n",
"\n",
"Out-of-place means that the original tree is not modified. Instead, a new tree is created with the changes. This is the default behavior of `at`. The following example demonstrates this behavior:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"pytree1=[1, [2, 3], 4], pytree2=[1, [2, 3], 4]\n",
"pytree1 is pytree2 = False\n"
]
}
],
"source": [
"import sepes as sp\n",
"\n",
"pytree1 = [1, [2, 3], 4]\n",
"pytree2 = sp.at(pytree1)[...].get() # get the whole pytree using ...\n",
"print(f\"{pytree1=}, {pytree2=}\")\n",
"# even though pytree1 and pytree2 are the same, they are not the same object\n",
"# because pytree2 is a copy of pytree1\n",
"print(f\"pytree1 is pytree2 = {pytree1 is pytree2}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Matching keys"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Match all\n",
"\n",
"Use `...` to match all keys. \n",
"\n",
"The following example applies `plus_one` function to all tree nodes. This is equivalent to `tree = tree_map(plus_one, tree)`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[2, [3, 4], 5]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sepes as sp\n",
"\n",
"pytree1 = [1, [2, 3], 4]\n",
"plus_one = lambda x: x + 1\n",
"pytree2 = sp.at(pytree1)[...].apply(plus_one)\n",
"pytree2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Integer indexing\n",
"\n",
"`at` can edit pytrees by integer paths."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[1, [100, 3], 4]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sepes as sp\n",
"\n",
"pytree1 = [1, [2, 3], 4]\n",
"pytree2 = sp.at(pytree1)[1][0].set(100) # equivalent to pytree1[1][0] = 100\n",
"pytree2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Named path indexing\n",
"`at` can edit pytrees by named paths."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'a': -1, 'b': 100, 'e': -4, 'f': {'g': 7, 'h': 8}}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sepes as sp\n",
"\n",
"pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4, \"f\": {\"g\": 7, \"h\": 8}}\n",
"pytree2 = sp.at(pytree1)[\"b\"].set(100) # equivalent to pytree1[\"b\"] = 100\n",
"pytree2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Regular expressions indexing\n",
"`at` can edit pytrees by regular expressions applied to keys."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'key_1': 100, 'key_2': 100, 'key_5': 100, 'key_6': {'key_7': 7, 'key_8': 8}}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sepes as sp\n",
"import re\n",
"\n",
"pytree1 = {\n",
" \"key_1\": 1,\n",
" \"key_2\": {\"key_3\": 3, \"key_4\": 4},\n",
" \"key_5\": 5,\n",
" \"key_6\": {\"key_7\": 7, \"key_8\": 8},\n",
"}\n",
"# from 1 - 5, set the value to 100\n",
"pattern = re.compile(r\"key_[1-5]\")\n",
"pytree2 = sp.at(pytree1)[pattern].set(100)\n",
"pytree2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Mask indexing\n",
"\n",
"`at` can edit pytree entries by a boolean mask, meaning that given a mask of the same structure of the pytree, then nodes marked `True` will be edited, otherwise will not be touched. \n",
"\n",
"\n",
"The following example set all negative tree nodes to zero."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'a': 0, 'b': {'c': 2, 'd': 3}, 'e': 0}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sepes as sp\n",
"import jax\n",
"\n",
"pytree1 = {\"a\": -1, \"b\": {\"c\": 2, \"d\": 3}, \"e\": -4}\n",
"# mask defines all desired entries to apply the function\n",
"mask = jax.tree_util.tree_map(lambda x: x < 0, pytree1)\n",
"pytree2 = sp.at(pytree1)[mask].set(0)\n",
"pytree2"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Composition\n",
"\n",
"`at` can compose multiple keys, integer paths, named paths, regular expressions, and masks to edit the tree.\n",
"\n",
"The following example demonstrates how to apply a function to all nodes with:\n",
"- Name starting with \"key_\"\n",
"- Positive values"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'key_1': 100, 'key_2': -2, 'value_1': 1, 'value_2': 2}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import sepes as sp\n",
"import re\n",
"import jax\n",
"\n",
"pytree1 = {\"key_1\": 1, \"key_2\": -2, \"value_1\": 1, \"value_2\": 2}\n",
"pattern = re.compile(r\"key_.*\")\n",
"mask = jax.tree_util.tree_map(lambda x: x > 0, pytree1)\n",
"pytree2 = sp.at(pytree1)[pattern][mask].set(100)\n",
"pytree2"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "dev-jax",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit bea703f

Please sign in to comment.