14
14
"""Interface for linear operators."""
15
15
16
16
import functools
17
+ from dataclasses import dataclass
18
+ from typing import Tuple
19
+
17
20
import jax
18
21
import jax .numpy as jnp
19
- import numpy as onp
20
22
21
- from jaxopt .tree_util import tree_map , tree_sum , tree_mul
23
+ from jaxopt .tree_util import tree_map
22
24
23
25
24
26
class DenseLinearOperator :
25
-
26
27
def __init__ (self , pytree ):
27
28
self .pytree = pytree
28
29
@@ -33,7 +34,7 @@ def matvec(self, x):
33
34
return tree_map (jnp .dot , self .pytree , x )
34
35
35
36
def rmatvec (self , _ , y ):
36
- return tree_map (lambda w ,yi : jnp .dot (w .T , yi ), self .pytree , y )
37
+ return tree_map (lambda w , yi : jnp .dot (w .T , yi ), self .pytree , y )
37
38
38
39
def matvec_and_rmatvec (self , x , y ):
39
40
return self .matvec (x ), self .rmatvec (x , y )
@@ -52,11 +53,11 @@ def col_norm(w):
52
53
if not squared :
53
54
col_norms = jnp .sqrt (col_norms )
54
55
return col_norms
56
+
55
57
return tree_map (col_norm , self .pytree )
56
58
57
59
58
60
class FunctionalLinearOperator :
59
-
60
61
def __init__ (self , fun , params ):
61
62
self .fun = functools .partial (fun , params )
62
63
@@ -71,7 +72,7 @@ def rmatvec(self, x, y):
71
72
72
73
def matvec_and_rmatvec (self , x , y ):
73
74
matvec_x , vjp = jax .vjp (self .matvec , x )
74
- rmatvec_y , = vjp (y )
75
+ ( rmatvec_y ,) = vjp (y )
75
76
return matvec_x , rmatvec_y
76
77
77
78
def normal_matvec (self , x ):
@@ -85,3 +86,72 @@ def _make_linear_operator(matvec):
85
86
return DenseLinearOperator
86
87
else :
87
88
return functools .partial (FunctionalLinearOperator , matvec )
89
+
90
+
91
+ def block_row_matvec (block , x ):
92
+ """Performs a matvec for a row of block matrices.
93
+
94
+ The following matvec is performed:
95
+ [U1, ..., UN] * [x1, ..., xN]
96
+ where U1, ..., UN are matrices and x1, ..., xN are vectors
97
+ of compatible shapes.
98
+ """
99
+ if len (block ) != len (x ):
100
+ raise ValueError (
101
+ "We need as many blocks in the matrix as in the vector."
102
+ )
103
+ return sum (jax .tree_util .tree_map (jnp .dot , block , x ))
104
+
105
+
106
+ # TODO(gnegiar): Extend to arbitrary block shapes.
107
+ @jax .tree_util .register_pytree_node_class
108
+ @dataclass
109
+ class BlockLinearOperator :
110
+ """Represents a linear operator defined by blocks over a block pytree.
111
+
112
+ Attributes:
113
+ blocks: a 2x2 block matrix of the form
114
+ [[A, B]
115
+ [C, D]]
116
+ """
117
+
118
+ blocks : Tuple [Tuple [jnp .array ]]
119
+
120
+ def __call__ (self , x ):
121
+ return self .matvec (x )
122
+
123
+ def matvec (self , x ):
124
+ """Performs the block matvec with u defined by blocks.
125
+
126
+ The matvec is of form:
127
+ [u1, u2]
128
+ [[A, B] *
129
+ [C, D]]
130
+
131
+ """
132
+ return jax .tree_util .tree_map (
133
+ lambda row_of_blocks : block_row_matvec (row_of_blocks , x ),
134
+ self .blocks ,
135
+ is_leaf = lambda x : x is self .blocks [0 ] or x is self .blocks [1 ],
136
+ )
137
+
138
+ def rmatvec (self , x , y ):
139
+ return self .matvec_and_rmatvec (x , y )[1 ]
140
+
141
+ def matvec_and_rmatvec (self , x , y ):
142
+ matvec_x , vjp = jax .vjp (self .matvec , x )
143
+ (rmatvec_y ,) = vjp (y )
144
+ return matvec_x , rmatvec_y
145
+
146
+ def normal_matvec (self , x ):
147
+ """Computes A^T A x from matvec(x) = A x."""
148
+ matvec_x , vjp = jax .vjp (self .matvec , x )
149
+ return vjp (matvec_x )[0 ]
150
+
151
+ def tree_flatten (self ):
152
+ return self .blocks , None
153
+
154
+ @classmethod
155
+ def tree_unflatten (cls , aux_data , children ):
156
+ del aux_data
157
+ return cls (children )
0 commit comments