From 1d8d896adcc65486b3065074c47c5c39184747eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Tue, 26 Nov 2024 15:36:06 +0000 Subject: [PATCH] feat!: OpDefs and TypeDefs keep a reference to their extension --- hugr-core/src/builder/circuit.rs | 16 +- hugr-core/src/export.rs | 6 +- hugr-core/src/extension.rs | 133 ++++++-- hugr-core/src/extension/declarative.rs | 32 +- hugr-core/src/extension/declarative/ops.rs | 12 +- hugr-core/src/extension/declarative/types.rs | 7 + hugr-core/src/extension/op_def.rs | 305 ++++++++++-------- hugr-core/src/extension/prelude.rs | 149 ++++----- hugr-core/src/extension/prelude/array.rs | 11 +- hugr-core/src/extension/simple_op.rs | 36 ++- hugr-core/src/extension/type_def.rs | 47 ++- hugr-core/src/hugr/rewrite/replace.rs | 36 ++- hugr-core/src/hugr/validate/test.rs | 128 ++++---- hugr-core/src/ops/custom.rs | 42 +-- .../std_extensions/arithmetic/conversions.rs | 23 +- .../std_extensions/arithmetic/float_ops.rs | 15 +- .../std_extensions/arithmetic/float_types.rs | 23 +- .../src/std_extensions/arithmetic/int_ops.rs | 15 +- .../std_extensions/arithmetic/int_types.rs | 23 +- hugr-core/src/std_extensions/collections.rs | 39 +-- hugr-core/src/std_extensions/logic.rs | 20 +- hugr-core/src/std_extensions/ptr.rs | 25 +- hugr-core/src/types/poly_func.rs | 20 +- hugr-core/src/utils.rs | 96 +++--- hugr-llvm/src/custom/extension_op.rs | 2 +- hugr-passes/src/merge_bbs.rs | 42 +-- hugr/benches/benchmarks/hugr/examples.rs | 56 ++-- hugr/src/lib.rs | 20 +- 28 files changed, 829 insertions(+), 550 deletions(-) diff --git a/hugr-core/src/builder/circuit.rs b/hugr-core/src/builder/circuit.rs index 112cb83fb..43c106209 100644 --- a/hugr-core/src/builder/circuit.rs +++ b/hugr-core/src/builder/circuit.rs @@ -243,7 +243,7 @@ mod test { use super::*; use cool_asserts::assert_matches; - use crate::extension::{ExtensionId, ExtensionSet}; + use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY}; use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; use crate::utils::test_quantum_extension::{ self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64, @@ -298,8 +298,18 @@ mod test { #[test] fn with_nonlinear_and_outputs() { let my_ext_name: ExtensionId = "MyExt".try_into().unwrap(); - let mut my_ext = Extension::new_test(my_ext_name.clone()); - let my_custom_op = my_ext.simple_ext_op("MyOp", Signature::new(vec![QB, NAT], vec![QB])); + let my_ext = Extension::new_test_arc(my_ext_name.clone(), |ext, extension_ref| { + ext.add_op( + "MyOp".into(), + "".to_string(), + Signature::new(vec![QB, NAT], vec![QB]), + extension_ref, + ) + .unwrap(); + }); + let my_custom_op = my_ext + .instantiate_extension_op("MyOp", [], &PRELUDE_REGISTRY) + .unwrap(); let build_res = build_main( Signature::new(type_row![QB, QB, NAT], type_row![QB, QB, BOOL_T]) diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 093368b60..220b03fea 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -443,10 +443,10 @@ impl<'a> Context<'a> { let poly_func_type = match opdef.signature_func() { SignatureFunc::PolyFuncType(poly_func_type) => poly_func_type, - _ => return self.make_named_global_ref(opdef.extension(), opdef.name()), + _ => return self.make_named_global_ref(opdef.extension_id(), opdef.name()), }; - let key = (opdef.extension().clone(), opdef.name().clone()); + let key = (opdef.extension_id().clone(), opdef.name().clone()); let entry = self.decl_operations.entry(key); let node = match entry { @@ -467,7 +467,7 @@ impl<'a> Context<'a> { }; let decl = self.with_local_scope(node, |this| { - let name = this.make_qualified_name(opdef.extension(), opdef.name()); + let name = this.make_qualified_name(opdef.extension_id(), opdef.name()); let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type); let decl = this.bump.alloc(model::OperationDecl { name, diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 4d22ba7f0..10a8a7ebf 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -7,7 +7,8 @@ pub use semver::Version; use std::collections::btree_map; use std::collections::{BTreeMap, BTreeSet}; use std::fmt::{Debug, Display, Formatter}; -use std::sync::Arc; +use std::mem; +use std::sync::{Arc, Weak}; use thiserror::Error; @@ -335,6 +336,45 @@ impl ExtensionValue { pub type ExtensionId = IdentList; /// A extension is a set of capabilities required to execute a graph. +/// +/// These are normally defined once and shared across multiple graphs and +/// operations wrapped in [`Arc`]s inside [`ExtensionRegistry`]. +/// +/// # Example +/// +/// The following example demonstrates how to define a new extension with a +/// custom operation and a custom type. +/// +/// When using `arc`s, the extension can only be modified at creation time. The +/// defined operations and types keep a [`Weak`] reference to their extension. We provide a +/// helper method [`Extension::new_arc`] to aid their definition. +/// +/// ``` +/// # use hugr_core::types::Signature; +/// # use hugr_core::extension::{Extension, ExtensionId, Version}; +/// # use hugr_core::extension::{TypeDefBound}; +/// Extension::new_arc( +/// ExtensionId::new_unchecked("my.extension"), +/// Version::new(0, 1, 0), +/// |ext, extension_ref| { +/// // Add a custom type definition +/// ext.add_type( +/// "MyType".into(), +/// vec![], // No type parameters +/// "Some type".into(), +/// TypeDefBound::any(), +/// extension_ref, +/// ); +/// // Add a custom operation +/// ext.add_op( +/// "MyOp".into(), +/// "Some operation".into(), +/// Signature::new_endo(vec![]), +/// extension_ref, +/// ); +/// }, +/// ); +/// ``` #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct Extension { /// Extension version, follows semver. @@ -361,6 +401,12 @@ pub struct Extension { impl Extension { /// Creates a new extension with the given name. + /// + /// In most cases extensions are contained inside an [`Arc`] so that they + /// can be shared across hugr instances and operation definitions. + /// + /// See [`Extension::new_arc`] for a more ergonomic way to create boxed + /// extensions. pub fn new(name: ExtensionId, version: Version) -> Self { Self { name, @@ -372,14 +418,63 @@ impl Extension { } } - /// Extend the requirements of this extension with another set of extensions. - pub fn with_reqs(self, extension_reqs: impl Into) -> Self { - Self { - extension_reqs: self.extension_reqs.union(extension_reqs.into()), - ..self + /// Creates a new extension wrapped in an [`Arc`]. + /// + /// The closure lets us use a weak reference to the arc while the extension + /// is being built. This is necessary for calling [`Extension::add_op`] and + /// [`Extension::add_type`]. + pub fn new_arc( + name: ExtensionId, + version: Version, + init: impl FnOnce(&mut Extension, &Weak), + ) -> Arc { + Arc::new_cyclic(|extension_ref| { + let mut ext = Self::new(name, version); + init(&mut ext, extension_ref); + ext + }) + } + + /// Creates a new extension wrapped in an [`Arc`], using a fallible + /// initialization function. + /// + /// The closure lets us use a weak reference to the arc while the extension + /// is being built. This is necessary for calling [`Extension::add_op`] and + /// [`Extension::add_type`]. + pub fn try_new_arc( + name: ExtensionId, + version: Version, + init: impl FnOnce(&mut Extension, &Weak) -> Result<(), E>, + ) -> Result, E> { + // Annoying hack around not having `Arc::try_new_cyclic` that can return + // a Result. + // https://github.com/rust-lang/rust/issues/75861#issuecomment-980455381 + // + // When there is an error, we store it in `error` and return it at the + // end instead of the partially-initialized extension. + let mut error = None; + let ext = Arc::new_cyclic(|extension_ref| { + let mut ext = Self::new(name, version); + match init(&mut ext, extension_ref) { + Ok(_) => ext, + Err(e) => { + error = Some(e); + ext + } + } + }); + match error { + Some(e) => Err(e), + None => Ok(ext), } } + /// Extend the requirements of this extension with another set of extensions. + pub fn set_reqs(&mut self, extension_reqs: impl Into) { + let reqs = mem::take(&mut self.extension_reqs); + self.extension_reqs = reqs.union(extension_reqs.into()); + } + /// Allows read-only access to the operations in this Extension pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc> { self.operations.get(name) @@ -634,20 +729,22 @@ pub mod test { impl Extension { /// Create a new extension for testing, with a 0 version. - pub(crate) fn new_test(name: ExtensionId) -> Self { - Self::new(name, Version::new(0, 0, 0)) + pub(crate) fn new_test_arc( + name: ExtensionId, + init: impl FnOnce(&mut Extension, &Weak), + ) -> Arc { + Self::new_arc(name, Version::new(0, 0, 0), init) } - /// Add a simple OpDef to the extension and return an extension op for it. - /// No description, no type parameters. - pub(crate) fn simple_ext_op( - &mut self, - name: &str, - signature: impl Into, - ) -> ExtensionOp { - self.add_op(name.into(), "".to_string(), signature).unwrap(); - self.instantiate_extension_op(name, [], &PRELUDE_REGISTRY) - .unwrap() + /// Create a new extension for testing, with a 0 version. + pub(crate) fn try_new_test_arc( + name: ExtensionId, + init: impl FnOnce( + &mut Extension, + &Weak, + ) -> Result<(), Box>, + ) -> Result, Box> { + Self::try_new_arc(name, Version::new(0, 0, 0), init) } } diff --git a/hugr-core/src/extension/declarative.rs b/hugr-core/src/extension/declarative.rs index c81414c9f..e17be112e 100644 --- a/hugr-core/src/extension/declarative.rs +++ b/hugr-core/src/extension/declarative.rs @@ -29,6 +29,7 @@ mod types; use std::fs::File; use std::path::Path; +use std::sync::Arc; use crate::extension::prelude::PRELUDE_ID; use crate::ops::OpName; @@ -150,19 +151,24 @@ impl ExtensionDeclaration { &self, imports: &ExtensionSet, ctx: DeclarationContext<'_>, - ) -> Result { - let mut ext = Extension::new(self.name.clone(), crate::extension::Version::new(0, 0, 0)) - .with_reqs(imports.clone()); - - for t in &self.types { - t.register(&mut ext, ctx)?; - } - - for o in &self.operations { - o.register(&mut ext, ctx)?; - } - - Ok(ext) + ) -> Result, ExtensionDeclarationError> { + Extension::try_new_arc( + self.name.clone(), + // TODO: Get the version as a parameter. + crate::extension::Version::new(0, 0, 0), + |ext, extension_ref| { + for t in &self.types { + t.register(ext, ctx, extension_ref)?; + } + + for o in &self.operations { + o.register(ext, ctx, extension_ref)?; + } + ext.set_reqs(imports.clone()); + + Ok(()) + }, + ) } } diff --git a/hugr-core/src/extension/declarative/ops.rs b/hugr-core/src/extension/declarative/ops.rs index 8bd769e10..39e688a6b 100644 --- a/hugr-core/src/extension/declarative/ops.rs +++ b/hugr-core/src/extension/declarative/ops.rs @@ -8,6 +8,7 @@ //! [`ExtensionSetDeclaration`]: super::ExtensionSetDeclaration use std::collections::HashMap; +use std::sync::Weak; use serde::{Deserialize, Serialize}; use smol_str::SmolStr; @@ -55,10 +56,14 @@ pub(super) struct OperationDeclaration { impl OperationDeclaration { /// Register this operation in the given extension. + /// + /// Requires a [`Weak`] reference to the extension defining the operation. + /// This method is intended to be used inside the closure passed to [`Extension::new_arc`]. pub fn register<'ext>( &self, ext: &'ext mut Extension, ctx: DeclarationContext<'_>, + extension_ref: &Weak, ) -> Result<&'ext mut OpDef, ExtensionDeclarationError> { // We currently only support explicit signatures. // @@ -88,7 +93,12 @@ impl OperationDeclaration { let signature_func: SignatureFunc = signature.make_signature(ext, ctx, ¶ms)?; - let op_def = ext.add_op(self.name.clone(), self.description.clone(), signature_func)?; + let op_def = ext.add_op( + self.name.clone(), + self.description.clone(), + signature_func, + extension_ref, + )?; for (k, v) in &self.misc { op_def.add_misc(k, v.clone()); diff --git a/hugr-core/src/extension/declarative/types.rs b/hugr-core/src/extension/declarative/types.rs index 10b6e41a0..e426c69f2 100644 --- a/hugr-core/src/extension/declarative/types.rs +++ b/hugr-core/src/extension/declarative/types.rs @@ -7,6 +7,8 @@ //! [specification]: https://github.com/CQCL/hugr/blob/main/specification/hugr.md#declarative-format //! [`ExtensionSetDeclaration`]: super::ExtensionSetDeclaration +use std::sync::Weak; + use crate::extension::{TypeDef, TypeDefBound}; use crate::types::type_param::TypeParam; use crate::types::{TypeBound, TypeName}; @@ -49,10 +51,14 @@ impl TypeDeclaration { /// /// Types in the definition will be resolved using the extensions in `scope` /// and the current extension. + /// + /// Requires a [`Weak`] reference to the extension defining the operation. + /// This method is intended to be used inside the closure passed to [`Extension::new_arc`]. pub fn register<'ext>( &self, ext: &'ext mut Extension, ctx: DeclarationContext<'_>, + extension_ref: &Weak, ) -> Result<&'ext TypeDef, ExtensionDeclarationError> { let params = self .params @@ -64,6 +70,7 @@ impl TypeDeclaration { params, self.description.clone(), self.bound.into(), + extension_ref, )?; Ok(type_def) } diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index 5e74b9e9c..6f33cf3ef 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -2,7 +2,7 @@ use std::cmp::min; use std::collections::btree_map::Entry; use std::collections::HashMap; use std::fmt::{Debug, Formatter}; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use super::{ ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, @@ -302,6 +302,9 @@ impl Debug for LowerFunc { pub struct OpDef { /// The unique Extension owning this OpDef (of which this OpDef is a member) extension: ExtensionId, + /// A weak reference to the extension defining this operation. + #[serde(skip)] + extension_ref: Weak, /// Unique identifier of the operation. Used to look up OpDefs in the registry /// when deserializing nodes (which store only the name). name: OpName, @@ -394,11 +397,16 @@ impl OpDef { &self.name } - /// Returns a reference to the extension of this [`OpDef`]. - pub fn extension(&self) -> &ExtensionId { + /// Returns a reference to the extension id of this [`OpDef`]. + pub fn extension_id(&self) -> &ExtensionId { &self.extension } + /// Returns a weak reference to the extension defining this operation. + pub fn extension(&self) -> Weak { + self.extension_ref.clone() + } + /// Returns a reference to the description of this [`OpDef`]. pub fn description(&self) -> &str { self.description.as_ref() @@ -467,15 +475,41 @@ impl Extension { /// Add an operation definition to the extension. Must be a type scheme /// (defined by a [`PolyFuncTypeRV`]), a type scheme along with binary /// validation for type arguments ([`CustomValidator`]), or a custom binary - /// function for computing the signature given type arguments (`impl [CustomSignatureFunc]`). + /// function for computing the signature given type arguments (implementing + /// `[CustomSignatureFunc]`). + /// + /// This method requires a [`Weak`] reference to the [`Arc`] containing the + /// extension being defined. The intended way to call this method is inside + /// the closure passed to [`Extension::new_arc`] when defining the extension. + /// + /// # Example + /// + /// ``` + /// # use hugr_core::types::Signature; + /// # use hugr_core::extension::{Extension, ExtensionId, Version}; + /// Extension::new_arc( + /// ExtensionId::new_unchecked("my.extension"), + /// Version::new(0, 1, 0), + /// |ext, extension_ref| { + /// ext.add_op( + /// "MyOp".into(), + /// "Some operation".into(), + /// Signature::new_endo(vec![]), + /// extension_ref, + /// ); + /// }, + /// ); + /// ``` pub fn add_op( &mut self, name: OpName, description: String, signature_func: impl Into, + extension_ref: &Weak, ) -> Result<&mut OpDef, ExtensionBuildError> { let op = OpDef { extension: self.name.clone(), + extension_ref: extension_ref.clone(), name, description, signature_func: signature_func.into(), @@ -544,6 +578,7 @@ pub(super) mod test { fn eq(&self, other: &Self) -> bool { let OpDef { extension, + extension_ref: _, name, description, misc, @@ -553,6 +588,7 @@ pub(super) mod test { } = &self.0; let OpDef { extension: other_extension, + extension_ref: _, name: other_name, description: other_description, misc: other_misc, @@ -601,25 +637,28 @@ pub(super) mod test { #[test] fn op_def_with_type_scheme() -> Result<(), Box> { let list_def = EXTENSION.get_type(&LIST_TYPENAME).unwrap(); - let mut e = Extension::new_test(EXT_ID); - const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; - let list_of_var = - Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?); const OP_NAME: OpName = OpName::new_inline("Reverse"); - let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var])); - - let def = e.add_op(OP_NAME, "desc".into(), type_scheme)?; - def.add_lower_func(LowerFunc::FixedHugr { - extensions: ExtensionSet::new(), - hugr: crate::builder::test::simple_dfg_hugr(), // this is nonsense, but we are not testing the actual lowering here - }); - def.add_misc("key", Default::default()); - assert_eq!(def.description(), "desc"); - assert_eq!(def.lower_funcs.len(), 1); - assert_eq!(def.misc.len(), 1); - - let reg = - ExtensionRegistry::try_new([PRELUDE.clone(), EXTENSION.clone(), e.into()]).unwrap(); + + let ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + const TP: TypeParam = TypeParam::Type { b: TypeBound::Any }; + let list_of_var = + Type::new_extension(list_def.instantiate(vec![TypeArg::new_var_use(0, TP)])?); + let type_scheme = PolyFuncTypeRV::new(vec![TP], Signature::new_endo(vec![list_of_var])); + + let def = ext.add_op(OP_NAME, "desc".into(), type_scheme, extension_ref)?; + def.add_lower_func(LowerFunc::FixedHugr { + extensions: ExtensionSet::new(), + hugr: crate::builder::test::simple_dfg_hugr(), // this is nonsense, but we are not testing the actual lowering here + }); + def.add_misc("key", Default::default()); + assert_eq!(def.description(), "desc"); + assert_eq!(def.lower_funcs.len(), 1); + assert_eq!(def.misc.len(), 1); + + Ok(()) + })?; + + let reg = ExtensionRegistry::try_new([PRELUDE.clone(), EXTENSION.clone(), ext]).unwrap(); let e = reg.get(&EXT_ID).unwrap(); let list_usize = @@ -666,60 +705,63 @@ pub(super) mod test { MAX_NAT } } - let mut e = Extension::new_test(EXT_ID); - let def: &mut crate::extension::OpDef = - e.add_op("MyOp".into(), "".to_string(), SigFun())?; - - // Base case, no type variables: - let args = [TypeArg::BoundedNat { n: 3 }, USIZE_T.into()]; - assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok( - Signature::new(vec![USIZE_T; 3], vec![Type::new_tuple(vec![USIZE_T; 3])]) - .with_extension_delta(EXT_ID) - ) - ); - assert_eq!(def.validate_args(&args, &PRELUDE_REGISTRY, &[]), Ok(())); - - // Second arg may be a variable (substitutable) - let tyvar = Type::new_var_use(0, TypeBound::Copyable); - let tyvars: Vec = vec![tyvar.clone(); 3]; - let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; - assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok( - Signature::new(tyvars.clone(), vec![Type::new_tuple(tyvars)]) - .with_extension_delta(EXT_ID) - ) - ); - def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Copyable.into()]) - .unwrap(); - - // quick sanity check that we are validating the args - note changed bound: - assert_eq!( - def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Any.into()]), - Err(SignatureError::TypeVarDoesNotMatchDeclaration { - actual: TypeBound::Any.into(), - cached: TypeBound::Copyable.into() - }) - ); - - // First arg must be concrete, not a variable - let kind = TypeParam::bounded_nat(NonZeroU64::new(5).unwrap()); - let args = [TypeArg::new_var_use(0, kind.clone()), USIZE_T.into()]; - // We can't prevent this from getting into our compute_signature implementation: - assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), - Err(SignatureError::InvalidTypeArgs) - ); - // But validation rules it out, even when the variable is declared: - assert_eq!( - def.validate_args(&args, &PRELUDE_REGISTRY, &[kind]), - Err(SignatureError::FreeTypeVar { - idx: 0, - num_decls: 0 - }) - ); + let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + let def: &mut crate::extension::OpDef = + ext.add_op("MyOp".into(), "".to_string(), SigFun(), extension_ref)?; + + // Base case, no type variables: + let args = [TypeArg::BoundedNat { n: 3 }, USIZE_T.into()]; + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Ok( + Signature::new(vec![USIZE_T; 3], vec![Type::new_tuple(vec![USIZE_T; 3])]) + .with_extension_delta(EXT_ID) + ) + ); + assert_eq!(def.validate_args(&args, &PRELUDE_REGISTRY, &[]), Ok(())); + + // Second arg may be a variable (substitutable) + let tyvar = Type::new_var_use(0, TypeBound::Copyable); + let tyvars: Vec = vec![tyvar.clone(); 3]; + let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Ok( + Signature::new(tyvars.clone(), vec![Type::new_tuple(tyvars)]) + .with_extension_delta(EXT_ID) + ) + ); + def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Copyable.into()]) + .unwrap(); + + // quick sanity check that we are validating the args - note changed bound: + assert_eq!( + def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Any.into()]), + Err(SignatureError::TypeVarDoesNotMatchDeclaration { + actual: TypeBound::Any.into(), + cached: TypeBound::Copyable.into() + }) + ); + + // First arg must be concrete, not a variable + let kind = TypeParam::bounded_nat(NonZeroU64::new(5).unwrap()); + let args = [TypeArg::new_var_use(0, kind.clone()), USIZE_T.into()]; + // We can't prevent this from getting into our compute_signature implementation: + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Err(SignatureError::InvalidTypeArgs) + ); + // But validation rules it out, even when the variable is declared: + assert_eq!( + def.validate_args(&args, &PRELUDE_REGISTRY, &[kind]), + Err(SignatureError::FreeTypeVar { + idx: 0, + num_decls: 0 + }) + ); + + Ok(()) + })?; Ok(()) } @@ -728,34 +770,37 @@ pub(super) mod test { fn type_scheme_instantiate_var() -> Result<(), Box> { // Check that we can instantiate a PolyFuncTypeRV-scheme with an (external) // type variable - let mut e = Extension::new_test(EXT_ID); - let def = e.add_op( - "SimpleOp".into(), - "".into(), - PolyFuncTypeRV::new( - vec![TypeBound::Any.into()], - Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), - ), - )?; - let tv = Type::new_var_use(1, TypeBound::Copyable); - let args = [TypeArg::Type { ty: tv.clone() }]; - let decls = [TypeParam::Extensions, TypeBound::Copyable.into()]; - def.validate_args(&args, &EMPTY_REG, &decls).unwrap(); - assert_eq!( - def.compute_signature(&args, &EMPTY_REG), - Ok(Signature::new_endo(tv).with_extension_delta(EXT_ID)) - ); - // But not with an external row variable - let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into(); - assert_eq!( - def.compute_signature(&[arg.clone()], &EMPTY_REG), - Err(SignatureError::TypeArgMismatch( - TypeArgError::TypeMismatch { - param: TypeBound::Any.into(), - arg - } - )) - ); + let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + let def = ext.add_op( + "SimpleOp".into(), + "".into(), + PolyFuncTypeRV::new( + vec![TypeBound::Any.into()], + Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), + ), + extension_ref, + )?; + let tv = Type::new_var_use(1, TypeBound::Copyable); + let args = [TypeArg::Type { ty: tv.clone() }]; + let decls = [TypeParam::Extensions, TypeBound::Copyable.into()]; + def.validate_args(&args, &EMPTY_REG, &decls).unwrap(); + assert_eq!( + def.compute_signature(&args, &EMPTY_REG), + Ok(Signature::new_endo(tv).with_extension_delta(EXT_ID)) + ); + // But not with an external row variable + let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into(); + assert_eq!( + def.compute_signature(&[arg.clone()], &EMPTY_REG), + Err(SignatureError::TypeArgMismatch( + TypeArgError::TypeMismatch { + param: TypeBound::Any.into(), + arg + } + )) + ); + Ok(()) + })?; Ok(()) } @@ -763,33 +808,39 @@ pub(super) mod test { fn instantiate_extension_delta() -> Result<(), Box> { use crate::extension::prelude::{BOOL_T, PRELUDE_REGISTRY}; - let mut e = Extension::new_test(EXT_ID); - - let params: Vec = vec![TypeParam::Extensions]; - let db_set = ExtensionSet::type_var(0); - let fun_ty = Signature::new_endo(BOOL_T).with_extension_delta(db_set); + let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + let params: Vec = vec![TypeParam::Extensions]; + let db_set = ExtensionSet::type_var(0); + let fun_ty = Signature::new_endo(BOOL_T).with_extension_delta(db_set); + + let def = ext.add_op( + "SimpleOp".into(), + "".into(), + PolyFuncTypeRV::new(params.clone(), fun_ty), + extension_ref, + )?; + + // Concrete extension set + let es = ExtensionSet::singleton(&EXT_ID); + let exp_fun_ty = Signature::new_endo(BOOL_T).with_extension_delta(es.clone()); + let args = [TypeArg::Extensions { es }]; + + def.validate_args(&args, &PRELUDE_REGISTRY, ¶ms) + .unwrap(); + assert_eq!( + def.compute_signature(&args, &PRELUDE_REGISTRY), + Ok(exp_fun_ty) + ); + + Ok(()) + })?; - let def = e.add_op( - "SimpleOp".into(), - "".into(), - PolyFuncTypeRV::new(params.clone(), fun_ty), - )?; - - // Concrete extension set - let es = ExtensionSet::singleton(&EXT_ID); - let exp_fun_ty = Signature::new_endo(BOOL_T).with_extension_delta(es.clone()); - let args = [TypeArg::Extensions { es }]; - - def.validate_args(&args, &PRELUDE_REGISTRY, ¶ms) - .unwrap(); - assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok(exp_fun_ty) - ); Ok(()) } mod proptest { + use std::sync::Weak; + use super::SimpleOpDef; use ::proptest::prelude::*; @@ -846,6 +897,8 @@ pub(super) mod test { |(extension, name, description, misc, signature_func, lower_funcs)| { Self::new(OpDef { extension, + // Use a dead weak reference. Trying to access the extension will always return None. + extension_ref: Weak::default(), name, description, misc, diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index e7192f6ee..691a9cfb6 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -41,75 +41,80 @@ pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude"); pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); lazy_static! { static ref PRELUDE_DEF: Arc = { - let mut prelude = Extension::new(PRELUDE_ID, VERSION); - prelude - .add_type( - TypeName::new_inline("usize"), - vec![], - "usize".into(), - TypeDefBound::copyable(), - ) - .unwrap(); - prelude.add_type( - STRING_TYPE_NAME, - vec![], - "string".into(), - TypeDefBound::copyable(), - ) - .unwrap(); - prelude.add_op( - PRINT_OP_ID, - "Print the string to standard output".to_string(), - Signature::new(type_row![STRING_TYPE], type_row![]), - ) - .unwrap(); - prelude.add_type( - TypeName::new_inline(ARRAY_TYPE_NAME), - vec![ TypeParam::max_nat(), TypeBound::Any.into()], - "array".into(), - TypeDefBound::from_params(vec![1] ), - ) - .unwrap(); - - prelude - .add_type( - TypeName::new_inline("qubit"), - vec![], - "qubit".into(), - TypeDefBound::any(), - ) - .unwrap(); - prelude - .add_type( - ERROR_TYPE_NAME, - vec![], - "Simple opaque error type.".into(), - TypeDefBound::copyable(), - ) - .unwrap(); - - - prelude - .add_op( - PANIC_OP_ID, - "Panic with input error".to_string(), - PolyFuncTypeRV::new( - [TypeParam::new_list(TypeBound::Any), TypeParam::new_list(TypeBound::Any)], - FuncValueType::new( - vec![TypeRV::new_extension(ERROR_CUSTOM_TYPE), TypeRV::new_row_var_use(0, TypeBound::Any)], - vec![TypeRV::new_row_var_use(1, TypeBound::Any)], - ), - ), - ) - .unwrap(); - - TupleOpDef::load_all_ops(&mut prelude).unwrap(); - NoopDef.add_to_extension(&mut prelude).unwrap(); - LiftDef.add_to_extension(&mut prelude).unwrap(); - array::ArrayOpDef::load_all_ops(&mut prelude).unwrap(); - array::ArrayScanDef.add_to_extension(&mut prelude).unwrap(); - - Arc::new(prelude) + Extension::new_arc(PRELUDE_ID, VERSION, |prelude, extension_ref| { + prelude + .add_type( + TypeName::new_inline("usize"), + vec![], + "usize".into(), + TypeDefBound::copyable(), + extension_ref, + ) + .unwrap(); + prelude.add_type( + STRING_TYPE_NAME, + vec![], + "string".into(), + TypeDefBound::copyable(), + extension_ref, + ) + .unwrap(); + prelude.add_op( + PRINT_OP_ID, + "Print the string to standard output".to_string(), + Signature::new(type_row![STRING_TYPE], type_row![]), + extension_ref, + ) + .unwrap(); + prelude.add_type( + TypeName::new_inline(ARRAY_TYPE_NAME), + vec![ TypeParam::max_nat(), TypeBound::Any.into()], + "array".into(), + TypeDefBound::from_params(vec![1] ), + extension_ref, + ) + .unwrap(); + + prelude + .add_type( + TypeName::new_inline("qubit"), + vec![], + "qubit".into(), + TypeDefBound::any(), + extension_ref, + ) + .unwrap(); + prelude + .add_type( + ERROR_TYPE_NAME, + vec![], + "Simple opaque error type.".into(), + TypeDefBound::copyable(), + extension_ref, + ) + .unwrap(); + + prelude + .add_op( + PANIC_OP_ID, + "Panic with input error".to_string(), + PolyFuncTypeRV::new( + [TypeParam::new_list(TypeBound::Any), TypeParam::new_list(TypeBound::Any)], + FuncValueType::new( + vec![TypeRV::new_extension(ERROR_CUSTOM_TYPE), TypeRV::new_row_var_use(0, TypeBound::Any)], + vec![TypeRV::new_row_var_use(1, TypeBound::Any)], + ), + ), + extension_ref, + ) + .unwrap(); + + TupleOpDef::load_all_ops(prelude, extension_ref).unwrap(); + NoopDef.add_to_extension(prelude, extension_ref).unwrap(); + LiftDef.add_to_extension(prelude, extension_ref).unwrap(); + array::ArrayOpDef::load_all_ops(prelude, extension_ref).unwrap(); + array::ArrayScanDef.add_to_extension(prelude, extension_ref).unwrap(); + }) }; /// An extension registry containing only the prelude pub static ref PRELUDE_REGISTRY: ExtensionRegistry = @@ -528,7 +533,7 @@ impl MakeOpDef for TupleOpDef { } fn from_def(op_def: &OpDef) -> Result { - try_from_name(op_def.name(), op_def.extension()) + try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -695,7 +700,7 @@ impl MakeOpDef for NoopDef { } fn from_def(op_def: &OpDef) -> Result { - try_from_name(op_def.name(), op_def.extension()) + try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -805,7 +810,7 @@ impl MakeOpDef for LiftDef { } fn from_def(op_def: &OpDef) -> Result { - try_from_name(op_def.name(), op_def.extension()) + try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { diff --git a/hugr-core/src/extension/prelude/array.rs b/hugr-core/src/extension/prelude/array.rs index 6013039d4..c419a67c7 100644 --- a/hugr-core/src/extension/prelude/array.rs +++ b/hugr-core/src/extension/prelude/array.rs @@ -1,4 +1,5 @@ use std::str::FromStr; +use std::sync::Weak; use itertools::Itertools; use strum_macros::EnumIter; @@ -180,7 +181,7 @@ impl MakeOpDef for ArrayOpDef { where Self: Sized, { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn signature(&self) -> SignatureFunc { @@ -216,9 +217,10 @@ impl MakeOpDef for ArrayOpDef { fn add_to_extension( &self, extension: &mut Extension, + extension_ref: &Weak, ) -> Result<(), crate::extension::ExtensionBuildError> { let sig = self.signature_from_def(extension.get_type(ARRAY_TYPE_NAME).unwrap()); - let def = extension.add_op(self.name(), self.description(), sig)?; + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); @@ -394,7 +396,7 @@ impl MakeOpDef for ArrayScanDef { where Self: Sized, { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn signature(&self) -> SignatureFunc { @@ -421,9 +423,10 @@ impl MakeOpDef for ArrayScanDef { fn add_to_extension( &self, extension: &mut Extension, + extension_ref: &Weak, ) -> Result<(), crate::extension::ExtensionBuildError> { let sig = self.signature_from_def(extension.get_type(ARRAY_TYPE_NAME).unwrap()); - let def = extension.add_op(self.name(), self.description(), sig)?; + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); diff --git a/hugr-core/src/extension/simple_op.rs b/hugr-core/src/extension/simple_op.rs index c338a693d..6d1c678c5 100644 --- a/hugr-core/src/extension/simple_op.rs +++ b/hugr-core/src/extension/simple_op.rs @@ -1,5 +1,7 @@ //! A trait that enum for op definitions that gathers up some shared functionality. +use std::sync::Weak; + use strum::IntoEnumIterator; use crate::ops::{ExtensionOp, OpName, OpNameRef}; @@ -67,8 +69,20 @@ pub trait MakeOpDef: NamedOp { /// Add an operation implemented as an [MakeOpDef], which can provide the data /// required to define an [OpDef], to an extension. - fn add_to_extension(&self, extension: &mut Extension) -> Result<(), ExtensionBuildError> { - let def = extension.add_op(self.name(), self.description(), self.signature())?; + /// + /// Requires a [`Weak`] reference to the extension defining the operation. + /// This method is intended to be used inside the closure passed to [`Extension::new_arc`]. + fn add_to_extension( + &self, + extension: &mut Extension, + extension_ref: &Weak, + ) -> Result<(), ExtensionBuildError> { + let def = extension.add_op( + self.name(), + self.description(), + self.signature(), + extension_ref, + )?; self.post_opdef(def); @@ -77,12 +91,18 @@ pub trait MakeOpDef: NamedOp { /// Load all variants of an enum of op definitions in to an extension as op defs. /// See [strum::IntoEnumIterator]. - fn load_all_ops(extension: &mut Extension) -> Result<(), ExtensionBuildError> + /// + /// Requires a [`Weak`] reference to the extension defining the operation. + /// This method is intended to be used inside the closure passed to [`Extension::new_arc`]. + fn load_all_ops( + extension: &mut Extension, + extension_ref: &Weak, + ) -> Result<(), ExtensionBuildError> where Self: IntoEnumIterator, { for op in Self::iter() { - op.add_to_extension(extension)?; + op.add_to_extension(extension, extension_ref)?; } Ok(()) } @@ -316,9 +336,11 @@ mod test { lazy_static! { static ref EXT: Arc = { - let mut e = Extension::new_test(EXT_ID.clone()); - DummyEnum::Dumb.add_to_extension(&mut e).unwrap(); - Arc::new(e) + Extension::new_test_arc(EXT_ID.clone(), |ext, extension_ref| { + DummyEnum::Dumb + .add_to_extension(ext, extension_ref) + .unwrap(); + }) }; static ref DUMMY_REG: ExtensionRegistry = ExtensionRegistry::try_new([EXT.clone()]).unwrap(); diff --git a/hugr-core/src/extension/type_def.rs b/hugr-core/src/extension/type_def.rs index 1affe68f0..7f0daa3ca 100644 --- a/hugr-core/src/extension/type_def.rs +++ b/hugr-core/src/extension/type_def.rs @@ -1,4 +1,5 @@ use std::collections::btree_map::Entry; +use std::sync::Weak; use super::{CustomConcrete, ExtensionBuildError}; use super::{Extension, ExtensionId, SignatureError}; @@ -56,6 +57,9 @@ impl TypeDefBound { pub struct TypeDef { /// The unique Extension owning this TypeDef (of which this TypeDef is a member) extension: ExtensionId, + /// A weak reference to the extension defining this operation. + #[serde(skip)] + extension_ref: Weak, /// The unique name of the type name: TypeName, /// Declaration of type parameters. The TypeDef must be instantiated @@ -82,9 +86,9 @@ impl TypeDef { /// This function will return an error if the type of the instance does not /// match the definition. pub fn check_custom(&self, custom: &CustomType) -> Result<(), SignatureError> { - if self.extension() != custom.parent_extension() { + if self.extension_id() != custom.parent_extension() { return Err(SignatureError::ExtensionMismatch( - self.extension().clone(), + self.extension_id().clone(), custom.parent_extension().clone(), )); } @@ -121,7 +125,7 @@ impl TypeDef { Ok(CustomType::new( self.name().clone(), args, - self.extension().clone(), + self.extension_id().clone(), bound, )) } @@ -156,22 +160,55 @@ impl TypeDef { &self.name } - fn extension(&self) -> &ExtensionId { + /// Returns a reference to the extension id of this [`TypeDef`]. + pub fn extension_id(&self) -> &ExtensionId { &self.extension } + + /// Returns a weak reference to the extension defining this type. + pub fn extension(&self) -> Weak { + self.extension_ref.clone() + } } impl Extension { /// Add an exported type to the extension. + /// + /// This method requires a [`Weak`] reference to the [`std::sync::Arc`] containing the + /// extension being defined. The intended way to call this method is inside + /// the closure passed to [`Extension::new_arc`] when defining the extension. + /// + /// # Example + /// + /// ``` + /// # use hugr_core::types::Signature; + /// # use hugr_core::extension::{Extension, ExtensionId, Version}; + /// # use hugr_core::extension::{TypeDefBound}; + /// Extension::new_arc( + /// ExtensionId::new_unchecked("my.extension"), + /// Version::new(0, 1, 0), + /// |ext, extension_ref| { + /// ext.add_type( + /// "MyType".into(), + /// vec![], // No type parameters + /// "Some type".into(), + /// TypeDefBound::any(), + /// extension_ref, + /// ); + /// }, + /// ); + /// ``` pub fn add_type( &mut self, name: TypeName, params: Vec, description: String, bound: TypeDefBound, + extension_ref: &Weak, ) -> Result<&TypeDef, ExtensionBuildError> { let ty = TypeDef { extension: self.name.clone(), + extension_ref: extension_ref.clone(), name, params, description, @@ -202,6 +239,8 @@ mod test { b: TypeBound::Copyable, }], extension: "MyRsrc".try_into().unwrap(), + // Dummy extension. Will return `None` when trying to upgrade it into an `Arc`. + extension_ref: Default::default(), description: "Some parametrised type".into(), bound: TypeDefBound::FromParams { indices: vec![0] }, }; diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index 8967df9a5..e23fab7c5 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -463,7 +463,7 @@ mod test { use crate::std_extensions::collections::{self, list_type, ListOp}; use crate::types::{Signature, Type, TypeRow}; use crate::utils::depth; - use crate::{type_row, Direction, Hugr, HugrView, OutgoingPort}; + use crate::{type_row, Direction, Extension, Hugr, HugrView, OutgoingPort}; use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement}; @@ -638,10 +638,26 @@ mod test { #[test] fn test_invalid() { - let mut new_ext = crate::Extension::new_test("new_ext".try_into().unwrap()); - let ext_name = new_ext.name().clone(); let utou = Signature::new_endo(vec![USIZE_T]); - let mut mk_op = |s| new_ext.simple_ext_op(s, utou.clone()); + let ext = Extension::new_test_arc("new_ext".try_into().unwrap(), |ext, extension_ref| { + ext.add_op("foo".into(), "".to_string(), utou.clone(), extension_ref) + .unwrap(); + ext.add_op("bar".into(), "".to_string(), utou.clone(), extension_ref) + .unwrap(); + ext.add_op("baz".into(), "".to_string(), utou.clone(), extension_ref) + .unwrap(); + }); + let ext_name = ext.name().clone(); + let foo = ext + .instantiate_extension_op("foo", [], &PRELUDE_REGISTRY) + .unwrap(); + let bar = ext + .instantiate_extension_op("bar", [], &PRELUDE_REGISTRY) + .unwrap(); + let baz = ext + .instantiate_extension_op("baz", [], &PRELUDE_REGISTRY) + .unwrap(); + let mut h = DFGBuilder::new( Signature::new(type_row![USIZE_T, BOOL_T], type_row![USIZE_T]) .with_extension_delta(ext_name.clone()), @@ -657,23 +673,17 @@ mod test { ) .unwrap(); let mut case1 = cond.case_builder(0).unwrap(); - let foo = case1 - .add_dataflow_op(mk_op("foo"), case1.input_wires()) - .unwrap(); + let foo = case1.add_dataflow_op(foo, case1.input_wires()).unwrap(); let case1 = case1.finish_with_outputs(foo.outputs()).unwrap().node(); let mut case2 = cond.case_builder(1).unwrap(); - let bar = case2 - .add_dataflow_op(mk_op("bar"), case2.input_wires()) - .unwrap(); + let bar = case2.add_dataflow_op(bar, case2.input_wires()).unwrap(); let mut baz_dfg = case2 .dfg_builder( utou.clone().with_extension_delta(ext_name.clone()), bar.outputs(), ) .unwrap(); - let baz = baz_dfg - .add_dataflow_op(mk_op("baz"), baz_dfg.input_wires()) - .unwrap(); + let baz = baz_dfg.add_dataflow_op(baz, baz_dfg.input_wires()).unwrap(); let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs()).unwrap(); let case2 = case2.finish_with_outputs(baz_dfg.outputs()).unwrap().node(); let cond = cond.finish_sub_container().unwrap(); diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index cf934e18b..97ffc3d53 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -1,5 +1,6 @@ use std::fs::File; use std::io::BufReader; +use std::sync::Arc; use cool_asserts::assert_matches; @@ -378,15 +379,17 @@ const_extension_ids! { } #[test] fn invalid_types() { - let mut e = Extension::new_test(EXT_ID); - e.add_type( - "MyContainer".into(), - vec![TypeBound::Copyable.into()], - "".into(), - TypeDefBound::any(), - ) - .unwrap(); - let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()]).unwrap(); + let ext = Extension::new_test_arc(EXT_ID, |ext, extension_ref| { + ext.add_type( + "MyContainer".into(), + vec![TypeBound::Copyable.into()], + "".into(), + TypeDefBound::any(), + extension_ref, + ) + .unwrap(); + }); + let reg = ExtensionRegistry::try_new([ext, PRELUDE.clone()]).unwrap(); let validate_to_sig_error = |t: CustomType| { let (h, def) = identity_hugr_with_type(Type::new_extension(t)); @@ -587,33 +590,33 @@ fn no_polymorphic_consts() -> Result<(), Box> { Ok(()) } -pub(crate) fn extension_with_eval_parallel() -> Extension { +pub(crate) fn extension_with_eval_parallel() -> Arc { let rowp = TypeParam::new_list(TypeBound::Any); - let mut e = Extension::new_test(EXT_ID); - - let inputs = TypeRV::new_row_var_use(0, TypeBound::Any); - let outputs = TypeRV::new_row_var_use(1, TypeBound::Any); - let evaled_fn = TypeRV::new_function(FuncValueType::new(inputs.clone(), outputs.clone())); - let pf = PolyFuncTypeRV::new( - [rowp.clone(), rowp.clone()], - FuncValueType::new(vec![evaled_fn, inputs], outputs), - ); - e.add_op("eval".into(), "".into(), pf).unwrap(); - - let rv = |idx| TypeRV::new_row_var_use(idx, TypeBound::Any); - let pf = PolyFuncTypeRV::new( - [rowp.clone(), rowp.clone(), rowp.clone(), rowp.clone()], - Signature::new( - vec![ - Type::new_function(FuncValueType::new(rv(0), rv(2))), - Type::new_function(FuncValueType::new(rv(1), rv(3))), - ], - Type::new_function(FuncValueType::new(vec![rv(0), rv(1)], vec![rv(2), rv(3)])), - ), - ); - e.add_op("parallel".into(), "".into(), pf).unwrap(); + Extension::new_test_arc(EXT_ID, |ext, extension_ref| { + let inputs = TypeRV::new_row_var_use(0, TypeBound::Any); + let outputs = TypeRV::new_row_var_use(1, TypeBound::Any); + let evaled_fn = TypeRV::new_function(FuncValueType::new(inputs.clone(), outputs.clone())); + let pf = PolyFuncTypeRV::new( + [rowp.clone(), rowp.clone()], + FuncValueType::new(vec![evaled_fn, inputs], outputs), + ); + ext.add_op("eval".into(), "".into(), pf, extension_ref) + .unwrap(); - e + let rv = |idx| TypeRV::new_row_var_use(idx, TypeBound::Any); + let pf = PolyFuncTypeRV::new( + [rowp.clone(), rowp.clone(), rowp.clone(), rowp.clone()], + Signature::new( + vec![ + Type::new_function(FuncValueType::new(rv(0), rv(2))), + Type::new_function(FuncValueType::new(rv(1), rv(3))), + ], + Type::new_function(FuncValueType::new(vec![rv(0), rv(1)], vec![rv(2), rv(3)])), + ), + ); + ext.add_op("parallel".into(), "".into(), pf, extension_ref) + .unwrap(); + }) } #[test] @@ -643,7 +646,7 @@ fn instantiate_row_variables() -> Result<(), Box> { let eval2 = dfb.add_dataflow_op(eval2, [par_func, a, b])?; dfb.finish_hugr_with_outputs( eval2.outputs(), - &ExtensionRegistry::try_new([PRELUDE.clone(), e.into()]).unwrap(), + &ExtensionRegistry::try_new([PRELUDE.clone(), e]).unwrap(), )?; Ok(()) } @@ -683,41 +686,44 @@ fn row_variables() -> Result<(), Box> { let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?; fb.finish_hugr_with_outputs( par_func.outputs(), - &ExtensionRegistry::try_new([PRELUDE.clone(), e.into()]).unwrap(), + &ExtensionRegistry::try_new([PRELUDE.clone(), e]).unwrap(), )?; Ok(()) } #[test] fn test_polymorphic_call() -> Result<(), Box> { - let mut e = Extension::new_test(EXT_ID); - - let params: Vec = vec![ - TypeBound::Any.into(), - TypeParam::Extensions, - TypeBound::Any.into(), - ]; - let evaled_fn = Type::new_function( - Signature::new( - Type::new_var_use(0, TypeBound::Any), - Type::new_var_use(2, TypeBound::Any), - ) - .with_extension_delta(ExtensionSet::type_var(1)), - ); - // Single-input/output version of the higher-order "eval" operation, with extension param. - // Note the extension-delta of the eval node includes that of the input function. - e.add_op( - "eval".into(), - "".into(), - PolyFuncTypeRV::new( - params.clone(), + let e = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { + let params: Vec = vec![ + TypeBound::Any.into(), + TypeParam::Extensions, + TypeBound::Any.into(), + ]; + let evaled_fn = Type::new_function( Signature::new( - vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)], + Type::new_var_use(0, TypeBound::Any), Type::new_var_use(2, TypeBound::Any), ) .with_extension_delta(ExtensionSet::type_var(1)), - ), - )?; + ); + // Single-input/output version of the higher-order "eval" operation, with extension param. + // Note the extension-delta of the eval node includes that of the input function. + ext.add_op( + "eval".into(), + "".into(), + PolyFuncTypeRV::new( + params.clone(), + Signature::new( + vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)], + Type::new_var_use(2, TypeBound::Any), + ) + .with_extension_delta(ExtensionSet::type_var(1)), + ), + extension_ref, + )?; + + Ok(()) + })?; fn utou(e: impl Into) -> Type { Type::new_function(Signature::new_endo(USIZE_T).with_extension_delta(e.into())) @@ -763,7 +769,7 @@ fn test_polymorphic_call() -> Result<(), Box> { f.finish_with_outputs([tup])? }; - let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()])?; + let reg = ExtensionRegistry::try_new([e, PRELUDE.clone()])?; let [func, tup] = d.input_wires_arr(); let call = d.call( f.handle(), diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index eec5f4d34..d7f1c2c57 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -104,7 +104,7 @@ impl ExtensionOp { /// For a non-cloning version of this operation, use [`OpaqueOp::from`]. pub fn make_opaque(&self) -> OpaqueOp { OpaqueOp { - extension: self.def.extension().clone(), + extension: self.def.extension_id().clone(), name: self.def.name().clone(), description: self.def.description().into(), args: self.args.clone(), @@ -121,7 +121,7 @@ impl From for OpaqueOp { signature, } = op; OpaqueOp { - extension: def.extension().clone(), + extension: def.extension_id().clone(), name: def.name().clone(), description: def.description().into(), args, @@ -141,7 +141,7 @@ impl Eq for ExtensionOp {} impl NamedOp for ExtensionOp { /// The name of the operation. fn name(&self) -> OpName { - qualify_name(self.def.extension(), self.def.name()) + qualify_name(self.def.extension_id(), self.def.name()) } } @@ -402,26 +402,30 @@ mod test { #[test] fn resolve_missing() { - let mut ext = Extension::new_test("ext".try_into().unwrap()); - let ext_id = ext.name().clone(); let val_name = "missing_val"; let comp_name = "missing_comp"; - let endo_sig = Signature::new_endo(BOOL_T); - ext.add_op( - val_name.into(), - "".to_string(), - SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()), - ) - .unwrap(); - ext.add_op( - comp_name.into(), - "".to_string(), - SignatureFunc::MissingComputeFunc, - ) - .unwrap(); - let registry = ExtensionRegistry::try_new([ext.into()]).unwrap(); + let ext = Extension::new_test_arc("ext".try_into().unwrap(), |ext, extension_ref| { + ext.add_op( + val_name.into(), + "".to_string(), + SignatureFunc::MissingValidateFunc(FuncValueType::from(endo_sig.clone()).into()), + extension_ref, + ) + .unwrap(); + + ext.add_op( + comp_name.into(), + "".to_string(), + SignatureFunc::MissingComputeFunc, + extension_ref, + ) + .unwrap(); + }); + let ext_id = ext.name().clone(); + + let registry = ExtensionRegistry::try_new([ext]).unwrap(); let opaque_val = OpaqueOp::new( ext_id.clone(), val_name, diff --git a/hugr-core/src/std_extensions/arithmetic/conversions.rs b/hugr-core/src/std_extensions/arithmetic/conversions.rs index deb93f8c2..f26e27485 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions.rs @@ -50,7 +50,7 @@ pub enum ConvertOpDef { impl MakeOpDef for ConvertOpDef { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -158,18 +158,15 @@ impl MakeExtensionOp for ConvertOpType { lazy_static! { /// Extension for conversions between integers and floats. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new( - EXTENSION_ID, - VERSION).with_reqs( - ExtensionSet::from_iter(vec![ - super::int_types::EXTENSION_ID, - super::float_types::EXTENSION_ID, - ]), - ); - - ConvertOpDef::load_all_ops(&mut extension).unwrap(); - - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.set_reqs( + ExtensionSet::from_iter(vec![ + super::int_types::EXTENSION_ID, + super::float_types::EXTENSION_ID, + ])); + + ConvertOpDef::load_all_ops(extension, extension_ref).unwrap(); + }) }; /// Registry of extensions required to validate integer operations. diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops.rs b/hugr-core/src/std_extensions/arithmetic/float_ops.rs index 7d353e71a..1ef416d0d 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops.rs @@ -50,7 +50,7 @@ pub enum FloatOps { impl MakeOpDef for FloatOps { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -107,15 +107,10 @@ impl MakeOpDef for FloatOps { lazy_static! { /// Extension for basic float operations. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new( - EXTENSION_ID, - VERSION).with_reqs( - ExtensionSet::singleton(&super::int_types::EXTENSION_ID), - ); - - FloatOps::load_all_ops(&mut extension).unwrap(); - - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.set_reqs(ExtensionSet::singleton(&super::int_types::EXTENSION_ID)); + FloatOps::load_all_ops(extension, extension_ref).unwrap(); + }) }; /// Registry of extensions required to validate float operations. diff --git a/hugr-core/src/std_extensions/arithmetic/float_types.rs b/hugr-core/src/std_extensions/arithmetic/float_types.rs index ec145008f..0af5f8728 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_types.rs @@ -82,18 +82,17 @@ impl CustomConst for ConstF64 { lazy_static! { /// Extension defining the float type. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - - extension - .add_type( - FLOAT_TYPE_ID, - vec![], - "64-bit IEEE 754-2019 floating-point value".to_owned(), - TypeBound::Copyable.into(), - ) - .unwrap(); - - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension + .add_type( + FLOAT_TYPE_ID, + vec![], + "64-bit IEEE 754-2019 floating-point value".to_owned(), + TypeBound::Copyable.into(), + extension_ref, + ) + .unwrap(); + }) }; } #[cfg(test)] diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index 97bb247a2..235fbc9ba 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -104,7 +104,7 @@ pub enum IntOpDef { impl MakeOpDef for IntOpDef { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -250,15 +250,10 @@ fn iunop_sig() -> PolyFuncTypeRV { lazy_static! { /// Extension for basic integer operations. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new( - EXTENSION_ID, - VERSION).with_reqs( - ExtensionSet::singleton(&super::int_types::EXTENSION_ID) - ); - - IntOpDef::load_all_ops(&mut extension).unwrap(); - - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.set_reqs(ExtensionSet::singleton(&super::int_types::EXTENSION_ID)); + IntOpDef::load_all_ops(extension, extension_ref).unwrap(); + }) }; /// Registry of extensions required to validate integer operations. diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index 3d257b9d0..82f1c27ae 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -188,18 +188,17 @@ impl CustomConst for ConstInt { /// Extension for basic integer types. pub fn extension() -> Arc { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - - extension - .add_type( - INT_TYPE_ID, - vec![LOG_WIDTH_TYPE_PARAM], - "integral value of a given bit width".to_owned(), - TypeBound::Copyable.into(), - ) - .unwrap(); - - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension + .add_type( + INT_TYPE_ID, + vec![LOG_WIDTH_TYPE_PARAM], + "integral value of a given bit width".to_owned(), + TypeBound::Copyable.into(), + extension_ref, + ) + .unwrap(); + }) } lazy_static! { diff --git a/hugr-core/src/std_extensions/collections.rs b/hugr-core/src/std_extensions/collections.rs index 17a1b0d03..2f416c92b 100644 --- a/hugr-core/src/std_extensions/collections.rs +++ b/hugr-core/src/std_extensions/collections.rs @@ -5,7 +5,7 @@ use std::hash::{Hash, Hasher}; mod list_fold; use std::str::FromStr; -use std::sync::Arc; +use std::sync::{Arc, Weak}; use itertools::Itertools; use lazy_static::lazy_static; @@ -204,7 +204,7 @@ impl ListOp { impl MakeOpDef for ListOp { fn from_def(op_def: &OpDef) -> Result { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -216,9 +216,13 @@ impl MakeOpDef for ListOp { // // This method is re-defined here since we need to pass the list type def while computing the signature, // to avoid recursive loops initializing the extension. - fn add_to_extension(&self, extension: &mut Extension) -> Result<(), ExtensionBuildError> { + fn add_to_extension( + &self, + extension: &mut Extension, + extension_ref: &Weak, + ) -> Result<(), ExtensionBuildError> { let sig = self.compute_signature(extension.get_type(&LIST_TYPENAME).unwrap()); - let def = extension.add_op(self.name(), self.description(), sig)?; + let def = extension.add_op(self.name(), self.description(), sig, extension_ref)?; self.post_opdef(def); @@ -251,20 +255,19 @@ impl MakeOpDef for ListOp { lazy_static! { /// Extension for list operations. pub static ref EXTENSION: Arc = { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - - // The list type must be defined before the operations are added. - extension.add_type( - LIST_TYPENAME, - vec![ListOp::TP], - "Generic dynamically sized list of type T.".into(), - TypeDefBound::from_params(vec![0]), - ) - .unwrap(); - - ListOp::load_all_ops(&mut extension).unwrap(); + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension.add_type( + LIST_TYPENAME, + vec![ListOp::TP], + "Generic dynamically sized list of type T.".into(), + TypeDefBound::from_params(vec![0]), + extension_ref + ) + .unwrap(); - Arc::new(extension) + // The list type must be defined before the operations are added. + ListOp::load_all_ops(extension, extension_ref).unwrap(); + }) }; /// Registry of extensions required to validate list operations. @@ -392,7 +395,7 @@ mod test { assert_eq!(&ListOp::push.extension(), EXTENSION.name()); assert!(ListOp::pop.registry().contains(EXTENSION.name())); for (_, op_def) in EXTENSION.operations() { - assert_eq!(op_def.extension(), &EXTENSION_ID); + assert_eq!(op_def.extension_id(), &EXTENSION_ID); } } diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index 89f9dfa8b..4799e3f33 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -91,7 +91,7 @@ impl MakeOpDef for LogicOp { } fn from_def(op_def: &OpDef) -> Result { - try_from_name(op_def.name(), op_def.extension()) + try_from_name(op_def.name(), op_def.extension_id()) } fn extension(&self) -> ExtensionId { @@ -110,16 +110,16 @@ pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); /// Extension for basic logical operations. fn extension() -> Arc { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - LogicOp::load_all_ops(&mut extension).unwrap(); + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + LogicOp::load_all_ops(extension, extension_ref).unwrap(); - extension - .add_value(FALSE_NAME, ops::Value::false_val()) - .unwrap(); - extension - .add_value(TRUE_NAME, ops::Value::true_val()) - .unwrap(); - Arc::new(extension) + extension + .add_value(FALSE_NAME, ops::Value::false_val()) + .unwrap(); + extension + .add_value(TRUE_NAME, ops::Value::true_val()) + .unwrap(); + }) } lazy_static! { diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index e3023e4b5..1822967b7 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -47,7 +47,7 @@ impl MakeOpDef for PtrOpDef { where Self: Sized, { - crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension_id()) } fn signature(&self) -> SignatureFunc { @@ -87,17 +87,18 @@ pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); /// Extension for pointer operations. fn extension() -> Arc { - let mut extension = Extension::new(EXTENSION_ID, VERSION); - extension - .add_type( - PTR_TYPE_ID, - TYPE_PARAMS.into(), - "Standard extension pointer type.".into(), - TypeDefBound::copyable(), - ) - .unwrap(); - PtrOpDef::load_all_ops(&mut extension).unwrap(); - Arc::new(extension) + Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { + extension + .add_type( + PTR_TYPE_ID, + TYPE_PARAMS.into(), + "Standard extension pointer type.".into(), + TypeDefBound::copyable(), + extension_ref, + ) + .unwrap(); + PtrOpDef::load_all_ops(extension, extension_ref).unwrap(); + }) } lazy_static! { diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 7e3f4f664..77c4ab990 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -321,16 +321,18 @@ pub(crate) mod test { const EXT_ID: ExtensionId = ExtensionId::new_unchecked("my_ext"); const TYPE_NAME: TypeName = TypeName::new_inline("MyType"); - let mut e = Extension::new_test(EXT_ID); - e.add_type( - TYPE_NAME, - vec![bound.clone()], - "".into(), - TypeDefBound::any(), - ) - .unwrap(); + let ext = Extension::new_test_arc(EXT_ID, |ext, extension_ref| { + ext.add_type( + TYPE_NAME, + vec![bound.clone()], + "".into(), + TypeDefBound::any(), + extension_ref, + ) + .unwrap(); + }); - let reg = ExtensionRegistry::try_new([e.into()]).unwrap(); + let reg = ExtensionRegistry::try_new([ext]).unwrap(); let make_scheme = |tp: TypeParam| { PolyFuncTypeBase::new_validated( diff --git a/hugr-core/src/utils.rs b/hugr-core/src/utils.rs index f1f97cc5a..702ad8c19 100644 --- a/hugr-core/src/utils.rs +++ b/hugr-core/src/utils.rs @@ -131,48 +131,60 @@ pub(crate) mod test_quantum_extension { /// The extension identifier. pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("test.quantum"); fn extension() -> Arc { - let mut extension = Extension::new_test(EXTENSION_ID); - - extension - .add_op(OpName::new_inline("H"), "Hadamard".into(), one_qb_func()) - .unwrap(); - extension - .add_op( - OpName::new_inline("RzF64"), - "Rotation specified by float".into(), - Signature::new(type_row![QB_T, float_types::FLOAT64_TYPE], type_row![QB_T]), - ) - .unwrap(); - - extension - .add_op(OpName::new_inline("CX"), "CX".into(), two_qb_func()) - .unwrap(); - - extension - .add_op( - OpName::new_inline("Measure"), - "Measure a qubit, returning the qubit and the measurement result.".into(), - Signature::new(type_row![QB_T], type_row![QB_T, BOOL_T]), - ) - .unwrap(); - - extension - .add_op( - OpName::new_inline("QAlloc"), - "Allocate a new qubit.".into(), - Signature::new(type_row![], type_row![QB_T]), - ) - .unwrap(); - - extension - .add_op( - OpName::new_inline("QDiscard"), - "Discard a qubit.".into(), - Signature::new(type_row![QB_T], type_row![]), - ) - .unwrap(); - - Arc::new(extension) + Extension::new_test_arc(EXTENSION_ID, |extension, extension_ref| { + extension + .add_op( + OpName::new_inline("H"), + "Hadamard".into(), + one_qb_func(), + extension_ref, + ) + .unwrap(); + extension + .add_op( + OpName::new_inline("RzF64"), + "Rotation specified by float".into(), + Signature::new(type_row![QB_T, float_types::FLOAT64_TYPE], type_row![QB_T]), + extension_ref, + ) + .unwrap(); + + extension + .add_op( + OpName::new_inline("CX"), + "CX".into(), + two_qb_func(), + extension_ref, + ) + .unwrap(); + + extension + .add_op( + OpName::new_inline("Measure"), + "Measure a qubit, returning the qubit and the measurement result.".into(), + Signature::new(type_row![QB_T], type_row![QB_T, BOOL_T]), + extension_ref, + ) + .unwrap(); + + extension + .add_op( + OpName::new_inline("QAlloc"), + "Allocate a new qubit.".into(), + Signature::new(type_row![], type_row![QB_T]), + extension_ref, + ) + .unwrap(); + + extension + .add_op( + OpName::new_inline("QDiscard"), + "Discard a qubit.".into(), + Signature::new(type_row![QB_T], type_row![]), + extension_ref, + ) + .unwrap(); + }) } lazy_static! { diff --git a/hugr-llvm/src/custom/extension_op.rs b/hugr-llvm/src/custom/extension_op.rs index 08b392036..cd3c3b6e7 100644 --- a/hugr-llvm/src/custom/extension_op.rs +++ b/hugr-llvm/src/custom/extension_op.rs @@ -100,7 +100,7 @@ impl<'a, H: HugrView> ExtensionOpMap<'a, H> { args: EmitOpArgs<'c, '_, ExtensionOp, H>, ) -> Result<()> { let node = args.node(); - let key = (node.def().extension().clone(), node.def().name().clone()); + let key = (node.def().extension_id().clone(), node.def().name().clone()); let Some(handler) = self.0.get(&key) else { bail!("No extension could emit extension op: {key:?}") }; diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index a34ecc351..3ac3dacb7 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -157,6 +157,7 @@ fn mk_rep( #[cfg(test)] mod test { use std::collections::HashSet; + use std::sync::Arc; use hugr_core::extension::prelude::Lift; use itertools::Itertools; @@ -178,21 +179,26 @@ mod test { const EXT_ID: ExtensionId = "TestExt"; } - fn extension() -> Extension { - let mut e = Extension::new(EXT_ID, hugr_core::extension::Version::new(0, 1, 0)); - e.add_op( - "Test".into(), - String::new(), - Signature::new( - type_row![QB_T, USIZE_T], - TypeRow::from(vec![Type::new_sum(vec![ - type_row![QB_T], - type_row![USIZE_T], - ])]), - ), + fn extension() -> Arc { + Extension::new_arc( + EXT_ID, + hugr_core::extension::Version::new(0, 1, 0), + |ext, extension_ref| { + ext.add_op( + "Test".into(), + String::new(), + Signature::new( + type_row![QB_T, USIZE_T], + TypeRow::from(vec![Type::new_sum(vec![ + type_row![QB_T], + type_row![USIZE_T], + ])]), + ), + extension_ref, + ) + .unwrap(); + }, ) - .unwrap(); - e } fn lifted_unary_unit_sum + AsRef, T>(b: &mut DFGWrapper) -> Wire { @@ -228,7 +234,7 @@ mod test { let exit_types = type_row![USIZE_T]; let e = extension(); let tst_op = e.instantiate_extension_op("Test", [], &PRELUDE_REGISTRY)?; - let reg = ExtensionRegistry::try_new([PRELUDE.clone(), e.into()])?; + let reg = ExtensionRegistry::try_new([PRELUDE.clone(), e])?; let mut h = CFGBuilder::new(inout_sig(loop_variants.clone(), exit_types.clone()))?; let mut no_b1 = h.simple_entry_builder_exts(loop_variants.clone(), 1, PRELUDE_ID)?; let n = no_b1.add_dataflow_op(Noop::new(QB_T), no_b1.input_wires())?; @@ -299,7 +305,7 @@ mod test { // And the Noop in the entry block is consumed by the custom Test op let tst = find_unique( h.nodes(), - |n| matches!(h.get_optype(*n), OpType::ExtensionOp(c) if c.def().extension() != &PRELUDE_ID), + |n| matches!(h.get_optype(*n), OpType::ExtensionOp(c) if c.def().extension_id() != &PRELUDE_ID), ); assert_eq!(h.get_parent(tst), Some(entry)); assert_eq!( @@ -355,7 +361,7 @@ mod test { h.branch(&bb2, 0, &bb3)?; h.branch(&bb3, 0, &h.exit_block())?; - let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()])?; + let reg = ExtensionRegistry::try_new([e, PRELUDE.clone()])?; let mut h = h.finish_hugr(®)?; let root = h.root(); merge_basic_blocks(&mut SiblingMut::try_new(&mut h, root)?); @@ -365,7 +371,7 @@ mod test { let [bb, _exit] = h.children(h.root()).collect::>().try_into().unwrap(); let tst = find_unique( h.nodes(), - |n| matches!(h.get_optype(*n), OpType::ExtensionOp(c) if c.def().extension() != &PRELUDE_ID), + |n| matches!(h.get_optype(*n), OpType::ExtensionOp(c) if c.def().extension_id() != &PRELUDE_ID), ); assert_eq!(h.get_parent(tst), Some(bb)); diff --git a/hugr/benches/benchmarks/hugr/examples.rs b/hugr/benches/benchmarks/hugr/examples.rs index 3abcee535..48a336d2d 100644 --- a/hugr/benches/benchmarks/hugr/examples.rs +++ b/hugr/benches/benchmarks/hugr/examples.rs @@ -1,5 +1,7 @@ //! Builders and utilities for benchmarks. +use std::sync::Arc; + use hugr::builder::{ BuildError, CFGBuilder, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, @@ -53,35 +55,35 @@ pub fn simple_cfg_hugr() -> Hugr { } lazy_static! { - static ref QUANTUM_EXT: Extension = { - let mut extension = Extension::new( + static ref QUANTUM_EXT: Arc = { + Extension::new_arc( "bench.quantum".try_into().unwrap(), hugr::extension::Version::new(0, 0, 0), - ); - - extension - .add_op( - OpName::new_inline("H"), - "".into(), - Signature::new_endo(QB_T), - ) - .unwrap(); - extension - .add_op( - OpName::new_inline("Rz"), - "".into(), - Signature::new(type_row![QB_T, FLOAT64_TYPE], type_row![QB_T]), - ) - .unwrap(); - - extension - .add_op( - OpName::new_inline("CX"), - "".into(), - Signature::new_endo(type_row![QB_T, QB_T]), - ) - .unwrap(); - extension + |ext, extension_ref| { + ext.add_op( + OpName::new_inline("H"), + "".into(), + Signature::new_endo(QB_T), + extension_ref, + ) + .unwrap(); + ext.add_op( + OpName::new_inline("Rz"), + "".into(), + Signature::new(type_row![QB_T, FLOAT64_TYPE], type_row![QB_T]), + extension_ref, + ) + .unwrap(); + + ext.add_op( + OpName::new_inline("CX"), + "".into(), + Signature::new_endo(type_row![QB_T, QB_T]), + extension_ref, + ) + .unwrap(); + }, + ) }; } diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index 94dd141a3..d23063dfa 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -61,25 +61,21 @@ //! pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("mini.quantum"); //! pub const VERSION: Version = Version::new(0, 1, 0); //! fn extension() -> Arc { -//! let mut extension = Extension::new(EXTENSION_ID, VERSION); +//! Extension::new_arc(EXTENSION_ID, VERSION, |ext, extension_ref| { +//! ext.add_op(OpName::new_inline("H"), "Hadamard".into(), one_qb_func(), extension_ref) +//! .unwrap(); //! -//! extension -//! .add_op(OpName::new_inline("H"), "Hadamard".into(), one_qb_func()) -//! .unwrap(); -//! -//! extension -//! .add_op(OpName::new_inline("CX"), "CX".into(), two_qb_func()) -//! .unwrap(); +//! ext.add_op(OpName::new_inline("CX"), "CX".into(), two_qb_func(), extension_ref) +//! .unwrap(); //! -//! extension -//! .add_op( +//! ext.add_op( //! OpName::new_inline("Measure"), //! "Measure a qubit, returning the qubit and the measurement result.".into(), //! FuncValueType::new(type_row![QB_T], type_row![QB_T, BOOL_T]), +//! extension_ref, //! ) //! .unwrap(); -//! -//! Arc::new(extension) +//! }) //! } //! //! lazy_static! {