Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate aes #20

Merged
merged 8 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 145 additions & 0 deletions circuits/aes-gcm/aes-gcm.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
pragma circom 2.1.9;

include "ghash/ghash.circom";
include "aes/cipher.circom";
include "../utils/array.circom";
include "gctr.circom";


/// AES-GCM with 128 bit key authenticated encryption according to: https://nvlpubs.nist.gov/nistpubs/legacy/sp/nistspecialpublication800-38d.pdf
///
/// Parameters:
/// l: length of the plaintext
///
/// Inputs:
/// key: 128-bit key
/// iv: initialization vector
/// plainText: plaintext to be encrypted
/// aad: additional data to be authenticated
///
/// Outputs:
/// cipherText: encrypted ciphertext
/// authTag: authentication tag
///
template AESGCM(l) {
// Inputs
signal input key[16]; // 128-bit key
signal input iv[12]; // IV length is 96 bits (12 bytes)
signal input plainText[l];
signal input aad[16]; // AAD length is 128 bits (16 bytes)

// Outputs
signal output cipherText[l];
signal output authTag[16]; // Authentication tag length is 128 bits (16 bytes)

component zeroBlock = ToBlocks(16);
for (var i = 0; i < 16; i++) {
zeroBlock.stream[i] <== 0;
}

// Step 1: Let H = aes(key, zeroBlock)
component cipherH = Cipher();
cipherH.key <== key;
cipherH.block <== zeroBlock.blocks[0];

// Step 2: Define a block, J0 with 96 bits of iv and 32 bits of 0s
component J0builder = ToBlocks(16);
for (var i = 0; i < 12; i++) {
J0builder.stream[i] <== iv[i];
}
for (var i = 12; i < 16; i++) {
J0builder.stream[i] <== 0;
}
component J0WordIncrementer = IncrementWord();
J0WordIncrementer.in <== J0builder.blocks[0][3];

component J0WordIncrementer2 = IncrementWord();
J0WordIncrementer2.in <== J0WordIncrementer.out;

signal J0[4][4];
for (var i = 0; i < 3; i++) {
J0[i] <== J0builder.blocks[0][i];
}
J0[3] <== J0WordIncrementer2.out;

// Step 3: Let C = GCTRK(inc32(J0), P)
component gctr = GCTR(l);
gctr.key <== key;
gctr.initialCounterBlock <== J0;
gctr.plainText <== plainText;


// Step 4: Let u and v (v is always zero with out key size and aad length)
var blockCount = l\16;
if(l%16 > 0){
blockCount = blockCount + 1;
}
// so the reason there is a plus two is because
// the first block is the aad
// the second is the ciphertext
// the last is the length of the aad and ciphertext
// i.e. S = GHASHH (A || C || [len(A)] || [len(C)]). <- which is always 48 bytes: 3 blocks
var ghashblocks = blockCount + 2;
signal ghashMessage[ghashblocks][4][4];

// set aad as first block
for (var i=0; i < 4; i++) {
for (var j=0; j < 4; j++) {
ghashMessage[0][i][j] <== aad[i*4+j];
}
}
// set cipher text block padded
component ciphertextBlocks = ToBlocks(l);
ciphertextBlocks.stream <== gctr.cipherText;

for (var i=0; i<blockCount; i++) {
ghashMessage[i+1] <== ciphertextBlocks.blocks[i];
}

// length of aad = 128 = 0x80 as 64 bit number
ghashMessage[ghashblocks-1][0] <== [0x00, 0x00, 0x00, 0x00];
ghashMessage[ghashblocks-1][1] <== [0x00, 0x00, 0x00, 0x80];

var len = blockCount * 128;
for (var i=0; i<8; i++) {
var byte_value = 0;
for (var j=0; j<8; j++) {
byte_value += (len >> i*8+j) & 1;
}
ghashMessage[ghashblocks-1][i\4+2][i%4] <== byte_value;
}

// Step 5: Define a block, S
// needs to take in the number of blocks
component ghash = GHASH(ghashblocks);
component hashKeyToStream = ToStream(1, 16);
hashKeyToStream.blocks[0] <== cipherH.cipher;
ghash.HashKey <== hashKeyToStream.stream;
// S = GHASHH (A || 0^v || C || 0^u || [len(A)] || [len(C)]).
component selectedBlocksToStream[ghashblocks];
for (var i = 0 ; i<ghashblocks ; i++) {
ghash.msg[i] <== ToStream(1, 16)([ghashMessage[i]]);
}

signal bytes[16];
signal tagBytes[16 * 8] <== BytesToBits(16)(ghash.tag);
for(var i = 0; i < 16; i++) {
var byteValue = 0;
var sum=1;
for(var j = 0; j<8; j++) {
var bitIndex = i*8+j;
byteValue += tagBytes[bitIndex]*sum;
sum = sum*sum;
}
bytes[i] <== byteValue;
}

// Step 6: Let T = MSBt(GCTRK(J0, S))
component gctrT = GCTR(16);
gctrT.key <== key;
gctrT.initialCounterBlock <== J0;
gctrT.plainText <== bytes;

authTag <== gctrT.cipherText;
cipherText <== gctr.cipherText;
}
159 changes: 159 additions & 0 deletions circuits/aes-gcm/aes/cipher.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// from: https://github.com/crema-labs/aes-circom/tree/main/circuits
pragma circom 2.1.9;

include "key_expansion.circom";
include "mix_columns.circom";
include "../../utils/bytes.circom";

// Cipher Process
// AES 128 keys have 10 rounds.
// Input Block Initial Round Key Round Key Final Round Key
// │ │ │ │
// ▼ ▼ ▼ ▼
// ┌─────────┐ ┌──────────┐ ┌────────┐ ┌──────────┐ ┌────────┐ ┌──────────┐
// │ Block │──► │ Add │ │ Sub │ │ Mix │ │ Sub │ │ Add │
// │ │ │ Round │ │ Bytes │ │ Columns │ │ Bytes │ │ Round │
// │ │ │ Key │ │ │ │ │ │ │ │ Key │
// └─────────┘ └────┬─────┘ └───┬────┘ └────┬─────┘ └───┬────┘ └────┬─────┘
// │ │ │ │ │
// ▼ ▼ ▼ ▼ ▼
// ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐
// │ Round 0 │ │ Round 1 │ │ Round 2 │ │ Round │ │ Final │
// │ │ │ to │ │ to │ │ Nr - 1 │ │ Round │
// │ │ │ Nr - 2 │ │ Nr - 1 │ │ │ │ │
// └─────────┘ └─────────┘ └─────────┘ └─────────┘ └────┬────┘
// │
// ▼
// Ciphertext


// @inputs block: 4x4 matrix representing the input block
// @inputs key: array of 16 bytes representing the key
// @outputs cipher: 4x4 matrix representing the output block
template Cipher(){
signal input block[4][4];
signal input key[16];
signal output cipher[4][4];

component keyExpansion = KeyExpansion();
keyExpansion.key <== key;

component addRoundKey[11];
component subBytes[10];
component shiftRows[10];
component mixColumns[9];

signal interBlock[10][4][4];

addRoundKey[0] = AddRoundKey();
addRoundKey[0].state <== block;
for (var i = 0; i < 4; i++) {
addRoundKey[0].roundKey[i] <== keyExpansion.keyExpanded[i];
}

interBlock[0] <== addRoundKey[0].newState;
// for each round.
for (var i = 1; i < 10; i++) {
// SubBytes
subBytes[i-1] = SubBlock();
subBytes[i-1].state <== interBlock[i-1];

// ShiftRows
shiftRows[i-1] = ShiftRows();
shiftRows[i-1].state <== subBytes[i-1].newState;

// MixColumns
mixColumns[i-1] = MixColumns();
mixColumns[i-1].state <== shiftRows[i-1].newState;

// AddRoundKey
addRoundKey[i] = AddRoundKey();
addRoundKey[i].state <== mixColumns[i-1].out;
for (var j = 0; j < 4; j++) {
addRoundKey[i].roundKey[j] <== keyExpansion.keyExpanded[j + (i * 4)];
}

interBlock[i] <== addRoundKey[i].newState;
}

// Final SubBytes
subBytes[9] = SubBlock();
subBytes[9].state <== interBlock[9];

shiftRows[9] = ShiftRows();
shiftRows[9].state <== subBytes[9].newState;

// Final AddRoundKey
addRoundKey[10] = AddRoundKey();
addRoundKey[10].state <== shiftRows[9].newState;
for (var i = 0; i < 4; i++) {
addRoundKey[10].roundKey[i] <== keyExpansion.keyExpanded[i + (40)];
}

cipher <== addRoundKey[10].newState;
}

// XORs a cipher state: 4x4 byte array
template AddCipher(){
signal input state[4][4];
signal input cipher[4][4];
signal output newState[4][4];

component xorbyte[4][4];

for (var i = 0; i < 4; i++) {
for (var j = 0; j < 4; j++) {
xorbyte[i][j] = XorByte();
xorbyte[i][j].a <== state[i][j];
xorbyte[i][j].b <== cipher[i][j];
newState[i][j] <== xorbyte[i][j].out;
}
}
}

// ShiftRows: Performs circular left shift on each row
// 0, 1, 2, 3 shifts for rows 0, 1, 2, 3 respectively
template ShiftRows(){
signal input state[4][4];
signal output newState[4][4];

component shiftWord[4];

for (var i = 0; i < 4; i++) {
// Rotate: Performs circular left shift on each row
shiftWord[i] = Rotate(i, 4);
shiftWord[i].bytes <== state[i];
newState[i] <== shiftWord[i].rotated;
}
}

// Applies S-box substitution to each byte
template SubBlock(){
signal input state[4][4];
signal output newState[4][4];
component sbox[4];

for (var i = 0; i < 4; i++) {
sbox[i] = SubstituteWord();
sbox[i].bytes <== state[i];
newState[i] <== sbox[i].substituted;
}
}

// AddRoundKey: XORs the state with transposed the round key
template AddRoundKey(){
signal input state[4][4];
signal input roundKey[4][4];
signal output newState[4][4];

component xorbyte[4][4];

for (var i = 0; i < 4; i++) {
for (var j = 0; j < 4; j++) {
xorbyte[i][j] = XorByte();
xorbyte[i][j].a <== state[i][j];
xorbyte[i][j].b <== roundKey[j][i];
newState[i][j] <== xorbyte[i][j].out;
}
}
}
Loading