diff --git a/circuits/fetcher.circom b/circuits/fetcher.circom index a3947a2..03482ce 100644 --- a/circuits/fetcher.circom +++ b/circuits/fetcher.circom @@ -3,7 +3,8 @@ pragma circom 2.1.9; include "extract.circom"; include "parser.circom"; include "language.circom"; -include "utils.circom"; +include "search.circom"; +include "./utils/array.circom"; include "circomlib/circuits/mux1.circom"; include "circomlib/circuits/gates.circom"; include "@zk-email/circuits/utils/functions.circom"; @@ -152,7 +153,7 @@ template KeyMatch(dataLen, keyLen) { signal start_of_key <== IndexSelector(dataLen)(data, index - 1); signal is_start_of_key_equal_to_quote <== IsEqual()([start_of_key, 34]); - signal substring_match <== IsSubstringMatchWithIndex(dataLen, keyLen)(data, key, 100, index); + signal substring_match <== SubstringMatchWithIndex(dataLen, keyLen)(data, key, 100, index); signal is_key_between_quotes <== is_start_of_key_equal_to_quote * is_end_of_key_equal_to_quote; signal is_parsing_correct_key <== is_key_between_quotes * parsing_key; @@ -179,7 +180,7 @@ template KeyMatchAtDepth(dataLen, n, keyLen, depth) { signal start_of_key <== IndexSelector(dataLen)(data, index - 1); signal is_start_of_key_equal_to_quote <== IsEqual()([start_of_key, 34]); - signal substring_match <== IsSubstringMatchWithIndex(dataLen, keyLen)(data, key, 100, index); + signal substring_match <== SubstringMatchWithIndex(dataLen, keyLen)(data, key, 100, index); signal is_key_between_quotes <== is_start_of_key_equal_to_quote * is_end_of_key_equal_to_quote; log("key pointer", pointer, depth); diff --git a/circuits/parser.circom b/circuits/parser.circom index bd157e2..0f63120 100644 --- a/circuits/parser.circom +++ b/circuits/parser.circom @@ -178,8 +178,6 @@ template StateToMask(n) { signal input parsing_number; signal output out[3]; - // `parsing_string` can change: - out[1] <== 1 - 2 * parsing_string; // `read_write_value`can change: IF NOT `parsing_string` out[0] <== (1 - parsing_string); diff --git a/circuits/search.circom b/circuits/search.circom index cfd8c12..1799c01 100644 --- a/circuits/search.circom +++ b/circuits/search.circom @@ -113,6 +113,8 @@ template SubstringMatchWithIndex(dataLen, keyLen) { signal input r; signal input start; + signal output out; + // key end index in `data` signal end; end <== start + keyLen; @@ -198,7 +200,7 @@ template SubstringMatchWithIndex(dataLen, keyLen) { } // final sum for data and key should be equal - hashMaskedData[dataLen - 1] === hashMaskedKey[keyLen - 1]; + out <== IsZero()(hashMaskedData[dataLen-1]-hashMaskedKey[keyLen-1]); } /* @@ -236,7 +238,8 @@ template SubstringMatch(dataLen, keyLen) { // matches a `key` in `data` at `pos` // NOTE: constrained verification assures correctness - SubstringMatchWithIndex(dataLen, keyLen)(data, key, r, start); + signal isMatch <== SubstringMatchWithIndex(dataLen, keyLen)(data, key, r, start); + isMatch === 1; position <== start; } \ No newline at end of file diff --git a/circuits/test/search.test.ts b/circuits/test/search.test.ts index 59692db..39e8143 100644 --- a/circuits/test/search.test.ts +++ b/circuits/test/search.test.ts @@ -81,7 +81,7 @@ describe("search", () => { }); describe("SubstringMatchWithIndex", () => { - let circuit: WitnessTester<["data", "key", "r", "start"]>; + let circuit: WitnessTester<["data", "key", "r", "start"], ["out"]>; before(async () => { circuit = await circomkit.WitnessTester(`SubstringSearch`, { @@ -100,17 +100,19 @@ describe("search", () => { r: PoseidonModular(witness["key"].concat(witness["data"])), start: 6 }, + { out: 1 }, ); }); - it("data = witness.json:data, key = witness.json:key, r = hash(key+data), incorrect position", async () => { - await circuit.expectFail( + it("data = witness.json:data, key = witness.json:key, r = hash(key+data), output false", async () => { + await circuit.expectPass( { data: witness["data"], key: witness["key"], r: PoseidonModular(witness["key"].concat(witness["data"])), start: 98 }, + { out: 0 } ); }); }); diff --git a/circuits/test/utils/utils.test.ts b/circuits/test/utils/utils.test.ts deleted file mode 100644 index c92fd42..0000000 --- a/circuits/test/utils/utils.test.ts +++ /dev/null @@ -1,291 +0,0 @@ -import { circomkit, WitnessTester } from "../common"; - -describe("ASCII", () => { - let circuit: WitnessTester<["in"], ["out"]>; - before(async () => { - circuit = await circomkit.WitnessTester(`ASCII`, { - file: "circuits/utils", - template: "ASCII", - params: [13], - }); - console.log("#constraints:", await circuit.getConstraintCount()); - }); - - it("(valid) witness: in = b\"Hello, world!\"", async () => { - await circuit.expectPass( - { in: [72, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33] }, - ); - }); - - it("(invalid) witness: in = [256, ...]", async () => { - await circuit.expectFail( - { in: [256, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100, 33] } - ); - }); -}); - -describe("IsEqualArray", () => { - let circuit: WitnessTester<["in"], ["out"]>; - before(async () => { - circuit = await circomkit.WitnessTester(`IsEqualArray`, { - file: "circuits/utils", - template: "IsEqualArray", - params: [3], - }); - console.log("#constraints:", await circuit.getConstraintCount()); - }); - - it("witness: [[0,0,0],[0,0,0]]", async () => { - await circuit.expectPass( - { in: [[0, 0, 0], [0, 0, 0]] }, - { out: 1 } - ); - }); - - it("witness: [[1,420,69],[1,420,69]]", async () => { - await circuit.expectPass( - { in: [[1, 420, 69], [1, 420, 69]] }, - { out: 1 }, - ); - }); - - it("witness: [[0,0,0],[1,420,69]]", async () => { - await circuit.expectPass( - { in: [[0, 0, 0], [1, 420, 69]] }, - { out: 0 }, - ); - }); - - it("witness: [[1,420,0],[1,420,69]]", async () => { - await circuit.expectPass( - { in: [[1, 420, 0], [1, 420, 69]] }, - { out: 0 }, - ); - }); - - it("witness: [[1,0,69],[1,420,69]]", async () => { - await circuit.expectPass( - { in: [[1, 0, 69], [1, 420, 69]] }, - { out: 0 }, - ); - }); - - it("witness: [[0,420,69],[1,420,69]]", async () => { - await circuit.expectPass( - { in: [[0, 420, 69], [1, 420, 69]] }, - { out: 0 }, - ); - }); -}); - -describe("Contains", () => { - let circuit: WitnessTester<["in", "array"], ["out"]>; - before(async () => { - circuit = await circomkit.WitnessTester(`Contains`, { - file: "circuits/utils", - template: "Contains", - params: [3], - }); - console.log("#constraints:", await circuit.getConstraintCount()); - }); - - it("witness: in = 0, array = [0,1,2]", async () => { - await circuit.expectPass( - { in: 0, array: [0, 1, 2] }, - { out: 1 } - ); - }); - - it("witness: in = 1, array = [0,1,2]", async () => { - await circuit.expectPass( - { in: 1, array: [0, 1, 2] }, - { out: 1 } - ); - }); - - it("witness: in = 2, array = [0,1,2]", async () => { - await circuit.expectPass( - { in: 2, array: [0, 1, 2] }, - { out: 1 } - ); - }); - - it("witness: in = 42069, array = [0,1,2]", async () => { - await circuit.expectPass( - { in: 42069, array: [0, 1, 2] }, - { out: 0 } - ); - }); - -}); - -describe("ArrayAdd", () => { - let circuit: WitnessTester<["lhs", "rhs"], ["out"]>; - before(async () => { - circuit = await circomkit.WitnessTester(`ArrayAdd`, { - file: "circuits/utils", - template: "ArrayAdd", - params: [3], - }); - console.log("#constraints:", await circuit.getConstraintCount()); - }); - - it("witness: lhs = [0,1,2], rhs = [3,5,7]", async () => { - await circuit.expectPass( - { lhs: [0, 1, 2], rhs: [3, 5, 7] }, - { out: [3, 6, 9] } - ); - }); - -}); - -describe("ArrayMul", () => { - let circuit: WitnessTester<["lhs", "rhs"], ["out"]>; - before(async () => { - circuit = await circomkit.WitnessTester(`ArrayMul`, { - file: "circuits/utils", - template: "ArrayMul", - params: [3], - }); - console.log("#constraints:", await circuit.getConstraintCount()); - }); - - it("witness: lhs = [0,1,2], rhs = [3,5,7]", async () => { - await circuit.expectPass( - { lhs: [0, 1, 2], rhs: [3, 5, 7] }, - { out: [0, 5, 14] } - ); - }); - -}); - -describe("InRange", () => { - let circuit: WitnessTester<["in", "range"], ["out"]>; - before(async () => { - circuit = await circomkit.WitnessTester(`InRange`, { - file: "circuits/utils", - template: "InRange", - params: [8], - }); - console.log("#constraints:", await circuit.getConstraintCount()); - }); - - it("witness: in = 1, range = [0,2]", async () => { - await circuit.expectPass( - { in: 1, range: [0, 2] }, - { out: 1 } - ); - }); - - it("witness: in = 69, range = [128,255]", async () => { - await circuit.expectPass( - { in: 69, range: [128, 255] }, - { out: 0 } - ); - }); - - it("witness: in = 200, range = [128,255]", async () => { - await circuit.expectPass( - { in: 1, range: [0, 2] }, - { out: 1 } - ); - }); -}); - -describe("Switch", () => { - let circuit: WitnessTester<["case", "branches", "vals"], ["match", "out"]>; - before(async () => { - circuit = await circomkit.WitnessTester(`Switch`, { - file: "circuits/utils", - template: "Switch", - params: [3], - }); - console.log("#constraints:", await circuit.getConstraintCount()); - }); - - it("witness: case = 0, branches = [0, 1, 2], vals = [69, 420, 1337]", async () => { - await circuit.expectPass( - { case: 0, branches: [0, 1, 2], vals: [69, 420, 1337] }, - { match: 1, out: 69 }, - ); - }); - - it("witness: case = 1, branches = [0, 1, 2], vals = [69, 420, 1337]", async () => { - await circuit.expectPass( - { case: 1, branches: [0, 1, 2], vals: [69, 420, 1337] }, - { match: 1, out: 420 }, - ); - }); - - it("witness: case = 2, branches = [0, 1, 2], vals = [69, 420, 1337]", async () => { - await circuit.expectPass( - { case: 2, branches: [0, 1, 2], vals: [69, 420, 1337] }, - { match: 1, out: 1337 }, - ); - }); - - it("witness: case = 3, branches = [0, 1, 2], vals = [69, 420, 1337]", async () => { - await circuit.expectPass( - { case: 3, branches: [0, 1, 2], vals: [69, 420, 1337] }, - { match: 0, out: 0 }, - ); - }); - - -}); - -describe("SwitchArray", () => { - let circuit: WitnessTester<["case", "branches", "vals"], ["match", "out"]>; - before(async () => { - circuit = await circomkit.WitnessTester(`SwitchArray`, { - file: "circuits/utils", - template: "SwitchArray", - params: [3, 2], - }); - console.log("#constraints:", await circuit.getConstraintCount()); - }); - - it("witness: case = 0, branches = [0, 1, 2], vals = [[69,0], [420,1], [1337,2]]", async () => { - await circuit.expectPass( - { case: 0, branches: [0, 1, 2], vals: [[69, 0], [420, 1], [1337, 2]] }, - { match: 1, out: [69, 0] }, - ); - }); - - it("witness: case = 1, branches = [0, 1, 2], vals = [[69,0], [420,1], [1337,2]]", async () => { - await circuit.expectPass( - { case: 1, branches: [0, 1, 2], vals: [[69, 0], [420, 1], [1337, 2]] }, - { match: 1, out: [420, 1] }, - ); - }); - - it("witness: case = 2, branches = [0, 1, 2], vals = [[69,0], [420,1], [1337,2]]", async () => { - await circuit.expectPass( - { case: 2, branches: [0, 1, 2], vals: [[69, 0], [420, 1], [1337, 2]] }, - { match: 1, out: [1337, 2] }, - ); - }); - - it("witness: case = 3, branches = [0, 1, 2], vals = [[69,0], [420,1], [1337,2]]", async () => { - await circuit.expectPass( - { case: 3, branches: [0, 1, 2], vals: [[69, 0], [420, 1], [1337, 2]] }, - { match: 0, out: [0, 0] } - ); - }); - - it("witness: case = 420, branches = [69, 420, 1337], vals = [[10,3], [20,5], [30,7]]", async () => { - await circuit.expectPass( - { case: 420, branches: [69, 420, 1337], vals: [[10, 3], [20, 5], [30, 7]] }, - { match: 1, out: [20, 5] } - ); - }); - - it("witness: case = 0, branches = [69, 420, 1337], vals = [[10,3], [20,5], [30,7]]", async () => { - await circuit.expectPass( - { case: 0, branches: [69, 420, 1337], vals: [[10, 3], [20, 5], [30, 7]] }, - { match: 0, out: [0, 0] } - ); - }); - -}); - diff --git a/circuits/utils.circom b/circuits/utils.circom deleted file mode 100644 index 735bcc2..0000000 --- a/circuits/utils.circom +++ /dev/null @@ -1,330 +0,0 @@ -pragma circom 2.1.9; - -include "circomlib/circuits/bitify.circom"; -include "circomlib/circuits/comparators.circom"; -include "circomlib/circuits/mux1.circom"; - -/* -All tests for this file are located in: `./test/utils/utils.test.ts` -*/ - -template ASCII(n) { - signal input in[n]; - - component Byte[n]; - for(var i = 0; i < n; i++) { - Byte[i] = Num2Bits(8); - Byte[i].in <== in[i]; - } -} - -/* -This function is an indicator for two equal array inputs. - -# Inputs: -- `n`: the length of arrays to compare -- `in[2][n]`: two arrays of `n` numbers -- `out`: either `0` or `1` - - `1` if `in[0]` is equal to `in[1]` as arrays (i.e., component by component) - - `0` otherwise -*/ -template IsEqualArray(n) { - signal input in[2][n]; - signal output out; - - var accum = 0; - component equalComponent[n]; - - for(var i = 0; i < n; i++) { - equalComponent[i] = IsEqual(); - equalComponent[i].in[0] <== in[0][i]; - equalComponent[i].in[1] <== in[1][i]; - accum += equalComponent[i].out; - } - - component totalEqual = IsEqual(); - totalEqual.in[0] <== n; - totalEqual.in[1] <== accum; - out <== totalEqual.out; -} - - -// TODO: There should be a way to have the below assertion come from the field itself. -/* -This function is an indicator for if an array contains an element. - -# Inputs: -- `n`: the size of the array to search through -- `in`: a number -- `array[n]`: the array we want to search through -- `out`: either `0` or `1` - - `1` if `in` is found inside `array` - - `0` otherwise -*/ -template Contains(n) { - assert(n > 0); - /* - If `n = p` for this large `p`, then it could be that this function - returns the wrong value if every element in `array` was equal to `in`. - This is EXTREMELY unlikely and iterating this high is impossible anyway. - But it is better to check than miss something, so we bound it by `2**254` for now. - */ - assert(n < 2**254); - signal input in; - signal input array[n]; - signal output out; - - var accum = 0; - component equalComponent[n]; - for(var i = 0; i < n; i++) { - equalComponent[i] = IsEqual(); - equalComponent[i].in[0] <== in; - equalComponent[i].in[1] <== array[i]; - accum = accum + equalComponent[i].out; - } - - component someEqual = IsZero(); - someEqual.in <== accum; - - // Apply `not` to this by 1-x - out <== 1 - someEqual.out; -} - -template ArrayAdd(n) { - signal input lhs[n]; - signal input rhs[n]; - signal output out[n]; - - for(var i = 0; i < n; i++) { - out[i] <== lhs[i] + rhs[i]; - } -} - -template ArrayMul(n) { - signal input lhs[n]; - signal input rhs[n]; - signal output out[n]; - - for(var i = 0; i < n; i++) { - out[i] <== lhs[i] * rhs[i]; - } -} - -template InRange(n) { - signal input in; - signal input range[2]; - signal output out; - - component gte = GreaterEqThan(n); - gte.in <== [in, range[0]]; - - component lte = LessEqThan(n); - lte.in <== [in, range[1]]; - - out <== gte.out * lte.out; -} - -/* -This function is creates an exhaustive switch statement from `0` up to `n`. - -# Inputs: -- `m`: the number of switch cases -- `n`: the output array length -- `case`: which case of the switch to select -- `branches[m]`: the values that enable taking different branches in the switch - (e.g., if `branch[i] == 10` then if `case == 10` we set `out == `vals[i]`) -- `vals[m][n]`: the value that is emitted for a given switch case - (e.g., `val[i]` array is emitted on `case == `branch[i]`) - -# Outputs -- `match`: is set to `0` if `case` does not match on any of `branches` -- `out[n]`: the selected output value if one of `branches` is selected (will be `[0,0,...]` otherwise) -*/ -template SwitchArray(m, n) { - assert(m > 0); - assert(n > 0); - signal input case; - signal input branches[m]; - signal input vals[m][n]; - signal output match; - signal output out[n]; - - - // Verify that the `case` is in the possible set of branches - component indicator[m]; - component matchChecker = Contains(m); - signal component_out[m][n]; - var sum[n]; - for(var i = 0; i < m; i++) { - indicator[i] = IsZero(); - indicator[i].in <== case - branches[i]; - matchChecker.array[i] <== 1 - indicator[i].out; - for(var j = 0; j < n; j++) { - component_out[i][j] <== indicator[i].out * vals[i][j]; - sum[j] += component_out[i][j]; - } - } - matchChecker.in <== 0; - match <== matchChecker.out; - - out <== sum; -} - -template Switch(n) { - assert(n > 0); - signal input case; - signal input branches[n]; - signal input vals[n]; - signal output match; - signal output out; - - - // Verify that the `case` is in the possible set of branches - component indicator[n]; - component matchChecker = Contains(n); - signal temp_val[n]; - var sum; - for(var i = 0; i < n; i++) { - indicator[i] = IsZero(); - indicator[i].in <== case - branches[i]; - matchChecker.array[i] <== 1 - indicator[i].out; - temp_val[i] <== indicator[i].out * vals[i]; - sum += temp_val[i]; - } - matchChecker.in <== 0; - match <== matchChecker.out; - - out <== sum; -} - -template IsSubstringMatchWithIndex(dataLen, keyLen) { - signal input data[dataLen]; - signal input key[keyLen]; - signal input r; - signal input start; - - signal output out; - - // key end index in `data` - signal end; - end <== start + keyLen; - - // 2n constraints - // - // create start mask from [pos, dataLen-1] - // | 0 | 0 0 0 0 0 0 |1| 1 1 1 |1| 1 1 |1| - // 0 start end dataLen - signal startMask[dataLen]; - signal startMaskEq[dataLen]; - startMaskEq[0] <== IsEqual()([0, start]); - startMask[0] <== startMaskEq[0]; - for (var i = 1 ; i < dataLen ; i++) { - startMaskEq[i] <== IsEqual()([i, start]); - startMask[i] <== startMask[i-1] + startMaskEq[i]; - } - - // 3n constraints - // - // create end mask from [0, end] - // | 1 | 1 1 1 1 1 1 |1| 1 1 1 |1| 0 0 |0| - // 0 start end dataLen - signal endMask[dataLen]; - signal endMaskEq[dataLen]; - endMaskEq[0] <== IsEqual()([0, end]); - endMask[0] <== 1 - endMaskEq[0]; - for (var i = 1 ; i < dataLen ; i++) { - endMaskEq[i] <== IsEqual()([i, end]); - endMask[i] <== endMask[i-1] * (1 - endMaskEq[i]); - } - - // n constraints - // - // combine start mask and end mask - // | 0 | 0 0 0 0 0 0 |1| 1 1 1 |1| 0 0 |0| - // 0 start end dataLen - signal mask[dataLen]; - for (var i = 0; i < dataLen; i++) { - mask[i] <== startMask[i] * endMask[i]; - } - - // n constraints - // - // masked data from mask - signal maskedData[dataLen]; - for (var i = 0 ; i < dataLen ; i++) { - maskedData[i] <== data[i] * mask[i]; - } - - // n constraints - // - // powers of `r` for masked data - // if (masked data == 1) rDataMasked[i] = rDataMasked[i-1] * r - // else rDataMasked[i] = rDataMasked[i-1] - signal rDataMasked[dataLen]; - rDataMasked[0] <== Mux1()([1, r], mask[0]); - for (var i = 1 ; i < dataLen ; i++) { - rDataMasked[i] <== Mux1()([rDataMasked[i-1], rDataMasked[i-1] * r], mask[i]); - } - - // powers of `r` for key - signal rKeyMasked[keyLen]; - rKeyMasked[0] <== r; - for (var i = 1; i < keyLen ; i++) { - rKeyMasked[i] <== rKeyMasked[i-1] * r; - } - - // n constraints - // - // calculate linear combination with random_num for data: data[i] = data[i-1] + (r^i * data[i]) - signal hashMaskedData[dataLen]; - hashMaskedData[0] <== rDataMasked[0] * maskedData[0]; - for (var i = 1; i < dataLen ; i++) { - hashMaskedData[i] <== hashMaskedData[i-1] + (rDataMasked[i] * maskedData[i]); - } - - // calculate linear combination with random_num for key: key[i] = key[i-1] + (r^i * key[i]) - signal hashMaskedKey[keyLen]; - hashMaskedKey[0] <== rKeyMasked[0] * key[0]; - for (var i = 1; i < keyLen ; i++) { - hashMaskedKey[i] <== hashMaskedKey[i-1] + (rKeyMasked[i] * key[i]); - } - - // final sum for data and key should be equal - // hashMaskedData[dataLen - 1] === hashMaskedKey[keyLen - 1]; - out <== IsZero()(hashMaskedData[dataLen-1]-hashMaskedKey[keyLen-1]); -} - -// from: https://github.com/pluto/aes-proof/blob/main/circuits/aes-gcm/helper_functions.circom - -template SumMultiple(n) { - signal input nums[n]; - signal output sum; - - signal sums[n]; - sums[0] <== nums[0]; - - for(var i=1; i