-
Notifications
You must be signed in to change notification settings - Fork 2
/
block_tridiag_solve.h
142 lines (116 loc) · 4.09 KB
/
block_tridiag_solve.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
/**
* block_tridiag_solve.h
*
* Custom-written solver routines for block-tridiagonal matrices.
*/
#ifndef _BLOCK_TRIDIAG_SOLVE_H
#define _BLOCK_TRIDIAG_SOLVE_H
#include <stdexcept>
#include <Eigen/Dense>
/* Container class
*/
class BlockTriDiagSolverBase
{
public:
virtual void factor_matrix(double *l, double *d, double *u) = 0;
virtual void solve(double *rhs, double *x) = 0;
virtual ~BlockTriDiagSolverBase(){};
};
/* Solver a tri-diagonal systems of equations where each
* of the blocks have the same size (N x N).
*/
template <int fixed_block_size = Eigen::Dynamic>
class BlockTriDiagSolver
: public BlockTriDiagSolverBase
{
// Inernal vector/matrix views
typedef Eigen::Map<Eigen::Matrix<double, fixed_block_size, 1>> vec_t;
typedef Eigen::Map<Eigen::Matrix<double,
fixed_block_size, fixed_block_size,
Eigen::RowMajor>> mat_t;
public:
BlockTriDiagSolver() {} ;
BlockTriDiagSolver(int num_blocks)
: _sLU(num_blocks),
_l(num_blocks * fixed_block_size * fixed_block_size),
_u(num_blocks * fixed_block_size * fixed_block_size),
_n_blocks(num_blocks), _block_size(fixed_block_size){};
BlockTriDiagSolver(int num_blocks, int block_size)
: _sLU(num_blocks),
_l(num_blocks * block_size * block_size),
_u(num_blocks * block_size * block_size),
_n_blocks(num_blocks), _block_size(block_size)
{
if (fixed_block_size != Eigen::Dynamic)
throw std::invalid_argument("block_size parameter should not be passed "
"to the constructor unless dynamic memory is "
"used.");
};
/* Pre-compute the matrix factorization for later solution */
void factor_matrix(double *l, double *d, double *u)
{
int step = _block_size * _block_size;
Eigen::Matrix<double, fixed_block_size, fixed_block_size> S_i =
mat_t(d, _block_size, _block_size);
// Compute the LU decomposition
_sLU[0].compute(S_i);
for (int i = 1; i < _n_blocks; i++)
{
// Load the required matrices
mat_t l_0(l + i * step, _block_size, _block_size);
mat_t d_0(d + i * step, _block_size, _block_size);
mat_t u_m(u + (i - 1) * step, _block_size, _block_size);
// Update d_0 to give the next denominator:
S_i.noalias() = d_0 - l_0 * _sLU[i-1].solve(u_m);
// Compute the LU decomposition
_sLU[i].compute(S_i);
}
// Save the l/u matrices for later
for (int i = 0; i < _n_blocks * step; i++)
{
_l[i] = l[i];
_u[i] = u[i];
}
}
void solve(double *rhs, double *x)
{
int step = _block_size * _block_size;
double *l = &_l.front();
double *u = &_u.front();
// Do the forward-pass:
// Replace the vectors on the RHS with the vectors
// A_i = [S_i]^-1 * (rhs_i - l_i * A_i-1),
// Note l_0 = 0.
vec_t rhs_0(rhs, _block_size, 1) ;
rhs_0 = _sLU[0].solve(rhs_0).eval() ;
for (int i = 1; i < _n_blocks; i++)
{
// Load the required matrix/vectors
vec_t rhs_m(rhs + (i-1)*_block_size, _block_size, 1);
vec_t rhs_0(rhs + i *_block_size, _block_size, 1);
mat_t l_0(l + i*step, _block_size, _block_size);
rhs_0 = _sLU[i].solve(rhs_0 - l_0*rhs_m).eval() ;
}
// Now do the back-subsitution for x@
// x_i = A_i - [S_i]^-1 * u_i * x_i+1
vec_t x_n(x + (_n_blocks - 1) * _block_size, _block_size, 1);
vec_t rhs_n(rhs + (_n_blocks - 1) * _block_size, _block_size, 1);
x_n = rhs_n ;
for (int i = _n_blocks - 2; i >= 0; i--)
{
vec_t x_p(x + (i + 1) * _block_size, _block_size, 1);
vec_t x_0(x + i * _block_size, _block_size, 1);
vec_t rhs_0(rhs + i * _block_size, _block_size, 1);
mat_t u_0(u + i * step, _block_size, _block_size);
x_0.noalias() = rhs_0 - _sLU[i].solve(u_0 * x_p);
}
}
private:
typedef Eigen::PartialPivLU<Eigen::Matrix<double, fixed_block_size,
fixed_block_size>>
LU;
std::vector<LU> _sLU;
std::vector<double> _l, _u;
int _n_blocks, _block_size;
};
#endif//_BLOCK_TRIDIAG_SOLVE_H