Skip to content

Commit

Permalink
Waksman multiplication for gr_mat (#2109)
Browse files Browse the repository at this point in the history
  • Loading branch information
fredrik-johansson authored Nov 12, 2024
1 parent ed0f2b0 commit 3c28655
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 0 deletions.
4 changes: 4 additions & 0 deletions doc/source/gr_mat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,17 @@ Arithmetic

.. function:: int gr_mat_mul_classical(gr_mat_t res, const gr_mat_t mat1, const gr_mat_t mat2, gr_ctx_t ctx)
int gr_mat_mul_strassen(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
int gr_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
int gr_mat_mul_generic(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
int gr_mat_mul(gr_mat_t res, const gr_mat_t mat1, const gr_mat_t mat2, gr_ctx_t ctx)

Matrix multiplication. The default function can be overloaded by specific rings;
otherwise, it falls back to :func:`gr_mat_mul_generic` which currently
only performs classical multiplication.

The *Waksman* algorithm assumes a commutative base ring which supports
exact division by two.

.. function:: int gr_mat_sqr(gr_mat_t res, const gr_mat_t mat, gr_ctx_t ctx)

.. function:: int gr_mat_add_scalar(gr_mat_t res, const gr_mat_t mat, gr_srcptr c, gr_ctx_t ctx)
Expand Down
1 change: 1 addition & 0 deletions src/gr_mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ WARN_UNUSED_RESULT int gr_mat_div_scalar(gr_mat_t res, const gr_mat_t mat, gr_sr

WARN_UNUSED_RESULT int gr_mat_mul_classical(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx);
WARN_UNUSED_RESULT int gr_mat_mul_strassen(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx);
WARN_UNUSED_RESULT int gr_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx);
WARN_UNUSED_RESULT int gr_mat_mul_generic(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx);
WARN_UNUSED_RESULT int gr_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx);

Expand Down
140 changes: 140 additions & 0 deletions src/gr_mat/mul_waksman.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
Copyright (C) 2024 Éric Schost
Copyright (C) 2024 Vincent Neiger
Copyright (C) 2024 Fredrik Johansson
This file is part of FLINT.
FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 3 of the License, or
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "gr_mat.h"
#include "gr_vec.h"

/* todo: division by two should be divexact by two */
/* todo: avoid redundant additions 0 + ... */

int gr_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx)
{
int status = GR_SUCCESS;
slong m, n, p;
slong sz = ctx->sizeof_elem;

m = A->r;
n = A->c;
p = B->c;

if (m == 0 || n == 0 || p == 0)
{
return gr_mat_zero(C, ctx);
}

if (n != B->r || m != C->r || p != C->c)
{
return GR_DOMAIN;
}

if (A == C || B == C)
{
gr_mat_t T;
gr_mat_init(T, m, p, ctx);
status |= gr_mat_mul_waksman(T, A, B, ctx);
status |= gr_mat_swap_entrywise(T, C, ctx);
gr_mat_clear(T, ctx);
return status;
}

slong i, l, j, k;

gr_ptr tmp, Crow, Ccol, val0, val1, val2, crow;

GR_TMP_INIT_VEC(tmp, p + m + 4, ctx);

Crow = tmp;
Ccol = GR_ENTRY(Crow, p, sz);
val0 = GR_ENTRY(Ccol, m, sz);
val1 = GR_ENTRY(val0, 1, sz);
val2 = GR_ENTRY(val1, 1, sz);
crow = GR_ENTRY(val2, 1, sz);

slong np = n >> 1;

for (i = 0; i < m; i++)
status |= _gr_vec_zero(GR_MAT_ENTRY(C, i, 0, sz), p, ctx);

for (j = 1; j <= np; j++)
{
slong j2 = (j << 1) - 1;

for (k = 0; k < p; k++)
{
status |= gr_add(val1, GR_MAT_ENTRY(A, 0, j2 - 1, sz), GR_MAT_ENTRY(B, j2, k, sz), ctx);
status |= gr_add(val2, GR_MAT_ENTRY(A, 0, j2, sz), GR_MAT_ENTRY(B, j2 - 1, k, sz), ctx);
status |= gr_addmul(GR_MAT_ENTRY(C, 0, k, sz), val1, val2, ctx);

status |= gr_sub(val1, GR_MAT_ENTRY(A, 0, j2 - 1, sz), GR_MAT_ENTRY(B, j2, k, sz), ctx);
status |= gr_sub(val2, GR_MAT_ENTRY(A, 0, j2, sz), GR_MAT_ENTRY(B, j2 - 1, k, sz), ctx);
status |= gr_addmul(GR_ENTRY(Crow, k, sz), val1, val2, ctx);
}

for (l = 1; l < m; l++)
{
status |= gr_add(val1, GR_MAT_ENTRY(A, l, j2 - 1, sz), GR_MAT_ENTRY(B, j2, 0, sz), ctx);
status |= gr_add(val2, GR_MAT_ENTRY(A, l, j2, sz), GR_MAT_ENTRY(B, j2 - 1, 0, sz), ctx);
status |= gr_addmul(GR_MAT_ENTRY(C, l, 0, sz), val1, val2, ctx);

status |= gr_sub(val1, GR_MAT_ENTRY(A, l, j2 - 1, sz), GR_MAT_ENTRY(B, j2, 0, sz), ctx);
status |= gr_sub(val2, GR_MAT_ENTRY(A, l, j2, sz), GR_MAT_ENTRY(B, j2 - 1, 0, sz), ctx);
status |= gr_addmul(GR_ENTRY(Ccol, l, sz), val1, val2, ctx);
}

for (k = 1; k < p; k++)
{
for (l = 1; l < m; l++)
{
status |= gr_add(val1, GR_MAT_ENTRY(A, l, j2 - 1, sz), GR_MAT_ENTRY(B, j2, k, sz), ctx);
status |= gr_add(val2, GR_MAT_ENTRY(A, l, j2, sz), GR_MAT_ENTRY(B, j2 - 1, k, sz), ctx);
status |= gr_addmul(GR_MAT_ENTRY(C, l, k, sz), val1, val2, ctx);
}
}
}

for (l = 1; l < m; l++)
{
status |= gr_add(val1, GR_ENTRY(Ccol, l, sz), GR_MAT_ENTRY(C, l, 0, sz), ctx);
status |= gr_mul_2exp_si(GR_ENTRY(Ccol, l, sz), val1, -1, ctx);
status |= gr_sub(GR_MAT_ENTRY(C, l, 0, sz), GR_MAT_ENTRY(C, l, 0, sz), GR_ENTRY(Ccol, l, sz), ctx);
}

status |= gr_add(val1, Crow, GR_MAT_ENTRY(C, 0, 0, sz), ctx);
status |= gr_mul_2exp_si(val0, val1, -1, ctx);
status |= gr_sub(GR_MAT_ENTRY(C, 0, 0, sz), GR_MAT_ENTRY(C, 0, 0, sz), val0, ctx);

for (k = 1; k < p; k++)
{
status |= gr_add(crow, GR_ENTRY(Crow, k, sz), GR_MAT_ENTRY(C, 0, k, sz), ctx);
status |= gr_mul_2exp_si(val1, crow, -1, ctx);
status |= gr_sub(GR_MAT_ENTRY(C, 0, k, sz), GR_MAT_ENTRY(C, 0, k, sz), val1, ctx);
status |= gr_sub(crow, val1, val0, ctx);

for (l = 1; l < m; l++)
{
status |= gr_sub(val2, GR_MAT_ENTRY(C, l, k, sz), crow, ctx);
status |= gr_sub(GR_MAT_ENTRY(C, l, k, sz), val2, GR_ENTRY(Ccol, l, sz), ctx);
}
}

if ((n & 1) == 1)
for (l = 0; l < m; l++)
for (k = 0; k < p; k++)
status |= gr_addmul(GR_MAT_ENTRY(C, l, k, sz),
GR_MAT_ENTRY(A, l, n - 1, sz), GR_MAT_ENTRY(B, n - 1, k, sz), ctx);

GR_TMP_CLEAR_VEC(tmp, p + m + 4, ctx);

return status;
}

2 changes: 2 additions & 0 deletions src/gr_mat/test/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "t-lu_recursive.c"
#include "t-minpoly_field.c"
#include "t-mul_strassen.c"
#include "t-mul_waksman.c"
#include "t-nullspace.c"
#include "t-properties.c"
#include "t-randrank.c"
Expand Down Expand Up @@ -82,6 +83,7 @@ test_struct tests[] =
TEST_FUNCTION(gr_mat_lu_recursive),
TEST_FUNCTION(gr_mat_minpoly_field),
TEST_FUNCTION(gr_mat_mul_strassen),
TEST_FUNCTION(gr_mat_mul_waksman),
TEST_FUNCTION(gr_mat_nullspace),
TEST_FUNCTION(gr_mat_properties),
TEST_FUNCTION(gr_mat_randrank),
Expand Down
97 changes: 97 additions & 0 deletions src/gr_mat/test/t-mul_waksman.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
Copyright (C) 2022 Fredrik Johansson
This file is part of FLINT.
FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 3 of the License, or
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "test_helpers.h"
#include "ulong_extras.h"
#include "gr_mat.h"

TEST_FUNCTION_START(gr_mat_mul_waksman, state)
{
slong iter;

for (iter = 0; iter < 1000; iter++)
{
gr_ctx_t ctx;
gr_mat_t A, B, C, D;
slong a, b, c;
int status = GR_SUCCESS;
int can_div2;

if (n_randint(state, 2))
{
gr_ctx_init_fmpz(ctx);
can_div2 = 1;
}
else
{
ulong m = n_randtest_not_zero(state);
can_div2 = m % 2;
gr_ctx_init_nmod(ctx, m);
}

a = n_randint(state, 8);
b = n_randint(state, 2) ? a : n_randint(state, 8);
c = n_randint(state, 2) ? a : n_randint(state, 8);

gr_mat_init(A, a, b, ctx);
gr_mat_init(B, b, c, ctx);
gr_mat_init(C, a, c, ctx);
gr_mat_init(D, a, c, ctx);

status |= gr_mat_randtest(A, state, ctx);
status |= gr_mat_randtest(B, state, ctx);
status |= gr_mat_randtest(C, state, ctx);
status |= gr_mat_randtest(D, state, ctx);

if (a == b && b == c && n_randint(state, 2))
{
status |= gr_mat_set(B, A, ctx);
status |= gr_mat_mul_waksman(C, A, A, ctx);
}
else if (b == c && n_randint(state, 2))
{
status |= gr_mat_set(C, A, ctx);
status |= gr_mat_mul_waksman(C, C, B, ctx);
}
else if (a == b && n_randint(state, 2))
{
status |= gr_mat_set(C, B, ctx);
status |= gr_mat_mul_waksman(C, A, C, ctx);
}
else
{
status |= gr_mat_mul_waksman(C, A, B, ctx);
}

status |= gr_mat_mul_classical(D, A, B, ctx);

if ((can_div2 && (status != GR_SUCCESS || gr_mat_equal(C, D, ctx) != T_TRUE))
|| (status == GR_SUCCESS && gr_mat_equal(C, D, ctx) != T_TRUE))
{
flint_printf("FAIL:\n");
gr_ctx_println(ctx);
flint_printf("A:\n"); gr_mat_print(A, ctx); flint_printf("\n\n");
flint_printf("B:\n"); gr_mat_print(B, ctx); flint_printf("\n\n");
flint_printf("C:\n"); gr_mat_print(C, ctx); flint_printf("\n\n");
flint_printf("D:\n"); gr_mat_print(D, ctx); flint_printf("\n\n");
flint_abort();
}

gr_mat_clear(A, ctx);
gr_mat_clear(B, ctx);
gr_mat_clear(C, ctx);
gr_mat_clear(D, ctx);

gr_ctx_clear(ctx);
}

TEST_FUNCTION_END(state);
}

0 comments on commit 3c28655

Please sign in to comment.