From a91a980c38fd18b6f8787b93df92aa9153b1e717 Mon Sep 17 00:00:00 2001
From: Andrew Walbran <qwandor@google.com>
Date: Fri, 25 Oct 2024 15:31:09 +0100
Subject: [PATCH] Factor configuration access out from PciRoot to a new trait.

This will allow implementations using methods other than MMIO to be
implemented outside of this crate.
---
 examples/aarch64/src/main.rs |  20 ++-
 examples/x86_64/src/main.rs  |  12 +-
 src/transport/pci.rs         |  40 +++---
 src/transport/pci/bus.rs     | 262 ++++++++++++++++++++---------------
 4 files changed, 197 insertions(+), 137 deletions(-)

diff --git a/examples/aarch64/src/main.rs b/examples/aarch64/src/main.rs
index 9716db7a..0158d686 100644
--- a/examples/aarch64/src/main.rs
+++ b/examples/aarch64/src/main.rs
@@ -38,7 +38,10 @@ use virtio_drivers::{
     transport::{
         mmio::{MmioTransport, VirtIOHeader},
         pci::{
-            bus::{BarInfo, Cam, Command, DeviceFunction, MemoryBarType, PciRoot},
+            bus::{
+                BarInfo, Cam, Command, ConfigurationAccess, DeviceFunction, MemoryBarType, MmioCam,
+                PciRoot,
+            },
             virtio_device_type, PciTransport,
         },
         DeviceType, Transport,
@@ -291,8 +294,9 @@ fn enumerate_pci(pci_node: FdtNode, cam: Cam) {
             region.starting_address as usize + region.size.unwrap()
         );
         assert_eq!(region.size.unwrap(), cam.size() as usize);
-        // Safe because we know the pointer is to a valid MMIO region.
-        let mut pci_root = unsafe { PciRoot::new(region.starting_address as *mut u8, cam) };
+        // SAFETY: We know the pointer is to a valid MMIO region.
+        let mut pci_root =
+            PciRoot::new(unsafe { MmioCam::new(region.starting_address as *mut u8, cam) });
         for (device_function, info) in pci_root.enumerate_bus(0) {
             let (status, command) = pci_root.get_status_command(device_function);
             info!(
@@ -304,7 +308,7 @@ fn enumerate_pci(pci_node: FdtNode, cam: Cam) {
                 allocate_bars(&mut pci_root, device_function, &mut allocator);
                 dump_bar_contents(&mut pci_root, device_function, 4);
                 let mut transport =
-                    PciTransport::new::<HalImpl>(&mut pci_root, device_function).unwrap();
+                    PciTransport::new::<HalImpl, _>(&mut pci_root, device_function).unwrap();
                 info!(
                     "Detected virtio PCI device with device type {:?}, features {:#018x}",
                     transport.device_type(),
@@ -379,7 +383,11 @@ const fn align_up(value: u32, alignment: u32) -> u32 {
     ((value - 1) | (alignment - 1)) + 1
 }
 
-fn dump_bar_contents(root: &mut PciRoot, device_function: DeviceFunction, bar_index: u8) {
+fn dump_bar_contents(
+    root: &mut PciRoot<impl ConfigurationAccess>,
+    device_function: DeviceFunction,
+    bar_index: u8,
+) {
     let bar_info = root.bar_info(device_function, bar_index).unwrap();
     trace!("Dumping bar {}: {:#x?}", bar_index, bar_info);
     if let BarInfo::Memory { address, size, .. } = bar_info {
@@ -400,7 +408,7 @@ fn dump_bar_contents(root: &mut PciRoot, device_function: DeviceFunction, bar_in
 
 /// Allocates appropriately-sized memory regions and assigns them to the device's BARs.
 fn allocate_bars(
-    root: &mut PciRoot,
+    root: &mut PciRoot<impl ConfigurationAccess>,
     device_function: DeviceFunction,
     allocator: &mut PciMemory32Allocator,
 ) {
diff --git a/examples/x86_64/src/main.rs b/examples/x86_64/src/main.rs
index f9c9f3ed..0e165b83 100644
--- a/examples/x86_64/src/main.rs
+++ b/examples/x86_64/src/main.rs
@@ -21,7 +21,7 @@ use virtio_drivers::{
     device::{blk::VirtIOBlk, gpu::VirtIOGpu},
     transport::{
         pci::{
-            bus::{BarInfo, Cam, Command, DeviceFunction, PciRoot},
+            bus::{BarInfo, Cam, Command, ConfigurationAccess, DeviceFunction, MmioCam, PciRoot},
             virtio_device_type, PciTransport,
         },
         DeviceType, Transport,
@@ -150,7 +150,7 @@ fn virtio_net<T: Transport>(transport: T) {
 fn enumerate_pci(mmconfig_base: *mut u8) {
     info!("mmconfig_base = {:#x}", mmconfig_base as usize);
 
-    let mut pci_root = unsafe { PciRoot::new(mmconfig_base, Cam::Ecam) };
+    let mut pci_root = PciRoot::new(unsafe { MmioCam::new(mmconfig_base, Cam::Ecam) });
     for (device_function, info) in pci_root.enumerate_bus(0) {
         let (status, command) = pci_root.get_status_command(device_function);
         info!(
@@ -168,7 +168,7 @@ fn enumerate_pci(mmconfig_base: *mut u8) {
             dump_bar_contents(&mut pci_root, device_function, 4);
 
             let mut transport =
-                PciTransport::new::<HalImpl>(&mut pci_root, device_function).unwrap();
+                PciTransport::new::<HalImpl, _>(&mut pci_root, device_function).unwrap();
             info!(
                 "Detected virtio PCI device with device type {:?}, features {:#018x}",
                 transport.device_type(),
@@ -179,7 +179,11 @@ fn enumerate_pci(mmconfig_base: *mut u8) {
     }
 }
 
-fn dump_bar_contents(root: &mut PciRoot, device_function: DeviceFunction, bar_index: u8) {
+fn dump_bar_contents(
+    root: &mut PciRoot<impl ConfigurationAccess>,
+    device_function: DeviceFunction,
+    bar_index: u8,
+) {
     let bar_info = root.bar_info(device_function, bar_index).unwrap();
     trace!("Dumping bar {}: {:#x?}", bar_index, bar_info);
     if let BarInfo::Memory { address, size, .. } = bar_info {
diff --git a/src/transport/pci.rs b/src/transport/pci.rs
index 94b4cdfa..fbbb9524 100644
--- a/src/transport/pci.rs
+++ b/src/transport/pci.rs
@@ -2,7 +2,9 @@
 
 pub mod bus;
 
-use self::bus::{DeviceFunction, DeviceFunctionInfo, PciError, PciRoot, PCI_CAP_ID_VNDR};
+use self::bus::{
+    ConfigurationAccess, DeviceFunction, DeviceFunctionInfo, PciError, PciRoot, PCI_CAP_ID_VNDR,
+};
 use super::{DeviceStatus, DeviceType, Transport};
 use crate::{
     hal::{Hal, PhysAddr},
@@ -99,11 +101,11 @@ impl PciTransport {
     /// root controller.
     ///
     /// The PCI device must already have had its BARs allocated.
-    pub fn new<H: Hal>(
-        root: &mut PciRoot,
+    pub fn new<H: Hal, C: ConfigurationAccess>(
+        root: &mut PciRoot<C>,
         device_function: DeviceFunction,
     ) -> Result<Self, VirtioPciError> {
-        let device_vendor = root.config_read_word(device_function, 0);
+        let device_vendor = root.configuration_access.read_word(device_function, 0);
         let device_id = (device_vendor >> 16) as u16;
         let vendor_id = device_vendor as u16;
         if vendor_id != VIRTIO_VENDOR_ID {
@@ -127,12 +129,16 @@ impl PciTransport {
                 continue;
             }
             let struct_info = VirtioCapabilityInfo {
-                bar: root.config_read_word(device_function, capability.offset + CAP_BAR_OFFSET)
+                bar: root
+                    .configuration_access
+                    .read_word(device_function, capability.offset + CAP_BAR_OFFSET)
                     as u8,
                 offset: root
-                    .config_read_word(device_function, capability.offset + CAP_BAR_OFFSET_OFFSET),
+                    .configuration_access
+                    .read_word(device_function, capability.offset + CAP_BAR_OFFSET_OFFSET),
                 length: root
-                    .config_read_word(device_function, capability.offset + CAP_LENGTH_OFFSET),
+                    .configuration_access
+                    .read_word(device_function, capability.offset + CAP_LENGTH_OFFSET),
             };
 
             match cfg_type {
@@ -141,7 +147,7 @@ impl PciTransport {
                 }
                 VIRTIO_PCI_CAP_NOTIFY_CFG if cap_len >= 20 && notify_cfg.is_none() => {
                     notify_cfg = Some(struct_info);
-                    notify_off_multiplier = root.config_read_word(
+                    notify_off_multiplier = root.configuration_access.read_word(
                         device_function,
                         capability.offset + CAP_NOTIFY_OFF_MULTIPLIER_OFFSET,
                     );
@@ -156,7 +162,7 @@ impl PciTransport {
             }
         }
 
-        let common_cfg = get_bar_region::<H, _>(
+        let common_cfg = get_bar_region::<H, _, _>(
             root,
             device_function,
             &common_cfg.ok_or(VirtioPciError::MissingCommonConfig)?,
@@ -168,16 +174,16 @@ impl PciTransport {
                 notify_off_multiplier,
             ));
         }
-        let notify_region = get_bar_region_slice::<H, _>(root, device_function, &notify_cfg)?;
+        let notify_region = get_bar_region_slice::<H, _, _>(root, device_function, &notify_cfg)?;
 
-        let isr_status = get_bar_region::<H, _>(
+        let isr_status = get_bar_region::<H, _, _>(
             root,
             device_function,
             &isr_cfg.ok_or(VirtioPciError::MissingIsrConfig)?,
         )?;
 
         let config_space = if let Some(device_cfg) = device_cfg {
-            Some(get_bar_region_slice::<H, _>(
+            Some(get_bar_region_slice::<H, _, _>(
                 root,
                 device_function,
                 &device_cfg,
@@ -387,8 +393,8 @@ struct VirtioCapabilityInfo {
     length: u32,
 }
 
-fn get_bar_region<H: Hal, T>(
-    root: &mut PciRoot,
+fn get_bar_region<H: Hal, T, C: ConfigurationAccess>(
+    root: &mut PciRoot<C>,
     device_function: DeviceFunction,
     struct_info: &VirtioCapabilityInfo,
 ) -> Result<NonNull<T>, VirtioPciError> {
@@ -417,12 +423,12 @@ fn get_bar_region<H: Hal, T>(
     Ok(vaddr.cast())
 }
 
-fn get_bar_region_slice<H: Hal, T>(
-    root: &mut PciRoot,
+fn get_bar_region_slice<H: Hal, T, C: ConfigurationAccess>(
+    root: &mut PciRoot<C>,
     device_function: DeviceFunction,
     struct_info: &VirtioCapabilityInfo,
 ) -> Result<NonNull<[T]>, VirtioPciError> {
-    let ptr = get_bar_region::<H, T>(root, device_function, struct_info)?;
+    let ptr = get_bar_region::<H, T, C>(root, device_function, struct_info)?;
     Ok(nonnull_slice_from_raw_parts(
         ptr,
         struct_info.length as usize / size_of::<T>(),
diff --git a/src/transport/pci/bus.rs b/src/transport/pci/bus.rs
index cbf43271..a31a43b0 100644
--- a/src/transport/pci/bus.rs
+++ b/src/transport/pci/bus.rs
@@ -92,9 +92,8 @@ pub enum PciError {
 
 /// The root complex of a PCI bus.
 #[derive(Debug)]
-pub struct PciRoot {
-    mmio_base: *mut u32,
-    cam: Cam,
+pub struct PciRoot<C: ConfigurationAccess> {
+    pub(crate) configuration_access: C,
 }
 
 /// A PCI Configuration Access Mechanism.
@@ -120,93 +119,21 @@ impl Cam {
     }
 }
 
-impl PciRoot {
-    /// Wraps the PCI root complex with the given MMIO base address.
-    ///
-    /// Panics if the base address is not aligned to a 4-byte boundary.
-    ///
-    /// # Safety
-    ///
-    /// `mmio_base` must be a valid pointer to an appropriately-mapped MMIO region of at least
-    /// 16 MiB (if `cam == Cam::MmioCam`) or 256 MiB (if `cam == Cam::Ecam`). The pointer must be
-    /// valid for the entire lifetime of the program (i.e. `'static`), which implies that no Rust
-    /// references may be used to access any of the memory region at any point.
-    pub unsafe fn new(mmio_base: *mut u8, cam: Cam) -> Self {
-        assert!(mmio_base as usize & 0x3 == 0);
+impl<C: ConfigurationAccess> PciRoot<C> {
+    /// Creates a new `PciRoot` to access a PCI root complex through the given configuration access
+    /// implementation.
+    pub fn new(configuration_access: C) -> Self {
         Self {
-            mmio_base: mmio_base as *mut u32,
-            cam,
-        }
-    }
-
-    /// Makes a clone of the `PciRoot`, pointing at the same MMIO region.
-    ///
-    /// # Safety
-    ///
-    /// This function allows concurrent mutable access to the PCI CAM. To avoid this causing
-    /// problems, the returned `PciRoot` instance must only be used to read read-only fields.
-    unsafe fn unsafe_clone(&self) -> Self {
-        Self {
-            mmio_base: self.mmio_base,
-            cam: self.cam,
-        }
-    }
-
-    fn cam_offset(&self, device_function: DeviceFunction, register_offset: u8) -> u32 {
-        assert!(device_function.valid());
-
-        let bdf = (device_function.bus as u32) << 8
-            | (device_function.device as u32) << 3
-            | device_function.function as u32;
-        let address =
-            bdf << match self.cam {
-                Cam::MmioCam => 8,
-                Cam::Ecam => 12,
-            } | register_offset as u32;
-        // Ensure that address is within range.
-        assert!(address < self.cam.size());
-        // Ensure that address is word-aligned.
-        assert!(address & 0x3 == 0);
-        address
-    }
-
-    /// Reads 4 bytes from configuration space using the appropriate CAM.
-    pub(crate) fn config_read_word(
-        &self,
-        device_function: DeviceFunction,
-        register_offset: u8,
-    ) -> u32 {
-        let address = self.cam_offset(device_function, register_offset);
-        // Safe because both the `mmio_base` and the address offset are properly aligned, and the
-        // resulting pointer is within the MMIO range of the CAM.
-        unsafe {
-            // Right shift to convert from byte offset to word offset.
-            (self.mmio_base.add((address >> 2) as usize)).read_volatile()
-        }
-    }
-
-    /// Writes 4 bytes to configuration space using the appropriate CAM.
-    pub(crate) fn config_write_word(
-        &mut self,
-        device_function: DeviceFunction,
-        register_offset: u8,
-        data: u32,
-    ) {
-        let address = self.cam_offset(device_function, register_offset);
-        // Safe because both the `mmio_base` and the address offset are properly aligned, and the
-        // resulting pointer is within the MMIO range of the CAM.
-        unsafe {
-            // Right shift to convert from byte offset to word offset.
-            (self.mmio_base.add((address >> 2) as usize)).write_volatile(data)
+            configuration_access,
         }
     }
 
     /// Enumerates PCI devices on the given bus.
-    pub fn enumerate_bus(&self, bus: u8) -> BusDeviceIterator {
+    pub fn enumerate_bus(&self, bus: u8) -> BusDeviceIterator<C> {
         // Safe because the BusDeviceIterator only reads read-only fields.
-        let root = unsafe { self.unsafe_clone() };
+        let configuration_access = unsafe { self.configuration_access.unsafe_clone() };
         BusDeviceIterator {
-            root,
+            configuration_access,
             next: DeviceFunction {
                 bus,
                 device: 0,
@@ -217,7 +144,9 @@ impl PciRoot {
 
     /// Reads the status and command registers of the given device function.
     pub fn get_status_command(&self, device_function: DeviceFunction) -> (Status, Command) {
-        let status_command = self.config_read_word(device_function, STATUS_COMMAND_OFFSET);
+        let status_command = self
+            .configuration_access
+            .read_word(device_function, STATUS_COMMAND_OFFSET);
         let status = Status::from_bits_truncate((status_command >> 16) as u16);
         let command = Command::from_bits_truncate(status_command as u16);
         (status, command)
@@ -225,7 +154,7 @@ impl PciRoot {
 
     /// Sets the command register of the given device function.
     pub fn set_command(&mut self, device_function: DeviceFunction, command: Command) {
-        self.config_write_word(
+        self.configuration_access.write_word(
             device_function,
             STATUS_COMMAND_OFFSET,
             command.bits().into(),
@@ -233,9 +162,9 @@ impl PciRoot {
     }
 
     /// Gets an iterator over the capabilities of the given device function.
-    pub fn capabilities(&self, device_function: DeviceFunction) -> CapabilityIterator {
+    pub fn capabilities(&self, device_function: DeviceFunction) -> CapabilityIterator<C> {
         CapabilityIterator {
-            root: self,
+            configuration_access: &self.configuration_access,
             device_function,
             next_capability_offset: self.capabilities_offset(device_function),
         }
@@ -263,17 +192,29 @@ impl PciRoot {
         device_function: DeviceFunction,
         bar_index: u8,
     ) -> Result<BarInfo, PciError> {
-        let bar_orig = self.config_read_word(device_function, BAR0_OFFSET + 4 * bar_index);
+        let bar_orig = self
+            .configuration_access
+            .read_word(device_function, BAR0_OFFSET + 4 * bar_index);
 
         // Get the size of the BAR.
-        self.config_write_word(device_function, BAR0_OFFSET + 4 * bar_index, 0xffffffff);
-        let size_mask = self.config_read_word(device_function, BAR0_OFFSET + 4 * bar_index);
+        self.configuration_access.write_word(
+            device_function,
+            BAR0_OFFSET + 4 * bar_index,
+            0xffffffff,
+        );
+        let size_mask = self
+            .configuration_access
+            .read_word(device_function, BAR0_OFFSET + 4 * bar_index);
         // A wrapping add is necessary to correctly handle the case of unused BARs, which read back
         // as 0, and should be treated as size 0.
         let size = (!(size_mask & 0xfffffff0)).wrapping_add(1);
 
         // Restore the original value.
-        self.config_write_word(device_function, BAR0_OFFSET + 4 * bar_index, bar_orig);
+        self.configuration_access.write_word(
+            device_function,
+            BAR0_OFFSET + 4 * bar_index,
+            bar_orig,
+        );
 
         if bar_orig & 0x00000001 == 0x00000001 {
             // I/O space
@@ -288,8 +229,9 @@ impl PciRoot {
                 if bar_index >= 5 {
                     return Err(PciError::InvalidBarType);
                 }
-                let address_top =
-                    self.config_read_word(device_function, BAR0_OFFSET + 4 * (bar_index + 1));
+                let address_top = self
+                    .configuration_access
+                    .read_word(device_function, BAR0_OFFSET + 4 * (bar_index + 1));
                 address |= u64::from(address_top) << 32;
             }
             Ok(BarInfo::Memory {
@@ -303,13 +245,18 @@ impl PciRoot {
 
     /// Sets the address of the given 32-bit memory or I/O BAR of the given device function.
     pub fn set_bar_32(&mut self, device_function: DeviceFunction, bar_index: u8, address: u32) {
-        self.config_write_word(device_function, BAR0_OFFSET + 4 * bar_index, address);
+        self.configuration_access
+            .write_word(device_function, BAR0_OFFSET + 4 * bar_index, address);
     }
 
     /// Sets the address of the given 64-bit memory BAR of the given device function.
     pub fn set_bar_64(&mut self, device_function: DeviceFunction, bar_index: u8, address: u64) {
-        self.config_write_word(device_function, BAR0_OFFSET + 4 * bar_index, address as u32);
-        self.config_write_word(
+        self.configuration_access.write_word(
+            device_function,
+            BAR0_OFFSET + 4 * bar_index,
+            address as u32,
+        );
+        self.configuration_access.write_word(
             device_function,
             BAR0_OFFSET + 4 * (bar_index + 1),
             (address >> 32) as u32,
@@ -320,19 +267,112 @@ impl PciRoot {
     fn capabilities_offset(&self, device_function: DeviceFunction) -> Option<u8> {
         let (status, _) = self.get_status_command(device_function);
         if status.contains(Status::CAPABILITIES_LIST) {
-            Some((self.config_read_word(device_function, 0x34) & 0xFC) as u8)
+            Some((self.configuration_access.read_word(device_function, 0x34) & 0xFC) as u8)
         } else {
             None
         }
     }
 }
 
+/// A method to access PCI configuration space for a particular PCI bus.
+pub trait ConfigurationAccess {
+    /// Reads 4 bytes from the configuration space.
+    fn read_word(&self, device_function: DeviceFunction, register_offset: u8) -> u32;
+
+    /// Writes 4 bytes to the configuration space.
+    fn write_word(&mut self, device_function: DeviceFunction, register_offset: u8, data: u32);
+
+    /// Makes a clone of the `ConfigurationAccess`, accessing the same PCI bus.
+    ///
+    /// # Safety
+    ///
+    /// This function allows concurrent mutable access to the PCI CAM. To avoid this causing
+    /// problems, the returned `ConfigurationAccess` instance must only be used to read read-only
+    /// fields.
+    unsafe fn unsafe_clone(&self) -> Self;
+}
+
+/// `ConfigurationAccess` implementation for memory-mapped access to a PCI root complex, via either
+/// a 16 MiB region for the PCI Configuration Access Mechanism or a 256 MiB region for the PCIe
+/// Enhanced Configuration Access Mechanism.
+pub struct MmioCam {
+    mmio_base: *mut u32,
+    cam: Cam,
+}
+
+impl MmioCam {
+    /// Wraps the PCI root complex with the given MMIO base address.
+    ///
+    /// Panics if the base address is not aligned to a 4-byte boundary.
+    ///
+    /// # Safety
+    ///
+    /// `mmio_base` must be a valid pointer to an appropriately-mapped MMIO region of at least
+    /// 16 MiB (if `cam == Cam::MmioCam`) or 256 MiB (if `cam == Cam::Ecam`). The pointer must be
+    /// valid for the entire lifetime of the program (i.e. `'static`), which implies that no Rust
+    /// references may be used to access any of the memory region at any point.
+    pub unsafe fn new(mmio_base: *mut u8, cam: Cam) -> Self {
+        assert!(mmio_base as usize & 0x3 == 0);
+        Self {
+            mmio_base: mmio_base as *mut u32,
+            cam,
+        }
+    }
+
+    fn cam_offset(&self, device_function: DeviceFunction, register_offset: u8) -> u32 {
+        assert!(device_function.valid());
+
+        let bdf = (device_function.bus as u32) << 8
+            | (device_function.device as u32) << 3
+            | device_function.function as u32;
+        let address =
+            bdf << match self.cam {
+                Cam::MmioCam => 8,
+                Cam::Ecam => 12,
+            } | register_offset as u32;
+        // Ensure that address is within range.
+        assert!(address < self.cam.size());
+        // Ensure that address is word-aligned.
+        assert!(address & 0x3 == 0);
+        address
+    }
+}
+
+impl ConfigurationAccess for MmioCam {
+    fn read_word(&self, device_function: DeviceFunction, register_offset: u8) -> u32 {
+        let address = self.cam_offset(device_function, register_offset);
+        // Safe because both the `mmio_base` and the address offset are properly aligned, and the
+        // resulting pointer is within the MMIO range of the CAM.
+        unsafe {
+            // Right shift to convert from byte offset to word offset.
+            (self.mmio_base.add((address >> 2) as usize)).read_volatile()
+        }
+    }
+
+    fn write_word(&mut self, device_function: DeviceFunction, register_offset: u8, data: u32) {
+        let address = self.cam_offset(device_function, register_offset);
+        // Safe because both the `mmio_base` and the address offset are properly aligned, and the
+        // resulting pointer is within the MMIO range of the CAM.
+        unsafe {
+            // Right shift to convert from byte offset to word offset.
+            (self.mmio_base.add((address >> 2) as usize)).write_volatile(data)
+        }
+    }
+
+    unsafe fn unsafe_clone(&self) -> Self {
+        Self {
+            mmio_base: self.mmio_base,
+            cam: self.cam,
+        }
+    }
+}
+
 // SAFETY: `mmio_base` is only used for MMIO, which can happen from any thread or CPU core.
-unsafe impl Send for PciRoot {}
+unsafe impl Send for MmioCam {}
 
-// SAFETY: `&PciRoot` only allows MMIO reads, which are fine to happen concurrently on different CPU
+// SAFETY: `&MmioCam` only allows MMIO reads, which are fine to happen concurrently on different CPU
 // cores.
-unsafe impl Sync for PciRoot {}
+unsafe impl Sync for MmioCam {}
 
 /// Information about a PCI Base Address Register.
 #[derive(Clone, Debug, Eq, PartialEq)]
@@ -438,20 +478,22 @@ impl TryFrom<u8> for MemoryBarType {
 
 /// Iterator over capabilities for a device.
 #[derive(Debug)]
-pub struct CapabilityIterator<'a> {
-    root: &'a PciRoot,
+pub struct CapabilityIterator<'a, C: ConfigurationAccess> {
+    configuration_access: &'a C,
     device_function: DeviceFunction,
     next_capability_offset: Option<u8>,
 }
 
-impl<'a> Iterator for CapabilityIterator<'a> {
+impl<'a, C: ConfigurationAccess> Iterator for CapabilityIterator<'a, C> {
     type Item = CapabilityInfo;
 
     fn next(&mut self) -> Option<Self::Item> {
         let offset = self.next_capability_offset?;
 
         // Read the first 4 bytes of the capability.
-        let capability_header = self.root.config_read_word(self.device_function, offset);
+        let capability_header = self
+            .configuration_access
+            .read_word(self.device_function, offset);
         let id = capability_header as u8;
         let next_offset = (capability_header >> 8) as u8;
         let private_header = (capability_header >> 16) as u16;
@@ -486,21 +528,21 @@ pub struct CapabilityInfo {
 
 /// An iterator which enumerates PCI devices and functions on a given bus.
 #[derive(Debug)]
-pub struct BusDeviceIterator {
+pub struct BusDeviceIterator<C: ConfigurationAccess> {
     /// This must only be used to read read-only fields, and must not be exposed outside this
     /// module, because it uses the same CAM as the main `PciRoot` instance.
-    root: PciRoot,
+    configuration_access: C,
     next: DeviceFunction,
 }
 
-impl Iterator for BusDeviceIterator {
+impl<C: ConfigurationAccess> Iterator for BusDeviceIterator<C> {
     type Item = (DeviceFunction, DeviceFunctionInfo);
 
     fn next(&mut self) -> Option<Self::Item> {
         while self.next.device < MAX_DEVICES {
             // Read the header for the current device and function.
             let current = self.next;
-            let device_vendor = self.root.config_read_word(current, 0);
+            let device_vendor = self.configuration_access.read_word(current, 0);
 
             // Advance to the next device or function.
             self.next.function += 1;
@@ -510,14 +552,14 @@ impl Iterator for BusDeviceIterator {
             }
 
             if device_vendor != INVALID_READ {
-                let class_revision = self.root.config_read_word(current, 8);
+                let class_revision = self.configuration_access.read_word(current, 8);
                 let device_id = (device_vendor >> 16) as u16;
                 let vendor_id = device_vendor as u16;
                 let class = (class_revision >> 24) as u8;
                 let subclass = (class_revision >> 16) as u8;
                 let prog_if = (class_revision >> 8) as u8;
                 let revision = class_revision as u8;
-                let bist_type_latency_cache = self.root.config_read_word(current, 12);
+                let bist_type_latency_cache = self.configuration_access.read_word(current, 12);
                 let header_type = HeaderType::from((bist_type_latency_cache >> 16) as u8 & 0x7f);
                 return Some((
                     current,