@@ -40,17 +40,23 @@ which corresponds to the choice :math:`g(w, \text{l1reg}) = \text{l1reg} \cdot |
40
40
corresponding ``prox `` operator is :func: `prox_lasso <jaxopt.prox.prox_lasso> `.
41
41
We can therefore write::
42
42
43
- from jaxopt import ProximalGradient
44
- from jaxopt.prox import prox_lasso
43
+ .. doctest ::
44
+ >>> import jax.numpy as jnp
45
+ >>> from jaxopt import ProximalGradient
46
+ >>> from jaxopt.prox import prox_lasso
47
+ >>> from sklearn import datasets
48
+ >>> X, y = datasets.make_regression()
45
49
46
- def least_squares(w, data):
47
- X, y = data
48
- residuals = jnp.dot(X, w) - y
49
- return jnp.mean(residuals ** 2)
50
+ >>> def least_squares (w , data ):
51
+ ... inputs, targets = data
52
+ ... residuals = jnp.dot(inputs, w) - targets
53
+ ... return jnp.mean(residuals ** 2 )
54
+
55
+ >>> l1reg = 1.0
56
+ >>> w_init = jnp.zeros(n_features)
57
+ >>> pg = ProximalGradient(fun = least_squares, prox = prox_lasso)
58
+ >>> pg_sol = pg.run(w_init, hyperparams_prox = l1reg, data = (X, y)).params
50
59
51
- l1reg = 1.0
52
- pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
53
- pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
54
60
55
61
Note that :func: `prox_lasso <jaxopt.prox.prox_lasso> ` has a hyperparameter
56
62
``l1reg ``, which controls the :math: `L_1 ` regularization strength. As shown
@@ -65,13 +71,15 @@ Differentiation
65
71
66
72
In some applications, it is useful to differentiate the solution of the solver
67
73
with respect to some hyperparameters. Continuing the previous example, we can
68
- now differentiate the solution w.r.t. ``l1reg ``::
74
+ now differentiate the solution w.r.t. ``l1reg ``:
75
+
69
76
70
- def solution(l1reg):
71
- pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
72
- return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
77
+ .. doctest ::
78
+ >>> def solution (l1reg ):
79
+ ... pg = ProximalGradient(fun = least_squares, prox = prox_lasso, implicit_diff = True )
80
+ ... return pg.run(w_init, hyperparams_prox = l1reg, data = (X, y)).params
73
81
74
- print(jax.jacobian(solution)(l1reg))
82
+ >>> print (jax.jacobian(solution)(l1reg))
75
83
76
84
Under the hood, we use the implicit function theorem if ``implicit_diff=True ``
77
85
and autodiff of unrolled iterations if ``implicit_diff=False ``. See the
@@ -95,15 +103,16 @@ Block coordinate descent
95
103
Contrary to other solvers, :class: `jaxopt.BlockCoordinateDescent ` only works with
96
104
:ref: `composite linear objective functions <composite_linear_functions >`.
97
105
98
- Example::
106
+ Example:
99
107
100
- from jaxopt import objective
101
- from jaxopt import prox
108
+ .. doctest ::
109
+ >>> from jaxopt import objective
110
+ >>> from jaxopt import prox
102
111
103
- l1reg = 1.0
104
- w_init = jnp.zeros(n_features)
105
- bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
106
- lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
112
+ >>> l1reg = 1.0
113
+ >>> w_init = jnp.zeros(n_features)
114
+ >>> bcd = BlockCoordinateDescent(fun = objective.least_squares, block_prox = prox.prox_lasso)
115
+ >>> lasso_sol = bcd.run(w_init, hyperparams_prox = l1reg, data = (X, y)).params
107
116
108
117
.. topic :: Examples
109
118
0 commit comments