From c7742793bf15744e5afbc3610ad186fce57d679a Mon Sep 17 00:00:00 2001 From: Yota Toyama Date: Wed, 18 Oct 2023 20:46:24 +1100 Subject: [PATCH] Use C++ style strings in most places (#354) --- melior/src/context.rs | 24 +++++-------- melior/src/dialect/llvm.rs | 4 +-- melior/src/dialect/ods.rs | 16 ++++----- melior/src/execution_engine.rs | 4 +-- melior/src/ir/module.rs | 15 ++++---- melior/src/pass/manager.rs | 17 +++++---- melior/src/pass/operation_manager.rs | 13 +++---- melior/src/string_ref.rs | 52 ++++++++++++++++++++-------- 8 files changed, 79 insertions(+), 66 deletions(-) diff --git a/melior/src/context.rs b/melior/src/context.rs index 0bc2a34002..8e5407f7f5 100644 --- a/melior/src/context.rs +++ b/melior/src/context.rs @@ -13,10 +13,7 @@ use mlir_sys::{ mlirContextIsRegisteredOperation, mlirContextLoadAllAvailableDialects, mlirContextSetAllowUnregisteredDialects, MlirContext, MlirDiagnostic, MlirLogicalResult, }; -use std::{ - ffi::{c_void, CString}, - marker::PhantomData, -}; +use std::{ffi::c_void, marker::PhantomData, pin::Pin}; /// A context of IR, dialects, and passes. /// @@ -27,7 +24,7 @@ pub struct Context { raw: MlirContext, // We need to pass null-terminated strings to functions in the MLIR API although // Rust's strings are not. - string_cache: DashMap, + string_cache: DashMap, ()>, } impl Context { @@ -51,12 +48,9 @@ impl Context { /// Gets or loads a dialect. pub fn get_or_load_dialect(&self, name: &str) -> Dialect { - unsafe { - Dialect::from_raw(mlirContextGetOrLoadDialect( - self.raw, - StringRef::from_str(self, name).to_raw(), - )) - } + let name = StringRef::new(name); + + unsafe { Dialect::from_raw(mlirContextGetOrLoadDialect(self.raw, name.to_raw())) } } /// Appends a dialect registry. @@ -86,9 +80,9 @@ impl Context { /// Returns `true` if a given operation is registered in a context. pub fn is_registered_operation(&self, name: &str) -> bool { - unsafe { - mlirContextIsRegisteredOperation(self.raw, StringRef::from_str(self, name).to_raw()) - } + let name = StringRef::new(name); + + unsafe { mlirContextIsRegisteredOperation(self.raw, name.to_raw()) } } /// Converts a context into a raw object. @@ -131,7 +125,7 @@ impl Context { unsafe { ContextRef::from_raw(self.to_raw()) } } - pub(crate) fn string_cache(&self) -> &DashMap { + pub(crate) fn string_cache(&self) -> &DashMap, ()> { &self.string_cache } } diff --git a/melior/src/dialect/llvm.rs b/melior/src/dialect/llvm.rs index c87071423a..af6a28ff83 100644 --- a/melior/src/dialect/llvm.rs +++ b/melior/src/dialect/llvm.rs @@ -406,10 +406,10 @@ mod tests { pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager - .nested_under(context, "func.func") + .nested_under("func.func") .add_pass(pass::conversion::create_arith_to_llvm()); pass_manager - .nested_under(context, "func.func") + .nested_under("func.func") .add_pass(pass::conversion::create_index_to_llvm()); pass_manager.add_pass(pass::conversion::create_scf_to_control_flow()); pass_manager.add_pass(pass::conversion::create_control_flow_to_llvm()); diff --git a/melior/src/dialect/ods.rs b/melior/src/dialect/ods.rs index 1c764ce274..1e9f7bf067 100644 --- a/melior/src/dialect/ods.rs +++ b/melior/src/dialect/ods.rs @@ -129,10 +129,10 @@ mod tests { pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager - .nested_under(context, "func.func") + .nested_under("func.func") .add_pass(pass::conversion::create_arith_to_llvm()); pass_manager - .nested_under(context, "func.func") + .nested_under("func.func") .add_pass(pass::conversion::create_index_to_llvm()); pass_manager.add_pass(pass::conversion::create_scf_to_control_flow()); pass_manager.add_pass(pass::conversion::create_control_flow_to_llvm()); @@ -148,13 +148,13 @@ mod tests { argument_types: &[Type<'c>], callback: impl FnOnce(&Block<'c>), ) { - let location = Location::unknown(&context); + let location = Location::unknown(context); let mut module = Module::new(location); module.body().append_operation(func::func( - &context, - StringAttribute::new(&context, "foo"), - TypeAttribute::new(FunctionType::new(&context, argument_types, &[]).into()), + context, + StringAttribute::new(context, "foo"), + TypeAttribute::new(FunctionType::new(context, argument_types, &[]).into()), { let block = Block::new( &argument_types @@ -174,7 +174,7 @@ mod tests { location, )); - convert_module(&context, &mut module); + convert_module(context, &mut module); assert!(module.as_operation().verify()); insta::assert_display_snapshot!(name, module.as_operation()); @@ -193,7 +193,7 @@ mod tests { block.append_operation( llvm::alloca( &context, - dialect::llvm::r#type::pointer(i64_type.into(), 0).into(), + dialect::llvm::r#type::pointer(i64_type.into(), 0), alloca_size, location, ) diff --git a/melior/src/execution_engine.rs b/melior/src/execution_engine.rs index a84bf97ac9..367ba8a240 100644 --- a/melior/src/execution_engine.rs +++ b/melior/src/execution_engine.rs @@ -120,7 +120,7 @@ mod tests { pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager - .nested_under(&context, "func.func") + .nested_under("func.func") .add_pass(pass::conversion::create_arith_to_llvm()); assert_eq!(pass_manager.run(&mut module), Ok(())); @@ -169,7 +169,7 @@ mod tests { pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager - .nested_under(&context, "func.func") + .nested_under("func.func") .add_pass(pass::conversion::create_arith_to_llvm()); assert_eq!(pass_manager.run(&mut module), Ok(())); diff --git a/melior/src/ir/module.rs b/melior/src/ir/module.rs index 9c20329002..a853c82644 100644 --- a/melior/src/ir/module.rs +++ b/melior/src/ir/module.rs @@ -7,7 +7,7 @@ use mlir_sys::{ mlirModuleCreateEmpty, mlirModuleCreateParse, mlirModuleDestroy, mlirModuleFromOperation, mlirModuleGetBody, mlirModuleGetContext, mlirModuleGetOperation, MlirModule, }; -use std::marker::PhantomData; +use std::{ffi::CString, marker::PhantomData}; /// A module. #[derive(Debug)] @@ -24,13 +24,12 @@ impl<'c> Module<'c> { /// Parses a module. pub fn parse(context: &Context, source: &str) -> Option { - // TODO Should we allocate StringRef locally because sources can be big? - unsafe { - Self::from_option_raw(mlirModuleCreateParse( - context.to_raw(), - StringRef::from_str(context, source).to_raw(), - )) - } + // TODO Use a string not null-terminated. + // Somehow, we still need a null-terminated string for a source. + let source = CString::new(source).unwrap(); + let source = StringRef::from_c_str(&source); + + unsafe { Self::from_option_raw(mlirModuleCreateParse(context.to_raw(), source.to_raw())) } } /// Converts a module into an operation. diff --git a/melior/src/pass/manager.rs b/melior/src/pass/manager.rs index 4f99588852..fca584eed6 100644 --- a/melior/src/pass/manager.rs +++ b/melior/src/pass/manager.rs @@ -28,12 +28,11 @@ impl<'c> PassManager<'c> { /// Gets an operation pass manager for nested operations corresponding to a /// given name. - pub fn nested_under(&self, context: &'c Context, name: &str) -> OperationPassManager { + pub fn nested_under(&self, name: &str) -> OperationPassManager { + let name = StringRef::new(name); + unsafe { - OperationPassManager::from_raw(mlirPassManagerGetNestedUnder( - self.raw, - StringRef::from_str(context, name).to_raw(), - )) + OperationPassManager::from_raw(mlirPassManagerGetNestedUnder(self.raw, name.to_raw())) } } @@ -178,15 +177,15 @@ mod tests { let manager = PassManager::new(&context); manager - .nested_under(&context, "func.func") + .nested_under("func.func") .add_pass(pass::transform::create_print_op_stats()); assert_eq!(manager.run(&mut module), Ok(())); let manager = PassManager::new(&context); manager - .nested_under(&context, "builtin.module") - .nested_under(&context, "func.func") + .nested_under("builtin.module") + .nested_under("func.func") .add_pass(pass::transform::create_print_op_stats()); assert_eq!(manager.run(&mut module), Ok(())); @@ -196,7 +195,7 @@ mod tests { fn print_pass_pipeline() { let context = create_test_context(); let manager = PassManager::new(&context); - let function_manager = manager.nested_under(&context, "func.func"); + let function_manager = manager.nested_under("func.func"); function_manager.add_pass(pass::transform::create_print_op_stats()); diff --git a/melior/src/pass/operation_manager.rs b/melior/src/pass/operation_manager.rs index e4ec956bdb..41e4051538 100644 --- a/melior/src/pass/operation_manager.rs +++ b/melior/src/pass/operation_manager.rs @@ -1,5 +1,5 @@ use super::PassManager; -use crate::{pass::Pass, string_ref::StringRef, Context}; +use crate::{pass::Pass, string_ref::StringRef}; use mlir_sys::{ mlirOpPassManagerAddOwnedPass, mlirOpPassManagerGetNestedUnder, mlirPrintPassPipeline, MlirOpPassManager, MlirStringRef, @@ -20,13 +20,10 @@ pub struct OperationPassManager<'c, 'a> { impl<'c, 'a> OperationPassManager<'c, 'a> { /// Gets an operation pass manager for nested operations corresponding to a /// given name. - pub fn nested_under(&self, context: &'c Context, name: &str) -> Self { - unsafe { - Self::from_raw(mlirOpPassManagerGetNestedUnder( - self.raw, - StringRef::from_str(context, name).to_raw(), - )) - } + pub fn nested_under(&self, name: &str) -> Self { + let name = StringRef::new(name); + + unsafe { Self::from_raw(mlirOpPassManagerGetNestedUnder(self.raw, name.to_raw())) } } /// Adds a pass. diff --git a/melior/src/string_ref.rs b/melior/src/string_ref.rs index 5ef68a1216..bbce2b1ac4 100644 --- a/melior/src/string_ref.rs +++ b/melior/src/string_ref.rs @@ -1,8 +1,9 @@ use crate::Context; -use mlir_sys::{mlirStringRefCreateFromCString, mlirStringRefEqual, MlirStringRef}; +use mlir_sys::{mlirStringRefEqual, MlirStringRef}; use std::{ - ffi::CString, + ffi::CStr, marker::PhantomData, + pin::Pin, slice, str::{self, Utf8Error}, }; @@ -13,25 +14,48 @@ use std::{ // TODO The documentation says string refs do not have to be null-terminated. // But it looks like some functions do not handle strings not null-terminated? #[derive(Clone, Copy, Debug)] -pub struct StringRef<'c> { +pub struct StringRef<'a> { raw: MlirStringRef, - _parent: PhantomData<&'c Context>, + _parent: PhantomData<&'a str>, } -impl<'c> StringRef<'c> { - pub fn from_str(context: &'c Context, string: &str) -> Self { - let string = context - .string_cache() - .entry(CString::new(string).unwrap()) - .or_default() - .key() - .as_ptr(); +impl<'a> StringRef<'a> { + /// Creates a string reference. + pub fn new(string: &'a str) -> Self { + let string = MlirStringRef { + data: string.as_bytes().as_ptr() as *const i8, + length: string.len(), + }; + + unsafe { Self::from_raw(string) } + } + + /// Converts a C-style string into a string reference. + pub fn from_c_str(string: &'a CStr) -> Self { + let string = MlirStringRef { + data: string.as_ptr(), + length: string.to_bytes_with_nul().len() - 1, + }; - unsafe { Self::from_raw(mlirStringRefCreateFromCString(string)) } + unsafe { Self::from_raw(string) } + } + + /// Converts a string into a null-terminated string reference. + pub fn from_str(context: &'a Context, string: &str) -> Self { + let entry = context + .string_cache() + .entry(Pin::new(string.into())) + .or_default(); + let string = MlirStringRef { + data: entry.key().as_bytes().as_ptr() as *const i8, + length: entry.key().len(), + }; + + unsafe { Self::from_raw(string) } } /// Converts a string reference into a `str`. - pub fn as_str(&self) -> Result<&'c str, Utf8Error> { + pub fn as_str(&self) -> Result<&'a str, Utf8Error> { unsafe { let bytes = slice::from_raw_parts(self.raw.data as *mut u8, self.raw.length);