diff --git a/src/lib.rs b/src/lib.rs index 796c20c78..b31c5251d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,4 @@ pub mod dynamic_resources; mod macros; pub mod piping_server; -pub mod req_res_handler; pub mod util; diff --git a/src/main.rs b/src/main.rs index e5f3bf82a..4f7e14038 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,6 @@ use tokio::net::TcpListener; use tokio_rustls::TlsAcceptor; use piping_server::piping_server::PipingServer; -use piping_server::req_res_handler::req_res_handler; use piping_server::util; /// Piping Server in Rust @@ -86,10 +85,9 @@ async fn main() -> std::io::Result<()> { }); let https_svc = make_service_fn(move |_| { let piping_server = piping_server.clone(); - let handler = req_res_handler(move |req, res_sender| { - piping_server.handler(true, req, res_sender) - }); - futures::future::ok::<_, Infallible>(service_fn(handler)) + futures::future::ok::<_, Infallible>(service_fn(move |req| { + piping_server.handler(true, req) + })) }); let https_server = Server::builder(util::HyperAcceptor { acceptor: incoming_tls_stream, @@ -107,9 +105,9 @@ async fn main() -> std::io::Result<()> { let http_svc = make_service_fn(|_| { let piping_server = piping_server.clone(); - let handler = - req_res_handler(move |req, res_sender| piping_server.handler(false, req, res_sender)); - futures::future::ok::<_, Infallible>(service_fn(handler)) + futures::future::ok::<_, Infallible>(service_fn(move |req| { + piping_server.handler(false, req) + })) }); let http_server = Server::bind(&(args.host, args.http_port).into()).serve(http_svc); diff --git a/src/piping_server.rs b/src/piping_server.rs index f5acffa48..31ac5b973 100644 --- a/src/piping_server.rs +++ b/src/piping_server.rs @@ -62,8 +62,9 @@ impl PipingServer { &self, uses_https: bool, req: Request, - res_sender: oneshot::Sender>, - ) -> impl std::future::Future { + ) -> impl std::future::Future, oneshot::Canceled>> /* TODO: use better Error instead of oneshot::Canceled */ + { + let (res_sender, res_receiver) = oneshot::channel::>(); let path_to_sender = Arc::clone(&self.path_to_sender); let path_to_receiver = Arc::clone(&self.path_to_receiver); async move { @@ -81,7 +82,7 @@ impl PipingServer { .body(Body::from(dynamic_resources::index())) .unwrap(); res_sender.send(res).unwrap(); - return; + return res_receiver.await; } reserved_paths::NO_SCRIPT => { let query_params = query_param_to_hash_map(req.uri().query()); @@ -102,7 +103,7 @@ impl PipingServer { .body(Body::from(html)) .unwrap(); res_sender.send(res).unwrap(); - return; + return res_receiver.await; } reserved_paths::VERSION => { let version: &'static str = env!("CARGO_PKG_VERSION"); @@ -113,7 +114,7 @@ impl PipingServer { .body(Body::from(format!("{version} (Rust)\n"))) .unwrap(); res_sender.send(res).unwrap(); - return; + return res_receiver.await; } reserved_paths::HELP => { let host: &str = req @@ -142,12 +143,12 @@ impl PipingServer { .body(Body::from(help)) .unwrap(); res_sender.send(res).unwrap(); - return; + return res_receiver.await; } reserved_paths::FAVICON_ICO => { let res = Response::builder().status(204).body(Body::empty()).unwrap(); res_sender.send(res).unwrap(); - return; + return res_receiver.await; } reserved_paths::ROBOTS_TXT => { let res = Response::builder() @@ -157,7 +158,7 @@ impl PipingServer { .body(Body::empty()) .unwrap(); res_sender.send(res).unwrap(); - return; + return res_receiver.await; } _ => {} } @@ -173,7 +174,7 @@ impl PipingServer { "[ERROR] Service Worker registration is rejected.\n", ))) .unwrap(); - return; + return res_receiver.await; } } let query_params = query_param_to_hash_map(req.uri().query()); @@ -184,7 +185,7 @@ impl PipingServer { "[ERROR] Invalid \"n\" query parameter\n", ))) .unwrap(); - return; + return res_receiver.await; } let n_receivers = n_receivers_result.unwrap(); if n_receivers <= 0 { @@ -193,7 +194,7 @@ impl PipingServer { "[ERROR] n should > 0, but n = {n_receivers}.\n" )))) .unwrap(); - return; + return res_receiver.await; } if n_receivers > 1 { res_sender @@ -201,7 +202,7 @@ impl PipingServer { "[ERROR] n > 1 not supported yet.\n", ))) .unwrap(); - return; + return res_receiver.await; } let receiver_connected: bool = path_to_receiver.contains_key(path); // If a receiver has been connected already @@ -211,7 +212,7 @@ impl PipingServer { "[ERROR] Another receiver has been connected on '{path}'.\n", )))) .unwrap(); - return; + return res_receiver.await; } let sender = path_to_sender.remove(path); match sender { @@ -236,7 +237,7 @@ impl PipingServer { if reserved_paths::VALUES.contains(&path) { // Reject reserved path sending res_sender.send(rejection_response(Body::from(format!("[ERROR] Cannot send to the reserved path '{path}'. (e.g. '/mypath123')\n")))).unwrap(); - return; + return res_receiver.await; } // Notify that Content-Range is not supported // In the future, resumable upload using Content-Range might be supported @@ -249,7 +250,7 @@ impl PipingServer { req.method() )))) .unwrap(); - return; + return res_receiver.await; } let query_params = query_param_to_hash_map(req.uri().query()); let n_receivers_result: Result = get_n_receivers_result(&query_params); @@ -259,7 +260,7 @@ impl PipingServer { "[ERROR] Invalid \"n\" query parameter\n", ))) .unwrap(); - return; + return res_receiver.await; } let n_receivers = n_receivers_result.unwrap(); if n_receivers <= 0 { @@ -268,7 +269,7 @@ impl PipingServer { "[ERROR] n should > 0, but n = {n_receivers}.\n" )))) .unwrap(); - return; + return res_receiver.await; } if n_receivers > 1 { res_sender @@ -276,7 +277,7 @@ impl PipingServer { "[ERROR] n > 1 not supported yet.\n", ))) .unwrap(); - return; + return res_receiver.await; } let sender_connected: bool = path_to_sender.contains_key(path); // If a sender has been connected already @@ -286,7 +287,7 @@ impl PipingServer { "[ERROR] Another sender has been connected on '{path}'.\n", )))) .unwrap(); - return; + return res_receiver.await; } let (mut res_body_sender, body) = Body::channel(); @@ -381,7 +382,8 @@ impl PipingServer { .unwrap(); res_sender.send(res).unwrap(); } - } + }; + return res_receiver.await; } } } diff --git a/src/req_res_handler.rs b/src/req_res_handler.rs deleted file mode 100644 index 165b1dc96..000000000 --- a/src/req_res_handler.rs +++ /dev/null @@ -1,25 +0,0 @@ -use core::future::Future; -use futures::channel::oneshot; -use futures::FutureExt; -use http::{Request, Response}; -use hyper::Body; - -// NOTE: futures::future::Map<..., oneshot::Receiver, ...> can be a Future -pub fn req_res_handler( - mut handler: impl FnMut(Request, oneshot::Sender>) -> Fut, -) -> impl (FnMut( - Request, -) -> futures::future::Map< - futures::future::Join>>, - fn( - ((), Result, oneshot::Canceled>), - ) -> Result, oneshot::Canceled>, ->) -where - Fut: Future, -{ - move |req| { - let (res_sender, res_receiver) = oneshot::channel::>(); - futures::future::join(handler(req, res_sender), res_receiver).map(|(_, x)| x) - } -} diff --git a/tests/piping_server.rs b/tests/piping_server.rs index ffefd409e..800913f57 100644 --- a/tests/piping_server.rs +++ b/tests/piping_server.rs @@ -7,7 +7,6 @@ use specit::tokio_it as it; use std::convert::Infallible; use piping_server::piping_server::PipingServer; -use piping_server::req_res_handler::req_res_handler; use std::net::SocketAddr; use std::time; @@ -52,10 +51,9 @@ async fn serve() -> Serve { tokio::spawn(async move { let http_svc = make_service_fn(|_| { let piping_server = piping_server.clone(); - let handler = req_res_handler(move |req, res_sender| { - piping_server.handler(false, req, res_sender) - }); - futures::future::ok::<_, Infallible>(service_fn(handler)) + futures::future::ok::<_, Infallible>(service_fn(move |req| { + piping_server.handler(false, req) + })) }); let http_server = Server::bind(&([127, 0, 0, 1], 0).into()).serve(http_svc); addr_tx