From f3f3d39aadcc5a8c39324a73fd6c3afbe7b7bcd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=87a=C4=9Fatay=20Yi=C4=9Fit=20=C5=9Eahin?= Date: Tue, 6 Aug 2024 15:38:25 +0200 Subject: [PATCH] refactor(virtq): reduce code duplication --- src/drivers/virtio/virtqueue/mod.rs | 96 ++++++++++++++-- src/drivers/virtio/virtqueue/packed.rs | 152 +++++++------------------ src/drivers/virtio/virtqueue/split.rs | 79 +++---------- 3 files changed, 138 insertions(+), 189 deletions(-) diff --git a/src/drivers/virtio/virtqueue/mod.rs b/src/drivers/virtio/virtqueue/mod.rs index 92e34a7b11..2ae238091d 100644 --- a/src/drivers/virtio/virtqueue/mod.rs +++ b/src/drivers/virtio/virtqueue/mod.rs @@ -8,7 +8,6 @@ //! Drivers who need a more fine grained access to the specific queues must //! use the respective virtqueue structs directly. #![allow(dead_code)] -#![allow(clippy::type_complexity)] pub mod packed; pub mod split; @@ -21,12 +20,14 @@ use core::mem::MaybeUninit; use core::{mem, ptr}; use async_channel::TryRecvError; +use virtio::{le32, le64, pvirtq, virtq}; use self::error::VirtqError; #[cfg(not(feature = "pci"))] use super::transport::mmio::{ComCfg, NotifCfg}; #[cfg(feature = "pci")] use super::transport::pci::{ComCfg, NotifCfg}; +use crate::arch::mm::{paging, VirtAddr}; use crate::mm::device_alloc::DeviceAlloc; /// A u16 newtype. If instantiated via ``VqIndex::from(T)``, the newtype is ensured to be @@ -98,7 +99,6 @@ type UsedBufferTokenSender = async_channel::Sender; /// might not provide the complete feature set of each queue. Drivers who /// do need these features should refrain from providing support for both /// Virtqueue types and use the structs directly instead. -#[allow(private_bounds)] pub trait Virtq { /// The `notif` parameter indicates if the driver wants to have a notification for this specific /// transfer. This is only for performance optimization. As it is NOT ensured, that the device sees the @@ -221,29 +221,33 @@ pub trait Virtq { /// These methods are an implementation detail and are meant only for consumption by the default method /// implementations in [Virtq]. trait VirtqPrivate { - type Descriptor; + type Descriptor: VirtqDescriptor; fn create_indirect_ctrl( - &self, - send: &[BufferElem], - recv: &[BufferElem], + buffer_tkn: &AvailBufferToken, ) -> Result, VirtqError>; + fn indirect_desc(table: &[Self::Descriptor]) -> Self::Descriptor { + Self::Descriptor::incomplete_desc( + paging::virt_to_phys(VirtAddr::from(table.as_ptr() as u64)) + .as_u64() + .into(), + (mem::size_of_val(table) as u32).into(), + virtq::DescF::INDIRECT, + ) + } + /// Consumes the [AvailBufferToken] and returns a [TransferToken], that can be used to actually start the transfer. /// /// After this call, the buffers are no longer writable. fn transfer_token_from_buffer_token( - &self, buff_tkn: AvailBufferToken, await_queue: Option, buffer_type: BufferType, ) -> TransferToken { let ctrl_desc = match buffer_type { BufferType::Direct => None, - BufferType::Indirect => Some( - self.create_indirect_ctrl(&buff_tkn.send_buff, &buff_tkn.recv_buff) - .unwrap(), - ), + BufferType::Indirect => Some(Self::create_indirect_ctrl(&buff_tkn).unwrap()), }; TransferToken { @@ -252,6 +256,76 @@ trait VirtqPrivate { ctrl_desc, } } + + // The descriptors returned by the iterator will be incomplete, as they do not + // have all the information necessary. + fn descriptor_iter<'a>( + buffer_tkn: &AvailBufferToken, + ) -> Result, VirtqError> { + let send_desc_iter = buffer_tkn + .send_buff + .iter() + .map(|elem| (elem, elem.len(), virtq::DescF::empty())); + let recv_desc_iter = buffer_tkn + .recv_buff + .iter() + .map(|elem| (elem, elem.capacity(), virtq::DescF::WRITE)); + let mut all_desc_iter = + send_desc_iter + .chain(recv_desc_iter) + .map(|(mem_descr, len, incomplete_flags)| { + Self::Descriptor::incomplete_desc( + paging::virt_to_phys(VirtAddr::from(mem_descr.addr() as u64)) + .as_u64() + .into(), + (len as u32).into(), + incomplete_flags | virtq::DescF::NEXT, + ) + }); + + let mut last_desc = all_desc_iter + .next_back() + .ok_or(VirtqError::BufferNotSpecified)?; + *last_desc.flags_mut() -= virtq::DescF::NEXT; + + Ok(all_desc_iter.chain([last_desc])) + } +} + +trait VirtqDescriptor { + fn flags_mut(&mut self) -> &mut virtq::DescF; + + fn incomplete_desc(addr: virtio::le64, len: virtio::le32, flags: virtq::DescF) -> Self; +} + +impl VirtqDescriptor for virtq::Desc { + fn flags_mut(&mut self) -> &mut virtq::DescF { + &mut self.flags + } + + fn incomplete_desc(addr: le64, len: le32, flags: virtq::DescF) -> Self { + Self { + addr, + len, + flags, + next: 0.into(), + } + } +} + +impl VirtqDescriptor for pvirtq::Desc { + fn flags_mut(&mut self) -> &mut virtq::DescF { + &mut self.flags + } + + fn incomplete_desc(addr: le64, len: le32, flags: virtq::DescF) -> Self { + Self { + addr, + len, + flags, + id: 0.into(), + } + } } /// The struct represents buffers which are ready to be send via the diff --git a/src/drivers/virtio/virtqueue/packed.rs b/src/drivers/virtio/virtqueue/packed.rs index 68850dac45..b1ea248eee 100644 --- a/src/drivers/virtio/virtqueue/packed.rs +++ b/src/drivers/virtio/virtqueue/packed.rs @@ -6,7 +6,7 @@ use alloc::boxed::Box; use alloc::vec::Vec; use core::cell::Cell; use core::sync::atomic::{fence, Ordering}; -use core::{mem, ops, ptr}; +use core::{ops, ptr}; use align_address::Align; #[cfg(not(feature = "pci"))] @@ -23,7 +23,7 @@ use super::super::transport::mmio::{ComCfg, NotifCfg, NotifCtrl}; use super::super::transport::pci::{ComCfg, NotifCfg, NotifCtrl}; use super::error::VirtqError; use super::{ - AvailBufferToken, BufferElem, BufferType, MemDescrId, MemPool, TransferToken, UsedBufferToken, + AvailBufferToken, BufferType, MemDescrId, MemPool, TransferToken, UsedBufferToken, UsedBufferTokenSender, Virtq, VirtqPrivate, VqIndex, VqSize, }; use crate::arch::mm::paging::{BasePageSize, PageSize}; @@ -147,25 +147,29 @@ impl DescriptorRing { fn push_batch( &mut self, - tkn_lst: Vec>, + tkn_lst: impl IntoIterator>, ) -> Result { // Catch empty push, in order to allow zero initialized first_ctrl_settings struct // which will be overwritten in the first iteration of the for-loop - assert!(!tkn_lst.is_empty()); - - let mut first_ctrl_settings = (0, MemDescrId(0), DescF::empty()); - let mut first_buffer = None; - - for (i, tkn) in tkn_lst.into_iter().enumerate() { - let mut ctrl = self.push_without_making_available(&tkn)?; - if i == 0 { - first_ctrl_settings = (ctrl.start, ctrl.buff_id, ctrl.first_flags); - first_buffer = Some(Box::new(tkn)); - } else { - // Update flags of the first descriptor and set new write_index - ctrl.make_avail(Box::new(tkn)); - } + + let first_ctrl_settings; + let first_buffer; + let mut ctrl; + + let mut tkn_iterator = tkn_lst.into_iter(); + if let Some(first_tkn) = tkn_iterator.next() { + ctrl = self.push_without_making_available(&first_tkn)?; + first_ctrl_settings = (ctrl.start, ctrl.buff_id, ctrl.first_flags); + first_buffer = Some(Box::new(first_tkn)); + } else { + // Empty batches are an error + return Err(VirtqError::BufferNotSpecified); } + // Push the remaining tokens (if any) + for tkn in tkn_iterator { + ctrl.make_avail(Box::new(tkn)); + } + // Manually make the first buffer available lastly // // Providing the first buffer in the list manually @@ -182,14 +186,7 @@ impl DescriptorRing { } fn push(&mut self, tkn: TransferToken) -> Result { - let mut ctrl = self.push_without_making_available(&tkn)?; - // Update flags of the first descriptor and set new write_index - ctrl.make_avail(Box::new(tkn)); - - Ok(RingIdx { - off: self.write_index, - wrap: self.drv_wc.0.into(), - }) + self.push_batch([tkn]) } fn push_without_making_available( @@ -210,52 +207,10 @@ impl DescriptorRing { // The buffer uses indirect descriptors if the ctrl_desc field is Some. if let Some(ctrl_desc) = tkn.ctrl_desc.as_ref() { - let indirect_table_slice_ref = ctrl_desc.as_ref(); - // One indirect descriptor with only flag indirect set - let desc = pvirtq::Desc { - addr: paging::virt_to_phys( - VirtAddr::from(indirect_table_slice_ref.as_ptr() as u64), - ) - .as_u64() - .into(), - len: (mem::size_of_val(indirect_table_slice_ref) as u32).into(), - id: 0.into(), - flags: virtq::DescF::INDIRECT, - }; + let desc = PackedVq::indirect_desc(ctrl_desc.as_ref()); ctrl.write_desc(desc); } else { - let send_desc_iter = tkn - .buff_tkn - .send_buff - .iter() - .map(|elem| (elem, elem.len(), virtq::DescF::empty())); - let recv_desc_iter = tkn - .buff_tkn - .recv_buff - .iter() - .map(|elem| (elem, elem.capacity(), virtq::DescF::WRITE)); - let mut all_desc_iter = - send_desc_iter - .chain(recv_desc_iter) - .map(|(mem_desc, len, incomplete_flags)| pvirtq::Desc { - addr: paging::virt_to_phys(VirtAddr::from(mem_desc.addr() as u64)) - .as_u64() - .into(), - len: (len as u32).into(), - id: 0.into(), - flags: incomplete_flags | virtq::DescF::NEXT, - }); - // We take all but the last pair to be able to remove the [virtq::DescF::NEXT] flag in the last one. - for incomplete_desc in all_desc_iter - .by_ref() - .take(usize::from(tkn.buff_tkn.num_descr()) - 1) - { - ctrl.write_desc(incomplete_desc); - } - { - // The iterator should have left the last element, as we took one less than what is available. - let mut incomplete_desc = all_desc_iter.next().unwrap(); - incomplete_desc.flags -= virtq::DescF::NEXT; + for incomplete_desc in PackedVq::descriptor_iter(&tkn.buff_tkn)? { ctrl.write_desc(incomplete_desc); } } @@ -596,12 +551,9 @@ impl Virtq for PackedVq { // Zero transfers are not allowed assert!(!buffer_tkns.is_empty()); - let transfer_tkns = buffer_tkns - .into_iter() - .map(|(buffer_tkn, buffer_type)| { - self.transfer_token_from_buffer_token(buffer_tkn, None, buffer_type) - }) - .collect(); + let transfer_tkns = buffer_tkns.into_iter().map(|(buffer_tkn, buffer_type)| { + Self::transfer_token_from_buffer_token(buffer_tkn, None, buffer_type) + }); let next_idx = self.descr_ring.push_batch(transfer_tkns)?; @@ -635,16 +587,13 @@ impl Virtq for PackedVq { // Zero transfers are not allowed assert!(!buffer_tkns.is_empty()); - let transfer_tkns = buffer_tkns - .into_iter() - .map(|(buffer_tkn, buffer_type)| { - self.transfer_token_from_buffer_token( - buffer_tkn, - Some(await_queue.clone()), - buffer_type, - ) - }) - .collect(); + let transfer_tkns = buffer_tkns.into_iter().map(|(buffer_tkn, buffer_type)| { + Self::transfer_token_from_buffer_token( + buffer_tkn, + Some(await_queue.clone()), + buffer_type, + ) + }); let next_idx = self.descr_ring.push_batch(transfer_tkns)?; @@ -676,7 +625,7 @@ impl Virtq for PackedVq { notif: bool, buffer_type: BufferType, ) -> Result<(), VirtqError> { - let transfer_tkn = self.transfer_token_from_buffer_token(buffer_tkn, sender, buffer_type); + let transfer_tkn = Self::transfer_token_from_buffer_token(buffer_tkn, sender, buffer_type); let next_idx = self.descr_ring.push(transfer_tkn)?; if notif { @@ -806,33 +755,10 @@ impl VirtqPrivate for PackedVq { type Descriptor = pvirtq::Desc; fn create_indirect_ctrl( - &self, - send: &[BufferElem], - recv: &[BufferElem], + buffer_tkn: &AvailBufferToken, ) -> Result, VirtqError> { - let send_desc_iter = send - .iter() - .map(|elem| (elem, elem.len(), virtq::DescF::empty())); - let recv_desc_iter = recv - .iter() - .map(|elem| (elem, elem.capacity(), virtq::DescF::WRITE)); - let all_desc_iter = - send_desc_iter - .chain(recv_desc_iter) - .map(|(mem_descr, len, incomplete_flags)| pvirtq::Desc { - addr: paging::virt_to_phys(VirtAddr::from(mem_descr.addr() as u64)) - .as_u64() - .into(), - len: (len as u32).into(), - id: 0.into(), - flags: incomplete_flags | virtq::DescF::NEXT, - }); - - let mut indirect_table: Vec<_> = all_desc_iter.collect(); - let last_desc = indirect_table - .last_mut() - .ok_or(VirtqError::BufferNotSpecified)?; - last_desc.flags -= virtq::DescF::NEXT; - Ok(indirect_table.into_boxed_slice()) + Ok(Self::descriptor_iter(buffer_tkn)? + .collect::>() + .into_boxed_slice()) } } diff --git a/src/drivers/virtio/virtqueue/split.rs b/src/drivers/virtio/virtqueue/split.rs index 7538d8a6ef..fb1d9c5908 100644 --- a/src/drivers/virtio/virtqueue/split.rs +++ b/src/drivers/virtio/virtqueue/split.rs @@ -19,8 +19,8 @@ use super::super::transport::mmio::{ComCfg, NotifCfg, NotifCtrl}; use super::super::transport::pci::{ComCfg, NotifCfg, NotifCtrl}; use super::error::VirtqError; use super::{ - AvailBufferToken, BufferElem, BufferType, MemPool, TransferToken, UsedBufferToken, - UsedBufferTokenSender, Virtq, VirtqPrivate, VqIndex, VqSize, + AvailBufferToken, BufferType, MemPool, TransferToken, UsedBufferToken, UsedBufferTokenSender, + Virtq, VirtqPrivate, VqIndex, VqSize, }; use crate::arch::memory_barrier; use crate::arch::mm::{paging, VirtAddr}; @@ -59,44 +59,12 @@ impl DescrRing { fn push(&mut self, tkn: TransferToken) -> Result { let mut index; if let Some(ctrl_desc) = tkn.ctrl_desc.as_ref() { - let indirect_table_slice_ref = ctrl_desc.as_ref(); - - let descriptor = virtq::Desc { - addr: paging::virt_to_phys( - VirtAddr::from(indirect_table_slice_ref.as_ptr() as u64), - ) - .as_u64() - .into(), - len: (mem::size_of_val(indirect_table_slice_ref) as u32).into(), - flags: virtq::DescF::INDIRECT, - next: 0.into(), - }; + let descriptor = SplitVq::indirect_desc(ctrl_desc.as_ref()); index = self.mem_pool.pool.pop().ok_or(VirtqError::NoDescrAvail)?.0; self.descr_table_mut()[usize::from(index)] = MaybeUninit::new(descriptor); } else { - let send_desc_iter = tkn - .buff_tkn - .send_buff - .iter() - .map(|elem| (elem, elem.len(), virtq::DescF::empty())); - let recv_desc_iter = tkn - .buff_tkn - .recv_buff - .iter() - .map(|elem| (elem, elem.capacity(), virtq::DescF::WRITE)); - let mut rev_all_desc_iter = - send_desc_iter - .chain(recv_desc_iter) - .rev() - .map(|(mem_descr, len, flags)| virtq::Desc { - addr: paging::virt_to_phys(VirtAddr::from(mem_descr.addr() as u64)) - .as_u64() - .into(), - len: (len as u32).into(), - flags, - next: 0.into(), - }); + let mut rev_all_desc_iter = SplitVq::descriptor_iter(&tkn.buff_tkn)?.rev(); // We need to handle the last descriptor (the first for the reversed iterator) specially to not set the next flag. { @@ -107,7 +75,6 @@ impl DescrRing { self.descr_table_mut()[usize::from(index)] = MaybeUninit::new(descriptor); } for mut descriptor in rev_all_desc_iter { - descriptor.flags |= virtq::DescF::NEXT; // We have not updated `index` yet, so it is at this point the index of the previous descriptor that had been written. descriptor.next = le16::from(index); @@ -241,7 +208,7 @@ impl Virtq for SplitVq { notif: bool, buffer_type: BufferType, ) -> Result<(), VirtqError> { - let transfer_tkn = self.transfer_token_from_buffer_token(buffer_tkn, sender, buffer_type); + let transfer_tkn = Self::transfer_token_from_buffer_token(buffer_tkn, sender, buffer_type); let next_idx = self.ring.push(transfer_tkn)?; if notif { @@ -364,33 +331,15 @@ impl Virtq for SplitVq { impl VirtqPrivate for SplitVq { type Descriptor = virtq::Desc; fn create_indirect_ctrl( - &self, - send: &[BufferElem], - recv: &[BufferElem], + buffer_tkn: &AvailBufferToken, ) -> Result, VirtqError> { - let send_desc_iter = send - .iter() - .map(|elem| (elem, elem.len(), virtq::DescF::empty())); - let recv_desc_iter = recv - .iter() - .map(|elem| (elem, elem.capacity(), virtq::DescF::WRITE)); - let all_desc_iter = send_desc_iter.chain(recv_desc_iter).zip(1u16..).map( - |((mem_descr, len, incomplete_flags), next_idx)| virtq::Desc { - addr: paging::virt_to_phys(VirtAddr::from(mem_descr.addr() as u64)) - .as_u64() - .into(), - len: (len as u32).into(), - flags: incomplete_flags | virtq::DescF::NEXT, - next: next_idx.into(), - }, - ); - - let mut indirect_table: Vec<_> = all_desc_iter.collect(); - let last_desc = indirect_table - .last_mut() - .ok_or(VirtqError::BufferNotSpecified)?; - last_desc.flags -= virtq::DescF::NEXT; - last_desc.next = 0.into(); - Ok(indirect_table.into_boxed_slice()) + Ok(Self::descriptor_iter(buffer_tkn)? + .zip(1..) + .map(|(descriptor, next_id)| Self::Descriptor { + next: next_id.into(), + ..descriptor + }) + .collect::>() + .into_boxed_slice()) } }