diff --git a/src/crypto.cpp b/src/crypto.cpp index b578702ca..9c5f2c455 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -93,11 +93,27 @@ Blob aesEncrypt(const uint8_t* data, size_t data_length, const Blob& key) std::random_device rdev; std::generate_n(ret.begin(), GCM_IV_SIZE, std::bind(rand_byte, std::ref(rdev))); } - struct gcm_aes_ctx aes; - gcm_aes_set_key(&aes, key.size(), key.data()); - gcm_aes_set_iv(&aes, GCM_IV_SIZE, ret.data()); - gcm_aes_encrypt(&aes, data_length, ret.data() + GCM_IV_SIZE, data); - gcm_aes_digest(&aes, GCM_DIGEST_SIZE, ret.data() + GCM_IV_SIZE + data_length); + + if (key.size() == AES_LENGTHS[0]) { + struct gcm_aes128_ctx aes; + gcm_aes128_set_key(&aes, key.data()); + gcm_aes128_set_iv(&aes, GCM_IV_SIZE, ret.data()); + gcm_aes128_encrypt(&aes, data_length, ret.data() + GCM_IV_SIZE, data); + gcm_aes128_digest(&aes, GCM_DIGEST_SIZE, ret.data() + GCM_IV_SIZE + data_length); + } else if (key.size() == AES_LENGTHS[1]) { + struct gcm_aes192_ctx aes; + gcm_aes192_set_key(&aes, key.data()); + gcm_aes192_set_iv(&aes, GCM_IV_SIZE, ret.data()); + gcm_aes192_encrypt(&aes, data_length, ret.data() + GCM_IV_SIZE, data); + gcm_aes192_digest(&aes, GCM_DIGEST_SIZE, ret.data() + GCM_IV_SIZE + data_length); + } else if (key.size() == AES_LENGTHS[2]) { + struct gcm_aes256_ctx aes; + gcm_aes256_set_key(&aes, key.data()); + gcm_aes256_set_iv(&aes, GCM_IV_SIZE, ret.data()); + gcm_aes256_encrypt(&aes, data_length, ret.data() + GCM_IV_SIZE, data); + gcm_aes256_digest(&aes, GCM_DIGEST_SIZE, ret.data() + GCM_IV_SIZE + data_length); + } + return ret; } @@ -118,14 +134,28 @@ Blob aesDecrypt(const uint8_t* data, size_t data_length, const Blob& key) std::array digest; - struct gcm_aes_ctx aes; - gcm_aes_set_key(&aes, key.size(), key.data()); - gcm_aes_set_iv(&aes, GCM_IV_SIZE, data); - size_t data_sz = data_length - GCM_IV_SIZE - GCM_DIGEST_SIZE; Blob ret(data_sz); - gcm_aes_decrypt(&aes, data_sz, ret.data(), data + GCM_IV_SIZE); - gcm_aes_digest(&aes, GCM_DIGEST_SIZE, digest.data()); + + if (key.size() == AES_LENGTHS[0]) { + struct gcm_aes128_ctx aes; + gcm_aes128_set_key(&aes, key.data()); + gcm_aes128_set_iv(&aes, GCM_IV_SIZE, data); + gcm_aes128_decrypt(&aes, data_sz, ret.data(), data + GCM_IV_SIZE); + gcm_aes128_digest(&aes, GCM_DIGEST_SIZE, digest.data()); + } else if (key.size() == AES_LENGTHS[1]) { + struct gcm_aes192_ctx aes; + gcm_aes192_set_key(&aes, key.data()); + gcm_aes192_set_iv(&aes, GCM_IV_SIZE, data); + gcm_aes192_decrypt(&aes, data_sz, ret.data(), data + GCM_IV_SIZE); + gcm_aes192_digest(&aes, GCM_DIGEST_SIZE, digest.data()); + } else if (key.size() == AES_LENGTHS[2]) { + struct gcm_aes256_ctx aes; + gcm_aes256_set_key(&aes, key.data()); + gcm_aes256_set_iv(&aes, GCM_IV_SIZE, data); + gcm_aes256_decrypt(&aes, data_sz, ret.data(), data + GCM_IV_SIZE); + gcm_aes256_digest(&aes, GCM_DIGEST_SIZE, digest.data()); + } if (not std::equal(digest.begin(), digest.end(), data + data_length - GCM_DIGEST_SIZE)) { throw DecryptError("Can't decrypt data"); diff --git a/tests/cryptotester.cpp b/tests/cryptotester.cpp index c33a038ff..1673784cd 100644 --- a/tests/cryptotester.cpp +++ b/tests/cryptotester.cpp @@ -239,6 +239,27 @@ void CryptoTester::testAesEncryption() { CPPUNIT_ASSERT(data2 == decrypted2); } +void CryptoTester::testAesEncryptionWithMultipleKeySizes() { + auto data = std::vector(rand(), rand()); + + // Valid key sizes + for (auto key_length : {16, 24, 32}) { + auto key = std::vector(key_length, rand()); + + auto encrypted_data = dht::crypto::aesEncrypt(data, key); + auto decrypted_data = dht::crypto::aesDecrypt(encrypted_data, key); + + CPPUNIT_ASSERT(data == decrypted_data); + } + + // Invalid key sizes + for (auto key_length : {12, 28, 36}) { + auto key = std::vector(key_length, rand()); + + CPPUNIT_ASSERT_THROW(dht::crypto::aesEncrypt(data, key), dht::crypto::DecryptError); + } +} + void CryptoTester::tearDown() { diff --git a/tests/cryptotester.h b/tests/cryptotester.h index 890190571..6cd552c40 100644 --- a/tests/cryptotester.h +++ b/tests/cryptotester.h @@ -34,6 +34,7 @@ class CryptoTester : public CppUnit::TestFixture { CPPUNIT_TEST(testCertificateSerialNumber); CPPUNIT_TEST(testOcsp); CPPUNIT_TEST(testAesEncryption); + CPPUNIT_TEST(testAesEncryptionWithMultipleKeySizes); CPPUNIT_TEST_SUITE_END(); public: @@ -69,6 +70,7 @@ class CryptoTester : public CppUnit::TestFixture { * Test key streching and aes encryption/decryption */ void testAesEncryption(); + void testAesEncryptionWithMultipleKeySizes(); }; } // namespace test