Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adaptable foldable AES circuit #47

Merged
merged 5 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 39 additions & 21 deletions circuits/aes-gcm/nivc/aes-gctr-nivc.circom
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,47 @@ template AESGCTRFOLD(NUM_CHUNKS) {
aes[i].aad <== aad;
}

signal ciphertext_equal_check[NUM_CHUNKS][16];
for(var i = 0 ; i < NUM_CHUNKS; i++) {
for(var j = 0 ; j < 16 ; j++) {
ciphertext_equal_check[i][j] <== IsEqual()([aes[i].cipherText[j], cipherText[i][j]]);
ciphertext_equal_check[i][j] === 1;
}
// Regroup the plaintext and ciphertext into byte packed form
var computedCipherText[NUM_CHUNKS][16];
for(var i = 0 ; i < NUM_CHUNKS ; i++) {
computedCipherText[i] = aes[i].cipherText;
}
signal packedCiphertext[NUM_CHUNKS] <== GenericBytePackArray(NUM_CHUNKS, 16)(cipherText);
signal packedComputedCiphertext[NUM_CHUNKS] <== GenericBytePackArray(NUM_CHUNKS, 16)(computedCipherText);
signal packedPlaintext[NUM_CHUNKS] <== GenericBytePackArray(NUM_CHUNKS, 16)(plainText);


var packedPlaintext[NUM_CHUNKS];
for(var i = 0 ; i < NUM_CHUNKS ; i++) {
packedPlaintext[i] = 0;
for(var j = 0 ; j < 16 ; j++) {
packedPlaintext[i] += plainText[i][j] * 2**(8*j);
}
signal plaintext_input_was_zero_chunk[NUM_CHUNKS];
signal ciphertext_input_was_zero_chunk[NUM_CHUNKS];
signal both_input_chunks_were_zero[NUM_CHUNKS];
signal ciphertext_option[NUM_CHUNKS];
signal ciphertext_equal_check[NUM_CHUNKS];
for(var i = 0 ; i < NUM_CHUNKS; i++) {
plaintext_input_was_zero_chunk[i] <== IsZero()(packedPlaintext[i]);
ciphertext_input_was_zero_chunk[i] <== IsZero()(packedCiphertext[i]);
both_input_chunks_were_zero[i] <== plaintext_input_was_zero_chunk[i] * ciphertext_input_was_zero_chunk[i];
ciphertext_option[i] <== (1 - both_input_chunks_were_zero[i]) * packedComputedCiphertext[i];
ciphertext_equal_check[i] <== IsEqual()([packedCiphertext[i], ciphertext_option[i]]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice okay i see whats going on here. This is clever. Do we know how much perf it gives us? How many changes will this require downstream in the web-prover? Is it worth spending time on performance for aes when we know chacha20 is more performant over all?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is about allowing for doing less folds with AES which should boost performance especially in WASM. Every fold adds time, every fold adds memory pressure, and every fold makes compression slower.

Yes, we will likely move to chacha20, but it is never a bad idea to have a backup!

ciphertext_equal_check[i] === 1;
}
signal hash[NUM_CHUNKS];
step_out[0] <== AESHasher(NUM_CHUNKS)(packedPlaintext, step_in[0]);
}

// TODO (autoparallel): Could probably just have datahasher take in an initial hash as an input, but this was quicker to try first.
template AESHasher(NUM_CHUNKS) {
// TODO: add this assert back after witnesscalc supports
// assert(DATA_BYTES % 16 == 0);
signal input in[NUM_CHUNKS];
signal input initial_hash;
signal output out;

signal not_to_hash[NUM_CHUNKS];
signal option_hash[NUM_CHUNKS];
signal hashes[NUM_CHUNKS + 1];
hashes[0] <== initial_hash;
for(var i = 0 ; i < NUM_CHUNKS ; i++) {
if(i == 0) {
hash[i] <== PoseidonChainer()([step_in[0],packedPlaintext[i]]);
} else {
hash[i] <== PoseidonChainer()([hash[i-1], packedPlaintext[i]]);
}
not_to_hash[i] <== IsZero()(in[i]);
option_hash[i] <== PoseidonChainer()([hashes[i],in[i]]);
hashes[i+1] <== not_to_hash[i] * (hashes[i] - option_hash[i]) + option_hash[i]; // same as: (1 - not_to_hash[i]) * option_hash[i] + not_to_hash[i] * hash[i];
}
step_out[0] <== hash[NUM_CHUNKS - 1];
}
out <== hashes[NUM_CHUNKS];
}
23 changes: 20 additions & 3 deletions circuits/test/aes-gcm/nivc/aes-gctr-nivc.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ describe("aes-gctr-nivc", () => {
let plainText = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
let iv = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
let aad = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
let ct = [0x03, 0x88, 0xda, 0xce, 0x60, 0xb6, 0xa3, 0x92, 0xf3, 0x28, 0xc2, 0xb9, 0x71, 0xb2, 0xfe, 0x78];
// let ct = [0x03, 0x88, 0xda, 0xce, 0x60, 0xb6, 0xa3, 0x92, 0xf3, 0x28, 0xc2, 0xb9, 0x71, 0xb2, 0xfe, 0x78];

const ctr = [0x00, 0x00, 0x00, 0x01];
const step_in = 0;

const witness = await circuit_one_block.compute({ key: key, iv: iv, plainText: plainText, aad: aad, ctr: ctr, cipherText: ct, step_in: step_in }, ["step_out"])
assert.deepEqual(witness.step_out, PoseidonModular([step_in, bytesToBigInt(plainText)]));
const witness = await circuit_one_block.compute({ key: key, iv: iv, plainText: plainText, aad: aad, ctr: ctr, cipherText: plainText, step_in: step_in }, ["step_out"])
0xJepsen marked this conversation as resolved.
Show resolved Hide resolved
console.log(witness.step_out);
assert.deepEqual(witness.step_out, BigInt(0));
});

it("all correct for self generated single non zero pt block", async () => {
Expand Down Expand Up @@ -110,4 +111,20 @@ describe("aes-gctr-nivc", () => {
let hash_0 = PoseidonModular([step_in_0, bytesToBigInt(plainText1)]);
assert.deepEqual(witness.step_out, PoseidonModular([hash_0, bytesToBigInt(plainText2)]));
});

it("all correct for two folds at once one zero chunk", async () => {
circuit_two_block = await circomkit.WitnessTester("aes-gcm-fold", {
file: "aes-gcm/nivc/aes-gctr-nivc",
template: "AESGCTRFOLD",
params: [2]
});

const ctr_0 = [0x00, 0x00, 0x00, 0x01];
const step_in_0 = 0;
let zero_chunk = Array(16).fill(0);

const witness = await circuit_two_block.compute({ key: key, iv: iv, aad: aad, ctr: ctr_0, plainText: [plainText1, zero_chunk], cipherText: [ct_part1, zero_chunk], step_in: step_in_0 }, ["step_out"])
let hash_0 = PoseidonModular([step_in_0, bytesToBigInt(plainText1)]);
assert.deepEqual(witness.step_out, hash_0);
});
});
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"name": "web-prover-circuits",
"description": "ZK Circuits for WebProofs",
"version": "0.5.4",
"version": "0.5.5",
"license": "Apache-2.0",
"repository": {
"type": "git",
Expand Down