Skip to content

Commit

Permalink
Query input from URL (#1508)
Browse files Browse the repository at this point in the history
Support for helper to pull query input from a URL, rather than receiving it directly from the client in an HTTP request body.

Also adds a simple HTTP server in test_mpc to serve local files, for testing purposes.
  • Loading branch information
andyleiserson authored Dec 20, 2024
1 parent 1cdc2f7 commit fc61e53
Show file tree
Hide file tree
Showing 18 changed files with 373 additions and 78 deletions.
2 changes: 2 additions & 0 deletions ipa-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ web-app = [
"rustls",
"rustls-pemfile",
"time",
"tiny_http",
"tokio-rustls",
"toml",
"tower",
Expand Down Expand Up @@ -141,6 +142,7 @@ thiserror = "1.0"
tikv-jemallocator = { version = "0.6", optional = true, features = ["profiling"] }
tikv-jemalloc-ctl = { version = "0.6", optional = true, features = ["stats"] }
time = { version = "0.3", optional = true }
tiny_http = { version = "0.12", optional = true }
tokio = { version = "1.42", features = ["fs", "rt", "rt-multi-thread", "macros"] }
tokio-rustls = { version = "0.26", optional = true }
tokio-stream = "0.1.14"
Expand Down
24 changes: 17 additions & 7 deletions ipa-core/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,24 @@ impl HelperApp {
///
/// ## Errors
/// Propagates errors from the helper.
/// ## Panics
/// If `input` asks to obtain query input from a remote URL.
pub fn execute_query(&self, input: QueryInput) -> Result<(), ApiError> {
let mpc_transport = self.inner.mpc_transport.clone_ref();
let shard_transport = self.inner.shard_transport.clone_ref();
self.inner
.query_processor
.receive_inputs(mpc_transport, shard_transport, input)?;
let QueryInput::Inline {
query_id,
input_stream,
} = input
else {
panic!("this client does not support pulling query input from a URL");
};
self.inner.query_processor.receive_inputs(
mpc_transport,
shard_transport,
query_id,
input_stream,
)?;
Ok(())
}

Expand Down Expand Up @@ -258,10 +270,8 @@ impl RequestHandler<HelperIdentity> for Inner {
HelperResponse::from(qp.receive_inputs(
Transport::clone_ref(&self.mpc_transport),
Transport::clone_ref(&self.shard_transport),
QueryInput {
query_id,
input_stream: data,
},
query_id,
data,
)?)
}
RouteId::QueryStatus => {
Expand Down
80 changes: 79 additions & 1 deletion ipa-core/src/bin/test_mpc.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
use std::{error::Error, fmt::Debug, ops::Add, path::PathBuf};
use std::{
error::Error,
fmt::Debug,
fs::File,
io::ErrorKind,
net::TcpListener,
ops::Add,
os::fd::{FromRawFd, RawFd},
path::PathBuf,
};

use clap::{Parser, Subcommand};
use generic_array::ArrayLength;
Expand All @@ -21,6 +30,8 @@ use ipa_core::{
net::{Helper, IpaHttpClient},
secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares},
};
use tiny_http::{Response, ResponseBox, Server, StatusCode};
use tracing::{error, info};

#[derive(Debug, Parser)]
#[clap(
Expand Down Expand Up @@ -95,6 +106,23 @@ enum TestAction {
/// This is exactly what shuffle does and that's why it is picked
/// for this purpose.
ShardedShuffle,
ServeInput(ServeInputArgs),
}

#[derive(Debug, clap::Args)]
#[clap(about = "Run a simple HTTP server to serve query input files")]
pub struct ServeInputArgs {
/// Port to listen on
#[arg(short, long)]
port: Option<u16>,

/// Listen on the supplied prebound socket instead of binding a new socket
#[arg(long, conflicts_with = "port")]
fd: Option<RawFd>,

/// Directory with input files to serve
#[arg(short, long = "dir")]
directory: PathBuf,
}

#[tokio::main]
Expand Down Expand Up @@ -129,6 +157,7 @@ async fn main() -> Result<(), Box<dyn Error>> {
.await;
sharded_shuffle(&args, clients).await
}
TestAction::ServeInput(options) => serve_input(options),
};

Ok(())
Expand Down Expand Up @@ -204,3 +233,52 @@ async fn sharded_shuffle(args: &Args, helper_clients: Vec<[IpaHttpClient<Helper>
assert_eq!(shuffled.len(), input_rows.len());
assert_ne!(shuffled, input_rows);
}

fn not_found() -> ResponseBox {
Response::from_string("not found")
.with_status_code(StatusCode(404))
.boxed()
}

#[tracing::instrument("serve_input", skip_all)]
fn serve_input(args: ServeInputArgs) {
let server = if let Some(port) = args.port {
Server::http(("localhost", port)).unwrap()
} else if let Some(fd) = args.fd {
Server::from_listener(unsafe { TcpListener::from_raw_fd(fd) }, None).unwrap()
} else {
Server::http("localhost:0").unwrap()
};

if args.port.is_none() {
info!(
"Listening on :{}",
server.server_addr().to_ip().unwrap().port()
);
}

loop {
let request = server.recv().unwrap();
tracing::info!(target: "request_url", "{}", request.url());

let url = request.url()[1..].to_owned();
let response = if url.contains('/') {
error!(target: "error", "Request URL contains a slash");
not_found()
} else {
match File::open(args.directory.join(&url)) {
Ok(file) => Response::from_file(file).boxed(),
Err(err) => {
if err.kind() != ErrorKind::NotFound {
error!(target: "error", "{err}");
}
not_found()
}
}
};

let _ = request.respond(response).map_err(|err| {
error!(target: "error", "{err}");
});
}
}
2 changes: 1 addition & 1 deletion ipa-core/src/cli/playbook/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ where
.into_iter()
.zip(clients)
.map(|(input_stream, client)| {
client.query_input(QueryInput {
client.query_input(QueryInput::Inline {
query_id,
input_stream,
})
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/cli/playbook/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ where
|(shard_clients, shard_inputs)| {
try_join_all(shard_clients.iter().zip(shard_inputs.into_iter()).map(
|(client, input)| {
client.query_input(QueryInput {
client.query_input(QueryInput::Inline {
query_id,
input_stream: input,
})
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/cli/playbook/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ where
.into_iter()
.zip(clients)
.map(|(input_stream, client)| {
client.query_input(QueryInput {
client.query_input(QueryInput::Inline {
query_id,
input_stream,
})
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/cli/playbook/multiply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ where
.into_iter()
.zip(clients)
.map(|(input_stream, client)| {
client.query_input(QueryInput {
client.query_input(QueryInput::Inline {
query_id,
input_stream,
})
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/cli/playbook/sharded_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ where
let shared = chunk.iter().copied().share();
try_join_all(mpc_clients.each_ref().iter().zip(shared).map(
|(mpc_client, input)| {
mpc_client.query_input(QueryInput {
mpc_client.query_input(QueryInput::Inline {
query_id,
input_stream: BodyStream::from_serializable_iter(input),
})
Expand Down
52 changes: 48 additions & 4 deletions ipa-core/src/helpers/transport/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,58 @@ impl RouteParams<RouteId, QueryId, NoStep> for &PrepareQuery {
}
}

pub struct QueryInput {
pub query_id: QueryId,
pub input_stream: BodyStream,
pub enum QueryInput {
FromUrl {
query_id: QueryId,
url: String,
},
Inline {
query_id: QueryId,
input_stream: BodyStream,
},
}

impl QueryInput {
#[must_use]
pub fn query_id(&self) -> QueryId {
match self {
Self::FromUrl { query_id, .. } | Self::Inline { query_id, .. } => *query_id,
}
}

#[must_use]
pub fn input_stream(self) -> Option<BodyStream> {
match self {
Self::Inline { input_stream, .. } => Some(input_stream),
Self::FromUrl { .. } => None,
}
}

#[must_use]
pub fn url(&self) -> Option<&str> {
match self {
Self::FromUrl { url, .. } => Some(url),
Self::Inline { .. } => None,
}
}
}

impl Debug for QueryInput {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "query_inputs[{:?}]", self.query_id)
match self {
QueryInput::Inline {
query_id,
input_stream: _,
} => f
.debug_struct("QueryInput::Inline")
.field("query_id", query_id)
.finish(),
QueryInput::FromUrl { query_id, url } => f
.debug_struct("QueryInput::FromUrl")
.field("query_id", query_id)
.field("url", url)
.finish(),
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/net/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ pub(crate) mod tests {
};
test_query_command(
|client| async move {
let data = QueryInput {
let data = QueryInput::Inline {
query_id: expected_query_id,
input_stream: expected_input.to_vec().into(),
};
Expand Down
64 changes: 56 additions & 8 deletions ipa-core/src/net/http_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,12 +322,23 @@ pub mod query {
}

pub mod input {
use axum::{body::Body, http::uri};
use hyper::header::CONTENT_TYPE;
use axum::{
async_trait,
body::Body,
extract::FromRequestParts,
http::{request::Parts, uri},
};
use hyper::{
header::{HeaderValue, CONTENT_TYPE},
Uri,
};

use crate::{
helpers::query::QueryInput,
net::{http_serde::query::BASE_AXUM_PATH, APPLICATION_OCTET_STREAM},
net::{
http_serde::query::BASE_AXUM_PATH, Error, APPLICATION_OCTET_STREAM,
HTTP_QUERY_INPUT_URL_HEADER,
},
};

#[derive(Debug)]
Expand All @@ -351,17 +362,54 @@ pub mod query {
.path_and_query(format!(
"{}/{}/input",
BASE_AXUM_PATH,
self.query_input.query_id.as_ref(),
self.query_input.query_id().as_ref(),
))
.build()?;
let body = Body::from_stream(self.query_input.input_stream);
Ok(hyper::Request::post(uri)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(body)?)
let query_input_url = self.query_input.url().map(ToOwned::to_owned);
let body = self
.query_input
.input_stream()
.map_or_else(Body::empty, Body::from_stream);
let mut request =
hyper::Request::post(uri).header(CONTENT_TYPE, APPLICATION_OCTET_STREAM);
if let Some(url) = query_input_url {
request.headers_mut().unwrap().insert(
&HTTP_QUERY_INPUT_URL_HEADER,
HeaderValue::try_from(url).unwrap(),
);
}
Ok(request.body(body)?)
}
}

pub const AXUM_PATH: &str = "/:query_id/input";

pub struct QueryInputUrl(Option<Uri>);

#[async_trait]
impl<S: Send + Sync> FromRequestParts<S> for QueryInputUrl {
type Rejection = Error;

async fn from_request_parts(
req: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
match req.headers.get(&HTTP_QUERY_INPUT_URL_HEADER) {
None => Ok(QueryInputUrl(None)),
Some(value) => {
let value_str = value.to_str()?;
let uri = value_str.parse()?;
Ok(QueryInputUrl(Some(uri)))
}
}
}
}

impl From<QueryInputUrl> for Option<Uri> {
fn from(value: QueryInputUrl) -> Self {
value.0
}
}
}

pub mod step {
Expand Down
2 changes: 2 additions & 0 deletions ipa-core/src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::{
mod client;
mod error;
mod http_serde;
pub mod query_input;
mod server;
#[cfg(all(test, not(feature = "shuttle")))]
pub mod test;
Expand All @@ -32,6 +33,7 @@ const APPLICATION_JSON: &str = "application/json";
const APPLICATION_OCTET_STREAM: &str = "application/octet-stream";
static HTTP_HELPER_ID_HEADER: HeaderName = HeaderName::from_static("x-unverified-helper-identity");
static HTTP_SHARD_INDEX_HEADER: HeaderName = HeaderName::from_static("x-unverified-shard-index");
static HTTP_QUERY_INPUT_URL_HEADER: HeaderName = HeaderName::from_static("x-query-input-url");

/// This has the same meaning as const defined in h2 crate, but we don't import it directly.
/// According to the [`spec`] it cannot exceed 2^31 - 1.
Expand Down
Loading

0 comments on commit fc61e53

Please sign in to comment.