Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: byte pack NIVC #35

Merged
merged 9 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ all: build
build:
@for circuit in $(CIRCOM_FILES); do \
echo "Processing $${circuit}..."; \
circom "$${circuit}" --r1cs -o "$$(dirname $${circuit})/artifacts" -l node_modules; \
circom "$${circuit}" --r1cs --wasm -o "$$(dirname $${circuit})/artifacts" -l node_modules; \
build-circuit "$${circuit}" "$$(dirname $${circuit})/artifacts/$$(basename $${circuit} .circom).bin" -l node_modules; \
done

Expand Down
130 changes: 19 additions & 111 deletions circuits/aes-gcm/nivc/aes-gctr-nivc.circom
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ template AESGCTRFOLD(DATA_BYTES) {
// ------------------------------------------------------------------------------------------------------------------ //
// ~~ Set sizes at compile time ~~
assert(DATA_BYTES % 16 == 0);
// Value for accumulating both plaintext and ciphertext as well as counter
var TOTAL_BYTES_ACROSS_NIVC = 2 * DATA_BYTES + 4;
// Value for accumulating both packed plaintext and ciphertext as well as counter
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4;
// ------------------------------------------------------------------------------------------------------------------ //


Expand All @@ -29,7 +29,7 @@ template AESGCTRFOLD(DATA_BYTES) {
// We extract the number from the 4 byte word counter
component last_counter_bits = BytesToBits(4);
for(var i = 0; i < 4; i ++) {
last_counter_bits.in[i] <== step_in[DATA_BYTES * 2 + i];
last_counter_bits.in[i] <== step_in[DATA_BYTES + i];
}
component last_counter_num = Bits2Num(32);
// pass in reverse order
Expand All @@ -46,124 +46,32 @@ template AESGCTRFOLD(DATA_BYTES) {
aes.plainText <== plainText;

for(var i = 0; i < 4; i++) {
aes.lastCounter[i] <== step_in[DATA_BYTES * 2 + i];
aes.lastCounter[i] <== step_in[DATA_BYTES + i];
}


// Write out the plaintext and ciphertext to our accumulation arrays, both at once.
signal prevAccumulatedPlaintext[DATA_BYTES];
for(var i = 0 ; i < DATA_BYTES ; i++) {
prevAccumulatedPlaintext[i] <== step_in[i];
signal textToPack[16][2];
for(var i = 0 ; i < 16 ; i++) {
textToPack[i][0] <== plainText[i];
textToPack[i][1] <== aes.cipherText[i];
}
signal prevAccumulatedCiphertext[DATA_BYTES];
signal nextPackedChunk[16] <== GenericBytePackArray(16,2)(textToPack);

signal prevAccumulatedPackedText[DATA_BYTES];
for(var i = 0 ; i < DATA_BYTES ; i++) {
prevAccumulatedCiphertext[i] <== step_in[DATA_BYTES + i];
prevAccumulatedPackedText[i] <== step_in[i];
}
component nextTexts = WriteToIndexForTwoArrays(DATA_BYTES, 16);
nextTexts.first_array_to_write_to <== prevAccumulatedPlaintext;
nextTexts.second_array_to_write_to <== prevAccumulatedCiphertext;
nextTexts.first_array_to_write_at_index <== plainText;
nextTexts.second_array_to_write_at_index <== aes.cipherText;
nextTexts.index <== index * 16;

component nextAccumulatedPackedText = WriteToIndex(DATA_BYTES, 16);
nextAccumulatedPackedText.array_to_write_to <== prevAccumulatedPackedText;
nextAccumulatedPackedText.array_to_write_at_index <== nextPackedChunk;
nextAccumulatedPackedText.index <== index * 16;

for(var i = 0 ; i < TOTAL_BYTES_ACROSS_NIVC ; i++) {
if(i < DATA_BYTES) {
step_out[i] <== nextTexts.outFirst[i];
} else if(i < 2 * DATA_BYTES) {
step_out[i] <== nextTexts.outSecond[i - DATA_BYTES];
} else if(i < 2 * DATA_BYTES + 4) {
step_out[i] <== aes.counter[i - (2 * DATA_BYTES)];
step_out[i] <== nextAccumulatedPackedText.out[i];
} else {
step_out[i] <== aes.counter[i - DATA_BYTES];
}
}
}



template WriteToIndexForTwoArrays(m, n) {
signal input first_array_to_write_to[m];
signal input second_array_to_write_to[m];
signal input first_array_to_write_at_index[n];
signal input second_array_to_write_at_index[n];
signal input index;

signal output outFirst[m];
signal output outSecond[m];

assert(m >= n);

// Note: this is underconstrained, we need to constrain that index + n <= m
// Need to constrain that index + n <= m -- can't be an assertion, because uses a signal
// ------------------------- //

// Here, we get an array of ALL zeros, except at the `index` AND `index + n`
// beginning-------^^^^^ end---^^^^^^^^^
signal indexMatched[m];
component indexBegining[m];
component indexEnding[m];
for(var i = 0 ; i < m ; i++) {
indexBegining[i] = IsZero();
indexBegining[i].in <== i - index;
indexEnding[i] = IsZero();
indexEnding[i].in <== i - (index + n);
indexMatched[i] <== indexBegining[i].out + indexEnding[i].out;
}

// E.g., index == 31, m == 160, n == 16
// => indexMatch[31] == 1;
// => indexMatch[47] == 1;
// => otherwise, all 0.

signal accum[m];
accum[0] <== indexMatched[0];

component writeAt = IsZero();
writeAt.in <== accum[0] - 1;

component orFirst = OR();
orFirst.a <== (writeAt.out * first_array_to_write_at_index[0]);
orFirst.b <== (1 - writeAt.out) * first_array_to_write_to[0];
outFirst[0] <== orFirst.out;

component orSecond = OR();
orSecond.a <== (writeAt.out * second_array_to_write_at_index[0]);
orSecond.b <== (1 - writeAt.out) * second_array_to_write_to[0];
outSecond[0] <== orSecond.out;
// IF accum == 1 then { array_to_write_at } ELSE IF accum != 1 then { array to write_to }
signal accum_index[m];
accum_index[0] <== accum[0];

component writeSelector[m - 1];
component indexSelectorFirst[m - 1];
component indexSelectorSecond[m - 1];
component orsFirst[m-1];
component orsSecond[m-1];
for(var i = 1 ; i < m ; i++) {
// accum will be 1 at all indices where we want to write the new array
accum[i] <== accum[i-1] + indexMatched[i];
writeSelector[i-1] = IsZero();
writeSelector[i-1].in <== accum[i] - 1;
// IsZero(accum[i] - 1); --> tells us we are in the range where we want to write the new array

indexSelectorFirst[i-1] = IndexSelector(n);
indexSelectorFirst[i-1].index <== accum_index[i-1];
indexSelectorFirst[i-1].in <== first_array_to_write_at_index;

indexSelectorSecond[i-1] = IndexSelector(n);
indexSelectorSecond[i-1].index <== accum_index[i-1];
indexSelectorSecond[i-1].in <== second_array_to_write_at_index;
// When accum is not zero, out is array_to_write_at_index, otherwise it is array_to_write_to

orsFirst[i-1] = OR();
orsFirst[i-1].a <== (writeSelector[i-1].out * indexSelectorFirst[i-1].out);
orsFirst[i-1].b <== (1 - writeSelector[i-1].out) * first_array_to_write_to[i];
outFirst[i] <== orsFirst[i-1].out;

orsSecond[i-1] = OR();
orsSecond[i-1].a <== (writeSelector[i-1].out * indexSelectorSecond[i-1].out);
orsSecond[i-1].b <== (1 - writeSelector[i-1].out) * second_array_to_write_to[i];
outSecond[i] <== orsSecond[i-1].out;

accum_index[i] <== accum_index[i-1] + writeSelector[i-1].out;
}
}
2 changes: 1 addition & 1 deletion circuits/http/nivc/body_mask.circom
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ include "../parser/machine.circom";

template HTTPMaskBodyNIVC(DATA_BYTES) {
// ------------------------------------------------------------------------------------------------------------------ //
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4; // aes ct/pt + ctr
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // aes ct/pt + ctr
// ------------------------------------------------------------------------------------------------------------------ //

// ------------------------------------------------------------------------------------------------------------------ //
Expand Down
2 changes: 1 addition & 1 deletion circuits/http/nivc/lock_header.circom
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ include "circomlib/circuits/comparators.circom";
// TODO: should use a MAX_HEADER_NAME_LENGTH and a MAX_HEADER_VALUE_LENGTH
template LockHeader(DATA_BYTES, MAX_HEADER_NAME_LENGTH, MAX_HEADER_VALUE_LENGTH) {
// ------------------------------------------------------------------------------------------------------------------ //
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4; // aes pt/ct + ctr
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // aes pt/ct + ctr
// ------------------------------------------------------------------------------------------------------------------ //

// ------------------------------------------------------------------------------------------------------------------ //
Expand Down
11 changes: 7 additions & 4 deletions circuits/http/nivc/parse_and_lock_start_line.circom
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@ include "../../utils/bytes.circom";
template ParseAndLockStartLine(DATA_BYTES, MAX_BEGINNING_LENGTH, MAX_MIDDLE_LENGTH, MAX_FINAL_LENGTH) {
// ------------------------------------------------------------------------------------------------------------------ //
// ~~ Set sizes at compile time ~~
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4; // AES ct/pt + ctr
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // AES ct/pt + ctr
// ------------------------------------------------------------------------------------------------------------------ //

// ------------------------------------------------------------------------------------------------------------------ //
signal input step_in[TOTAL_BYTES_ACROSS_NIVC];
signal output step_out[TOTAL_BYTES_ACROSS_NIVC];

// Get the plaintext
signal data[DATA_BYTES];
signal packedData[DATA_BYTES];
for (var i = 0 ; i < DATA_BYTES ; i++) {
data[i] <== step_in[i];
packedData[i] <== step_in[i];
}
component unpackData = UnpackDoubleByteArray(DATA_BYTES);
unpackData.in <== packedData;
signal data[DATA_BYTES] <== unpackData.lower;

signal input beginning[MAX_BEGINNING_LENGTH];
signal input beginning_length;
Expand Down Expand Up @@ -100,7 +103,7 @@ template ParseAndLockStartLine(DATA_BYTES, MAX_BEGINNING_LENGTH, MAX_MIDDLE_LENG
for (var i = 0 ; i < TOTAL_BYTES_ACROSS_NIVC ; i++) {
// add plaintext http input to step_out and ignore the ciphertext
if(i < DATA_BYTES) {
step_out[i] <== step_in[i];
step_out[i] <== data[i]; // PASS OUT JUST THE PLAINTEXT DATA
} else {
step_out[i] <== 0;
}
Expand Down
2 changes: 1 addition & 1 deletion circuits/json/nivc/extractor.circom
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ include "@zk-email/circuits/utils/array.circom";

template MaskExtractFinal(DATA_BYTES, MAX_VALUE_LENGTH) {
// ------------------------------------------------------------------------------------------------------------------ //
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4; // aes pt/ct + ctr
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // aes pt/ct + ctr
// ------------------------------------------------------------------------------------------------------------------ //
signal input step_in[TOTAL_BYTES_ACROSS_NIVC];
signal output step_out[TOTAL_BYTES_ACROSS_NIVC];
Expand Down
8 changes: 4 additions & 4 deletions circuits/json/nivc/masker.circom
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ include "../interpreter.circom";
template JsonMaskObjectNIVC(DATA_BYTES, MAX_STACK_HEIGHT, MAX_KEY_LENGTH) {
// ------------------------------------------------------------------------------------------------------------------ //
assert(MAX_STACK_HEIGHT >= 2); // TODO (autoparallel): idk if we need this now
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4; // aes ct/pt + ctr
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // aes ct/pt + ctr
// ------------------------------------------------------------------------------------------------------------------ //
signal input step_in[TOTAL_BYTES_ACROSS_NIVC];
signal output step_out[TOTAL_BYTES_ACROSS_NIVC];
Expand Down Expand Up @@ -87,15 +87,15 @@ template JsonMaskObjectNIVC(DATA_BYTES, MAX_STACK_HEIGHT, MAX_KEY_LENGTH) {
// mask = currently parsing value and all subsequent keys matched
step_out[data_idx] <== data[data_idx] * or[data_idx - 1];
}
for(var i = DATA_BYTES - MAX_KEY_LENGTH; i < 2 * DATA_BYTES + 4; i ++) {
for(var i = DATA_BYTES - MAX_KEY_LENGTH; i < DATA_BYTES + 4; i ++) {
step_out[i] <== 0;
}
}

template JsonMaskArrayIndexNIVC(DATA_BYTES, MAX_STACK_HEIGHT) {
// ------------------------------------------------------------------------------------------------------------------ //
assert(MAX_STACK_HEIGHT >= 2); // TODO (autoparallel): idk if we need this now
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4; // aes ct/pt + ctr
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4; // aes ct/pt + ctr
// ------------------------------------------------------------------------------------------------------------------ //
signal input step_in[TOTAL_BYTES_ACROSS_NIVC];
signal output step_out[TOTAL_BYTES_ACROSS_NIVC];
Expand Down Expand Up @@ -136,7 +136,7 @@ template JsonMaskArrayIndexNIVC(DATA_BYTES, MAX_STACK_HEIGHT) {
or[data_idx - 1] <== OR()(parsing_array[data_idx], parsing_array[data_idx - 1]);
step_out[data_idx] <== data[data_idx] * or[data_idx - 1];
}
for(var i = DATA_BYTES ; i < 2 * DATA_BYTES + 4; i++) {
for(var i = DATA_BYTES ; i < TOTAL_BYTES_ACROSS_NIVC; i++) {
step_out[i] <== 0;
}
}
35 changes: 19 additions & 16 deletions circuits/test/aes-gcm/nivc/aes-gctr-nivc.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ describe("aes-gctr-nivc", () => {


const DATA_BYTES_0 = 16;
const TOTAL_BYTES_ACROSS_NIVC_0 = 2 * DATA_BYTES_0 + 4;
const TOTAL_BYTES_ACROSS_NIVC_0 = DATA_BYTES_0 + 4;

it("all correct for self generated single zero pt block case", async () => {
circuit_one_block = await circomkit.WitnessTester("aes-gcm-fold", {
Expand All @@ -25,12 +25,13 @@ describe("aes-gctr-nivc", () => {
const counter = [0x00, 0x00, 0x00, 0x01];
const step_in = new Array(TOTAL_BYTES_ACROSS_NIVC_0).fill(0x00);
counter.forEach((value, index) => {
step_in[2 * DATA_BYTES_0 + index] = value;
step_in[DATA_BYTES_0 + index] = value;
});

let expected = plainText.concat(ct).concat([0x00, 0x00, 0x00, 0x02]);
expected = expected.concat(new Array(TOTAL_BYTES_ACROSS_NIVC_0 - expected.length).fill(0));
const witness = await circuit_one_block.compute({ key: key, iv: iv, plainText: plainText, aad: aad, step_in: step_in }, ["step_out"])

let packed = plainText.map((x, i) => x + (ct[i] * 256));
let expected = [...packed, 0x00, 0x00, 0x00, 0x02];
assert.deepEqual(witness.step_out, expected.map(BigInt));
});

Expand All @@ -50,18 +51,18 @@ describe("aes-gctr-nivc", () => {
const counter = [0x00, 0x00, 0x00, 0x01];
const step_in = new Array(TOTAL_BYTES_ACROSS_NIVC_0).fill(0x00);
counter.forEach((value, index) => {
step_in[2 * DATA_BYTES_0 + index] = value;
step_in[DATA_BYTES_0 + index] = value;
});

let expected = plainText.concat(ct).concat([0x00, 0x00, 0x00, 0x02]);
expected = expected.concat(new Array(TOTAL_BYTES_ACROSS_NIVC_0 - expected.length).fill(0));

const witness = await circuit_one_block.compute({ key: key, iv: iv, plainText: plainText, aad: aad, step_in: step_in }, ["step_out"])

let packed = plainText.map((x, i) => x + (ct[i] * 256));
let expected = [...packed, 0x00, 0x00, 0x00, 0x02];
assert.deepEqual(witness.step_out, expected.map(BigInt));
});

const DATA_BYTES_1 = 32;
const TOTAL_BYTES_ACROSS_NIVC_1 = DATA_BYTES_1 * 2 + 4;
const TOTAL_BYTES_ACROSS_NIVC_1 = DATA_BYTES_1 + 4;


let zero_block = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
Expand All @@ -83,12 +84,13 @@ describe("aes-gctr-nivc", () => {
const counter = [0x00, 0x00, 0x00, 0x01];
const step_in = new Array(TOTAL_BYTES_ACROSS_NIVC_1).fill(0x00);
counter.forEach((value, index) => {
step_in[2 * DATA_BYTES_1 + index] = value;
step_in[DATA_BYTES_1 + index] = value;
});
let expected = plainText1.concat(zero_block).concat(ct_part1).concat(zero_block).concat([0x00, 0x00, 0x00, 0x02]);
expected = expected.concat(new Array(TOTAL_BYTES_ACROSS_NIVC_1 - expected.length).fill(0));

const witness = await circuit_one_block.compute({ key: key, iv: iv, plainText: plainText1, aad: aad, step_in: step_in }, ["step_out"])

let packed1 = plainText1.map((x, i) => x + (ct_part1[i] * 256));
let expected = packed1.concat(zero_block).concat([0x00, 0x00, 0x00, 0x02]);
assert.deepEqual(witness.step_out, expected.map(BigInt));
});

Expand All @@ -99,12 +101,13 @@ describe("aes-gctr-nivc", () => {
params: [DATA_BYTES_1], // input len is 32 bytes
});

const counter = [0x00, 0x00, 0x00, 0x02];
let step_in = plainText1.concat(zero_block).concat(ct_part1).concat(zero_block).concat(counter);
let packed1 = plainText1.map((x, i) => x + (ct_part1[i] * 256));
let packed2 = plainText2.map((x, i) => x + (ct_part2[i] * 256));
let step_in = packed1.concat(zero_block).concat([0x00, 0x00, 0x00, 0x02]);
step_in = step_in.concat(new Array(TOTAL_BYTES_ACROSS_NIVC_1 - step_in.length).fill(0));

let expected = plainText1.concat(plainText2).concat(ct_part1).concat(ct_part2).concat([0x00, 0x00, 0x00, 0x03]);
expected = expected.concat(new Array(TOTAL_BYTES_ACROSS_NIVC_1 - expected.length).fill(0));

let expected = packed1.concat(packed2).concat([0x00, 0x00, 0x00, 0x03]);

const witness = await circuit_one_block.compute({ key: key, iv: iv, plainText: plainText2, aad: aad, step_in: step_in }, ["step_out"])
assert.deepEqual(witness.step_out, expected.map(BigInt));
Expand Down
6 changes: 2 additions & 4 deletions circuits/test/full/full.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ describe("NIVC_FULL", async () => {

const DATA_BYTES = 320;
const MAX_STACK_HEIGHT = 5;
const TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4;
const TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4;

const MAX_HEADER_NAME_LENGTH = 20;
const MAX_HEADER_VALUE_LENGTH = 35;
Expand Down Expand Up @@ -132,7 +132,7 @@ describe("NIVC_FULL", async () => {
const counter = [0x00, 0x00, 0x00, 0x01];
const init_nivc_input = new Array(TOTAL_BYTES_ACROSS_NIVC).fill(0x00);
counter.forEach((value, index) => {
init_nivc_input[2 * DATA_BYTES + index] = value;
init_nivc_input[DATA_BYTES + index] = value;
});
let pt = http_response_plaintext.slice(0, 16);
aes_gcm = await aesCircuit.compute({ key: Array(16).fill(0), iv: Array(12).fill(0), plainText: pt, aad: Array(16).fill(0), step_in: init_nivc_input }, ["step_out"]);
Expand All @@ -154,8 +154,6 @@ describe("NIVC_FULL", async () => {
let maskedInput = extendedJsonInput.fill(0, 0, idx);
maskedInput = maskedInput.fill(0, 320);



let key0 = [100, 97, 116, 97, 0, 0, 0, 0]; // "data"
let key0Len = 4;
let key1 = [105, 116, 101, 109, 115, 0, 0, 0]; // "items"
Expand Down
2 changes: 1 addition & 1 deletion circuits/test/http/nivc/body_mask.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ describe("NIVC_HTTP", async () => {
let bodyMaskCircuit: WitnessTester<["step_in"], ["step_out"]>;

const DATA_BYTES = 320;
const TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4;
const TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4;

const MAX_HEADER_NAME_LENGTH = 20;
const MAX_HEADER_VALUE_LENGTH = 35;
Expand Down
2 changes: 1 addition & 1 deletion circuits/test/http/nivc/lock_header.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ describe("HTTPLockHeader", async () => {
let lockHeaderCircuit: WitnessTester<["step_in", "header", "headerNameLength", "value", "headerValueLength"], ["step_out"]>;

const DATA_BYTES = 320;
const TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * 2 + 4;
const TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES + 4;

const MAX_BEGINNING_LENGTH = 10;
const MAX_MIDDLE_LENGTH = 50;
Expand Down
Loading