Skip to content

Commit

Permalink
Merge pull request #2112 from fredrik-johansson/nfixed4
Browse files Browse the repository at this point in the history
Precise bounds calculation for nfixed_mat_mul
  • Loading branch information
fredrik-johansson authored Nov 28, 2024
2 parents 9b3d2b6 + b5f6533 commit 2f4594b
Show file tree
Hide file tree
Showing 9 changed files with 667 additions and 47 deletions.
24 changes: 23 additions & 1 deletion doc/source/nfloat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,33 @@ intermediate results (including rounding errors) lie in `(-1,1)`.
indicate the offset in number of limbs between consecutive entries
and may be negative.

.. function:: void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
.. function:: void _nfixed_mat_mul_classical_precise(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
void _nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
void _nfixed_mat_mul_strassen(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong cutoff, slong nlimbs)
void _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)

Matrix multiplication using various algorithms.
The *strassen* variant takes a *cutoff* parameter specifying where
to switch from basecase multiplication to Strassen multiplication.
The *classical_precise* version computes with one extra limb of
internal precision; this is only intended for testing purposes.

.. function:: void _nfixed_mat_mul_bound_classical(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs)
void _nfixed_mat_mul_bound_waksman(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs)
void _nfixed_mat_mul_bound_strassen(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong cutoff, slong nlimbs)
void _nfixed_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs)
void _nfixed_complex_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, double C, double D, slong nlimbs)

For the respective matrix multiplication algorithm, computes bounds
for a size `m \times n \times p` product at precision *nlimbs*
given entrywise bounds *A* and *B*.

The *bound* output is set to a bound for the entries in all intermediate
variables of the computation. This should be < 1 to
ensure correctness. The *error* output is set to a bound for the
output error, measured in ulp.
The caller can assume that the computed bounds are nondecreasing
functions of *A* and *B*.

For complex multiplication, the entrywise bounds are for `A+Bi` and `C+Di`.
8 changes: 8 additions & 0 deletions src/nfloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -590,11 +590,19 @@ void _nfixed_dot_6(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ys
void _nfixed_dot_7(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len);
void _nfixed_dot_8(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len);

void _nfixed_mat_mul_classical_precise(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs);
void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs);
void _nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs);
void _nfixed_mat_mul_strassen(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong cutoff, slong nlimbs);
void _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs);

void _nfixed_mat_mul_bound_classical(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs);
void _nfixed_mat_mul_bound_waksman(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs);
void _nfixed_mat_mul_bound_strassen(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong cutoff, slong nlimbs);
void _nfixed_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs);
void _nfixed_complex_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, double C, double D, slong nlimbs);


#ifdef __cplusplus
}
#endif
Expand Down
89 changes: 57 additions & 32 deletions src/nfloat/mat_mul.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "double_extras.h"
#include "mpn_extras.h"
#include "gr.h"
#include "gr_vec.h"
Expand Down Expand Up @@ -500,27 +501,39 @@ nfloat_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong max_e
if (Adelta > 10 * prec || Bdelta > 10 * prec)
return GR_UNABLE;

/*
To double check: for Waksman,
* The intermediate entries are bounded by 8n max(|A|,|B|)^2.
* The error, including error from converting
the input matrices, is bounded by 8n ulps.
*/
/* We must scale inputs to 2^(-pad_top) so that intermediate
entries satisfy |x| < 1. */
{
double Abound, Bbound, bound, error;

pad_top = 2;
Abound = Bbound = ldexp(1.0, -pad_top);
/* Option: improve accuracy by adding more trailing guard bits. */
/* pad_bot = 3 + FLINT_BIT_COUNT(n); */
pad_bot = 2;

pad_top = 3 + FLINT_BIT_COUNT(n);
pad_bot = 3 + FLINT_BIT_COUNT(n);
while (1)
{
Aexp = Amax + pad_top;
Bexp = Bmax + pad_top;
extra_bits = Adelta + Bdelta + pad_top + pad_bot;

extra_bits = Adelta + Bdelta + pad_top + pad_bot;
if (extra_bits >= max_extra_bits)
return GR_UNABLE;

if (extra_bits >= max_extra_bits)
return GR_UNABLE;
fbits = prec + extra_bits;
fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS;

Aexp = Amax + pad_top;
Bexp = Bmax + pad_top;
fbits = prec + extra_bits;
fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS;
_nfixed_mat_mul_bound(&bound, &error, A->r, n, B->c, Abound, Bbound, fnlimbs);

return _nfloat_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx);
if (bound < 0.999)
return _nfloat_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx);

pad_top++;
Abound *= 0.5;
Bbound *= 0.5;
}
}
}

static void
Expand Down Expand Up @@ -1389,27 +1402,39 @@ nfloat_complex_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slo
if (Adelta > 10 * prec || Bdelta > 10 * prec)
return GR_UNABLE;

/*
To double check: for Waksman,
* The intermediate entries are bounded by 8n max(|A|,|B|)^2.
* The error, including error from converting
the input matrices, is bounded by 8n ulps.
*/
/* We must scale inputs to 2^(-pad_top) so that intermediate
entries satisfy |x| < 1. */
{
double Abound, Bbound, bound, error;

pad_top = 2;
Abound = Bbound = ldexp(1.0, -pad_top);
/* Option: improve accuracy by adding more trailing guard bits. */
/* pad_bot = 3 + FLINT_BIT_COUNT(n); */
pad_bot = 2;

pad_top = 3 + FLINT_BIT_COUNT(n);
pad_bot = 3 + FLINT_BIT_COUNT(n);
while (1)
{
Aexp = Amax + pad_top;
Bexp = Bmax + pad_top;
extra_bits = Adelta + Bdelta + pad_top + pad_bot;

extra_bits = Adelta + Bdelta + pad_top + pad_bot;
if (extra_bits >= max_extra_bits)
return GR_UNABLE;

if (extra_bits >= max_extra_bits)
return GR_UNABLE;
fbits = prec + extra_bits;
fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS;

Aexp = Amax + pad_top;
Bexp = Bmax + pad_top;
fbits = prec + extra_bits;
fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS;
_nfixed_complex_mat_mul_bound(&bound, &error, A->r, n, B->c, Abound, Abound, Bbound, Bbound, fnlimbs);

return _nfloat_complex_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx);
if (bound < 0.999)
return _nfloat_complex_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx);

pad_top++;
Abound *= 0.5;
Bbound *= 0.5;
}
}
}

FLINT_FORCE_INLINE slong
Expand Down
Loading

0 comments on commit 2f4594b

Please sign in to comment.