diff --git a/doc/library/compile/mode.rst b/doc/library/compile/mode.rst index 4a977b7b8c..21c4240f4f 100644 --- a/doc/library/compile/mode.rst +++ b/doc/library/compile/mode.rst @@ -20,6 +20,9 @@ PyTensor defines the following modes by name: - ``'FAST_COMPILE'``: Apply just a few graph rewrites and only use Python implementations. - ``'FAST_RUN'``: Apply all rewrites, and use C implementations where possible. +- ``NUMBA``: Apply all relevant related rewrites and compile the whole graph using Numba. +- ``JAX``: Apply all relevant rewrites and compile the whole graph using JAX. +- ``PYTORCH`` Apply all relevant rewrites and compile the whole graph using PyTorch compile. - ``'DebugMode'``: A mode for debugging. See :ref:`DebugMode ` for details. - ``'NanGuardMode``: :ref:`Nan detector ` - ``'DEBUG_MODE'``: Deprecated. Use the string DebugMode. @@ -28,6 +31,12 @@ The default mode is typically ``FAST_RUN``, but it can be controlled via the configuration variable :attr:`config.mode`, which can be overridden by passing the keyword argument to :func:`pytensor.function`. +For Numba, JAX, and PyTorch, we exclude rewrites that introduce C-only Ops, +as well as BLAS optimizations, as those are done automatically by the respective backends. + +For JAX we also exclude fusion and inplace optimizations, as JAX does not support them +at the user level. They are performed automatically by JAX. + .. TODO:: For a finer level of control over which rewrites are applied, and whether