diff --git a/jaxopt/__init__.py b/jaxopt/__init__.py
index 96aef470..fc7f3063 100644
--- a/jaxopt/__init__.py
+++ b/jaxopt/__init__.py
@@ -35,6 +35,7 @@
 from jaxopt._src.polyak_sgd import PolyakSGD
 from jaxopt._src.projected_gradient import ProjectedGradient
 from jaxopt._src.proximal_gradient import ProximalGradient
+from jaxopt._src.eq_qp_preconditioned import PseudoInversePreconditionedEqQP
 from jaxopt._src.quadratic_prog import QuadraticProgramming
 from jaxopt._src.scipy_wrappers import ScipyBoundedLeastSquares
 from jaxopt._src.scipy_wrappers import ScipyBoundedMinimize
diff --git a/jaxopt/_src/eq_qp.py b/jaxopt/_src/eq_qp.py
index 65fc3936..fcb8a934 100644
--- a/jaxopt/_src/eq_qp.py
+++ b/jaxopt/_src/eq_qp.py
@@ -57,7 +57,8 @@ def eq_fun(primal_var, params_eq):
 
   # It is required to post_process the output of `idf.make_kkt_optimality_fun`
   # to make the signatures of optimality_fun() and run() agree.
-  def optimality_fun(params, params_obj, params_eq):
+  # The M argument is needed for using preconditioners.
+  def optimality_fun(params, params_obj, params_eq, M=None):
     return optimality_fun_with_ineq(params, params_obj, params_eq, None)
 
   return optimality_fun
@@ -103,7 +104,7 @@ class EqualityConstrainedQP(base.Solver):
   implicit_diff_solve: Optional[Callable] = None
   jit: bool = True
 
-  def _refined_solve(self, matvec, b, init, maxiter, tol):
+  def _refined_solve(self, matvec, b, init, maxiter, tol, **kwargs):
     # Instead of solving S x = b
     # We solve     \bar{S} x = b
     #
@@ -152,13 +153,14 @@ def matvec_regularized_qp(_, x):
       maxiter=self.refine_maxiter,
       tol=tol,
     )
-    return solver.run(init_params=init, A=None, b=b)[0]
+    return solver.run(init_params=init, A=None, b=b, **kwargs)[0]
 
   def run(
     self,
     init_params: Optional[base.KKTSolution] = None,
     params_obj: Optional[Any] = None,
     params_eq: Optional[Any] = None,
+    **kwargs,
   ) -> base.OptStep:
     """Solves 0.5 * x^T Q x + c^T x subject to Ax = b.
 
@@ -168,6 +170,7 @@ def run(
       init_params: ignored.
       params_obj: (Q, c) or (params_Q, c) if matvec_Q is provided.
       params_eq: (A, b) or (params_A, b) if matvec_A is provided.
+      **kwargs: Keyword args provided to the solver.
     Returns:
       (params, state),  where params = (primal_var, dual_var_eq, None)
     """
@@ -200,10 +203,11 @@ def matvec(u):
         init=init_params,
         tol=self.tol,
         maxiter=self.maxiter,
+        **kwargs,
       )
     else:
       primal, dual_eq = self._refined_solve(
-        matvec, target, init_params, tol=self.tol, maxiter=self.maxiter
+        matvec, target, init_params, tol=self.tol, maxiter=self.maxiter, **kwargs
       )
 
     return base.OptStep(params=base.KKTSolution(primal, dual_eq, None), state=None)
diff --git a/jaxopt/_src/eq_qp_preconditioned.py b/jaxopt/_src/eq_qp_preconditioned.py
new file mode 100644
index 00000000..3490f6f6
--- /dev/null
+++ b/jaxopt/_src/eq_qp_preconditioned.py
@@ -0,0 +1,58 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Preconditioned solvers for equality constrained quadratic programming."""
+
+from typing import Optional, Any
+from dataclasses import dataclass
+import jax.numpy as jnp
+import jaxopt
+from jaxopt._src import base
+from jaxopt._src import linear_operator
+
+
+@dataclass
+class PseudoInversePreconditionedEqQP(base.Solver):
+  qp_solver: jaxopt.EqualityConstrainedQP
+
+  def init_params(self, params_obj, params_eq):
+    """Computes the matvec associated to the pseudo inverse of the KKT matrix."""
+    Q, p = params_obj
+    A, b = params_eq
+    del p, b
+
+    kkt_mat = jnp.block([[Q, A.T], [A, jnp.zeros((A.shape[0], A.shape[0]))]])
+    kkt_mat_pinv = jnp.linalg.pinv(kkt_mat)
+
+    d = Q.shape[0]
+
+    pinv_blocks = (
+      (kkt_mat_pinv[:d, :d], kkt_mat_pinv[:d, d:]),
+      (kkt_mat_pinv[d:, :d], kkt_mat_pinv[d:, d:]),
+    )
+    return linear_operator.BlockLinearOperator(pinv_blocks)
+
+  def run(
+    self,
+    init_params: Optional[base.KKTSolution] = None,
+    params_obj: Optional[Any] = None,
+    params_eq: Optional[Any] = None,
+    params_precond=None,
+    **kwargs
+  ):
+    # TODO(gnegiar): the M parameter should be passed to both
+    # the QP solve and the implicit_diff_solve
+    return self.qp_solver.run(
+      init_params, params_obj, params_eq, M=params_precond, **kwargs
+    )
diff --git a/jaxopt/_src/linear_operator.py b/jaxopt/_src/linear_operator.py
index 34e4941e..0bd0ce2f 100644
--- a/jaxopt/_src/linear_operator.py
+++ b/jaxopt/_src/linear_operator.py
@@ -14,15 +14,16 @@
 """Interface for linear operators."""
 
 import functools
+from dataclasses import dataclass
+from typing import Tuple
+
 import jax
 import jax.numpy as jnp
-import numpy as onp
 
-from jaxopt.tree_util import tree_map, tree_sum, tree_mul
+from jaxopt.tree_util import tree_map
 
 
 class DenseLinearOperator:
-
   def __init__(self, pytree):
     self.pytree = pytree
 
@@ -33,7 +34,7 @@ def matvec(self, x):
     return tree_map(jnp.dot, self.pytree, x)
 
   def rmatvec(self, _, y):
-    return tree_map(lambda w,yi: jnp.dot(w.T, yi), self.pytree, y)
+    return tree_map(lambda w, yi: jnp.dot(w.T, yi), self.pytree, y)
 
   def matvec_and_rmatvec(self, x, y):
     return self.matvec(x), self.rmatvec(x, y)
@@ -52,11 +53,11 @@ def col_norm(w):
       if not squared:
         col_norms = jnp.sqrt(col_norms)
       return col_norms
+
     return tree_map(col_norm, self.pytree)
 
 
 class FunctionalLinearOperator:
-
   def __init__(self, fun, params):
     self.fun = functools.partial(fun, params)
 
@@ -71,7 +72,7 @@ def rmatvec(self, x, y):
 
   def matvec_and_rmatvec(self, x, y):
     matvec_x, vjp = jax.vjp(self.matvec, x)
-    rmatvec_y, = vjp(y)
+    (rmatvec_y,) = vjp(y)
     return matvec_x, rmatvec_y
 
   def normal_matvec(self, x):
@@ -85,3 +86,72 @@ def _make_linear_operator(matvec):
     return DenseLinearOperator
   else:
     return functools.partial(FunctionalLinearOperator, matvec)
+
+
+def block_row_matvec(block, x):
+  """Performs a matvec for a row of block matrices.
+  
+  The following matvec is performed: 
+  [U1, ..., UN] * [x1, ..., xN]
+  where U1, ..., UN are matrices and x1, ..., xN are vectors
+  of compatible shapes.
+  """
+  if len(block) != len(x):
+    raise ValueError(
+      "We need as many blocks in the matrix as in the vector."
+      )
+  return sum(jax.tree_util.tree_map(jnp.dot, block, x))
+
+
+# TODO(gnegiar): Extend to arbitrary block shapes.
+@jax.tree_util.register_pytree_node_class
+@dataclass
+class BlockLinearOperator:
+  """Represents a linear operator defined by blocks over a block pytree.
+
+  Attributes:
+    blocks: a 2x2 block matrix of the form
+      [[A, B]
+       [C, D]]
+  """
+
+  blocks: Tuple[Tuple[jnp.array]]
+
+  def __call__(self, x):
+    return self.matvec(x)
+
+  def matvec(self, x):
+    """Performs the block matvec with u defined by blocks.
+
+    The matvec is of form:
+               [u1, u2]
+    [[A, B]  *
+     [C, D]]
+
+    """
+    return jax.tree_util.tree_map(
+      lambda row_of_blocks: block_row_matvec(row_of_blocks, x),
+      self.blocks,
+      is_leaf=lambda x: x is self.blocks[0] or x is self.blocks[1],
+    )
+
+  def rmatvec(self, x, y):
+    return self.matvec_and_rmatvec(x, y)[1]
+
+  def matvec_and_rmatvec(self, x, y):
+    matvec_x, vjp = jax.vjp(self.matvec, x)
+    (rmatvec_y,) = vjp(y)
+    return matvec_x, rmatvec_y
+
+  def normal_matvec(self, x):
+    """Computes A^T A x from matvec(x) = A x."""
+    matvec_x, vjp = jax.vjp(self.matvec, x)
+    return vjp(matvec_x)[0]
+
+  def tree_flatten(self):
+    return self.blocks, None
+
+  @classmethod
+  def tree_unflatten(cls, aux_data, children):
+    del aux_data
+    return cls(children)
diff --git a/tests/eq_qp_preconditioned_test.py b/tests/eq_qp_preconditioned_test.py
new file mode 100644
index 00000000..f05f7750
--- /dev/null
+++ b/tests/eq_qp_preconditioned_test.py
@@ -0,0 +1,82 @@
+# Copyright 2021 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import jax
+from jax import test_util as jtu
+import jax.numpy as jnp
+
+from jaxopt import PseudoInversePreconditionedEqQP
+from jaxopt import EqualityConstrainedQP
+import numpy as onp
+
+
+class PreconditionedEqualityConstrainedQPTest(jtu.JaxTestCase):
+  def _check_derivative_Q_c_A_b(self, solver, Q, c, A, b):
+    def fun(Q, c, A, b):
+      Q = 0.5 * (Q + Q.T)
+
+      hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
+      # reduce the primal variables to a scalar value for test purpose.
+      return jnp.sum(solver.run(**hyperparams).params[0])
+
+    # Derivative w.r.t. A.
+    rng = onp.random.RandomState(0)
+    V = rng.rand(*A.shape)
+    V /= onp.sqrt(onp.sum(V ** 2))
+    eps = 1e-4
+    deriv_jax = jnp.vdot(V, jax.grad(fun, argnums=2)(Q, c, A, b))
+    deriv_num = (fun(Q, c, A + eps * V, b) - fun(Q, c, A - eps * V, b)) / (2 * eps)
+    self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)
+
+    # Derivative w.r.t. b.
+    v = rng.rand(*b.shape)
+    v /= onp.sqrt(onp.sum(v ** 2))
+    eps = 1e-4
+    deriv_jax = jnp.vdot(v, jax.grad(fun, argnums=3)(Q, c, A, b))
+    deriv_num = (fun(Q, c, A, b + eps * v) - fun(Q, c, A, b - eps * v)) / (2 * eps)
+    self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)
+
+    # Derivative w.r.t. Q
+    W = rng.rand(*Q.shape)
+    W /= onp.sqrt(onp.sum(W ** 2))
+    eps = 1e-4
+    deriv_jax = jnp.vdot(W, jax.grad(fun, argnums=0)(Q, c, A, b))
+    deriv_num = (fun(Q + eps * W, c, A, b) - fun(Q - eps * W, c, A, b)) / (2 * eps)
+    self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)
+
+    # Derivative w.r.t. c
+    w = rng.rand(*c.shape)
+    w /= onp.sqrt(onp.sum(w ** 2))
+    eps = 1e-4
+    deriv_jax = jnp.vdot(w, jax.grad(fun, argnums=1)(Q, c, A, b))
+    deriv_num = (fun(Q, c + eps * w, A, b) - fun(Q, c - eps * w, A, b)) / (2 * eps)
+    self.assertAllClose(deriv_jax, deriv_num, atol=1e-3)
+
+  def test_pseudoinverse_preconditioner(self):
+    Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
+    c = jnp.array([1.0, 1.0])
+    A = jnp.array([[1.0, 1.0]])
+    b = jnp.array([1.0])
+    qp = EqualityConstrainedQP(tol=1e-7)
+    preconditioned_qp = PseudoInversePreconditionedEqQP(qp)
+    params_obj = (Q, c)
+    params_eq = (A, b)
+    params_precond = preconditioned_qp.init_params(params_obj, params_eq)
+    hyperparams = dict(
+      params_obj=params_obj,
+      params_eq=params_eq,
+    )
+    sol = preconditioned_qp.run(**hyperparams, params_precond=params_precond).params
+    self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
+    self._check_derivative_Q_c_A_b(preconditioned_qp, Q, c, A, b)
diff --git a/tests/eq_qp_test.py b/tests/eq_qp_test.py
index 071c3d64..11aeae32 100644
--- a/tests/eq_qp_test.py
+++ b/tests/eq_qp_test.py
@@ -26,7 +26,7 @@
 
 
 class EqualityConstrainedQPTest(jtu.JaxTestCase):
-  def _check_derivative_Q_c_A_b(self, solver, params, Q, c, A, b):
+  def _check_derivative_Q_c_A_b(self, solver, Q, c, A, b):
     def fun(Q, c, A, b):
       Q = 0.5 * (Q + Q.T)
 
@@ -77,7 +77,7 @@ def test_qp_eq_only(self):
     hyperparams = dict(params_obj=(Q, c), params_eq=(A, b))
     sol = qp.run(**hyperparams).params
     self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
-    self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b)
+    self._check_derivative_Q_c_A_b(qp, Q, c, A, b)
 
   def test_qp_eq_with_init(self):
     Q = 2 * jnp.array([[2.0, 0.5], [0.5, 1]])
@@ -89,7 +89,7 @@ def test_qp_eq_with_init(self):
     init_params = KKTSolution(jnp.array([1.0, 1.0]), jnp.array([1.0]))
     sol = qp.run(init_params, **hyperparams).params
     self.assertAllClose(qp.l2_optimality_error(sol, **hyperparams), 0.0)
-    self._check_derivative_Q_c_A_b(qp, hyperparams, Q, c, A, b)
+    self._check_derivative_Q_c_A_b(qp, Q, c, A, b)
 
   def test_projection_hyperplane(self):
     x = jnp.array([1.0, 2.0])