From f30b3ec92ec3e356b00d7a932131bb6e78edd36f Mon Sep 17 00:00:00 2001 From: Casper Beyer Date: Tue, 30 May 2023 00:38:32 +0800 Subject: [PATCH] Add const representation for standard header names --- async-nats/src/header.rs | 162 ++++++++++++++++++++++------ async-nats/tests/jetstream_tests.rs | 10 +- 2 files changed, 132 insertions(+), 40 deletions(-) diff --git a/async-nats/src/header.rs b/async-nats/src/header.rs index 6b837df62..52cb3e6c7 100644 --- a/async-nats/src/header.rs +++ b/async-nats/src/header.rs @@ -15,29 +15,6 @@ use std::{collections::HashMap, fmt, slice, str::FromStr}; -use serde::Serialize; - -pub const NATS_LAST_STREAM: &str = "Nats-Last-Stream"; -pub const NATS_LAST_CONSUMER: &str = "Nats-Last-Consumer"; - -/// Direct Get headers -pub const NATS_STREAM: &str = "Nats-Stream"; -pub const NATS_SEQUENCE: &str = "Nats-Sequence"; -pub const NATS_TIME_STAMP: &str = "Nats-Time-Stamp"; -pub const NATS_SUBJECT: &str = "Nats-Subject"; -pub const NATS_LAST_SEQUENCE: &str = "Nats-Last-Sequence"; - -/// Nats-Expected-Last-Subject-Sequence -pub const NATS_EXPECTED_LAST_SUBJECT_SEQUENCE: &str = "Nats-Expected-Last-Subject-Sequence"; -/// Message identifier used for deduplication window -pub const NATS_MESSAGE_ID: &str = "Nats-Msg-Id"; -/// Last expected message ID for JetStream message publish -pub const NATS_EXPECTED_LAST_MESSAGE_ID: &str = "Nats-Expected-Last-Msg-Id"; -/// Last expected sequence for JetStream message publish -pub const NATS_EXPECTED_LAST_SEQUENCE: &str = "Nats-Expected-Last-Sequence"; -/// Expect that given message will be ingested by specified stream. -pub const NATS_EXPECTED_STREAM: &str = "Nats-Expected-Stream"; - /// A struct for handling NATS headers. /// Has a similar API to [http::header], but properly serializes and deserializes /// according to NATS requirements. @@ -56,7 +33,7 @@ pub const NATS_EXPECTED_STREAM: &str = "Nats-Expected-Stream"; /// # Ok(()) /// # } /// ``` -#[derive(Clone, PartialEq, Eq, Debug, Serialize, Default)] +#[derive(Clone, PartialEq, Eq, Debug, Default)] pub struct HeaderMap { inner: HashMap, } @@ -154,7 +131,7 @@ impl HeaderMap { buf.extend_from_slice(b"NATS/1.0\r\n"); for (k, vs) in &self.inner { for v in vs.iter() { - buf.extend_from_slice(k.value.as_bytes()); + buf.extend_from_slice(k.as_str().as_bytes()); buf.extend_from_slice(b": "); buf.extend_from_slice(v.as_bytes()); buf.extend_from_slice(b"\r\n"); @@ -179,7 +156,7 @@ impl HeaderMap { /// # Ok(()) /// # } /// ``` -#[derive(Clone, PartialEq, Eq, Debug, Serialize, Default)] +#[derive(Clone, PartialEq, Eq, Debug, Default)] pub struct HeaderValue { value: Vec, } @@ -288,7 +265,7 @@ pub trait IntoHeaderName { impl IntoHeaderName for &str { fn into_header_name(self) -> HeaderName { HeaderName { - value: self.to_string(), + inner: HeaderRepr::Custom(self.to_string()), } } } @@ -316,10 +293,120 @@ impl IntoHeaderValue for HeaderValue { } } -#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize)] +macro_rules! standard_headers { + ( + $( + $(#[$docs:meta])* + ($variant:ident, $constant:ident, $bytes:literal); + )+ + ) => { + #[allow(clippy::enum_variant_names)] + #[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] + enum StandardHeader { + $( + $variant, + )+ + } + + $( + $(#[$docs])* + pub const $constant: HeaderName = HeaderName { + inner: HeaderRepr::Standard(StandardHeader::$variant), + }; + )+ + + impl StandardHeader { + #[inline] + fn as_str(&self) -> &'static str { + match *self { + $( + StandardHeader::$variant => unsafe { std::str::from_utf8_unchecked( $bytes ) }, + )+ + } + } + + const fn from_bytes(bytes: &[u8]) -> Option { + match bytes { + $( + $bytes => Some(StandardHeader::$variant), + )+ + _ => None, + } + } + } + + #[cfg(test)] + mod standard_header_tests { + use super::HeaderName; + use std::str::{self, FromStr}; + + const TEST_HEADERS: &'static [(&'static HeaderName, &'static [u8])] = &[ + $( + (&super::$constant, $bytes), + )+ + ]; + + #[test] + fn from_str() { + for &(header, bytes) in TEST_HEADERS { + let utf8 = str::from_utf8(bytes).expect("string constants isn't utf8"); + assert_eq!(HeaderName::from_str(utf8).unwrap(), *header); + } + } + } + } +} + +// Generate constants for all standard NATS headers. +standard_headers! { + /// The name of the stream the message belongs to. + (NatsStream, NATS_STREAM, b"Nats-Stream"); + /// The sequence number of the message within the stream. + (NatsSequence, NATS_SEQUENCE, b"Nats-Sequence"); + /// The timestamp of when the message was sent. + (NatsTimeStamp, NATS_TIME_STAMP, b"Nats-Time-Stamp"); + /// The subject of the message, used for routing and filtering messages. + (NatsSubject, NATS_SUBJECT, b"Nats-Subject"); + /// A unique identifier for the message. + (NatsMessageId, NATS_MESSAGE_ID, b"Nats-Msg-Id"); + /// The last known stream the message was part of. + (NatsLastStream, NATS_LAST_STREAM, b"Nats-Last-Stream"); + /// The last known consumer that processed the message. + (NatsLastConsumer, NATS_LAST_CONSUMER, b"Nats-Last-Consumer"); + /// The last known sequence number of the message. + (NatsLastSequence, NATS_LAST_SEQUENCE, b"Nats-Last-Sequence"); + /// The expected last sequence number of the subject. + (NatsExpectgedLastSubjectSequence, NATS_EXPECTED_LAST_SUBJECT_SEQUENCE, b"Nats-Expected-Last-Subject-Sequence"); + /// The expected last message ID within the stream. + (NatsExpectedLastMessageId, NATS_EXPECTED_LAST_MESSAGE_ID, b"Nats-Expected-Last-Msg-Id"); + /// The expected last sequence number within the stream. + (NatsExpectedLastSequence, NATS_EXPECTED_LAST_SEQUENCE, b"Nats-Expected-Last-Sequence"); + /// The expected stream the message should be part of. + (NatsExpectedStream, NATS_EXPECTED_STREAM, b"Nats-Expected-Stream"); +} + +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +enum HeaderRepr { + Standard(StandardHeader), + Custom(String), +} + +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct HeaderName { - value: String, + inner: HeaderRepr, } + +impl HeaderName { + /// Returns a `str` representation of the header. + #[inline] + fn as_str(&self) -> &str { + match self.inner { + HeaderRepr::Standard(v) => v.as_str(), + HeaderRepr::Custom(ref v) => v.as_str(), + } + } +} + impl FromStr for HeaderName { type Err = ParseHeaderNameError; @@ -328,27 +415,32 @@ impl FromStr for HeaderName { return Err(ParseHeaderNameError); } - Ok(HeaderName { - value: s.to_string(), - }) + match StandardHeader::from_bytes(s.as_ref()) { + Some(v) => Ok(HeaderName { + inner: HeaderRepr::Standard(v), + }), + None => Ok(HeaderName { + inner: HeaderRepr::Custom(s.to_string()), + }), + } } } impl fmt::Display for HeaderName { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Display::fmt(&self.value, f) + fmt::Display::fmt(&self.as_str(), f) } } impl AsRef<[u8]> for HeaderName { fn as_ref(&self) -> &[u8] { - self.value.as_bytes() + self.as_str().as_bytes() } } impl AsRef for HeaderName { fn as_ref(&self) -> &str { - self.value.as_ref() + self.as_str() } } diff --git a/async-nats/tests/jetstream_tests.rs b/async-nats/tests/jetstream_tests.rs index 60a599ef0..0531432aa 100644 --- a/async-nats/tests/jetstream_tests.rs +++ b/async-nats/tests/jetstream_tests.rs @@ -31,7 +31,7 @@ mod jetstream { use super::*; use async_nats::connection::State; - use async_nats::header::{HeaderMap, NATS_MESSAGE_ID}; + use async_nats::header::{self, HeaderMap, NATS_MESSAGE_ID}; use async_nats::jetstream::consumer::{ self, AckPolicy, DeliverPolicy, Info, OrderedPushConsumer, PullConsumer, PushConsumer, }; @@ -680,7 +680,7 @@ mod jetstream { .headers .as_ref() .unwrap() - .get("Nats-Sequence") + .get(header::NATS_SEQUENCE) .unwrap() .iter() .next() @@ -739,7 +739,7 @@ mod jetstream { .headers .as_ref() .unwrap() - .get("Nats-Sequence") + .get(header::NATS_SEQUENCE) .unwrap() .iter() .next() @@ -813,7 +813,7 @@ mod jetstream { .headers .as_ref() .unwrap() - .get("Nats-Sequence") + .get(header::NATS_SEQUENCE) .unwrap() .iter() .next() @@ -882,7 +882,7 @@ mod jetstream { .headers .as_ref() .unwrap() - .get("Nats-Sequence") + .get(header::NATS_SEQUENCE) .unwrap() .iter() .next()