diff --git a/melior/src/utility.rs b/melior/src/utility.rs index ed450bf673..ac8fa9d0c3 100644 --- a/melior/src/utility.rs +++ b/melior/src/utility.rs @@ -1,12 +1,12 @@ //! Utility functions. use crate::{ - context::Context, dialect::DialectRegistry, logical_result::LogicalResult, pass, + context::Context, dialect::DialectRegistry, ir::Module, logical_result::LogicalResult, pass, string_ref::StringRef, Error, }; use mlir_sys::{ - mlirParsePassPipeline, mlirRegisterAllDialects, mlirRegisterAllLLVMTranslations, - mlirRegisterAllPasses, MlirStringRef, + mlirLoadIRDLDialects, mlirParsePassPipeline, mlirRegisterAllDialects, + mlirRegisterAllLLVMTranslations, mlirRegisterAllPasses, MlirStringRef, }; use std::{ ffi::c_void, @@ -54,6 +54,11 @@ pub fn parse_pass_pipeline(manager: pass::OperationPassManager, source: &str) -> } } +/// Loads all IRDL dialects in the provided module, registering the dialects in the module's associated context. +pub fn load_irdl_dialects(module: &Module) -> bool { + unsafe { mlirLoadIRDLDialects(module.to_raw()).value == 1 } +} + unsafe extern "C" fn handle_parse_error(raw_string: MlirStringRef, data: *mut c_void) { let string = StringRef::from_raw(raw_string); let data = &mut *(data as *mut Option); @@ -99,6 +104,8 @@ pub(crate) unsafe extern "C" fn print_string_callback(string: MlirStringRef, dat #[cfg(test)] mod tests { + use crate::ir::Location; + use super::*; #[test] @@ -148,4 +155,12 @@ mod tests { register_all_passes(); } } + + #[test] + fn test_load_irdl_dialects() { + let context = Context::new(); + let module = Module::new(Location::unknown(&context)); + + assert!(load_irdl_dialects(&module)); + } }