diff --git a/include/opendht/crypto.h b/include/opendht/crypto.h index 1033c04d9..1f14771f2 100644 --- a/include/opendht/crypto.h +++ b/include/opendht/crypto.h @@ -793,7 +793,7 @@ OPENDHT_PUBLIC void hash(const uint8_t* data, size_t data_length, uint8_t* hash, * that can be transmitted in clear, and will be generated if * not provided (32 bytes). */ -OPENDHT_PUBLIC Blob stretchKey(const std::string& password, Blob& salt, size_t key_length = 512/8); +OPENDHT_PUBLIC Blob stretchKey(std::string_view password, Blob& salt, size_t key_length = 512/8); /** * AES-GCM encryption. Key must be 128, 192 or 256 bits long (16, 24 or 32 bytes). @@ -802,15 +802,37 @@ OPENDHT_PUBLIC Blob aesEncrypt(const uint8_t* data, size_t data_length, const Bl OPENDHT_PUBLIC inline Blob aesEncrypt(const Blob& data, const Blob& key) { return aesEncrypt(data.data(), data.size(), key); } -OPENDHT_PUBLIC Blob aesEncrypt(const Blob& data, const std::string& password); +OPENDHT_PUBLIC Blob aesEncrypt(const Blob& data, std::string_view password); /** * AES-GCM decryption. */ OPENDHT_PUBLIC Blob aesDecrypt(const uint8_t* data, size_t data_length, const Blob& key); OPENDHT_PUBLIC inline Blob aesDecrypt(const Blob& data, const Blob& key) { return aesDecrypt(data.data(), data.size(), key); } -OPENDHT_PUBLIC Blob aesDecrypt(const uint8_t* data, size_t data_length, const std::string& password); -OPENDHT_PUBLIC inline Blob aesDecrypt(const Blob& data, const std::string& password) { return aesDecrypt(data.data(), data.size(), password); } +OPENDHT_PUBLIC inline Blob aesDecrypt(std::string_view data, const Blob& key) { return aesDecrypt((uint8_t*)data.data(), data.size(), key); } + +OPENDHT_PUBLIC Blob aesDecrypt(const uint8_t* data, size_t data_length, std::string_view password); +OPENDHT_PUBLIC inline Blob aesDecrypt(const Blob& data, std::string_view password) { return aesDecrypt(data.data(), data.size(), password); } +OPENDHT_PUBLIC inline Blob aesDecrypt(std::string_view data, std::string_view password) { return aesDecrypt((uint8_t*)data.data(), data.size(), password); } + +/** + * Get raw AES key from password and salt stored with the encrypted data. + */ +OPENDHT_PUBLIC Blob aesGetKey(const uint8_t* data, size_t data_length, std::string_view password); +OPENDHT_PUBLIC Blob inline aesGetKey(const Blob& data, std::string_view password) { + return aesGetKey(data.data(), data.size(), password); +} +/** Get the salt part of data password-encrypted with `aesEncrypt(data, password)` */ +OPENDHT_PUBLIC Blob aesGetSalt(const uint8_t* data, size_t data_length); +OPENDHT_PUBLIC Blob inline aesGetSalt(const Blob& data) { + return aesGetSalt(data.data(), data.size()); +} +/** Get the salt part of data password-encrypted with `aesEncrypt(data, password)` */ +OPENDHT_PUBLIC std::string_view aesGetEncrypted(const uint8_t* data, size_t data_length); +OPENDHT_PUBLIC std::string_view inline aesGetEncrypted(const Blob& data) { + return aesGetEncrypted(data.data(), data.size()); +} + } } diff --git a/python/opendht.pyx b/python/opendht.pyx index 39f3d8efd..26bac2af3 100644 --- a/python/opendht.pyx +++ b/python/opendht.pyx @@ -462,6 +462,26 @@ cdef class Identity(object): k._key = self._id.first return k +def aesEncrypt(bytes data, str password) -> bytes : + cdef size_t d_len = len(data) + cdef cpp.uint8_t* d_ptr = data + cdef cpp.Blob indat + indat.assign(d_ptr, (d_ptr + d_len)) + cdef cpp.Blob encrypted = cpp.aesEncrypt(indat, password.encode()) + cdef char* encrypted_c_str = encrypted.data() + cdef Py_ssize_t length = encrypted.size() + return encrypted_c_str[:length] + +def aesDecrypt(bytes data, str password) -> bytes : + cdef size_t d_len = len(data) + cdef cpp.uint8_t* d_ptr = data + cdef cpp.Blob indat + indat.assign(d_ptr, (d_ptr + d_len)) + cdef cpp.Blob decrypted = cpp.aesDecrypt(indat, password.encode()) + cdef char* decrypted_c_str = decrypted.data() + cdef Py_ssize_t length = decrypted.size() + return decrypted_c_str[:length] + cdef class DhtConfig(object): cdef cpp.DhtRunnerConfig _config def __init__(self): diff --git a/python/opendht_cpp.pxd b/python/opendht_cpp.pxd index 2d374d149..7b73f16c1 100644 --- a/python/opendht_cpp.pxd +++ b/python/opendht_cpp.pxd @@ -94,6 +94,8 @@ ctypedef vector[uint8_t] Blob cdef extern from "opendht/crypto.h" namespace "dht::crypto": ctypedef pair[shared_ptr[PrivateKey], shared_ptr[Certificate]] Identity cdef Identity generateIdentity(string name, Identity ca, unsigned bits) + cdef Blob aesEncrypt(Blob data, string password) except + + cdef Blob aesDecrypt(Blob encrypted, string password) except + cdef cppclass PrivateKey: PrivateKey() diff --git a/src/crypto.cpp b/src/crypto.cpp index 578003e28..49cbc9dea 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -103,7 +103,7 @@ Blob aesEncrypt(const uint8_t* data, size_t data_length, const Blob& key) return ret; } -Blob aesEncrypt(const Blob& data, const std::string& password) +Blob aesEncrypt(const Blob& data, std::string_view password) { Blob salt; Blob key = stretchKey(password, salt, 256 / 8); @@ -152,16 +152,35 @@ Blob aesDecrypt(const uint8_t* data, size_t data_length, const Blob& key) return ret; } -Blob aesDecrypt(const uint8_t* data, size_t data_len, const std::string& password) +Blob aesDecrypt(const uint8_t* data, size_t data_length, std::string_view password) { - if (data_len <= PASSWORD_SALT_LENGTH) + return aesDecrypt( + aesGetEncrypted(data, data_length), + aesGetKey(data, data_length, password) + ); +} + +Blob aesGetSalt(const uint8_t* data, size_t data_length) +{ + if (data_length <= PASSWORD_SALT_LENGTH) + throw DecryptError("Wrong data size"); + return Blob {data, data+PASSWORD_SALT_LENGTH}; +} + +std::string_view aesGetEncrypted(const uint8_t* data, size_t data_length) +{ + if (data_length <= PASSWORD_SALT_LENGTH) throw DecryptError("Wrong data size"); - Blob salt {data, data+PASSWORD_SALT_LENGTH}; - Blob key = stretchKey(password, salt, 256/8); - return aesDecrypt(data+PASSWORD_SALT_LENGTH, data_len - PASSWORD_SALT_LENGTH, key); + return std::string_view((const char*)(data+PASSWORD_SALT_LENGTH), data_length - PASSWORD_SALT_LENGTH); +} + +Blob aesGetKey(const uint8_t* data, size_t data_length, std::string_view password) +{ + Blob salt = aesGetSalt(data, data_length); + return stretchKey(password, salt, 256/8); } -Blob stretchKey(const std::string& password, Blob& salt, size_t key_length) +Blob stretchKey(std::string_view password, Blob& salt, size_t key_length) { if (salt.empty()) { salt.resize(PASSWORD_SALT_LENGTH); diff --git a/tests/cryptotester.cpp b/tests/cryptotester.cpp index 4f1b2f98a..81b2c5c54 100644 --- a/tests/cryptotester.cpp +++ b/tests/cryptotester.cpp @@ -173,6 +173,54 @@ void CryptoTester::testOcsp() { CPPUNIT_ASSERT(ocspRequest.second == req.getNonce()); } +void CryptoTester::testAesEncryption() { + auto password = "this is a password 123414!@#%@#$?" + std::to_string(rand()); + + std::vector data1 {5, 10}; + std::vector data2(128 * 1024 + 13, 10); + + auto encrypted1 = dht::crypto::aesEncrypt(data1, password); + auto encrypted2 = dht::crypto::aesEncrypt(data2, password); + + auto decrypted1 = dht::crypto::aesDecrypt(encrypted1, password); + auto decrypted2 = dht::crypto::aesDecrypt(encrypted2, password); + + CPPUNIT_ASSERT(data1 != encrypted1); + CPPUNIT_ASSERT(data2 != encrypted2); + CPPUNIT_ASSERT(data1 == decrypted1); + CPPUNIT_ASSERT(data2 == decrypted2); + + auto key1 = dht::crypto::aesGetKey(encrypted1, password); + auto key2 = dht::crypto::aesGetKey(encrypted2, password); + auto encrypted1_data = dht::crypto::aesGetEncrypted(encrypted1); + auto encrypted2_data = dht::crypto::aesGetEncrypted(encrypted2); + + CPPUNIT_ASSERT(key1 != key2); + + decrypted1 = dht::crypto::aesDecrypt(encrypted1_data, key1); + decrypted2 = dht::crypto::aesDecrypt(encrypted2_data, key2); + + CPPUNIT_ASSERT(data1 == decrypted1); + CPPUNIT_ASSERT(data2 == decrypted2); + + auto salt1 = dht::crypto::aesGetSalt(encrypted1); + auto salt2 = dht::crypto::aesGetSalt(encrypted2); + + CPPUNIT_ASSERT(salt1 != salt2); + + auto key12 = dht::crypto::stretchKey(password, salt1, 256/8); + auto key22 = dht::crypto::stretchKey(password, salt2, 256/8); + + CPPUNIT_ASSERT(key1 == key12); + CPPUNIT_ASSERT(key2 == key22); + + decrypted1 = dht::crypto::aesDecrypt(encrypted1_data, key12); + decrypted2 = dht::crypto::aesDecrypt(encrypted2_data, key22); + + CPPUNIT_ASSERT(data1 == decrypted1); + CPPUNIT_ASSERT(data2 == decrypted2); +} + void CryptoTester::tearDown() { diff --git a/tests/cryptotester.h b/tests/cryptotester.h index e56d15a53..890190571 100644 --- a/tests/cryptotester.h +++ b/tests/cryptotester.h @@ -33,6 +33,7 @@ class CryptoTester : public CppUnit::TestFixture { CPPUNIT_TEST(testCertificateRequest); CPPUNIT_TEST(testCertificateSerialNumber); CPPUNIT_TEST(testOcsp); + CPPUNIT_TEST(testAesEncryption); CPPUNIT_TEST_SUITE_END(); public: @@ -64,6 +65,10 @@ class CryptoTester : public CppUnit::TestFixture { * Test OCSP */ void testOcsp(); + /** + * Test key streching and aes encryption/decryption + */ + void testAesEncryption(); }; } // namespace test