Skip to content

Commit 66446d6

Browse files
committed
test documentation docstrings
1 parent c488cd9 commit 66446d6

9 files changed

+360
-266
lines changed

.github/workflows/tests.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,8 @@ jobs:
6464
set -xe
6565
pip install --upgrade pip setuptools wheel
6666
pip install -r docs/requirements.txt
67-
- name: Build documentation
67+
- name: Test examples and docstrings
6868
run: |
6969
set -xe
7070
python -VV
71-
cd docs && make clean && make html
71+
make doctest

docs/Makefile

+8-3
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33

44
# You can set these variables from the command line, and also
55
# from the environment for the first two.
6-
SPHINXOPTS ?=
76
SPHINXBUILD ?= sphinx-build
87
SOURCEDIR = .
98
BUILDDIR = _build
9+
SPHINXOPTS = -d $(BUILDDIR)/doctrees -T
1010

1111
# Put it first so that "make" without argument is like "make help".
1212
help:
1313
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
1414

15-
.PHONY: help Makefile
15+
.PHONY: help Makefile doctest
1616

1717
# Catch-all target: route all unknown targets to Sphinx using the new
1818
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
@@ -25,6 +25,11 @@ clean:
2525
rm -rf _autosummary/
2626

2727
html-noplot:
28-
$(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(ALLSPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html
28+
$(SPHINXBUILD) -D plot_gallery=0 -D jupyter_execute_notebooks=off -b html $(SPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html
2929
@echo
3030
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
31+
32+
doctest:
33+
$(SPHINXBUILD) -b doctest $(SPHINXOPTS) . $(BUILDDIR)/doctest
34+
@echo "Testing of doctests in the sources finished, look at the " \
35+
"results in $(BUILDDIR)/doctest/output.txt."

docs/conf.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,14 @@
5050
'sphinx.ext.napoleon', # napoleon on top of autodoc: https://stackoverflow.com/a/66930447 might correct some warnings
5151
'sphinx.ext.autodoc',
5252
'sphinx.ext.autosummary',
53+
'sphinx.ext.doctest',
5354
'sphinx.ext.intersphinx',
5455
'sphinx.ext.mathjax',
5556
'sphinx.ext.viewcode',
5657
'matplotlib.sphinxext.plot_directive',
5758
'sphinx_autodoc_typehints',
5859
'myst_nb',
59-
"sphinx_remove_toctrees",
60+
'sphinx_remove_toctrees',
6061
'sphinx_rtd_theme',
6162
'sphinx_gallery.gen_gallery',
6263
'sphinx_copybutton',
@@ -70,7 +71,12 @@
7071
"backreferences_dir": os.path.join("modules", "generated"),
7172
}
7273

74+
# Specify how to identify the prompt when copying code snippets
75+
copybutton_prompt_text = r">>> |\.\.\. "
76+
copybutton_prompt_is_regexp = True
77+
copybutton_exclude = "style"
7378

79+
trim_doctests_flags = True
7480
source_suffix = ['.rst', '.ipynb', '.md']
7581

7682
autosummary_generate = True

docs/non_smooth.rst

+30-21
Original file line numberDiff line numberDiff line change
@@ -40,17 +40,23 @@ which corresponds to the choice :math:`g(w, \text{l1reg}) = \text{l1reg} \cdot |
4040
corresponding ``prox`` operator is :func:`prox_lasso <jaxopt.prox.prox_lasso>`.
4141
We can therefore write::
4242

43-
from jaxopt import ProximalGradient
44-
from jaxopt.prox import prox_lasso
43+
.. doctest::
44+
>>> import jax.numpy as jnp
45+
>>> from jaxopt import ProximalGradient
46+
>>> from jaxopt.prox import prox_lasso
47+
>>> from sklearn import datasets
48+
>>> X, y = datasets.make_regression()
4549

46-
def least_squares(w, data):
47-
X, y = data
48-
residuals = jnp.dot(X, w) - y
49-
return jnp.mean(residuals ** 2)
50+
>>> def least_squares(w, data):
51+
... inputs, targets = data
52+
... residuals = jnp.dot(inputs, w) - targets
53+
... return jnp.mean(residuals ** 2)
54+
55+
>>> l1reg = 1.0
56+
>>> w_init = jnp.zeros(n_features)
57+
>>> pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
58+
>>> pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
5059

51-
l1reg = 1.0
52-
pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
53-
pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
5460

5561
Note that :func:`prox_lasso <jaxopt.prox.prox_lasso>` has a hyperparameter
5662
``l1reg``, which controls the :math:`L_1` regularization strength. As shown
@@ -65,13 +71,15 @@ Differentiation
6571

6672
In some applications, it is useful to differentiate the solution of the solver
6773
with respect to some hyperparameters. Continuing the previous example, we can
68-
now differentiate the solution w.r.t. ``l1reg``::
74+
now differentiate the solution w.r.t. ``l1reg``:
75+
6976

70-
def solution(l1reg):
71-
pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
72-
return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
77+
.. doctest::
78+
>>> def solution(l1reg):
79+
... pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
80+
... return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
7381

74-
print(jax.jacobian(solution)(l1reg))
82+
>>> print(jax.jacobian(solution)(l1reg))
7583

7684
Under the hood, we use the implicit function theorem if ``implicit_diff=True``
7785
and autodiff of unrolled iterations if ``implicit_diff=False``. See the
@@ -95,15 +103,16 @@ Block coordinate descent
95103
Contrary to other solvers, :class:`jaxopt.BlockCoordinateDescent` only works with
96104
:ref:`composite linear objective functions <composite_linear_functions>`.
97105

98-
Example::
106+
Example:
99107

100-
from jaxopt import objective
101-
from jaxopt import prox
108+
.. doctest::
109+
>>> from jaxopt import objective
110+
>>> from jaxopt import prox
102111

103-
l1reg = 1.0
104-
w_init = jnp.zeros(n_features)
105-
bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
106-
lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
112+
>>> l1reg = 1.0
113+
>>> w_init = jnp.zeros(n_features)
114+
>>> bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
115+
>>> lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
107116

108117
.. topic:: Examples
109118

0 commit comments

Comments
 (0)