From 322df14e7908b87c609aab870ed49d58457e4118 Mon Sep 17 00:00:00 2001 From: Irene Zhang Date: Thu, 16 Nov 2023 00:56:49 +0000 Subject: [PATCH 1/3] [runtime] Enhancement: Add new remove call --- src/rust/runtime/mod.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/rust/runtime/mod.rs b/src/rust/runtime/mod.rs index aae72c185..5baa65794 100644 --- a/src/rust/runtime/mod.rs +++ b/src/rust/runtime/mod.rs @@ -177,6 +177,11 @@ impl SharedDemiRuntime { OperationTask::from(boxed_task.as_any()) } + /// Removes a coroutine from the underlying scheduler given its associated [QToken] `qt`. + pub fn remove_coroutine_with_qtoken(&mut self, qt: QToken) -> OperationTask { + self.remove_coroutine(&self.scheduler.from_task_id(qt.into()).expect("coroutine should exist")) + } + /// Removes a coroutine from the underlying scheduler given its associated [TaskHandle] `handle` /// and gets the result immediately. pub fn remove_coroutine_and_get_result(&mut self, handle: &TaskHandle, qt: u64) -> demi_qresult_t { From 30a1cd8c395563e3db24128781cd8c894defbc6b Mon Sep 17 00:00:00 2001 From: Irene Zhang Date: Thu, 16 Nov 2023 01:15:28 +0000 Subject: [PATCH 2/3] [inetstack] Enhancement: Move to shared state machine --- src/rust/catnip/mod.rs | 4 +- src/rust/catpowder/mod.rs | 4 +- src/rust/inetstack/mod.rs | 69 ++-- .../inetstack/protocols/tcp/passive_open.rs | 2 +- src/rust/inetstack/protocols/tcp/peer.rs | 306 +++++++++++------ src/rust/inetstack/protocols/tcp/queue.rs | 325 ++++++++++++------ .../protocols/tcp/tests/established.rs | 20 +- .../inetstack/protocols/tcp/tests/setup.rs | 50 ++- src/rust/inetstack/test_helpers/engine.rs | 13 +- tests/rust/tcp.rs | 4 +- 10 files changed, 468 insertions(+), 329 deletions(-) diff --git a/src/rust/catnip/mod.rs b/src/rust/catnip/mod.rs index 55c7dcb95..7790f099e 100644 --- a/src/rust/catnip/mod.rs +++ b/src/rust/catnip/mod.rs @@ -118,9 +118,7 @@ impl CatnipLibOS { return Err(Fail::new(libc::EINVAL, "zero-length buffer")); } - let handle: TaskHandle = self.do_push(qd, buf)?; - let qt: QToken = handle.get_task_id().into(); - Ok(qt) + self.do_push(qd, buf) }, Err(e) => Err(e), } diff --git a/src/rust/catpowder/mod.rs b/src/rust/catpowder/mod.rs index 36ed7335a..013a9357c 100644 --- a/src/rust/catpowder/mod.rs +++ b/src/rust/catpowder/mod.rs @@ -98,9 +98,7 @@ impl CatpowderLibOS { if buf.len() == 0 { return Err(Fail::new(libc::EINVAL, "zero-length buffer")); } - let handle: TaskHandle = self.do_push(qd, buf)?; - let qt: QToken = handle.get_task_id().into(); - Ok(qt) + self.do_push(qd, buf) }, Err(e) => Err(e), } diff --git a/src/rust/inetstack/mod.rs b/src/rust/inetstack/mod.rs index 6edcdbc6d..8843bd03a 100644 --- a/src/rust/inetstack/mod.rs +++ b/src/rust/inetstack/mod.rs @@ -249,12 +249,7 @@ impl InetStack { // Search for target queue descriptor. match self.runtime.get_queue_type(&qd)? { - QType::TcpSocket => { - let coroutine: Pin> = self.ipv4.tcp.accept(qd)?; - let task_id: String = format!("Inetstack::TCP::accept for qd={:?}", qd); - let handle: TaskHandle = self.runtime.insert_coroutine(task_id.as_str(), coroutine)?; - Ok(handle.get_task_id().into()) - }, + QType::TcpSocket => self.ipv4.tcp.accept(qd), // This queue descriptor does not concern a TCP socket. _ => Err(Fail::new(libc::EINVAL, "invalid queue type")), } @@ -281,12 +276,7 @@ impl InetStack { let remote: SocketAddrV4 = unwrap_socketaddr(remote)?; match self.runtime.get_queue_type(&qd)? { - QType::TcpSocket => { - let coroutine: Pin> = self.ipv4.tcp.connect(qd, remote)?; - let task_id: String = format!("Inetstack::TCP::connect for qd={:?}", qd); - let handle: TaskHandle = self.runtime.insert_coroutine(task_id.as_str(), coroutine)?; - Ok(handle.get_task_id().into()) - }, + QType::TcpSocket => self.ipv4.tcp.connect(qd, remote), _ => Err(Fail::new(libc::EINVAL, "invalid queue type")), } } @@ -328,12 +318,8 @@ impl InetStack { timer!("inetstack::async_close"); trace!("async_close(): qd={:?}", qd); - let (task_id, coroutine): (String, Pin>) = match self.runtime.get_queue_type(&qd)? { - QType::TcpSocket => { - let task_id: String = format!("Inetstack::TCP::close for qd={:?}", qd); - let coroutine: Pin> = self.ipv4.tcp.async_close(qd)?; - (task_id, coroutine) - }, + match self.runtime.get_queue_type(&qd)? { + QType::TcpSocket => self.ipv4.tcp.async_close(qd), QType::UdpSocket => { self.ipv4.udp.close(qd)?; let task_id: String = format!("Inetstack::UDP::close for qd={:?}", qd); @@ -346,26 +332,20 @@ impl InetStack { .expect("queue should exist"); (qd, OperationResult::Close) }); - (task_id, coroutine) + let handle: TaskHandle = self.runtime.insert_coroutine(task_id.as_str(), coroutine)?; + let qt: QToken = handle.get_task_id().into(); + trace!("async_close() qt={:?}", qt); + Ok(qt) }, - _ => return Err(Fail::new(libc::EINVAL, "invalid queue type")), - }; - - let handle: TaskHandle = self.runtime.insert_coroutine(task_id.as_str(), coroutine)?; - let qt: QToken = handle.get_task_id().into(); - trace!("async_close() qt={:?}", qt); - Ok(qt) + _ => Err(Fail::new(libc::EINVAL, "invalid queue type")), + } } /// Pushes a buffer to a TCP socket. /// TODO: Rename this function to push() once we have a common representation across all libOSes. - pub fn do_push(&mut self, qd: QDesc, buf: DemiBuffer) -> Result { + pub fn do_push(&mut self, qd: QDesc, buf: DemiBuffer) -> Result { match self.runtime.get_queue_type(&qd)? { - QType::TcpSocket => { - let coroutine: Pin> = self.ipv4.tcp.push(qd, buf)?; - let task_id: String = format!("Inetstack::TCP::push for qd={:?}", qd); - self.runtime.insert_coroutine(task_id.as_str(), coroutine) - }, + QType::TcpSocket => self.ipv4.tcp.push(qd, buf), _ => Err(Fail::new(libc::EINVAL, "invalid queue type")), } } @@ -384,10 +364,7 @@ impl InetStack { } // Issue operation. - let handle: TaskHandle = self.do_push(qd, buf)?; - let qt: QToken = handle.get_task_id().into(); - trace!("push2() qt={:?}", qt); - Ok(qt) + self.do_push(qd, buf) } /// Pushes a buffer to a UDP socket. @@ -436,24 +413,18 @@ impl InetStack { // We just assert 'size' here, because it was previously checked at PDPIX layer. debug_assert!(size.is_none() || ((size.unwrap() > 0) && (size.unwrap() <= limits::POP_SIZE_MAX))); - let (task_id, coroutine): (String, Pin>) = match self.runtime.get_queue_type(&qd)? { - QType::TcpSocket => { - let task_id: String = format!("Inetstack::TCP::pop for qd={:?}", qd); - let coroutine: Pin> = self.ipv4.tcp.pop(qd, size)?; - (task_id, coroutine) - }, + match self.runtime.get_queue_type(&qd)? { + QType::TcpSocket => self.ipv4.tcp.pop(qd, size), QType::UdpSocket => { let task_id: String = format!("Inetstack::UDP::pop for qd={:?}", qd); let coroutine: Pin> = self.ipv4.udp.pop(qd, size)?; - (task_id, coroutine) + let handle: TaskHandle = self.runtime.insert_coroutine(task_id.as_str(), coroutine)?; + let qt: QToken = handle.get_task_id().into(); + trace!("async_close() qt={:?}", qt); + Ok(qt) }, _ => return Err(Fail::new(libc::EINVAL, "invalid queue type")), - }; - - let handle: TaskHandle = self.runtime.insert_coroutine(task_id.as_str(), coroutine)?; - let qt: QToken = handle.get_task_id().into(); - trace!("pop() qt={:?}", qt); - Ok(qt) + } } /// Waits for an operation to complete. diff --git a/src/rust/inetstack/protocols/tcp/passive_open.rs b/src/rust/inetstack/protocols/tcp/passive_open.rs index e13ee452b..00788502a 100644 --- a/src/rust/inetstack/protocols/tcp/passive_open.rs +++ b/src/rust/inetstack/protocols/tcp/passive_open.rs @@ -164,7 +164,7 @@ impl SharedPassiveSocket { self.local } - pub async fn accept(&mut self, yielder: Yielder) -> Result, Fail> { + pub async fn do_accept(&mut self, yielder: Yielder) -> Result, Fail> { self.ready.pop(yielder).await } diff --git a/src/rust/inetstack/protocols/tcp/peer.rs b/src/rust/inetstack/protocols/tcp/peer.rs index 7513418c4..fba766863 100644 --- a/src/rust/inetstack/protocols/tcp/peer.rs +++ b/src/rust/inetstack/protocols/tcp/peer.rs @@ -26,10 +26,14 @@ use crate::{ NetworkRuntime, }, queue::NetworkQueue, - scheduler::Yielder, + scheduler::{ + TaskHandle, + Yielder, + }, Operation, OperationResult, QDesc, + QToken, SharedBox, SharedDemiRuntime, SharedObject, @@ -143,7 +147,7 @@ impl SharedTcpPeer { // TODO: Check if we are binding to a non-local address. - // Check wether the address is in use. + // Check whether the address is in use. if self.runtime.addr_in_use(local) { let cause: String = format!("address is already bound to a socket (qd={:?}", qd); error!("bind(): {}", &cause); @@ -199,98 +203,175 @@ impl SharedTcpPeer { } /// Sets up the coroutine for accepting a new connection. - pub fn accept(&self, qd: QDesc) -> Result>, Fail> { - let yielder: Yielder = Yielder::new(); + pub fn accept(&mut self, qd: QDesc) -> Result { + #[cfg(feature = "profiler")] + timer!("inet::tcp::accept"); + trace!("accept(): qd={:?}", qd); let mut queue: SharedTcpQueue = self.get_shared_queue(&qd)?; - let mut runtime: SharedDemiRuntime = self.runtime.clone(); - Ok(Box::pin(async move { - // Wait for accept to complete. - // Handle result: If successful, allocate a new queue. - match queue.accept(yielder).await { - Ok(new_queue) => { - let endpoints: (SocketAddrV4, SocketAddrV4) = match new_queue.endpoints() { - Ok(endpoints) => endpoints, - Err(e) => return (qd, OperationResult::Failed(e)), - }; - let new_qd: QDesc = runtime.alloc_queue::>(new_queue.clone()); - if let Some(existing_qd) = - runtime.insert_socket_id_to_qd(SocketId::Active(endpoints.0, endpoints.1), new_qd) - { - // We should panic here because the ephemeral port allocator should not allocate the same port more than - // once. - unreachable!( - "There is already a queue listening on this queue descriptor {:?}", - existing_qd - ); - } - (qd, OperationResult::Accept((new_qd, endpoints.1))) - }, - Err(e) => (qd, OperationResult::Failed(e)), - } - })) + let coroutine_constructor = |yielder: Yielder| -> Result { + // Asynchronous accept code. Clone the self reference and move into the coroutine. + let coroutine: Pin> = Box::pin(self.clone().accept_coroutine(qd, yielder)); + // Insert async coroutine into the scheduler. + let task_name: String = format!("Catnap::accept for qd={:?}", qd); + self.runtime.insert_coroutine(&task_name, coroutine) + }; + + queue.accept(coroutine_constructor) + } + + /// Runs until a new connection is accepted. + async fn accept_coroutine(mut self, qd: QDesc, yielder: Yielder) -> (QDesc, OperationResult) { + // Grab the queue, make sure it hasn't been closed in the meantime. + // This will bump the Rc refcount so the coroutine can have it's own reference to the shared queue data + // structure and the SharedTcpQueue will not be freed until this coroutine finishes. + let mut queue: SharedTcpQueue = match self.get_shared_queue(&qd) { + Ok(queue) => queue.clone(), + Err(e) => return (qd, OperationResult::Failed(e)), + }; + // Wait for accept to complete. + match queue.accept_coroutine(yielder).await { + Ok(new_queue) => { + // Handle result: If successful, allocate a new queue. + let endpoints: (SocketAddrV4, SocketAddrV4) = match new_queue.endpoints() { + Ok(endpoints) => endpoints, + Err(e) => return (qd, OperationResult::Failed(e)), + }; + let new_qd: QDesc = self.runtime.alloc_queue::>(new_queue.clone()); + if let Some(existing_qd) = self + .runtime + .insert_socket_id_to_qd(SocketId::Active(endpoints.0, endpoints.1), new_qd) + { + // We should panic here because the ephemeral port allocator should not allocate the same port more than + // once. + unreachable!( + "There is already a queue listening on this queue descriptor {:?}", + existing_qd + ); + } + (qd, OperationResult::Accept((new_qd, endpoints.1))) + }, + Err(e) => (qd, OperationResult::Failed(e)), + } } /// Sets up the coroutine for connecting the socket to [remote]. - pub fn connect(&mut self, qd: QDesc, remote: SocketAddrV4) -> Result>, Fail> { - let yielder: Yielder = Yielder::new(); + pub fn connect(&mut self, qd: QDesc, remote: SocketAddrV4) -> Result { + #[cfg(feature = "profiler")] + timer!("inet::tcp::connect"); + trace!("connect(): qd={:?} remote={:?}", qd, remote); let mut queue: SharedTcpQueue = self.get_shared_queue(&qd)?; - let local: SocketAddrV4 = { - // TODO: we should free this when closing. - let local_port: u16 = self.runtime.alloc_ephemeral_port()?; - SocketAddrV4::new(self.local_ipv4_addr, local_port) + // Check whether we need to allocate an ephemeral port. + let local: SocketAddrV4 = match queue.local() { + Some(addr) => addr, + None => { + // TODO: we should free this when closing. + // FIXME: https://github.com/microsoft/demikernel/issues/236 + let local_port: u16 = self.runtime.alloc_ephemeral_port()?; + SocketAddrV4::new(self.local_ipv4_addr, local_port) + }, }; + // Insert the connection to receive incoming packets for this address pair. + // Should we remove the passive entry for the local address if the socket was previously bound? + if let Some(existing_qd) = self + .runtime + .insert_socket_id_to_qd(SocketId::Active(local, remote.clone()), qd) + { + // We should panic here because the ephemeral port allocator should not allocate the same port more than + // once. + unreachable!( + "There is already a queue listening on this queue descriptor {:?}", + existing_qd + ); + } let local_isn: SeqNumber = self.isn_generator.generate(&local, &remote); - let mut peer: SharedTcpPeer = self.clone(); - Ok(Box::pin(async move { - // Wait for connect to complete. - if let Some(existing_qd) = peer - .runtime - .insert_socket_id_to_qd(SocketId::Active(local, remote.clone()), qd) - { - // We should panic here because the ephemeral port allocator should not allocate the same port more than - // once. - unreachable!( - "There is already a queue listening on this queue descriptor {:?}", - existing_qd - ); - } - match queue.connect(local, remote, local_isn, yielder).await { - Ok(()) => (qd, OperationResult::Connect), - Err(e) => { - peer.runtime - .remove_socket_id_to_qd(&SocketId::Active(local, remote.clone())); - (qd, OperationResult::Failed(e)) - }, - } - })) + let coroutine_constructor = |yielder: Yielder| -> Result { + // Clone the self reference and move into the coroutine. + let coroutine: Pin> = Box::pin(self.clone().connect_coroutine(qd, yielder)); + let task_name: String = format!("inetstack::tcp::connect for qd={:?}", qd); + self.runtime.insert_coroutine(&task_name, coroutine) + }; + + queue.connect(local, remote, local_isn, coroutine_constructor) + } + + /// Runs until the connect to remote is made or times out. + async fn connect_coroutine(mut self, qd: QDesc, yielder: Yielder) -> (QDesc, OperationResult) { + // Grab the queue, make sure it hasn't been closed in the meantime. + // This will bump the Rc refcount so the coroutine can have it's own reference to the shared queue data + // structure and the SharedTcpQueue will not be freed until this coroutine finishes. + let mut queue: SharedTcpQueue = match self.runtime.get_shared_queue(&qd) { + Ok(queue) => queue, + Err(e) => return (qd, OperationResult::Failed(e)), + }; + let (local, remote): (SocketAddrV4, SocketAddrV4) = queue + .endpoints() + .expect("We should have allocated endpoints when we allocated the coroutine"); + // Wait for connect to complete. + match queue.connect_coroutine(yielder).await { + Ok(()) => (qd, OperationResult::Connect), + Err(e) => { + self.runtime.remove_socket_id_to_qd(&SocketId::Active(local, remote)); + (qd, OperationResult::Failed(e)) + }, + } } /// Pushes immediately to the socket and returns the result asynchronously. - pub fn push(&self, qd: QDesc, buf: DemiBuffer) -> Result>, Fail> { + pub fn push(&mut self, qd: QDesc, buf: DemiBuffer) -> Result { let mut queue: SharedTcpQueue = self.get_shared_queue(&qd)?; - let result: Result<(), Fail> = queue.push(buf); - Ok(Box::pin(async move { - // Wait for push to complete. - match result { - Ok(()) => (qd, OperationResult::Push), - Err(e) => (qd, OperationResult::Failed(e)), - } - })) + let coroutine_constructor = |yielder: Yielder| -> Result { + // Clone the self reference and move into the coroutine. + let coroutine: Pin> = Box::pin(self.clone().push_coroutine(qd, yielder)); + let task_name: String = format!("inetstack::tcp::push for qd={:?}", qd); + self.runtime.insert_coroutine(&task_name, coroutine) + }; + queue.push(buf, coroutine_constructor) + } + + async fn push_coroutine(self, qd: QDesc, yielder: Yielder) -> (QDesc, OperationResult) { + // Grab the queue, make sure it hasn't been closed in the meantime. + // This will bump the Rc refcount so the coroutine can have it's own reference to the shared queue data + // structure and the SharedTcpQueue will not be freed until this coroutine finishes. + let mut queue: SharedTcpQueue = match self.get_shared_queue(&qd) { + Ok(queue) => queue, + Err(e) => return (qd, OperationResult::Failed(e)), + }; + // Wait for push to complete. + match queue.push_coroutine(yielder).await { + Ok(()) => (qd, OperationResult::Push), + Err(e) => { + warn!("push() qd={:?}: {:?}", qd, &e); + (qd, OperationResult::Failed(e)) + }, + } } /// Sets up a coroutine for popping data from the socket. - pub fn pop(&self, qd: QDesc, size: Option) -> Result>, Fail> { - let yielder: Yielder = Yielder::new(); + pub fn pop(&mut self, qd: QDesc, size: Option) -> Result { // Get local address bound to socket. let mut queue: SharedTcpQueue = self.get_shared_queue(&qd)?; + let coroutine_constructor = |yielder: Yielder| -> Result { + // Clone the self reference and move into the coroutine. + let coroutine: Pin> = Box::pin(self.clone().pop_coroutine(qd, size, yielder)); + let task_name: String = format!("inetstack::tcp::pop for qd={:?}", qd); + self.runtime.insert_coroutine(&task_name, coroutine) + }; + queue.pop(coroutine_constructor) + } - Ok(Box::pin(async move { - // Wait for pop to complete. - match queue.pop(size, yielder).await { - Ok(buf) => (qd, OperationResult::Pop(None, buf)), - Err(e) => (qd, OperationResult::Failed(e)), - } - })) + async fn pop_coroutine(self, qd: QDesc, size: Option, yielder: Yielder) -> (QDesc, OperationResult) { + // Grab the queue, make sure it hasn't been closed in the meantime. + // This will bump the Rc refcount so the coroutine can have it's own reference to the shared queue data + // structure and the SharedTcpQueue will not be freed until this coroutine finishes. + let mut queue: SharedTcpQueue = match self.get_shared_queue(&qd) { + Ok(queue) => queue, + Err(e) => return (qd, OperationResult::Failed(e)), + }; + // Wait for pop to complete. + match queue.pop_coroutine(size, yielder).await { + Ok(buf) => (qd, OperationResult::Pop(None, buf)), + Err(e) => (qd, OperationResult::Failed(e)), + } } /// Closes a TCP socket. @@ -320,40 +401,51 @@ impl SharedTcpPeer { } /// Closes a TCP socket. - pub fn async_close(&self, qd: QDesc) -> Result>, Fail> { + pub fn async_close(&mut self, qd: QDesc) -> Result { trace!("Closing socket: qd={:?}", qd); - let yielder: Yielder = Yielder::new(); let mut queue: SharedTcpQueue = self.get_shared_queue(&qd)?; - let mut peer: SharedTcpPeer = self.clone(); - Ok(Box::pin(async move { - // Wait for accept to complete. - // Handle result: If unsuccessful, free the new queue descriptor. - match queue.async_close(yielder).await { - Ok(socket_id) => { - if let Some(socket_id) = socket_id { - match peer.runtime.remove_socket_id_to_qd(&socket_id) { - Some(existing_qd) if existing_qd == qd => {}, - _ => { - return ( - qd, - OperationResult::Failed(Fail::new( - libc::EINVAL, - "socket id did not map to this qd!", - )), - ) - }, - } + let coroutine_constructor = |yielder: Yielder| -> Result { + // Clone the self reference and move into the coroutine. + let coroutine: Pin> = Box::pin(self.clone().close_coroutine(qd, yielder)); + let task_name: String = format!("inetstack::tcp::close for qd={:?}", qd); + self.runtime.insert_coroutine(&task_name, coroutine) + }; + + queue.async_close(coroutine_constructor) + } + + async fn close_coroutine(mut self, qd: QDesc, yielder: Yielder) -> (QDesc, OperationResult) { + // Grab the queue, make sure it hasn't been closed in the meantime. + // This will bump the Rc refcount so the coroutine can have it's own reference to the shared queue data + // structure and the SharedTcpQueue will not be freed until this coroutine finishes. + let mut queue: SharedTcpQueue = match self.get_shared_queue(&qd) { + Ok(queue) => queue, + Err(e) => return (qd, OperationResult::Failed(e)), + }; + // Wait for close to complete. + // Handle result: If unsuccessful, free the new queue descriptor. + match queue.close_coroutine(yielder).await { + Ok(socket_id) => { + if let Some(socket_id) = socket_id { + match self.runtime.remove_socket_id_to_qd(&socket_id) { + Some(existing_qd) if existing_qd == qd => {}, + _ => { + return ( + qd, + OperationResult::Failed(Fail::new(libc::EINVAL, "socket id did not map to this qd!")), + ) + }, } - // Free the queue. - peer.runtime - .free_queue::>(&qd) - .expect("queue should exist"); + } + // Free the queue. + self.runtime + .free_queue::>(&qd) + .expect("queue should exist"); - (qd, OperationResult::Close) - }, - Err(e) => (qd, OperationResult::Failed(e)), - } - })) + (qd, OperationResult::Close) + }, + Err(e) => (qd, OperationResult::Failed(e)), + } } pub fn remote_mss(&self, qd: QDesc) -> Result { diff --git a/src/rust/inetstack/protocols/tcp/queue.rs b/src/rust/inetstack/protocols/tcp/queue.rs index 6e356d415..89b053b29 100644 --- a/src/rust/inetstack/protocols/tcp/queue.rs +++ b/src/rust/inetstack/protocols/tcp/queue.rs @@ -33,15 +33,24 @@ use crate::{ fail::Fail, memory::DemiBuffer, network::{ - socket::SocketId, + socket::{ + operation::SocketOp, + state::SocketStateMachine, + SocketId, + }, NetworkRuntime, }, queue::{ IoQueue, NetworkQueue, }, - scheduler::Yielder, + scheduler::{ + TaskHandle, + Yielder, + YielderHandle, + }, QDesc, + QToken, QType, SharedBox, SharedDemiRuntime, @@ -51,6 +60,7 @@ use crate::{ use ::futures::channel::mpsc; use ::std::{ any::Any, + collections::HashMap, net::SocketAddrV4, ops::{ Deref, @@ -64,7 +74,8 @@ use ::std::{ //====================================================================================================================== pub enum Socket { - Inactive(Option), + Unbound, + Bound(SocketAddrV4), Listening(SharedPassiveSocket), Connecting(SharedActiveOpenSocket), Established(EstablishedSocket), @@ -77,6 +88,7 @@ pub enum Socket { /// Per-queue metadata for the TCP socket. pub struct TcpQueue { + state_machine: SocketStateMachine, socket: Socket, runtime: SharedDemiRuntime, transport: SharedBox>, @@ -84,6 +96,7 @@ pub struct TcpQueue { tcp_config: TcpConfig, arp: SharedArpPeer, dead_socket_tx: mpsc::UnboundedSender, + pending_ops: HashMap, } #[derive(Clone)] @@ -104,13 +117,15 @@ impl SharedTcpQueue { dead_socket_tx: mpsc::UnboundedSender, ) -> Self { Self(SharedObject::>::new(TcpQueue { - socket: Socket::Inactive(None), + state_machine: SocketStateMachine::new_unbound(libc::SOCK_STREAM), + socket: Socket::Unbound, runtime, transport, local_link_addr, tcp_config, arp, dead_socket_tx, + pending_ops: HashMap::::new(), })) } @@ -124,6 +139,7 @@ impl SharedTcpQueue { dead_socket_tx: mpsc::UnboundedSender, ) -> Self { Self(SharedObject::>::new(TcpQueue { + state_machine: SocketStateMachine::new_connected(), socket: Socket::Established(socket), runtime, transport, @@ -131,56 +147,57 @@ impl SharedTcpQueue { tcp_config, arp, dead_socket_tx, + pending_ops: HashMap::::new(), })) } /// Binds the target queue to `local` address. pub fn bind(&mut self, local: SocketAddrV4) -> Result<(), Fail> { - match self.socket { - Socket::Inactive(None) => { - self.socket = Socket::Inactive(Some(local)); - Ok(()) - }, - Socket::Inactive(_) => Err(Fail::new(libc::EINVAL, "socket is already bound to an address")), - Socket::Listening(_) => Err(Fail::new(libc::EINVAL, "socket is already listening")), - Socket::Connecting(_) => Err(Fail::new(libc::EINVAL, "socket is connecting")), - Socket::Established(_) => Err(Fail::new(libc::EINVAL, "socket is connected")), - Socket::Closing(_) => Err(Fail::new(libc::EINVAL, "socket is closed")), - } + self.state_machine.prepare(SocketOp::Bind)?; + self.socket = Socket::Bound(local); + self.state_machine.commit(); + Ok(()) } /// Sets the target queue to listen for incoming connections. pub fn listen(&mut self, backlog: usize, nonce: u32) -> Result<(), Fail> { - match self.socket { - Socket::Inactive(Some(local)) => { - self.socket = Socket::Listening(SharedPassiveSocket::new( - local, - backlog, - self.runtime.clone(), - self.transport.clone(), - self.tcp_config.clone(), - self.local_link_addr, - self.arp.clone(), - self.dead_socket_tx.clone(), - nonce, - )); - Ok(()) - }, - Socket::Inactive(None) => Err(Fail::new(libc::EDESTADDRREQ, "socket is not bound to a local address")), - Socket::Listening(_) => Err(Fail::new(libc::EINVAL, "socket is already listening")), - Socket::Connecting(_) => Err(Fail::new(libc::EINVAL, "socket is connecting")), - Socket::Established(_) => Err(Fail::new(libc::EINVAL, "socket is connected")), - Socket::Closing(_) => Err(Fail::new(libc::EINVAL, "socket is closed")), - } + self.state_machine.prepare(SocketOp::Listen)?; + self.socket = Socket::Listening(SharedPassiveSocket::new( + self.local() + .expect("If we were able to prepare, then the socket must be bound"), + backlog, + self.runtime.clone(), + self.transport.clone(), + self.tcp_config.clone(), + self.local_link_addr, + self.arp.clone(), + self.dead_socket_tx.clone(), + nonce, + )); + self.state_machine.commit(); + Ok(()) } - pub async fn accept(&mut self, yielder: Yielder) -> Result, Fail> { + pub fn accept(&mut self, coroutine_constructor: F) -> Result + where + F: FnOnce(Yielder) -> Result, + { + self.state_machine.prepare(SocketOp::Accept)?; + Ok(self + .do_generic_sync_control_path_call(coroutine_constructor)? + .get_task_id() + .into()) + } + + pub async fn accept_coroutine(&mut self, yielder: Yielder) -> Result, Fail> { // Wait for a new connection on the listening socket. + self.state_machine.may_accept()?; let mut listening_socket: SharedPassiveSocket = match self.socket { Socket::Listening(ref listening_socket) => listening_socket.clone(), - _ => return Err(Fail::new(libc::EOPNOTSUPP, "socket not listening")), + _ => unreachable!("State machine check should ensure that this socket is listening"), }; - let new_socket: EstablishedSocket = listening_socket.accept(yielder).await?; + let new_socket: EstablishedSocket = listening_socket.do_accept(yielder).await?; + self.state_machine.prepare(SocketOp::Accepted)?; // Insert queue into queue table and get new queue descriptor. let new_queue = Self::new_established( new_socket, @@ -191,88 +208,116 @@ impl SharedTcpQueue { self.arp.clone(), self.dead_socket_tx.clone(), ); + self.state_machine.commit(); Ok(new_queue) } - pub async fn connect( + pub fn connect( &mut self, local: SocketAddrV4, remote: SocketAddrV4, local_isn: SeqNumber, - yielder: Yielder, - ) -> Result<(), Fail> { - let socket: SharedActiveOpenSocket = match self.socket { - Socket::Inactive(Some(local)) => { - // Create active socket. - SharedActiveOpenSocket::new( - local_isn, - local, - remote, - self.runtime.clone(), - self.transport.clone(), - self.tcp_config.clone(), - self.local_link_addr, - self.arp.clone(), - self.dead_socket_tx.clone(), - )? - }, - Socket::Inactive(None) => { - // Create active socket. - SharedActiveOpenSocket::new( - local_isn, - local, - remote, - self.runtime.clone(), - self.transport.clone(), - self.tcp_config.clone(), - self.local_link_addr, - self.arp.clone(), - self.dead_socket_tx.clone(), - )? - }, - Socket::Listening(_) => return Err(Fail::new(libc::EOPNOTSUPP, "socket is listening")), - Socket::Connecting(_) => return Err(Fail::new(libc::EALREADY, "socket is connecting")), - Socket::Established(_) => return Err(Fail::new(libc::EISCONN, "socket is connected")), - Socket::Closing(_) => return Err(Fail::new(libc::EINVAL, "socket is closed")), - }; - // Update socket state to active open. - self.socket = Socket::Connecting(socket.clone()); + coroutine_constructor: F, + ) -> Result + where + F: FnOnce(Yielder) -> Result, + { + self.state_machine.prepare(SocketOp::Connect)?; + + // Create active socket. + self.socket = Socket::Connecting(SharedActiveOpenSocket::new( + local_isn, + local, + remote, + self.runtime.clone(), + self.transport.clone(), + self.tcp_config.clone(), + self.local_link_addr, + self.arp.clone(), + self.dead_socket_tx.clone(), + )?); + + Ok(self + .do_generic_sync_control_path_call(coroutine_constructor)? + .get_task_id() + .into()) + } + + pub async fn connect_coroutine(&mut self, yielder: Yielder) -> Result<(), Fail> { // Wait for the established socket to come back and update again. - self.socket = Socket::Established(socket.connect(yielder).await?); + let connecting_socket: SharedActiveOpenSocket = match self.socket { + Socket::Connecting(ref connecting_socket) => connecting_socket.clone(), + _ => unreachable!("State machine check should ensure that this socket is connecting"), + }; + let socket: EstablishedSocket = connecting_socket.connect(yielder).await?; + self.state_machine.prepare(SocketOp::Connected)?; + self.socket = Socket::Established(socket); + self.state_machine.commit(); Ok(()) } - pub fn push(&mut self, buf: DemiBuffer) -> Result<(), Fail> { + pub fn push(&mut self, buf: DemiBuffer, coroutine_constructor: F) -> Result + where + F: FnOnce(Yielder) -> Result, + { + self.state_machine.may_push()?; + // Send synchronously. match self.socket { - Socket::Established(ref mut socket) => socket.send(buf), - _ => Err(Fail::new(libc::ENOTCONN, "connection not established")), - } + Socket::Established(ref mut socket) => socket.send(buf)?, + _ => unreachable!("State machine check should ensure that this socket is connected"), + }; + Ok(self + .do_generic_sync_data_path_call(coroutine_constructor)? + .get_task_id() + .into()) + } + + pub async fn push_coroutine(&mut self, _yielder: Yielder) -> Result<(), Fail> { + Ok(()) + } + + pub fn pop(&mut self, coroutine_constructor: F) -> Result + where + F: FnOnce(Yielder) -> Result, + { + self.state_machine.may_pop()?; + Ok(self + .do_generic_sync_data_path_call(coroutine_constructor)? + .get_task_id() + .into()) } - pub async fn pop(&mut self, size: Option, yielder: Yielder) -> Result { + pub async fn pop_coroutine(&mut self, size: Option, yielder: Yielder) -> Result { + self.state_machine.may_pop()?; match self.socket { Socket::Established(ref mut socket) => socket.pop(size, yielder).await, - Socket::Closing(_) => Err(Fail::new(libc::EBADF, "socket closing")), - Socket::Connecting(_) => Err(Fail::new(libc::EINPROGRESS, "socket connecting")), - Socket::Inactive(_) => Err(Fail::new(libc::EBADF, "socket inactive")), - Socket::Listening(_) => Err(Fail::new(libc::ENOTCONN, "socket listening")), + _ => unreachable!("State machine check should ensure that this socket is connected"), } } + pub fn async_close(&mut self, coroutine_constructor: F) -> Result + where + F: FnOnce(Yielder) -> Result, + { + self.state_machine.prepare(SocketOp::Close)?; + Ok(self + .do_generic_sync_control_path_call(coroutine_constructor)? + .get_task_id() + .into()) + } + + pub async fn close_coroutine(&mut self, _: Yielder) -> Result, Fail> { + self.close() + } + pub fn close(&mut self) -> Result, Fail> { - let socket: EstablishedSocket = match self.socket { + self.state_machine.prepare(SocketOp::Close)?; + let new_socket: Option> = match self.socket { // Closing an active socket. Socket::Established(ref mut socket) => { socket.close()?; - // Only using a clone here because we need to read and write the socket. - socket.clone() - }, - // Closing an unbound socket. - Socket::Inactive(None) => { - return Ok(None); + Some(Socket::Closing(socket.clone())) }, - // Closing a bound socket. - Socket::Inactive(Some(addr)) => return Ok(Some(SocketId::Passive(addr.clone()))), // Closing a listening socket. Socket::Listening(_) => { let cause: String = format!("cannot close a listening socket"); @@ -291,13 +336,18 @@ impl SharedTcpQueue { error!("do_close(): {}", &cause); return Err(Fail::new(libc::ENOTSUP, &cause)); }, + _ => None, }; - self.socket = Socket::Closing(socket.clone()); - return Ok(Some(SocketId::Active(socket.endpoints().0, socket.endpoints().1))); - } - - pub async fn async_close(&mut self, _: Yielder) -> Result, Fail> { - self.close() + if let Some(socket) = new_socket { + self.socket = socket; + } + self.state_machine.commit(); + match self.socket { + Socket::Closing(ref socket) => Ok(Some(SocketId::Active(socket.endpoints().0, socket.endpoints().1))), + Socket::Bound(addr) => Ok(Some(SocketId::Passive(addr))), + Socket::Unbound => Ok(None), + _ => unreachable!("We do not support closing of other socket types"), + } } pub fn remote_mss(&self) -> Result { @@ -317,6 +367,7 @@ impl SharedTcpQueue { pub fn endpoints(&self) -> Result<(SocketAddrV4, SocketAddrV4), Fail> { match self.socket { Socket::Established(ref socket) => Ok(socket.endpoints()), + Socket::Connecting(ref socket) => Ok(socket.endpoints()), _ => Err(Fail::new(libc::ENOTCONN, "connection not established")), } } @@ -352,14 +403,13 @@ impl SharedTcpQueue { } }, // The segment is for an inactive connection. - Socket::Inactive(addr) => { - // It is safe to expect a bound socket here because we would not have found this queue otherwise. - debug!( - "Routing to inactive connection: {:?}", - addr.expect("This queue must be bound or we could not have routed to it") - ); + Socket::Bound(addr) => { + debug!("Routing to inactive connection: {:?}", addr); // Fall through and send a RST segment back. }, + // The segment is for a totally unbound connection. + Socket::Unbound => unreachable!("This socket must be at least bound for us to find it"), + // Fall through and send a RST segment back. Socket::Closing(ref mut socket) => { debug!("Routing to closing connection: {:?}", socket.endpoints()); socket.receive(tcp_hdr, buf); @@ -431,6 +481,57 @@ impl SharedTcpQueue { Ok(()) } + + /// Removes an operation from the list of pending operations on this queue. This function should only be called if + /// add_pending_op() was previously called. + /// TODO: Remove this when we clean up take_result(). + /// This function is deprecated, do not use. + /// FIXME: https://github.com/microsoft/demikernel/issues/888 + pub fn remove_pending_op(&mut self, handle: &TaskHandle) { + self.pending_ops.remove(handle); + } + + /// Adds a new operation to the list of pending operations on this queue. + fn add_pending_op(&mut self, handle: &TaskHandle, yielder_handle: &YielderHandle) { + self.pending_ops.insert(handle.clone(), yielder_handle.clone()); + } + + /// Generic function for spawning a control-path coroutine on [self]. + fn do_generic_sync_control_path_call(&mut self, coroutine: F) -> Result + where + F: FnOnce(Yielder) -> Result, + { + let yielder: Yielder = Yielder::new(); + let yielder_handle: YielderHandle = yielder.get_handle(); + // Spawn coroutine. + match coroutine(Yielder::new()) { + // We successfully spawned the coroutine. + Ok(handle) => { + // Commit the operation on the socket. + self.add_pending_op(&handle, &yielder_handle); + self.state_machine.commit(); + Ok(handle) + }, + // We failed to spawn the coroutine. + Err(e) => { + // Abort the operation on the socket. + self.state_machine.abort(); + Err(e) + }, + } + } + + /// Generic function for spawning a data-path coroutine on [self]. + fn do_generic_sync_data_path_call(&mut self, coroutine: F) -> Result + where + F: FnOnce(Yielder) -> Result, + { + let yielder: Yielder = Yielder::new(); + let yielder_handle: YielderHandle = yielder.get_handle(); + let task_handle: TaskHandle = coroutine(yielder)?; + self.add_pending_op(&task_handle, &yielder_handle); + Ok(task_handle) + } } //====================================================================================================================== @@ -459,7 +560,8 @@ impl NetworkQueue for SharedTcpQueue { /// Returns the local address to which the target queue is bound. fn local(&self) -> Option { match self.socket { - Socket::Inactive(addr) => addr, + Socket::Unbound => None, + Socket::Bound(addr) => Some(addr), Socket::Listening(ref socket) => Some(socket.endpoint()), Socket::Connecting(ref socket) => Some(socket.endpoints().0), Socket::Established(ref socket) => Some(socket.endpoints().0), @@ -470,7 +572,8 @@ impl NetworkQueue for SharedTcpQueue { /// Returns the remote address to which the target queue is connected to. fn remote(&self) -> Option { match self.socket { - Socket::Inactive(_) => None, + Socket::Unbound => None, + Socket::Bound(_) => None, Socket::Listening(_) => None, Socket::Connecting(ref socket) => Some(socket.endpoints().1), Socket::Established(ref socket) => Some(socket.endpoints().1), diff --git a/src/rust/inetstack/protocols/tcp/tests/established.rs b/src/rust/inetstack/protocols/tcp/tests/established.rs index 79d8e4805..65c728437 100644 --- a/src/rust/inetstack/protocols/tcp/tests/established.rs +++ b/src/rust/inetstack/protocols/tcp/tests/established.rs @@ -26,10 +26,9 @@ use crate::{ runtime::{ memory::DemiBuffer, network::consts::RECEIVE_BATCH_SIZE, - scheduler::TaskHandle, - Operation, OperationResult, QDesc, + QToken, }, }; use ::anyhow::Result; @@ -37,7 +36,6 @@ use ::rand; use ::std::{ collections::VecDeque, net::SocketAddrV4, - pin::Pin, time::Instant, }; @@ -73,11 +71,7 @@ fn send_data( ); // Push data. - let push_coroutine: Pin> = sender.tcp_push(sender_qd, bytes.clone())?; - let handle: TaskHandle = sender - .get_test_rig() - .get_runtime() - .insert_coroutine("test::send_data::push_coroutine", push_coroutine)?; + let qt: QToken = sender.tcp_push(sender_qd, bytes.clone())?; // Poll the coroutine. sender.get_test_rig().poll_scheduler(); @@ -111,7 +105,7 @@ fn send_data( match sender .get_test_rig() .get_runtime() - .remove_coroutine(&handle) + .remove_coroutine_with_qtoken(qt) .get_result() { Some((_, OperationResult::Push)) => { @@ -137,11 +131,7 @@ fn recv_data( ); // Pop data. - let pop_coroutine: Pin> = receiver.tcp_pop(receiver_qd)?; - let handle: TaskHandle = receiver - .get_test_rig() - .get_runtime() - .insert_coroutine("test::recv_data::pop_coroutine", pop_coroutine)?; + let qt: QToken = receiver.tcp_pop(receiver_qd)?; // Deliver data. if let Err(e) = receiver.receive(bytes.clone()) { @@ -155,7 +145,7 @@ fn recv_data( match receiver .get_test_rig() .get_runtime() - .remove_coroutine(&handle) + .remove_coroutine_with_qtoken(qt) .get_result() { Some((_, OperationResult::Pop(_, _))) => { diff --git a/src/rust/inetstack/protocols/tcp/tests/setup.rs b/src/rust/inetstack/protocols/tcp/tests/setup.rs index 4e967e8f7..66c269cc0 100644 --- a/src/rust/inetstack/protocols/tcp/tests/setup.rs +++ b/src/rust/inetstack/protocols/tcp/tests/setup.rs @@ -33,10 +33,9 @@ use crate::{ types::MacAddress, PacketBuf, }, - scheduler::TaskHandle, - Operation, OperationResult, QDesc, + QToken, }, }; use ::anyhow::Result; @@ -46,7 +45,6 @@ use ::std::{ Ipv4Addr, SocketAddrV4, }, - pin::Pin, time::{ Duration, Instant, @@ -75,8 +73,7 @@ fn test_connection_timeout() -> Result<()> { advance_clock(None, Some(&mut client), &mut now); // Client: SYN_SENT state at T(1). - let (_, connect_handle, bytes): (QDesc, TaskHandle, DemiBuffer) = - connection_setup_listen_syn_sent(&mut client, listen_addr)?; + let (_, qt, bytes): (QDesc, QToken, DemiBuffer) = connection_setup_listen_syn_sent(&mut client, listen_addr)?; // Sanity check packet. check_packet_pure_syn( @@ -98,7 +95,7 @@ fn test_connection_timeout() -> Result<()> { match client .get_test_rig() .get_runtime() - .remove_coroutine(&connect_handle) + .remove_coroutine_with_qtoken(qt) .get_result() { None => Ok(()), @@ -120,13 +117,13 @@ fn test_refuse_connection_early_rst() -> Result<()> { let mut client: SharedEngine = test_helpers::new_alice2(now); // Server: LISTEN state at T(0). - let _: TaskHandle = connection_setup_closed_listen(&mut server, listen_addr)?; + let _: QToken = connection_setup_closed_listen(&mut server, listen_addr)?; // T(0) -> T(1) advance_clock(Some(&mut server), Some(&mut client), &mut now); // Client: SYN_SENT state at T(1). - let (_, _, bytes): (QDesc, TaskHandle, DemiBuffer) = connection_setup_listen_syn_sent(&mut client, listen_addr)?; + let (_, _, bytes): (QDesc, QToken, DemiBuffer) = connection_setup_listen_syn_sent(&mut client, listen_addr)?; // Temper packet. let (eth2_header, ipv4_header, tcp_header): (Ethernet2Header, Ipv4Header, TcpHeader) = @@ -184,13 +181,13 @@ fn test_refuse_connection_early_ack() -> Result<()> { let mut client: SharedEngine = test_helpers::new_alice2(now); // Server: LISTEN state at T(0). - let _: TaskHandle = connection_setup_closed_listen(&mut server, listen_addr)?; + let _: QToken = connection_setup_closed_listen(&mut server, listen_addr)?; // T(0) -> T(1) advance_clock(Some(&mut server), Some(&mut client), &mut now); // Client: SYN_SENT state at T(1). - let (_, _, bytes): (QDesc, TaskHandle, DemiBuffer) = connection_setup_listen_syn_sent(&mut client, listen_addr)?; + let (_, _, bytes): (QDesc, QToken, DemiBuffer) = connection_setup_listen_syn_sent(&mut client, listen_addr)?; // Temper packet. let (eth2_header, ipv4_header, tcp_header): (Ethernet2Header, Ipv4Header, TcpHeader) = @@ -248,13 +245,13 @@ fn test_refuse_connection_missing_syn() -> Result<()> { let mut client: SharedEngine = test_helpers::new_alice2(now); // Server: LISTEN state at T(0). - let _: TaskHandle = connection_setup_closed_listen(&mut server, listen_addr)?; + let _: QToken = connection_setup_closed_listen(&mut server, listen_addr)?; // T(0) -> T(1) advance_clock(Some(&mut server), Some(&mut client), &mut now); // Client: SYN_SENT state at T(1). - let (_, _, bytes): (QDesc, TaskHandle, DemiBuffer) = connection_setup_listen_syn_sent(&mut client, listen_addr)?; + let (_, _, bytes): (QDesc, QToken, DemiBuffer) = connection_setup_listen_syn_sent(&mut client, listen_addr)?; // Sanity check packet. check_packet_pure_syn( @@ -356,17 +353,13 @@ fn serialize_segment(pkt: TcpSegment) -> Result { fn connection_setup_listen_syn_sent( client: &mut SharedEngine, listen_addr: SocketAddrV4, -) -> Result<(QDesc, TaskHandle, DemiBuffer)> { +) -> Result<(QDesc, QToken, DemiBuffer)> { // Issue CONNECT operation. let client_fd: QDesc = match client.tcp_socket() { Ok(fd) => fd, Err(e) => anyhow::bail!("client tcp socket returned error: {:?}", e), }; - let connect_coroutine: Pin> = client.tcp_connect(client_fd, listen_addr)?; - let connect_handle: TaskHandle = client - .get_test_rig() - .get_runtime() - .insert_coroutine("test::connection_setup_listen_syn_sent()", connect_coroutine)?; + let qt: QToken = client.tcp_connect(client_fd, listen_addr)?; // SYN_SENT state. client.get_test_rig().poll_scheduler(); @@ -374,14 +367,14 @@ fn connection_setup_listen_syn_sent( let bytes: DemiBuffer = client.get_test_rig().pop_frame(); - Ok((client_fd, connect_handle, bytes)) + Ok((client_fd, qt, bytes)) } /// Triggers CLOSED -> LISTEN state transition. fn connection_setup_closed_listen( server: &mut SharedEngine, listen_addr: SocketAddrV4, -) -> Result { +) -> Result { // Issue ACCEPT operation. let socket_fd: QDesc = match server.tcp_socket() { Ok(fd) => fd, @@ -393,15 +386,12 @@ fn connection_setup_closed_listen( if let Err(e) = server.tcp_listen(socket_fd, 1) { anyhow::bail!("server listen returned an error: {:?}", e); } - let accept_coroutine: Pin> = server.tcp_accept(socket_fd)?; - let accept_handle: TaskHandle = server.get_test_rig().get_runtime().insert_coroutine( - "test::connection_setup_closed_listen::accept_coroutine", - accept_coroutine, - )?; + let accept_qt: QToken = server.tcp_accept(socket_fd)?; + // LISTEN state. server.get_test_rig().poll_scheduler(); - Ok(accept_handle) + Ok(accept_qt) } /// Triggers LISTEN -> SYN_RCVD state transition. @@ -538,13 +528,13 @@ pub fn connection_setup( listen_addr: SocketAddrV4, ) -> Result<((QDesc, SocketAddrV4), QDesc)> { // Server: LISTEN state at T(0). - let accept_handle: TaskHandle = connection_setup_closed_listen(server, listen_addr)?; + let accept_qt: QToken = connection_setup_closed_listen(server, listen_addr)?; // T(0) -> T(1) advance_clock(Some(server), Some(client), now); // Client: SYN_SENT state at T(1). - let (client_fd, connect_handle, mut bytes): (QDesc, TaskHandle, DemiBuffer) = + let (client_fd, connect_qt, mut bytes): (QDesc, QToken, DemiBuffer) = connection_setup_listen_syn_sent(client, listen_addr)?; // Sanity check packet. @@ -597,7 +587,7 @@ pub fn connection_setup( let (server_fd, addr): (QDesc, SocketAddrV4) = match server .get_test_rig() .get_runtime() - .remove_coroutine(&accept_handle) + .remove_coroutine_with_qtoken(accept_qt) .get_result() { Some((_, crate::OperationResult::Accept((server_fd, addr)))) => (server_fd, addr), @@ -606,7 +596,7 @@ pub fn connection_setup( match client .get_test_rig() .get_runtime() - .remove_coroutine(&connect_handle) + .remove_coroutine_with_qtoken(connect_qt) .get_result() { Some((_, OperationResult::Connect)) => {}, diff --git a/src/rust/inetstack/test_helpers/engine.rs b/src/rust/inetstack/test_helpers/engine.rs index e6d6277e2..b8ae1ff91 100644 --- a/src/rust/inetstack/test_helpers/engine.rs +++ b/src/rust/inetstack/test_helpers/engine.rs @@ -26,6 +26,7 @@ use crate::{ scheduler::Yielder, Operation, QDesc, + QToken, SharedBox, SharedObject, }, @@ -136,11 +137,7 @@ impl SharedEngine { self.ipv4.tcp.socket() } - pub fn tcp_connect( - &mut self, - socket_fd: QDesc, - remote_endpoint: SocketAddrV4, - ) -> Result>, Fail> { + pub fn tcp_connect(&mut self, socket_fd: QDesc, remote_endpoint: SocketAddrV4) -> Result { self.ipv4.tcp.connect(socket_fd, remote_endpoint) } @@ -148,15 +145,15 @@ impl SharedEngine { self.ipv4.tcp.bind(socket_fd, endpoint) } - pub fn tcp_accept(&self, fd: QDesc) -> Result>, Fail> { + pub fn tcp_accept(&mut self, fd: QDesc) -> Result { self.ipv4.tcp.accept(fd) } - pub fn tcp_push(&mut self, socket_fd: QDesc, buf: DemiBuffer) -> Result>, Fail> { + pub fn tcp_push(&mut self, socket_fd: QDesc, buf: DemiBuffer) -> Result { self.ipv4.tcp.push(socket_fd, buf) } - pub fn tcp_pop(&mut self, socket_fd: QDesc) -> Result>, Fail> { + pub fn tcp_pop(&mut self, socket_fd: QDesc) -> Result { self.ipv4.tcp.pop(socket_fd, None) } diff --git a/tests/rust/tcp.rs b/tests/rust/tcp.rs index afe819ccf..65766fcc7 100644 --- a/tests/rust/tcp.rs +++ b/tests/rust/tcp.rs @@ -562,11 +562,11 @@ fn tcp_bad_listen() -> Result<()> { safe_bind(&mut libos, sockqd, local2)?; safe_listen(&mut libos, sockqd)?; match libos.listen(sockqd, 16) { - Err(e) if e.errno == libc::EINVAL => (), + Err(e) if e.errno == libc::EADDRINUSE => (), _ => { // Close socket if not error because this test cannot continue. // FIXME: https://github.com/demikernel/demikernel/issues/633 - anyhow::bail!("listen() called on an already listening socket should fail with EINVAL") + anyhow::bail!("listen() called on an already listening socket should fail with EADDRINUSE") }, }; safe_close_passive(&mut libos, sockqd)?; From c38446dd3b48e0cc8406f5169599045457f1228b Mon Sep 17 00:00:00 2001 From: Irene Zhang Date: Thu, 16 Nov 2023 01:20:00 +0000 Subject: [PATCH 3/3] [inetstack] Enhancement: Cancel pending ops --- src/rust/inetstack/protocols/tcp/queue.rs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/rust/inetstack/protocols/tcp/queue.rs b/src/rust/inetstack/protocols/tcp/queue.rs index 89b053b29..292911b1a 100644 --- a/src/rust/inetstack/protocols/tcp/queue.rs +++ b/src/rust/inetstack/protocols/tcp/queue.rs @@ -312,6 +312,7 @@ impl SharedTcpQueue { pub fn close(&mut self) -> Result, Fail> { self.state_machine.prepare(SocketOp::Close)?; + self.cancel_pending_ops(Fail::new(libc::ECANCELED, "This queue was closed")); let new_socket: Option> = match self.socket { // Closing an active socket. Socket::Established(ref mut socket) => { @@ -496,6 +497,16 @@ impl SharedTcpQueue { self.pending_ops.insert(handle.clone(), yielder_handle.clone()); } + /// Cancel all currently pending operations on this queue. If the operation is not complete and the coroutine has + /// yielded, wake the coroutine with an error. + fn cancel_pending_ops(&mut self, cause: Fail) { + for (handle, mut yielder_handle) in self.pending_ops.drain() { + if !handle.has_completed() { + yielder_handle.wake_with(Err(cause.clone())); + } + } + } + /// Generic function for spawning a control-path coroutine on [self]. fn do_generic_sync_control_path_call(&mut self, coroutine: F) -> Result where