Skip to content

Commit

Permalink
Update in-code comments, referring to ML-KEM standard
Browse files Browse the repository at this point in the history
Signed-off-by: Anjan Roy <[email protected]>
  • Loading branch information
itzmeanjan committed Sep 2, 2024
1 parent 4621071 commit 0ab30f5
Show file tree
Hide file tree
Showing 17 changed files with 78 additions and 93 deletions.
22 changes: 11 additions & 11 deletions include/ml_kem/internals/k_pke.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace k_pke {

// K-PKE key generation algorithm, generating byte serialized public key and secret keym given a 32 -bytes input seed `d`.
// See algorithm 12 of K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See algorithm 13 of K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.
template<size_t k, size_t eta1>
constexpr void
keygen(std::span<const uint8_t, 32> d,
Expand All @@ -19,20 +19,20 @@ keygen(std::span<const uint8_t, 32> d,
requires(ml_kem_params::check_keygen_params(k, eta1))
{
std::array<uint8_t, 64> g_out{};
auto _g_out = std::span(g_out);
auto g_out_span = std::span(g_out);

// Repurposing `g_out` (i.e. array for holding output of hash function G),
// for preparing the concatenated input to hash function G.
std::copy(d.begin(), d.end(), _g_out.begin());
_g_out[d.size()] = k; // Domain seperator to prevent misuse of key
std::copy(d.begin(), d.end(), g_out_span.begin());
g_out_span[d.size()] = k; // Domain seperator to prevent misuse of key

sha3_512::sha3_512_t h512;
h512.absorb(_g_out.template first<d.size() + 1>());
h512.absorb(g_out_span.template first<d.size() + 1>());
h512.finalize();
h512.digest(_g_out);
h512.digest(g_out_span);

const auto rho = _g_out.template subspan<0, 32>();
const auto sigma = _g_out.template subspan<rho.size(), 32>();
const auto rho = g_out_span.template subspan<0, 32>();
const auto sigma = g_out_span.template subspan<rho.size(), 32>();

std::array<ml_kem_field::zq_t, k * k * ml_kem_ntt::N> A_prime{};
ml_kem_utils::generate_matrix<k, false>(A_prime, rho);
Expand Down Expand Up @@ -68,9 +68,9 @@ keygen(std::span<const uint8_t, 32> d,
// ( from where all randomness is deterministically sampled ), this routine encrypts message using
// K-PKE encryption algorithm, computing compressed cipher text.
//
// If modulus check, as described in point (2) of section 6.2 of ML-KEM draft standard, fails, it returns false.
// If modulus check, as described in point (2) of section 7.2 of ML-KEM standard, fails, it returns false.
//
// See algorithm 13 of K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See algorithm 14 of K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv>
[[nodiscard("Use result of modulus check on public key")]] constexpr bool
encrypt(std::span<const uint8_t, ml_kem_utils::get_pke_public_key_len(k)> pubkey,
Expand Down Expand Up @@ -147,7 +147,7 @@ encrypt(std::span<const uint8_t, ml_kem_utils::get_pke_public_key_len(k)> pubkey
// Given K-PKE secret key and cipher text, this routine recovers 32 -bytes plain text which
// was encrypted using K-PKE public key i.e. associated with this secret key.
//
// See algorithm 14 defined in K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See algorithm 15 defined in K-PKE specification https://doi.org/10.6028/NIST.FIPS.203.
template<size_t k, size_t du, size_t dv>
constexpr void
decrypt(std::span<const uint8_t, ml_kem_utils::get_pke_secret_key_len(k)> seckey,
Expand Down
54 changes: 27 additions & 27 deletions include/ml_kem/internals/ml_kem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace ml_kem {

// ML-KEM key generation algorithm, generating byte serialized public key and secret key, given 32 -bytes seed `d` and `z`.
// See algorithm 15 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd
// See algorithm 16 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
template<size_t k, size_t eta1>
constexpr void
keygen(std::span<const uint8_t, 32> d, // used in CPA-PKE
Expand Down Expand Up @@ -48,7 +48,7 @@ keygen(std::span<const uint8_t, 32> d, // used in CPA-PKE
//
// If invalid ML-KEM public key is input, this function execution will fail, returning false.
//
// See algorithm 16 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd
// See algorithm 17 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv>
[[nodiscard("Use result, it might fail because of malformed input public key")]] constexpr bool
encapsulate(std::span<const uint8_t, 32> m,
Expand All @@ -60,33 +60,33 @@ encapsulate(std::span<const uint8_t, 32> m,
std::array<uint8_t, m.size() + sha3_256::DIGEST_LEN> g_in{};
std::array<uint8_t, sha3_512::DIGEST_LEN> g_out{};

auto _g_in = std::span(g_in);
auto _g_in0 = _g_in.template first<m.size()>();
auto _g_in1 = _g_in.template last<sha3_256::DIGEST_LEN>();
auto g_in_span = std::span(g_in);
auto g_in_span0 = g_in_span.template first<m.size()>();
auto g_in_span1 = g_in_span.template last<sha3_256::DIGEST_LEN>();

auto _g_out = std::span(g_out);
auto _g_out0 = _g_out.template first<shared_secret.size()>();
auto _g_out1 = _g_out.template last<_g_out.size() - _g_out0.size()>();
auto g_out_span = std::span(g_out);
auto g_out_span0 = g_out_span.template first<shared_secret.size()>();
auto g_out_span1 = g_out_span.template last<g_out_span.size() - g_out_span0.size()>();

std::copy(m.begin(), m.end(), _g_in0.begin());
std::copy(m.begin(), m.end(), g_in_span0.begin());

sha3_256::sha3_256_t h256{};
h256.absorb(pubkey);
h256.finalize();
h256.digest(_g_in1);
h256.digest(g_in_span1);

sha3_512::sha3_512_t h512{};
h512.absorb(_g_in);
h512.absorb(g_in_span);
h512.finalize();
h512.digest(_g_out);
h512.digest(g_out_span);

const auto has_mod_check_passed = k_pke::encrypt<k, eta1, eta2, du, dv>(pubkey, m, _g_out1, cipher);
const auto has_mod_check_passed = k_pke::encrypt<k, eta1, eta2, du, dv>(pubkey, m, g_out_span1, cipher);
if (!has_mod_check_passed) {
// Got an invalid public key
return has_mod_check_passed;
}

std::copy(_g_out0.begin(), _g_out0.end(), shared_secret.begin());
std::copy(g_out_span0.begin(), g_out_span0.end(), shared_secret.begin());
return true;
}

Expand All @@ -96,7 +96,7 @@ encapsulate(std::span<const uint8_t, 32> m,
// Recovered 32 -bytes plain text is used for deriving a 32 -bytes shared secret key, which can now be
// used for encrypting communication between two participating parties, using fast symmetric key algorithms.
//
// See algorithm 17 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See algorithm 18 defined in ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
template<size_t k, size_t eta1, size_t eta2, size_t du, size_t dv>
constexpr void
decapsulate(std::span<const uint8_t, ml_kem_utils::get_kem_secret_key_len(k)> seckey,
Expand All @@ -122,21 +122,21 @@ decapsulate(std::span<const uint8_t, ml_kem_utils::get_kem_secret_key_len(k)> se
std::array<uint8_t, shared_secret.size()> j_out{};
std::array<uint8_t, cipher.size()> c_prime{};

auto _g_in = std::span(g_in);
auto _g_in0 = _g_in.template first<32>();
auto _g_in1 = _g_in.template last<h.size()>();
auto g_in_span = std::span(g_in);
auto g_in_span0 = g_in_span.template first<32>();
auto g_in_span1 = g_in_span.template last<h.size()>();

auto _g_out = std::span(g_out);
auto _g_out0 = _g_out.template first<shared_secret.size()>();
auto _g_out1 = _g_out.template last<32>();
auto g_out_span = std::span(g_out);
auto g_out_span0 = g_out_span.template first<shared_secret.size()>();
auto g_out_span1 = g_out_span.template last<32>();

k_pke::decrypt<k, du, dv>(pke_sk, cipher, _g_in0);
std::copy(h.begin(), h.end(), _g_in1.begin());
k_pke::decrypt<k, du, dv>(pke_sk, cipher, g_in_span0);
std::copy(h.begin(), h.end(), g_in_span1.begin());

sha3_512::sha3_512_t h512{};
h512.absorb(_g_in);
h512.absorb(g_in_span);
h512.finalize();
h512.digest(_g_out);
h512.digest(g_out_span);

shake256::shake256_t xof256{};
xof256.absorb(z);
Expand All @@ -145,12 +145,12 @@ decapsulate(std::span<const uint8_t, ml_kem_utils::get_kem_secret_key_len(k)> se
xof256.squeeze(j_out);

// Explicitly ignore return value, because public key, held as part of secret key is *assumed* to be valid.
(void)k_pke::encrypt<k, eta1, eta2, du, dv>(pubkey, _g_in0, _g_out1, c_prime);
(void)k_pke::encrypt<k, eta1, eta2, du, dv>(pubkey, g_in_span0, g_out_span1, c_prime);

// line 9-12 of algorithm 17, in constant-time
using kdf_t = std::span<const uint8_t, shared_secret.size()>;
const uint32_t cond = ml_kem_utils::ct_memcmp(cipher, std::span<const uint8_t, ctlen>(c_prime));
ml_kem_utils::ct_cond_memcpy(cond, shared_secret, kdf_t(_g_out0), kdf_t(z));
ml_kem_utils::ct_cond_memcpy(cond, shared_secret, kdf_t(g_out_span0), kdf_t(z));
}

}
4 changes: 2 additions & 2 deletions include/ml_kem/internals/poly/compression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace ml_kem_utils {

// Given an element x ∈ Z_q | q = 3329, this routine compresses it by discarding some low-order bits, computing y ∈ [0, 2^d) | d < round(log2(q)).
//
// See formula 4.5 on page 18 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See formula 4.7 on page 21 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
// Following implementation collects inspiration from https://github.com/FiloSottile/mlkem768/blob/cffbfb96/mlkem768.go#L395-L425.
template<size_t d>
forceinline constexpr ml_kem_field::zq_t
Expand All @@ -30,7 +30,7 @@ compress(const ml_kem_field::zq_t x)

// Given an element x ∈ [0, 2^d) | d < round(log2(q)), this routine decompresses it back to y ∈ Z_q | q = 3329.
//
// See formula 4.6 on page 18 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See formula 4.8 on page 21 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
template<size_t d>
forceinline constexpr ml_kem_field::zq_t
decompress(const ml_kem_field::zq_t x)
Expand Down
9 changes: 5 additions & 4 deletions include/ml_kem/internals/poly/ntt.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ inline constexpr size_t N = 1 << LOG2N;
//
// Meaning, 17 ** 256 == 1 mod q
inline constexpr auto ζ = ml_kem_field::zq_t(17);
static_assert((ζ ^ N) == ml_kem_field::zq_t::one(), "ζ must be 256th root of unity modulo Q");

// Multiplicative inverse of N/ 2 over Z_q | q = 3329 and N = 256
//
Expand Down Expand Up @@ -74,7 +75,7 @@ inline constexpr std::array<ml_kem_field::zq_t, N / 2> POLY_MUL_ζ_EXP = []() ->
// Note, this routine mutates input i.e. it's an in-place NTT implementation.
//
// Implementation inspired from https://github.com/itzmeanjan/falcon/blob/45b0593/include/ntt.hpp#L69-L144.
// See algorithm 8 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See algorithm 9 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
forceinline constexpr void
ntt(std::span<ml_kem_field::zq_t, N> poly)
{
Expand Down Expand Up @@ -110,7 +111,7 @@ ntt(std::span<ml_kem_field::zq_t, N> poly)
// Note, this routine mutates input i.e. it's an in-place iNTT implementation.
//
// Implementation inspired from https://github.com/itzmeanjan/falcon/blob/45b0593/include/ntt.hpp#L146-L224.
// See algorithm 9 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See algorithm 10 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
forceinline constexpr void
intt(std::span<ml_kem_field::zq_t, N> poly)
{
Expand Down Expand Up @@ -146,7 +147,7 @@ intt(std::span<ml_kem_field::zq_t, N> poly)
}

// Given two degree-1 polynomials, this routine computes resulting degree-1 polynomial h.
// See algorithm 11 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See algorithm 12 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
forceinline constexpr void
basemul(std::span<const ml_kem_field::zq_t, 2> f, std::span<const ml_kem_field::zq_t, 2> g, std::span<ml_kem_field::zq_t, 2> h, const ml_kem_field::zq_t ζ)
{
Expand Down Expand Up @@ -178,7 +179,7 @@ basemul(std::span<const ml_kem_field::zq_t, 2> f, std::span<const ml_kem_field::
//
// h = f ◦ g
//
// See algorithm 10 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See algorithm 11 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
constexpr void
polymul(std::span<const ml_kem_field::zq_t, N> f, std::span<const ml_kem_field::zq_t, N> g, std::span<ml_kem_field::zq_t, N> h)
{
Expand Down
4 changes: 2 additions & 2 deletions include/ml_kem/internals/poly/poly_vec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ matrix_multiply(std::span<const ml_kem_field::zq_t, a_rows * a_cols * ml_kem_ntt
using poly_t = std::span<const ml_kem_field::zq_t, ml_kem_ntt::N>;

std::array<ml_kem_field::zq_t, ml_kem_ntt::N> tmp{};
auto _tmp = std::span(tmp);
auto tmp_span = std::span(tmp);

for (size_t i = 0; i < a_rows; i++) {
for (size_t j = 0; j < b_cols; j++) {
Expand All @@ -29,7 +29,7 @@ matrix_multiply(std::span<const ml_kem_field::zq_t, a_rows * a_cols * ml_kem_ntt
const size_t aoff = (i * a_cols + k) * ml_kem_ntt::N;
const size_t boff = (k * b_cols + j) * ml_kem_ntt::N;

ml_kem_ntt::polymul(poly_t(a.subspan(aoff, ml_kem_ntt::N)), poly_t(b.subspan(boff, ml_kem_ntt::N)), _tmp);
ml_kem_ntt::polymul(poly_t(a.subspan(aoff, ml_kem_ntt::N)), poly_t(b.subspan(boff, ml_kem_ntt::N)), tmp_span);

for (size_t l = 0; l < ml_kem_ntt::N; l++) {
c[coff + l] += tmp[l];
Expand Down
8 changes: 4 additions & 4 deletions include/ml_kem/internals/poly/sampling.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace ml_kem_utils {
// If the byte stream is statistically close to uniform random byte stream, produced polynomial coefficients are also
// statiscally close to randomly sampled elements of R_q.
//
// See algorithm 6 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See algorithm 7 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
forceinline constexpr void
sample_ntt(shake128::shake128_t& hasher, std::span<ml_kem_field::zq_t, ml_kem_ntt::N> poly)
{
Expand Down Expand Up @@ -47,7 +47,7 @@ sample_ntt(shake128::shake128_t& hasher, std::span<ml_kem_field::zq_t, ml_kem_nt
// Generate public matrix A ( consists of degree-255 polynomials ) in NTT domain, by sampling from a XOF ( read SHAKE128 ),
// which is seeded with 32 -bytes key and two nonces ( each of 1 -byte ).
//
// See step (4-8) of algorithm 12/ 13 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See step (3-7) of algorithm 13 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
template<size_t k, bool transpose>
constexpr void
generate_matrix(std::span<ml_kem_field::zq_t, k * k * ml_kem_ntt::N> mat, std::span<const uint8_t, 32> rho)
Expand Down Expand Up @@ -81,7 +81,7 @@ generate_matrix(std::span<ml_kem_field::zq_t, k * k * ml_kem_ntt::N> mat, std::s
// Centered Binomial Distribution.
// A degree 255 polynomial deterministically sampled from `64 * eta` -bytes output of a pseudorandom function ( PRF ).
//
// See algorithm 7 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See algorithm 8 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
template<size_t eta>
constexpr void
sample_poly_cbd(std::span<const uint8_t, 64 * eta> prf, std::span<ml_kem_field::zq_t, ml_kem_ntt::N> poly)
Expand Down Expand Up @@ -131,7 +131,7 @@ sample_poly_cbd(std::span<const uint8_t, 64 * eta> prf, std::span<ml_kem_field::
}
}

// Sample a polynomial vector from Bη, following step (9-12) of algorithm 12/ 13 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// Sample a polynomial vector from Bη, following step (8-11) of algorithm 13 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
template<size_t k, size_t eta>
constexpr void
generate_vector(std::span<ml_kem_field::zq_t, k * ml_kem_ntt::N> vec, std::span<const uint8_t, 32> sigma, const uint8_t nonce)
Expand Down
6 changes: 3 additions & 3 deletions include/ml_kem/internals/poly/serialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace ml_kem_utils {
// Given a degree-255 polynomial, where significant portion of each ( total 256 of them ) coefficient ∈ [0, 2^l),
// this routine serializes the polynomial to a byte array of length 32 * l -bytes.
//
// See algorithm 4 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See algorithm 5 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
template<size_t l>
constexpr void
encode(std::span<const ml_kem_field::zq_t, ml_kem_ntt::N> poly, std::span<uint8_t, 32 * l> arr)
Expand Down Expand Up @@ -142,7 +142,7 @@ encode(std::span<const ml_kem_field::zq_t, ml_kem_ntt::N> poly, std::span<uint8_
// Given a byte array of length 32 * l -bytes this routine deserializes it to a polynomial of degree 255 s.t. significant
// portion of each ( total 256 of them ) coefficient ∈ [0, 2^l).
//
// See algorithm 5 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// See algorithm 6 of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
template<size_t l>
constexpr void
decode(std::span<const uint8_t, 32 * l> arr, std::span<ml_kem_field::zq_t, ml_kem_ntt::N> poly)
Expand Down Expand Up @@ -271,7 +271,7 @@ decode(std::span<const uint8_t, 32 * l> arr, std::span<ml_kem_field::zq_t, ml_ke
const auto t0 = (static_cast<uint16_t>(arr[boff + 1] & mask4) << 8) | static_cast<uint16_t>(arr[boff + 0]);
const auto t1 = (static_cast<uint16_t>(arr[boff + 2]) << 4) | static_cast<uint16_t>(arr[boff + 1] >> 4);

// Read line (786-792) of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.ipd.
// Read line (786-792) of ML-KEM specification https://doi.org/10.6028/NIST.FIPS.203.
poly[poff + 0] = ml_kem_field::zq_t::from_non_reduced(t0);
poly[poff + 1] = ml_kem_field::zq_t::from_non_reduced(t1);
}
Expand Down
8 changes: 4 additions & 4 deletions include/ml_kem/internals/rng/prng.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,20 @@ struct prng_t
forceinline prng_t()
{
std::array<uint8_t, bit_security_level / std::numeric_limits<uint8_t>::digits> seed{};
auto _seed = std::span(seed);
auto seed_span = std::span(seed);

// Read more @ https://en.cppreference.com/w/cpp/numeric/random/random_device/random_device
std::random_device rd{};

size_t off = 0;
while (off < _seed.size()) {
while (off < seed_span.size()) {
const uint32_t v = rd();
std::memcpy(_seed.subspan(off, sizeof(v)).data(), &v, sizeof(v));
std::memcpy(seed_span.subspan(off, sizeof(v)).data(), &v, sizeof(v));

off += sizeof(v);
}

state.absorb(_seed);
state.absorb(seed_span);
state.finalize();
}

Expand Down
Loading

0 comments on commit 0ab30f5

Please sign in to comment.