-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add some triangular ring solver and asymptotically fast generics (#1490)
* add some triangular ring solver but what do we want? we have solve_triu for fields Ax = b rings (this PR) Ax = b solve_triu_left rings (this PR) xA = b All three accept large "b", ie. can solve multiple equations in one. However, they require "A" to be non-singular square We also have can_solve_left_reduced_triu which can deal with arbitrary matrices "A" in rref, but can only deal with single row "b" solve_triu for fields also has flags for the diagonal being 1 if b and A are square, this is asymptotically not optimal as it will be a n^3/2 algo I can add test - if we want this "interface" Note: for triangular, one cannot transpose to reduce one case to the other Note: in serious use, should also be supplemented by special implementations in Nemo/ Flint for nmod and/or ZZ and friends * add generic fast matrix support in module Strassen * solve_tril! * sub! and mul! for generic matrices, INPLACE * add test for Strassen * fix typo * fix doctest
- Loading branch information
Showing
5 changed files
with
449 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,309 @@ | ||
""" | ||
Provides generic asymptotically fast matrix methods: | ||
- mul and mul! using the Strassen scheme | ||
- solve_tril! | ||
- lu! | ||
- solve_triu | ||
Just prefix the function by "Strassen." all 4 functions support a keyword | ||
argument "cutoff" to indicate when the base case should be used. | ||
The speedup depends on the ring and the entry sizes. | ||
#Examples: | ||
```jldoctest; setup = :(using AbstractAlgebra) | ||
julia> m = matrix(ZZ, rand(-10:10, 1000, 1000)); | ||
julia> n = similar(m); | ||
julia> mul!(n, m, m); | ||
julia> Strassen.mul!(n, m, m); | ||
julia> Strassen.mul!(n, m, m; cutoff = 100); | ||
``` | ||
""" | ||
module Strassen | ||
using AbstractAlgebra | ||
import AbstractAlgebra:Perm | ||
|
||
const cutoff = 1500 | ||
|
||
function mul(A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff) where {T} | ||
C = zero_matrix(base_ring(A), nrows(A), ncols(B)) | ||
mul!(C, A, B; cutoff) | ||
return C | ||
end | ||
|
||
#scheduling copied from the nmod_mat_mul in Flint | ||
function mul!(C::MatElem{T}, A::MatElem{T}, B::MatElem{T}; cutoff::Int = cutoff) where {T} | ||
sA = size(A) | ||
sB = size(B) | ||
sC = size(C) | ||
a = sA[1] | ||
b = sA[2] | ||
c = sB[2] | ||
|
||
@assert a == sC[1] && b == sB[1] && c == sC[2] | ||
|
||
if (a <= cutoff || b <= cutoff || c <= cutoff) | ||
AbstractAlgebra.mul!(C, A, B) | ||
return | ||
end | ||
|
||
anr = div(a, 2) | ||
anc = div(b, 2) | ||
bnr = anc | ||
bnc = div(c, 2) | ||
|
||
#nmod_mat_window_init(A11, A, 0, 0, anr, anc); | ||
#nmod_mat_window_init(A12, A, 0, anc, anr, 2*anc); | ||
#nmod_mat_window_init(A21, A, anr, 0, 2*anr, anc); | ||
#nmod_mat_window_init(A22, A, anr, anc, 2*anr, 2*anc); | ||
A11 = view(A, 1:anr, 1:anc) | ||
A12 = view(A, 1:anr, anc+1:2*anc) | ||
A21 = view(A, anr+1:2*anr, 1:anc) | ||
A22 = view(A, anr+1:2*anr, anc+1:2*anc) | ||
|
||
#nmod_mat_window_init(B11, B, 0, 0, bnr, bnc); | ||
#nmod_mat_window_init(B12, B, 0, bnc, bnr, 2*bnc); | ||
#nmod_mat_window_init(B21, B, bnr, 0, 2*bnr, bnc); | ||
#nmod_mat_window_init(B22, B, bnr, bnc, 2*bnr, 2*bnc); | ||
B11 = view(B, 1:bnr, 1:bnc) | ||
B12 = view(B, 1:bnr, bnc+1:2*bnc) | ||
B21 = view(B, bnr+1:2*bnr, 1:bnc) | ||
B22 = view(B, bnr+1:2*bnr, bnc+1:2*bnc) | ||
|
||
#nmod_mat_window_init(C11, C, 0, 0, anr, bnc); | ||
#nmod_mat_window_init(C12, C, 0, bnc, anr, 2*bnc); | ||
#nmod_mat_window_init(C21, C, anr, 0, 2*anr, bnc); | ||
#nmod_mat_window_init(C22, C, anr, bnc, 2*anr, 2*bnc); | ||
C11 = view(C, 1:anr, 1:bnc) | ||
C12 = view(C, 1:anr, bnc+1:2*bnc) | ||
C21 = view(C, anr+1:2*anr, 1:bnc) | ||
C22 = view(C, anr+1:2*anr, bnc+1:2*bnc) | ||
|
||
#nmod_mat_init(X1, anr, FLINT_MAX(bnc, anc), A->mod.n); | ||
#nmod_mat_init(X2, anc, bnc, A->mod.n); | ||
|
||
#X1->c = anc; | ||
|
||
#= | ||
See Jean-Guillaume Dumas, Clement Pernet, Wei Zhou; "Memory | ||
efficient scheduling of Strassen-Winograd's matrix multiplication | ||
algorithm"; http://arxiv.org/pdf/0707.2347v3 for reference on the | ||
used operation scheduling. | ||
=# | ||
|
||
X1 = A11 - A21 | ||
X2 = B22 - B12 | ||
#nmod_mat_mul(C21, X1, X2); | ||
mul!(C21, X1, X2; cutoff) | ||
|
||
add!(X1, A21, A22); | ||
sub!(X2, B12, B11); | ||
#nmod_mat_mul(C22, X1, X2); | ||
mul!(C22, X1, X2; cutoff) | ||
|
||
sub!(X1, X1, A11); | ||
sub!(X2, B22, X2); | ||
#nmod_mat_mul(C12, X1, X2); | ||
mul!(C12, X1, X2; cutoff) | ||
|
||
sub!(X1, A12, X1); | ||
#nmod_mat_mul(C11, X1, B22); | ||
mul!(C11, X1, B22; cutoff) | ||
|
||
#X1->c = bnc; | ||
#nmod_mat_mul(X1, A11, B11); | ||
mul!(X1, A11, B11; cutoff) | ||
|
||
add!(C12, X1, C12); | ||
add!(C21, C12, C21); | ||
add!(C12, C12, C22); | ||
add!(C22, C21, C22); | ||
add!(C12, C12, C11); | ||
sub!(X2, X2, B21); | ||
#nmod_mat_mul(C11, A22, X2); | ||
mul!(C11, A22, X2; cutoff) | ||
|
||
sub!(C21, C21, C11); | ||
|
||
#nmod_mat_mul(C11, A12, B21); | ||
mul!(C11, A12, B21; cutoff) | ||
|
||
add!(C11, X1, C11); | ||
|
||
if c > 2*bnc #A by last col of B -> last col of C | ||
#nmod_mat_window_init(Bc, B, 0, 2*bnc, b, c); | ||
Bc = view(B, 1:b, 2*bnc+1:c) | ||
#nmod_mat_window_init(Cc, C, 0, 2*bnc, a, c); | ||
Cc = view(C, 1:a, 2*bnc+1:c) | ||
#nmod_mat_mul(Cc, A, Bc); | ||
AbstractAlgebra.mul!(Cc, A, Bc) | ||
end | ||
|
||
if a > 2*anr #last row of A by B -> last row of C | ||
#nmod_mat_window_init(Ar, A, 2*anr, 0, a, b); | ||
Ar = view(A, 2*anr+1:a, 1:b) | ||
#nmod_mat_window_init(Cr, C, 2*anr, 0, a, c); | ||
Cr = view(C, 2*anr+1:a, 1:c) | ||
#nmod_mat_mul(Cr, Ar, B); | ||
AbstractAlgebra.mul!(Cr, Ar, B) | ||
end | ||
|
||
if b > 2*anc # last col of A by last row of B -> C | ||
#nmod_mat_window_init(Ac, A, 0, 2*anc, 2*anr, b); | ||
Ac = view(A, 1:2*anr, 2*anc+1:b) | ||
#nmod_mat_window_init(Br, B, 2*bnr, 0, b, 2*bnc); | ||
Br = view(B, 2*bnr+1:b, 1:2*bnc) | ||
#nmod_mat_window_init(Cb, C, 0, 0, 2*anr, 2*bnc); | ||
Cb = view(C, 1:2*anr, 1:2*bnc) | ||
#nmod_mat_addmul(Cb, Cb, Ac, Br); | ||
AbstractAlgebra.mul!(Cb, Ac, Br, true) | ||
end | ||
end | ||
|
||
#solve_tril fast, recursive | ||
# A * X = U | ||
# B C Y V | ||
# => X = solve(A, U) | ||
# Y = solve(C, V - B*X) | ||
function solve_tril!(A::MatElem{T}, B::MatElem{T}, C::MatElem{T}, f::Int = 0; cutoff::Int = 2) where T | ||
if nrows(A) < cutoff || ncols(A) < cutoff | ||
return AbstractAlgebra.solve_tril!(A, B, C, f) | ||
end | ||
n = nrows(B) | ||
n2 = div(n, 2) | ||
B11 = view(B, 1:n2, 1:n2) | ||
B21 = view(B, n2+1:n, 1:n2) | ||
B22 = view(B, n2+1:n, n2+1:ncols(B)) | ||
X1 = view(A, 1:n2, 1:ncols(A)) | ||
X2 = view(A, n2+1:n, 1:ncols(A)) | ||
C1 = view(C, 1:n2, 1:ncols(A)) | ||
C2 = view(C, n2+1:n, 1:ncols(A)) | ||
solve_tril!(X1, B11, C1, f; cutoff) | ||
x = B21 * X1 # strassen... | ||
sub!(X2, C2, x) | ||
solve_tril!(X2, B22, X2, f; cutoff) | ||
end | ||
|
||
function apply!(A::MatElem, P::Perm{Int}; offset::Int = 0) | ||
n = length(P.d) | ||
Q = copy(inv(P).d) #the inv is experimentally verified with the other apply | ||
cnt = 0 | ||
start = 0 | ||
while cnt < n | ||
ptr = start = findnext(!iszero, Q, start +1)::Int | ||
next = Q[start] | ||
cnt += 1 | ||
while next != start | ||
swap_rows!(A, ptr, next) | ||
Q[ptr] = 0 | ||
next = Q[next] | ||
cnt += 1 | ||
end | ||
Q[ptr] = 0 | ||
end | ||
end | ||
|
||
function apply!(Q::Perm{Int}, P::Perm{Int}; offset::Int = 0) | ||
n = length(P.d) | ||
t = zeros(Int, n-offset) | ||
for i=1:n-offset | ||
t[i] = Q.d[P.d[i] + offset] | ||
end | ||
for i=1:n-offset | ||
Q.d[i + offset] = t[i] | ||
end | ||
end | ||
|
||
function lu!(P::Perm{Int}, A; cutoff::Int = 300) | ||
m = nrows(A) | ||
|
||
@assert length(P.d) == m | ||
n = ncols(A) | ||
if n < cutoff | ||
return AbstractAlgebra.lu!(P, A) | ||
end | ||
n1 = div(n, 2) | ||
for i=1:m | ||
P.d[i] = i | ||
end | ||
P1 = AbstractAlgebra.Perm(m) | ||
A0 = view(A, 1:m, 1:n1) | ||
r1 = lu!(P1, A0; cutoff) | ||
@assert r1 == n1 | ||
if r1 > 0 | ||
apply!(A, P1) | ||
apply!(P, P1) | ||
end | ||
|
||
A00 = view(A, 1:r1, 1:r1) | ||
A10 = view(A, r1+1:m, 1:r1) | ||
A01 = view(A, 1:r1, n1+1:n) | ||
A11 = view(A, r1+1:m, n1+1:n) | ||
|
||
if r1 > 0 | ||
#Note: A00 is a view of A0 thus a view of A | ||
# A0 is lu!, thus implicitly two triangular matrices giving the | ||
# lu decomosition. solve_tril! looks ONLY at the lower part of A00 | ||
solve_tril!(A01, A00, A01, 1) | ||
X = A10 * A01 | ||
sub!(A11, A11, X) | ||
end | ||
|
||
P1 = Perm(nrows(A11)) | ||
r2 = lu!(P1, A11) | ||
apply!(A, P1, offset = r1) | ||
apply!(P, P1, offset = r1) | ||
|
||
if (r1 != n1) | ||
for i=1:m-r1 | ||
for j=1:min(i, r2) | ||
A[r1+i-1, r1+j-1] = A[r1+i-1, n1+j-1] | ||
A[r1+i-1, n1+j-1] = 0 | ||
end | ||
end | ||
end | ||
return r1 + r2 | ||
end | ||
|
||
function solve_triu(T::MatElem, b::MatElem; cutoff::Int = cutoff) | ||
#b*inv(T), thus solves Tx = b for T upper triangular | ||
n = ncols(T) | ||
if n <= cutoff | ||
R = AbstractAlgebra.solve_triu(T, b) | ||
return R | ||
end | ||
|
||
n2 = div(n, 2) + n % 2 | ||
m = nrows(b) | ||
m2 = div(m, 2) + m % 2 | ||
|
||
U = view(b, 1:m2, 1:n2) | ||
V = view(b, 1:m2, n2+1:n) | ||
X = view(b, m2+1:m, 1:n2) | ||
Y = view(b, m2+1:m, n2+1:n) | ||
|
||
A = view(T, 1:n2, 1:n2) | ||
B = view(T, 1:n2, 1+n2:n) | ||
C = view(T, 1+n2:n, 1+n2:n) | ||
|
||
S = solve_triu(A, U; cutoff) | ||
R = solve_triu(A, X; cutoff) | ||
|
||
SS = mul(S, B; cutoff) | ||
sub!(SS, V, SS) | ||
SS = solve_triu(C, SS; cutoff) | ||
|
||
RR = mul(R, B; cutoff) | ||
sub!(RR, Y, RR) | ||
RR = solve_triu(C, RR; cutoff) | ||
|
||
return [S SS; R RR] | ||
end | ||
|
||
end # module |
Oops, something went wrong.