diff --git a/ipa-core/Cargo.toml b/ipa-core/Cargo.toml index 1b273521f..c4a769221 100644 --- a/ipa-core/Cargo.toml +++ b/ipa-core/Cargo.toml @@ -41,6 +41,7 @@ web-app = [ "rustls", "rustls-pemfile", "time", + "tiny_http", "tokio-rustls", "toml", "tower", @@ -140,6 +141,7 @@ thiserror = "1.0" tikv-jemallocator = { version = "0.6", optional = true, features = ["profiling"] } tikv-jemalloc-ctl = { version = "0.6", optional = true, features = ["stats"] } time = { version = "0.3", optional = true } +tiny_http = { version = "0.12", optional = true } tokio = { version = "1.42", features = ["fs", "rt", "rt-multi-thread", "macros"] } tokio-rustls = { version = "0.26", optional = true } tokio-stream = "0.1.14" diff --git a/ipa-core/src/app.rs b/ipa-core/src/app.rs index 0d61e9bb6..715cc4d0d 100644 --- a/ipa-core/src/app.rs +++ b/ipa-core/src/app.rs @@ -136,12 +136,24 @@ impl HelperApp { /// /// ## Errors /// Propagates errors from the helper. + /// ## Panics + /// If `input` asks to obtain query input from a remote URL. pub fn execute_query(&self, input: QueryInput) -> Result<(), ApiError> { let mpc_transport = self.inner.mpc_transport.clone_ref(); let shard_transport = self.inner.shard_transport.clone_ref(); - self.inner - .query_processor - .receive_inputs(mpc_transport, shard_transport, input)?; + let QueryInput::Inline { + query_id, + input_stream, + } = input + else { + panic!("this client does not support pulling query input from a URL"); + }; + self.inner.query_processor.receive_inputs( + mpc_transport, + shard_transport, + query_id, + input_stream, + )?; Ok(()) } @@ -254,10 +266,8 @@ impl RequestHandler for Inner { HelperResponse::from(qp.receive_inputs( Transport::clone_ref(&self.mpc_transport), Transport::clone_ref(&self.shard_transport), - QueryInput { - query_id, - input_stream: data, - }, + query_id, + data, )?) } RouteId::QueryStatus => { diff --git a/ipa-core/src/bin/test_mpc.rs b/ipa-core/src/bin/test_mpc.rs index baf99a2ca..3735c4c18 100644 --- a/ipa-core/src/bin/test_mpc.rs +++ b/ipa-core/src/bin/test_mpc.rs @@ -1,4 +1,13 @@ -use std::{error::Error, fmt::Debug, ops::Add, path::PathBuf}; +use std::{ + error::Error, + fmt::Debug, + fs::File, + io::ErrorKind, + net::TcpListener, + ops::Add, + os::fd::{FromRawFd, RawFd}, + path::PathBuf, +}; use clap::{Parser, Subcommand}; use generic_array::ArrayLength; @@ -21,6 +30,8 @@ use ipa_core::{ net::{Helper, IpaHttpClient}, secret_sharing::{replicated::semi_honest::AdditiveShare, IntoShares}, }; +use tiny_http::{Response, ResponseBox, Server, StatusCode}; +use tracing::{error, info}; #[derive(Debug, Parser)] #[clap( @@ -95,6 +106,23 @@ enum TestAction { /// This is exactly what shuffle does and that's why it is picked /// for this purpose. ShardedShuffle, + ServeInput(ServeInputArgs), +} + +#[derive(Debug, clap::Args)] +#[clap(about = "Run a simple HTTP server to serve query input files")] +pub struct ServeInputArgs { + /// Port to listen on + #[arg(short, long)] + port: Option, + + /// Listen on the supplied prebound socket instead of binding a new socket + #[arg(long, conflicts_with = "port")] + fd: Option, + + /// Directory with input files to serve + #[arg(short, long = "dir")] + directory: PathBuf, } #[tokio::main] @@ -129,6 +157,7 @@ async fn main() -> Result<(), Box> { .await; sharded_shuffle(&args, clients).await } + TestAction::ServeInput(options) => serve_input(options), }; Ok(()) @@ -204,3 +233,52 @@ async fn sharded_shuffle(args: &Args, helper_clients: Vec<[IpaHttpClient assert_eq!(shuffled.len(), input_rows.len()); assert_ne!(shuffled, input_rows); } + +fn not_found() -> ResponseBox { + Response::from_string("not found") + .with_status_code(StatusCode(404)) + .boxed() +} + +#[tracing::instrument("serve_input", skip_all)] +fn serve_input(args: ServeInputArgs) { + let server = if let Some(port) = args.port { + Server::http(("localhost", port)).unwrap() + } else if let Some(fd) = args.fd { + Server::from_listener(unsafe { TcpListener::from_raw_fd(fd) }, None).unwrap() + } else { + Server::http("localhost:0").unwrap() + }; + + if args.port.is_none() { + info!( + "Listening on :{}", + server.server_addr().to_ip().unwrap().port() + ); + } + + loop { + let request = server.recv().unwrap(); + tracing::info!(target: "request_url", "{}", request.url()); + + let url = request.url()[1..].to_owned(); + let response = if url.contains('/') { + error!(target: "error", "Request URL contains a slash"); + not_found() + } else { + match File::open(args.directory.join(&url)) { + Ok(file) => Response::from_file(file).boxed(), + Err(err) => { + if err.kind() != ErrorKind::NotFound { + error!(target: "error", "{err}"); + } + not_found() + } + } + }; + + let _ = request.respond(response).map_err(|err| { + error!(target: "error", "{err}"); + }); + } +} diff --git a/ipa-core/src/cli/playbook/add.rs b/ipa-core/src/cli/playbook/add.rs index eafa1da8d..5c785bd1b 100644 --- a/ipa-core/src/cli/playbook/add.rs +++ b/ipa-core/src/cli/playbook/add.rs @@ -47,7 +47,7 @@ where .into_iter() .zip(clients) .map(|(input_stream, client)| { - client.query_input(QueryInput { + client.query_input(QueryInput::Inline { query_id, input_stream, }) diff --git a/ipa-core/src/cli/playbook/hybrid.rs b/ipa-core/src/cli/playbook/hybrid.rs index 92d74b383..53bfc6c28 100644 --- a/ipa-core/src/cli/playbook/hybrid.rs +++ b/ipa-core/src/cli/playbook/hybrid.rs @@ -44,7 +44,7 @@ where |(shard_clients, shard_inputs)| { try_join_all(shard_clients.iter().zip(shard_inputs.into_iter()).map( |(client, input)| { - client.query_input(QueryInput { + client.query_input(QueryInput::Inline { query_id, input_stream: input, }) diff --git a/ipa-core/src/cli/playbook/ipa.rs b/ipa-core/src/cli/playbook/ipa.rs index 6f3691306..b294d7695 100644 --- a/ipa-core/src/cli/playbook/ipa.rs +++ b/ipa-core/src/cli/playbook/ipa.rs @@ -115,7 +115,7 @@ where .into_iter() .zip(clients) .map(|(input_stream, client)| { - client.query_input(QueryInput { + client.query_input(QueryInput::Inline { query_id, input_stream, }) diff --git a/ipa-core/src/cli/playbook/multiply.rs b/ipa-core/src/cli/playbook/multiply.rs index ec777005a..265659c85 100644 --- a/ipa-core/src/cli/playbook/multiply.rs +++ b/ipa-core/src/cli/playbook/multiply.rs @@ -55,7 +55,7 @@ where .into_iter() .zip(clients) .map(|(input_stream, client)| { - client.query_input(QueryInput { + client.query_input(QueryInput::Inline { query_id, input_stream, }) diff --git a/ipa-core/src/cli/playbook/sharded_shuffle.rs b/ipa-core/src/cli/playbook/sharded_shuffle.rs index 0139a8171..fceb1f4e4 100644 --- a/ipa-core/src/cli/playbook/sharded_shuffle.rs +++ b/ipa-core/src/cli/playbook/sharded_shuffle.rs @@ -45,7 +45,7 @@ where let shared = chunk.iter().copied().share(); try_join_all(mpc_clients.each_ref().iter().zip(shared).map( |(mpc_client, input)| { - mpc_client.query_input(QueryInput { + mpc_client.query_input(QueryInput::Inline { query_id, input_stream: BodyStream::from_serializable_iter(input), }) diff --git a/ipa-core/src/helpers/transport/query/mod.rs b/ipa-core/src/helpers/transport/query/mod.rs index c1daa7c71..cd3e389d1 100644 --- a/ipa-core/src/helpers/transport/query/mod.rs +++ b/ipa-core/src/helpers/transport/query/mod.rs @@ -184,14 +184,58 @@ impl RouteParams for &PrepareQuery { } } -pub struct QueryInput { - pub query_id: QueryId, - pub input_stream: BodyStream, +pub enum QueryInput { + FromUrl { + query_id: QueryId, + url: String, + }, + Inline { + query_id: QueryId, + input_stream: BodyStream, + }, +} + +impl QueryInput { + #[must_use] + pub fn query_id(&self) -> QueryId { + match self { + Self::FromUrl { query_id, .. } | Self::Inline { query_id, .. } => *query_id, + } + } + + #[must_use] + pub fn input_stream(self) -> Option { + match self { + Self::Inline { input_stream, .. } => Some(input_stream), + Self::FromUrl { .. } => None, + } + } + + #[must_use] + pub fn url(&self) -> Option<&str> { + match self { + Self::FromUrl { url, .. } => Some(url), + Self::Inline { .. } => None, + } + } } impl Debug for QueryInput { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "query_inputs[{:?}]", self.query_id) + match self { + QueryInput::Inline { + query_id, + input_stream: _, + } => f + .debug_struct("QueryInput::Inline") + .field("query_id", query_id) + .finish(), + QueryInput::FromUrl { query_id, url } => f + .debug_struct("QueryInput::FromUrl") + .field("query_id", query_id) + .field("url", url) + .finish(), + } } } diff --git a/ipa-core/src/net/client/mod.rs b/ipa-core/src/net/client/mod.rs index d4789b198..42d7c1377 100644 --- a/ipa-core/src/net/client/mod.rs +++ b/ipa-core/src/net/client/mod.rs @@ -726,7 +726,7 @@ pub(crate) mod tests { }; test_query_command( |client| async move { - let data = QueryInput { + let data = QueryInput::Inline { query_id: expected_query_id, input_stream: expected_input.to_vec().into(), }; diff --git a/ipa-core/src/net/http_serde.rs b/ipa-core/src/net/http_serde.rs index c7c3eb1a4..8cc2be8ed 100644 --- a/ipa-core/src/net/http_serde.rs +++ b/ipa-core/src/net/http_serde.rs @@ -312,12 +312,23 @@ pub mod query { } pub mod input { - use axum::{body::Body, http::uri}; - use hyper::header::CONTENT_TYPE; + use axum::{ + async_trait, + body::Body, + extract::FromRequestParts, + http::{request::Parts, uri}, + }; + use hyper::{ + header::{HeaderValue, CONTENT_TYPE}, + Uri, + }; use crate::{ helpers::query::QueryInput, - net::{http_serde::query::BASE_AXUM_PATH, APPLICATION_OCTET_STREAM}, + net::{ + http_serde::query::BASE_AXUM_PATH, Error, APPLICATION_OCTET_STREAM, + HTTP_QUERY_INPUT_URL_HEADER, + }, }; #[derive(Debug)] @@ -341,17 +352,54 @@ pub mod query { .path_and_query(format!( "{}/{}/input", BASE_AXUM_PATH, - self.query_input.query_id.as_ref(), + self.query_input.query_id().as_ref(), )) .build()?; - let body = Body::from_stream(self.query_input.input_stream); - Ok(hyper::Request::post(uri) - .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) - .body(body)?) + let query_input_url = self.query_input.url().map(ToOwned::to_owned); + let body = self + .query_input + .input_stream() + .map_or_else(Body::empty, Body::from_stream); + let mut request = + hyper::Request::post(uri).header(CONTENT_TYPE, APPLICATION_OCTET_STREAM); + if let Some(url) = query_input_url { + request.headers_mut().unwrap().insert( + &HTTP_QUERY_INPUT_URL_HEADER, + HeaderValue::try_from(url).unwrap(), + ); + } + Ok(request.body(body)?) } } pub const AXUM_PATH: &str = "/:query_id/input"; + + pub struct QueryInputUrl(Option); + + #[async_trait] + impl FromRequestParts for QueryInputUrl { + type Rejection = Error; + + async fn from_request_parts( + req: &mut Parts, + _state: &S, + ) -> Result { + match req.headers.get(&HTTP_QUERY_INPUT_URL_HEADER) { + None => Ok(QueryInputUrl(None)), + Some(value) => { + let value_str = value.to_str()?; + let uri = value_str.parse()?; + Ok(QueryInputUrl(Some(uri))) + } + } + } + } + + impl From for Option { + fn from(value: QueryInputUrl) -> Self { + value.0 + } + } } pub mod step { diff --git a/ipa-core/src/net/mod.rs b/ipa-core/src/net/mod.rs index 621365077..5bce44553 100644 --- a/ipa-core/src/net/mod.rs +++ b/ipa-core/src/net/mod.rs @@ -18,6 +18,7 @@ use crate::{ mod client; mod error; mod http_serde; +pub mod query_input; mod server; #[cfg(all(test, not(feature = "shuttle")))] pub mod test; @@ -32,6 +33,7 @@ const APPLICATION_JSON: &str = "application/json"; const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; static HTTP_HELPER_ID_HEADER: HeaderName = HeaderName::from_static("x-unverified-helper-identity"); static HTTP_SHARD_INDEX_HEADER: HeaderName = HeaderName::from_static("x-unverified-shard-index"); +static HTTP_QUERY_INPUT_URL_HEADER: HeaderName = HeaderName::from_static("x-query-input-url"); /// This has the same meaning as const defined in h2 crate, but we don't import it directly. /// According to the [`spec`] it cannot exceed 2^31 - 1. diff --git a/ipa-core/src/net/query_input.rs b/ipa-core/src/net/query_input.rs new file mode 100644 index 000000000..f2608392f --- /dev/null +++ b/ipa-core/src/net/query_input.rs @@ -0,0 +1,59 @@ +use axum::{body::Body, BoxError}; +use http_body_util::BodyExt; +use hyper::Uri; +use hyper_rustls::HttpsConnectorBuilder; +use hyper_util::{ + client::legacy::Client, + rt::{TokioExecutor, TokioTimer}, +}; + +use crate::{helpers::BodyStream, net::Error}; + +/// Connect to a remote URL to download query input. +/// +/// # Errors +/// If the connection to the remote URL fails or returns an HTTP error. +/// +/// # Panics +/// If unable to create an HTTPS client using the system truststore. +pub async fn stream_query_input_from_url(uri: &Uri) -> Result { + let mut builder = Client::builder(TokioExecutor::new()); + // the following timer is necessary for http2, in particular for any timeouts + // and waits the clients will need to make + // TODO: implement IpaTimer to allow wrapping other than Tokio runtimes + builder.timer(TokioTimer::new()); + let client = builder.build::<_, Body>( + HttpsConnectorBuilder::default() + .with_native_roots() + .expect("System truststore is required") + .https_only() + .enable_all_versions() + .build(), + ); + + let resp = client + .get(uri.clone()) + .await + .map_err(|inner| Error::ConnectError { + dest: uri.to_string(), + inner, + })?; + + if !resp.status().is_success() { + let status = resp.status(); + assert!(status.is_client_error() || status.is_server_error()); // must be failure + return Err( + axum::body::to_bytes(Body::new(resp.into_body()), 36_000_000) // Roughly 36mb + .await + .map_or_else(Into::into, |reason_bytes| Error::FailedHttpRequest { + dest: uri.to_string(), + status, + reason: String::from_utf8_lossy(&reason_bytes).to_string(), + }), + ); + } + + Ok(BodyStream::from_bytes_stream( + resp.into_body().map_err(BoxError::from).into_data_stream(), + )) +} diff --git a/ipa-core/src/net/server/handlers/query/input.rs b/ipa-core/src/net/server/handlers/query/input.rs index da47e9386..1de6ea570 100644 --- a/ipa-core/src/net/server/handlers/query/input.rs +++ b/ipa-core/src/net/server/handlers/query/input.rs @@ -2,25 +2,29 @@ use axum::{extract::Path, routing::post, Extension, Router}; use hyper::StatusCode; use crate::{ - helpers::{query::QueryInput, routing::RouteId, BodyStream}, - net::{http_serde, transport::MpcHttpTransport, Error}, + helpers::{routing::RouteId, BodyStream}, + net::{ + http_serde::{self, query::input::QueryInputUrl}, + query_input::stream_query_input_from_url, + transport::MpcHttpTransport, + Error, + }, protocol::QueryId, }; async fn handler( transport: Extension, Path(query_id): Path, + input_url: QueryInputUrl, input_stream: BodyStream, ) -> Result<(), Error> { - let query_input = QueryInput { - query_id, - input_stream, + let input_stream = if let Some(url) = input_url.into() { + stream_query_input_from_url(&url).await? + } else { + input_stream }; let _ = transport - .dispatch( - (RouteId::QueryInput, query_input.query_id), - query_input.input_stream, - ) + .dispatch((RouteId::QueryInput, query_id), input_stream) .await .map_err(|e| Error::application(StatusCode::INTERNAL_SERVER_ERROR, e))?; @@ -35,10 +39,15 @@ pub fn router(transport: MpcHttpTransport) -> Router { #[cfg(all(test, unit_test))] mod tests { + use std::thread; + use axum::{ body::Body, http::uri::{Authority, Scheme}, }; + use bytes::BytesMut; + use futures::TryStreamExt; + use http_body_util::BodyExt; use hyper::StatusCode; use tokio::runtime::Handle; @@ -49,24 +58,22 @@ mod tests { net::{ http_serde, server::handlers::query::test_helpers::{assert_fails_with, assert_success_with}, + test::TestServer, }, protocol::QueryId, }; #[tokio::test(flavor = "multi_thread")] - async fn input_test() { - let expected_query_id = QueryId; + async fn input_inline() { + const QUERY_ID: QueryId = QueryId; let expected_input = &[4u8; 4]; - let req = http_serde::query::input::Request::new(QueryInput { - query_id: expected_query_id, - input_stream: expected_input.to_vec().into(), - }); + let req_handler = make_owned_handler(move |addr, data| async move { let RouteId::QueryInput = addr.route else { panic!("unexpected call"); }; - assert_eq!(addr.query_id, Some(expected_query_id)); + assert_eq!(addr.query_id, Some(QUERY_ID)); assert_eq!( tokio::task::block_in_place(move || { Handle::current().block_on(async move { data.to_vec().await }) @@ -76,10 +83,66 @@ mod tests { Ok(HelperResponse::ok()) }); - let req = req + + let req = http_serde::query::input::Request::new(QueryInput::Inline { + query_id: QUERY_ID, + input_stream: expected_input.to_vec().into(), + }); + let hyper_req = req .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) .unwrap(); - assert_success_with(req, req_handler).await; + + assert_success_with(hyper_req, req_handler).await; + } + + #[tokio::test(flavor = "multi_thread")] + async fn input_from_url() { + const QUERY_ID: QueryId = QueryId; + const DATA: &str = "input records"; + + let server = tiny_http::Server::http("localhost:0").unwrap(); + let addr = server.server_addr(); + thread::spawn(move || { + let request = server.recv().unwrap(); + let response = tiny_http::Response::from_string(DATA); + request.respond(response).unwrap(); + }); + + let req_handler = make_owned_handler(move |addr, body| async move { + let RouteId::QueryInput = addr.route else { + panic!("unexpected call"); + }; + + assert_eq!(addr.query_id, Some(QUERY_ID)); + assert_eq!(body.try_collect::().await.unwrap(), DATA); + + Ok(HelperResponse::ok()) + }); + let test_server = TestServer::builder() + .with_request_handler(req_handler) + .build() + .await; + + let url = format!( + "http://localhost:{}{}/{QUERY_ID}/input", + addr.to_ip().unwrap().port(), + http_serde::query::BASE_AXUM_PATH, + ); + let req = http_serde::query::input::Request::new(QueryInput::FromUrl { + query_id: QUERY_ID, + url, + }); + let hyper_req = req + .try_into_http_request(Scheme::HTTP, Authority::from_static("localhost")) + .unwrap(); + + let resp = test_server.server.handle_req(hyper_req).await; + if !resp.status().is_success() { + let (head, body) = resp.into_parts(); + let body_bytes = body.collect().await.unwrap().to_bytes(); + let body = String::from_utf8_lossy(&body_bytes); + panic!("{head:?}\n{body}"); + } } struct OverrideReq { diff --git a/ipa-core/src/net/transport.rs b/ipa-core/src/net/transport.rs index d8fbfc4b5..7c75d683e 100644 --- a/ipa-core/src/net/transport.rs +++ b/ipa-core/src/net/transport.rs @@ -574,7 +574,7 @@ mod tests { let mut handle_resps = Vec::with_capacity(helper_shares.len()); for (i, input_stream) in helper_shares.into_iter().enumerate() { - let data = QueryInput { + let data = QueryInput::Inline { query_id, input_stream, }; @@ -586,7 +586,7 @@ mod tests { // convention - first client is shard leader, and we submitted the inputs to it. try_join_all(clients.iter().skip(1).map(|ring| { try_join_all(ring.each_ref().map(|shard_client| { - shard_client.query_input(QueryInput { + shard_client.query_input(QueryInput::Inline { query_id, input_stream: BodyStream::empty(), }) @@ -638,7 +638,7 @@ mod tests { |(helper, shard_streams)| async move { try_join_all(shard_streams.into_iter().enumerate().map( |(shard, input_stream)| { - clients[shard][helper].query_input(QueryInput { + clients[shard][helper].query_input(QueryInput::Inline { query_id, input_stream, }) diff --git a/ipa-core/src/query/processor.rs b/ipa-core/src/query/processor.rs index df392b44f..95cfa0f44 100644 --- a/ipa-core/src/query/processor.rs +++ b/ipa-core/src/query/processor.rs @@ -11,10 +11,10 @@ use crate::{ error::Error as ProtocolError, executor::IpaRuntime, helpers::{ - query::{CompareStatusRequest, PrepareQuery, QueryConfig, QueryInput}, + query::{CompareStatusRequest, PrepareQuery, QueryConfig}, routing::RouteId, - BroadcastError, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, Role, - RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, + BodyStream, BroadcastError, Gateway, GatewayConfig, MpcTransportError, MpcTransportImpl, + Role, RoleAssignment, ShardTransportError, ShardTransportImpl, Transport, }, hpke::{KeyRegistry, PrivateKeyOnly}, protocol::QueryId, @@ -213,7 +213,7 @@ impl Processor { // to rollback 1,2 and 3 shard_transport.broadcast(prepare_request.clone()).await?; - handle.set_state(QueryState::AwaitingInputs(query_id, req, roles))?; + handle.set_state(QueryState::AwaitingInputs(req, roles))?; guard.restore(); Ok(prepare_request) @@ -249,11 +249,7 @@ impl Processor { // TODO: If shards 1,2 and 3 succeed but 4 fails, then we need to rollback 1,2 and 3. shard_transport.broadcast(req.clone()).await?; - handle.set_state(QueryState::AwaitingInputs( - req.query_id, - req.config, - req.roles, - ))?; + handle.set_state(QueryState::AwaitingInputs(req.config, req.roles))?; Ok(()) } @@ -280,11 +276,7 @@ impl Processor { return Err(PrepareQueryError::AlreadyRunning); } - handle.set_state(QueryState::AwaitingInputs( - req.query_id, - req.config, - req.roles, - ))?; + handle.set_state(QueryState::AwaitingInputs(req.config, req.roles))?; Ok(()) } @@ -300,17 +292,14 @@ impl Processor { &self, mpc_transport: MpcTransportImpl, shard_transport: ShardTransportImpl, - input: QueryInput, + query_id: QueryId, + input_stream: BodyStream, ) -> Result<(), QueryInputError> { let mut queries = self.queries.inner.lock().unwrap(); - match queries.entry(input.query_id) { + match queries.entry(query_id) { Entry::Occupied(entry) => { let state = entry.remove(); - if let QueryState::AwaitingInputs(query_id, config, role_assignment) = state { - assert_eq!( - input.query_id, query_id, - "received inputs for a different query" - ); + if let QueryState::AwaitingInputs(config, role_assignment) = state { let mut gateway_config = GatewayConfig::default(); if let Some(active_work) = self.active_work { gateway_config.active = active_work; @@ -325,13 +314,13 @@ impl Processor { shard_transport, ); queries.insert( - input.query_id, + query_id, QueryState::Running(executor::execute( &self.runtime, config, Arc::clone(&self.key_registry), gateway, - input.input_stream, + input_stream, )), ); Ok(()) @@ -340,11 +329,11 @@ impl Processor { from: QueryStatus::from(&state), to: QueryStatus::Running, }; - queries.insert(input.query_id, state); + queries.insert(query_id, state); Err(QueryInputError::StateError { source: error }) } } - Entry::Vacant(_) => Err(QueryInputError::NoSuchQuery(input.query_id)), + Entry::Vacant(_) => Err(QueryInputError::NoSuchQuery(query_id)), } } diff --git a/ipa-core/src/query/state.rs b/ipa-core/src/query/state.rs index bca5c7e1d..148a40565 100644 --- a/ipa-core/src/query/state.rs +++ b/ipa-core/src/query/state.rs @@ -46,7 +46,7 @@ impl From<&QueryState> for QueryStatus { match source { QueryState::Empty => panic!("Query cannot be in the empty state"), QueryState::Preparing(_) => QueryStatus::Preparing, - QueryState::AwaitingInputs(_, _, _) => QueryStatus::AwaitingInputs, + QueryState::AwaitingInputs(_, _) => QueryStatus::AwaitingInputs, QueryState::Running(_) => QueryStatus::Running, QueryState::AwaitingCompletion => QueryStatus::AwaitingCompletion, QueryState::Completed(_) => QueryStatus::Completed, @@ -78,7 +78,7 @@ pub fn min_status(a: QueryStatus, b: QueryStatus) -> QueryStatus { pub enum QueryState { Empty, Preparing(QueryConfig), - AwaitingInputs(QueryId, QueryConfig, RoleAssignment), + AwaitingInputs(QueryConfig, RoleAssignment), Running(RunningQuery), AwaitingCompletion, Completed(QueryResult), @@ -91,9 +91,9 @@ impl QueryState { match (cur_state, &new_state) { // If query is not running, coordinator initial state is preparing // and followers initial state is awaiting inputs - (Empty, Preparing(_) | AwaitingInputs(_, _, _)) - | (Preparing(_), AwaitingInputs(_, _, _)) - | (AwaitingInputs(_, _, _), Running(_)) => Ok(new_state), + (Empty, Preparing(_) | AwaitingInputs(_, _)) + | (Preparing(_), AwaitingInputs(_, _)) + | (AwaitingInputs(_, _), Running(_)) => Ok(new_state), (_, Preparing(_)) => Err(StateError::AlreadyRunning), (_, _) => Err(StateError::InvalidState { from: cur_state.into(), diff --git a/ipa-core/src/test_fixture/app.rs b/ipa-core/src/test_fixture/app.rs index 1866cef74..98c90142f 100644 --- a/ipa-core/src/test_fixture/app.rs +++ b/ipa-core/src/test_fixture/app.rs @@ -104,7 +104,7 @@ impl TestApp { .into_iter() .enumerate() .map(|(i, input)| { - self.drivers[i].execute_query(QueryInput { + self.drivers[i].execute_query(QueryInput::Inline { query_id, input_stream: input.into(), })