Skip to content

Commit

Permalink
differentiate reponse and request
Browse files Browse the repository at this point in the history
  • Loading branch information
Autoparallel committed Sep 10, 2024
1 parent 12a6651 commit d1c6b92
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 44 deletions.
9 changes: 9 additions & 0 deletions examples/lockfile/request.lock.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"method": "GET",
"target": "/api",
"version": "HTTP/1.1",
"headerName1": "Host",
"headerValue1": "localhost",
"headerName2": "Accept",
"headerValue2": "application/json"
}
7 changes: 7 additions & 0 deletions examples/lockfile/response.lock.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"version": "HTTP/1.1",
"status": "200",
"message": "OK",
"headerName1": "Content-Type",
"headerValue1": "application/json"
}
28 changes: 0 additions & 28 deletions examples/lockfile/test.lock.json

This file was deleted.

78 changes: 62 additions & 16 deletions src/http_lock.rs
Original file line number Diff line number Diff line change
@@ -1,32 +1,67 @@
use super::*;
use std::{
collections::HashMap,
fs::{self, create_dir_all},
};

#[derive(Debug, Serialize, Deserialize)]
struct HttpData {
request: Request,
response: Response,
#[serde(untagged)]
pub enum HttpData {
Request(Request),
Response(Response),
}

#[derive(Debug, Serialize, Deserialize)]
struct Request {
pub struct Request {
method: String,
target: String,
version: String,
headers: Vec<(String, String)>,
#[serde(flatten)]
#[serde(deserialize_with = "deserialize_headers")]
headers: HashMap<String, String>,
}

#[derive(Debug, Serialize, Deserialize)]
struct Response {
pub struct Response {
version: String,
status: String,
message: String,
headers: Vec<(String, String)>,
#[serde(flatten)]
#[serde(deserialize_with = "deserialize_headers")]
headers: HashMap<String, String>,
}

use std::fs::{self, create_dir_all};
impl HttpData {
fn headers(&self) -> HashMap<String, String> {
match self {
HttpData::Request(request) => request.headers.clone(),
HttpData::Response(response) => response.headers.clone(),
}
}
}

fn deserialize_headers<'de, D>(deserializer: D) -> Result<HashMap<String, String>, D::Error>
where
D: serde::Deserializer<'de>,
{
let mut map = HashMap::new();
let mut temp_map: HashMap<String, String> = HashMap::deserialize(deserializer)?;

let mut i = 1;
while let (Some(name), Some(value)) = (
temp_map.remove(&format!("headerName{}", i)),
temp_map.remove(&format!("headerValue{}", i)),
) {
map.insert(name, value);
i += 1;
}

Ok(map)
}

const PRAGMA: &str = "pragma circom 2.1.9;\n\n";

fn request_locker_circuit(
fn locker_circuit(
data: HttpData,
output_filename: String,
) -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -41,8 +76,18 @@ fn request_locker_circuit(

// template LockHTTP(DATA_BYTES, beginningLen, middleLen, finalLen, headerNameLen1, headerValueLen1, ...) {
{
circuit_buffer += "template LockHTTP(DATA_BYTES, beginningLen, middleLen, finalLen ";
for (i, _header) in data.request.headers.iter().enumerate() {
match data {
HttpData::Request(_) => {
circuit_buffer +=
"template LockHTTPRequest(DATA_BYTES, methodLen, targetLen, versionLen";
}
HttpData::Response(_) => {
circuit_buffer +=
"template LockHTTPResponse(DATA_BYTES, versionLen, statusLen, messageLen";
}
}

for (i, _header) in data.headers().iter().enumerate() {
circuit_buffer += &format!(", headerNameLen{}, headerValueLen{}", i + 1, i + 1);
}
circuit_buffer += ") {";
Expand All @@ -66,7 +111,7 @@ fn request_locker_circuit(
// Header signals
"#;

for (i, _header) in data.request.headers.iter().enumerate() {
for (i, _header) in data.headers().iter().enumerate() {
circuit_buffer += &format!(
" signal input header{}[headerNameLen{}];\n",
i + 1,
Expand Down Expand Up @@ -114,7 +159,7 @@ fn request_locker_circuit(

// Create header match signals
{
for (i, _header) in data.request.headers.iter().enumerate() {
for (i, _header) in data.headers().iter().enumerate() {
circuit_buffer += &format!(" signal headerNameValueMatch{}[DATA_BYTES];\n", i + 1);
circuit_buffer += &format!(" headerNameValueMatch{}[0] <== 0;\n", i + 1);
circuit_buffer += &format!(" var hasMatchedHeaderValue{} = 0;\n\n", i + 1);
Expand Down Expand Up @@ -156,7 +201,7 @@ fn request_locker_circuit(

// Header matches
{
for (i, _header) in data.request.headers.iter().enumerate() {
for (i, _header) in data.headers().iter().enumerate() {
circuit_buffer += &format!(" headerNameValueMatch{}[data_idx] <== HeaderFieldNameValueMatch(DATA_BYTES, headerNameLen{}, headerValueLen{})(data, header{}, value{}, 100, data_idx);\n", i + 1,i + 1,i + 1,i + 1,i + 1);
circuit_buffer += &format!(
" hasMatchedHeaderValue{} += headerNameValueMatch{}[data_idx];\n",
Expand Down Expand Up @@ -215,7 +260,7 @@ fn request_locker_circuit(

// Verify all headers have matched
{
for (i, _header) in data.request.headers.iter().enumerate() {
for (i, _header) in data.headers().iter().enumerate() {
circuit_buffer += &format!(" hasMatchedHeaderValue{} === 1;\n", i + 1);
}
}
Expand All @@ -242,9 +287,10 @@ fn request_locker_circuit(
// TODO: This needs to codegen a circuit now.
pub fn http_lock(args: HttpLockArgs) -> Result<(), Box<dyn Error>> {
let data = std::fs::read(&args.lockfile)?;

let http_data: HttpData = serde_json::from_slice(&data)?;

request_locker_circuit(http_data, args.output_filename)?;
locker_circuit(http_data, args.output_filename)?;

Ok(())
}

0 comments on commit d1c6b92

Please sign in to comment.