Skip to content

Commit

Permalink
Use C++ style strings in most places (#354)
Browse files Browse the repository at this point in the history
  • Loading branch information
raviqqe authored Oct 18, 2023
1 parent f14e2a5 commit c774279
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 66 deletions.
24 changes: 9 additions & 15 deletions melior/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ use mlir_sys::{
mlirContextIsRegisteredOperation, mlirContextLoadAllAvailableDialects,
mlirContextSetAllowUnregisteredDialects, MlirContext, MlirDiagnostic, MlirLogicalResult,
};
use std::{
ffi::{c_void, CString},
marker::PhantomData,
};
use std::{ffi::c_void, marker::PhantomData, pin::Pin};

/// A context of IR, dialects, and passes.
///
Expand All @@ -27,7 +24,7 @@ pub struct Context {
raw: MlirContext,
// We need to pass null-terminated strings to functions in the MLIR API although
// Rust's strings are not.
string_cache: DashMap<CString, ()>,
string_cache: DashMap<Pin<String>, ()>,
}

impl Context {
Expand All @@ -51,12 +48,9 @@ impl Context {

/// Gets or loads a dialect.
pub fn get_or_load_dialect(&self, name: &str) -> Dialect {
unsafe {
Dialect::from_raw(mlirContextGetOrLoadDialect(
self.raw,
StringRef::from_str(self, name).to_raw(),
))
}
let name = StringRef::new(name);

unsafe { Dialect::from_raw(mlirContextGetOrLoadDialect(self.raw, name.to_raw())) }
}

/// Appends a dialect registry.
Expand Down Expand Up @@ -86,9 +80,9 @@ impl Context {

/// Returns `true` if a given operation is registered in a context.
pub fn is_registered_operation(&self, name: &str) -> bool {
unsafe {
mlirContextIsRegisteredOperation(self.raw, StringRef::from_str(self, name).to_raw())
}
let name = StringRef::new(name);

unsafe { mlirContextIsRegisteredOperation(self.raw, name.to_raw()) }
}

/// Converts a context into a raw object.
Expand Down Expand Up @@ -131,7 +125,7 @@ impl Context {
unsafe { ContextRef::from_raw(self.to_raw()) }
}

pub(crate) fn string_cache(&self) -> &DashMap<CString, ()> {
pub(crate) fn string_cache(&self) -> &DashMap<Pin<String>, ()> {
&self.string_cache
}
}
Expand Down
4 changes: 2 additions & 2 deletions melior/src/dialect/llvm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,10 @@ mod tests {

pass_manager.add_pass(pass::conversion::create_func_to_llvm());
pass_manager
.nested_under(context, "func.func")
.nested_under("func.func")
.add_pass(pass::conversion::create_arith_to_llvm());
pass_manager
.nested_under(context, "func.func")
.nested_under("func.func")
.add_pass(pass::conversion::create_index_to_llvm());
pass_manager.add_pass(pass::conversion::create_scf_to_control_flow());
pass_manager.add_pass(pass::conversion::create_control_flow_to_llvm());
Expand Down
16 changes: 8 additions & 8 deletions melior/src/dialect/ods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,10 @@ mod tests {

pass_manager.add_pass(pass::conversion::create_func_to_llvm());
pass_manager
.nested_under(context, "func.func")
.nested_under("func.func")
.add_pass(pass::conversion::create_arith_to_llvm());
pass_manager
.nested_under(context, "func.func")
.nested_under("func.func")
.add_pass(pass::conversion::create_index_to_llvm());
pass_manager.add_pass(pass::conversion::create_scf_to_control_flow());
pass_manager.add_pass(pass::conversion::create_control_flow_to_llvm());
Expand All @@ -148,13 +148,13 @@ mod tests {
argument_types: &[Type<'c>],
callback: impl FnOnce(&Block<'c>),
) {
let location = Location::unknown(&context);
let location = Location::unknown(context);
let mut module = Module::new(location);

module.body().append_operation(func::func(
&context,
StringAttribute::new(&context, "foo"),
TypeAttribute::new(FunctionType::new(&context, argument_types, &[]).into()),
context,
StringAttribute::new(context, "foo"),
TypeAttribute::new(FunctionType::new(context, argument_types, &[]).into()),
{
let block = Block::new(
&argument_types
Expand All @@ -174,7 +174,7 @@ mod tests {
location,
));

convert_module(&context, &mut module);
convert_module(context, &mut module);

assert!(module.as_operation().verify());
insta::assert_display_snapshot!(name, module.as_operation());
Expand All @@ -193,7 +193,7 @@ mod tests {
block.append_operation(
llvm::alloca(
&context,
dialect::llvm::r#type::pointer(i64_type.into(), 0).into(),
dialect::llvm::r#type::pointer(i64_type.into(), 0),
alloca_size,
location,
)
Expand Down
4 changes: 2 additions & 2 deletions melior/src/execution_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ mod tests {
pass_manager.add_pass(pass::conversion::create_func_to_llvm());

pass_manager
.nested_under(&context, "func.func")
.nested_under("func.func")
.add_pass(pass::conversion::create_arith_to_llvm());

assert_eq!(pass_manager.run(&mut module), Ok(()));
Expand Down Expand Up @@ -169,7 +169,7 @@ mod tests {
pass_manager.add_pass(pass::conversion::create_func_to_llvm());

pass_manager
.nested_under(&context, "func.func")
.nested_under("func.func")
.add_pass(pass::conversion::create_arith_to_llvm());

assert_eq!(pass_manager.run(&mut module), Ok(()));
Expand Down
15 changes: 7 additions & 8 deletions melior/src/ir/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use mlir_sys::{
mlirModuleCreateEmpty, mlirModuleCreateParse, mlirModuleDestroy, mlirModuleFromOperation,
mlirModuleGetBody, mlirModuleGetContext, mlirModuleGetOperation, MlirModule,
};
use std::marker::PhantomData;
use std::{ffi::CString, marker::PhantomData};

/// A module.
#[derive(Debug)]
Expand All @@ -24,13 +24,12 @@ impl<'c> Module<'c> {

/// Parses a module.
pub fn parse(context: &Context, source: &str) -> Option<Self> {
// TODO Should we allocate StringRef locally because sources can be big?
unsafe {
Self::from_option_raw(mlirModuleCreateParse(
context.to_raw(),
StringRef::from_str(context, source).to_raw(),
))
}
// TODO Use a string not null-terminated.
// Somehow, we still need a null-terminated string for a source.
let source = CString::new(source).unwrap();
let source = StringRef::from_c_str(&source);

unsafe { Self::from_option_raw(mlirModuleCreateParse(context.to_raw(), source.to_raw())) }
}

/// Converts a module into an operation.
Expand Down
17 changes: 8 additions & 9 deletions melior/src/pass/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,11 @@ impl<'c> PassManager<'c> {

/// Gets an operation pass manager for nested operations corresponding to a
/// given name.
pub fn nested_under(&self, context: &'c Context, name: &str) -> OperationPassManager {
pub fn nested_under(&self, name: &str) -> OperationPassManager {
let name = StringRef::new(name);

unsafe {
OperationPassManager::from_raw(mlirPassManagerGetNestedUnder(
self.raw,
StringRef::from_str(context, name).to_raw(),
))
OperationPassManager::from_raw(mlirPassManagerGetNestedUnder(self.raw, name.to_raw()))
}
}

Expand Down Expand Up @@ -178,15 +177,15 @@ mod tests {

let manager = PassManager::new(&context);
manager
.nested_under(&context, "func.func")
.nested_under("func.func")
.add_pass(pass::transform::create_print_op_stats());

assert_eq!(manager.run(&mut module), Ok(()));

let manager = PassManager::new(&context);
manager
.nested_under(&context, "builtin.module")
.nested_under(&context, "func.func")
.nested_under("builtin.module")
.nested_under("func.func")
.add_pass(pass::transform::create_print_op_stats());

assert_eq!(manager.run(&mut module), Ok(()));
Expand All @@ -196,7 +195,7 @@ mod tests {
fn print_pass_pipeline() {
let context = create_test_context();
let manager = PassManager::new(&context);
let function_manager = manager.nested_under(&context, "func.func");
let function_manager = manager.nested_under("func.func");

function_manager.add_pass(pass::transform::create_print_op_stats());

Expand Down
13 changes: 5 additions & 8 deletions melior/src/pass/operation_manager.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::PassManager;
use crate::{pass::Pass, string_ref::StringRef, Context};
use crate::{pass::Pass, string_ref::StringRef};
use mlir_sys::{
mlirOpPassManagerAddOwnedPass, mlirOpPassManagerGetNestedUnder, mlirPrintPassPipeline,
MlirOpPassManager, MlirStringRef,
Expand All @@ -20,13 +20,10 @@ pub struct OperationPassManager<'c, 'a> {
impl<'c, 'a> OperationPassManager<'c, 'a> {
/// Gets an operation pass manager for nested operations corresponding to a
/// given name.
pub fn nested_under(&self, context: &'c Context, name: &str) -> Self {
unsafe {
Self::from_raw(mlirOpPassManagerGetNestedUnder(
self.raw,
StringRef::from_str(context, name).to_raw(),
))
}
pub fn nested_under(&self, name: &str) -> Self {
let name = StringRef::new(name);

unsafe { Self::from_raw(mlirOpPassManagerGetNestedUnder(self.raw, name.to_raw())) }
}

/// Adds a pass.
Expand Down
52 changes: 38 additions & 14 deletions melior/src/string_ref.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::Context;
use mlir_sys::{mlirStringRefCreateFromCString, mlirStringRefEqual, MlirStringRef};
use mlir_sys::{mlirStringRefEqual, MlirStringRef};
use std::{
ffi::CString,
ffi::CStr,
marker::PhantomData,
pin::Pin,
slice,
str::{self, Utf8Error},
};
Expand All @@ -13,25 +14,48 @@ use std::{
// TODO The documentation says string refs do not have to be null-terminated.
// But it looks like some functions do not handle strings not null-terminated?
#[derive(Clone, Copy, Debug)]
pub struct StringRef<'c> {
pub struct StringRef<'a> {
raw: MlirStringRef,
_parent: PhantomData<&'c Context>,
_parent: PhantomData<&'a str>,
}

impl<'c> StringRef<'c> {
pub fn from_str(context: &'c Context, string: &str) -> Self {
let string = context
.string_cache()
.entry(CString::new(string).unwrap())
.or_default()
.key()
.as_ptr();
impl<'a> StringRef<'a> {
/// Creates a string reference.
pub fn new(string: &'a str) -> Self {
let string = MlirStringRef {
data: string.as_bytes().as_ptr() as *const i8,
length: string.len(),
};

unsafe { Self::from_raw(string) }
}

/// Converts a C-style string into a string reference.
pub fn from_c_str(string: &'a CStr) -> Self {
let string = MlirStringRef {
data: string.as_ptr(),
length: string.to_bytes_with_nul().len() - 1,
};

unsafe { Self::from_raw(mlirStringRefCreateFromCString(string)) }
unsafe { Self::from_raw(string) }
}

/// Converts a string into a null-terminated string reference.
pub fn from_str(context: &'a Context, string: &str) -> Self {
let entry = context
.string_cache()
.entry(Pin::new(string.into()))
.or_default();
let string = MlirStringRef {
data: entry.key().as_bytes().as_ptr() as *const i8,
length: entry.key().len(),
};

unsafe { Self::from_raw(string) }
}

/// Converts a string reference into a `str`.
pub fn as_str(&self) -> Result<&'c str, Utf8Error> {
pub fn as_str(&self) -> Result<&'a str, Utf8Error> {
unsafe {
let bytes = slice::from_raw_parts(self.raw.data as *mut u8, self.raw.length);

Expand Down

0 comments on commit c774279

Please sign in to comment.