From 9fca0a10c3e3e56b07e126a71719bc555f4fcc90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=87a=C4=9Fatay=20Yi=C4=9Fit=20=C5=9Eahin?= Date: Sun, 3 Mar 2024 20:35:43 +0100 Subject: [PATCH] virtio_net: use smoltcp methods instead of hardcoding field offsets --- src/drivers/net/virtio_net.rs | 56 +++++++++++++++++------------------ 1 file changed, 27 insertions(+), 29 deletions(-) diff --git a/src/drivers/net/virtio_net.rs b/src/drivers/net/virtio_net.rs index 3d79d9dc98..8ccf9dfb91 100644 --- a/src/drivers/net/virtio_net.rs +++ b/src/drivers/net/virtio_net.rs @@ -12,7 +12,7 @@ use core::mem; use align_address::Align; use pci_types::InterruptLine; use smoltcp::phy::{Checksum, ChecksumCapabilities}; -use smoltcp::wire::{ETHERNET_HEADER_LEN, IPV4_HEADER_LEN, IPV6_HEADER_LEN}; +use smoltcp::wire::{EthernetFrame, Ipv4Packet, Ipv6Packet, ETHERNET_HEADER_LEN}; use zerocopy::AsBytes; use self::constants::{FeatureSet, Features, NetHdrFlag, NetHdrGSO, Status, MAX_NUM_VQ}; @@ -523,35 +523,33 @@ impl NetworkDriver for VirtioNetDriver { // If a checksum isn't necessary, we have inform the host within the header // see Virtio specification 5.1.6.2 if !self.checksums.tcp.tx() || !self.checksums.udp.tx() { - let type_ = unsafe { u16::from_be(*(buff_ptr.offset(12) as *const u16)) }; - - match type_ { - 0x0800 /* IPv4 */ => { - let protocol = unsafe { *(buff_ptr.offset((ETHERNET_HEADER_LEN+9).try_into().unwrap()) as *const u8) }; - if protocol == 6 /* TCP */ { - header.flags = NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM; - header.csum_start = (ETHERNET_HEADER_LEN+IPV4_HEADER_LEN).try_into().unwrap(); - header.csum_offset = 16; - } else if protocol == 17 /* UDP */ { - header.flags = NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM; - header.csum_start = (ETHERNET_HEADER_LEN+IPV4_HEADER_LEN).try_into().unwrap(); - header.csum_offset = 6; - } - }, - 0x86DD /* IPv6 */ => { - let protocol = unsafe { *(buff_ptr.offset((ETHERNET_HEADER_LEN+9).try_into().unwrap()) as *const u8) }; - if protocol == 6 /* TCP */ { - header.flags = NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM; - header.csum_start = (ETHERNET_HEADER_LEN+IPV6_HEADER_LEN).try_into().unwrap(); - header.csum_offset = 16; - } else if protocol == 17 /* UDP */ { - header.flags = NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM; - header.csum_start = (ETHERNET_HEADER_LEN+IPV6_HEADER_LEN).try_into().unwrap(); - header.csum_offset = 6; - } - }, - _ => {}, + header.flags = NetHdrFlag::VIRTIO_NET_HDR_F_NEEDS_CSUM; + let ethernet_frame: smoltcp::wire::EthernetFrame<&[u8]> = + EthernetFrame::new_unchecked(buf_slice); + let packet_header_len: u16; + let protocol; + match ethernet_frame.ethertype() { + smoltcp::wire::EthernetProtocol::Ipv4 => { + let packet = Ipv4Packet::new_unchecked(ethernet_frame.payload()); + packet_header_len = packet.header_len().into(); + protocol = Some(packet.next_header()); + } + smoltcp::wire::EthernetProtocol::Ipv6 => { + let packet = Ipv6Packet::new_unchecked(ethernet_frame.payload()); + packet_header_len = packet.header_len().try_into().unwrap(); + protocol = Some(packet.next_header()); + } + _ => { + packet_header_len = 0; + protocol = None; + } } + header.csum_start = u16::try_from(ETHERNET_HEADER_LEN).unwrap() + packet_header_len; + header.csum_offset = match protocol { + Some(smoltcp::wire::IpProtocol::Tcp) => 16, + Some(smoltcp::wire::IpProtocol::Udp) => 6, + _ => 0, + }; } buff_tkn