From 139cd0bbd018696e1bca6ea6c6ed4d9148b89003 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=87a=C4=9Fatay=20Yi=C4=9Fit=20=C5=9Eahin?= Date: Fri, 3 May 2024 18:06:02 +0200 Subject: [PATCH] virtq: use async_channel for moving buffers to the callers --- Cargo.lock | 46 +++++++++++++++++-- Cargo.toml | 1 + src/drivers/net/virtio_mmio.rs | 15 +------ src/drivers/net/virtio_net.rs | 61 ++++++++++++-------------- src/drivers/net/virtio_pci.rs | 12 +---- src/drivers/virtio/virtqueue/mod.rs | 45 +++++++++---------- src/drivers/virtio/virtqueue/packed.rs | 9 ++-- src/drivers/virtio/virtqueue/split.rs | 7 +-- 8 files changed, 105 insertions(+), 91 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2fdc7ee8b3..0232f45809 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -129,14 +129,27 @@ dependencies = [ "bitflags 2.5.0", ] +[[package]] +name = "async-channel" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f2776ead772134d55b62dd45e59a79e21612d85d0af729b8b7d3967d601a62a" +dependencies = [ + "concurrent-queue", + "event-listener 5.3.0", + "event-listener-strategy 0.5.2", + "futures-core", + "pin-project-lite", +] + [[package]] name = "async-lock" version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d034b430882f8381900d3fe6f0aaa3ad94f2cb4ac519b429692a1bc2dda4ae7b" dependencies = [ - "event-listener", - "event-listener-strategy", + "event-listener 4.0.3", + "event-listener-strategy 0.4.0", "pin-project-lite", ] @@ -412,13 +425,33 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "event-listener" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d9944b8ca13534cdfb2800775f8dd4902ff3fc75a50101466decadfdf322a24" +dependencies = [ + "concurrent-queue", + "pin-project-lite", +] + [[package]] name = "event-listener-strategy" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" dependencies = [ - "event-listener", + "event-listener 4.0.3", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f214dc438f977e6d4e3500aaa277f5ad94ca83fbbd9b1a15713ce2344ccc5a1" +dependencies = [ + "event-listener 5.3.0", "pin-project-lite", ] @@ -473,6 +506,12 @@ dependencies = [ "x86_64 0.15.1", ] +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + [[package]] name = "generic_once_cell" version = "0.1.1" @@ -561,6 +600,7 @@ dependencies = [ "anstyle", "anyhow", "arm-gic", + "async-channel", "async-lock", "async-trait", "bit_field", diff --git a/Cargo.toml b/Cargo.toml index a895abea8e..e8d797b6bc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -104,6 +104,7 @@ async-lock = { version = "3.3.0", default-features = false } simple-shell = { version = "0.0.1", optional = true } volatile = { version = "0.5.4", features = ["unstable"] } anstyle = { version = "1", default-features = false } +async-channel = { version = "2.2", default-features = false } [dependencies.smoltcp] version = "0.11" diff --git a/src/drivers/net/virtio_mmio.rs b/src/drivers/net/virtio_mmio.rs index 331bf85eb1..140843026e 100644 --- a/src/drivers/net/virtio_mmio.rs +++ b/src/drivers/net/virtio_mmio.rs @@ -2,10 +2,8 @@ //! //! The module contains ... -use alloc::collections::VecDeque; use alloc::rc::Rc; use alloc::vec::Vec; -use core::cell::RefCell; use core::ptr; use core::ptr::read_volatile; use core::str::FromStr; @@ -138,17 +136,8 @@ impl VirtioNetDriver { isr_stat, notif_cfg, ctrl_vq: CtrlQueue::new(None), - recv_vqs: RxQueues::new( - Vec::>::new(), - Rc::new(RefCell::new(VecDeque::new())), - false, - ), - send_vqs: TxQueues::new( - Vec::>::new(), - Rc::new(RefCell::new(VecDeque::new())), - Vec::new(), - false, - ), + recv_vqs: RxQueues::new(Vec::>::new(), false), + send_vqs: TxQueues::new(Vec::>::new(), Vec::new(), false), num_vqs: 0, irq, mtu, diff --git a/src/drivers/net/virtio_net.rs b/src/drivers/net/virtio_net.rs index b3861d0e09..a40a74a78f 100644 --- a/src/drivers/net/virtio_net.rs +++ b/src/drivers/net/virtio_net.rs @@ -3,10 +3,8 @@ //! The module contains ... use alloc::boxed::Box; -use alloc::collections::VecDeque; use alloc::rc::Rc; use alloc::vec::Vec; -use core::cell::RefCell; use core::cmp::Ordering; use core::mem; @@ -154,19 +152,18 @@ enum MqCmd { pub struct RxQueues { vqs: Vec>, - poll_queue: Rc>>>, + poll_sender: async_channel::Sender>, + poll_receiver: async_channel::Receiver>, is_multi: bool, } impl RxQueues { - pub fn new( - vqs: Vec>, - poll_queue: Rc>>>, - is_multi: bool, - ) -> Self { + pub fn new(vqs: Vec>, is_multi: bool) -> Self { + let (poll_sender, poll_receiver) = async_channel::unbounded(); Self { vqs, - poll_queue, + poll_sender, + poll_receiver, is_multi, } } @@ -212,7 +209,7 @@ impl RxQueues { // Transfers will be awaited at the queue buff_tkn .provide() - .dispatch_await(Rc::clone(&self.poll_queue), false); + .dispatch_await(self.poll_sender.clone(), false); } // Safe virtqueue @@ -224,14 +221,16 @@ impl RxQueues { } fn get_next(&mut self) -> Option> { - let transfer = self.poll_queue.borrow_mut().pop_front(); + let transfer = self.poll_receiver.try_recv(); - transfer.or_else(|| { - // Check if any not yet provided transfers are in the queue. - self.poll(); + transfer + .or_else(|_| { + // Check if any not yet provided transfers are in the queue. + self.poll(); - self.poll_queue.borrow_mut().pop_front() - }) + self.poll_receiver.try_recv() + }) + .ok() } fn poll(&self) { @@ -269,7 +268,8 @@ impl RxQueues { /// to the respective queue structures. pub struct TxQueues { vqs: Vec>, - poll_queue: Rc>>>, + poll_sender: async_channel::Sender>, + poll_receiver: async_channel::Receiver>, ready_queue: Vec, /// Indicates, whether the Driver/Device are using multiple /// queues for communication. @@ -277,15 +277,12 @@ pub struct TxQueues { } impl TxQueues { - pub fn new( - vqs: Vec>, - poll_queue: Rc>>>, - ready_queue: Vec, - is_multi: bool, - ) -> Self { + pub fn new(vqs: Vec>, ready_queue: Vec, is_multi: bool) -> Self { + let (poll_sender, poll_receiver) = async_channel::unbounded(); Self { vqs, - poll_queue, + poll_sender, + poll_receiver, ready_queue, is_multi, } @@ -408,11 +405,11 @@ impl TxQueues { } } - if self.poll_queue.borrow().is_empty() { + if self.poll_receiver.is_empty() { self.poll(); } - while let Some(buffer_token) = self.poll_queue.borrow_mut().pop_back() { + while let Ok(buffer_token) = self.poll_receiver.try_recv() { let mut tkn = buffer_token.reset(); let (send_len, _) = tkn.len(); @@ -484,7 +481,7 @@ impl NetworkDriver for VirtioNetDriver { #[allow(dead_code)] fn has_packet(&self) -> bool { self.recv_vqs.poll(); - !self.recv_vqs.poll_queue.borrow().is_empty() + !self.recv_vqs.poll_receiver.is_empty() } /// Provides smoltcp a slice to copy the IP packet and transfer the packet @@ -547,7 +544,7 @@ impl NetworkDriver for VirtioNetDriver { buff_tkn .provide() - .dispatch_await(Rc::clone(&self.send_vqs.poll_queue), false); + .dispatch_await(self.send_vqs.poll_sender.clone(), false); result } else { @@ -581,7 +578,7 @@ impl NetworkDriver for VirtioNetDriver { transfer .reset() .provide() - .dispatch_await(Rc::clone(&self.recv_vqs.poll_queue), false); + .dispatch_await(self.recv_vqs.poll_sender.clone(), false); return None; } @@ -598,7 +595,7 @@ impl NetworkDriver for VirtioNetDriver { transfer .reset() .provide() - .dispatch_await(Rc::clone(&self.recv_vqs.poll_queue), false); + .dispatch_await(self.recv_vqs.poll_sender.clone(), false); num_buffers }; @@ -620,7 +617,7 @@ impl NetworkDriver for VirtioNetDriver { transfer .reset() .provide() - .dispatch_await(Rc::clone(&self.recv_vqs.poll_queue), false); + .dispatch_await(self.recv_vqs.poll_sender.clone(), false); } Some((RxToken::new(vec_data), TxToken::new())) @@ -631,7 +628,7 @@ impl NetworkDriver for VirtioNetDriver { .write_seq(None::<&VirtioNetHdr>, Some(&VirtioNetHdr::default())) .unwrap() .provide() - .dispatch_await(Rc::clone(&self.recv_vqs.poll_queue), false); + .dispatch_await(self.recv_vqs.poll_sender.clone(), false); None } diff --git a/src/drivers/net/virtio_pci.rs b/src/drivers/net/virtio_pci.rs index f1852138f8..5559c6c009 100644 --- a/src/drivers/net/virtio_pci.rs +++ b/src/drivers/net/virtio_pci.rs @@ -2,10 +2,7 @@ //! //! The module contains ... -use alloc::collections::VecDeque; -use alloc::rc::Rc; use alloc::vec::Vec; -use core::cell::RefCell; use core::str::FromStr; use smoltcp::phy::ChecksumCapabilities; @@ -146,13 +143,8 @@ impl VirtioNetDriver { notif_cfg, ctrl_vq: CtrlQueue::new(None), - recv_vqs: RxQueues::new(Vec::new(), Rc::new(RefCell::new(VecDeque::new())), false), - send_vqs: TxQueues::new( - Vec::new(), - Rc::new(RefCell::new(VecDeque::new())), - Vec::new(), - false, - ), + recv_vqs: RxQueues::new(Vec::new(), false), + send_vqs: TxQueues::new(Vec::new(), Vec::new(), false), num_vqs: 0, irq: device.get_irq().unwrap(), mtu, diff --git a/src/drivers/virtio/virtqueue/mod.rs b/src/drivers/virtio/virtqueue/mod.rs index b90100d72b..72b75529ba 100644 --- a/src/drivers/virtio/virtqueue/mod.rs +++ b/src/drivers/virtio/virtqueue/mod.rs @@ -14,7 +14,6 @@ pub mod packed; pub mod split; use alloc::boxed::Box; -use alloc::collections::VecDeque; use alloc::rc::Rc; use alloc::vec::Vec; use core::cell::RefCell; @@ -22,6 +21,7 @@ use core::ops::{BitAnd, Deref, DerefMut}; use core::ptr; use align_address::Align; +use async_channel::TryRecvError; use zerocopy::AsBytes; use self::error::{BufferError, VirtqError}; @@ -99,6 +99,8 @@ struct Descriptor { flags: u16, } +type BufferTokenSender = async_channel::Sender>; + // Public interface of Virtq /// The Virtq trait unifies access to the two different Virtqueue types @@ -155,7 +157,7 @@ pub trait Virtq: VirtqPrivate { fn dispatch_batch_await( &self, tkns: Vec, - await_queue: Rc>>>, + await_queue: BufferTokenSender, notif: bool, ); @@ -1385,11 +1387,7 @@ pub fn dispatch_batch(tkns: Vec, notif: bool) { /// The `notif` parameter indicates if the driver wants to have a notification for this specific /// transfer. This is only for performance optimization. As it is NOT ensured, that the device sees the /// updated notification flags before finishing transfers! -pub fn dispatch_batch_await( - tkns: Vec, - await_queue: Rc>>>, - notif: bool, -) { +pub fn dispatch_batch_await(tkns: Vec, await_queue: BufferTokenSender, notif: bool) { let mut used_vqs: Vec<(Rc, Vec)> = Vec::new(); // Sort the TransferTokens depending in the queue their coming from. @@ -1420,7 +1418,7 @@ pub fn dispatch_batch_await( } for (vq, tkn_lst) in used_vqs { - vq.dispatch_batch_await(tkn_lst, Rc::clone(&await_queue), notif); + vq.dispatch_batch_await(tkn_lst, await_queue.clone(), notif); } } @@ -1476,7 +1474,7 @@ pub struct TransferToken { /// If Some, finished TransferTokens will be placed here /// as finished `Transfers`. If None, only the state /// of the Token will be changed. - await_queue: Option>>>>, + await_queue: Option, } /// Public Interface for TransferToken @@ -1492,12 +1490,8 @@ impl TransferToken { /// The `notif` parameter indicates if the driver wants to have a notification for this specific /// transfer. This is only for performance optimization. As it is NOT ensured, that the device sees the /// updated notification flags before finishing transfers! - pub fn dispatch_await( - mut self, - await_queue: Rc>>>, - notif: bool, - ) { - self.await_queue = Some(Rc::clone(&await_queue)); + pub fn dispatch_await(mut self, await_queue: BufferTokenSender, notif: bool) { + self.await_queue = Some(await_queue.clone()); self.get_vq().dispatch(self, notif); } @@ -1522,20 +1516,27 @@ impl TransferToken { /// Upon finish notifications are enabled again. pub fn dispatch_blocking(self) -> Result, VirtqError> { let vq = self.get_vq(); - let rcv_queue = Rc::new(RefCell::new(VecDeque::with_capacity(1))); - self.dispatch_await(rcv_queue.clone(), false); + let (sender, receiver) = async_channel::bounded(1); + self.dispatch_await(sender, false); vq.disable_notifs(); - while rcv_queue.borrow().is_empty() { - // Keep Spinning until the receive queue is filled - vq.poll() + let result: Box; + // Keep Spinning until the receive queue is filled + loop { + match receiver.try_recv() { + Ok(buffer_tkn) => { + result = buffer_tkn; + break; + } + Err(TryRecvError::Closed) => return Err(VirtqError::General), + Err(TryRecvError::Empty) => vq.poll(), + } } vq.enable_notifs(); - let result = Ok(rcv_queue.borrow_mut().pop_front().unwrap()); - result + Ok(result) } } diff --git a/src/drivers/virtio/virtqueue/packed.rs b/src/drivers/virtio/virtqueue/packed.rs index b82a17571a..d7edab90a4 100644 --- a/src/drivers/virtio/virtqueue/packed.rs +++ b/src/drivers/virtio/virtqueue/packed.rs @@ -3,7 +3,6 @@ #![allow(dead_code)] use alloc::boxed::Box; -use alloc::collections::VecDeque; use alloc::rc::Rc; use alloc::vec::Vec; use core::cell::RefCell; @@ -138,9 +137,7 @@ impl DescriptorRing { if let Some(mut tkn) = ctrl.poll_next() { if let Some(queue) = tkn.await_queue.take() { // Place the TransferToken in a Transfer, which will hold ownership of the token - queue - .borrow_mut() - .push_back(Box::new(tkn.buff_tkn.unwrap())); + queue.try_send(Box::new(tkn.buff_tkn.unwrap())).unwrap(); } } } @@ -996,7 +993,7 @@ impl Virtq for PackedVq { fn dispatch_batch_await( &self, mut tkns: Vec, - await_queue: Rc>>>, + await_queue: super::BufferTokenSender, notif: bool, ) { // Zero transfers are not allowed @@ -1004,7 +1001,7 @@ impl Virtq for PackedVq { // We have to iterate here too, in order to ensure, tokens are placed into the await_queue for tkn in tkns.iter_mut() { - tkn.await_queue = Some(Rc::clone(&await_queue)); + tkn.await_queue = Some(await_queue.clone()); } let (next_off, next_wrap) = self.descr_ring.borrow_mut().push_batch(tkns); diff --git a/src/drivers/virtio/virtqueue/split.rs b/src/drivers/virtio/virtqueue/split.rs index cf7825c44a..0d1cc81c54 100644 --- a/src/drivers/virtio/virtqueue/split.rs +++ b/src/drivers/virtio/virtqueue/split.rs @@ -3,7 +3,6 @@ #![allow(dead_code)] use alloc::boxed::Box; -use alloc::collections::VecDeque; use alloc::rc::Rc; use alloc::vec::Vec; use core::alloc::{Allocator, Layout}; @@ -305,9 +304,7 @@ impl DescrRing { .unwrap(); } if let Some(queue) = tkn.await_queue.take() { - queue - .borrow_mut() - .push_back(Box::new(tkn.buff_tkn.unwrap())) + queue.try_send(Box::new(tkn.buff_tkn.unwrap())).unwrap() } memory_barrier(); self.read_idx = self.read_idx.wrapping_add(1); @@ -363,7 +360,7 @@ impl Virtq for SplitVq { fn dispatch_batch_await( &self, _tkns: Vec, - _await_queue: Rc>>>, + _await_queue: super::BufferTokenSender, _notif: bool, ) { unimplemented!()