Skip to content

Commit

Permalink
refactor Wrapper type
Browse files Browse the repository at this point in the history
  • Loading branch information
rukai committed Aug 21, 2024
1 parent 70cf0be commit f93baab
Show file tree
Hide file tree
Showing 30 changed files with 389 additions and 336 deletions.
3 changes: 2 additions & 1 deletion changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ This assists us in knowing when to make the next release a breaking release and

### shotover rust API

`Transform::transform` now takes `&mut Wrapper` instead of `Wrapper`.
`Transform::transform` previously took a `Wrapper` type as an argument.
That has now been split into 2 separate types: `&mut ChainState` and `DownChainTransforms`.

## 0.4.0

Expand Down
14 changes: 8 additions & 6 deletions custom-transforms-example/src/redis_get_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use shotover::frame::{Frame, MessageType, RedisFrame};
use shotover::message::{MessageIdSet, Messages};
use shotover::transforms::{DownChainProtocol, TransformContextBuilder, UpChainProtocol};
use shotover::transforms::{
Transform, TransformBuilder, TransformConfig, TransformContextConfig, Wrapper,
ChainState, DownChainTransforms, Transform, TransformBuilder, TransformConfig,
TransformContextConfig,
};
use shotover::transforms::{DownChainProtocol, TransformContextBuilder, UpChainProtocol};

#[derive(Serialize, Deserialize, Debug)]
#[serde(deny_unknown_fields)]
Expand Down Expand Up @@ -64,18 +65,19 @@ impl Transform for RedisGetRewrite {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
requests_wrapper: &'shorter mut Wrapper<'longer>,
chain_state: &mut ChainState,
down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
for message in requests_wrapper.requests.iter_mut() {
for message in chain_state.requests.iter_mut() {
if let Some(frame) = message.frame() {
if is_get(frame) {
self.get_requests.insert(message.id());
}
}
}
let mut responses = requests_wrapper.call_next_transform().await?;
let mut responses = down_chain.call_next_transform(chain_state).await?;

for response in responses.iter_mut() {
if response
Expand Down
64 changes: 32 additions & 32 deletions shotover/benches/benches/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use shotover::transforms::query_counter::QueryCounter;
use shotover::transforms::redis::cluster_ports_rewrite::RedisClusterPortsRewrite;
use shotover::transforms::throttling::RequestThrottlingConfig;
use shotover::transforms::{
TransformConfig, TransformContextBuilder, TransformContextConfig, Wrapper,
ChainState, TransformConfig, TransformContextBuilder, TransformContextConfig,
};

fn criterion_benchmark(c: &mut Criterion) {
Expand All @@ -32,14 +32,14 @@ fn criterion_benchmark(c: &mut Criterion) {
// loopback is the fastest possible transform as it does not even have to drop the received requests
{
let chain = TransformChainBuilder::new(vec![Box::<Loopback>::default()], "bench");
let wrapper = Wrapper::new_with_addr(
let chain_state = ChainState::new_with_addr(
vec![Message::from_frame(Frame::Redis(RedisFrame::Null))],
"127.0.0.1:6379".parse().unwrap(),
);

group.bench_function("loopback", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput::new_pre_used(&chain, &wrapper),
|| BenchInput::new_pre_used(&chain, &chain_state),
BenchInput::bench,
BatchSize::SmallInput,
)
Expand All @@ -48,14 +48,14 @@ fn criterion_benchmark(c: &mut Criterion) {

{
let chain = TransformChainBuilder::new(vec![Box::<NullSink>::default()], "bench");
let wrapper = Wrapper::new_with_addr(
let chain_state = ChainState::new_with_addr(
vec![Message::from_frame(Frame::Redis(RedisFrame::Null))],
"127.0.0.1:6379".parse().unwrap(),
);

group.bench_function("nullsink", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput::new_pre_used(&chain, &wrapper),
|| BenchInput::new_pre_used(&chain, &chain_state),
BenchInput::bench,
BatchSize::SmallInput,
)
Expand All @@ -73,7 +73,7 @@ fn criterion_benchmark(c: &mut Criterion) {
],
"bench",
);
let wrapper = Wrapper::new_with_addr(
let chain_state = ChainState::new_with_addr(
vec![
Message::from_frame(Frame::Redis(RedisFrame::Array(vec![
RedisFrame::BulkString(Bytes::from_static(b"SET")),
Expand All @@ -90,7 +90,7 @@ fn criterion_benchmark(c: &mut Criterion) {

group.bench_function("redis_filter", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput::new_pre_used(&chain, &wrapper),
|| BenchInput::new_pre_used(&chain, &chain_state),
BenchInput::bench,
BatchSize::SmallInput,
)
Expand All @@ -105,7 +105,7 @@ fn criterion_benchmark(c: &mut Criterion) {
],
"bench",
);
let wrapper = Wrapper::new_with_addr(
let chain_state = ChainState::new_with_addr(
vec![Message::from_frame(Frame::Redis(RedisFrame::Array(vec![
RedisFrame::BulkString(Bytes::from_static(b"SET")),
RedisFrame::BulkString(Bytes::from_static(b"foo")),
Expand All @@ -116,7 +116,7 @@ fn criterion_benchmark(c: &mut Criterion) {

group.bench_function("redis_cluster_ports_rewrite", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput::new_pre_used(&chain, &wrapper),
|| BenchInput::new_pre_used(&chain, &chain_state),
BenchInput::bench,
BatchSize::SmallInput,
)
Expand All @@ -141,7 +141,7 @@ fn criterion_benchmark(c: &mut Criterion) {
],
"bench",
);
let wrapper = Wrapper::new_with_addr(
let chain_state = ChainState::new_with_addr(
vec![Message::from_bytes(
Bytes::from(
// a simple select query
Expand All @@ -160,7 +160,7 @@ fn criterion_benchmark(c: &mut Criterion) {

group.bench_function("cassandra_request_throttling_unparsed", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput::new_pre_used(&chain, &wrapper),
|| BenchInput::new_pre_used(&chain, &chain_state),
BenchInput::bench,
BatchSize::SmallInput,
)
Expand All @@ -176,7 +176,7 @@ fn criterion_benchmark(c: &mut Criterion) {
"bench",
);

let wrapper = Wrapper::new_with_addr(
let chain_state = ChainState::new_with_addr(
vec![Message::from_bytes(
CassandraFrame {
version: Version::V4,
Expand Down Expand Up @@ -211,7 +211,7 @@ fn criterion_benchmark(c: &mut Criterion) {

group.bench_function("cassandra_rewrite_peers_passthrough", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput::new_pre_used(&chain, &wrapper),
|| BenchInput::new_pre_used(&chain, &chain_state),
BenchInput::bench,
BatchSize::SmallInput,
)
Expand Down Expand Up @@ -248,25 +248,25 @@ fn criterion_benchmark(c: &mut Criterion) {
"bench",
);

let wrapper = cassandra_parsed_query(
let chain_state = cassandra_parsed_query(
"INSERT INTO test_protect_keyspace.unprotected_table (pk, cluster, col1, col2, col3) VALUES ('pk1', 'cluster', 'I am gonna get encrypted!!', 42, true);"
);

group.bench_function("cassandra_protect_unprotected", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput::new_pre_used(&chain, &wrapper),
|| BenchInput::new_pre_used(&chain, &chain_state),
BenchInput::bench,
BatchSize::SmallInput,
)
});

let wrapper = cassandra_parsed_query(
let chain_state = cassandra_parsed_query(
"INSERT INTO test_protect_keyspace.protected_table (pk, cluster, col1, col2, col3) VALUES ('pk1', 'cluster', 'I am gonna get encrypted!!', 42, true);"
);

group.bench_function("cassandra_protect_protected", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput::new_pre_used(&chain, &wrapper),
|| BenchInput::new_pre_used(&chain, &chain_state),
BenchInput::bench,
BatchSize::SmallInput,
)
Expand All @@ -281,7 +281,7 @@ fn criterion_benchmark(c: &mut Criterion) {
],
"bench",
);
let wrapper = Wrapper::new_with_addr(
let chain_state = ChainState::new_with_addr(
vec![
Message::from_frame(Frame::Redis(RedisFrame::Array(vec![
RedisFrame::BulkString(Bytes::from_static(b"SET")),
Expand All @@ -298,15 +298,15 @@ fn criterion_benchmark(c: &mut Criterion) {

group.bench_function("query_counter_fresh", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput::new_fresh(&chain, &wrapper),
|| BenchInput::new_fresh(&chain, &chain_state),
BenchInput::bench,
BatchSize::SmallInput,
)
});

group.bench_function("query_counter_pre_used", |b| {
b.to_async(&rt).iter_batched(
|| BenchInput::new_pre_used(&chain, &wrapper),
|| BenchInput::new_pre_used(&chain, &chain_state),
BenchInput::bench,
BatchSize::SmallInput,
)
Expand All @@ -315,8 +315,8 @@ fn criterion_benchmark(c: &mut Criterion) {
}

#[cfg(feature = "alpha-transforms")]
fn cassandra_parsed_query(query: &str) -> Wrapper {
Wrapper::new_with_addr(
fn cassandra_parsed_query(query: &str) -> ChainState {
ChainState::new_with_addr(
vec![Message::from_frame(Frame::Cassandra(CassandraFrame {
version: Version::V4,
stream_id: 0,
Expand All @@ -341,38 +341,38 @@ fn cassandra_parsed_query(query: &str) -> Wrapper {
)
}

struct BenchInput<'a> {
struct BenchInput {
chain: TransformChain,
wrapper: Wrapper<'a>,
chain_state: ChainState,
}

impl<'a> BenchInput<'a> {
impl BenchInput {
// Setup the bench such that the chain is completely fresh
fn new_fresh(chain: &TransformChainBuilder, wrapper: &Wrapper<'a>) -> Self {
fn new_fresh(chain: &TransformChainBuilder, chain_state: &ChainState) -> Self {
BenchInput {
chain: chain.build(TransformContextBuilder::new_test()),
wrapper: wrapper.clone(),
chain_state: chain_state.clone(),
}
}

// Setup the bench such that the chain has already had the test wrapper passed through it.
// Setup the bench such that the chain has already had the test chain_state passed through it.
// This ensures that any adhoc setup for that message type has been performed.
// This is a more realistic bench for typical usage.
fn new_pre_used(chain: &TransformChainBuilder, wrapper: &Wrapper<'a>) -> Self {
fn new_pre_used(chain: &TransformChainBuilder, chain_state: &ChainState) -> Self {
let mut chain = chain.build(TransformContextBuilder::new_test());

// Run the chain once so we are measuring the chain once each transform has been fully initialized
futures::executor::block_on(chain.process_request(&mut wrapper.clone())).unwrap();
futures::executor::block_on(chain.process_request(&mut chain_state.clone())).unwrap();

BenchInput {
chain,
wrapper: wrapper.clone(),
chain_state: chain_state.clone(),
}
}

async fn bench(mut self) -> (Vec<Message>, TransformChain) {
// Return both the chain itself and the response to avoid measuring the time to drop the values in the benchmark
let mut wrapper = self.wrapper;
let mut wrapper = self.chain_state;
(
self.chain.process_request(&mut wrapper).await.unwrap(),
self.chain,
Expand Down
13 changes: 7 additions & 6 deletions shotover/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::message::{Message, MessageIdMap, Messages, Metadata};
use crate::sources::Transport;
use crate::tls::{AcceptError, TlsAcceptor};
use crate::transforms::chain::{TransformChain, TransformChainBuilder};
use crate::transforms::{TransformContextBuilder, TransformContextConfig, Wrapper};
use crate::transforms::{ChainState, TransformContextBuilder, TransformContextConfig};
use anyhow::{anyhow, Result};
use bytes::BytesMut;
use futures::future::join_all;
Expand Down Expand Up @@ -637,7 +637,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {
// Only flush messages if we are shutting down due to shotover shutdown or client disconnect
// If a Transform::transform returns an Err the transform is no longer in a usable state and needs to be destroyed without reusing.
if let Ok(CloseReason::ShotoverShutdown | CloseReason::ClientClosed) = result {
match self.chain.process_request(&mut Wrapper::flush()).await {
match self.chain.process_request(&mut ChainState::flush()).await {
Ok(_) => {}
Err(e) => error!(
"{:?}",
Expand Down Expand Up @@ -727,10 +727,11 @@ impl<C: CodecBuilder + 'static> Handler<C> {
out_tx: &mpsc::UnboundedSender<Messages>,
requests: Messages,
) -> Result<Option<CloseReason>> {
let mut wrapper = Wrapper::new_with_addr(requests, local_addr);
let mut chain_state = ChainState::new_with_addr(requests, local_addr);

self.pending_requests.process_requests(&wrapper.requests);
let responses = match self.chain.process_request(&mut wrapper).await {
self.pending_requests
.process_requests(&chain_state.requests);
let responses = match self.chain.process_request(&mut chain_state).await {
Ok(x) => x,
Err(err) => {
let err = err.context("Chain failed to send and/or receive messages, the connection will now be closed.");
Expand All @@ -752,7 +753,7 @@ impl<C: CodecBuilder + 'static> Handler<C> {
}

// if requested by a transform, close connection AFTER sending any responses back to the client
if wrapper.close_client_connection {
if chain_state.close_client_connection {
return Ok(Some(CloseReason::TransformRequested));
}

Expand Down
13 changes: 7 additions & 6 deletions shotover/src/transforms/cassandra/peers_rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use crate::frame::MessageType;
use crate::message::{Message, MessageIdMap, Messages};
use crate::transforms::cassandra::peers_rewrite::CassandraOperation::Event;
use crate::transforms::{
DownChainProtocol, Transform, TransformBuilder, TransformConfig, TransformContextBuilder,
UpChainProtocol, Wrapper,
ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder,
TransformConfig, TransformContextBuilder, UpChainProtocol,
};
use crate::{
frame::{
Expand Down Expand Up @@ -79,18 +79,19 @@ impl Transform for CassandraPeersRewrite {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
requests_wrapper: &'shorter mut Wrapper<'longer>,
chain_state: &mut ChainState,
down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
// Find the indices of queries to system.peers & system.peers_v2
// we need to know which columns in which CQL queries in which messages have system peers
for request in &mut requests_wrapper.requests {
for request in &mut chain_state.requests {
let sys_peers = extract_native_port_column(&self.peer_table, request);
self.column_names_to_rewrite.insert(request.id(), sys_peers);
}

let mut responses = requests_wrapper.call_next_transform().await?;
let mut responses = down_chain.call_next_transform(chain_state).await?;

for response in &mut responses {
if let Some(Frame::Cassandra(frame)) = response.frame() {
Expand Down
11 changes: 6 additions & 5 deletions shotover/src/transforms/cassandra/sink_cluster/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use crate::frame::{CassandraFrame, CassandraOperation, CassandraResult, Frame, M
use crate::message::{Message, MessageIdMap, Messages, Metadata};
use crate::tls::{TlsConnector, TlsConnectorConfig};
use crate::transforms::{
DownChainProtocol, Transform, TransformBuilder, TransformConfig, TransformContextBuilder,
TransformContextConfig, UpChainProtocol, Wrapper,
ChainState, DownChainProtocol, DownChainTransforms, Transform, TransformBuilder,
TransformConfig, TransformContextBuilder, TransformContextConfig, UpChainProtocol,
};
use anyhow::{anyhow, Context, Result};
use async_trait::async_trait;
Expand Down Expand Up @@ -761,11 +761,12 @@ impl Transform for CassandraSinkCluster {
NAME
}

async fn transform<'shorter, 'longer: 'shorter>(
async fn transform(
&mut self,
requests_wrapper: &'shorter mut Wrapper<'longer>,
chain_state: &mut ChainState,
_down_chain: DownChainTransforms<'_>,
) -> Result<Messages> {
self.send_message(std::mem::take(&mut requests_wrapper.requests))
self.send_message(std::mem::take(&mut chain_state.requests))
.await
}
}
Loading

0 comments on commit f93baab

Please sign in to comment.