Skip to content

Commit

Permalink
perf: make filters synchronous
Browse files Browse the repository at this point in the history
  • Loading branch information
XAMPPRocky committed Oct 7, 2024
1 parent 7a81ba0 commit 007df10
Show file tree
Hide file tree
Showing 24 changed files with 147 additions and 193 deletions.
7 changes: 2 additions & 5 deletions src/components/proxy/packet_router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,7 @@ impl DownstreamReceiveWorkerConfig {
packet.source.into(),
packet.contents,
);
filters
.read(&mut context)
.await
.map_err(PipelineError::Filter)?;
filters.read(&mut context).map_err(PipelineError::Filter)?;

let ReadContext {
destinations,
Expand All @@ -123,7 +120,7 @@ impl DownstreamReceiveWorkerConfig {
for epa in destinations {
let session_key = SessionKey {
source: packet.source,
dest: epa.to_socket_addr().await?,
dest: epa.to_socket_addr()?,
};

sessions.send(session_key, contents.clone()).await?;
Expand Down
7 changes: 3 additions & 4 deletions src/components/proxy/sessions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ impl SessionPool {
asn_info,
packet,
)
.await
};

if let Err((asn_info, error)) = result {
Expand Down Expand Up @@ -334,7 +333,7 @@ impl SessionPool {
}

/// process_recv_packet processes a packet that is received by this session.
async fn process_recv_packet(
fn process_recv_packet(
config: Arc<crate::Config>,
downstream_sender: &DownstreamSender,
source: SocketAddr,
Expand All @@ -346,7 +345,7 @@ impl SessionPool {

let mut context = crate::filters::WriteContext::new(source.into(), dest.into(), packet);

if let Err(err) = config.filters.load().write(&mut context).await {
if let Err(err) = config.filters.load().write(&mut context) {
return Err((asn_info, err.into()));
}

Expand Down Expand Up @@ -723,7 +722,7 @@ mod tests {
async fn send_and_recv() {
let mut t = TestHelper::default();
let dest = t.run_echo_server(AddressType::Ipv6).await;
let mut dest = dest.to_socket_addr().await.unwrap();
let mut dest = dest.to_socket_addr().unwrap();
crate::test::map_addr_to_localhost(&mut dest);
let source = available_addr(AddressType::Ipv6).await;
let socket = tokio::net::UdpSocket::bind(source).await.unwrap();
Expand Down
9 changes: 4 additions & 5 deletions src/config/slot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,13 @@ impl<T: JsonSchema + Default> JsonSchema for Slot<T> {
}
}

#[async_trait::async_trait]
impl<T: crate::filters::Filter + Default> crate::filters::Filter for Slot<T> {
async fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
self.load().read(ctx).await
fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
self.load().read(ctx)
}

async fn write(&self, ctx: &mut WriteContext) -> Result<(), FilterError> {
self.load().write(ctx).await
fn write(&self, ctx: &mut WriteContext) -> Result<(), FilterError> {
self.load().write(ctx)
}
}

Expand Down
5 changes: 2 additions & 3 deletions src/filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,14 @@ where
/// `write` implementation to execute.
/// * Labels
/// * `filter` The name of the filter being executed.
#[async_trait::async_trait]
pub trait Filter: Send + Sync {
/// [`Filter::read`] is invoked when the proxy receives data from a
/// downstream connection on the listening port.
///
/// This function should return an `Some` if the packet processing should
/// proceed. If the packet should be rejected, it will return [`None`]
/// instead. By default, the context passes through unchanged.
async fn read(&self, _: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, _: &mut ReadContext) -> Result<(), FilterError> {
Ok(())
}

Expand All @@ -226,7 +225,7 @@ pub trait Filter: Send + Sync {
///
/// This function should return an `Some` if the packet processing should
/// proceed. If the packet should be rejected, it will return [`None`]
async fn write(&self, _: &mut WriteContext) -> Result<(), FilterError> {
fn write(&self, _: &mut WriteContext) -> Result<(), FilterError> {
Ok(())
}
}
16 changes: 7 additions & 9 deletions src/filters/capture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,9 @@ impl Capture {
}
}

#[async_trait::async_trait]
impl Filter for Capture {
#[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))]
async fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
let capture = self.capture.capture(&mut ctx.contents);
ctx.metadata.insert(
self.is_present_key,
Expand Down Expand Up @@ -109,7 +108,7 @@ mod tests {
}
});
let filter = Capture::from_config(Some(serde_json::from_value(config).unwrap()));
assert_end_strategy(&filter, TOKEN_KEY.into(), true).await;
assert_end_strategy(&filter, TOKEN_KEY.into(), true);
}

#[tokio::test]
Expand All @@ -121,7 +120,7 @@ mod tests {
});

let filter = Capture::from_config(Some(serde_json::from_value(config).unwrap()));
assert_end_strategy(&filter, CAPTURED_BYTES.into(), false).await;
assert_end_strategy(&filter, CAPTURED_BYTES.into(), false);
}

#[test]
Expand All @@ -145,7 +144,7 @@ mod tests {
};

let filter = Capture::from_config(config.into());
assert_end_strategy(&filter, TOKEN_KEY.into(), true).await;
assert_end_strategy(&filter, TOKEN_KEY.into(), true);
}

#[tokio::test]
Expand All @@ -167,7 +166,6 @@ mod tests {
(std::net::Ipv4Addr::LOCALHOST, 80).into(),
alloc_buffer(b"abc"),
))
.await
.is_err());
}

Expand All @@ -181,7 +179,7 @@ mod tests {
metadata_key: TOKEN_KEY.into(),
};
let filter = Capture::from_config(config.into());
assert_write_no_change(&filter).await;
assert_write_no_change(&filter);
}

#[test]
Expand Down Expand Up @@ -232,7 +230,7 @@ mod tests {
assert_eq!(b"hello", &*contents);
}

async fn assert_end_strategy<F>(filter: &F, key: metadata::Key, remove: bool)
fn assert_end_strategy<F>(filter: &F, key: metadata::Key, remove: bool)
where
F: Filter + ?Sized,
{
Expand All @@ -245,7 +243,7 @@ mod tests {
alloc_buffer(b"helloabc"),
);

filter.read(&mut context).await.unwrap();
filter.read(&mut context).unwrap();

if remove {
assert_eq!(b"hello", &*context.contents);
Expand Down
17 changes: 8 additions & 9 deletions src/filters/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,17 +272,16 @@ impl schemars::JsonSchema for FilterChain {
}
}

#[async_trait::async_trait]
impl Filter for FilterChain {
async fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
for ((id, instance), histogram) in self
.filters
.iter()
.zip(self.filter_read_duration_seconds.iter())
{
tracing::trace!(%id, "read filtering packet");
let timer = histogram.start_timer();
let result = instance.filter().read(ctx).await;
let result = instance.filter().read(ctx);
timer.stop_and_record();
match result {
Ok(()) => tracing::trace!(%id, "read passing packet"),
Expand All @@ -308,7 +307,7 @@ impl Filter for FilterChain {
Ok(())
}

async fn write(&self, ctx: &mut WriteContext) -> Result<(), FilterError> {
fn write(&self, ctx: &mut WriteContext) -> Result<(), FilterError> {
for ((id, instance), histogram) in self
.filters
.iter()
Expand All @@ -317,7 +316,7 @@ impl Filter for FilterChain {
{
tracing::trace!(%id, "write filtering packet");
let timer = histogram.start_timer();
let result = instance.filter().write(ctx).await;
let result = instance.filter().write(ctx);
timer.stop_and_record();
match result {
Ok(()) => tracing::trace!(%id, "write passing packet"),
Expand Down Expand Up @@ -389,7 +388,7 @@ mod tests {
alloc_buffer(b"hello"),
);

config.filters.read(&mut context).await.unwrap();
config.filters.read(&mut context).unwrap();
let expected = endpoints_fixture.clone();

assert_eq!(&*expected.endpoints(), &*context.destinations);
Expand All @@ -412,7 +411,7 @@ mod tests {
"127.0.0.1:70".parse().unwrap(),
alloc_buffer(b"hello"),
);
config.filters.write(&mut context).await.unwrap();
config.filters.write(&mut context).unwrap();

assert_eq!(
"receive",
Expand Down Expand Up @@ -442,7 +441,7 @@ mod tests {
alloc_buffer(b"hello"),
);

chain.read(&mut context).await.unwrap();
chain.read(&mut context).unwrap();
let expected = endpoints_fixture.clone();
assert_eq!(expected.endpoints(), context.destinations);
assert_eq!(
Expand All @@ -465,7 +464,7 @@ mod tests {
alloc_buffer(b"hello"),
);

chain.write(&mut context).await.unwrap();
chain.write(&mut context).unwrap();
assert_eq!(
"hello:our:127.0.0.1:80:127.0.0.1:70:our:127.0.0.1:80:127.0.0.1:70",
std::str::from_utf8(&context.contents).unwrap(),
Expand Down
35 changes: 11 additions & 24 deletions src/filters/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@ impl Compress {
}
}

#[async_trait::async_trait]
impl Filter for Compress {
#[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))]
async fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> {
let original_size = ctx.contents.len();

match self.on_read {
Expand Down Expand Up @@ -102,7 +101,7 @@ impl Filter for Compress {
}

#[cfg_attr(feature = "instrument", tracing::instrument(skip(self, ctx)))]
async fn write(&self, ctx: &mut WriteContext) -> Result<(), FilterError> {
fn write(&self, ctx: &mut WriteContext) -> Result<(), FilterError> {
let original_size = ctx.contents.len();
match self.on_write {
Action::Compress => {
Expand Down Expand Up @@ -254,7 +253,7 @@ mod tests {

});
let filter = Compress::from_config(Some(serde_json::from_value(config).unwrap()));
assert_downstream(&filter).await;
assert_downstream(&filter);
}

#[tokio::test]
Expand All @@ -266,7 +265,7 @@ mod tests {

});
let filter = Compress::from_config(Some(serde_json::from_value(config).unwrap()));
assert_downstream(&filter).await;
assert_downstream(&filter);
}

#[tokio::test]
Expand All @@ -278,7 +277,7 @@ mod tests {

});
let filter = Compress::from_config(Some(serde_json::from_value(config).unwrap()));
assert_downstream(&filter).await;
assert_downstream(&filter);
}

#[tokio::test]
Expand All @@ -302,10 +301,7 @@ mod tests {
"127.0.0.1:8080".parse().unwrap(),
alloc_buffer(&expected),
);
compress
.read(&mut read_context)
.await
.expect("should compress");
compress.read(&mut read_context).expect("should compress");

assert_ne!(expected, &*read_context.contents);
assert!(
Expand All @@ -324,7 +320,6 @@ mod tests {

compress
.write(&mut write_context)
.await
.expect("should decompress");

assert_eq!(expected, &*write_context.contents);
Expand All @@ -347,7 +342,6 @@ mod tests {
"127.0.0.1:8081".parse().unwrap(),
alloc_buffer(b"hello"),
))
.await
.is_err());

let compression = Compress::new(
Expand All @@ -368,7 +362,6 @@ mod tests {
"127.0.0.1:8080".parse().unwrap(),
alloc_buffer(b"hello"),
))
.await
.is_err());
}

Expand All @@ -391,7 +384,7 @@ mod tests {
"127.0.0.1:8080".parse().unwrap(),
alloc_buffer(b"hello"),
);
compression.read(&mut read_context).await.unwrap();
compression.read(&mut read_context).unwrap();
assert_eq!(b"hello", &*read_context.contents);

let mut write_context = WriteContext::new(
Expand All @@ -400,7 +393,7 @@ mod tests {
alloc_buffer(b"hello"),
);

compression.write(&mut write_context).await.unwrap();
compression.write(&mut write_context).unwrap();

assert_eq!(b"hello".to_vec(), &*write_context.contents)
}
Expand Down Expand Up @@ -455,7 +448,7 @@ mod tests {

/// assert compression work with decompress on read and compress on write
/// Returns the original data packet, and the compressed version
async fn assert_downstream<F>(filter: &F)
fn assert_downstream<F>(filter: &F)
where
F: Filter + ?Sized,
{
Expand All @@ -467,10 +460,7 @@ mod tests {
alloc_buffer(&expected),
);

filter
.write(&mut write_context)
.await
.expect("should compress");
filter.write(&mut write_context).expect("should compress");

assert_ne!(expected, &*write_context.contents);
assert!(
Expand All @@ -490,10 +480,7 @@ mod tests {
write_context.contents,
);

filter
.read(&mut read_context)
.await
.expect("should decompress");
filter.read(&mut read_context).expect("should decompress");

assert_eq!(expected, &*read_context.contents);
}
Expand Down
Loading

0 comments on commit 007df10

Please sign in to comment.