Skip to content

Commit

Permalink
String ref and value APIs (#27)
Browse files Browse the repository at this point in the history
Part of #24.
  • Loading branch information
raviqqe authored Sep 14, 2022
1 parent 2eeb5fc commit db29582
Show file tree
Hide file tree
Showing 3 changed files with 269 additions and 65 deletions.
28 changes: 26 additions & 2 deletions src/string_ref.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use mlir_sys::{mlirStringRefCreateFromCString, MlirStringRef};
use mlir_sys::{mlirStringRefCreateFromCString, mlirStringRefEqual, MlirStringRef};
use once_cell::sync::Lazy;
use std::{collections::HashMap, ffi::CString, marker::PhantomData, slice, str, sync::RwLock};

Expand All @@ -10,6 +10,7 @@ static STRING_CACHE: Lazy<RwLock<HashMap<String, CString>>> = Lazy::new(Default:
//
// TODO The documentation says string refs do not have to be null-terminated.
// But it looks like some functions do not handle strings not null-terminated?
#[derive(Clone, Copy, Debug)]
pub struct StringRef<'a> {
raw: MlirStringRef,
_parent: PhantomData<&'a ()>,
Expand All @@ -29,7 +30,7 @@ impl<'a> StringRef<'a> {
}
}

pub(crate) unsafe fn to_raw(&self) -> MlirStringRef {
pub(crate) unsafe fn to_raw(self) -> MlirStringRef {
self.raw
}

Expand All @@ -41,6 +42,14 @@ impl<'a> StringRef<'a> {
}
}

impl<'a> PartialEq for StringRef<'a> {
fn eq(&self, other: &Self) -> bool {
unsafe { mlirStringRefEqual(self.raw, other.raw) }
}
}

impl<'a> Eq for StringRef<'a> {}

impl From<&str> for StringRef<'static> {
fn from(string: &str) -> Self {
if !STRING_CACHE.read().unwrap().contains_key(string) {
Expand All @@ -56,3 +65,18 @@ impl From<&str> for StringRef<'static> {
unsafe { Self::from_raw(mlirStringRefCreateFromCString(string.as_ptr())) }
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn equal() {
assert_eq!(StringRef::from("foo"), StringRef::from("foo"));
}

#[test]
fn not_equal() {
assert_ne!(StringRef::from("foo"), StringRef::from("bar"));
}
}
166 changes: 112 additions & 54 deletions src/type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ impl<'c> PartialEq for Type<'c> {

impl<'c> Eq for Type<'c> {}

impl<'c> Display for &Type<'c> {
impl<'c> Display for Type<'c> {
fn fmt(&self, formatter: &mut Formatter) -> fmt::Result {
let mut data = (formatter, Ok(()));

Expand Down Expand Up @@ -228,6 +228,40 @@ mod tests {
Type::parse(&Context::new(), "i8").context();
}

#[test]
fn integer() {
let context = Context::new();

assert_eq!(Type::integer(&context, 42), Type::parse(&context, "i42"));
}

#[test]
fn signed_integer() {
let context = Context::new();

assert_eq!(
Type::signed_integer(&context, 42),
Type::parse(&context, "si42")
);
}

#[test]
fn unsigned_integer() {
let context = Context::new();

assert_eq!(
Type::unsigned_integer(&context, 42),
Type::parse(&context, "ui42")
);
}

#[test]
fn display() {
let context = Context::new();

assert_eq!(Type::integer(&context, 42).to_string(), "i42");
}

mod function {
use super::*;

Expand Down Expand Up @@ -330,72 +364,96 @@ mod tests {
}
}

#[test]
fn integer() {
let context = Context::new();
mod llvm {
use super::*;

assert_eq!(Type::integer(&context, 42), Type::parse(&context, "i42"));
}
fn create_context() -> Context {
let context = Context::new();

#[test]
fn signed_integer() {
let context = Context::new();
DialectHandle::llvm().register_dialect(&context);
context.get_or_load_dialect("llvm");

assert_eq!(
Type::signed_integer(&context, 42),
Type::parse(&context, "si42")
);
}
context
}

#[test]
fn unsigned_integer() {
let context = Context::new();
#[test]
fn pointer() {
let context = create_context();
let i32 = Type::integer(&context, 32);

assert_eq!(
Type::unsigned_integer(&context, 42),
Type::parse(&context, "ui42")
);
}
assert_eq!(
Type::llvm_pointer(i32, 0),
Type::parse(&context, "!llvm.ptr<i32>")
);
}

#[test]
fn create_llvm_types() {
let context = Context::new();
#[test]
fn pointer_with_address_space() {
let context = create_context();
let i32 = Type::integer(&context, 32);

DialectHandle::llvm().register_dialect(&context);
context.get_or_load_dialect("llvm");
assert_eq!(
Type::llvm_pointer(i32, 4),
Type::parse(&context, "!llvm.ptr<i32, 4>")
);
}

let i8 = Type::integer(&context, 8);
let i32 = Type::integer(&context, 32);
let i64 = Type::integer(&context, 64);
#[test]
fn void() {
let context = create_context();

assert_eq!(
Type::llvm_pointer(i32, 0),
Type::parse(&context, "!llvm.ptr<i32>")
);
assert_eq!(
Type::llvm_void(&context),
Type::parse(&context, "!llvm.void")
);
}

assert_eq!(
Type::llvm_pointer(i32, 4),
Type::parse(&context, "!llvm.ptr<i32, 4>")
);
#[test]
fn array() {
let context = create_context();
let i32 = Type::integer(&context, 32);

assert_eq!(
Type::llvm_void(&context),
Type::parse(&context, "!llvm.void")
);
assert_eq!(
Type::llvm_array(i32, 4),
Type::parse(&context, "!llvm.array<4xi32>")
);
}

assert_eq!(
Type::llvm_array(i32, 4),
Type::parse(&context, "!llvm.array<4xi32>")
);
#[test]
fn function() {
let context = create_context();
let i8 = Type::integer(&context, 8);
let i32 = Type::integer(&context, 32);
let i64 = Type::integer(&context, 64);

assert_eq!(
Type::llvm_function(i8, &[i32, i64], false),
Type::parse(&context, "!llvm.func<i8 (i32, i64)>")
);
assert_eq!(
Type::llvm_function(i8, &[i32, i64], false),
Type::parse(&context, "!llvm.func<i8 (i32, i64)>")
);
}

assert_eq!(
Type::llvm_struct(&context, &[i32, i64], false),
Type::parse(&context, "!llvm.struct<(i32, i64)>")
);
#[test]
fn r#struct() {
let context = create_context();
let i32 = Type::integer(&context, 32);
let i64 = Type::integer(&context, 64);

assert_eq!(
Type::llvm_struct(&context, &[i32, i64], false),
Type::parse(&context, "!llvm.struct<(i32, i64)>")
);
}

#[test]
fn packed_struct() {
let context = create_context();
let i32 = Type::integer(&context, 32);
let i64 = Type::integer(&context, 64);

assert_eq!(
Type::llvm_struct(&context, &[i32, i64], true),
Type::parse(&context, "!llvm.struct<packed (i32, i64)>")
);
}
}
}
Loading

0 comments on commit db29582

Please sign in to comment.