Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: AES decryption #124

Merged
merged 15 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions src/encryption/symmetric/aes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,20 @@ Unlike DES, it does not use a Feistel network, and most AES calculations are don
## Algorithm

The core encryption algorithm consists of the following routines:
- [KeyExpansion](#KeyExpansion)
- [AddRoundKey](#AddRoundKey)

- [KeyExpansion][keyexp]
- [AddRoundKey][arc]
- [SubBytes](#SubBytes)
- [ShiftRows](#ShiftRows)
- [MixColumns](#MixColumns)

For decryption, we take the inverses of these routines:
For decryption, we take the inverses of these following routines:

- [InvSubBytes](#InvSubBytes)
- [InvShiftRows](#InvShiftRows)
- [InvMixColumns](#InvMixColumns)

TODO
Note that we do not need the inverses of [KeyExpansion][keyexp] or [AddRoundKey][arc], since for decryption we're simply using the round keys from the last to the first, and [AddRoundKey][arc] is its own inverse.

### Encryption

Expand Down Expand Up @@ -75,7 +80,7 @@ Substitutes each byte in the `State` with another byte according to a [substitut

#### ShiftRow

Shift i-th row of i positions, for i ranging from 0 to 3, eg. Row 0: no shift occurs, row 1: a left shift of 1 position occurs.
Do a **left** shift i-th row of i positions, for i ranging from 0 to 3, eg. Row 0: no shift occurs, row 1: a left shift of 1 position occurs.

#### MixColumns

Expand All @@ -87,7 +92,24 @@ More details can be found [here][mixcolumns].

### Decryption

TODO
For decryption, we use the inverses of some of the above routines to decrypt a ciphertext. To reiterate, we do not need the inverses of [KeyExpansion][keyexp] or [AddRoundKey][arc], since for decryption we're simply using the round keys from the last to the first, and [AddRoundKey][arc] is its own inverse.


#### InvSubBytes

Substitutes each byte in the `State` with another byte according to a [substitution box](#substitution-box). Note that the only difference here is that the substitution box used in decryption is derived differently from the version used in encryption.

#### InvShiftRows

Do a **right** shift i-th row of i positions, for i ranging from 0 to 3, eg. Row 0: no shift occurs, row 1: a right shift of 1 position occurs.

#### InvMixColumns

Each column of bytes is treated as a 4-term polynomial, multiplied modulo x^4 + 1 with the inverse of the fixed polynomial
a(x) = 3x^3 + x^2 + x + 2 found in the [MixColumns] step. The inverse of a(x) is a^-1(x) = 11x^3 + 13x^2 + 9x + 14. This multiplication is done using matrix multiplication.

More details can be found [here][mixcolumns].


## Substitution Box

Expand Down Expand Up @@ -117,6 +139,8 @@ In production-level AES code, fast AES software uses special techniques called t
[des]: ../des/README.md
[spn]: https://en.wikipedia.org/wiki/Substitution%E2%80%93permutation_network
[slide attacks]: https://en.wikipedia.org/wiki/Slide_attack
[keyexp]: #KeyExpansion
[arc]: #AddRoundKey
[mixcolumns]: https://en.wikipedia.org/wiki/Rijndael_MixColumns
[Rijndael ff]: https://en.wikipedia.org/wiki/Finite_field_arithmetic#Rijndael's_(AES)_finite_field
[fips197]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.197-upd1.pdf
Expand Down
232 changes: 208 additions & 24 deletions src/encryption/symmetric/aes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,20 @@
//! and decryption.
#![cfg_attr(not(doctest), doc = include_str!("./README.md"))]

use std::ops::Mul;

use itertools::Itertools;

use crate::field::{extension::AESFieldExtension, prime::AESField};

pub mod sbox;
#[cfg(test)] pub mod tests;

use super::SymmetricEncryption;
use crate::encryption::symmetric::aes::sbox::SBOX;
use crate::{
encryption::symmetric::aes::sbox::{INVERSE_SBOX, SBOX},
field::FiniteField,
};

/// A block in AES represents a 128-bit sized message data.
pub type Block = [u8; 16];
Expand Down Expand Up @@ -72,7 +79,34 @@ where [(); N / 8]:
Self::aes_encrypt(plaintext, key, num_rounds)
}

fn decrypt(_key: &Self::Key, _ciphertext: &Self::Block) -> Self::Block { unimplemented!() }
/// Decrypt a ciphertext of size [`Block`] with a [`Key`] of size `N`-bits.
///
/// ## Example
/// ```rust
/// #![feature(generic_const_exprs)]
///
/// use rand::{thread_rng, Rng};
/// use ronkathon::encryption::symmetric::{
/// aes::{Key, AES},
/// SymmetricEncryption,
/// };
///
/// let mut rng = thread_rng();
/// let key = Key::<128>::new(rng.gen());
/// let plaintext = rng.gen();
/// let encrypted = AES::encrypt(&key, &plaintext);
/// let decrypted = AES::decrypt(&key, &encrypted);
/// ```
fn decrypt(key: &Self::Key, ciphertext: &Self::Block) -> Self::Block {
let num_rounds = match N {
128 => 10,
192 => 12,
256 => 14,
_ => panic!("AES only supports key sizes 128, 192 and 256 bits. You provided: {N}"),
};

Self::aes_decrypt(ciphertext, key, num_rounds)
}
}

/// Contains the values given by [x^(i-1), {00}, {00}, {00}], with x^(i-1)
Expand Down Expand Up @@ -109,6 +143,42 @@ pub struct AES<const N: usize> {}
#[derive(Debug, Default, Clone, Copy, PartialEq)]
struct State([[u8; 4]; 4]);

/// Multiplies a 8-bit number in the Galois field GF(2^8).
///
/// This is defined on two bytes in two steps:
///
/// 1) The two polynomials that represent the bytes are multiplied as polynomials,
/// 2) The resulting polynomial is reduced modulo the following fixed polynomial: m(x) = x^8 + x^4 +
/// x^3 + x + 1
///
/// The above steps are implemented in [`AESFieldExtension`], within the operation traits.
///
/// Note that in most AES implementations, this is done using "carry-less" multiplication -
/// to see how this works in more concretely in field arithmetic, this implementation uses an actual
/// polynomial implementation.
fn galois_multiplication(mut col: u8, mut multiplicand: u8) -> u8 {
// Decompose bits into degree-7 polynomials.
let mut col_bits: [AESField; 8] = [AESField::ZERO; 8];
let mut mult_bits: [AESField; 8] = [AESField::ZERO; 8];
for i in 0..8 {
col_bits[i] = AESField::new((col & 1).into());
mult_bits[i] = AESField::new((multiplicand & 1).into());
col >>= 1;
multiplicand >>= 1;
}

let col_poly = AESFieldExtension::new(col_bits);
let mult_poly = AESFieldExtension::new(mult_bits);
let res = col_poly.mul(mult_poly);

let mut product: u8 = 0;
for i in 0..8 {
product += res.coeffs[i].value as u8 * 2_u8.pow(i as u32);
}

product
}

impl<const N: usize> AES<N>
where [(); N / 8]:
{
Expand Down Expand Up @@ -156,6 +226,51 @@ where [(); N / 8]:
state.0.into_iter().flatten().collect::<Vec<_>>().try_into().unwrap()
}

/// Deciphers a given `ciphertext`, with key size of `N` (in bits), as seen in Figure 5 of the
/// document linked in the front-page.
fn aes_decrypt(ciphertext: &[u8; 16], key: &Key<N>, num_rounds: usize) -> Block {
assert!(!key.is_empty(), "Key is not instantiated");

let key_len_words = N / 32;
let mut round_keys_words = Vec::with_capacity(key_len_words * (num_rounds + 1));
Self::key_expansion(*key, &mut round_keys_words, key_len_words, num_rounds);
// For decryption; we use the round keys from the back, so we iterate from the back here.
let mut round_keys = round_keys_words.chunks_exact(4).rev();

let mut state = State(
ciphertext
.chunks(4)
.map(|c| c.try_into().unwrap())
.collect::<Vec<[u8; 4]>>()
.try_into()
.unwrap(),
);
assert!(state != State::default(), "State is not instantiated");

// Round 0 - add round key
Self::add_round_key(&mut state, round_keys.next().unwrap());

// Rounds 1 to N - 1
for _ in 1..num_rounds {
Self::inv_shift_rows(&mut state);
Self::inv_sub_bytes(&mut state);
Self::add_round_key(&mut state, round_keys.next().unwrap());
Self::inv_mix_columns(&mut state);
}

// Last round - we do not mix columns here.
Self::inv_shift_rows(&mut state);
Self::inv_sub_bytes(&mut state);
Self::add_round_key(&mut state, round_keys.next().unwrap());

assert!(
round_keys.next().is_none(),
"Round keys not fully consumed - perhaps check key expansion?"
);

state.0.into_iter().flatten().collect::<Vec<_>>().try_into().unwrap()
}

/// XOR a round key to its internal state.
fn add_round_key(state: &mut State, round_key: &[[u8; 4]]) {
for (col, word) in state.0.iter_mut().zip(round_key.iter()) {
Expand All @@ -175,9 +290,22 @@ where [(); N / 8]:
}
}

/// Substitutes each byte [s_0, s_1, ..., s_15] with another byte according to a substitution box
/// (usually referred to as an S-box).
///
/// Note that the only difference here from [`Self::sub_bytes`] is that we use a different
/// substitution box [`INVERSE_SBOX`], which is derived differently.
fn inv_sub_bytes(state: &mut State) {
for i in 0..4 {
for j in 0..4 {
state.0[i][j] = INVERSE_SBOX[state.0[i][j] as usize];
}
}
}

/// Shift i-th row of i positions, for i ranging from 0 to 3.
///
/// For row 0, no shifting occurs, for row 1, a left shift of 1 index occurs, ..
/// For row 0, no shifting occurs, for row 1, a **left** shift of 1 index occurs, ..
///
/// Note that since our state is in column-major form, we transpose the state to a
/// row-major form to make this step simpler.
Expand All @@ -190,8 +318,7 @@ where [(); N / 8]:
(0..len).map(|_| iters.iter_mut().map(|n| n.next().unwrap()).collect::<Vec<_>>()).collect();

for (r, i) in transposed.iter_mut().zip(0..4) {
let (left, right) = r.split_at(i);
*r = [right.to_vec(), left.to_vec()].concat();
r.rotate_left(i);
}
let mut iters: Vec<_> = transposed.into_iter().map(|n| n.into_iter()).collect();

Expand All @@ -202,35 +329,92 @@ where [(); N / 8]:
.unwrap();
}

/// Applies the same linear transformation to each of the four columns of the state.
/// The inverse of [`Self::shift_rows`].
///
/// Shift i-th row of i positions, for i ranging from 0 to 3.
///
/// Mix columns is done as such:
/// For row 0, no shifting occurs, for row 1, a **right** shift of 1 index occurs, ..
///
/// Each column of bytes is treated as a 4-term polynomial, multiplied modulo x^4 + 1 with a fixed
/// polynomial a(x) = 3x^3 + x^2 + x + 2. This is done using matrix multiplication.
/// Note that since our state is in column-major form, we transpose the state to a
/// row-major form to make this step simpler.
fn inv_shift_rows(state: &mut State) {
let len = state.0.len();
let mut iters: Vec<_> = state.0.into_iter().map(|n| n.into_iter()).collect();

// Transpose to row-major form
let mut transposed: Vec<_> =
(0..len).map(|_| iters.iter_mut().map(|n| n.next().unwrap()).collect::<Vec<_>>()).collect();

for (r, i) in transposed.iter_mut().zip(0..4) {
r.rotate_right(i);
}
let mut iters: Vec<_> = transposed.into_iter().map(|n| n.into_iter()).collect();

state.0 = (0..len)
.map(|_| iters.iter_mut().map(|n| n.next().unwrap()).collect::<Vec<_>>().try_into().unwrap())
.collect::<Vec<_>>()
.try_into()
.unwrap();
}

/// Mixes the data in each of the 4 columns with a single fixed matrix, with its entries taken
/// from the word [a_0, a_1, a_2, a_3] = [{02}, {01}, {01}, {03}] (hex) (or [2, 1, 1, 3] in
/// decimal).
///
/// This is done by interpreting both the byte from the state and the byte from the fixed matrix
/// as degree-7 polynomials and doing multiplication in the GF(2^8) field. For details, see
/// [`galois_multiplication`].
fn mix_columns(state: &mut State) {
for col in state.0.iter_mut() {
let tmp = *col;
let mut col_doubled = *col;

// Perform the matrix multiplication in GF(2^8).
// We process the multiplications first, so we can just do additions later.
for (i, c) in col_doubled.iter_mut().enumerate() {
let hi_bit = col[i] >> 7;
*c = col[i] << 1;
*c ^= hi_bit * 0x1B; // This XOR brings the column back into the field if an
// overflow occurs (ie. hi_bit == 1)
}

// Do all additions (XORs) here.
// 2a0 + 3a1 + a2 + a3
col[0] = col_doubled[0] ^ tmp[3] ^ tmp[2] ^ col_doubled[1] ^ tmp[1];
col[0] =
galois_multiplication(tmp[0], 2) ^ tmp[3] ^ tmp[2] ^ galois_multiplication(tmp[1], 3);
// a0 + 2a1 + 3a2 + a3
col[1] = col_doubled[1] ^ tmp[0] ^ tmp[3] ^ col_doubled[2] ^ tmp[2];
col[1] =
galois_multiplication(tmp[1], 2) ^ tmp[0] ^ tmp[3] ^ galois_multiplication(tmp[2], 3);
// a0 + a1 + 2a2 + 3a3
col[2] = col_doubled[2] ^ tmp[1] ^ tmp[0] ^ col_doubled[3] ^ tmp[3];
col[2] =
galois_multiplication(tmp[2], 2) ^ tmp[1] ^ tmp[0] ^ galois_multiplication(tmp[3], 3);
// 3a0 + a1 + a2 + 2a3
col[3] = col_doubled[3] ^ tmp[2] ^ tmp[1] ^ col_doubled[0] ^ tmp[0];
col[3] =
galois_multiplication(tmp[3], 2) ^ tmp[2] ^ tmp[1] ^ galois_multiplication(tmp[0], 3);
}
}

/// The inverse of [`Self::mix_columns`].
///
/// Mixes the data in each of the 4 columns with a single fixed matrix, with its entries taken
/// from the word [a_0, a_1, a_2, a_3] = [{0e}, {09}, {0d}, {0b}] (or [14, 9, 13, 11] in decimal).
///
/// This is done by interpreting both the byte from the state and the byte from the fixed matrix
/// as degree-7 polynomials and doing multiplication in the GF(2^8) field. For details, see
/// [`galois_multiplication`].
fn inv_mix_columns(state: &mut State) {
for col in state.0.iter_mut() {
let tmp = *col;

// 14a0 + 11a1 + 13a2 + 9a3
col[0] = galois_multiplication(tmp[0], 14)
^ galois_multiplication(tmp[3], 9)
^ galois_multiplication(tmp[2], 13)
^ galois_multiplication(tmp[1], 11);
// 9a0 + 14a1 + 11a2 + 13a3
col[1] = galois_multiplication(tmp[1], 14)
^ galois_multiplication(tmp[0], 9)
^ galois_multiplication(tmp[3], 13)
^ galois_multiplication(tmp[2], 11);
// 13a0 + 9a1 + 14a2 + 11a3
col[2] = galois_multiplication(tmp[2], 14)
^ galois_multiplication(tmp[1], 9)
^ galois_multiplication(tmp[0], 13)
^ galois_multiplication(tmp[3], 11);
// 11a0 + 13a1 + 9a2 + 14a3
col[3] = galois_multiplication(tmp[3], 14)
^ galois_multiplication(tmp[2], 9)
^ galois_multiplication(tmp[1], 13)
^ galois_multiplication(tmp[0], 11);
}
}

Expand Down
23 changes: 23 additions & 0 deletions src/encryption/symmetric/aes/sbox.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,26 @@ pub(crate) const SBOX: [u8; 256] = [
0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf,
0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16,
];

/// An inverse substitution box for an instance of [`AES`](super::AES).
///
/// Since substitution involves mapping a single byte (m = 8) into another (n = 8), we have a
/// lookup table of size 2^8 = 256 of 8 bits per index, implemented as a linear array.
pub(crate) const INVERSE_SBOX: [u8; 256] = [
0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb,
0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb,
0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e,
0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25,
0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92,
0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84,
0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06,
0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b,
0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73,
0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e,
0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b,
0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4,
0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f,
0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef,
0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61,
0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d,
];
Loading
Loading