From 12a66514ff5ed36f9113b7a33f7ca0c60c6627fb Mon Sep 17 00:00:00 2001 From: Colin Roberts Date: Mon, 9 Sep 2024 15:12:32 -0600 Subject: [PATCH] feat: http locking and circuit codegen (#76) * fix: circuits.json * remove unneeded import * feat: `MethodMatch` * check method * validating method and target * feat: lock a request line * a few more tests * request and status line working * WIP: progress towards header locking * lock with a header * codgen for request locking * working codegen! * Update codegen.test.ts * address Mr. Sambhav's feedback :) --- circuits.json | 9 +- circuits/http/extractor.circom | 3 +- circuits/http/interpreter.circom | 66 ++++---- circuits/http/locker.circom | 185 ++++++++++++++++++++ circuits/http/parser/language.circom | 25 --- circuits/test/common/index.ts | 22 ++- circuits/test/http/codegen.test.ts | 139 +++++++++++++++ circuits/test/http/extractor.test.ts | 3 +- circuits/test/http/interpreter.test.ts | 36 ++-- circuits/test/http/locker.test.ts | 135 +++++++++++++++ examples/lockfile/test.lock.json | 1 + src/http_lock.rs | 224 ++++++++++++++++++++++++- 12 files changed, 765 insertions(+), 83 deletions(-) create mode 100644 circuits/http/locker.circom create mode 100644 circuits/test/http/codegen.test.ts create mode 100644 circuits/test/http/locker.test.ts diff --git a/circuits.json b/circuits.json index 485460c..9139100 100644 --- a/circuits.json +++ b/circuits.json @@ -74,12 +74,17 @@ "get_request": { "file": "http/parser/parser", "template": "Parser", - "params": [60] + "params": [ + 60 + ] }, "get_response": { "file": "http/parser/parser", "template": "Parser", - "params": [89] + "params": [ + 89 + ] + }, "json_extract_value_string": { "file": "main/value_string", "template": "ExtractStringValue", diff --git a/circuits/http/extractor.circom b/circuits/http/extractor.circom index a4fa227..d6b3dd6 100644 --- a/circuits/http/extractor.circom +++ b/circuits/http/extractor.circom @@ -4,7 +4,6 @@ include "interpreter.circom"; include "parser/machine.circom"; include "../utils/bytes.circom"; include "../utils/search.circom"; -include "circomlib/circuits/mux1.circom"; include "circomlib/circuits/gates.circom"; include "@zk-email/circuits/utils/array.circom"; @@ -163,4 +162,4 @@ template ExtractHeaderValue(DATA_BYTES, headerNameLength, maxValueLength) { } value <== SelectSubArray(DATA_BYTES, maxValueLength)(valueMask, valueStartingIndex[DATA_BYTES-1]+1, maxValueLength); -} \ No newline at end of file +} diff --git a/circuits/http/interpreter.circom b/circuits/http/interpreter.circom index b4d8e75..1fce278 100644 --- a/circuits/http/interpreter.circom +++ b/circuits/http/interpreter.circom @@ -4,35 +4,41 @@ include "parser/language.circom"; include "../utils/search.circom"; include "../utils/array.circom"; -/* TODO: -Notes -- -- This is a pretty efficient way to simply check what the method used in a request is by checking - the first `DATA_LENGTH` number of bytes. -- Could probably change this to a template that checks if it is one of the given methods - so we don't check them all in one -*/ -template YieldMethod(DATA_LENGTH) { - signal input bytes[DATA_LENGTH]; - signal output MethodTag; - - component RequestMethod = RequestMethod(); - component RequestMethodTag = RequestMethodTag(); - - component IsGet = IsEqualArray(3); - for(var byte_idx = 0; byte_idx < 3; byte_idx++) { - IsGet.in[0][byte_idx] <== bytes[byte_idx]; - IsGet.in[1][byte_idx] <== RequestMethod.GET[byte_idx]; - } - signal TagGet <== IsGet.out * RequestMethodTag.GET; - - component IsPost = IsEqualArray(4); - for(var byte_idx = 0; byte_idx < 4; byte_idx++) { - IsPost.in[0][byte_idx] <== bytes[byte_idx]; - IsPost.in[1][byte_idx] <== RequestMethod.POST[byte_idx]; - } - signal TagPost <== IsPost.out * RequestMethodTag.POST; - - MethodTag <== TagGet + TagPost; +template inStartLine() { + signal input parsing_start; + signal output out; + + signal isBeginning <== IsEqual()([parsing_start, 1]); + signal isMiddle <== IsEqual()([parsing_start, 2]); + signal isEnd <== IsEqual()([parsing_start, 3]); + + out <== isBeginning + isMiddle + isEnd; +} + +template inStartMiddle() { + signal input parsing_start; + signal output out; + + out <== IsEqual()([parsing_start, 2]); +} + +template inStartEnd() { + signal input parsing_start; + signal output out; + + out <== IsEqual()([parsing_start, 3]); +} + +// TODO: This likely isn't really an "Intepreter" thing +template MethodMatch(dataLen, methodLen) { + signal input data[dataLen]; + signal input method[methodLen]; + + signal input r; + signal input index; + + signal isMatch <== SubstringMatchWithIndex(dataLen, methodLen)(data, method, r, index); + isMatch === 1; } // https://www.rfc-editor.org/rfc/rfc9112.html#name-field-syntax @@ -45,7 +51,7 @@ template HeaderFieldNameValueMatch(dataLen, nameLen, valueLen) { component syntax = Syntax(); - signal output value[valueLen]; + // signal output value[valueLen]; // is name matches signal headerNameMatch <== SubstringMatchWithIndex(dataLen, nameLen)(data, headerName, r, index); diff --git a/circuits/http/locker.circom b/circuits/http/locker.circom new file mode 100644 index 0000000..8680759 --- /dev/null +++ b/circuits/http/locker.circom @@ -0,0 +1,185 @@ +pragma circom 2.1.9; + +include "interpreter.circom"; +include "parser/machine.circom"; +include "../utils/bytes.circom"; +include "../utils/search.circom"; +include "circomlib/circuits/gates.circom"; +include "@zk-email/circuits/utils/array.circom"; + +template LockStartLine(DATA_BYTES, beginningLen, middleLen, finalLen) { + signal input data[DATA_BYTES]; + signal input beginning[beginningLen]; + signal input middle[middleLen]; + signal input final[finalLen]; + + //--------------------------------------------------------------------------------------------// + //-CONSTRAINTS--------------------------------------------------------------------------------// + //--------------------------------------------------------------------------------------------// + component dataASCII = ASCII(DATA_BYTES); + dataASCII.in <== data; + //--------------------------------------------------------------------------------------------// + + // Initialze the parser + component State[DATA_BYTES]; + State[0] = StateUpdate(); + State[0].byte <== data[0]; + State[0].parsing_start <== 1; + State[0].parsing_header <== 0; + State[0].parsing_field_name <== 0; + State[0].parsing_field_value <== 0; + State[0].parsing_body <== 0; + State[0].line_status <== 0; + + /* + Note, because we know a beginning is the very first thing in a request + we can make this more efficient by just comparing the first `beginningLen` bytes + of the data ASCII against the beginning ASCII itself. + */ + // Check first beginning byte + signal beginningIsEqual[beginningLen]; + beginningIsEqual[0] <== IsEqual()([data[0],beginning[0]]); + beginningIsEqual[0] === 1; + + // Setup to check middle bytes + signal startLineMask[DATA_BYTES]; + signal middleMask[DATA_BYTES]; + signal finalMask[DATA_BYTES]; + + var middle_start_counter = 1; + var middle_end_counter = 1; + var final_end_counter = 1; + for(var data_idx = 1; data_idx < DATA_BYTES; data_idx++) { + State[data_idx] = StateUpdate(); + State[data_idx].byte <== data[data_idx]; + State[data_idx].parsing_start <== State[data_idx - 1].next_parsing_start; + State[data_idx].parsing_header <== State[data_idx - 1].next_parsing_header; + State[data_idx].parsing_field_name <== State[data_idx-1].next_parsing_field_name; + State[data_idx].parsing_field_value <== State[data_idx-1].next_parsing_field_value; + State[data_idx].parsing_body <== State[data_idx - 1].next_parsing_body; + State[data_idx].line_status <== State[data_idx - 1].next_line_status; + + // Check remaining beginning bytes + if(data_idx < beginningLen) { + beginningIsEqual[data_idx] <== IsEqual()([data[data_idx], beginning[data_idx]]); + beginningIsEqual[data_idx] === 1; + } + + // Middle + startLineMask[data_idx] <== inStartLine()(State[data_idx].parsing_start); + middleMask[data_idx] <== inStartMiddle()(State[data_idx].parsing_start); + finalMask[data_idx] <== inStartEnd()(State[data_idx].parsing_start); + middle_start_counter += startLineMask[data_idx] - middleMask[data_idx] - finalMask[data_idx]; + // The end of middle is the start of the final + middle_end_counter += startLineMask[data_idx] - finalMask[data_idx]; + final_end_counter += startLineMask[data_idx]; + + // Debugging + log("State[", data_idx, "].parsing_start = ", State[data_idx].parsing_start); + log("State[", data_idx, "].parsing_header = ", State[data_idx].parsing_header); + log("State[", data_idx, "].parsing_field_name = ", State[data_idx].parsing_field_name); + log("State[", data_idx, "].parsing_field_value = ", State[data_idx].parsing_field_value); + log("State[", data_idx, "].parsing_body = ", State[data_idx].parsing_body); + log("State[", data_idx, "].line_status = ", State[data_idx].line_status); + log("------------------------------------------------"); + log("middle_start_counter = ", middle_start_counter); + log("middle_end_counter = ", middle_end_counter); + log("final_end_counter = ", final_end_counter); + log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"); + } + + // Debugging + log("State[", DATA_BYTES, "].parsing_start ", "= ", State[DATA_BYTES-1].next_parsing_start); + log("State[", DATA_BYTES, "].parsing_header ", "= ", State[DATA_BYTES-1].next_parsing_header); + log("State[", DATA_BYTES, "].parsing_field_name ", "= ", State[DATA_BYTES-1].parsing_field_name); + log("State[", DATA_BYTES, "].parsing_field_value", "= ", State[DATA_BYTES-1].parsing_field_value); + log("State[", DATA_BYTES, "].parsing_body ", "= ", State[DATA_BYTES-1].next_parsing_body); + log("State[", DATA_BYTES, "].line_status ", "= ", State[DATA_BYTES-1].next_line_status); + log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"); + + // Additionally verify beginning had correct length + beginningLen === middle_start_counter - 1; + + // Check middle is correct by substring match and length check + // TODO: change r + signal middleMatch <== SubstringMatchWithIndex(DATA_BYTES, middleLen)(data, middle, 100, middle_start_counter); + middleMatch === 1; + middleLen === middle_end_counter - middle_start_counter - 1; + + // Check final is correct by substring match and length check + // TODO: change r + signal finalMatch <== SubstringMatchWithIndex(DATA_BYTES, finalLen)(data, final, 100, middle_end_counter); + finalMatch === 1; + // -2 here for the CRLF + finalLen === final_end_counter - middle_end_counter - 2; +} + +template LockHeader(DATA_BYTES, headerNameLen, headerValueLen) { + signal input data[DATA_BYTES]; + signal input header[headerNameLen]; + signal input value[headerValueLen]; + + //--------------------------------------------------------------------------------------------// + //-CONSTRAINTS--------------------------------------------------------------------------------// + //--------------------------------------------------------------------------------------------// + component dataASCII = ASCII(DATA_BYTES); + dataASCII.in <== data; + //--------------------------------------------------------------------------------------------// + + // Initialze the parser + component State[DATA_BYTES]; + State[0] = StateUpdate(); + State[0].byte <== data[0]; + State[0].parsing_start <== 1; + State[0].parsing_header <== 0; + State[0].parsing_field_name <== 0; + State[0].parsing_field_value <== 0; + State[0].parsing_body <== 0; + State[0].line_status <== 0; + + component headerFieldNameValueMatch[DATA_BYTES]; + signal isHeaderFieldNameValueMatch[DATA_BYTES]; + + isHeaderFieldNameValueMatch[0] <== 0; + var hasMatched = 0; + + for(var data_idx = 1; data_idx < DATA_BYTES; data_idx++) { + State[data_idx] = StateUpdate(); + State[data_idx].byte <== data[data_idx]; + State[data_idx].parsing_start <== State[data_idx - 1].next_parsing_start; + State[data_idx].parsing_header <== State[data_idx - 1].next_parsing_header; + State[data_idx].parsing_field_name <== State[data_idx-1].next_parsing_field_name; + State[data_idx].parsing_field_value <== State[data_idx-1].next_parsing_field_value; + State[data_idx].parsing_body <== State[data_idx - 1].next_parsing_body; + State[data_idx].line_status <== State[data_idx - 1].next_line_status; + + // TODO: change r + headerFieldNameValueMatch[data_idx] = HeaderFieldNameValueMatch(DATA_BYTES, headerNameLen, headerValueLen); + headerFieldNameValueMatch[data_idx].data <== data; + headerFieldNameValueMatch[data_idx].headerName <== header; + headerFieldNameValueMatch[data_idx].headerValue <== value; + headerFieldNameValueMatch[data_idx].r <== 100; + headerFieldNameValueMatch[data_idx].index <== data_idx; + isHeaderFieldNameValueMatch[data_idx] <== isHeaderFieldNameValueMatch[data_idx-1] + headerFieldNameValueMatch[data_idx].out; + + // Debugging + log("State[", data_idx, "].parsing_start ", "= ", State[data_idx].parsing_start); + log("State[", data_idx, "].parsing_header ", "= ", State[data_idx].parsing_header); + log("State[", data_idx, "].parsing_field_name ", "= ", State[data_idx].parsing_field_name); + log("State[", data_idx, "].parsing_field_value", "= ", State[data_idx].parsing_field_value); + log("State[", data_idx, "].parsing_body ", "= ", State[data_idx].parsing_body); + log("State[", data_idx, "].line_status ", "= ", State[data_idx].line_status); + log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"); + } + + // Debugging + log("State[", DATA_BYTES, "].parsing_start ", "= ", State[DATA_BYTES-1].next_parsing_start); + log("State[", DATA_BYTES, "].parsing_header ", "= ", State[DATA_BYTES-1].next_parsing_header); + log("State[", DATA_BYTES, "].parsing_field_name ", "= ", State[DATA_BYTES-1].parsing_field_name); + log("State[", DATA_BYTES, "].parsing_field_value", "= ", State[DATA_BYTES-1].parsing_field_value); + log("State[", DATA_BYTES, "].parsing_body ", "= ", State[DATA_BYTES-1].next_parsing_body); + log("State[", DATA_BYTES, "].line_status ", "= ", State[DATA_BYTES-1].next_line_status); + log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"); + + isHeaderFieldNameValueMatch[DATA_BYTES - 1] === 1; +} \ No newline at end of file diff --git a/circuits/http/parser/language.circom b/circuits/http/parser/language.circom index faf958b..ffb48bb 100644 --- a/circuits/http/parser/language.circom +++ b/circuits/http/parser/language.circom @@ -24,29 +24,4 @@ template Syntax() { //-Escape-------------------------------------------------------------------------------------// // - ASCII char: `\` signal output ESCAPE <== 92; -} - -template RequestMethod() { - signal output GET[3] <== [71, 69, 84]; - // signal output HEAD[4] <== [72, 69, 65, 68]; - signal output POST[4] <== [80, 79, 83, 84]; - // signal output PUT <== 3; - // signal output DELETE <== 4; - // signal output CONNECT <== 5; - // signal output OPTIONS <== 6; - // signal output TRACE <== 7; - // signal output PATCH <== 8; -} - -// NOTE: Starting at 1 to avoid a false positive with a 0. -template RequestMethodTag() { - signal output GET <== 1; - // signal output HEAD <== 2; - signal output POST <== 3; - // signal output PUT <== 4; - // signal output DELETE <== 5; - // signal output CONNECT <== 6; - // signal output OPTIONS <== 7; - // signal output TRACE <== 8; - // signal output PATCH <== 9; } \ No newline at end of file diff --git a/circuits/test/common/index.ts b/circuits/test/common/index.ts index 1d4c661..55aa960 100644 --- a/circuits/test/common/index.ts +++ b/circuits/test/common/index.ts @@ -67,14 +67,14 @@ export function toByte(data: string): number[] { export function readHTTPInputFile(filename: string) { const filePath = join(__dirname, "..", "..", "..", "examples", "http", filename); - let input: number[] = []; - let data = readFileSync(filePath, 'utf-8'); - input = toByte(data); + let input = toByte(data); - // Split headers and body - const [headerSection, bodySection] = data.split('\r\n\r\n'); + // Split headers and body, accounting for possible lack of body + const parts = data.split('\r\n\r\n'); + const headerSection = parts[0]; + const bodySection = parts.length > 1 ? parts[1] : ''; // Function to parse headers into a dictionary function parseHeaders(headerLines: string[]) { @@ -82,7 +82,7 @@ export function readHTTPInputFile(filename: string) { headerLines.forEach(line => { const [key, value] = line.split(/:\s(.+)/); - headers[key] = value ? value : ''; + if (key) headers[key] = value ? value : ''; }); return headers; @@ -95,8 +95,12 @@ export function readHTTPInputFile(filename: string) { // Parse the body, if JSON response let responseBody = {}; - if (headers["Content-Type"] == "application/json") { - responseBody = JSON.parse(bodySection); + if (headers["Content-Type"] == "application/json" && bodySection) { + try { + responseBody = JSON.parse(bodySection); + } catch (e) { + console.error("Failed to parse JSON body:", e); + } } // Combine headers and body into an object @@ -105,6 +109,6 @@ export function readHTTPInputFile(filename: string) { initialLine: initialLine, headers: headers, body: responseBody, - bodyBytes: toByte(bodySection), + bodyBytes: toByte(bodySection || ''), }; } \ No newline at end of file diff --git a/circuits/test/http/codegen.test.ts b/circuits/test/http/codegen.test.ts new file mode 100644 index 0000000..bf08ac7 --- /dev/null +++ b/circuits/test/http/codegen.test.ts @@ -0,0 +1,139 @@ +import { circomkit, WitnessTester, readHTTPInputFile, toByte } from "../common"; +import { join } from "path"; +import { spawn } from "child_process"; +import { readFileSync } from "fs"; + +function readLockFile(filename: string): HttpData { + const filePath = join(__dirname, "..", "..", "..", "examples", "lockfile", filename); + const jsonString = readFileSync(filePath, 'utf-8'); + const jsonData = JSON.parse(jsonString); + return jsonData; +} + +interface HttpData { + request: Request; + response: Response; +} + +interface Request { + method: string, + target: string, + version: string, + headers: [string, string][], +} + +interface Response { + version: string, + status: string, + message: string, + headers: [string, string][], +} + + +function executeCodegen(inputFilename: string, outputFilename: string) { + return new Promise((resolve, reject) => { + const inputPath = join(__dirname, "..", "..", "..", "examples", "lockfile", inputFilename); + + const codegen = spawn("cargo", ["run", "http-lock", "--lockfile", inputPath, "--output-filename", outputFilename]); + + codegen.stdout.on('data', (data) => { + console.log(`stdout: ${data}`); + }); + + codegen.stderr.on('data', (data) => { + console.error(`stderr: ${data}`); + }); + + codegen.on('close', (code) => { + if (code === 0) { + resolve(`child process exited with code ${code}`); // Resolve the promise if the process exits successfully + } else { + reject(new Error(`Process exited with code ${code}`)); // Reject if there's an error + } + }); + }); +} + +describe("HTTP :: Codegen", async () => { + let circuit: WitnessTester<["data", "beginning", "middle", "final", "header1", "value1", "header2", "value2"], []>; + + it("(valid) get_request:", async () => { + let lockfile = "test.lock"; + let inputfile = "get_request.http"; + + // generate extractor circuit using codegen + await executeCodegen(`${lockfile}.json`, lockfile); + + const lockData = await readLockFile(`${lockfile}.json`); + console.log("lockData: ", JSON.stringify(lockData)); + + const input = await readHTTPInputFile(`${inputfile}`).input + + const params = [input.length, lockData.request.method.length, lockData.request.target.length, lockData.request.version.length]; + lockData.request.headers.forEach(header => { + params.push(header[0].length); // Header name length + params.push(header[1].length); // Header value length + console.log("header: ", header[0]); + console.log("value: ", header[1]); + }); + + + circuit = await circomkit.WitnessTester(`Extract`, { + file: `circuits/main/${lockfile}`, + template: "LockHTTP", + params: params, + }); + console.log("#constraints:", await circuit.getConstraintCount()); + + // match circuit output to original JSON value + await circuit.expectPass({ + data: input, + beginning: toByte(lockData.request.method), + middle: toByte(lockData.request.target), + final: toByte(lockData.request.version), + header1: toByte(lockData.request.headers[0][0]), + value1: toByte(lockData.request.headers[0][1]), + header2: toByte(lockData.request.headers[1][0]), + value2: toByte(lockData.request.headers[1][1]) + }, + {} + ); + }); + + it("(invalid) get_request:", async () => { + let lockfile = "test.lock"; + let inputfile = "get_request.http"; + + // generate extractor circuit using codegen + await executeCodegen(`${lockfile}.json`, lockfile); + + const lockData = await readLockFile(`${lockfile}.json`); + + const input = await readHTTPInputFile(`${inputfile}`).input + + const params = [input.length, lockData.request.method.length, lockData.request.target.length, lockData.request.version.length]; + lockData.request.headers.forEach(header => { + params.push(header[0].length); // Header name length + params.push(header[1].length); // Header value length + }); + + + circuit = await circomkit.WitnessTester(`Extract`, { + file: `circuits/main/${lockfile}`, + template: "LockHTTP", + params: params, + }); + console.log("#constraints:", await circuit.getConstraintCount()); + + await circuit.expectFail({ + data: input.slice(0), + beginning: toByte(lockData.request.method), + middle: toByte(lockData.request.target), + final: toByte(lockData.request.version), + header1: toByte(lockData.request.headers[0][0]), + value1: toByte("/aip"), + header2: toByte(lockData.request.headers[1][0]), + value2: toByte(lockData.request.headers[1][1]) + }); + }); +}); diff --git a/circuits/test/http/extractor.test.ts b/circuits/test/http/extractor.test.ts index cdcfbdc..6619de4 100644 --- a/circuits/test/http/extractor.test.ts +++ b/circuits/test/http/extractor.test.ts @@ -85,4 +85,5 @@ describe("HTTP :: header Extractor", async () => { // // output3.pop(); // TODO: fails due to shift subarray bug // generatePassCase(parsedHttp.input, output3, "output length less than actual length"); }); -}); \ No newline at end of file +}); + diff --git a/circuits/test/http/interpreter.test.ts b/circuits/test/http/interpreter.test.ts index 10cc80c..5c46d95 100644 --- a/circuits/test/http/interpreter.test.ts +++ b/circuits/test/http/interpreter.test.ts @@ -1,25 +1,41 @@ -import { circomkit, WitnessTester, generateDescription } from "../common"; +import { circomkit, WitnessTester, generateDescription, toByte, readHTTPInputFile } from "../common"; describe("HTTP :: Interpreter", async () => { - describe("YieldMethod", async () => { - let circuit: WitnessTester<["bytes"], ["MethodTag"]>; + describe("MethodMatch", async () => { + let circuit: WitnessTester<["data", "method", "r", "index"], []>; - function generatePassCase(input: any, expected: any, depth: number, desc: string) { + function generatePassCase(input: number[], method: number[], index: number, desc: string) { const description = generateDescription(input); it(`(valid) witness: ${description} ${desc}`, async () => { - circuit = await circomkit.WitnessTester(`YieldMethod`, { + circuit = await circomkit.WitnessTester(`LockRequestLineData`, { file: "circuits/http/interpreter", - template: "YieldMethod", - params: [4], + template: "MethodMatch", + params: [input.length, method.length], }); console.log("#constraints:", await circuit.getConstraintCount()); - await circuit.expectPass(input, expected); + await circuit.expectPass({ data: input, method: method, r: 100, index: index }, {}); }); } - // The string `"GET "` - generatePassCase({ bytes: [71, 69, 84, 32] }, { MethodTag: 1 }, 0, ""); + function generateFailCase(input: number[], method: number[], index: number, desc: string) { + const description = generateDescription(input); + + it(`(invalid) witness: ${description} ${desc}`, async () => { + circuit = await circomkit.WitnessTester(`LockRequestLineData`, { + file: "circuits/http/interpreter", + template: "MethodMatch", + params: [input.length, method.length], + }); + console.log("#constraints:", await circuit.getConstraintCount()); + + await circuit.expectFail({ data: input, method: method, r: 100, index: index }); + }); + } + + let parsedHttp = readHTTPInputFile("get_request.http"); + generatePassCase(parsedHttp.input, toByte("GET"), 0, ""); + generateFailCase(parsedHttp.input, toByte("POST"), 0, ""); }); }); \ No newline at end of file diff --git a/circuits/test/http/locker.test.ts b/circuits/test/http/locker.test.ts new file mode 100644 index 0000000..1a67fc1 --- /dev/null +++ b/circuits/test/http/locker.test.ts @@ -0,0 +1,135 @@ +import { circomkit, WitnessTester, generateDescription, toByte, readHTTPInputFile } from "../common"; + +describe("HTTP :: Locker :: Request Line", async () => { + let circuit: WitnessTester<["data", "beginning", "middle", "final"], []>; + + function generatePassCase(input: number[], beginning: number[], middle: number[], final: number[], desc: string) { + const description = generateDescription(input); + + it(`(valid) witness: ${description} ${desc}`, async () => { + circuit = await circomkit.WitnessTester(`LockStartLine`, { + file: "circuits/http/locker", + template: "LockStartLine", + params: [input.length, beginning.length, middle.length, final.length], + }); + console.log("#constraints:", await circuit.getConstraintCount()); + + await circuit.expectPass({ data: input, beginning: beginning, middle: middle, final: final }, {}); + }); + } + + function generateFailCase(input: number[], beginning: number[], middle: number[], final: number[], desc: string) { + const description = generateDescription(input); + + it(`(invalid) witness: ${description} ${desc}`, async () => { + circuit = await circomkit.WitnessTester(`LockStartLine`, { + file: "circuits/http/locker", + template: "LockStartLine", + params: [input.length, beginning.length, middle.length, final.length], + }); + console.log("#constraints:", await circuit.getConstraintCount()); + + await circuit.expectFail({ data: input, beginning: beginning, middle: middle, final: final }); + }); + } + + describe("GET", async () => { + let parsedHttp = readHTTPInputFile("get_request.http"); + generatePassCase(parsedHttp.input, toByte("GET"), toByte("/api"), toByte("HTTP/1.1"), ""); + generateFailCase(parsedHttp.input.slice(0), toByte("POST"), toByte("/api"), toByte("HTTP/1.1"), ""); + generateFailCase(parsedHttp.input.slice(0), toByte("GET"), toByte("/"), toByte("HTTP/1.1"), ""); + generateFailCase(parsedHttp.input.slice(0), toByte("GET"), toByte("/api"), toByte("HTTP"), ""); + }); + + describe("POST", async () => { + let parsedHttp = readHTTPInputFile("post_request.http"); + generatePassCase(parsedHttp.input, toByte("POST"), toByte("/contact_form.php"), toByte("HTTP/1.1"), ""); + generateFailCase(parsedHttp.input.slice(0), toByte("GET"), toByte("/contact_form.php"), toByte("HTTP/1.1"), ""); + generateFailCase(parsedHttp.input.slice(0), toByte("POST"), toByte("/"), toByte("HTTP/1.1"), ""); + generateFailCase(parsedHttp.input.slice(0), toByte("POST"), toByte("/contact_form.php"), toByte("HTTP"), ""); + }); +}); + +describe("HTTP :: Locker :: Status Line", async () => { + let circuit: WitnessTester<["data", "beginning", "middle", "final"], []>; + + function generatePassCase(input: number[], beginning: number[], middle: number[], final: number[], desc: string) { + const description = generateDescription(input); + + it(`(valid) witness: ${description} ${desc}`, async () => { + circuit = await circomkit.WitnessTester(`LockStartLine`, { + file: "circuits/http/locker", + template: "LockStartLine", + params: [input.length, beginning.length, middle.length, final.length], + }); + console.log("#constraints:", await circuit.getConstraintCount()); + + await circuit.expectPass({ data: input, beginning: beginning, middle: middle, final: final }, {}); + }); + } + + function generateFailCase(input: number[], beginning: number[], middle: number[], final: number[], desc: string) { + const description = generateDescription(input); + + it(`(invalid) witness: ${description} ${desc}`, async () => { + circuit = await circomkit.WitnessTester(`LockStartLine`, { + file: "circuits/http/locker", + template: "LockStartLine", + params: [input.length, beginning.length, middle.length, final.length], + }); + console.log("#constraints:", await circuit.getConstraintCount()); + + await circuit.expectFail({ data: input, beginning: beginning, middle: middle, final: final }); + }); + } + + describe("GET", async () => { + let parsedHttp = readHTTPInputFile("get_response.http"); + generatePassCase(parsedHttp.input, toByte("HTTP/1.1"), toByte("200"), toByte("OK"), ""); + generateFailCase(parsedHttp.input, toByte("HTTP"), toByte("200"), toByte("OK"), ""); + generateFailCase(parsedHttp.input, toByte("HTTP/1.1"), toByte("404"), toByte("OK"), ""); + generateFailCase(parsedHttp.input, toByte("HTTP/1.1"), toByte("200"), toByte("Not Found"), ""); + }); +}); + +describe("HTTP :: Locker :: Header", async () => { + let circuit: WitnessTester<["data", "header", "value"], []>; + + function generatePassCase(input: number[], header: number[], value: number[], desc: string) { + const description = generateDescription(input); + + it(`(valid) witness: ${description} ${desc}`, async () => { + circuit = await circomkit.WitnessTester(`LockHeader`, { + file: "circuits/http/locker", + template: "LockHeader", + params: [input.length, header.length, value.length], + }); + console.log("#constraints:", await circuit.getConstraintCount()); + + await circuit.expectPass({ data: input, header: header, value: value }, {}); + }); + } + + function generateFailCase(input: number[], header: number[], value: number[], desc: string) { + const description = generateDescription(input); + + it(`(invalid) witness: ${description} ${desc}`, async () => { + circuit = await circomkit.WitnessTester(`LockHeader`, { + file: "circuits/http/locker", + template: "LockHeader", + params: [input.length, header.length, value.length], + }); + console.log("#constraints:", await circuit.getConstraintCount()); + + await circuit.expectFail({ data: input, header: header, value: value }); + }); + } + + describe("GET", async () => { + let parsedHttp = readHTTPInputFile("get_request.http"); + generatePassCase(parsedHttp.input, toByte("Host"), toByte("localhost"), ""); + generateFailCase(parsedHttp.input, toByte("Accept"), toByte("localhost"), ""); + generateFailCase(parsedHttp.input, toByte("Host"), toByte("venmo.com"), ""); + generateFailCase(parsedHttp.input, toByte("Connection"), toByte("keep-alive"), ""); + }); +}); \ No newline at end of file diff --git a/examples/lockfile/test.lock.json b/examples/lockfile/test.lock.json index a95f84e..e4eada6 100644 --- a/examples/lockfile/test.lock.json +++ b/examples/lockfile/test.lock.json @@ -17,6 +17,7 @@ "response": { "version": "HTTP/1.1", "status": "200", + "message": "OK", "headers": [ [ "Content-Type", diff --git a/src/http_lock.rs b/src/http_lock.rs index 7bf2c24..907c8a3 100644 --- a/src/http_lock.rs +++ b/src/http_lock.rs @@ -12,15 +12,231 @@ struct Request { target: String, version: String, headers: Vec<(String, String)>, - #[serde(rename = "Host")] - host: String, } #[derive(Debug, Serialize, Deserialize)] struct Response { version: String, status: String, - headers: Vec<(String, serde_json::Value)>, + message: String, + headers: Vec<(String, String)>, +} + +use std::fs::{self, create_dir_all}; + +const PRAGMA: &str = "pragma circom 2.1.9;\n\n"; + +fn request_locker_circuit( + data: HttpData, + output_filename: String, +) -> Result<(), Box> { + let mut circuit_buffer = String::new(); + circuit_buffer += PRAGMA; + circuit_buffer += "include \"../http/interpreter.circom\";\n"; + circuit_buffer += "include \"../http/parser/machine.circom\";\n"; + circuit_buffer += "include \"../utils/bytes.circom\";\n"; + circuit_buffer += "include \"../utils/search.circom\";\n"; + circuit_buffer += "include \"circomlib/circuits/gates.circom\";\n"; + circuit_buffer += "include \"@zk-email/circuits/utils/array.circom\";\n\n"; + + // template LockHTTP(DATA_BYTES, beginningLen, middleLen, finalLen, headerNameLen1, headerValueLen1, ...) { + { + circuit_buffer += "template LockHTTP(DATA_BYTES, beginningLen, middleLen, finalLen "; + for (i, _header) in data.request.headers.iter().enumerate() { + circuit_buffer += &format!(", headerNameLen{}, headerValueLen{}", i + 1, i + 1); + } + circuit_buffer += ") {"; + } + + /* + signal input data[DATA_BYTES]; + + signal input key1[keyLen1]; + signal input key3[keyLen3]; + */ + { + circuit_buffer += r#" + signal input data[DATA_BYTES]; + + // Start line signals + signal input beginning[beginningLen]; + signal input middle[middleLen]; + signal input final[finalLen]; + + // Header signals +"#; + + for (i, _header) in data.request.headers.iter().enumerate() { + circuit_buffer += &format!( + " signal input header{}[headerNameLen{}];\n", + i + 1, + i + 1 + ); + circuit_buffer += &format!( + " signal input value{}[headerValueLen{}];\n", + i + 1, + i + 1 + ); + } + } + + // Setup for parsing the start line + { + circuit_buffer += r#" + // Check first beginning byte + signal beginningIsEqual[beginningLen]; + beginningIsEqual[0] <== IsEqual()([data[0],beginning[0]]); + beginningIsEqual[0] === 1; + + // Setup to check middle bytes + signal startLineMask[DATA_BYTES]; + signal middleMask[DATA_BYTES]; + signal finalMask[DATA_BYTES]; + + var middle_start_counter = 1; + var middle_end_counter = 1; + var final_end_counter = 1; +"#; + } + + circuit_buffer += r#" + component State[DATA_BYTES]; + State[0] = StateUpdate(); + State[0].byte <== data[0]; + State[0].parsing_start <== 1; + State[0].parsing_header <== 0; + State[0].parsing_field_name <== 0; + State[0].parsing_field_value <== 0; + State[0].parsing_body <== 0; + State[0].line_status <== 0; + +"#; + + // Create header match signals + { + for (i, _header) in data.request.headers.iter().enumerate() { + circuit_buffer += &format!(" signal headerNameValueMatch{}[DATA_BYTES];\n", i + 1); + circuit_buffer += &format!(" headerNameValueMatch{}[0] <== 0;\n", i + 1); + circuit_buffer += &format!(" var hasMatchedHeaderValue{} = 0;\n\n", i + 1); + } + } + + circuit_buffer += r#" + for(var data_idx = 1; data_idx < DATA_BYTES; data_idx++) { + State[data_idx] = StateUpdate(); + State[data_idx].byte <== data[data_idx]; + State[data_idx].parsing_start <== State[data_idx - 1].next_parsing_start; + State[data_idx].parsing_header <== State[data_idx - 1].next_parsing_header; + State[data_idx].parsing_field_name <== State[data_idx-1].next_parsing_field_name; + State[data_idx].parsing_field_value <== State[data_idx-1].next_parsing_field_value; + State[data_idx].parsing_body <== State[data_idx - 1].next_parsing_body; + State[data_idx].line_status <== State[data_idx - 1].next_line_status; + +"#; + // Start line matches + { + circuit_buffer += r#" + // Check remaining beginning bytes + if(data_idx < beginningLen) { + beginningIsEqual[data_idx] <== IsEqual()([data[data_idx], beginning[data_idx]]); + beginningIsEqual[data_idx] === 1; + } + + // Middle + startLineMask[data_idx] <== inStartLine()(State[data_idx].parsing_start); + middleMask[data_idx] <== inStartMiddle()(State[data_idx].parsing_start); + finalMask[data_idx] <== inStartEnd()(State[data_idx].parsing_start); + middle_start_counter += startLineMask[data_idx] - middleMask[data_idx] - finalMask[data_idx]; + // The end of middle is the start of the final + middle_end_counter += startLineMask[data_idx] - finalMask[data_idx]; + final_end_counter += startLineMask[data_idx]; + +"#; + } + + // Header matches + { + for (i, _header) in data.request.headers.iter().enumerate() { + circuit_buffer += &format!(" headerNameValueMatch{}[data_idx] <== HeaderFieldNameValueMatch(DATA_BYTES, headerNameLen{}, headerValueLen{})(data, header{}, value{}, 100, data_idx);\n", i + 1,i + 1,i + 1,i + 1,i + 1); + circuit_buffer += &format!( + " hasMatchedHeaderValue{} += headerNameValueMatch{}[data_idx];\n", + i + 1, + i + 1 + ); + } + } + + // debugging + circuit_buffer += r#" + // Debugging + log("State[", data_idx, "].parsing_start ", "= ", State[data_idx].parsing_start); + log("State[", data_idx, "].parsing_header ", "= ", State[data_idx].parsing_header); + log("State[", data_idx, "].parsing_field_name ", "= ", State[data_idx].parsing_field_name); + log("State[", data_idx, "].parsing_field_value", "= ", State[data_idx].parsing_field_value); + log("State[", data_idx, "].parsing_body ", "= ", State[data_idx].parsing_body); + log("State[", data_idx, "].line_status ", "= ", State[data_idx].line_status); + log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"); +"#; + + circuit_buffer += " }"; + + // debugging + circuit_buffer += r#" + // Debugging + log("State[", DATA_BYTES, "].parsing_start ", "= ", State[DATA_BYTES-1].next_parsing_start); + log("State[", DATA_BYTES, "].parsing_header ", "= ", State[DATA_BYTES-1].next_parsing_header); + log("State[", DATA_BYTES, "].parsing_field_name ", "= ", State[DATA_BYTES-1].parsing_field_name); + log("State[", DATA_BYTES, "].parsing_field_value", "= ", State[DATA_BYTES-1].parsing_field_value); + log("State[", DATA_BYTES, "].parsing_body ", "= ", State[DATA_BYTES-1].next_parsing_body); + log("State[", DATA_BYTES, "].line_status ", "= ", State[DATA_BYTES-1].next_line_status); + log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"); + +"#; + // Verify all start line has matched + { + circuit_buffer += r#" + // Additionally verify beginning had correct length + beginningLen === middle_start_counter - 1; + + // Check middle is correct by substring match and length check + // TODO: change r + signal middleMatch <== SubstringMatchWithIndex(DATA_BYTES, middleLen)(data, middle, 100, middle_start_counter); + middleMatch === 1; + middleLen === middle_end_counter - middle_start_counter - 1; + + // Check final is correct by substring match and length check + // TODO: change r + signal finalMatch <== SubstringMatchWithIndex(DATA_BYTES, finalLen)(data, final, 100, middle_end_counter); + finalMatch === 1; + // -2 here for the CRLF + finalLen === final_end_counter - middle_end_counter - 2; +"#; + } + + // Verify all headers have matched + { + for (i, _header) in data.request.headers.iter().enumerate() { + circuit_buffer += &format!(" hasMatchedHeaderValue{} === 1;\n", i + 1); + } + } + // End file + circuit_buffer += "\n}"; + + // write circuits to file + let mut file_path = std::env::current_dir()?; + file_path.push("circuits"); + file_path.push("main"); + + // create dir if doesn't exist + create_dir_all(&file_path)?; + + file_path.push(format!("{}.circom", output_filename)); + + fs::write(&file_path, circuit_buffer)?; + + println!("Code generated at: {}", file_path.display()); + + Ok(()) } // TODO: This needs to codegen a circuit now. @@ -28,7 +244,7 @@ pub fn http_lock(args: HttpLockArgs) -> Result<(), Box> { let data = std::fs::read(&args.lockfile)?; let http_data: HttpData = serde_json::from_slice(&data)?; - dbg!(http_data); + request_locker_circuit(http_data, args.output_filename)?; Ok(()) }