diff --git a/docs/notebooks/summary_tutorial.ipynb b/docs/notebooks/summary_tutorial.ipynb index 41b3064..e321651 100644 --- a/docs/notebooks/summary_tutorial.ipynb +++ b/docs/notebooks/summary_tutorial.ipynb @@ -997,9 +997,9 @@ "id": "jNt9CNJf2HJN" }, "source": [ - "### jax.experimental.host_callback\n", + "### jax external callbacks\n", "\n", - "Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in jax.experimental.host_callback.\n", + "Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html.\n", "\n", "One can use this to print which is a quick way to get data out of a network." ] @@ -1025,13 +1025,12 @@ } ], "source": [ - "from jax.experimental import host_callback as hcb\n", "\n", "\n", "def loss(parameters):\n", " loss = jnp.mean(parameters**2)\n", " to_look_at = jnp.mean(123.)\n", - " hcb.id_print(to_look_at, name=\"to_look_at\")\n", + " jax.debug.print(\"to_look_at={}\", to_look_at)\n", " return loss\n", "\n", "\n", diff --git a/docs/notebooks/summary_tutorial.md b/docs/notebooks/summary_tutorial.md index 3c4c8a9..abbe421 100644 --- a/docs/notebooks/summary_tutorial.md +++ b/docs/notebooks/summary_tutorial.md @@ -461,9 +461,9 @@ print(to_look_at) +++ {"id": "jNt9CNJf2HJN"} -### jax.experimental.host_callback +### jax external callbacks -Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in jax.experimental.host_callback. +Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html. One can use this to print which is a quick way to get data out of a network. @@ -474,13 +474,12 @@ colab: id: 1Ih2LxP22MZD outputId: 0dd0b8ec-2c9e-414d-eadf-843122b7b8ab --- -from jax.experimental import host_callback as hcb def loss(parameters): loss = jnp.mean(parameters**2) to_look_at = jnp.mean(123.) - hcb.id_print(to_look_at, name="to_look_at") + jax.debug.print("to_look_at={}", to_look_at) return loss diff --git a/docs/notebooks/summary_tutorial.py b/docs/notebooks/summary_tutorial.py index 9a36604..b4585b7 100644 --- a/docs/notebooks/summary_tutorial.py +++ b/docs/notebooks/summary_tutorial.py @@ -348,20 +348,21 @@ def loss(parameters): print(to_look_at) # + [markdown] id="jNt9CNJf2HJN" -# ### jax.experimental.host_callback +# ### jax external callbacks # -# Jax has some support to send data back from an accelerator back to the host while a ja program is running. This is exposed in jax.experimental.host_callback. +# Jax has some support to send data back from an accelerator back to the host +# while a jax program is running. This is exposed in +# https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html. # # One can use this to print which is a quick way to get data out of a network. # + colab={"base_uri": "https://localhost:8080/"} id="1Ih2LxP22MZD" outputId="0dd0b8ec-2c9e-414d-eadf-843122b7b8ab" -from jax.experimental import host_callback as hcb def loss(parameters): loss = jnp.mean(parameters**2) to_look_at = jnp.mean(123.) - hcb.id_print(to_look_at, name="to_look_at") + jax.debug.print("to_look_at={}", to_look_at) return loss