diff --git a/src/configuration.rs b/src/configuration.rs index 554c086d..e026254d 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -486,16 +486,7 @@ impl TryFrom for FilterConfig { let services = config .extensions .into_iter() - .map(|(name, ext)| { - ( - name, - Rc::new(GrpcService::new( - ext.extension_type, - ext.endpoint, - ext.failure_mode, - )), - ) - }) + .map(|(name, ext)| (name, Rc::new(GrpcService::new(Rc::new(ext))))) .collect(); Ok(Self { @@ -505,7 +496,7 @@ impl TryFrom for FilterConfig { } } -#[derive(Deserialize, Debug, Clone, Default)] +#[derive(Deserialize, Debug, Clone, Default, PartialEq)] #[serde(rename_all = "lowercase")] pub enum FailureMode { #[default] @@ -513,10 +504,11 @@ pub enum FailureMode { Allow, } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, Default, PartialEq)] #[serde(rename_all = "lowercase")] pub enum ExtensionType { Auth, + #[default] RateLimit, } @@ -527,7 +519,7 @@ pub struct PluginConfiguration { pub policies: Vec, } -#[derive(Deserialize, Debug, Clone)] +#[derive(Deserialize, Debug, Clone, Default)] #[serde(rename_all = "camelCase")] pub struct Extension { #[serde(rename = "type")] @@ -537,6 +529,14 @@ pub struct Extension { pub failure_mode: FailureMode, } +#[derive(Deserialize, Debug, Clone)] +#[serde(rename_all = "camelCase")] +pub struct Action { + pub extension: String, + #[allow(dead_code)] + pub data: DataType, +} + #[cfg(test)] mod test { use super::*; @@ -587,7 +587,18 @@ mod test { "selector": "auth.metadata.username" } }] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; @@ -682,7 +693,18 @@ mod test { "default": "my_selector_default_value" } }] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; let res = serde_json::from_str::(config); @@ -759,7 +781,18 @@ mod test { }] }], "data": [ { "selector": { "selector": "my.selector.path" } }] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; let res = serde_json::from_str::(config); @@ -825,7 +858,18 @@ mod test { "selector": "auth.metadata.username" } }] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; let res = serde_json::from_str::(config); @@ -872,7 +916,18 @@ mod test { "selector": "auth.metadata.username" } }] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; let res = serde_json::from_str::(bad_config); @@ -902,7 +957,18 @@ mod test { "value": "1" } }] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; let res = serde_json::from_str::(bad_config); @@ -934,7 +1000,18 @@ mod test { }] }], "data": [ { "selector": { "selector": "my.selector.path" } }] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; let res = serde_json::from_str::(bad_config); diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index fa80e75c..2290266d 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -1,8 +1,7 @@ use crate::configuration::{FailureMode, FilterConfig}; use crate::envoy::{RateLimitResponse, RateLimitResponse_Code}; +use crate::operation_dispatcher::OperationDispatcher; use crate::policy::Policy; -use crate::service::rate_limit::RateLimitService; -use crate::service::{GrpcServiceHandler, HeaderResolver}; use log::{debug, warn}; use protobuf::Message; use proxy_wasm::traits::{Context, HttpContext}; @@ -13,7 +12,7 @@ pub struct Filter { pub context_id: u32, pub config: Rc, pub response_headers_to_add: Vec<(String, String)>, - pub header_resolver: Rc, + pub operation_dispatcher: OperationDispatcher, } impl Filter { @@ -40,33 +39,27 @@ impl Filter { return Action::Continue; } - // todo(adam-cattermole): For now we just get the first GrpcService but we expect to have - // an action which links to the service that should be used - let rls = self - .config - .services - .values() - .next() - .expect("expect a value"); - - let handler = GrpcServiceHandler::new(Rc::clone(rls), Rc::clone(&self.header_resolver)); - let message = RateLimitService::message(rlp.domain.clone(), descriptors); + self.operation_dispatcher.build_operations(rlp, descriptors); - match handler.send(message) { - Ok(call_id) => { - debug!( - "#{} initiated gRPC call (id# {}) to Limitador", - self.context_id, call_id - ); - Action::Pause - } - Err(e) => { - warn!("gRPC call to Limitador failed! {e:?}"); - if let FailureMode::Deny = rls.failure_mode() { - self.send_http_response(500, vec![], Some(b"Internal Server Error.\n")) + if let Some(operation) = self.operation_dispatcher.next() { + match operation.get_result() { + Ok(call_id) => { + debug!( + "#{} initiated gRPC call (id# {}) to Limitador", + self.context_id, call_id + ); + Action::Pause + } + Err(e) => { + warn!("gRPC call to Limitador failed! {e:?}"); + if let FailureMode::Deny = operation.get_failure_mode() { + self.send_http_response(500, vec![], Some(b"Internal Server Error.\n")) + } + Action::Continue } - Action::Continue } + } else { + Action::Continue } } diff --git a/src/filter/root_context.rs b/src/filter/root_context.rs index 90774e1c..6dcd4bdc 100644 --- a/src/filter/root_context.rs +++ b/src/filter/root_context.rs @@ -1,10 +1,12 @@ use crate::configuration::{FilterConfig, PluginConfiguration}; use crate::filter::http_context::Filter; -use crate::service::HeaderResolver; +use crate::operation_dispatcher::OperationDispatcher; +use crate::service::{GrpcServiceHandler, HeaderResolver}; use const_format::formatcp; use log::{debug, error, info}; use proxy_wasm::traits::{Context, HttpContext, RootContext}; use proxy_wasm::types::ContextType; +use std::collections::HashMap; use std::rc::Rc; const WASM_SHIM_VERSION: &str = env!("CARGO_PKG_VERSION"); @@ -37,11 +39,23 @@ impl RootContext for FilterRoot { fn create_http_context(&self, context_id: u32) -> Option> { debug!("#{} create_http_context", context_id); + let mut service_handlers: HashMap> = HashMap::new(); + self.config + .services + .iter() + .for_each(|(extension, service)| { + service_handlers + .entry(extension.clone()) + .or_insert(Rc::from(GrpcServiceHandler::new( + Rc::clone(service), + Rc::new(HeaderResolver::new()), + ))); + }); Some(Box::new(Filter { context_id, config: Rc::clone(&self.config), response_headers_to_add: Vec::default(), - header_resolver: Rc::new(HeaderResolver::new()), + operation_dispatcher: OperationDispatcher::new(service_handlers), })) } diff --git a/src/lib.rs b/src/lib.rs index fb1c60aa..a7a500a3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,6 +3,7 @@ mod configuration; mod envoy; mod filter; mod glob; +mod operation_dispatcher; mod policy; mod policy_index; mod service; diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs new file mode 100644 index 00000000..a9ab3c18 --- /dev/null +++ b/src/operation_dispatcher.rs @@ -0,0 +1,312 @@ +use crate::configuration::{Extension, ExtensionType, FailureMode}; +use crate::envoy::RateLimitDescriptor; +use crate::policy::Policy; +use crate::service::{GetMapValuesBytesFn, GrpcCallFn, GrpcMessage, GrpcServiceHandler}; +use protobuf::RepeatedField; +use proxy_wasm::hostcalls; +use proxy_wasm::types::{Bytes, MapType, Status}; +use std::cell::RefCell; +use std::collections::HashMap; +use std::rc::Rc; +use std::time::Duration; + +#[allow(dead_code)] +#[derive(PartialEq, Debug, Clone)] +pub(crate) enum State { + Pending, + Waiting, + Done, +} + +#[allow(dead_code)] +impl State { + fn next(&mut self) { + match self { + State::Pending => *self = State::Waiting, + State::Waiting => *self = State::Done, + _ => {} + } + } +} + +type Procedure = (Rc, GrpcMessage); + +#[allow(dead_code)] +#[derive(Clone)] +pub(crate) struct Operation { + state: State, + result: Result, + extension: Rc, + procedure: Procedure, + grpc_call_fn: GrpcCallFn, + get_map_values_bytes_fn: GetMapValuesBytesFn, +} + +#[allow(dead_code)] +impl Operation { + pub fn new(extension: Rc, procedure: Procedure) -> Self { + Self { + state: State::Pending, + result: Err(Status::Empty), + extension, + procedure, + grpc_call_fn, + get_map_values_bytes_fn, + } + } + + fn trigger(&mut self) { + if let State::Done = self.state { + } else { + self.result = self.procedure.0.send( + self.get_map_values_bytes_fn, + self.grpc_call_fn, + self.procedure.1.clone(), + ); + self.state.next(); + } + } + + pub fn get_state(&self) -> State { + self.state.clone() + } + + pub fn get_result(&self) -> Result { + self.result + } + + pub fn get_extension_type(&self) -> ExtensionType { + self.extension.extension_type.clone() + } + + pub fn get_failure_mode(&self) -> FailureMode { + self.extension.failure_mode.clone() + } +} + +#[allow(dead_code)] +pub struct OperationDispatcher { + operations: RefCell>, + service_handlers: HashMap>, +} + +#[allow(dead_code)] +impl OperationDispatcher { + pub fn default() -> Self { + OperationDispatcher { + operations: RefCell::new(vec![]), + service_handlers: HashMap::default(), + } + } + pub fn new(service_handlers: HashMap>) -> Self { + Self { + service_handlers, + operations: RefCell::new(vec![]), + } + } + + pub fn build_operations( + &self, + policy: &Policy, + descriptors: RepeatedField, + ) { + let mut operations: Vec = vec![]; + policy.actions.iter().for_each(|action| { + // TODO(didierofrivia): Error handling + if let Some(service) = self.service_handlers.get(&action.extension) { + let message = GrpcMessage::new( + service.get_extension_type(), + policy.domain.clone(), + descriptors.clone(), + ); + operations.push(Operation::new( + service.get_extension(), + (Rc::clone(service), message), + )) + } + }); + self.push_operations(operations); + } + + pub fn push_operations(&self, operations: Vec) { + self.operations.borrow_mut().extend(operations); + } + + pub fn get_current_operation_state(&self) -> Option { + self.operations + .borrow() + .first() + .map(|operation| operation.get_state().clone()) + } + + pub fn get_current_operation_result(&self) -> Result { + self.operations.borrow().first().unwrap().get_result() + } + + pub fn next(&self) -> Option { + let mut operations = self.operations.borrow_mut(); + if let Some((i, operation)) = operations.iter_mut().enumerate().next() { + if let State::Done = operation.get_state() { + operations.remove(i); + operations.get(i).cloned() // The next op is now at `i` + } else { + operation.trigger(); + Some(operation.clone()) + } + } else { + None + } + } +} + +fn grpc_call_fn( + upstream_name: &str, + service_name: &str, + method_name: &str, + initial_metadata: Vec<(&str, &[u8])>, + message: Option<&[u8]>, + timeout: Duration, +) -> Result { + hostcalls::dispatch_grpc_call( + upstream_name, + service_name, + method_name, + initial_metadata, + message, + timeout, + ) +} + +fn get_map_values_bytes_fn(map_type: MapType, key: &str) -> Result, Status> { + hostcalls::get_map_value_bytes(map_type, key) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::envoy::RateLimitRequest; + use std::time::Duration; + + fn grpc_call_fn_stub( + _upstream_name: &str, + _service_name: &str, + _method_name: &str, + _initial_metadata: Vec<(&str, &[u8])>, + _message: Option<&[u8]>, + _timeout: Duration, + ) -> Result { + Ok(200) + } + + fn get_map_values_bytes_fn_stub( + _map_type: MapType, + _key: &str, + ) -> Result, Status> { + Ok(Some(Vec::new())) + } + + fn build_grpc_service_handler() -> GrpcServiceHandler { + GrpcServiceHandler::new(Rc::new(Default::default()), Rc::new(Default::default())) + } + + fn build_message() -> RateLimitRequest { + RateLimitRequest { + domain: "example.org".to_string(), + descriptors: RepeatedField::new(), + hits_addend: 1, + unknown_fields: Default::default(), + cached_size: Default::default(), + } + } + + fn build_operation() -> Operation { + Operation { + state: State::Pending, + result: Ok(1), + extension: Rc::new(Extension::default()), + procedure: ( + Rc::new(build_grpc_service_handler()), + GrpcMessage::RateLimit(build_message()), + ), + grpc_call_fn: grpc_call_fn_stub, + get_map_values_bytes_fn: get_map_values_bytes_fn_stub, + } + } + + #[test] + fn operation_getters() { + let operation = build_operation(); + + assert_eq!(operation.get_state(), State::Pending); + assert_eq!(operation.get_extension_type(), ExtensionType::RateLimit); + assert_eq!(operation.get_failure_mode(), FailureMode::Deny); + assert_eq!(operation.get_result(), Ok(1)); + } + + #[test] + fn operation_transition() { + let mut operation = build_operation(); + assert_eq!(operation.get_state(), State::Pending); + operation.trigger(); + assert_eq!(operation.get_state(), State::Waiting); + operation.trigger(); + assert_eq!(operation.result, Ok(200)); + assert_eq!(operation.get_state(), State::Done); + } + + #[test] + fn operation_dispatcher_push_actions() { + let operation_dispatcher = OperationDispatcher::default(); + + assert_eq!(operation_dispatcher.operations.borrow().len(), 0); + operation_dispatcher.push_operations(vec![build_operation()]); + + assert_eq!(operation_dispatcher.operations.borrow().len(), 1); + } + + #[test] + fn operation_dispatcher_get_current_action_state() { + let operation_dispatcher = OperationDispatcher::default(); + operation_dispatcher.push_operations(vec![build_operation()]); + assert_eq!( + operation_dispatcher.get_current_operation_state(), + Some(State::Pending) + ); + } + + #[test] + fn operation_dispatcher_next() { + let operation_dispatcher = OperationDispatcher::default(); + operation_dispatcher.push_operations(vec![build_operation(), build_operation()]); + + assert_eq!(operation_dispatcher.get_current_operation_result(), Ok(1)); + assert_eq!( + operation_dispatcher.get_current_operation_state(), + Some(State::Pending) + ); + + let mut op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(200)); + assert_eq!(op.unwrap().get_state(), State::Waiting); + + op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(200)); + assert_eq!(op.unwrap().get_state(), State::Done); + + op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(1)); + assert_eq!(op.unwrap().get_state(), State::Pending); + + op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(200)); + assert_eq!(op.unwrap().get_state(), State::Waiting); + + op = operation_dispatcher.next(); + assert_eq!(op.clone().unwrap().get_result(), Ok(200)); + assert_eq!(op.unwrap().get_state(), State::Done); + + op = operation_dispatcher.next(); + assert!(op.is_none()); + assert!(operation_dispatcher.get_current_operation_state().is_none()); + } +} diff --git a/src/policy.rs b/src/policy.rs index 7505b8fd..c7ee430a 100644 --- a/src/policy.rs +++ b/src/policy.rs @@ -1,5 +1,5 @@ use crate::attribute::Attribute; -use crate::configuration::{DataItem, DataType, PatternExpression}; +use crate::configuration::{Action, DataItem, DataType, PatternExpression}; use crate::envoy::{RateLimitDescriptor, RateLimitDescriptor_Entry}; use crate::filter::http_context::Filter; use log::debug; @@ -28,16 +28,24 @@ pub struct Policy { pub domain: String, pub hostnames: Vec, pub rules: Vec, + pub actions: Vec, } impl Policy { #[cfg(test)] - pub fn new(name: String, domain: String, hostnames: Vec, rules: Vec) -> Self { + pub fn new( + name: String, + domain: String, + hostnames: Vec, + rules: Vec, + actions: Vec, + ) -> Self { Policy { name, domain, hostnames, rules, + actions, } } diff --git a/src/policy_index.rs b/src/policy_index.rs index 58b31d94..3620179e 100644 --- a/src/policy_index.rs +++ b/src/policy_index.rs @@ -41,7 +41,13 @@ mod tests { use crate::policy_index::PolicyIndex; fn build_ratelimit_policy(name: &str) -> Policy { - Policy::new(name.to_owned(), "".to_owned(), Vec::new(), Vec::new()) + Policy::new( + name.to_owned(), + "".to_owned(), + Vec::new(), + Vec::new(), + Vec::new(), + ) } #[test] diff --git a/src/service.rs b/src/service.rs index e6b13d61..e89077f2 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,45 +1,174 @@ pub(crate) mod auth; pub(crate) mod rate_limit; -use crate::configuration::{ExtensionType, FailureMode}; -use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; -use crate::service::rate_limit::{RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; +use crate::configuration::{Extension, ExtensionType, FailureMode}; +use crate::envoy::{CheckRequest, RateLimitDescriptor, RateLimitRequest}; +use crate::service::auth::{AuthService, AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; +use crate::service::rate_limit::{RateLimitService, RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate}; -use protobuf::Message; -use proxy_wasm::hostcalls; -use proxy_wasm::hostcalls::dispatch_grpc_call; +use protobuf::reflect::MessageDescriptor; +use protobuf::{ + Clear, CodedInputStream, CodedOutputStream, Message, ProtobufResult, UnknownFields, +}; use proxy_wasm::types::{Bytes, MapType, Status}; +use std::any::Any; use std::cell::OnceCell; +use std::fmt::Debug; use std::rc::Rc; use std::time::Duration; +#[derive(Clone, Debug)] +pub enum GrpcMessage { + Auth(CheckRequest), + RateLimit(RateLimitRequest), +} + +impl Default for GrpcMessage { + fn default() -> Self { + GrpcMessage::RateLimit(RateLimitRequest::new()) + } +} + +impl Clear for GrpcMessage { + fn clear(&mut self) { + match self { + GrpcMessage::Auth(msg) => msg.clear(), + GrpcMessage::RateLimit(msg) => msg.clear(), + } + } +} + +impl Message for GrpcMessage { + fn descriptor(&self) -> &'static MessageDescriptor { + match self { + GrpcMessage::Auth(msg) => msg.descriptor(), + GrpcMessage::RateLimit(msg) => msg.descriptor(), + } + } + + fn is_initialized(&self) -> bool { + match self { + GrpcMessage::Auth(msg) => msg.is_initialized(), + GrpcMessage::RateLimit(msg) => msg.is_initialized(), + } + } + + fn merge_from(&mut self, is: &mut CodedInputStream) -> ProtobufResult<()> { + match self { + GrpcMessage::Auth(msg) => msg.merge_from(is), + GrpcMessage::RateLimit(msg) => msg.merge_from(is), + } + } + + fn write_to_with_cached_sizes(&self, os: &mut CodedOutputStream) -> ProtobufResult<()> { + match self { + GrpcMessage::Auth(msg) => msg.write_to_with_cached_sizes(os), + GrpcMessage::RateLimit(msg) => msg.write_to_with_cached_sizes(os), + } + } + + fn write_to_bytes(&self) -> ProtobufResult> { + match self { + GrpcMessage::Auth(msg) => msg.write_to_bytes(), + GrpcMessage::RateLimit(msg) => msg.write_to_bytes(), + } + } + + fn compute_size(&self) -> u32 { + match self { + GrpcMessage::Auth(msg) => msg.compute_size(), + GrpcMessage::RateLimit(msg) => msg.compute_size(), + } + } + + fn get_cached_size(&self) -> u32 { + match self { + GrpcMessage::Auth(msg) => msg.get_cached_size(), + GrpcMessage::RateLimit(msg) => msg.get_cached_size(), + } + } + + fn get_unknown_fields(&self) -> &UnknownFields { + match self { + GrpcMessage::Auth(msg) => msg.get_unknown_fields(), + GrpcMessage::RateLimit(msg) => msg.get_unknown_fields(), + } + } + + fn mut_unknown_fields(&mut self) -> &mut UnknownFields { + match self { + GrpcMessage::Auth(msg) => msg.mut_unknown_fields(), + GrpcMessage::RateLimit(msg) => msg.mut_unknown_fields(), + } + } + + fn as_any(&self) -> &dyn Any { + match self { + GrpcMessage::Auth(msg) => msg.as_any(), + GrpcMessage::RateLimit(msg) => msg.as_any(), + } + } + + fn new() -> Self + where + Self: Sized, + { + // Returning default value + GrpcMessage::default() + } + + fn default_instance() -> &'static Self + where + Self: Sized, + { + #[allow(non_upper_case_globals)] + static instance: ::protobuf::rt::LazyV2 = ::protobuf::rt::LazyV2::INIT; + instance.get(|| GrpcMessage::RateLimit(RateLimitRequest::new())) + } +} + +impl GrpcMessage { + // Using domain as ce_host for the time being, we might pass a DataType in the future. + pub fn new( + extension_type: ExtensionType, + domain: String, + descriptors: protobuf::RepeatedField, + ) -> Self { + match extension_type { + ExtensionType::RateLimit => { + GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)) + } + ExtensionType::Auth => GrpcMessage::Auth(AuthService::message(domain.clone())), + } + } +} + #[derive(Default)] pub struct GrpcService { - endpoint: String, + #[allow(dead_code)] + extension: Rc, name: &'static str, method: &'static str, - failure_mode: FailureMode, } impl GrpcService { - pub fn new(extension_type: ExtensionType, endpoint: String, failure_mode: FailureMode) -> Self { - match extension_type { + pub fn new(extension: Rc) -> Self { + match extension.extension_type { ExtensionType::Auth => Self { - endpoint, + extension, name: AUTH_SERVICE_NAME, method: AUTH_METHOD_NAME, - failure_mode, }, ExtensionType::RateLimit => Self { - endpoint, + extension, name: RATELIMIT_SERVICE_NAME, method: RATELIMIT_METHOD_NAME, - failure_mode, }, } } + fn endpoint(&self) -> &str { - &self.endpoint + &self.extension.endpoint } fn name(&self) -> &str { self.name @@ -48,11 +177,21 @@ impl GrpcService { self.method } pub fn failure_mode(&self) -> &FailureMode { - &self.failure_mode + &self.extension.failure_mode } } -#[derive(Default)] +pub type GrpcCallFn = fn( + upstream_name: &str, + service_name: &str, + method_name: &str, + initial_metadata: Vec<(&str, &[u8])>, + message: Option<&[u8]>, + timeout: Duration, +) -> Result; + +pub type GetMapValuesBytesFn = fn(map_type: MapType, key: &str) -> Result, Status>; + pub struct GrpcServiceHandler { service: Rc, header_resolver: Rc, @@ -66,16 +205,21 @@ impl GrpcServiceHandler { } } - pub fn send(&self, message: M) -> Result { + pub fn send( + &self, + get_map_values_bytes_fn: GetMapValuesBytesFn, + grpc_call_fn: GrpcCallFn, + message: GrpcMessage, + ) -> Result { let msg = Message::write_to_bytes(&message).unwrap(); let metadata = self .header_resolver - .get() + .get(get_map_values_bytes_fn) .iter() .map(|(header, value)| (*header, value.as_slice())) .collect(); - dispatch_grpc_call( + grpc_call_fn( self.service.endpoint(), self.service.name(), self.service.method(), @@ -84,6 +228,14 @@ impl GrpcServiceHandler { Duration::from_secs(5), ) } + + pub fn get_extension(&self) -> Rc { + Rc::clone(&self.service.extension) + } + + pub fn get_extension_type(&self) -> ExtensionType { + self.service.extension.extension_type.clone() + } } pub struct HeaderResolver { @@ -103,12 +255,12 @@ impl HeaderResolver { } } - pub fn get(&self) -> &Vec<(&'static str, Bytes)> { + pub fn get(&self, get_map_values_bytes_fn: GetMapValuesBytesFn) -> &Vec<(&'static str, Bytes)> { self.headers.get_or_init(|| { let mut headers = Vec::new(); for header in TracingHeader::all() { if let Ok(Some(value)) = - hostcalls::get_map_value_bytes(MapType::HttpRequestHeaders, (*header).as_str()) + get_map_values_bytes_fn(MapType::HttpRequestHeaders, (*header).as_str()) { headers.push(((*header).as_str(), value)); } diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index 6dfc3c89..b6b0357c 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -27,8 +27,6 @@ mod tests { use crate::service::rate_limit::RateLimitService; //use crate::service::Service; use protobuf::{CachedSize, RepeatedField, UnknownFields}; - //use proxy_wasm::types::Status; - //use crate::filter::http_context::{Filter}; fn build_message() -> RateLimitRequest { let domain = "rlp1"; @@ -52,20 +50,4 @@ mod tests { assert_eq!(msg.unknown_fields, UnknownFields::default()); assert_eq!(msg.cached_size, CachedSize::default()); } - /*#[test] - fn sends_message() { - let msg = build_message(); - let metadata = vec![("header-1", "value-1".as_bytes())]; - let rls = RateLimitService::new("limitador-cluster", metadata); - - // TODO(didierofrivia): When we have a grpc response type, assert the async response - } - - fn grpc_call( - _upstream_name: &str, - _initial_metadata: Vec<(&str, &[u8])>, - _message: RateLimitRequest, - ) -> Result { - Ok(1) - } */ } diff --git a/tests/rate_limited.rs b/tests/rate_limited.rs index d1421350..6bc3168f 100644 --- a/tests/rate_limited.rs +++ b/tests/rate_limited.rs @@ -132,7 +132,18 @@ fn it_limits() { } } ] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; @@ -275,7 +286,18 @@ fn it_passes_additional_headers() { } } ] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; @@ -412,7 +434,18 @@ fn it_rate_limits_with_empty_conditions() { } } ] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#; @@ -528,7 +561,18 @@ fn it_does_not_rate_limits_when_selector_does_not_exist_and_misses_default_value } } ] - }] + }], + "actions": [ + { + "extension": "limitador", + "data": { + "static": { + "key": "rlp-ns-A/rlp-name-A", + "value": "1" + } + } + } + ] }] }"#;