From b5f653368b585a94f15bbba79d21dd1635fa6220 Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Fri, 15 Nov 2024 16:00:31 +0100 Subject: [PATCH] Precise bounds calculation for nfixed_mat_mul --- doc/source/nfloat.rst | 24 +- src/nfloat.h | 8 + src/nfloat/mat_mul.c | 89 +++++--- src/nfloat/nfixed.c | 217 +++++++++++++++++++ src/nfloat/test/main.c | 6 + src/nfloat/test/t-nfixed_mat_mul.c | 37 ++-- src/nfloat/test/t-nfixed_mat_mul_classical.c | 110 ++++++++++ src/nfloat/test/t-nfixed_mat_mul_strassen.c | 113 ++++++++++ src/nfloat/test/t-nfixed_mat_mul_waksman.c | 110 ++++++++++ 9 files changed, 667 insertions(+), 47 deletions(-) create mode 100644 src/nfloat/test/t-nfixed_mat_mul_classical.c create mode 100644 src/nfloat/test/t-nfixed_mat_mul_strassen.c create mode 100644 src/nfloat/test/t-nfixed_mat_mul_waksman.c diff --git a/doc/source/nfloat.rst b/doc/source/nfloat.rst index 5acdaf55da..9ac02883cd 100644 --- a/doc/source/nfloat.rst +++ b/doc/source/nfloat.rst @@ -467,7 +467,8 @@ intermediate results (including rounding errors) lie in `(-1,1)`. indicate the offset in number of limbs between consecutive entries and may be negative. -.. function:: void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) +.. function:: void _nfixed_mat_mul_classical_precise(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) + void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) void _nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) void _nfixed_mat_mul_strassen(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong cutoff, slong nlimbs) void _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) @@ -475,3 +476,24 @@ intermediate results (including rounding errors) lie in `(-1,1)`. Matrix multiplication using various algorithms. The *strassen* variant takes a *cutoff* parameter specifying where to switch from basecase multiplication to Strassen multiplication. + The *classical_precise* version computes with one extra limb of + internal precision; this is only intended for testing purposes. + +.. function:: void _nfixed_mat_mul_bound_classical(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs) + void _nfixed_mat_mul_bound_waksman(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs) + void _nfixed_mat_mul_bound_strassen(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong cutoff, slong nlimbs) + void _nfixed_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs) + void _nfixed_complex_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, double C, double D, slong nlimbs) + + For the respective matrix multiplication algorithm, computes bounds + for a size `m \times n \times p` product at precision *nlimbs* + given entrywise bounds *A* and *B*. + + The *bound* output is set to a bound for the entries in all intermediate + variables of the computation. This should be < 1 to + ensure correctness. The *error* output is set to a bound for the + output error, measured in ulp. + The caller can assume that the computed bounds are nondecreasing + functions of *A* and *B*. + + For complex multiplication, the entrywise bounds are for `A+Bi` and `C+Di`. diff --git a/src/nfloat.h b/src/nfloat.h index d73c32bbc4..c60c32e8b5 100644 --- a/src/nfloat.h +++ b/src/nfloat.h @@ -590,11 +590,19 @@ void _nfixed_dot_6(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ys void _nfixed_dot_7(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); void _nfixed_dot_8(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); +void _nfixed_mat_mul_classical_precise(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs); void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs); void _nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs); void _nfixed_mat_mul_strassen(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong cutoff, slong nlimbs); void _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs); +void _nfixed_mat_mul_bound_classical(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs); +void _nfixed_mat_mul_bound_waksman(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs); +void _nfixed_mat_mul_bound_strassen(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong cutoff, slong nlimbs); +void _nfixed_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs); +void _nfixed_complex_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, double C, double D, slong nlimbs); + + #ifdef __cplusplus } #endif diff --git a/src/nfloat/mat_mul.c b/src/nfloat/mat_mul.c index 6643015a91..81d7c11faa 100644 --- a/src/nfloat/mat_mul.c +++ b/src/nfloat/mat_mul.c @@ -9,6 +9,7 @@ (at your option) any later version. See . */ +#include "double_extras.h" #include "mpn_extras.h" #include "gr.h" #include "gr_vec.h" @@ -500,27 +501,39 @@ nfloat_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong max_e if (Adelta > 10 * prec || Bdelta > 10 * prec) return GR_UNABLE; - /* - To double check: for Waksman, - * The intermediate entries are bounded by 8n max(|A|,|B|)^2. - * The error, including error from converting - the input matrices, is bounded by 8n ulps. - */ + /* We must scale inputs to 2^(-pad_top) so that intermediate + entries satisfy |x| < 1. */ + { + double Abound, Bbound, bound, error; + + pad_top = 2; + Abound = Bbound = ldexp(1.0, -pad_top); + /* Option: improve accuracy by adding more trailing guard bits. */ + /* pad_bot = 3 + FLINT_BIT_COUNT(n); */ + pad_bot = 2; - pad_top = 3 + FLINT_BIT_COUNT(n); - pad_bot = 3 + FLINT_BIT_COUNT(n); + while (1) + { + Aexp = Amax + pad_top; + Bexp = Bmax + pad_top; + extra_bits = Adelta + Bdelta + pad_top + pad_bot; - extra_bits = Adelta + Bdelta + pad_top + pad_bot; + if (extra_bits >= max_extra_bits) + return GR_UNABLE; - if (extra_bits >= max_extra_bits) - return GR_UNABLE; + fbits = prec + extra_bits; + fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS; - Aexp = Amax + pad_top; - Bexp = Bmax + pad_top; - fbits = prec + extra_bits; - fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS; + _nfixed_mat_mul_bound(&bound, &error, A->r, n, B->c, Abound, Bbound, fnlimbs); - return _nfloat_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx); + if (bound < 0.999) + return _nfloat_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx); + + pad_top++; + Abound *= 0.5; + Bbound *= 0.5; + } + } } static void @@ -1389,27 +1402,39 @@ nfloat_complex_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slo if (Adelta > 10 * prec || Bdelta > 10 * prec) return GR_UNABLE; - /* - To double check: for Waksman, - * The intermediate entries are bounded by 8n max(|A|,|B|)^2. - * The error, including error from converting - the input matrices, is bounded by 8n ulps. - */ + /* We must scale inputs to 2^(-pad_top) so that intermediate + entries satisfy |x| < 1. */ + { + double Abound, Bbound, bound, error; + + pad_top = 2; + Abound = Bbound = ldexp(1.0, -pad_top); + /* Option: improve accuracy by adding more trailing guard bits. */ + /* pad_bot = 3 + FLINT_BIT_COUNT(n); */ + pad_bot = 2; - pad_top = 3 + FLINT_BIT_COUNT(n); - pad_bot = 3 + FLINT_BIT_COUNT(n); + while (1) + { + Aexp = Amax + pad_top; + Bexp = Bmax + pad_top; + extra_bits = Adelta + Bdelta + pad_top + pad_bot; - extra_bits = Adelta + Bdelta + pad_top + pad_bot; + if (extra_bits >= max_extra_bits) + return GR_UNABLE; - if (extra_bits >= max_extra_bits) - return GR_UNABLE; + fbits = prec + extra_bits; + fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS; - Aexp = Amax + pad_top; - Bexp = Bmax + pad_top; - fbits = prec + extra_bits; - fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS; + _nfixed_complex_mat_mul_bound(&bound, &error, A->r, n, B->c, Abound, Abound, Bbound, Bbound, fnlimbs); - return _nfloat_complex_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx); + if (bound < 0.999) + return _nfloat_complex_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx); + + pad_top++; + Abound *= 0.5; + Bbound *= 0.5; + } + } } FLINT_FORCE_INLINE slong diff --git a/src/nfloat/nfixed.c b/src/nfloat/nfixed.c index 2c8eaef16e..597bc96a67 100644 --- a/src/nfloat/nfixed.c +++ b/src/nfloat/nfixed.c @@ -9,6 +9,7 @@ (at your option) any later version. See . */ +#include "double_extras.h" #include "mpn_extras.h" #include "gr.h" #include "gr_vec.h" @@ -758,6 +759,51 @@ _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, #undef C_ENTRY } +/* todo: optimize */ +/* A is (m x n), B is (n x p), C is (m x p) */ +void +_nfixed_mat_mul_classical_precise(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) +{ + slong i; + nn_ptr t, tA, tB, tC, u; + nn_srcptr s; + + t = flint_malloc(((m * n) + (n * p) + (m * p)) * (nlimbs + 2) * sizeof(ulong)); + tA = t; + tB = tA + (m * n) * (nlimbs + 2); + tC = tB + (n * p) * (nlimbs + 2); + + for (i = 0; i < m * n; i++) + { + s = A + i * (nlimbs + 1); + u = tA + i* (nlimbs + 2); + flint_mpn_copyi(u + 2, s + 1, nlimbs); + u[0] = s[0]; + u[1] = 0; + } + + for (i = 0; i < n * p; i++) + { + s = B + i * (nlimbs + 1); + u = tB + i * (nlimbs + 2); + flint_mpn_copyi(u + 2, s + 1, nlimbs); + u[0] = s[0]; + u[1] = 0; + } + + _nfixed_mat_mul_classical(tC, tA, tB, m, n, p, nlimbs + 1); + + for (i = 0; i < m * p; i++) + { + s = tC + i * (nlimbs + 2); + u = C + i * (nlimbs + 1); + flint_mpn_copyi(u + 1, s + 2, nlimbs); + u[0] = s[0]; + } + + flint_free(t); +} + /* compute c += (a1 + b1) * (a2 + b2) */ /* val0, val1, val2 are scratch space */ FLINT_FORCE_INLINE void @@ -1225,6 +1271,7 @@ _nfixed_mat_mul_strassen2(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_ nn = FLINT_MIN(ar, ac); nn = FLINT_MIN(nn, bc); + /* Important: if the cutoff handling changes, _nfixed_mat_mul_bound_strassen must change too. */ if (cutoff < 0) cutoff = nfixed_mat_mul_strassen_cutoff(nn, ac & 1, nlimbs); else @@ -1392,6 +1439,7 @@ _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, s d = FLINT_MIN(m, n); d = FLINT_MIN(d, p); + /* Important: if the cutoff handling changes, _nfixed_mat_mul_bound must change too. */ if (d > 10) { cutoff = nfixed_mat_mul_strassen_cutoff(d, n & 1, nlimbs); @@ -1408,3 +1456,172 @@ _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, s else _nfixed_mat_mul_classical(C, A, B, m, n, p, nlimbs); } + +/* + Given an m x n x p matrix multiplication with inputs bounded + entrywise by A, B and nlimbs precision: + + - Set *bound* to a bound for the entries in all intermediate + variables of the computation. This should be < 1 to + ensure correctness. + - Set *error* to a bound for the output error, measured in ulp. + + The caller can assume that the bound is a nondecreasing function + of A and B. + + IMPORTANT: when changing the algorithm in _nfixed_mat_mul, this + must be changed to correspond. +*/ + +void +_nfixed_mat_mul_bound_classical(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs) +{ + double fixup = 1.0 + 1e-6; + + /* Error bound (in ulp) for naive scalar multiplication, and for dot product */ + double error_mul = (2 * nlimbs - 1); + double error_dot = n * error_mul; + + *bound = (n * A * B) * fixup; + *error = error_dot * fixup; +} + +void +_nfixed_mat_mul_bound_waksman(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs) +{ + double fixup = 1.0 + 1e-6; + + /* Error bound (in ulp) for naive scalar multiplication */ + double error_mul = (2 * nlimbs - 1); + + *bound = FLINT_MAX(A + B, 6 * (n / 2) * (A + B) * (A + B) + A * B) * fixup; + *error = ((6 * (n / 2) + 1) * error_mul + 5) * fixup; +} + +void +_nfixed_mat_mul_bound_strassen(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong cutoff, slong nlimbs) +{ + double fixup = 1.0 + 1e-6; + slong d; + + /* Error bound (in ulp) for naive scalar multiplication, and for dot product */ + double error_mul = (2 * nlimbs - 1); + double error_dot = n * error_mul; + + d = FLINT_MIN(m, n); + d = FLINT_MIN(d, p); + + if (cutoff < 0) + cutoff = nfixed_mat_mul_strassen_cutoff(d, n, nlimbs); + else + cutoff = FLINT_MAX(cutoff, 2); + + if (d < cutoff) + { + if (nfixed_mat_mul_use_waksman(d, nlimbs)) + _nfixed_mat_mul_bound_waksman(bound, error, m, n, p, A, B, nlimbs); + else + _nfixed_mat_mul_bound_classical(bound, error, m, n, p, A, B, nlimbs); + return; + } + + slong mm, nn, pp; + double bound_transformed_A, bound_transformed_B; + double bound_everything, bound_recursive, error_recursive; + double bound_recombination, error_recombination; + double ulp; + + /* Bound for entries of transformed block matrices */ + /* S1 = A22 + A12 <= 2A + S2 = A22 - A21 <= 2A + S3 = S2 + A12 <= 3A + S4 = S3 - A11 <= 3A, and similarly Ti for B */ + bound_transformed_A = 3.0 * A; + bound_transformed_B = 3.0 * B; + bound_everything = FLINT_MAX(bound_transformed_A, bound_transformed_B); + + /* Bound intermediate entries and errors for recursive multiplications. */ + mm = m / 2; + nn = n / 2; + pp = p / 2; + _nfixed_mat_mul_bound_strassen(&bound_recursive, &error_recursive, mm, nn, pp, bound_transformed_A, bound_transformed_B, cutoff, nlimbs); + bound_everything = FLINT_MAX(bound_everything, bound_recursive); + + /* Bound for recombinations. We don't use bound_recursive here, + because this can be a huge overestimate if the basecase + is nonclassical multiplication. Instead, we use the + theoretical bounds for the subproducts and add the + bound u = error_recursive (in the event the recursive + multiplications were not rounded down). */ + /* P1 = S1 T1 <= 4 nn A B + u + P2 = S2 T2 <= 4 nn A B + u + P3 = S3 T3 <= 9 nn A B + u + P4 = A11 B11 <= nn A B + u + P5 = A12 B21 <= nn A B + u + P6 = S4 B12 <= 3 nn A B + u + P7 = A21 T4 <= 3 nn A B + u + U1 = P3 + P5 <= 10 nn A B + 2u + U2 = P1 - U1 <= 14 nn A B + 3u + U3 = U1 - P2 <= 14 nn A B + 3u + C11 = P4 + P5 <= 2 nn A B + 2u + C12 = U3 - P6 <= 17 nn A B + 4u + C21 = U2 - P7 <= 17 nn A B + 4u + C22 = P2 + U2 <= 18 nn A B + 4u + */ + + ulp = ldexp(1.0, FLINT_MAX(-128, -nlimbs * FLINT_BITS)); + + error_recombination = 4 * error_recursive; + bound_recombination = 18 * nn * A * B + error_recombination * ulp; + + /* Bound for border corrections when m, n, and/or p is odd. + Todo: could be added conditionally. Assumes border + corrections use classical multiplication. */ + bound_recombination += A * B; + bound_recombination = FLINT_MAX(bound_recombination, n * A * B); + error_recombination += error_mul; + error_recombination = FLINT_MAX(error_recombination, error_dot); + + bound_everything = FLINT_MAX(bound_everything, bound_recombination); + + *bound = bound_everything * fixup; + *error = error_recombination * fixup; +} + +void +_nfixed_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs) +{ + slong d, cutoff; + + d = FLINT_MIN(m, n); + d = FLINT_MIN(d, p); + + if (d > 10) + { + cutoff = nfixed_mat_mul_strassen_cutoff(d, n & 1, nlimbs); + + if (n >= cutoff) + { + _nfixed_mat_mul_bound_strassen(bound, error, m, n, p, A, B, -1, nlimbs); + return; + } + } + + if (nfixed_mat_mul_use_waksman(d, nlimbs)) + _nfixed_mat_mul_bound_waksman(bound, error, m, n, p, A, B, nlimbs); + else + _nfixed_mat_mul_bound_classical(bound, error, m, n, p, A, B, nlimbs); +} + +/* Karatsuba formula */ +/* (A C - B D) + ((A + B)(C + D) - A C - B D) i */ +void +_nfixed_complex_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, double C, double D, slong nlimbs) +{ + double rbound, rerror, fixup = 1.0 + 1e-6; + + _nfixed_mat_mul_bound(&rbound, &rerror, m, n, p, A + B, C + D, nlimbs); + + (*bound) = FLINT_MAX(3.0 * rbound, FLINT_MAX(A + B, C + D)) * fixup; + (*error) = rerror * (3.0 * fixup); +} diff --git a/src/nfloat/test/main.c b/src/nfloat/test/main.c index 88f9a6c362..0054610e07 100644 --- a/src/nfloat/test/main.c +++ b/src/nfloat/test/main.c @@ -17,6 +17,9 @@ #include "t-mat_mul.c" #include "t-nfixed_dot.c" #include "t-nfixed_mat_mul.c" +#include "t-nfixed_mat_mul_classical.c" +#include "t-nfixed_mat_mul_strassen.c" +#include "t-nfixed_mat_mul_waksman.c" #include "t-nfloat.c" #include "t-nfloat_complex.c" @@ -30,6 +33,9 @@ test_struct tests[] = TEST_FUNCTION(mat_mul), TEST_FUNCTION(nfixed_dot), TEST_FUNCTION(nfixed_mat_mul), + TEST_FUNCTION(nfixed_mat_mul_classical), + TEST_FUNCTION(nfixed_mat_mul_strassen), + TEST_FUNCTION(nfixed_mat_mul_waksman), TEST_FUNCTION(nfloat), TEST_FUNCTION(nfloat_complex), }; diff --git a/src/nfloat/test/t-nfixed_mat_mul.c b/src/nfloat/test/t-nfixed_mat_mul.c index 4ef678057d..54788e4f4b 100644 --- a/src/nfloat/test/t-nfixed_mat_mul.c +++ b/src/nfloat/test/t-nfixed_mat_mul.c @@ -10,6 +10,7 @@ */ #include "test_helpers.h" +#include "double_extras.h" #include "fmpq.h" #include "arf.h" #include "gr_vec.h" @@ -21,23 +22,35 @@ TEST_FUNCTION_START(nfixed_mat_mul, state) slong iter, m, n, p, i, nlimbs; nn_ptr A, B, C, D, t; nn_ptr a; - int which; - slong MAXN = 12; + slong MAXN = 20; slong MINLIMBS = 2; slong MAXLIMBS = 12; - for (iter = 0; iter < 10000 * flint_test_multiplier(); iter++) + for (iter = 0; iter < 1000 * flint_test_multiplier(); iter++) { - which = n_randint(state, 6); - m = 1 + n_randint(state, MAXN); n = 1 + n_randint(state, MAXN); p = 1 + n_randint(state, MAXN); nlimbs = MINLIMBS + n_randint(state, MAXLIMBS - MINLIMBS + 1); - ulong maxerr = 2 * (2 * nlimbs - 1) * n; + ulong maxerr; + + int top; + double bound, error, classical_precise_error; + + top = 1; + while (1) + { + _nfixed_mat_mul_bound(&bound, &error, m, n, p, ldexp(1.0, -top), ldexp(1.0, -top), nlimbs); + if (bound < 1.0) + break; + top++; + } + + classical_precise_error = 1.01; + maxerr = (ulong) (error + classical_precise_error + 1.0); A = flint_malloc((nlimbs + 1) * (m * n) * sizeof(ulong)); B = flint_malloc((nlimbs + 1) * (n * p) * sizeof(ulong)); @@ -50,7 +63,7 @@ TEST_FUNCTION_START(nfixed_mat_mul, state) a = A + i * (nlimbs + 1); a[0] = n_randint(state, 2); flint_mpn_rrandom(a + 1, state, nlimbs); - a[nlimbs] >>= 10; + a[nlimbs] >>= top; } for (i = 0; i < n * p; i++) @@ -58,7 +71,7 @@ TEST_FUNCTION_START(nfixed_mat_mul, state) a = B + i * (nlimbs + 1); a[0] = n_randint(state, 2); flint_mpn_rrandom(a + 1, state, nlimbs); - a[nlimbs] >>= 10; + a[nlimbs] >>= top; } for (i = 0; i < m * p; i++) @@ -72,12 +85,8 @@ TEST_FUNCTION_START(nfixed_mat_mul, state) flint_mpn_rrandom(a + 1, state, nlimbs); } - _nfixed_mat_mul_classical(C, A, B, m, n, p, nlimbs); - - if (which == 0) - _nfixed_mat_mul_waksman(D, A, B, m, n, p, nlimbs); - else - _nfixed_mat_mul_strassen(D, A, B, m, n, p, which, nlimbs); + _nfixed_mat_mul_classical_precise(C, A, B, m, n, p, nlimbs); + _nfixed_mat_mul(D, A, B, m, n, p, nlimbs); for (i = 0; i < m * p; i++) { diff --git a/src/nfloat/test/t-nfixed_mat_mul_classical.c b/src/nfloat/test/t-nfixed_mat_mul_classical.c new file mode 100644 index 0000000000..a6c86ab88c --- /dev/null +++ b/src/nfloat/test/t-nfixed_mat_mul_classical.c @@ -0,0 +1,110 @@ +/* + Copyright (C) 2024 Fredrik Johansson + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "test_helpers.h" +#include "double_extras.h" +#include "fmpq.h" +#include "arf.h" +#include "gr_vec.h" +#include "gr_special.h" +#include "nfloat.h" + +TEST_FUNCTION_START(nfixed_mat_mul_classical, state) +{ + slong iter, m, n, p, i, nlimbs; + nn_ptr A, B, C, D, t; + nn_ptr a; + + slong MAXN = 20; + slong MINLIMBS = 2; + slong MAXLIMBS = 12; + + for (iter = 0; iter < 1000 * flint_test_multiplier(); iter++) + { + m = 1 + n_randint(state, MAXN); + n = 1 + n_randint(state, MAXN); + p = 1 + n_randint(state, MAXN); + + nlimbs = MINLIMBS + n_randint(state, MAXLIMBS - MINLIMBS + 1); + + ulong maxerr; + + int top; + double bound, error, classical_precise_error; + + top = 1; + while (1) + { + _nfixed_mat_mul_bound_classical(&bound, &error, m, n, p, ldexp(1.0, -top), ldexp(1.0, -top), nlimbs); + if (bound < 1.0) + break; + top++; + } + + classical_precise_error = 1.01; + maxerr = (ulong) (error + classical_precise_error + 1.0); + + A = flint_malloc((nlimbs + 1) * (m * n) * sizeof(ulong)); + B = flint_malloc((nlimbs + 1) * (n * p) * sizeof(ulong)); + C = flint_malloc((nlimbs + 1) * (m * p) * sizeof(ulong)); + D = flint_malloc((nlimbs + 1) * (m * p) * sizeof(ulong)); + t = flint_malloc((nlimbs + 1) * sizeof(ulong)); + + for (i = 0; i < m * n; i++) + { + a = A + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + a[nlimbs] >>= top; + } + + for (i = 0; i < n * p; i++) + { + a = B + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + a[nlimbs] >>= top; + } + + for (i = 0; i < m * p; i++) + { + a = C + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + + a = D + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + } + + _nfixed_mat_mul_classical_precise(C, A, B, m, n, p, nlimbs); + _nfixed_mat_mul_classical(D, A, B, m, n, p, nlimbs); + + for (i = 0; i < m * p; i++) + { + nfixed_sub(t, C + i * (nlimbs + 1), D + i * (nlimbs + 1), nlimbs); + + if (!flint_mpn_zero_p(t + 2, nlimbs - 1) || t[1] > maxerr) + { + TEST_FUNCTION_FAIL("nlimbs = %wd, m = %wd, n = %wd, p = %wd, top = %d\n\nt = %{ulong*}, maxerr = %wu\n\nA = %{ulong*}\n\nB = %{ulong*}\n\nC = %{ulong*}\n\nD = %{ulong*}\n\n", + nlimbs, m, n, p, top, + t, nlimbs + 1, maxerr, A, m * n * (nlimbs + 1), B, n * p * (nlimbs + 1), C, m * p * (nlimbs + 1), D, m * p * (nlimbs + 1)); + } + } + + flint_free(A); + flint_free(B); + flint_free(C); + flint_free(D); + } + + TEST_FUNCTION_END(state); +} \ No newline at end of file diff --git a/src/nfloat/test/t-nfixed_mat_mul_strassen.c b/src/nfloat/test/t-nfixed_mat_mul_strassen.c new file mode 100644 index 0000000000..3e6438db8b --- /dev/null +++ b/src/nfloat/test/t-nfixed_mat_mul_strassen.c @@ -0,0 +1,113 @@ +/* + Copyright (C) 2024 Fredrik Johansson + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "test_helpers.h" +#include "double_extras.h" +#include "fmpq.h" +#include "arf.h" +#include "gr_vec.h" +#include "gr_special.h" +#include "nfloat.h" + +TEST_FUNCTION_START(nfixed_mat_mul_strassen, state) +{ + slong iter, m, n, p, i, nlimbs; + nn_ptr A, B, C, D, t; + nn_ptr a; + slong cutoff; + + slong MAXN = 20; + slong MINLIMBS = 2; + slong MAXLIMBS = 12; + + for (iter = 0; iter < 1000 * flint_test_multiplier(); iter++) + { + cutoff = n_randint(state, 6); + + m = 1 + n_randint(state, MAXN); + n = 1 + n_randint(state, MAXN); + p = 1 + n_randint(state, MAXN); + + nlimbs = MINLIMBS + n_randint(state, MAXLIMBS - MINLIMBS + 1); + + ulong maxerr; + + int top; + double bound, error, classical_precise_error; + + top = 1; + while (1) + { + _nfixed_mat_mul_bound_strassen(&bound, &error, m, n, p, ldexp(1.0, -top), ldexp(1.0, -top), cutoff, nlimbs); + if (bound < 1.0) + break; + top++; + } + + classical_precise_error = 1.01; + maxerr = (ulong) (error + classical_precise_error + 1.0); + + A = flint_malloc((nlimbs + 1) * (m * n) * sizeof(ulong)); + B = flint_malloc((nlimbs + 1) * (n * p) * sizeof(ulong)); + C = flint_malloc((nlimbs + 1) * (m * p) * sizeof(ulong)); + D = flint_malloc((nlimbs + 1) * (m * p) * sizeof(ulong)); + t = flint_malloc((nlimbs + 1) * sizeof(ulong)); + + for (i = 0; i < m * n; i++) + { + a = A + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + a[nlimbs] >>= top; + } + + for (i = 0; i < n * p; i++) + { + a = B + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + a[nlimbs] >>= top; + } + + for (i = 0; i < m * p; i++) + { + a = C + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + + a = D + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + } + + _nfixed_mat_mul_classical(C, A, B, m, n, p, nlimbs); + _nfixed_mat_mul_strassen(D, A, B, m, n, p, cutoff, nlimbs); + + for (i = 0; i < m * p; i++) + { + nfixed_sub(t, C + i * (nlimbs + 1), D + i * (nlimbs + 1), nlimbs); + + if (!flint_mpn_zero_p(t + 2, nlimbs - 1) || t[1] > maxerr) + { + TEST_FUNCTION_FAIL("nlimbs = %wd, m = %wd, n = %wd, p = %wd\n\nt = %{ulong*}, maxerr = %wu\n\nA = %{ulong*}\n\nB = %{ulong*}\n\nC = %{ulong*}\n\nD = %{ulong*}\n\n", + nlimbs, m, n, p, + t, nlimbs + 1, maxerr, A, m * n * (nlimbs + 1), B, n * p * (nlimbs + 1), C, m * p * (nlimbs + 1), D, m * p * (nlimbs + 1)); + } + } + + flint_free(A); + flint_free(B); + flint_free(C); + flint_free(D); + } + + TEST_FUNCTION_END(state); +} \ No newline at end of file diff --git a/src/nfloat/test/t-nfixed_mat_mul_waksman.c b/src/nfloat/test/t-nfixed_mat_mul_waksman.c new file mode 100644 index 0000000000..2a1495a43d --- /dev/null +++ b/src/nfloat/test/t-nfixed_mat_mul_waksman.c @@ -0,0 +1,110 @@ +/* + Copyright (C) 2024 Fredrik Johansson + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "test_helpers.h" +#include "double_extras.h" +#include "fmpq.h" +#include "arf.h" +#include "gr_vec.h" +#include "gr_special.h" +#include "nfloat.h" + +TEST_FUNCTION_START(nfixed_mat_mul_waksman, state) +{ + slong iter, m, n, p, i, nlimbs; + nn_ptr A, B, C, D, t; + nn_ptr a; + + slong MAXN = 20; + slong MINLIMBS = 2; + slong MAXLIMBS = 12; + + for (iter = 0; iter < 1000 * flint_test_multiplier(); iter++) + { + m = 1 + n_randint(state, MAXN); + n = 1 + n_randint(state, MAXN); + p = 1 + n_randint(state, MAXN); + + nlimbs = MINLIMBS + n_randint(state, MAXLIMBS - MINLIMBS + 1); + + ulong maxerr; + + int top; + double bound, error, classical_precise_error; + + top = 1; + while (1) + { + _nfixed_mat_mul_bound_waksman(&bound, &error, m, n, p, ldexp(1.0, -top), ldexp(1.0, -top), nlimbs); + if (bound < 1.0) + break; + top++; + } + + classical_precise_error = 1.01; + maxerr = (ulong) (error + classical_precise_error + 1.0); + + A = flint_malloc((nlimbs + 1) * (m * n) * sizeof(ulong)); + B = flint_malloc((nlimbs + 1) * (n * p) * sizeof(ulong)); + C = flint_malloc((nlimbs + 1) * (m * p) * sizeof(ulong)); + D = flint_malloc((nlimbs + 1) * (m * p) * sizeof(ulong)); + t = flint_malloc((nlimbs + 1) * sizeof(ulong)); + + for (i = 0; i < m * n; i++) + { + a = A + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + a[nlimbs] >>= top; + } + + for (i = 0; i < n * p; i++) + { + a = B + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + a[nlimbs] >>= top; + } + + for (i = 0; i < m * p; i++) + { + a = C + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + + a = D + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + } + + _nfixed_mat_mul_classical_precise(C, A, B, m, n, p, nlimbs); + _nfixed_mat_mul_waksman(D, A, B, m, n, p, nlimbs); + + for (i = 0; i < m * p; i++) + { + nfixed_sub(t, C + i * (nlimbs + 1), D + i * (nlimbs + 1), nlimbs); + + if (!flint_mpn_zero_p(t + 2, nlimbs - 1) || t[1] > maxerr) + { + TEST_FUNCTION_FAIL("nlimbs = %wd, m = %wd, n = %wd, p = %wd\n\nt = %{ulong*}, maxerr = %wu\n\nA = %{ulong*}\n\nB = %{ulong*}\n\nC = %{ulong*}\n\nD = %{ulong*}\n\n", + nlimbs, m, n, p, + t, nlimbs + 1, maxerr, A, m * n * (nlimbs + 1), B, n * p * (nlimbs + 1), C, m * p * (nlimbs + 1), D, m * p * (nlimbs + 1)); + } + } + + flint_free(A); + flint_free(B); + flint_free(C); + flint_free(D); + } + + TEST_FUNCTION_END(state); +} \ No newline at end of file