diff --git a/circuits/http/extractor.circom b/circuits/http/extractor.circom index d6b3dd6..03ad0c2 100644 --- a/circuits/http/extractor.circom +++ b/circuits/http/extractor.circom @@ -72,7 +72,7 @@ template ExtractResponse(DATA_BYTES, maxContentLength) { signal isPrevStartingIndex[DATA_BYTES]; valueStartingIndex[0] <== 0; isZeroMask[0] <== IsZero()(dataMask[0]); - for (var i=1 ; i(filename: string): T { const filePath = join(__dirname, "..", "..", "..", "examples", "lockfile", filename); const jsonString = readFileSync(filePath, 'utf-8'); const jsonData = JSON.parse(jsonString); return jsonData; } -interface HttpData { - request: Request; - response: Response; +function getHeaders(data: Request | Response): [string, string][] { + const headers: [string, string][] = []; + let i = 1; + while (true) { + const nameKey = `headerName${i}`; + const valueKey = `headerValue${i}`; + if (nameKey in data && valueKey in data) { + headers.push([data[nameKey], data[valueKey]]); + i++; + } else { + break; + } + } + return headers; } interface Request { method: string, target: string, version: string, - headers: [string, string][], + [key: string]: string, } interface Response { version: string, status: string, message: string, - headers: [string, string][], + [key: string]: string, } @@ -54,86 +65,90 @@ function executeCodegen(inputFilename: string, outputFilename: string) { }); } -describe("HTTP :: Codegen", async () => { - let circuit: WitnessTester<["data", "beginning", "middle", "final", "header1", "value1", "header2", "value2"], []>; +describe("HTTP :: Codegen :: Request", async () => { + let circuit: WitnessTester<["data", "method", "target", "version", "header1", "value1", "header2", "value2"], []>; - it("(valid) get_request:", async () => { - let lockfile = "test.lock"; + it("(valid) GET:", async () => { + let lockfile = "request.lock"; let inputfile = "get_request.http"; // generate extractor circuit using codegen await executeCodegen(`${lockfile}.json`, lockfile); - const lockData = await readLockFile(`${lockfile}.json`); + const lockData = await readLockFile(`${lockfile}.json`); console.log("lockData: ", JSON.stringify(lockData)); const input = await readHTTPInputFile(`${inputfile}`).input - const params = [input.length, lockData.request.method.length, lockData.request.target.length, lockData.request.version.length]; - lockData.request.headers.forEach(header => { - params.push(header[0].length); // Header name length - params.push(header[1].length); // Header value length - console.log("header: ", header[0]); - console.log("value: ", header[1]); + const headers = getHeaders(lockData); + const params = [input.length, lockData.method.length, lockData.target.length, lockData.version.length]; + headers.forEach(header => { + params.push(header[0].length); + params.push(header[1].length); }); circuit = await circomkit.WitnessTester(`Extract`, { file: `circuits/main/${lockfile}`, - template: "LockHTTP", + template: "LockHTTPRequest", params: params, }); console.log("#constraints:", await circuit.getConstraintCount()); // match circuit output to original JSON value - await circuit.expectPass({ + const circuitInput: any = { data: input, - beginning: toByte(lockData.request.method), - middle: toByte(lockData.request.target), - final: toByte(lockData.request.version), - header1: toByte(lockData.request.headers[0][0]), - value1: toByte(lockData.request.headers[0][1]), - header2: toByte(lockData.request.headers[1][0]), - value2: toByte(lockData.request.headers[1][1]) - }, - {} - ); + method: toByte(lockData.method), + target: toByte(lockData.target), + version: toByte(lockData.version), + }; + + headers.forEach((header, index) => { + circuitInput[`header${index + 1}`] = toByte(header[0]); + circuitInput[`value${index + 1}`] = toByte(header[1]); + }); + await circuit.expectPass(circuitInput, {}); }); - it("(invalid) get_request:", async () => { - let lockfile = "test.lock"; + it("(invalid) GET:", async () => { + let lockfile = "request.lock"; let inputfile = "get_request.http"; // generate extractor circuit using codegen await executeCodegen(`${lockfile}.json`, lockfile); - const lockData = await readLockFile(`${lockfile}.json`); + const lockData = await readLockFile(`${lockfile}.json`); const input = await readHTTPInputFile(`${inputfile}`).input - const params = [input.length, lockData.request.method.length, lockData.request.target.length, lockData.request.version.length]; - lockData.request.headers.forEach(header => { - params.push(header[0].length); // Header name length - params.push(header[1].length); // Header value length + const headers = getHeaders(lockData); + const params = [input.length, lockData.method.length, lockData.target.length, lockData.version.length]; + headers.forEach(header => { + params.push(header[0].length); + params.push(header[1].length); }); circuit = await circomkit.WitnessTester(`Extract`, { file: `circuits/main/${lockfile}`, - template: "LockHTTP", + template: "LockHTTPRequest", params: params, }); console.log("#constraints:", await circuit.getConstraintCount()); - await circuit.expectFail({ - data: input.slice(0), - beginning: toByte(lockData.request.method), - middle: toByte(lockData.request.target), - final: toByte(lockData.request.version), - header1: toByte(lockData.request.headers[0][0]), - value1: toByte("/aip"), - header2: toByte(lockData.request.headers[1][0]), - value2: toByte(lockData.request.headers[1][1]) + const circuitInput: any = { + data: input, + method: toByte(lockData.method), + target: toByte(lockData.target), + version: toByte(lockData.version), + }; + + headers.forEach((header, index) => { + circuitInput[`header${index + 1}`] = toByte(header[0]); + circuitInput[`value${index + 1}`] = toByte(header[1]); }); + + circuitInput.value1 = toByte("/aip"); + await circuit.expectFail(circuitInput); }); }); diff --git a/src/http_lock.rs b/src/http_lock.rs index 881d425..4238b88 100644 --- a/src/http_lock.rs +++ b/src/http_lock.rs @@ -63,6 +63,7 @@ const PRAGMA: &str = "pragma circom 2.1.9;\n\n"; fn locker_circuit( data: HttpData, + debug: bool, output_filename: String, ) -> Result<(), Box> { let mut circuit_buffer = String::new(); @@ -74,7 +75,6 @@ fn locker_circuit( circuit_buffer += "include \"circomlib/circuits/gates.circom\";\n"; circuit_buffer += "include \"@zk-email/circuits/utils/array.circom\";\n\n"; - // template LockHTTP(DATA_BYTES, beginningLen, middleLen, finalLen, headerNameLen1, headerValueLen1, ...) { { match data { HttpData::Request(_) => { @@ -93,24 +93,38 @@ fn locker_circuit( circuit_buffer += ") {"; } - /* - signal input data[DATA_BYTES]; - - signal input key1[keyLen1]; - signal input key3[keyLen3]; - */ { circuit_buffer += r#" + // Raw HTTP bytestream signal input data[DATA_BYTES]; +"#; - // Start line signals - signal input beginning[beginningLen]; - signal input middle[middleLen]; - signal input final[finalLen]; + // Start line signals + { + match data { + HttpData::Request(_) => { + circuit_buffer += r#" + // Request line attributes + signal input method[methodLen]; + signal input target[targetLen]; + signal input version[versionLen]; - // Header signals "#; + } + HttpData::Response(_) => { + circuit_buffer += r#" + // Status line attributes + signal input version[versionLen]; + signal input status[statusLen]; + signal input message[messageLen]; + +"#; + } + } + } + // Header signals + circuit_buffer += " // Header names and values to lock\n"; for (i, _header) in data.headers().iter().enumerate() { circuit_buffer += &format!( " signal input header{}[headerNameLen{}];\n", @@ -127,33 +141,54 @@ fn locker_circuit( // Setup for parsing the start line { - circuit_buffer += r#" - // Check first beginning byte - signal beginningIsEqual[beginningLen]; - beginningIsEqual[0] <== IsEqual()([data[0],beginning[0]]); - beginningIsEqual[0] === 1; + match data { + HttpData::Request(_) => { + circuit_buffer += r#" + // Check first method byte + signal methodIsEqual[methodLen]; + methodIsEqual[0] <== IsEqual()([data[0],method[0]]); + methodIsEqual[0] === 1; - // Setup to check middle bytes + // Setup to check target and version bytes signal startLineMask[DATA_BYTES]; - signal middleMask[DATA_BYTES]; - signal finalMask[DATA_BYTES]; + signal targetMask[DATA_BYTES]; + signal versionMask[DATA_BYTES]; - var middle_start_counter = 1; - var middle_end_counter = 1; - var final_end_counter = 1; + var target_start_counter = 1; + var target_end_counter = 1; + var version_end_counter = 1; "#; + } + HttpData::Response(_) => { + circuit_buffer += r#" + // Check first version byte + signal versionIsEqual[versionLen]; + versionIsEqual[0] <== IsEqual()([data[0],version[0]]); + versionIsEqual[0] === 1; + + // Setup to check status and message bytes + signal startLineMask[DATA_BYTES]; + signal statusMask[DATA_BYTES]; + signal messageMask[DATA_BYTES]; + + var status_start_counter = 1; + var status_end_counter = 1; + var message_end_counter = 1; +"#; + } + } } circuit_buffer += r#" component State[DATA_BYTES]; - State[0] = StateUpdate(); - State[0].byte <== data[0]; - State[0].parsing_start <== 1; - State[0].parsing_header <== 0; - State[0].parsing_field_name <== 0; + State[0] = StateUpdate(); + State[0].byte <== data[0]; + State[0].parsing_start <== 1; + State[0].parsing_header <== 0; + State[0].parsing_field_name <== 0; State[0].parsing_field_value <== 0; - State[0].parsing_body <== 0; - State[0].line_status <== 0; + State[0].parsing_body <== 0; + State[0].line_status <== 0; "#; @@ -166,37 +201,64 @@ fn locker_circuit( } } - circuit_buffer += r#" + // Intro loop + { + circuit_buffer += r#" for(var data_idx = 1; data_idx < DATA_BYTES; data_idx++) { - State[data_idx] = StateUpdate(); - State[data_idx].byte <== data[data_idx]; - State[data_idx].parsing_start <== State[data_idx - 1].next_parsing_start; - State[data_idx].parsing_header <== State[data_idx - 1].next_parsing_header; - State[data_idx].parsing_field_name <== State[data_idx-1].next_parsing_field_name; + State[data_idx] = StateUpdate(); + State[data_idx].byte <== data[data_idx]; + State[data_idx].parsing_start <== State[data_idx - 1].next_parsing_start; + State[data_idx].parsing_header <== State[data_idx - 1].next_parsing_header; + State[data_idx].parsing_field_name <== State[data_idx-1].next_parsing_field_name; State[data_idx].parsing_field_value <== State[data_idx-1].next_parsing_field_value; - State[data_idx].parsing_body <== State[data_idx - 1].next_parsing_body; - State[data_idx].line_status <== State[data_idx - 1].next_line_status; + State[data_idx].parsing_body <== State[data_idx - 1].next_parsing_body; + State[data_idx].line_status <== State[data_idx - 1].next_line_status; "#; + } + // Start line matches { - circuit_buffer += r#" - // Check remaining beginning bytes - if(data_idx < beginningLen) { - beginningIsEqual[data_idx] <== IsEqual()([data[data_idx], beginning[data_idx]]); - beginningIsEqual[data_idx] === 1; + match data { + HttpData::Request(_) => { + circuit_buffer += r#" + // Check remaining method bytes + if(data_idx < methodLen) { + methodIsEqual[data_idx] <== IsEqual()([data[data_idx], method[data_idx]]); + methodIsEqual[data_idx] === 1; + } + + // Get the target bytes + startLineMask[data_idx] <== inStartLine()(State[data_idx].parsing_start); + targetMask[data_idx] <== inStartMiddle()(State[data_idx].parsing_start); + versionMask[data_idx] <== inStartEnd()(State[data_idx].parsing_start); + target_start_counter += startLineMask[data_idx] - targetMask[data_idx] - versionMask[data_idx]; + + // Get the version bytes + target_end_counter += startLineMask[data_idx] - versionMask[data_idx]; + version_end_counter += startLineMask[data_idx]; +"#; + } + HttpData::Response(_) => { + circuit_buffer += r#" + // Check remaining version bytes + if(data_idx < versionLen) { + versionIsEqual[data_idx] <== IsEqual()([data[data_idx], version[data_idx]]); + versionIsEqual[data_idx] === 1; } - // Middle - startLineMask[data_idx] <== inStartLine()(State[data_idx].parsing_start); - middleMask[data_idx] <== inStartMiddle()(State[data_idx].parsing_start); - finalMask[data_idx] <== inStartEnd()(State[data_idx].parsing_start); - middle_start_counter += startLineMask[data_idx] - middleMask[data_idx] - finalMask[data_idx]; - // The end of middle is the start of the final - middle_end_counter += startLineMask[data_idx] - finalMask[data_idx]; - final_end_counter += startLineMask[data_idx]; + // Get the status bytes + startLineMask[data_idx] <== inStartLine()(State[data_idx].parsing_start); + statusMask[data_idx] <== inStartMiddle()(State[data_idx].parsing_start); + messageMask[data_idx] <== inStartEnd()(State[data_idx].parsing_start); + status_start_counter += startLineMask[data_idx] - statusMask[data_idx] - messageMask[data_idx]; + // Get the message bytes + status_end_counter += startLineMask[data_idx] - messageMask[data_idx]; + message_end_counter += startLineMask[data_idx]; "#; + } + } } // Header matches @@ -212,7 +274,8 @@ fn locker_circuit( } // debugging - circuit_buffer += r#" + if debug { + circuit_buffer += r#" // Debugging log("State[", data_idx, "].parsing_start ", "= ", State[data_idx].parsing_start); log("State[", data_idx, "].parsing_header ", "= ", State[data_idx].parsing_header); @@ -222,11 +285,13 @@ fn locker_circuit( log("State[", data_idx, "].line_status ", "= ", State[data_idx].line_status); log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"); "#; + } circuit_buffer += " }"; // debugging - circuit_buffer += r#" + if debug { + circuit_buffer += r#" // Debugging log("State[", DATA_BYTES, "].parsing_start ", "= ", State[DATA_BYTES-1].next_parsing_start); log("State[", DATA_BYTES, "].parsing_header ", "= ", State[DATA_BYTES-1].next_parsing_header); @@ -237,25 +302,50 @@ fn locker_circuit( log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"); "#; + } + // Verify all start line has matched { - circuit_buffer += r#" - // Additionally verify beginning had correct length - beginningLen === middle_start_counter - 1; + match data { + HttpData::Request(_) => { + circuit_buffer += r#" + // Verify method had correct length + methodLen === target_start_counter - 1; - // Check middle is correct by substring match and length check + // Check target is correct by substring match and length check // TODO: change r - signal middleMatch <== SubstringMatchWithIndex(DATA_BYTES, middleLen)(data, middle, 100, middle_start_counter); - middleMatch === 1; - middleLen === middle_end_counter - middle_start_counter - 1; + signal targetMatch <== SubstringMatchWithIndex(DATA_BYTES, targetLen)(data, target, 100, target_start_counter); + targetMatch === 1; + targetLen === target_end_counter - target_start_counter - 1; - // Check final is correct by substring match and length check + // Check version is correct by substring match and length check // TODO: change r - signal finalMatch <== SubstringMatchWithIndex(DATA_BYTES, finalLen)(data, final, 100, middle_end_counter); - finalMatch === 1; + signal versionMatch <== SubstringMatchWithIndex(DATA_BYTES, versionLen)(data, version, 100, target_end_counter); + versionMatch === 1; // -2 here for the CRLF - finalLen === final_end_counter - middle_end_counter - 2; + versionLen === version_end_counter - target_end_counter - 2; "#; + } + HttpData::Response(_) => { + circuit_buffer += r#" + // Verify version had correct length + versionLen === target_start_counter - 1; + + // Check status is correct by substring match and length check + // TODO: change r + signal statusMatch <== SubstringMatchWithIndex(DATA_BYTES, statusLen)(data, status, 100, status_start_counter); + statusMatch === 1; + statusLen === status_end_counter - status_start_counter - 1; + + // Check message is correct by substring match and length check + // TODO: change r + signal messageMatch <== SubstringMatchWithIndex(DATA_BYTES, messageLen)(data, message, 100, status_end_counter); + messageMatch === 1; + // -2 here for the CRLF + messageLen === message_end_counter - status_end_counter - 2; +"#; + } + } } // Verify all headers have matched @@ -290,7 +380,7 @@ pub fn http_lock(args: HttpLockArgs) -> Result<(), Box> { let http_data: HttpData = serde_json::from_slice(&data)?; - locker_circuit(http_data, args.output_filename)?; + locker_circuit(http_data, args.debug, args.output_filename)?; Ok(()) } diff --git a/src/main.rs b/src/main.rs index 2312c5e..f5518a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -53,6 +53,9 @@ pub struct ExtractorArgs { /// Output circuit file name #[arg(long, default_value = "extractor")] output_filename: String, + + #[arg(long, default_value = "false")] + debug: bool, } #[derive(Parser, Debug)] @@ -64,6 +67,9 @@ pub struct HttpLockArgs { /// Output circuit file name #[arg(long, default_value = "extractor")] output_filename: String, + + #[arg(long, default_value = "false")] + debug: bool, } pub fn main() -> Result<(), Box> {