diff --git a/mlkem/common.h b/mlkem/common.h index d7ac86d6d..76504b88e 100644 --- a/mlkem/common.h +++ b/mlkem/common.h @@ -6,4 +6,7 @@ #define ALIGN __attribute__((aligned(DEFAULT_ALIGN))) #define ALWAYS_INLINE __attribute__((always_inline)) +#define MLKEM_CONCAT_(left, right) left##right +#define MLKEM_CONCAT(left, right) MLKEM_CONCAT_(left, right) + #endif diff --git a/mlkem/debug/debug.h b/mlkem/debug/debug.h index 13b18418c..03ede7365 100644 --- a/mlkem/debug/debug.h +++ b/mlkem/debug/debug.h @@ -144,7 +144,6 @@ void mlkem_debug_print_error(const char *file, int line, const char *msg); } while (0) // Following AWS-LC to define a C99-compliant static assert -#define MLKEM_CONCAT(left, right) left##right #define MLKEM_STATIC_ASSERT_DEFINE(cond, msg) \ typedef struct { \ unsigned int MLKEM_CONCAT(static_assertion_, msg) : (cond) ? 1 : -1; \ diff --git a/mlkem/ntt.c b/mlkem/ntt.c index ca123ff4b..f31b76498 100644 --- a/mlkem/ntt.c +++ b/mlkem/ntt.c @@ -74,49 +74,128 @@ const int16_t zetas[128] = { **************************************************/ #if !defined(MLKEM_USE_NATIVE_NTT) -// Check that the specific bound for the reference NTT implies -// the bound required by the C<->Native interface. -#define NTT_BOUND_REF (5 * MLKEM_Q) -STATIC_ASSERT(NTT_BOUND_REF <= NTT_BOUND, ntt_ref_bound) +// Helper macros to make NTT contracts and invariants more readable +// TODO: Consolidate this with the macros in cbmc.h +#define SCALAR_IN_BOUNDS(a, lb, ub) ((lb) <= (a) && (a) <= (ub)) +#define SCALAR_Q_BOUND(a, k) \ + SCALAR_IN_BOUNDS((a), (-(k) * MLKEM_Q + 1), ((k) * MLKEM_Q - 1)) +#define ARRAY_Q_BOUND(arr, lb, ub, k) \ + ARRAY_IN_BOUNDS(int, MLKEM_CONCAT(i, __LINE__), (lb), (ub), (arr), \ + (-(k) * MLKEM_Q + 1), ((k) * MLKEM_Q - 1)) + +// Compute a block CT butterflies with a fixed twiddle factor +STATIC_TESTABLE // clang-format off +void ntt_butterfly_block(int16_t *r, int16_t root, int stride, int bound) + REQUIRES(1 <= stride && stride <= 128) + REQUIRES(1 <= bound && bound <= 7) + REQUIRES(IS_FRESH(r, 4 * stride)) + REQUIRES(SCALAR_IN_BOUNDS(root, -HALF_Q + 1, HALF_Q - 1)) + REQUIRES(ARRAY_Q_BOUND(r, 0, (2 * stride - 1), bound)) + ASSIGNS(OBJECT_UPTO(r, sizeof(int16_t) * 2 * stride)) + ENSURES(ARRAY_Q_BOUND(r, 0, (2 * stride - 1), bound+1)) +// clang-format off +{ + // Used for the specification only + ((void) bound); + // clang-format off + for (int j = 0; j < stride; j++) + ASSIGNS(j, OBJECT_UPTO(r, sizeof(int16_t) * 2 * stride)) + INVARIANT(0 <= j && j <= stride) + INVARIANT(ARRAY_Q_BOUND(r, 0, j-1, bound+1)) + INVARIANT(ARRAY_Q_BOUND(r, stride, stride + j - 1, bound+1)) + INVARIANT(ARRAY_Q_BOUND(r, j, stride - 1, bound)) + INVARIANT(ARRAY_Q_BOUND(r, stride + j, 2*stride - 1, bound)) + // clang-format on + { + int16_t t; + t = fqmul(r[j + stride], root); + r[j + stride] = r[j] - t; + r[j] = r[j] + t; + } +} + +// Framed version of ntt_bufferly_block +// +// TODO: This only exists because inlining ntt_butterfly_block() leads to +// much longer proof times (in fact, I have not witnessed it finishing so far). +// Even this proof, while seemingly a trivial framing, takes longer than all +// of the 'actual' NTT proof. +STATIC_TESTABLE +// clang-format off +void ntt_butterfly_block_at(int16_t *p, int16_t root, int sz, int base, int stride, int bound) + REQUIRES(2 <= sz && sz <= 256 && (sz & 1) == 0) + REQUIRES(0 <= base && base < sz && 1 <= stride && stride <= sz / 2 && base + 2 * stride <= sz) + REQUIRES((stride & (stride - 1)) == 0) + REQUIRES((base & (stride - 1)) == 0) + REQUIRES(1 <= bound && bound <= 7) + REQUIRES(IS_FRESH(p, sizeof(int16_t) * sz)) + REQUIRES(SCALAR_IN_BOUNDS(root, -HALF_Q + 1, HALF_Q - 1)) + REQUIRES(ARRAY_Q_BOUND(p, 0, base - 1, bound + 1)) + REQUIRES(ARRAY_Q_BOUND(p, base, sz - 1, bound)) + ASSIGNS(OBJECT_UPTO(p, sizeof(int16_t) * sz)) + ENSURES(ARRAY_Q_BOUND(p, 0, base + 2*stride - 1, bound + 1)) + ENSURES(ARRAY_Q_BOUND(p, base + 2 * stride, sz - 1, bound)) +// clang-format on +{ + // Parameter only used in the CBMC specification + ((void)sz); + ntt_butterfly_block(p + base, root, stride, bound); +} + +STATIC_TESTABLE +// clang-format off +void ntt_layer(poly *p, int layer) + REQUIRES(IS_FRESH(p, sizeof(poly))) + REQUIRES(1 <= layer && layer <= 7) + REQUIRES(ARRAY_Q_BOUND(p->coeffs, 0, MLKEM_N - 1, layer)) + ASSIGNS(OBJECT_UPTO(p, sizeof(poly))) + ENSURES(ARRAY_Q_BOUND(p->coeffs, 0, MLKEM_N - 1, layer + 1)) +// clang-format on +{ + int16_t *r = p->coeffs; + const int len = 1u << (8 - layer); + const int blocks = 1 << (layer - 1); + for (int i = 0; i < blocks; i++) + // clang-format off + ASSIGNS(i, OBJECT_UPTO(r, sizeof(poly))) + INVARIANT(1 <= layer && layer <= 7) + INVARIANT(0 <= i && i <= blocks) + INVARIANT(ARRAY_Q_BOUND(r, 2 * i * len, MLKEM_N - 1, layer)) + INVARIANT(ARRAY_Q_BOUND(r, 0, 2 * i * len - 1, layer + 1)) + // clang-format off + { + int16_t zeta = zetas[blocks + i]; + ntt_butterfly_block_at(r, zeta, MLKEM_N, 2 * i * len, len, layer); + } +} // REF-CHANGE: Removed indirection poly_ntt -> ntt() // and integrated polynomial reduction into the NTT. void poly_ntt(poly *p) { POLY_BOUND_MSG(p, MLKEM_Q, "ref ntt input"); - unsigned int len, start, j, k; - int16_t t, zeta; - int16_t *r = p->coeffs; - - k = 1; - // Bounds reasoning: // - There are 7 layers // - When passing from layer N to layer N+1, each layer-N value // is modified through the addition/subtraction of a Montgomery // product of a twiddle of absolute value < q/2 and a layer-N value. - // - Recalling that |fqmul(a,t)| < q * (0.0254*C + 1/2) for - // |a| < C*q and |t|= 2; len >>= 1) { - for (start = 0; start < 256; start = j + len) { - zeta = zetas[k++]; - for (j = start; j < start + len; j++) { - t = fqmul(r[j + len], zeta); - r[j + len] = r[j] - t; - r[j] = r[j] + t; - } + // - Recalling that |fqmul(a,t)| < q for |t|coeffs, 0, MLKEM_N - 1, layer)) + // clang-format on + { + ntt_layer(p, layer); } - } // Check the stronger bound - POLY_BOUND_MSG(p, NTT_BOUND_REF, "ref ntt output"); + POLY_BOUND_MSG(p, NTT_BOUND, "ref ntt output"); } #else /* MLKEM_USE_NATIVE_NTT */ diff --git a/mlkem/ntt.h b/mlkem/ntt.h index 0add34db2..d308ffbe5 100644 --- a/mlkem/ntt.h +++ b/mlkem/ntt.h @@ -4,6 +4,7 @@ #include #include "arith_native.h" +#include "cbmc.h" #include "params.h" #include "poly.h" #include "reduce.h" @@ -11,8 +12,15 @@ #define zetas MLKEM_NAMESPACE(zetas) extern const int16_t zetas[128]; +// clang-format off #define poly_ntt MLKEM_NAMESPACE(poly_ntt) -void poly_ntt(poly *r); +void poly_ntt(poly *r) +REQUIRES(IS_FRESH(r, sizeof(poly))) +REQUIRES(ARRAY_IN_BOUNDS(int, t, 0, MLKEM_N - 1, r->coeffs, -MLKEM_Q + 1, MLKEM_Q - 1)) +ASSIGNS(OBJECT_UPTO(r, sizeof(poly))) +ENSURES(ARRAY_IN_BOUNDS(int, t, 0, MLKEM_N - 1, r->coeffs, -NTT_BOUND + 1, NTT_BOUND - 1)) +; +// clang-format on #define poly_invntt_tomont MLKEM_NAMESPACE(poly_invntt_tomont) void poly_invntt_tomont(poly *r);