Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: parser improvs #75

Merged
merged 9 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 104 additions & 9 deletions circuits/http/extractor.circom
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
pragma circom 2.1.9;

include "../utils/bytes.circom";
include "interpreter.circom";
include "parser/machine.circom";
include "../utils/bytes.circom";
include "../utils/search.circom";
include "circomlib/circuits/mux1.circom";
include "circomlib/circuits/gates.circom";
include "@zk-email/circuits/utils/array.circom";

// TODO:
Expand All @@ -24,6 +28,8 @@ template ExtractResponse(DATA_BYTES, maxContentLength) {
State[0].byte <== data[0];
State[0].parsing_start <== 1;
State[0].parsing_header <== 0;
State[0].parsing_field_name <== 0;
State[0].parsing_field_value <== 0;
State[0].parsing_body <== 0;
State[0].line_status <== 0;

Expand All @@ -35,25 +41,31 @@ template ExtractResponse(DATA_BYTES, maxContentLength) {
State[data_idx].byte <== data[data_idx];
State[data_idx].parsing_start <== State[data_idx - 1].next_parsing_start;
State[data_idx].parsing_header <== State[data_idx - 1].next_parsing_header;
State[data_idx].parsing_field_name <== State[data_idx-1].next_parsing_field_name;
State[data_idx].parsing_field_value <== State[data_idx-1].next_parsing_field_value;
State[data_idx].parsing_body <== State[data_idx - 1].next_parsing_body;
State[data_idx].line_status <== State[data_idx - 1].next_line_status;

// apply body mask to data
dataMask[data_idx] <== data[data_idx] * State[data_idx].next_parsing_body;

// Debugging
log("State[", data_idx, "].parsing_start ", "= ", State[data_idx].parsing_start);
log("State[", data_idx, "].parsing_header", "= ", State[data_idx].parsing_header);
log("State[", data_idx, "].parsing_body ", "= ", State[data_idx].parsing_body);
log("State[", data_idx, "].line_status ", "= ", State[data_idx].line_status);
log("State[", data_idx, "].parsing_start ", "= ", State[data_idx].parsing_start);
log("State[", data_idx, "].parsing_header ", "= ", State[data_idx].parsing_header);
log("State[", data_idx, "].parsing_field_name ", "= ", State[data_idx].parsing_field_name);
log("State[", data_idx, "].parsing_field_value", "= ", State[data_idx].parsing_field_value);
log("State[", data_idx, "].parsing_body ", "= ", State[data_idx].parsing_body);
log("State[", data_idx, "].line_status ", "= ", State[data_idx].line_status);
log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");
}

// Debugging
log("State[", DATA_BYTES, "].parsing_start ", "= ", State[DATA_BYTES-1].next_parsing_start);
log("State[", DATA_BYTES, "].parsing_header", "= ", State[DATA_BYTES-1].next_parsing_header);
log("State[", DATA_BYTES, "].parsing_body ", "= ", State[DATA_BYTES-1].next_parsing_body);
log("State[", DATA_BYTES, "].line_status ", "= ", State[DATA_BYTES-1].next_line_status);
log("State[", DATA_BYTES, "].parsing_start ", "= ", State[DATA_BYTES-1].next_parsing_start);
log("State[", DATA_BYTES, "].parsing_header ", "= ", State[DATA_BYTES-1].next_parsing_header);
log("State[", DATA_BYTES, "].parsing_field_name ", "= ", State[DATA_BYTES-1].parsing_field_name);
log("State[", DATA_BYTES, "].parsing_field_value", "= ", State[DATA_BYTES-1].parsing_field_value);
log("State[", DATA_BYTES, "].parsing_body ", "= ", State[DATA_BYTES-1].next_parsing_body);
log("State[", DATA_BYTES, "].line_status ", "= ", State[DATA_BYTES-1].next_line_status);
log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");

signal valueStartingIndex[DATA_BYTES];
Expand All @@ -68,4 +80,87 @@ template ExtractResponse(DATA_BYTES, maxContentLength) {
}

response <== SelectSubArray(DATA_BYTES, maxContentLength)(dataMask, valueStartingIndex[DATA_BYTES-1]+1, DATA_BYTES - valueStartingIndex[DATA_BYTES-1]);
}

template ExtractHeaderValue(DATA_BYTES, headerNameLength, maxValueLength) {
signal input data[DATA_BYTES];
signal input header[headerNameLength];

signal output value[maxValueLength];

//--------------------------------------------------------------------------------------------//
//-CONSTRAINTS--------------------------------------------------------------------------------//
//--------------------------------------------------------------------------------------------//
component dataASCII = ASCII(DATA_BYTES);
dataASCII.in <== data;
//--------------------------------------------------------------------------------------------//

// Initialze the parser
component State[DATA_BYTES];
State[0] = StateUpdate();
State[0].byte <== data[0];
State[0].parsing_start <== 1;
State[0].parsing_header <== 0;
State[0].parsing_field_name <== 0;
State[0].parsing_field_value <== 0;
State[0].parsing_body <== 0;
State[0].line_status <== 0;

signal headerMatch[DATA_BYTES];
headerMatch[0] <== 0;
signal isHeaderNameMatch[DATA_BYTES];
isHeaderNameMatch[0] <== 0;
signal readCRLF[DATA_BYTES];
readCRLF[0] <== 0;
signal valueMask[DATA_BYTES];
valueMask[0] <== 0;

for(var data_idx = 1; data_idx < DATA_BYTES; data_idx++) {
State[data_idx] = StateUpdate();
State[data_idx].byte <== data[data_idx];
State[data_idx].parsing_start <== State[data_idx - 1].next_parsing_start;
State[data_idx].parsing_header <== State[data_idx - 1].next_parsing_header;
State[data_idx].parsing_field_name <== State[data_idx-1].next_parsing_field_name;
State[data_idx].parsing_field_value <== State[data_idx-1].next_parsing_field_value;
State[data_idx].parsing_body <== State[data_idx - 1].next_parsing_body;
State[data_idx].line_status <== State[data_idx - 1].next_line_status;

// apply value mask to data
// TODO: change r
headerMatch[data_idx] <== HeaderFieldNameMatch(DATA_BYTES, headerNameLength)(data, header, 100, data_idx);
readCRLF[data_idx] <== IsEqual()([State[data_idx].line_status, 2]);
isHeaderNameMatch[data_idx] <== Mux1()([isHeaderNameMatch[data_idx-1] * (1-readCRLF[data_idx]), 1], headerMatch[data_idx]);
valueMask[data_idx] <== MultiAND(3)([data[data_idx], isHeaderNameMatch[data_idx], State[data_idx].parsing_field_value]);

// Debugging
log("State[", data_idx, "].parsing_start ", "= ", State[data_idx].parsing_start);
log("State[", data_idx, "].parsing_header ", "= ", State[data_idx].parsing_header);
log("State[", data_idx, "].parsing_field_name ", "= ", State[data_idx].parsing_field_name);
log("State[", data_idx, "].parsing_field_value", "= ", State[data_idx].parsing_field_value);
log("State[", data_idx, "].parsing_body ", "= ", State[data_idx].parsing_body);
log("State[", data_idx, "].line_status ", "= ", State[data_idx].line_status);
log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");
}

// Debugging
log("State[", DATA_BYTES, "].parsing_start ", "= ", State[DATA_BYTES-1].next_parsing_start);
log("State[", DATA_BYTES, "].parsing_header ", "= ", State[DATA_BYTES-1].next_parsing_header);
log("State[", DATA_BYTES, "].parsing_field_name ", "= ", State[DATA_BYTES-1].parsing_field_name);
log("State[", DATA_BYTES, "].parsing_field_value", "= ", State[DATA_BYTES-1].parsing_field_value);
log("State[", DATA_BYTES, "].parsing_body ", "= ", State[DATA_BYTES-1].next_parsing_body);
log("State[", DATA_BYTES, "].line_status ", "= ", State[DATA_BYTES-1].next_line_status);
log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");

signal valueStartingIndex[DATA_BYTES];
signal isZeroMask[DATA_BYTES];
signal isPrevStartingIndex[DATA_BYTES];
valueStartingIndex[0] <== 0;
isZeroMask[0] <== IsZero()(valueMask[0]);
for (var i=1 ; i<DATA_BYTES ; i++) {
isZeroMask[i] <== IsZero()(valueMask[i]);
isPrevStartingIndex[i] <== IsZero()(valueStartingIndex[i-1]);
valueStartingIndex[i] <== valueStartingIndex[i-1] + i * (1-isZeroMask[i]) * isPrevStartingIndex[i];
}

value <== SelectSubArray(DATA_BYTES, maxValueLength)(valueMask, valueStartingIndex[DATA_BYTES-1]+1, maxValueLength);
}
54 changes: 52 additions & 2 deletions circuits/http/interpreter.circom
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
pragma circom 2.1.9;

include "parser/language.circom";
include "../utils/search.circom";
include "../utils/array.circom";

/* TODO:
/* TODO:
Notes --
- This is a pretty efficient way to simply check what the method used in a request is by checking
the first `DATA_LENGTH` number of bytes.
the first `DATA_LENGTH` number of bytes.
- Could probably change this to a template that checks if it is one of the given methods
so we don't check them all in one
*/
Expand All @@ -32,4 +33,53 @@ template YieldMethod(DATA_LENGTH) {
signal TagPost <== IsPost.out * RequestMethodTag.POST;

MethodTag <== TagGet + TagPost;
}

// https://www.rfc-editor.org/rfc/rfc9112.html#name-field-syntax
template HeaderFieldNameValueMatch(dataLen, nameLen, valueLen) {
signal input data[dataLen];
signal input headerName[nameLen];
signal input headerValue[valueLen];
signal input r;
signal input index;

component syntax = Syntax();

signal output value[valueLen];

// is name matches
signal headerNameMatch <== SubstringMatchWithIndex(dataLen, nameLen)(data, headerName, r, index);

// next byte to name should be COLON
signal endOfHeaderName <== IndexSelector(dataLen)(data, index + nameLen);
signal isNextByteColon <== IsEqual()([endOfHeaderName, syntax.COLON]);

signal headerNameMatchAndNextByteColon <== headerNameMatch * isNextByteColon;

// field-name: SP field-value
signal headerValueMatch <== SubstringMatchWithIndex(dataLen, valueLen)(data, headerValue, r, index + nameLen + 2);

// header name matches + header value matches
signal output out <== headerNameMatchAndNextByteColon * headerValueMatch;
}

// https://www.rfc-editor.org/rfc/rfc9112.html#name-field-syntax
template HeaderFieldNameMatch(dataLen, nameLen) {
signal input data[dataLen];
signal input headerName[nameLen];
signal input r;
signal input index;

component syntax = Syntax();

// is name matches
signal headerNameMatch <== SubstringMatchWithIndex(dataLen, nameLen)(data, headerName, r, index);

// next byte to name should be COLON
signal endOfHeaderName <== IndexSelector(dataLen)(data, index + nameLen);
signal isNextByteColon <== IsEqual()([endOfHeaderName, syntax.COLON]);

// header name matches
signal output out;
out <== headerNameMatch * isNextByteColon;
}
64 changes: 53 additions & 11 deletions circuits/http/parser/machine.circom
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,37 @@ include "language.circom";
include "../../utils/array.circom";

template StateUpdate() {
signal input parsing_start; // Bool flag for if we are in the start line
signal input parsing_start; // flag that counts up to 3 for each value in the start line
signal input parsing_header; // Flag + Counter for what header line we are in
signal input parsing_body;
signal input parsing_field_name; // flag that tells if parsing header field name
signal input parsing_field_value; // flag that tells if parsing header field value
signal input parsing_body; // Flag when we are inside body
signal input line_status; // Flag that counts up to 4 to read a double CLRF
signal input byte;

signal output next_parsing_start;
signal output next_parsing_header;
signal output next_parsing_field_name;
signal output next_parsing_field_value;
signal output next_parsing_body;
signal output next_line_status;

component Syntax = Syntax();

//---------------------------------------------------------------------------------//
//---------------------------------------------------------------------------------//
// check if we read space or colon
component readSP = IsEqual();
readSP.in <== [byte, Syntax.SPACE];
component readColon = IsEqual();
readColon.in <== [byte, Syntax.COLON];

// Check if what we just read is a CR / LF
component readCR = IsEqual();
readCR.in <== [byte, Syntax.CR];
component readLF = IsEqual();
readLF.in <== [byte, Syntax.LF];

signal notCRAndLF <== (1 - readCR.out) * (1 - readLF.out);
signal notCRAndLF <== (1 - readCR.out) * (1 - readLF.out);
//---------------------------------------------------------------------------------//

//---------------------------------------------------------------------------------//
Expand All @@ -42,32 +52,64 @@ template StateUpdate() {

//---------------------------------------------------------------------------------//
// Take current state and CRLF info to update state
signal state[3] <== [parsing_start, parsing_header, parsing_body];
signal state[5] <== [parsing_start, parsing_header, parsing_field_name, parsing_field_value, parsing_body];
component stateChange = StateChange();
stateChange.readCRLF <== readCRLF;
stateChange.readCRLFCRLF <== readCRLFCRLF;
stateChange.readSP <== readSP.out;
stateChange.readColon <== readColon.out;
stateChange.state <== state;

component nextState = ArrayAdd(3);
component nextState = ArrayAdd(5);
nextState.lhs <== state;
nextState.rhs <== stateChange.out;
//---------------------------------------------------------------------------------//

next_parsing_start <== nextState.out[0];
next_parsing_header <== nextState.out[1];
next_parsing_body <== nextState.out[2];
next_parsing_field_name <== nextState.out[2];
next_parsing_field_value <== nextState.out[3];
next_parsing_body <== nextState.out[4];
next_line_status <== line_status + readCR.out + readCRLF + readCRLFCRLF - line_status * notCRAndLF;

}

// TODO:
// - multiple space between start line values
// - handle incrementParsingHeader being incremented for header -> body CRLF
// - header value parsing doesn't handle SPACE between colon and actual value
template StateChange() {
signal input readCRLF;
signal input readCRLFCRLF;
signal input state[3];
signal output out[3];
signal input readSP;
signal input readColon;
signal input state[5];
signal output out[5];

// GreaterEqThan(2) because start line can have at most 3 values for request or response
signal isParsingStart <== GreaterEqThan(2)([state[0], 1]);
// increment parsing start counter on reading SP
signal incrementParsingStart <== readSP * isParsingStart;
// disable parsing start on reading CRLF
signal disableParsingStart <== readCRLF * state[0];

// enable parsing header on reading CRLF
signal enableParsingHeader <== readCRLF * isParsingStart;
// check if we are parsing header
signal isParsingHeader <== GreaterEqThan(10)([state[1], 1]);
// increment parsing header counter on CRLF and parsing header
signal incrementParsingHeader <== readCRLF * isParsingHeader;
// disable parsing header on reading CRLF-CRLF
signal disableParsingHeader <== readCRLFCRLF * state[1];
// parsing field value when parsing header and read Colon `:`
signal isParsingFieldValue <== isParsingHeader * readColon;

// parsing body when reading CRLF-CRLF and parsing header
signal enableParsingBody <== readCRLFCRLF * isParsingHeader;

out <== [-disableParsingStart, disableParsingStart - disableParsingHeader, disableParsingHeader];
// parsing_start = out[0] = enable header (default 1) + increment start - disable start
// parsing_header = out[1] = enable header + increment header - disable header
// parsing_field_name = out[2] = enable header + increment header - parsing field value - parsing body
// parsing_field_value = out[3] = parsing field value - increment parsing header (zeroed every time new header starts)
// parsing_body = out[4] = enable body
out <== [incrementParsingStart - disableParsingStart, enableParsingHeader + incrementParsingHeader - disableParsingHeader, enableParsingHeader + incrementParsingHeader - isParsingFieldValue - enableParsingBody, isParsingFieldValue - incrementParsingHeader, enableParsingBody];
}
2 changes: 1 addition & 1 deletion circuits/test/common/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ export function readJSONInputFile(filename: string, key: any[]): [number[], numb
return [input, keyUnicode, output];
}

function toByte(data: string): number[] {
export function toByte(data: string): number[] {
const byteArray = [];
for (let i = 0; i < data.length; i++) {
byteArray.push(data.charCodeAt(i));
Expand Down
39 changes: 37 additions & 2 deletions circuits/test/http/extractor.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { circomkit, WitnessTester, generateDescription, readHTTPInputFile } from "../common";
import { circomkit, WitnessTester, generateDescription, readHTTPInputFile, toByte } from "../common";

describe("HTTP :: Extractor", async () => {
describe("HTTP :: body Extractor", async () => {
let circuit: WitnessTester<["data"], ["response"]>;


Expand Down Expand Up @@ -50,4 +50,39 @@ describe("HTTP :: Extractor", async () => {
output3.pop();
generatePassCase(parsedHttp.input, output3, "output length less than actual length");
});
});

describe("HTTP :: header Extractor", async () => {
let circuit: WitnessTester<["data", "header"], ["value"]>;

function generatePassCase(input: number[], headerName: number[], headerValue: number[], desc: string) {
const description = generateDescription(input);

it(`(valid) witness: ${description} ${desc}`, async () => {
circuit = await circomkit.WitnessTester(`ExtractHeaderValue`, {
file: "circuits/http/extractor",
template: "ExtractHeaderValue",
params: [input.length, headerName.length, headerValue.length],
});
console.log("#constraints:", await circuit.getConstraintCount());

await circuit.expectPass({ data: input, header: headerName }, { value: headerValue });
});
}

describe("response", async () => {

let parsedHttp = readHTTPInputFile("get_response.http");

generatePassCase(parsedHttp.input, toByte("Content-Length"), toByte(parsedHttp.headers["Content-Length"]), "");

// let output2 = parsedHttp.bodyBytes.slice(0);
// output2.push(0, 0, 0, 0);
// generatePassCase(parsedHttp.input, output2, "output length more than actual length");

// let output3 = parsedHttp.bodyBytes.slice(0);
// output3.pop();
// // output3.pop(); // TODO: fails due to shift subarray bug
// generatePassCase(parsedHttp.input, output3, "output length less than actual length");
});
});