Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Query input from URL #1508

Merged
2 changes: 2 additions & 0 deletions ipa-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ web-app = [
"rustls",
"rustls-pemfile",
"time",
"tiny_http",
"tokio-rustls",
"toml",
"tower",
Expand Down Expand Up @@ -140,6 +141,7 @@ thiserror = "1.0"
tikv-jemallocator = { version = "0.6", optional = true, features = ["profiling"] }
tikv-jemalloc-ctl = { version = "0.6", optional = true, features = ["stats"] }
time = { version = "0.3", optional = true }
tiny_http = { version = "0.12", optional = true }
tokio = { version = "1.42", features = ["fs", "rt", "rt-multi-thread", "macros"] }
tokio-rustls = { version = "0.26", optional = true }
tokio-stream = "0.1.14"
Expand Down
24 changes: 17 additions & 7 deletions ipa-core/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,24 @@
///
/// ## Errors
/// Propagates errors from the helper.
/// ## Panics
/// If `input` asks to obtain query input from a remote URL.
pub fn execute_query(&self, input: QueryInput) -> Result<(), ApiError> {
let mpc_transport = self.inner.mpc_transport.clone_ref();
let shard_transport = self.inner.shard_transport.clone_ref();
self.inner
.query_processor
.receive_inputs(mpc_transport, shard_transport, input)?;
let QueryInput::Inline {
query_id,
input_stream,
} = input
else {
panic!("this client does not support pulling query input from a URL");
};
self.inner.query_processor.receive_inputs(
mpc_transport,
shard_transport,

Check warning on line 153 in ipa-core/src/app.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/app.rs#L153

Added line #L153 was not covered by tests
query_id,
input_stream,
)?;
Ok(())
}

Expand Down Expand Up @@ -254,10 +266,8 @@
HelperResponse::from(qp.receive_inputs(
Transport::clone_ref(&self.mpc_transport),
Transport::clone_ref(&self.shard_transport),
QueryInput {
query_id,
input_stream: data,
},
query_id,
data,
)?)
}
RouteId::QueryStatus => {
Expand Down
80 changes: 79 additions & 1 deletion ipa-core/src/bin/test_mpc.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
use std::{error::Error, fmt::Debug, ops::Add, path::PathBuf};
use std::{
error::Error,
fmt::Debug,
fs::File,
io::ErrorKind,
net::TcpListener,
ops::Add,
os::fd::{FromRawFd, RawFd},
path::PathBuf,
};

use clap::{Parser, Subcommand};
use generic_array::ArrayLength;
Expand All @@ -21,6 +30,8 @@
net::{Helper, IpaHttpClient},
secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares},
};
use tiny_http::{Response, ResponseBox, Server, StatusCode};
use tracing::{error, info};

#[derive(Debug, Parser)]
#[clap(
Expand Down Expand Up @@ -95,6 +106,23 @@
/// This is exactly what shuffle does and that's why it is picked
/// for this purpose.
ShardedShuffle,
ServeInput(ServeInputArgs),
}

#[derive(Debug, clap::Args)]
#[clap(about = "Run a simple HTTP server to serve query input files")]
pub struct ServeInputArgs {
/// Port to listen on
#[arg(short, long)]
port: Option<u16>,

/// Listen on the supplied prebound socket instead of binding a new socket
#[arg(long, conflicts_with = "port")]
fd: Option<RawFd>,

/// Directory with input files to serve
#[arg(short, long = "dir")]
directory: PathBuf,

Check warning on line 125 in ipa-core/src/bin/test_mpc.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/test_mpc.rs#L125

Added line #L125 was not covered by tests
}

#[tokio::main]
Expand Down Expand Up @@ -129,6 +157,7 @@
.await;
sharded_shuffle(&args, clients).await
}
TestAction::ServeInput(options) => serve_input(options),
};

Ok(())
Expand Down Expand Up @@ -204,3 +233,52 @@
assert_eq!(shuffled.len(), input_rows.len());
assert_ne!(shuffled, input_rows);
}

fn not_found() -> ResponseBox {
Response::from_string("not found")
.with_status_code(StatusCode(404))
.boxed()
}

Check warning on line 241 in ipa-core/src/bin/test_mpc.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/test_mpc.rs#L237-L241

Added lines #L237 - L241 were not covered by tests

#[tracing::instrument("serve_input", skip_all)]

Check warning on line 243 in ipa-core/src/bin/test_mpc.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/test_mpc.rs#L243

Added line #L243 was not covered by tests
fn serve_input(args: ServeInputArgs) {
let server = if let Some(port) = args.port {
Server::http(("localhost", port)).unwrap()
} else if let Some(fd) = args.fd {
Server::from_listener(unsafe { TcpListener::from_raw_fd(fd) }, None).unwrap()
} else {
Server::http("localhost:0").unwrap()
};

if args.port.is_none() {
info!(
"Listening on :{}",
server.server_addr().to_ip().unwrap().port()
);
}

loop {
let request = server.recv().unwrap();
tracing::info!(target: "request_url", "{}", request.url());

let url = request.url()[1..].to_owned();
let response = if url.contains('/') {
error!(target: "error", "Request URL contains a slash");
not_found()
} else {
match File::open(args.directory.join(&url)) {
Ok(file) => Response::from_file(file).boxed(),
Err(err) => {
if err.kind() != ErrorKind::NotFound {
error!(target: "error", "{err}");
}
not_found()
}
}
};

let _ = request.respond(response).map_err(|err| {
error!(target: "error", "{err}");
});

Check warning on line 282 in ipa-core/src/bin/test_mpc.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/bin/test_mpc.rs#L280-L282

Added lines #L280 - L282 were not covered by tests
}
}
2 changes: 1 addition & 1 deletion ipa-core/src/cli/playbook/add.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ where
.into_iter()
.zip(clients)
.map(|(input_stream, client)| {
client.query_input(QueryInput {
client.query_input(QueryInput::Inline {
query_id,
input_stream,
})
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/cli/playbook/hybrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
|(shard_clients, shard_inputs)| {
try_join_all(shard_clients.iter().zip(shard_inputs.into_iter()).map(
|(client, input)| {
client.query_input(QueryInput {
client.query_input(QueryInput::Inline {

Check warning on line 47 in ipa-core/src/cli/playbook/hybrid.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/cli/playbook/hybrid.rs#L47

Added line #L47 was not covered by tests
query_id,
input_stream: input,
})
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/cli/playbook/ipa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ where
.into_iter()
.zip(clients)
.map(|(input_stream, client)| {
client.query_input(QueryInput {
client.query_input(QueryInput::Inline {
query_id,
input_stream,
})
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/cli/playbook/multiply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ where
.into_iter()
.zip(clients)
.map(|(input_stream, client)| {
client.query_input(QueryInput {
client.query_input(QueryInput::Inline {
query_id,
input_stream,
})
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/cli/playbook/sharded_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ where
let shared = chunk.iter().copied().share();
try_join_all(mpc_clients.each_ref().iter().zip(shared).map(
|(mpc_client, input)| {
mpc_client.query_input(QueryInput {
mpc_client.query_input(QueryInput::Inline {
query_id,
input_stream: BodyStream::from_serializable_iter(input),
})
Expand Down
52 changes: 48 additions & 4 deletions ipa-core/src/helpers/transport/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,58 @@
}
}

pub struct QueryInput {
pub query_id: QueryId,
pub input_stream: BodyStream,
pub enum QueryInput {
FromUrl {
query_id: QueryId,
url: String,
},
Inline {
query_id: QueryId,
input_stream: BodyStream,
},
}

impl QueryInput {
#[must_use]
pub fn query_id(&self) -> QueryId {
match self {
Self::FromUrl { query_id, .. } | Self::Inline { query_id, .. } => *query_id,
}
}

#[must_use]
pub fn input_stream(self) -> Option<BodyStream> {
match self {
Self::Inline { input_stream, .. } => Some(input_stream),
Self::FromUrl { .. } => None,
}
}

#[must_use]
pub fn url(&self) -> Option<&str> {
match self {
Self::FromUrl { url, .. } => Some(url),
Self::Inline { .. } => None,
}
}
}

impl Debug for QueryInput {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "query_inputs[{:?}]", self.query_id)
match self {

Check warning on line 225 in ipa-core/src/helpers/transport/query/mod.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/helpers/transport/query/mod.rs#L225

Added line #L225 was not covered by tests
QueryInput::Inline {
query_id,
input_stream: _,
} => f
.debug_struct("QueryInput::Inline")
.field("query_id", query_id)
.finish(),
QueryInput::FromUrl { query_id, url } => f
.debug_struct("QueryInput::FromUrl")
.field("query_id", query_id)
.field("url", url)
.finish(),

Check warning on line 237 in ipa-core/src/helpers/transport/query/mod.rs

View check run for this annotation

Codecov / codecov/patch

ipa-core/src/helpers/transport/query/mod.rs#L227-L237

Added lines #L227 - L237 were not covered by tests
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/net/client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,7 @@ pub(crate) mod tests {
};
test_query_command(
|client| async move {
let data = QueryInput {
let data = QueryInput::Inline {
query_id: expected_query_id,
input_stream: expected_input.to_vec().into(),
};
Expand Down
64 changes: 56 additions & 8 deletions ipa-core/src/net/http_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -312,12 +312,23 @@ pub mod query {
}

pub mod input {
use axum::{body::Body, http::uri};
use hyper::header::CONTENT_TYPE;
use axum::{
async_trait,
body::Body,
extract::FromRequestParts,
http::{request::Parts, uri},
};
use hyper::{
header::{HeaderValue, CONTENT_TYPE},
Uri,
};

use crate::{
helpers::query::QueryInput,
net::{http_serde::query::BASE_AXUM_PATH, APPLICATION_OCTET_STREAM},
net::{
http_serde::query::BASE_AXUM_PATH, Error, APPLICATION_OCTET_STREAM,
HTTP_QUERY_INPUT_URL_HEADER,
},
};

#[derive(Debug)]
Expand All @@ -341,17 +352,54 @@ pub mod query {
.path_and_query(format!(
"{}/{}/input",
BASE_AXUM_PATH,
self.query_input.query_id.as_ref(),
self.query_input.query_id().as_ref(),
))
.build()?;
let body = Body::from_stream(self.query_input.input_stream);
Ok(hyper::Request::post(uri)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(body)?)
let query_input_url = self.query_input.url().map(ToOwned::to_owned);
let body = self
.query_input
.input_stream()
.map_or_else(Body::empty, Body::from_stream);
let mut request =
hyper::Request::post(uri).header(CONTENT_TYPE, APPLICATION_OCTET_STREAM);
if let Some(url) = query_input_url {
request.headers_mut().unwrap().insert(
&HTTP_QUERY_INPUT_URL_HEADER,
HeaderValue::try_from(url).unwrap(),
);
}
Ok(request.body(body)?)
}
}

pub const AXUM_PATH: &str = "/:query_id/input";

pub struct QueryInputUrl(Option<Uri>);

#[async_trait]
impl<S: Send + Sync> FromRequestParts<S> for QueryInputUrl {
type Rejection = Error;

async fn from_request_parts(
req: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
match req.headers.get(&HTTP_QUERY_INPUT_URL_HEADER) {
None => Ok(QueryInputUrl(None)),
Some(value) => {
let value_str = value.to_str()?;
let uri = value_str.parse()?;
Ok(QueryInputUrl(Some(uri)))
}
}
}
}

impl From<QueryInputUrl> for Option<Uri> {
fn from(value: QueryInputUrl) -> Self {
value.0
}
}
}

pub mod step {
Expand Down
2 changes: 2 additions & 0 deletions ipa-core/src/net/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::{
mod client;
mod error;
mod http_serde;
pub mod query_input;
mod server;
#[cfg(all(test, not(feature = "shuttle")))]
pub mod test;
Expand All @@ -32,6 +33,7 @@ const APPLICATION_JSON: &str = "application/json";
const APPLICATION_OCTET_STREAM: &str = "application/octet-stream";
static HTTP_HELPER_ID_HEADER: HeaderName = HeaderName::from_static("x-unverified-helper-identity");
static HTTP_SHARD_INDEX_HEADER: HeaderName = HeaderName::from_static("x-unverified-shard-index");
static HTTP_QUERY_INPUT_URL_HEADER: HeaderName = HeaderName::from_static("x-query-input-url");

/// This has the same meaning as const defined in h2 crate, but we don't import it directly.
/// According to the [`spec`] it cannot exceed 2^31 - 1.
Expand Down
Loading
Loading