diff --git a/melior/src/context.rs b/melior/src/context.rs index addd690976..0bc2a34002 100644 --- a/melior/src/context.rs +++ b/melior/src/context.rs @@ -127,6 +127,10 @@ impl Context { unsafe { mlirContextDetachDiagnosticHandler(self.to_raw(), id.to_raw()) } } + pub(crate) fn to_ref(&self) -> ContextRef { + unsafe { ContextRef::from_raw(self.to_raw()) } + } + pub(crate) fn string_cache(&self) -> &DashMap { &self.string_cache } @@ -150,6 +154,12 @@ impl PartialEq for Context { } } +impl<'a> PartialEq> for Context { + fn eq(&self, &other: &ContextRef<'a>) -> bool { + self.to_ref() == other + } +} + impl Eq for Context {} /// A reference to a context. @@ -173,13 +183,19 @@ impl<'c> ContextRef<'c> { } } -impl<'a> PartialEq for ContextRef<'a> { +impl<'c> PartialEq for ContextRef<'c> { fn eq(&self, other: &Self) -> bool { unsafe { mlirContextEqual(self.raw, other.raw) } } } -impl<'a> Eq for ContextRef<'a> {} +impl<'c> PartialEq for ContextRef<'c> { + fn eq(&self, other: &Context) -> bool { + self == &other.to_ref() + } +} + +impl<'c> Eq for ContextRef<'c> {} #[cfg(test)] mod tests { @@ -266,4 +282,36 @@ mod tests { context.detach_diagnostic_handler(id); } + + #[test] + fn compare_contexts() { + let one = Context::new(); + let other = Context::new(); + + assert_eq!(&one, &one); + assert_ne!(&one, &other); + assert_ne!(&other, &one); + assert_eq!(&other, &other); + } + + #[test] + fn compare_context_refs() { + let one = Context::new(); + let other = Context::new(); + + let one_ref = one.to_ref(); + let other_ref = other.to_ref(); + + assert_eq!(&one, &one_ref); + assert_eq!(&one_ref, &one); + + assert_eq!(&other, &other_ref); + assert_eq!(&other_ref, &other); + + assert_ne!(&one, &other_ref); + assert_ne!(&other_ref, &one); + + assert_ne!(&other, &one_ref); + assert_ne!(&one_ref, &other); + } }