diff --git a/examples/http_auth_random.rs b/examples/http_auth_random.rs index b9e747c6..ddb5fbf7 100644 --- a/examples/http_auth_random.rs +++ b/examples/http_auth_random.rs @@ -20,7 +20,17 @@ use std::time::Duration; #[no_mangle] pub fn _start() { proxy_wasm::set_log_level(LogLevel::Trace); - proxy_wasm::set_http_context(|_, _| -> Box { Box::new(HttpAuthRandom) }); + proxy_wasm::set_root_context(|_| -> Box { Box::new(HttpAuthRandomRoot) }); +} + +struct HttpAuthRandomRoot; + +impl Context for HttpAuthRandomRoot {} + +impl RootContext for HttpAuthRandomRoot { + fn on_create_child_context(&mut self, _context_id: u32) -> Option { + Some(ChildContext::HttpContext(Box::new(HttpAuthRandom))) + } } struct HttpAuthRandom; diff --git a/examples/http_body.rs b/examples/http_body.rs index ff884648..cf160fd0 100644 --- a/examples/http_body.rs +++ b/examples/http_body.rs @@ -26,12 +26,8 @@ struct HttpBodyRoot; impl Context for HttpBodyRoot {} impl RootContext for HttpBodyRoot { - fn get_type(&self) -> Option { - Some(ContextType::HttpContext) - } - - fn create_http_context(&self, _: u32) -> Option> { - Some(Box::new(HttpBody)) + fn on_create_child_context(&mut self, _context_id: u32) -> Option { + Some(ChildContext::HttpContext(Box::new(HttpBody))) } } diff --git a/examples/http_config.rs b/examples/http_config.rs index d912ae03..81838315 100644 --- a/examples/http_config.rs +++ b/examples/http_config.rs @@ -52,13 +52,9 @@ impl RootContext for HttpConfigHeaderRoot { true } - fn create_http_context(&self, _: u32) -> Option> { - Some(Box::new(HttpConfigHeader { + fn on_create_child_context(&mut self, _context_id: u32) -> Option { + Some(ChildContext::HttpContext(Box::new(HttpConfigHeader { header_content: self.header_content.clone(), - })) - } - - fn get_type(&self) -> Option { - Some(ContextType::HttpContext) + }))) } } diff --git a/examples/http_headers.rs b/examples/http_headers.rs index b0f1a745..d9692698 100644 --- a/examples/http_headers.rs +++ b/examples/http_headers.rs @@ -27,12 +27,10 @@ struct HttpHeadersRoot; impl Context for HttpHeadersRoot {} impl RootContext for HttpHeadersRoot { - fn get_type(&self) -> Option { - Some(ContextType::HttpContext) - } - - fn create_http_context(&self, context_id: u32) -> Option> { - Some(Box::new(HttpHeaders { context_id })) + fn on_create_child_context(&mut self, context_id: u32) -> Option { + Some(ChildContext::HttpContext(Box::new(HttpHeaders { + context_id, + }))) } } diff --git a/src/dispatcher.rs b/src/dispatcher.rs index d9c9d491..769921d0 100644 --- a/src/dispatcher.rs +++ b/src/dispatcher.rs @@ -26,14 +26,6 @@ pub(crate) fn set_root_context(callback: NewRootContext) { DISPATCHER.with(|dispatcher| dispatcher.set_root_context(callback)); } -pub(crate) fn set_stream_context(callback: NewStreamContext) { - DISPATCHER.with(|dispatcher| dispatcher.set_stream_context(callback)); -} - -pub(crate) fn set_http_context(callback: NewHttpContext) { - DISPATCHER.with(|dispatcher| dispatcher.set_http_context(callback)); -} - pub(crate) fn register_callout(token_id: u32) { DISPATCHER.with(|dispatcher| dispatcher.register_callout(token_id)); } @@ -46,9 +38,7 @@ impl RootContext for NoopRoot {} struct Dispatcher { new_root: Cell>, roots: RefCell>>, - new_stream: Cell>, streams: RefCell>>, - new_http_stream: Cell>, http_streams: RefCell>>, active_id: Cell, callouts: RefCell>, @@ -59,9 +49,7 @@ impl Dispatcher { Dispatcher { new_root: Cell::new(None), roots: RefCell::new(HashMap::new()), - new_stream: Cell::new(None), streams: RefCell::new(HashMap::new()), - new_http_stream: Cell::new(None), http_streams: RefCell::new(HashMap::new()), active_id: Cell::new(0), callouts: RefCell::new(HashMap::new()), @@ -72,14 +60,6 @@ impl Dispatcher { self.new_root.set(Some(callback)); } - fn set_stream_context(&self, callback: NewStreamContext) { - self.new_stream.set(Some(callback)); - } - - fn set_http_context(&self, callback: NewHttpContext) { - self.new_http_stream.set(Some(callback)); - } - fn register_callout(&self, token_id: u32) { if self .callouts @@ -106,67 +86,46 @@ impl Dispatcher { } } - fn create_stream_context(&self, context_id: u32, root_context_id: u32) { - let new_context = match self.roots.borrow().get(&root_context_id) { - Some(root_context) => match self.new_stream.get() { - Some(f) => f(context_id, root_context_id), - None => match root_context.create_stream_context(context_id) { - Some(stream_context) => stream_context, - None => panic!("create_stream_context returned None"), - }, - }, - None => panic!("invalid root_context_id"), - }; + fn register_stream_context(&self, context_id: u32, stream_context: Box) { if self .streams .borrow_mut() - .insert(context_id, new_context) + .insert(context_id, stream_context) .is_some() { - panic!("duplicate context_id") + panic!("duplicate context_id {}", context_id); } } - fn create_http_context(&self, context_id: u32, root_context_id: u32) { - let new_context = match self.roots.borrow().get(&root_context_id) { - Some(root_context) => match self.new_http_stream.get() { - Some(f) => f(context_id, root_context_id), - None => match root_context.create_http_context(context_id) { - Some(stream_context) => stream_context, - None => panic!("create_http_context returned None"), - }, - }, - None => panic!("invalid root_context_id"), - }; + fn register_http_context(&self, context_id: u32, http_context: Box) { if self .http_streams .borrow_mut() - .insert(context_id, new_context) + .insert(context_id, http_context) .is_some() { - panic!("duplicate context_id") + panic!("duplicate context_id {}", context_id); } } fn on_create_context(&self, context_id: u32, root_context_id: u32) { if root_context_id == 0 { self.create_root_context(context_id); - } else if self.new_http_stream.get().is_some() { - self.create_http_context(context_id, root_context_id); - } else if self.new_stream.get().is_some() { - self.create_stream_context(context_id, root_context_id); - } else if let Some(root_context) = self.roots.borrow().get(&root_context_id) { - match root_context.get_type() { - Some(ContextType::HttpContext) => { - self.create_http_context(context_id, root_context_id) + return; + } + + if let Some(root_context) = self.roots.borrow_mut().get_mut(&root_context_id) { + match root_context.on_create_child_context(context_id) { + Some(ChildContext::HttpContext(http_context)) => { + self.register_http_context(context_id, http_context); } - Some(ContextType::StreamContext) => { - self.create_stream_context(context_id, root_context_id) + Some(ChildContext::StreamContext(stream_context)) => { + self.register_stream_context(context_id, stream_context); } - None => panic!("missing ContextType on root_context"), + None => panic!("you must implement on_create_child_context in your root context"), } } else { - panic!("invalid root_context_id and missing constructors"); + panic!("invalid root_context_id {}", root_context_id); } } diff --git a/src/lib.rs b/src/lib.rs index cee98b72..ff55cbac 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,13 +28,5 @@ pub fn set_root_context(callback: types::NewRootContext) { dispatcher::set_root_context(callback); } -pub fn set_stream_context(callback: types::NewStreamContext) { - dispatcher::set_stream_context(callback); -} - -pub fn set_http_context(callback: types::NewHttpContext) { - dispatcher::set_http_context(callback); -} - #[no_mangle] pub extern "C" fn proxy_abi_version_0_1_0() {} diff --git a/src/traits.rs b/src/traits.rs index 5b7fc4be..a02a2512 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -122,15 +122,8 @@ pub trait RootContext: Context { fn on_log(&mut self) {} - fn create_http_context(&self, _context_id: u32) -> Option> { - None - } - - fn create_stream_context(&self, _context_id: u32) -> Option> { - None - } - - fn get_type(&self) -> Option { + fn on_create_child_context(&mut self, _context_id: u32) -> Option { + // on_create_child_context is required to create non root contexts None } } diff --git a/src/types.rs b/src/types.rs index 855a414b..743a9f8a 100644 --- a/src/types.rs +++ b/src/types.rs @@ -15,8 +15,11 @@ use crate::traits::*; pub type NewRootContext = fn(context_id: u32) -> Box; -pub type NewStreamContext = fn(context_id: u32, root_context_id: u32) -> Box; -pub type NewHttpContext = fn(context_id: u32, root_context_id: u32) -> Box; + +pub enum ChildContext { + StreamContext(Box), + HttpContext(Box), +} #[repr(u32)] #[derive(Debug)] @@ -47,13 +50,6 @@ pub enum Status { InternalFailure = 10, } -#[repr(u32)] -#[derive(Debug)] -pub enum ContextType { - HttpContext = 0, - StreamContext = 1, -} - #[repr(u32)] #[derive(Debug)] pub enum BufferType {