Skip to content

Commit

Permalink
feat: http locking and circuit codegen (#76)
Browse files Browse the repository at this point in the history
* fix: circuits.json

* remove unneeded import

* feat: `MethodMatch`

* check method

* validating method and target

* feat: lock a request line

* a few more tests

* request and status line working

* WIP: progress towards header locking

* lock with a header

* codgen for request locking

* working codegen!

* Update codegen.test.ts

* address Mr. Sambhav's feedback :)
  • Loading branch information
Autoparallel authored Sep 9, 2024
1 parent 1d814b3 commit 12a6651
Show file tree
Hide file tree
Showing 12 changed files with 765 additions and 83 deletions.
9 changes: 7 additions & 2 deletions circuits.json
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,17 @@
"get_request": {
"file": "http/parser/parser",
"template": "Parser",
"params": [60]
"params": [
60
]
},
"get_response": {
"file": "http/parser/parser",
"template": "Parser",
"params": [89]
"params": [
89
]
},
"json_extract_value_string": {
"file": "main/value_string",
"template": "ExtractStringValue",
Expand Down
3 changes: 1 addition & 2 deletions circuits/http/extractor.circom
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ include "interpreter.circom";
include "parser/machine.circom";
include "../utils/bytes.circom";
include "../utils/search.circom";
include "circomlib/circuits/mux1.circom";
include "circomlib/circuits/gates.circom";
include "@zk-email/circuits/utils/array.circom";

Expand Down Expand Up @@ -163,4 +162,4 @@ template ExtractHeaderValue(DATA_BYTES, headerNameLength, maxValueLength) {
}

value <== SelectSubArray(DATA_BYTES, maxValueLength)(valueMask, valueStartingIndex[DATA_BYTES-1]+1, maxValueLength);
}
}
66 changes: 36 additions & 30 deletions circuits/http/interpreter.circom
Original file line number Diff line number Diff line change
Expand Up @@ -4,35 +4,41 @@ include "parser/language.circom";
include "../utils/search.circom";
include "../utils/array.circom";

/* TODO:
Notes --
- This is a pretty efficient way to simply check what the method used in a request is by checking
the first `DATA_LENGTH` number of bytes.
- Could probably change this to a template that checks if it is one of the given methods
so we don't check them all in one
*/
template YieldMethod(DATA_LENGTH) {
signal input bytes[DATA_LENGTH];
signal output MethodTag;

component RequestMethod = RequestMethod();
component RequestMethodTag = RequestMethodTag();

component IsGet = IsEqualArray(3);
for(var byte_idx = 0; byte_idx < 3; byte_idx++) {
IsGet.in[0][byte_idx] <== bytes[byte_idx];
IsGet.in[1][byte_idx] <== RequestMethod.GET[byte_idx];
}
signal TagGet <== IsGet.out * RequestMethodTag.GET;

component IsPost = IsEqualArray(4);
for(var byte_idx = 0; byte_idx < 4; byte_idx++) {
IsPost.in[0][byte_idx] <== bytes[byte_idx];
IsPost.in[1][byte_idx] <== RequestMethod.POST[byte_idx];
}
signal TagPost <== IsPost.out * RequestMethodTag.POST;

MethodTag <== TagGet + TagPost;
template inStartLine() {
signal input parsing_start;
signal output out;

signal isBeginning <== IsEqual()([parsing_start, 1]);
signal isMiddle <== IsEqual()([parsing_start, 2]);
signal isEnd <== IsEqual()([parsing_start, 3]);

out <== isBeginning + isMiddle + isEnd;
}

template inStartMiddle() {
signal input parsing_start;
signal output out;

out <== IsEqual()([parsing_start, 2]);
}

template inStartEnd() {
signal input parsing_start;
signal output out;

out <== IsEqual()([parsing_start, 3]);
}

// TODO: This likely isn't really an "Intepreter" thing
template MethodMatch(dataLen, methodLen) {
signal input data[dataLen];
signal input method[methodLen];

signal input r;
signal input index;

signal isMatch <== SubstringMatchWithIndex(dataLen, methodLen)(data, method, r, index);
isMatch === 1;
}

// https://www.rfc-editor.org/rfc/rfc9112.html#name-field-syntax
Expand All @@ -45,7 +51,7 @@ template HeaderFieldNameValueMatch(dataLen, nameLen, valueLen) {

component syntax = Syntax();

signal output value[valueLen];
// signal output value[valueLen];

// is name matches
signal headerNameMatch <== SubstringMatchWithIndex(dataLen, nameLen)(data, headerName, r, index);
Expand Down
185 changes: 185 additions & 0 deletions circuits/http/locker.circom
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
pragma circom 2.1.9;

include "interpreter.circom";
include "parser/machine.circom";
include "../utils/bytes.circom";
include "../utils/search.circom";
include "circomlib/circuits/gates.circom";
include "@zk-email/circuits/utils/array.circom";

template LockStartLine(DATA_BYTES, beginningLen, middleLen, finalLen) {
signal input data[DATA_BYTES];
signal input beginning[beginningLen];
signal input middle[middleLen];
signal input final[finalLen];

//--------------------------------------------------------------------------------------------//
//-CONSTRAINTS--------------------------------------------------------------------------------//
//--------------------------------------------------------------------------------------------//
component dataASCII = ASCII(DATA_BYTES);
dataASCII.in <== data;
//--------------------------------------------------------------------------------------------//

// Initialze the parser
component State[DATA_BYTES];
State[0] = StateUpdate();
State[0].byte <== data[0];
State[0].parsing_start <== 1;
State[0].parsing_header <== 0;
State[0].parsing_field_name <== 0;
State[0].parsing_field_value <== 0;
State[0].parsing_body <== 0;
State[0].line_status <== 0;

/*
Note, because we know a beginning is the very first thing in a request
we can make this more efficient by just comparing the first `beginningLen` bytes
of the data ASCII against the beginning ASCII itself.
*/
// Check first beginning byte
signal beginningIsEqual[beginningLen];
beginningIsEqual[0] <== IsEqual()([data[0],beginning[0]]);
beginningIsEqual[0] === 1;

// Setup to check middle bytes
signal startLineMask[DATA_BYTES];
signal middleMask[DATA_BYTES];
signal finalMask[DATA_BYTES];

var middle_start_counter = 1;
var middle_end_counter = 1;
var final_end_counter = 1;
for(var data_idx = 1; data_idx < DATA_BYTES; data_idx++) {
State[data_idx] = StateUpdate();
State[data_idx].byte <== data[data_idx];
State[data_idx].parsing_start <== State[data_idx - 1].next_parsing_start;
State[data_idx].parsing_header <== State[data_idx - 1].next_parsing_header;
State[data_idx].parsing_field_name <== State[data_idx-1].next_parsing_field_name;
State[data_idx].parsing_field_value <== State[data_idx-1].next_parsing_field_value;
State[data_idx].parsing_body <== State[data_idx - 1].next_parsing_body;
State[data_idx].line_status <== State[data_idx - 1].next_line_status;

// Check remaining beginning bytes
if(data_idx < beginningLen) {
beginningIsEqual[data_idx] <== IsEqual()([data[data_idx], beginning[data_idx]]);
beginningIsEqual[data_idx] === 1;
}

// Middle
startLineMask[data_idx] <== inStartLine()(State[data_idx].parsing_start);
middleMask[data_idx] <== inStartMiddle()(State[data_idx].parsing_start);
finalMask[data_idx] <== inStartEnd()(State[data_idx].parsing_start);
middle_start_counter += startLineMask[data_idx] - middleMask[data_idx] - finalMask[data_idx];
// The end of middle is the start of the final
middle_end_counter += startLineMask[data_idx] - finalMask[data_idx];
final_end_counter += startLineMask[data_idx];

// Debugging
log("State[", data_idx, "].parsing_start = ", State[data_idx].parsing_start);
log("State[", data_idx, "].parsing_header = ", State[data_idx].parsing_header);
log("State[", data_idx, "].parsing_field_name = ", State[data_idx].parsing_field_name);
log("State[", data_idx, "].parsing_field_value = ", State[data_idx].parsing_field_value);
log("State[", data_idx, "].parsing_body = ", State[data_idx].parsing_body);
log("State[", data_idx, "].line_status = ", State[data_idx].line_status);
log("------------------------------------------------");
log("middle_start_counter = ", middle_start_counter);
log("middle_end_counter = ", middle_end_counter);
log("final_end_counter = ", final_end_counter);
log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");
}

// Debugging
log("State[", DATA_BYTES, "].parsing_start ", "= ", State[DATA_BYTES-1].next_parsing_start);
log("State[", DATA_BYTES, "].parsing_header ", "= ", State[DATA_BYTES-1].next_parsing_header);
log("State[", DATA_BYTES, "].parsing_field_name ", "= ", State[DATA_BYTES-1].parsing_field_name);
log("State[", DATA_BYTES, "].parsing_field_value", "= ", State[DATA_BYTES-1].parsing_field_value);
log("State[", DATA_BYTES, "].parsing_body ", "= ", State[DATA_BYTES-1].next_parsing_body);
log("State[", DATA_BYTES, "].line_status ", "= ", State[DATA_BYTES-1].next_line_status);
log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");

// Additionally verify beginning had correct length
beginningLen === middle_start_counter - 1;

// Check middle is correct by substring match and length check
// TODO: change r
signal middleMatch <== SubstringMatchWithIndex(DATA_BYTES, middleLen)(data, middle, 100, middle_start_counter);
middleMatch === 1;
middleLen === middle_end_counter - middle_start_counter - 1;

// Check final is correct by substring match and length check
// TODO: change r
signal finalMatch <== SubstringMatchWithIndex(DATA_BYTES, finalLen)(data, final, 100, middle_end_counter);
finalMatch === 1;
// -2 here for the CRLF
finalLen === final_end_counter - middle_end_counter - 2;
}

template LockHeader(DATA_BYTES, headerNameLen, headerValueLen) {
signal input data[DATA_BYTES];
signal input header[headerNameLen];
signal input value[headerValueLen];

//--------------------------------------------------------------------------------------------//
//-CONSTRAINTS--------------------------------------------------------------------------------//
//--------------------------------------------------------------------------------------------//
component dataASCII = ASCII(DATA_BYTES);
dataASCII.in <== data;
//--------------------------------------------------------------------------------------------//

// Initialze the parser
component State[DATA_BYTES];
State[0] = StateUpdate();
State[0].byte <== data[0];
State[0].parsing_start <== 1;
State[0].parsing_header <== 0;
State[0].parsing_field_name <== 0;
State[0].parsing_field_value <== 0;
State[0].parsing_body <== 0;
State[0].line_status <== 0;

component headerFieldNameValueMatch[DATA_BYTES];
signal isHeaderFieldNameValueMatch[DATA_BYTES];

isHeaderFieldNameValueMatch[0] <== 0;
var hasMatched = 0;

for(var data_idx = 1; data_idx < DATA_BYTES; data_idx++) {
State[data_idx] = StateUpdate();
State[data_idx].byte <== data[data_idx];
State[data_idx].parsing_start <== State[data_idx - 1].next_parsing_start;
State[data_idx].parsing_header <== State[data_idx - 1].next_parsing_header;
State[data_idx].parsing_field_name <== State[data_idx-1].next_parsing_field_name;
State[data_idx].parsing_field_value <== State[data_idx-1].next_parsing_field_value;
State[data_idx].parsing_body <== State[data_idx - 1].next_parsing_body;
State[data_idx].line_status <== State[data_idx - 1].next_line_status;

// TODO: change r
headerFieldNameValueMatch[data_idx] = HeaderFieldNameValueMatch(DATA_BYTES, headerNameLen, headerValueLen);
headerFieldNameValueMatch[data_idx].data <== data;
headerFieldNameValueMatch[data_idx].headerName <== header;
headerFieldNameValueMatch[data_idx].headerValue <== value;
headerFieldNameValueMatch[data_idx].r <== 100;
headerFieldNameValueMatch[data_idx].index <== data_idx;
isHeaderFieldNameValueMatch[data_idx] <== isHeaderFieldNameValueMatch[data_idx-1] + headerFieldNameValueMatch[data_idx].out;

// Debugging
log("State[", data_idx, "].parsing_start ", "= ", State[data_idx].parsing_start);
log("State[", data_idx, "].parsing_header ", "= ", State[data_idx].parsing_header);
log("State[", data_idx, "].parsing_field_name ", "= ", State[data_idx].parsing_field_name);
log("State[", data_idx, "].parsing_field_value", "= ", State[data_idx].parsing_field_value);
log("State[", data_idx, "].parsing_body ", "= ", State[data_idx].parsing_body);
log("State[", data_idx, "].line_status ", "= ", State[data_idx].line_status);
log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");
}

// Debugging
log("State[", DATA_BYTES, "].parsing_start ", "= ", State[DATA_BYTES-1].next_parsing_start);
log("State[", DATA_BYTES, "].parsing_header ", "= ", State[DATA_BYTES-1].next_parsing_header);
log("State[", DATA_BYTES, "].parsing_field_name ", "= ", State[DATA_BYTES-1].parsing_field_name);
log("State[", DATA_BYTES, "].parsing_field_value", "= ", State[DATA_BYTES-1].parsing_field_value);
log("State[", DATA_BYTES, "].parsing_body ", "= ", State[DATA_BYTES-1].next_parsing_body);
log("State[", DATA_BYTES, "].line_status ", "= ", State[DATA_BYTES-1].next_line_status);
log("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx");

isHeaderFieldNameValueMatch[DATA_BYTES - 1] === 1;
}
25 changes: 0 additions & 25 deletions circuits/http/parser/language.circom
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,4 @@ template Syntax() {
//-Escape-------------------------------------------------------------------------------------//
// - ASCII char: `\`
signal output ESCAPE <== 92;
}

template RequestMethod() {
signal output GET[3] <== [71, 69, 84];
// signal output HEAD[4] <== [72, 69, 65, 68];
signal output POST[4] <== [80, 79, 83, 84];
// signal output PUT <== 3;
// signal output DELETE <== 4;
// signal output CONNECT <== 5;
// signal output OPTIONS <== 6;
// signal output TRACE <== 7;
// signal output PATCH <== 8;
}

// NOTE: Starting at 1 to avoid a false positive with a 0.
template RequestMethodTag() {
signal output GET <== 1;
// signal output HEAD <== 2;
signal output POST <== 3;
// signal output PUT <== 4;
// signal output DELETE <== 5;
// signal output CONNECT <== 6;
// signal output OPTIONS <== 7;
// signal output TRACE <== 8;
// signal output PATCH <== 9;
}
22 changes: 13 additions & 9 deletions circuits/test/common/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,22 +67,22 @@ export function toByte(data: string): number[] {

export function readHTTPInputFile(filename: string) {
const filePath = join(__dirname, "..", "..", "..", "examples", "http", filename);
let input: number[] = [];

let data = readFileSync(filePath, 'utf-8');

input = toByte(data);
let input = toByte(data);

// Split headers and body
const [headerSection, bodySection] = data.split('\r\n\r\n');
// Split headers and body, accounting for possible lack of body
const parts = data.split('\r\n\r\n');
const headerSection = parts[0];
const bodySection = parts.length > 1 ? parts[1] : '';

// Function to parse headers into a dictionary
function parseHeaders(headerLines: string[]) {
const headers: { [id: string]: string } = {};

headerLines.forEach(line => {
const [key, value] = line.split(/:\s(.+)/);
headers[key] = value ? value : '';
if (key) headers[key] = value ? value : '';
});

return headers;
Expand All @@ -95,8 +95,12 @@ export function readHTTPInputFile(filename: string) {

// Parse the body, if JSON response
let responseBody = {};
if (headers["Content-Type"] == "application/json") {
responseBody = JSON.parse(bodySection);
if (headers["Content-Type"] == "application/json" && bodySection) {
try {
responseBody = JSON.parse(bodySection);
} catch (e) {
console.error("Failed to parse JSON body:", e);
}
}

// Combine headers and body into an object
Expand All @@ -105,6 +109,6 @@ export function readHTTPInputFile(filename: string) {
initialLine: initialLine,
headers: headers,
body: responseBody,
bodyBytes: toByte(bodySection),
bodyBytes: toByte(bodySection || ''),
};
}
Loading

0 comments on commit 12a6651

Please sign in to comment.