diff --git a/src/errors.rs b/src/errors.rs index 5dcc64d..4a84ba2 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -32,29 +32,56 @@ impl From for EncodeError { } #[derive(Debug, Error)] -#[error("Decode error occurred: {inner}")] -pub struct DecodeError { - inner: anyhow::Error, +pub enum DecodeError { + #[error("Invalid MAC address")] + InvalidMACAddress, + + #[error("Invalid IP address")] + InvalidIPAddress, + + #[error("Invalid string")] + Utf8Error(#[from] std::string::FromUtf8Error), + + #[error("Invalid u8")] + InvalidU8, + + #[error("Invalid u16")] + InvalidU16, + + #[error("Invalid u32")] + InvalidU32, + + #[error("Invalid u64")] + InvalidU64, + + #[error("Invalid u128")] + InvalidU128, + + #[error("Invalid i32")] + InvalidI32, + + #[error("Invalid {name}: length {len} < {buffer_len}")] + InvalidBufferLength { + name: &'static str, + len: usize, + buffer_len: usize, + }, + + #[error(transparent)] + Nla(#[from] crate::nla::NlaError), + + #[error(transparent)] + Other(#[from] anyhow::Error), } impl From<&'static str> for DecodeError { fn from(msg: &'static str) -> Self { - DecodeError { - inner: anyhow!(msg), - } + DecodeError::Other(anyhow!(msg)) } } impl From for DecodeError { fn from(msg: String) -> Self { - DecodeError { - inner: anyhow!(msg), - } - } -} - -impl From for DecodeError { - fn from(inner: anyhow::Error) -> DecodeError { - DecodeError { inner } + DecodeError::Other(anyhow!(msg)) } } diff --git a/src/macros.rs b/src/macros.rs index 836b5cb..b4239cf 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -192,15 +192,11 @@ macro_rules! buffer_check_length { fn check_buffer_length(&self) -> Result<(), DecodeError> { let len = self.buffer.as_ref().len(); if len < $buffer_len { - Err(format!( - concat!( - "invalid ", - stringify!($name), - ": length {} < {}" - ), - len, $buffer_len - ) - .into()) + Err(DecodeError::InvalidBufferLength { + name: stringify!($name), + len, + buffer_len: $buffer_len, + }) } else { Ok(()) } diff --git a/src/nla.rs b/src/nla.rs index 96f7d79..2c70989 100644 --- a/src/nla.rs +++ b/src/nla.rs @@ -1,14 +1,12 @@ // SPDX-License-Identifier: MIT -use core::ops::Range; - -use anyhow::Context; -use byteorder::{ByteOrder, NativeEndian}; - use crate::{ traits::{Emitable, Parseable}, DecodeError, }; +use byteorder::{ByteOrder, NativeEndian}; +use core::ops::Range; +use thiserror::Error; /// Represent a multi-bytes field with a fixed size in a packet type Field = Range; @@ -25,6 +23,20 @@ pub const NLA_ALIGNTO: usize = 4; /// NlA(RTA) header size. (unsigned short rta_len) + (unsigned short rta_type) pub const NLA_HEADER_SIZE: usize = 4; +#[derive(Debug, Error)] +pub enum NlaError { + #[error("buffer has length {buffer_len}, but an NLA header is {} bytes", TYPE.end)] + BufferTooSmall { buffer_len: usize }, + + #[error("buffer has length: {buffer_len}, but the NLA is {nla_len} bytes")] + LengthMismatch { buffer_len: usize, nla_len: u16 }, + + #[error( + "NLA has invalid length: {nla_len} (should be at least {} bytes", TYPE.end + )] + InvalidLength { nla_len: u16 }, +} + #[macro_export] macro_rules! nla_align { ($len: expr) => { @@ -52,33 +64,26 @@ impl> NlaBuffer { NlaBuffer { buffer } } - pub fn new_checked(buffer: T) -> Result, DecodeError> { + pub fn new_checked(buffer: T) -> Result, NlaError> { let buffer = Self::new(buffer); - buffer.check_buffer_length().context("invalid NLA buffer")?; + buffer.check_buffer_length()?; Ok(buffer) } - pub fn check_buffer_length(&self) -> Result<(), DecodeError> { + pub fn check_buffer_length(&self) -> Result<(), NlaError> { let len = self.buffer.as_ref().len(); if len < TYPE.end { - Err(format!( - "buffer has length {}, but an NLA header is {} bytes", - len, TYPE.end - ) - .into()) + Err(NlaError::BufferTooSmall { buffer_len: len }.into()) } else if len < self.length() as usize { - Err(format!( - "buffer has length: {}, but the NLA is {} bytes", - len, - self.length() - ) + Err(NlaError::LengthMismatch { + buffer_len: len, + nla_len: self.length(), + } .into()) } else if (self.length() as usize) < TYPE.end { - Err(format!( - "NLA has invalid length: {} (should be at least {} bytes", - self.length(), - TYPE.end, - ) + Err(NlaError::InvalidLength { + nla_len: self.length(), + } .into()) } else { Ok(()) @@ -204,7 +209,9 @@ impl Nla for DefaultNla { impl<'buffer, T: AsRef<[u8]> + ?Sized> Parseable> for DefaultNla { - fn parse(buf: &NlaBuffer<&'buffer T>) -> Result { + type Error = DecodeError; + + fn parse(buf: &NlaBuffer<&'buffer T>) -> Result { let mut kind = buf.kind(); if buf.network_byte_order_flag() { @@ -314,7 +321,7 @@ impl NlasIterator { impl<'buffer, T: AsRef<[u8]> + ?Sized + 'buffer> Iterator for NlasIterator<&'buffer T> { - type Item = Result, DecodeError>; + type Item = Result, NlaError>; fn next(&mut self) -> Option { if self.position >= self.buffer.as_ref().len() { diff --git a/src/parsers.rs b/src/parsers.rs index f1198d3..50364c9 100644 --- a/src/parsers.rs +++ b/src/parsers.rs @@ -5,14 +5,13 @@ use std::{ net::{IpAddr, Ipv4Addr, Ipv6Addr}, }; -use anyhow::Context; use byteorder::{BigEndian, ByteOrder, NativeEndian}; use crate::DecodeError; pub fn parse_mac(payload: &[u8]) -> Result<[u8; 6], DecodeError> { if payload.len() != 6 { - return Err(format!("invalid MAC address: {payload:?}").into()); + return Err(DecodeError::InvalidMACAddress); } let mut address: [u8; 6] = [0; 6]; for (i, byte) in payload.iter().enumerate() { @@ -23,7 +22,7 @@ pub fn parse_mac(payload: &[u8]) -> Result<[u8; 6], DecodeError> { pub fn parse_ipv6(payload: &[u8]) -> Result<[u8; 16], DecodeError> { if payload.len() != 16 { - return Err(format!("invalid IPv6 address: {payload:?}").into()); + return Err(DecodeError::InvalidIPAddress); } let mut address: [u8; 16] = [0; 16]; for (i, byte) in payload.iter().enumerate() { @@ -57,7 +56,7 @@ pub fn parse_ip(payload: &[u8]) -> Result { payload[15], ]) .into()), - _ => Err(format!("invalid IPv6 address: {payload:?}").into()), + _ => Err(DecodeError::InvalidIPAddress), } } @@ -71,62 +70,62 @@ pub fn parse_string(payload: &[u8]) -> Result { } else { &payload[..payload.len()] }; - let s = String::from_utf8(slice.to_vec()).context("invalid string")?; + let s = String::from_utf8(slice.to_vec())?; Ok(s) } pub fn parse_u8(payload: &[u8]) -> Result { if payload.len() != 1 { - return Err(format!("invalid u8: {payload:?}").into()); + return Err(DecodeError::InvalidU8); } Ok(payload[0]) } pub fn parse_u32(payload: &[u8]) -> Result { if payload.len() != size_of::() { - return Err(format!("invalid u32: {payload:?}").into()); + return Err(DecodeError::InvalidU32); } Ok(NativeEndian::read_u32(payload)) } pub fn parse_u64(payload: &[u8]) -> Result { if payload.len() != size_of::() { - return Err(format!("invalid u64: {payload:?}").into()); + return Err(DecodeError::InvalidU64); } Ok(NativeEndian::read_u64(payload)) } pub fn parse_u128(payload: &[u8]) -> Result { if payload.len() != size_of::() { - return Err(format!("invalid u128: {payload:?}").into()); + return Err(DecodeError::InvalidU128); } Ok(NativeEndian::read_u128(payload)) } pub fn parse_u16(payload: &[u8]) -> Result { if payload.len() != size_of::() { - return Err(format!("invalid u16: {payload:?}").into()); + return Err(DecodeError::InvalidU16); } Ok(NativeEndian::read_u16(payload)) } pub fn parse_i32(payload: &[u8]) -> Result { if payload.len() != 4 { - return Err(format!("invalid u32: {payload:?}").into()); + return Err(DecodeError::InvalidI32); } Ok(NativeEndian::read_i32(payload)) } pub fn parse_u16_be(payload: &[u8]) -> Result { if payload.len() != size_of::() { - return Err(format!("invalid u16: {payload:?}").into()); + return Err(DecodeError::InvalidU16); } Ok(BigEndian::read_u16(payload)) } pub fn parse_u32_be(payload: &[u8]) -> Result { if payload.len() != size_of::() { - return Err(format!("invalid u32: {payload:?}").into()); + return Err(DecodeError::InvalidU32); } Ok(BigEndian::read_u32(payload)) } diff --git a/src/traits.rs b/src/traits.rs index 89c1bed..855dc60 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,7 +1,5 @@ // SPDX-License-Identifier: MIT -use crate::DecodeError; - /// A type that implements `Emitable` can be serialized. pub trait Emitable { /// Return the length of the serialized data. @@ -26,8 +24,10 @@ where Self: Sized, T: ?Sized, { + type Error; + /// Deserialize the current type. - fn parse(buf: &T) -> Result; + fn parse(buf: &T) -> Result; } /// A `Parseable` type can be used to deserialize data from the type `T` for @@ -37,6 +37,8 @@ where Self: Sized, T: ?Sized, { + type Error; + /// Deserialize the current type. - fn parse_with_param(buf: &T, params: P) -> Result; + fn parse_with_param(buf: &T, params: P) -> Result; }