From cfc7932b18104d829caebed5db48f32dce301e47 Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Tue, 27 Aug 2024 16:40:04 +0200 Subject: [PATCH 01/16] [feat] Action dispatcher state machine, naive impl Signed-off-by: dd di cesare --- src/action_dispatcher.rs | 201 +++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 2 files changed, 202 insertions(+) create mode 100644 src/action_dispatcher.rs diff --git a/src/action_dispatcher.rs b/src/action_dispatcher.rs new file mode 100644 index 00000000..d22550d0 --- /dev/null +++ b/src/action_dispatcher.rs @@ -0,0 +1,201 @@ +use std::cell::RefCell; + +#[derive(PartialEq, Debug, Clone)] +pub(crate) enum State { + Pending, + Waiting, + Done, +} + +impl State { + fn next(&mut self) { + match self { + State::Pending => *self = State::Waiting, + State::Waiting => *self = State::Done, + _ => {} + } + } +} +#[derive(PartialEq, Clone)] +pub(crate) enum Action { + Auth { state: State }, + RateLimit { state: State }, +} + +impl Action { + pub fn trigger(&mut self) { + match self { + Action::Auth { .. } => self.auth(), + Action::RateLimit { .. } => self.rate_limit(), + } + } + + fn get_state(&self) -> &State { + match self { + Action::Auth { state } => state, + Action::RateLimit { state } => state, + } + } + + fn rate_limit(&mut self) { + // Specifics for RL, returning State + if let Action::RateLimit { state } = self { + match state { + State::Pending => { + println!("Trigger the request and return State::Waiting"); + state.next(); + } + State::Waiting => { + println!( + "When got on_grpc_response, process RL response and return State::Done" + ); + state.next(); + } + State::Done => { + println!("Done for RL... calling next action (?)"); + } + } + } + } + + fn auth(&mut self) { + // Specifics for Auth, returning State + if let Action::Auth { state } = self { + match state { + State::Pending => { + println!("Trigger the request and return State::Waiting"); + state.next(); + } + State::Waiting => { + println!( + "When got on_grpc_response, process Auth response and return State::Done" + ); + state.next(); + } + State::Done => { + println!("Done for Auth... calling next action (?)"); + } + } + } + } +} + +pub struct ActionDispatcher { + actions: RefCell>, +} + +impl ActionDispatcher { + pub fn default() -> ActionDispatcher { + ActionDispatcher { + actions: RefCell::new(vec![]), + } + } + + pub fn new(/*vec of PluginConfig actions*/) -> ActionDispatcher { + ActionDispatcher::default() + } + + pub fn push_actions(&self, actions: Vec) { + self.actions.borrow_mut().extend(actions); + } + + pub fn get_current_action_state(&self) -> Option { + self.actions + .borrow() + .first() + .map(|action| action.get_state().clone()) + } + + pub fn next(&self) -> bool { + let mut actions = self.actions.borrow_mut(); + if let Some((i, action)) = actions.iter_mut().enumerate().next() { + if let State::Done = action.get_state() { + actions.remove(i); + actions.len() > 0 + } else { + action.trigger(); + true + } + } else { + false + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn action_transition() { + let mut action = Action::Auth { + state: State::Pending, + }; + assert_eq!(*action.get_state(), State::Pending); + action.trigger(); + assert_eq!(*action.get_state(), State::Waiting); + action.trigger(); + assert_eq!(*action.get_state(), State::Done); + } + + #[test] + fn action_dispatcher_push_actions() { + let mut action_dispatcher = ActionDispatcher { + actions: RefCell::new(vec![Action::RateLimit { + state: State::Pending, + }]), + }; + + assert_eq!(action_dispatcher.actions.borrow().len(), 1); + + action_dispatcher.push_actions(vec![Action::Auth { + state: State::Pending, + }]); + + assert_eq!(action_dispatcher.actions.borrow().len(), 2); + } + + #[test] + fn action_dispatcher_get_current_action_state() { + let action_dispatcher = ActionDispatcher { + actions: RefCell::new(vec![Action::RateLimit { + state: State::Waiting, + }]), + }; + + assert_eq!( + action_dispatcher.get_current_action_state(), + Some(State::Waiting) + ); + + let action_dispatcher2 = ActionDispatcher::default(); + + assert_eq!(action_dispatcher2.get_current_action_state(), None); + } + + #[test] + fn action_dispatcher_next() { + let mut action_dispatcher = ActionDispatcher { + actions: RefCell::new(vec![Action::RateLimit { + state: State::Pending, + }]), + }; + let mut res = action_dispatcher.next(); + assert_eq!(res, true); + assert_eq!( + action_dispatcher.get_current_action_state(), + Some(State::Waiting) + ); + + res = action_dispatcher.next(); + assert_eq!(res, true); + assert_eq!( + action_dispatcher.get_current_action_state(), + Some(State::Done) + ); + + res = action_dispatcher.next(); + assert_eq!(res, false); + assert_eq!(action_dispatcher.get_current_action_state(), None); + } +} diff --git a/src/lib.rs b/src/lib.rs index fb1c60aa..eb78d0d2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,4 @@ +mod action_dispatcher; mod attribute; mod configuration; mod envoy; From 1f7b8b3f839f6837b1b611b1bc583e147481a148 Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Fri, 30 Aug 2024 20:15:48 +0200 Subject: [PATCH 02/16] [feat] A simplistic approach, agnostic to extension type Signed-off-by: dd di cesare --- src/action_dispatcher.rs | 129 ++++++++++++++------------------------- 1 file changed, 47 insertions(+), 82 deletions(-) diff --git a/src/action_dispatcher.rs b/src/action_dispatcher.rs index d22550d0..2e93a6cc 100644 --- a/src/action_dispatcher.rs +++ b/src/action_dispatcher.rs @@ -1,3 +1,4 @@ +use proxy_wasm::types::Status; use std::cell::RefCell; #[derive(PartialEq, Debug, Clone)] @@ -16,67 +17,40 @@ impl State { } } } -#[derive(PartialEq, Clone)] -pub(crate) enum Action { - Auth { state: State }, - RateLimit { state: State }, +#[derive(Clone)] +pub(crate) struct Action { + state: State, + result: Result, + operation: Option Result>, } impl Action { - pub fn trigger(&mut self) { - match self { - Action::Auth { .. } => self.auth(), - Action::RateLimit { .. } => self.rate_limit(), + pub fn default() -> Self { + Self { + state: State::Pending, + result: Err(Status::Empty), + operation: None, } } - fn get_state(&self) -> &State { - match self { - Action::Auth { state } => state, - Action::RateLimit { state } => state, - } + pub fn set_operation(&mut self, operation: fn() -> Result) { + self.operation = Some(operation); } - fn rate_limit(&mut self) { - // Specifics for RL, returning State - if let Action::RateLimit { state } = self { - match state { - State::Pending => { - println!("Trigger the request and return State::Waiting"); - state.next(); - } - State::Waiting => { - println!( - "When got on_grpc_response, process RL response and return State::Done" - ); - state.next(); - } - State::Done => { - println!("Done for RL... calling next action (?)"); - } - } + pub fn trigger(&mut self) { + if let State::Done = self.state { + } else if let Some(operation) = self.operation { + self.result = operation(); + self.state.next(); } } - fn auth(&mut self) { - // Specifics for Auth, returning State - if let Action::Auth { state } = self { - match state { - State::Pending => { - println!("Trigger the request and return State::Waiting"); - state.next(); - } - State::Waiting => { - println!( - "When got on_grpc_response, process Auth response and return State::Done" - ); - state.next(); - } - State::Done => { - println!("Done for Auth... calling next action (?)"); - } - } - } + fn get_state(&self) -> State { + self.state.clone() + } + + fn get_result(&self) -> Result { + self.result } } @@ -91,10 +65,6 @@ impl ActionDispatcher { } } - pub fn new(/*vec of PluginConfig actions*/) -> ActionDispatcher { - ActionDispatcher::default() - } - pub fn push_actions(&self, actions: Vec) { self.actions.borrow_mut().extend(actions); } @@ -106,6 +76,10 @@ impl ActionDispatcher { .map(|action| action.get_state().clone()) } + pub fn get_current_action_result(&self) -> Result { + self.actions.borrow().first().unwrap().get_result() + } + pub fn next(&self) -> bool { let mut actions = self.actions.borrow_mut(); if let Some((i, action)) = actions.iter_mut().enumerate().next() { @@ -128,29 +102,25 @@ mod tests { #[test] fn action_transition() { - let mut action = Action::Auth { - state: State::Pending, - }; - assert_eq!(*action.get_state(), State::Pending); + let mut action = Action::default(); + action.set_operation(|| -> Result { Ok(200) }); + assert_eq!(action.get_state(), State::Pending); action.trigger(); - assert_eq!(*action.get_state(), State::Waiting); + assert_eq!(action.get_state(), State::Waiting); action.trigger(); - assert_eq!(*action.get_state(), State::Done); + assert_eq!(action.result, Ok(200)); + assert_eq!(action.get_state(), State::Done); } #[test] fn action_dispatcher_push_actions() { - let mut action_dispatcher = ActionDispatcher { - actions: RefCell::new(vec![Action::RateLimit { - state: State::Pending, - }]), + let action_dispatcher = ActionDispatcher { + actions: RefCell::new(vec![Action::default()]), }; assert_eq!(action_dispatcher.actions.borrow().len(), 1); - action_dispatcher.push_actions(vec![Action::Auth { - state: State::Pending, - }]); + action_dispatcher.push_actions(vec![Action::default()]); assert_eq!(action_dispatcher.actions.borrow().len(), 2); } @@ -158,44 +128,39 @@ mod tests { #[test] fn action_dispatcher_get_current_action_state() { let action_dispatcher = ActionDispatcher { - actions: RefCell::new(vec![Action::RateLimit { - state: State::Waiting, - }]), + actions: RefCell::new(vec![Action::default()]), }; assert_eq!( action_dispatcher.get_current_action_state(), - Some(State::Waiting) + Some(State::Pending) ); - - let action_dispatcher2 = ActionDispatcher::default(); - - assert_eq!(action_dispatcher2.get_current_action_state(), None); } #[test] fn action_dispatcher_next() { - let mut action_dispatcher = ActionDispatcher { - actions: RefCell::new(vec![Action::RateLimit { - state: State::Pending, - }]), + let mut action = Action::default(); + action.set_operation(|| -> Result { Ok(200) }); + let action_dispatcher = ActionDispatcher { + actions: RefCell::new(vec![action]), }; let mut res = action_dispatcher.next(); - assert_eq!(res, true); + assert!(res); assert_eq!( action_dispatcher.get_current_action_state(), Some(State::Waiting) ); res = action_dispatcher.next(); - assert_eq!(res, true); + assert!(res); assert_eq!( action_dispatcher.get_current_action_state(), Some(State::Done) ); + assert_eq!(action_dispatcher.get_current_action_result(), Ok(200)); res = action_dispatcher.next(); - assert_eq!(res, false); + assert!(!res); assert_eq!(action_dispatcher.get_current_action_state(), None); } } From 4ce41486b8e3fd21410300c6d67c4f44f3c265c1 Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Mon, 2 Sep 2024 10:49:52 +0200 Subject: [PATCH 03/16] [tmp] Allowing dead code Signed-off-by: dd di cesare --- src/action_dispatcher.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/action_dispatcher.rs b/src/action_dispatcher.rs index 2e93a6cc..dd43912d 100644 --- a/src/action_dispatcher.rs +++ b/src/action_dispatcher.rs @@ -1,6 +1,7 @@ use proxy_wasm::types::Status; use std::cell::RefCell; +#[allow(dead_code)] #[derive(PartialEq, Debug, Clone)] pub(crate) enum State { Pending, @@ -8,6 +9,7 @@ pub(crate) enum State { Done, } +#[allow(dead_code)] impl State { fn next(&mut self) { match self { @@ -17,6 +19,8 @@ impl State { } } } + +#[allow(dead_code)] #[derive(Clone)] pub(crate) struct Action { state: State, @@ -24,6 +28,7 @@ pub(crate) struct Action { operation: Option Result>, } +#[allow(dead_code)] impl Action { pub fn default() -> Self { Self { @@ -54,10 +59,12 @@ impl Action { } } +#[allow(dead_code)] pub struct ActionDispatcher { actions: RefCell>, } +#[allow(dead_code)] impl ActionDispatcher { pub fn default() -> ActionDispatcher { ActionDispatcher { From 698384d9bd6b76d08481f523474eaa28906a71d0 Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Mon, 2 Sep 2024 15:51:55 +0200 Subject: [PATCH 04/16] [refactor] Changing name to `Operation` instead of `Action` * Could get confusing with proxy_wasm `Actions` * Also with plugin configuration `Action` Signed-off-by: dd di cesare --- src/action_dispatcher.rs | 173 ------------------------------------ src/lib.rs | 2 +- src/operation_dispatcher.rs | 173 ++++++++++++++++++++++++++++++++++++ 3 files changed, 174 insertions(+), 174 deletions(-) delete mode 100644 src/action_dispatcher.rs create mode 100644 src/operation_dispatcher.rs diff --git a/src/action_dispatcher.rs b/src/action_dispatcher.rs deleted file mode 100644 index dd43912d..00000000 --- a/src/action_dispatcher.rs +++ /dev/null @@ -1,173 +0,0 @@ -use proxy_wasm::types::Status; -use std::cell::RefCell; - -#[allow(dead_code)] -#[derive(PartialEq, Debug, Clone)] -pub(crate) enum State { - Pending, - Waiting, - Done, -} - -#[allow(dead_code)] -impl State { - fn next(&mut self) { - match self { - State::Pending => *self = State::Waiting, - State::Waiting => *self = State::Done, - _ => {} - } - } -} - -#[allow(dead_code)] -#[derive(Clone)] -pub(crate) struct Action { - state: State, - result: Result, - operation: Option Result>, -} - -#[allow(dead_code)] -impl Action { - pub fn default() -> Self { - Self { - state: State::Pending, - result: Err(Status::Empty), - operation: None, - } - } - - pub fn set_operation(&mut self, operation: fn() -> Result) { - self.operation = Some(operation); - } - - pub fn trigger(&mut self) { - if let State::Done = self.state { - } else if let Some(operation) = self.operation { - self.result = operation(); - self.state.next(); - } - } - - fn get_state(&self) -> State { - self.state.clone() - } - - fn get_result(&self) -> Result { - self.result - } -} - -#[allow(dead_code)] -pub struct ActionDispatcher { - actions: RefCell>, -} - -#[allow(dead_code)] -impl ActionDispatcher { - pub fn default() -> ActionDispatcher { - ActionDispatcher { - actions: RefCell::new(vec![]), - } - } - - pub fn push_actions(&self, actions: Vec) { - self.actions.borrow_mut().extend(actions); - } - - pub fn get_current_action_state(&self) -> Option { - self.actions - .borrow() - .first() - .map(|action| action.get_state().clone()) - } - - pub fn get_current_action_result(&self) -> Result { - self.actions.borrow().first().unwrap().get_result() - } - - pub fn next(&self) -> bool { - let mut actions = self.actions.borrow_mut(); - if let Some((i, action)) = actions.iter_mut().enumerate().next() { - if let State::Done = action.get_state() { - actions.remove(i); - actions.len() > 0 - } else { - action.trigger(); - true - } - } else { - false - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn action_transition() { - let mut action = Action::default(); - action.set_operation(|| -> Result { Ok(200) }); - assert_eq!(action.get_state(), State::Pending); - action.trigger(); - assert_eq!(action.get_state(), State::Waiting); - action.trigger(); - assert_eq!(action.result, Ok(200)); - assert_eq!(action.get_state(), State::Done); - } - - #[test] - fn action_dispatcher_push_actions() { - let action_dispatcher = ActionDispatcher { - actions: RefCell::new(vec![Action::default()]), - }; - - assert_eq!(action_dispatcher.actions.borrow().len(), 1); - - action_dispatcher.push_actions(vec![Action::default()]); - - assert_eq!(action_dispatcher.actions.borrow().len(), 2); - } - - #[test] - fn action_dispatcher_get_current_action_state() { - let action_dispatcher = ActionDispatcher { - actions: RefCell::new(vec![Action::default()]), - }; - - assert_eq!( - action_dispatcher.get_current_action_state(), - Some(State::Pending) - ); - } - - #[test] - fn action_dispatcher_next() { - let mut action = Action::default(); - action.set_operation(|| -> Result { Ok(200) }); - let action_dispatcher = ActionDispatcher { - actions: RefCell::new(vec![action]), - }; - let mut res = action_dispatcher.next(); - assert!(res); - assert_eq!( - action_dispatcher.get_current_action_state(), - Some(State::Waiting) - ); - - res = action_dispatcher.next(); - assert!(res); - assert_eq!( - action_dispatcher.get_current_action_state(), - Some(State::Done) - ); - assert_eq!(action_dispatcher.get_current_action_result(), Ok(200)); - - res = action_dispatcher.next(); - assert!(!res); - assert_eq!(action_dispatcher.get_current_action_state(), None); - } -} diff --git a/src/lib.rs b/src/lib.rs index eb78d0d2..4cd39f4f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -mod action_dispatcher; +mod operation_dispatcher; mod attribute; mod configuration; mod envoy; diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs new file mode 100644 index 00000000..86c6deac --- /dev/null +++ b/src/operation_dispatcher.rs @@ -0,0 +1,173 @@ +use proxy_wasm::types::Status; +use std::cell::RefCell; + +#[allow(dead_code)] +#[derive(PartialEq, Debug, Clone)] +pub(crate) enum State { + Pending, + Waiting, + Done, +} + +#[allow(dead_code)] +impl State { + fn next(&mut self) { + match self { + State::Pending => *self = State::Waiting, + State::Waiting => *self = State::Done, + _ => {} + } + } +} + +#[allow(dead_code)] +#[derive(Clone)] +pub(crate) struct Operation { + state: State, + result: Result, + action: Option Result>, +} + +#[allow(dead_code)] +impl Operation { + pub fn default() -> Self { + Self { + state: State::Pending, + result: Err(Status::Empty), + action: None, + } + } + + pub fn set_action(&mut self, action: fn() -> Result) { + self.action = Some(action); + } + + pub fn trigger(&mut self) { + if let State::Done = self.state { + } else if let Some(action) = self.action { + self.result = action(); + self.state.next(); + } + } + + fn get_state(&self) -> State { + self.state.clone() + } + + fn get_result(&self) -> Result { + self.result + } +} + +#[allow(dead_code)] +pub struct OperationDispatcher { + operations: RefCell>, +} + +#[allow(dead_code)] +impl OperationDispatcher { + pub fn default() -> OperationDispatcher { + OperationDispatcher { + operations: RefCell::new(vec![]), + } + } + + pub fn push_operations(&self, operations: Vec) { + self.operations.borrow_mut().extend(operations); + } + + pub fn get_current_operation_state(&self) -> Option { + self.operations + .borrow() + .first() + .map(|operation| operation.get_state().clone()) + } + + pub fn get_current_operation_result(&self) -> Result { + self.operations.borrow().first().unwrap().get_result() + } + + pub fn next(&self) -> bool { + let mut operations = self.operations.borrow_mut(); + if let Some((i, operation)) = operations.iter_mut().enumerate().next() { + if let State::Done = operation.get_state() { + operations.remove(i); + operations.len() > 0 + } else { + operation.trigger(); + true + } + } else { + false + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn operation_transition() { + let mut operation = Operation::default(); + operation.set_action(|| -> Result { Ok(200) }); + assert_eq!(operation.get_state(), State::Pending); + operation.trigger(); + assert_eq!(operation.get_state(), State::Waiting); + operation.trigger(); + assert_eq!(operation.result, Ok(200)); + assert_eq!(operation.get_state(), State::Done); + } + + #[test] + fn operation_dispatcher_push_actions() { + let operation_dispatcher = OperationDispatcher { + operations: RefCell::new(vec![Operation::default()]), + }; + + assert_eq!(operation_dispatcher.operations.borrow().len(), 1); + + operation_dispatcher.push_operations(vec![Operation::default()]); + + assert_eq!(operation_dispatcher.operations.borrow().len(), 2); + } + + #[test] + fn operation_dispatcher_get_current_action_state() { + let operation_dispatcher = OperationDispatcher { + operations: RefCell::new(vec![Operation::default()]), + }; + + assert_eq!( + operation_dispatcher.get_current_operation_state(), + Some(State::Pending) + ); + } + + #[test] + fn operation_dispatcher_next() { + let mut operation = Operation::default(); + operation.set_action(|| -> Result { Ok(200) }); + let operation_dispatcher = OperationDispatcher { + operations: RefCell::new(vec![operation]), + }; + let mut res = operation_dispatcher.next(); + assert!(res); + assert_eq!( + operation_dispatcher.get_current_operation_state(), + Some(State::Waiting) + ); + + res = operation_dispatcher.next(); + assert!(res); + assert_eq!( + operation_dispatcher.get_current_operation_state(), + Some(State::Done) + ); + assert_eq!(operation_dispatcher.get_current_operation_result(), Ok(200)); + + res = operation_dispatcher.next(); + assert!(!res); + assert_eq!(operation_dispatcher.get_current_operation_state(), None); + } +} From 94bf5d12f225df8a09922a3901d2ed7b188ffc5b Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Tue, 3 Sep 2024 14:53:38 +0200 Subject: [PATCH 05/16] [refactor] Configuration, adding Actions Signed-off-by: dd di cesare --- src/configuration.rs | 102 ++++++++++++++++++++++++++++++++++++++---- src/policy.rs | 12 ++++- src/policy_index.rs | 8 +++- tests/rate_limited.rs | 52 +++++++++++++++++++-- 4 files changed, 159 insertions(+), 15 deletions(-) diff --git a/src/configuration.rs b/src/configuration.rs index 554c086d..174b7e83 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -513,10 +513,11 @@ pub enum FailureMode { Allow, } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, Default)] #[serde(rename_all = "lowercase")] pub enum ExtensionType { Auth, + #[default] RateLimit, } @@ -537,6 +538,14 @@ pub struct Extension { pub failure_mode: FailureMode, } +#[derive(Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct Action { + pub extension: String, + #[allow(dead_code)] + pub data: DataType, +} + #[cfg(test)] mod test { use super::*; @@ -587,7 +596,18 @@ mod test { "selector": "auth.metadata.username" } }] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; @@ -682,7 +702,18 @@ mod test { "default": "my_selector_default_value" } }] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; let res = serde_json::from_str::(config); @@ -759,7 +790,18 @@ mod test { }] }], "data": [ { "selector": { "selector": "my.selector.path" } }] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; let res = serde_json::from_str::(config); @@ -825,7 +867,18 @@ mod test { "selector": "auth.metadata.username" } }] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; let res = serde_json::from_str::(config); @@ -872,7 +925,18 @@ mod test { "selector": "auth.metadata.username" } }] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; let res = serde_json::from_str::(bad_config); @@ -902,7 +966,18 @@ mod test { "value": "1" } }] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; let res = serde_json::from_str::(bad_config); @@ -934,7 +1009,18 @@ mod test { }] }], "data": [ { "selector": { "selector": "my.selector.path" } }] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; let res = serde_json::from_str::(bad_config); diff --git a/src/policy.rs b/src/policy.rs index 7505b8fd..c7ee430a 100644 --- a/src/policy.rs +++ b/src/policy.rs @@ -1,5 +1,5 @@ use crate::attribute::Attribute; -use crate::configuration::{DataItem, DataType, PatternExpression}; +use crate::configuration::{Action, DataItem, DataType, PatternExpression}; use crate::envoy::{RateLimitDescriptor, RateLimitDescriptor_Entry}; use crate::filter::http_context::Filter; use log::debug; @@ -28,16 +28,24 @@ pub struct Policy { pub domain: String, pub hostnames: Vec, pub rules: Vec, + pub actions: Vec, } impl Policy { #[cfg(test)] - pub fn new(name: String, domain: String, hostnames: Vec, rules: Vec) -> Self { + pub fn new( + name: String, + domain: String, + hostnames: Vec, + rules: Vec, + actions: Vec, + ) -> Self { Policy { name, domain, hostnames, rules, + actions, } } diff --git a/src/policy_index.rs b/src/policy_index.rs index 58b31d94..3620179e 100644 --- a/src/policy_index.rs +++ b/src/policy_index.rs @@ -41,7 +41,13 @@ mod tests { use crate::policy_index::PolicyIndex; fn build_ratelimit_policy(name: &str) -> Policy { - Policy::new(name.to_owned(), "".to_owned(), Vec::new(), Vec::new()) + Policy::new( + name.to_owned(), + "".to_owned(), + Vec::new(), + Vec::new(), + Vec::new(), + ) } #[test] diff --git a/tests/rate_limited.rs b/tests/rate_limited.rs index d1421350..6bc3168f 100644 --- a/tests/rate_limited.rs +++ b/tests/rate_limited.rs @@ -132,7 +132,18 @@ fn it_limits() { } } ] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; @@ -275,7 +286,18 @@ fn it_passes_additional_headers() { } } ] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; @@ -412,7 +434,18 @@ fn it_rate_limits_with_empty_conditions() { } } ] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; @@ -528,7 +561,18 @@ fn it_does_not_rate_limits_when_selector_does_not_exist_and_misses_default_value } } ] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; From 63a211a33be35d85c5a2b0b7a3c0d0361d51a196 Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Tue, 3 Sep 2024 14:55:09 +0200 Subject: [PATCH 06/16] [wip, refactor] GrpcServiceHandler builds message * GrpcMessage type created Signed-off-by: dd di cesare --- src/service.rs | 61 +++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 55 insertions(+), 6 deletions(-) diff --git a/src/service.rs b/src/service.rs index e6b13d61..4d553b77 100644 --- a/src/service.rs +++ b/src/service.rs @@ -2,8 +2,9 @@ pub(crate) mod auth; pub(crate) mod rate_limit; use crate::configuration::{ExtensionType, FailureMode}; +use crate::envoy::{RateLimitDescriptor, RateLimitRequest}; use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; -use crate::service::rate_limit::{RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; +use crate::service::rate_limit::{RateLimitService, RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate}; use protobuf::Message; use proxy_wasm::hostcalls; @@ -13,9 +14,26 @@ use std::cell::OnceCell; use std::rc::Rc; use std::time::Duration; +#[derive(Clone)] +pub enum GrpcMessage { + //Auth(CheckRequest), + RateLimit(RateLimitRequest), +} + +impl GrpcMessage { + pub fn get_message(&self) -> &RateLimitRequest { + //TODO(didierofrivia): Should return Message + match self { + GrpcMessage::RateLimit(message) => message, + } + } +} + #[derive(Default)] pub struct GrpcService { endpoint: String, + #[allow(dead_code)] + extension_type: ExtensionType, name: &'static str, method: &'static str, failure_mode: FailureMode, @@ -26,18 +44,21 @@ impl GrpcService { match extension_type { ExtensionType::Auth => Self { endpoint, + extension_type, name: AUTH_SERVICE_NAME, method: AUTH_METHOD_NAME, failure_mode, }, ExtensionType::RateLimit => Self { endpoint, + extension_type, name: RATELIMIT_SERVICE_NAME, method: RATELIMIT_METHOD_NAME, failure_mode, }, } } + fn endpoint(&self) -> &str { &self.endpoint } @@ -52,22 +73,36 @@ impl GrpcService { } } -#[derive(Default)] +type GrpcCall = fn( + upstream_name: &str, + service_name: &str, + method_name: &str, + initial_metadata: Vec<(&str, &[u8])>, + message: Option<&[u8]>, + timeout: Duration, +) -> Result; + pub struct GrpcServiceHandler { service: Rc, header_resolver: Rc, + grpc_call: GrpcCall, } impl GrpcServiceHandler { - pub fn new(service: Rc, header_resolver: Rc) -> Self { + pub fn new( + service: Rc, + header_resolver: Rc, + grpc_call: Option, + ) -> Self { Self { service, header_resolver, + grpc_call: grpc_call.unwrap_or(dispatch_grpc_call), } } - pub fn send(&self, message: M) -> Result { - let msg = Message::write_to_bytes(&message).unwrap(); + pub fn send(&self, message: GrpcMessage) -> Result { + let msg = Message::write_to_bytes(message.get_message()).unwrap(); let metadata = self .header_resolver .get() @@ -75,7 +110,7 @@ impl GrpcServiceHandler { .map(|(header, value)| (*header, value.as_slice())) .collect(); - dispatch_grpc_call( + (self.grpc_call)( self.service.endpoint(), self.service.name(), self.service.method(), @@ -84,6 +119,20 @@ impl GrpcServiceHandler { Duration::from_secs(5), ) } + + // Using domain as ce_host for the time being, we might pass a DataType in the future. + //TODO(didierofrivia): Make it work with Message. for both Auth and RL + pub fn build_message( + &self, + domain: String, + descriptors: protobuf::RepeatedField, + ) -> GrpcMessage { + /*match self.service.extension_type { + //ExtensionType::Auth => GrpcMessage::Auth(AuthService::message(domain.clone())), + //ExtensionType::RateLimit => GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)), + }*/ + GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)) + } } pub struct HeaderResolver { From 172bdbd7421a450c8b058ff622c29fa8098a412c Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Tue, 3 Sep 2024 14:56:09 +0200 Subject: [PATCH 07/16] [clean] Removing obsolete code Signed-off-by: dd di cesare --- src/service/rate_limit.rs | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index 6dfc3c89..b6b0357c 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -27,8 +27,6 @@ mod tests { use crate::service::rate_limit::RateLimitService; //use crate::service::Service; use protobuf::{CachedSize, RepeatedField, UnknownFields}; - //use proxy_wasm::types::Status; - //use crate::filter::http_context::{Filter}; fn build_message() -> RateLimitRequest { let domain = "rlp1"; @@ -52,20 +50,4 @@ mod tests { assert_eq!(msg.unknown_fields, UnknownFields::default()); assert_eq!(msg.cached_size, CachedSize::default()); } - /*#[test] - fn sends_message() { - let msg = build_message(); - let metadata = vec![("header-1", "value-1".as_bytes())]; - let rls = RateLimitService::new("limitador-cluster", metadata); - - // TODO(didierofrivia): When we have a grpc response type, assert the async response - } - - fn grpc_call( - _upstream_name: &str, - _initial_metadata: Vec<(&str, &[u8])>, - _message: RateLimitRequest, - ) -> Result { - Ok(1) - } */ } From eb8b9246c380f1bb8327f7edc7a98bf8b280125c Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Tue, 3 Sep 2024 14:56:49 +0200 Subject: [PATCH 08/16] [refactor] OperationDispatcher triggering procedures Signed-off-by: dd di cesare --- src/lib.rs | 2 +- src/operation_dispatcher.rs | 126 +++++++++++++++++++++++++++--------- 2 files changed, 97 insertions(+), 31 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 4cd39f4f..a7a500a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,9 @@ -mod operation_dispatcher; mod attribute; mod configuration; mod envoy; mod filter; mod glob; +mod operation_dispatcher; mod policy; mod policy_index; mod service; diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs index 86c6deac..91efc538 100644 --- a/src/operation_dispatcher.rs +++ b/src/operation_dispatcher.rs @@ -1,5 +1,11 @@ +use crate::envoy::RateLimitDescriptor; +use crate::policy::Policy; +use crate::service::{GrpcMessage, GrpcServiceHandler}; +use protobuf::RepeatedField; use proxy_wasm::types::Status; use std::cell::RefCell; +use std::collections::HashMap; +use std::rc::Rc; #[allow(dead_code)] #[derive(PartialEq, Debug, Clone)] @@ -20,32 +26,33 @@ impl State { } } +type Procedure = (Rc, GrpcMessage); + #[allow(dead_code)] -#[derive(Clone)] pub(crate) struct Operation { state: State, result: Result, - action: Option Result>, + procedure: Procedure, } #[allow(dead_code)] impl Operation { - pub fn default() -> Self { + pub fn new(procedure: Procedure) -> Self { Self { state: State::Pending, result: Err(Status::Empty), - action: None, + procedure, } } - pub fn set_action(&mut self, action: fn() -> Result) { - self.action = Some(action); + pub fn set_action(&mut self, procedure: Procedure) { + self.procedure = procedure; } pub fn trigger(&mut self) { if let State::Done = self.state { - } else if let Some(action) = self.action { - self.result = action(); + } else { + self.result = self.procedure.0.send(self.procedure.1.clone()); self.state.next(); } } @@ -62,15 +69,39 @@ impl Operation { #[allow(dead_code)] pub struct OperationDispatcher { operations: RefCell>, + service_handlers: HashMap>, } #[allow(dead_code)] impl OperationDispatcher { - pub fn default() -> OperationDispatcher { + pub fn default() -> Self { OperationDispatcher { operations: RefCell::new(vec![]), + service_handlers: HashMap::default(), } } + pub fn new(service_handlers: HashMap>) -> Self { + Self { + service_handlers, + operations: RefCell::new(vec![]), + } + } + + pub fn build_operations( + &self, + policy: &Policy, + descriptors: RepeatedField, + ) { + let mut operations: Vec = vec![]; + policy.actions.iter().for_each(|action| { + // TODO(didierofrivia): Error handling + if let Some(service) = self.service_handlers.get(&action.extension) { + let message = service.build_message(policy.domain.clone(), descriptors.clone()); + operations.push(Operation::new((service.clone(), message))) + } + }); + self.push_operations(operations); + } pub fn push_operations(&self, operations: Vec) { self.operations.borrow_mut().extend(operations); @@ -87,18 +118,19 @@ impl OperationDispatcher { self.operations.borrow().first().unwrap().get_result() } - pub fn next(&self) -> bool { + pub fn next(&self) -> Option<(State, Result)> { let mut operations = self.operations.borrow_mut(); if let Some((i, operation)) = operations.iter_mut().enumerate().next() { if let State::Done = operation.get_state() { + let res = operation.get_result(); operations.remove(i); - operations.len() > 0 + Some((State::Done, res)) } else { operation.trigger(); - true + Some((operation.state.clone(), operation.result)) } } else { - false + None } } } @@ -106,11 +138,44 @@ impl OperationDispatcher { #[cfg(test)] mod tests { use super::*; + use crate::envoy::RateLimitRequest; + use std::time::Duration; + + fn grpc_call( + _upstream_name: &str, + _service_name: &str, + _method_name: &str, + _initial_metadata: Vec<(&str, &[u8])>, + _message: Option<&[u8]>, + _timeout: Duration, + ) -> Result { + Ok(1) + } + + fn build_grpc_service_handler() -> GrpcServiceHandler { + GrpcServiceHandler::new( + Rc::new(Default::default()), + Rc::new(Default::default()), + Some(grpc_call), + ) + } + + fn build_message() -> RateLimitRequest { + RateLimitRequest { + domain: "example.org".to_string(), + descriptors: RepeatedField::new(), + hits_addend: 1, + unknown_fields: Default::default(), + cached_size: Default::default(), + } + } #[test] fn operation_transition() { - let mut operation = Operation::default(); - operation.set_action(|| -> Result { Ok(200) }); + let mut operation = Operation::new(( + Rc::new(build_grpc_service_handler()), + GrpcMessage::RateLimit(build_message()), + )); assert_eq!(operation.get_state(), State::Pending); operation.trigger(); assert_eq!(operation.get_state(), State::Waiting); @@ -121,22 +186,21 @@ mod tests { #[test] fn operation_dispatcher_push_actions() { - let operation_dispatcher = OperationDispatcher { - operations: RefCell::new(vec![Operation::default()]), - }; + let operation_dispatcher = OperationDispatcher::default(); assert_eq!(operation_dispatcher.operations.borrow().len(), 1); - operation_dispatcher.push_operations(vec![Operation::default()]); + operation_dispatcher.push_operations(vec![Operation::new(( + Rc::new(build_grpc_service_handler()), + GrpcMessage::RateLimit(build_message()), + ))]); assert_eq!(operation_dispatcher.operations.borrow().len(), 2); } #[test] fn operation_dispatcher_get_current_action_state() { - let operation_dispatcher = OperationDispatcher { - operations: RefCell::new(vec![Operation::default()]), - }; + let operation_dispatcher = OperationDispatcher::default(); assert_eq!( operation_dispatcher.get_current_operation_state(), @@ -146,20 +210,22 @@ mod tests { #[test] fn operation_dispatcher_next() { - let mut operation = Operation::default(); - operation.set_action(|| -> Result { Ok(200) }); - let operation_dispatcher = OperationDispatcher { - operations: RefCell::new(vec![operation]), - }; + let operation = Operation::new(( + Rc::new(build_grpc_service_handler()), + GrpcMessage::RateLimit(build_message()), + )); + let operation_dispatcher = OperationDispatcher::default(); + operation_dispatcher.push_operations(vec![operation]); + let mut res = operation_dispatcher.next(); - assert!(res); + assert_eq!(res, Some((State::Waiting, Ok(200)))); assert_eq!( operation_dispatcher.get_current_operation_state(), Some(State::Waiting) ); res = operation_dispatcher.next(); - assert!(res); + assert_eq!(res, Some((State::Done, Ok(200)))); assert_eq!( operation_dispatcher.get_current_operation_state(), Some(State::Done) @@ -167,7 +233,7 @@ mod tests { assert_eq!(operation_dispatcher.get_current_operation_result(), Ok(200)); res = operation_dispatcher.next(); - assert!(!res); + assert_eq!(res, None); assert_eq!(operation_dispatcher.get_current_operation_state(), None); } } From 6dc8aeea97bba5bcaaa10343b0d5602196958102 Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Tue, 3 Sep 2024 14:57:26 +0200 Subject: [PATCH 09/16] [refactor] Wiring up altogether Signed-off-by: dd di cesare --- src/filter/http_context.rs | 51 ++++++++++++++++++-------------------- src/filter/root_context.rs | 19 ++++++++++++-- 2 files changed, 41 insertions(+), 29 deletions(-) diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index fa80e75c..1d2a140b 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -1,8 +1,7 @@ use crate::configuration::{FailureMode, FilterConfig}; use crate::envoy::{RateLimitResponse, RateLimitResponse_Code}; +use crate::operation_dispatcher::OperationDispatcher; use crate::policy::Policy; -use crate::service::rate_limit::RateLimitService; -use crate::service::{GrpcServiceHandler, HeaderResolver}; use log::{debug, warn}; use protobuf::Message; use proxy_wasm::traits::{Context, HttpContext}; @@ -13,7 +12,7 @@ pub struct Filter { pub context_id: u32, pub config: Rc, pub response_headers_to_add: Vec<(String, String)>, - pub header_resolver: Rc, + pub operation_dispatcher: OperationDispatcher, } impl Filter { @@ -40,33 +39,31 @@ impl Filter { return Action::Continue; } - // todo(adam-cattermole): For now we just get the first GrpcService but we expect to have - // an action which links to the service that should be used - let rls = self - .config - .services - .values() - .next() - .expect("expect a value"); - - let handler = GrpcServiceHandler::new(Rc::clone(rls), Rc::clone(&self.header_resolver)); - let message = RateLimitService::message(rlp.domain.clone(), descriptors); + // Build Actions from config actions + self.operation_dispatcher.build_operations(rlp, descriptors); + // populate actions in the dispatcher + // call the next on the match - match handler.send(message) { - Ok(call_id) => { - debug!( - "#{} initiated gRPC call (id# {}) to Limitador", - self.context_id, call_id - ); - Action::Pause - } - Err(e) => { - warn!("gRPC call to Limitador failed! {e:?}"); - if let FailureMode::Deny = rls.failure_mode() { - self.send_http_response(500, vec![], Some(b"Internal Server Error.\n")) + if let Some(result) = self.operation_dispatcher.next() { + match result { + (_state, Ok(call_id)) => { + debug!( + "#{} initiated gRPC call (id# {}) to Limitador", + self.context_id, call_id + ); + Action::Pause + } + (_state, Err(e)) => { + warn!("gRPC call to Limitador failed! {e:?}"); + // TODO(didierofrivia): Get the failure_mode + /*if let FailureMode::Deny = rls.failure_mode() { + self.send_http_response(500, vec![], Some(b"Internal Server Error.\n")) + } */ + Action::Continue } - Action::Continue } + } else { + Action::Continue } } diff --git a/src/filter/root_context.rs b/src/filter/root_context.rs index 90774e1c..b4fafb77 100644 --- a/src/filter/root_context.rs +++ b/src/filter/root_context.rs @@ -1,10 +1,12 @@ use crate::configuration::{FilterConfig, PluginConfiguration}; use crate::filter::http_context::Filter; -use crate::service::HeaderResolver; +use crate::operation_dispatcher::OperationDispatcher; +use crate::service::{GrpcServiceHandler, HeaderResolver}; use const_format::formatcp; use log::{debug, error, info}; use proxy_wasm::traits::{Context, HttpContext, RootContext}; use proxy_wasm::types::ContextType; +use std::collections::HashMap; use std::rc::Rc; const WASM_SHIM_VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -37,11 +39,24 @@ impl RootContext for FilterRoot { fn create_http_context(&self, context_id: u32) -> Option> { debug!("#{} create_http_context", context_id); + let mut service_handlers: HashMap> = HashMap::new(); + self.config + .services + .iter() + .for_each(|(extension, service)| { + service_handlers + .entry(extension.clone()) + .or_insert(Rc::from(GrpcServiceHandler::new( + Rc::clone(service), + Rc::new(HeaderResolver::new()), + None, + ))); + }); Some(Box::new(Filter { context_id, config: Rc::clone(&self.config), response_headers_to_add: Vec::default(), - header_resolver: Rc::new(HeaderResolver::new()), + operation_dispatcher: OperationDispatcher::new(service_handlers), })) } From fb9ff4521aab99f7585d5cc807dfdef072a08e82 Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Wed, 4 Sep 2024 15:11:19 +0200 Subject: [PATCH 10/16] [refactor] Implementing own Message for GrpcMessage Signed-off-by: dd di cesare --- src/operation_dispatcher.rs | 6 +- src/service.rs | 150 ++++++++++++++++++++++++++++++------ 2 files changed, 133 insertions(+), 23 deletions(-) diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs index 91efc538..0b104402 100644 --- a/src/operation_dispatcher.rs +++ b/src/operation_dispatcher.rs @@ -96,7 +96,11 @@ impl OperationDispatcher { policy.actions.iter().for_each(|action| { // TODO(didierofrivia): Error handling if let Some(service) = self.service_handlers.get(&action.extension) { - let message = service.build_message(policy.domain.clone(), descriptors.clone()); + let message = GrpcMessage::new( + service.get_extension_type(), + policy.domain.clone(), + descriptors.clone(), + ); operations.push(Operation::new((service.clone(), message))) } }); diff --git a/src/service.rs b/src/service.rs index 4d553b77..a28ecfde 100644 --- a/src/service.rs +++ b/src/service.rs @@ -2,29 +2,145 @@ pub(crate) mod auth; pub(crate) mod rate_limit; use crate::configuration::{ExtensionType, FailureMode}; -use crate::envoy::{RateLimitDescriptor, RateLimitRequest}; -use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; +use crate::envoy::{CheckRequest, RateLimitDescriptor, RateLimitRequest}; +use crate::service::auth::{AuthService, AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; use crate::service::rate_limit::{RateLimitService, RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate}; -use protobuf::Message; +use protobuf::reflect::MessageDescriptor; +use protobuf::{ + Clear, CodedInputStream, CodedOutputStream, Message, ProtobufResult, UnknownFields, +}; use proxy_wasm::hostcalls; use proxy_wasm::hostcalls::dispatch_grpc_call; use proxy_wasm::types::{Bytes, MapType, Status}; +use std::any::Any; use std::cell::OnceCell; +use std::fmt::Debug; use std::rc::Rc; use std::time::Duration; -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum GrpcMessage { - //Auth(CheckRequest), + Auth(CheckRequest), RateLimit(RateLimitRequest), } -impl GrpcMessage { - pub fn get_message(&self) -> &RateLimitRequest { - //TODO(didierofrivia): Should return Message +impl Default for GrpcMessage { + fn default() -> Self { + GrpcMessage::RateLimit(RateLimitRequest::new()) + } +} + +impl Clear for GrpcMessage { + fn clear(&mut self) { match self { - GrpcMessage::RateLimit(message) => message, + GrpcMessage::Auth(msg) => msg.clear(), + GrpcMessage::RateLimit(msg) => msg.clear(), + } + } +} + +impl Message for GrpcMessage { + fn descriptor(&self) -> &'static MessageDescriptor { + match self { + GrpcMessage::Auth(msg) => msg.descriptor(), + GrpcMessage::RateLimit(msg) => msg.descriptor(), + } + } + + fn is_initialized(&self) -> bool { + match self { + GrpcMessage::Auth(msg) => msg.is_initialized(), + GrpcMessage::RateLimit(msg) => msg.is_initialized(), + } + } + + fn merge_from(&mut self, is: &mut CodedInputStream) -> ProtobufResult<()> { + match self { + GrpcMessage::Auth(msg) => msg.merge_from(is), + GrpcMessage::RateLimit(msg) => msg.merge_from(is), + } + } + + fn write_to_with_cached_sizes(&self, os: &mut CodedOutputStream) -> ProtobufResult<()> { + match self { + GrpcMessage::Auth(msg) => msg.write_to_with_cached_sizes(os), + GrpcMessage::RateLimit(msg) => msg.write_to_with_cached_sizes(os), + } + } + + fn write_to_bytes(&self) -> ProtobufResult> { + match self { + GrpcMessage::Auth(msg) => msg.write_to_bytes(), + GrpcMessage::RateLimit(msg) => msg.write_to_bytes(), + } + } + + fn compute_size(&self) -> u32 { + match self { + GrpcMessage::Auth(msg) => msg.compute_size(), + GrpcMessage::RateLimit(msg) => msg.compute_size(), + } + } + + fn get_cached_size(&self) -> u32 { + match self { + GrpcMessage::Auth(msg) => msg.get_cached_size(), + GrpcMessage::RateLimit(msg) => msg.get_cached_size(), + } + } + + fn get_unknown_fields(&self) -> &UnknownFields { + match self { + GrpcMessage::Auth(msg) => msg.get_unknown_fields(), + GrpcMessage::RateLimit(msg) => msg.get_unknown_fields(), + } + } + + fn mut_unknown_fields(&mut self) -> &mut UnknownFields { + match self { + GrpcMessage::Auth(msg) => msg.mut_unknown_fields(), + GrpcMessage::RateLimit(msg) => msg.mut_unknown_fields(), + } + } + + fn as_any(&self) -> &dyn Any { + match self { + GrpcMessage::Auth(msg) => msg.as_any(), + GrpcMessage::RateLimit(msg) => msg.as_any(), + } + } + + fn new() -> Self + where + Self: Sized, + { + // Returning default value + GrpcMessage::default() + } + + fn default_instance() -> &'static Self + where + Self: Sized, + { + #[allow(non_upper_case_globals)] + static instance: ::protobuf::rt::LazyV2 = ::protobuf::rt::LazyV2::INIT; + instance.get(|| GrpcMessage::RateLimit(RateLimitRequest::new())) + } +} + +impl GrpcMessage { + // Using domain as ce_host for the time being, we might pass a DataType in the future. + pub fn new( + extension_type: ExtensionType, + domain: String, + descriptors: protobuf::RepeatedField, + ) -> Self { + match extension_type { + ExtensionType::RateLimit => { + GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)) + } + ExtensionType::Auth => GrpcMessage::Auth(AuthService::message(domain.clone())), } } } @@ -102,7 +218,7 @@ impl GrpcServiceHandler { } pub fn send(&self, message: GrpcMessage) -> Result { - let msg = Message::write_to_bytes(message.get_message()).unwrap(); + let msg = Message::write_to_bytes(&message).unwrap(); let metadata = self .header_resolver .get() @@ -120,18 +236,8 @@ impl GrpcServiceHandler { ) } - // Using domain as ce_host for the time being, we might pass a DataType in the future. - //TODO(didierofrivia): Make it work with Message. for both Auth and RL - pub fn build_message( - &self, - domain: String, - descriptors: protobuf::RepeatedField, - ) -> GrpcMessage { - /*match self.service.extension_type { - //ExtensionType::Auth => GrpcMessage::Auth(AuthService::message(domain.clone())), - //ExtensionType::RateLimit => GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)), - }*/ - GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)) + pub fn get_extension_type(&self) -> ExtensionType { + self.service.extension_type.clone() } } From 64510b03c41f5d584cb5e81ef04ce64fc659bf23 Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Wed, 4 Sep 2024 16:37:15 +0200 Subject: [PATCH 11/16] [refactor] Inlucing Extension within Service and Operation as Rc Signed-off-by: dd di cesare --- src/configuration.rs | 17 ++------- src/operation_dispatcher.rs | 76 ++++++++++++++++++++++++++++--------- src/service.rs | 28 +++++++------- 3 files changed, 76 insertions(+), 45 deletions(-) diff --git a/src/configuration.rs b/src/configuration.rs index 174b7e83..e026254d 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -486,16 +486,7 @@ impl TryFrom for FilterConfig { let services = config .extensions .into_iter() - .map(|(name, ext)| { - ( - name, - Rc::new(GrpcService::new( - ext.extension_type, - ext.endpoint, - ext.failure_mode, - )), - ) - }) + .map(|(name, ext)| (name, Rc::new(GrpcService::new(Rc::new(ext))))) .collect(); Ok(Self { @@ -505,7 +496,7 @@ impl TryFrom for FilterConfig { } } -#[derive(Deserialize, Debug, Clone, Default)] +#[derive(Deserialize, Debug, Clone, Default, PartialEq)] #[serde(rename_all = "lowercase")] pub enum FailureMode { #[default] @@ -513,7 +504,7 @@ pub enum FailureMode { Allow, } -#[derive(Deserialize, Debug, Clone, Default)] +#[derive(Deserialize, Debug, Clone, Default, PartialEq)] #[serde(rename_all = "lowercase")] pub enum ExtensionType { Auth, @@ -528,7 +519,7 @@ pub struct PluginConfiguration { pub policies: Vec, } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, Default)] #[serde(rename_all = "camelCase")] pub struct Extension { #[serde(rename = "type")] diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs index 0b104402..1584dc05 100644 --- a/src/operation_dispatcher.rs +++ b/src/operation_dispatcher.rs @@ -1,3 +1,4 @@ +use crate::configuration::{Extension, ExtensionType, FailureMode}; use crate::envoy::RateLimitDescriptor; use crate::policy::Policy; use crate::service::{GrpcMessage, GrpcServiceHandler}; @@ -32,15 +33,17 @@ type Procedure = (Rc, GrpcMessage); pub(crate) struct Operation { state: State, result: Result, + extension: Rc, procedure: Procedure, } #[allow(dead_code)] impl Operation { - pub fn new(procedure: Procedure) -> Self { + pub fn new(extension: Rc, procedure: Procedure) -> Self { Self { state: State::Pending, result: Err(Status::Empty), + extension, procedure, } } @@ -57,13 +60,21 @@ impl Operation { } } - fn get_state(&self) -> State { + pub fn get_state(&self) -> State { self.state.clone() } - fn get_result(&self) -> Result { + pub fn get_result(&self) -> Result { self.result } + + pub fn get_extension_type(&self) -> ExtensionType { + self.extension.extension_type.clone() + } + + pub fn get_failure_mode(&self) -> FailureMode { + self.extension.failure_mode.clone() + } } #[allow(dead_code)] @@ -101,7 +112,10 @@ impl OperationDispatcher { policy.domain.clone(), descriptors.clone(), ); - operations.push(Operation::new((service.clone(), message))) + operations.push(Operation::new( + service.get_extension(), + (Rc::clone(service), message), + )) } }); self.push_operations(operations); @@ -174,12 +188,33 @@ mod tests { } } + #[test] + fn operation_getters() { + let extension = Rc::new(Extension::default()); + let operation = Operation::new( + extension, + ( + Rc::new(build_grpc_service_handler()), + GrpcMessage::RateLimit(build_message()), + ), + ); + + assert_eq!(operation.get_state(), State::Pending); + assert_eq!(operation.get_extension_type(), ExtensionType::RateLimit); + assert_eq!(operation.get_failure_mode(), FailureMode::Deny); + assert_eq!(operation.get_result(), Result::Ok(1)); + } + #[test] fn operation_transition() { - let mut operation = Operation::new(( - Rc::new(build_grpc_service_handler()), - GrpcMessage::RateLimit(build_message()), - )); + let extension = Rc::new(Extension::default()); + let mut operation = Operation::new( + extension, + ( + Rc::new(build_grpc_service_handler()), + GrpcMessage::RateLimit(build_message()), + ), + ); assert_eq!(operation.get_state(), State::Pending); operation.trigger(); assert_eq!(operation.get_state(), State::Waiting); @@ -193,11 +228,14 @@ mod tests { let operation_dispatcher = OperationDispatcher::default(); assert_eq!(operation_dispatcher.operations.borrow().len(), 1); - - operation_dispatcher.push_operations(vec![Operation::new(( - Rc::new(build_grpc_service_handler()), - GrpcMessage::RateLimit(build_message()), - ))]); + let extension = Rc::new(Extension::default()); + operation_dispatcher.push_operations(vec![Operation::new( + extension, + ( + Rc::new(build_grpc_service_handler()), + GrpcMessage::RateLimit(build_message()), + ), + )]); assert_eq!(operation_dispatcher.operations.borrow().len(), 2); } @@ -214,10 +252,14 @@ mod tests { #[test] fn operation_dispatcher_next() { - let operation = Operation::new(( - Rc::new(build_grpc_service_handler()), - GrpcMessage::RateLimit(build_message()), - )); + let extension = Rc::new(Extension::default()); + let operation = Operation::new( + extension, + ( + Rc::new(build_grpc_service_handler()), + GrpcMessage::RateLimit(build_message()), + ), + ); let operation_dispatcher = OperationDispatcher::default(); operation_dispatcher.push_operations(vec![operation]); diff --git a/src/service.rs b/src/service.rs index a28ecfde..4bd77c08 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,7 +1,7 @@ pub(crate) mod auth; pub(crate) mod rate_limit; -use crate::configuration::{ExtensionType, FailureMode}; +use crate::configuration::{Extension, ExtensionType, FailureMode}; use crate::envoy::{CheckRequest, RateLimitDescriptor, RateLimitRequest}; use crate::service::auth::{AuthService, AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; use crate::service::rate_limit::{RateLimitService, RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; @@ -147,36 +147,30 @@ impl GrpcMessage { #[derive(Default)] pub struct GrpcService { - endpoint: String, #[allow(dead_code)] - extension_type: ExtensionType, + extension: Rc, name: &'static str, method: &'static str, - failure_mode: FailureMode, } impl GrpcService { - pub fn new(extension_type: ExtensionType, endpoint: String, failure_mode: FailureMode) -> Self { - match extension_type { + pub fn new(extension: Rc) -> Self { + match extension.extension_type { ExtensionType::Auth => Self { - endpoint, - extension_type, + extension, name: AUTH_SERVICE_NAME, method: AUTH_METHOD_NAME, - failure_mode, }, ExtensionType::RateLimit => Self { - endpoint, - extension_type, + extension, name: RATELIMIT_SERVICE_NAME, method: RATELIMIT_METHOD_NAME, - failure_mode, }, } } fn endpoint(&self) -> &str { - &self.endpoint + &self.extension.endpoint } fn name(&self) -> &str { self.name @@ -185,7 +179,7 @@ impl GrpcService { self.method } pub fn failure_mode(&self) -> &FailureMode { - &self.failure_mode + &self.extension.failure_mode } } @@ -236,8 +230,12 @@ impl GrpcServiceHandler { ) } + pub fn get_extension(&self) -> Rc { + Rc::clone(&self.service.extension) + } + pub fn get_extension_type(&self) -> ExtensionType { - self.service.extension_type.clone() + self.service.extension.extension_type.clone() } } From f0b06489ddbaa5cd2351e993bc1a283d3c991726 Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Wed, 4 Sep 2024 16:51:46 +0200 Subject: [PATCH 12/16] [refactor] OperationDispatcher.next() returns Option Signed-off-by: dd di cesare --- src/operation_dispatcher.rs | 36 ++++++++++++++---------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs index 1584dc05..41480f91 100644 --- a/src/operation_dispatcher.rs +++ b/src/operation_dispatcher.rs @@ -30,6 +30,7 @@ impl State { type Procedure = (Rc, GrpcMessage); #[allow(dead_code)] +#[derive(Clone)] pub(crate) struct Operation { state: State, result: Result, @@ -136,16 +137,14 @@ impl OperationDispatcher { self.operations.borrow().first().unwrap().get_result() } - pub fn next(&self) -> Option<(State, Result)> { + pub fn next(&self) -> Option { let mut operations = self.operations.borrow_mut(); if let Some((i, operation)) = operations.iter_mut().enumerate().next() { if let State::Done = operation.get_state() { - let res = operation.get_result(); - operations.remove(i); - Some((State::Done, res)) + Some(operations.remove(i)) } else { operation.trigger(); - Some((operation.state.clone(), operation.result)) + Some(operation.clone()) } } else { None @@ -167,7 +166,7 @@ mod tests { _message: Option<&[u8]>, _timeout: Duration, ) -> Result { - Ok(1) + Ok(200) } fn build_grpc_service_handler() -> GrpcServiceHandler { @@ -263,23 +262,16 @@ mod tests { let operation_dispatcher = OperationDispatcher::default(); operation_dispatcher.push_operations(vec![operation]); - let mut res = operation_dispatcher.next(); - assert_eq!(res, Some((State::Waiting, Ok(200)))); - assert_eq!( - operation_dispatcher.get_current_operation_state(), - Some(State::Waiting) - ); - - res = operation_dispatcher.next(); - assert_eq!(res, Some((State::Done, Ok(200)))); - assert_eq!( - operation_dispatcher.get_current_operation_state(), - Some(State::Done) - ); - assert_eq!(operation_dispatcher.get_current_operation_result(), Ok(200)); + if let Some(operation) = operation_dispatcher.next() { + assert_eq!(operation.get_result(), Ok(200)); + assert_eq!(operation.get_state(), State::Waiting); + } - res = operation_dispatcher.next(); - assert_eq!(res, None); + if let Some(operation) = operation_dispatcher.next() { + assert_eq!(operation.get_result(), Ok(200)); + assert_eq!(operation.get_state(), State::Done); + } + operation_dispatcher.next(); assert_eq!(operation_dispatcher.get_current_operation_state(), None); } } From a2c69f190bb25f45bf1275e68fbad3adba93ac1f Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Wed, 4 Sep 2024 16:52:32 +0200 Subject: [PATCH 13/16] [refactor] Wiring up with the new API Signed-off-by: dd di cesare --- src/filter/http_context.rs | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index 1d2a140b..2290266d 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -39,26 +39,22 @@ impl Filter { return Action::Continue; } - // Build Actions from config actions self.operation_dispatcher.build_operations(rlp, descriptors); - // populate actions in the dispatcher - // call the next on the match - if let Some(result) = self.operation_dispatcher.next() { - match result { - (_state, Ok(call_id)) => { + if let Some(operation) = self.operation_dispatcher.next() { + match operation.get_result() { + Ok(call_id) => { debug!( "#{} initiated gRPC call (id# {}) to Limitador", self.context_id, call_id ); Action::Pause } - (_state, Err(e)) => { + Err(e) => { warn!("gRPC call to Limitador failed! {e:?}"); - // TODO(didierofrivia): Get the failure_mode - /*if let FailureMode::Deny = rls.failure_mode() { + if let FailureMode::Deny = operation.get_failure_mode() { self.send_http_response(500, vec![], Some(b"Internal Server Error.\n")) - } */ + } Action::Continue } } From 159d247a55716e4897bb1b70a05f07ea3f2c6ee5 Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Wed, 4 Sep 2024 17:47:38 +0200 Subject: [PATCH 14/16] [refactor] grpc_call function delegated to the caller Signed-off-by: dd di cesare --- src/filter/root_context.rs | 1 - src/operation_dispatcher.rs | 28 ++++++++++++++++++++++------ src/service.rs | 13 +++---------- 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/filter/root_context.rs b/src/filter/root_context.rs index b4fafb77..6dcd4bdc 100644 --- a/src/filter/root_context.rs +++ b/src/filter/root_context.rs @@ -49,7 +49,6 @@ impl RootContext for FilterRoot { .or_insert(Rc::from(GrpcServiceHandler::new( Rc::clone(service), Rc::new(HeaderResolver::new()), - None, ))); }); Some(Box::new(Filter { diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs index 41480f91..61e12b8a 100644 --- a/src/operation_dispatcher.rs +++ b/src/operation_dispatcher.rs @@ -3,10 +3,12 @@ use crate::envoy::RateLimitDescriptor; use crate::policy::Policy; use crate::service::{GrpcMessage, GrpcServiceHandler}; use protobuf::RepeatedField; +use proxy_wasm::hostcalls; use proxy_wasm::types::Status; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; +use std::time::Duration; #[allow(dead_code)] #[derive(PartialEq, Debug, Clone)] @@ -27,6 +29,24 @@ impl State { } } +fn grpc_call( + upstream_name: &str, + service_name: &str, + method_name: &str, + initial_metadata: Vec<(&str, &[u8])>, + message: Option<&[u8]>, + timeout: Duration, +) -> Result { + hostcalls::dispatch_grpc_call( + upstream_name, + service_name, + method_name, + initial_metadata, + message, + timeout, + ) +} + type Procedure = (Rc, GrpcMessage); #[allow(dead_code)] @@ -56,7 +76,7 @@ impl Operation { pub fn trigger(&mut self) { if let State::Done = self.state { } else { - self.result = self.procedure.0.send(self.procedure.1.clone()); + self.result = self.procedure.0.send(grpc_call, self.procedure.1.clone()); self.state.next(); } } @@ -170,11 +190,7 @@ mod tests { } fn build_grpc_service_handler() -> GrpcServiceHandler { - GrpcServiceHandler::new( - Rc::new(Default::default()), - Rc::new(Default::default()), - Some(grpc_call), - ) + GrpcServiceHandler::new(Rc::new(Default::default()), Rc::new(Default::default())) } fn build_message() -> RateLimitRequest { diff --git a/src/service.rs b/src/service.rs index 4bd77c08..38fe2fce 100644 --- a/src/service.rs +++ b/src/service.rs @@ -11,7 +11,6 @@ use protobuf::{ Clear, CodedInputStream, CodedOutputStream, Message, ProtobufResult, UnknownFields, }; use proxy_wasm::hostcalls; -use proxy_wasm::hostcalls::dispatch_grpc_call; use proxy_wasm::types::{Bytes, MapType, Status}; use std::any::Any; use std::cell::OnceCell; @@ -195,23 +194,17 @@ type GrpcCall = fn( pub struct GrpcServiceHandler { service: Rc, header_resolver: Rc, - grpc_call: GrpcCall, } impl GrpcServiceHandler { - pub fn new( - service: Rc, - header_resolver: Rc, - grpc_call: Option, - ) -> Self { + pub fn new(service: Rc, header_resolver: Rc) -> Self { Self { service, header_resolver, - grpc_call: grpc_call.unwrap_or(dispatch_grpc_call), } } - pub fn send(&self, message: GrpcMessage) -> Result { + pub fn send(&self, grpc_call: GrpcCall, message: GrpcMessage) -> Result { let msg = Message::write_to_bytes(&message).unwrap(); let metadata = self .header_resolver @@ -220,7 +213,7 @@ impl GrpcServiceHandler { .map(|(header, value)| (*header, value.as_slice())) .collect(); - (self.grpc_call)( + grpc_call( self.service.endpoint(), self.service.name(), self.service.method(), From 41b920d81b65e62884a4e6e5fdb42aa83fb70e1b Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Wed, 4 Sep 2024 19:05:17 +0200 Subject: [PATCH 15/16] [refactor] Operation responsible of providing hostcalls fns * Easier to test, mocking fn * Assigned fn on creation, default hostcall and mock on tests Signed-off-by: dd di cesare --- src/operation_dispatcher.rs | 120 ++++++++++++++++++------------------ src/service.rs | 18 ++++-- 2 files changed, 71 insertions(+), 67 deletions(-) diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs index 61e12b8a..1d855a8f 100644 --- a/src/operation_dispatcher.rs +++ b/src/operation_dispatcher.rs @@ -1,10 +1,10 @@ use crate::configuration::{Extension, ExtensionType, FailureMode}; use crate::envoy::RateLimitDescriptor; use crate::policy::Policy; -use crate::service::{GrpcMessage, GrpcServiceHandler}; +use crate::service::{GetMapValuesBytes, GrpcCall, GrpcMessage, GrpcServiceHandler}; use protobuf::RepeatedField; use proxy_wasm::hostcalls; -use proxy_wasm::types::Status; +use proxy_wasm::types::{Bytes, MapType, Status}; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; @@ -29,24 +29,6 @@ impl State { } } -fn grpc_call( - upstream_name: &str, - service_name: &str, - method_name: &str, - initial_metadata: Vec<(&str, &[u8])>, - message: Option<&[u8]>, - timeout: Duration, -) -> Result { - hostcalls::dispatch_grpc_call( - upstream_name, - service_name, - method_name, - initial_metadata, - message, - timeout, - ) -} - type Procedure = (Rc, GrpcMessage); #[allow(dead_code)] @@ -56,6 +38,8 @@ pub(crate) struct Operation { result: Result, extension: Rc, procedure: Procedure, + grpc_call: GrpcCall, + get_map_values_bytes: GetMapValuesBytes, } #[allow(dead_code)] @@ -66,17 +50,19 @@ impl Operation { result: Err(Status::Empty), extension, procedure, + grpc_call, + get_map_values_bytes, } } - pub fn set_action(&mut self, procedure: Procedure) { - self.procedure = procedure; - } - - pub fn trigger(&mut self) { + fn trigger(&mut self) { if let State::Done = self.state { } else { - self.result = self.procedure.0.send(grpc_call, self.procedure.1.clone()); + self.result = self.procedure.0.send( + self.get_map_values_bytes, + self.grpc_call, + self.procedure.1.clone(), + ); self.state.next(); } } @@ -172,6 +158,28 @@ impl OperationDispatcher { } } +fn grpc_call( + upstream_name: &str, + service_name: &str, + method_name: &str, + initial_metadata: Vec<(&str, &[u8])>, + message: Option<&[u8]>, + timeout: Duration, +) -> Result { + hostcalls::dispatch_grpc_call( + upstream_name, + service_name, + method_name, + initial_metadata, + message, + timeout, + ) +} + +fn get_map_values_bytes(map_type: MapType, key: &str) -> Result, Status> { + hostcalls::get_map_value_bytes(map_type, key) +} + #[cfg(test)] mod tests { use super::*; @@ -189,6 +197,10 @@ mod tests { Ok(200) } + fn get_map_values_bytes(_map_type: MapType, _key: &str) -> Result, Status> { + Ok(Some(Vec::new())) + } + fn build_grpc_service_handler() -> GrpcServiceHandler { GrpcServiceHandler::new(Rc::new(Default::default()), Rc::new(Default::default())) } @@ -203,33 +215,33 @@ mod tests { } } - #[test] - fn operation_getters() { - let extension = Rc::new(Extension::default()); - let operation = Operation::new( - extension, - ( + fn build_operation() -> Operation { + Operation { + state: State::Pending, + result: Ok(200), + extension: Rc::new(Extension::default()), + procedure: ( Rc::new(build_grpc_service_handler()), GrpcMessage::RateLimit(build_message()), ), - ); + grpc_call, + get_map_values_bytes, + } + } + + #[test] + fn operation_getters() { + let operation = build_operation(); assert_eq!(operation.get_state(), State::Pending); assert_eq!(operation.get_extension_type(), ExtensionType::RateLimit); assert_eq!(operation.get_failure_mode(), FailureMode::Deny); - assert_eq!(operation.get_result(), Result::Ok(1)); + assert_eq!(operation.get_result(), Ok(200)); } #[test] fn operation_transition() { - let extension = Rc::new(Extension::default()); - let mut operation = Operation::new( - extension, - ( - Rc::new(build_grpc_service_handler()), - GrpcMessage::RateLimit(build_message()), - ), - ); + let mut operation = build_operation(); assert_eq!(operation.get_state(), State::Pending); operation.trigger(); assert_eq!(operation.get_state(), State::Waiting); @@ -242,23 +254,16 @@ mod tests { fn operation_dispatcher_push_actions() { let operation_dispatcher = OperationDispatcher::default(); - assert_eq!(operation_dispatcher.operations.borrow().len(), 1); - let extension = Rc::new(Extension::default()); - operation_dispatcher.push_operations(vec![Operation::new( - extension, - ( - Rc::new(build_grpc_service_handler()), - GrpcMessage::RateLimit(build_message()), - ), - )]); + assert_eq!(operation_dispatcher.operations.borrow().len(), 0); + operation_dispatcher.push_operations(vec![build_operation()]); - assert_eq!(operation_dispatcher.operations.borrow().len(), 2); + assert_eq!(operation_dispatcher.operations.borrow().len(), 1); } #[test] fn operation_dispatcher_get_current_action_state() { let operation_dispatcher = OperationDispatcher::default(); - + operation_dispatcher.push_operations(vec![build_operation()]); assert_eq!( operation_dispatcher.get_current_operation_state(), Some(State::Pending) @@ -267,14 +272,7 @@ mod tests { #[test] fn operation_dispatcher_next() { - let extension = Rc::new(Extension::default()); - let operation = Operation::new( - extension, - ( - Rc::new(build_grpc_service_handler()), - GrpcMessage::RateLimit(build_message()), - ), - ); + let operation = build_operation(); let operation_dispatcher = OperationDispatcher::default(); operation_dispatcher.push_operations(vec![operation]); diff --git a/src/service.rs b/src/service.rs index 38fe2fce..aec2aff0 100644 --- a/src/service.rs +++ b/src/service.rs @@ -10,7 +10,6 @@ use protobuf::reflect::MessageDescriptor; use protobuf::{ Clear, CodedInputStream, CodedOutputStream, Message, ProtobufResult, UnknownFields, }; -use proxy_wasm::hostcalls; use proxy_wasm::types::{Bytes, MapType, Status}; use std::any::Any; use std::cell::OnceCell; @@ -182,7 +181,7 @@ impl GrpcService { } } -type GrpcCall = fn( +pub type GrpcCall = fn( upstream_name: &str, service_name: &str, method_name: &str, @@ -191,6 +190,8 @@ type GrpcCall = fn( timeout: Duration, ) -> Result; +pub type GetMapValuesBytes = fn(map_type: MapType, key: &str) -> Result, Status>; + pub struct GrpcServiceHandler { service: Rc, header_resolver: Rc, @@ -204,11 +205,16 @@ impl GrpcServiceHandler { } } - pub fn send(&self, grpc_call: GrpcCall, message: GrpcMessage) -> Result { + pub fn send( + &self, + get_map_values_bytes: GetMapValuesBytes, + grpc_call: GrpcCall, + message: GrpcMessage, + ) -> Result { let msg = Message::write_to_bytes(&message).unwrap(); let metadata = self .header_resolver - .get() + .get(get_map_values_bytes) .iter() .map(|(header, value)| (*header, value.as_slice())) .collect(); @@ -249,12 +255,12 @@ impl HeaderResolver { } } - pub fn get(&self) -> &Vec<(&'static str, Bytes)> { + pub fn get(&self, get_map_values_bytes: GetMapValuesBytes) -> &Vec<(&'static str, Bytes)> { self.headers.get_or_init(|| { let mut headers = Vec::new(); for header in TracingHeader::all() { if let Ok(Some(value)) = - hostcalls::get_map_value_bytes(MapType::HttpRequestHeaders, (*header).as_str()) + get_map_values_bytes(MapType::HttpRequestHeaders, (*header).as_str()) { headers.push(((*header).as_str(), value)); } From c22103675ca636947eee741d39b1c6f1ebb73ead Mon Sep 17 00:00:00 2001 From: dd di cesare Date: Fri, 6 Sep 2024 11:29:10 +0200 Subject: [PATCH 16/16] [refactor] Fix `OperationDispatcher.next()` behaviour * Bonus: Addressed review regarding testing and Fn types Signed-off-by: dd di cesare --- src/operation_dispatcher.rs | 77 +++++++++++++++++++++++-------------- src/service.rs | 16 ++++---- 2 files changed, 57 insertions(+), 36 deletions(-) diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs index 1d855a8f..a9ab3c18 100644 --- a/src/operation_dispatcher.rs +++ b/src/operation_dispatcher.rs @@ -1,7 +1,7 @@ use crate::configuration::{Extension, ExtensionType, FailureMode}; use crate::envoy::RateLimitDescriptor; use crate::policy::Policy; -use crate::service::{GetMapValuesBytes, GrpcCall, GrpcMessage, GrpcServiceHandler}; +use crate::service::{GetMapValuesBytesFn, GrpcCallFn, GrpcMessage, GrpcServiceHandler}; use protobuf::RepeatedField; use proxy_wasm::hostcalls; use proxy_wasm::types::{Bytes, MapType, Status}; @@ -38,8 +38,8 @@ pub(crate) struct Operation { result: Result, extension: Rc, procedure: Procedure, - grpc_call: GrpcCall, - get_map_values_bytes: GetMapValuesBytes, + grpc_call_fn: GrpcCallFn, + get_map_values_bytes_fn: GetMapValuesBytesFn, } #[allow(dead_code)] @@ -50,8 +50,8 @@ impl Operation { result: Err(Status::Empty), extension, procedure, - grpc_call, - get_map_values_bytes, + grpc_call_fn, + get_map_values_bytes_fn, } } @@ -59,8 +59,8 @@ impl Operation { if let State::Done = self.state { } else { self.result = self.procedure.0.send( - self.get_map_values_bytes, - self.grpc_call, + self.get_map_values_bytes_fn, + self.grpc_call_fn, self.procedure.1.clone(), ); self.state.next(); @@ -147,7 +147,8 @@ impl OperationDispatcher { let mut operations = self.operations.borrow_mut(); if let Some((i, operation)) = operations.iter_mut().enumerate().next() { if let State::Done = operation.get_state() { - Some(operations.remove(i)) + operations.remove(i); + operations.get(i).cloned() // The next op is now at `i` } else { operation.trigger(); Some(operation.clone()) @@ -158,7 +159,7 @@ impl OperationDispatcher { } } -fn grpc_call( +fn grpc_call_fn( upstream_name: &str, service_name: &str, method_name: &str, @@ -176,7 +177,7 @@ fn grpc_call( ) } -fn get_map_values_bytes(map_type: MapType, key: &str) -> Result, Status> { +fn get_map_values_bytes_fn(map_type: MapType, key: &str) -> Result, Status> { hostcalls::get_map_value_bytes(map_type, key) } @@ -186,7 +187,7 @@ mod tests { use crate::envoy::RateLimitRequest; use std::time::Duration; - fn grpc_call( + fn grpc_call_fn_stub( _upstream_name: &str, _service_name: &str, _method_name: &str, @@ -197,7 +198,10 @@ mod tests { Ok(200) } - fn get_map_values_bytes(_map_type: MapType, _key: &str) -> Result, Status> { + fn get_map_values_bytes_fn_stub( + _map_type: MapType, + _key: &str, + ) -> Result, Status> { Ok(Some(Vec::new())) } @@ -218,14 +222,14 @@ mod tests { fn build_operation() -> Operation { Operation { state: State::Pending, - result: Ok(200), + result: Ok(1), extension: Rc::new(Extension::default()), procedure: ( Rc::new(build_grpc_service_handler()), GrpcMessage::RateLimit(build_message()), ), - grpc_call, - get_map_values_bytes, + grpc_call_fn: grpc_call_fn_stub, + get_map_values_bytes_fn: get_map_values_bytes_fn_stub, } } @@ -236,7 +240,7 @@ mod tests { assert_eq!(operation.get_state(), State::Pending); assert_eq!(operation.get_extension_type(), ExtensionType::RateLimit); assert_eq!(operation.get_failure_mode(), FailureMode::Deny); - assert_eq!(operation.get_result(), Ok(200)); + assert_eq!(operation.get_result(), Ok(1)); } #[test] @@ -272,20 +276,37 @@ mod tests { #[test] fn operation_dispatcher_next() { - let operation = build_operation(); let operation_dispatcher = OperationDispatcher::default(); - operation_dispatcher.push_operations(vec![operation]); + operation_dispatcher.push_operations(vec![build_operation(), build_operation()]); - if let Some(operation) = operation_dispatcher.next() { - assert_eq!(operation.get_result(), Ok(200)); - assert_eq!(operation.get_state(), State::Waiting); - } + assert_eq!(operation_dispatcher.get_current_operation_result(), Ok(1)); + assert_eq!( + operation_dispatcher.get_current_operation_state(), + Some(State::Pending) + ); - if let Some(operation) = operation_dispatcher.next() { - assert_eq!(operation.get_result(), Ok(200)); - assert_eq!(operation.get_state(), State::Done); - } - operation_dispatcher.next(); - assert_eq!(operation_dispatcher.get_current_operation_state(), None); + let mut op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(200)); + assert_eq!(op.unwrap().get_state(), State::Waiting); + + op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(200)); + assert_eq!(op.unwrap().get_state(), State::Done); + + op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(1)); + assert_eq!(op.unwrap().get_state(), State::Pending); + + op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(200)); + assert_eq!(op.unwrap().get_state(), State::Waiting); + + op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(200)); + assert_eq!(op.unwrap().get_state(), State::Done); + + op = operation_dispatcher.next(); + assert!(op.is_none()); + assert!(operation_dispatcher.get_current_operation_state().is_none()); } } diff --git a/src/service.rs b/src/service.rs index aec2aff0..e89077f2 100644 --- a/src/service.rs +++ b/src/service.rs @@ -181,7 +181,7 @@ impl GrpcService { } } -pub type GrpcCall = fn( +pub type GrpcCallFn = fn( upstream_name: &str, service_name: &str, method_name: &str, @@ -190,7 +190,7 @@ pub type GrpcCall = fn( timeout: Duration, ) -> Result; -pub type GetMapValuesBytes = fn(map_type: MapType, key: &str) -> Result, Status>; +pub type GetMapValuesBytesFn = fn(map_type: MapType, key: &str) -> Result, Status>; pub struct GrpcServiceHandler { service: Rc, @@ -207,19 +207,19 @@ impl GrpcServiceHandler { pub fn send( &self, - get_map_values_bytes: GetMapValuesBytes, - grpc_call: GrpcCall, + get_map_values_bytes_fn: GetMapValuesBytesFn, + grpc_call_fn: GrpcCallFn, message: GrpcMessage, ) -> Result { let msg = Message::write_to_bytes(&message).unwrap(); let metadata = self .header_resolver - .get(get_map_values_bytes) + .get(get_map_values_bytes_fn) .iter() .map(|(header, value)| (*header, value.as_slice())) .collect(); - grpc_call( + grpc_call_fn( self.service.endpoint(), self.service.name(), self.service.method(), @@ -255,12 +255,12 @@ impl HeaderResolver { } } - pub fn get(&self, get_map_values_bytes: GetMapValuesBytes) -> &Vec<(&'static str, Bytes)> { + pub fn get(&self, get_map_values_bytes_fn: GetMapValuesBytesFn) -> &Vec<(&'static str, Bytes)> { self.headers.get_or_init(|| { let mut headers = Vec::new(); for header in TracingHeader::all() { if let Ok(Some(value)) = - get_map_values_bytes(MapType::HttpRequestHeaders, (*header).as_str()) + get_map_values_bytes_fn(MapType::HttpRequestHeaders, (*header).as_str()) { headers.push(((*header).as_str(), value)); }