diff --git a/src/configuration.rs b/src/configuration.rs index b696e2b0..2bdbc9f7 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -486,16 +486,7 @@ impl TryFrom for FilterConfig { let services = config .extensions .into_iter() - .map(|(name, ext)| { - ( - name, - Rc::new(GrpcService::new( - ext.extension_type, - ext.endpoint, - ext.failure_mode, - )), - ) - }) + .map(|(name, ext)| (name, Rc::new(GrpcService::new(Rc::new(ext))))) .collect(); Ok(Self { @@ -505,7 +496,7 @@ impl TryFrom for FilterConfig { } } -#[derive(Deserialize, Debug, Clone, Default)] +#[derive(Deserialize, Debug, Clone, Default, PartialEq)] #[serde(rename_all = "lowercase")] pub enum FailureMode { #[default] @@ -513,7 +504,7 @@ pub enum FailureMode { Allow, } -#[derive(Deserialize, Debug, Clone, Default)] +#[derive(Deserialize, Debug, Clone, Default, PartialEq)] #[serde(rename_all = "lowercase")] pub enum ExtensionType { Auth, @@ -528,7 +519,7 @@ pub struct PluginConfiguration { pub policies: Vec, } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, Default)] #[serde(rename_all = "camelCase")] pub struct Extension { #[serde(rename = "type")] diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs index 0b104402..1584dc05 100644 --- a/src/operation_dispatcher.rs +++ b/src/operation_dispatcher.rs @@ -1,3 +1,4 @@ +use crate::configuration::{Extension, ExtensionType, FailureMode}; use crate::envoy::RateLimitDescriptor; use crate::policy::Policy; use crate::service::{GrpcMessage, GrpcServiceHandler}; @@ -32,15 +33,17 @@ type Procedure = (Rc, GrpcMessage); pub(crate) struct Operation { state: State, result: Result, + extension: Rc, procedure: Procedure, } #[allow(dead_code)] impl Operation { - pub fn new(procedure: Procedure) -> Self { + pub fn new(extension: Rc, procedure: Procedure) -> Self { Self { state: State::Pending, result: Err(Status::Empty), + extension, procedure, } } @@ -57,13 +60,21 @@ impl Operation { } } - fn get_state(&self) -> State { + pub fn get_state(&self) -> State { self.state.clone() } - fn get_result(&self) -> Result { + pub fn get_result(&self) -> Result { self.result } + + pub fn get_extension_type(&self) -> ExtensionType { + self.extension.extension_type.clone() + } + + pub fn get_failure_mode(&self) -> FailureMode { + self.extension.failure_mode.clone() + } } #[allow(dead_code)] @@ -101,7 +112,10 @@ impl OperationDispatcher { policy.domain.clone(), descriptors.clone(), ); - operations.push(Operation::new((service.clone(), message))) + operations.push(Operation::new( + service.get_extension(), + (Rc::clone(service), message), + )) } }); self.push_operations(operations); @@ -174,12 +188,33 @@ mod tests { } } + #[test] + fn operation_getters() { + let extension = Rc::new(Extension::default()); + let operation = Operation::new( + extension, + ( + Rc::new(build_grpc_service_handler()), + GrpcMessage::RateLimit(build_message()), + ), + ); + + assert_eq!(operation.get_state(), State::Pending); + assert_eq!(operation.get_extension_type(), ExtensionType::RateLimit); + assert_eq!(operation.get_failure_mode(), FailureMode::Deny); + assert_eq!(operation.get_result(), Result::Ok(1)); + } + #[test] fn operation_transition() { - let mut operation = Operation::new(( - Rc::new(build_grpc_service_handler()), - GrpcMessage::RateLimit(build_message()), - )); + let extension = Rc::new(Extension::default()); + let mut operation = Operation::new( + extension, + ( + Rc::new(build_grpc_service_handler()), + GrpcMessage::RateLimit(build_message()), + ), + ); assert_eq!(operation.get_state(), State::Pending); operation.trigger(); assert_eq!(operation.get_state(), State::Waiting); @@ -193,11 +228,14 @@ mod tests { let operation_dispatcher = OperationDispatcher::default(); assert_eq!(operation_dispatcher.operations.borrow().len(), 1); - - operation_dispatcher.push_operations(vec![Operation::new(( - Rc::new(build_grpc_service_handler()), - GrpcMessage::RateLimit(build_message()), - ))]); + let extension = Rc::new(Extension::default()); + operation_dispatcher.push_operations(vec![Operation::new( + extension, + ( + Rc::new(build_grpc_service_handler()), + GrpcMessage::RateLimit(build_message()), + ), + )]); assert_eq!(operation_dispatcher.operations.borrow().len(), 2); } @@ -214,10 +252,14 @@ mod tests { #[test] fn operation_dispatcher_next() { - let operation = Operation::new(( - Rc::new(build_grpc_service_handler()), - GrpcMessage::RateLimit(build_message()), - )); + let extension = Rc::new(Extension::default()); + let operation = Operation::new( + extension, + ( + Rc::new(build_grpc_service_handler()), + GrpcMessage::RateLimit(build_message()), + ), + ); let operation_dispatcher = OperationDispatcher::default(); operation_dispatcher.push_operations(vec![operation]); diff --git a/src/service.rs b/src/service.rs index a28ecfde..4bd77c08 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,7 +1,7 @@ pub(crate) mod auth; pub(crate) mod rate_limit; -use crate::configuration::{ExtensionType, FailureMode}; +use crate::configuration::{Extension, ExtensionType, FailureMode}; use crate::envoy::{CheckRequest, RateLimitDescriptor, RateLimitRequest}; use crate::service::auth::{AuthService, AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; use crate::service::rate_limit::{RateLimitService, RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; @@ -147,36 +147,30 @@ impl GrpcMessage { #[derive(Default)] pub struct GrpcService { - endpoint: String, #[allow(dead_code)] - extension_type: ExtensionType, + extension: Rc, name: &'static str, method: &'static str, - failure_mode: FailureMode, } impl GrpcService { - pub fn new(extension_type: ExtensionType, endpoint: String, failure_mode: FailureMode) -> Self { - match extension_type { + pub fn new(extension: Rc) -> Self { + match extension.extension_type { ExtensionType::Auth => Self { - endpoint, - extension_type, + extension, name: AUTH_SERVICE_NAME, method: AUTH_METHOD_NAME, - failure_mode, }, ExtensionType::RateLimit => Self { - endpoint, - extension_type, + extension, name: RATELIMIT_SERVICE_NAME, method: RATELIMIT_METHOD_NAME, - failure_mode, }, } } fn endpoint(&self) -> &str { - &self.endpoint + &self.extension.endpoint } fn name(&self) -> &str { self.name @@ -185,7 +179,7 @@ impl GrpcService { self.method } pub fn failure_mode(&self) -> &FailureMode { - &self.failure_mode + &self.extension.failure_mode } } @@ -236,8 +230,12 @@ impl GrpcServiceHandler { ) } + pub fn get_extension(&self) -> Rc { + Rc::clone(&self.service.extension) + } + pub fn get_extension_type(&self) -> ExtensionType { - self.service.extension_type.clone() + self.service.extension.extension_type.clone() } }