Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test Improvements for ML-KEM #1947

Merged
merged 17 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ endif()
add_executable(test_kem test_kem.c)
target_link_libraries(test_kem PRIVATE ${TEST_DEPS})

if(CMAKE_SYSTEM_NAME STREQUAL "Windows" AND BUILD_SHARED_LIBS)
# workaround for Windows .dll
if(MINGW OR MSYS OR CYGWIN OR CMAKE_CROSSCOMPILING)
target_link_options(test_kem PRIVATE -Wl,--allow-multiple-definition)
else()
target_link_options(test_kem PRIVATE "/FORCE:MULTIPLE")
endif()
endif()

add_executable(test_kem_mem test_kem_mem.c)
target_link_libraries(test_kem_mem PRIVATE ${TEST_DEPS})

Expand Down
3 changes: 0 additions & 3 deletions tests/test_acvp_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
ml_dsa_ver = "ACVP_Vectors/ML-DSA-sigVer-FIPS204/internalProjection.json"

@helpers.filtered_test
@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not needed on Windows")
@pytest.mark.parametrize('kem_name', helpers.available_kems_by_name())
def test_acvp_vec_kem_keygen(kem_name):
if not(helpers.is_kem_enabled_by_name(kem_name)): pytest.skip('Not enabled')
Expand All @@ -45,7 +44,6 @@ def test_acvp_vec_kem_keygen(kem_name):
assert(variantFound == True)

@helpers.filtered_test
@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not needed on Windows")
@pytest.mark.parametrize('kem_name', helpers.available_kems_by_name())
def test_acvp_vec_kem_encdec_aft(kem_name):

Expand Down Expand Up @@ -76,7 +74,6 @@ def test_acvp_vec_kem_encdec_aft(kem_name):
assert(variantFound == True)

@helpers.filtered_test
@pytest.mark.skipif(sys.platform.startswith("win"), reason="Not needed on Windows")
@pytest.mark.parametrize('kem_name', helpers.available_kems_by_name())
def test_acvp_vec_kem_encdec_val(kem_name):

Expand Down
101 changes: 100 additions & 1 deletion tests/test_kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
#include <stdbool.h>
#include <stdio.h>
#include <stdlib.h>
#if defined(_WIN32)
#include <string.h>
#define strcasecmp _stricmp
#else
#include <strings.h>
#endif

#include <oqs/oqs.h>

#include <oqs/sha3.h>
#if OQS_USE_PTHREADS
#include <pthread.h>
#endif
Expand All @@ -20,6 +25,10 @@
#define OQS_TEST_CT_DECLASSIFY(addr, len)
#endif

#ifdef OQS_ENABLE_KEM_ML_KEM
#define MLKEM_SECRET_LEN 32
baentsch marked this conversation as resolved.
Show resolved Hide resolved
#endif

#include "system_info.c"

/* Displays hexadecimal strings */
Expand All @@ -31,6 +40,89 @@ static void OQS_print_hex_string(const char *label, const uint8_t *str, size_t l
printf("\n");
}

#ifdef OQS_ENABLE_KEM_ML_KEM
/* mlkem rejection key testcase */
baentsch marked this conversation as resolved.
Show resolved Hide resolved
static bool mlkem_rej_testcase(OQS_KEM *kem, uint8_t *ciphertext, uint8_t *secret_key) {
// sanity checks
if ((kem == NULL) || (ciphertext == NULL) || (secret_key == NULL)) {
fprintf(stderr, "ERROR: inputs NULL!\n");
return false;
}
// Only run tests for ML-KEM
if (!(strcasecmp(kem->method_name, OQS_KEM_alg_ml_kem_512) == 0 ||
strcasecmp(kem->method_name, OQS_KEM_alg_ml_kem_768) == 0 ||
strcasecmp(kem->method_name, OQS_KEM_alg_ml_kem_1024) == 0)) {
return true;
}
// Buffer to hold z and c. z is always 32 bytes
uint8_t *buff_z_c = NULL;
bool retval = false;
OQS_STATUS rc;
int rv;
size_t length_z_c = 32 + kem->length_ciphertext;
buff_z_c = OQS_MEM_malloc(length_z_c);
if (buff_z_c == NULL) {
fprintf(stderr, "ERROR: OQS_MEM_malloc failed\n");
return false;
}
// Scenario 1: Test rejection key by corrupting the secret key
secret_key[0] += 1;
uint8_t shared_secret_r[MLKEM_SECRET_LEN]; // expected output
uint8_t shared_secret_d[MLKEM_SECRET_LEN]; // calculated output
memcpy(buff_z_c, &secret_key[kem->length_secret_key - 32], 32);
memcpy(&buff_z_c[MLKEM_SECRET_LEN], ciphertext, kem->length_ciphertext);
// Calculate expected secret in case of corrupted cipher: shake256(z || c)
OQS_SHA3_shake256(shared_secret_r, MLKEM_SECRET_LEN, buff_z_c, length_z_c);
rc = OQS_KEM_decaps(kem, shared_secret_d, ciphertext, secret_key);
OQS_TEST_CT_DECLASSIFY(&rc, sizeof rc);
if (rc != OQS_SUCCESS) {
fprintf(stderr, "ERROR: OQS_KEM_decaps failed for rejection testcase scenario 1\n");
goto cleanup;
}
OQS_TEST_CT_DECLASSIFY(shared_secret_d, MLKEM_SECRET_LEN);
OQS_TEST_CT_DECLASSIFY(shared_secret_r, MLKEM_SECRET_LEN);
rv = memcmp(shared_secret_d, shared_secret_r, MLKEM_SECRET_LEN);
if (rv != 0) {
fprintf(stderr, "ERROR: shared secrets are not equal for rejection key in decapsulation scenario 1\n");
OQS_print_hex_string("shared_secret_d", shared_secret_d, MLKEM_SECRET_LEN);
OQS_print_hex_string("shared_secret_r", shared_secret_r, MLKEM_SECRET_LEN);
goto cleanup;
}
secret_key[0] -= 1; // Restore private key
memset(buff_z_c, 0, length_z_c); // Reset buffer

// Scenario 2: Test rejection key by corrupting the ciphertext
ciphertext[0] += 1;
memcpy(buff_z_c, &secret_key[kem->length_secret_key - 32], 32);
memcpy(&buff_z_c[MLKEM_SECRET_LEN], ciphertext, kem->length_ciphertext);

// Calculate expected secret in case of corrupted cipher: shake256(z || c)
OQS_SHA3_shake256(shared_secret_r, MLKEM_SECRET_LEN, buff_z_c, length_z_c);
rc = OQS_KEM_decaps(kem, shared_secret_d, ciphertext, secret_key);
OQS_TEST_CT_DECLASSIFY(&rc, sizeof rc);
if (rc != OQS_SUCCESS) {
fprintf(stderr, "ERROR: OQS_KEM_decaps failed for rejection testcase scenario 2\n");
goto cleanup;
}
OQS_TEST_CT_DECLASSIFY(shared_secret_d, MLKEM_SECRET_LEN);
OQS_TEST_CT_DECLASSIFY(shared_secret_r, MLKEM_SECRET_LEN);
rv = memcmp(shared_secret_d, shared_secret_r, MLKEM_SECRET_LEN);
if (rv != 0) {
fprintf(stderr, "ERROR: shared secrets are not equal for rejection key in decapsulation scenario 2\n");
OQS_print_hex_string("shared_secret_d", shared_secret_d, MLKEM_SECRET_LEN);
OQS_print_hex_string("shared_secret_r", shared_secret_r, MLKEM_SECRET_LEN);
goto cleanup;
}
ciphertext[0] -= 1; // Restore ciphertext
retval = true;
cleanup:
if (buff_z_c) {
OQS_MEM_secure_free(buff_z_c, length_z_c);
}
return retval;
}
#endif //OQS_ENABLE_KEM_ML_KEM

typedef struct magic_s {
uint8_t val[31];
} magic_t;
Expand Down Expand Up @@ -127,6 +219,13 @@ static OQS_STATUS kem_test_correctness(const char *method_name) {
printf("shared secrets are equal\n");
}

#ifdef OQS_ENABLE_KEM_ML_KEM
/* check mlkem rejection testcases. returns true for all other kem algos */
if (false == mlkem_rej_testcase(kem, ciphertext, secret_key)) {
goto err;
}
#endif

// test invalid encapsulation (call should either fail or result in invalid shared secret)
OQS_randombytes(ciphertext, kem->length_ciphertext);
OQS_TEST_CT_DECLASSIFY(ciphertext, kem->length_ciphertext);
Expand Down
10 changes: 5 additions & 5 deletions tests/vectors_kem.c
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ static OQS_STATUS kem_vector_encdec_val(const char *method_name,
ret = OQS_SUCCESS;
} else {
ret = OQS_ERROR;
fprintf(stderr, "[vectors_kem] %s ERROR (AFT): ciphertext or shared secret doesn't match!\n", method_name);
fprintf(stderr, "[vectors_kem] %s ERROR (AFT): shared secret doesn't match!\n", method_name);
}

goto cleanup;
Expand Down Expand Up @@ -358,11 +358,11 @@ int main(int argc, char **argv) {
}

if (!strcmp(test_name, "keyGen")) {
prng_output_stream = argv[3]; // d || z
prng_output_stream = argv[3]; // d || z : both should be 32 bytes each as per FIPS-203
kg_pk = argv[4];
kg_sk = argv[5];

if (strlen(prng_output_stream) % 2 != 0 ||
if (strlen(prng_output_stream) != 128 ||
SWilson4 marked this conversation as resolved.
Show resolved Hide resolved
strlen(kg_pk) != 2 * kem->length_public_key ||
strlen(kg_sk) != 2 * kem->length_secret_key) {
rc = OQS_ERROR;
Expand All @@ -386,12 +386,12 @@ int main(int argc, char **argv) {

rc = kem_kg_vector(alg_name, prng_output_stream_bytes, kg_pk_bytes, kg_sk_bytes);
} else if (!strcmp(test_name, "encDecAFT")) {
prng_output_stream = argv[3]; // m
prng_output_stream = argv[3]; // m : should be 32 bytes as per FIPS-203
encdec_aft_pk = argv[4];
encdec_aft_k = argv[5];
encdec_aft_c = argv[6];

if (strlen(prng_output_stream) % 2 != 0 ||
if (strlen(prng_output_stream) != 64 ||
strlen(encdec_aft_c) != 2 * kem->length_ciphertext ||
strlen(encdec_aft_k) != 2 * kem->length_shared_secret ||
strlen(encdec_aft_pk) != 2 * kem->length_public_key) {
Expand Down
Loading