diff --git a/pgrx-tests/src/tests/hooks_tests.rs b/pgrx-tests/src/tests/hooks_tests.rs index bc46d21259..67aefcc49c 100644 --- a/pgrx-tests/src/tests/hooks_tests.rs +++ b/pgrx-tests/src/tests/hooks_tests.rs @@ -121,13 +121,34 @@ mod tests { self.events += 1; prev_hook(parse_state, query, jumble_state) } + + fn object_access_hook( + &mut self, + access: pg_sys::ObjectAccessType, + class_id: pg_sys::Oid, + object_id: pg_sys::Oid, + sub_id: ::std::os::raw::c_int, + arg: *mut ::std::os::raw::c_void, + prev_hook: fn( + access: pg_sys::ObjectAccessType, + class_id: pg_sys::Oid, + object_id: pg_sys::Oid, + sub_id: ::std::os::raw::c_int, + arg: *mut ::std::os::raw::c_void, + ) -> HookResult<()>, + ) -> HookResult<()> { + self.events += 1; + prev_hook(access, class_id, object_id, sub_id, arg) + } } + Spi::run("CREATE TABLE foo (bar int)").expect("SPI failed"); + static mut HOOK: TestHook = TestHook { events: 0 }; pgrx::hooks::register_hook(&mut HOOK); // To trigger the emit_log hook, we need something to log. // We therefore ensure the select statement will be logged. - Spi::run("SET local log_statement to 'all'; SELECT 1").expect("SPI failed"); + Spi::run("SET local log_statement to 'all'; SELECT * from foo").expect("SPI failed"); assert_eq!(8, HOOK.events); // TODO: it'd be nice to also test that .commit() and .abort() also get called diff --git a/pgrx/src/hooks.rs b/pgrx/src/hooks.rs index 1ed73855e1..6947b3866c 100644 --- a/pgrx/src/hooks.rs +++ b/pgrx/src/hooks.rs @@ -212,7 +212,7 @@ struct Hooks { prev_process_utility_hook: pg_sys::ProcessUtility_hook_type, prev_planner_hook: pg_sys::planner_hook_type, prev_post_parse_analyze_hook: pg_sys::post_parse_analyze_hook_type, - object_access_hook: pg_sys::object_access_hook_type + prev_object_access_hook: pg_sys::object_access_hook_type, } static mut HOOKS: Option = None; @@ -256,9 +256,7 @@ pub unsafe fn register_hook(hook: &'static mut (dyn PgHooks)) { prev_post_parse_analyze_hook: pg_sys::post_parse_analyze_hook .replace(pgrx_post_parse_analyze), prev_emit_log_hook: pg_sys::emit_log_hook.replace(pgrx_emit_log), - object_access_hook: pg_sys::object_access_hook_type - .replace(pgrx_object_access_hook) - .or(Some(pgrx_standard_object_access_hook_wrapper)), + prev_object_access_hook: pg_sys::object_access_hook.replace(pgrx_object_access_hook), }); #[pg_guard] @@ -630,6 +628,33 @@ unsafe extern "C" fn pgrx_emit_log(error_data: *mut pg_sys::ErrorData) { hook.emit_log(PgBox::from_pg(error_data), prev).inner } +#[pg_guard] +unsafe extern "C" fn pgrx_object_access_hook( + access: pg_sys::ObjectAccessType, + class_id: pg_sys::Oid, + object_id: pg_sys::Oid, + sub_id: ::std::os::raw::c_int, + arg: *mut ::std::os::raw::c_void, +) { + fn prev( + access: pg_sys::ObjectAccessType, + class_id: pg_sys::Oid, + object_id: pg_sys::Oid, + sub_id: ::std::os::raw::c_int, + arg: *mut ::std::os::raw::c_void, + ) -> HookResult<()> { + HookResult::new(unsafe { + match HOOKS.as_mut().unwrap().prev_object_access_hook.as_ref() { + None => (), + Some(f) => (f)(access, class_id, object_id, sub_id, arg), + } + }) + } + + let hook = &mut HOOKS.as_mut().unwrap().current_hook; + hook.object_access_hook(access, class_id, object_id, sub_id, arg, prev).inner +} + #[pg_guard] unsafe extern "C" fn pgrx_standard_executor_start_wrapper( query_desc: *mut pg_sys::QueryDesc,