@@ -40,17 +40,23 @@ which corresponds to the choice :math:`g(w, \text{l1reg}) = \text{l1reg} \cdot |
4040corresponding ``prox `` operator is :func: `prox_lasso <jaxopt.prox.prox_lasso> `.
4141We can therefore write::
4242
43- from jaxopt import ProximalGradient
44- from jaxopt.prox import prox_lasso
43+ .. doctest ::
44+ >>> import jax.numpy as jnp
45+ >>> from jaxopt import ProximalGradient
46+ >>> from jaxopt.prox import prox_lasso
47+ >>> from sklearn import datasets
48+ >>> X, y = datasets.make_regression()
4549
46- def least_squares(w, data):
47- X, y = data
48- residuals = jnp.dot(X, w) - y
49- return jnp.mean(residuals ** 2)
50+ >>> def least_squares (w , data ):
51+ ... inputs, targets = data
52+ ... residuals = jnp.dot(inputs, w) - targets
53+ ... return jnp.mean(residuals ** 2 )
54+
55+ >>> l1reg = 1.0
56+ >>> w_init = jnp.zeros(n_features)
57+ >>> pg = ProximalGradient(fun = least_squares, prox = prox_lasso)
58+ >>> pg_sol = pg.run(w_init, hyperparams_prox = l1reg, data = (X, y)).params
5059
51- l1reg = 1.0
52- pg = ProximalGradient(fun=least_squares, prox=prox_lasso)
53- pg_sol = pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
5460
5561Note that :func: `prox_lasso <jaxopt.prox.prox_lasso> ` has a hyperparameter
5662``l1reg ``, which controls the :math: `L_1 ` regularization strength. As shown
@@ -65,13 +71,15 @@ Differentiation
6571
6672In some applications, it is useful to differentiate the solution of the solver
6773with respect to some hyperparameters. Continuing the previous example, we can
68- now differentiate the solution w.r.t. ``l1reg ``::
74+ now differentiate the solution w.r.t. ``l1reg ``:
75+
6976
70- def solution(l1reg):
71- pg = ProximalGradient(fun=least_squares, prox=prox_lasso, implicit_diff=True)
72- return pg.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
77+ .. doctest ::
78+ >>> def solution (l1reg ):
79+ ... pg = ProximalGradient(fun = least_squares, prox = prox_lasso, implicit_diff = True )
80+ ... return pg.run(w_init, hyperparams_prox = l1reg, data = (X, y)).params
7381
74- print(jax.jacobian(solution)(l1reg))
82+ >>> print (jax.jacobian(solution)(l1reg))
7583
7684Under the hood, we use the implicit function theorem if ``implicit_diff=True ``
7785and autodiff of unrolled iterations if ``implicit_diff=False ``. See the
@@ -95,15 +103,16 @@ Block coordinate descent
95103Contrary to other solvers, :class: `jaxopt.BlockCoordinateDescent ` only works with
96104:ref: `composite linear objective functions <composite_linear_functions >`.
97105
98- Example::
106+ Example:
99107
100- from jaxopt import objective
101- from jaxopt import prox
108+ .. doctest ::
109+ >>> from jaxopt import objective
110+ >>> from jaxopt import prox
102111
103- l1reg = 1.0
104- w_init = jnp.zeros(n_features)
105- bcd = BlockCoordinateDescent(fun=objective.least_squares, block_prox=prox.prox_lasso)
106- lasso_sol = bcd.run(w_init, hyperparams_prox=l1reg, data=(X, y)).params
112+ >>> l1reg = 1.0
113+ >>> w_init = jnp.zeros(n_features)
114+ >>> bcd = BlockCoordinateDescent(fun = objective.least_squares, block_prox = prox.prox_lasso)
115+ >>> lasso_sol = bcd.run(w_init, hyperparams_prox = l1reg, data = (X, y)).params
107116
108117.. topic :: Examples
109118
0 commit comments