diff --git a/src/header/map.rs b/src/header/map.rs index 1f93b22c..c954b4a7 100644 --- a/src/header/map.rs +++ b/src/header/map.rs @@ -8,7 +8,7 @@ use std::{fmt, mem, ops, ptr, vec}; use crate::Error; -use super::name::{HdrName, HeaderName, InvalidHeaderName}; +use super::name::{FastHash, HdrName, HeaderName, InvalidHeaderName}; use super::HeaderValue; pub use self::as_header_name::AsHeaderName; @@ -1077,7 +1077,7 @@ impl HeaderMap { fn entry2(&mut self, key: K) -> Entry<'_, T> where - K: Hash + Into, + K: FastHash + Into, HeaderName: PartialEq, { // Ensure that there is space in the map @@ -1149,7 +1149,7 @@ impl HeaderMap { #[inline] fn insert2(&mut self, key: K, value: T) -> Option where - K: Hash + Into, + K: FastHash + Into, HeaderName: PartialEq, { self.reserve_one(); @@ -1250,7 +1250,7 @@ impl HeaderMap { #[inline] fn append2(&mut self, key: K, value: T) -> bool where - K: Hash + Into, + K: FastHash + Into, HeaderName: PartialEq, { self.reserve_one(); @@ -1287,7 +1287,7 @@ impl HeaderMap { #[inline] fn find(&self, key: &K) -> Option<(usize, usize)> where - K: Hash + Into, + K: FastHash + Into, HeaderName: PartialEq, { if self.entries.is_empty() { @@ -3272,10 +3272,8 @@ fn probe_distance(mask: Size, hash: HashValue, current: usize) -> usize { fn hash_elem_using(danger: &Danger, k: &K) -> HashValue where - K: Hash, + K: FastHash, { - use fnv::FnvHasher; - const MASK: u64 = (MAX_SIZE as u64) - 1; let hash = match *danger { @@ -3286,11 +3284,7 @@ where h.finish() } // Fast hash - _ => { - let mut h = FnvHasher::default(); - k.hash(&mut h); - h.finish() - } + _ => k.fast_hash() }; HashValue((hash & MASK) as u16) diff --git a/src/header/name.rs b/src/header/name.rs index ecf0ad23..39681e64 100644 --- a/src/header/name.rs +++ b/src/header/name.rs @@ -29,32 +29,127 @@ use std::str::FromStr; /// /// [`HeaderMap`]: struct.HeaderMap.html /// [`header`]: index.html -#[derive(Clone, Eq, PartialEq, Hash)] +#[derive(Clone, Eq, PartialEq)] pub struct HeaderName { inner: Repr, } // Almost a full `HeaderName` -#[derive(Debug, Hash)] +#[derive(Debug)] pub struct HdrName<'a> { inner: Repr>, } -#[derive(Debug, Clone, Eq, PartialEq, Hash)] -enum Repr { - Standard(StandardHeader), - Custom(T), +#[derive(Debug, Clone, Eq)] +enum Repr + PartialEq> { + Static(StaticRepresentation), + Runtime(T), +} + +impl + PartialEq> PartialEq> for Repr { + #[inline] + fn eq(&self, other: &Repr) -> bool { + use Repr::*; + + match (self, other) { + (Static(a), Static(b)) => a == b, + (Static(a), Runtime(b)) => b == a, + (Runtime(a), Static(b)) => a == b, + (Runtime(a), Runtime(b)) => a == b, + } + } +} + +pub trait FastHash: Hash { + #[inline] + fn fast_hash(&self) -> u64 { + use fnv::FnvHasher; + + let mut h = FnvHasher::default(); + self.hash(&mut h); + h.finish() + } } // Used to hijack the Hash impl #[derive(Debug, Clone, Eq, PartialEq)] struct Custom(ByteStr); -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Eq, PartialEq)] // Invariant: If lower then buf is valid UTF-8. struct MaybeLower<'a> { - buf: &'a [u8], lower: bool, + buf: &'a [u8], +} + +#[derive(Debug, Clone, Eq, PartialEq)] +struct StaticHeader { + compile_time_hash: u64, + name: &'static str, +} + +impl Hash for StaticHeader { + #[inline] + fn hash(&self, state: &mut H) { + state.write(self.name.as_bytes()) + } +} + +impl FastHash for StaticHeader { + #[inline] + fn fast_hash(&self) -> u64 { + self.compile_time_hash + } +} + +#[derive(Debug, Clone, Eq, PartialEq)] +enum StaticRepresentation { + Custom(StaticHeader), + Standard(StandardHeader), +} + +impl PartialEq for StaticRepresentation { + #[inline] + fn eq(&self, other: &Custom) -> bool { + use StaticRepresentation::*; + + match self { + Standard(_) => false, + Custom(a) => a.name.as_bytes() == other.0.as_bytes(), + } + } +} + +impl PartialEq for Custom { + #[inline] + fn eq(&self, other: &StaticRepresentation) -> bool { + other == self + } +} + +impl<'a> PartialEq> for StaticRepresentation { + #[inline] + fn eq(&self, other: &MaybeLower<'a>) -> bool { + use StaticRepresentation::*; + + match self { + Standard(_) => false, + Custom(a) => { + if other.lower { + a.name.as_bytes() == other.buf + } else { + eq_ignore_ascii_case(a.name.as_bytes(), other.buf) + } + } + } + } +} + +impl<'a> PartialEq for MaybeLower<'a> { + #[inline] + fn eq(&self, other: &StaticRepresentation) -> bool { + other == self + } } /// A possible error when converting a `HeaderName` from another type. @@ -79,7 +174,7 @@ macro_rules! standard_headers { $( $(#[$docs])* pub const $upcase: HeaderName = HeaderName { - inner: Repr::Standard(StandardHeader::$konst), + inner: Repr::Static(StaticRepresentation::Standard(StandardHeader::$konst)), }; )+ @@ -1067,6 +1162,21 @@ const HEADER_CHARS_H2: [u8; 256] = [ 0, 0, 0, 0, 0, 0 // 25x ]; +#[inline] +const fn const_fnv(bytes: &[u8]) -> u64 { + let mut hash: u64 = 0xcbf29ce484222325; + let mut i = 0; + let size = bytes.len(); + + while i < size { + hash ^= bytes[i] as u64; + hash = hash.wrapping_mul(0x100000001b3 as u64); + i += 1; + } + + hash +} + fn parse_hdr<'a>( data: &'a [u8], b: &'a mut [MaybeUninit; SCRATCH_BUF_SIZE], @@ -1100,7 +1210,7 @@ fn parse_hdr<'a>( impl<'a> From for HdrName<'a> { fn from(hdr: StandardHeader) -> HdrName<'a> { HdrName { - inner: Repr::Standard(hdr), + inner: Repr::Static(StaticRepresentation::Standard(hdr)), } } } @@ -1113,14 +1223,14 @@ impl HeaderName { let mut buf = uninit_u8_array(); // Precondition: HEADER_CHARS is a valid table for parse_hdr(). match parse_hdr(src, &mut buf, &HEADER_CHARS)?.inner { - Repr::Standard(std) => Ok(std.into()), - Repr::Custom(MaybeLower { buf, lower: true }) => { + Repr::Static(s) => Ok(s.into()), + Repr::Runtime(MaybeLower { buf, lower: true }) => { let buf = Bytes::copy_from_slice(buf); // Safety: the invariant on MaybeLower ensures buf is valid UTF-8. let val = unsafe { ByteStr::from_utf8_unchecked(buf) }; Ok(Custom(val).into()) } - Repr::Custom(MaybeLower { buf, lower: false }) => { + Repr::Runtime(MaybeLower { buf, lower: false }) => { use bytes::BufMut; let mut dst = BytesMut::with_capacity(buf.len()); @@ -1167,14 +1277,14 @@ impl HeaderName { let mut buf = uninit_u8_array(); // Precondition: HEADER_CHARS_H2 is a valid table for parse_hdr() match parse_hdr(src, &mut buf, &HEADER_CHARS_H2)?.inner { - Repr::Standard(std) => Ok(std.into()), - Repr::Custom(MaybeLower { buf, lower: true }) => { + Repr::Static(s) => Ok(s.into()), + Repr::Runtime(MaybeLower { buf, lower: true }) => { let buf = Bytes::copy_from_slice(buf); // Safety: the invariant on MaybeLower ensures buf is valid UTF-8. let val = unsafe { ByteStr::from_utf8_unchecked(buf) }; Ok(Custom(val).into()) } - Repr::Custom(MaybeLower { buf, lower: false }) => { + Repr::Runtime(MaybeLower { buf, lower: false }) => { for &b in buf.iter() { // HEADER_CHARS maps all bytes that are not valid single-byte // UTF-8 to 0 so this check returns an error for invalid UTF-8. @@ -1254,7 +1364,7 @@ impl HeaderName { let name_bytes = src.as_bytes(); if let Some(standard) = StandardHeader::from_bytes(name_bytes) { return HeaderName { - inner: Repr::Standard(standard), + inner: Repr::Static(StaticRepresentation::Standard(standard)), }; } @@ -1273,7 +1383,10 @@ impl HeaderName { } HeaderName { - inner: Repr::Custom(Custom(ByteStr::from_static(src))), + inner: Repr::Static(StaticRepresentation::Custom(StaticHeader { + compile_time_hash: const_fnv(src.as_bytes()), + name: src, + })), } } @@ -1283,8 +1396,11 @@ impl HeaderName { #[inline] pub fn as_str(&self) -> &str { match self.inner { - Repr::Standard(v) => v.as_str(), - Repr::Custom(ref v) => &*v.0, + Repr::Static(ref s) => match s { + StaticRepresentation::Standard(std) => std.as_str(), + StaticRepresentation::Custom(custom) => custom.name, + }, + Repr::Runtime(ref r) => &*r.0, } } @@ -1343,15 +1459,28 @@ impl<'a> From<&'a HeaderName> for HeaderName { } } +impl From for HeaderName { + #[inline] + fn from(s: StaticRepresentation) -> HeaderName { + HeaderName { + inner: Repr::Static(s), + } + } +} + #[doc(hidden)] -impl From> for Bytes +impl> From> for Bytes where T: Into, { + #[inline] fn from(repr: Repr) -> Bytes { match repr { - Repr::Standard(header) => Bytes::from_static(header.as_str().as_bytes()), - Repr::Custom(header) => header.into(), + Repr::Static(s) => match s { + StaticRepresentation::Standard(std) => Bytes::from_static(std.as_str().as_bytes()), + StaticRepresentation::Custom(custom) => Bytes::from_static(custom.name.as_bytes()), + }, + Repr::Runtime(header) => header.into(), } } } @@ -1407,18 +1536,30 @@ impl TryFrom> for HeaderName { #[doc(hidden)] impl From for HeaderName { + #[inline] fn from(src: StandardHeader) -> HeaderName { HeaderName { - inner: Repr::Standard(src), + inner: Repr::Static(StaticRepresentation::Standard(src)), + } + } +} + +#[doc(hidden)] +impl From for HeaderName { + #[inline] + fn from(src: StaticHeader) -> HeaderName { + HeaderName { + inner: Repr::Static(StaticRepresentation::Custom(src)), } } } #[doc(hidden)] impl From for HeaderName { + #[inline] fn from(src: Custom) -> HeaderName { HeaderName { - inner: Repr::Custom(src), + inner: Repr::Runtime(src), } } } @@ -1516,7 +1657,7 @@ impl<'a> HdrName<'a> { fn custom(buf: &'a [u8], lower: bool) -> HdrName<'a> { HdrName { // Invariant (on MaybeLower): follows from the precondition - inner: Repr::Custom(MaybeLower { + inner: Repr::Runtime(MaybeLower { buf: buf, lower: lower, }), @@ -1549,17 +1690,17 @@ impl<'a> HdrName<'a> { impl<'a> From> for HeaderName { fn from(src: HdrName<'a>) -> HeaderName { match src.inner { - Repr::Standard(s) => HeaderName { - inner: Repr::Standard(s), + Repr::Static(s) => HeaderName { + inner: Repr::Static(s), }, - Repr::Custom(maybe_lower) => { + Repr::Runtime(maybe_lower) => { if maybe_lower.lower { let buf = Bytes::copy_from_slice(&maybe_lower.buf[..]); // Safety: the invariant on MaybeLower ensures buf is valid UTF-8. let byte_str = unsafe { ByteStr::from_utf8_unchecked(buf) }; HeaderName { - inner: Repr::Custom(Custom(byte_str)), + inner: Repr::Runtime(Custom(byte_str)), } } else { use bytes::BufMut; @@ -1577,7 +1718,7 @@ impl<'a> From> for HeaderName { let buf = unsafe { ByteStr::from_utf8_unchecked(dst.freeze()) }; HeaderName { - inner: Repr::Custom(Custom(buf)), + inner: Repr::Runtime(Custom(buf)), } } } @@ -1589,21 +1730,17 @@ impl<'a> From> for HeaderName { impl<'a> PartialEq> for HeaderName { #[inline] fn eq(&self, other: &HdrName<'a>) -> bool { - match self.inner { - Repr::Standard(a) => match other.inner { - Repr::Standard(b) => a == b, - _ => false, - }, - Repr::Custom(Custom(ref a)) => match other.inner { - Repr::Custom(ref b) => { - if b.lower { - a.as_bytes() == b.buf - } else { - eq_ignore_ascii_case(a.as_bytes(), b.buf) - } + match (&self.inner, &other.inner) { + (Repr::Static(ref a), Repr::Static(ref b)) => a == b, + (Repr::Static(ref a), Repr::Runtime(ref b)) => a == b, + (Repr::Runtime(ref a), Repr::Static(ref b)) => a == b, + (Repr::Runtime(Custom(ref a)), Repr::Runtime(ref b)) => { + if b.lower { + a.as_bytes() == b.buf + } else { + eq_ignore_ascii_case(a.as_bytes(), b.buf) } - _ => false, - }, + } } } } @@ -1632,6 +1769,94 @@ impl<'a> Hash for MaybeLower<'a> { } } +impl Hash for HeaderName { + #[inline] + fn hash(&self, state: &mut H) { + self.inner.hash(state) + } +} + +impl FastHash for HeaderName { + #[inline] + fn fast_hash(&self) -> u64 { + self.inner.fast_hash() + } +} + +impl FastHash for &HeaderName { + #[inline] + fn fast_hash(&self) -> u64 { + (*self).fast_hash() + } +} + +impl<'a> Hash for HdrName<'a> { + #[inline] + fn hash(&self, state: &mut H) { + self.inner.hash(state) + } +} + +impl<'a> FastHash for HdrName<'a> { + #[inline] + fn fast_hash(&self) -> u64 { + self.inner.fast_hash() + } +} + +impl FastHash for Custom {} +impl<'a> FastHash for MaybeLower<'a> {} + +impl FastHash for StandardHeader {} + +impl Hash for StaticRepresentation { + #[inline] + fn hash(&self, state: &mut H) { + use StaticRepresentation::*; + + match self { + Custom(c) => c.hash(state), + Standard(std) => std.hash(state), + } + } +} + +impl FastHash for StaticRepresentation { + #[inline] + fn fast_hash(&self) -> u64 { + use StaticRepresentation::*; + + match self { + Custom(c) => c.fast_hash(), + Standard(s) => s.fast_hash(), + } + } +} + +impl> Hash for Repr { + #[inline] + fn hash(&self, state: &mut H) { + use Repr::*; + + match self { + Static(s) => s.hash(state), + Runtime(r) => r.hash(state), + } + } +} + +impl> FastHash for Repr { + #[inline] + fn fast_hash(&self) -> u64 { + use Repr::*; + + match self { + Static(s) => s.fast_hash(), + Runtime(r) => r.fast_hash(), + } + } +} + // Assumes that the left hand side is already lower case #[inline] fn eq_ignore_ascii_case(lower: &[u8], s: &[u8]) -> bool { @@ -1667,6 +1892,8 @@ unsafe fn slice_assume_init(slice: &[MaybeUninit]) -> &[T] { #[cfg(test)] mod tests { + use fnv::FnvHasher; + use self::StandardHeader::Vary; use super::*; @@ -1724,13 +1951,16 @@ mod tests { use self::StandardHeader::Vary; let name = HeaderName::from(HdrName { - inner: Repr::Standard(Vary), + inner: Repr::Static(StaticRepresentation::Standard(Vary)), }); - assert_eq!(name.inner, Repr::Standard(Vary)); + assert_eq!( + name.inner, + Repr::Static(StaticRepresentation::Standard(Vary)) + ); let name = HeaderName::from(HdrName { - inner: Repr::Custom(MaybeLower { + inner: Repr::Runtime(MaybeLower { buf: b"hello-world", lower: true, }), @@ -1738,11 +1968,11 @@ mod tests { assert_eq!( name.inner, - Repr::Custom(Custom(ByteStr::from_static("hello-world"))) + Repr::Runtime(Custom(ByteStr::from_static("hello-world"))) ); let name = HeaderName::from(HdrName { - inner: Repr::Custom(MaybeLower { + inner: Repr::Runtime(MaybeLower { buf: b"Hello-World", lower: false, }), @@ -1750,7 +1980,7 @@ mod tests { assert_eq!( name.inner, - Repr::Custom(Custom(ByteStr::from_static("hello-world"))) + Repr::Runtime(Custom(ByteStr::from_static("hello-world"))) ); } @@ -1759,21 +1989,21 @@ mod tests { use self::StandardHeader::Vary; let a = HeaderName { - inner: Repr::Standard(Vary), + inner: Repr::Static(StaticRepresentation::Standard(Vary)), }; let b = HdrName { - inner: Repr::Standard(Vary), + inner: Repr::Static(StaticRepresentation::Standard(Vary)), }; assert_eq!(a, b); let a = HeaderName { - inner: Repr::Custom(Custom(ByteStr::from_static("vaary"))), + inner: Repr::Runtime(Custom(ByteStr::from_static("vaary"))), }; assert_ne!(a, b); let b = HdrName { - inner: Repr::Custom(MaybeLower { + inner: Repr::Runtime(MaybeLower { buf: b"vaary", lower: true, }), @@ -1782,7 +2012,7 @@ mod tests { assert_eq!(a, b); let b = HdrName { - inner: Repr::Custom(MaybeLower { + inner: Repr::Runtime(MaybeLower { buf: b"vaary", lower: false, }), @@ -1791,7 +2021,7 @@ mod tests { assert_eq!(a, b); let b = HdrName { - inner: Repr::Custom(MaybeLower { + inner: Repr::Runtime(MaybeLower { buf: b"VAARY", lower: false, }), @@ -1800,7 +2030,7 @@ mod tests { assert_eq!(a, b); let a = HeaderName { - inner: Repr::Standard(Vary), + inner: Repr::Static(StaticRepresentation::Standard(Vary)), }; assert_ne!(a, b); } @@ -1808,7 +2038,7 @@ mod tests { #[test] fn test_from_static_std() { let a = HeaderName { - inner: Repr::Standard(Vary), + inner: Repr::Static(StaticRepresentation::Standard(Vary)), }; let b = HeaderName::from_static("vary"); @@ -1834,7 +2064,7 @@ mod tests { #[test] fn test_from_static_custom_short() { let a = HeaderName { - inner: Repr::Custom(Custom(ByteStr::from_static("customheader"))), + inner: Repr::Runtime(Custom(ByteStr::from_static("customheader"))), }; let b = HeaderName::from_static("customheader"); assert_eq!(a, b); @@ -1843,20 +2073,20 @@ mod tests { #[test] #[should_panic] fn test_from_static_custom_short_uppercase() { - HeaderName::from_static("custom header"); + HeaderName::from_static("CustomHeader"); } #[test] #[should_panic] fn test_from_static_custom_short_symbol() { - HeaderName::from_static("CustomHeader"); + HeaderName::from_static("custom header"); } // MaybeLower { lower: false } #[test] fn test_from_static_custom_long() { let a = HeaderName { - inner: Repr::Custom(Custom(ByteStr::from_static( + inner: Repr::Runtime(Custom(ByteStr::from_static( "longer-than-63--thisheaderislongerthansixtythreecharactersandthushandleddifferent", ))), }; @@ -1885,7 +2115,7 @@ mod tests { #[test] fn test_from_static_custom_single_char() { let a = HeaderName { - inner: Repr::Custom(Custom(ByteStr::from_static("a"))), + inner: Repr::Runtime(Custom(ByteStr::from_static("a"))), }; let b = HeaderName::from_static("a"); assert_eq!(a, b); @@ -1901,4 +2131,38 @@ mod tests { fn test_all_tokens() { HeaderName::from_static("!#$%&'*+-.^_`|~0123456789abcdefghijklmnopqrstuvwxyz"); } + + fn hash_header(header: &T) -> u64 { + let mut h = FnvHasher::default(); + header.hash(&mut h); + h.finish() + } + + #[test] + fn test_eq_hash_holds() { + let a = HeaderName { + inner: Repr::Runtime(Custom(ByteStr::from_static("some-header"))), + }; + let b = HeaderName::from_static("some-header"); + + assert_eq!(a, b); + assert_eq!(a.fast_hash(), b.fast_hash()); + + assert_eq!(hash_header(&a), hash_header(&b)); + } + + #[test] + fn test_eq_hash_holds_for_hdrname() { + let a = HeaderName::from_static("some-header"); + let b = HdrName { + inner: Repr::Runtime(MaybeLower { + lower: false, + buf: b"Some-HeaDer", + }), + }; + + assert_eq!(a, b); + assert_eq!(a.fast_hash(), b.fast_hash()); + assert_eq!(hash_header(&a), hash_header(&b)); + } }