Skip to content

Commit 5a3ca42

Browse files
committed
Added Pseudo-inverse preconditioner for EqQP.
This allows to precompute a preconditioner, and share it across multiple outer loops, where the inner loop is solving an Equality Constrained QP. This should provide speedups when the parameters of the inner loop QP don't change too much. TODO: modify the implicit diff decorator so that the jvp also uses the preconditioner.
1 parent ced83e9 commit 5a3ca42

6 files changed

+226
-12
lines changed

jaxopt/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from jaxopt._src.polyak_sgd import PolyakSGD
3636
from jaxopt._src.projected_gradient import ProjectedGradient
3737
from jaxopt._src.proximal_gradient import ProximalGradient
38+
from jaxopt._src.eq_qp_preconditioned import PseudoInversePreconditionedEqQP
3839
from jaxopt._src.quadratic_prog import QuadraticProgramming
3940
from jaxopt._src.scipy_wrappers import ScipyBoundedLeastSquares
4041
from jaxopt._src.scipy_wrappers import ScipyBoundedMinimize

jaxopt/_src/eq_qp.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class EqualityConstrainedQP(base.Solver):
103103
implicit_diff_solve: Optional[Callable] = None
104104
jit: bool = True
105105

106-
def _refined_solve(self, matvec, b, init, maxiter, tol):
106+
def _refined_solve(self, matvec, b, init, maxiter, tol, **kwargs):
107107
# Instead of solving S x = b
108108
# We solve \bar{S} x = b
109109
#
@@ -152,13 +152,14 @@ def matvec_regularized_qp(_, x):
152152
maxiter=self.refine_maxiter,
153153
tol=tol,
154154
)
155-
return solver.run(init_params=init, A=None, b=b)[0]
155+
return solver.run(init_params=init, A=None, b=b, **kwargs)[0]
156156

157157
def run(
158158
self,
159159
init_params: Optional[base.KKTSolution] = None,
160160
params_obj: Optional[Any] = None,
161161
params_eq: Optional[Any] = None,
162+
**kwargs,
162163
) -> base.OptStep:
163164
"""Solves 0.5 * x^T Q x + c^T x subject to Ax = b.
164165
@@ -168,6 +169,7 @@ def run(
168169
init_params: ignored.
169170
params_obj: (Q, c) or (params_Q, c) if matvec_Q is provided.
170171
params_eq: (A, b) or (params_A, b) if matvec_A is provided.
172+
**kwargs: Keyword args provided to the solver.
171173
Returns:
172174
(params, state), where params = (primal_var, dual_var_eq, None)
173175
"""
@@ -200,10 +202,11 @@ def matvec(u):
200202
init=init_params,
201203
tol=self.tol,
202204
maxiter=self.maxiter,
205+
**kwargs,
203206
)
204207
else:
205208
primal, dual_eq = self._refined_solve(
206-
matvec, target, init_params, tol=self.tol, maxiter=self.maxiter
209+
matvec, target, init_params, tol=self.tol, maxiter=self.maxiter, **kwargs
207210
)
208211

209212
return base.OptStep(params=base.KKTSolution(primal, dual_eq, None), state=None)

jaxopt/_src/eq_qp_preconditioned.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Preconditioned solvers for equality constrained quadratic programming."""
16+
17+
from typing import Optional, Any
18+
from dataclasses import dataclass
19+
import jax.numpy as jnp
20+
import jaxopt
21+
from jaxopt._src import base
22+
from jaxopt._src import linear_operator
23+
24+
25+
@dataclass
26+
class PseudoInversePreconditionedEqQP(base.Solver):
27+
qp_solver: jaxopt.EqualityConstrainedQP
28+
29+
def init_params(self, params_obj, params_eq):
30+
"""Computes the matvec associated to the pseudo inverse of the KKT matrix."""
31+
Q, p = params_obj
32+
A, b = params_eq
33+
del p, b
34+
35+
kkt_mat = jnp.block([[Q, A.T], [A, jnp.zeros((A.shape[0], A.shape[0]))]])
36+
kkt_mat_pinv = jnp.linalg.pinv(kkt_mat)
37+
38+
d = Q.shape[0]
39+
40+
pinv_blocks = (
41+
(kkt_mat_pinv[:d, :d], kkt_mat_pinv[:d, d:]),
42+
(kkt_mat_pinv[d:, :d], kkt_mat_pinv[d:, d:]),
43+
)
44+
return linear_operator.BlockLinearOperator(pinv_blocks)
45+
46+
def run(
47+
self,
48+
init_params: Optional[base.KKTSolution] = None,
49+
params_obj: Optional[Any] = None,
50+
params_eq: Optional[Any] = None,
51+
params_precond=None,
52+
**kwargs
53+
):
54+
# TODO(gnegiar): the M parameter should be passed to both
55+
# the QP solve and the implicit_diff_solve
56+
return self.qp_solver.run(
57+
init_params, params_obj, params_eq, M=params_precond, **kwargs
58+
)

jaxopt/_src/linear_operator.py

+76-6
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@
1414
"""Interface for linear operators."""
1515

1616
import functools
17+
from dataclasses import dataclass
18+
from typing import Tuple
19+
1720
import jax
1821
import jax.numpy as jnp
19-
import numpy as onp
2022

21-
from jaxopt.tree_util import tree_map, tree_sum, tree_mul
23+
from jaxopt.tree_util import tree_map
2224

2325

2426
class DenseLinearOperator:
25-
2627
def __init__(self, pytree):
2728
self.pytree = pytree
2829

@@ -33,7 +34,7 @@ def matvec(self, x):
3334
return tree_map(jnp.dot, self.pytree, x)
3435

3536
def rmatvec(self, _, y):
36-
return tree_map(lambda w,yi: jnp.dot(w.T, yi), self.pytree, y)
37+
return tree_map(lambda w, yi: jnp.dot(w.T, yi), self.pytree, y)
3738

3839
def matvec_and_rmatvec(self, x, y):
3940
return self.matvec(x), self.rmatvec(x, y)
@@ -52,11 +53,11 @@ def col_norm(w):
5253
if not squared:
5354
col_norms = jnp.sqrt(col_norms)
5455
return col_norms
56+
5557
return tree_map(col_norm, self.pytree)
5658

5759

5860
class FunctionalLinearOperator:
59-
6061
def __init__(self, fun, params):
6162
self.fun = functools.partial(fun, params)
6263

@@ -71,7 +72,7 @@ def rmatvec(self, x, y):
7172

7273
def matvec_and_rmatvec(self, x, y):
7374
matvec_x, vjp = jax.vjp(self.matvec, x)
74-
rmatvec_y, = vjp(y)
75+
(rmatvec_y,) = vjp(y)
7576
return matvec_x, rmatvec_y
7677

7778
def normal_matvec(self, x):
@@ -85,3 +86,72 @@ def _make_linear_operator(matvec):
8586
return DenseLinearOperator
8687
else:
8788
return functools.partial(FunctionalLinearOperator, matvec)
89+
90+
91+
def block_row_matvec(block, x):
92+
"""Performs a matvec for a row of block matrices.
93+
94+
The following matvec is performed:
95+
[U1, ..., UN] * [x1, ..., xN]
96+
where U1, ..., UN are matrices and x1, ..., xN are vectors
97+
of compatible shapes.
98+
"""
99+
if len(block) != len(x):
100+
raise ValueError(
101+
"We need as many blocks in the matrix as in the vector."
102+
)
103+
return sum(jax.tree_util.tree_map(jnp.dot, block, x))
104+
105+
106+
# TODO(gnegiar): Extend to arbitrary block shapes.
107+
@jax.tree_util.register_pytree_node_class
108+
@dataclass
109+
class BlockLinearOperator:
110+
"""Represents a linear operator defined by blocks over a block pytree.
111+
112+
Attributes:
113+
blocks: a 2x2 block matrix of the form
114+
[[A, B]
115+
[C, D]]
116+
"""
117+
118+
blocks: Tuple[Tuple[jnp.array]]
119+
120+
def __call__(self, x):
121+
return self.matvec(x)
122+
123+
def matvec(self, x):
124+
"""Performs the block matvec with u defined by blocks.
125+
126+
The matvec is of form:
127+
[u1, u2]
128+
[[A, B] *
129+
[C, D]]
130+
131+
"""
132+
return jax.tree_util.tree_map(
133+
lambda row_of_blocks: block_row_matvec(row_of_blocks, x),
134+
self.blocks,
135+
is_leaf=lambda x: x is self.blocks[0] or x is self.blocks[1],
136+
)
137+
138+
def rmatvec(self, x, y):
139+
return self.matvec_and_rmatvec(x, y)[1]
140+
141+
def matvec_and_rmatvec(self, x, y):
142+
matvec_x, vjp = jax.vjp(self.matvec, x)
143+
(rmatvec_y,) = vjp(y)
144+
return matvec_x, rmatvec_y
145+
146+
def normal_matvec(self, x):
147+
"""Computes A^T A x from matvec(x) = A x."""
148+
matvec_x, vjp = jax.vjp(self.matvec, x)
149+
return vjp(matvec_x)[0]
150+
151+
def tree_flatten(self):
152+
return self.blocks, None
153+
154+
@classmethod
155+
def tree_unflatten(cls, aux_data, children):
156+
del aux_data
157+
return cls(children)

tests/eq_qp_preconditioned_test.py

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import jax
16+
from jax import test_util as jtu
17+
import jax.numpy as jnp
18+
19+
from jaxopt import PseudoInversePreconditionedEqQP
20+
from jaxopt import EqualityConstrainedQP
21+
import numpy as onp
22+
23+
24+
class PreconditionedEqualityConstrainedQPTest(jtu.JaxTestCase):
25+
def _check_derivative_Q_c_A_b(self, solver, Q, c, A, b):
26+
def fun(Q, c, A, b):
27+
Q = 0.5 * (Q + Q.T)
28+
29+
hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
30+
# reduce the primal variables to a scalar value for test purpose.
31+
return jnp.sum(solver.run(**hyperparams).params[0])
32+
33+
# Derivative w.r.t. A.
34+
rng = onp.random.RandomState(0)
35+
V = rng.rand(*A.shape)
36+
V /= onp.sqrt(onp.sum(V ** 2))
37+
eps = 1e-4
38+
deriv_jax = jnp.vdot(V, jax.grad(fun, argnums=2)(Q, c, A, b))
39+
deriv_num = (fun(Q, c, A + eps * V, b) - fun(Q, c, A - eps * V, b)) / (2 * eps)
40+
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)
41+
42+
# Derivative w.r.t. b.
43+
v = rng.rand(*b.shape)
44+
v /= onp.sqrt(onp.sum(v ** 2))
45+
eps = 1e-4
46+
deriv_jax = jnp.vdot(v, jax.grad(fun, argnums=3)(Q, c, A, b))
47+
deriv_num = (fun(Q, c, A, b + eps * v) - fun(Q, c, A, b - eps * v)) / (2 * eps)
48+
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)
49+
50+
# Derivative w.r.t. Q
51+
W = rng.rand(*Q.shape)
52+
W /= onp.sqrt(onp.sum(W ** 2))
53+
eps = 1e-4
54+
deriv_jax = jnp.vdot(W, jax.grad(fun, argnums=0)(Q, c, A, b))
55+
deriv_num = (fun(Q + eps * W, c, A, b) - fun(Q - eps * W, c, A, b)) / (2 * eps)
56+
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)
57+
58+
# Derivative w.r.t. c
59+
w = rng.rand(*c.shape)
60+
w /= onp.sqrt(onp.sum(w ** 2))
61+
eps = 1e-4
62+
deriv_jax = jnp.vdot(w, jax.grad(fun, argnums=1)(Q, c, A, b))
63+
deriv_num = (fun(Q, c + eps * w, A, b) - fun(Q, c - eps * w, A, b)) / (2 * eps)
64+
self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)
65+
66+
def test_pseudoinverse_preconditioner(self):
67+
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
68+
c = jnp.array([1.0, 1.0])
69+
A = jnp.array([[1.0, 1.0]])
70+
b = jnp.array([1.0])
71+
qp = EqualityConstrainedQP(tol=1e-7)
72+
preconditioned_qp = PseudoInversePreconditionedEqQP(qp)
73+
params_obj = (Q, c)
74+
params_eq = (A, b)
75+
params_precond = preconditioned_qp.init_params(params_obj, params_eq)
76+
hyperparams = dict(
77+
params_obj=params_obj,
78+
params_eq=params_eq,
79+
)
80+
sol = preconditioned_qp.run(**hyperparams, params_precond=params_precond).params
81+
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
82+
self._check_derivative_Q_c_A_b(preconditioned_qp, Q, c, A, b)

tests/eq_qp_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
class EqualityConstrainedQPTest(jtu.JaxTestCase):
29-
def _check_derivative_Q_c_A_b(self, solver, params, Q, c, A, b):
29+
def _check_derivative_Q_c_A_b(self, solver, Q, c, A, b):
3030
def fun(Q, c, A, b):
3131
Q = 0.5 * (Q + Q.T)
3232

@@ -77,7 +77,7 @@ def test_qp_eq_only(self):
7777
hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
7878
sol = qp.run(**hyperparams).params
7979
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
80-
self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b)
80+
self._check_derivative_Q_c_A_b(qp, Q, c, A, b)
8181

8282
def test_qp_eq_with_init(self):
8383
Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
@@ -89,7 +89,7 @@ def test_qp_eq_with_init(self):
8989
init_params = KKTSolution(jnp.array([1.0, 1.0]), jnp.array([1.0]))
9090
sol = qp.run(init_params, **hyperparams).params
9191
self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
92-
self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b)
92+
self._check_derivative_Q_c_A_b(qp, Q, c, A, b)
9393

9494
def test_projection_hyperplane(self):
9595
x = jnp.array([1.0, 2.0])

0 commit comments

Comments
 (0)