Skip to content

Commit

Permalink
fix: masking tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Autoparallel committed Oct 20, 2024
1 parent 93a6e1d commit acd2f2c
Show file tree
Hide file tree
Showing 9 changed files with 2,909 additions and 199 deletions.
3 changes: 3 additions & 0 deletions circuits/json/interpreter.circom
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,10 @@ template KeyMatchAtDepth(dataLen, n, keyLen, depth) {
signal output out <== substring_match * is_parsing_correct_key_at_depth;
}

// TODO: Not checking start of key is quote since that is handled by `parsing_key`?
template MatchPaddedKey(n) {
// TODO: If key is not padded at all, then `in[1]` will not contain an end quote.
// Perhaps we modify this to handle that, or just always pad the key at least once.
signal input in[2][n];
signal input keyLen;
signal output out;
Expand Down
28 changes: 18 additions & 10 deletions circuits/json/nivc/extractor.circom
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,16 @@ pragma circom 2.1.9;
include "circomlib/circuits/gates.circom";
include "@zk-email/circuits/utils/array.circom";

template MaskExtractFinal(TOTAL_BYTES, DATA_BYTES, maxValueLen) {
signal input step_in[TOTAL_BYTES + 1];
signal output step_out[TOTAL_BYTES + 1];
template MaskExtractFinal(DATA_BYTES, MAX_STACK_HEIGHT, MAX_VALUE_LENGTH) {
// ------------------------------------------------------------------------------------------------------------------ //
// ~~ Set sizes at compile time ~~
// Total number of variables in the parser for each byte of data
assert(MAX_STACK_HEIGHT >= 2);
var PER_ITERATION_DATA_LENGTH = MAX_STACK_HEIGHT * 2 + 2;
var TOTAL_BYTES_ACROSS_NIVC = DATA_BYTES * (PER_ITERATION_DATA_LENGTH + 1) + 1;
// ------------------------------------------------------------------------------------------------------------------ //
signal input step_in[TOTAL_BYTES_ACROSS_NIVC];
signal output step_out[TOTAL_BYTES_ACROSS_NIVC];

signal is_zero_mask[DATA_BYTES];
signal is_prev_starting_index[DATA_BYTES];
Expand All @@ -19,21 +26,22 @@ template MaskExtractFinal(TOTAL_BYTES, DATA_BYTES, maxValueLen) {
value_starting_index[0] <== 0;
is_prev_starting_index[0] <== 0;
is_zero_mask[0] <== IsZero()(step_in[0]);
for (var i=1 ; i<DATA_BYTES ; i++) {
for (var i=1 ; i < DATA_BYTES ; i++) {
is_zero_mask[i] <== IsZero()(step_in[i]);
is_prev_starting_index[i] <== IsZero()(value_starting_index[i-1]);
value_starting_index[i] <== value_starting_index[i-1] + i * (1-is_zero_mask[i]) * is_prev_starting_index[i];
}

signal value[maxValueLen] <== SelectSubArray(DATA_BYTES, maxValueLen)(data, value_starting_index[DATA_BYTES-1], maxValueLen);
for (var i = 0 ; i < maxValueLen ; i++) {
// TODO: Clear step out?
signal output value[MAX_VALUE_LENGTH] <== SelectSubArray(DATA_BYTES, MAX_VALUE_LENGTH)(data, value_starting_index[DATA_BYTES-1], MAX_VALUE_LENGTH);
for (var i = 0 ; i < MAX_VALUE_LENGTH ; i++) {
// log(i, value[i]);
step_out[i] <== value[i];
}
for (var i = maxValueLen ; i < TOTAL_BYTES ; i++) {
for (var i = MAX_VALUE_LENGTH ; i < TOTAL_BYTES_ACROSS_NIVC ; i++) {
step_out[i] <== 0;
}
step_out[TOTAL_BYTES] <== 0;
// TODO: Do anything with last depth?
// step_out[TOTAL_BYTES_ACROSS_NIVC - 1] <== 0;
}

component main { public [step_in] } = MaskExtractFinal(4160, 320, 200);
// component main { public [step_in] } = MaskExtractFinal(4160, 320, 200);
106 changes: 64 additions & 42 deletions circuits/json/nivc/masker.circom
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,20 @@ template JsonMaskObjectNIVC(DATA_BYTES, MAX_STACK_HEIGHT, MAX_KEY_LENGTH) {
}

// Decode the encoded data in `step_in` back into parser variables
signal stack[DATA_BYTES][MAX_STACK_HEIGHT][2];
signal stack[DATA_BYTES][MAX_STACK_HEIGHT + 1][2];
signal parsingData[DATA_BYTES][2];
for (var i = 0 ; i < DATA_BYTES ; i++) {
for (var j = 0 ; j < MAX_STACK_HEIGHT ; j++) {
stack[i][j][0] <== step_in[DATA_BYTES + i * PER_ITERATION_DATA_LENGTH + j * 2];
stack[i][j][1] <== step_in[DATA_BYTES + i * PER_ITERATION_DATA_LENGTH + j * 2 + 1];
for (var j = 0 ; j < MAX_STACK_HEIGHT + 1 ; j++) {
if (j < MAX_STACK_HEIGHT) {
stack[i][j][0] <== step_in[DATA_BYTES + i * PER_ITERATION_DATA_LENGTH + j * 2];
stack[i][j][1] <== step_in[DATA_BYTES + i * PER_ITERATION_DATA_LENGTH + j * 2 + 1];
} else {
// TODO: Now need to do "if curr_depth == MAX_STACK_HEIGHT", set nextStackSelector <== [0,0]
// Add one extra stack element without doing this while parsing.
stack[i][j][0] <== 0;
stack[i][j][1] <== 0;
}

}
parsingData[i][0] <== step_in[DATA_BYTES + i * PER_ITERATION_DATA_LENGTH + MAX_STACK_HEIGHT * 2];
parsingData[i][1] <== step_in[DATA_BYTES + i * PER_ITERATION_DATA_LENGTH + MAX_STACK_HEIGHT * 2 + 1];
Expand All @@ -46,15 +54,15 @@ template JsonMaskObjectNIVC(DATA_BYTES, MAX_STACK_HEIGHT, MAX_KEY_LENGTH) {
signal input keyLen;

// Signals to detect if we are parsing a key or value with initial setup
signal parsing_key[DATA_BYTES - MAX_KEY_LENGTH];
signal parsing_value[DATA_BYTES - MAX_KEY_LENGTH];
signal parsing_key[DATA_BYTES];
signal parsing_value[DATA_BYTES];

// Flags at each byte to indicate if we are matching correct key and in subsequent value
signal is_key_match[DATA_BYTES - MAX_KEY_LENGTH];
signal is_value_match[DATA_BYTES - MAX_KEY_LENGTH];
signal is_key_match[DATA_BYTES];
signal is_value_match[DATA_BYTES];

signal is_next_pair_at_depth[DATA_BYTES - MAX_KEY_LENGTH];
signal is_key_match_for_value[DATA_BYTES + 1 - MAX_KEY_LENGTH];
signal is_next_pair_at_depth[DATA_BYTES];
signal is_key_match_for_value[DATA_BYTES + 1];
is_key_match_for_value[0] <== 0;

// Initialize values knowing 0th bit of data will never be a key/value
Expand All @@ -63,54 +71,60 @@ template JsonMaskObjectNIVC(DATA_BYTES, MAX_STACK_HEIGHT, MAX_KEY_LENGTH) {
is_key_match[0] <== 0;

component stackSelector[DATA_BYTES];
stackSelector[0] = ArraySelector(MAX_STACK_HEIGHT, 2);
stackSelector[0] = ArraySelector(MAX_STACK_HEIGHT + 1, 2);
stackSelector[0].in <== stack[0];
stackSelector[0].index <== step_in[TOTAL_BYTES_ACROSS_NIVC - 1];

is_next_pair_at_depth[0] <== NextKVPairAtDepth(MAX_STACK_HEIGHT)(stack[0], data[0],step_in[TOTAL_BYTES_ACROSS_NIVC - 1]);
component nextStackSelector[DATA_BYTES];
nextStackSelector[0] = ArraySelector(MAX_STACK_HEIGHT + 1, 2);
nextStackSelector[0].in <== stack[0];
nextStackSelector[0].index <== step_in[TOTAL_BYTES_ACROSS_NIVC - 1] + 1;

is_next_pair_at_depth[0] <== NextKVPairAtDepth(MAX_STACK_HEIGHT + 1)(stack[0], data[0],step_in[TOTAL_BYTES_ACROSS_NIVC - 1]);
is_key_match_for_value[1] <== Mux1()([is_key_match_for_value[0] * (1-is_next_pair_at_depth[0]), is_key_match[0] * (1-is_next_pair_at_depth[0])], is_key_match[0]);
is_value_match[0] <== parsing_value[0] * is_key_match_for_value[1];

signal output step_out[TOTAL_BYTES_ACROSS_NIVC];
step_out[0] <== data[0] * is_value_match[0];

for(var data_idx = 1; data_idx < DATA_BYTES - MAX_KEY_LENGTH; data_idx++) {

signal or[DATA_BYTES];
for(var data_idx = 1; data_idx < DATA_BYTES; data_idx++) {
// Grab the stack at the indicated height (from `step_in`)
stackSelector[data_idx] = ArraySelector(MAX_STACK_HEIGHT, 2);
stackSelector[data_idx] = ArraySelector(MAX_STACK_HEIGHT + 1, 2);
stackSelector[data_idx].in <== stack[data_idx];
stackSelector[data_idx].index <== step_in[TOTAL_BYTES_ACROSS_NIVC - 1];

log("step_in[", data_idx, "] =", step_in[data_idx]);
log("stackSelector[", data_idx, "].out[0] = ", stackSelector[data_idx].out[0]);
log("stackSelector[", data_idx, "].out[1] = ", stackSelector[data_idx].out[1]);
nextStackSelector[data_idx] = ArraySelector(MAX_STACK_HEIGHT + 1, 2);
nextStackSelector[data_idx].in <== stack[data_idx];
nextStackSelector[data_idx].index <== step_in[TOTAL_BYTES_ACROSS_NIVC - 1] + 1;

// log("stackSelector[", data_idx, "].out[0] = ", stackSelector[data_idx].out[0]);
// log("stackSelector[", data_idx, "].out[1] = ", stackSelector[data_idx].out[1]);

// Detect if we are parsing
parsing_key[data_idx] <== InsideKey()(stackSelector[data_idx].out, parsingData[data_idx][0], parsingData[data_idx][1]);
parsing_value[data_idx] <== InsideValueObject()(stackSelector[data_idx].out, stack[data_idx][1], parsingData[data_idx][0], parsingData[data_idx][1]);
parsing_value[data_idx] <== InsideValueObject()(stackSelector[data_idx].out, nextStackSelector[data_idx].out, parsingData[data_idx][0], parsingData[data_idx][1]);

log("parsing_key[", data_idx, "] = ", parsing_key[data_idx]);
log("parsing_value[", data_idx, "] = ", parsing_value[data_idx]);
// log("parsing_key[", data_idx, "] = ", parsing_key[data_idx]);
// log("parsing_value[", data_idx, "] = ", parsing_value[data_idx]);

// to get correct value, check:
// - key matches at current index and depth of key is as specified
// - whether next KV pair starts
// - whether key matched for a value (propogate key match until new KV pair of lower depth starts)
is_key_match[data_idx] <== KeyMatchAtIndex(paddedDataLen, MAX_KEY_LENGTH, data_idx)(data, key, keyLen, parsing_key[data_idx]);
is_next_pair_at_depth[data_idx] <== NextKVPairAtDepth(MAX_STACK_HEIGHT)(stack[data_idx], data[data_idx], step_in[TOTAL_BYTES_ACROSS_NIVC - 1]);
is_next_pair_at_depth[data_idx] <== NextKVPairAtDepth(MAX_STACK_HEIGHT + 1)(stack[data_idx], data[data_idx], step_in[TOTAL_BYTES_ACROSS_NIVC - 1]);

log("is_key_match[", data_idx, "] = ", is_key_match[data_idx]);
log("is_next_pair_at_depth[", data_idx, "] = ", is_next_pair_at_depth[data_idx]);
// log("is_key_match[", data_idx, "] = ", is_key_match[data_idx]);
// log("is_next_pair_at_depth[", data_idx, "] = ", is_next_pair_at_depth[data_idx]);

is_key_match_for_value[data_idx+1] <== Mux1()([is_key_match_for_value[data_idx] * (1-is_next_pair_at_depth[data_idx]), is_key_match[data_idx] * (1-is_next_pair_at_depth[data_idx])], is_key_match[data_idx]);
is_value_match[data_idx] <== is_key_match_for_value[data_idx+1] * parsing_value[data_idx];

// Set the next NIVC step to only have the masked data
log("is_value_match", is_value_match[data_idx]);
step_out[data_idx] <== data[data_idx] * is_value_match[data_idx];
log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");
}
for (var i = 0 ; i < MAX_KEY_LENGTH ; i++) {
step_out[DATA_BYTES - MAX_KEY_LENGTH + i] <== 0;
or[data_idx] <== OR()(is_value_match[data_idx], is_value_match[data_idx -1]);
step_out[data_idx] <== data[data_idx] * or[data_idx];
// log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");
}
// Append the parser state back on `step_out`
for (var i = DATA_BYTES ; i < TOTAL_BYTES_ACROSS_NIVC - 1 ; i++) {
Expand Down Expand Up @@ -143,12 +157,20 @@ template JsonMaskArrayIndexNIVC(DATA_BYTES, MAX_STACK_HEIGHT) {
}

// Decode the encoded data in `step_in` back into parser variables
signal stack[DATA_BYTES][MAX_STACK_HEIGHT][2];
signal stack[DATA_BYTES][MAX_STACK_HEIGHT + 1][2];
signal parsingData[DATA_BYTES][2];
for (var i = 0 ; i < DATA_BYTES ; i++) {
for (var j = 0 ; j < MAX_STACK_HEIGHT ; j++) {
stack[i][j][0] <== step_in[DATA_BYTES + i * PER_ITERATION_DATA_LENGTH + j * 2];
stack[i][j][1] <== step_in[DATA_BYTES + i * PER_ITERATION_DATA_LENGTH + j * 2 + 1];
for (var j = 0 ; j < MAX_STACK_HEIGHT + 1 ; j++) {
if (j < MAX_STACK_HEIGHT) {
stack[i][j][0] <== step_in[DATA_BYTES + i * PER_ITERATION_DATA_LENGTH + j * 2];
stack[i][j][1] <== step_in[DATA_BYTES + i * PER_ITERATION_DATA_LENGTH + j * 2 + 1];
} else {
// TODO: Now need to do "if curr_depth == MAX_STACK_HEIGHT", set nextStackSelector <== [0,0]
// Add one extra stack element without doing this while parsing.
stack[i][j][0] <== 0;
stack[i][j][1] <== 0;
}

}
parsingData[i][0] <== step_in[DATA_BYTES + i * PER_ITERATION_DATA_LENGTH + MAX_STACK_HEIGHT * 2];
parsingData[i][1] <== step_in[DATA_BYTES + i * PER_ITERATION_DATA_LENGTH + MAX_STACK_HEIGHT * 2 + 1];
Expand All @@ -167,31 +189,31 @@ template JsonMaskArrayIndexNIVC(DATA_BYTES, MAX_STACK_HEIGHT) {
signal or[DATA_BYTES]; // Maybe don't need

component stackSelector[DATA_BYTES];
stackSelector[0] = ArraySelector(MAX_STACK_HEIGHT, 2);
stackSelector[0] = ArraySelector(MAX_STACK_HEIGHT + 1, 2);
stackSelector[0].in <== stack[0];
stackSelector[0].index <== step_in[TOTAL_BYTES_ACROSS_NIVC - 1];

component nextStackSelector[DATA_BYTES];
nextStackSelector[0] = ArraySelector(MAX_STACK_HEIGHT, 2);
nextStackSelector[0] = ArraySelector(MAX_STACK_HEIGHT + 1, 2);
nextStackSelector[0].in <== stack[0];
nextStackSelector[0].index <== step_in[TOTAL_BYTES_ACROSS_NIVC - 1] + 1;

parsing_array[0] <== InsideArrayIndexObject()(stackSelector[0].out, nextStackSelector[0].out, parsingData[0][0], parsingData[0][1], index);
mask[0] <== data[0] * parsing_array[0];

for(var data_idx = 1; data_idx < DATA_BYTES; data_idx++) {
stackSelector[data_idx] = ArraySelector(MAX_STACK_HEIGHT, 2);
stackSelector[data_idx] = ArraySelector(MAX_STACK_HEIGHT + 1, 2);
stackSelector[data_idx].in <== stack[data_idx];
stackSelector[data_idx].index <== step_in[TOTAL_BYTES_ACROSS_NIVC - 1];

nextStackSelector[data_idx] = ArraySelector(MAX_STACK_HEIGHT, 2);
nextStackSelector[data_idx] = ArraySelector(MAX_STACK_HEIGHT + 1, 2);
nextStackSelector[data_idx].in <== stack[data_idx];
nextStackSelector[data_idx].index <== step_in[TOTAL_BYTES_ACROSS_NIVC - 1] + 1;

log("stackSelector[", data_idx, "].out[0] = ", stackSelector[data_idx].out[0]);
log("stackSelector[", data_idx, "].out[1] = ", stackSelector[data_idx].out[1]);
log("nextStackSelector[", data_idx, "].out[0] = ", nextStackSelector[data_idx].out[0]);
log("nextStackSelector[", data_idx, "].out[1] = ", nextStackSelector[data_idx].out[1]);
// log("stackSelector[", data_idx, "].out[0] = ", stackSelector[data_idx].out[0]);
// log("stackSelector[", data_idx, "].out[1] = ", stackSelector[data_idx].out[1]);
// log("nextStackSelector[", data_idx, "].out[0] = ", nextStackSelector[data_idx].out[0]);
// log("nextStackSelector[", data_idx, "].out[1] = ", nextStackSelector[data_idx].out[1]);

parsing_array[data_idx] <== InsideArrayIndexObject()(stackSelector[data_idx].out, nextStackSelector[data_idx].out, parsingData[data_idx][0], parsingData[data_idx][1], index);

Expand Down
Loading

0 comments on commit acd2f2c

Please sign in to comment.