Skip to content

Commit

Permalink
[refactor] Implementing own Message for GrpcMessage
Browse files Browse the repository at this point in the history
Signed-off-by: dd di cesare <[email protected]>
  • Loading branch information
didierofrivia committed Sep 4, 2024
1 parent c4796de commit ca3cdff
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/operation_dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl OperationDispatcher {
policy.actions.iter().for_each(|action| {
// TODO(didierofrivia): Error handling
if let Some(service) = self.service_handlers.get(&action.extension) {
let message = service.build_message(policy.domain.clone(), descriptors.clone());
let message = GrpcMessage::new(service.get_extension_type(), policy.domain.clone(), descriptors.clone());
operations.push(Operation::new((service.clone(), message)))
}
});
Expand Down
145 changes: 123 additions & 22 deletions src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,144 @@ pub(crate) mod auth;
pub(crate) mod rate_limit;

use crate::configuration::{ExtensionType, FailureMode};
use crate::envoy::{RateLimitDescriptor, RateLimitRequest};
use crate::service::auth::{AUTH_METHOD_NAME, AUTH_SERVICE_NAME};
use crate::envoy::{CheckRequest, RateLimitDescriptor, RateLimitRequest};
use crate::service::auth::{AuthService, AUTH_METHOD_NAME, AUTH_SERVICE_NAME};
use crate::service::rate_limit::{RateLimitService, RATELIMIT_METHOD_NAME, RATELIMIT_SERVICE_NAME};
use crate::service::TracingHeader::{Baggage, Traceparent, Tracestate};
use protobuf::Message;
use protobuf::reflect::MessageDescriptor;
use protobuf::{
Clear, CodedInputStream, CodedOutputStream, Message, ProtobufResult, UnknownFields,
};
use proxy_wasm::hostcalls;
use proxy_wasm::hostcalls::dispatch_grpc_call;
use proxy_wasm::types::{Bytes, MapType, Status};
use std::any::Any;
use std::cell::OnceCell;
use std::fmt::{Debug};
use std::rc::Rc;
use std::time::Duration;

#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum GrpcMessage {
//Auth(CheckRequest),
Auth(CheckRequest),
RateLimit(RateLimitRequest),
}

impl GrpcMessage {
pub fn get_message(&self) -> &RateLimitRequest {
//TODO(didierofrivia): Should return Message
impl Default for GrpcMessage {
fn default() -> Self {
GrpcMessage::RateLimit(RateLimitRequest::new())
}
}

impl Clear for GrpcMessage {
fn clear(&mut self) {
match self {
GrpcMessage::RateLimit(message) => message,
GrpcMessage::Auth(msg) => msg.clear(),
GrpcMessage::RateLimit(msg) => msg.clear(),
}
}
}

impl Message for GrpcMessage {
fn descriptor(&self) -> &'static MessageDescriptor {
match self {
GrpcMessage::Auth(msg) => msg.descriptor(),
GrpcMessage::RateLimit(msg) => msg.descriptor(),
}
}

fn is_initialized(&self) -> bool {
match self {
GrpcMessage::Auth(msg) => msg.is_initialized(),
GrpcMessage::RateLimit(msg) => msg.is_initialized(),
}
}

fn merge_from(&mut self, is: &mut CodedInputStream) -> ProtobufResult<()> {
match self {
GrpcMessage::Auth(msg) => msg.merge_from(is),
GrpcMessage::RateLimit(msg) => msg.merge_from(is),
}
}

fn write_to_with_cached_sizes(&self, os: &mut CodedOutputStream) -> ProtobufResult<()> {
match self {
GrpcMessage::Auth(msg) => msg.write_to_with_cached_sizes(os),
GrpcMessage::RateLimit(msg) => msg.write_to_with_cached_sizes(os),
}
}

fn write_to_bytes(&self) -> ProtobufResult<Vec<u8>> {
match self {
GrpcMessage::Auth(msg) => msg.write_to_bytes(),
GrpcMessage::RateLimit(msg) => msg.write_to_bytes(),
}
}

fn compute_size(&self) -> u32 {
match self {
GrpcMessage::Auth(msg) => msg.compute_size(),
GrpcMessage::RateLimit(msg) => msg.compute_size(),
}
}

fn get_cached_size(&self) -> u32 {
match self {
GrpcMessage::Auth(msg) => msg.get_cached_size(),
GrpcMessage::RateLimit(msg) => msg.get_cached_size(),
}
}

fn get_unknown_fields(&self) -> &UnknownFields {
match self {
GrpcMessage::Auth(msg) => msg.get_unknown_fields(),
GrpcMessage::RateLimit(msg) => msg.get_unknown_fields(),
}
}

fn mut_unknown_fields(&mut self) -> &mut UnknownFields {
match self {
GrpcMessage::Auth(msg) => msg.mut_unknown_fields(),
GrpcMessage::RateLimit(msg) => msg.mut_unknown_fields(),
}
}

fn as_any(&self) -> &dyn Any {
match self {
GrpcMessage::Auth(msg) => msg.as_any(),
GrpcMessage::RateLimit(msg) => msg.as_any(),
}
}

fn new() -> Self
where
Self: Sized,
{
// Returning default value
GrpcMessage::default()
}

fn default_instance() -> &'static Self
where
Self: Sized,
{
#[allow(non_upper_case_globals)]
static instance: ::protobuf::rt::LazyV2<GrpcMessage> = ::protobuf::rt::LazyV2::INIT;
instance.get(|| GrpcMessage::RateLimit(RateLimitRequest::new()))
}
}

impl GrpcMessage {
// Using domain as ce_host for the time being, we might pass a DataType in the future.
pub fn new(extension_type: ExtensionType, domain: String, descriptors: protobuf::RepeatedField<RateLimitDescriptor>) -> Self {
match extension_type {
ExtensionType::RateLimit => GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)),
ExtensionType::Auth => GrpcMessage::Auth(AuthService::message(domain.clone()))
}
}

}

#[derive(Default)]
pub struct GrpcService {
endpoint: String,
Expand Down Expand Up @@ -102,7 +213,7 @@ impl GrpcServiceHandler {
}

pub fn send(&self, message: GrpcMessage) -> Result<u32, Status> {
let msg = Message::write_to_bytes(message.get_message()).unwrap();
let msg = Message::write_to_bytes(&message).unwrap();
let metadata = self
.header_resolver
.get()
Expand All @@ -120,18 +231,8 @@ impl GrpcServiceHandler {
)
}

// Using domain as ce_host for the time being, we might pass a DataType in the future.
//TODO(didierofrivia): Make it work with Message. for both Auth and RL
pub fn build_message(
&self,
domain: String,
descriptors: protobuf::RepeatedField<RateLimitDescriptor>,
) -> GrpcMessage {
/*match self.service.extension_type {
//ExtensionType::Auth => GrpcMessage::Auth(AuthService::message(domain.clone())),
//ExtensionType::RateLimit => GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors)),
}*/
GrpcMessage::RateLimit(RateLimitService::message(domain.clone(), descriptors))
pub fn get_extension_type(&self) -> ExtensionType {
self.service.extension_type.clone()
}
}

Expand Down

0 comments on commit ca3cdff

Please sign in to comment.