Skip to content

Commit

Permalink
[refactor] Fix OperationDispatcher.next() behaviour
Browse files Browse the repository at this point in the history
* Bonus: Addressed review regarding testing and Fn types

Signed-off-by: dd di cesare <[email protected]>
  • Loading branch information
didierofrivia committed Sep 16, 2024
1 parent 502d62b commit 312710e
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 36 deletions.
77 changes: 49 additions & 28 deletions src/operation_dispatcher.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::configuration::{Extension, ExtensionType, FailureMode};
use crate::envoy::RateLimitDescriptor;
use crate::policy::Policy;
use crate::service::{GetMapValuesBytes, GrpcCall, GrpcMessage, GrpcServiceHandler};
use crate::service::{GetMapValuesBytesFn, GrpcCallFn, GrpcMessage, GrpcServiceHandler};
use protobuf::RepeatedField;
use proxy_wasm::hostcalls;
use proxy_wasm::types::{Bytes, MapType, Status};
Expand Down Expand Up @@ -38,8 +38,8 @@ pub(crate) struct Operation {
result: Result<u32, Status>,
extension: Rc<Extension>,
procedure: Procedure,
grpc_call: GrpcCall,
get_map_values_bytes: GetMapValuesBytes,
grpc_call_fn: GrpcCallFn,
get_map_values_bytes_fn: GetMapValuesBytesFn,
}

#[allow(dead_code)]
Expand All @@ -50,17 +50,17 @@ impl Operation {
result: Err(Status::Empty),
extension,
procedure,
grpc_call,
get_map_values_bytes,
grpc_call_fn,
get_map_values_bytes_fn,
}
}

fn trigger(&mut self) {
if let State::Done = self.state {
} else {
self.result = self.procedure.0.send(
self.get_map_values_bytes,
self.grpc_call,
self.get_map_values_bytes_fn,
self.grpc_call_fn,
self.procedure.1.clone(),
);
self.state.next();
Expand Down Expand Up @@ -147,7 +147,8 @@ impl OperationDispatcher {
let mut operations = self.operations.borrow_mut();
if let Some((i, operation)) = operations.iter_mut().enumerate().next() {
if let State::Done = operation.get_state() {
Some(operations.remove(i))
operations.remove(i);
operations.get(i).cloned() // The next op is now at `i`
} else {
operation.trigger();
Some(operation.clone())
Expand All @@ -158,7 +159,7 @@ impl OperationDispatcher {
}
}

fn grpc_call(
fn grpc_call_fn(
upstream_name: &str,
service_name: &str,
method_name: &str,
Expand All @@ -176,7 +177,7 @@ fn grpc_call(
)
}

fn get_map_values_bytes(map_type: MapType, key: &str) -> Result<Option<Bytes>, Status> {
fn get_map_values_bytes_fn(map_type: MapType, key: &str) -> Result<Option<Bytes>, Status> {
hostcalls::get_map_value_bytes(map_type, key)
}

Expand All @@ -186,7 +187,7 @@ mod tests {
use crate::envoy::RateLimitRequest;
use std::time::Duration;

fn grpc_call(
fn grpc_call_fn_stub(
_upstream_name: &str,
_service_name: &str,
_method_name: &str,
Expand All @@ -197,7 +198,10 @@ mod tests {
Ok(200)
}

fn get_map_values_bytes(_map_type: MapType, _key: &str) -> Result<Option<Bytes>, Status> {
fn get_map_values_bytes_fn_stub(
_map_type: MapType,
_key: &str,
) -> Result<Option<Bytes>, Status> {
Ok(Some(Vec::new()))
}

Expand All @@ -218,14 +222,14 @@ mod tests {
fn build_operation() -> Operation {
Operation {
state: State::Pending,
result: Ok(200),
result: Ok(1),
extension: Rc::new(Extension::default()),
procedure: (
Rc::new(build_grpc_service_handler()),
GrpcMessage::RateLimit(build_message()),
),
grpc_call,
get_map_values_bytes,
grpc_call_fn: grpc_call_fn_stub,
get_map_values_bytes_fn: get_map_values_bytes_fn_stub,
}
}

Expand All @@ -236,7 +240,7 @@ mod tests {
assert_eq!(operation.get_state(), State::Pending);
assert_eq!(operation.get_extension_type(), ExtensionType::RateLimit);
assert_eq!(operation.get_failure_mode(), FailureMode::Deny);
assert_eq!(operation.get_result(), Ok(200));
assert_eq!(operation.get_result(), Ok(1));
}

#[test]
Expand Down Expand Up @@ -272,20 +276,37 @@ mod tests {

#[test]
fn operation_dispatcher_next() {
let operation = build_operation();
let operation_dispatcher = OperationDispatcher::default();
operation_dispatcher.push_operations(vec![operation]);
operation_dispatcher.push_operations(vec![build_operation(), build_operation()]);

if let Some(operation) = operation_dispatcher.next() {
assert_eq!(operation.get_result(), Ok(200));
assert_eq!(operation.get_state(), State::Waiting);
}
assert_eq!(operation_dispatcher.get_current_operation_result(), Ok(1));
assert_eq!(
operation_dispatcher.get_current_operation_state(),
Some(State::Pending)
);

if let Some(operation) = operation_dispatcher.next() {
assert_eq!(operation.get_result(), Ok(200));
assert_eq!(operation.get_state(), State::Done);
}
operation_dispatcher.next();
assert_eq!(operation_dispatcher.get_current_operation_state(), None);
let mut op = operation_dispatcher.next();
assert_eq!(op.clone().unwrap().get_result(), Ok(200));
assert_eq!(op.unwrap().get_state(), State::Waiting);

op = operation_dispatcher.next();
assert_eq!(op.clone().unwrap().get_result(), Ok(200));
assert_eq!(op.unwrap().get_state(), State::Done);

op = operation_dispatcher.next();
assert_eq!(op.clone().unwrap().get_result(), Ok(1));
assert_eq!(op.unwrap().get_state(), State::Pending);

op = operation_dispatcher.next();
assert_eq!(op.clone().unwrap().get_result(), Ok(200));
assert_eq!(op.unwrap().get_state(), State::Waiting);

op = operation_dispatcher.next();
assert_eq!(op.clone().unwrap().get_result(), Ok(200));
assert_eq!(op.unwrap().get_state(), State::Done);

op = operation_dispatcher.next();
assert!(op.is_none());
assert!(operation_dispatcher.get_current_operation_state().is_none());
}
}
16 changes: 8 additions & 8 deletions src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ impl GrpcService {
}
}

pub type GrpcCall = fn(
pub type GrpcCallFn = fn(
upstream_name: &str,
service_name: &str,
method_name: &str,
Expand All @@ -190,7 +190,7 @@ pub type GrpcCall = fn(
timeout: Duration,
) -> Result<u32, Status>;

pub type GetMapValuesBytes = fn(map_type: MapType, key: &str) -> Result<Option<Bytes>, Status>;
pub type GetMapValuesBytesFn = fn(map_type: MapType, key: &str) -> Result<Option<Bytes>, Status>;

pub struct GrpcServiceHandler {
service: Rc<GrpcService>,
Expand All @@ -207,19 +207,19 @@ impl GrpcServiceHandler {

pub fn send(
&self,
get_map_values_bytes: GetMapValuesBytes,
grpc_call: GrpcCall,
get_map_values_bytes_fn: GetMapValuesBytesFn,
grpc_call_fn: GrpcCallFn,
message: GrpcMessage,
) -> Result<u32, Status> {
let msg = Message::write_to_bytes(&message).unwrap();
let metadata = self
.header_resolver
.get(get_map_values_bytes)
.get(get_map_values_bytes_fn)
.iter()
.map(|(header, value)| (*header, value.as_slice()))
.collect();

grpc_call(
grpc_call_fn(
self.service.endpoint(),
self.service.name(),
self.service.method(),
Expand Down Expand Up @@ -255,12 +255,12 @@ impl HeaderResolver {
}
}

pub fn get(&self, get_map_values_bytes: GetMapValuesBytes) -> &Vec<(&'static str, Bytes)> {
pub fn get(&self, get_map_values_bytes_fn: GetMapValuesBytesFn) -> &Vec<(&'static str, Bytes)> {
self.headers.get_or_init(|| {
let mut headers = Vec::new();
for header in TracingHeader::all() {
if let Ok(Some(value)) =
get_map_values_bytes(MapType::HttpRequestHeaders, (*header).as_str())
get_map_values_bytes_fn(MapType::HttpRequestHeaders, (*header).as_str())
{
headers.push(((*header).as_str(), value));
}
Expand Down

0 comments on commit 312710e

Please sign in to comment.