Skip to content

Commit

Permalink
[refactor] Operation responsible of providing hostcalls fns
Browse files Browse the repository at this point in the history
* Easier to test, mocking fn
* Assigned fn on creation, default hostcall and mock on tests

Signed-off-by: dd di cesare <[email protected]>
  • Loading branch information
didierofrivia committed Sep 6, 2024
1 parent 159d247 commit 41b920d
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 67 deletions.
120 changes: 59 additions & 61 deletions src/operation_dispatcher.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use crate::configuration::{Extension, ExtensionType, FailureMode};
use crate::envoy::RateLimitDescriptor;
use crate::policy::Policy;
use crate::service::{GrpcMessage, GrpcServiceHandler};
use crate::service::{GetMapValuesBytes, GrpcCall, GrpcMessage, GrpcServiceHandler};
use protobuf::RepeatedField;
use proxy_wasm::hostcalls;
use proxy_wasm::types::Status;
use proxy_wasm::types::{Bytes, MapType, Status};
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
Expand All @@ -29,24 +29,6 @@ impl State {
}
}

fn grpc_call(
upstream_name: &str,
service_name: &str,
method_name: &str,
initial_metadata: Vec<(&str, &[u8])>,
message: Option<&[u8]>,
timeout: Duration,
) -> Result<u32, Status> {
hostcalls::dispatch_grpc_call(
upstream_name,
service_name,
method_name,
initial_metadata,
message,
timeout,
)
}

type Procedure = (Rc<GrpcServiceHandler>, GrpcMessage);

#[allow(dead_code)]
Expand All @@ -56,6 +38,8 @@ pub(crate) struct Operation {
result: Result<u32, Status>,
extension: Rc<Extension>,
procedure: Procedure,
grpc_call: GrpcCall,
get_map_values_bytes: GetMapValuesBytes,
}

#[allow(dead_code)]
Expand All @@ -66,17 +50,19 @@ impl Operation {
result: Err(Status::Empty),
extension,
procedure,
grpc_call,
get_map_values_bytes,
}
}

pub fn set_action(&mut self, procedure: Procedure) {
self.procedure = procedure;
}

pub fn trigger(&mut self) {
fn trigger(&mut self) {
if let State::Done = self.state {
} else {
self.result = self.procedure.0.send(grpc_call, self.procedure.1.clone());
self.result = self.procedure.0.send(
self.get_map_values_bytes,
self.grpc_call,
self.procedure.1.clone(),
);
self.state.next();
}
}
Expand Down Expand Up @@ -172,6 +158,28 @@ impl OperationDispatcher {
}
}

fn grpc_call(
upstream_name: &str,
service_name: &str,
method_name: &str,
initial_metadata: Vec<(&str, &[u8])>,
message: Option<&[u8]>,
timeout: Duration,
) -> Result<u32, Status> {
hostcalls::dispatch_grpc_call(
upstream_name,
service_name,
method_name,
initial_metadata,
message,
timeout,
)
}

fn get_map_values_bytes(map_type: MapType, key: &str) -> Result<Option<Bytes>, Status> {
hostcalls::get_map_value_bytes(map_type, key)
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -189,6 +197,10 @@ mod tests {
Ok(200)
}

fn get_map_values_bytes(_map_type: MapType, _key: &str) -> Result<Option<Bytes>, Status> {
Ok(Some(Vec::new()))
}

fn build_grpc_service_handler() -> GrpcServiceHandler {
GrpcServiceHandler::new(Rc::new(Default::default()), Rc::new(Default::default()))
}
Expand All @@ -203,33 +215,33 @@ mod tests {
}
}

#[test]
fn operation_getters() {
let extension = Rc::new(Extension::default());
let operation = Operation::new(
extension,
(
fn build_operation() -> Operation {
Operation {
state: State::Pending,
result: Ok(200),
extension: Rc::new(Extension::default()),
procedure: (
Rc::new(build_grpc_service_handler()),
GrpcMessage::RateLimit(build_message()),
),
);
grpc_call,
get_map_values_bytes,
}
}

#[test]
fn operation_getters() {
let operation = build_operation();

assert_eq!(operation.get_state(), State::Pending);
assert_eq!(operation.get_extension_type(), ExtensionType::RateLimit);
assert_eq!(operation.get_failure_mode(), FailureMode::Deny);
assert_eq!(operation.get_result(), Result::Ok(1));
assert_eq!(operation.get_result(), Ok(200));
}

#[test]
fn operation_transition() {
let extension = Rc::new(Extension::default());
let mut operation = Operation::new(
extension,
(
Rc::new(build_grpc_service_handler()),
GrpcMessage::RateLimit(build_message()),
),
);
let mut operation = build_operation();
assert_eq!(operation.get_state(), State::Pending);
operation.trigger();
assert_eq!(operation.get_state(), State::Waiting);
Expand All @@ -242,23 +254,16 @@ mod tests {
fn operation_dispatcher_push_actions() {
let operation_dispatcher = OperationDispatcher::default();

assert_eq!(operation_dispatcher.operations.borrow().len(), 1);
let extension = Rc::new(Extension::default());
operation_dispatcher.push_operations(vec![Operation::new(
extension,
(
Rc::new(build_grpc_service_handler()),
GrpcMessage::RateLimit(build_message()),
),
)]);
assert_eq!(operation_dispatcher.operations.borrow().len(), 0);
operation_dispatcher.push_operations(vec![build_operation()]);

assert_eq!(operation_dispatcher.operations.borrow().len(), 2);
assert_eq!(operation_dispatcher.operations.borrow().len(), 1);
}

#[test]
fn operation_dispatcher_get_current_action_state() {
let operation_dispatcher = OperationDispatcher::default();

operation_dispatcher.push_operations(vec![build_operation()]);
assert_eq!(
operation_dispatcher.get_current_operation_state(),
Some(State::Pending)
Expand All @@ -267,14 +272,7 @@ mod tests {

#[test]
fn operation_dispatcher_next() {
let extension = Rc::new(Extension::default());
let operation = Operation::new(
extension,
(
Rc::new(build_grpc_service_handler()),
GrpcMessage::RateLimit(build_message()),
),
);
let operation = build_operation();
let operation_dispatcher = OperationDispatcher::default();
operation_dispatcher.push_operations(vec![operation]);

Expand Down
18 changes: 12 additions & 6 deletions src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use protobuf::reflect::MessageDescriptor;
use protobuf::{
Clear, CodedInputStream, CodedOutputStream, Message, ProtobufResult, UnknownFields,
};
use proxy_wasm::hostcalls;
use proxy_wasm::types::{Bytes, MapType, Status};
use std::any::Any;
use std::cell::OnceCell;
Expand Down Expand Up @@ -182,7 +181,7 @@ impl GrpcService {
}
}

type GrpcCall = fn(
pub type GrpcCall = fn(
upstream_name: &str,
service_name: &str,
method_name: &str,
Expand All @@ -191,6 +190,8 @@ type GrpcCall = fn(
timeout: Duration,
) -> Result<u32, Status>;

pub type GetMapValuesBytes = fn(map_type: MapType, key: &str) -> Result<Option<Bytes>, Status>;

pub struct GrpcServiceHandler {
service: Rc<GrpcService>,
header_resolver: Rc<HeaderResolver>,
Expand All @@ -204,11 +205,16 @@ impl GrpcServiceHandler {
}
}

pub fn send(&self, grpc_call: GrpcCall, message: GrpcMessage) -> Result<u32, Status> {
pub fn send(
&self,
get_map_values_bytes: GetMapValuesBytes,
grpc_call: GrpcCall,
message: GrpcMessage,
) -> Result<u32, Status> {
let msg = Message::write_to_bytes(&message).unwrap();
let metadata = self
.header_resolver
.get()
.get(get_map_values_bytes)
.iter()
.map(|(header, value)| (*header, value.as_slice()))
.collect();
Expand Down Expand Up @@ -249,12 +255,12 @@ impl HeaderResolver {
}
}

pub fn get(&self) -> &Vec<(&'static str, Bytes)> {
pub fn get(&self, get_map_values_bytes: GetMapValuesBytes) -> &Vec<(&'static str, Bytes)> {
self.headers.get_or_init(|| {
let mut headers = Vec::new();
for header in TracingHeader::all() {
if let Ok(Some(value)) =
hostcalls::get_map_value_bytes(MapType::HttpRequestHeaders, (*header).as_str())
get_map_values_bytes(MapType::HttpRequestHeaders, (*header).as_str())
{
headers.push(((*header).as_str(), value));
}
Expand Down

0 comments on commit 41b920d

Please sign in to comment.