Skip to content

Commit

Permalink
Make ReadContext::destinations a reference
Browse files Browse the repository at this point in the history
This removes a heap allocation done for every downstream packet that uses eg. tokenrouter so that it can be allocated (usually) just once and reused
  • Loading branch information
Jake-Shadle committed Nov 26, 2024
1 parent e431875 commit 072488d
Show file tree
Hide file tree
Showing 20 changed files with 136 additions and 79 deletions.
2 changes: 1 addition & 1 deletion src/config/slot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ impl<T: JsonSchema + Default> JsonSchema for Slot<T> {
}

impl<T: crate::filters::Filter + Default> crate::filters::Filter for Slot<T> {
fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, ctx: &mut ReadContext<'_>) -> Result<(), FilterError> {
self.load().read(ctx)
}

Expand Down
2 changes: 1 addition & 1 deletion src/filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ pub trait Filter: Send + Sync {
/// This function should return an `Some` if the packet processing should
/// proceed. If the packet should be rejected, it will return [`None`]
/// instead. By default, the context passes through unchanged.
fn read(&self, _: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, _: &mut ReadContext<'_>) -> Result<(), FilterError> {
Ok(())
}

Expand Down
6 changes: 5 additions & 1 deletion src/filters/capture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ impl Capture {

impl Filter for Capture {
#[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))]
fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, ctx: &mut ReadContext<'_>) -> Result<(), FilterError> {
let capture = self.capture.capture(&mut ctx.contents);
ctx.metadata.insert(
self.is_present_key,
Expand Down Expand Up @@ -160,11 +160,13 @@ mod tests {
let endpoints = crate::net::cluster::ClusterMap::new_default(
[Endpoint::new("127.0.0.1:81".parse().unwrap())].into(),
);
let mut dest = Vec::new();
assert!(filter
.read(&mut ReadContext::new(
endpoints.into(),
(std::net::Ipv4Addr::LOCALHOST, 80).into(),
alloc_buffer(b"abc"),
&mut dest,
))
.is_err());
}
Expand Down Expand Up @@ -237,10 +239,12 @@ mod tests {
let endpoints = crate::net::cluster::ClusterMap::new_default(
[Endpoint::new("127.0.0.1:81".parse().unwrap())].into(),
);
let mut dest = Vec::new();
let mut context = ReadContext::new(
endpoints.into(),
"127.0.0.1:80".parse().unwrap(),
alloc_buffer(b"helloabc"),
&mut dest,
);

filter.read(&mut context).unwrap();
Expand Down
40 changes: 20 additions & 20 deletions src/filters/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ impl schemars::JsonSchema for FilterChain {
}

impl Filter for FilterChain {
fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, ctx: &mut ReadContext<'_>) -> Result<(), FilterError> {
for ((id, instance), histogram) in self
.filters
.iter()
Expand All @@ -296,12 +296,8 @@ impl Filter for FilterChain {
// has rejected, and the destinations is empty, we passthrough to all.
// Which mimics the old behaviour while avoid clones in most cases.
if ctx.destinations.is_empty() {
ctx.destinations = ctx
.endpoints
.endpoints()
.into_iter()
.map(|ep| ep.address)
.collect();
ctx.destinations
.extend(ctx.endpoints.endpoints().into_iter().map(|ep| ep.address));
}

Ok(())
Expand Down Expand Up @@ -382,10 +378,12 @@ mod tests {
crate::test::load_test_filters();
let config = TestConfig::new();
let endpoints_fixture = endpoints();
let mut dest = Vec::new();
let mut context = ReadContext::new(
endpoints_fixture.clone(),
"127.0.0.1:70".parse().unwrap(),
alloc_buffer(b"hello"),
&mut dest,
);

config.filters.read(&mut context).unwrap();
Expand Down Expand Up @@ -435,22 +433,24 @@ mod tests {
.unwrap();

let endpoints_fixture = endpoints();
let mut context = ReadContext::new(
endpoints_fixture.clone(),
"127.0.0.1:70".parse().unwrap(),
alloc_buffer(b"hello"),
);

chain.read(&mut context).unwrap();
let mut dest = Vec::new();

let (contents, metadata) = {
let mut context = ReadContext::new(
endpoints_fixture.clone(),
"127.0.0.1:70".parse().unwrap(),
alloc_buffer(b"hello"),
&mut dest,
);
chain.read(&mut context).unwrap();
(context.contents, context.metadata)
};
let expected = endpoints_fixture.clone();
assert_eq!(expected.endpoints(), context.destinations);
assert_eq!(
b"hello:odr:127.0.0.1:70:odr:127.0.0.1:70",
&*context.contents
);
assert_eq!(expected.endpoints(), dest);
assert_eq!(b"hello:odr:127.0.0.1:70:odr:127.0.0.1:70", &*contents);
assert_eq!(
"receive:receive",
context.metadata[&"downstream".into()].as_string().unwrap()
metadata[&"downstream".into()].as_string().unwrap()
);

let mut context = WriteContext::new(
Expand Down
10 changes: 9 additions & 1 deletion src/filters/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ impl Compress {

impl Filter for Compress {
#[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))]
fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, ctx: &mut ReadContext<'_>) -> Result<(), FilterError> {
let original_size = ctx.contents.len();

match self.on_read {
Expand Down Expand Up @@ -296,10 +296,12 @@ mod tests {
let endpoints = crate::net::cluster::ClusterMap::new_default(
[Endpoint::new("127.0.0.1:81".parse().unwrap())].into(),
);
let mut dest = Vec::new();
let mut read_context = ReadContext::new(
endpoints.into(),
"127.0.0.1:8080".parse().unwrap(),
alloc_buffer(&expected),
&mut dest,
);
compress.read(&mut read_context).expect("should compress");

Expand Down Expand Up @@ -356,11 +358,13 @@ mod tests {
let endpoints = crate::net::cluster::ClusterMap::new_default(
[Endpoint::new("127.0.0.1:81".parse().unwrap())].into(),
);
let mut dest = Vec::new();
assert!(compression
.read(&mut ReadContext::new(
endpoints.into(),
"127.0.0.1:8080".parse().unwrap(),
alloc_buffer(b"hello"),
&mut dest,
))
.is_err());
}
Expand All @@ -379,10 +383,12 @@ mod tests {
let endpoints = crate::net::cluster::ClusterMap::new_default(
[Endpoint::new("127.0.0.1:81".parse().unwrap())].into(),
);
let mut dest = Vec::new();
let mut read_context = ReadContext::new(
endpoints.into(),
"127.0.0.1:8080".parse().unwrap(),
alloc_buffer(b"hello"),
&mut dest,
);
compression.read(&mut read_context).unwrap();
assert_eq!(b"hello", &*read_context.contents);
Expand Down Expand Up @@ -474,10 +480,12 @@ mod tests {
let endpoints = crate::net::cluster::ClusterMap::new_default(
[Endpoint::new("127.0.0.1:81".parse().unwrap())].into(),
);
let mut dest = Vec::new();
let mut read_context = ReadContext::new(
endpoints.into(),
"127.0.0.1:8080".parse().unwrap(),
write_context.contents,
&mut dest,
);

filter.read(&mut read_context).expect("should decompress");
Expand Down
2 changes: 1 addition & 1 deletion src/filters/concatenate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl Concatenate {
}

impl Filter for Concatenate {
fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, ctx: &mut ReadContext<'_>) -> Result<(), FilterError> {
match self.on_read {
Strategy::Append => {
ctx.contents.extend_from_slice(&self.bytes);
Expand Down
2 changes: 1 addition & 1 deletion src/filters/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl Debug {

impl Filter for Debug {
#[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))]
fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, ctx: &mut ReadContext<'_>) -> Result<(), FilterError> {
info!(id = ?self.config.id, source = ?&ctx.source, contents = ?String::from_utf8_lossy(&ctx.contents), "Read filter event");
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion src/filters/drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ impl Drop {

impl Filter for Drop {
#[cfg_attr(feature = "instrument", tracing::instrument(skip_all))]
fn read(&self, _: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, _: &mut ReadContext<'_>) -> Result<(), FilterError> {
Err(FilterError::Dropped)
}

Expand Down
18 changes: 15 additions & 3 deletions src/filters/firewall.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ impl StaticFilter for Firewall {

impl Filter for Firewall {
#[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))]
fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, ctx: &mut ReadContext<'_>) -> Result<(), FilterError> {
for rule in &self.on_read {
if rule.contains(ctx.source.to_socket_addr()?) {
return match rule.action {
Expand Down Expand Up @@ -134,13 +134,25 @@ mod tests {
let endpoints = crate::net::cluster::ClusterMap::new_default(
[Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())].into(),
);
let mut ctx = ReadContext::new(endpoints.into(), (local_ip, 80).into(), alloc_buffer([]));
let mut dest = Vec::new();
let mut ctx = ReadContext::new(
endpoints.into(),
(local_ip, 80).into(),
alloc_buffer([]),
&mut dest,
);
assert!(firewall.read(&mut ctx).is_ok());

let endpoints = crate::net::cluster::ClusterMap::new_default(
[Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())].into(),
);
let mut ctx = ReadContext::new(endpoints.into(), (local_ip, 2000).into(), alloc_buffer([]));
let mut dest = Vec::new();
let mut ctx = ReadContext::new(
endpoints.into(),
(local_ip, 2000).into(),
alloc_buffer([]),
&mut dest,
);
assert!(logs_contain("quilkin::filters::firewall")); // the given name to the the logger by tracing
assert!(logs_contain("Allow"));

Expand Down
13 changes: 8 additions & 5 deletions src/filters/load_balancer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl LoadBalancer {
}

impl Filter for LoadBalancer {
fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, ctx: &mut ReadContext<'_>) -> Result<(), FilterError> {
self.endpoint_chooser.choose_endpoints(ctx);
Ok(())
}
Expand Down Expand Up @@ -75,11 +75,14 @@ mod tests {
.map(Endpoint::new)
.collect::<std::collections::BTreeSet<_>>();
let endpoints = crate::net::cluster::ClusterMap::new_default(endpoints);
let mut context = ReadContext::new(endpoints.into(), source, alloc_buffer([]));

filter.read(&mut context).unwrap();
let mut dest = Vec::new();
{
let mut context =
ReadContext::new(endpoints.into(), source, alloc_buffer([]), &mut dest);
filter.read(&mut context).unwrap();
}

context.destinations
dest
}

#[tokio::test]
Expand Down
37 changes: 20 additions & 17 deletions src/filters/load_balancer/endpoint_chooser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::filters::ReadContext;
/// EndpointChooser chooses from a set of endpoints that a proxy is connected to.
pub trait EndpointChooser: Send + Sync {
/// choose_endpoints asks for the next endpoint(s) to use.
fn choose_endpoints(&self, endpoints: &mut ReadContext);
fn choose_endpoints(&self, endpoints: &mut ReadContext<'_>);
}

/// RoundRobinEndpointChooser chooses endpoints in round-robin order.
Expand All @@ -45,41 +45,44 @@ impl RoundRobinEndpointChooser {
}

impl EndpointChooser for RoundRobinEndpointChooser {
fn choose_endpoints(&self, ctx: &mut ReadContext) {
fn choose_endpoints(&self, ctx: &mut ReadContext<'_>) {
let count = self.next_endpoint.fetch_add(1, Ordering::Relaxed);
// Note: The index is guaranteed to be in range.
ctx.destinations = vec![ctx
.endpoints
.nth_endpoint(count % ctx.endpoints.num_of_endpoints())
.unwrap()
.address
.clone()];
ctx.destinations.push(
ctx.endpoints
.nth_endpoint(count % ctx.endpoints.num_of_endpoints())
.unwrap()
.address
.clone(),
);
}
}

/// RandomEndpointChooser chooses endpoints in random order.
pub struct RandomEndpointChooser;

impl EndpointChooser for RandomEndpointChooser {
fn choose_endpoints(&self, ctx: &mut ReadContext) {
fn choose_endpoints(&self, ctx: &mut ReadContext<'_>) {
// The index is guaranteed to be in range.
let index = thread_rng().gen_range(0..ctx.endpoints.num_of_endpoints());
ctx.destinations = vec![ctx.endpoints.nth_endpoint(index).unwrap().address.clone()];
ctx.destinations
.push(ctx.endpoints.nth_endpoint(index).unwrap().address.clone());
}
}

/// HashEndpointChooser chooses endpoints based on a hash of source IP and port.
pub struct HashEndpointChooser;

impl EndpointChooser for HashEndpointChooser {
fn choose_endpoints(&self, ctx: &mut ReadContext) {
fn choose_endpoints(&self, ctx: &mut ReadContext<'_>) {
let mut hasher = DefaultHasher::new();
ctx.source.hash(&mut hasher);
ctx.destinations = vec![ctx
.endpoints
.nth_endpoint(hasher.finish() as usize % ctx.endpoints.num_of_endpoints())
.unwrap()
.address
.clone()];
ctx.destinations.push(
ctx.endpoints
.nth_endpoint(hasher.finish() as usize % ctx.endpoints.num_of_endpoints())
.unwrap()
.address
.clone(),
);
}
}
10 changes: 8 additions & 2 deletions src/filters/local_rate_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ impl LocalRateLimit {
}

impl Filter for LocalRateLimit {
fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, ctx: &mut ReadContext<'_>) -> Result<(), FilterError> {
if self.acquire_token(&ctx.source) {
Ok(())
} else {
Expand Down Expand Up @@ -235,7 +235,13 @@ mod tests {
.into(),
);

let mut context = ReadContext::new(endpoints.into(), address.clone(), alloc_buffer([9]));
let mut dest = Vec::new();
let mut context = ReadContext::new(
endpoints.into(),
address.clone(),
alloc_buffer([9]),
&mut dest,
);
let result = r.read(&mut context);

if should_succeed {
Expand Down
Loading

0 comments on commit 072488d

Please sign in to comment.