Skip to content

Commit

Permalink
migrate aes (#20)
Browse files Browse the repository at this point in the history
* aes-gcm: aes components and tests

aes-gcm: aes components and tests

* aes-gcm: Ghash components and tests

* aes-gcm gctr and gcm

aes-gcm gctr and gcm

* aes-gcm nivc components and tests

* test and document new utils

* final nivc component

final nivc component

* Resolving CI errors

* test path fix
  • Loading branch information
0xJepsen authored Nov 1, 2024
1 parent c2e113e commit adbbf2c
Show file tree
Hide file tree
Showing 27 changed files with 3,094 additions and 110 deletions.
145 changes: 145 additions & 0 deletions circuits/aes-gcm/aes-gcm.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
pragma circom 2.1.9;

include "ghash/ghash.circom";
include "aes/cipher.circom";
include "../utils/array.circom";
include "gctr.circom";


/// AES-GCM with 128 bit key authenticated encryption according to: https://nvlpubs.nist.gov/nistpubs/legacy/sp/nistspecialpublication800-38d.pdf
///
/// Parameters:
/// l: length of the plaintext
///
/// Inputs:
/// key: 128-bit key
/// iv: initialization vector
/// plainText: plaintext to be encrypted
/// aad: additional data to be authenticated
///
/// Outputs:
/// cipherText: encrypted ciphertext
/// authTag: authentication tag
///
template AESGCM(l) {
// Inputs
signal input key[16]; // 128-bit key
signal input iv[12]; // IV length is 96 bits (12 bytes)
signal input plainText[l];
signal input aad[16]; // AAD length is 128 bits (16 bytes)

// Outputs
signal output cipherText[l];
signal output authTag[16]; // Authentication tag length is 128 bits (16 bytes)

component zeroBlock = ToBlocks(16);
for (var i = 0; i < 16; i++) {
zeroBlock.stream[i] <== 0;
}

// Step 1: Let H = aes(key, zeroBlock)
component cipherH = Cipher();
cipherH.key <== key;
cipherH.block <== zeroBlock.blocks[0];

// Step 2: Define a block, J0 with 96 bits of iv and 32 bits of 0s
component J0builder = ToBlocks(16);
for (var i = 0; i < 12; i++) {
J0builder.stream[i] <== iv[i];
}
for (var i = 12; i < 16; i++) {
J0builder.stream[i] <== 0;
}
component J0WordIncrementer = IncrementWord();
J0WordIncrementer.in <== J0builder.blocks[0][3];

component J0WordIncrementer2 = IncrementWord();
J0WordIncrementer2.in <== J0WordIncrementer.out;

signal J0[4][4];
for (var i = 0; i < 3; i++) {
J0[i] <== J0builder.blocks[0][i];
}
J0[3] <== J0WordIncrementer2.out;

// Step 3: Let C = GCTRK(inc32(J0), P)
component gctr = GCTR(l);
gctr.key <== key;
gctr.initialCounterBlock <== J0;
gctr.plainText <== plainText;


// Step 4: Let u and v (v is always zero with out key size and aad length)
var blockCount = l\16;
if(l%16 > 0){
blockCount = blockCount + 1;
}
// so the reason there is a plus two is because
// the first block is the aad
// the second is the ciphertext
// the last is the length of the aad and ciphertext
// i.e. S = GHASHH (A || C || [len(A)] || [len(C)]). <- which is always 48 bytes: 3 blocks
var ghashblocks = blockCount + 2;
signal ghashMessage[ghashblocks][4][4];

// set aad as first block
for (var i=0; i < 4; i++) {
for (var j=0; j < 4; j++) {
ghashMessage[0][i][j] <== aad[i*4+j];
}
}
// set cipher text block padded
component ciphertextBlocks = ToBlocks(l);
ciphertextBlocks.stream <== gctr.cipherText;

for (var i=0; i<blockCount; i++) {
ghashMessage[i+1] <== ciphertextBlocks.blocks[i];
}

// length of aad = 128 = 0x80 as 64 bit number
ghashMessage[ghashblocks-1][0] <== [0x00, 0x00, 0x00, 0x00];
ghashMessage[ghashblocks-1][1] <== [0x00, 0x00, 0x00, 0x80];

var len = blockCount * 128;
for (var i=0; i<8; i++) {
var byte_value = 0;
for (var j=0; j<8; j++) {
byte_value += (len >> i*8+j) & 1;
}
ghashMessage[ghashblocks-1][i\4+2][i%4] <== byte_value;
}

// Step 5: Define a block, S
// needs to take in the number of blocks
component ghash = GHASH(ghashblocks);
component hashKeyToStream = ToStream(1, 16);
hashKeyToStream.blocks[0] <== cipherH.cipher;
ghash.HashKey <== hashKeyToStream.stream;
// S = GHASHH (A || 0^v || C || 0^u || [len(A)] || [len(C)]).
component selectedBlocksToStream[ghashblocks];
for (var i = 0 ; i<ghashblocks ; i++) {
ghash.msg[i] <== ToStream(1, 16)([ghashMessage[i]]);
}

signal bytes[16];
signal tagBytes[16 * 8] <== BytesToBits(16)(ghash.tag);
for(var i = 0; i < 16; i++) {
var byteValue = 0;
var sum=1;
for(var j = 0; j<8; j++) {
var bitIndex = i*8+j;
byteValue += tagBytes[bitIndex]*sum;
sum = sum*sum;
}
bytes[i] <== byteValue;
}

// Step 6: Let T = MSBt(GCTRK(J0, S))
component gctrT = GCTR(16);
gctrT.key <== key;
gctrT.initialCounterBlock <== J0;
gctrT.plainText <== bytes;

authTag <== gctrT.cipherText;
cipherText <== gctr.cipherText;
}
159 changes: 159 additions & 0 deletions circuits/aes-gcm/aes/cipher.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// from: https://github.com/crema-labs/aes-circom/tree/main/circuits
pragma circom 2.1.9;

include "key_expansion.circom";
include "mix_columns.circom";
include "../../utils/bytes.circom";

// Cipher Process
// AES 128 keys have 10 rounds.
// Input Block Initial Round Key Round Key Final Round Key
// │ │ │ │
// ▼ ▼ ▼ ▼
// ┌─────────┐ ┌──────────┐ ┌────────┐ ┌──────────┐ ┌────────┐ ┌──────────┐
// │ Block │──► │ Add │ │ Sub │ │ Mix │ │ Sub │ │ Add │
// │ │ │ Round │ │ Bytes │ │ Columns │ │ Bytes │ │ Round │
// │ │ │ Key │ │ │ │ │ │ │ │ Key │
// └─────────┘ └────┬─────┘ └───┬────┘ └────┬─────┘ └───┬────┘ └────┬─────┘
// │ │ │ │ │
// ▼ ▼ ▼ ▼ ▼
// ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐
// │ Round 0 │ │ Round 1 │ │ Round 2 │ │ Round │ │ Final │
// │ │ │ to │ │ to │ │ Nr - 1 │ │ Round │
// │ │ │ Nr - 2 │ │ Nr - 1 │ │ │ │ │
// └─────────┘ └─────────┘ └─────────┘ └─────────┘ └────┬────┘
//
//
// Ciphertext


// @inputs block: 4x4 matrix representing the input block
// @inputs key: array of 16 bytes representing the key
// @outputs cipher: 4x4 matrix representing the output block
template Cipher(){
signal input block[4][4];
signal input key[16];
signal output cipher[4][4];

component keyExpansion = KeyExpansion();
keyExpansion.key <== key;

component addRoundKey[11];
component subBytes[10];
component shiftRows[10];
component mixColumns[9];

signal interBlock[10][4][4];

addRoundKey[0] = AddRoundKey();
addRoundKey[0].state <== block;
for (var i = 0; i < 4; i++) {
addRoundKey[0].roundKey[i] <== keyExpansion.keyExpanded[i];
}

interBlock[0] <== addRoundKey[0].newState;
// for each round.
for (var i = 1; i < 10; i++) {
// SubBytes
subBytes[i-1] = SubBlock();
subBytes[i-1].state <== interBlock[i-1];

// ShiftRows
shiftRows[i-1] = ShiftRows();
shiftRows[i-1].state <== subBytes[i-1].newState;

// MixColumns
mixColumns[i-1] = MixColumns();
mixColumns[i-1].state <== shiftRows[i-1].newState;

// AddRoundKey
addRoundKey[i] = AddRoundKey();
addRoundKey[i].state <== mixColumns[i-1].out;
for (var j = 0; j < 4; j++) {
addRoundKey[i].roundKey[j] <== keyExpansion.keyExpanded[j + (i * 4)];
}

interBlock[i] <== addRoundKey[i].newState;
}

// Final SubBytes
subBytes[9] = SubBlock();
subBytes[9].state <== interBlock[9];

shiftRows[9] = ShiftRows();
shiftRows[9].state <== subBytes[9].newState;

// Final AddRoundKey
addRoundKey[10] = AddRoundKey();
addRoundKey[10].state <== shiftRows[9].newState;
for (var i = 0; i < 4; i++) {
addRoundKey[10].roundKey[i] <== keyExpansion.keyExpanded[i + (40)];
}

cipher <== addRoundKey[10].newState;
}

// XORs a cipher state: 4x4 byte array
template AddCipher(){
signal input state[4][4];
signal input cipher[4][4];
signal output newState[4][4];

component xorbyte[4][4];

for (var i = 0; i < 4; i++) {
for (var j = 0; j < 4; j++) {
xorbyte[i][j] = XorByte();
xorbyte[i][j].a <== state[i][j];
xorbyte[i][j].b <== cipher[i][j];
newState[i][j] <== xorbyte[i][j].out;
}
}
}

// ShiftRows: Performs circular left shift on each row
// 0, 1, 2, 3 shifts for rows 0, 1, 2, 3 respectively
template ShiftRows(){
signal input state[4][4];
signal output newState[4][4];

component shiftWord[4];

for (var i = 0; i < 4; i++) {
// Rotate: Performs circular left shift on each row
shiftWord[i] = Rotate(i, 4);
shiftWord[i].bytes <== state[i];
newState[i] <== shiftWord[i].rotated;
}
}

// Applies S-box substitution to each byte
template SubBlock(){
signal input state[4][4];
signal output newState[4][4];
component sbox[4];

for (var i = 0; i < 4; i++) {
sbox[i] = SubstituteWord();
sbox[i].bytes <== state[i];
newState[i] <== sbox[i].substituted;
}
}

// AddRoundKey: XORs the state with transposed the round key
template AddRoundKey(){
signal input state[4][4];
signal input roundKey[4][4];
signal output newState[4][4];

component xorbyte[4][4];

for (var i = 0; i < 4; i++) {
for (var j = 0; j < 4; j++) {
xorbyte[i][j] = XorByte();
xorbyte[i][j].a <== state[i][j];
xorbyte[i][j].b <== roundKey[j][i];
newState[i][j] <== xorbyte[i][j].out;
}
}
}
Loading

0 comments on commit adbbf2c

Please sign in to comment.