Skip to content

Commit

Permalink
Add const representation for standard header names
Browse files Browse the repository at this point in the history
  • Loading branch information
caspervonb authored May 29, 2023
1 parent 4211b60 commit f30b3ec
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 40 deletions.
162 changes: 127 additions & 35 deletions async-nats/src/header.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,6 @@
use std::{collections::HashMap, fmt, slice, str::FromStr};

use serde::Serialize;

pub const NATS_LAST_STREAM: &str = "Nats-Last-Stream";
pub const NATS_LAST_CONSUMER: &str = "Nats-Last-Consumer";

/// Direct Get headers
pub const NATS_STREAM: &str = "Nats-Stream";
pub const NATS_SEQUENCE: &str = "Nats-Sequence";
pub const NATS_TIME_STAMP: &str = "Nats-Time-Stamp";
pub const NATS_SUBJECT: &str = "Nats-Subject";
pub const NATS_LAST_SEQUENCE: &str = "Nats-Last-Sequence";

/// Nats-Expected-Last-Subject-Sequence
pub const NATS_EXPECTED_LAST_SUBJECT_SEQUENCE: &str = "Nats-Expected-Last-Subject-Sequence";
/// Message identifier used for deduplication window
pub const NATS_MESSAGE_ID: &str = "Nats-Msg-Id";
/// Last expected message ID for JetStream message publish
pub const NATS_EXPECTED_LAST_MESSAGE_ID: &str = "Nats-Expected-Last-Msg-Id";
/// Last expected sequence for JetStream message publish
pub const NATS_EXPECTED_LAST_SEQUENCE: &str = "Nats-Expected-Last-Sequence";
/// Expect that given message will be ingested by specified stream.
pub const NATS_EXPECTED_STREAM: &str = "Nats-Expected-Stream";

/// A struct for handling NATS headers.
/// Has a similar API to [http::header], but properly serializes and deserializes
/// according to NATS requirements.
Expand All @@ -56,7 +33,7 @@ pub const NATS_EXPECTED_STREAM: &str = "Nats-Expected-Stream";
/// # Ok(())
/// # }
/// ```
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Default)]
#[derive(Clone, PartialEq, Eq, Debug, Default)]
pub struct HeaderMap {
inner: HashMap<HeaderName, HeaderValue>,
}
Expand Down Expand Up @@ -154,7 +131,7 @@ impl HeaderMap {
buf.extend_from_slice(b"NATS/1.0\r\n");
for (k, vs) in &self.inner {
for v in vs.iter() {
buf.extend_from_slice(k.value.as_bytes());
buf.extend_from_slice(k.as_str().as_bytes());
buf.extend_from_slice(b": ");
buf.extend_from_slice(v.as_bytes());
buf.extend_from_slice(b"\r\n");
Expand All @@ -179,7 +156,7 @@ impl HeaderMap {
/// # Ok(())
/// # }
/// ```
#[derive(Clone, PartialEq, Eq, Debug, Serialize, Default)]
#[derive(Clone, PartialEq, Eq, Debug, Default)]
pub struct HeaderValue {
value: Vec<String>,
}
Expand Down Expand Up @@ -288,7 +265,7 @@ pub trait IntoHeaderName {
impl IntoHeaderName for &str {
fn into_header_name(self) -> HeaderName {
HeaderName {
value: self.to_string(),
inner: HeaderRepr::Custom(self.to_string()),
}
}
}
Expand Down Expand Up @@ -316,10 +293,120 @@ impl IntoHeaderValue for HeaderValue {
}
}

#[derive(Clone, PartialEq, Eq, Hash, Debug, Serialize)]
macro_rules! standard_headers {
(
$(
$(#[$docs:meta])*
($variant:ident, $constant:ident, $bytes:literal);
)+
) => {
#[allow(clippy::enum_variant_names)]
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
enum StandardHeader {
$(
$variant,
)+
}

$(
$(#[$docs])*
pub const $constant: HeaderName = HeaderName {
inner: HeaderRepr::Standard(StandardHeader::$variant),
};
)+

impl StandardHeader {
#[inline]
fn as_str(&self) -> &'static str {
match *self {
$(
StandardHeader::$variant => unsafe { std::str::from_utf8_unchecked( $bytes ) },
)+
}
}

const fn from_bytes(bytes: &[u8]) -> Option<StandardHeader> {
match bytes {
$(
$bytes => Some(StandardHeader::$variant),
)+
_ => None,
}
}
}

#[cfg(test)]
mod standard_header_tests {
use super::HeaderName;
use std::str::{self, FromStr};

const TEST_HEADERS: &'static [(&'static HeaderName, &'static [u8])] = &[
$(
(&super::$constant, $bytes),
)+
];

#[test]
fn from_str() {
for &(header, bytes) in TEST_HEADERS {
let utf8 = str::from_utf8(bytes).expect("string constants isn't utf8");
assert_eq!(HeaderName::from_str(utf8).unwrap(), *header);
}
}
}
}
}

// Generate constants for all standard NATS headers.
standard_headers! {
/// The name of the stream the message belongs to.
(NatsStream, NATS_STREAM, b"Nats-Stream");
/// The sequence number of the message within the stream.
(NatsSequence, NATS_SEQUENCE, b"Nats-Sequence");
/// The timestamp of when the message was sent.
(NatsTimeStamp, NATS_TIME_STAMP, b"Nats-Time-Stamp");
/// The subject of the message, used for routing and filtering messages.
(NatsSubject, NATS_SUBJECT, b"Nats-Subject");
/// A unique identifier for the message.
(NatsMessageId, NATS_MESSAGE_ID, b"Nats-Msg-Id");
/// The last known stream the message was part of.
(NatsLastStream, NATS_LAST_STREAM, b"Nats-Last-Stream");
/// The last known consumer that processed the message.
(NatsLastConsumer, NATS_LAST_CONSUMER, b"Nats-Last-Consumer");
/// The last known sequence number of the message.
(NatsLastSequence, NATS_LAST_SEQUENCE, b"Nats-Last-Sequence");
/// The expected last sequence number of the subject.
(NatsExpectgedLastSubjectSequence, NATS_EXPECTED_LAST_SUBJECT_SEQUENCE, b"Nats-Expected-Last-Subject-Sequence");
/// The expected last message ID within the stream.
(NatsExpectedLastMessageId, NATS_EXPECTED_LAST_MESSAGE_ID, b"Nats-Expected-Last-Msg-Id");
/// The expected last sequence number within the stream.
(NatsExpectedLastSequence, NATS_EXPECTED_LAST_SEQUENCE, b"Nats-Expected-Last-Sequence");
/// The expected stream the message should be part of.
(NatsExpectedStream, NATS_EXPECTED_STREAM, b"Nats-Expected-Stream");
}

#[derive(Debug, Hash, PartialEq, Eq, Clone)]
enum HeaderRepr {
Standard(StandardHeader),
Custom(String),
}

#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct HeaderName {
value: String,
inner: HeaderRepr,
}

impl HeaderName {
/// Returns a `str` representation of the header.
#[inline]
fn as_str(&self) -> &str {
match self.inner {
HeaderRepr::Standard(v) => v.as_str(),
HeaderRepr::Custom(ref v) => v.as_str(),
}
}
}

impl FromStr for HeaderName {
type Err = ParseHeaderNameError;

Expand All @@ -328,27 +415,32 @@ impl FromStr for HeaderName {
return Err(ParseHeaderNameError);
}

Ok(HeaderName {
value: s.to_string(),
})
match StandardHeader::from_bytes(s.as_ref()) {
Some(v) => Ok(HeaderName {
inner: HeaderRepr::Standard(v),
}),
None => Ok(HeaderName {
inner: HeaderRepr::Custom(s.to_string()),
}),
}
}
}

impl fmt::Display for HeaderName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.value, f)
fmt::Display::fmt(&self.as_str(), f)
}
}

impl AsRef<[u8]> for HeaderName {
fn as_ref(&self) -> &[u8] {
self.value.as_bytes()
self.as_str().as_bytes()
}
}

impl AsRef<str> for HeaderName {
fn as_ref(&self) -> &str {
self.value.as_ref()
self.as_str()
}
}

Expand Down
10 changes: 5 additions & 5 deletions async-nats/tests/jetstream_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ mod jetstream {

use super::*;
use async_nats::connection::State;
use async_nats::header::{HeaderMap, NATS_MESSAGE_ID};
use async_nats::header::{self, HeaderMap, NATS_MESSAGE_ID};
use async_nats::jetstream::consumer::{
self, AckPolicy, DeliverPolicy, Info, OrderedPushConsumer, PullConsumer, PushConsumer,
};
Expand Down Expand Up @@ -680,7 +680,7 @@ mod jetstream {
.headers
.as_ref()
.unwrap()
.get("Nats-Sequence")
.get(header::NATS_SEQUENCE)
.unwrap()
.iter()
.next()
Expand Down Expand Up @@ -739,7 +739,7 @@ mod jetstream {
.headers
.as_ref()
.unwrap()
.get("Nats-Sequence")
.get(header::NATS_SEQUENCE)
.unwrap()
.iter()
.next()
Expand Down Expand Up @@ -813,7 +813,7 @@ mod jetstream {
.headers
.as_ref()
.unwrap()
.get("Nats-Sequence")
.get(header::NATS_SEQUENCE)
.unwrap()
.iter()
.next()
Expand Down Expand Up @@ -882,7 +882,7 @@ mod jetstream {
.headers
.as_ref()
.unwrap()
.get("Nats-Sequence")
.get(header::NATS_SEQUENCE)
.unwrap()
.iter()
.next()
Expand Down

0 comments on commit f30b3ec

Please sign in to comment.