-
Notifications
You must be signed in to change notification settings - Fork 1
/
jaxeigs.py
1806 lines (1583 loc) · 75.8 KB
/
jaxeigs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
__all__ = ['eigs','eigsh']
import jax
import jax.numpy as jnp
import warnings
from typing import Optional,Type, Any, List, Tuple, Callable, Sequence, Text
import functools
import types
import numpy as np
Tensor = Any
Tensor = Any
_CACHED_MATVECS = {}
_CACHED_FUNCTIONS = {}
def randn(shape: Tuple[int, ...],
dtype: Optional[np.dtype] = None,
seed: Optional[int] = None) -> Tensor:
if not seed:
seed = np.random.randint(0, 2**63)
key = jax.random.PRNGKey(seed)
dtype = dtype if dtype is not None else np.dtype(np.float64)
def cmplx_randn(complex_dtype, real_dtype):
real_dtype = np.dtype(real_dtype)
complex_dtype = np.dtype(complex_dtype)
key_2 = jax.random.PRNGKey(seed + 1)
real_part = jax.random.normal(key, shape, dtype=real_dtype)
complex_part = jax.random.normal(key_2, shape, dtype=real_dtype)
unit = (
np.complex64(1j)
if complex_dtype == np.dtype(np.complex64) else np.complex128(1j))
return real_part + unit * complex_part
if np.dtype(dtype) is np.dtype(jnp.complex128):
return cmplx_randn(dtype, jnp.float64)
if np.dtype(dtype) is np.dtype(jnp.complex64):
return cmplx_randn(dtype, jnp.float32)
return jax.random.normal(key, shape).astype(dtype)
def random_uniform(shape: Tuple[int, ...],
boundaries: Optional[Tuple[float, float]] = (0.0, 1.0),
dtype: Optional[np.dtype] = None,
seed: Optional[int] = None) -> Tensor:
if not seed:
seed = np.random.randint(0, 2**63)
key = jax.random.PRNGKey(seed)
dtype = dtype if dtype is not None else np.dtype(np.float64)
def cmplx_random_uniform(complex_dtype, real_dtype):
real_dtype = np.dtype(real_dtype)
complex_dtype = np.dtype(complex_dtype)
key_2 = jax.random.PRNGKey(seed + 1)
real_part = jax.random.uniform(
key,
shape,
dtype=real_dtype,
minval=boundaries[0],
maxval=boundaries[1])
complex_part = jax.random.uniform(
key_2,
shape,
dtype=real_dtype,
minval=boundaries[0],
maxval=boundaries[1])
unit = (
np.complex64(1j)
if complex_dtype == np.dtype(np.complex64) else np.complex128(1j))
return real_part + unit * complex_part
if np.dtype(dtype) is np.dtype(jnp.complex128):
return cmplx_random_uniform(dtype, jnp.float64)
if np.dtype(dtype) is np.dtype(jnp.complex64):
return cmplx_random_uniform(dtype, jnp.float32)
return jax.random.uniform(
key, shape, minval=boundaries[0], maxval=boundaries[1]).astype(dtype)
"""
Implicitly restarted Arnoldi method for finding the lowest
eigenvector-eigenvalue pairs of a linear operator `A`.
`A` is a function implementing the matrix-vector
product.
WARNING: This routine uses jax.jit to reduce runtimes. jitting is triggered
at the first invocation of `eigs`, and on any subsequent calls
if the python `id` of `A` changes, even if the formal definition of `A`
stays the same.
Example: the following will jit once at the beginning, and then never again:
```python
import jax
import numpy as np
def A(H,x):
return jax.np.dot(H,x)
for n in range(100):
H = jax.np.array(np.random.rand(10,10))
x = jax.np.array(np.random.rand(10,10))
res = eigs(A, [H],x) #jitting is triggerd only at `n=0`
```
The following code triggers jitting at every iteration, which
results in considerably reduced performance
```python
import jax
import numpy as np
for n in range(100):
def A(H,x):
return jax.np.dot(H,x)
H = jax.np.array(np.random.rand(10,10))
x = jax.np.array(np.random.rand(10,10))
res = eigs(A, [H],x) #jitting is triggerd at every step `n`
```
Args:
A: A (sparse) implementation of a linear operator.
Call signature of `A` is `res = A(vector, *args)`, where `vector`
can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`.
args: A list of arguments to `A`. `A` will be called as
`res = A(initial_state, *args)`.
initial_state: An initial vector for the algorithm. If `None`,
a random initial `Tensor` is created using the `backend.randn` method
shape: The shape of the input-dimension of `A`.
dtype: The dtype of the input `A`. If no `initial_state` is provided,
a random initial state with shape `shape` and dtype `dtype` is created.
num_krylov_vecs: The number of iterations (number of krylov vectors).
numeig: The number of eigenvector-eigenvalue pairs to be computed.
tol: The desired precision of the eigenvalues. For the jax backend
this has currently no effect, and precision of eigenvalues is not
guaranteed. This feature may be added at a later point. To increase
precision the caller can either increase `maxiter` or `num_krylov_vecs`.
which: Flag for targetting different types of eigenvalues. Currently
supported are `which = 'LR'` (larges real part) and `which = 'LM'`
(larges magnitude).
maxiter: Maximum number of restarts. For `maxiter=0` the routine becomes
equivalent to a simple Arnoldi method.
Returns:
(eigvals, eigvecs)
eigvals: A list of `numeig` eigenvalues
eigvecs: A list of `numeig` eigenvectors
"""
def eigs(A: Callable,
args: Optional[List] = None,
initial_state: Optional[Tensor] = None,
shape: Optional[Tuple[int, ...]] = None,
dtype: Optional[Type[np.number]] = None,
num_krylov_vecs: int = 50,
numeig: int = 6,
tol: float = 1E-8,
which: Text = 'LM',
maxiter: int = 20) -> Tuple[Tensor, List]:
if args is None:
args = []
if which not in ('LR', 'LM'):
raise ValueError(f'which = {which} is currently not supported.')
if numeig > num_krylov_vecs:
raise ValueError('`num_krylov_vecs` >= `numeig` required!')
if initial_state is None:
if (shape is None) or (dtype is None):
raise ValueError("if no `initial_state` is passed, then `shape` and"
"`dtype` have to be provided")
initial_state = randn(shape, dtype)
if not isinstance(initial_state, (jnp.ndarray, np.ndarray)):
raise TypeError("Expected a `jax.array`. Got {}".format(
type(initial_state)))
if A not in _CACHED_MATVECS:
_CACHED_MATVECS[A] = jax.tree_util.Partial(jax.jit(A))
if "imp_arnoldi" not in _CACHED_FUNCTIONS:
imp_arnoldi = _implicitly_restarted_arnoldi(jax)
_CACHED_FUNCTIONS["imp_arnoldi"] = imp_arnoldi
eta, U, numits = _CACHED_FUNCTIONS["imp_arnoldi"](_CACHED_MATVECS[A], args,
initial_state,
num_krylov_vecs, numeig,
which, tol, maxiter,
jax.lax.Precision.DEFAULT)
if numeig > numits:
warnings.warn(
f"Arnoldi terminated early after numits = {numits}"
f" < numeig = {numeig} steps. For this value of `numeig `"
f"the routine will return spurious eigenvalues of value 0.0."
f"Use a smaller value of numeig, or a smaller value for `tol`")
return eta, U
def eigsh(A: Callable,
args: Optional[List] = None,
initial_state: Optional[Tensor] = None,
shape: Optional[Tuple[int, ...]] = None,
dtype: Optional[Type[np.number]] = None,
num_krylov_vecs: int = 50,
numeig: int = 6,
tol: float = 1E-8,
which: Text = 'SA',
maxiter: int = 20) -> Tuple[Tensor, List]:
"""
Implicitly restarted Lanczos method for finding the lowest
eigenvector-eigenvalue pairs of a symmetric (hermitian) linear operator `A`.
`A` is a function implementing the matrix-vector
product.
WARNING: This routine uses jax.jit to reduce runtimes. jitting is triggered
at the first invocation of `eigsh`, and on any subsequent calls
if the python `id` of `A` changes, even if the formal definition of `A`
stays the same.
Example: the following will jit once at the beginning, and then never again:
```python
import jax
import numpy as np
def A(H,x):
return jax.np.dot(H,x)
for n in range(100):
H = jax.np.array(np.random.rand(10,10))
x = jax.np.array(np.random.rand(10,10))
res = eigsh(A, [H],x) #jitting is triggerd only at `n=0`
```
The following code triggers jitting at every iteration, which
results in considerably reduced performance
```python
import jax
import numpy as np
for n in range(100):
def A(H,x):
return jax.np.dot(H,x)
H = jax.np.array(np.random.rand(10,10))
x = jax.np.array(np.random.rand(10,10))
res = eigsh(A, [H],x) #jitting is triggerd at every step `n`
```
Args:
A: A (sparse) implementation of a linear operator.
Call signature of `A` is `res = A(vector, *args)`, where `vector`
can be an arbitrary `Tensor`, and `res.shape` has to be `vector.shape`.
args: A list of arguments to `A`. `A` will be called as
`res = A(initial_state, *args)`.
initial_state: An initial vector for the algorithm. If `None`,
a random initial `Tensor` is created using the `backend.randn` method
shape: The shape of the input-dimension of `A`.
dtype: The dtype of the input `A`. If no `initial_state` is provided,
a random initial state with shape `shape` and dtype `dtype` is created.
num_krylov_vecs: The number of iterations (number of krylov vectors).
numeig: The number of eigenvector-eigenvalue pairs to be computed.
tol: The desired precision of the eigenvalues. For the jax backend
this has currently no effect, and precision of eigenvalues is not
guaranteed. This feature may be added at a later point. To increase
precision the caller can either increase `maxiter` or `num_krylov_vecs`.
which: Flag for targetting different types of eigenvalues. Currently
supported are `which = 'LR'` (larges real part) and `which = 'LM'`
(larges magnitude).
maxiter: Maximum number of restarts. For `maxiter=0` the routine becomes
equivalent to a simple Arnoldi method.
Returns:
(eigvals, eigvecs)
eigvals: A list of `numeig` eigenvalues
eigvecs: A list of `numeig` eigenvectors
"""
if args is None:
args = []
if which not in ('SA', 'LA', 'LM'):
raise ValueError(f'which = {which} is currently not supported.')
if numeig > num_krylov_vecs:
raise ValueError('`num_krylov_vecs` >= `numeig` required!')
if initial_state is None:
if (shape is None) or (dtype is None):
raise ValueError("if no `initial_state` is passed, then `shape` and"
"`dtype` have to be provided")
initial_state = randn(shape, dtype)
if not isinstance(initial_state, (jnp.ndarray, np.ndarray)):
raise TypeError("Expected a `jax.array`. Got {}".format(
type(initial_state)))
if A not in _CACHED_MATVECS:
_CACHED_MATVECS[A] = jax.tree_util.Partial(jax.jit(A))
if "imp_lanczos" not in _CACHED_FUNCTIONS:
imp_lanczos = _implicitly_restarted_lanczos(jax)
_CACHED_FUNCTIONS["imp_lanczos"] = imp_lanczos
eta, U, numits = _CACHED_FUNCTIONS["imp_lanczos"](_CACHED_MATVECS[A], args,
initial_state,
num_krylov_vecs, numeig,
which, tol, maxiter,
jax.lax.Precision.DEFAULT)
if numeig > numits:
warnings.warn(
f"Arnoldi terminated early after numits = {numits}"
f" < numeig = {numeig} steps. For this value of `numeig `"
f"the routine will return spurious eigenvalues of value 0.0."
f"Use a smaller value of numeig, or a smaller value for `tol`")
return eta, U
def cpu_eig_host(H):
res = np.linalg.eig(H)
print(res)
return res
def cpu_eig(H):
result_shape = (jax.ShapeDtypeStruct(H.shape[0:1], H.dtype),
jax.ShapeDtypeStruct(H.shape, H.dtype))
return jax.pure_callback(cpu_eig_host, result_shape, H)
def _iterative_classical_gram_schmidt(jax: types.ModuleType) -> Callable:
JaxPrecisionType = type(jax.lax.Precision.DEFAULT)
def iterative_classical_gram_schmidt(
vector: jax.Array,
krylov_vectors: jax.Array,
precision: JaxPrecisionType,
iterations: int = 2,
) -> jax.Array:
"""
Orthogonalize `vector` to all rows of `krylov_vectors`.
Args:
vector: Initial vector.
krylov_vectors: Matrix of krylov vectors, each row is treated as a
vector.
iterations: Number of iterations.
Returns:
jax.Array: The orthogonalized vector.
"""
i1 = list(range(1, len(krylov_vectors.shape)))
i2 = list(range(len(vector.shape)))
vec = vector
overlaps = 0
for _ in range(iterations):
ov = jax.numpy.tensordot(
krylov_vectors.conj(), vec, (i1, i2), precision=precision)
vec = vec - jax.numpy.tensordot(
ov, krylov_vectors, ([0], [0]), precision=precision)
overlaps = overlaps + ov
return vec, overlaps
return iterative_classical_gram_schmidt
def _generate_jitted_eigsh_lanczos(jax: types.ModuleType) -> Callable:
"""
Helper function to generate jitted lanczos function used
in JaxBackend.eigsh_lanczos. The function `jax_lanczos`
returned by this higher-order function has the following
call signature:
```
eigenvalues, eigenvectors = jax_lanczos(matvec:Callable,
arguments: List[Tensor],
init: Tensor,
ncv: int,
neig: int,
landelta: float,
reortho: bool)
```
`matvec`: A callable implementing the matrix-vector product of a
linear operator. `arguments`: Arguments to `matvec` additional to
an input vector. `matvec` will be called as `matvec(init, *args)`.
`init`: An initial input vector to `matvec`.
`ncv`: Number of krylov iterations (i.e. dimension of the Krylov space).
`neig`: Number of eigenvalue-eigenvector pairs to be computed.
`landelta`: Convergence parameter: if the norm of the current Lanczos vector
`reortho`: If `True`, reorthogonalize all krylov vectors at each step.
This should be used if `neig>1`.
Args:
jax: The `jax` module.
Returns:
Callable: A jitted function that does a lanczos iteration.
"""
JaxPrecisionType = type(jax.lax.Precision.DEFAULT)
@functools.partial(jax.jit, static_argnums=(3, 4, 5, 6, 7))
def jax_lanczos(matvec: Callable, arguments: List, init: jax.Array,
ncv: int, neig: int, landelta: float, reortho: bool,
precision: JaxPrecisionType) -> Tuple[jax.Array, List]:
"""
Lanczos iteration for symmeric eigenvalue problems. If reortho = False,
the Krylov basis is constructed without explicit re-orthogonalization.
In infinite precision, all Krylov vectors would be orthogonal. Due to
finite precision arithmetic, orthogonality is usually quickly lost.
For reortho=True, the Krylov basis is explicitly reorthogonalized.
Args:
matvec: A callable implementing the matrix-vector product of a
linear operator.
arguments: Arguments to `matvec` additional to an input vector.
`matvec` will be called as `matvec(init, *args)`.
init: An initial input vector to `matvec`.
ncv: Number of krylov iterations (i.e. dimension of the Krylov space).
neig: Number of eigenvalue-eigenvector pairs to be computed.
landelta: Convergence parameter: if the norm of the current Lanczos vector
falls below `landelta`, iteration is stopped.
reortho: If `True`, reorthogonalize all krylov vectors at each step.
This should be used if `neig>1`.
precision: jax.lax.Precision type used in jax.numpy.vdot
Returns:
jax.Array: Eigenvalues
List: Eigenvectors
int: Number of iterations
"""
shape = init.shape
dtype = init.dtype
iterative_classical_gram_schmidt = _iterative_classical_gram_schmidt(jax)
mask_slice = (slice(ncv + 2), ) + (None,) * len(shape)
def scalar_product(a, b):
i1 = list(range(len(a.shape)))
i2 = list(range(len(b.shape)))
return jax.numpy.tensordot(a.conj(), b, (i1, i2), precision=precision)
def norm(a):
return jax.numpy.sqrt(scalar_product(a, a))
def body_lanczos(vals):
krylov_vectors, alphas, betas, i = vals
previous_vector = krylov_vectors[i]
def body_while(vals):
pv, kv, _ = vals
pv = iterative_classical_gram_schmidt(
pv, (i > jax.numpy.arange(ncv + 2))[mask_slice] * kv, precision)[0]
return [pv, kv, False]
def cond_while(vals):
return vals[2]
previous_vector, krylov_vectors, _ = jax.lax.while_loop(
cond_while, body_while,
[previous_vector, krylov_vectors, reortho])
beta = norm(previous_vector)
normalized_vector = previous_vector / beta
Av = matvec(normalized_vector, *arguments)
alpha = scalar_product(normalized_vector, Av)
alphas = alphas.at[i - 1].set(alpha)
betas = betas.at[i].set(beta)
def while_next(vals):
Av, _ = vals
res = Av - normalized_vector * alpha - krylov_vectors[i - 1] * beta
return [res, False]
def cond_next(vals):
return vals[1]
next_vector, _ = jax.lax.while_loop(
cond_next, while_next,
[Av, jax.numpy.logical_not(reortho)])
next_vector = jax.numpy.reshape(next_vector, shape)
krylov_vectors = krylov_vectors.at[i].set(normalized_vector)
krylov_vectors = krylov_vectors.at[i + 1].set(next_vector)
return [krylov_vectors, alphas, betas, i + 1]
def cond_fun(vals):
betas, i = vals[-2], vals[-1]
norm = betas[i - 1]
return jax.lax.cond(i <= ncv, lambda x: x[0] > x[1], lambda x: False,
[norm, landelta])
# note: ncv + 2 because the first vector is all zeros, and the
# last is the unnormalized residual.
krylov_vecs = jax.numpy.zeros((ncv + 2,) + shape, dtype=dtype)
# NOTE (mganahl): initial vector is normalized inside the loop
krylov_vecs = krylov_vecs.at[1].set(init)
# betas are the upper and lower diagonal elements
# of the projected linear operator
# the first two beta-values can be discarded
# set betas[0] to 1.0 for initialization of loop
# betas[2] is set to the norm of the initial vector.
betas = jax.numpy.zeros(ncv + 1, dtype=dtype)
betas = betas.at[0].set(1.0)
# diagonal elements of the projected linear operator
alphas = jax.numpy.zeros(ncv, dtype=dtype)
initvals = [krylov_vecs, alphas, betas, 1]
krylov_vecs, alphas, betas, numits = jax.lax.while_loop(
cond_fun, body_lanczos, initvals)
# FIXME (mganahl): if the while_loop stopps early at iteration i, alphas
# and betas are 0.0 at positions n >= i - 1. eigh will then wrongly give
# degenerate eigenvalues 0.0. JAX does currently not support
# dynamic slicing with variable slice sizes, so these beta values
# can't be truncated. Thus, if numeig >= i - 1, jitted_lanczos returns
# a set of spurious eigen vectors and eigen values.
# If algebraically small EVs are desired, one can initialize `alphas` with
# large positive values, thus pushing the spurious eigenvalues further
# away from the desired ones (similar for algebraically large EVs)
#FIXME: replace with eigh_banded once JAX supports it
A_tridiag = jax.numpy.diag(alphas) + jax.numpy.diag(
betas[2:], 1) + jax.numpy.diag(jax.numpy.conj(betas[2:]), -1)
eigvals, U = jax.numpy.linalg.eigh(A_tridiag)
eigvals = eigvals.astype(dtype)
# expand eigenvectors in krylov basis
def body_vector(i, vals):
krv, unitary, vectors = vals
dim = unitary.shape[1]
n, m = jax.numpy.divmod(i, dim)
vectors = vectors.at[n, :].set(vectors[n, :] + krv[m + 1] * unitary[m, n])
return [krv, unitary, vectors]
_vectors = jax.numpy.zeros((neig,) + shape, dtype=dtype)
_, _, vectors = jax.lax.fori_loop(0, neig * (krylov_vecs.shape[0] - 1),
body_vector,
[krylov_vecs, U, _vectors])
return jax.numpy.array(eigvals[0:neig]), [
vectors[n] / norm(vectors[n]) for n in range(neig)
], numits
return jax_lanczos
def _generate_lanczos_factorization(jax: types.ModuleType) -> Callable:
"""
Helper function to generate a jitteed function that
computes a lanczos factoriazation of a linear operator.
Returns:
Callable: A jitted function that does a lanczos factorization.
"""
JaxPrecisionType = type(jax.lax.Precision.DEFAULT)
@functools.partial(jax.jit, static_argnums=(6, 7, 8, 9))
def _lanczos_fact(
matvec: Callable, args: List, v0: jax.Array,
Vm: jax.Array, alphas: jax.Array, betas: jax.Array,
start: int, num_krylov_vecs: int, tol: float, precision: JaxPrecisionType
):
"""
Compute an m-step lanczos factorization of `matvec`, with
m <=`num_krylov_vecs`. The factorization will
do at most `num_krylov_vecs` steps, and terminate early
if an invariat subspace is encountered. The returned arrays
`alphas`, `betas` and `Vm` will satisfy the Lanczos recurrence relation
```
matrix @ Vm - Vm @ Hm - fm * em = 0
```
with `matrix` the matrix representation of `matvec`,
`Hm = jnp.diag(alphas) + jnp.diag(betas, -1) + jnp.diag(betas.conj(), 1)`
`fm=residual * norm`, and `em` a cartesian basis vector of shape
`(1, kv.shape[1])` with `em[0, -1] == 1` and 0 elsewhere.
Note that the caller is responsible for dtype consistency between
the inputs, i.e. dtypes between all input arrays have to match.
Args:
matvec: The matrix vector product.
args: List of arguments to `matvec`.
v0: Initial state to `matvec`.
Vm: An array for storing the krylov vectors. The individual
vectors are stored as columns.
The shape of `krylov_vecs` has to be
(num_krylov_vecs + 1, np.ravel(v0).shape[0]).
alphas: An array for storing the diagonal elements of the reduced
operator.
betas: An array for storing the lower diagonal elements of the
reduced operator.
start: Integer denoting the start position where the first
produced krylov_vector should be inserted into `Vm`
num_krylov_vecs: Number of krylov iterations, should be identical to
`Vm.shape[0] + 1`
tol: Convergence parameter. Iteration is terminated if the norm of a
krylov-vector falls below `tol`.
Returns:
jax.Array: An array of shape
`(num_krylov_vecs, np.prod(initial_state.shape))` of krylov vectors.
jax.Array: The diagonal elements of the tridiagonal reduced
operator ("alphas")
jax.Array: The lower-diagonal elements of the tridiagonal reduced
operator ("betas")
jax.Array: The unnormalized residual of the Lanczos process.
float: The norm of the residual.
int: The number of performed iterations.
bool: if `True`: iteration hit an invariant subspace.
if `False`: iteration terminated without encountering
an invariant subspace.
"""
shape = v0.shape
iterative_classical_gram_schmidt = _iterative_classical_gram_schmidt(jax)
Z = jax.numpy.linalg.norm(v0)
#only normalize if norm > tol, else return zero vector
v = jax.lax.cond(Z > tol, lambda x: v0 / Z, lambda x: v0 * 0.0, None)
Vm = Vm.at[start, :].set(jax.numpy.ravel(v))
betas = jax.lax.cond(
start > 0,
lambda x: betas.at[start - 1].set(Z),
lambda x: betas, start)
# body of the arnoldi iteration
def body(vals):
Vm, alphas, betas, previous_vector, _, i = vals
Av = matvec(previous_vector, *args)
Av, overlaps = iterative_classical_gram_schmidt(
Av.ravel(),
(i >= jax.numpy.arange(Vm.shape[0]))[:, None] * Vm, precision)
alphas = alphas.at[i].set(overlaps[i])
norm = jax.numpy.linalg.norm(Av)
Av = jax.numpy.reshape(Av, shape)
# only normalize if norm is larger than threshold,
# otherwise return zero vector
Av = jax.lax.cond(norm > tol, lambda x: Av/norm, lambda x: Av * 0.0, None)
Vm, betas = jax.lax.cond(
i < num_krylov_vecs - 1,
lambda x: (Vm.at[i + 1, :].set(Av.ravel()), betas.at[i].set(norm)),
lambda x: (Vm, betas),
None)
return [Vm, alphas, betas, Av, norm, i + 1]
def cond_fun(vals):
# Continue loop while iteration < num_krylov_vecs and norm > tol
norm, iteration = vals[4], vals[5]
counter_done = (iteration >= num_krylov_vecs)
norm_not_too_small = norm > tol
continue_iteration = jax.lax.cond(counter_done, lambda x: False,
lambda x: norm_not_too_small, None)
return continue_iteration
initial_values = [Vm, alphas, betas, v, Z, start]
final_values = jax.lax.while_loop(cond_fun, body, initial_values)
Vm, alphas, betas, residual, norm, it = final_values
return Vm, alphas, betas, residual, norm, it, norm < tol
return _lanczos_fact
def _generate_arnoldi_factorization(jax: types.ModuleType) -> Callable:
"""
Helper function to create a jitted arnoldi factorization.
The function returns a function `_arnoldi_fact` which
performs an m-step arnoldi factorization.
`_arnoldi_fact` computes an m-step arnoldi factorization
of an input callable `matvec`, with m = min(`it`,`num_krylov_vecs`).
`_arnoldi_fact` will do at most `num_krylov_vecs` steps.
`_arnoldi_fact` returns arrays `kv` and `H` which satisfy
the Arnoldi recurrence relation
```
matrix @ Vm - Vm @ Hm - fm * em = 0
```
with `matrix` the matrix representation of `matvec` and
`Vm = jax.numpy.transpose(kv[:it, :])`,
`Hm = H[:it, :it]`, `fm = np.expand_dims(kv[it, :] * H[it, it - 1]`,1)
and `em` a kartesian basis vector of shape `(1, kv.shape[1])`
with `em[0, -1] == 1` and 0 elsewhere.
Note that the caller is responsible for dtype consistency between
the inputs, i.e. dtypes between all input arrays have to match.
Args:
matvec: The matrix vector product. This function has to be wrapped into
`jax.tree_util.Partial`. `matvec` will be called as `matvec(x, *args)`
for an input vector `x`.
args: List of arguments to `matvec`.
v0: Initial state to `matvec`.
Vm: An array for storing the krylov vectors. The individual
vectors are stored as columns. The shape of `krylov_vecs` has to be
(num_krylov_vecs + 1, np.ravel(v0).shape[0]).
H: Matrix of overlaps. The shape has to be
(num_krylov_vecs + 1,num_krylov_vecs + 1).
start: Integer denoting the start position where the first
produced krylov_vector should be inserted into `Vm`
num_krylov_vecs: Number of krylov iterations, should be identical to
`Vm.shape[0] + 1`
tol: Convergence parameter. Iteration is terminated if the norm of a
krylov-vector falls below `tol`.
Returns:
kv: An array of krylov vectors
H: A matrix of overlaps
it: The number of performed iterations.
converged: Whether convergence was achieved.
"""
JaxPrecisionType = type(jax.lax.Precision.DEFAULT)
iterative_classical_gram_schmidt = _iterative_classical_gram_schmidt(jax)
@functools.partial(jax.jit, static_argnums=(5, 6, 7, 8))
def _arnoldi_fact(
matvec: Callable, args: List, v0: jax.Array,
Vm: jax.Array, H: jax.Array, start: int,
num_krylov_vecs: int, tol: float, precision: JaxPrecisionType
) -> Tuple[jax.Array, jax.Array, jax.Array, float, int,
bool]:
"""
Compute an m-step arnoldi factorization of `matvec`, with
m = min(`it`,`num_krylov_vecs`). The factorization will
do at most `num_krylov_vecs` steps. The returned arrays
`kv` and `H` will satisfy the Arnoldi recurrence relation
```
matrix @ Vm - Vm @ Hm - fm * em = 0
```
with `matrix` the matrix representation of `matvec` and
`Vm = jax.numpy.transpose(kv[:it, :])`,
`Hm = H[:it, :it]`, `fm = np.expand_dims(kv[it, :] * H[it, it - 1]`,1)
and `em` a cartesian basis vector of shape `(1, kv.shape[1])`
with `em[0, -1] == 1` and 0 elsewhere.
Note that the caller is responsible for dtype consistency between
the inputs, i.e. dtypes between all input arrays have to match.
Args:
matvec: The matrix vector product.
args: List of arguments to `matvec`.
v0: Initial state to `matvec`.
Vm: An array for storing the krylov vectors. The individual
vectors are stored as columns.
The shape of `krylov_vecs` has to be
(num_krylov_vecs + 1, np.ravel(v0).shape[0]).
H: Matrix of overlaps. The shape has to be
(num_krylov_vecs + 1,num_krylov_vecs + 1).
start: Integer denoting the start position where the first
produced krylov_vector should be inserted into `Vm`
num_krylov_vecs: Number of krylov iterations, should be identical to
`Vm.shape[0] + 1`
tol: Convergence parameter. Iteration is terminated if the norm of a
krylov-vector falls below `tol`.
Returns:
jax.Array: An array of shape
`(num_krylov_vecs, np.prod(initial_state.shape))` of krylov vectors.
jax.Array: Upper Hessenberg matrix of shape
`(num_krylov_vecs, num_krylov_vecs`) of the Arnoldi processs.
jax.Array: The unnormalized residual of the Arnoldi process.
int: The norm of the residual.
int: The number of performed iterations.
bool: if `True`: iteration hit an invariant subspace.
if `False`: iteration terminated without encountering
an invariant subspace.
"""
# Note (mganahl): currently unused, but is very convenient to have
# for further development and tests (it's usually more accurate than
# classical gs)
# Call signature:
#```python
# initial_vals = [Av.ravel(), Vm, i, H]
# Av, Vm, _, H = jax.lax.fori_loop(
# 0, i + 1, modified_gram_schmidt_step_arnoldi, initial_vals)
#```
def modified_gram_schmidt_step_arnoldi(j, vals): #pylint: disable=unused-variable
"""
Single step of a modified gram-schmidt orthogonalization.
Substantially more accurate than classical gram schmidt
Args:
j: Integer value denoting the vector to be orthogonalized.
vals: A list of variables:
`vector`: The current vector to be orthogonalized
to all previous ones
`Vm`: jax.array of collected krylov vectors
`n`: integer denoting the column-position of the overlap
<`krylov_vector`|`vector`> within `H`.
Returns:
updated vals.
"""
vector, krylov_vectors, n, H = vals
v = krylov_vectors[j, :]
h = jax.numpy.vdot(v, vector, precision=precision)
H = H.at[j, n].set(h)
vector = vector - h * v
return [vector, krylov_vectors, n, H]
shape = v0.shape
Z = jax.numpy.linalg.norm(v0)
#only normalize if norm > tol, else return zero vector
v = jax.lax.cond(Z > tol, lambda x: v0 / Z, lambda x: v0 * 0.0, None)
Vm = Vm.at[start, :].set(jax.numpy.ravel(v))
H = jax.lax.cond(
start > 0,
lambda x: H.at[x, x - 1].set(Z),
lambda x: H, start)
# body of the arnoldi iteration
def body(vals):
Vm, H, previous_vector, _, i = vals
Av = matvec(previous_vector, *args)
Av, overlaps = iterative_classical_gram_schmidt(
Av.ravel(),
(i >= jax.numpy.arange(Vm.shape[0]))[:, None] *
Vm, precision)
H = H.at[:, i].set(overlaps)
norm = jax.numpy.linalg.norm(Av)
Av = jax.numpy.reshape(Av, shape)
# only normalize if norm is larger than threshold,
# otherwise return zero vector
Av = jax.lax.cond(norm > tol, lambda x: Av/norm, lambda x: Av * 0.0, None)
Vm, H = jax.lax.cond(
i < num_krylov_vecs - 1,
lambda x: (Vm.at[i + 1, :].set(Av.ravel()), H.at[i + 1, i].set(norm)), #pylint: disable=line-too-long
lambda x: (x[0], x[1]),
(Vm, H, Av, i, norm))
return [Vm, H, Av, norm, i + 1]
def cond_fun(vals):
# Continue loop while iteration < num_krylov_vecs and norm > tol
norm, iteration = vals[3], vals[4]
counter_done = (iteration >= num_krylov_vecs)
norm_not_too_small = norm > tol
continue_iteration = jax.lax.cond(counter_done, lambda x: False,
lambda x: norm_not_too_small, None)
return continue_iteration
initial_values = [Vm, H, v, Z, start]
final_values = jax.lax.while_loop(cond_fun, body, initial_values)
Vm, H, residual, norm, it = final_values
return Vm, H, residual, norm, it, norm < tol
return _arnoldi_fact
# ######################################################
# ####### NEW SORTING FUCTIONS INSERTED HERE #########
# ######################################################
def _LR_sort(jax):
@functools.partial(jax.jit, static_argnums=(0,))
def sorter(
p: int,
evals: jax.Array) -> Tuple[jax.Array, jax.Array]:
inds = jax.numpy.argsort(jax.numpy.real(evals), stable=True)[::-1]
shifts = evals[inds][-p:]
return shifts, inds
return sorter
def _SA_sort(jax):
@functools.partial(jax.jit, static_argnums=(0,))
def sorter(
p: int,
evals: jax.Array) -> Tuple[jax.Array, jax.Array]:
inds = jax.numpy.argsort(evals, stable=True)
shifts = evals[inds][-p:]
return shifts, inds
return sorter
def _LA_sort(jax):
@functools.partial(jax.jit, static_argnums=(0,))
def sorter(
p: int,
evals: jax.Array) -> Tuple[jax.Array, jax.Array]:
inds = jax.numpy.argsort(evals, kind='stable')[::-1]
shifts = evals[inds][-p:]
return shifts, inds
return sorter
def _LM_sort(jax):
@functools.partial(jax.jit, static_argnums=(0,))
def sorter(
p: int,
evals: jax.Array) -> Tuple[jax.Array, jax.Array]:
inds = jax.numpy.argsort(jax.numpy.abs(evals), stable=True)[::-1]
shifts = evals[inds][-p:]
return shifts, inds
return sorter
# ####################################################
# ####################################################
def _shifted_QR(jax):
@functools.partial(jax.jit, static_argnums=(4,))
def shifted_QR(
Vm: jax.Array, Hm: jax.Array, fm: jax.Array,
shifts: jax.Array,
numeig: int) -> Tuple[jax.Array, jax.Array, jax.Array]:
# compress arnoldi factorization
q = jax.numpy.zeros(Hm.shape[0], dtype=Hm.dtype)
q = q.at[-1].set(1.0)
def body(i, vals):
Vm, Hm, q = vals
shift = shifts[i] * jax.numpy.eye(Hm.shape[0], dtype=Hm.dtype)
Qj, R = jax.numpy.linalg.qr(Hm - shift)
Hm = R @ Qj + shift
Vm = Qj.T @ Vm
q = q @ Qj
return Vm, Hm, q
Vm, Hm, q = jax.lax.fori_loop(0, shifts.shape[0], body,
(Vm, Hm, q))
fk = Vm[numeig, :] * Hm[numeig, numeig - 1] + fm * q[numeig - 1]
return Vm, Hm, fk
return shifted_QR
def _get_vectors(jax):
@functools.partial(jax.jit, static_argnums=(3,))
def get_vectors(Vm: jax.Array, unitary: jax.Array,
inds: jax.Array, numeig: int) -> jax.Array:
def body_vector(i, states):
dim = unitary.shape[1]
n, m = jax.numpy.divmod(i, dim)
states = states.at[n, :].set(states[n,:] + Vm[m, :] * unitary[m, inds[n]])
return states
state_vectors = jax.numpy.zeros([numeig, Vm.shape[1]], dtype=Vm.dtype)
state_vectors = jax.lax.fori_loop(0, numeig * Vm.shape[0], body_vector,
state_vectors)
state_norms = jax.numpy.linalg.norm(state_vectors, axis=1)
state_vectors = state_vectors / state_norms[:, None]
return state_vectors
return get_vectors
def _check_eigvals_convergence_eigh(jax):
@functools.partial(jax.jit, static_argnums=(3,))
def check_eigvals_convergence(beta_m: float, Hm: jax.Array,
Hm_norm: float,
tol: float) -> bool:
eigvals, eigvecs = jax.numpy.linalg.eigh(Hm)
# TODO (mganahl) confirm that this is a valid matrix norm)
thresh = jax.numpy.maximum(
jax.numpy.finfo(eigvals.dtype).eps * Hm_norm,
jax.numpy.abs(eigvals) * tol)
vals = jax.numpy.abs(eigvecs[-1, :])
return jax.numpy.all(beta_m * vals < thresh)
return check_eigvals_convergence
def _check_eigvals_convergence_eig(jax):
@functools.partial(jax.jit, static_argnums=(2, 3))
def check_eigvals_convergence(beta_m: float, Hm: jax.Array,
tol: float, numeig: int) -> bool:
eigvals, eigvecs = cpu_eig(Hm)
# TODO (mganahl) confirm that this is a valid matrix norm)
Hm_norm = jax.numpy.linalg.norm(Hm)
thresh = jax.numpy.maximum(
jax.numpy.finfo(eigvals.dtype).eps * Hm_norm,
jax.numpy.abs(eigvals[:numeig]) * tol)
vals = jax.numpy.abs(eigvecs[numeig - 1, :numeig])
return jax.numpy.all(beta_m * vals < thresh)
return check_eigvals_convergence
def _implicitly_restarted_arnoldi(jax: types.ModuleType) -> Callable:
"""
Helper function to generate a jitted function to do an
implicitly restarted arnoldi factorization of `matvec`. The
returned routine finds the lowest `numeig`
eigenvector-eigenvalue pairs of `matvec`
by alternating between compression and re-expansion of an initial
`num_krylov_vecs`-step Arnoldi factorization.
Note: The caller has to ensure that the dtype of the return value
of `matvec` matches the dtype of the initial state. Otherwise jax
will raise a TypeError.
The function signature of the returned function is
Args:
matvec: A callable representing the linear operator.
args: Arguments to `matvec`. `matvec` is called with
`matvec(x, *args)` with `x` the input array on which
`matvec` should act.
initial_state: An starting vector for the iteration.
num_krylov_vecs: Number of krylov vectors of the arnoldi factorization.
numeig: The number of desired eigenvector-eigenvalue pairs.
which: Which eigenvalues to target. Currently supported: `which = 'LR'`.
tol: Convergence flag. If the norm of a krylov vector drops below `tol`
the iteration is terminated.
maxiter: Maximum number of (outer) iteration steps.
Returns:
eta, U: Two lists containing eigenvalues and eigenvectors.
Args:
jax: The jax module.
Returns:
Callable: A function performing an implicitly restarted
Arnoldi factorization
"""
JaxPrecisionType = type(jax.lax.Precision.DEFAULT)
arnoldi_fact = _generate_arnoldi_factorization(jax)
@functools.partial(jax.jit, static_argnums=(3, 4, 5, 6, 7, 8))
def implicitly_restarted_arnoldi_method(
matvec: Callable, args: List, initial_state: jax.Array,
num_krylov_vecs: int, numeig: int, which: Text, tol: float, maxiter: int,
precision: JaxPrecisionType