Skip to content

Commit

Permalink
Merge pull request #161 from rcore-os/pciroot
Browse files Browse the repository at this point in the history
Factor configuration access out from PciRoot to a new trait.
  • Loading branch information
qwandor authored Nov 21, 2024
2 parents fea238e + 3bcc5e4 commit 0c7f230
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 124 deletions.
20 changes: 14 additions & 6 deletions examples/aarch64/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ use virtio_drivers::{
transport::{
mmio::{MmioTransport, VirtIOHeader},
pci::{
bus::{BarInfo, Cam, Command, DeviceFunction, MemoryBarType, PciRoot},
bus::{
BarInfo, Cam, Command, ConfigurationAccess, DeviceFunction, MemoryBarType, MmioCam,
PciRoot,
},
virtio_device_type, PciTransport,
},
DeviceType, Transport,
Expand Down Expand Up @@ -291,8 +294,9 @@ fn enumerate_pci(pci_node: FdtNode, cam: Cam) {
region.starting_address as usize + region.size.unwrap()
);
assert_eq!(region.size.unwrap(), cam.size() as usize);
// Safe because we know the pointer is to a valid MMIO region.
let mut pci_root = unsafe { PciRoot::new(region.starting_address as *mut u8, cam) };
// SAFETY: We know the pointer is to a valid MMIO region.
let mut pci_root =
PciRoot::new(unsafe { MmioCam::new(region.starting_address as *mut u8, cam) });
for (device_function, info) in pci_root.enumerate_bus(0) {
let (status, command) = pci_root.get_status_command(device_function);
info!(
Expand All @@ -304,7 +308,7 @@ fn enumerate_pci(pci_node: FdtNode, cam: Cam) {
allocate_bars(&mut pci_root, device_function, &mut allocator);
dump_bar_contents(&mut pci_root, device_function, 4);
let mut transport =
PciTransport::new::<HalImpl>(&mut pci_root, device_function).unwrap();
PciTransport::new::<HalImpl, _>(&mut pci_root, device_function).unwrap();
info!(
"Detected virtio PCI device with device type {:?}, features {:#018x}",
transport.device_type(),
Expand Down Expand Up @@ -379,7 +383,11 @@ const fn align_up(value: u32, alignment: u32) -> u32 {
((value - 1) | (alignment - 1)) + 1
}

fn dump_bar_contents(root: &mut PciRoot, device_function: DeviceFunction, bar_index: u8) {
fn dump_bar_contents(
root: &mut PciRoot<impl ConfigurationAccess>,
device_function: DeviceFunction,
bar_index: u8,
) {
let bar_info = root.bar_info(device_function, bar_index).unwrap();
trace!("Dumping bar {}: {:#x?}", bar_index, bar_info);
if let BarInfo::Memory { address, size, .. } = bar_info {
Expand All @@ -400,7 +408,7 @@ fn dump_bar_contents(root: &mut PciRoot, device_function: DeviceFunction, bar_in

/// Allocates appropriately-sized memory regions and assigns them to the device's BARs.
fn allocate_bars(
root: &mut PciRoot,
root: &mut PciRoot<impl ConfigurationAccess>,
device_function: DeviceFunction,
allocator: &mut PciMemory32Allocator,
) {
Expand Down
12 changes: 8 additions & 4 deletions examples/x86_64/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use virtio_drivers::{
device::{blk::VirtIOBlk, gpu::VirtIOGpu},
transport::{
pci::{
bus::{BarInfo, Cam, Command, DeviceFunction, PciRoot},
bus::{BarInfo, Cam, Command, ConfigurationAccess, DeviceFunction, MmioCam, PciRoot},
virtio_device_type, PciTransport,
},
DeviceType, Transport,
Expand Down Expand Up @@ -150,7 +150,7 @@ fn virtio_net<T: Transport>(transport: T) {
fn enumerate_pci(mmconfig_base: *mut u8) {
info!("mmconfig_base = {:#x}", mmconfig_base as usize);

let mut pci_root = unsafe { PciRoot::new(mmconfig_base, Cam::Ecam) };
let mut pci_root = PciRoot::new(unsafe { MmioCam::new(mmconfig_base, Cam::Ecam) });
for (device_function, info) in pci_root.enumerate_bus(0) {
let (status, command) = pci_root.get_status_command(device_function);
info!(
Expand All @@ -168,7 +168,7 @@ fn enumerate_pci(mmconfig_base: *mut u8) {
dump_bar_contents(&mut pci_root, device_function, 4);

let mut transport =
PciTransport::new::<HalImpl>(&mut pci_root, device_function).unwrap();
PciTransport::new::<HalImpl, _>(&mut pci_root, device_function).unwrap();
info!(
"Detected virtio PCI device with device type {:?}, features {:#018x}",
transport.device_type(),
Expand All @@ -179,7 +179,11 @@ fn enumerate_pci(mmconfig_base: *mut u8) {
}
}

fn dump_bar_contents(root: &mut PciRoot, device_function: DeviceFunction, bar_index: u8) {
fn dump_bar_contents(
root: &mut PciRoot<impl ConfigurationAccess>,
device_function: DeviceFunction,
bar_index: u8,
) {
let bar_info = root.bar_info(device_function, bar_index).unwrap();
trace!("Dumping bar {}: {:#x?}", bar_index, bar_info);
if let BarInfo::Memory { address, size, .. } = bar_info {
Expand Down
40 changes: 23 additions & 17 deletions src/transport/pci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
pub mod bus;

use self::bus::{DeviceFunction, DeviceFunctionInfo, PciError, PciRoot, PCI_CAP_ID_VNDR};
use self::bus::{
ConfigurationAccess, DeviceFunction, DeviceFunctionInfo, PciError, PciRoot, PCI_CAP_ID_VNDR,
};
use super::{DeviceStatus, DeviceType, Transport};
use crate::{
hal::{Hal, PhysAddr},
Expand Down Expand Up @@ -99,11 +101,11 @@ impl PciTransport {
/// root controller.
///
/// The PCI device must already have had its BARs allocated.
pub fn new<H: Hal>(
root: &mut PciRoot,
pub fn new<H: Hal, C: ConfigurationAccess>(
root: &mut PciRoot<C>,
device_function: DeviceFunction,
) -> Result<Self, VirtioPciError> {
let device_vendor = root.config_read_word(device_function, 0);
let device_vendor = root.configuration_access.read_word(device_function, 0);
let device_id = (device_vendor >> 16) as u16;
let vendor_id = device_vendor as u16;
if vendor_id != VIRTIO_VENDOR_ID {
Expand All @@ -127,12 +129,16 @@ impl PciTransport {
continue;
}
let struct_info = VirtioCapabilityInfo {
bar: root.config_read_word(device_function, capability.offset + CAP_BAR_OFFSET)
bar: root
.configuration_access
.read_word(device_function, capability.offset + CAP_BAR_OFFSET)
as u8,
offset: root
.config_read_word(device_function, capability.offset + CAP_BAR_OFFSET_OFFSET),
.configuration_access
.read_word(device_function, capability.offset + CAP_BAR_OFFSET_OFFSET),
length: root
.config_read_word(device_function, capability.offset + CAP_LENGTH_OFFSET),
.configuration_access
.read_word(device_function, capability.offset + CAP_LENGTH_OFFSET),
};

match cfg_type {
Expand All @@ -141,7 +147,7 @@ impl PciTransport {
}
VIRTIO_PCI_CAP_NOTIFY_CFG if cap_len >= 20 && notify_cfg.is_none() => {
notify_cfg = Some(struct_info);
notify_off_multiplier = root.config_read_word(
notify_off_multiplier = root.configuration_access.read_word(
device_function,
capability.offset + CAP_NOTIFY_OFF_MULTIPLIER_OFFSET,
);
Expand All @@ -156,7 +162,7 @@ impl PciTransport {
}
}

let common_cfg = get_bar_region::<H, _>(
let common_cfg = get_bar_region::<H, _, _>(
root,
device_function,
&common_cfg.ok_or(VirtioPciError::MissingCommonConfig)?,
Expand All @@ -168,16 +174,16 @@ impl PciTransport {
notify_off_multiplier,
));
}
let notify_region = get_bar_region_slice::<H, _>(root, device_function, &notify_cfg)?;
let notify_region = get_bar_region_slice::<H, _, _>(root, device_function, &notify_cfg)?;

let isr_status = get_bar_region::<H, _>(
let isr_status = get_bar_region::<H, _, _>(
root,
device_function,
&isr_cfg.ok_or(VirtioPciError::MissingIsrConfig)?,
)?;

let config_space = if let Some(device_cfg) = device_cfg {
Some(get_bar_region_slice::<H, _>(
Some(get_bar_region_slice::<H, _, _>(
root,
device_function,
&device_cfg,
Expand Down Expand Up @@ -387,8 +393,8 @@ struct VirtioCapabilityInfo {
length: u32,
}

fn get_bar_region<H: Hal, T>(
root: &mut PciRoot,
fn get_bar_region<H: Hal, T, C: ConfigurationAccess>(
root: &mut PciRoot<C>,
device_function: DeviceFunction,
struct_info: &VirtioCapabilityInfo,
) -> Result<NonNull<T>, VirtioPciError> {
Expand Down Expand Up @@ -417,12 +423,12 @@ fn get_bar_region<H: Hal, T>(
Ok(vaddr.cast())
}

fn get_bar_region_slice<H: Hal, T>(
root: &mut PciRoot,
fn get_bar_region_slice<H: Hal, T, C: ConfigurationAccess>(
root: &mut PciRoot<C>,
device_function: DeviceFunction,
struct_info: &VirtioCapabilityInfo,
) -> Result<NonNull<[T]>, VirtioPciError> {
let ptr = get_bar_region::<H, T>(root, device_function, struct_info)?;
let ptr = get_bar_region::<H, T, C>(root, device_function, struct_info)?;
Ok(nonnull_slice_from_raw_parts(
ptr,
struct_info.length as usize / size_of::<T>(),
Expand Down
Loading

0 comments on commit 0c7f230

Please sign in to comment.