From b9feeeb240ddf867da85198de7d59e73cba4b008 Mon Sep 17 00:00:00 2001 From: lonerapier Date: Thu, 10 Oct 2024 17:43:02 +0530 Subject: [PATCH] add optimisations from web-prover --- circuits/json/extractor.circom | 38 +++--- circuits/json/interpreter.circom | 111 ++++++++++++------ .../test/json/extractor/extractor.test.ts | 7 +- .../test/json/extractor/interpreter.test.ts | 81 +++++++++---- circuits/utils/array.circom | 23 ++++ src/codegen/json.rs | 4 +- 6 files changed, 177 insertions(+), 87 deletions(-) diff --git a/circuits/json/extractor.circom b/circuits/json/extractor.circom index 3a37942..0b75ab1 100644 --- a/circuits/json/extractor.circom +++ b/circuits/json/extractor.circom @@ -13,13 +13,13 @@ template ObjectExtractor(DATA_BYTES, MAX_STACK_HEIGHT, maxKeyLen, maxValueLen) { signal output value[maxValueLen]; // Constraints. - signal value_starting_index[DATA_BYTES]; + signal value_starting_index[DATA_BYTES - maxKeyLen]; // flag determining whether this byte is matched value - signal is_value_match[DATA_BYTES]; + signal is_value_match[DATA_BYTES - maxKeyLen]; // final mask - signal mask[DATA_BYTES]; + signal mask[DATA_BYTES - maxKeyLen]; - component State[DATA_BYTES]; + component State[DATA_BYTES - maxKeyLen]; State[0] = StateUpdate(MAX_STACK_HEIGHT); State[0].byte <== data[0]; for(var i = 0; i < MAX_STACK_HEIGHT; i++) { @@ -28,29 +28,29 @@ template ObjectExtractor(DATA_BYTES, MAX_STACK_HEIGHT, maxKeyLen, maxValueLen) { State[0].parsing_string <== 0; State[0].parsing_number <== 0; - signal parsing_key[DATA_BYTES]; - signal parsing_value[DATA_BYTES]; - signal parsing_object_value[DATA_BYTES]; - signal is_key_match[DATA_BYTES]; - signal is_key_match_for_value[DATA_BYTES+1]; + signal parsing_key[DATA_BYTES - maxKeyLen]; + signal parsing_value[DATA_BYTES - maxKeyLen]; + signal parsing_object_value[DATA_BYTES - maxKeyLen]; + signal is_key_match[DATA_BYTES - maxKeyLen]; + signal is_key_match_for_value[DATA_BYTES+1 - maxKeyLen]; is_key_match_for_value[0] <== 0; - signal is_next_pair_at_depth[DATA_BYTES]; - signal or[DATA_BYTES]; + signal is_next_pair_at_depth[DATA_BYTES - maxKeyLen]; + signal or[DATA_BYTES - maxKeyLen]; // initialise first iteration // check inside key or value - parsing_key[0] <== InsideKey(MAX_STACK_HEIGHT)(State[0].next_stack, State[0].next_parsing_string, State[0].next_parsing_number); + parsing_key[0] <== InsideKey()(State[0].next_stack[0], State[0].next_parsing_string, State[0].next_parsing_number); parsing_value[0] <== InsideValueObject()(State[0].next_stack[0], State[0].next_stack[1], State[0].next_parsing_string, State[0].next_parsing_number); - is_key_match[0] <== KeyMatchAtDepthWithIndex(DATA_BYTES, MAX_STACK_HEIGHT, maxKeyLen, 0)(data, key, keyLen, 0, parsing_key[0], State[0].next_stack); + is_key_match[0] <== 0; is_next_pair_at_depth[0] <== NextKVPairAtDepth(MAX_STACK_HEIGHT, 0)(State[0].next_stack, data[0]); is_key_match_for_value[1] <== Mux1()([is_key_match_for_value[0] * (1-is_next_pair_at_depth[0]), is_key_match[0] * (1-is_next_pair_at_depth[0])], is_key_match[0]); is_value_match[0] <== parsing_value[0] * is_key_match_for_value[1]; mask[0] <== data[0] * is_value_match[0]; - for(var data_idx = 1; data_idx < DATA_BYTES; data_idx++) { + for(var data_idx = 1; data_idx < DATA_BYTES - maxKeyLen; data_idx++) { State[data_idx] = StateUpdate(MAX_STACK_HEIGHT); State[data_idx].byte <== data[data_idx]; State[data_idx].stack <== State[data_idx - 1].next_stack; @@ -66,7 +66,7 @@ template ObjectExtractor(DATA_BYTES, MAX_STACK_HEIGHT, maxKeyLen, maxValueLen) { // - mask // check if inside key or not - parsing_key[data_idx] <== InsideKey(MAX_STACK_HEIGHT)(State[data_idx].next_stack, State[data_idx].next_parsing_string, State[data_idx].next_parsing_number); + parsing_key[data_idx] <== InsideKey()(State[data_idx].next_stack[0], State[data_idx].next_parsing_string, State[data_idx].next_parsing_number); // check if inside value parsing_value[data_idx] <== InsideValueObject()(State[data_idx].next_stack[0], State[data_idx].next_stack[1], State[data_idx].next_parsing_string, State[data_idx].next_parsing_number); @@ -74,7 +74,7 @@ template ObjectExtractor(DATA_BYTES, MAX_STACK_HEIGHT, maxKeyLen, maxValueLen) { // - key matches at current index and depth of key is as specified // - whether next KV pair starts // - whether key matched for a value (propogate key match until new KV pair of lower depth starts) - is_key_match[data_idx] <== KeyMatchAtDepthWithIndex(DATA_BYTES, MAX_STACK_HEIGHT, maxKeyLen, 0)(data, key, keyLen, data_idx, parsing_key[data_idx], State[data_idx].next_stack); + is_key_match[data_idx] <== KeyMatchAtIndex(DATA_BYTES, maxKeyLen, data_idx)(data, key, keyLen, parsing_key[data_idx]); is_next_pair_at_depth[data_idx] <== NextKVPairAtDepth(MAX_STACK_HEIGHT, 0)(State[data_idx].next_stack, data[data_idx]); is_key_match_for_value[data_idx+1] <== Mux1()([is_key_match_for_value[data_idx] * (1-is_next_pair_at_depth[data_idx]), is_key_match[data_idx] * (1-is_next_pair_at_depth[data_idx])], is_key_match[data_idx]); is_value_match[data_idx] <== is_key_match_for_value[data_idx+1] * parsing_value[data_idx]; @@ -91,14 +91,14 @@ template ObjectExtractor(DATA_BYTES, MAX_STACK_HEIGHT, maxKeyLen, maxValueLen) { value_starting_index[0] <== 0; is_prev_starting_index[0] <== 0; is_zero_mask[0] <== IsZero()(mask[0]); - for (var i=1 ; i 34 - // end of key equals `"` - signal end_of_key <== IndexSelector(dataLen)(data, index + keyLen); - signal is_end_of_key_equal_to_quote <== IsEqual()([end_of_key, 34]); - - // start of key equals `"` - signal start_of_key <== IndexSelector(dataLen)(data, index - 1); - signal is_start_of_key_equal_to_quote <== IsEqual()([start_of_key, 34]); + // start of key equal to quote + signal startOfKeyEqualToQuote <== IsEqual()([data[index - 1], 34]); + signal isParsingCorrectKey <== parsing_key * startOfKeyEqualToQuote; // key matches - signal substring_match <== SubstringMatchWithIndexx(dataLen, maxKeyLen)(data, key, keyLen, index); - - // key should be a string - signal is_key_between_quotes <== is_start_of_key_equal_to_quote * is_end_of_key_equal_to_quote; - - // is the index given correct? - signal is_parsing_correct_key <== is_key_between_quotes * parsing_key; - // is the key given by index at correct depth? - signal is_key_at_depth <== IsEqual()([pointer-1, depth]); - - signal is_parsing_correct_key_at_depth <== is_parsing_correct_key * is_key_at_depth; - - signal output out <== substring_match * is_parsing_correct_key_at_depth; + component isSubstringMatch = MatchPaddedKey(maxKeyLen); + isSubstringMatch.in[0] <== key; + isSubstringMatch.keyLen <== keyLen; + for(var matcher_idx = 0; matcher_idx < maxKeyLen; matcher_idx++) { + isSubstringMatch.in[1][matcher_idx] <== data[index + matcher_idx]; + } + + signal output out <== isSubstringMatch.out * isParsingCorrectKey; } \ No newline at end of file diff --git a/circuits/test/json/extractor/extractor.test.ts b/circuits/test/json/extractor/extractor.test.ts index 74576fa..30738e4 100644 --- a/circuits/test/json/extractor/extractor.test.ts +++ b/circuits/test/json/extractor/extractor.test.ts @@ -294,7 +294,8 @@ describe("object-extractor", async () => { let circuit: WitnessTester<["data", "key", "keyLen"], ["value"]>; let jsonFilename = "value_object"; let jsonFile: number[] = []; - let maxKeyLen = 10; + let maxDataLen = 200; + let maxKeyLen = 3; let maxValueLen = 30; before(async () => { @@ -304,12 +305,12 @@ describe("object-extractor", async () => { "a" ] ); - jsonFile = inputJson; + jsonFile = inputJson.concat(Array(maxDataLen - inputJson.length).fill(0)); circuit = await circomkit.WitnessTester(`Extract`, { file: `json/extractor`, template: "ObjectExtractor", - params: [inputJson.length, 3, maxKeyLen, maxValueLen], + params: [maxDataLen, 3, maxKeyLen, maxValueLen], }); console.log("#constraints:", await circuit.getConstraintCount()); }); diff --git a/circuits/test/json/extractor/interpreter.test.ts b/circuits/test/json/extractor/interpreter.test.ts index 44318c8..219beac 100644 --- a/circuits/test/json/extractor/interpreter.test.ts +++ b/circuits/test/json/extractor/interpreter.test.ts @@ -2,13 +2,13 @@ import { circomkit, WitnessTester, generateDescription, readJSONInputFile } from import { PoseidonModular } from "../../common/poseidon"; describe("Interpreter", async () => { - describe("InsideKey", async () => { + describe("InsideKeyAtTop", async () => { let circuit: WitnessTester<["stack", "parsing_string", "parsing_number"], ["out"]>; before(async () => { - circuit = await circomkit.WitnessTester(`InsideKey`, { + circuit = await circomkit.WitnessTester(`InsideKeyAtTop`, { file: "json/interpreter", - template: "InsideKey", + template: "InsideKeyAtTop", params: [4], }); console.log("#constraints:", await circuit.getConstraintCount()); @@ -41,6 +41,38 @@ describe("Interpreter", async () => { generatePassCase(input5, { out: 0 }, "parsing number as a key"); }); + describe("InsideKey", async () => { + let circuit: WitnessTester<["stack", "parsing_string", "parsing_number"], ["out"]>; + + before(async () => { + circuit = await circomkit.WitnessTester(`InsideKey`, { + file: "json/interpreter", + template: "InsideKey", + }); + console.log("#constraints:", await circuit.getConstraintCount()); + }); + + function generatePassCase(input: any, expected: any, desc: string) { + const description = generateDescription(input); + + it(`(valid) witness: ${description} ${desc}`, async () => { + await circuit.expectPass(input, expected); + }); + } + + let input1 = { stack: [1, 0], parsing_string: 1, parsing_number: 0 }; + let output = { out: 1 }; + generatePassCase(input1, output, ""); + + // fail cases + + let input2 = { stack: [1, 1], parsing_string: 1, parsing_number: 0 }; + generatePassCase(input2, { out: 0 }, "invalid stack"); + + let input3 = { stack: [1, 0], parsing_string: 1, parsing_number: 1 }; + generatePassCase(input3, { out: 0 }, "parsing number as a key"); + }); + describe("InsideValueAtTop", async () => { let circuit: WitnessTester<["stack", "parsing_string", "parsing_number"], ["out"]>; @@ -355,11 +387,11 @@ describe("Interpreter", async () => { generatePassCase(input6, { out: 0 }, 1, "wrong depth"); }); - describe("KeyMatchAtDepthWithIndex", async () => { - let circuit: WitnessTester<["data", "key", "keyLen", "index", "parsing_key", "stack"], ["out"]>; - let maxKeyLen = 10; + describe("KeyMatchAtIndex", async () => { + let circuit: WitnessTester<["data", "key", "keyLen", "parsing_key"], ["out"]>; + let maxKeyLen = 3; - function generatePassCase(input: any, expected: any, depth: number, desc: string) { + function generatePassCase(input: any, expected: any, index: number, desc: string) { const description = generateDescription(input); it(`(valid) witness: ${description} ${desc}`, async () => { @@ -367,10 +399,10 @@ describe("Interpreter", async () => { let padded_key = input.key.concat(Array(maxKeyLen - input.key.length).fill(0)); input.key = padded_key; - circuit = await circomkit.WitnessTester(`KeyMatchAtDepthWithIndex`, { + circuit = await circomkit.WitnessTester(`KeyMatchAtIndex`, { file: "json/interpreter", - template: "KeyMatchAtDepthWithIndex", - params: [input.data.length, 4, maxKeyLen, depth], + template: "KeyMatchAtIndex", + params: [input.data.length, maxKeyLen, index], }); console.log("#constraints:", await circuit.getConstraintCount()); @@ -383,31 +415,32 @@ describe("Interpreter", async () => { let output = { out: 1 }; let key1 = input[1][0]; - let input1 = { data: input[0], key: key1, keyLen: key1.length, index: 2, parsing_key: 1, stack: [[1, 0], [0, 0], [0, 0], [0, 0]] }; - generatePassCase(input1, output, 0, ""); + let input1 = { data: input[0], key: key1, keyLen: key1.length, parsing_key: 1 }; + generatePassCase(input1, output, 2, ""); let key2 = input[1][2]; - let input2 = { data: input[0], key: key2, keyLen: key2.length, index: 8, parsing_key: 1, stack: [[1, 1], [2, 0], [1, 0], [0, 0]] }; - generatePassCase(input2, output, 2, ""); + let input2 = { data: input[0], key: key2, keyLen: key2.length, parsing_key: 1 }; + generatePassCase(input2, output, 8, ""); - let input3 = { data: input[0], key: [99], keyLen: 1, index: 20, parsing_key: 1, stack: [[1, 1], [2, 1], [1, 1], [0, 0]] }; - generatePassCase(input3, output, 2, "wrong stack"); + let input3 = { data: input[0], key: [99], keyLen: 1, parsing_key: 1 }; + generatePassCase(input3, output, 20, "wrong stack"); // fail cases + let failOutput = { out: 0 }; let key4 = input[1][1]; - let input4 = { data: input[0], key: key4, keyLen: key4.length, index: 3, parsing_key: 1, stack: [[1, 0], [2, 0], [1, 0], [0, 0]] }; - generatePassCase(input4, { out: 0 }, 2, "wrong key"); + let input4 = { data: input[0], key: key4, keyLen: key4.length, parsing_key: 1 }; + generatePassCase(input4, failOutput, 3, "wrong key"); - let input5 = { data: input[0], key: [97], keyLen: 1, index: 12, parsing_key: 0, stack: [[1, 1], [2, 0], [1, 1], [0, 0]] }; - generatePassCase(input5, { out: 0 }, 3, "not parsing key"); + let input5 = { data: input[0], key: [97], keyLen: 1, parsing_key: 0 }; + generatePassCase(input5, failOutput, 12, "not parsing key"); let input6Data = input[0].slice(0); input6Data.splice(1, 1, 35); - let input6 = { data: input6Data, key: input[1][0], keyLen: input[1][0].length, index: 2, parsing_key: 1, stack: [[1, 0], [0, 0], [0, 0], [0, 0]] }; - generatePassCase(input6, { out: 0 }, 0, "invalid key (not surrounded by quotes)"); + let input6 = { data: input6Data, key: input[1][0], keyLen: input[1][0].length, parsing_key: 1 }; + generatePassCase(input6, failOutput, 2, "invalid key (not surrounded by quotes)"); - let input7 = { data: input[0], key: input[1][0], keyLen: input[1][0].length, index: 2, parsing_key: 1, stack: [[1, 0], [0, 0], [0, 0], [0, 0]] }; - generatePassCase(input6, { out: 0 }, 1, "wrong depth"); + let input7 = { data: input[0], key: input[1][0], keyLen: input[1][0].length, parsing_key: 1 }; + generatePassCase(input6, failOutput, 2, "wrong depth"); }); }); \ No newline at end of file diff --git a/circuits/utils/array.circom b/circuits/utils/array.circom index ff45809..d19b4d6 100644 --- a/circuits/utils/array.circom +++ b/circuits/utils/array.circom @@ -36,6 +36,29 @@ template IsEqualArray(n) { out <== totalEqual.out; } +template IsEqualArrayPaddedLHS(n) { + signal input in[2][n]; + signal output out; + + var accum = 0; + component equalComponent[n]; + component isPaddedElement[n]; + + for(var i = 0; i < n; i++) { + isPaddedElement[i] = IsZero(); + isPaddedElement[i].in <== in[0][i]; + equalComponent[i] = IsEqual(); + equalComponent[i].in[0] <== in[0][i]; + equalComponent[i].in[1] <== in[1][i] * (1-isPaddedElement[i].out); + accum += equalComponent[i].out; + } + + component totalEqual = IsEqual(); + totalEqual.in[0] <== n; + totalEqual.in[1] <== accum; + out <== totalEqual.out; +} + // TODO: There should be a way to have the below assertion come from the field itself. /* This template is an indicator for if an array contains an element. diff --git a/src/codegen/json.rs b/src/codegen/json.rs index a90d545..4fc726f 100644 --- a/src/codegen/json.rs +++ b/src/codegen/json.rs @@ -383,7 +383,7 @@ fn build_json_circuit( // parsing_key and parsing_object{i}_value circuit_buffer += r#" // initialise first iteration - parsing_key[0] <== InsideKey(MAX_STACK_HEIGHT)(State[0].next_stack, State[0].next_parsing_string, State[0].next_parsing_number); + parsing_key[0] <== InsideKeyAtTop(MAX_STACK_HEIGHT)(State[0].next_stack, State[0].next_parsing_string, State[0].next_parsing_number); "#; @@ -487,7 +487,7 @@ fn build_json_circuit( // - mask // check if inside key or not - parsing_key[data_idx] <== InsideKey(MAX_STACK_HEIGHT)(State[data_idx].next_stack, State[data_idx].next_parsing_string, State[data_idx].next_parsing_number); + parsing_key[data_idx] <== InsideKeyAtTop(MAX_STACK_HEIGHT)(State[data_idx].next_stack, State[data_idx].next_parsing_string, State[data_idx].next_parsing_number); "#;