From d986efd5490ce7ea96cde57e9e69d580920e79e6 Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Mon, 19 Aug 2024 16:14:27 +0200 Subject: [PATCH 01/15] some more test code --- src/nfloat/test/t-mat_mul.c | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/nfloat/test/t-mat_mul.c b/src/nfloat/test/t-mat_mul.c index dad6ad625c..5e7d8b29ac 100644 --- a/src/nfloat/test/t-mat_mul.c +++ b/src/nfloat/test/t-mat_mul.c @@ -27,6 +27,24 @@ TEST_FUNCTION_START(mat_mul, state) slong iter; gr_ptr tol; + for (iter = 0; iter < 100 * flint_test_multiplier(); iter++) + { + prec = 64; + + nfloat_ctx_init(ctx, prec, 0); + + tol = gr_heap_init(ctx); + GR_MUST_SUCCEED(gr_one(tol, ctx)); + GR_MUST_SUCCEED(gr_mul_2exp_si(tol, tol, -prec + 2, ctx)); + + gr_mat_test_approx_mul_max_norm( + (gr_method_mat_binary_op) nfloat_mat_mul_fixed_classical, + tol, state, 10, 10, ctx); + + gr_heap_clear(tol, ctx); + gr_ctx_clear(ctx); + } + for (iter = 0; iter < 10 * flint_test_multiplier(); iter++) { if (n_randint(state, 5)) From 3beb2d7c34920b731cb94fa840d2319f139103be Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Mon, 19 Aug 2024 16:14:43 +0200 Subject: [PATCH 02/15] define add_sssssss... everywhere --- src/crt_helpers.h | 229 +----------------------- src/longlong.h | 75 ++++++++ src/longlong_asm_clang.h | 111 ++++++++++++ src/longlong_asm_gcc.h | 233 ++++++++++++++++++++++++- src/mpn_mod/mat_lu_classical_delayed.c | 8 - src/mpn_mod/poly_divrem_basecase.c | 10 -- src/nfloat/complex.c | 75 -------- src/nfloat/dot.c | 58 ------ src/nfloat/nfloat.c | 58 ------ 9 files changed, 421 insertions(+), 436 deletions(-) diff --git a/src/crt_helpers.h b/src/crt_helpers.h index 4d54633f09..2de185deb7 100644 --- a/src/crt_helpers.h +++ b/src/crt_helpers.h @@ -86,237 +86,14 @@ FLINT_FORCE_INLINE unsigned char _subborrow_ulong(unsigned char cf, ulong x, ulo #if defined(__GNUC__) && defined(__AVX2__) -#define add_sssssaaaaaaaaaa(s4,s3,s2,s1,s0, a4,a3,a2,a1,a0, b4,b3,b2,b1,b0) \ - __asm__ ("addq %14,%q4\n\tadcq %12,%q3\n\tadcq %10,%q2\n\tadcq %8,%q1\n\tadcq %6,%q0" \ - : "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "1" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "2" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "3" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "4" ((ulong)(a0)), "rme" ((ulong)(b0))) - -#define add_ssssssaaaaaaaaaaaa(s5,s4,s3,s2,s1,s0, a5,a4,a3,a2,a1,a0, b5,b4,b3,b2,b1,b0) \ - __asm__ ("addq %17,%q5\nadcq %15,%q4\n\tadcq %13,%q3\n\tadcq %11,%q2\n\tadcq %9,%q1\n\tadcq %7,%q0" \ - : "=r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a5)), "rme" ((ulong)(b5)), \ - "1" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "2" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "3" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "4" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "5" ((ulong)(a0)), "rme" ((ulong)(b0))) - -#define add_sssssssaaaaaaaaaaaaaa(s6,s5,s4,s3,s2,s1,s0, a6,a5,a4,a3,a2,a1,a0, b6,b5,b4,b3,b2,b1,b0) \ - __asm__ ("addq %20,%q6\nadcq %18,%q5\nadcq %16,%q4\n\tadcq %14,%q3\n\tadcq %12,%q2\n\tadcq %10,%q1\n\tadcq %8,%q0" \ - : "=r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a6)), "rme" ((ulong)(b6)), \ - "1" ((ulong)(a5)), "rme" ((ulong)(b5)), \ - "2" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "3" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "4" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "5" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "6" ((ulong)(a0)), "rme" ((ulong)(b0))) - -#define add_ssssssssaaaaaaaaaaaaaaaa(s7,s6,s5,s4,s3,s2,s1,s0, a7,a6,a5,a4,a3,a2,a1,a0, b7,b6,b5,b4,b3,b2,b1,b0) \ - __asm__ ("addq %23,%q7\nadcq %21,%q6\nadcq %19,%q5\n\tadcq %17,%q4\n\tadcq %15,%q3\n\tadcq %13,%q2\n\tadcq %11,%q1\n\tadcq %9,%q0" \ - : "=r" (s7), "=&r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a7)), "rme" ((ulong)(b7)), \ - "1" ((ulong)(a6)), "rme" ((ulong)(b6)), \ - "2" ((ulong)(a5)), "rme" ((ulong)(b5)), \ - "3" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "4" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "5" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "6" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "7" ((ulong)(a0)), "rme" ((ulong)(b0))) - - -#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \ - __asm__ ("subq %11,%q3\n\tsbbq %9,%q2\n\tsbbq %7,%q1\n\tsbbq %5,%q0" \ - : "=r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "1" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "2" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "3" ((ulong)(a0)), "rme" ((ulong)(b0))) - -#define sub_dddddmmmmmsssss(s4,s3,s2,s1,s0, a4,a3,a2,a1,a0, b4,b3,b2,b1,b0) \ - __asm__ ("subq %14,%q4\n\tsbbq %12,%q3\n\tsbbq %10,%q2\n\tsbbq %8,%q1\n\tsbbq %6,%q0" \ - : "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "1" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "2" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "3" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "4" ((ulong)(a0)), "rme" ((ulong)(b0))) - -#define sub_ddddddmmmmmmssssss(s5,s4,s3,s2,s1,s0, a5,a4,a3,a2,a1,a0, b5,b4,b3,b2,b1,b0) \ - __asm__ ("subq %17,%q5\nsbbq %15,%q4\n\tsbbq %13,%q3\n\tsbbq %11,%q2\n\tsbbq %9,%q1\n\tsbbq %7,%q0" \ - : "=r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a5)), "rme" ((ulong)(b5)), \ - "1" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "2" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "3" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "4" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "5" ((ulong)(a0)), "rme" ((ulong)(b0))) - -#define sub_dddddddmmmmmmmsssssss(s6,s5,s4,s3,s2,s1,s0, a6,a5,a4,a3,a2,a1,a0, b6,b5,b4,b3,b2,b1,b0) \ - __asm__ ("subq %20,%q6\nsbbq %18,%q5\nsbbq %16,%q4\n\tsbbq %14,%q3\n\tsbbq %12,%q2\n\tsbbq %10,%q1\n\tsbbq %8,%q0" \ - : "=r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a6)), "rme" ((ulong)(b6)), \ - "1" ((ulong)(a5)), "rme" ((ulong)(b5)), \ - "2" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "3" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "4" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "5" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "6" ((ulong)(a0)), "rme" ((ulong)(b0))) - -#define sub_ddddddddmmmmmmmmssssssss(s7,s6,s5,s4,s3,s2,s1,s0, a7,a6,a5,a4,a3,a2,a1,a0, b7,b6,b5,b4,b3,b2,b1,b0) \ - __asm__ ("subq %23,%q7\nsbbq %21,%q6\nsbbq %19,%q5\n\tsbbq %17,%q4\n\tsbbq %15,%q3\n\tsbbq %13,%q2\n\tsbbq %11,%q1\n\tsbbq %9,%q0" \ - : "=r" (s7), "=&r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a7)), "rme" ((ulong)(b7)), \ - "1" ((ulong)(a6)), "rme" ((ulong)(b6)), \ - "2" ((ulong)(a5)), "rme" ((ulong)(b5)), \ - "3" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "4" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "5" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "6" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "7" ((ulong)(a0)), "rme" ((ulong)(b0))) #elif defined(__GNUC__) && defined(__ARM_NEON) -#define add_sssssaaaaaaaaaa(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ - __asm__ ("adds %4,%9,%14\n\tadcs %3,%8,%13\n\tadcs %2,%7,%12\n\tadcs %1,%6,%11\n\tadc %0,%5,%10"\ - : "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ - "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ - : "cc") - -#define add_ssssssaaaaaaaaaaaa(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ - __asm__ ("adds %5,%11,%17\n\tadcs %4,%10,%16\n\tadcs %3,%9,%15\n\tadcs %2,%8,%14\n\tadcs %1,%7,%13\n\tadc %0,%6,%12"\ - : "=r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ - "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ - : "cc") - -#define add_sssssssaaaaaaaaaaaaaa(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ - __asm__ ("adds %6,%13,%20\n\tadcs %5,%12,%19\n\tadcs %4,%11,%18\n\tadcs %3,%10,%17\n\tadcs %2,%9,%16\n\tadcs %1,%8,%15\n\tadc %0,%7,%14"\ - : "=r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "r" ((ulong)(a6)), "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ - "r" ((ulong)(b6)), "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ - : "cc") - -#define add_ssssssssaaaaaaaaaaaaaaaa(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ - __asm__ ("adds %7,%15,%23\n\tadcs %6,%14,%22\n\tadcs %5,%13,%21\n\tadcs %4,%12,%20\n\tadcs %3,%11,%19\n\tadcs %2,%10,%18\n\tadcs %1,%9,%17\n\tadc %0,%8,%16"\ - : "=r" (s7), "=&r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "r" ((ulong)(a7)), "r" ((ulong)(a6)), "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ - "r" ((ulong)(b7)), "r" ((ulong)(b6)), "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ - : "cc") - - -#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \ - __asm__ ("subs %3,%7,%11\n\tsbcs %2,%6,%10\n\tsbcs %1,%5,%9\n\tsbc %0,%4,%8"\ - : "=r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ - "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ - : "cc") - -#define sub_dddddmmmmmsssss(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ - __asm__ ("subs %4,%9,%14\n\tsbcs %3,%8,%13\n\tsbcs %2,%7,%12\n\tsbcs %1,%6,%11\n\tsbc %0,%5,%10"\ - : "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ - "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ - : "cc") - -#define sub_ddddddmmmmmmssssss(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ - __asm__ ("subs %5,%11,%17\n\tsbcs %4,%10,%16\n\tsbcs %3,%9,%15\n\tsbcs %2,%8,%14\n\tsbcs %1,%7,%13\n\tsbc %0,%6,%12"\ - : "=r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ - "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ - : "cc") - -#define sub_dddddddmmmmmmmsssssss(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ - __asm__ ("subs %6,%13,%20\n\tsbcs %5,%12,%19\n\tsbcs %4,%11,%18\n\tsbcs %3,%10,%17\n\tsbcs %2,%9,%16\n\tsbcs %1,%8,%15\n\tsbc %0,%7,%14"\ - : "=r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "r" ((ulong)(a6)), "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ - "r" ((ulong)(b6)), "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ - : "cc") - -#define sub_ddddddddmmmmmmmmssssssss(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ - __asm__ ("subs %7,%15,%23\n\tsbcs %6,%14,%22\n\tsbcs %5,%13,%21\n\tsbcs %4,%12,%20\n\tsbcs %3,%11,%19\n\tsbcs %2,%10,%18\n\tsbcs %1,%9,%17\n\tsbc %0,%8,%16"\ - : "=r" (s7), "=&r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "r" ((ulong)(a7)), "r" ((ulong)(a6)), "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ - "r" ((ulong)(b7)), "r" ((ulong)(b6)), "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ - : "cc") + #elif defined(_MSC_VER) && (defined(__AVX2__) || defined(_M_ARM64)) -#define add_sssssaaaaaaaaaa(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ - do { \ - ulong __t0 = 0; \ - add_ssssaaaaaaaa(__t0, s2, s1, s0, (ulong) 0, a2, a1, a0, (ulong) 0, b2, b1, b0); \ - add_ssaaaa(s4, s3, a4, a3, b4, b3); \ - add_ssaaaa(s4, s3, s4, s3, (ulong) 0, __t0); \ - } while (0) - -#define add_ssssssaaaaaaaaaaaa(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ - do { \ - ulong __t1 = 0; \ - add_sssssaaaaaaaaaa(__t1, s3, s2, s1, s0, (ulong) 0, a3, a2, a1, a0, (ulong) 0, b3, b2, b1, b0);\ - add_ssaaaa(s5, s4, a5, a4, b5, b4); \ - add_ssaaaa(s5, s4, s5, s4, (ulong) 0, __t1); \ - } while (0) - -#define add_sssssssaaaaaaaaaaaaaa(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ - do { \ - ulong __t2 = 0; \ - add_ssssssaaaaaaaaaaaa(__t2, s4, s3, s2, s1, s0, (ulong) 0, a4, a3, a2, a1, a0, (ulong) 0, b4, b3, b2, b1, b0); \ - add_ssaaaa(s6, s5, a6, a5, b6, b5); \ - add_ssaaaa(s6, s5, s6, s5, (ulong) 0, __t2); \ - } while (0) - -#define add_ssssssssaaaaaaaaaaaaaaaa(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ - do { \ - ulong __t3 = 0; \ - add_sssssssaaaaaaaaaaaaaa(__t3, s5, s4, s3, s2, s1, s0, (ulong) 0, a5, a4, a3, a2, a1, a0, (ulong) 0, b5, b4, b3, b2, b1, b0); \ - add_ssaaaa(s7, s6, a7, a6, b7, b6); \ - add_ssaaaa(s7, s6, s7, s6, (ulong) 0, __t3); \ - } while (0) - -#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \ - do { \ - ulong __t1, __u1; \ - sub_dddmmmsss(__t1, s1, s0, (ulong) 0, a1, a0, (ulong) 0, b1, b0); \ - sub_ddmmss(__u1, s2, (ulong) 0, a2, (ulong) 0, b2); \ - sub_ddmmss(s3, s2, (a3) - (b3), s2, -__u1, -__t1); \ - } while (0) - -#define sub_dddddmmmmmsssss(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ - do { \ - ulong __t2, __u2; \ - sub_ddddmmmmssss(__t2, s2, s1, s0, (ulong) 0, a2, a1, a0, (ulong) 0, b2, b1, b0); \ - sub_ddmmss(__u2, s3, (ulong) 0, a3, (ulong) 0, b3); \ - sub_ddmmss(s4, s3, (a4) - (b4), s3, -__u2, -__t2); \ - } while (0) - -#define sub_ddddddmmmmmmssssss(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ - do { \ - ulong __t3, __u3; \ - sub_dddddmmmmmsssss(__t3, s3, s2, s1, s0, (ulong) 0, a3, a2, a1, a0, (ulong) 0, b3, b2, b1, b0);\ - sub_ddmmss(__u3, s4, (ulong) 0, a4, (ulong) 0, b4); \ - sub_ddmmss(s5, s4, (a5) - (b5), s4, -__u3, -__t3); \ - } while (0) - -#define sub_dddddddmmmmmmmsssssss(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ - do { \ - ulong __t4, __u4; \ - sub_ddddddmmmmmmssssss(__t4, s4, s3, s2, s1, s0, (ulong) 0, a4, a3, a2, a1, a0, (ulong) 0, b4, b3, b2, b1, b0); \ - sub_ddmmss(__u4, s5, (ulong) 0, a5, (ulong) 0, b5); \ - sub_ddmmss(s6, s5, (a6) - (b6), s5, -__u4, -__t4); \ - } while (0) - -#define sub_ddddddddmmmmmmmmssssssss(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ - do { \ - ulong __t5, __u5; \ - sub_dddddddmmmmmmmsssssss(__t5, s5, s4, s3, s2, s1, s0, (ulong) 0, a5, a4, a3, a2, a1, a0, (ulong) 0, b5, b4, b3, b2, b1, b0); \ - sub_ddmmss(__u5, s6, (ulong) 0, a6, (ulong) 0, b6); \ - sub_ddmmss(s7, s6, (a7) - (b7), s6, -__u5, -__t5); \ - } while (0) + + #else # error crt_helpers.h requires AVX2 or Neon instructions diff --git a/src/longlong.h b/src/longlong.h index 5704a570b4..ea55421922 100644 --- a/src/longlong.h +++ b/src/longlong.h @@ -108,6 +108,7 @@ flint_bitcnt_t FLINT_BIT_COUNT(ulong x) /* Addition and subtraction */ #if !defined(add_ssaaaa) + # define add_ssaaaa(s1, s0, a1, a0, b1, b0) \ do { \ ulong __t0 = (a0); \ @@ -131,6 +132,38 @@ flint_bitcnt_t FLINT_BIT_COUNT(ulong x) add_ssaaaa(s3, s2, s3, s2, (ulong) 0, __u2); \ } while (0) +#define add_sssssaaaaaaaaaa(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ + do { \ + ulong __t0 = 0; \ + add_ssssaaaaaaaa(__t0, s2, s1, s0, (ulong) 0, a2, a1, a0, (ulong) 0, b2, b1, b0); \ + add_ssaaaa(s4, s3, a4, a3, b4, b3); \ + add_ssaaaa(s4, s3, s4, s3, (ulong) 0, __t0); \ + } while (0) + +#define add_ssssssaaaaaaaaaaaa(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ + do { \ + ulong __t1 = 0; \ + add_sssssaaaaaaaaaa(__t1, s3, s2, s1, s0, (ulong) 0, a3, a2, a1, a0, (ulong) 0, b3, b2, b1, b0); \ + add_ssaaaa(s5, s4, a5, a4, b5, b4); \ + add_ssaaaa(s5, s4, s5, s4, (ulong) 0, __t1); \ + } while (0) + +#define add_sssssssaaaaaaaaaaaaaa(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ + do { \ + ulong __t2 = 0; \ + add_ssssssaaaaaaaaaaaa(__t2, s4, s3, s2, s1, s0, (ulong) 0, a4, a3, a2, a1, a0, (ulong) 0, b4, b3, b2, b1, b0); \ + add_ssaaaa(s6, s5, a6, a5, b6, b5); \ + add_ssaaaa(s6, s5, s6, s5, (ulong) 0, __t2); \ + } while (0) + +#define add_ssssssssaaaaaaaaaaaaaaaa(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ + do { \ + ulong __t3 = 0; \ + add_sssssssaaaaaaaaaaaaaa(__t3, s5, s4, s3, s2, s1, s0, (ulong) 0, a5, a4, a3, a2, a1, a0, (ulong) 0, b5, b4, b3, b2, b1, b0); \ + add_ssaaaa(s7, s6, a7, a6, b7, b6); \ + add_ssaaaa(s7, s6, s7, s6, (ulong) 0, __t3); \ + } while (0) + # define sub_ddmmss(s1, s0, a1, a0, b1, b0) \ do { \ ulong __t0 = (a0); \ @@ -145,8 +178,50 @@ flint_bitcnt_t FLINT_BIT_COUNT(ulong x) sub_ddmmss(__t2, d1, (ulong) 0, m1, (ulong) 0, s1); \ sub_ddmmss(d2, d1, (m2) - (s2), d1, -__t2, -__t1); \ } while (0) + +#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \ + do { \ + ulong __t1, __u1; \ + sub_dddmmmsss(__t1, s1, s0, (ulong) 0, a1, a0, (ulong) 0, b1, b0); \ + sub_ddmmss(__u1, s2, (ulong) 0, a2, (ulong) 0, b2); \ + sub_ddmmss(s3, s2, (a3) - (b3), s2, -__u1, -__t1); \ + } while (0) + +#define sub_dddddmmmmmsssss(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ + do { \ + ulong __t2, __u2; \ + sub_ddddmmmmssss(__t2, s2, s1, s0, (ulong) 0, a2, a1, a0, (ulong) 0, b2, b1, b0); \ + sub_ddmmss(__u2, s3, (ulong) 0, a3, (ulong) 0, b3); \ + sub_ddmmss(s4, s3, (a4) - (b4), s3, -__u2, -__t2); \ + } while (0) + +#define sub_ddddddmmmmmmssssss(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ + do { \ + ulong __t3, __u3; \ + sub_dddddmmmmmsssss(__t3, s3, s2, s1, s0, (ulong) 0, a3, a2, a1, a0, (ulong) 0, b3, b2, b1, b0);\ + sub_ddmmss(__u3, s4, (ulong) 0, a4, (ulong) 0, b4); \ + sub_ddmmss(s5, s4, (a5) - (b5), s4, -__u3, -__t3); \ + } while (0) + +#define sub_dddddddmmmmmmmsssssss(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ + do { \ + ulong __t4, __u4; \ + sub_ddddddmmmmmmssssss(__t4, s4, s3, s2, s1, s0, (ulong) 0, a4, a3, a2, a1, a0, (ulong) 0, b4, b3, b2, b1, b0); \ + sub_ddmmss(__u4, s5, (ulong) 0, a5, (ulong) 0, b5); \ + sub_ddmmss(s6, s5, (a6) - (b6), s5, -__u4, -__t4); \ + } while (0) + +#define sub_ddddddddmmmmmmmmssssssss(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ + do { \ + ulong __t5, __u5; \ + sub_dddddddmmmmmmmsssssss(__t5, s5, s4, s3, s2, s1, s0, (ulong) 0, a5, a4, a3, a2, a1, a0, (ulong) 0, b5, b4, b3, b2, b1, b0); \ + sub_ddmmss(__u5, s6, (ulong) 0, a6, (ulong) 0, b6); \ + sub_ddmmss(s7, s6, (a7) - (b7), s6, -__u5, -__t5); \ + } while (0) + #endif + #if !defined(MPN_INCR_U) # if FLINT_WANT_ASSERT # define MPN_INCR_U(ptr, size, incr) \ diff --git a/src/longlong_asm_clang.h b/src/longlong_asm_clang.h index f18beb7461..302f80e000 100644 --- a/src/longlong_asm_clang.h +++ b/src/longlong_asm_clang.h @@ -55,6 +55,57 @@ do \ (s3) = _FLINT_ADC(a3, b3, _carry, &_carry); \ } while (0) +#define add_sssssaaaaaaaaaa(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ +do \ +{ \ + ulong _carry; \ + (s0) = _FLINT_ADC(a0, b0, 0, &_carry); \ + (s1) = _FLINT_ADC(a1, b1, _carry, &_carry); \ + (s2) = _FLINT_ADC(a2, b2, _carry, &_carry); \ + (s3) = _FLINT_ADC(a3, b3, _carry, &_carry); \ + (s4) = _FLINT_ADC(a4, b4, _carry, &_carry); \ +} while (0) + +#define add_ssssssaaaaaaaaaaaa(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ +do \ +{ \ + ulong _carry; \ + (s0) = _FLINT_ADC(a0, b0, 0, &_carry); \ + (s1) = _FLINT_ADC(a1, b1, _carry, &_carry); \ + (s2) = _FLINT_ADC(a2, b2, _carry, &_carry); \ + (s3) = _FLINT_ADC(a3, b3, _carry, &_carry); \ + (s4) = _FLINT_ADC(a4, b4, _carry, &_carry); \ + (s5) = _FLINT_ADC(a5, b5, _carry, &_carry); \ +} while (0) + +#define add_sssssssaaaaaaaaaaaaaa(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ +do \ +{ \ + ulong _carry; \ + (s0) = _FLINT_ADC(a0, b0, 0, &_carry); \ + (s1) = _FLINT_ADC(a1, b1, _carry, &_carry); \ + (s2) = _FLINT_ADC(a2, b2, _carry, &_carry); \ + (s3) = _FLINT_ADC(a3, b3, _carry, &_carry); \ + (s4) = _FLINT_ADC(a4, b4, _carry, &_carry); \ + (s5) = _FLINT_ADC(a5, b5, _carry, &_carry); \ + (s6) = _FLINT_ADC(a6, b6, _carry, &_carry); \ +} while (0) + +#define add_ssssssssaaaaaaaaaaaaaaaa(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ +do \ +{ \ + ulong _carry; \ + (s0) = _FLINT_ADC(a0, b0, 0, &_carry); \ + (s1) = _FLINT_ADC(a1, b1, _carry, &_carry); \ + (s2) = _FLINT_ADC(a2, b2, _carry, &_carry); \ + (s3) = _FLINT_ADC(a3, b3, _carry, &_carry); \ + (s4) = _FLINT_ADC(a4, b4, _carry, &_carry); \ + (s5) = _FLINT_ADC(a5, b5, _carry, &_carry); \ + (s6) = _FLINT_ADC(a6, b6, _carry, &_carry); \ + (s7) = _FLINT_ADC(a7, b7, _carry, &_carry); \ +} while (0) + + #define sub_ddmmss(s1, s0, a1, a0, b1, b0) \ do \ { \ @@ -72,6 +123,66 @@ do \ (s2) = _FLINT_SBB(a2, b2, _carry, &_carry); \ } while (0) +#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \ +do \ +{ \ + ulong _carry; \ + (s0) = _FLINT_SBB(a0, b0, 0, &_carry); \ + (s1) = _FLINT_SBB(a1, b1, _carry, &_carry); \ + (s2) = _FLINT_SBB(a2, b2, _carry, &_carry); \ + (s3) = _FLINT_SBB(a3, b3, _carry, &_carry); \ +} while (0) + +#define sub_dddddmmmmmsssss(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ +do \ +{ \ + ulong _carry; \ + (s0) = _FLINT_SBB(a0, b0, 0, &_carry); \ + (s1) = _FLINT_SBB(a1, b1, _carry, &_carry); \ + (s2) = _FLINT_SBB(a2, b2, _carry, &_carry); \ + (s3) = _FLINT_SBB(a3, b3, _carry, &_carry); \ + (s4) = _FLINT_SBB(a4, b4, _carry, &_carry); \ +} while (0) + +#define sub_ddddddmmmmmmssssss(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ +do \ +{ \ + ulong _carry; \ + (s0) = _FLINT_SBB(a0, b0, 0, &_carry); \ + (s1) = _FLINT_SBB(a1, b1, _carry, &_carry); \ + (s2) = _FLINT_SBB(a2, b2, _carry, &_carry); \ + (s3) = _FLINT_SBB(a3, b3, _carry, &_carry); \ + (s4) = _FLINT_SBB(a4, b4, _carry, &_carry); \ + (s5) = _FLINT_SBB(a5, b5, _carry, &_carry); \ +} while (0) + +#define sub_dddddddmmmmmmmsssssss(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ +do \ +{ \ + ulong _carry; \ + (s0) = _FLINT_SBB(a0, b0, 0, &_carry); \ + (s1) = _FLINT_SBB(a1, b1, _carry, &_carry); \ + (s2) = _FLINT_SBB(a2, b2, _carry, &_carry); \ + (s3) = _FLINT_SBB(a3, b3, _carry, &_carry); \ + (s4) = _FLINT_SBB(a4, b4, _carry, &_carry); \ + (s5) = _FLINT_SBB(a5, b5, _carry, &_carry); \ + (s6) = _FLINT_SBB(a6, b6, _carry, &_carry); \ +} while (0) + +#define sub_ddddddddmmmmmmmmssssssss(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ +do \ +{ \ + ulong _carry; \ + (s0) = _FLINT_SBB(a0, b0, 0, &_carry); \ + (s1) = _FLINT_SBB(a1, b1, _carry, &_carry); \ + (s2) = _FLINT_SBB(a2, b2, _carry, &_carry); \ + (s3) = _FLINT_SBB(a3, b3, _carry, &_carry); \ + (s4) = _FLINT_SBB(a4, b4, _carry, &_carry); \ + (s5) = _FLINT_SBB(a5, b5, _carry, &_carry); \ + (s6) = _FLINT_SBB(a6, b6, _carry, &_carry); \ + (s7) = _FLINT_SBB(a7, b7, _carry, &_carry); \ +} while (0) + #define _mul_ppmm(big_type, small_type, r1, r0, u, v) \ do \ { \ diff --git a/src/longlong_asm_gcc.h b/src/longlong_asm_gcc.h index 15786cbdaf..5f4925775e 100644 --- a/src/longlong_asm_gcc.h +++ b/src/longlong_asm_gcc.h @@ -57,7 +57,7 @@ "2" ((ulong)(a0)), _ASM_RME ((ulong)(b0))) # define add_ssssaaaaaaaa(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \ - __asm__(_ASM_ADD " %11,%" _ASM_PRE"3\n" \ + __asm__(_ASM_ADD " %11,%" _ASM_PRE "3\n" \ "\t" _ASM_ADC " %9,%" _ASM_PRE "2\n" \ "\t" _ASM_ADC " %7,%" _ASM_PRE "1\n" \ "\t" _ASM_ADC " %5,%" _ASM_PRE "0" \ @@ -67,6 +67,70 @@ "2" ((ulong)(a1)), _ASM_RME ((ulong)(b1)), \ "3" ((ulong)(a0)), _ASM_RME ((ulong)(b0))) +# define add_sssssaaaaaaaaaa(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ + __asm__(_ASM_ADD " %14,%" _ASM_PRE "4\n" \ + "\t" _ASM_ADC " %12,%" _ASM_PRE "3\n" \ + "\t" _ASM_ADC " %10,%" _ASM_PRE "2\n" \ + "\t" _ASM_ADC " %8,%" _ASM_PRE "1\n" \ + "\t" _ASM_ADC " %6,%" _ASM_PRE "0" \ + : "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "0" ((ulong)(a4)), _ASM_RME ((ulong)(b4)), \ + "1" ((ulong)(a3)), _ASM_RME ((ulong)(b3)), \ + "2" ((ulong)(a2)), _ASM_RME ((ulong)(b2)), \ + "3" ((ulong)(a1)), _ASM_RME ((ulong)(b1)), \ + "4" ((ulong)(a0)), _ASM_RME ((ulong)(b0))) + +# define add_ssssssaaaaaaaaaaaa(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ + __asm__(_ASM_ADD " %17,%" _ASM_PRE "5\n" \ + "\t" _ASM_ADC " %15,%" _ASM_PRE "4\n" \ + "\t" _ASM_ADC " %13,%" _ASM_PRE "3\n" \ + "\t" _ASM_ADC " %11,%" _ASM_PRE "2\n" \ + "\t" _ASM_ADC " %9,%" _ASM_PRE "1\n" \ + "\t" _ASM_ADC " %7,%" _ASM_PRE "0" \ + : "=r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "0" ((ulong)(a5)), _ASM_RME ((ulong)(b5)), \ + "1" ((ulong)(a4)), _ASM_RME ((ulong)(b4)), \ + "2" ((ulong)(a3)), _ASM_RME ((ulong)(b3)), \ + "3" ((ulong)(a2)), _ASM_RME ((ulong)(b2)), \ + "4" ((ulong)(a1)), _ASM_RME ((ulong)(b1)), \ + "5" ((ulong)(a0)), _ASM_RME ((ulong)(b0))) + +# define add_sssssssaaaaaaaaaaaaaa(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ + __asm__(_ASM_ADD " %20,%" _ASM_PRE "6\n" \ + "\t" _ASM_ADC " %18,%" _ASM_PRE "5\n" \ + "\t" _ASM_ADC " %16,%" _ASM_PRE "4\n" \ + "\t" _ASM_ADC " %14,%" _ASM_PRE "3\n" \ + "\t" _ASM_ADC " %12,%" _ASM_PRE "2\n" \ + "\t" _ASM_ADC " %10,%" _ASM_PRE "1\n" \ + "\t" _ASM_ADC " %8,%" _ASM_PRE "0" \ + : "=r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "0" ((ulong)(a6)), _ASM_RME ((ulong)(b6)), \ + "1" ((ulong)(a5)), _ASM_RME ((ulong)(b5)), \ + "2" ((ulong)(a4)), _ASM_RME ((ulong)(b4)), \ + "3" ((ulong)(a3)), _ASM_RME ((ulong)(b3)), \ + "4" ((ulong)(a2)), _ASM_RME ((ulong)(b2)), \ + "5" ((ulong)(a1)), _ASM_RME ((ulong)(b1)), \ + "6" ((ulong)(a0)), _ASM_RME ((ulong)(b0))) + +# define add_ssssssssaaaaaaaaaaaaaaaa(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ + __asm__(_ASM_ADD " %23,%" _ASM_PRE "7\n" \ + "\t" _ASM_ADC " %21,%" _ASM_PRE "6\n" \ + "\t" _ASM_ADC " %19,%" _ASM_PRE "5\n" \ + "\t" _ASM_ADC " %17,%" _ASM_PRE "4\n" \ + "\t" _ASM_ADC " %15,%" _ASM_PRE "3\n" \ + "\t" _ASM_ADC " %13,%" _ASM_PRE "2\n" \ + "\t" _ASM_ADC " %11,%" _ASM_PRE "1\n" \ + "\t" _ASM_ADC " %9,%" _ASM_PRE "0" \ + : "=r" (s7), "=&r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "0" ((ulong)(a7)), _ASM_RME ((ulong)(b7)), \ + "1" ((ulong)(a6)), _ASM_RME ((ulong)(b6)), \ + "2" ((ulong)(a5)), _ASM_RME ((ulong)(b5)), \ + "3" ((ulong)(a4)), _ASM_RME ((ulong)(b4)), \ + "4" ((ulong)(a3)), _ASM_RME ((ulong)(b3)), \ + "5" ((ulong)(a2)), _ASM_RME ((ulong)(b2)), \ + "6" ((ulong)(a1)), _ASM_RME ((ulong)(b1)), \ + "7" ((ulong)(a0)), _ASM_RME ((ulong)(b0))) + # define sub_ddmmss(d1, d0, m1, m0, s1, s0) \ __asm__(_ASM_SUB " %5,%" _ASM_PRE "1\n" \ "\t" _ASM_SBB " %3,%" _ASM_PRE "0" \ @@ -83,6 +147,81 @@ "1" ((ulong)(m1)), _ASM_RME ((ulong)(s1)), \ "2" ((ulong)(m0)), _ASM_RME ((ulong)(s0))) +# define sub_ddddmmmmssss(d3, d2, d1, d0, m3, m2, m1, m0, s3, s2, s1, s0) \ + __asm__(_ASM_SUB " %11,%" _ASM_PRE "3\n" \ + "\t" _ASM_SBB " %9,%" _ASM_PRE "2\n" \ + "\t" _ASM_SBB " %7,%" _ASM_PRE "1\n" \ + "\t" _ASM_SBB " %5,%" _ASM_PRE "0" \ + : "=r" (d3), "=&r" (d2), "=&r" (d1), "=&r" (d0) \ + : "0" ((ulong)(m3)), _ASM_RME ((ulong)(s3)), \ + "1" ((ulong)(m2)), _ASM_RME ((ulong)(s2)), \ + "2" ((ulong)(m1)), _ASM_RME ((ulong)(s1)), \ + "3" ((ulong)(m0)), _ASM_RME ((ulong)(s0))) + +# define sub_dddddmmmmmsssss(d4, d3, d2, d1, d0, m4, m3, m2, m1, m0, s4, s3, s2, s1, s0) \ + __asm__(_ASM_SUB " %14,%" _ASM_PRE "4\n" \ + "\t" _ASM_SBB " %12,%" _ASM_PRE "3\n" \ + "\t" _ASM_SBB " %10,%" _ASM_PRE "2\n" \ + "\t" _ASM_SBB " %8,%" _ASM_PRE "1\n" \ + "\t" _ASM_SBB " %6,%" _ASM_PRE "0" \ + : "=r" (d4), "=&r" (d3), "=&r" (d2), "=&r" (d1), "=&r" (d0) \ + : "0" ((ulong)(m4)), _ASM_RME ((ulong)(s4)), \ + "1" ((ulong)(m3)), _ASM_RME ((ulong)(s3)), \ + "2" ((ulong)(m2)), _ASM_RME ((ulong)(s2)), \ + "3" ((ulong)(m1)), _ASM_RME ((ulong)(s1)), \ + "4" ((ulong)(m0)), _ASM_RME ((ulong)(s0))) + +# define sub_ddddddmmmmmmssssss(d5, d4, d3, d2, d1, d0, m5, m4, m3, m2, m1, m0, s5, s4, s3, s2, s1, s0) \ + __asm__(_ASM_SUB " %17,%" _ASM_PRE "5\n" \ + "\t" _ASM_SBB " %15,%" _ASM_PRE "4\n" \ + "\t" _ASM_SBB " %13,%" _ASM_PRE "3\n" \ + "\t" _ASM_SBB " %11,%" _ASM_PRE "2\n" \ + "\t" _ASM_SBB " %9,%" _ASM_PRE "1\n" \ + "\t" _ASM_SBB " %7,%" _ASM_PRE "0" \ + : "=r" (d5), "=&r" (d4), "=&r" (d3), "=&r" (d2), "=&r" (d1), "=&r" (d0) \ + : "0" ((ulong)(m5)), _ASM_RME ((ulong)(s5)), \ + "1" ((ulong)(m4)), _ASM_RME ((ulong)(s4)), \ + "2" ((ulong)(m3)), _ASM_RME ((ulong)(s3)), \ + "3" ((ulong)(m2)), _ASM_RME ((ulong)(s2)), \ + "4" ((ulong)(m1)), _ASM_RME ((ulong)(s1)), \ + "5" ((ulong)(m0)), _ASM_RME ((ulong)(s0))) + +# define sub_dddddddmmmmmmmsssssss(d6, d5, d4, d3, d2, d1, d0, m6, m5, m4, m3, m2, m1, m0, s6, s5, s4, s3, s2, s1, s0) \ + __asm__(_ASM_SUB " %20,%" _ASM_PRE "6\n" \ + "\t" _ASM_SBB " %18,%" _ASM_PRE "5\n" \ + "\t" _ASM_SBB " %16,%" _ASM_PRE "4\n" \ + "\t" _ASM_SBB " %14,%" _ASM_PRE "3\n" \ + "\t" _ASM_SBB " %12,%" _ASM_PRE "2\n" \ + "\t" _ASM_SBB " %10,%" _ASM_PRE "1\n" \ + "\t" _ASM_SBB " %8,%" _ASM_PRE "0" \ + : "=r" (d6), "=&r" (d5), "=&r" (d4), "=&r" (d3), "=&r" (d2), "=&r" (d1), "=&r" (d0) \ + : "0" ((ulong)(m6)), _ASM_RME ((ulong)(s6)), \ + "1" ((ulong)(m5)), _ASM_RME ((ulong)(s5)), \ + "2" ((ulong)(m4)), _ASM_RME ((ulong)(s4)), \ + "3" ((ulong)(m3)), _ASM_RME ((ulong)(s3)), \ + "4" ((ulong)(m2)), _ASM_RME ((ulong)(s2)), \ + "5" ((ulong)(m1)), _ASM_RME ((ulong)(s1)), \ + "6" ((ulong)(m0)), _ASM_RME ((ulong)(s0))) + +# define sub_ddddddddmmmmmmmmssssssss(d7, d6, d5, d4, d3, d2, d1, d0, m7, m6, m5, m4, m3, m2, m1, m0, s7, s6, s5, s4, s3, s2, s1, s0) \ + __asm__(_ASM_SUB " %23,%" _ASM_PRE "7\n" \ + "\t" _ASM_SBB " %21,%" _ASM_PRE "6\n" \ + "\t" _ASM_SBB " %19,%" _ASM_PRE "5\n" \ + "\t" _ASM_SBB " %17,%" _ASM_PRE "4\n" \ + "\t" _ASM_SBB " %15,%" _ASM_PRE "3\n" \ + "\t" _ASM_SBB " %13,%" _ASM_PRE "2\n" \ + "\t" _ASM_SBB " %11,%" _ASM_PRE "1\n" \ + "\t" _ASM_SBB " %9,%" _ASM_PRE "0" \ + : "=r" (d7), "=&r" (d6), "=&r" (d5), "=&r" (d4), "=&r" (d3), "=&r" (d2), "=&r" (d1), "=&r" (d0) \ + : "0" ((ulong)(m7)), _ASM_RME ((ulong)(s7)), \ + "1" ((ulong)(m6)), _ASM_RME ((ulong)(s6)), \ + "2" ((ulong)(m5)), _ASM_RME ((ulong)(s5)), \ + "3" ((ulong)(m4)), _ASM_RME ((ulong)(s4)), \ + "4" ((ulong)(m3)), _ASM_RME ((ulong)(s3)), \ + "5" ((ulong)(m2)), _ASM_RME ((ulong)(s2)), \ + "6" ((ulong)(m1)), _ASM_RME ((ulong)(s1)), \ + "7" ((ulong)(m0)), _ASM_RME ((ulong)(s0))) + # if defined(__BMI2__) && defined(__amd64__) # define umul_ppmm(w1, w0, u, v) \ __asm__("mulx\t%3, %q0, %q1" \ @@ -129,6 +268,34 @@ "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ : "cc") +#define add_sssssaaaaaaaaaa(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ + __asm__ ("adds %4,%9,%14\n\tadcs %3,%8,%13\n\tadcs %2,%7,%12\n\tadcs %1,%6,%11\n\tadc %0,%5,%10"\ + : "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ + "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ + : "cc") + +#define add_ssssssaaaaaaaaaaaa(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ + __asm__ ("adds %5,%11,%17\n\tadcs %4,%10,%16\n\tadcs %3,%9,%15\n\tadcs %2,%8,%14\n\tadcs %1,%7,%13\n\tadc %0,%6,%12"\ + : "=r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ + "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ + : "cc") + +#define add_sssssssaaaaaaaaaaaaaa(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ + __asm__ ("adds %6,%13,%20\n\tadcs %5,%12,%19\n\tadcs %4,%11,%18\n\tadcs %3,%10,%17\n\tadcs %2,%9,%16\n\tadcs %1,%8,%15\n\tadc %0,%7,%14"\ + : "=r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "r" ((ulong)(a6)), "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ + "r" ((ulong)(b6)), "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ + : "cc") + +#define add_ssssssssaaaaaaaaaaaaaaaa(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ + __asm__ ("adds %7,%15,%23\n\tadcs %6,%14,%22\n\tadcs %5,%13,%21\n\tadcs %4,%12,%20\n\tadcs %3,%11,%19\n\tadcs %2,%10,%18\n\tadcs %1,%9,%17\n\tadc %0,%8,%16"\ + : "=r" (s7), "=&r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "r" ((ulong)(a7)), "r" ((ulong)(a6)), "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ + "r" ((ulong)(b7)), "r" ((ulong)(b6)), "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ + : "cc") + # define sub_ddmmss(s1, s0, a1, a0, b1, b0) \ __asm__("subs %1,%3,%5\n" \ "\tsbc %0,%2,%4" \ @@ -146,6 +313,70 @@ "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ : "cc") +#define add_sssssaaaaaaaaaa(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ + __asm__ ("adds %4,%9,%14\n\tadcs %3,%8,%13\n\tadcs %2,%7,%12\n\tadcs %1,%6,%11\n\tadc %0,%5,%10"\ + : "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ + "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ + : "cc") + +#define add_ssssssaaaaaaaaaaaa(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ + __asm__ ("adds %5,%11,%17\n\tadcs %4,%10,%16\n\tadcs %3,%9,%15\n\tadcs %2,%8,%14\n\tadcs %1,%7,%13\n\tadc %0,%6,%12"\ + : "=r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ + "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ + : "cc") + +#define add_sssssssaaaaaaaaaaaaaa(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ + __asm__ ("adds %6,%13,%20\n\tadcs %5,%12,%19\n\tadcs %4,%11,%18\n\tadcs %3,%10,%17\n\tadcs %2,%9,%16\n\tadcs %1,%8,%15\n\tadc %0,%7,%14"\ + : "=r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "r" ((ulong)(a6)), "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ + "r" ((ulong)(b6)), "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ + : "cc") + +#define add_ssssssssaaaaaaaaaaaaaaaa(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ + __asm__ ("adds %7,%15,%23\n\tadcs %6,%14,%22\n\tadcs %5,%13,%21\n\tadcs %4,%12,%20\n\tadcs %3,%11,%19\n\tadcs %2,%10,%18\n\tadcs %1,%9,%17\n\tadc %0,%8,%16"\ + : "=r" (s7), "=&r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "r" ((ulong)(a7)), "r" ((ulong)(a6)), "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ + "r" ((ulong)(b7)), "r" ((ulong)(b6)), "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ + : "cc") + + +#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \ + __asm__ ("subs %3,%7,%11\n\tsbcs %2,%6,%10\n\tsbcs %1,%5,%9\n\tsbc %0,%4,%8"\ + : "=r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ + "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ + : "cc") + +#define sub_dddddmmmmmsssss(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ + __asm__ ("subs %4,%9,%14\n\tsbcs %3,%8,%13\n\tsbcs %2,%7,%12\n\tsbcs %1,%6,%11\n\tsbc %0,%5,%10"\ + : "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ + "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ + : "cc") + +#define sub_ddddddmmmmmmssssss(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ + __asm__ ("subs %5,%11,%17\n\tsbcs %4,%10,%16\n\tsbcs %3,%9,%15\n\tsbcs %2,%8,%14\n\tsbcs %1,%7,%13\n\tsbc %0,%6,%12"\ + : "=r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ + "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ + : "cc") + +#define sub_dddddddmmmmmmmsssssss(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ + __asm__ ("subs %6,%13,%20\n\tsbcs %5,%12,%19\n\tsbcs %4,%11,%18\n\tsbcs %3,%10,%17\n\tsbcs %2,%9,%16\n\tsbcs %1,%8,%15\n\tsbc %0,%7,%14"\ + : "=r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "r" ((ulong)(a6)), "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ + "r" ((ulong)(b6)), "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ + : "cc") + +#define sub_ddddddddmmmmmmmmssssssss(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ + __asm__ ("subs %7,%15,%23\n\tsbcs %6,%14,%22\n\tsbcs %5,%13,%21\n\tsbcs %4,%12,%20\n\tsbcs %3,%11,%19\n\tsbcs %2,%10,%18\n\tsbcs %1,%9,%17\n\tsbc %0,%8,%16"\ + : "=r" (s7), "=&r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "r" ((ulong)(a7)), "r" ((ulong)(a6)), "r" ((ulong)(a5)), "r" ((ulong)(a4)), "r" ((ulong)(a3)), "r" ((ulong)(a2)), "r" ((ulong)(a1)), "r" ((ulong)(a0)), \ + "r" ((ulong)(b7)), "r" ((ulong)(b6)), "r" ((ulong)(b5)), "r" ((ulong)(b4)), "r" ((ulong)(b3)), "r" ((ulong)(b2)), "r" ((ulong)(b1)), "rI" ((ulong)(b0)) \ + : "cc") + # if defined(__arm__) # define umul_ppmm(xh, xl, a, b) \ __asm__("umull %0,%1,%2,%3" \ diff --git a/src/mpn_mod/mat_lu_classical_delayed.c b/src/mpn_mod/mat_lu_classical_delayed.c index eeb132b249..1a489003f4 100644 --- a/src/mpn_mod/mat_lu_classical_delayed.c +++ b/src/mpn_mod/mat_lu_classical_delayed.c @@ -11,12 +11,6 @@ #include "mpn_mod.h" -/* for wide add_ssss.... macros. todo; these ought to be provided - everywhere */ -#if FLINT_BITS == 64 && defined(__AVX2__) -#include "crt_helpers.h" -#endif - /* todo: optimize for when 2n rather than 2n+1 limbs suffice */ int mpn_mod_mat_lu_classical_delayed(slong * res_rank, slong * P, gr_mat_t A, const gr_mat_t A_in, int rank_check, gr_ctx_t ctx) @@ -143,7 +137,6 @@ mpn_mod_mat_lu_classical_delayed(slong * res_rank, slong * P, gr_mat_t A, const mpn_mod_mul(e, REDUCED(i, col), d, ctx); mpn_mod_neg(f, e, ctx); -#if defined(add_sssssaaaaaaaaaa) if (n == 2) { for (j = col + 1; j < ncols; j++) @@ -161,7 +154,6 @@ mpn_mod_mat_lu_classical_delayed(slong * res_rank, slong * P, gr_mat_t A, const REDUCED(i, rank - 1)[1] = e[1]; } else -#endif { if (col + 1 < ncols) { diff --git a/src/mpn_mod/poly_divrem_basecase.c b/src/mpn_mod/poly_divrem_basecase.c index f8e4300846..ce88958020 100644 --- a/src/mpn_mod/poly_divrem_basecase.c +++ b/src/mpn_mod/poly_divrem_basecase.c @@ -11,12 +11,6 @@ #include "mpn_mod.h" -/* for wide add_ssss.... macros. todo; these ought to be provided - everywhere */ -#if FLINT_BITS == 64 && defined(__AVX2__) -#include "crt_helpers.h" -#endif - static void mpn_mod_set_mpn2(nn_ptr res, nn_srcptr s, slong l, gr_ctx_t ctx) { @@ -103,7 +97,6 @@ static int _mpn_mod_poly_divrem_q1_preinv1(nn_ptr Q, nn_ptr R, mpn_mod_set(Q + nlimbs, q1, ctx); mpn_mod_neg(q1, q1, ctx); -#if defined(add_sssssaaaaaaaaaa) if (nlimbs == 2) { slong bits = 2 * MPN_MOD_CTX_MODULUS_BITS(ctx) + 1; @@ -141,7 +134,6 @@ static int _mpn_mod_poly_divrem_q1_preinv1(nn_ptr Q, nn_ptr R, } } else -#endif { for (i = 1; i < lenB - 1; i++) { @@ -240,7 +232,6 @@ _mpn_mod_poly_divrem_basecase_preinv1(nn_ptr Q, nn_ptr R, /* todo: consider writing all products to a temporary buffer and doing a single big mpn_add_n */ -#if defined(add_sssssaaaaaaaaaa) if (nlimbs == 2) { ulong t[4]; @@ -277,7 +268,6 @@ _mpn_mod_poly_divrem_basecase_preinv1(nn_ptr Q, nn_ptr R, } } else -#endif { if (slimbs == 2 * nlimbs + 1) { diff --git a/src/nfloat/complex.c b/src/nfloat/complex.c index e9fff75bbd..f925ed178e 100644 --- a/src/nfloat/complex.c +++ b/src/nfloat/complex.c @@ -37,81 +37,6 @@ _flint_mpn_signed_add_n(nn_ptr res, nn_srcptr x, int xsgnbit, nn_srcptr y, int y return xsgnbit; } -/* todo: define in longlong.h */ -#if FLINT_BITS == 64 && defined(__GNUC__) && defined(__AVX2__) - -#define add_sssssaaaaaaaaaa(s4,s3,s2,s1,s0, a4,a3,a2,a1,a0, b4,b3,b2,b1,b0) \ - __asm__ ("addq %14,%q4\n\tadcq %12,%q3\n\tadcq %10,%q2\n\tadcq %8,%q1\n\tadcq %6,%q0" \ - : "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "1" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "2" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "3" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "4" ((ulong)(a0)), "rme" ((ulong)(b0))) - -#define add_ssssssaaaaaaaaaaaa(s5,s4,s3,s2,s1,s0, a5,a4,a3,a2,a1,a0, b5,b4,b3,b2,b1,b0) \ - __asm__ ("addq %17,%q5\nadcq %15,%q4\n\tadcq %13,%q3\n\tadcq %11,%q2\n\tadcq %9,%q1\n\tadcq %7,%q0" \ - : "=r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a5)), "rme" ((ulong)(b5)), \ - "1" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "2" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "3" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "4" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "5" ((ulong)(a0)), "rme" ((ulong)(b0))) - -#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \ - __asm__ ("subq %11,%q3\n\tsbbq %9,%q2\n\tsbbq %7,%q1\n\tsbbq %5,%q0" \ - : "=r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "1" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "2" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "3" ((ulong)(a0)), "rme" ((ulong)(b0))) - -#define sub_dddddmmmmmsssss(s4,s3,s2,s1,s0, a4,a3,a2,a1,a0, b4,b3,b2,b1,b0) \ - __asm__ ("subq %14,%q4\n\tsbbq %12,%q3\n\tsbbq %10,%q2\n\tsbbq %8,%q1\n\tsbbq %6,%q0" \ - : "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "1" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "2" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "3" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "4" ((ulong)(a0)), "rme" ((ulong)(b0))) -#else - -#define add_sssssaaaaaaaaaa(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ - do { \ - ulong __t0 = 0; \ - add_ssssaaaaaaaa(__t0, s2, s1, s0, (ulong) 0, a2, a1, a0, (ulong) 0, b2, b1, b0); \ - add_ssaaaa(s4, s3, a4, a3, b4, b3); \ - add_ssaaaa(s4, s3, s4, s3, (ulong) 0, __t0); \ - } while (0) - -#define add_ssssssaaaaaaaaaaaa(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ - do { \ - ulong __t1 = 0; \ - add_sssssaaaaaaaaaa(__t1, s3, s2, s1, s0, (ulong) 0, a3, a2, a1, a0, (ulong) 0, b3, b2, b1, b0);\ - add_ssaaaa(s5, s4, a5, a4, b5, b4); \ - add_ssaaaa(s5, s4, s5, s4, (ulong) 0, __t1); \ - } while (0) - -#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \ - do { \ - ulong __t1, __u1; \ - sub_dddmmmsss(__t1, s1, s0, (ulong) 0, a1, a0, (ulong) 0, b1, b0); \ - sub_ddmmss(__u1, s2, (ulong) 0, a2, (ulong) 0, b2); \ - sub_ddmmss(s3, s2, (a3) - (b3), s2, -__u1, -__t1); \ - } while (0) - -#define sub_dddddmmmmmsssss(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ - do { \ - ulong __t2, __u2; \ - sub_ddddmmmmssss(__t2, s2, s1, s0, (ulong) 0, a2, a1, a0, (ulong) 0, b2, b1, b0); \ - sub_ddmmss(__u2, s3, (ulong) 0, a3, (ulong) 0, b3); \ - sub_ddmmss(s4, s3, (a4) - (b4), s3, -__u2, -__t2); \ - } while (0) - -#endif - - int nfloat_complex_get_acf(acf_t res, nfloat_complex_srcptr x, gr_ctx_t ctx) { diff --git a/src/nfloat/dot.c b/src/nfloat/dot.c index 4d2a0c0970..db25936c7b 100644 --- a/src/nfloat/dot.c +++ b/src/nfloat/dot.c @@ -35,64 +35,6 @@ (r1) = __r1; (r2) = __r2; (r3) = __r3; \ } while (0) -/* todo: define in longlong.h */ -#if FLINT_BITS == 64 && defined(__GNUC__) && defined(__AVX2__) - -#define add_sssssaaaaaaaaaa(s4,s3,s2,s1,s0, a4,a3,a2,a1,a0, b4,b3,b2,b1,b0) \ - __asm__ ("addq %14,%q4\n\tadcq %12,%q3\n\tadcq %10,%q2\n\tadcq %8,%q1\n\tadcq %6,%q0" \ - : "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "1" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "2" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "3" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "4" ((ulong)(a0)), "rme" ((ulong)(b0))) - - -#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \ - __asm__ ("subq %11,%q3\n\tsbbq %9,%q2\n\tsbbq %7,%q1\n\tsbbq %5,%q0" \ - : "=r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "1" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "2" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "3" ((ulong)(a0)), "rme" ((ulong)(b0))) - -#define sub_dddddmmmmmsssss(s4,s3,s2,s1,s0, a4,a3,a2,a1,a0, b4,b3,b2,b1,b0) \ - __asm__ ("subq %14,%q4\n\tsbbq %12,%q3\n\tsbbq %10,%q2\n\tsbbq %8,%q1\n\tsbbq %6,%q0" \ - : "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "1" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "2" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "3" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "4" ((ulong)(a0)), "rme" ((ulong)(b0))) -#else - -#define add_sssssaaaaaaaaaa(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ - do { \ - ulong __t0 = 0; \ - add_ssssaaaaaaaa(__t0, s2, s1, s0, (ulong) 0, a2, a1, a0, (ulong) 0, b2, b1, b0); \ - add_ssaaaa(s4, s3, a4, a3, b4, b3); \ - add_ssaaaa(s4, s3, s4, s3, (ulong) 0, __t0); \ - } while (0) - - -#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \ - do { \ - ulong __t1, __u1; \ - sub_dddmmmsss(__t1, s1, s0, (ulong) 0, a1, a0, (ulong) 0, b1, b0); \ - sub_ddmmss(__u1, s2, (ulong) 0, a2, (ulong) 0, b2); \ - sub_ddmmss(s3, s2, (a3) - (b3), s2, -__u1, -__t1); \ - } while (0) - -#define sub_dddddmmmmmsssss(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ - do { \ - ulong __t2, __u2; \ - sub_ddddmmmmssss(__t2, s2, s1, s0, (ulong) 0, a2, a1, a0, (ulong) 0, b2, b1, b0); \ - sub_ddmmss(__u2, s3, (ulong) 0, a3, (ulong) 0, b3); \ - sub_ddmmss(s4, s3, (a4) - (b4), s3, -__u2, -__t2); \ - } while (0) - -#endif - void _nfloat_vec_dot_set_initial(ulong * s, slong sexp, nfloat_srcptr x, int subtract, slong nlimbs) { diff --git a/src/nfloat/nfloat.c b/src/nfloat/nfloat.c index 10cb533916..4a47995326 100644 --- a/src/nfloat/nfloat.c +++ b/src/nfloat/nfloat.c @@ -17,64 +17,6 @@ #include "gr_generic.h" #include "gr_special.h" -/* todo: define in longlong.h */ -#if FLINT_BITS == 64 && defined(__GNUC__) && defined(__AVX2__) - -#define add_sssssaaaaaaaaaa(s4,s3,s2,s1,s0, a4,a3,a2,a1,a0, b4,b3,b2,b1,b0) \ - __asm__ ("addq %14,%q4\n\tadcq %12,%q3\n\tadcq %10,%q2\n\tadcq %8,%q1\n\tadcq %6,%q0" \ - : "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "1" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "2" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "3" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "4" ((ulong)(a0)), "rme" ((ulong)(b0))) - - -#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \ - __asm__ ("subq %11,%q3\n\tsbbq %9,%q2\n\tsbbq %7,%q1\n\tsbbq %5,%q0" \ - : "=r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "1" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "2" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "3" ((ulong)(a0)), "rme" ((ulong)(b0))) - -#define sub_dddddmmmmmsssss(s4,s3,s2,s1,s0, a4,a3,a2,a1,a0, b4,b3,b2,b1,b0) \ - __asm__ ("subq %14,%q4\n\tsbbq %12,%q3\n\tsbbq %10,%q2\n\tsbbq %8,%q1\n\tsbbq %6,%q0" \ - : "=r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a4)), "rme" ((ulong)(b4)), \ - "1" ((ulong)(a3)), "rme" ((ulong)(b3)), \ - "2" ((ulong)(a2)), "rme" ((ulong)(b2)), \ - "3" ((ulong)(a1)), "rme" ((ulong)(b1)), \ - "4" ((ulong)(a0)), "rme" ((ulong)(b0))) -#else - -#define add_sssssaaaaaaaaaa(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ - do { \ - ulong __t0 = 0; \ - add_ssssaaaaaaaa(__t0, s2, s1, s0, (ulong) 0, a2, a1, a0, (ulong) 0, b2, b1, b0); \ - add_ssaaaa(s4, s3, a4, a3, b4, b3); \ - add_ssaaaa(s4, s3, s4, s3, (ulong) 0, __t0); \ - } while (0) - - -#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \ - do { \ - ulong __t1, __u1; \ - sub_dddmmmsss(__t1, s1, s0, (ulong) 0, a1, a0, (ulong) 0, b1, b0); \ - sub_ddmmss(__u1, s2, (ulong) 0, a2, (ulong) 0, b2); \ - sub_ddmmss(s3, s2, (a3) - (b3), s2, -__u1, -__t1); \ - } while (0) - -#define sub_dddddmmmmmsssss(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ - do { \ - ulong __t2, __u2; \ - sub_ddddmmmmssss(__t2, s2, s1, s0, (ulong) 0, a2, a1, a0, (ulong) 0, b2, b1, b0); \ - sub_ddmmss(__u2, s3, (ulong) 0, a3, (ulong) 0, b3); \ - sub_ddmmss(s4, s3, (a4) - (b4), s3, -__u2, -__t2); \ - } while (0) - -#endif - int nfloat_write(gr_stream_t out, nfloat_srcptr x, gr_ctx_t ctx) { From 80df284e3abcc0fd2ea55637d901913157a15244 Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Mon, 19 Aug 2024 16:51:35 +0200 Subject: [PATCH 03/15] msc fix --- src/longlong_msc_x86.h | 112 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 112 insertions(+) diff --git a/src/longlong_msc_x86.h b/src/longlong_msc_x86.h index 79aa353b64..f23e57f5da 100644 --- a/src/longlong_msc_x86.h +++ b/src/longlong_msc_x86.h @@ -95,6 +95,57 @@ do \ _FLINT_ADC(_carry, a3, b3, &(s3)); \ } while (0) +#define add_sssssaaaaaaaaaa(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ +do \ +{ \ + unsigned char _carry; \ + _carry = _FLINT_ADC(0, a0, b0, &(s0)); \ + _carry = _FLINT_ADC(_carry, a1, b1, &(s1)); \ + _carry = _FLINT_ADC(_carry, a2, b2, &(s2)); \ + _carry = _FLINT_ADC(_carry, a3, b3, &(s3)); \ + _FLINT_ADC(_carry, a4, b4, &(s4)); \ +} while (0) + +#define add_ssssssaaaaaaaaaaaa(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ +do \ +{ \ + unsigned char _carry; \ + _carry = _FLINT_ADC(0, a0, b0, &(s0)); \ + _carry = _FLINT_ADC(_carry, a1, b1, &(s1)); \ + _carry = _FLINT_ADC(_carry, a2, b2, &(s2)); \ + _carry = _FLINT_ADC(_carry, a3, b3, &(s3)); \ + _carry = _FLINT_ADC(_carry, a4, b4, &(s4)); \ + _FLINT_ADC(_carry, a5, b5, &(s5)); \ +} while (0) + +#define add_sssssssaaaaaaaaaaaaaa(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ +do \ +{ \ + unsigned char _carry; \ + _carry = _FLINT_ADC(0, a0, b0, &(s0)); \ + _carry = _FLINT_ADC(_carry, a1, b1, &(s1)); \ + _carry = _FLINT_ADC(_carry, a2, b2, &(s2)); \ + _carry = _FLINT_ADC(_carry, a3, b3, &(s3)); \ + _carry = _FLINT_ADC(_carry, a4, b4, &(s4)); \ + _carry = _FLINT_ADC(_carry, a5, b5, &(s5)); \ + _FLINT_ADC(_carry, a6, b6, &(s6)); \ +} while (0) + +#define add_ssssssssaaaaaaaaaaaaaaaa(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ +do \ +{ \ + unsigned char _carry; \ + _carry = _FLINT_ADC(0, a0, b0, &(s0)); \ + _carry = _FLINT_ADC(_carry, a1, b1, &(s1)); \ + _carry = _FLINT_ADC(_carry, a2, b2, &(s2)); \ + _carry = _FLINT_ADC(_carry, a3, b3, &(s3)); \ + _carry = _FLINT_ADC(_carry, a4, b4, &(s4)); \ + _carry = _FLINT_ADC(_carry, a5, b5, &(s5)); \ + _carry = _FLINT_ADC(_carry, a6, b6, &(s6)); \ + _FLINT_ADC(_carry, a7, b7, &(s7)); \ +} while (0) + + #define sub_ddmmss(s1, s0, a1, a0, b1, b0) \ do \ { \ @@ -112,6 +163,67 @@ do \ _FLINT_SBB(_carry, a2, b2, &(s2)); \ } while (0) +#define sub_ddddmmmmssss(s3, s2, s1, s0, a3, a2, a1, a0, b3, b2, b1, b0) \ +do \ +{ \ + unsigned char _carry; \ + _carry = _FLINT_SBB(0, a0, b0, &(s0)); \ + _carry = _FLINT_SBB(_carry, a1, b1, &(s1)); \ + _carry = _FLINT_SBB(_carry, a2, b2, &(s2)); \ + _FLINT_SBB(_carry, a3, b3, &(s3)); \ +} while (0) + +#define sub_dddddmmmmmsssss(s4, s3, s2, s1, s0, a4, a3, a2, a1, a0, b4, b3, b2, b1, b0) \ +do \ +{ \ + unsigned char _carry; \ + _carry = _FLINT_SBB(0, a0, b0, &(s0)); \ + _carry = _FLINT_SBB(_carry, a1, b1, &(s1)); \ + _carry = _FLINT_SBB(_carry, a2, b2, &(s2)); \ + _carry = _FLINT_SBB(_carry, a3, b3, &(s3)); \ + _FLINT_SBB(_carry, a4, b4, &(s4)); \ +} while (0) + +#define sub_ddddddmmmmmmssssss(s5, s4, s3, s2, s1, s0, a5, a4, a3, a2, a1, a0, b5, b4, b3, b2, b1, b0) \ +do \ +{ \ + unsigned char _carry; \ + _carry = _FLINT_SBB(0, a0, b0, &(s0)); \ + _carry = _FLINT_SBB(_carry, a1, b1, &(s1)); \ + _carry = _FLINT_SBB(_carry, a2, b2, &(s2)); \ + _carry = _FLINT_SBB(_carry, a3, b3, &(s3)); \ + _carry = _FLINT_SBB(_carry, a4, b4, &(s4)); \ + _FLINT_SBB(_carry, a5, b5, &(s5)); \ +} while (0) + +#define sub_dddddddmmmmmmmsssssss(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ +do \ +{ \ + unsigned char _carry; \ + _carry = _FLINT_SBB(0, a0, b0, &(s0)); \ + _carry = _FLINT_SBB(_carry, a1, b1, &(s1)); \ + _carry = _FLINT_SBB(_carry, a2, b2, &(s2)); \ + _carry = _FLINT_SBB(_carry, a3, b3, &(s3)); \ + _carry = _FLINT_SBB(_carry, a4, b4, &(s4)); \ + _carry = _FLINT_SBB(_carry, a5, b5, &(s5)); \ + _FLINT_SBB(_carry, a6, b6, &(s6)); \ +} while (0) + + +#define sub_ddddddddmmmmmmmmssssssss(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ +do \ +{ \ + unsigned char _carry; \ + _carry = _FLINT_SBB(0, a0, b0, &(s0)); \ + _carry = _FLINT_SBB(_carry, a1, b1, &(s1)); \ + _carry = _FLINT_SBB(_carry, a2, b2, &(s2)); \ + _carry = _FLINT_SBB(_carry, a3, b3, &(s3)); \ + _carry = _FLINT_SBB(_carry, a4, b4, &(s4)); \ + _carry = _FLINT_SBB(_carry, a5, b5, &(s5)); \ + _carry = _FLINT_SBB(_carry, a6, b6, &(s6)); \ + _FLINT_SBB(_carry, a7, b7, &(s7)); \ +} while (0) + /* Division */ #define udiv_qrnnd(q, r, n1, n0, dx) do { (q) = _FLINT_DIV(n1, n0, dx, &(r)); } while (0) #define sdiv_qrnnd(q, r, n1, n0, dx) do { (q) = _FLINT_IDIV(n1, n0, dx, &(r)); } while (0) From 20b28e8b84ac2db147f5eddad40bdfaf833a3fa0 Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Mon, 19 Aug 2024 23:23:56 +0200 Subject: [PATCH 04/15] test code --- src/mpn_extras.h | 41 +++++++++++++++++ src/test/main.c | 2 - src/test/t-add_sssaaaaaa.c | 86 +++++++++++++++++++++++++++-------- src/test/t-add_ssssaaaaaaaa.c | 66 --------------------------- src/test/t-sub_dddmmmsss.c | 85 ++++++++++++++++++++++++++-------- 5 files changed, 176 insertions(+), 104 deletions(-) delete mode 100644 src/test/t-add_ssssaaaaaaaa.c diff --git a/src/mpn_extras.h b/src/mpn_extras.h index fd9b346df9..90fc8e6436 100644 --- a/src/mpn_extras.h +++ b/src/mpn_extras.h @@ -190,6 +190,47 @@ flint_mpn_signed_sub_n(mp_ptr res, mp_srcptr x, mp_srcptr y, mp_size_t n) } } +/* add without carry in or carry out */ +#define NN_ADD_2(r, u, v) add_ssaaaa((r)[1], (r)[0], (u)[1], (u)[0], (v)[1], (v)[0]) +#define NN_ADD_3(r, u, v) add_sssaaaaaa((r)[2], (r)[1], (r)[0], (u)[2], (u)[1], (u)[0], (v)[2], (v)[1], (v)[0]) +#define NN_ADD_4(r, u, v) add_ssssaaaaaaaa((r)[3], (r)[2], (r)[1], (r)[0], (u)[3], (u)[2], (u)[1], (u)[0], (v)[3], (v)[2], (v)[1], (v)[0]) +#define NN_ADD_5(r, u, v) add_sssssaaaaaaaaaa((r)[4], (r)[3], (r)[2], (r)[1], (r)[0], (u)[4], (u)[3], (u)[2], (u)[1], (u)[0], (v)[4], (v)[3], (v)[2], (v)[1], (v)[0]) +#define NN_ADD_6(r, u, v) add_ssssssaaaaaaaaaaaa((r)[5], (r)[4], (r)[3], (r)[2], (r)[1], (r)[0], (u)[5], (u)[4], (u)[3], (u)[2], (u)[1], (u)[0], (v)[5], (v)[4], (v)[3], (v)[2], (v)[1], (v)[0]) +#define NN_ADD_7(r, u, v) add_sssssssaaaaaaaaaaaaaa((r)[6], (r)[5], (r)[4], (r)[3], (r)[2], (r)[1], (r)[0], (u)[6], (u)[5], (u)[4], (u)[3], (u)[2], (u)[1], (u)[0], (v)[6], (v)[5], (v)[4], (v)[3], (v)[2], (v)[1], (v)[0]) +#define NN_ADD_8(r, u, v) add_ssssssssaaaaaaaaaaaaaaaa((r)[7], (r)[6], (r)[5], (r)[4], (r)[3], (r)[2], (r)[1], (r)[0], (u)[7], (u)[6], (u)[5], (u)[4], (u)[3], (u)[2], (u)[1], (u)[0], (v)[7], (v)[6], (v)[5], (v)[4], (v)[3], (v)[2], (v)[1], (v)[0]) + +#define NN_SUB_2(r, u, v) sub_ddmmss((r)[1], (r)[0], (u)[1], (u)[0], (v)[1], (v)[0]) +#define NN_SUB_3(r, u, v) sub_dddmmmsss((r)[2], (r)[1], (r)[0], (u)[2], (u)[1], (u)[0], (v)[2], (v)[1], (v)[0]) +#define NN_SUB_4(r, u, v) sub_ddddmmmmssss((r)[3], (r)[2], (r)[1], (r)[0], (u)[3], (u)[2], (u)[1], (u)[0], (v)[3], (v)[2], (v)[1], (v)[0]) +#define NN_SUB_5(r, u, v) sub_dddddmmmmmsssss((r)[4], (r)[3], (r)[2], (r)[1], (r)[0], (u)[4], (u)[3], (u)[2], (u)[1], (u)[0], (v)[4], (v)[3], (v)[2], (v)[1], (v)[0]) +#define NN_SUB_6(r, u, v) sub_ddddddmmmmmmssssss((r)[5], (r)[4], (r)[3], (r)[2], (r)[1], (r)[0], (u)[5], (u)[4], (u)[3], (u)[2], (u)[1], (u)[0], (v)[5], (v)[4], (v)[3], (v)[2], (v)[1], (v)[0]) +#define NN_SUB_7(r, u, v) sub_dddddddmmmmmmmsssssss((r)[6], (r)[5], (r)[4], (r)[3], (r)[2], (r)[1], (r)[0], (u)[6], (u)[5], (u)[4], (u)[3], (u)[2], (u)[1], (u)[0], (v)[6], (v)[5], (v)[4], (v)[3], (v)[2], (v)[1], (v)[0]) +#define NN_SUB_8(r, u, v) sub_ddddddddmmmmmmmmssssssss((r)[7], (r)[6], (r)[5], (r)[4], (r)[3], (r)[2], (r)[1], (r)[0], (u)[7], (u)[6], (u)[5], (u)[4], (u)[3], (u)[2], (u)[1], (u)[0], (v)[7], (v)[6], (v)[5], (v)[4], (v)[3], (v)[2], (v)[1], (v)[0]) + +#define DEF_SIGNED_SUB(n) \ +FLINT_FORCE_INLINE int \ +flint_mpn_signed_sub_ ## n(mp_ptr res, mp_srcptr x, mp_srcptr y) \ +{ \ + if (mpn_cmp(x, y, n) >= 0) \ + { \ + NN_SUB_ ## n(res, x, y); \ + return 0; \ + } \ + else \ + { \ + NN_SUB_ ## n(res, y, x); \ + return 1; \ + } \ +} + +DEF_SIGNED_SUB(2) +DEF_SIGNED_SUB(3) +DEF_SIGNED_SUB(4) +DEF_SIGNED_SUB(5) +DEF_SIGNED_SUB(6) +DEF_SIGNED_SUB(7) +DEF_SIGNED_SUB(8) + FLINT_FORCE_INLINE void flint_mpn_signed_div2(mp_ptr res, mp_srcptr x, mp_size_t n) { diff --git a/src/test/main.c b/src/test/main.c index 0243785ac4..c94c7b8415 100644 --- a/src/test/main.c +++ b/src/test/main.c @@ -13,7 +13,6 @@ #include "t-add_ssaaaa.c" #include "t-add_sssaaaaaa.c" -#include "t-add_ssssaaaaaaaa.c" #include "t-flint_clz.c" #include "t-flint_ctz.c" #include "t-io.c" @@ -32,7 +31,6 @@ test_struct tests[] = { TEST_FUNCTION(add_ssaaaa), TEST_FUNCTION(add_sssaaaaaa), - TEST_FUNCTION(add_ssssaaaaaaaa), TEST_FUNCTION(flint_clz), TEST_FUNCTION(flint_ctz), TEST_FUNCTION(flint_fprintf), diff --git a/src/test/t-add_sssaaaaaa.c b/src/test/t-add_sssaaaaaa.c index e697d9529b..b4b518eddb 100644 --- a/src/test/t-add_sssaaaaaa.c +++ b/src/test/t-add_sssaaaaaa.c @@ -12,17 +12,19 @@ #include #include "ulong_extras.h" +#include "mpn_extras.h" #include "test_helpers.h" TEST_FUNCTION_START(add_sssaaaaaa, state) { - int i, j, result; + int i, j, n, result; for (i = 0; i < 100000 * flint_test_multiplier(); i++) { - ulong s[3], t[3], a[3], b[3]; + ulong s[8], t[8], a[8], b[8]; + int aliasing; - for (j = 0; j < 3; j++) + for (j = 0; j < 8; j++) { s[j] = n_randtest(state); t[j] = n_randtest(state); @@ -30,21 +32,69 @@ TEST_FUNCTION_START(add_sssaaaaaa, state) b[j] = n_randtest(state); } - add_sssaaaaaa(s[2], s[1], s[0], a[2], a[1], a[0], b[2], b[1], b[0]); - - mpn_add_n(t, a, b, 3); - - result = ((s[2] == t[2]) && (s[1] == t[1]) && (s[0] == t[0])); - if (!result) - TEST_FUNCTION_FAIL( - "a[2] = %wu, a[1] = %wu, a[0] = %wu\n" - "b[2] = %wu, b[1] = %wu, b[0] = %wu\n" - "s[2] = %wu, s[1] = %wu, s[0] = %wu\n" - "t[2] = %wu, t[1] = %wu, t[0] = %wu\n", - a[2], a[1], a[0], - b[2], b[1], b[0], - s[2], s[1], s[0], - t[2], t[1], t[0]); + aliasing = n_randint(state, 2); + + for (n = 2; n < 8; n++) + { + if (aliasing) + { + for (j = 0; j < 8; j++) + s[j] = a[j]; + + if (n == 2) + NN_ADD_2(s, s, b); + else if (n == 3) + NN_ADD_3(s, s, b); + else if (n == 4) + NN_ADD_4(s, s, b); + else if (n == 5) + NN_ADD_5(s, s, b); + else if (n == 6) + NN_ADD_6(s, s, b); + else if (n == 7) + NN_ADD_7(s, s, b); + else if (n == 8) + NN_ADD_8(s, s, b); + } + else + { + if (n == 2) + NN_ADD_2(s, a, b); + else if (n == 3) + NN_ADD_3(s, a, b); + else if (n == 4) + NN_ADD_4(s, a, b); + else if (n == 5) + NN_ADD_5(s, a, b); + else if (n == 6) + NN_ADD_6(s, a, b); + else if (n == 7) + NN_ADD_7(s, a, b); + else if (n == 8) + NN_ADD_8(s, a, b); + } + + mpn_add_n(t, a, b, n); + + result = flint_mpn_equal_p(s, t, n); + + if (!result) + { + TEST_FUNCTION_FAIL( + "Aliasing: %d\n" + "n = %d\n" + "a = %{ulong*}\n" + "b = %{ulong*}\n" + "s = %{ulong*}\n" + "t = %{ulong*}\n", + aliasing, + n, + a, n, + b, n, + s, n, + t, n); + } + } } TEST_FUNCTION_END(state); diff --git a/src/test/t-add_ssssaaaaaaaa.c b/src/test/t-add_ssssaaaaaaaa.c deleted file mode 100644 index bcda402d31..0000000000 --- a/src/test/t-add_ssssaaaaaaaa.c +++ /dev/null @@ -1,66 +0,0 @@ -/* - Copyright (C) 2019 Daniel Schultz - - This file is part of FLINT. - - FLINT is free software: you can redistribute it and/or modify it under - the terms of the GNU Lesser General Public License (LGPL) as published - by the Free Software Foundation; either version 3 of the License, or - (at your option) any later version. See . -*/ - -#include "ulong_extras.h" -#include "test_helpers.h" - -TEST_FUNCTION_START(add_ssssaaaaaaaa, state) -{ - int i, j, result; - - for (i = 0; i < 100000 * flint_test_multiplier(); i++) - { - ulong s[4], t[4], a[4], b[4]; - int aliasing; - - for (j = 0; j < 4; j++) - { - s[j] = n_randtest(state); - t[j] = n_randtest(state); - a[j] = n_randtest(state); - b[j] = n_randtest(state); - } - - aliasing = n_randint(state, 2); - - if (aliasing) - { - for (j = 0; j < 4; j++) - s[j] = a[j]; - - add_ssssaaaaaaaa(s[3], s[2], s[1], s[0], s[3], s[2], s[1], s[0], - b[3], b[2], b[1], b[0]); - } - else - { - add_ssssaaaaaaaa(s[3], s[2], s[1], s[0], a[3], a[2], a[1], a[0], - b[3], b[2], b[1], b[0]); - } - - mpn_add_n(t, a, b, 4); - - result = ((s[3] == t[3]) && (s[2] == t[2]) && (s[1] == t[1]) && (s[0] == t[0])); - if (!result) - TEST_FUNCTION_FAIL( - "Aliasing: %d\n" - "a[3] = %wu, a[2] = %wu, a[1] = %wu, a[0] = %wu\n" - "b[3] = %wu, b[2] = %wu, b[1] = %wu, b[0] = %wu\n" - "s[3] = %wu, s[2] = %wu, s[1] = %wu, s[0] = %wu\n" - "t[3] = %wu, t[2] = %wu, t[1] = %wu, t[0] = %wu\n", - aliasing, - a[3], a[2], a[1], a[0], - b[3], b[2], b[1], b[0], - s[3], s[2], s[1], s[0], - t[3], t[2], t[1], t[0]); - } - - TEST_FUNCTION_END(state); -} diff --git a/src/test/t-sub_dddmmmsss.c b/src/test/t-sub_dddmmmsss.c index 61bc3838e3..b250becdff 100644 --- a/src/test/t-sub_dddmmmsss.c +++ b/src/test/t-sub_dddmmmsss.c @@ -15,13 +15,14 @@ TEST_FUNCTION_START(sub_dddmmmsss, state) { - int i, j, result; + int i, j, n, result; for (i = 0; i < 100000 * flint_test_multiplier(); i++) { - ulong s[3], t[3], a[3], b[3]; + ulong s[8], t[8], a[8], b[8]; + int aliasing; - for (j = 0; j < 3; j++) + for (j = 0; j < 8; j++) { s[j] = n_randtest(state); t[j] = n_randtest(state); @@ -29,21 +30,69 @@ TEST_FUNCTION_START(sub_dddmmmsss, state) b[j] = n_randtest(state); } - sub_dddmmmsss(s[2], s[1], s[0], a[2], a[1], a[0], b[2], b[1], b[0]); - - mpn_sub_n(t, a, b, 3); - - result = ((s[2] == t[2]) && (s[1] == t[1]) && (s[0] == t[0])); - if (!result) - TEST_FUNCTION_FAIL( - "a[2] = %wu, a[1] = %wu, a[0] = %wu\n" - "b[2] = %wu, b[1] = %wu, b[0] = %wu\n" - "s[2] = %wu, s[1] = %wu, s[0] = %wu\n" - "t[2] = %wu, t[1] = %wu, t[0] = %wu\n", - a[2], a[1], a[0], - b[2], b[1], b[0], - s[2], s[1], s[0], - t[2], t[1], t[0]); + aliasing = n_randint(state, 2); + + for (n = 2; n < 8; n++) + { + if (aliasing) + { + for (j = 0; j < 8; j++) + s[j] = a[j]; + + if (n == 2) + NN_SUB_2(s, s, b); + else if (n == 3) + NN_SUB_3(s, s, b); + else if (n == 4) + NN_SUB_4(s, s, b); + else if (n == 5) + NN_SUB_5(s, s, b); + else if (n == 6) + NN_SUB_6(s, s, b); + else if (n == 7) + NN_SUB_7(s, s, b); + else if (n == 8) + NN_SUB_8(s, s, b); + } + else + { + if (n == 2) + NN_SUB_2(s, a, b); + else if (n == 3) + NN_SUB_3(s, a, b); + else if (n == 4) + NN_SUB_4(s, a, b); + else if (n == 5) + NN_SUB_5(s, a, b); + else if (n == 6) + NN_SUB_6(s, a, b); + else if (n == 7) + NN_SUB_7(s, a, b); + else if (n == 8) + NN_SUB_8(s, a, b); + } + + mpn_sub_n(t, a, b, n); + + result = flint_mpn_equal_p(s, t, n); + + if (!result) + { + TEST_FUNCTION_FAIL( + "Aliasing: %d\n" + "n = %d\n" + "a = %{ulong*}\n" + "b = %{ulong*}\n" + "s = %{ulong*}\n" + "t = %{ulong*}\n", + aliasing, + n, + a, n, + b, n, + s, n, + t, n); + } + } } TEST_FUNCTION_END(state); From d238c29409dbc875b8789e9a8bcc98e5ea3f5817 Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Mon, 19 Aug 2024 23:48:13 +0200 Subject: [PATCH 05/15] try to fix 32-bit x86 --- src/longlong.h | 38 +++++++++++---------- src/longlong_asm_gcc.h | 77 ++++++++++++++++++++++-------------------- 2 files changed, 62 insertions(+), 53 deletions(-) diff --git a/src/longlong.h b/src/longlong.h index ea55421922..375f86dad3 100644 --- a/src/longlong.h +++ b/src/longlong.h @@ -108,7 +108,6 @@ flint_bitcnt_t FLINT_BIT_COUNT(ulong x) /* Addition and subtraction */ #if !defined(add_ssaaaa) - # define add_ssaaaa(s1, s0, a1, a0, b1, b0) \ do { \ ulong __t0 = (a0); \ @@ -148,22 +147,6 @@ flint_bitcnt_t FLINT_BIT_COUNT(ulong x) add_ssaaaa(s5, s4, s5, s4, (ulong) 0, __t1); \ } while (0) -#define add_sssssssaaaaaaaaaaaaaa(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ - do { \ - ulong __t2 = 0; \ - add_ssssssaaaaaaaaaaaa(__t2, s4, s3, s2, s1, s0, (ulong) 0, a4, a3, a2, a1, a0, (ulong) 0, b4, b3, b2, b1, b0); \ - add_ssaaaa(s6, s5, a6, a5, b6, b5); \ - add_ssaaaa(s6, s5, s6, s5, (ulong) 0, __t2); \ - } while (0) - -#define add_ssssssssaaaaaaaaaaaaaaaa(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ - do { \ - ulong __t3 = 0; \ - add_sssssssaaaaaaaaaaaaaa(__t3, s5, s4, s3, s2, s1, s0, (ulong) 0, a5, a4, a3, a2, a1, a0, (ulong) 0, b5, b4, b3, b2, b1, b0); \ - add_ssaaaa(s7, s6, a7, a6, b7, b6); \ - add_ssaaaa(s7, s6, s7, s6, (ulong) 0, __t3); \ - } while (0) - # define sub_ddmmss(s1, s0, a1, a0, b1, b0) \ do { \ ulong __t0 = (a0); \ @@ -203,6 +186,27 @@ flint_bitcnt_t FLINT_BIT_COUNT(ulong x) sub_ddmmss(s5, s4, (a5) - (b5), s4, -__u3, -__t3); \ } while (0) +#endif + +/* extra wide variants might not have inline asm if there are not enough registers */ +#if !defined(add_sssssssaaaaaaaaaaaaaa) + +#define add_sssssssaaaaaaaaaaaaaa(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ + do { \ + ulong __t2 = 0; \ + add_ssssssaaaaaaaaaaaa(__t2, s4, s3, s2, s1, s0, (ulong) 0, a4, a3, a2, a1, a0, (ulong) 0, b4, b3, b2, b1, b0); \ + add_ssaaaa(s6, s5, a6, a5, b6, b5); \ + add_ssaaaa(s6, s5, s6, s5, (ulong) 0, __t2); \ + } while (0) + +#define add_ssssssssaaaaaaaaaaaaaaaa(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ + do { \ + ulong __t3 = 0; \ + add_sssssssaaaaaaaaaaaaaa(__t3, s5, s4, s3, s2, s1, s0, (ulong) 0, a5, a4, a3, a2, a1, a0, (ulong) 0, b5, b4, b3, b2, b1, b0); \ + add_ssaaaa(s7, s6, a7, a6, b7, b6); \ + add_ssaaaa(s7, s6, s7, s6, (ulong) 0, __t3); \ + } while (0) + #define sub_dddddddmmmmmmmsssssss(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ do { \ ulong __t4, __u4; \ diff --git a/src/longlong_asm_gcc.h b/src/longlong_asm_gcc.h index 5f4925775e..1deea2f66c 100644 --- a/src/longlong_asm_gcc.h +++ b/src/longlong_asm_gcc.h @@ -95,42 +95,6 @@ "4" ((ulong)(a1)), _ASM_RME ((ulong)(b1)), \ "5" ((ulong)(a0)), _ASM_RME ((ulong)(b0))) -# define add_sssssssaaaaaaaaaaaaaa(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ - __asm__(_ASM_ADD " %20,%" _ASM_PRE "6\n" \ - "\t" _ASM_ADC " %18,%" _ASM_PRE "5\n" \ - "\t" _ASM_ADC " %16,%" _ASM_PRE "4\n" \ - "\t" _ASM_ADC " %14,%" _ASM_PRE "3\n" \ - "\t" _ASM_ADC " %12,%" _ASM_PRE "2\n" \ - "\t" _ASM_ADC " %10,%" _ASM_PRE "1\n" \ - "\t" _ASM_ADC " %8,%" _ASM_PRE "0" \ - : "=r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a6)), _ASM_RME ((ulong)(b6)), \ - "1" ((ulong)(a5)), _ASM_RME ((ulong)(b5)), \ - "2" ((ulong)(a4)), _ASM_RME ((ulong)(b4)), \ - "3" ((ulong)(a3)), _ASM_RME ((ulong)(b3)), \ - "4" ((ulong)(a2)), _ASM_RME ((ulong)(b2)), \ - "5" ((ulong)(a1)), _ASM_RME ((ulong)(b1)), \ - "6" ((ulong)(a0)), _ASM_RME ((ulong)(b0))) - -# define add_ssssssssaaaaaaaaaaaaaaaa(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ - __asm__(_ASM_ADD " %23,%" _ASM_PRE "7\n" \ - "\t" _ASM_ADC " %21,%" _ASM_PRE "6\n" \ - "\t" _ASM_ADC " %19,%" _ASM_PRE "5\n" \ - "\t" _ASM_ADC " %17,%" _ASM_PRE "4\n" \ - "\t" _ASM_ADC " %15,%" _ASM_PRE "3\n" \ - "\t" _ASM_ADC " %13,%" _ASM_PRE "2\n" \ - "\t" _ASM_ADC " %11,%" _ASM_PRE "1\n" \ - "\t" _ASM_ADC " %9,%" _ASM_PRE "0" \ - : "=r" (s7), "=&r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ - : "0" ((ulong)(a7)), _ASM_RME ((ulong)(b7)), \ - "1" ((ulong)(a6)), _ASM_RME ((ulong)(b6)), \ - "2" ((ulong)(a5)), _ASM_RME ((ulong)(b5)), \ - "3" ((ulong)(a4)), _ASM_RME ((ulong)(b4)), \ - "4" ((ulong)(a3)), _ASM_RME ((ulong)(b3)), \ - "5" ((ulong)(a2)), _ASM_RME ((ulong)(b2)), \ - "6" ((ulong)(a1)), _ASM_RME ((ulong)(b1)), \ - "7" ((ulong)(a0)), _ASM_RME ((ulong)(b0))) - # define sub_ddmmss(d1, d0, m1, m0, s1, s0) \ __asm__(_ASM_SUB " %5,%" _ASM_PRE "1\n" \ "\t" _ASM_SBB " %3,%" _ASM_PRE "0" \ @@ -186,6 +150,45 @@ "4" ((ulong)(m1)), _ASM_RME ((ulong)(s1)), \ "5" ((ulong)(m0)), _ASM_RME ((ulong)(s0))) +/* x86 does not have enough registers */ +# if FLINT_BITS == 64 && defined (__amd64__) + +# define add_sssssssaaaaaaaaaaaaaa(s6, s5, s4, s3, s2, s1, s0, a6, a5, a4, a3, a2, a1, a0, b6, b5, b4, b3, b2, b1, b0) \ + __asm__(_ASM_ADD " %20,%" _ASM_PRE "6\n" \ + "\t" _ASM_ADC " %18,%" _ASM_PRE "5\n" \ + "\t" _ASM_ADC " %16,%" _ASM_PRE "4\n" \ + "\t" _ASM_ADC " %14,%" _ASM_PRE "3\n" \ + "\t" _ASM_ADC " %12,%" _ASM_PRE "2\n" \ + "\t" _ASM_ADC " %10,%" _ASM_PRE "1\n" \ + "\t" _ASM_ADC " %8,%" _ASM_PRE "0" \ + : "=r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "0" ((ulong)(a6)), _ASM_RME ((ulong)(b6)), \ + "1" ((ulong)(a5)), _ASM_RME ((ulong)(b5)), \ + "2" ((ulong)(a4)), _ASM_RME ((ulong)(b4)), \ + "3" ((ulong)(a3)), _ASM_RME ((ulong)(b3)), \ + "4" ((ulong)(a2)), _ASM_RME ((ulong)(b2)), \ + "5" ((ulong)(a1)), _ASM_RME ((ulong)(b1)), \ + "6" ((ulong)(a0)), _ASM_RME ((ulong)(b0))) + +# define add_ssssssssaaaaaaaaaaaaaaaa(s7, s6, s5, s4, s3, s2, s1, s0, a7, a6, a5, a4, a3, a2, a1, a0, b7, b6, b5, b4, b3, b2, b1, b0) \ + __asm__(_ASM_ADD " %23,%" _ASM_PRE "7\n" \ + "\t" _ASM_ADC " %21,%" _ASM_PRE "6\n" \ + "\t" _ASM_ADC " %19,%" _ASM_PRE "5\n" \ + "\t" _ASM_ADC " %17,%" _ASM_PRE "4\n" \ + "\t" _ASM_ADC " %15,%" _ASM_PRE "3\n" \ + "\t" _ASM_ADC " %13,%" _ASM_PRE "2\n" \ + "\t" _ASM_ADC " %11,%" _ASM_PRE "1\n" \ + "\t" _ASM_ADC " %9,%" _ASM_PRE "0" \ + : "=r" (s7), "=&r" (s6), "=&r" (s5), "=&r" (s4), "=&r" (s3), "=&r" (s2), "=&r" (s1), "=&r" (s0) \ + : "0" ((ulong)(a7)), _ASM_RME ((ulong)(b7)), \ + "1" ((ulong)(a6)), _ASM_RME ((ulong)(b6)), \ + "2" ((ulong)(a5)), _ASM_RME ((ulong)(b5)), \ + "3" ((ulong)(a4)), _ASM_RME ((ulong)(b4)), \ + "4" ((ulong)(a3)), _ASM_RME ((ulong)(b3)), \ + "5" ((ulong)(a2)), _ASM_RME ((ulong)(b2)), \ + "6" ((ulong)(a1)), _ASM_RME ((ulong)(b1)), \ + "7" ((ulong)(a0)), _ASM_RME ((ulong)(b0))) + # define sub_dddddddmmmmmmmsssssss(d6, d5, d4, d3, d2, d1, d0, m6, m5, m4, m3, m2, m1, m0, s6, s5, s4, s3, s2, s1, s0) \ __asm__(_ASM_SUB " %20,%" _ASM_PRE "6\n" \ "\t" _ASM_SBB " %18,%" _ASM_PRE "5\n" \ @@ -222,6 +225,8 @@ "6" ((ulong)(m1)), _ASM_RME ((ulong)(s1)), \ "7" ((ulong)(m0)), _ASM_RME ((ulong)(s0))) +#endif + # if defined(__BMI2__) && defined(__amd64__) # define umul_ppmm(w1, w0, u, v) \ __asm__("mulx\t%3, %q0, %q1" \ From 46767184b1ca1160cb913978b884d11a18fb5a6f Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Mon, 2 Sep 2024 09:56:10 +0200 Subject: [PATCH 06/15] matrix multiplication wip --- src/nfloat/mat_mul.c | 1021 ++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 979 insertions(+), 42 deletions(-) diff --git a/src/nfloat/mat_mul.c b/src/nfloat/mat_mul.c index dd496c9bc6..bd1ec02ac3 100644 --- a/src/nfloat/mat_mul.c +++ b/src/nfloat/mat_mul.c @@ -11,20 +11,15 @@ #include "mpn_extras.h" #include "gr.h" +#include "gr_vec.h" #include "gr_mat.h" #include "gr_generic.h" #include "acf.h" #include "acb.h" #include "nfloat.h" - -#include "gr.h" -#include "nfloat.h" -#include "gr_vec.h" -#include "gr_mat.h" #include "gr_special.h" #include "fmpz_mat.h" - int nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); @@ -46,9 +41,59 @@ nfixed_print(nn_srcptr x, slong nlimbs, slong exp) arf_clear(t); } +#define DEF_NFIXED_ADD(n) \ +FLINT_FORCE_INLINE \ +void nfixed_add_ ## n(nn_ptr res, nn_srcptr a, nn_srcptr b) \ +{ \ + int asgn, bsgn; \ + asgn = a[0]; \ + bsgn = b[0]; \ + \ + if (asgn == bsgn) \ + { \ + res[0] = asgn; \ + NN_ADD_ ## n(res + 1, a + 1, b + 1); \ + } \ + else \ + { \ + res[0] = asgn ^ flint_mpn_signed_sub_ ## n(res + 1, a + 1, b + 1); \ + } \ +} + +#define DEF_NFIXED_SUB(n) \ +FLINT_FORCE_INLINE \ +void nfixed_sub_ ## n(nn_ptr res, nn_srcptr a, nn_srcptr b) \ +{ \ + int asgn, bsgn; \ + asgn = a[0]; \ + bsgn = b[0]; \ + \ + if (asgn != bsgn) \ + { \ + res[0] = asgn; \ + NN_ADD_ ## n(res + 1, a + 1, b + 1); \ + } \ + else \ + { \ + res[0] = asgn ^ flint_mpn_signed_sub_ ## n(res + 1, a + 1, b + 1); \ + } \ +} -/* todo: don't do this */ -#define NFIXED_MAX_NLIMBS (2 * NFLOAT_MAX_LIMBS) +DEF_NFIXED_ADD(2) +DEF_NFIXED_ADD(3) +DEF_NFIXED_ADD(4) +DEF_NFIXED_ADD(5) +DEF_NFIXED_ADD(6) +DEF_NFIXED_ADD(7) +DEF_NFIXED_ADD(8) + +DEF_NFIXED_SUB(2) +DEF_NFIXED_SUB(3) +DEF_NFIXED_SUB(4) +DEF_NFIXED_SUB(5) +DEF_NFIXED_SUB(6) +DEF_NFIXED_SUB(7) +DEF_NFIXED_SUB(8) FLINT_FORCE_INLINE void nfixed_add(nn_ptr res, nn_srcptr a, nn_srcptr b, slong nlimbs) @@ -86,22 +131,98 @@ void nfixed_sub(nn_ptr res, nn_srcptr a, nn_srcptr b, slong nlimbs) } } -FLINT_FORCE_INLINE +static void _nfixed_vec_add(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs) { slong i; - for (i = 0; i < len; i++) - nfixed_add(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1), nlimbs); + if (nlimbs == 2) + { + for (i = 0; i < len; i++) + nfixed_add_2(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 3) + { + for (i = 0; i < len; i++) + nfixed_add_3(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 4) + { + for (i = 0; i < len; i++) + nfixed_add_4(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 5) + { + for (i = 0; i < len; i++) + nfixed_add_5(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 6) + { + for (i = 0; i < len; i++) + nfixed_add_6(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 7) + { + for (i = 0; i < len; i++) + nfixed_add_7(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 8) + { + for (i = 0; i < len; i++) + nfixed_add_8(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else + { + for (i = 0; i < len; i++) + nfixed_add(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1), nlimbs); + } } -FLINT_FORCE_INLINE +static void _nfixed_vec_sub(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs) { slong i; - for (i = 0; i < len; i++) - nfixed_sub(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1), nlimbs); + if (nlimbs == 2) + { + for (i = 0; i < len; i++) + nfixed_sub_2(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 3) + { + for (i = 0; i < len; i++) + nfixed_sub_3(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 4) + { + for (i = 0; i < len; i++) + nfixed_sub_4(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 5) + { + for (i = 0; i < len; i++) + nfixed_sub_5(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 6) + { + for (i = 0; i < len; i++) + nfixed_sub_6(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 7) + { + for (i = 0; i < len; i++) + nfixed_sub_7(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 8) + { + for (i = 0; i < len; i++) + nfixed_sub_8(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else + { + for (i = 0; i < len; i++) + nfixed_sub(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1), nlimbs); + } } FLINT_FORCE_INLINE @@ -128,37 +249,385 @@ void nfixed_div2(nn_ptr res, nn_srcptr a, slong nlimbs) mpn_rshift(res + 1, a + 1, nlimbs, 1); } +void _nfixed_dot_2(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) +{ + slong j; + + ulong as, bs, a0, a1, b0, b1, s0, s1, t0, t1, u0, u1, hi, lo; + + s0 = s1 = t0 = t1 = u0 = u1 = 0; + + /* + s0 s1 + |a1 b1----| + u0 u1 + |a1 b0----| + t0 t1 + |a0 b1----| + */ + + for (j = 0; j < len; j++) + { + as = x[j * xstride]; + a0 = x[j * xstride + 1]; + a1 = x[j * xstride + 2]; + bs = y[j * ystride]; + b0 = y[j * ystride + 1]; + b1 = y[j * ystride + 2]; + + if (as == bs) + { + umul_ppmm(hi, lo, a1, b1); + add_ssaaaa(s1, s0, s1, s0, hi, lo); + umul_ppmm(hi, lo, a0, b1); + add_ssaaaa(t1, t0, t1, t0, 0, hi); + umul_ppmm(hi, lo, a1, b0); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + } + else + { + umul_ppmm(hi, lo, a1, b1); + sub_ddmmss(s1, s0, s1, s0, hi, lo); + umul_ppmm(hi, lo, a0, b1); + sub_ddmmss(t1, t0, t1, t0, 0, hi); + umul_ppmm(hi, lo, a1, b0); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + } + } + + add_ssaaaa(s1, s0, s1, s0, t1, t0); + add_ssaaaa(s1, s0, s1, s0, u1, u0); + + if ((slong) s1 < WORD(0)) + { + sub_ddmmss(s1, s0, 0, 0, s1, s0); + res[0] = 1; + } + else + { + res[0] = 0; + } + + res[1] = s0; + res[2] = s1; +} + +void _nfixed_dot_3(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) +{ + slong j; + + ulong as, bs, a0, a1, a2, b0, b1, b2, s0, s1, s2, hi, lo; + ulong u0, u1, v0, v1; + + s0 = s1 = s2 = 0; + u0 = u1 = v0 = v1 = 0; + + for (j = 0; j < len; j++) + { + as = x[j * xstride]; + a0 = x[j * xstride + 1]; + a1 = x[j * xstride + 2]; + a2 = x[j * xstride + 3]; + bs = y[j * ystride]; + b0 = y[j * ystride + 1]; + b1 = y[j * ystride + 2]; + b2 = y[j * ystride + 3]; + + /* + |a2 b2----| + + |a1 b2----| + |a2 b1----| + + |a0 b2----| + |a1 b1----| + |a2 b0----| + */ + + if (as == bs) + { + umul_ppmm(hi, lo, a0, b2); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a1, b1); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a2, b0); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + + umul_ppmm(hi, lo, a2, b2); + add_ssaaaa(s2, s1, s2, s1, hi, lo); + + umul_ppmm(hi, lo, a1, b2); + add_sssaaaaaa(s2, s1, s0, s2, s1, s0, 0, hi, lo); + umul_ppmm(hi, lo, a2, b1); + add_sssaaaaaa(s2, s1, s0, s2, s1, s0, 0, hi, lo); + } + else + { + umul_ppmm(hi, lo, a0, b2); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a1, b1); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a2, b0); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + + umul_ppmm(hi, lo, a2, b2); + sub_ddmmss(s2, s1, s2, s1, hi, lo); + + umul_ppmm(hi, lo, a1, b2); + sub_dddmmmsss(s2, s1, s0, s2, s1, s0, 0, hi, lo); + umul_ppmm(hi, lo, a2, b1); + sub_dddmmmsss(s2, s1, s0, s2, s1, s0, 0, hi, lo); + } + } + + if ((slong) u1 < WORD(0)) + { + sub_ddmmss(u1, u0, 0, 0, u1, u0); + sub_dddmmmsss(s2, s1, s0, s2, s1, s0, 0, u1, u0); + } + else + { + add_sssaaaaaa(s2, s1, s0, s2, s1, s0, 0, u1, u0); + } + + if ((slong) s2 < WORD(0)) + { + sub_dddmmmsss(s2, s1, s0, 0, 0, 0, s2, s1, s0); + res[0] = 1; + } + else + { + res[0] = 0; + } + + + res[1] = s0; + res[2] = s1; + res[3] = s2; +} + +void _nfixed_dot_4(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) +{ + slong j; + + /* + s1 s2 s3 + |a3 b3----| + + |a2 b3----| + |a3 b2----| + + t0 t1 t2 + + |a1 b3----| + |a2 b2----| + |a3 b1----| + + u0 u1 + + |a0 b3----| + |a1 b2----| + |a2 b1----| + |a3 b0----| + + */ + + ulong as, a0, a1, a2, a3; + ulong bs, b0, b1, b2, b3; + ulong s0, s1, s2, s3, t0, t1, t2, u0, u1; + ulong hi, lo; + + s0 = s1 = s2 = s3 = t0 = t1 = t2 = u0 = u1 = 0; + + for (j = 0; j < len; j++) + { + as = x[j * xstride]; + a0 = x[j * xstride + 1]; + a1 = x[j * xstride + 2]; + a2 = x[j * xstride + 3]; + a3 = x[j * xstride + 4]; + + bs = y[j * ystride]; + b0 = y[j * ystride + 1]; + b1 = y[j * ystride + 2]; + b2 = y[j * ystride + 3]; + b3 = y[j * ystride + 4]; + + if (as == bs) + { + umul_ppmm(hi, lo, a3, b3); + add_ssaaaa(s3, s2, s3, s2, hi, lo); + + umul_ppmm(hi, lo, a2, b3); + add_sssaaaaaa(s3, s2, s1, s3, s2, s1, 0, hi, lo); + umul_ppmm(hi, lo, a3, b2); + add_sssaaaaaa(s3, s2, s1, s3, s2, s1, 0, hi, lo); + + umul_ppmm(hi, lo, a1, b3); + add_sssaaaaaa(t2, t1, t0, t2, t1, t0, 0, hi, lo); + umul_ppmm(hi, lo, a2, b2); + add_sssaaaaaa(t2, t1, t0, t2, t1, t0, 0, hi, lo); + umul_ppmm(hi, lo, a3, b1); + add_sssaaaaaa(t2, t1, t0, t2, t1, t0, 0, hi, lo); + + umul_ppmm(hi, lo, a0, b3); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a1, b2); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a2, b1); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a3, b0); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + } + else + { + umul_ppmm(hi, lo, a3, b3); + sub_ddmmss(s3, s2, s3, s2, hi, lo); + + umul_ppmm(hi, lo, a2, b3); + sub_dddmmmsss(s3, s2, s1, s3, s2, s1, 0, hi, lo); + umul_ppmm(hi, lo, a3, b2); + sub_dddmmmsss(s3, s2, s1, s3, s2, s1, 0, hi, lo); + + umul_ppmm(hi, lo, a1, b3); + sub_dddmmmsss(t2, t1, t0, t2, t1, t0, 0, hi, lo); + umul_ppmm(hi, lo, a2, b2); + sub_dddmmmsss(t2, t1, t0, t2, t1, t0, 0, hi, lo); + umul_ppmm(hi, lo, a3, b1); + sub_dddmmmsss(t2, t1, t0, t2, t1, t0, 0, hi, lo); + + umul_ppmm(hi, lo, a0, b3); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a1, b2); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a2, b1); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a3, b0); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + } + } + + if ((slong) u1 < WORD(0)) + { + sub_ddmmss(u1, u0, 0, 0, u1, u0); + sub_dddmmmsss(t2, t1, s0, t2, t1, s0, 0, u1, u0); + } + else + { + add_sssaaaaaa(t2, t1, t0, t2, t1, t0, 0, u1, u0); + } + + if ((slong) t1 < WORD(0)) + { + sub_dddmmmsss(t2, t1, t0, 0, 0, 0, t2, t1, t0); + sub_ddddmmmmssss(s3, s2, s1, s0, s3, s2, s1, s0, 0, t2, t1, t0); + } + else + { + add_ssssaaaaaaaa(s3, s2, s1, s0, s3, s2, s1, s0, 0, t2, t1, t0); + } + + if ((slong) s3 < WORD(0)) + { + sub_ddddmmmmssss(s3, s2, s1, s0, 0, 0, 0, 0, s3, s2, s1, s0); + res[0] = 1; + } + else + { + res[0] = 0; + } + + res[1] = s0; + res[2] = s1; + res[3] = s2; + res[4] = s3; +} + +void _nfixed_dot_5(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) +{ + slong j; + slong nlimbs = 5; + + ulong tmp[6]; + ulong spos[6] = { 0, 0, 0, 0, 0, 0 }; + ulong sneg[6] = { 0, 0, 0, 0, 0, 0 }; + + if (x[0] == y[0]) + flint_mpn_mulhigh_n(spos + 1, x + 1, y + 1, nlimbs); + else + flint_mpn_mulhigh_n(sneg + 1, x + 1, y + 1, nlimbs); + + for (j = 1; j < len; j++) + { + nfixed_mul(tmp, x + j * xstride, y + ystride, nlimbs); + + if (tmp[0]) + NN_ADD_5(spos + 1, spos + 1, tmp + 1); + else + NN_ADD_5(sneg + 1, sneg + 1, tmp + 1); + } + + nfixed_sub_5(res, spos, sneg); +} + /* A is (m x n), B is (n x p), C is (m x p) */ void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) { slong i, j, k; nn_ptr t; - TMP_INIT; - - TMP_START; - - t = TMP_ALLOC((nlimbs + 1) * sizeof(ulong)); #define A_ENTRY(i, j) ((A) + ((i) * n + (j)) * (nlimbs + 1)) #define B_ENTRY(i, j) ((B) + ((i) * p + (j)) * (nlimbs + 1)) #define C_ENTRY(i, j) ((C) + ((i) * p + (j)) * (nlimbs + 1)) - for (i = 0; i < m; i++) + if (nlimbs == 2) { - for (j = 0; j < p; j++) - { - nfixed_mul(C_ENTRY(i, j), A_ENTRY(i, 0), B_ENTRY(0, j), nlimbs); + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_2(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); + } + else if (nlimbs == 3) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_3(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); + } + else if (nlimbs == 4) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_4(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); + } + else if (nlimbs == 5) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_5(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); + } + else + { + TMP_INIT; + TMP_START; - for (k = 1; k < n; k++) + t = TMP_ALLOC((nlimbs + 1) * sizeof(ulong)); + + for (i = 0; i < m; i++) + { + for (j = 0; j < p; j++) { - nfixed_mul(t, A_ENTRY(i, k), B_ENTRY(k, j), nlimbs); - nfixed_add(C_ENTRY(i, j), C_ENTRY(i, j), t, nlimbs); + nfixed_mul(C_ENTRY(i, j), A_ENTRY(i, 0), B_ENTRY(0, j), nlimbs); + + for (k = 1; k < n; k++) + { + nfixed_mul(t, A_ENTRY(i, k), B_ENTRY(k, j), nlimbs); + nfixed_add(C_ENTRY(i, j), C_ENTRY(i, j), t, nlimbs); + } } } - } - TMP_END; + TMP_END; + } #undef A_ENTRY #undef B_ENTRY @@ -187,13 +656,61 @@ addmul_subsub(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_ nfixed_add(c, c, val0, nlimbs); } +FLINT_FORCE_INLINE void +addmul_addadd_4(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) +{ + nfixed_add_4(val1, a1, b1); + nfixed_add_4(val2, a2, b2); + nfixed_mul(val0, val1, val2, 5); + nfixed_add_4(c, c, val0); +} + +FLINT_FORCE_INLINE void +addmul_addadd_5(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) +{ + nfixed_add_5(val1, a1, b1); + nfixed_add_5(val2, a2, b2); + nfixed_mul(val0, val1, val2, 5); + nfixed_add_5(c, c, val0); +} + +FLINT_FORCE_INLINE void +addmul_addadd_6(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) +{ + nfixed_add_6(val1, a1, b1); + nfixed_add_6(val2, a2, b2); + nfixed_mul(val0, val1, val2, 5); + nfixed_add_6(c, c, val0); +} + +FLINT_FORCE_INLINE void +addmul_addadd_7(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) +{ + nfixed_add_7(val1, a1, b1); + nfixed_add_7(val2, a2, b2); + nfixed_mul(val0, val1, val2, 5); + nfixed_add_7(c, c, val0); +} + +FLINT_FORCE_INLINE void +addmul_addadd_8(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) +{ + nfixed_add_8(val1, a1, b1); + nfixed_add_8(val2, a2, b2); + nfixed_mul(val0, val1, val2, 5); + nfixed_add_8(c, c, val0); +} + void -_nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) +_nfixed_mat_mul_waksman2(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs, slong Cstride, slong Astride, slong Bstride) { slong l, j, k; + slong np = n >> 1; + nn_ptr Ctmp = flint_calloc((nlimbs + 1) * ((p + m) + 5), sizeof(ulong)); - /* Ctmp itself has m * p entries */ + + /* remaining temp space */ nn_ptr Crow = Ctmp; /* Crow has p entries */ nn_ptr Ccol = Crow + (nlimbs + 1) * p; /* Ccol has m entries */ nn_ptr val0 = Ccol + (nlimbs + 1) * m; /* val0 has room for 2 sums */ @@ -201,15 +718,13 @@ _nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, sl nn_ptr val2 = val1 + (nlimbs + 1); /* val2 has room for 1 sum */ nn_ptr crow = val2 + (nlimbs + 1); /* crow has room for 1 sum */ -#define A_ENTRY(i, j) ((A) + ((i) * n + (j)) * (nlimbs + 1)) -#define B_ENTRY(i, j) ((B) + ((i) * p + (j)) * (nlimbs + 1)) -#define C_ENTRY(i, j) ((C) + ((i) * p + (j)) * (nlimbs + 1)) +#define A_ENTRY(i, j) ((A) + (i) * Astride + (j) * (nlimbs + 1)) +#define B_ENTRY(i, j) ((B) + (i) * Bstride + (j) * (nlimbs + 1)) +#define C_ENTRY(i, j) ((C) + (i) * Cstride + (j) * (nlimbs + 1)) #define Crow_ENTRY(ii) (Crow + (ii) * (nlimbs + 1)) #define Ccol_ENTRY(ii) (Ccol + (ii) * (nlimbs + 1)) - slong np = n >> 1; - for (j = 1; j <= np; j++) { slong j2 = (j << 1) - 1; @@ -226,12 +741,41 @@ _nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, sl addmul_subsub(val0, val1, val2, Ccol_ENTRY(l), A_ENTRY(l, j2-1), B_ENTRY(j2, 0), A_ENTRY(l, j2), B_ENTRY(j2-1, 0), nlimbs); } - for (k = 1; k < p; k++) + /* + Note: for nlimbs <= 4 a significant speedup is possible by reordering the loops + so that the O(n^3) part of the algorithm is performed as dot products. + However, classical multiplication still wins over Waksman in that range, + so we do not bother. + */ + if (nlimbs == 5) { - for (l = 1; l < m; l++) - { - addmul_addadd(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k), nlimbs); - } + for (k = 1; k < p; k++) + for (l = 1; l < m; l++) + addmul_addadd_5(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k)); + } + else if (nlimbs == 6) + { + for (k = 1; k < p; k++) + for (l = 1; l < m; l++) + addmul_addadd_6(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k)); + } + else if (nlimbs == 7) + { + for (k = 1; k < p; k++) + for (l = 1; l < m; l++) + addmul_addadd_7(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k)); + } + else if (nlimbs == 8) + { + for (k = 1; k < p; k++) + for (l = 1; l < m; l++) + addmul_addadd_8(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k)); + } + else + { + for (k = 1; k < p; k++) + for (l = 1; l < m; l++) + addmul_addadd(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k), nlimbs); } } @@ -279,6 +823,392 @@ _nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, sl #undef C_ENTRY } +void +_nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) +{ + _nfixed_mat_mul_waksman2(C, A, B, m, n, p, nlimbs, p * (nlimbs + 1), n * (nlimbs + 1), p * (nlimbs + 1)); +} + + +typedef struct +{ + nn_ptr start; + slong r; + slong c; + slong row_stride; +} +_nfixed_mat_struct; + +typedef _nfixed_mat_struct _nfixed_mat_t[1]; + +static void +_nfixed_mat_init(_nfixed_mat_t A, slong r, slong c, slong nlimbs) +{ + A->start = flint_malloc((nlimbs + 1) * (r * c) * sizeof(ulong)); + A->r = r; + A->c = c; + A->row_stride = c * (nlimbs + 1); +} + +static void +_nfixed_mat_clear(_nfixed_mat_t A, slong nlimbs) +{ + flint_free(A->start); +} + +static void +_nfixed_mat_window_init(_nfixed_mat_t A, const _nfixed_mat_t mat, slong r1, slong c1, slong r2, slong c2, slong nlimbs) +{ + A->start = mat->start + (r1 * mat->row_stride) + c1 * (nlimbs + 1); + A->r = r2 - r1; + A->c = c2 - c1; + A->row_stride = mat->row_stride; +} + +static void +_nfixed_mat_window_clear(_nfixed_mat_t A, slong nlimbs) +{ +} + +/* +static void +nfixed_mat_print(nn_ptr A, slong ar, slong ac, slong nlimbs) +{ + slong i, j; + flint_printf("{%wd y %wd : [", ar, ac); + for (i = 0; i < ar; i++) + for (j = 0; j < ac; j++) + { + nfixed_print(A + i * ac * (nlimbs + 1) + j * (nlimbs + 1), nlimbs, 0); + flint_printf(", "); + } + + flint_printf("]}\n"); +} + +static void +_nfixed_mat_print(_nfixed_mat_t A, slong nlimbs) +{ + slong i, j; + flint_printf("{%wd x %wd : [", A->r, A->c); + for (i = 0; i < A->r; i++) + for (j = 0; j < A->c; j++) + { + nfixed_print(A->start + i * A->row_stride + j * (nlimbs + 1), nlimbs, 0); + flint_printf(", "); + } + + flint_printf("]}\n"); +} +*/ + +static void +_nfixed_mat_add(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong nlimbs) +{ + nn_srcptr Aptr, Bptr; + nn_ptr Cptr; + + Aptr = A->start; + Bptr = B->start; + Cptr = C->start; + + slong Astride = A->row_stride; + slong Bstride = B->row_stride; + slong Cstride = C->row_stride; + + slong i, r = A->r, c = A->c; + + for (i = 0; i < r; i++) + _nfixed_vec_add(Cptr + i * Cstride, Aptr + i * Astride, Bptr + i * Bstride, c, nlimbs); +} + +static void +_nfixed_mat_sub(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong nlimbs) +{ + nn_srcptr Aptr, Bptr; + nn_ptr Cptr; + + Aptr = A->start; + Bptr = B->start; + Cptr = C->start; + + slong Astride = A->row_stride; + slong Bstride = B->row_stride; + slong Cstride = C->row_stride; + + slong i, r = A->r, c = A->c; + + for (i = 0; i < r; i++) + _nfixed_vec_sub(Cptr + i * Cstride, Aptr + i * Astride, Bptr + i * Bstride, c, nlimbs); +} + +void +_nfixed_mat_mul_waksman3(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong nlimbs) +{ + nn_srcptr Aptr, Bptr; + nn_ptr Cptr; + + Aptr = A->start; + Bptr = B->start; + Cptr = C->start; + + slong Astride = A->row_stride; + slong Bstride = B->row_stride; + slong Cstride = C->row_stride; + + slong m = A->r; + slong n = A->c; + slong p = B->c; + + _nfixed_mat_mul_waksman2(Cptr, Aptr, Bptr, m, n, p, nlimbs, Cstride, Astride, Bstride); +} + +static void +_nfixed_mat_mul_classical2(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong nlimbs) +{ + nn_srcptr Aptr, Bptr; + nn_ptr Cptr; + + Aptr = A->start; + Bptr = B->start; + Cptr = C->start; + + slong Astride = A->row_stride; + slong Bstride = B->row_stride; + slong Cstride = C->row_stride; + + slong m = A->r; + slong n = A->c; + slong p = B->c; + + slong i, j, k; + nn_ptr t; + +#define A_ENTRY(i, j) ((Aptr) + (i) * Astride + (j) * (nlimbs + 1)) +#define B_ENTRY(i, j) ((Bptr) + (i) * Bstride + (j) * (nlimbs + 1)) +#define C_ENTRY(i, j) ((Cptr) + (i) * Cstride + (j) * (nlimbs + 1)) + + if (nlimbs == 2) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_2(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); + } + else if (nlimbs == 3) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_3(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); + } + else if (nlimbs == 4) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_4(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); + } + else if (nlimbs == 5) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_5(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); + } + else + { + TMP_INIT; + TMP_START; + + t = TMP_ALLOC((nlimbs + 1) * sizeof(ulong)); + + for (i = 0; i < m; i++) + { + for (j = 0; j < p; j++) + { + nfixed_mul(C_ENTRY(i, j), A_ENTRY(i, 0), B_ENTRY(0, j), nlimbs); + + for (k = 1; k < n; k++) + { + nfixed_mul(t, A_ENTRY(i, k), B_ENTRY(k, j), nlimbs); + nfixed_add(C_ENTRY(i, j), C_ENTRY(i, j), t, nlimbs); + } + } + } + + TMP_END; + } + +#undef A_ENTRY +#undef B_ENTRY +#undef C_ENTRY +} + + +static void +_nfixed_mat_mul_strassen2(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong cutoff, slong nlimbs) +{ + slong ar, ac, bc; + slong anr, anc, bnr, bnc; + + _nfixed_mat_t A11, A12, A21, A22; + _nfixed_mat_t B11, B12, B21, B22; + _nfixed_mat_t C11, C12, C21, C22; + _nfixed_mat_t X1, X2; + + ar = A->r; + ac = A->c; + bc = B->c; + + cutoff = FLINT_MAX(cutoff, 2); + + if (ar < cutoff || ac < cutoff || bc < cutoff) + { + if (nlimbs <= 3) + _nfixed_mat_mul_classical2(C, A, B, nlimbs); + else + _nfixed_mat_mul_waksman3(C, A, B, nlimbs); + return; + } + + anr = ar / 2; + anc = ac / 2; + bnr = anc; + bnc = bc / 2; + + _nfixed_mat_window_init(A11, A, 0, 0, anr, anc, nlimbs); + _nfixed_mat_window_init(A12, A, 0, anc, anr, 2 * anc, nlimbs); + _nfixed_mat_window_init(A21, A, anr, 0, 2 * anr, anc, nlimbs); + _nfixed_mat_window_init(A22, A, anr, anc, 2 * anr, 2 * anc, nlimbs); + + _nfixed_mat_window_init(B11, B, 0, 0, bnr, bnc, nlimbs); + _nfixed_mat_window_init(B12, B, 0, bnc, bnr, 2 * bnc, nlimbs); + _nfixed_mat_window_init(B21, B, bnr, 0, 2 * bnr, bnc, nlimbs); + _nfixed_mat_window_init(B22, B, bnr, bnc, 2 * bnr, 2 * bnc, nlimbs); + + _nfixed_mat_window_init(C11, C, 0, 0, anr, bnc, nlimbs); + _nfixed_mat_window_init(C12, C, 0, bnc, anr, 2 * bnc, nlimbs); + _nfixed_mat_window_init(C21, C, anr, 0, 2 * anr, bnc, nlimbs); + _nfixed_mat_window_init(C22, C, anr, bnc, 2 * anr, 2 * bnc, nlimbs); + + _nfixed_mat_init(X1, anr, FLINT_MAX(bnc, anc), nlimbs); + _nfixed_mat_init(X2, anc, bnc, nlimbs); + + X1->c = anc; + + _nfixed_mat_add(X1, A22, A12, nlimbs); + _nfixed_mat_add(X2, B22, B12, nlimbs); + _nfixed_mat_mul_strassen2(C21, X1, X2, cutoff, nlimbs); + + _nfixed_mat_sub(X1, A22, A21, nlimbs); + _nfixed_mat_sub(X2, B22, B21, nlimbs); + _nfixed_mat_mul_strassen2(C22, X1, X2, cutoff, nlimbs); + + _nfixed_mat_add(X1, X1, A12, nlimbs); + _nfixed_mat_add(X2, X2, B12, nlimbs); + _nfixed_mat_mul_strassen2(C11, X1, X2, cutoff, nlimbs); + + _nfixed_mat_sub(X1, X1, A11, nlimbs); + _nfixed_mat_mul_strassen2(C12, X1, B12, cutoff, nlimbs); + + X1->c = bnc; + _nfixed_mat_mul_strassen2(X1, A12, B21, cutoff, nlimbs); + _nfixed_mat_add(C11, C11, X1, nlimbs); + _nfixed_mat_add(C12, C12, C22, nlimbs); + _nfixed_mat_sub(C12, C11, C12, nlimbs); + _nfixed_mat_sub(C11, C21, C11, nlimbs); + _nfixed_mat_sub(X2, X2, B11, nlimbs); + _nfixed_mat_mul_strassen2(C21, A21, X2, cutoff, nlimbs); + + _nfixed_mat_clear(X2, nlimbs); + + _nfixed_mat_sub(C21, C11, C21, nlimbs); + _nfixed_mat_add(C22, C22, C11, nlimbs); + _nfixed_mat_mul_strassen2(C11, A11, B11, cutoff, nlimbs); + + _nfixed_mat_add(C11, X1, C11, nlimbs); + + X1->c = FLINT_MAX(bnc, anc); + _nfixed_mat_clear(X1, nlimbs); + + _nfixed_mat_window_clear(A11, nlimbs); + _nfixed_mat_window_clear(A12, nlimbs); + _nfixed_mat_window_clear(A21, nlimbs); + _nfixed_mat_window_clear(A22, nlimbs); + + _nfixed_mat_window_clear(B11, nlimbs); + _nfixed_mat_window_clear(B12, nlimbs); + _nfixed_mat_window_clear(B21, nlimbs); + _nfixed_mat_window_clear(B22, nlimbs); + + _nfixed_mat_window_clear(C11, nlimbs); + _nfixed_mat_window_clear(C12, nlimbs); + _nfixed_mat_window_clear(C21, nlimbs); + _nfixed_mat_window_clear(C22, nlimbs); + + if (bc > 2 * bnc) + { + _nfixed_mat_t Bc, Cc; + _nfixed_mat_window_init(Bc, B, 0, 2 * bnc, ac, bc, nlimbs); + _nfixed_mat_window_init(Cc, C, 0, 2 * bnc, ar, bc, nlimbs); + _nfixed_mat_mul_strassen2(Cc, A, Bc, cutoff, nlimbs); + _nfixed_mat_window_clear(Bc, nlimbs); + _nfixed_mat_window_clear(Cc, nlimbs); + } + + if (ar > 2 * anr) + { + _nfixed_mat_t Ar, Cr; + _nfixed_mat_window_init(Ar, A, 2 * anr, 0, ar, ac, nlimbs); + _nfixed_mat_window_init(Cr, C, 2 * anr, 0, ar, bc, nlimbs); + _nfixed_mat_mul_strassen2(Cr, Ar, B, cutoff, nlimbs); + _nfixed_mat_window_clear(Ar, nlimbs); + _nfixed_mat_window_clear(Cr, nlimbs); + } + + if (ac > 2 * anc) + { + _nfixed_mat_t Ac, Br, Cb, tmp; + slong mt, nt; + + _nfixed_mat_window_init(Ac, A, 0, 2 * anc, 2 * anr, ac, nlimbs); + _nfixed_mat_window_init(Br, B, 2 * bnr, 0, ac, 2 * bnc, nlimbs); + _nfixed_mat_window_init(Cb, C, 0, 0, 2 * anr, 2 * bnc, nlimbs); + + mt = Ac->r; + nt = Br->c; + + /* todo: faster */ + _nfixed_mat_init(tmp, mt, nt, nlimbs); + _nfixed_mat_mul_strassen2(tmp, Ac, Br, cutoff, nlimbs); + _nfixed_mat_add(Cb, Cb, tmp, nlimbs); + _nfixed_mat_clear(tmp, nlimbs); + _nfixed_mat_window_clear(Ac, nlimbs); + _nfixed_mat_window_clear(Br, nlimbs); + _nfixed_mat_window_clear(Cb, nlimbs); + } +} + +void +_nfixed_mat_mul_strassen(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong cutoff, slong nlimbs) +{ + _nfixed_mat_t CC, AA, BB; + + AA->start = (nn_ptr) A; + AA->r = m; + AA->c = n; + AA->row_stride = n * (nlimbs + 1); + + BB->start = (nn_ptr) B; + BB->r = n; + BB->c = p; + BB->row_stride = p * (nlimbs + 1); + + CC->start = C; + CC->r = m; + CC->c = p; + CC->row_stride = p * (nlimbs + 1); + + _nfixed_mat_mul_strassen2(CC, AA, BB, cutoff, nlimbs); +} + FLINT_FORCE_INLINE void _nfloat_get_nfixed(nn_ptr res, nn_srcptr x, slong exp, slong fix_nlimbs, gr_ctx_t ctx) { @@ -357,10 +1287,12 @@ _nfloat_mat_mul_fixed_given_exp(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, for (j = 0; j < p; j++) _nfloat_get_nfixed(TB + i * fdnlimbs * p + j * fdnlimbs, GR_MAT_ENTRY(B, i, j, sz), Bexp, fnlimbs, ctx); - if (waksman) + if (waksman == 1) _nfixed_mat_mul_waksman(TC, TA, TB, m, n, p, fnlimbs); - else + else if (waksman == 0) _nfixed_mat_mul_classical(TC, TA, TB, m, n, p, fnlimbs); + else + _nfixed_mat_mul_strassen(TC, TA, TB, m, n, p, waksman, fnlimbs); for (i = 0; i < m; i++) for (j = 0; j < p; j++) @@ -437,6 +1369,11 @@ nfloat_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t return _nfloat_mat_mul_fixed(C, A, B, 1, 100000, ctx); } +int +nfloat_mat_mul_strassen(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong cutoff, gr_ctx_t ctx) +{ + return _nfloat_mat_mul_fixed(C, A, B, cutoff, 100000, ctx); +} static void _nfloat_2exp_get_fmpz(fmpz_t res, nfloat_srcptr x, slong fixexp, gr_ctx_t ctx) From e4291668a2896d6b97372f7e9b1084b59891c901 Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Mon, 2 Sep 2024 15:32:46 +0200 Subject: [PATCH 07/15] bug fixes and optimizations --- src/nfloat.h | 1 + src/nfloat/mat_mul.c | 149 +++++++++++++++++- src/nfloat/profile/p-mat_mul.c | 280 +++++++++++++++++++++++++++++++++ src/nfloat/test/main.c | 2 + src/nfloat/test/t-mat_mul.c | 18 +++ src/nfloat/test/t-nfixed_dot.c | 155 ++++++++++++++++++ 6 files changed, 597 insertions(+), 8 deletions(-) create mode 100644 src/nfloat/profile/p-mat_mul.c create mode 100644 src/nfloat/test/t-nfixed_dot.c diff --git a/src/nfloat.h b/src/nfloat.h index 12b7363e66..520ee8b430 100644 --- a/src/nfloat.h +++ b/src/nfloat.h @@ -455,6 +455,7 @@ int _nfloat_vec_dot_rev(nfloat_ptr res, nfloat_srcptr initial, int subtract, nfl int nfloat_mat_mul_fixed_classical(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); int nfloat_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); +int nfloat_mat_mul_strassen(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong cutoff, gr_ctx_t ctx); int nfloat_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx); int nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); diff --git a/src/nfloat/mat_mul.c b/src/nfloat/mat_mul.c index bd1ec02ac3..2db4dc8648 100644 --- a/src/nfloat/mat_mul.c +++ b/src/nfloat/mat_mul.c @@ -517,7 +517,7 @@ void _nfixed_dot_4(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ys add_sssaaaaaa(t2, t1, t0, t2, t1, t0, 0, u1, u0); } - if ((slong) t1 < WORD(0)) + if ((slong) t2 < WORD(0)) { sub_dddmmmsss(t2, t1, t0, 0, 0, 0, t2, t1, t0); sub_ddddmmmmssss(s3, s2, s1, s0, s3, s2, s1, s0, 0, t2, t1, t0); @@ -559,9 +559,9 @@ void _nfixed_dot_5(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ys for (j = 1; j < len; j++) { - nfixed_mul(tmp, x + j * xstride, y + ystride, nlimbs); + nfixed_mul(tmp, x + j * xstride, y + j * ystride, nlimbs); - if (tmp[0]) + if (tmp[0] == 0) NN_ADD_5(spos + 1, spos + 1, tmp + 1); else NN_ADD_5(sneg + 1, sneg + 1, tmp + 1); @@ -570,6 +570,87 @@ void _nfixed_dot_5(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ys nfixed_sub_5(res, spos, sneg); } +void _nfixed_dot_6(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) +{ + slong j; + slong nlimbs = 6; + + ulong tmp[7]; + ulong spos[7] = { 0, 0, 0, 0, 0, 0, 0 }; + ulong sneg[7] = { 0, 0, 0, 0, 0, 0, 0 }; + + if (x[0] == y[0]) + flint_mpn_mulhigh_n(spos + 1, x + 1, y + 1, nlimbs); + else + flint_mpn_mulhigh_n(sneg + 1, x + 1, y + 1, nlimbs); + + for (j = 1; j < len; j++) + { + nfixed_mul(tmp, x + j * xstride, y + j * ystride, nlimbs); + + if (tmp[0] == 0) + NN_ADD_6(spos + 1, spos + 1, tmp + 1); + else + NN_ADD_6(sneg + 1, sneg + 1, tmp + 1); + } + + nfixed_sub_6(res, spos, sneg); +} + +void _nfixed_dot_7(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) +{ + slong j; + slong nlimbs = 7; + + ulong tmp[8]; + ulong spos[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }; + ulong sneg[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }; + + if (x[0] == y[0]) + flint_mpn_mulhigh_n(spos + 1, x + 1, y + 1, nlimbs); + else + flint_mpn_mulhigh_n(sneg + 1, x + 1, y + 1, nlimbs); + + for (j = 1; j < len; j++) + { + nfixed_mul(tmp, x + j * xstride, y + j * ystride, nlimbs); + + if (tmp[0] == 0) + NN_ADD_7(spos + 1, spos + 1, tmp + 1); + else + NN_ADD_7(sneg + 1, sneg + 1, tmp + 1); + } + + nfixed_sub_7(res, spos, sneg); +} + +void _nfixed_dot_8(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) +{ + slong j; + slong nlimbs = 8; + + ulong tmp[9]; + ulong spos[9] = { 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + ulong sneg[9] = { 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + + if (x[0] == y[0]) + flint_mpn_mulhigh_n(spos + 1, x + 1, y + 1, nlimbs); + else + flint_mpn_mulhigh_n(sneg + 1, x + 1, y + 1, nlimbs); + + for (j = 1; j < len; j++) + { + nfixed_mul(tmp, x + j * xstride, y + j * ystride, nlimbs); + + if (tmp[0] == 0) + NN_ADD_8(spos + 1, spos + 1, tmp + 1); + else + NN_ADD_8(sneg + 1, sneg + 1, tmp + 1); + } + + nfixed_sub_8(res, spos, sneg); +} + /* A is (m x n), B is (n x p), C is (m x p) */ void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) @@ -581,6 +662,14 @@ _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, #define B_ENTRY(i, j) ((B) + ((i) * p + (j)) * (nlimbs + 1)) #define C_ENTRY(i, j) ((C) + ((i) * p + (j)) * (nlimbs + 1)) + if (n == 1) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + nfixed_mul(C_ENTRY(i, j), A_ENTRY(i, 0), B_ENTRY(0, j), nlimbs); + return; + } + if (nlimbs == 2) { for (i = 0; i < m; i++) @@ -605,6 +694,24 @@ _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, for (j = 0; j < p; j++) _nfixed_dot_5(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); } + else if (nlimbs == 6) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_6(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); + } + else if (nlimbs == 7) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_7(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); + } + else if (nlimbs == 8) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_8(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); + } else { TMP_INIT; @@ -661,7 +768,7 @@ addmul_addadd_4(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, n { nfixed_add_4(val1, a1, b1); nfixed_add_4(val2, a2, b2); - nfixed_mul(val0, val1, val2, 5); + nfixed_mul(val0, val1, val2, 4); nfixed_add_4(c, c, val0); } @@ -679,7 +786,7 @@ addmul_addadd_6(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, n { nfixed_add_6(val1, a1, b1); nfixed_add_6(val2, a2, b2); - nfixed_mul(val0, val1, val2, 5); + nfixed_mul(val0, val1, val2, 6); nfixed_add_6(c, c, val0); } @@ -688,7 +795,7 @@ addmul_addadd_7(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, n { nfixed_add_7(val1, a1, b1); nfixed_add_7(val2, a2, b2); - nfixed_mul(val0, val1, val2, 5); + nfixed_mul(val0, val1, val2, 7); nfixed_add_7(c, c, val0); } @@ -697,7 +804,7 @@ addmul_addadd_8(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, n { nfixed_add_8(val1, a1, b1); nfixed_add_8(val2, a2, b2); - nfixed_mul(val0, val1, val2, 5); + nfixed_mul(val0, val1, val2, 8); nfixed_add_8(c, c, val0); } @@ -988,6 +1095,14 @@ _nfixed_mat_mul_classical2(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed #define B_ENTRY(i, j) ((Bptr) + (i) * Bstride + (j) * (nlimbs + 1)) #define C_ENTRY(i, j) ((Cptr) + (i) * Cstride + (j) * (nlimbs + 1)) + if (n == 1) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + nfixed_mul(C_ENTRY(i, j), A_ENTRY(i, 0), B_ENTRY(0, j), nlimbs); + return; + } + if (nlimbs == 2) { for (i = 0; i < m; i++) @@ -1012,6 +1127,24 @@ _nfixed_mat_mul_classical2(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed for (j = 0; j < p; j++) _nfixed_dot_5(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); } + else if (nlimbs == 6) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_6(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); + } + else if (nlimbs == 7) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_7(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); + } + else if (nlimbs == 8) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_8(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); + } else { TMP_INIT; @@ -1061,7 +1194,7 @@ _nfixed_mat_mul_strassen2(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_ if (ar < cutoff || ac < cutoff || bc < cutoff) { - if (nlimbs <= 3) + if (nlimbs <= 5) _nfixed_mat_mul_classical2(C, A, B, nlimbs); else _nfixed_mat_mul_waksman3(C, A, B, nlimbs); diff --git a/src/nfloat/profile/p-mat_mul.c b/src/nfloat/profile/p-mat_mul.c new file mode 100644 index 0000000000..172a3c2adb --- /dev/null +++ b/src/nfloat/profile/p-mat_mul.c @@ -0,0 +1,280 @@ +/* + Copyright (C) 2024 Fredrik Johansson + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include +#include "fmpz.h" +#include "gr.h" +#include "gr_special.h" +#include "gr_vec.h" +#include "gr_mat.h" +#include "arf.h" +#include "nfloat.h" +#include "profiler.h" +#include "double_extras.h" + +#define TABN (NFLOAT_MAX_LIMBS + 1) +#define WAKSMAN_MIN_PREC 320 + +#if 1 +#undef TIMEIT_END_REPEAT +#define TIMEIT_END_REPEAT(__timer, __reps) \ + } \ + timeit_stop(__timer); \ + if (__timer->cpu >= 100) \ + break; \ + __reps *= 10; \ + } \ + } while (0); +#endif + +void +randmat(gr_mat_t mat, flint_rand_t state, gr_ctx_t ctx) +{ + slong m = gr_mat_nrows(mat, ctx); + slong n = gr_mat_ncols(mat, ctx); + + slong i, j; + + for (i = 0; i < m; i++) + { + for (j = 0; j < n; j++) + { + gr_ptr v = gr_mat_entry_ptr(mat, i, j, ctx); + + GR_MUST_SUCCEED(gr_set_si(v, 1 + n_randint(state, 1000), ctx)); + GR_MUST_SUCCEED(gr_div_ui(v, v, 1 + n_randint(state, 1000), ctx)); + if (n_randint(state, 2)) + GR_MUST_SUCCEED(gr_neg(v, v, ctx)); + } + } +} + +void tune_fixed_vs_waksman(int * cutoffs) +{ + gr_ctx_t ctx; + gr_mat_t A, B, C; + slong i, n; + slong prec; + double FLINT_SET_BUT_UNUSED(__), t1, t2; + + flint_rand_t state; + flint_rand_init(state); + + for (i = 0; i < TABN; i++) + cutoffs[i] = -1; + + for (prec = WAKSMAN_MIN_PREC; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec += 64) + { + flint_printf("prec = %wd\n", prec); + + nfloat_ctx_init(ctx, prec, 0); + + for (n = 2; n <= 128; n++) + { + gr_mat_init(A, n, n, ctx); + gr_mat_init(B, n, n, ctx); + gr_mat_init(C, n, n, ctx); + + randmat(A, state, ctx); + randmat(B, state, ctx); + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_mat_mul_fixed_classical(C, A, B, ctx)); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_mat_mul_waksman(C, A, B, ctx)); + TIMEIT_STOP_VALUES(__, t2) + + flint_printf("%wd %wd %e %e %.3f\n", prec, n, t1, t2, t1 / t2); + + gr_mat_clear(A, ctx); + gr_mat_clear(B, ctx); + gr_mat_clear(C, ctx); + + if (t2 < t1 * 0.99) + { + cutoffs[prec / 64] = n; + + flint_printf("short tab_fixed_classical_vs_waksman[] = {\n"); + for (i = 0; i <= prec / 64; i++) + flint_printf(" %d, /* prec = %wd */\n", cutoffs[i], i * 64); + flint_printf("}\n"); + + break; + } + } + } + + flint_rand_clear(state); +} + +void tune_strassen(int * cutoffs) +{ + gr_ctx_t ctx; + gr_mat_t A, B, C; + slong i, n; + slong prec; + double FLINT_SET_BUT_UNUSED(__), t1, t2; + int prev_ok = 0; + + flint_rand_t state; + flint_rand_init(state); + + for (i = 0; i < TABN; i++) + cutoffs[i] = -1; + + for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec += 64) + { + flint_printf("prec = %wd\n", prec); + + prev_ok = 0; + + nfloat_ctx_init(ctx, prec, 0); + + for (n = 2; n <= 128; n++) + { + gr_mat_init(A, n, n, ctx); + gr_mat_init(B, n, n, ctx); + gr_mat_init(C, n, n, ctx); + + randmat(A, state, ctx); + randmat(B, state, ctx); + + TIMEIT_START + if (prec < WAKSMAN_MIN_PREC) + GR_MUST_SUCCEED(nfloat_mat_mul_fixed_classical(C, A, B, ctx)); + else + GR_MUST_SUCCEED(nfloat_mat_mul_waksman(C, A, B, ctx)); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_mat_mul_strassen(C, A, B, n, ctx)); + TIMEIT_STOP_VALUES(__, t2) + + flint_printf("%wd %wd %e %e %.3f\n", prec, n, t1, t2, t1 / t2); + + gr_mat_clear(A, ctx); + gr_mat_clear(B, ctx); + gr_mat_clear(C, ctx); + + if (t2 < t1 * 0.99) + { + if (prev_ok) + { + cutoffs[prec / 64] = n; + + flint_printf("short tab_strassen[] = {\n"); + for (i = 0; i <= prec / 64; i++) + flint_printf(" %d, /* prec = %wd */\n", cutoffs[i], i * 64); + flint_printf("}\n"); + + break; + } + else + { + prev_ok = 1; + } + } + else + { + prev_ok = 0; + } + } + } + + flint_rand_clear(state); +} + + +short tab_fixed_classical_vs_waksman[] = { + -1, /* prec = 0 */ + -1, /* prec = 64 */ + -1, /* prec = 128 */ + -1, /* prec = 192 */ + -1, /* prec = 256 */ + 16, /* prec = 320 */ + 10, /* prec = 384 */ + 7, /* prec = 448 */ + 7, /* prec = 512 */ + 6, /* prec = 576 */ + 5, /* prec = 640 */ + 4, /* prec = 704 */ + 4, /* prec = 768 */ + 4, /* prec = 832 */ + 4, /* prec = 896 */ + 4, /* prec = 960 */ + 4, /* prec = 1024 */ + 4, /* prec = 1088 */ + 4, /* prec = 1152 */ + 4, /* prec = 1216 */ + 4, /* prec = 1280 */ + 3, /* prec = 1344 */ + 4, /* prec = 1408 */ + 3, /* prec = 1472 */ + 3, /* prec = 1536 */ + 3, /* prec = 1600 */ + 3, /* prec = 1664 */ + 3, /* prec = 1728 */ + 3, /* prec = 1792 */ + 3, /* prec = 1856 */ + 3, /* prec = 1920 */ + 3, /* prec = 1984 */ + 3, /* prec = 2048 */ + 3, /* prec = 2112 */ + 3, /* prec = 2176 */ + 3, /* prec = 2240 */ + 3, /* prec = 2304 */ + 3, /* prec = 2368 */ + 3, /* prec = 2432 */ + 3, /* prec = 2496 */ + 3, /* prec = 2560 */ + 3, /* prec = 2624 */ + 3, /* prec = 2688 */ + 3, /* prec = 2752 */ + 3, /* prec = 2816 */ + 3, /* prec = 2880 */ + 3, /* prec = 2944 */ + 3, /* prec = 3008 */ + 2, /* prec = 3072 */ + 3, /* prec = 3136 */ + 3, /* prec = 3200 */ + 2, /* prec = 3264 */ + 2, /* prec = 3328 */ + 2, /* prec = 3392 */ + 2, /* prec = 3456 */ + 2, /* prec = 3520 */ + 3, /* prec = 3584 */ + 2, /* prec = 3648 */ + 2, /* prec = 3712 */ + 2, /* prec = 3776 */ + 2, /* prec = 3840 */ + 2, /* prec = 3904 */ + 2, /* prec = 3968 */ + 2, /* prec = 4032 */ + 2, /* prec = 4096 */ + 2, /* prec = 4160 */ + 2, /* prec = 4224 */ +}; + + + +int main() +{ + int tab_fixed_classical_vs_waksman[TABN]; + int tab_strassen[TABN]; + + tune_strassen(tab_strassen); + + + tune_fixed_vs_waksman(tab_fixed_classical_vs_waksman); + +} diff --git a/src/nfloat/test/main.c b/src/nfloat/test/main.c index a4473c5693..077f81f91b 100644 --- a/src/nfloat/test/main.c +++ b/src/nfloat/test/main.c @@ -15,6 +15,7 @@ #include "t-addmul_submul.c" #include "t-complex_mat_mul.c" #include "t-mat_mul.c" +#include "t-nfixed_dot.c" #include "t-nfloat.c" #include "t-nfloat_complex.c" @@ -26,6 +27,7 @@ test_struct tests[] = TEST_FUNCTION(addmul_submul), TEST_FUNCTION(complex_mat_mul), TEST_FUNCTION(mat_mul), + TEST_FUNCTION(nfixed_dot), TEST_FUNCTION(nfloat), TEST_FUNCTION(nfloat_complex), }; diff --git a/src/nfloat/test/t-mat_mul.c b/src/nfloat/test/t-mat_mul.c index 5e7d8b29ac..d0849a193a 100644 --- a/src/nfloat/test/t-mat_mul.c +++ b/src/nfloat/test/t-mat_mul.c @@ -45,6 +45,24 @@ TEST_FUNCTION_START(mat_mul, state) gr_ctx_clear(ctx); } + for (iter = 0; iter < 100 * flint_test_multiplier(); iter++) + { + prec = 64; + + nfloat_ctx_init(ctx, prec, 0); + + tol = gr_heap_init(ctx); + GR_MUST_SUCCEED(gr_one(tol, ctx)); + GR_MUST_SUCCEED(gr_mul_2exp_si(tol, tol, -prec + 2, ctx)); + + gr_mat_test_approx_mul_max_norm( + (gr_method_mat_binary_op) nfloat_mat_mul_waksman, + tol, state, 10, 10, ctx); + + gr_heap_clear(tol, ctx); + gr_ctx_clear(ctx); + } + for (iter = 0; iter < 10 * flint_test_multiplier(); iter++) { if (n_randint(state, 5)) diff --git a/src/nfloat/test/t-nfixed_dot.c b/src/nfloat/test/t-nfixed_dot.c new file mode 100644 index 0000000000..9c980d98db --- /dev/null +++ b/src/nfloat/test/t-nfixed_dot.c @@ -0,0 +1,155 @@ +/* + Copyright (C) 2024 Fredrik Johansson + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "test_helpers.h" +#include "fmpq.h" +#include "arf.h" +#include "gr_vec.h" +#include "gr_special.h" +#include "nfloat.h" + +#define MAXLEN 10 +#define MINLIMBS 2 +#define MAXLIMBS 8 + +FLINT_FORCE_INLINE +void nfixed_add(nn_ptr res, nn_srcptr a, nn_srcptr b, slong nlimbs) +{ + int asgn, bsgn; + asgn = a[0]; + bsgn = b[0]; + + if (asgn == bsgn) + { + res[0] = asgn; + mpn_add_n(res + 1, a + 1, b + 1, nlimbs); + } + else + { + res[0] = asgn ^ flint_mpn_signed_sub_n(res + 1, a + 1, b + 1, nlimbs); + } +} + +FLINT_FORCE_INLINE +void nfixed_sub(nn_ptr res, nn_srcptr a, nn_srcptr b, slong nlimbs) +{ + int asgn, bsgn; + asgn = a[0]; + bsgn = b[0]; + + if (asgn != bsgn) + { + res[0] = asgn; + mpn_add_n(res + 1, a + 1, b + 1, nlimbs); + } + else + { + res[0] = asgn ^ flint_mpn_signed_sub_n(res + 1, a + 1, b + 1, nlimbs); + } +} + +FLINT_FORCE_INLINE +void nfixed_mul(nn_ptr res, nn_srcptr a, nn_srcptr b, slong nlimbs) +{ + int asgn, bsgn; + asgn = a[0]; + bsgn = b[0]; + res[0] = asgn ^ bsgn; + flint_mpn_mulhigh_n(res + 1, a + 1, b + 1, nlimbs); +} + +void _nfixed_dot_2(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); +void _nfixed_dot_3(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); +void _nfixed_dot_4(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); +void _nfixed_dot_5(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); +void _nfixed_dot_6(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); +void _nfixed_dot_7(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); +void _nfixed_dot_8(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); + +TEST_FUNCTION_START(nfixed_dot, state) +{ + slong iter, len, i, nlimbs; + nn_ptr a; + + ulong A[MAXLEN * (MAXLIMBS + 1)]; + ulong B[MAXLEN * (MAXLIMBS + 1)]; + ulong C[MAXLIMBS + 1]; + ulong D[MAXLIMBS + 1]; + ulong t[MAXLIMBS + 1]; + + for (iter = 0; iter < 10000 * flint_test_multiplier(); iter++) + { + len = 1 + n_randint(state, MAXLEN); + nlimbs = MINLIMBS + n_randint(state, MAXLIMBS - MINLIMBS + 1); + + ulong maxerr = (2 * nlimbs - 1) * len; + + for (i = 0; i < len; i++) + { + a = A + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + a[nlimbs] >>= 10; + + a = B + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + a[nlimbs] >>= 10; + } + + switch (nlimbs) + { + case 2: + _nfixed_dot_2(C, A, nlimbs + 1, B, nlimbs + 1, len); + break; + case 3: + _nfixed_dot_3(C, A, nlimbs + 1, B, nlimbs + 1, len); + break; + case 4: + _nfixed_dot_4(C, A, nlimbs + 1, B, nlimbs + 1, len); + break; + case 5: + _nfixed_dot_5(C, A, nlimbs + 1, B, nlimbs + 1, len); + break; + case 6: + _nfixed_dot_6(C, A, nlimbs + 1, B, nlimbs + 1, len); + break; + case 7: + _nfixed_dot_7(C, A, nlimbs + 1, B, nlimbs + 1, len); + break; + case 8: + _nfixed_dot_8(C, A, nlimbs + 1, B, nlimbs + 1, len); + break; + default: + flint_abort(); + } + + flint_mpn_zero(D, nlimbs + 1); + + for (i = 0; i < len; i++) + { + nfixed_mul(t, A + i * (nlimbs + 1), B + i * (nlimbs + 1), nlimbs); + nfixed_add(D, D, t, nlimbs); + } + + nfixed_sub(t, C, D, nlimbs); + if (!flint_mpn_zero_p(t + 2, nlimbs - 1) || t[1] > maxerr) + { + TEST_FUNCTION_FAIL("nlimbs = %wd, len = %wd,\n\nA = %{ulong*},\n\nB = %{ulong*},\n\nC = %{ulong*},\n\nD = %{ulong*},\n\nt = %{ulong*}\n", nlimbs, len, + A, len * (nlimbs + 1), B, len * (nlimbs + 1), + C, nlimbs + 1, + D, nlimbs + 1, + t, nlimbs + 1); + } + } + + TEST_FUNCTION_END(state); +} \ No newline at end of file From 9293b5c1bf53a4e9ec9d5d29fc2478eb161fc7a9 Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Tue, 3 Sep 2024 17:55:32 +0200 Subject: [PATCH 08/15] more matrix multiplication cleanup and improvements --- doc/source/nfloat.rst | 59 +- src/nfloat.h | 26 +- src/nfloat/mat_mul.c | 1446 +------------------------ src/nfloat/nfixed.c | 1406 ++++++++++++++++++++++++ src/nfloat/profile/p-mat_mul.c | 207 ++-- src/nfloat/profile/p-nfixed_mat_mul.c | 268 +++++ src/nfloat/test/main.c | 2 + src/nfloat/test/t-complex_mat_mul.c | 10 +- src/nfloat/test/t-mat_mul.c | 39 +- src/nfloat/test/t-nfixed_dot.c | 6 +- src/nfloat/test/t-nfixed_mat_mul.c | 101 ++ 11 files changed, 1986 insertions(+), 1584 deletions(-) create mode 100644 src/nfloat/nfixed.c create mode 100644 src/nfloat/profile/p-nfixed_mat_mul.c create mode 100644 src/nfloat/test/t-nfixed_mat_mul.c diff --git a/doc/source/nfloat.rst b/doc/source/nfloat.rst index 39d425ce52..92fb364736 100644 --- a/doc/source/nfloat.rst +++ b/doc/source/nfloat.rst @@ -317,8 +317,7 @@ code for reduced overhead. Matrix functions ------------------------------------------------------------------------------- -.. function:: int nfloat_mat_mul_fixed_classical(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) - int nfloat_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) +.. function:: int nfloat_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong max_extra_prec, gr_ctx_t ctx) int nfloat_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx) int nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) @@ -416,11 +415,63 @@ real pairs. int _nfloat_complex_vec_set(nfloat_complex_ptr res, nfloat_complex_srcptr x, slong len, gr_ctx_t ctx) int _nfloat_complex_vec_add(nfloat_complex_ptr res, nfloat_complex_srcptr x, nfloat_complex_srcptr y, slong len, gr_ctx_t ctx) int _nfloat_complex_vec_sub(nfloat_complex_ptr res, nfloat_complex_srcptr x, nfloat_complex_srcptr y, slong len, gr_ctx_t ctx) - int nfloat_complex_mat_mul_fixed_classical(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) - int nfloat_complex_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) + int nfloat_complex_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong max_extra_prec, gr_ctx_t ctx) int nfloat_complex_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx) int nfloat_complex_mat_mul_reorder(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) int nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) int nfloat_complex_mat_nonsingular_solve_tril(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx) int nfloat_complex_mat_nonsingular_solve_triu(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx) int nfloat_complex_mat_lu(slong * rank, slong * P, gr_mat_t LU, const gr_mat_t A, int rank_check, gr_ctx_t ctx) + +Packed fixed-point arithmetic +------------------------------------------------------------------------------- + +A fixed-point number in the range `(-1,1)` with `n`-limb precision +is represented as `n+1` contiguous limbs as follows: + + +---------------+ + | sign limb | + +---------------+ + | mantissa[0] | + +---------------+ + | ... | + +---------------+ + | mantissa[n-1] | + +---------------+ + +In the following method signatures, ``nlimbs`` always refers to the +precision ``n`` while the storage is ``nlimbs + 1``. + +There is no overflow handling: all methods assume that inputs have +been scaled to a range `[-\varepsilon,\varepsilon]` so that all +intermediate results (including rounding errors) lie in `(-1,1)`. + +.. function:: void _nfixed_print(nn_srcptr x, slong nlimbs, slong exp) + + Print the fixed-point number + +.. function:: void _nfixed_vec_add(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs) + void _nfixed_vec_sub(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs) + + Vectorized addition or subtraction of *len* fixed-point numbers. + +.. function:: void _nfixed_dot_2(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) + void _nfixed_dot_3(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) + void _nfixed_dot_4(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) + void _nfixed_dot_5(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) + void _nfixed_dot_6(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) + void _nfixed_dot_7(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) + void _nfixed_dot_8(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) + + Dot product with a fixed number of limbs. The ``xstride`` and ``ystride`` parameters + indicate the offset in number of limbs between consecutive entries + and may be negative. + +.. function:: void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) + void _nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) + void _nfixed_mat_mul_strassen(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong cutoff, slong nlimbs) + void _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) + + Matrix multiplication using various algorithms. + The *strassen* variant takes a *cutoff* parameter specifying where + to switch from basecase multiplication to Strassen multiplication. diff --git a/src/nfloat.h b/src/nfloat.h index 520ee8b430..d73c32bbc4 100644 --- a/src/nfloat.h +++ b/src/nfloat.h @@ -453,9 +453,7 @@ int _nfloat_vec_submul_scalar(nfloat_ptr res, nfloat_srcptr x, slong len, nfloat int _nfloat_vec_dot(nfloat_ptr res, nfloat_srcptr initial, int subtract, nfloat_srcptr x, nfloat_srcptr y, slong len, gr_ctx_t ctx); int _nfloat_vec_dot_rev(nfloat_ptr res, nfloat_srcptr initial, int subtract, nfloat_srcptr x, nfloat_srcptr y, slong len, gr_ctx_t ctx); -int nfloat_mat_mul_fixed_classical(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); -int nfloat_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); -int nfloat_mat_mul_strassen(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong cutoff, gr_ctx_t ctx); +int nfloat_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong max_extra_prec, gr_ctx_t ctx); int nfloat_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx); int nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); @@ -569,8 +567,7 @@ int _nfloat_complex_vec_sub(nfloat_complex_ptr res, nfloat_complex_srcptr x, nfl int _nfloat_complex_vec_dot(nfloat_complex_ptr res, nfloat_complex_srcptr initial, int subtract, nfloat_complex_srcptr x, nfloat_complex_srcptr y, slong len, gr_ctx_t ctx); int _nfloat_complex_vec_dot_rev(nfloat_complex_ptr res, nfloat_complex_srcptr initial, int subtract, nfloat_complex_srcptr x, nfloat_complex_srcptr y, slong len, gr_ctx_t ctx); -int nfloat_complex_mat_mul_fixed_classical(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); -int nfloat_complex_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); +int nfloat_complex_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong max_extra_prec, gr_ctx_t ctx); int nfloat_complex_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_block_size, gr_ctx_t ctx); int nfloat_complex_mat_mul_reorder(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); int nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); @@ -579,6 +576,25 @@ int nfloat_complex_mat_nonsingular_solve_tril(gr_mat_t X, const gr_mat_t L, cons int nfloat_complex_mat_nonsingular_solve_triu(gr_mat_t X, const gr_mat_t L, const gr_mat_t B, int unit, gr_ctx_t ctx); int nfloat_complex_mat_lu(slong * rank, slong * P, gr_mat_t LU, const gr_mat_t A, int rank_check, gr_ctx_t ctx); +/* Fixed-point arithmetic */ + +void _nfixed_print(nn_srcptr x, slong nlimbs, slong exp); +void _nfixed_vec_add(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs); +void _nfixed_vec_sub(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs); + +void _nfixed_dot_2(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); +void _nfixed_dot_3(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); +void _nfixed_dot_4(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); +void _nfixed_dot_5(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); +void _nfixed_dot_6(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); +void _nfixed_dot_7(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); +void _nfixed_dot_8(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len); + +void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs); +void _nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs); +void _nfixed_mat_mul_strassen(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong cutoff, slong nlimbs); +void _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs); + #ifdef __cplusplus } #endif diff --git a/src/nfloat/mat_mul.c b/src/nfloat/mat_mul.c index 2db4dc8648..99d2393f4e 100644 --- a/src/nfloat/mat_mul.c +++ b/src/nfloat/mat_mul.c @@ -20,1327 +20,7 @@ #include "gr_special.h" #include "fmpz_mat.h" -int -nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx); - - -/* For printing */ -#include "arf.h" - -/* Arithmetic on fixed-point numbers in (-1,1) */ -/* x[0] stores the sign bit, x[1], ..., x[n] store the absolute value */ - -void -nfixed_print(nn_srcptr x, slong nlimbs, slong exp) -{ - arf_t t; - arf_init(t); - _arf_set_mpn_fixed(t, x + 1, nlimbs, nlimbs, x[0], nlimbs * FLINT_BITS, ARF_RND_DOWN); - arf_mul_2exp_si(t, t, exp); - arf_printd(t, nlimbs * FLINT_BITS / 3.321928 + 1); - arf_clear(t); -} - -#define DEF_NFIXED_ADD(n) \ -FLINT_FORCE_INLINE \ -void nfixed_add_ ## n(nn_ptr res, nn_srcptr a, nn_srcptr b) \ -{ \ - int asgn, bsgn; \ - asgn = a[0]; \ - bsgn = b[0]; \ - \ - if (asgn == bsgn) \ - { \ - res[0] = asgn; \ - NN_ADD_ ## n(res + 1, a + 1, b + 1); \ - } \ - else \ - { \ - res[0] = asgn ^ flint_mpn_signed_sub_ ## n(res + 1, a + 1, b + 1); \ - } \ -} - -#define DEF_NFIXED_SUB(n) \ -FLINT_FORCE_INLINE \ -void nfixed_sub_ ## n(nn_ptr res, nn_srcptr a, nn_srcptr b) \ -{ \ - int asgn, bsgn; \ - asgn = a[0]; \ - bsgn = b[0]; \ - \ - if (asgn != bsgn) \ - { \ - res[0] = asgn; \ - NN_ADD_ ## n(res + 1, a + 1, b + 1); \ - } \ - else \ - { \ - res[0] = asgn ^ flint_mpn_signed_sub_ ## n(res + 1, a + 1, b + 1); \ - } \ -} - -DEF_NFIXED_ADD(2) -DEF_NFIXED_ADD(3) -DEF_NFIXED_ADD(4) -DEF_NFIXED_ADD(5) -DEF_NFIXED_ADD(6) -DEF_NFIXED_ADD(7) -DEF_NFIXED_ADD(8) - -DEF_NFIXED_SUB(2) -DEF_NFIXED_SUB(3) -DEF_NFIXED_SUB(4) -DEF_NFIXED_SUB(5) -DEF_NFIXED_SUB(6) -DEF_NFIXED_SUB(7) -DEF_NFIXED_SUB(8) - -FLINT_FORCE_INLINE -void nfixed_add(nn_ptr res, nn_srcptr a, nn_srcptr b, slong nlimbs) -{ - int asgn, bsgn; - asgn = a[0]; - bsgn = b[0]; - - if (asgn == bsgn) - { - res[0] = asgn; - mpn_add_n(res + 1, a + 1, b + 1, nlimbs); - } - else - { - res[0] = asgn ^ flint_mpn_signed_sub_n(res + 1, a + 1, b + 1, nlimbs); - } -} - -FLINT_FORCE_INLINE -void nfixed_sub(nn_ptr res, nn_srcptr a, nn_srcptr b, slong nlimbs) -{ - int asgn, bsgn; - asgn = a[0]; - bsgn = b[0]; - - if (asgn != bsgn) - { - res[0] = asgn; - mpn_add_n(res + 1, a + 1, b + 1, nlimbs); - } - else - { - res[0] = asgn ^ flint_mpn_signed_sub_n(res + 1, a + 1, b + 1, nlimbs); - } -} - -static -void _nfixed_vec_add(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs) -{ - slong i; - - if (nlimbs == 2) - { - for (i = 0; i < len; i++) - nfixed_add_2(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); - } - else if (nlimbs == 3) - { - for (i = 0; i < len; i++) - nfixed_add_3(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); - } - else if (nlimbs == 4) - { - for (i = 0; i < len; i++) - nfixed_add_4(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); - } - else if (nlimbs == 5) - { - for (i = 0; i < len; i++) - nfixed_add_5(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); - } - else if (nlimbs == 6) - { - for (i = 0; i < len; i++) - nfixed_add_6(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); - } - else if (nlimbs == 7) - { - for (i = 0; i < len; i++) - nfixed_add_7(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); - } - else if (nlimbs == 8) - { - for (i = 0; i < len; i++) - nfixed_add_8(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); - } - else - { - for (i = 0; i < len; i++) - nfixed_add(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1), nlimbs); - } -} - -static -void _nfixed_vec_sub(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs) -{ - slong i; - - if (nlimbs == 2) - { - for (i = 0; i < len; i++) - nfixed_sub_2(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); - } - else if (nlimbs == 3) - { - for (i = 0; i < len; i++) - nfixed_sub_3(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); - } - else if (nlimbs == 4) - { - for (i = 0; i < len; i++) - nfixed_sub_4(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); - } - else if (nlimbs == 5) - { - for (i = 0; i < len; i++) - nfixed_sub_5(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); - } - else if (nlimbs == 6) - { - for (i = 0; i < len; i++) - nfixed_sub_6(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); - } - else if (nlimbs == 7) - { - for (i = 0; i < len; i++) - nfixed_sub_7(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); - } - else if (nlimbs == 8) - { - for (i = 0; i < len; i++) - nfixed_sub_8(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); - } - else - { - for (i = 0; i < len; i++) - nfixed_sub(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1), nlimbs); - } -} - -FLINT_FORCE_INLINE -void nfixed_mul(nn_ptr res, nn_srcptr a, nn_srcptr b, slong nlimbs) -{ - int asgn, bsgn; - asgn = a[0]; - bsgn = b[0]; - res[0] = asgn ^ bsgn; - flint_mpn_mulhigh_n(res + 1, a + 1, b + 1, nlimbs); -} - -FLINT_FORCE_INLINE -void nfixed_sqr(nn_ptr res, nn_srcptr a, slong nlimbs) -{ - res[0] = 0; - flint_mpn_sqrhigh(res + 1, a + 1, nlimbs); -} - -FLINT_FORCE_INLINE -void nfixed_div2(nn_ptr res, nn_srcptr a, slong nlimbs) -{ - res[0] = a[0]; - mpn_rshift(res + 1, a + 1, nlimbs, 1); -} - -void _nfixed_dot_2(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) -{ - slong j; - - ulong as, bs, a0, a1, b0, b1, s0, s1, t0, t1, u0, u1, hi, lo; - - s0 = s1 = t0 = t1 = u0 = u1 = 0; - - /* - s0 s1 - |a1 b1----| - u0 u1 - |a1 b0----| - t0 t1 - |a0 b1----| - */ - - for (j = 0; j < len; j++) - { - as = x[j * xstride]; - a0 = x[j * xstride + 1]; - a1 = x[j * xstride + 2]; - bs = y[j * ystride]; - b0 = y[j * ystride + 1]; - b1 = y[j * ystride + 2]; - - if (as == bs) - { - umul_ppmm(hi, lo, a1, b1); - add_ssaaaa(s1, s0, s1, s0, hi, lo); - umul_ppmm(hi, lo, a0, b1); - add_ssaaaa(t1, t0, t1, t0, 0, hi); - umul_ppmm(hi, lo, a1, b0); - add_ssaaaa(u1, u0, u1, u0, 0, hi); - } - else - { - umul_ppmm(hi, lo, a1, b1); - sub_ddmmss(s1, s0, s1, s0, hi, lo); - umul_ppmm(hi, lo, a0, b1); - sub_ddmmss(t1, t0, t1, t0, 0, hi); - umul_ppmm(hi, lo, a1, b0); - sub_ddmmss(u1, u0, u1, u0, 0, hi); - } - } - - add_ssaaaa(s1, s0, s1, s0, t1, t0); - add_ssaaaa(s1, s0, s1, s0, u1, u0); - - if ((slong) s1 < WORD(0)) - { - sub_ddmmss(s1, s0, 0, 0, s1, s0); - res[0] = 1; - } - else - { - res[0] = 0; - } - - res[1] = s0; - res[2] = s1; -} - -void _nfixed_dot_3(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) -{ - slong j; - - ulong as, bs, a0, a1, a2, b0, b1, b2, s0, s1, s2, hi, lo; - ulong u0, u1, v0, v1; - - s0 = s1 = s2 = 0; - u0 = u1 = v0 = v1 = 0; - - for (j = 0; j < len; j++) - { - as = x[j * xstride]; - a0 = x[j * xstride + 1]; - a1 = x[j * xstride + 2]; - a2 = x[j * xstride + 3]; - bs = y[j * ystride]; - b0 = y[j * ystride + 1]; - b1 = y[j * ystride + 2]; - b2 = y[j * ystride + 3]; - - /* - |a2 b2----| - - |a1 b2----| - |a2 b1----| - - |a0 b2----| - |a1 b1----| - |a2 b0----| - */ - - if (as == bs) - { - umul_ppmm(hi, lo, a0, b2); - add_ssaaaa(u1, u0, u1, u0, 0, hi); - umul_ppmm(hi, lo, a1, b1); - add_ssaaaa(u1, u0, u1, u0, 0, hi); - umul_ppmm(hi, lo, a2, b0); - add_ssaaaa(u1, u0, u1, u0, 0, hi); - - umul_ppmm(hi, lo, a2, b2); - add_ssaaaa(s2, s1, s2, s1, hi, lo); - - umul_ppmm(hi, lo, a1, b2); - add_sssaaaaaa(s2, s1, s0, s2, s1, s0, 0, hi, lo); - umul_ppmm(hi, lo, a2, b1); - add_sssaaaaaa(s2, s1, s0, s2, s1, s0, 0, hi, lo); - } - else - { - umul_ppmm(hi, lo, a0, b2); - sub_ddmmss(u1, u0, u1, u0, 0, hi); - umul_ppmm(hi, lo, a1, b1); - sub_ddmmss(u1, u0, u1, u0, 0, hi); - umul_ppmm(hi, lo, a2, b0); - sub_ddmmss(u1, u0, u1, u0, 0, hi); - - umul_ppmm(hi, lo, a2, b2); - sub_ddmmss(s2, s1, s2, s1, hi, lo); - - umul_ppmm(hi, lo, a1, b2); - sub_dddmmmsss(s2, s1, s0, s2, s1, s0, 0, hi, lo); - umul_ppmm(hi, lo, a2, b1); - sub_dddmmmsss(s2, s1, s0, s2, s1, s0, 0, hi, lo); - } - } - - if ((slong) u1 < WORD(0)) - { - sub_ddmmss(u1, u0, 0, 0, u1, u0); - sub_dddmmmsss(s2, s1, s0, s2, s1, s0, 0, u1, u0); - } - else - { - add_sssaaaaaa(s2, s1, s0, s2, s1, s0, 0, u1, u0); - } - - if ((slong) s2 < WORD(0)) - { - sub_dddmmmsss(s2, s1, s0, 0, 0, 0, s2, s1, s0); - res[0] = 1; - } - else - { - res[0] = 0; - } - - - res[1] = s0; - res[2] = s1; - res[3] = s2; -} - -void _nfixed_dot_4(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) -{ - slong j; - - /* - s1 s2 s3 - |a3 b3----| - - |a2 b3----| - |a3 b2----| - - t0 t1 t2 - - |a1 b3----| - |a2 b2----| - |a3 b1----| - - u0 u1 - - |a0 b3----| - |a1 b2----| - |a2 b1----| - |a3 b0----| - - */ - - ulong as, a0, a1, a2, a3; - ulong bs, b0, b1, b2, b3; - ulong s0, s1, s2, s3, t0, t1, t2, u0, u1; - ulong hi, lo; - - s0 = s1 = s2 = s3 = t0 = t1 = t2 = u0 = u1 = 0; - - for (j = 0; j < len; j++) - { - as = x[j * xstride]; - a0 = x[j * xstride + 1]; - a1 = x[j * xstride + 2]; - a2 = x[j * xstride + 3]; - a3 = x[j * xstride + 4]; - - bs = y[j * ystride]; - b0 = y[j * ystride + 1]; - b1 = y[j * ystride + 2]; - b2 = y[j * ystride + 3]; - b3 = y[j * ystride + 4]; - - if (as == bs) - { - umul_ppmm(hi, lo, a3, b3); - add_ssaaaa(s3, s2, s3, s2, hi, lo); - - umul_ppmm(hi, lo, a2, b3); - add_sssaaaaaa(s3, s2, s1, s3, s2, s1, 0, hi, lo); - umul_ppmm(hi, lo, a3, b2); - add_sssaaaaaa(s3, s2, s1, s3, s2, s1, 0, hi, lo); - - umul_ppmm(hi, lo, a1, b3); - add_sssaaaaaa(t2, t1, t0, t2, t1, t0, 0, hi, lo); - umul_ppmm(hi, lo, a2, b2); - add_sssaaaaaa(t2, t1, t0, t2, t1, t0, 0, hi, lo); - umul_ppmm(hi, lo, a3, b1); - add_sssaaaaaa(t2, t1, t0, t2, t1, t0, 0, hi, lo); - - umul_ppmm(hi, lo, a0, b3); - add_ssaaaa(u1, u0, u1, u0, 0, hi); - umul_ppmm(hi, lo, a1, b2); - add_ssaaaa(u1, u0, u1, u0, 0, hi); - umul_ppmm(hi, lo, a2, b1); - add_ssaaaa(u1, u0, u1, u0, 0, hi); - umul_ppmm(hi, lo, a3, b0); - add_ssaaaa(u1, u0, u1, u0, 0, hi); - } - else - { - umul_ppmm(hi, lo, a3, b3); - sub_ddmmss(s3, s2, s3, s2, hi, lo); - - umul_ppmm(hi, lo, a2, b3); - sub_dddmmmsss(s3, s2, s1, s3, s2, s1, 0, hi, lo); - umul_ppmm(hi, lo, a3, b2); - sub_dddmmmsss(s3, s2, s1, s3, s2, s1, 0, hi, lo); - - umul_ppmm(hi, lo, a1, b3); - sub_dddmmmsss(t2, t1, t0, t2, t1, t0, 0, hi, lo); - umul_ppmm(hi, lo, a2, b2); - sub_dddmmmsss(t2, t1, t0, t2, t1, t0, 0, hi, lo); - umul_ppmm(hi, lo, a3, b1); - sub_dddmmmsss(t2, t1, t0, t2, t1, t0, 0, hi, lo); - - umul_ppmm(hi, lo, a0, b3); - sub_ddmmss(u1, u0, u1, u0, 0, hi); - umul_ppmm(hi, lo, a1, b2); - sub_ddmmss(u1, u0, u1, u0, 0, hi); - umul_ppmm(hi, lo, a2, b1); - sub_ddmmss(u1, u0, u1, u0, 0, hi); - umul_ppmm(hi, lo, a3, b0); - sub_ddmmss(u1, u0, u1, u0, 0, hi); - } - } - - if ((slong) u1 < WORD(0)) - { - sub_ddmmss(u1, u0, 0, 0, u1, u0); - sub_dddmmmsss(t2, t1, s0, t2, t1, s0, 0, u1, u0); - } - else - { - add_sssaaaaaa(t2, t1, t0, t2, t1, t0, 0, u1, u0); - } - - if ((slong) t2 < WORD(0)) - { - sub_dddmmmsss(t2, t1, t0, 0, 0, 0, t2, t1, t0); - sub_ddddmmmmssss(s3, s2, s1, s0, s3, s2, s1, s0, 0, t2, t1, t0); - } - else - { - add_ssssaaaaaaaa(s3, s2, s1, s0, s3, s2, s1, s0, 0, t2, t1, t0); - } - - if ((slong) s3 < WORD(0)) - { - sub_ddddmmmmssss(s3, s2, s1, s0, 0, 0, 0, 0, s3, s2, s1, s0); - res[0] = 1; - } - else - { - res[0] = 0; - } - - res[1] = s0; - res[2] = s1; - res[3] = s2; - res[4] = s3; -} - -void _nfixed_dot_5(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) -{ - slong j; - slong nlimbs = 5; - - ulong tmp[6]; - ulong spos[6] = { 0, 0, 0, 0, 0, 0 }; - ulong sneg[6] = { 0, 0, 0, 0, 0, 0 }; - - if (x[0] == y[0]) - flint_mpn_mulhigh_n(spos + 1, x + 1, y + 1, nlimbs); - else - flint_mpn_mulhigh_n(sneg + 1, x + 1, y + 1, nlimbs); - - for (j = 1; j < len; j++) - { - nfixed_mul(tmp, x + j * xstride, y + j * ystride, nlimbs); - - if (tmp[0] == 0) - NN_ADD_5(spos + 1, spos + 1, tmp + 1); - else - NN_ADD_5(sneg + 1, sneg + 1, tmp + 1); - } - - nfixed_sub_5(res, spos, sneg); -} - -void _nfixed_dot_6(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) -{ - slong j; - slong nlimbs = 6; - - ulong tmp[7]; - ulong spos[7] = { 0, 0, 0, 0, 0, 0, 0 }; - ulong sneg[7] = { 0, 0, 0, 0, 0, 0, 0 }; - - if (x[0] == y[0]) - flint_mpn_mulhigh_n(spos + 1, x + 1, y + 1, nlimbs); - else - flint_mpn_mulhigh_n(sneg + 1, x + 1, y + 1, nlimbs); - - for (j = 1; j < len; j++) - { - nfixed_mul(tmp, x + j * xstride, y + j * ystride, nlimbs); - - if (tmp[0] == 0) - NN_ADD_6(spos + 1, spos + 1, tmp + 1); - else - NN_ADD_6(sneg + 1, sneg + 1, tmp + 1); - } - - nfixed_sub_6(res, spos, sneg); -} - -void _nfixed_dot_7(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) -{ - slong j; - slong nlimbs = 7; - - ulong tmp[8]; - ulong spos[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }; - ulong sneg[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }; - - if (x[0] == y[0]) - flint_mpn_mulhigh_n(spos + 1, x + 1, y + 1, nlimbs); - else - flint_mpn_mulhigh_n(sneg + 1, x + 1, y + 1, nlimbs); - - for (j = 1; j < len; j++) - { - nfixed_mul(tmp, x + j * xstride, y + j * ystride, nlimbs); - - if (tmp[0] == 0) - NN_ADD_7(spos + 1, spos + 1, tmp + 1); - else - NN_ADD_7(sneg + 1, sneg + 1, tmp + 1); - } - - nfixed_sub_7(res, spos, sneg); -} - -void _nfixed_dot_8(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) -{ - slong j; - slong nlimbs = 8; - - ulong tmp[9]; - ulong spos[9] = { 0, 0, 0, 0, 0, 0, 0, 0, 0 }; - ulong sneg[9] = { 0, 0, 0, 0, 0, 0, 0, 0, 0 }; - - if (x[0] == y[0]) - flint_mpn_mulhigh_n(spos + 1, x + 1, y + 1, nlimbs); - else - flint_mpn_mulhigh_n(sneg + 1, x + 1, y + 1, nlimbs); - - for (j = 1; j < len; j++) - { - nfixed_mul(tmp, x + j * xstride, y + j * ystride, nlimbs); - - if (tmp[0] == 0) - NN_ADD_8(spos + 1, spos + 1, tmp + 1); - else - NN_ADD_8(sneg + 1, sneg + 1, tmp + 1); - } - - nfixed_sub_8(res, spos, sneg); -} - -/* A is (m x n), B is (n x p), C is (m x p) */ -void -_nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) -{ - slong i, j, k; - nn_ptr t; - -#define A_ENTRY(i, j) ((A) + ((i) * n + (j)) * (nlimbs + 1)) -#define B_ENTRY(i, j) ((B) + ((i) * p + (j)) * (nlimbs + 1)) -#define C_ENTRY(i, j) ((C) + ((i) * p + (j)) * (nlimbs + 1)) - - if (n == 1) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - nfixed_mul(C_ENTRY(i, j), A_ENTRY(i, 0), B_ENTRY(0, j), nlimbs); - return; - } - - if (nlimbs == 2) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfixed_dot_2(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); - } - else if (nlimbs == 3) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfixed_dot_3(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); - } - else if (nlimbs == 4) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfixed_dot_4(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); - } - else if (nlimbs == 5) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfixed_dot_5(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); - } - else if (nlimbs == 6) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfixed_dot_6(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); - } - else if (nlimbs == 7) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfixed_dot_7(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); - } - else if (nlimbs == 8) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfixed_dot_8(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); - } - else - { - TMP_INIT; - TMP_START; - - t = TMP_ALLOC((nlimbs + 1) * sizeof(ulong)); - - for (i = 0; i < m; i++) - { - for (j = 0; j < p; j++) - { - nfixed_mul(C_ENTRY(i, j), A_ENTRY(i, 0), B_ENTRY(0, j), nlimbs); - - for (k = 1; k < n; k++) - { - nfixed_mul(t, A_ENTRY(i, k), B_ENTRY(k, j), nlimbs); - nfixed_add(C_ENTRY(i, j), C_ENTRY(i, j), t, nlimbs); - } - } - } - - TMP_END; - } - -#undef A_ENTRY -#undef B_ENTRY -#undef C_ENTRY -} - -/* compute c += (a1 + b1) * (a2 + b2) */ -/* val0, val1, val2 are scratch space */ -FLINT_FORCE_INLINE void -addmul_addadd(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2, slong nlimbs) -{ - nfixed_add(val1, a1, b1, nlimbs); - nfixed_add(val2, a2, b2, nlimbs); - nfixed_mul(val0, val1, val2, nlimbs); - nfixed_add(c, c, val0, nlimbs); -} - -/* compute c += (a1 - b1) * (a2 - b2) */ -/* val0, val1, val2 are scratch space */ -FLINT_FORCE_INLINE void -addmul_subsub(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2, slong nlimbs) -{ - nfixed_sub(val1, a1, b1, nlimbs); - nfixed_sub(val2, a2, b2, nlimbs); - nfixed_mul(val0, val1, val2, nlimbs); - nfixed_add(c, c, val0, nlimbs); -} - -FLINT_FORCE_INLINE void -addmul_addadd_4(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) -{ - nfixed_add_4(val1, a1, b1); - nfixed_add_4(val2, a2, b2); - nfixed_mul(val0, val1, val2, 4); - nfixed_add_4(c, c, val0); -} - -FLINT_FORCE_INLINE void -addmul_addadd_5(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) -{ - nfixed_add_5(val1, a1, b1); - nfixed_add_5(val2, a2, b2); - nfixed_mul(val0, val1, val2, 5); - nfixed_add_5(c, c, val0); -} - -FLINT_FORCE_INLINE void -addmul_addadd_6(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) -{ - nfixed_add_6(val1, a1, b1); - nfixed_add_6(val2, a2, b2); - nfixed_mul(val0, val1, val2, 6); - nfixed_add_6(c, c, val0); -} - -FLINT_FORCE_INLINE void -addmul_addadd_7(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) -{ - nfixed_add_7(val1, a1, b1); - nfixed_add_7(val2, a2, b2); - nfixed_mul(val0, val1, val2, 7); - nfixed_add_7(c, c, val0); -} - -FLINT_FORCE_INLINE void -addmul_addadd_8(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) -{ - nfixed_add_8(val1, a1, b1); - nfixed_add_8(val2, a2, b2); - nfixed_mul(val0, val1, val2, 8); - nfixed_add_8(c, c, val0); -} - -void -_nfixed_mat_mul_waksman2(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs, slong Cstride, slong Astride, slong Bstride) -{ - slong l, j, k; - - slong np = n >> 1; - - nn_ptr Ctmp = flint_calloc((nlimbs + 1) * ((p + m) + 5), sizeof(ulong)); - - /* remaining temp space */ - nn_ptr Crow = Ctmp; /* Crow has p entries */ - nn_ptr Ccol = Crow + (nlimbs + 1) * p; /* Ccol has m entries */ - nn_ptr val0 = Ccol + (nlimbs + 1) * m; /* val0 has room for 2 sums */ - nn_ptr val1 = val0 + (nlimbs + 1) * 2; /* val1 has room for 1 sum */ - nn_ptr val2 = val1 + (nlimbs + 1); /* val2 has room for 1 sum */ - nn_ptr crow = val2 + (nlimbs + 1); /* crow has room for 1 sum */ - -#define A_ENTRY(i, j) ((A) + (i) * Astride + (j) * (nlimbs + 1)) -#define B_ENTRY(i, j) ((B) + (i) * Bstride + (j) * (nlimbs + 1)) -#define C_ENTRY(i, j) ((C) + (i) * Cstride + (j) * (nlimbs + 1)) - -#define Crow_ENTRY(ii) (Crow + (ii) * (nlimbs + 1)) -#define Ccol_ENTRY(ii) (Ccol + (ii) * (nlimbs + 1)) - - for (j = 1; j <= np; j++) - { - slong j2 = (j << 1) - 1; - - for (k = 0; k < p; k++) - { - addmul_addadd(val0, val1, val2, C_ENTRY(0, k), A_ENTRY(0, j2-1), B_ENTRY(j2, k), A_ENTRY(0, j2), B_ENTRY(j2-1, k), nlimbs); - addmul_subsub(val0, val1, val2, Crow_ENTRY(k), A_ENTRY(0, j2-1), B_ENTRY(j2, k), A_ENTRY(0, j2), B_ENTRY(j2-1, k), nlimbs); - } - - for (l = 1; l < m; l++) - { - addmul_addadd(val0, val1, val2, C_ENTRY(l, 0), A_ENTRY(l, j2-1), B_ENTRY(j2, 0), A_ENTRY(l, j2), B_ENTRY(j2-1, 0), nlimbs); - addmul_subsub(val0, val1, val2, Ccol_ENTRY(l), A_ENTRY(l, j2-1), B_ENTRY(j2, 0), A_ENTRY(l, j2), B_ENTRY(j2-1, 0), nlimbs); - } - - /* - Note: for nlimbs <= 4 a significant speedup is possible by reordering the loops - so that the O(n^3) part of the algorithm is performed as dot products. - However, classical multiplication still wins over Waksman in that range, - so we do not bother. - */ - if (nlimbs == 5) - { - for (k = 1; k < p; k++) - for (l = 1; l < m; l++) - addmul_addadd_5(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k)); - } - else if (nlimbs == 6) - { - for (k = 1; k < p; k++) - for (l = 1; l < m; l++) - addmul_addadd_6(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k)); - } - else if (nlimbs == 7) - { - for (k = 1; k < p; k++) - for (l = 1; l < m; l++) - addmul_addadd_7(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k)); - } - else if (nlimbs == 8) - { - for (k = 1; k < p; k++) - for (l = 1; l < m; l++) - addmul_addadd_8(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k)); - } - else - { - for (k = 1; k < p; k++) - for (l = 1; l < m; l++) - addmul_addadd(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k), nlimbs); - } - } - - for (l = 1; l < m; l++) - { - nfixed_add(val1, Ccol_ENTRY(l), C_ENTRY(l, 0), nlimbs); - nfixed_div2(Ccol_ENTRY(l), val1, nlimbs); - nfixed_sub(C_ENTRY(l, 0), C_ENTRY(l, 0), Ccol_ENTRY(l), nlimbs); - } - - nfixed_add(val1, Crow, C_ENTRY(0, 0), nlimbs); - nfixed_div2(val0, val1, nlimbs); - nfixed_sub(C_ENTRY(0, 0), C_ENTRY(0, 0), val0, nlimbs); - - for (k = 1; k < p; k++) - { - nfixed_add(crow, Crow_ENTRY(k), C_ENTRY(0, k), nlimbs); - nfixed_div2(val1, crow, nlimbs); - nfixed_sub(C_ENTRY(0, k), C_ENTRY(0, k), val1, nlimbs); - nfixed_sub(crow, val1, val0, nlimbs); - - for (l = 1; l < m; l++) - { - nfixed_sub(val2, C_ENTRY(l, k), crow, nlimbs); - nfixed_sub(C_ENTRY(l, k), val2, Ccol_ENTRY(l), nlimbs); - } - } - - if ((n & 1) == 1) - { - for (l = 0; l < m; l++) - { - for (k = 0; k < p; k++) - { - nfixed_mul(val0, A_ENTRY(l, n-1), B_ENTRY(n-1, k), nlimbs); - nfixed_add(C_ENTRY(l, k), C_ENTRY(l, k), val0, nlimbs); - } - } - } - - flint_free(Ctmp); - -#undef A_ENTRY -#undef B_ENTRY -#undef C_ENTRY -} - -void -_nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) -{ - _nfixed_mat_mul_waksman2(C, A, B, m, n, p, nlimbs, p * (nlimbs + 1), n * (nlimbs + 1), p * (nlimbs + 1)); -} - - -typedef struct -{ - nn_ptr start; - slong r; - slong c; - slong row_stride; -} -_nfixed_mat_struct; - -typedef _nfixed_mat_struct _nfixed_mat_t[1]; - -static void -_nfixed_mat_init(_nfixed_mat_t A, slong r, slong c, slong nlimbs) -{ - A->start = flint_malloc((nlimbs + 1) * (r * c) * sizeof(ulong)); - A->r = r; - A->c = c; - A->row_stride = c * (nlimbs + 1); -} - -static void -_nfixed_mat_clear(_nfixed_mat_t A, slong nlimbs) -{ - flint_free(A->start); -} - -static void -_nfixed_mat_window_init(_nfixed_mat_t A, const _nfixed_mat_t mat, slong r1, slong c1, slong r2, slong c2, slong nlimbs) -{ - A->start = mat->start + (r1 * mat->row_stride) + c1 * (nlimbs + 1); - A->r = r2 - r1; - A->c = c2 - c1; - A->row_stride = mat->row_stride; -} - -static void -_nfixed_mat_window_clear(_nfixed_mat_t A, slong nlimbs) -{ -} - -/* -static void -nfixed_mat_print(nn_ptr A, slong ar, slong ac, slong nlimbs) -{ - slong i, j; - flint_printf("{%wd y %wd : [", ar, ac); - for (i = 0; i < ar; i++) - for (j = 0; j < ac; j++) - { - nfixed_print(A + i * ac * (nlimbs + 1) + j * (nlimbs + 1), nlimbs, 0); - flint_printf(", "); - } - - flint_printf("]}\n"); -} - -static void -_nfixed_mat_print(_nfixed_mat_t A, slong nlimbs) -{ - slong i, j; - flint_printf("{%wd x %wd : [", A->r, A->c); - for (i = 0; i < A->r; i++) - for (j = 0; j < A->c; j++) - { - nfixed_print(A->start + i * A->row_stride + j * (nlimbs + 1), nlimbs, 0); - flint_printf(", "); - } - - flint_printf("]}\n"); -} -*/ - -static void -_nfixed_mat_add(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong nlimbs) -{ - nn_srcptr Aptr, Bptr; - nn_ptr Cptr; - - Aptr = A->start; - Bptr = B->start; - Cptr = C->start; - - slong Astride = A->row_stride; - slong Bstride = B->row_stride; - slong Cstride = C->row_stride; - - slong i, r = A->r, c = A->c; - - for (i = 0; i < r; i++) - _nfixed_vec_add(Cptr + i * Cstride, Aptr + i * Astride, Bptr + i * Bstride, c, nlimbs); -} - -static void -_nfixed_mat_sub(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong nlimbs) -{ - nn_srcptr Aptr, Bptr; - nn_ptr Cptr; - - Aptr = A->start; - Bptr = B->start; - Cptr = C->start; - - slong Astride = A->row_stride; - slong Bstride = B->row_stride; - slong Cstride = C->row_stride; - - slong i, r = A->r, c = A->c; - - for (i = 0; i < r; i++) - _nfixed_vec_sub(Cptr + i * Cstride, Aptr + i * Astride, Bptr + i * Bstride, c, nlimbs); -} - -void -_nfixed_mat_mul_waksman3(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong nlimbs) -{ - nn_srcptr Aptr, Bptr; - nn_ptr Cptr; - - Aptr = A->start; - Bptr = B->start; - Cptr = C->start; - - slong Astride = A->row_stride; - slong Bstride = B->row_stride; - slong Cstride = C->row_stride; - - slong m = A->r; - slong n = A->c; - slong p = B->c; - - _nfixed_mat_mul_waksman2(Cptr, Aptr, Bptr, m, n, p, nlimbs, Cstride, Astride, Bstride); -} - -static void -_nfixed_mat_mul_classical2(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong nlimbs) -{ - nn_srcptr Aptr, Bptr; - nn_ptr Cptr; - - Aptr = A->start; - Bptr = B->start; - Cptr = C->start; - - slong Astride = A->row_stride; - slong Bstride = B->row_stride; - slong Cstride = C->row_stride; - - slong m = A->r; - slong n = A->c; - slong p = B->c; - - slong i, j, k; - nn_ptr t; - -#define A_ENTRY(i, j) ((Aptr) + (i) * Astride + (j) * (nlimbs + 1)) -#define B_ENTRY(i, j) ((Bptr) + (i) * Bstride + (j) * (nlimbs + 1)) -#define C_ENTRY(i, j) ((Cptr) + (i) * Cstride + (j) * (nlimbs + 1)) - - if (n == 1) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - nfixed_mul(C_ENTRY(i, j), A_ENTRY(i, 0), B_ENTRY(0, j), nlimbs); - return; - } - - if (nlimbs == 2) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfixed_dot_2(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); - } - else if (nlimbs == 3) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfixed_dot_3(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); - } - else if (nlimbs == 4) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfixed_dot_4(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); - } - else if (nlimbs == 5) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfixed_dot_5(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); - } - else if (nlimbs == 6) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfixed_dot_6(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); - } - else if (nlimbs == 7) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfixed_dot_7(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); - } - else if (nlimbs == 8) - { - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfixed_dot_8(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); - } - else - { - TMP_INIT; - TMP_START; - - t = TMP_ALLOC((nlimbs + 1) * sizeof(ulong)); - - for (i = 0; i < m; i++) - { - for (j = 0; j < p; j++) - { - nfixed_mul(C_ENTRY(i, j), A_ENTRY(i, 0), B_ENTRY(0, j), nlimbs); - - for (k = 1; k < n; k++) - { - nfixed_mul(t, A_ENTRY(i, k), B_ENTRY(k, j), nlimbs); - nfixed_add(C_ENTRY(i, j), C_ENTRY(i, j), t, nlimbs); - } - } - } - - TMP_END; - } - -#undef A_ENTRY -#undef B_ENTRY -#undef C_ENTRY -} - - -static void -_nfixed_mat_mul_strassen2(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong cutoff, slong nlimbs) -{ - slong ar, ac, bc; - slong anr, anc, bnr, bnc; - - _nfixed_mat_t A11, A12, A21, A22; - _nfixed_mat_t B11, B12, B21, B22; - _nfixed_mat_t C11, C12, C21, C22; - _nfixed_mat_t X1, X2; - - ar = A->r; - ac = A->c; - bc = B->c; - - cutoff = FLINT_MAX(cutoff, 2); - - if (ar < cutoff || ac < cutoff || bc < cutoff) - { - if (nlimbs <= 5) - _nfixed_mat_mul_classical2(C, A, B, nlimbs); - else - _nfixed_mat_mul_waksman3(C, A, B, nlimbs); - return; - } - - anr = ar / 2; - anc = ac / 2; - bnr = anc; - bnc = bc / 2; - - _nfixed_mat_window_init(A11, A, 0, 0, anr, anc, nlimbs); - _nfixed_mat_window_init(A12, A, 0, anc, anr, 2 * anc, nlimbs); - _nfixed_mat_window_init(A21, A, anr, 0, 2 * anr, anc, nlimbs); - _nfixed_mat_window_init(A22, A, anr, anc, 2 * anr, 2 * anc, nlimbs); - - _nfixed_mat_window_init(B11, B, 0, 0, bnr, bnc, nlimbs); - _nfixed_mat_window_init(B12, B, 0, bnc, bnr, 2 * bnc, nlimbs); - _nfixed_mat_window_init(B21, B, bnr, 0, 2 * bnr, bnc, nlimbs); - _nfixed_mat_window_init(B22, B, bnr, bnc, 2 * bnr, 2 * bnc, nlimbs); - - _nfixed_mat_window_init(C11, C, 0, 0, anr, bnc, nlimbs); - _nfixed_mat_window_init(C12, C, 0, bnc, anr, 2 * bnc, nlimbs); - _nfixed_mat_window_init(C21, C, anr, 0, 2 * anr, bnc, nlimbs); - _nfixed_mat_window_init(C22, C, anr, bnc, 2 * anr, 2 * bnc, nlimbs); - - _nfixed_mat_init(X1, anr, FLINT_MAX(bnc, anc), nlimbs); - _nfixed_mat_init(X2, anc, bnc, nlimbs); - - X1->c = anc; - - _nfixed_mat_add(X1, A22, A12, nlimbs); - _nfixed_mat_add(X2, B22, B12, nlimbs); - _nfixed_mat_mul_strassen2(C21, X1, X2, cutoff, nlimbs); - - _nfixed_mat_sub(X1, A22, A21, nlimbs); - _nfixed_mat_sub(X2, B22, B21, nlimbs); - _nfixed_mat_mul_strassen2(C22, X1, X2, cutoff, nlimbs); - - _nfixed_mat_add(X1, X1, A12, nlimbs); - _nfixed_mat_add(X2, X2, B12, nlimbs); - _nfixed_mat_mul_strassen2(C11, X1, X2, cutoff, nlimbs); - - _nfixed_mat_sub(X1, X1, A11, nlimbs); - _nfixed_mat_mul_strassen2(C12, X1, B12, cutoff, nlimbs); - - X1->c = bnc; - _nfixed_mat_mul_strassen2(X1, A12, B21, cutoff, nlimbs); - _nfixed_mat_add(C11, C11, X1, nlimbs); - _nfixed_mat_add(C12, C12, C22, nlimbs); - _nfixed_mat_sub(C12, C11, C12, nlimbs); - _nfixed_mat_sub(C11, C21, C11, nlimbs); - _nfixed_mat_sub(X2, X2, B11, nlimbs); - _nfixed_mat_mul_strassen2(C21, A21, X2, cutoff, nlimbs); - - _nfixed_mat_clear(X2, nlimbs); - - _nfixed_mat_sub(C21, C11, C21, nlimbs); - _nfixed_mat_add(C22, C22, C11, nlimbs); - _nfixed_mat_mul_strassen2(C11, A11, B11, cutoff, nlimbs); - - _nfixed_mat_add(C11, X1, C11, nlimbs); - - X1->c = FLINT_MAX(bnc, anc); - _nfixed_mat_clear(X1, nlimbs); - - _nfixed_mat_window_clear(A11, nlimbs); - _nfixed_mat_window_clear(A12, nlimbs); - _nfixed_mat_window_clear(A21, nlimbs); - _nfixed_mat_window_clear(A22, nlimbs); - - _nfixed_mat_window_clear(B11, nlimbs); - _nfixed_mat_window_clear(B12, nlimbs); - _nfixed_mat_window_clear(B21, nlimbs); - _nfixed_mat_window_clear(B22, nlimbs); - - _nfixed_mat_window_clear(C11, nlimbs); - _nfixed_mat_window_clear(C12, nlimbs); - _nfixed_mat_window_clear(C21, nlimbs); - _nfixed_mat_window_clear(C22, nlimbs); - - if (bc > 2 * bnc) - { - _nfixed_mat_t Bc, Cc; - _nfixed_mat_window_init(Bc, B, 0, 2 * bnc, ac, bc, nlimbs); - _nfixed_mat_window_init(Cc, C, 0, 2 * bnc, ar, bc, nlimbs); - _nfixed_mat_mul_strassen2(Cc, A, Bc, cutoff, nlimbs); - _nfixed_mat_window_clear(Bc, nlimbs); - _nfixed_mat_window_clear(Cc, nlimbs); - } - - if (ar > 2 * anr) - { - _nfixed_mat_t Ar, Cr; - _nfixed_mat_window_init(Ar, A, 2 * anr, 0, ar, ac, nlimbs); - _nfixed_mat_window_init(Cr, C, 2 * anr, 0, ar, bc, nlimbs); - _nfixed_mat_mul_strassen2(Cr, Ar, B, cutoff, nlimbs); - _nfixed_mat_window_clear(Ar, nlimbs); - _nfixed_mat_window_clear(Cr, nlimbs); - } - - if (ac > 2 * anc) - { - _nfixed_mat_t Ac, Br, Cb, tmp; - slong mt, nt; - - _nfixed_mat_window_init(Ac, A, 0, 2 * anc, 2 * anr, ac, nlimbs); - _nfixed_mat_window_init(Br, B, 2 * bnr, 0, ac, 2 * bnc, nlimbs); - _nfixed_mat_window_init(Cb, C, 0, 0, 2 * anr, 2 * bnc, nlimbs); - - mt = Ac->r; - nt = Br->c; - - /* todo: faster */ - _nfixed_mat_init(tmp, mt, nt, nlimbs); - _nfixed_mat_mul_strassen2(tmp, Ac, Br, cutoff, nlimbs); - _nfixed_mat_add(Cb, Cb, tmp, nlimbs); - _nfixed_mat_clear(tmp, nlimbs); - _nfixed_mat_window_clear(Ac, nlimbs); - _nfixed_mat_window_clear(Br, nlimbs); - _nfixed_mat_window_clear(Cb, nlimbs); - } -} - -void -_nfixed_mat_mul_strassen(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong cutoff, slong nlimbs) -{ - _nfixed_mat_t CC, AA, BB; - - AA->start = (nn_ptr) A; - AA->r = m; - AA->c = n; - AA->row_stride = n * (nlimbs + 1); - - BB->start = (nn_ptr) B; - BB->r = n; - BB->c = p; - BB->row_stride = p * (nlimbs + 1); - - CC->start = C; - CC->r = m; - CC->c = p; - CC->row_stride = p * (nlimbs + 1); - - _nfixed_mat_mul_strassen2(CC, AA, BB, cutoff, nlimbs); -} +/* todo: retune classical -> fixed -> block cutoffs */ FLINT_FORCE_INLINE void _nfloat_get_nfixed(nn_ptr res, nn_srcptr x, slong exp, slong fix_nlimbs, gr_ctx_t ctx) @@ -1392,7 +72,7 @@ _nfloat_mat_exp_range(slong * _Amin, slong * _Amax, const gr_mat_t A, gr_ctx_t c } int -_nfloat_mat_mul_fixed_given_exp(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong Aexp, slong Bexp, slong fnlimbs, int waksman, gr_ctx_t ctx) +_nfloat_mat_mul_fixed_given_exp(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong Aexp, slong Bexp, slong fnlimbs, gr_ctx_t ctx) { nn_ptr T, TA, TB, TC; slong i, j; @@ -1420,12 +100,7 @@ _nfloat_mat_mul_fixed_given_exp(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, for (j = 0; j < p; j++) _nfloat_get_nfixed(TB + i * fdnlimbs * p + j * fdnlimbs, GR_MAT_ENTRY(B, i, j, sz), Bexp, fnlimbs, ctx); - if (waksman == 1) - _nfixed_mat_mul_waksman(TC, TA, TB, m, n, p, fnlimbs); - else if (waksman == 0) - _nfixed_mat_mul_classical(TC, TA, TB, m, n, p, fnlimbs); - else - _nfixed_mat_mul_strassen(TC, TA, TB, m, n, p, waksman, fnlimbs); + _nfixed_mat_mul(TC, TA, TB, m, n, p, fnlimbs); for (i = 0; i < m; i++) for (j = 0; j < p; j++) @@ -1437,7 +112,7 @@ _nfloat_mat_mul_fixed_given_exp(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, } int -_nfloat_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, int waksman, slong max_extra_bits, gr_ctx_t ctx) +nfloat_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong max_extra_bits, gr_ctx_t ctx) { slong Amax, Amin, Bmax, Bmin, Adelta, Bdelta, Aexp, Bexp; slong prec; @@ -1487,25 +162,7 @@ _nfloat_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, int waksma fbits = prec + extra_bits; fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS; - return _nfloat_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, waksman, ctx); -} - -int -nfloat_mat_mul_fixed_classical(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) -{ - return _nfloat_mat_mul_fixed(C, A, B, 0, 100000, ctx); -} - -int -nfloat_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) -{ - return _nfloat_mat_mul_fixed(C, A, B, 1, 100000, ctx); -} - -int -nfloat_mat_mul_strassen(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong cutoff, gr_ctx_t ctx) -{ - return _nfloat_mat_mul_fixed(C, A, B, cutoff, 100000, ctx); + return _nfloat_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx); } static void @@ -2273,7 +930,6 @@ int nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) { slong cutoff1, cutoff2, dim; - int use_waksman = 0; slong prec; slong max_extra_prec; @@ -2306,10 +962,9 @@ nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) if (dim < cutoff1) return gr_mat_mul_classical(C, A, B, ctx); - use_waksman = (dim >= cutoff2); max_extra_prec = (prec < 768) ? 64 : prec / 4; - return _nfloat_mat_mul_fixed(C, A, B, use_waksman, max_extra_prec, ctx); + return nfloat_mat_mul_fixed(C, A, B, max_extra_prec, ctx); } else { @@ -2351,7 +1006,7 @@ _nfloat_complex_mat_exp_range(slong * _Amin, slong * _Amax, const gr_mat_t A, gr } int -_nfloat_complex_mat_mul_fixed_given_exp(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong Aexp, slong Bexp, slong fnlimbs, int waksman, gr_ctx_t ctx) +_nfloat_complex_mat_mul_fixed_given_exp(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong Aexp, slong Bexp, slong fnlimbs, gr_ctx_t ctx) { nn_ptr T, Aa, Ab, Ba, Bb, AaBa, AbBb, Cb; slong i, j; @@ -2397,46 +1052,23 @@ _nfloat_complex_mat_mul_fixed_given_exp(gr_mat_t C, const gr_mat_t A, const gr_m /* (Aa Ba - Ab Bb) + ((Aa + Ab)(Ba + Bb) - Aa Ba - Ab Bb) i */ - if (waksman) - { - _nfixed_mat_mul_waksman(AaBa, Aa, Ba, m, n, p, fnlimbs); - _nfixed_mat_mul_waksman(AbBb, Ab, Bb, m, n, p, fnlimbs); - _nfixed_vec_add(Aa, Aa, Ab, m * n, fnlimbs); - _nfixed_vec_add(Ba, Ba, Bb, n * p, fnlimbs); - _nfixed_mat_mul_waksman(Cb, Aa, Ba, m, n, p, fnlimbs); - _nfixed_vec_sub(Cb, Cb, AaBa, m * p, fnlimbs); - _nfixed_vec_sub(Cb, Cb, AbBb, m * p, fnlimbs); - - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfloat_set_nfixed(NFLOAT_COMPLEX_IM(GR_MAT_ENTRY(C, i, j, sz), ctx), Cb + i * fdnlimbs * p + j * fdnlimbs, Aexp + Bexp, fnlimbs, ctx); - - _nfixed_vec_sub(Cb, AaBa, AbBb, m * p, fnlimbs); - - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfloat_set_nfixed(NFLOAT_COMPLEX_RE(GR_MAT_ENTRY(C, i, j, sz), ctx), Cb + i * fdnlimbs * p + j * fdnlimbs, Aexp + Bexp, fnlimbs, ctx); - } - else - { - _nfixed_mat_mul_classical(AaBa, Aa, Ba, m, n, p, fnlimbs); - _nfixed_mat_mul_classical(AbBb, Ab, Bb, m, n, p, fnlimbs); - _nfixed_vec_add(Aa, Aa, Ab, m * n, fnlimbs); - _nfixed_vec_add(Ba, Ba, Bb, n * p, fnlimbs); - _nfixed_mat_mul_classical(Cb, Aa, Ba, m, n, p, fnlimbs); - _nfixed_vec_sub(Cb, Cb, AaBa, m * p, fnlimbs); - _nfixed_vec_sub(Cb, Cb, AbBb, m * p, fnlimbs); - - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfloat_set_nfixed(NFLOAT_COMPLEX_IM(GR_MAT_ENTRY(C, i, j, sz), ctx), Cb + i * fdnlimbs * p + j * fdnlimbs, Aexp + Bexp, fnlimbs, ctx); - - _nfixed_vec_sub(Cb, AaBa, AbBb, m * p, fnlimbs); - - for (i = 0; i < m; i++) - for (j = 0; j < p; j++) - _nfloat_set_nfixed(NFLOAT_COMPLEX_RE(GR_MAT_ENTRY(C, i, j, sz), ctx), Cb + i * fdnlimbs * p + j * fdnlimbs, Aexp + Bexp, fnlimbs, ctx); - } + _nfixed_mat_mul(AaBa, Aa, Ba, m, n, p, fnlimbs); + _nfixed_mat_mul(AbBb, Ab, Bb, m, n, p, fnlimbs); + _nfixed_vec_add(Aa, Aa, Ab, m * n, fnlimbs); + _nfixed_vec_add(Ba, Ba, Bb, n * p, fnlimbs); + _nfixed_mat_mul(Cb, Aa, Ba, m, n, p, fnlimbs); + _nfixed_vec_sub(Cb, Cb, AaBa, m * p, fnlimbs); + _nfixed_vec_sub(Cb, Cb, AbBb, m * p, fnlimbs); + + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfloat_set_nfixed(NFLOAT_COMPLEX_IM(GR_MAT_ENTRY(C, i, j, sz), ctx), Cb + i * fdnlimbs * p + j * fdnlimbs, Aexp + Bexp, fnlimbs, ctx); + + _nfixed_vec_sub(Cb, AaBa, AbBb, m * p, fnlimbs); + + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfloat_set_nfixed(NFLOAT_COMPLEX_RE(GR_MAT_ENTRY(C, i, j, sz), ctx), Cb + i * fdnlimbs * p + j * fdnlimbs, Aexp + Bexp, fnlimbs, ctx); flint_free(T); @@ -2444,7 +1076,7 @@ _nfloat_complex_mat_mul_fixed_given_exp(gr_mat_t C, const gr_mat_t A, const gr_m } int -_nfloat_complex_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, int waksman, slong max_extra_bits, gr_ctx_t ctx) +nfloat_complex_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong max_extra_bits, gr_ctx_t ctx) { slong Amax, Amin, Bmax, Bmin, Adelta, Bdelta, Aexp, Bexp; slong prec; @@ -2494,19 +1126,7 @@ _nfloat_complex_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, in fbits = prec + extra_bits; fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS; - return _nfloat_complex_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, waksman, ctx); -} - -int -nfloat_complex_mat_mul_fixed_classical(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) -{ - return _nfloat_complex_mat_mul_fixed(C, A, B, 0, 100000, ctx); -} - -int -nfloat_complex_mat_mul_waksman(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) -{ - return _nfloat_complex_mat_mul_fixed(C, A, B, 1, 100000, ctx); + return _nfloat_complex_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx); } FLINT_FORCE_INLINE slong @@ -2772,7 +1392,6 @@ nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t { slong dim; slong block_cutoff; - int use_waksman = 0; slong prec; slong max_extra_prec; int A_real = 0, B_real = 0; @@ -2811,20 +1430,9 @@ nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t if (prec <= 128 || (prec <= 256 && n <= 4) || (prec == 576 && n <= 6)) return gr_mat_mul_classical(C, A, B, ctx); - if (prec <= 320) - use_waksman = 0; - else if (prec <= 448) - use_waksman = (dim >= 20); - else if (prec <= 640) - use_waksman = (dim >= 6); - else if (prec <= 1536) - use_waksman = (dim >= 4); - else - use_waksman = (dim >= 3); - max_extra_prec = (prec < 768) ? 64 : prec / 4; - return _nfloat_complex_mat_mul_fixed(C, A, B, use_waksman, max_extra_prec, ctx); + return nfloat_complex_mat_mul_fixed(C, A, B, max_extra_prec, ctx); } else { diff --git a/src/nfloat/nfixed.c b/src/nfloat/nfixed.c new file mode 100644 index 0000000000..cba1d53da1 --- /dev/null +++ b/src/nfloat/nfixed.c @@ -0,0 +1,1406 @@ +/* + Copyright (C) 2024 Fredrik Johansson + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "mpn_extras.h" +#include "gr.h" +#include "gr_vec.h" +#include "gr_mat.h" +#include "gr_generic.h" +#include "acf.h" +#include "acb.h" +#include "nfloat.h" +#include "gr_special.h" +#include "fmpz_mat.h" + +/* For printing */ +#include "arf.h" + +static int nfixed_mat_mul_use_waksman(slong n, slong nlimbs) +{ + if (nlimbs <= 8) + return 0; + if (nlimbs == 9) + return (n >= 6); + if (nlimbs == 10) + return (n >= 5); + if (nlimbs <= 24) + return (n >= 4); + if (nlimbs <= 46) + return (n >= 3); + return (n >= 2); +} + +static slong nfixed_mat_mul_strassen_cutoff(slong n, int parity, slong nlimbs) +{ + if (nlimbs <= 3) + return parity ? 57 : 50; + else + return parity ? 37 : 26; +} + +/* Arithmetic on fixed-point numbers in (-1,1) */ +/* x[0] stores the sign bit, x[1], ..., x[n] store the absolute value */ + +void +_nfixed_print(nn_srcptr x, slong nlimbs, slong exp) +{ + arf_t t; + arf_init(t); + _arf_set_mpn_fixed(t, x + 1, nlimbs, nlimbs, x[0], nlimbs * FLINT_BITS, ARF_RND_DOWN); + arf_mul_2exp_si(t, t, exp); + arf_printd(t, nlimbs * FLINT_BITS / 3.321928 + 1); + arf_clear(t); +} + +#define DEF_NFIXED_ADD(n) \ +FLINT_FORCE_INLINE \ +void _nfixed_add_ ## n(nn_ptr res, nn_srcptr a, nn_srcptr b) \ +{ \ + int asgn, bsgn; \ + asgn = a[0]; \ + bsgn = b[0]; \ + \ + if (asgn == bsgn) \ + { \ + res[0] = asgn; \ + NN_ADD_ ## n(res + 1, a + 1, b + 1); \ + } \ + else \ + { \ + res[0] = asgn ^ flint_mpn_signed_sub_ ## n(res + 1, a + 1, b + 1); \ + } \ +} + +#define DEF_NFIXED_SUB(n) \ +FLINT_FORCE_INLINE \ +void _nfixed_sub_ ## n(nn_ptr res, nn_srcptr a, nn_srcptr b) \ +{ \ + int asgn, bsgn; \ + asgn = a[0]; \ + bsgn = b[0]; \ + \ + if (asgn != bsgn) \ + { \ + res[0] = asgn; \ + NN_ADD_ ## n(res + 1, a + 1, b + 1); \ + } \ + else \ + { \ + res[0] = asgn ^ flint_mpn_signed_sub_ ## n(res + 1, a + 1, b + 1); \ + } \ +} + +DEF_NFIXED_ADD(2) +DEF_NFIXED_ADD(3) +DEF_NFIXED_ADD(4) +DEF_NFIXED_ADD(5) +DEF_NFIXED_ADD(6) +DEF_NFIXED_ADD(7) +DEF_NFIXED_ADD(8) + +DEF_NFIXED_SUB(2) +DEF_NFIXED_SUB(3) +DEF_NFIXED_SUB(4) +DEF_NFIXED_SUB(5) +DEF_NFIXED_SUB(6) +DEF_NFIXED_SUB(7) +DEF_NFIXED_SUB(8) + +FLINT_FORCE_INLINE +void _nfixed_add(nn_ptr res, nn_srcptr a, nn_srcptr b, slong nlimbs) +{ + int asgn, bsgn; + asgn = a[0]; + bsgn = b[0]; + + if (asgn == bsgn) + { + res[0] = asgn; + mpn_add_n(res + 1, a + 1, b + 1, nlimbs); + } + else + { + res[0] = asgn ^ flint_mpn_signed_sub_n(res + 1, a + 1, b + 1, nlimbs); + } +} + +FLINT_FORCE_INLINE +void _nfixed_sub(nn_ptr res, nn_srcptr a, nn_srcptr b, slong nlimbs) +{ + int asgn, bsgn; + asgn = a[0]; + bsgn = b[0]; + + if (asgn != bsgn) + { + res[0] = asgn; + mpn_add_n(res + 1, a + 1, b + 1, nlimbs); + } + else + { + res[0] = asgn ^ flint_mpn_signed_sub_n(res + 1, a + 1, b + 1, nlimbs); + } +} + +void _nfixed_vec_add(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs) +{ + slong i; + + if (nlimbs == 2) + { + for (i = 0; i < len; i++) + _nfixed_add_2(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 3) + { + for (i = 0; i < len; i++) + _nfixed_add_3(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 4) + { + for (i = 0; i < len; i++) + _nfixed_add_4(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 5) + { + for (i = 0; i < len; i++) + _nfixed_add_5(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 6) + { + for (i = 0; i < len; i++) + _nfixed_add_6(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 7) + { + for (i = 0; i < len; i++) + _nfixed_add_7(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 8) + { + for (i = 0; i < len; i++) + _nfixed_add_8(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else + { + for (i = 0; i < len; i++) + _nfixed_add(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1), nlimbs); + } +} + +void _nfixed_vec_sub(nn_ptr res, nn_srcptr a, nn_srcptr b, slong len, slong nlimbs) +{ + slong i; + + if (nlimbs == 2) + { + for (i = 0; i < len; i++) + _nfixed_sub_2(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 3) + { + for (i = 0; i < len; i++) + _nfixed_sub_3(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 4) + { + for (i = 0; i < len; i++) + _nfixed_sub_4(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 5) + { + for (i = 0; i < len; i++) + _nfixed_sub_5(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 6) + { + for (i = 0; i < len; i++) + _nfixed_sub_6(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 7) + { + for (i = 0; i < len; i++) + _nfixed_sub_7(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else if (nlimbs == 8) + { + for (i = 0; i < len; i++) + _nfixed_sub_8(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1)); + } + else + { + for (i = 0; i < len; i++) + _nfixed_sub(res + i * (nlimbs + 1), a + i * (nlimbs + 1), b + i * (nlimbs + 1), nlimbs); + } +} + +FLINT_FORCE_INLINE +void nfixed_mul(nn_ptr res, nn_srcptr a, nn_srcptr b, slong nlimbs) +{ + int asgn, bsgn; + asgn = a[0]; + bsgn = b[0]; + res[0] = asgn ^ bsgn; + flint_mpn_mulhigh_n(res + 1, a + 1, b + 1, nlimbs); +} + +FLINT_FORCE_INLINE +void nfixed_sqr(nn_ptr res, nn_srcptr a, slong nlimbs) +{ + res[0] = 0; + flint_mpn_sqrhigh(res + 1, a + 1, nlimbs); +} + +FLINT_FORCE_INLINE +void nfixed_div2(nn_ptr res, nn_srcptr a, slong nlimbs) +{ + res[0] = a[0]; + mpn_rshift(res + 1, a + 1, nlimbs, 1); +} + +void _nfixed_dot_2(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) +{ + slong j; + + ulong as, bs, a0, a1, b0, b1, s0, s1, t0, t1, u0, u1, hi, lo; + + s0 = s1 = t0 = t1 = u0 = u1 = 0; + + /* + s0 s1 + |a1 b1----| + u0 u1 + |a1 b0----| + t0 t1 + |a0 b1----| + */ + + for (j = 0; j < len; j++) + { + as = x[j * xstride]; + a0 = x[j * xstride + 1]; + a1 = x[j * xstride + 2]; + bs = y[j * ystride]; + b0 = y[j * ystride + 1]; + b1 = y[j * ystride + 2]; + + if (as == bs) + { + umul_ppmm(hi, lo, a1, b1); + add_ssaaaa(s1, s0, s1, s0, hi, lo); + umul_ppmm(hi, lo, a0, b1); + add_ssaaaa(t1, t0, t1, t0, 0, hi); + umul_ppmm(hi, lo, a1, b0); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + } + else + { + umul_ppmm(hi, lo, a1, b1); + sub_ddmmss(s1, s0, s1, s0, hi, lo); + umul_ppmm(hi, lo, a0, b1); + sub_ddmmss(t1, t0, t1, t0, 0, hi); + umul_ppmm(hi, lo, a1, b0); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + } + } + + add_ssaaaa(s1, s0, s1, s0, t1, t0); + add_ssaaaa(s1, s0, s1, s0, u1, u0); + + if ((slong) s1 < WORD(0)) + { + sub_ddmmss(s1, s0, 0, 0, s1, s0); + res[0] = 1; + } + else + { + res[0] = 0; + } + + res[1] = s0; + res[2] = s1; +} + +void _nfixed_dot_3(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) +{ + slong j; + + ulong as, bs, a0, a1, a2, b0, b1, b2, s0, s1, s2, hi, lo; + ulong u0, u1, v0, v1; + + s0 = s1 = s2 = 0; + u0 = u1 = v0 = v1 = 0; + + for (j = 0; j < len; j++) + { + as = x[j * xstride]; + a0 = x[j * xstride + 1]; + a1 = x[j * xstride + 2]; + a2 = x[j * xstride + 3]; + bs = y[j * ystride]; + b0 = y[j * ystride + 1]; + b1 = y[j * ystride + 2]; + b2 = y[j * ystride + 3]; + + /* + |a2 b2----| + + |a1 b2----| + |a2 b1----| + + |a0 b2----| + |a1 b1----| + |a2 b0----| + */ + + if (as == bs) + { + umul_ppmm(hi, lo, a0, b2); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a1, b1); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a2, b0); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + + umul_ppmm(hi, lo, a2, b2); + add_ssaaaa(s2, s1, s2, s1, hi, lo); + + umul_ppmm(hi, lo, a1, b2); + add_sssaaaaaa(s2, s1, s0, s2, s1, s0, 0, hi, lo); + umul_ppmm(hi, lo, a2, b1); + add_sssaaaaaa(s2, s1, s0, s2, s1, s0, 0, hi, lo); + } + else + { + umul_ppmm(hi, lo, a0, b2); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a1, b1); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a2, b0); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + + umul_ppmm(hi, lo, a2, b2); + sub_ddmmss(s2, s1, s2, s1, hi, lo); + + umul_ppmm(hi, lo, a1, b2); + sub_dddmmmsss(s2, s1, s0, s2, s1, s0, 0, hi, lo); + umul_ppmm(hi, lo, a2, b1); + sub_dddmmmsss(s2, s1, s0, s2, s1, s0, 0, hi, lo); + } + } + + if ((slong) u1 < WORD(0)) + { + sub_ddmmss(u1, u0, 0, 0, u1, u0); + sub_dddmmmsss(s2, s1, s0, s2, s1, s0, 0, u1, u0); + } + else + { + add_sssaaaaaa(s2, s1, s0, s2, s1, s0, 0, u1, u0); + } + + if ((slong) s2 < WORD(0)) + { + sub_dddmmmsss(s2, s1, s0, 0, 0, 0, s2, s1, s0); + res[0] = 1; + } + else + { + res[0] = 0; + } + + + res[1] = s0; + res[2] = s1; + res[3] = s2; +} + +void _nfixed_dot_4(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) +{ + slong j; + + /* + s1 s2 s3 + |a3 b3----| + + |a2 b3----| + |a3 b2----| + + t0 t1 t2 + + |a1 b3----| + |a2 b2----| + |a3 b1----| + + u0 u1 + + |a0 b3----| + |a1 b2----| + |a2 b1----| + |a3 b0----| + + */ + + ulong as, a0, a1, a2, a3; + ulong bs, b0, b1, b2, b3; + ulong s0, s1, s2, s3, t0, t1, t2, u0, u1; + ulong hi, lo; + + s0 = s1 = s2 = s3 = t0 = t1 = t2 = u0 = u1 = 0; + + for (j = 0; j < len; j++) + { + as = x[j * xstride]; + a0 = x[j * xstride + 1]; + a1 = x[j * xstride + 2]; + a2 = x[j * xstride + 3]; + a3 = x[j * xstride + 4]; + + bs = y[j * ystride]; + b0 = y[j * ystride + 1]; + b1 = y[j * ystride + 2]; + b2 = y[j * ystride + 3]; + b3 = y[j * ystride + 4]; + + if (as == bs) + { + umul_ppmm(hi, lo, a3, b3); + add_ssaaaa(s3, s2, s3, s2, hi, lo); + + umul_ppmm(hi, lo, a2, b3); + add_sssaaaaaa(s3, s2, s1, s3, s2, s1, 0, hi, lo); + umul_ppmm(hi, lo, a3, b2); + add_sssaaaaaa(s3, s2, s1, s3, s2, s1, 0, hi, lo); + + umul_ppmm(hi, lo, a1, b3); + add_sssaaaaaa(t2, t1, t0, t2, t1, t0, 0, hi, lo); + umul_ppmm(hi, lo, a2, b2); + add_sssaaaaaa(t2, t1, t0, t2, t1, t0, 0, hi, lo); + umul_ppmm(hi, lo, a3, b1); + add_sssaaaaaa(t2, t1, t0, t2, t1, t0, 0, hi, lo); + + umul_ppmm(hi, lo, a0, b3); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a1, b2); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a2, b1); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a3, b0); + add_ssaaaa(u1, u0, u1, u0, 0, hi); + } + else + { + umul_ppmm(hi, lo, a3, b3); + sub_ddmmss(s3, s2, s3, s2, hi, lo); + + umul_ppmm(hi, lo, a2, b3); + sub_dddmmmsss(s3, s2, s1, s3, s2, s1, 0, hi, lo); + umul_ppmm(hi, lo, a3, b2); + sub_dddmmmsss(s3, s2, s1, s3, s2, s1, 0, hi, lo); + + umul_ppmm(hi, lo, a1, b3); + sub_dddmmmsss(t2, t1, t0, t2, t1, t0, 0, hi, lo); + umul_ppmm(hi, lo, a2, b2); + sub_dddmmmsss(t2, t1, t0, t2, t1, t0, 0, hi, lo); + umul_ppmm(hi, lo, a3, b1); + sub_dddmmmsss(t2, t1, t0, t2, t1, t0, 0, hi, lo); + + umul_ppmm(hi, lo, a0, b3); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a1, b2); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a2, b1); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + umul_ppmm(hi, lo, a3, b0); + sub_ddmmss(u1, u0, u1, u0, 0, hi); + } + } + + if ((slong) u1 < WORD(0)) + { + sub_ddmmss(u1, u0, 0, 0, u1, u0); + sub_dddmmmsss(t2, t1, s0, t2, t1, s0, 0, u1, u0); + } + else + { + add_sssaaaaaa(t2, t1, t0, t2, t1, t0, 0, u1, u0); + } + + if ((slong) t2 < WORD(0)) + { + sub_dddmmmsss(t2, t1, t0, 0, 0, 0, t2, t1, t0); + sub_ddddmmmmssss(s3, s2, s1, s0, s3, s2, s1, s0, 0, t2, t1, t0); + } + else + { + add_ssssaaaaaaaa(s3, s2, s1, s0, s3, s2, s1, s0, 0, t2, t1, t0); + } + + if ((slong) s3 < WORD(0)) + { + sub_ddddmmmmssss(s3, s2, s1, s0, 0, 0, 0, 0, s3, s2, s1, s0); + res[0] = 1; + } + else + { + res[0] = 0; + } + + res[1] = s0; + res[2] = s1; + res[3] = s2; + res[4] = s3; +} + +void _nfixed_dot_5(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) +{ + slong j; + slong nlimbs = 5; + + ulong tmp[6]; + ulong spos[6] = { 0, 0, 0, 0, 0, 0 }; + ulong sneg[6] = { 0, 0, 0, 0, 0, 0 }; + + if (x[0] == y[0]) + flint_mpn_mulhigh_n(spos + 1, x + 1, y + 1, nlimbs); + else + flint_mpn_mulhigh_n(sneg + 1, x + 1, y + 1, nlimbs); + + for (j = 1; j < len; j++) + { + nfixed_mul(tmp, x + j * xstride, y + j * ystride, nlimbs); + + if (tmp[0] == 0) + NN_ADD_5(spos + 1, spos + 1, tmp + 1); + else + NN_ADD_5(sneg + 1, sneg + 1, tmp + 1); + } + + _nfixed_sub_5(res, spos, sneg); +} + +void _nfixed_dot_6(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) +{ + slong j; + slong nlimbs = 6; + + ulong tmp[7]; + ulong spos[7] = { 0, 0, 0, 0, 0, 0, 0 }; + ulong sneg[7] = { 0, 0, 0, 0, 0, 0, 0 }; + + if (x[0] == y[0]) + flint_mpn_mulhigh_n(spos + 1, x + 1, y + 1, nlimbs); + else + flint_mpn_mulhigh_n(sneg + 1, x + 1, y + 1, nlimbs); + + for (j = 1; j < len; j++) + { + nfixed_mul(tmp, x + j * xstride, y + j * ystride, nlimbs); + + if (tmp[0] == 0) + NN_ADD_6(spos + 1, spos + 1, tmp + 1); + else + NN_ADD_6(sneg + 1, sneg + 1, tmp + 1); + } + + _nfixed_sub_6(res, spos, sneg); +} + +void _nfixed_dot_7(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) +{ + slong j; + slong nlimbs = 7; + + ulong tmp[8]; + ulong spos[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }; + ulong sneg[8] = { 0, 0, 0, 0, 0, 0, 0, 0 }; + + if (x[0] == y[0]) + flint_mpn_mulhigh_n(spos + 1, x + 1, y + 1, nlimbs); + else + flint_mpn_mulhigh_n(sneg + 1, x + 1, y + 1, nlimbs); + + for (j = 1; j < len; j++) + { + nfixed_mul(tmp, x + j * xstride, y + j * ystride, nlimbs); + + if (tmp[0] == 0) + NN_ADD_7(spos + 1, spos + 1, tmp + 1); + else + NN_ADD_7(sneg + 1, sneg + 1, tmp + 1); + } + + _nfixed_sub_7(res, spos, sneg); +} + +void _nfixed_dot_8(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len) +{ + slong j; + slong nlimbs = 8; + + ulong tmp[9]; + ulong spos[9] = { 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + ulong sneg[9] = { 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + + if (x[0] == y[0]) + flint_mpn_mulhigh_n(spos + 1, x + 1, y + 1, nlimbs); + else + flint_mpn_mulhigh_n(sneg + 1, x + 1, y + 1, nlimbs); + + for (j = 1; j < len; j++) + { + nfixed_mul(tmp, x + j * xstride, y + j * ystride, nlimbs); + + if (tmp[0] == 0) + NN_ADD_8(spos + 1, spos + 1, tmp + 1); + else + NN_ADD_8(sneg + 1, sneg + 1, tmp + 1); + } + + _nfixed_sub_8(res, spos, sneg); +} + +/* A is (m x n), B is (n x p), C is (m x p) */ +void +_nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) +{ + slong i, j, k; + nn_ptr t; + +#define A_ENTRY(i, j) ((A) + ((i) * n + (j)) * (nlimbs + 1)) +#define B_ENTRY(i, j) ((B) + ((i) * p + (j)) * (nlimbs + 1)) +#define C_ENTRY(i, j) ((C) + ((i) * p + (j)) * (nlimbs + 1)) + + if (n == 1) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + nfixed_mul(C_ENTRY(i, j), A_ENTRY(i, 0), B_ENTRY(0, j), nlimbs); + return; + } + + if (nlimbs == 2) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_2(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); + } + else if (nlimbs == 3) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_3(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); + } + else if (nlimbs == 4) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_4(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); + } + else if (nlimbs == 5) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_5(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); + } + else if (nlimbs == 6) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_6(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); + } + else if (nlimbs == 7) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_7(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); + } + else if (nlimbs == 8) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_8(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), (nlimbs + 1) * p, n); + } + else + { + TMP_INIT; + TMP_START; + + t = TMP_ALLOC((nlimbs + 1) * sizeof(ulong)); + + for (i = 0; i < m; i++) + { + for (j = 0; j < p; j++) + { + nfixed_mul(C_ENTRY(i, j), A_ENTRY(i, 0), B_ENTRY(0, j), nlimbs); + + for (k = 1; k < n; k++) + { + nfixed_mul(t, A_ENTRY(i, k), B_ENTRY(k, j), nlimbs); + _nfixed_add(C_ENTRY(i, j), C_ENTRY(i, j), t, nlimbs); + } + } + } + + TMP_END; + } + +#undef A_ENTRY +#undef B_ENTRY +#undef C_ENTRY +} + +/* compute c += (a1 + b1) * (a2 + b2) */ +/* val0, val1, val2 are scratch space */ +FLINT_FORCE_INLINE void +addmul_addadd(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2, slong nlimbs) +{ + _nfixed_add(val1, a1, b1, nlimbs); + _nfixed_add(val2, a2, b2, nlimbs); + nfixed_mul(val0, val1, val2, nlimbs); + _nfixed_add(c, c, val0, nlimbs); +} + +/* compute c += (a1 - b1) * (a2 - b2) */ +/* val0, val1, val2 are scratch space */ +FLINT_FORCE_INLINE void +addmul_subsub(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2, slong nlimbs) +{ + _nfixed_sub(val1, a1, b1, nlimbs); + _nfixed_sub(val2, a2, b2, nlimbs); + nfixed_mul(val0, val1, val2, nlimbs); + _nfixed_add(c, c, val0, nlimbs); +} + +/* + Inlining speeds up Waksman multiplication with small nlimbs. + Further speedups are possible by reordering the loops so that + the O(n^3) part is done using dot products. + However, these tricks currently do not suffice to beat + classical multiplications in the relevant ranges, so we + do not bother here. +*/ + +#define WAKSMAN_WANT_INLINING 0 + +#if WAKSMAN_WANT_INLINING + +FLINT_FORCE_INLINE void +addmul_addadd_4(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) +{ + _nfixed_add_4(val1, a1, b1); + _nfixed_add_4(val2, a2, b2); + nfixed_mul(val0, val1, val2, 4); + _nfixed_add_4(c, c, val0); +} + +FLINT_FORCE_INLINE void +addmul_addadd_5(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) +{ + _nfixed_add_5(val1, a1, b1); + _nfixed_add_5(val2, a2, b2); + nfixed_mul(val0, val1, val2, 5); + _nfixed_add_5(c, c, val0); +} + +FLINT_FORCE_INLINE void +addmul_addadd_6(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) +{ + _nfixed_add_6(val1, a1, b1); + _nfixed_add_6(val2, a2, b2); + nfixed_mul(val0, val1, val2, 6); + _nfixed_add_6(c, c, val0); +} + +FLINT_FORCE_INLINE void +addmul_addadd_7(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) +{ + _nfixed_add_7(val1, a1, b1); + _nfixed_add_7(val2, a2, b2); + nfixed_mul(val0, val1, val2, 7); + _nfixed_add_7(c, c, val0); +} + +FLINT_FORCE_INLINE void +addmul_addadd_8(nn_ptr val0, nn_ptr val1, nn_ptr val2, nn_ptr c, nn_srcptr a1, nn_srcptr b1, nn_srcptr a2, nn_srcptr b2) +{ + _nfixed_add_8(val1, a1, b1); + _nfixed_add_8(val2, a2, b2); + nfixed_mul(val0, val1, val2, 8); + _nfixed_add_8(c, c, val0); +} + +#endif + +void +_nfixed_mat_mul_waksman2(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs, slong Cstride, slong Astride, slong Bstride) +{ + slong l, j, k; + + slong np = n >> 1; + + nn_ptr Ctmp = flint_calloc((nlimbs + 1) * ((p + m) + 5), sizeof(ulong)); + + /* remaining temp space */ + nn_ptr Crow = Ctmp; /* Crow has p entries */ + nn_ptr Ccol = Crow + (nlimbs + 1) * p; /* Ccol has m entries */ + nn_ptr val0 = Ccol + (nlimbs + 1) * m; /* val0 has room for 2 sums */ + nn_ptr val1 = val0 + (nlimbs + 1) * 2; /* val1 has room for 1 sum */ + nn_ptr val2 = val1 + (nlimbs + 1); /* val2 has room for 1 sum */ + nn_ptr crow = val2 + (nlimbs + 1); /* crow has room for 1 sum */ + +#define A_ENTRY(i, j) ((A) + (i) * Astride + (j) * (nlimbs + 1)) +#define B_ENTRY(i, j) ((B) + (i) * Bstride + (j) * (nlimbs + 1)) +#define C_ENTRY(i, j) ((C) + (i) * Cstride + (j) * (nlimbs + 1)) + +#define Crow_ENTRY(ii) (Crow + (ii) * (nlimbs + 1)) +#define Ccol_ENTRY(ii) (Ccol + (ii) * (nlimbs + 1)) + + /* todo: zero only where needed */ + for (j = 0; j < m; j++) + flint_mpn_zero(C_ENTRY(j, 0), p * (nlimbs + 1)); + + for (j = 1; j <= np; j++) + { + slong j2 = (j << 1) - 1; + + for (k = 0; k < p; k++) + { + addmul_addadd(val0, val1, val2, C_ENTRY(0, k), A_ENTRY(0, j2-1), B_ENTRY(j2, k), A_ENTRY(0, j2), B_ENTRY(j2-1, k), nlimbs); + addmul_subsub(val0, val1, val2, Crow_ENTRY(k), A_ENTRY(0, j2-1), B_ENTRY(j2, k), A_ENTRY(0, j2), B_ENTRY(j2-1, k), nlimbs); + } + + for (l = 1; l < m; l++) + { + addmul_addadd(val0, val1, val2, C_ENTRY(l, 0), A_ENTRY(l, j2-1), B_ENTRY(j2, 0), A_ENTRY(l, j2), B_ENTRY(j2-1, 0), nlimbs); + addmul_subsub(val0, val1, val2, Ccol_ENTRY(l), A_ENTRY(l, j2-1), B_ENTRY(j2, 0), A_ENTRY(l, j2), B_ENTRY(j2-1, 0), nlimbs); + } + +#if WAKSMAN_WANT_INLINING + if (nlimbs == 5) + { + for (k = 1; k < p; k++) + for (l = 1; l < m; l++) + addmul_addadd_5(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k)); + } + else if (nlimbs == 6) + { + for (k = 1; k < p; k++) + for (l = 1; l < m; l++) + addmul_addadd_6(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k)); + } + else if (nlimbs == 7) + { + for (k = 1; k < p; k++) + for (l = 1; l < m; l++) + addmul_addadd_7(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k)); + } + else if (nlimbs == 8) + { + for (k = 1; k < p; k++) + for (l = 1; l < m; l++) + addmul_addadd_8(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k)); + } + else +#endif + { + for (k = 1; k < p; k++) + for (l = 1; l < m; l++) + addmul_addadd(val0, val1, val2, C_ENTRY(l, k), A_ENTRY(l, j2-1), B_ENTRY(j2, k), A_ENTRY(l, j2), B_ENTRY(j2-1, k), nlimbs); + } + } + + for (l = 1; l < m; l++) + { + _nfixed_add(val1, Ccol_ENTRY(l), C_ENTRY(l, 0), nlimbs); + nfixed_div2(Ccol_ENTRY(l), val1, nlimbs); + _nfixed_sub(C_ENTRY(l, 0), C_ENTRY(l, 0), Ccol_ENTRY(l), nlimbs); + } + + _nfixed_add(val1, Crow, C_ENTRY(0, 0), nlimbs); + nfixed_div2(val0, val1, nlimbs); + _nfixed_sub(C_ENTRY(0, 0), C_ENTRY(0, 0), val0, nlimbs); + + for (k = 1; k < p; k++) + { + _nfixed_add(crow, Crow_ENTRY(k), C_ENTRY(0, k), nlimbs); + nfixed_div2(val1, crow, nlimbs); + _nfixed_sub(C_ENTRY(0, k), C_ENTRY(0, k), val1, nlimbs); + _nfixed_sub(crow, val1, val0, nlimbs); + + for (l = 1; l < m; l++) + { + _nfixed_sub(val2, C_ENTRY(l, k), crow, nlimbs); + _nfixed_sub(C_ENTRY(l, k), val2, Ccol_ENTRY(l), nlimbs); + } + } + + if ((n & 1) == 1) + { + for (l = 0; l < m; l++) + { + for (k = 0; k < p; k++) + { + nfixed_mul(val0, A_ENTRY(l, n-1), B_ENTRY(n-1, k), nlimbs); + _nfixed_add(C_ENTRY(l, k), C_ENTRY(l, k), val0, nlimbs); + } + } + } + + flint_free(Ctmp); + +#undef A_ENTRY +#undef B_ENTRY +#undef C_ENTRY +} + +void +_nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) +{ + _nfixed_mat_mul_waksman2(C, A, B, m, n, p, nlimbs, p * (nlimbs + 1), n * (nlimbs + 1), p * (nlimbs + 1)); +} + + +typedef struct +{ + nn_ptr start; + slong r; + slong c; + slong row_stride; +} +_nfixed_mat_struct; + +typedef _nfixed_mat_struct _nfixed_mat_t[1]; + +static void +_nfixed_mat_init(_nfixed_mat_t A, slong r, slong c, slong nlimbs) +{ + A->start = flint_malloc((nlimbs + 1) * (r * c) * sizeof(ulong)); + A->r = r; + A->c = c; + A->row_stride = c * (nlimbs + 1); +} + +static void +_nfixed_mat_clear(_nfixed_mat_t A, slong nlimbs) +{ + flint_free(A->start); +} + +static void +_nfixed_mat_window_init(_nfixed_mat_t A, const _nfixed_mat_t mat, slong r1, slong c1, slong r2, slong c2, slong nlimbs) +{ + A->start = mat->start + (r1 * mat->row_stride) + c1 * (nlimbs + 1); + A->r = r2 - r1; + A->c = c2 - c1; + A->row_stride = mat->row_stride; +} + +static void +_nfixed_mat_window_clear(_nfixed_mat_t A, slong nlimbs) +{ +} + +/* +static void +nfixed_mat_print(nn_ptr A, slong ar, slong ac, slong nlimbs) +{ + slong i, j; + flint_printf("{%wd y %wd : [", ar, ac); + for (i = 0; i < ar; i++) + for (j = 0; j < ac; j++) + { + _nfixed_print(A + i * ac * (nlimbs + 1) + j * (nlimbs + 1), nlimbs, 0); + flint_printf(", "); + } + + flint_printf("]}\n"); +} + +static void +_nfixed_mat_print(_nfixed_mat_t A, slong nlimbs) +{ + slong i, j; + flint_printf("{%wd x %wd : [", A->r, A->c); + for (i = 0; i < A->r; i++) + for (j = 0; j < A->c; j++) + { + _nfixed_print(A->start + i * A->row_stride + j * (nlimbs + 1), nlimbs, 0); + flint_printf(", "); + } + + flint_printf("]}\n"); +} +*/ + +static void +_nfixed_mat_add(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong nlimbs) +{ + nn_srcptr Aptr, Bptr; + nn_ptr Cptr; + + Aptr = A->start; + Bptr = B->start; + Cptr = C->start; + + slong Astride = A->row_stride; + slong Bstride = B->row_stride; + slong Cstride = C->row_stride; + + slong i, r = A->r, c = A->c; + + for (i = 0; i < r; i++) + _nfixed_vec_add(Cptr + i * Cstride, Aptr + i * Astride, Bptr + i * Bstride, c, nlimbs); +} + +static void +_nfixed_mat_sub(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong nlimbs) +{ + nn_srcptr Aptr, Bptr; + nn_ptr Cptr; + + Aptr = A->start; + Bptr = B->start; + Cptr = C->start; + + slong Astride = A->row_stride; + slong Bstride = B->row_stride; + slong Cstride = C->row_stride; + + slong i, r = A->r, c = A->c; + + for (i = 0; i < r; i++) + _nfixed_vec_sub(Cptr + i * Cstride, Aptr + i * Astride, Bptr + i * Bstride, c, nlimbs); +} + +void +_nfixed_mat_mul_waksman3(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong nlimbs) +{ + nn_srcptr Aptr, Bptr; + nn_ptr Cptr; + + Aptr = A->start; + Bptr = B->start; + Cptr = C->start; + + slong Astride = A->row_stride; + slong Bstride = B->row_stride; + slong Cstride = C->row_stride; + + slong m = A->r; + slong n = A->c; + slong p = B->c; + + _nfixed_mat_mul_waksman2(Cptr, Aptr, Bptr, m, n, p, nlimbs, Cstride, Astride, Bstride); +} + +static void +_nfixed_mat_mul_classical2(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong nlimbs) +{ + nn_srcptr Aptr, Bptr; + nn_ptr Cptr; + + Aptr = A->start; + Bptr = B->start; + Cptr = C->start; + + slong Astride = A->row_stride; + slong Bstride = B->row_stride; + slong Cstride = C->row_stride; + + slong m = A->r; + slong n = A->c; + slong p = B->c; + + slong i, j, k; + nn_ptr t; + +#define A_ENTRY(i, j) ((Aptr) + (i) * Astride + (j) * (nlimbs + 1)) +#define B_ENTRY(i, j) ((Bptr) + (i) * Bstride + (j) * (nlimbs + 1)) +#define C_ENTRY(i, j) ((Cptr) + (i) * Cstride + (j) * (nlimbs + 1)) + + if (n == 1) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + nfixed_mul(C_ENTRY(i, j), A_ENTRY(i, 0), B_ENTRY(0, j), nlimbs); + return; + } + + if (nlimbs == 2) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_2(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); + } + else if (nlimbs == 3) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_3(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); + } + else if (nlimbs == 4) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_4(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); + } + else if (nlimbs == 5) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_5(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); + } + else if (nlimbs == 6) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_6(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); + } + else if (nlimbs == 7) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_7(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); + } + else if (nlimbs == 8) + { + for (i = 0; i < m; i++) + for (j = 0; j < p; j++) + _nfixed_dot_8(C_ENTRY(i, j), A_ENTRY(i, 0), nlimbs + 1, B_ENTRY(0, j), Bstride, n); + } + else + { + TMP_INIT; + TMP_START; + + t = TMP_ALLOC((nlimbs + 1) * sizeof(ulong)); + + for (i = 0; i < m; i++) + { + for (j = 0; j < p; j++) + { + nfixed_mul(C_ENTRY(i, j), A_ENTRY(i, 0), B_ENTRY(0, j), nlimbs); + + for (k = 1; k < n; k++) + { + nfixed_mul(t, A_ENTRY(i, k), B_ENTRY(k, j), nlimbs); + _nfixed_add(C_ENTRY(i, j), C_ENTRY(i, j), t, nlimbs); + } + } + } + + TMP_END; + } + +#undef A_ENTRY +#undef B_ENTRY +#undef C_ENTRY +} + + +static void +_nfixed_mat_mul_strassen2(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_mat_t B, slong cutoff, slong nlimbs) +{ + slong ar, ac, bc, nn; + slong anr, anc, bnr, bnc; + + _nfixed_mat_t A11, A12, A21, A22; + _nfixed_mat_t B11, B12, B21, B22; + _nfixed_mat_t C11, C12, C21, C22; + _nfixed_mat_t X1, X2; + + ar = A->r; + ac = A->c; + bc = B->c; + + nn = FLINT_MIN(ar, ac); + nn = FLINT_MIN(nn, bc); + + if (cutoff < 0) + cutoff = nfixed_mat_mul_strassen_cutoff(nn, ac & 1, nlimbs); + else + cutoff = FLINT_MAX(cutoff, 2); + + if (nn < cutoff) + { + if (nfixed_mat_mul_use_waksman(nn, nlimbs)) + _nfixed_mat_mul_waksman3(C, A, B, nlimbs); + else + _nfixed_mat_mul_classical2(C, A, B, nlimbs); + return; + } + + anr = ar / 2; + anc = ac / 2; + bnr = anc; + bnc = bc / 2; + + _nfixed_mat_window_init(A11, A, 0, 0, anr, anc, nlimbs); + _nfixed_mat_window_init(A12, A, 0, anc, anr, 2 * anc, nlimbs); + _nfixed_mat_window_init(A21, A, anr, 0, 2 * anr, anc, nlimbs); + _nfixed_mat_window_init(A22, A, anr, anc, 2 * anr, 2 * anc, nlimbs); + + _nfixed_mat_window_init(B11, B, 0, 0, bnr, bnc, nlimbs); + _nfixed_mat_window_init(B12, B, 0, bnc, bnr, 2 * bnc, nlimbs); + _nfixed_mat_window_init(B21, B, bnr, 0, 2 * bnr, bnc, nlimbs); + _nfixed_mat_window_init(B22, B, bnr, bnc, 2 * bnr, 2 * bnc, nlimbs); + + _nfixed_mat_window_init(C11, C, 0, 0, anr, bnc, nlimbs); + _nfixed_mat_window_init(C12, C, 0, bnc, anr, 2 * bnc, nlimbs); + _nfixed_mat_window_init(C21, C, anr, 0, 2 * anr, bnc, nlimbs); + _nfixed_mat_window_init(C22, C, anr, bnc, 2 * anr, 2 * bnc, nlimbs); + + _nfixed_mat_init(X1, anr, FLINT_MAX(bnc, anc), nlimbs); + _nfixed_mat_init(X2, anc, bnc, nlimbs); + + X1->c = anc; + + _nfixed_mat_add(X1, A22, A12, nlimbs); + _nfixed_mat_add(X2, B22, B12, nlimbs); + _nfixed_mat_mul_strassen2(C21, X1, X2, cutoff, nlimbs); + + _nfixed_mat_sub(X1, A22, A21, nlimbs); + _nfixed_mat_sub(X2, B22, B21, nlimbs); + _nfixed_mat_mul_strassen2(C22, X1, X2, cutoff, nlimbs); + + _nfixed_mat_add(X1, X1, A12, nlimbs); + _nfixed_mat_add(X2, X2, B12, nlimbs); + _nfixed_mat_mul_strassen2(C11, X1, X2, cutoff, nlimbs); + + _nfixed_mat_sub(X1, X1, A11, nlimbs); + _nfixed_mat_mul_strassen2(C12, X1, B12, cutoff, nlimbs); + + X1->c = bnc; + _nfixed_mat_mul_strassen2(X1, A12, B21, cutoff, nlimbs); + _nfixed_mat_add(C11, C11, X1, nlimbs); + _nfixed_mat_add(C12, C12, C22, nlimbs); + _nfixed_mat_sub(C12, C11, C12, nlimbs); + _nfixed_mat_sub(C11, C21, C11, nlimbs); + _nfixed_mat_sub(X2, X2, B11, nlimbs); + _nfixed_mat_mul_strassen2(C21, A21, X2, cutoff, nlimbs); + + _nfixed_mat_clear(X2, nlimbs); + + _nfixed_mat_sub(C21, C11, C21, nlimbs); + _nfixed_mat_add(C22, C22, C11, nlimbs); + _nfixed_mat_mul_strassen2(C11, A11, B11, cutoff, nlimbs); + + _nfixed_mat_add(C11, X1, C11, nlimbs); + + X1->c = FLINT_MAX(bnc, anc); + _nfixed_mat_clear(X1, nlimbs); + + _nfixed_mat_window_clear(A11, nlimbs); + _nfixed_mat_window_clear(A12, nlimbs); + _nfixed_mat_window_clear(A21, nlimbs); + _nfixed_mat_window_clear(A22, nlimbs); + + _nfixed_mat_window_clear(B11, nlimbs); + _nfixed_mat_window_clear(B12, nlimbs); + _nfixed_mat_window_clear(B21, nlimbs); + _nfixed_mat_window_clear(B22, nlimbs); + + _nfixed_mat_window_clear(C11, nlimbs); + _nfixed_mat_window_clear(C12, nlimbs); + _nfixed_mat_window_clear(C21, nlimbs); + _nfixed_mat_window_clear(C22, nlimbs); + + if (bc > 2 * bnc) + { + _nfixed_mat_t Bc, Cc; + _nfixed_mat_window_init(Bc, B, 0, 2 * bnc, ac, bc, nlimbs); + _nfixed_mat_window_init(Cc, C, 0, 2 * bnc, ar, bc, nlimbs); + _nfixed_mat_mul_strassen2(Cc, A, Bc, cutoff, nlimbs); + _nfixed_mat_window_clear(Bc, nlimbs); + _nfixed_mat_window_clear(Cc, nlimbs); + } + + if (ar > 2 * anr) + { + _nfixed_mat_t Ar, Cr; + _nfixed_mat_window_init(Ar, A, 2 * anr, 0, ar, ac, nlimbs); + _nfixed_mat_window_init(Cr, C, 2 * anr, 0, ar, bc, nlimbs); + _nfixed_mat_mul_strassen2(Cr, Ar, B, cutoff, nlimbs); + _nfixed_mat_window_clear(Ar, nlimbs); + _nfixed_mat_window_clear(Cr, nlimbs); + } + + if (ac > 2 * anc) + { + _nfixed_mat_t Ac, Br, Cb, tmp; + slong mt, nt; + + _nfixed_mat_window_init(Ac, A, 0, 2 * anc, 2 * anr, ac, nlimbs); + _nfixed_mat_window_init(Br, B, 2 * bnr, 0, ac, 2 * bnc, nlimbs); + _nfixed_mat_window_init(Cb, C, 0, 0, 2 * anr, 2 * bnc, nlimbs); + + mt = Ac->r; + nt = Br->c; + + /* todo: faster */ + _nfixed_mat_init(tmp, mt, nt, nlimbs); + _nfixed_mat_mul_strassen2(tmp, Ac, Br, cutoff, nlimbs); + _nfixed_mat_add(Cb, Cb, tmp, nlimbs); + _nfixed_mat_clear(tmp, nlimbs); + _nfixed_mat_window_clear(Ac, nlimbs); + _nfixed_mat_window_clear(Br, nlimbs); + _nfixed_mat_window_clear(Cb, nlimbs); + } +} + +void +_nfixed_mat_mul_strassen(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong cutoff, slong nlimbs) +{ + _nfixed_mat_t CC, AA, BB; + + AA->start = (nn_ptr) A; + AA->r = m; + AA->c = n; + AA->row_stride = n * (nlimbs + 1); + + BB->start = (nn_ptr) B; + BB->r = n; + BB->c = p; + BB->row_stride = p * (nlimbs + 1); + + CC->start = C; + CC->r = m; + CC->c = p; + CC->row_stride = p * (nlimbs + 1); + + _nfixed_mat_mul_strassen2(CC, AA, BB, cutoff, nlimbs); +} + +void +_nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs) +{ + slong d, cutoff; + + d = FLINT_MIN(m, n); + d = FLINT_MIN(d, p); + + if (d > 10) + { + cutoff = nfixed_mat_mul_strassen_cutoff(d, n & 1, nlimbs); + + if (n >= cutoff) + { + _nfixed_mat_mul_strassen(C, A, B, m, n, p, -1, nlimbs); + return; + } + } + + if (nfixed_mat_mul_use_waksman(d, nlimbs)) + _nfixed_mat_mul_waksman(C, A, B, m, n, p, nlimbs); + else + _nfixed_mat_mul_classical(C, A, B, m, n, p, nlimbs); +} diff --git a/src/nfloat/profile/p-mat_mul.c b/src/nfloat/profile/p-mat_mul.c index 172a3c2adb..19bd0183ec 100644 --- a/src/nfloat/profile/p-mat_mul.c +++ b/src/nfloat/profile/p-mat_mul.c @@ -21,7 +21,6 @@ #include "double_extras.h" #define TABN (NFLOAT_MAX_LIMBS + 1) -#define WAKSMAN_MIN_PREC 320 #if 1 #undef TIMEIT_END_REPEAT @@ -57,11 +56,11 @@ randmat(gr_mat_t mat, flint_rand_t state, gr_ctx_t ctx) } } -void tune_fixed_vs_waksman(int * cutoffs) +void tune_classical_vs_fixed(int * cutoffs) { gr_ctx_t ctx; gr_mat_t A, B, C; - slong i, n; + slong i, n, nn; slong prec; double FLINT_SET_BUT_UNUSED(__), t1, t2; @@ -71,14 +70,16 @@ void tune_fixed_vs_waksman(int * cutoffs) for (i = 0; i < TABN; i++) cutoffs[i] = -1; - for (prec = WAKSMAN_MIN_PREC; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec += 64) + for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec += 64) { flint_printf("prec = %wd\n", prec); nfloat_ctx_init(ctx, prec, 0); - for (n = 2; n <= 128; n++) + for (nn = 1; nn <= 128; nn++) { + n = (nn == 1) ? 128 : nn; + gr_mat_init(A, n, n, ctx); gr_mat_init(B, n, n, ctx); gr_mat_init(C, n, n, ctx); @@ -87,11 +88,11 @@ void tune_fixed_vs_waksman(int * cutoffs) randmat(B, state, ctx); TIMEIT_START - GR_MUST_SUCCEED(nfloat_mat_mul_fixed_classical(C, A, B, ctx)); + GR_MUST_SUCCEED(gr_mat_mul_classical(C, A, B, ctx)); TIMEIT_STOP_VALUES(__, t1) TIMEIT_START - GR_MUST_SUCCEED(nfloat_mat_mul_waksman(C, A, B, ctx)); + GR_MUST_SUCCEED(nfloat_mat_mul_fixed(C, A, B, 1000, ctx)); TIMEIT_STOP_VALUES(__, t2) flint_printf("%wd %wd %e %e %.3f\n", prec, n, t1, t2, t1 / t2); @@ -100,11 +101,19 @@ void tune_fixed_vs_waksman(int * cutoffs) gr_mat_clear(B, ctx); gr_mat_clear(C, ctx); - if (t2 < t1 * 0.99) + if (nn == 1) + { + if (t2 < t1) + continue; + else + break; + } + + if (t2 < t1) { cutoffs[prec / 64] = n; - flint_printf("short tab_fixed_classical_vs_waksman[] = {\n"); + flint_printf("short tab_classical_vs_fixed[] = {\n"); for (i = 0; i <= prec / 64; i++) flint_printf(" %d, /* prec = %wd */\n", cutoffs[i], i * 64); flint_printf("}\n"); @@ -117,30 +126,26 @@ void tune_fixed_vs_waksman(int * cutoffs) flint_rand_clear(state); } -void tune_strassen(int * cutoffs) +slong ns[] = { 2, 3, 4, 6, 8, 12, 16, 24, 32, 48, 64, 80, 96, 112, 128, 144, 160, 192, 224, 256, 288, 320, 352, 0 }; + +void prof_classical_vs_fixed() { gr_ctx_t ctx; gr_mat_t A, B, C; - slong i, n; + slong i, ni, n; slong prec; double FLINT_SET_BUT_UNUSED(__), t1, t2; - int prev_ok = 0; flint_rand_t state; flint_rand_init(state); - for (i = 0; i < TABN; i++) - cutoffs[i] = -1; - - for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec += 64) + for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec = (prec < 1024) ? prec + 64 : prec + 256) { - flint_printf("prec = %wd\n", prec); - - prev_ok = 0; + flint_printf("%wd ", prec); nfloat_ctx_init(ctx, prec, 0); - for (n = 2; n <= 128; n++) + for (ni = 8; (n = ns[ni]) != 0; ni++) { gr_mat_init(A, n, n, ctx); gr_mat_init(B, n, n, ctx); @@ -150,131 +155,85 @@ void tune_strassen(int * cutoffs) randmat(B, state, ctx); TIMEIT_START - if (prec < WAKSMAN_MIN_PREC) - GR_MUST_SUCCEED(nfloat_mat_mul_fixed_classical(C, A, B, ctx)); - else - GR_MUST_SUCCEED(nfloat_mat_mul_waksman(C, A, B, ctx)); + GR_MUST_SUCCEED(gr_mat_mul_classical(C, A, B, ctx)); TIMEIT_STOP_VALUES(__, t1) TIMEIT_START - GR_MUST_SUCCEED(nfloat_mat_mul_strassen(C, A, B, n, ctx)); + GR_MUST_SUCCEED(nfloat_mat_mul_fixed(C, A, B, 1000, ctx)); TIMEIT_STOP_VALUES(__, t2) - flint_printf("%wd %wd %e %e %.3f\n", prec, n, t1, t2, t1 / t2); + flint_printf("%.3f ", t1 / t2); + fflush(stdout); gr_mat_clear(A, ctx); gr_mat_clear(B, ctx); gr_mat_clear(C, ctx); - - if (t2 < t1 * 0.99) - { - if (prev_ok) - { - cutoffs[prec / 64] = n; - - flint_printf("short tab_strassen[] = {\n"); - for (i = 0; i <= prec / 64; i++) - flint_printf(" %d, /* prec = %wd */\n", cutoffs[i], i * 64); - flint_printf("}\n"); - - break; - } - else - { - prev_ok = 1; - } - } - else - { - prev_ok = 0; - } } + + flint_printf("\n"); } flint_rand_clear(state); } +void prof_fixed_vs_block() +{ + gr_ctx_t ctx; + gr_mat_t A, B, C; + slong i, ni, n; + slong prec; + double FLINT_SET_BUT_UNUSED(__), t1, t2; -short tab_fixed_classical_vs_waksman[] = { - -1, /* prec = 0 */ - -1, /* prec = 64 */ - -1, /* prec = 128 */ - -1, /* prec = 192 */ - -1, /* prec = 256 */ - 16, /* prec = 320 */ - 10, /* prec = 384 */ - 7, /* prec = 448 */ - 7, /* prec = 512 */ - 6, /* prec = 576 */ - 5, /* prec = 640 */ - 4, /* prec = 704 */ - 4, /* prec = 768 */ - 4, /* prec = 832 */ - 4, /* prec = 896 */ - 4, /* prec = 960 */ - 4, /* prec = 1024 */ - 4, /* prec = 1088 */ - 4, /* prec = 1152 */ - 4, /* prec = 1216 */ - 4, /* prec = 1280 */ - 3, /* prec = 1344 */ - 4, /* prec = 1408 */ - 3, /* prec = 1472 */ - 3, /* prec = 1536 */ - 3, /* prec = 1600 */ - 3, /* prec = 1664 */ - 3, /* prec = 1728 */ - 3, /* prec = 1792 */ - 3, /* prec = 1856 */ - 3, /* prec = 1920 */ - 3, /* prec = 1984 */ - 3, /* prec = 2048 */ - 3, /* prec = 2112 */ - 3, /* prec = 2176 */ - 3, /* prec = 2240 */ - 3, /* prec = 2304 */ - 3, /* prec = 2368 */ - 3, /* prec = 2432 */ - 3, /* prec = 2496 */ - 3, /* prec = 2560 */ - 3, /* prec = 2624 */ - 3, /* prec = 2688 */ - 3, /* prec = 2752 */ - 3, /* prec = 2816 */ - 3, /* prec = 2880 */ - 3, /* prec = 2944 */ - 3, /* prec = 3008 */ - 2, /* prec = 3072 */ - 3, /* prec = 3136 */ - 3, /* prec = 3200 */ - 2, /* prec = 3264 */ - 2, /* prec = 3328 */ - 2, /* prec = 3392 */ - 2, /* prec = 3456 */ - 2, /* prec = 3520 */ - 3, /* prec = 3584 */ - 2, /* prec = 3648 */ - 2, /* prec = 3712 */ - 2, /* prec = 3776 */ - 2, /* prec = 3840 */ - 2, /* prec = 3904 */ - 2, /* prec = 3968 */ - 2, /* prec = 4032 */ - 2, /* prec = 4096 */ - 2, /* prec = 4160 */ - 2, /* prec = 4224 */ -}; + flint_rand_t state; + flint_rand_init(state); + flint_printf(" "); + for (ni = 8; (n = ns[ni]) != 0; ni++) + flint_printf("%5wd ", n); + flint_printf("\n"); + for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec = (prec < 1024) ? prec + 64 : prec + 256) + { + flint_printf("%4wd ", prec); -int main() -{ - int tab_fixed_classical_vs_waksman[TABN]; - int tab_strassen[TABN]; + nfloat_ctx_init(ctx, prec, 0); + + for (ni = 8; (n = ns[ni]) != 0; ni++) + { + gr_mat_init(A, n, n, ctx); + gr_mat_init(B, n, n, ctx); + gr_mat_init(C, n, n, ctx); + + randmat(A, state, ctx); + randmat(B, state, ctx); + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_mat_mul_fixed(C, A, B, 1000, ctx)); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_mat_mul_block(C, A, B, 1, ctx)); + TIMEIT_STOP_VALUES(__, t2) - tune_strassen(tab_strassen); + flint_printf("%.3f ", t1 / t2); + fflush(stdout); + gr_mat_clear(A, ctx); + gr_mat_clear(B, ctx); + gr_mat_clear(C, ctx); + } - tune_fixed_vs_waksman(tab_fixed_classical_vs_waksman); + flint_printf("\n"); + } + + flint_rand_clear(state); +} + +int main() +{ + int tab_classical_vs_fixed[TABN]; + //tune_classical_vs_fixed(tab_classical_vs_fixed); + //prof_classical_vs_fixed(); + prof_fixed_vs_block(); } diff --git a/src/nfloat/profile/p-nfixed_mat_mul.c b/src/nfloat/profile/p-nfixed_mat_mul.c new file mode 100644 index 0000000000..dc6cabd5b7 --- /dev/null +++ b/src/nfloat/profile/p-nfixed_mat_mul.c @@ -0,0 +1,268 @@ +/* + Copyright (C) 2024 Fredrik Johansson + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include +#include "fmpz.h" +#include "gr.h" +#include "gr_special.h" +#include "gr_vec.h" +#include "gr_mat.h" +#include "arf.h" +#include "nfloat.h" +#include "profiler.h" +#include "double_extras.h" + +#define TABN (NFLOAT_MAX_LIMBS + 1) + +#if 1 +#undef TIMEIT_END_REPEAT +#define TIMEIT_END_REPEAT(__timer, __reps) \ + } \ + timeit_stop(__timer); \ + if (__timer->cpu >= 100) \ + break; \ + __reps *= 10; \ + } \ + } while (0); +#endif + +void +randmat(gr_mat_t mat, flint_rand_t state, gr_ctx_t ctx) +{ + slong m = gr_mat_nrows(mat, ctx); + slong n = gr_mat_ncols(mat, ctx); + + slong i, j; + + for (i = 0; i < m; i++) + { + for (j = 0; j < n; j++) + { + gr_ptr v = gr_mat_entry_ptr(mat, i, j, ctx); + + GR_MUST_SUCCEED(gr_set_si(v, 1 + n_randint(state, 1000), ctx)); + GR_MUST_SUCCEED(gr_div_ui(v, v, 1 + n_randint(state, 1000), ctx)); + if (n_randint(state, 2)) + GR_MUST_SUCCEED(gr_neg(v, v, ctx)); + } + } +} + +void +nfixed_rand(nn_ptr a, flint_rand_t state, slong nlimbs) +{ + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + a[nlimbs] >>= 10; +} + +void +nfixed_randmat(nn_ptr a, slong m, slong n, flint_rand_t state, slong nlimbs) +{ + slong i; + for (i = 0; i < m * n; i++) + nfixed_rand(a + i * (nlimbs + 1), state, nlimbs); +} + +void tune_fixed_vs_waksman(int * cutoffs) +{ + nn_ptr A, B, C; + slong i, n, nlimbs, nn; + double FLINT_SET_BUT_UNUSED(__), t1, t2; + + flint_rand_t state; + flint_rand_init(state); + + for (i = 0; i < TABN; i++) + cutoffs[i] = -1; + + for (nlimbs = 2; nlimbs <= NFLOAT_MAX_LIMBS; nlimbs++) + { + flint_printf("nlimbs = %wd\n", nlimbs); + + for (nn = 1; nn <= 64; nn++) + { + n = (nn == 1) ? 128 : nn; + + A = flint_malloc((n * n) * (nlimbs + 1) * sizeof(ulong)); + B = flint_malloc((n * n) * (nlimbs + 1) * sizeof(ulong)); + C = flint_malloc((n * n) * (nlimbs + 1) * sizeof(ulong)); + + nfixed_randmat(A, n, n, state, nlimbs); + nfixed_randmat(B, n, n, state, nlimbs); + + TIMEIT_START + _nfixed_mat_mul_classical(C, A, B, n, n, n, nlimbs); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + _nfixed_mat_mul_waksman(C, A, B, n, n, n, nlimbs); + TIMEIT_STOP_VALUES(__, t2) + + flint_free(A); + flint_free(B); + flint_free(C); + + flint_printf("%wd %wd %e %e %.3f\n", nlimbs * FLINT_BITS, n, t1, t2, t1 / t2); + + if (nn == 1) + { + if (t2 < t1) + continue; + else + break; + } + + if (t2 < t1) + { + cutoffs[nlimbs] = n; + + flint_printf("short tab_fixed_classical_vs_waksman[] = {\n"); + for (i = 0; i <= nlimbs; i++) + flint_printf(" %d, /* nlimbs = %wd */\n", cutoffs[i], i); + flint_printf("}\n"); + + break; + } + } + } +} + +void tune_strassen(int * cutoffs) +{ + nn_ptr A, B, C; + slong i, n, nlimbs; + int parity; + int last_ok; + double FLINT_SET_BUT_UNUSED(__), t1, t2; + + flint_rand_t state; + flint_rand_init(state); + + for (parity = 0; parity < 2; parity++) + { + for (i = 0; i < TABN; i++) + cutoffs[i] = -1; + + for (nlimbs = 2; nlimbs <= NFLOAT_MAX_LIMBS; nlimbs++) + { + flint_printf("nlimbs = %wd\n", nlimbs); + + last_ok = 0; + + for (n = parity ? 1 : 2; ; n += 2) + { + A = flint_malloc((n * n) * (nlimbs + 1) * sizeof(ulong)); + B = flint_malloc((n * n) * (nlimbs + 1) * sizeof(ulong)); + C = flint_malloc((n * n) * (nlimbs + 1) * sizeof(ulong)); + + nfixed_randmat(A, n, n, state, nlimbs); + nfixed_randmat(B, n, n, state, nlimbs); + + TIMEIT_START + _nfixed_mat_mul_strassen(C, A, B, n, n, n, n + 1, nlimbs); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + _nfixed_mat_mul_strassen(C, A, B, n, n, n, n, nlimbs); + TIMEIT_STOP_VALUES(__, t2) + + flint_free(A); + flint_free(B); + flint_free(C); + + flint_printf("%wd %wd %e %e %.3f\n", nlimbs * FLINT_BITS, n, t1, t2, t1 / t2); + + if (t2 < t1) + { + if (!last_ok) + { + last_ok = 1; + continue; + } + + cutoffs[nlimbs] = n; + + if (parity) + flint_printf("short tab_strassen_odd[] = {\n"); + else + flint_printf("short tab_strassen_even[] = {\n"); + for (i = 0; i <= nlimbs; i++) + flint_printf(" %d, /* nlimbs = %wd */\n", cutoffs[i], i); + flint_printf("}\n"); + + break; + } + else + { + last_ok = 0; + } + } + } + } +} + +void prof_strassen_1() +{ + nn_ptr A, B, C; + slong i, n, nlimbs; + int parity; + double FLINT_SET_BUT_UNUSED(__), t1, t2; + + flint_rand_t state; + flint_rand_init(state); + + for (nlimbs = 2; nlimbs <= NFLOAT_MAX_LIMBS; nlimbs++) + { + for (parity = 0; parity < 2; parity++) + { + flint_printf("nlimbs = %wd ", nlimbs); + + if (nlimbs <= 3) + n = parity ? 57 : 50; + else + n = parity ? 37 : 26; + + A = flint_malloc((n * n) * (nlimbs + 1) * sizeof(ulong)); + B = flint_malloc((n * n) * (nlimbs + 1) * sizeof(ulong)); + C = flint_malloc((n * n) * (nlimbs + 1) * sizeof(ulong)); + + nfixed_randmat(A, n, n, state, nlimbs); + nfixed_randmat(B, n, n, state, nlimbs); + + TIMEIT_START + _nfixed_mat_mul_strassen(C, A, B, n, n, n, n + 1, nlimbs); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + _nfixed_mat_mul_strassen(C, A, B, n, n, n, n, nlimbs); + TIMEIT_STOP_VALUES(__, t2) + + flint_free(A); + flint_free(B); + flint_free(C); + + flint_printf("%wd %e %e %.3f ", n, t1, t2, t1 / t2); + } + + flint_printf("\n"); + } +} + +int main() +{ + int tab_fixed_classical_vs_waksman[TABN]; + int tab_strassen[TABN]; + + //tune_fixed_vs_waksman(tab_fixed_classical_vs_waksman); + //tune_strassen(tab_strassen); + prof_strassen_1(); +} diff --git a/src/nfloat/test/main.c b/src/nfloat/test/main.c index 077f81f91b..88f9a6c362 100644 --- a/src/nfloat/test/main.c +++ b/src/nfloat/test/main.c @@ -16,6 +16,7 @@ #include "t-complex_mat_mul.c" #include "t-mat_mul.c" #include "t-nfixed_dot.c" +#include "t-nfixed_mat_mul.c" #include "t-nfloat.c" #include "t-nfloat_complex.c" @@ -28,6 +29,7 @@ test_struct tests[] = TEST_FUNCTION(complex_mat_mul), TEST_FUNCTION(mat_mul), TEST_FUNCTION(nfixed_dot), + TEST_FUNCTION(nfixed_mat_mul), TEST_FUNCTION(nfloat), TEST_FUNCTION(nfloat_complex), }; diff --git a/src/nfloat/test/t-complex_mat_mul.c b/src/nfloat/test/t-complex_mat_mul.c index 8d5e38a978..aafc4aecf6 100644 --- a/src/nfloat/test/t-complex_mat_mul.c +++ b/src/nfloat/test/t-complex_mat_mul.c @@ -20,6 +20,12 @@ nfloat_complex_mat_mul_block1(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr return nfloat_complex_mat_mul_block(C, A, B, 2, ctx); } +int +nfloat_complex_mat_mul_fixed1(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) +{ + return nfloat_complex_mat_mul_fixed(C, A, B, 1000, ctx); +} + TEST_FUNCTION_START(complex_mat_mul, state) { gr_ctx_t ctx; @@ -45,7 +51,7 @@ TEST_FUNCTION_START(complex_mat_mul, state) tol, state, 10, 4, ctx); gr_mat_test_approx_mul_max_norm( - (gr_method_mat_binary_op) nfloat_complex_mat_mul_waksman, + (gr_method_mat_binary_op) nfloat_complex_mat_mul_fixed1, tol, state, (prec <= 256) ? 100 : 1, 10, ctx); gr_mat_test_approx_mul_max_norm( @@ -54,7 +60,7 @@ TEST_FUNCTION_START(complex_mat_mul, state) (prec <= 256) ? 40 : 20, ctx); gr_mat_test_approx_mul_max_norm( - (gr_method_mat_binary_op) nfloat_complex_mat_mul_fixed_classical, + (gr_method_mat_binary_op) nfloat_complex_mat_mul_fixed1, tol, state, (prec <= 256) ? 10 : 1, (prec <= 256) ? 40 : 20, ctx); diff --git a/src/nfloat/test/t-mat_mul.c b/src/nfloat/test/t-mat_mul.c index d0849a193a..4a5115ba3e 100644 --- a/src/nfloat/test/t-mat_mul.c +++ b/src/nfloat/test/t-mat_mul.c @@ -20,6 +20,13 @@ nfloat_mat_mul_block1(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t c return nfloat_mat_mul_block(C, A, B, 1, ctx); } +int +nfloat_mat_mul_fixed1(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) +{ + return nfloat_mat_mul_fixed(C, A, B, 1000, ctx); +} + + TEST_FUNCTION_START(mat_mul, state) { gr_ctx_t ctx; @@ -38,25 +45,7 @@ TEST_FUNCTION_START(mat_mul, state) GR_MUST_SUCCEED(gr_mul_2exp_si(tol, tol, -prec + 2, ctx)); gr_mat_test_approx_mul_max_norm( - (gr_method_mat_binary_op) nfloat_mat_mul_fixed_classical, - tol, state, 10, 10, ctx); - - gr_heap_clear(tol, ctx); - gr_ctx_clear(ctx); - } - - for (iter = 0; iter < 100 * flint_test_multiplier(); iter++) - { - prec = 64; - - nfloat_ctx_init(ctx, prec, 0); - - tol = gr_heap_init(ctx); - GR_MUST_SUCCEED(gr_one(tol, ctx)); - GR_MUST_SUCCEED(gr_mul_2exp_si(tol, tol, -prec + 2, ctx)); - - gr_mat_test_approx_mul_max_norm( - (gr_method_mat_binary_op) nfloat_mat_mul_waksman, + (gr_method_mat_binary_op) nfloat_mat_mul_fixed1, tol, state, 10, 10, ctx); gr_heap_clear(tol, ctx); @@ -76,17 +65,13 @@ TEST_FUNCTION_START(mat_mul, state) GR_MUST_SUCCEED(gr_one(tol, ctx)); GR_MUST_SUCCEED(gr_mul_2exp_si(tol, tol, -prec + 2, ctx)); - gr_mat_test_approx_mul_max_norm( - (gr_method_mat_binary_op) nfloat_mat_mul_waksman, - tol, state, (prec <= 256) ? 10 : 1, 10, ctx); - gr_mat_test_approx_mul_max_norm( (gr_method_mat_binary_op) nfloat_mat_mul_block1, tol, state, (prec <= 256) ? 10 : 1, (prec <= 256) ? 40 : 20, ctx); gr_mat_test_approx_mul_max_norm( - (gr_method_mat_binary_op) nfloat_mat_mul_fixed_classical, + (gr_method_mat_binary_op) nfloat_mat_mul_fixed1, tol, state, (prec <= 256) ? 10 : 1, (prec <= 256) ? 40 : 20, ctx); @@ -112,17 +97,13 @@ TEST_FUNCTION_START(mat_mul, state) GR_MUST_SUCCEED(gr_one(tol, ctx)); GR_MUST_SUCCEED(gr_mul_2exp_si(tol, tol, -prec + 6, ctx)); - gr_mat_test_approx_mul_pos_entrywise_accurate( - (gr_method_mat_binary_op) nfloat_mat_mul_waksman, - tol, state, (prec <= 256) ? 10 : 1, 10, ctx); - gr_mat_test_approx_mul_pos_entrywise_accurate( (gr_method_mat_binary_op) nfloat_mat_mul_block1, tol, state, (prec <= 256) ? 10 : 1, (prec <= 256) ? 40 : 20, ctx); gr_mat_test_approx_mul_pos_entrywise_accurate( - (gr_method_mat_binary_op) nfloat_mat_mul_fixed_classical, + (gr_method_mat_binary_op) nfloat_mat_mul_fixed1, tol, state, (prec <= 256) ? 10 : 1, (prec <= 256) ? 40 : 20, ctx); diff --git a/src/nfloat/test/t-nfixed_dot.c b/src/nfloat/test/t-nfixed_dot.c index 9c980d98db..9e29a0c86b 100644 --- a/src/nfloat/test/t-nfixed_dot.c +++ b/src/nfloat/test/t-nfixed_dot.c @@ -152,4 +152,8 @@ TEST_FUNCTION_START(nfixed_dot, state) } TEST_FUNCTION_END(state); -} \ No newline at end of file +} + +#undef MAXLEN +#undef MINLIMBS +#undef MAXLIMBS diff --git a/src/nfloat/test/t-nfixed_mat_mul.c b/src/nfloat/test/t-nfixed_mat_mul.c new file mode 100644 index 0000000000..5de8d8de6f --- /dev/null +++ b/src/nfloat/test/t-nfixed_mat_mul.c @@ -0,0 +1,101 @@ +/* + Copyright (C) 2024 Fredrik Johansson + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include "test_helpers.h" +#include "fmpq.h" +#include "arf.h" +#include "gr_vec.h" +#include "gr_special.h" +#include "nfloat.h" + +TEST_FUNCTION_START(nfixed_mat_mul, state) +{ + slong iter, m, n, p, i, nlimbs; + nn_ptr A, B, C, D, t; + nn_ptr a; + int which; + + slong MAXN = 10; + slong MINLIMBS = 2; + slong MAXLIMBS = 12; + + for (iter = 0; iter < 10000 * flint_test_multiplier(); iter++) + { + which = n_randint(state, 4); + + m = 1 + n_randint(state, MAXN); + n = 1 + n_randint(state, MAXN); + p = 1 + n_randint(state, MAXN); + + nlimbs = MINLIMBS + n_randint(state, MAXLIMBS - MINLIMBS + 1); + + ulong maxerr = 2 * (2 * nlimbs - 1) * n; + + A = flint_malloc((nlimbs + 1) * (m * n) * sizeof(ulong)); + B = flint_malloc((nlimbs + 1) * (n * p) * sizeof(ulong)); + C = flint_malloc((nlimbs + 1) * (m * p) * sizeof(ulong)); + D = flint_malloc((nlimbs + 1) * (m * p) * sizeof(ulong)); + t = flint_malloc((nlimbs + 1) * sizeof(ulong)); + + for (i = 0; i < m * n; i++) + { + a = A + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + a[nlimbs] >>= 10; + } + + for (i = 0; i < n * p; i++) + { + a = B + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + a[nlimbs] >>= 10; + } + + for (i = 0; i < m * p; i++) + { + a = C + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + + a = D + i * (nlimbs + 1); + a[0] = n_randint(state, 2); + flint_mpn_rrandom(a + 1, state, nlimbs); + } + + _nfixed_mat_mul_classical(C, A, B, m, n, p, nlimbs); + + if (which == 0) + _nfixed_mat_mul_waksman(D, A, B, m, n, p, nlimbs); + else + _nfixed_mat_mul_strassen(D, A, B, m, n, p, which, nlimbs); + + for (i = 0; i < m * p; i++) + { + nfixed_sub(t, C + i * (nlimbs + 1), D + i * (nlimbs + 1), nlimbs); + + if (!flint_mpn_zero_p(t + 2, nlimbs - 1) || t[1] > maxerr) + { + TEST_FUNCTION_FAIL("nlimbs = %wd, m = %wd, n = %wd, p = %wd\n\nt = %{ulong*}, maxerr = %wu\n\nA = %{ulong*}\n\nB = %{ulong*}\n\nC = %{ulong*}\n\nD = %{ulong*}\n\n", + nlimbs, m, n, p, + t, nlimbs + 1, maxerr, A, m * n * (nlimbs + 1), B, n * p * (nlimbs + 1), C, m * p * (nlimbs + 1), D, m * p * (nlimbs + 1)); + } + } + + flint_free(A); + flint_free(B); + flint_free(C); + flint_free(D); + } + + TEST_FUNCTION_END(state); +} \ No newline at end of file From d247cc1896d5a2b72fc5514b366f0b2aa142da09 Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Wed, 4 Sep 2024 10:28:05 +0200 Subject: [PATCH 09/15] retune real matrix mul --- src/nfloat/mat_mul.c | 220 ++++++++++++++++++++++++--------- src/nfloat/profile/p-mat_mul.c | 128 ++++++++++++++++++- 2 files changed, 290 insertions(+), 58 deletions(-) diff --git a/src/nfloat/mat_mul.c b/src/nfloat/mat_mul.c index 99d2393f4e..d62af20d2a 100644 --- a/src/nfloat/mat_mul.c +++ b/src/nfloat/mat_mul.c @@ -20,7 +20,155 @@ #include "gr_special.h" #include "fmpz_mat.h" -/* todo: retune classical -> fixed -> block cutoffs */ +/* todo: check errors which depend on nlimbs */ + +/* cutoffs classical -> block */ +#define CUTOFF_CLASSICAL_BLOCK 70 + +#define TAB_INDEX(prec) (FLINT_MIN(prec, 4224) / 64) + +/* cutoffs for classical -> fixed */ +static short tab_classical_vs_fixed[] = { + -1, /* prec = 0 */ + 14, /* prec = 64 */ + 16, /* prec = 128 */ + 5, /* prec = 192 */ + 7, /* prec = 256 */ + 3, /* prec = 320 */ + 3, /* prec = 384 */ + 3, /* prec = 448 */ + 3, /* prec = 512 */ + 10, /* prec = 576 */ + 5, /* prec = 640 */ + 4, /* prec = 704 */ + 4, /* prec = 768 */ + 4, /* prec = 832 */ + 4, /* prec = 896 */ + 4, /* prec = 960 */ + 4, /* prec = 1024 */ + 4, /* prec = 1088 */ + 3, /* prec = 1152 */ + 4, /* prec = 1216 */ + 4, /* prec = 1280 */ + 4, /* prec = 1344 */ + 4, /* prec = 1408 */ + 4, /* prec = 1472 */ + 4, /* prec = 1536 */ + 4, /* prec = 1600 */ + 3, /* prec = 1664 */ + 3, /* prec = 1728 */ + 3, /* prec = 1792 */ + 3, /* prec = 1856 */ + 3, /* prec = 1920 */ + 3, /* prec = 1984 */ + 3, /* prec = 2048 */ + 3, /* prec = 2112 */ + 3, /* prec = 2176 */ + 3, /* prec = 2240 */ + 3, /* prec = 2304 */ + 3, /* prec = 2368 */ + 3, /* prec = 2432 */ + 3, /* prec = 2496 */ + 3, /* prec = 2560 */ + 3, /* prec = 2624 */ + 3, /* prec = 2688 */ + 3, /* prec = 2752 */ + 3, /* prec = 2816 */ + 3, /* prec = 2880 */ + 3, /* prec = 2944 */ + 3, /* prec = 3008 */ + 3, /* prec = 3072 */ + 3, /* prec = 3136 */ + 3, /* prec = 3200 */ + 2, /* prec = 3264 */ + 3, /* prec = 3328 */ + 3, /* prec = 3392 */ + 3, /* prec = 3456 */ + 2, /* prec = 3520 */ + 2, /* prec = 3584 */ + 2, /* prec = 3648 */ + 3, /* prec = 3712 */ + 2, /* prec = 3776 */ + 2, /* prec = 3840 */ + 2, /* prec = 3904 */ + 2, /* prec = 3968 */ + 2, /* prec = 4032 */ + 3, /* prec = 4096 */ + 2, /* prec = 4160 */ + 2, /* prec = 4224 */ +}; + +/* cutoffs for fixed -> block */ +static short tab_fixed_vs_block[] = { + -1, /* prec = 0 */ + 50, /* prec = 64 */ + 94, /* prec = 128 */ + 124, /* prec = 192 */ + 86, /* prec = 256 */ + 196, /* prec = 320 */ + 215, /* prec = 384 */ + 236, /* prec = 448 */ + 236, /* prec = 512 */ + 196, /* prec = 576 */ + 215, /* prec = 640 */ + 196, /* prec = 704 */ + 196, /* prec = 768 */ + 196, /* prec = 832 */ + 179, /* prec = 896 */ + 179, /* prec = 960 */ + 179, /* prec = 1024 */ + 179, /* prec = 1088 */ + 149, /* prec = 1152 */ + 149, /* prec = 1216 */ + 163, /* prec = 1280 */ + 149, /* prec = 1344 */ + 149, /* prec = 1408 */ + 149, /* prec = 1472 */ + 149, /* prec = 1536 */ + 124, /* prec = 1600 */ + 124, /* prec = 1664 */ + 124, /* prec = 1728 */ + 124, /* prec = 1792 */ + 124, /* prec = 1856 */ + 124, /* prec = 1920 */ + 103, /* prec = 1984 */ + 124, /* prec = 2048 */ + 124, /* prec = 2112 */ + 124, /* prec = 2176 */ + 103, /* prec = 2240 */ + 103, /* prec = 2304 */ + 103, /* prec = 2368 */ + 103, /* prec = 2432 */ + 103, /* prec = 2496 */ + 103, /* prec = 2560 */ + 103, /* prec = 2624 */ + 94, /* prec = 2688 */ + 94, /* prec = 2752 */ + 94, /* prec = 2816 */ + 94, /* prec = 2880 */ + 86, /* prec = 2944 */ + 86, /* prec = 3008 */ + 86, /* prec = 3072 */ + 79, /* prec = 3136 */ + 79, /* prec = 3200 */ + 79, /* prec = 3264 */ + 79, /* prec = 3328 */ + 79, /* prec = 3392 */ + 79, /* prec = 3456 */ + 79, /* prec = 3520 */ + 79, /* prec = 3584 */ + 79, /* prec = 3648 */ + 79, /* prec = 3712 */ + 79, /* prec = 3776 */ + 79, /* prec = 3840 */ + 79, /* prec = 3904 */ + 79, /* prec = 3968 */ + 79, /* prec = 4032 */ + 79, /* prec = 4096 */ + 79, /* prec = 4160 */ + 79, /* prec = 4224 */ +}; + FLINT_FORCE_INLINE void _nfloat_get_nfixed(nn_ptr res, nn_srcptr x, slong exp, slong fix_nlimbs, gr_ctx_t ctx) @@ -133,14 +281,14 @@ nfloat_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong max_e /* Currently, we don't handle zeros. (They pose no problem, but zero entries in the output may not be exact. To be done.) */ if (Amin < NFLOAT_MIN_EXP || Bmin < NFLOAT_MIN_EXP) - return gr_mat_mul_classical(C, A, B, ctx); + return GR_UNABLE; Adelta = Amax - Amin; Bdelta = Bmax - Bmin; /* sanity check */ if (Adelta > 10 * prec || Bdelta > 10 * prec) - return gr_mat_mul_classical(C, A, B, ctx); + return GR_UNABLE; /* To double check: for Waksman, @@ -155,7 +303,7 @@ nfloat_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong max_e extra_bits = Adelta + Bdelta + pad_top + pad_bot; if (extra_bits >= max_extra_bits) - return gr_mat_mul_classical(C, A, B, ctx); + return GR_UNABLE; Aexp = Amax + pad_top; Bexp = Bmax + pad_top; @@ -897,41 +1045,13 @@ nfloat_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_b return status; } -/* Minimum precision for using fixed-point arithmetic */ - -/* TODO: for *unsigned* matrices, there is a speedup already for - prec = 192. Consider inlining fixed-point additions/subtractions for - 4 and 5 limbs to extend this to general matrices. */ -/* #define NFLOAT_MAT_MUL_FIXED_CUTOFF 192 */ -#define NFLOAT_MAT_MUL_FIXED_CUTOFF 320 - -/* first cutoff: classical -> fixed_classical - second cutoff: fixed_classical -> waksman */ -static const int nfloat_mat_mul_cutoff_tab[][2] = { - {0, 0}, /* prec = 0 */ - {0, 0}, /* prec = 64 */ - {0, 0}, /* prec = 128 */ - {32, 32}, /* prec = 192 */ - {8, 20}, /* prec = 256 */ - {4, 15}, /* prec = 320 */ - {3, 10}, /* prec = 384 */ - {3, 10}, /* prec = 448 */ - {3, 8}, /* prec = 512 */ - {10, 10}, /* prec = 576 */ - {4, 5}, /* prec = 640 */ -}; - -/* {4, 4} from this point */ -#define NFLOAT_MAT_MUL_CUTOFF_4 704 -/* {3, 3} from this point */ -#define NFLOAT_MAT_MUL_CUTOFF_3 1600 - int nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) { slong cutoff1, cutoff2, dim; slong prec; slong max_extra_prec; + int status; slong m = A->r; slong n = A->c; @@ -942,34 +1062,24 @@ nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) if (dim <= 2 || NFLOAT_CTX_HAS_INF_NAN(ctx)) return gr_mat_mul_classical(C, A, B, ctx); - if (dim <= 80) - { - prec = NFLOAT_CTX_PREC(ctx); - - if (prec < NFLOAT_MAT_MUL_FIXED_CUTOFF) - return gr_mat_mul_classical(C, A, B, ctx); + cutoff1 = tab_classical_vs_fixed[TAB_INDEX(prec)]; - if (prec >= NFLOAT_MAT_MUL_CUTOFF_3) - cutoff1 = cutoff2 = 3; - else if (prec >= NFLOAT_MAT_MUL_CUTOFF_4) - cutoff1 = cutoff2 = 4; - else - { - cutoff1 = nfloat_mat_mul_cutoff_tab[prec / 64][0]; - cutoff2 = nfloat_mat_mul_cutoff_tab[prec / 64][1]; - } + if (dim < cutoff1) + return gr_mat_mul_classical(C, A, B, ctx); - if (dim < cutoff1) - return gr_mat_mul_classical(C, A, B, ctx); + cutoff2 = tab_fixed_vs_block[TAB_INDEX(prec)]; + if (dim < cutoff2) + { max_extra_prec = (prec < 768) ? 64 : prec / 4; - return nfloat_mat_mul_fixed(C, A, B, max_extra_prec, ctx); - } - else - { - return nfloat_mat_mul_block(C, A, B, 70, ctx); + status = nfloat_mat_mul_fixed(C, A, B, max_extra_prec, ctx); + + if (status == GR_UNABLE && dim < CUTOFF_CLASSICAL_BLOCK) + return gr_mat_mul_classical(C, A, B, ctx); } + + return nfloat_mat_mul_block(C, A, B, CUTOFF_CLASSICAL_BLOCK, ctx); } static void diff --git a/src/nfloat/profile/p-mat_mul.c b/src/nfloat/profile/p-mat_mul.c index 19bd0183ec..b29800946b 100644 --- a/src/nfloat/profile/p-mat_mul.c +++ b/src/nfloat/profile/p-mat_mul.c @@ -176,6 +176,66 @@ void prof_classical_vs_fixed() flint_rand_clear(state); } +void tune_fixed_vs_block(int * cutoffs) +{ + gr_ctx_t ctx; + gr_mat_t A, B, C; + slong i, n; + slong prec; + double FLINT_SET_BUT_UNUSED(__), t1, t2; + + flint_rand_t state; + flint_rand_init(state); + + for (i = 0; i < TABN; i++) + cutoffs[i] = -1; + + for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec += 64) + { + flint_printf("prec = %wd\n", prec); + + nfloat_ctx_init(ctx, prec, 0); + + for (n = 16; ; n = FLINT_MAX(n + 1, n * 1.1)) + { + gr_mat_init(A, n, n, ctx); + gr_mat_init(B, n, n, ctx); + gr_mat_init(C, n, n, ctx); + + randmat(A, state, ctx); + randmat(B, state, ctx); + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_mat_mul_fixed(C, A, B, 1000, ctx)); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_mat_mul_block(C, A, B, 1, ctx)); + TIMEIT_STOP_VALUES(__, t2) + + flint_printf("%wd %wd %e %e %.3f\n", prec, n, t1, t2, t1 / t2); + + gr_mat_clear(A, ctx); + gr_mat_clear(B, ctx); + gr_mat_clear(C, ctx); + + if (t2 < t1) + { + cutoffs[prec / 64] = n; + + flint_printf("short tab_fixed_vs_block[] = {\n"); + for (i = 0; i <= prec / 64; i++) + flint_printf(" %d, /* prec = %wd */\n", cutoffs[i], i * 64); + flint_printf("}\n"); + + break; + } + } + } + + flint_rand_clear(state); +} + void prof_fixed_vs_block() { gr_ctx_t ctx; @@ -229,11 +289,73 @@ void prof_fixed_vs_block() flint_rand_clear(state); } +void tune_classical_vs_block(int * cutoffs) +{ + gr_ctx_t ctx; + gr_mat_t A, B, C; + slong i, n; + slong prec; + double FLINT_SET_BUT_UNUSED(__), t1, t2; + + flint_rand_t state; + flint_rand_init(state); + + for (i = 0; i < TABN; i++) + cutoffs[i] = -1; + + for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec += 64) + { + flint_printf("prec = %wd\n", prec); + + nfloat_ctx_init(ctx, prec, 0); + + for (n = 16; ; n = FLINT_MAX(n + 1, n * 1.1)) + { + gr_mat_init(A, n, n, ctx); + gr_mat_init(B, n, n, ctx); + gr_mat_init(C, n, n, ctx); + + randmat(A, state, ctx); + randmat(B, state, ctx); + + TIMEIT_START + GR_MUST_SUCCEED(gr_mat_mul_classical(C, A, B, ctx)); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_mat_mul_block(C, A, B, 1, ctx)); + TIMEIT_STOP_VALUES(__, t2) + + flint_printf("%wd %wd %e %e %.3f\n", prec, n, t1, t2, t1 / t2); + + gr_mat_clear(A, ctx); + gr_mat_clear(B, ctx); + gr_mat_clear(C, ctx); + + if (t2 < t1) + { + cutoffs[prec / 64] = n; + + flint_printf("short tab_classical_vs_block[] = {\n"); + for (i = 0; i <= prec / 64; i++) + flint_printf(" %d, /* prec = %wd */\n", cutoffs[i], i * 64); + flint_printf("}\n"); + + break; + } + } + } + + flint_rand_clear(state); +} + int main() { - int tab_classical_vs_fixed[TABN]; + int tab[TABN]; - //tune_classical_vs_fixed(tab_classical_vs_fixed); + //tune_classical_vs_fixed(tab); + //tune_fixed_vs_block(tab); //prof_classical_vs_fixed(); - prof_fixed_vs_block(); + //prof_fixed_vs_block(); + tune_classical_vs_block(tab); } From e591c83cacbbb2a6d9b3030f21d78df612da402b Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Wed, 4 Sep 2024 15:59:17 +0200 Subject: [PATCH 10/15] modified tunings --- src/nfloat/mat_mul.c | 335 ++++++++++++++++++----- src/nfloat/profile/p-complex_mat_mul.c | 365 +++++++++++++++++++++++++ 2 files changed, 636 insertions(+), 64 deletions(-) create mode 100644 src/nfloat/profile/p-complex_mat_mul.c diff --git a/src/nfloat/mat_mul.c b/src/nfloat/mat_mul.c index d62af20d2a..84b686fc34 100644 --- a/src/nfloat/mat_mul.c +++ b/src/nfloat/mat_mul.c @@ -169,6 +169,216 @@ static short tab_fixed_vs_block[] = { 79, /* prec = 4224 */ }; +static short tab_complex_classical_vs_fixed[] = { + -1, /* prec = 0 */ + 6, /* prec = 64 */ + 6, /* prec = 128 */ + 3, /* prec = 192 */ + 4, /* prec = 256 */ + 2, /* prec = 320 */ + 2, /* prec = 384 */ + 2, /* prec = 448 */ + 2, /* prec = 512 */ + 6, /* prec = 576 */ + 2, /* prec = 640 */ + 2, /* prec = 704 */ + 2, /* prec = 768 */ + 2, /* prec = 832 */ + 2, /* prec = 896 */ + 2, /* prec = 960 */ + 2, /* prec = 1024 */ + 2, /* prec = 1088 */ + 2, /* prec = 1152 */ + 2, /* prec = 1216 */ + 2, /* prec = 1280 */ + 2, /* prec = 1344 */ + 2, /* prec = 1408 */ + 3, /* prec = 1472 */ + 3, /* prec = 1536 */ + 2, /* prec = 1600 */ + 2, /* prec = 1664 */ + 2, /* prec = 1728 */ + 2, /* prec = 1792 */ + 2, /* prec = 1856 */ + 2, /* prec = 1920 */ + 2, /* prec = 1984 */ + 2, /* prec = 2048 */ + 2, /* prec = 2112 */ + 2, /* prec = 2176 */ + 2, /* prec = 2240 */ + 2, /* prec = 2304 */ + 2, /* prec = 2368 */ + 2, /* prec = 2432 */ + 2, /* prec = 2496 */ + 2, /* prec = 2560 */ + 2, /* prec = 2624 */ + 2, /* prec = 2688 */ + 2, /* prec = 2752 */ + 2, /* prec = 2816 */ + 2, /* prec = 2880 */ + 2, /* prec = 2944 */ + 2, /* prec = 3008 */ + 2, /* prec = 3072 */ + 2, /* prec = 3136 */ + 2, /* prec = 3200 */ + 2, /* prec = 3264 */ + 2, /* prec = 3328 */ + 2, /* prec = 3392 */ + 2, /* prec = 3456 */ + 2, /* prec = 3520 */ + 2, /* prec = 3584 */ + 2, /* prec = 3648 */ + 2, /* prec = 3712 */ + 2, /* prec = 3776 */ + 2, /* prec = 3840 */ + 2, /* prec = 3904 */ + 2, /* prec = 3968 */ + 2, /* prec = 4032 */ + 2, /* prec = 4096 */ + 2, /* prec = 4160 */ + 2, /* prec = 4224 */ +}; + +static short tab_complex_fixed_vs_block[] = { + -1, /* prec = 0 */ + 66, /* prec = 64 */ + 414, /* prec = 128 */ + 500, /* prec = 192 */ + 215, /* prec = 256 */ + 455, /* prec = 320 */ + 455, /* prec = 384 */ + 414, /* prec = 448 */ + 500, /* prec = 512 */ + 196, /* prec = 576 */ + 215, /* prec = 640 */ + 215, /* prec = 704 */ + 215, /* prec = 768 */ + 196, /* prec = 832 */ + 215, /* prec = 896 */ + 215, /* prec = 960 */ + 196, /* prec = 1024 */ + 196, /* prec = 1088 */ + 179, /* prec = 1152 */ + 163, /* prec = 1216 */ + 149, /* prec = 1280 */ + 163, /* prec = 1344 */ + 149, /* prec = 1408 */ + 149, /* prec = 1472 */ + 149, /* prec = 1536 */ + 124, /* prec = 1600 */ + 149, /* prec = 1664 */ + 124, /* prec = 1728 */ + 124, /* prec = 1792 */ + 124, /* prec = 1856 */ + 124, /* prec = 1920 */ + 103, /* prec = 1984 */ + 124, /* prec = 2048 */ + 124, /* prec = 2112 */ + 124, /* prec = 2176 */ + 103, /* prec = 2240 */ + 103, /* prec = 2304 */ + 103, /* prec = 2368 */ + 103, /* prec = 2432 */ + 103, /* prec = 2496 */ + 103, /* prec = 2560 */ + 103, /* prec = 2624 */ + 103, /* prec = 2688 */ + 103, /* prec = 2752 */ + 94, /* prec = 2816 */ + 103, /* prec = 2880 */ + 94, /* prec = 2944 */ + 94, /* prec = 3008 */ + 86, /* prec = 3072 */ + 86, /* prec = 3136 */ + 79, /* prec = 3200 */ + 86, /* prec = 3264 */ + 79, /* prec = 3328 */ + 79, /* prec = 3392 */ + 79, /* prec = 3456 */ + 79, /* prec = 3520 */ + 86, /* prec = 3584 */ + 79, /* prec = 3648 */ + 79, /* prec = 3712 */ + 79, /* prec = 3776 */ + 79, /* prec = 3840 */ + 79, /* prec = 3904 */ + 79, /* prec = 3968 */ + 79, /* prec = 4032 */ + 79, /* prec = 4096 */ + 79, /* prec = 4160 */ + 79, /* prec = 4224 */ +}; + +static short tab_complex_classical_vs_block[] = { + -1, /* prec = 0 */ + 36, /* prec = 64 */ + 79, /* prec = 128 */ + 60, /* prec = 192 */ + 50, /* prec = 256 */ + 50, /* prec = 320 */ + 46, /* prec = 384 */ + 55, /* prec = 448 */ + 60, /* prec = 512 */ + 55, /* prec = 576 */ + 39, /* prec = 640 */ + 39, /* prec = 704 */ + 39, /* prec = 768 */ + 39, /* prec = 832 */ + 28, /* prec = 896 */ + 28, /* prec = 960 */ + 39, /* prec = 1024 */ + 24, /* prec = 1088 */ + 28, /* prec = 1152 */ + 24, /* prec = 1216 */ + 24, /* prec = 1280 */ + 16, /* prec = 1344 */ + 24, /* prec = 1408 */ + 16, /* prec = 1472 */ + 20, /* prec = 1536 */ + 16, /* prec = 1600 */ + 16, /* prec = 1664 */ + 16, /* prec = 1728 */ + 16, /* prec = 1792 */ + 16, /* prec = 1856 */ + 16, /* prec = 1920 */ + 16, /* prec = 1984 */ + 16, /* prec = 2048 */ + 16, /* prec = 2112 */ + 16, /* prec = 2176 */ + 16, /* prec = 2240 */ + 16, /* prec = 2304 */ + 16, /* prec = 2368 */ + 16, /* prec = 2432 */ + 16, /* prec = 2496 */ + 16, /* prec = 2560 */ + 16, /* prec = 2624 */ + 16, /* prec = 2688 */ + 16, /* prec = 2752 */ + 16, /* prec = 2816 */ + 16, /* prec = 2880 */ + 16, /* prec = 2944 */ + 16, /* prec = 3008 */ + 16, /* prec = 3072 */ + 16, /* prec = 3136 */ + 16, /* prec = 3200 */ + 16, /* prec = 3264 */ + 16, /* prec = 3328 */ + 16, /* prec = 3392 */ + 16, /* prec = 3456 */ + 16, /* prec = 3520 */ + 16, /* prec = 3584 */ + 16, /* prec = 3648 */ + 16, /* prec = 3712 */ + 16, /* prec = 3776 */ + 16, /* prec = 3840 */ + 16, /* prec = 3904 */ + 16, /* prec = 3968 */ + 16, /* prec = 4032 */ + 16, /* prec = 4096 */ + 16, /* prec = 4160 */ + 16, /* prec = 4224 */ +}; + FLINT_FORCE_INLINE void _nfloat_get_nfixed(nn_ptr res, nn_srcptr x, slong exp, slong fix_nlimbs, gr_ctx_t ctx) @@ -1045,43 +1255,6 @@ nfloat_mat_mul_block(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong min_b return status; } -int -nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) -{ - slong cutoff1, cutoff2, dim; - slong prec; - slong max_extra_prec; - int status; - - slong m = A->r; - slong n = A->c; - slong p = B->c; - - dim = FLINT_MIN(n, FLINT_MIN(m, p)); - - if (dim <= 2 || NFLOAT_CTX_HAS_INF_NAN(ctx)) - return gr_mat_mul_classical(C, A, B, ctx); - - cutoff1 = tab_classical_vs_fixed[TAB_INDEX(prec)]; - - if (dim < cutoff1) - return gr_mat_mul_classical(C, A, B, ctx); - - cutoff2 = tab_fixed_vs_block[TAB_INDEX(prec)]; - - if (dim < cutoff2) - { - max_extra_prec = (prec < 768) ? 64 : prec / 4; - - status = nfloat_mat_mul_fixed(C, A, B, max_extra_prec, ctx); - - if (status == GR_UNABLE && dim < CUTOFF_CLASSICAL_BLOCK) - return gr_mat_mul_classical(C, A, B, ctx); - } - - return nfloat_mat_mul_block(C, A, B, CUTOFF_CLASSICAL_BLOCK, ctx); -} - static void _nfloat_complex_mat_exp_range(slong * _Amin, slong * _Amax, const gr_mat_t A, gr_ctx_t ctx) { @@ -1207,14 +1380,14 @@ nfloat_complex_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slo /* Currently, we don't handle zeros. (They pose no problem, but zero entries in the output may not be exact. To be done.) */ if (Amin < NFLOAT_MIN_EXP || Bmin < NFLOAT_MIN_EXP) - return gr_mat_mul_classical(C, A, B, ctx); + return GR_UNABLE; Adelta = Amax - Amin; Bdelta = Bmax - Bmin; /* sanity check */ if (Adelta > 10 * prec || Bdelta > 10 * prec) - return gr_mat_mul_classical(C, A, B, ctx); + return GR_UNABLE; /* To double check: for Waksman, @@ -1229,7 +1402,7 @@ nfloat_complex_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slo extra_bits = Adelta + Bdelta + pad_top + pad_bot; if (extra_bits >= max_extra_bits) - return gr_mat_mul_classical(C, A, B, ctx); + return GR_UNABLE; Aexp = Amax + pad_top; Bexp = Bmax + pad_top; @@ -1497,14 +1670,51 @@ nfloat_complex_mat_mul_reorder(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, g _nfloat_complex_mat_is_real(B, ctx), ctx); } +int +nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) +{ + slong cutoff1, cutoff2, dim; + slong prec; + slong max_extra_prec; + int status; + + slong m = A->r; + slong n = A->c; + slong p = B->c; + + dim = FLINT_MIN(n, FLINT_MIN(m, p)); + + if (dim <= 2 || NFLOAT_CTX_HAS_INF_NAN(ctx)) + return gr_mat_mul_classical(C, A, B, ctx); + + cutoff1 = tab_classical_vs_fixed[TAB_INDEX(prec)]; + + if (dim < cutoff1) + return gr_mat_mul_classical(C, A, B, ctx); + + cutoff2 = tab_fixed_vs_block[TAB_INDEX(prec)]; + + if (dim < cutoff2) + { + max_extra_prec = (prec < 768) ? 64 : prec / 4; + + status = nfloat_mat_mul_fixed(C, A, B, max_extra_prec, ctx); + + if (status == GR_UNABLE && dim < CUTOFF_CLASSICAL_BLOCK) + return gr_mat_mul_classical(C, A, B, ctx); + } + + return nfloat_mat_mul_block(C, A, B, CUTOFF_CLASSICAL_BLOCK, ctx); +} + int nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) { - slong dim; - slong block_cutoff; + slong cutoff1, cutoff2, cutoff3, dim; slong prec; slong max_extra_prec; int A_real = 0, B_real = 0; + int status; slong m = A->r; slong n = A->c; @@ -1526,34 +1736,31 @@ nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t prec = NFLOAT_CTX_PREC(ctx); - if (prec <= 256) - block_cutoff = 80; - else if (prec <= 512) - block_cutoff = 160; - else if (prec <= 3072) - block_cutoff = 100; - else - block_cutoff = 80; + cutoff1 = tab_complex_classical_vs_fixed[TAB_INDEX(prec)]; - if (dim < block_cutoff) - { - if (prec <= 128 || (prec <= 256 && n <= 4) || (prec == 576 && n <= 6)) - return gr_mat_mul_classical(C, A, B, ctx); + if (dim < cutoff1) + return gr_mat_mul_classical(C, A, B, ctx); + cutoff2 = tab_complex_fixed_vs_block[TAB_INDEX(prec)]; + cutoff3 = tab_complex_classical_vs_block[TAB_INDEX(prec)]; + + if (dim < cutoff2) + { max_extra_prec = (prec < 768) ? 64 : prec / 4; - return nfloat_complex_mat_mul_fixed(C, A, B, max_extra_prec, ctx); + status = nfloat_complex_mat_mul_fixed(C, A, B, max_extra_prec, ctx); + + if (status == GR_UNABLE && dim < cutoff3) + return gr_mat_mul_classical(C, A, B, ctx); + } + + if (_nfloat_complex_mat_parts_are_well_scaled(A, ctx) && + _nfloat_complex_mat_parts_are_well_scaled(B, ctx)) + { + return nfloat_complex_mat_mul_block(C, A, B, cutoff3, ctx); } else { - if (_nfloat_complex_mat_parts_are_well_scaled(A, ctx) && - _nfloat_complex_mat_parts_are_well_scaled(B, ctx)) - { - return nfloat_complex_mat_mul_block(C, A, B, block_cutoff - 10, ctx); - } - else - { - return _nfloat_complex_mat_mul_reorder(C, A, B, A_real, B_real, ctx); - } + return _nfloat_complex_mat_mul_reorder(C, A, B, A_real, B_real, ctx); } } diff --git a/src/nfloat/profile/p-complex_mat_mul.c b/src/nfloat/profile/p-complex_mat_mul.c new file mode 100644 index 0000000000..5366e94c8e --- /dev/null +++ b/src/nfloat/profile/p-complex_mat_mul.c @@ -0,0 +1,365 @@ +/* + Copyright (C) 2024 Fredrik Johansson + + This file is part of FLINT. + + FLINT is free software: you can redistribute it and/or modify it under + the terms of the GNU Lesser General Public License (LGPL) as published + by the Free Software Foundation; either version 3 of the License, or + (at your option) any later version. See . +*/ + +#include +#include "fmpz.h" +#include "gr.h" +#include "gr_special.h" +#include "gr_vec.h" +#include "gr_mat.h" +#include "arf.h" +#include "nfloat.h" +#include "profiler.h" +#include "double_extras.h" + +#define TABN (NFLOAT_MAX_LIMBS + 1) + +#if 1 +#undef TIMEIT_END_REPEAT +#define TIMEIT_END_REPEAT(__timer, __reps) \ + } \ + timeit_stop(__timer); \ + if (__timer->cpu >= 100) \ + break; \ + __reps *= 10; \ + } \ + } while (0); +#endif + +void +randmat(gr_mat_t mat, flint_rand_t state, gr_ctx_t ctx) +{ + slong m = gr_mat_nrows(mat, ctx); + slong n = gr_mat_ncols(mat, ctx); + + slong i, j; + + for (i = 0; i < m; i++) + { + for (j = 0; j < n; j++) + { + gr_ptr v = gr_mat_entry_ptr(mat, i, j, ctx); + + GR_MUST_SUCCEED(gr_i(v, ctx)); + GR_MUST_SUCCEED(gr_mul_si(v, v, 1 + n_randint(state, 1000), ctx)); + if (n_randint(state, 2)) + GR_MUST_SUCCEED(gr_neg(v, v, ctx)); + GR_MUST_SUCCEED(gr_add_si(v, v, 1 + n_randint(state, 1000), ctx)); + GR_MUST_SUCCEED(gr_div_ui(v, v, 1 + n_randint(state, 1000), ctx)); + if (n_randint(state, 2)) + GR_MUST_SUCCEED(gr_neg(v, v, ctx)); + } + } +} + +void tune_classical_vs_fixed(int * cutoffs) +{ + gr_ctx_t ctx; + gr_mat_t A, B, C; + slong i, n, nn; + slong prec; + double FLINT_SET_BUT_UNUSED(__), t1, t2; + + flint_rand_t state; + flint_rand_init(state); + + for (i = 0; i < TABN; i++) + cutoffs[i] = -1; + + for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec += 64) + { + flint_printf("prec = %wd\n", prec); + + nfloat_complex_ctx_init(ctx, prec, 0); + + for (nn = 1; nn <= 128; nn++) + { + n = (nn == 1) ? 128 : nn; + + gr_mat_init(A, n, n, ctx); + gr_mat_init(B, n, n, ctx); + gr_mat_init(C, n, n, ctx); + + randmat(A, state, ctx); + randmat(B, state, ctx); + + TIMEIT_START + GR_MUST_SUCCEED(gr_mat_mul_classical(C, A, B, ctx)); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_complex_mat_mul_fixed(C, A, B, 1000, ctx)); + TIMEIT_STOP_VALUES(__, t2) + + flint_printf("%wd %wd %e %e %.3f\n", prec, n, t1, t2, t1 / t2); + + gr_mat_clear(A, ctx); + gr_mat_clear(B, ctx); + gr_mat_clear(C, ctx); + + if (nn == 1) + { + if (t2 < t1) + continue; + else + break; + } + + if (t2 < t1) + { + cutoffs[prec / 64] = n; + + flint_printf("short tab_classical_vs_fixed[] = {\n"); + for (i = 0; i <= prec / 64; i++) + flint_printf(" %d, /* prec = %wd */\n", cutoffs[i], i * 64); + flint_printf("}\n"); + + break; + } + } + } + + flint_rand_clear(state); +} + +slong ns[] = { 2, 3, 4, 6, 8, 12, 16, 24, 32, 48, 64, 80, 96, 112, 128, 144, 160, 192, 224, 256, 288, 320, 352, 0 }; + +void prof_classical_vs_fixed() +{ + gr_ctx_t ctx; + gr_mat_t A, B, C; + slong i, ni, n; + slong prec; + double FLINT_SET_BUT_UNUSED(__), t1, t2; + + flint_rand_t state; + flint_rand_init(state); + + for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec = (prec < 1024) ? prec + 64 : prec + 256) + { + flint_printf("%wd ", prec); + + nfloat_complex_ctx_init(ctx, prec, 0); + + for (ni = 8; (n = ns[ni]) != 0; ni++) + { + gr_mat_init(A, n, n, ctx); + gr_mat_init(B, n, n, ctx); + gr_mat_init(C, n, n, ctx); + + randmat(A, state, ctx); + randmat(B, state, ctx); + + TIMEIT_START + GR_MUST_SUCCEED(gr_mat_mul_classical(C, A, B, ctx)); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_complex_mat_mul_fixed(C, A, B, 1000, ctx)); + TIMEIT_STOP_VALUES(__, t2) + + flint_printf("%.3f ", t1 / t2); + fflush(stdout); + + gr_mat_clear(A, ctx); + gr_mat_clear(B, ctx); + gr_mat_clear(C, ctx); + } + + flint_printf("\n"); + } + + flint_rand_clear(state); +} + +void tune_fixed_vs_block(int * cutoffs) +{ + gr_ctx_t ctx; + gr_mat_t A, B, C; + slong i, n; + slong prec; + double FLINT_SET_BUT_UNUSED(__), t1, t2; + + flint_rand_t state; + flint_rand_init(state); + + for (i = 0; i < TABN; i++) + cutoffs[i] = -1; + + for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec += 64) + { + flint_printf("prec = %wd\n", prec); + + nfloat_complex_ctx_init(ctx, prec, 0); + + for (n = 16; ; n = FLINT_MAX(n + 1, n * 1.1)) + { + gr_mat_init(A, n, n, ctx); + gr_mat_init(B, n, n, ctx); + gr_mat_init(C, n, n, ctx); + + randmat(A, state, ctx); + randmat(B, state, ctx); + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_complex_mat_mul_fixed(C, A, B, 1000, ctx)); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_complex_mat_mul_block(C, A, B, 1, ctx)); + TIMEIT_STOP_VALUES(__, t2) + + flint_printf("%wd %wd %e %e %.3f\n", prec, n, t1, t2, t1 / t2); + + gr_mat_clear(A, ctx); + gr_mat_clear(B, ctx); + gr_mat_clear(C, ctx); + + if (t2 < t1) + { + cutoffs[prec / 64] = n; + + flint_printf("short tab_fixed_vs_block[] = {\n"); + for (i = 0; i <= prec / 64; i++) + flint_printf(" %d, /* prec = %wd */\n", cutoffs[i], i * 64); + flint_printf("}\n"); + + break; + } + } + } + + flint_rand_clear(state); +} + +void prof_fixed_vs_block() +{ + gr_ctx_t ctx; + gr_mat_t A, B, C; + slong i, ni, n; + slong prec; + double FLINT_SET_BUT_UNUSED(__), t1, t2; + + flint_rand_t state; + flint_rand_init(state); + + flint_printf(" "); + for (ni = 8; (n = ns[ni]) != 0; ni++) + flint_printf("%5wd ", n); + flint_printf("\n"); + + for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec = (prec < 1024) ? prec + 64 : prec + 256) + { + flint_printf("%4wd ", prec); + + nfloat_complex_ctx_init(ctx, prec, 0); + + for (ni = 8; (n = ns[ni]) != 0; ni++) + { + gr_mat_init(A, n, n, ctx); + gr_mat_init(B, n, n, ctx); + gr_mat_init(C, n, n, ctx); + + randmat(A, state, ctx); + randmat(B, state, ctx); + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_complex_mat_mul_fixed(C, A, B, 1000, ctx)); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_complex_mat_mul_block(C, A, B, 1, ctx)); + TIMEIT_STOP_VALUES(__, t2) + + flint_printf("%.3f ", t1 / t2); + fflush(stdout); + + gr_mat_clear(A, ctx); + gr_mat_clear(B, ctx); + gr_mat_clear(C, ctx); + } + + flint_printf("\n"); + } + + flint_rand_clear(state); +} + +void tune_classical_vs_block(int * cutoffs) +{ + gr_ctx_t ctx; + gr_mat_t A, B, C; + slong i, n; + slong prec; + double FLINT_SET_BUT_UNUSED(__), t1, t2; + + flint_rand_t state; + flint_rand_init(state); + + for (i = 0; i < TABN; i++) + cutoffs[i] = -1; + + for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec += 64) + { + flint_printf("prec = %wd\n", prec); + + nfloat_complex_ctx_init(ctx, prec, 0); + + for (n = 16; ; n = FLINT_MAX(n + 1, n * 1.1)) + { + gr_mat_init(A, n, n, ctx); + gr_mat_init(B, n, n, ctx); + gr_mat_init(C, n, n, ctx); + + randmat(A, state, ctx); + randmat(B, state, ctx); + + TIMEIT_START + GR_MUST_SUCCEED(gr_mat_mul_classical(C, A, B, ctx)); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_complex_mat_mul_block(C, A, B, 1, ctx)); + TIMEIT_STOP_VALUES(__, t2) + + flint_printf("%wd %wd %e %e %.3f\n", prec, n, t1, t2, t1 / t2); + + gr_mat_clear(A, ctx); + gr_mat_clear(B, ctx); + gr_mat_clear(C, ctx); + + if (t2 < t1) + { + cutoffs[prec / 64] = n; + + flint_printf("short tab_classical_vs_block[] = {\n"); + for (i = 0; i <= prec / 64; i++) + flint_printf(" %d, /* prec = %wd */\n", cutoffs[i], i * 64); + flint_printf("}\n"); + + break; + } + } + } + + flint_rand_clear(state); +} + +int main() +{ + int tab[TABN]; + + //tune_classical_vs_fixed(tab); + //tune_fixed_vs_block(tab); + //prof_classical_vs_fixed(); + //prof_fixed_vs_block(); + tune_classical_vs_block(tab); +} From 2ab1b05ad24bf3d65ac139c124b6124acc05611b Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Thu, 5 Sep 2024 11:58:24 +0200 Subject: [PATCH 11/15] tuning improvements --- src/nfloat/mat_mul.c | 38 +++++++++++++---- src/nfloat/profile/p-complex_mat_mul.c | 58 +++++++++++++++++++++++++- src/nfloat/profile/p-mat_mul.c | 2 +- 3 files changed, 87 insertions(+), 11 deletions(-) diff --git a/src/nfloat/mat_mul.c b/src/nfloat/mat_mul.c index 84b686fc34..9dc31941d4 100644 --- a/src/nfloat/mat_mul.c +++ b/src/nfloat/mat_mul.c @@ -22,9 +22,6 @@ /* todo: check errors which depend on nlimbs */ -/* cutoffs classical -> block */ -#define CUTOFF_CLASSICAL_BLOCK 70 - #define TAB_INDEX(prec) (FLINT_MIN(prec, 4224) / 64) /* cutoffs for classical -> fixed */ @@ -309,6 +306,8 @@ static short tab_complex_fixed_vs_block[] = { 79, /* prec = 4224 */ }; +#if 0 + static short tab_complex_classical_vs_block[] = { -1, /* prec = 0 */ 36, /* prec = 64 */ @@ -379,6 +378,7 @@ static short tab_complex_classical_vs_block[] = { 16, /* prec = 4224 */ }; +#endif FLINT_FORCE_INLINE void _nfloat_get_nfixed(nn_ptr res, nn_srcptr x, slong exp, slong fix_nlimbs, gr_ctx_t ctx) @@ -1673,7 +1673,7 @@ nfloat_complex_mat_mul_reorder(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, g int nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) { - slong cutoff1, cutoff2, dim; + slong cutoff1, cutoff2, cutoff3, dim; slong prec; slong max_extra_prec; int status; @@ -1687,24 +1687,32 @@ nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) if (dim <= 2 || NFLOAT_CTX_HAS_INF_NAN(ctx)) return gr_mat_mul_classical(C, A, B, ctx); + /* classical -> fixed-point */ cutoff1 = tab_classical_vs_fixed[TAB_INDEX(prec)]; if (dim < cutoff1) return gr_mat_mul_classical(C, A, B, ctx); + /* fixed-point -> block */ cutoff2 = tab_fixed_vs_block[TAB_INDEX(prec)]; + /* classical -> block */ + cutoff3 = 80; + if (dim < cutoff2) { max_extra_prec = (prec < 768) ? 64 : prec / 4; status = nfloat_mat_mul_fixed(C, A, B, max_extra_prec, ctx); - if (status == GR_UNABLE && dim < CUTOFF_CLASSICAL_BLOCK) + if (status == GR_SUCCESS) + return status; + + if (status == GR_UNABLE && dim < cutoff3) return gr_mat_mul_classical(C, A, B, ctx); } - return nfloat_mat_mul_block(C, A, B, CUTOFF_CLASSICAL_BLOCK, ctx); + return nfloat_mat_mul_block(C, A, B, cutoff3 - 10, ctx); } int @@ -1742,7 +1750,18 @@ nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t return gr_mat_mul_classical(C, A, B, ctx); cutoff2 = tab_complex_fixed_vs_block[TAB_INDEX(prec)]; - cutoff3 = tab_complex_classical_vs_block[TAB_INDEX(prec)]; + + /* classical -> block */ + /* tuned for uniform matrices, so maybe not accurate in practice */ + /* cutoff3 = tab_complex_classical_vs_block[TAB_INDEX(prec)]; */ + if (prec <= 256) + cutoff3 = 80; + else if (prec <= 512) + cutoff3 = 160; + else if (prec <= 3072) + cutoff3 = 100; + else + cutoff3 = 80; if (dim < cutoff2) { @@ -1750,6 +1769,9 @@ nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t status = nfloat_complex_mat_mul_fixed(C, A, B, max_extra_prec, ctx); + if (status == GR_SUCCESS) + return status; + if (status == GR_UNABLE && dim < cutoff3) return gr_mat_mul_classical(C, A, B, ctx); } @@ -1757,7 +1779,7 @@ nfloat_complex_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t if (_nfloat_complex_mat_parts_are_well_scaled(A, ctx) && _nfloat_complex_mat_parts_are_well_scaled(B, ctx)) { - return nfloat_complex_mat_mul_block(C, A, B, cutoff3, ctx); + return nfloat_complex_mat_mul_block(C, A, B, cutoff3 - 10, ctx); } else { diff --git a/src/nfloat/profile/p-complex_mat_mul.c b/src/nfloat/profile/p-complex_mat_mul.c index 5366e94c8e..508529fcf7 100644 --- a/src/nfloat/profile/p-complex_mat_mul.c +++ b/src/nfloat/profile/p-complex_mat_mul.c @@ -130,7 +130,7 @@ void tune_classical_vs_fixed(int * cutoffs) flint_rand_clear(state); } -slong ns[] = { 2, 3, 4, 6, 8, 12, 16, 24, 32, 48, 64, 80, 96, 112, 128, 144, 160, 192, 224, 256, 288, 320, 352, 0 }; +slong ns[] = { 2, 3, 4, 8, 16, 24, 32, 48, 64, 80, 96, 128, 144, 256, 512, 1024, 0 }; void prof_classical_vs_fixed() { @@ -353,6 +353,59 @@ void tune_classical_vs_block(int * cutoffs) flint_rand_clear(state); } +void prof_mul() +{ + gr_ctx_t ctx; + gr_mat_t A, B, C; + slong i, ni, n; + slong prec; + double FLINT_SET_BUT_UNUSED(__), t1, t2; + + flint_rand_t state; + flint_rand_init(state); + + flint_printf(" "); + for (ni = 0; (n = ns[ni]) != 0; ni++) + flint_printf("%5wd ", n); + flint_printf("\n"); + + for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec = (prec < 1024) ? prec + 64 : prec + 256) + { + flint_printf("%4wd ", prec); + + nfloat_complex_ctx_init(ctx, prec, 0); + + for (ni = 0; (n = ns[ni]) != 0; ni++) + { + gr_mat_init(A, n, n, ctx); + gr_mat_init(B, n, n, ctx); + gr_mat_init(C, n, n, ctx); + + randmat(A, state, ctx); + randmat(B, state, ctx); + + TIMEIT_START + GR_MUST_SUCCEED(gr_mat_mul_classical(C, A, B, ctx)); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_complex_mat_mul(C, A, B, ctx)); + TIMEIT_STOP_VALUES(__, t2) + + flint_printf("%.3f ", t1 / t2); + fflush(stdout); + + gr_mat_clear(A, ctx); + gr_mat_clear(B, ctx); + gr_mat_clear(C, ctx); + } + + flint_printf("\n"); + } + + flint_rand_clear(state); +} + int main() { int tab[TABN]; @@ -361,5 +414,6 @@ int main() //tune_fixed_vs_block(tab); //prof_classical_vs_fixed(); //prof_fixed_vs_block(); - tune_classical_vs_block(tab); + //tune_classical_vs_block(tab); + prof_mul(); } diff --git a/src/nfloat/profile/p-mat_mul.c b/src/nfloat/profile/p-mat_mul.c index b29800946b..4cae5b6bf8 100644 --- a/src/nfloat/profile/p-mat_mul.c +++ b/src/nfloat/profile/p-mat_mul.c @@ -126,7 +126,7 @@ void tune_classical_vs_fixed(int * cutoffs) flint_rand_clear(state); } -slong ns[] = { 2, 3, 4, 6, 8, 12, 16, 24, 32, 48, 64, 80, 96, 112, 128, 144, 160, 192, 224, 256, 288, 320, 352, 0 }; +slong ns[] = { 2, 3, 4, 8, 16, 24, 32, 48, 64, 80, 96, 128, 144, 256, 512, 1024, 0 }; void prof_classical_vs_fixed() { From 6c414336380d5b59d0303d5ccf57abcdd20ceafd Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Thu, 5 Sep 2024 12:46:17 +0200 Subject: [PATCH 12/15] tuning fix --- src/nfloat/mat_mul.c | 12 +++--- src/nfloat/profile/p-complex_mat_mul.c | 4 +- src/nfloat/profile/p-mat_mul.c | 56 +++++++++++++++++++++++++- 3 files changed, 64 insertions(+), 8 deletions(-) diff --git a/src/nfloat/mat_mul.c b/src/nfloat/mat_mul.c index 9dc31941d4..6643015a91 100644 --- a/src/nfloat/mat_mul.c +++ b/src/nfloat/mat_mul.c @@ -26,7 +26,7 @@ /* cutoffs for classical -> fixed */ static short tab_classical_vs_fixed[] = { - -1, /* prec = 0 */ + 14, /* prec = 0 */ 14, /* prec = 64 */ 16, /* prec = 128 */ 5, /* prec = 192 */ @@ -97,7 +97,7 @@ static short tab_classical_vs_fixed[] = { /* cutoffs for fixed -> block */ static short tab_fixed_vs_block[] = { - -1, /* prec = 0 */ + 50, /* prec = 0 */ 50, /* prec = 64 */ 94, /* prec = 128 */ 124, /* prec = 192 */ @@ -167,7 +167,7 @@ static short tab_fixed_vs_block[] = { }; static short tab_complex_classical_vs_fixed[] = { - -1, /* prec = 0 */ + 6, /* prec = 0 */ 6, /* prec = 64 */ 6, /* prec = 128 */ 3, /* prec = 192 */ @@ -237,7 +237,7 @@ static short tab_complex_classical_vs_fixed[] = { }; static short tab_complex_fixed_vs_block[] = { - -1, /* prec = 0 */ + 66, /* prec = 0 */ 66, /* prec = 64 */ 414, /* prec = 128 */ 500, /* prec = 192 */ @@ -309,7 +309,7 @@ static short tab_complex_fixed_vs_block[] = { #if 0 static short tab_complex_classical_vs_block[] = { - -1, /* prec = 0 */ + 36, /* prec = 0 */ 36, /* prec = 64 */ 79, /* prec = 128 */ 60, /* prec = 192 */ @@ -1687,6 +1687,8 @@ nfloat_mat_mul(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, gr_ctx_t ctx) if (dim <= 2 || NFLOAT_CTX_HAS_INF_NAN(ctx)) return gr_mat_mul_classical(C, A, B, ctx); + prec = NFLOAT_CTX_PREC(ctx); + /* classical -> fixed-point */ cutoff1 = tab_classical_vs_fixed[TAB_INDEX(prec)]; diff --git a/src/nfloat/profile/p-complex_mat_mul.c b/src/nfloat/profile/p-complex_mat_mul.c index 508529fcf7..11e2e437f6 100644 --- a/src/nfloat/profile/p-complex_mat_mul.c +++ b/src/nfloat/profile/p-complex_mat_mul.c @@ -366,7 +366,7 @@ void prof_mul() flint_printf(" "); for (ni = 0; (n = ns[ni]) != 0; ni++) - flint_printf("%5wd ", n); + flint_printf("%5wd ", n); flint_printf("\n"); for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec = (prec < 1024) ? prec + 64 : prec + 256) @@ -392,7 +392,7 @@ void prof_mul() GR_MUST_SUCCEED(nfloat_complex_mat_mul(C, A, B, ctx)); TIMEIT_STOP_VALUES(__, t2) - flint_printf("%.3f ", t1 / t2); + flint_printf("%.3f ", t1 / t2); fflush(stdout); gr_mat_clear(A, ctx); diff --git a/src/nfloat/profile/p-mat_mul.c b/src/nfloat/profile/p-mat_mul.c index 4cae5b6bf8..670f5b87f3 100644 --- a/src/nfloat/profile/p-mat_mul.c +++ b/src/nfloat/profile/p-mat_mul.c @@ -349,6 +349,59 @@ void tune_classical_vs_block(int * cutoffs) flint_rand_clear(state); } +void prof_mul() +{ + gr_ctx_t ctx; + gr_mat_t A, B, C; + slong i, ni, n; + slong prec; + double FLINT_SET_BUT_UNUSED(__), t1, t2; + + flint_rand_t state; + flint_rand_init(state); + + flint_printf(" "); + for (ni = 0; (n = ns[ni]) != 0; ni++) + flint_printf("%5wd ", n); + flint_printf("\n"); + + for (prec = 64; prec <= NFLOAT_MAX_LIMBS * FLINT_BITS; prec = (prec < 1024) ? prec + 64 : prec + 256) + { + flint_printf("%4wd ", prec); + + nfloat_ctx_init(ctx, prec, 0); + + for (ni = 0; (n = ns[ni]) != 0; ni++) + { + gr_mat_init(A, n, n, ctx); + gr_mat_init(B, n, n, ctx); + gr_mat_init(C, n, n, ctx); + + randmat(A, state, ctx); + randmat(B, state, ctx); + + TIMEIT_START + GR_MUST_SUCCEED(gr_mat_mul_classical(C, A, B, ctx)); + TIMEIT_STOP_VALUES(__, t1) + + TIMEIT_START + GR_MUST_SUCCEED(nfloat_mat_mul(C, A, B, ctx)); + TIMEIT_STOP_VALUES(__, t2) + + flint_printf("%.3f ", t1 / t2); + fflush(stdout); + + gr_mat_clear(A, ctx); + gr_mat_clear(B, ctx); + gr_mat_clear(C, ctx); + } + + flint_printf("\n"); + } + + flint_rand_clear(state); +} + int main() { int tab[TABN]; @@ -357,5 +410,6 @@ int main() //tune_fixed_vs_block(tab); //prof_classical_vs_fixed(); //prof_fixed_vs_block(); - tune_classical_vs_block(tab); + //tune_classical_vs_block(tab); + prof_mul(); } From 016cb3e58b82ad31cdabaf60c09d9c1f48b1b96b Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Thu, 5 Sep 2024 13:23:19 +0200 Subject: [PATCH 13/15] profiling code tweak --- src/nfloat/profile/p-complex_mat_mul.c | 3 +++ src/nfloat/profile/p-mat_mul.c | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/nfloat/profile/p-complex_mat_mul.c b/src/nfloat/profile/p-complex_mat_mul.c index 11e2e437f6..6dbd4a1d8a 100644 --- a/src/nfloat/profile/p-complex_mat_mul.c +++ b/src/nfloat/profile/p-complex_mat_mul.c @@ -398,6 +398,9 @@ void prof_mul() gr_mat_clear(A, ctx); gr_mat_clear(B, ctx); gr_mat_clear(C, ctx); + + if (t1 > 3.0) + break; } flint_printf("\n"); diff --git a/src/nfloat/profile/p-mat_mul.c b/src/nfloat/profile/p-mat_mul.c index 670f5b87f3..5a0163d373 100644 --- a/src/nfloat/profile/p-mat_mul.c +++ b/src/nfloat/profile/p-mat_mul.c @@ -394,6 +394,9 @@ void prof_mul() gr_mat_clear(A, ctx); gr_mat_clear(B, ctx); gr_mat_clear(C, ctx); + + if (t1 > 3.0) + break; } flint_printf("\n"); From 2fd37991c3fb90b47074c3edaf3330276c80c9fa Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Thu, 5 Sep 2024 15:43:42 +0200 Subject: [PATCH 14/15] test code tweak --- src/nfloat/test/t-nfixed_mat_mul.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nfloat/test/t-nfixed_mat_mul.c b/src/nfloat/test/t-nfixed_mat_mul.c index 5de8d8de6f..4ef678057d 100644 --- a/src/nfloat/test/t-nfixed_mat_mul.c +++ b/src/nfloat/test/t-nfixed_mat_mul.c @@ -23,13 +23,13 @@ TEST_FUNCTION_START(nfixed_mat_mul, state) nn_ptr a; int which; - slong MAXN = 10; + slong MAXN = 12; slong MINLIMBS = 2; slong MAXLIMBS = 12; for (iter = 0; iter < 10000 * flint_test_multiplier(); iter++) { - which = n_randint(state, 4); + which = n_randint(state, 6); m = 1 + n_randint(state, MAXN); n = 1 + n_randint(state, MAXN); From 2737830f5fc7def37b209fb98e1924d9a383a8e2 Mon Sep 17 00:00:00 2001 From: Fredrik Johansson Date: Fri, 6 Sep 2024 17:43:07 +0200 Subject: [PATCH 15/15] _nfixed_mat_mul_strassen: avoid repeated operations in odd dimension; call classical directly --- src/nfloat/nfixed.c | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/nfloat/nfixed.c b/src/nfloat/nfixed.c index cba1d53da1..2c8eaef16e 100644 --- a/src/nfloat/nfixed.c +++ b/src/nfloat/nfixed.c @@ -1319,18 +1319,22 @@ _nfixed_mat_mul_strassen2(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_ _nfixed_mat_t Bc, Cc; _nfixed_mat_window_init(Bc, B, 0, 2 * bnc, ac, bc, nlimbs); _nfixed_mat_window_init(Cc, C, 0, 2 * bnc, ar, bc, nlimbs); - _nfixed_mat_mul_strassen2(Cc, A, Bc, cutoff, nlimbs); + + _nfixed_mat_mul_classical2(Cc, A, Bc, nlimbs); _nfixed_mat_window_clear(Bc, nlimbs); _nfixed_mat_window_clear(Cc, nlimbs); } if (ar > 2 * anr) { - _nfixed_mat_t Ar, Cr; + _nfixed_mat_t Ar, Bc, Cr; _nfixed_mat_window_init(Ar, A, 2 * anr, 0, ar, ac, nlimbs); - _nfixed_mat_window_init(Cr, C, 2 * anr, 0, ar, bc, nlimbs); - _nfixed_mat_mul_strassen2(Cr, Ar, B, cutoff, nlimbs); + _nfixed_mat_window_init(Bc, B, 0, 0, ac, 2 * bnc, nlimbs); + _nfixed_mat_window_init(Cr, C, 2 * anr, 0, ar, 2 * bnc, nlimbs); + + _nfixed_mat_mul_classical2(Cr, Ar, Bc, nlimbs); _nfixed_mat_window_clear(Ar, nlimbs); + _nfixed_mat_window_clear(Bc, nlimbs); _nfixed_mat_window_clear(Cr, nlimbs); } @@ -1348,7 +1352,7 @@ _nfixed_mat_mul_strassen2(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_ /* todo: faster */ _nfixed_mat_init(tmp, mt, nt, nlimbs); - _nfixed_mat_mul_strassen2(tmp, Ac, Br, cutoff, nlimbs); + _nfixed_mat_mul_classical2(tmp, Ac, Br, nlimbs); _nfixed_mat_add(Cb, Cb, tmp, nlimbs); _nfixed_mat_clear(tmp, nlimbs); _nfixed_mat_window_clear(Ac, nlimbs);