diff --git a/src/envoy/mod.rs b/src/envoy/mod.rs index db527204..6195cfae 100644 --- a/src/envoy/mod.rs +++ b/src/envoy/mod.rs @@ -37,7 +37,7 @@ pub use { AttributeContext_Request, }, base::Metadata, - external_auth::CheckRequest, + external_auth::{CheckRequest, DeniedHttpResponse, OkHttpResponse}, ratelimit::{RateLimitDescriptor, RateLimitDescriptor_Entry}, rls::{RateLimitRequest, RateLimitResponse, RateLimitResponse_Code}, }; diff --git a/src/filter/http_context.rs b/src/filter/http_context.rs index 2290266d..fe4832a2 100644 --- a/src/filter/http_context.rs +++ b/src/filter/http_context.rs @@ -1,9 +1,9 @@ -use crate::configuration::{FailureMode, FilterConfig}; +use crate::configuration::{ExtensionType, FailureMode, FilterConfig}; use crate::envoy::{RateLimitResponse, RateLimitResponse_Code}; use crate::operation_dispatcher::OperationDispatcher; use crate::policy::Policy; +use crate::service::grpc_message::GrpcMessageResponse; use log::{debug, warn}; -use protobuf::Message; use proxy_wasm::traits::{Context, HttpContext}; use proxy_wasm::types::Action; use std::rc::Rc; @@ -29,8 +29,8 @@ impl Filter { } } - fn process_rate_limit_policy(&self, rlp: &Policy) -> Action { - let descriptors = rlp.build_descriptors(self); + fn process_policy(&self, policy: &Policy) -> Action { + let descriptors = policy.build_descriptors(self); if descriptors.is_empty() { debug!( "#{} process_rate_limit_policy: empty descriptors", @@ -39,7 +39,8 @@ impl Filter { return Action::Continue; } - self.operation_dispatcher.build_operations(rlp, descriptors); + self.operation_dispatcher + .build_operations(policy, descriptors); if let Some(operation) = self.operation_dispatcher.next() { match operation.get_result() { @@ -63,22 +64,52 @@ impl Filter { } } - fn handle_error_on_grpc_response(&self) { - // todo(adam-cattermole): We need a method of knowing which service is the one currently - // being used (the current action) so that we can get the failure mode - let rls = self - .config - .services - .values() - .next() - .expect("expect a value"); - match rls.failure_mode() { + fn handle_error_on_grpc_response(&self, failure_mode: &FailureMode) { + match failure_mode { FailureMode::Deny => { self.send_http_response(500, vec![], Some(b"Internal Server Error.\n")) } FailureMode::Allow => self.resume_http_request(), } } + + fn process_ratelimit_grpc_response( + &mut self, + rl_resp: GrpcMessageResponse, + failure_mode: &FailureMode, + ) { + match rl_resp { + GrpcMessageResponse::RateLimit(RateLimitResponse { + overall_code: RateLimitResponse_Code::UNKNOWN, + .. + }) => { + self.handle_error_on_grpc_response(failure_mode); + } + GrpcMessageResponse::RateLimit(RateLimitResponse { + overall_code: RateLimitResponse_Code::OVER_LIMIT, + response_headers_to_add: rl_headers, + .. + }) => { + let mut response_headers = vec![]; + for header in &rl_headers { + response_headers.push((header.get_key(), header.get_value())); + } + self.send_http_response(429, response_headers, Some(b"Too Many Requests\n")); + } + GrpcMessageResponse::RateLimit(RateLimitResponse { + overall_code: RateLimitResponse_Code::OK, + response_headers_to_add: additional_headers, + .. + }) => { + for header in additional_headers { + self.response_headers_to_add + .push((header.key, header.value)); + } + } + _ => {} + } + self.operation_dispatcher.next(); + } } impl HttpContext for Filter { @@ -97,9 +128,12 @@ impl HttpContext for Filter { ); Action::Continue } - Some(rlp) => { - debug!("#{} ratelimitpolicy selected {}", self.context_id, rlp.name); - self.process_rate_limit_policy(rlp) + Some(policy) => { + debug!( + "#{} ratelimitpolicy selected {}", + self.context_id, policy.name + ); + self.process_policy(policy) } } } @@ -124,55 +158,42 @@ impl Context for Filter { self.context_id ); - let res_body_bytes = match self.get_grpc_call_response_body(0, resp_size) { - Some(bytes) => bytes, - None => { - warn!("grpc response body is empty!"); - self.handle_error_on_grpc_response(); - return; - } - }; - - let rl_resp: RateLimitResponse = match Message::parse_from_bytes(&res_body_bytes) { - Ok(res) => res, - Err(e) => { - warn!("failed to parse grpc response body into RateLimitResponse message: {e}"); - self.handle_error_on_grpc_response(); - return; - } - }; - - match rl_resp { - RateLimitResponse { - overall_code: RateLimitResponse_Code::UNKNOWN, - .. - } => { - self.handle_error_on_grpc_response(); - return; - } - RateLimitResponse { - overall_code: RateLimitResponse_Code::OVER_LIMIT, - response_headers_to_add: rl_headers, - .. - } => { - let mut response_headers = vec![]; - for header in &rl_headers { - response_headers.push((header.get_key(), header.get_value())); + if let Some(operation) = self.operation_dispatcher.get_operation(token_id) { + let failure_mode = &operation.get_failure_mode(); + let res_body_bytes = match self.get_grpc_call_response_body(0, resp_size) { + Some(bytes) => bytes, + None => { + warn!("grpc response body is empty!"); + self.handle_error_on_grpc_response(failure_mode); + return; } - self.send_http_response(429, response_headers, Some(b"Too Many Requests\n")); - return; - } - RateLimitResponse { - overall_code: RateLimitResponse_Code::OK, - response_headers_to_add: additional_headers, - .. - } => { - for header in additional_headers { - self.response_headers_to_add - .push((header.key, header.value)); + }; + let res = match GrpcMessageResponse::new( + operation.get_extension_type(), + &res_body_bytes, + status_code, + ) { + Ok(res) => res, + Err(e) => { + warn!( + "failed to parse grpc response body into GrpcMessageResponse message: {e}" + ); + self.handle_error_on_grpc_response(failure_mode); + return; } + }; + match operation.get_extension_type() { + ExtensionType::Auth => {} // TODO(didierofrivia): Process auth grpc response. + ExtensionType::RateLimit => self.process_ratelimit_grpc_response(res, failure_mode), } + + if let Some(_op) = self.operation_dispatcher.next() { + } else { + self.resume_http_request() + } + } else { + warn!("No Operation found with token_id: {token_id}"); + self.handle_error_on_grpc_response(&FailureMode::Deny); // TODO(didierofrivia): Decide on what's the default failure mode } - self.resume_http_request(); } } diff --git a/src/operation_dispatcher.rs b/src/operation_dispatcher.rs index a9ab3c18..764a3d08 100644 --- a/src/operation_dispatcher.rs +++ b/src/operation_dispatcher.rs @@ -1,7 +1,8 @@ use crate::configuration::{Extension, ExtensionType, FailureMode}; use crate::envoy::RateLimitDescriptor; use crate::policy::Policy; -use crate::service::{GetMapValuesBytesFn, GrpcCallFn, GrpcMessage, GrpcServiceHandler}; +use crate::service::grpc_message::GrpcMessageRequest; +use crate::service::{GetMapValuesBytesFn, GrpcCallFn, GrpcServiceHandler}; use protobuf::RepeatedField; use proxy_wasm::hostcalls; use proxy_wasm::types::{Bytes, MapType, Status}; @@ -29,7 +30,7 @@ impl State { } } -type Procedure = (Rc, GrpcMessage); +type Procedure = (Rc, GrpcMessageRequest); #[allow(dead_code)] #[derive(Clone)] @@ -47,7 +48,7 @@ impl Operation { pub fn new(extension: Rc, procedure: Procedure) -> Self { Self { state: State::Pending, - result: Err(Status::Empty), + result: Ok(0), // Heuristics: zero represents that it's not been triggered, following `hostcalls` example extension, procedure, grpc_call_fn, @@ -55,38 +56,46 @@ impl Operation { } } - fn trigger(&mut self) { - if let State::Done = self.state { - } else { - self.result = self.procedure.0.send( - self.get_map_values_bytes_fn, - self.grpc_call_fn, - self.procedure.1.clone(), - ); - self.state.next(); + fn trigger(&mut self) -> Result { + match self.state { + State::Pending => { + self.result = self.procedure.0.send( + self.get_map_values_bytes_fn, + self.grpc_call_fn, + self.procedure.1.clone(), + ); + self.state.next(); + self.result + } + State::Waiting => { + self.state.next(); + self.result + } + State::Done => self.result, } } - pub fn get_state(&self) -> State { - self.state.clone() + pub fn get_state(&self) -> &State { + &self.state } pub fn get_result(&self) -> Result { self.result } - pub fn get_extension_type(&self) -> ExtensionType { - self.extension.extension_type.clone() + pub fn get_extension_type(&self) -> &ExtensionType { + &self.extension.extension_type } - pub fn get_failure_mode(&self) -> FailureMode { - self.extension.failure_mode.clone() + pub fn get_failure_mode(&self) -> &FailureMode { + &self.extension.failure_mode } } #[allow(dead_code)] pub struct OperationDispatcher { operations: RefCell>, + waiting_operations: RefCell>, // TODO(didierofrivia): Maybe keep references or Rc service_handlers: HashMap>, } @@ -95,6 +104,7 @@ impl OperationDispatcher { pub fn default() -> Self { OperationDispatcher { operations: RefCell::new(vec![]), + waiting_operations: RefCell::new(HashMap::default()), service_handlers: HashMap::default(), } } @@ -102,9 +112,14 @@ impl OperationDispatcher { Self { service_handlers, operations: RefCell::new(vec![]), + waiting_operations: RefCell::new(HashMap::new()), } } + pub fn get_operation(&self, token_id: u32) -> Option { + self.waiting_operations.borrow_mut().get(&token_id).cloned() + } + pub fn build_operations( &self, policy: &Policy, @@ -114,7 +129,7 @@ impl OperationDispatcher { policy.actions.iter().for_each(|action| { // TODO(didierofrivia): Error handling if let Some(service) = self.service_handlers.get(&action.extension) { - let message = GrpcMessage::new( + let message = GrpcMessageRequest::new( service.get_extension_type(), policy.domain.clone(), descriptors.clone(), @@ -147,11 +162,24 @@ impl OperationDispatcher { let mut operations = self.operations.borrow_mut(); if let Some((i, operation)) = operations.iter_mut().enumerate().next() { if let State::Done = operation.get_state() { + if let Ok(token_id) = operation.result { + self.waiting_operations.borrow_mut().remove(&token_id); + } // If result was Err, means the operation wasn't indexed operations.remove(i); - operations.get(i).cloned() // The next op is now at `i` - } else { - operation.trigger(); + // The next op is now at `i` + } + if let Some(operation) = operations.get_mut(i) { + if let Ok(token_id) = operation.trigger() { + if *operation.get_state() == State::Waiting { + // We index only if it was just transitioned to Waiting after triggering + self.waiting_operations + .borrow_mut() + .insert(token_id, operation.clone()); + } // TODO(didierofrivia): Decide on indexing the failed operations. + } Some(operation.clone()) + } else { + None } } else { None @@ -187,7 +215,7 @@ mod tests { use crate::envoy::RateLimitRequest; use std::time::Duration; - fn grpc_call_fn_stub( + fn default_grpc_call_fn_stub( _upstream_name: &str, _service_name: &str, _method_name: &str, @@ -219,14 +247,18 @@ mod tests { } } - fn build_operation() -> Operation { + fn build_operation(grpc_call_fn_stub: GrpcCallFn, extension_type: ExtensionType) -> Operation { Operation { state: State::Pending, - result: Ok(1), - extension: Rc::new(Extension::default()), + result: Ok(0), + extension: Rc::new(Extension { + extension_type, + endpoint: "local".to_string(), + failure_mode: FailureMode::Deny, + }), procedure: ( Rc::new(build_grpc_service_handler()), - GrpcMessage::RateLimit(build_message()), + GrpcMessageRequest::RateLimit(build_message()), ), grpc_call_fn: grpc_call_fn_stub, get_map_values_bytes_fn: get_map_values_bytes_fn_stub, @@ -235,23 +267,26 @@ mod tests { #[test] fn operation_getters() { - let operation = build_operation(); + let operation = build_operation(default_grpc_call_fn_stub, ExtensionType::RateLimit); - assert_eq!(operation.get_state(), State::Pending); - assert_eq!(operation.get_extension_type(), ExtensionType::RateLimit); - assert_eq!(operation.get_failure_mode(), FailureMode::Deny); - assert_eq!(operation.get_result(), Ok(1)); + assert_eq!(*operation.get_state(), State::Pending); + assert_eq!(*operation.get_extension_type(), ExtensionType::RateLimit); + assert_eq!(*operation.get_failure_mode(), FailureMode::Deny); + assert_eq!(operation.get_result(), Ok(0)); } #[test] fn operation_transition() { - let mut operation = build_operation(); - assert_eq!(operation.get_state(), State::Pending); - operation.trigger(); - assert_eq!(operation.get_state(), State::Waiting); - operation.trigger(); + let mut operation = build_operation(default_grpc_call_fn_stub, ExtensionType::RateLimit); + assert_eq!(operation.result, Ok(0)); + assert_eq!(*operation.get_state(), State::Pending); + let mut res = operation.trigger(); + assert_eq!(res, Ok(200)); + assert_eq!(*operation.get_state(), State::Waiting); + res = operation.trigger(); + assert_eq!(res, Ok(200)); assert_eq!(operation.result, Ok(200)); - assert_eq!(operation.get_state(), State::Done); + assert_eq!(*operation.get_state(), State::Done); } #[test] @@ -259,7 +294,10 @@ mod tests { let operation_dispatcher = OperationDispatcher::default(); assert_eq!(operation_dispatcher.operations.borrow().len(), 0); - operation_dispatcher.push_operations(vec![build_operation()]); + operation_dispatcher.push_operations(vec![build_operation( + default_grpc_call_fn_stub, + ExtensionType::RateLimit, + )]); assert_eq!(operation_dispatcher.operations.borrow().len(), 1); } @@ -267,7 +305,10 @@ mod tests { #[test] fn operation_dispatcher_get_current_action_state() { let operation_dispatcher = OperationDispatcher::default(); - operation_dispatcher.push_operations(vec![build_operation()]); + operation_dispatcher.push_operations(vec![build_operation( + default_grpc_call_fn_stub, + ExtensionType::RateLimit, + )]); assert_eq!( operation_dispatcher.get_current_operation_state(), Some(State::Pending) @@ -277,36 +318,86 @@ mod tests { #[test] fn operation_dispatcher_next() { let operation_dispatcher = OperationDispatcher::default(); - operation_dispatcher.push_operations(vec![build_operation(), build_operation()]); - assert_eq!(operation_dispatcher.get_current_operation_result(), Ok(1)); + fn grpc_call_fn_stub_66( + _upstream_name: &str, + _service_name: &str, + _method_name: &str, + _initial_metadata: Vec<(&str, &[u8])>, + _message: Option<&[u8]>, + _timeout: Duration, + ) -> Result { + Ok(66) + } + + fn grpc_call_fn_stub_77( + _upstream_name: &str, + _service_name: &str, + _method_name: &str, + _initial_metadata: Vec<(&str, &[u8])>, + _message: Option<&[u8]>, + _timeout: Duration, + ) -> Result { + Ok(77) + } + + operation_dispatcher.push_operations(vec![ + build_operation(grpc_call_fn_stub_66, ExtensionType::RateLimit), + build_operation(grpc_call_fn_stub_77, ExtensionType::Auth), + ]); + + assert_eq!(operation_dispatcher.get_current_operation_result(), Ok(0)); assert_eq!( operation_dispatcher.get_current_operation_state(), Some(State::Pending) ); + assert_eq!( + operation_dispatcher.waiting_operations.borrow_mut().len(), + 0 + ); let mut op = operation_dispatcher.next(); - assert_eq!(op.clone().unwrap().get_result(), Ok(200)); - assert_eq!(op.unwrap().get_state(), State::Waiting); - - op = operation_dispatcher.next(); - assert_eq!(op.clone().unwrap().get_result(), Ok(200)); - assert_eq!(op.unwrap().get_state(), State::Done); + assert_eq!(op.clone().unwrap().get_result(), Ok(66)); + assert_eq!( + *op.clone().unwrap().get_extension_type(), + ExtensionType::RateLimit + ); + assert_eq!(*op.unwrap().get_state(), State::Waiting); + assert_eq!( + operation_dispatcher.waiting_operations.borrow_mut().len(), + 1 + ); op = operation_dispatcher.next(); - assert_eq!(op.clone().unwrap().get_result(), Ok(1)); - assert_eq!(op.unwrap().get_state(), State::Pending); + assert_eq!(op.clone().unwrap().get_result(), Ok(66)); + assert_eq!(*op.unwrap().get_state(), State::Done); op = operation_dispatcher.next(); - assert_eq!(op.clone().unwrap().get_result(), Ok(200)); - assert_eq!(op.unwrap().get_state(), State::Waiting); + assert_eq!(op.clone().unwrap().get_result(), Ok(77)); + assert_eq!( + *op.clone().unwrap().get_extension_type(), + ExtensionType::Auth + ); + assert_eq!(*op.unwrap().get_state(), State::Waiting); + assert_eq!( + operation_dispatcher.waiting_operations.borrow_mut().len(), + 1 + ); op = operation_dispatcher.next(); - assert_eq!(op.clone().unwrap().get_result(), Ok(200)); - assert_eq!(op.unwrap().get_state(), State::Done); + assert_eq!(op.clone().unwrap().get_result(), Ok(77)); + assert_eq!(*op.unwrap().get_state(), State::Done); + assert_eq!( + operation_dispatcher.waiting_operations.borrow_mut().len(), + 1 + ); op = operation_dispatcher.next(); assert!(op.is_none()); assert!(operation_dispatcher.get_current_operation_state().is_none()); + assert_eq!( + operation_dispatcher.waiting_operations.borrow_mut().len(), + 0 + ); } } diff --git a/src/service.rs b/src/service.rs index e89077f2..e712dfdb 100644 --- a/src/service.rs +++ b/src/service.rs @@ -1,148 +1,18 @@ pub(crate) mod auth; +pub(crate) mod grpc_message; pub(crate) mod rate_limit; use crate::configuration::{Extension, ExtensionType, FailureMode}; -use crate::envoy::{CheckRequest, RateLimitDescriptor, RateLimitRequest}; -use crate::service::auth::{AuthService, AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; -use crate::service::rate_limit::{RateLimitService, RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; +use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME}; +use crate::service::grpc_message::GrpcMessageRequest; +use crate::service::rate_limit::{RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME}; use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate}; -use protobuf::reflect::MessageDescriptor; -use protobuf::{ - Clear, CodedInputStream, CodedOutputStream, Message, ProtobufResult, UnknownFields, -}; +use protobuf::Message; use proxy_wasm::types::{Bytes, MapType, Status}; -use std::any::Any; use std::cell::OnceCell; -use std::fmt::Debug; use std::rc::Rc; use std::time::Duration; -#[derive(Clone, Debug)] -pub enum GrpcMessage { - Auth(CheckRequest), - RateLimit(RateLimitRequest), -} - -impl Default for GrpcMessage { - fn default() -> Self { - GrpcMessage::RateLimit(RateLimitRequest::new()) - } -} - -impl Clear for GrpcMessage { - fn clear(&mut self) { - match self { - GrpcMessage::Auth(msg) => msg.clear(), - GrpcMessage::RateLimit(msg) => msg.clear(), - } - } -} - -impl Message for GrpcMessage { - fn descriptor(&self) -> &'static MessageDescriptor { - match self { - GrpcMessage::Auth(msg) => msg.descriptor(), - GrpcMessage::RateLimit(msg) => msg.descriptor(), - } - } - - fn is_initialized(&self) -> bool { - match self { - GrpcMessage::Auth(msg) => msg.is_initialized(), - GrpcMessage::RateLimit(msg) => msg.is_initialized(), - } - } - - fn merge_from(&mut self, is: &mut CodedInputStream) -> ProtobufResult<()> { - match self { - GrpcMessage::Auth(msg) => msg.merge_from(is), - GrpcMessage::RateLimit(msg) => msg.merge_from(is), - } - } - - fn write_to_with_cached_sizes(&self, os: &mut CodedOutputStream) -> ProtobufResult<()> { - match self { - GrpcMessage::Auth(msg) => msg.write_to_with_cached_sizes(os), - GrpcMessage::RateLimit(msg) => msg.write_to_with_cached_sizes(os), - } - } - - fn write_to_bytes(&self) -> ProtobufResult> { - match self { - GrpcMessage::Auth(msg) => msg.write_to_bytes(), - GrpcMessage::RateLimit(msg) => msg.write_to_bytes(), - } - } - - fn compute_size(&self) -> u32 { - match self { - GrpcMessage::Auth(msg) => msg.compute_size(), - GrpcMessage::RateLimit(msg) => msg.compute_size(), - } - } - - fn get_cached_size(&self) -> u32 { - match self { - GrpcMessage::Auth(msg) => msg.get_cached_size(), - GrpcMessage::RateLimit(msg) => msg.get_cached_size(), - } - } - - fn get_unknown_fields(&self) -> &UnknownFields { - match self { - GrpcMessage::Auth(msg) => msg.get_unknown_fields(), - GrpcMessage::RateLimit(msg) => msg.get_unknown_fields(), - } - } - - fn mut_unknown_fields(&mut self) -> &mut UnknownFields { - match self { - GrpcMessage::Auth(msg) => msg.mut_unknown_fields(), - GrpcMessage::RateLimit(msg) => msg.mut_unknown_fields(), - } - } - - fn as_any(&self) -> &dyn Any { - match self { - GrpcMessage::Auth(msg) => msg.as_any(), - GrpcMessage::RateLimit(msg) => msg.as_any(), - } - } - - fn new() -> Self - where - Self: Sized, - { - // Returning default value - GrpcMessage::default() - } - - fn default_instance() -> &'static Self - where - Self: Sized, - { - #[allow(non_upper_case_globals)] - static instance: ::protobuf::rt::LazyV2 = ::protobuf::rt::LazyV2::INIT; - instance.get(|| GrpcMessage::RateLimit(RateLimitRequest::new())) - } -} - -impl GrpcMessage { - // Using domain as ce_host for the time being, we might pass a DataType in the future. - pub fn new( - extension_type: ExtensionType, - domain: String, - descriptors: protobuf::RepeatedField, - ) -> Self { - match extension_type { - ExtensionType::RateLimit => { - GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)) - } - ExtensionType::Auth => GrpcMessage::Auth(AuthService::message(domain.clone())), - } - } -} - #[derive(Default)] pub struct GrpcService { #[allow(dead_code)] @@ -176,6 +46,7 @@ impl GrpcService { fn method(&self) -> &str { self.method } + #[allow(dead_code)] pub fn failure_mode(&self) -> &FailureMode { &self.extension.failure_mode } @@ -209,7 +80,7 @@ impl GrpcServiceHandler { &self, get_map_values_bytes_fn: GetMapValuesBytesFn, grpc_call_fn: GrpcCallFn, - message: GrpcMessage, + message: GrpcMessageRequest, ) -> Result { let msg = Message::write_to_bytes(&message).unwrap(); let metadata = self diff --git a/src/service/auth.rs b/src/service/auth.rs index 0831cd6c..ece695cb 100644 --- a/src/service/auth.rs +++ b/src/service/auth.rs @@ -3,10 +3,12 @@ use crate::envoy::{ Address, AttributeContext, AttributeContext_HttpRequest, AttributeContext_Peer, AttributeContext_Request, CheckRequest, Metadata, SocketAddress, }; +use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; use chrono::{DateTime, FixedOffset, Timelike}; use protobuf::well_known_types::Timestamp; +use protobuf::Message; use proxy_wasm::hostcalls; -use proxy_wasm::types::MapType; +use proxy_wasm::types::{Bytes, MapType}; use std::collections::HashMap; pub const AUTH_SERVICE_NAME: &str = "envoy.service.auth.v3.Authorization"; @@ -16,10 +18,35 @@ pub struct AuthService; #[allow(dead_code)] impl AuthService { - pub fn message(ce_host: String) -> CheckRequest { + pub fn request_message(ce_host: String) -> CheckRequest { AuthService::build_check_req(ce_host) } + pub fn response_message( + res_body_bytes: &Bytes, + status_code: u32, + ) -> GrpcMessageResult { + if status_code % 2 == 0 { + AuthService::response_message_ok(res_body_bytes) + } else { + AuthService::response_message_denied(res_body_bytes) + } + } + + fn response_message_ok(res_body_bytes: &Bytes) -> GrpcMessageResult { + match Message::parse_from_bytes(res_body_bytes) { + Ok(res) => Ok(GrpcMessageResponse::AuthOk(res)), + Err(e) => Err(e), + } + } + + fn response_message_denied(res_body_bytes: &Bytes) -> GrpcMessageResult { + match Message::parse_from_bytes(res_body_bytes) { + Ok(res) => Ok(GrpcMessageResponse::AuthDenied(res)), + Err(e) => Err(e), + } + } + fn build_check_req(ce_host: String) -> CheckRequest { let mut auth_req = CheckRequest::default(); let mut attr = AttributeContext::default(); diff --git a/src/service/grpc_message.rs b/src/service/grpc_message.rs new file mode 100644 index 00000000..0a6f514f --- /dev/null +++ b/src/service/grpc_message.rs @@ -0,0 +1,275 @@ +use crate::configuration::ExtensionType; +use crate::envoy::{ + CheckRequest, DeniedHttpResponse, OkHttpResponse, RateLimitDescriptor, RateLimitRequest, + RateLimitResponse, +}; +use crate::service::auth::AuthService; +use crate::service::rate_limit::RateLimitService; +use protobuf::reflect::MessageDescriptor; +use protobuf::{ + Clear, CodedInputStream, CodedOutputStream, Message, ProtobufError, ProtobufResult, + UnknownFields, +}; +use proxy_wasm::types::Bytes; +use std::any::Any; + +#[derive(Clone, Debug)] +pub enum GrpcMessageRequest { + Auth(CheckRequest), + RateLimit(RateLimitRequest), +} + +impl Default for GrpcMessageRequest { + fn default() -> Self { + GrpcMessageRequest::RateLimit(RateLimitRequest::new()) + } +} + +impl Clear for GrpcMessageRequest { + fn clear(&mut self) { + match self { + GrpcMessageRequest::Auth(msg) => msg.clear(), + GrpcMessageRequest::RateLimit(msg) => msg.clear(), + } + } +} + +impl Message for GrpcMessageRequest { + fn descriptor(&self) -> &'static MessageDescriptor { + match self { + GrpcMessageRequest::Auth(msg) => msg.descriptor(), + GrpcMessageRequest::RateLimit(msg) => msg.descriptor(), + } + } + + fn is_initialized(&self) -> bool { + match self { + GrpcMessageRequest::Auth(msg) => msg.is_initialized(), + GrpcMessageRequest::RateLimit(msg) => msg.is_initialized(), + } + } + + fn merge_from(&mut self, is: &mut CodedInputStream) -> ProtobufResult<()> { + match self { + GrpcMessageRequest::Auth(msg) => msg.merge_from(is), + GrpcMessageRequest::RateLimit(msg) => msg.merge_from(is), + } + } + + fn write_to_with_cached_sizes(&self, os: &mut CodedOutputStream) -> ProtobufResult<()> { + match self { + GrpcMessageRequest::Auth(msg) => msg.write_to_with_cached_sizes(os), + GrpcMessageRequest::RateLimit(msg) => msg.write_to_with_cached_sizes(os), + } + } + + fn write_to_bytes(&self) -> ProtobufResult> { + match self { + GrpcMessageRequest::Auth(msg) => msg.write_to_bytes(), + GrpcMessageRequest::RateLimit(msg) => msg.write_to_bytes(), + } + } + + fn compute_size(&self) -> u32 { + match self { + GrpcMessageRequest::Auth(msg) => msg.compute_size(), + GrpcMessageRequest::RateLimit(msg) => msg.compute_size(), + } + } + + fn get_cached_size(&self) -> u32 { + match self { + GrpcMessageRequest::Auth(msg) => msg.get_cached_size(), + GrpcMessageRequest::RateLimit(msg) => msg.get_cached_size(), + } + } + + fn get_unknown_fields(&self) -> &UnknownFields { + match self { + GrpcMessageRequest::Auth(msg) => msg.get_unknown_fields(), + GrpcMessageRequest::RateLimit(msg) => msg.get_unknown_fields(), + } + } + + fn mut_unknown_fields(&mut self) -> &mut UnknownFields { + match self { + GrpcMessageRequest::Auth(msg) => msg.mut_unknown_fields(), + GrpcMessageRequest::RateLimit(msg) => msg.mut_unknown_fields(), + } + } + + fn as_any(&self) -> &dyn Any { + match self { + GrpcMessageRequest::Auth(msg) => msg.as_any(), + GrpcMessageRequest::RateLimit(msg) => msg.as_any(), + } + } + + fn new() -> Self + where + Self: Sized, + { + // Returning default value + GrpcMessageRequest::default() + } + + fn default_instance() -> &'static Self + where + Self: Sized, + { + #[allow(non_upper_case_globals)] + static instance: ::protobuf::rt::LazyV2 = ::protobuf::rt::LazyV2::INIT; + instance.get(|| GrpcMessageRequest::RateLimit(RateLimitRequest::new())) + } +} + +impl GrpcMessageRequest { + // Using domain as ce_host for the time being, we might pass a DataType in the future. + pub fn new( + extension_type: ExtensionType, + domain: String, + descriptors: protobuf::RepeatedField, + ) -> Self { + match extension_type { + ExtensionType::RateLimit => GrpcMessageRequest::RateLimit( + RateLimitService::request_message(domain.clone(), descriptors), + ), + ExtensionType::Auth => { + GrpcMessageRequest::Auth(AuthService::request_message(domain.clone())) + } + } + } +} + +#[derive(Clone, Debug)] +pub enum GrpcMessageResponse { + AuthOk(OkHttpResponse), + AuthDenied(DeniedHttpResponse), + RateLimit(RateLimitResponse), +} + +impl Default for GrpcMessageResponse { + fn default() -> Self { + GrpcMessageResponse::RateLimit(RateLimitResponse::new()) + } +} + +impl Clear for GrpcMessageResponse { + fn clear(&mut self) { + todo!() + } +} + +impl Message for GrpcMessageResponse { + fn descriptor(&self) -> &'static MessageDescriptor { + match self { + GrpcMessageResponse::AuthOk(res) => res.descriptor(), + GrpcMessageResponse::AuthDenied(res) => res.descriptor(), + GrpcMessageResponse::RateLimit(res) => res.descriptor(), + } + } + + fn is_initialized(&self) -> bool { + match self { + GrpcMessageResponse::AuthOk(res) => res.is_initialized(), + GrpcMessageResponse::AuthDenied(res) => res.is_initialized(), + GrpcMessageResponse::RateLimit(res) => res.is_initialized(), + } + } + + fn merge_from(&mut self, is: &mut CodedInputStream) -> ProtobufResult<()> { + match self { + GrpcMessageResponse::AuthOk(res) => res.merge_from(is), + GrpcMessageResponse::AuthDenied(res) => res.merge_from(is), + GrpcMessageResponse::RateLimit(res) => res.merge_from(is), + } + } + + fn write_to_with_cached_sizes(&self, os: &mut CodedOutputStream) -> ProtobufResult<()> { + match self { + GrpcMessageResponse::AuthOk(res) => res.write_to_with_cached_sizes(os), + GrpcMessageResponse::AuthDenied(res) => res.write_to_with_cached_sizes(os), + GrpcMessageResponse::RateLimit(res) => res.write_to_with_cached_sizes(os), + } + } + + fn write_to_bytes(&self) -> ProtobufResult> { + match self { + GrpcMessageResponse::AuthOk(res) => res.write_to_bytes(), + GrpcMessageResponse::AuthDenied(res) => res.write_to_bytes(), + GrpcMessageResponse::RateLimit(res) => res.write_to_bytes(), + } + } + + fn compute_size(&self) -> u32 { + match self { + GrpcMessageResponse::AuthOk(res) => res.compute_size(), + GrpcMessageResponse::AuthDenied(res) => res.compute_size(), + GrpcMessageResponse::RateLimit(res) => res.compute_size(), + } + } + + fn get_cached_size(&self) -> u32 { + match self { + GrpcMessageResponse::AuthOk(res) => res.get_cached_size(), + GrpcMessageResponse::AuthDenied(res) => res.get_cached_size(), + GrpcMessageResponse::RateLimit(res) => res.get_cached_size(), + } + } + + fn get_unknown_fields(&self) -> &UnknownFields { + match self { + GrpcMessageResponse::AuthOk(res) => res.get_unknown_fields(), + GrpcMessageResponse::AuthDenied(res) => res.get_unknown_fields(), + GrpcMessageResponse::RateLimit(res) => res.get_unknown_fields(), + } + } + + fn mut_unknown_fields(&mut self) -> &mut UnknownFields { + match self { + GrpcMessageResponse::AuthOk(res) => res.mut_unknown_fields(), + GrpcMessageResponse::AuthDenied(res) => res.mut_unknown_fields(), + GrpcMessageResponse::RateLimit(res) => res.mut_unknown_fields(), + } + } + + fn as_any(&self) -> &dyn Any { + match self { + GrpcMessageResponse::AuthOk(res) => res.as_any(), + GrpcMessageResponse::AuthDenied(res) => res.as_any(), + GrpcMessageResponse::RateLimit(res) => res.as_any(), + } + } + + fn new() -> Self + where + Self: Sized, + { + // Returning default value + GrpcMessageResponse::default() + } + + fn default_instance() -> &'static Self + where + Self: Sized, + { + #[allow(non_upper_case_globals)] + static instance: ::protobuf::rt::LazyV2 = ::protobuf::rt::LazyV2::INIT; + instance.get(|| GrpcMessageResponse::RateLimit(RateLimitResponse::new())) + } +} + +impl GrpcMessageResponse { + pub fn new( + extension_type: &ExtensionType, + res_body_bytes: &Bytes, + status_code: u32, + ) -> GrpcMessageResult { + match extension_type { + ExtensionType::RateLimit => RateLimitService::response_message(res_body_bytes), + ExtensionType::Auth => AuthService::response_message(res_body_bytes, status_code), + } + } +} + +pub type GrpcMessageResult = Result; diff --git a/src/service/rate_limit.rs b/src/service/rate_limit.rs index b6b0357c..4a81884a 100644 --- a/src/service/rate_limit.rs +++ b/src/service/rate_limit.rs @@ -1,5 +1,7 @@ use crate::envoy::{RateLimitDescriptor, RateLimitRequest}; -use protobuf::RepeatedField; +use crate::service::grpc_message::{GrpcMessageResponse, GrpcMessageResult}; +use protobuf::{Message, RepeatedField}; +use proxy_wasm::types::Bytes; pub const RATELIMIT_SERVICE_NAME: &str = "envoy.service.ratelimit.v3.RateLimitService"; pub const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; @@ -7,7 +9,7 @@ pub const RATELIMIT_METHOD_NAME: &str = "ShouldRateLimit"; pub struct RateLimitService; impl RateLimitService { - pub fn message( + pub fn request_message( domain: String, descriptors: RepeatedField, ) -> RateLimitRequest { @@ -19,6 +21,13 @@ impl RateLimitService { cached_size: Default::default(), } } + + pub fn response_message(res_body_bytes: &Bytes) -> GrpcMessageResult { + match Message::parse_from_bytes(res_body_bytes) { + Ok(res) => Ok(GrpcMessageResponse::RateLimit(res)), + Err(e) => Err(e), + } + } } #[cfg(test)] @@ -37,7 +46,7 @@ mod tests { field.set_entries(RepeatedField::from_vec(vec![entry])); let descriptors = RepeatedField::from_vec(vec![field]); - RateLimitService::message(domain.to_string(), descriptors.clone()) + RateLimitService::request_message(domain.to_string(), descriptors.clone()) } #[test] fn builds_correct_message() {