Skip to content

Commit

Permalink
feat: NIVC hash chain (#36)
Browse files Browse the repository at this point in the history
* init aes hashchain

* Update package.json

* add template `DataHasher`

* NIVC through HTTP start line

* reduce http start line

* json object extract

* working full NIVC proof

* fix aes-nivc tests and http nivc tests

* aes takes in cipher text and validates it

* fix: http lock header and parse tests

* extract final test passes

* move dataHasher to common

* circuit fmt

* fix json masker test

* fix the formatting

* fmt

---------

Co-authored-by: Waylon Jepsen <[email protected]>
Co-authored-by: lonerapier <[email protected]>
  • Loading branch information
3 people authored Nov 12, 2024
1 parent 8035a0d commit 2dd1558
Show file tree
Hide file tree
Showing 19 changed files with 690 additions and 388 deletions.
2 changes: 1 addition & 1 deletion builds/target_1024b/aes_gctr_nivc_1024b.circom
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ pragma circom 2.1.9;

include "../../circuits/aes-gcm/nivc/aes-gctr-nivc.circom";

component main { public [step_in] } = AESGCTRFOLD(1024);
component main { public [step_in] } = AESGCTRFOLD();
2 changes: 1 addition & 1 deletion builds/target_512b/aes_gctr_nivc_512b.circom
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ pragma circom 2.1.9;

include "../../circuits/aes-gcm/nivc/aes-gctr-nivc.circom";

component main { public [step_in] } = AESGCTRFOLD(512);
component main { public [step_in] } = AESGCTRFOLD();
47 changes: 47 additions & 0 deletions circuits.json
Original file line number Diff line number Diff line change
Expand Up @@ -212,5 +212,52 @@
12,
16
]
},
"nivc_aes": {
"file": "aes-gcm/nivc/aes-gctr-nivc",
"template": "AESGCTRFOLD"
},
"nivc_start_line": {
"file": "http/nivc/parse_and_lock_start_line",
"template": "ParseAndLockStartLine",
"params": [
1024,
50,
200,
50
]
},
"nivc_lock_header": {
"file": "http/nivc/lock_header",
"template": "LockHeader",
"params": [
1024,
50,
100
]
},
"nivc_body_mask": {
"file": "http/nivc/body_mask",
"template": "HTTPMaskBodyNIVC",
"params": [
1024
]
},
"nivc_json_object": {
"file": "json/nivc/masker",
"template": "JsonMaskObjectNIVC",
"params": [
1024,
10,
10
]
},
"nivc_json_array": {
"file": "json/nivc/masker",
"template": "JsonMaskArrayIndexNIVC",
"params": [
1024,
10
]
}
}
76 changes: 16 additions & 60 deletions circuits/aes-gcm/nivc/aes-gctr-nivc.circom
Original file line number Diff line number Diff line change
Expand Up @@ -2,76 +2,32 @@ pragma circom 2.1.9;

include "gctr-nivc.circom";
include "../../utils/array.circom";

include "../../utils/hash.circom";

// Compute AES-GCTR
template AESGCTRFOLD(DATA_BYTES) {
// ------------------------------------------------------------------------------------------------------------------ //
// ~~ Set sizes at compile time ~~
assert(DATA_BYTES % 16 == 0);
// Value for accumulating both packed plaintext and ciphertext as well as counter
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4;
// ------------------------------------------------------------------------------------------------------------------ //


template AESGCTRFOLD() {
signal input key[16];
signal input iv[12];
signal input aad[16];
signal input ctr[4];
signal input plainText[16];

// step_in[0..DATA_BYTES] => accumulate plaintext blocks
// step_in[DATA_BYTES..DATA_BYTES*2] => accumulate ciphertext blocks
// step_in[DATA_BYTES_LEN*2..DATA_BYTES*2+4] => accumulate counter
signal input step_in[TOTAL_BYTES_ACROSS_NIVC];
signal output step_out[TOTAL_BYTES_ACROSS_NIVC];
signal input cipherText[16];

signal input step_in[1];
signal output step_out[1];

// We extract the number from the 4 byte word counter
component last_counter_bits = BytesToBits(4);
for(var i = 0; i < 4; i ++) {
last_counter_bits.in[i] <== step_in[DATA_BYTES + i];
}
component last_counter_num = Bits2Num(32);
// pass in reverse order
for (var i = 0; i< 32; i++){
last_counter_num.in[i] <== last_counter_bits.out[31 - i];
}
signal index <== last_counter_num.out - 1;

// folds one block
component aes = AESGCTRFOLDABLE();
aes.key <== key;
aes.iv <== iv;
aes.aad <== aad;
aes.plainText <== plainText;
component aes = AESGCTRFOLDABLE();
aes.key <== key;
aes.iv <== iv;
aes.aad <== aad;
aes.plainText <== plainText;
aes.lastCounter <== ctr;

for(var i = 0; i < 4; i++) {
aes.lastCounter[i] <== step_in[DATA_BYTES + i];
}

// Write out the plaintext and ciphertext to our accumulation arrays, both at once.
signal textToPack[16][2];
aes.cipherText === cipherText;
var packedPlaintext = 0;
for(var i = 0 ; i < 16 ; i++) {
textToPack[i][0] <== plainText[i];
textToPack[i][1] <== aes.cipherText[i];
}
signal nextPackedChunk[16] <== GenericBytePackArray(16,2)(textToPack);

signal prevAccumulatedPackedText[DATA_BYTES];
for(var i = 0 ; i < DATA_BYTES ; i++) {
prevAccumulatedPackedText[i] <== step_in[i];
}
component nextAccumulatedPackedText = WriteToIndex(DATA_BYTES, 16);
nextAccumulatedPackedText.array_to_write_to <== prevAccumulatedPackedText;
nextAccumulatedPackedText.array_to_write_at_index <== nextPackedChunk;
nextAccumulatedPackedText.index <== index * 16;

for(var i = 0 ; i < TOTAL_BYTES_ACROSS_NIVC ; i++) {
if(i < DATA_BYTES) {
step_out[i] <== nextAccumulatedPackedText.out[i];
} else {
step_out[i] <== aes.counter[i - DATA_BYTES];
}
packedPlaintext += plainText[i] * 2**(8*i);
}
step_out[0] <== PoseidonChainer()([step_in[0],packedPlaintext]);
}

38 changes: 15 additions & 23 deletions circuits/http/nivc/body_mask.circom
Original file line number Diff line number Diff line change
@@ -1,24 +1,16 @@
pragma circom 2.1.9;

include "../parser/machine.circom";
include "../../utils/hash.circom";

template HTTPMaskBodyNIVC(DATA_BYTES) {
// ------------------------------------------------------------------------------------------------------------------ //
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // aes ct/pt + ctr
// ------------------------------------------------------------------------------------------------------------------ //

// ------------------------------------------------------------------------------------------------------------------ //
// ~ Unravel from previous NIVC step ~
// Read in from previous NIVC step (HttpParseAndLockStartLine or HTTPLockHeader)
signal input step_in[TOTAL_BYTES_ACROSS_NIVC];
signal output step_out[TOTAL_BYTES_ACROSS_NIVC];

signal data[DATA_BYTES];
// signal parsing_body[DATA_BYTES];
for (var i = 0 ; i < DATA_BYTES ; i++) {
data[i] <== step_in[i];
// parsing_body[i] <== step_in[DATA_BYTES + i * 5 + 4]; // `parsing_body` stored in every 5th slot of step_in/out
}
signal input step_in[1];
signal output step_out[1];

// Authenticate the plaintext we are passing in
signal input data[DATA_BYTES];
signal data_hash <== DataHasher(DATA_BYTES)(data);
data_hash === step_in[0];

// ------------------------------------------------------------------------------------------------------------------ //
// PARSE
Expand Down Expand Up @@ -46,13 +38,13 @@ template HTTPMaskBodyNIVC(DATA_BYTES) {
// ------------------------------------------------------------------------------------------------------------------ //

// ------------------------------------------------------------------------------------------------------------------ //
// ~ Write out to next NIVC step
for (var i = 0 ; i < TOTAL_BYTES_ACROSS_NIVC ; i++) {
if(i < DATA_BYTES) {
step_out[i] <== data[i] * State[i].next_parsing_body;
} else {
step_out[i] <== 0;
}
// Mask out just the JSON body
signal bodyMasked[DATA_BYTES];
for (var i = 0 ; i < DATA_BYTES ; i++) {
bodyMasked[i] <== data[i] * State[i].next_parsing_body;
}

// Hash the new data so this can now be used in the chain later
step_out[0] <== DataHasher(DATA_BYTES)(bodyMasked);
}

37 changes: 14 additions & 23 deletions circuits/http/nivc/lock_header.circom
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,21 @@ include "../parser/machine.circom";
include "../interpreter.circom";
include "../../utils/array.circom";
include "circomlib/circuits/comparators.circom";
include "@zk-email/circuits/utils/array.circom";

// TODO: should use a MAX_HEADER_NAME_LENGTH and a MAX_HEADER_VALUE_LENGTH
template LockHeader(DATA_BYTES, MAX_HEADER_NAME_LENGTH, MAX_HEADER_VALUE_LENGTH) {
// ------------------------------------------------------------------------------------------------------------------ //
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // aes pt/ct + ctr
// ------------------------------------------------------------------------------------------------------------------ //

// ------------------------------------------------------------------------------------------------------------------ //
signal input step_in[TOTAL_BYTES_ACROSS_NIVC];
signal output step_out[TOTAL_BYTES_ACROSS_NIVC];
assert(DATA_BYTES >= MAX_HEADER_NAME_LENGTH + MAX_HEADER_VALUE_LENGTH);

// get the plaintext
signal data[DATA_BYTES];
for (var i = 0 ; i < DATA_BYTES ; i++) {
data[i] <== step_in[i];
}
signal input step_in[1];
signal output step_out[1];

// Authenticate the plaintext we are passing in
signal input data[DATA_BYTES];
signal data_hash <== DataHasher(DATA_BYTES)(data);
data_hash === step_in[0];
step_out[0] <== step_in[0];

// ------------------------------------------------------------------------------------------------------------------ //
// PARSE
Expand Down Expand Up @@ -74,26 +73,18 @@ template LockHeader(DATA_BYTES, MAX_HEADER_NAME_LENGTH, MAX_HEADER_VALUE_LENGTH)
// find header location
signal headerNameLocation <== FirstStringMatch(DATA_BYTES, MAX_HEADER_NAME_LENGTH)(data, header);

// TODO (autoparallel): This could probably be optimized by selecting a subarray of length `MAX_HEADER_NAME_LENGTH + MAX_HEADER_VALUE_LENGTH` at `headerNameLocation`
// This is the assertion that we have locked down the correct header

// signal dataSubArray[MAX_HEADER_NAME_LENGTH + MAX_HEADER_VALUE_LENGTH] <== SelectSubArray(DATA_BYTES, MAX_HEADER_NAME_LENGTH + MAX_HEADER_VALUE_LENGTH)(data, headerNameLocation, MAX_HEADER_NAME_LENGTH + MAX_HEADER_VALUE_LENGTH);
// signal headerFieldNameValueMatch <== HeaderFieldNameValueMatchPadded(MAX_HEADER_NAME_LENGTH + MAX_HEADER_VALUE_LENGTH, MAX_HEADER_NAME_LENGTH, MAX_HEADER_VALUE_LENGTH)(dataSubArray, header, headerNameLength, value, headerValueLength, headerNameLocation);
signal headerFieldNameValueMatch <== HeaderFieldNameValueMatchPadded(DATA_BYTES, MAX_HEADER_NAME_LENGTH, MAX_HEADER_VALUE_LENGTH)(data, header, headerNameLength, value, headerValueLength, headerNameLocation);
headerFieldNameValueMatch === 1;

// parser state should be parsing header upto 2^10 max headers
signal isParsingHeader <== IndexSelector(DATA_BYTES * 5)(httpParserState, headerNameLocation * 5 + 1);
signal parsingHeader <== GreaterThan(10)([isParsingHeader, 0]);
parsingHeader === 1;

// ------------------------------------------------------------------------------------------------------------------ //
// write out the pt again
for (var i = 0 ; i < TOTAL_BYTES_ACROSS_NIVC ; i++) {
// add plaintext http input to step_out and ignore the ciphertext
if(i < DATA_BYTES) {
step_out[i] <== step_in[i];
} else {
step_out[i] <== 0;
}
}

}

// TODO: Handrolled template that I haven't tested YOLO.
Expand Down
68 changes: 28 additions & 40 deletions circuits/http/nivc/parse_and_lock_start_line.circom
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,23 @@ include "../parser/machine.circom";
include "../interpreter.circom";
include "../../utils/bytes.circom";

// TODO: Note that TOTAL_BYTES will match what we have for AESGCMFOLD step_out
// I have not gone through to double check the sizes of everything yet.
template ParseAndLockStartLine(DATA_BYTES, MAX_BEGINNING_LENGTH, MAX_MIDDLE_LENGTH, MAX_FINAL_LENGTH) {
// ------------------------------------------------------------------------------------------------------------------ //
// ~~ Set sizes at compile time ~~
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // AES ct/pt + ctr
// ------------------------------------------------------------------------------------------------------------------ //

// ------------------------------------------------------------------------------------------------------------------ //
signal input step_in[TOTAL_BYTES_ACROSS_NIVC];
signal output step_out[TOTAL_BYTES_ACROSS_NIVC];

// Get the plaintext
signal packedData[DATA_BYTES];
for (var i = 0 ; i < DATA_BYTES ; i++) {
packedData[i] <== step_in[i];
var MINIMUM_PARSE_LENGTH = MAX_BEGINNING_LENGTH + MAX_MIDDLE_LENGTH + MAX_FINAL_LENGTH;
assert(DATA_BYTES >= MINIMUM_PARSE_LENGTH);

signal input step_in[1];
signal output step_out[1];

// Authenticate the plaintext we are passing in
signal input data[DATA_BYTES];
signal data_hash <== DataHasher(DATA_BYTES)(data);
data_hash === step_in[0];
step_out[0] <== step_in[0];

signal dataToParse[MINIMUM_PARSE_LENGTH];
for(var i = 0 ; i < MINIMUM_PARSE_LENGTH ; i++) {
dataToParse[i] <== data[i];
}
component unpackData = UnpackDoubleByteArray(DATA_BYTES);
unpackData.in <== packedData;
signal data[DATA_BYTES] <== unpackData.lower;

signal input beginning[MAX_BEGINNING_LENGTH];
signal input beginning_length;
Expand All @@ -32,10 +29,11 @@ template ParseAndLockStartLine(DATA_BYTES, MAX_BEGINNING_LENGTH, MAX_MIDDLE_LENG
signal input final[MAX_FINAL_LENGTH];
signal input final_length;

// Initialze the parser
component State[DATA_BYTES];
// Initialze the parser, note that we only need to parse as much as the `MINIMUM_PARSE_LENGTH`
// since the start line could not possibly go past this point, or else this would fail anyway
component State[MINIMUM_PARSE_LENGTH];
State[0] = HttpStateUpdate();
State[0].byte <== data[0];
State[0].byte <== dataToParse[0];
State[0].parsing_start <== 1;
State[0].parsing_header <== 0;
State[0].parsing_field_name <== 0;
Expand All @@ -50,9 +48,9 @@ template ParseAndLockStartLine(DATA_BYTES, MAX_BEGINNING_LENGTH, MAX_MIDDLE_LENG
*/

// Setup to check middle bytes
signal startLineMask[DATA_BYTES];
signal middleMask[DATA_BYTES];
signal finalMask[DATA_BYTES];
signal startLineMask[MINIMUM_PARSE_LENGTH];
signal middleMask[MINIMUM_PARSE_LENGTH];
signal finalMask[MINIMUM_PARSE_LENGTH];
startLineMask[0] <== inStartLine()(State[0].parsing_start);
middleMask[0] <== inStartMiddle()(State[0].parsing_start);
finalMask[0] <== inStartEnd()(State[0].parsing_start);
Expand All @@ -61,9 +59,9 @@ template ParseAndLockStartLine(DATA_BYTES, MAX_BEGINNING_LENGTH, MAX_MIDDLE_LENG
var middle_start_counter = 1;
var middle_end_counter = 1;
var final_end_counter = 1;
for(var data_idx = 1; data_idx < DATA_BYTES; data_idx++) {
for(var data_idx = 1; data_idx < MINIMUM_PARSE_LENGTH; data_idx++) {
State[data_idx] = HttpStateUpdate();
State[data_idx].byte <== data[data_idx];
State[data_idx].byte <== dataToParse[data_idx];
State[data_idx].parsing_start <== State[data_idx - 1].next_parsing_start;
State[data_idx].parsing_header <== State[data_idx - 1].next_parsing_header;
State[data_idx].parsing_field_name <== State[data_idx - 1].next_parsing_field_name;
Expand All @@ -85,27 +83,17 @@ template ParseAndLockStartLine(DATA_BYTES, MAX_BEGINNING_LENGTH, MAX_MIDDLE_LENG
// Additionally verify beginning had correct length
beginning_length === middle_start_counter - 1;

signal beginningMatch <== SubstringMatchWithIndexPadded(DATA_BYTES, MAX_BEGINNING_LENGTH)(data, beginning, beginning_length, 0);
signal beginningMatch <== SubstringMatchWithIndexPadded(MINIMUM_PARSE_LENGTH, MAX_BEGINNING_LENGTH)(dataToParse, beginning, beginning_length, 0);

// Check middle is correct by substring match and length check
signal middleMatch <== SubstringMatchWithIndexPadded(DATA_BYTES, MAX_MIDDLE_LENGTH)(data, middle, middle_length, middle_start_counter);
signal middleMatch <== SubstringMatchWithIndexPadded(MINIMUM_PARSE_LENGTH, MAX_MIDDLE_LENGTH)(dataToParse, middle, middle_length, middle_start_counter);
middleMatch === 1;
middle_length === middle_end_counter - middle_start_counter - 1;

// Check final is correct by substring match and length check
signal finalMatch <== SubstringMatchWithIndexPadded(DATA_BYTES, MAX_FINAL_LENGTH)(data, final, final_length, middle_end_counter);
signal finalMatch <== SubstringMatchWithIndexPadded(MINIMUM_PARSE_LENGTH, MAX_FINAL_LENGTH)(dataToParse, final, final_length, middle_end_counter);
finalMatch === 1;
// -2 here for the CRLF
final_length === final_end_counter - middle_end_counter - 2;

// ------------------------------------------------------------------------------------------------------------------ //
// write out the pt again
for (var i = 0 ; i < TOTAL_BYTES_ACROSS_NIVC ; i++) {
// add plaintext http input to step_out and ignore the ciphertext
if(i < DATA_BYTES) {
step_out[i] <== data[i]; // PASS OUT JUST THE PLAINTEXT DATA
} else {
step_out[i] <== 0;
}
}
}

Loading

0 comments on commit 2dd1558

Please sign in to comment.