diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index ae04d00b4..6d0cf0865 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -29,7 +29,7 @@ //! # use hugr::Hugr; //! # use hugr::builder::{BuildError, BuildHandle, Container, DFGBuilder, Dataflow, DataflowHugr, ModuleBuilder, DataflowSubContainer, HugrBuilder}; //! use hugr::extension::prelude::bool_t; -//! use hugr::std_extensions::logic::{EXTENSION_ID, LOGIC_REG, LogicOp}; +//! use hugr::std_extensions::logic::{self, LogicOp}; //! use hugr::types::Signature; //! //! # fn doctest() -> Result<(), BuildError> { @@ -42,7 +42,7 @@ //! let _dfg_handle = { //! let mut dfg = module_builder.define_function( //! "main", -//! Signature::new_endo(bool_t()).with_extension_delta(EXTENSION_ID), +//! Signature::new_endo(bool_t()).with_extension_delta(logic::EXTENSION_ID), //! )?; //! //! // Get the wires from the function inputs. @@ -60,7 +60,7 @@ //! let mut dfg = module_builder.define_function( //! "circuit", //! Signature::new_endo(vec![bool_t(), bool_t()]) -//! .with_extension_delta(EXTENSION_ID), +//! .with_extension_delta(logic::EXTENSION_ID), //! )?; //! let mut circuit = dfg.as_circuit(dfg.input_wires()); //! diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index 44f5a00ac..9997c2f46 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -437,7 +437,7 @@ pub trait Dataflow: Container { }; let load_n = self.add_dataflow_op( - ops::LoadFunction::try_new(func_sig, type_args, self.hugr().extensions())?, + ops::LoadFunction::try_new(func_sig, type_args)?, // Static wire from the function node vec![Wire::new(func_node, func_op.static_output_port().unwrap())], )?; @@ -699,8 +699,7 @@ pub trait Dataflow: Container { }) } }; - let op: OpType = - ops::Call::try_new(type_scheme, type_args, self.hugr().extensions())?.into(); + let op: OpType = ops::Call::try_new(type_scheme, type_args)?.into(); let const_in_port = op.static_input_port().unwrap(); let op_id = self.add_dataflow_op(op, input_wires)?; let src_port = self.hugr_mut().num_outputs(function.node()) - 1; diff --git a/hugr-core/src/builder/circuit.rs b/hugr-core/src/builder/circuit.rs index 338c1b260..01f5e3e45 100644 --- a/hugr-core/src/builder/circuit.rs +++ b/hugr-core/src/builder/circuit.rs @@ -245,7 +245,7 @@ mod test { use crate::builder::{Container, HugrBuilder, ModuleBuilder}; use crate::extension::prelude::{qb_t, usize_t}; - use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY}; + use crate::extension::{ExtensionId, ExtensionSet}; use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; use crate::utils::test_quantum_extension::{ self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64, @@ -305,9 +305,7 @@ mod test { ) .unwrap(); }); - let my_custom_op = my_ext - .instantiate_extension_op("MyOp", [], &PRELUDE_REGISTRY) - .unwrap(); + let my_custom_op = my_ext.instantiate_extension_op("MyOp", []).unwrap(); let mut module_builder = ModuleBuilder::new(); let mut f_build = module_builder diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index e4720a676..2176d0ae6 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -316,7 +316,7 @@ pub(crate) mod test { }; use crate::extension::prelude::{bool_t, qb_t, usize_t}; use crate::extension::prelude::{Lift, Noop}; - use crate::extension::{ExtensionId, SignatureError, PRELUDE_REGISTRY}; + use crate::extension::{ExtensionId, SignatureError}; use crate::hugr::validate::InterGraphEdgeError; use crate::ops::{handle::NodeHandle, OpTag}; use crate::ops::{OpTrait, Value}; @@ -656,7 +656,6 @@ pub(crate) mod test { let ev = e.instantiate_extension_op( "eval", [vec![usize_t().into()].into(), vec![tv.into()].into()], - &PRELUDE_REGISTRY, ); assert_eq!( ev, diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 4e3066de4..408c88e15 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -113,20 +113,13 @@ impl ExtensionRegistry { self.exts.contains_key(name) } - /// Validate the set of extensions, ensuring that each extension requirements are also in the registry. - /// - /// Note this potentially asks extensions to validate themselves against other extensions that - /// may *not* be valid themselves yet. It'd be better to order these respecting dependencies, - /// or at least to validate the types first - which we don't do at all yet: - // - // TODO https://github.com/CQCL/hugr/issues/624. However, parametrized types could be - // cyclically dependent, so there is no perfect solution, and this is at least simple. + /// Validate the set of extensions. pub fn validate(&self) -> Result<(), ExtensionRegistryError> { if self.valid.load(Ordering::Relaxed) { return Ok(()); } for ext in self.exts.values() { - ext.validate(self) + ext.validate() .map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?; } self.valid.store(true, Ordering::Relaxed); @@ -398,14 +391,9 @@ pub enum SignatureError { /// Invalid type arguments #[error("Invalid type arguments for operation")] InvalidTypeArgs, - /// The Extension Registry did not contain an Extension referenced by the Signature - #[error("Extension '{missing}' is not part of the declared HUGR extensions [{}]", - available.iter().map(|e| e.to_string()).collect::>().join(", ") - )] - ExtensionNotFound { - missing: ExtensionId, - available: Vec, - }, + /// The weak [`Extension`] reference for a custom type has been dropped. + #[error("Type '{typ}' is defined in extension '{missing}', but the extension reference has been dropped.")] + MissingTypeExtension { typ: TypeName, missing: ExtensionId }, /// The Extension was found in the registry, but did not contain the Type(Def) referenced in the Signature #[error("Extension '{exn}' did not contain expected TypeDef '{typ}'")] ExtensionTypeNotFound { exn: ExtensionId, typ: TypeName }, @@ -740,18 +728,16 @@ impl Extension { &self, name: &OpNameRef, args: impl Into>, - ext_reg: &ExtensionRegistry, ) -> Result { let op_def = self.get_op(name).expect("Op not found."); - ExtensionOp::new(op_def.clone(), args, ext_reg) + ExtensionOp::new(op_def.clone(), args) } - /// Validates against a registry, which we can assume includes this extension itself. - // (TODO deal with the registry itself containing invalid extensions!) - fn validate(&self, all_exts: &ExtensionRegistry) -> Result<(), SignatureError> { + /// Validates the operation definitions in the register. + fn validate(&self) -> Result<(), SignatureError> { // We should validate TypeParams of TypeDefs too - https://github.com/CQCL/hugr/issues/624 for op_def in self.operations.values() { - op_def.validate(all_exts)?; + op_def.validate()?; } Ok(()) } diff --git a/hugr-core/src/extension/declarative.rs b/hugr-core/src/extension/declarative.rs index 2824aec80..64092981f 100644 --- a/hugr-core/src/extension/declarative.rs +++ b/hugr-core/src/extension/declarative.rs @@ -17,7 +17,7 @@ //! # const DECLARATIVE_YAML: &str = include_str!("../../examples/extension/declarative.yaml"); //! # use hugr::extension::declarative::load_extensions; //! // Required extensions must already be present in the registry. -//! let mut reg = hugr::std_extensions::logic::LOGIC_REG.clone(); +//! let mut reg = hugr::std_extensions::STD_REG.clone(); //! load_extensions(DECLARATIVE_YAML, &mut reg).unwrap(); //! ``` //! @@ -364,7 +364,7 @@ extensions: #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri #[rstest] - #[case(EXAMPLE_YAML_FILE, 1, 1, 3, &std_extensions::logic::LOGIC_REG)] + #[case(EXAMPLE_YAML_FILE, 1, 1, 3, &std_extensions::STD_REG)] fn test_decode_file( #[case] yaml_file: &str, #[case] num_declarations: usize, diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index b1dd7100e..31c576391 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -5,8 +5,8 @@ use std::fmt::{Debug, Formatter}; use std::sync::{Arc, Weak}; use super::{ - ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionRegistry, - ExtensionSet, SignatureError, + ConstFold, ConstFoldResult, Extension, ExtensionBuildError, ExtensionId, ExtensionSet, + SignatureError, }; use crate::ops::{OpName, OpNameRef}; @@ -24,7 +24,6 @@ pub trait CustomSignatureFunc: Send + Sync { &'a self, arg_values: &[TypeArg], def: &'o OpDef, - extension_registry: &ExtensionRegistry, ) -> Result; /// The declared type parameters which require values in order for signature to /// be computed. @@ -47,7 +46,6 @@ impl CustomSignatureFunc for T { &'a self, arg_values: &[TypeArg], _def: &'o OpDef, - _extension_registry: &ExtensionRegistry, ) -> Result { SignatureFromArgs::compute_signature(self, arg_values) } @@ -68,7 +66,6 @@ pub trait ValidateTypeArgs: Send + Sync { &self, arg_values: &[TypeArg], def: &'o OpDef, - extension_registry: &ExtensionRegistry, ) -> Result<(), SignatureError>; } @@ -86,7 +83,6 @@ impl ValidateTypeArgs for T { &self, arg_values: &[TypeArg], _def: &'o OpDef, - _extension_registry: &ExtensionRegistry, ) -> Result<(), SignatureError> { ValidateJustArgs::validate(self, arg_values) } @@ -228,12 +224,11 @@ impl SignatureFunc { &self, def: &OpDef, args: &[TypeArg], - exts: &ExtensionRegistry, ) -> Result { let temp: PolyFuncTypeRV; // to keep alive let (pf, args) = match &self { SignatureFunc::CustomValidator(custom) => { - custom.validate.validate(args, def, exts)?; + custom.validate.validate(args, def)?; (&custom.poly_func, args) } SignatureFunc::PolyFuncType(ts) => (ts, args), @@ -242,14 +237,14 @@ impl SignatureFunc { let (static_args, other_args) = args.split_at(min(static_params.len(), args.len())); check_type_args(static_args, static_params)?; - temp = func.compute_signature(static_args, def, exts)?; + temp = func.compute_signature(static_args, def)?; (&temp, other_args) } SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc), // TODO raise warning: https://github.com/CQCL/hugr/issues/1432 SignatureFunc::MissingValidateFunc(ts) => (ts, args), }; - let mut res = pf.instantiate(args, exts)?; + let mut res = pf.instantiate(args)?; res.runtime_reqs.insert(def.extension.clone()); // If there are any row variables left, this will fail with an error: @@ -332,13 +327,12 @@ pub struct OpDef { } impl OpDef { - /// Check provided type arguments are valid against [ExtensionRegistry], + /// Check provided type arguments are valid against their extensions, /// against parameters, and that no type variables are used as static arguments /// (to [compute_signature][CustomSignatureFunc::compute_signature]) pub fn validate_args( &self, args: &[TypeArg], - exts: &ExtensionRegistry, var_decls: &[TypeParam], ) -> Result<(), SignatureError> { let temp: PolyFuncTypeRV; // to keep alive @@ -348,11 +342,9 @@ impl OpDef { SignatureFunc::CustomFunc(custom) => { let (static_args, other_args) = args.split_at(min(custom.static_params().len(), args.len())); - static_args - .iter() - .try_for_each(|ta| ta.validate(exts, &[]))?; + static_args.iter().try_for_each(|ta| ta.validate(&[]))?; check_type_args(static_args, custom.static_params())?; - temp = custom.compute_signature(static_args, self, exts)?; + temp = custom.compute_signature(static_args, self)?; (&temp, other_args) } SignatureFunc::MissingComputeFunc => return Err(SignatureError::MissingComputeFunc), @@ -360,20 +352,15 @@ impl OpDef { return Err(SignatureError::MissingValidateFunc) } }; - args.iter() - .try_for_each(|ta| ta.validate(exts, var_decls))?; + args.iter().try_for_each(|ta| ta.validate(var_decls))?; check_type_args(args, pf.params())?; Ok(()) } /// Computes the signature of a node, i.e. an instantiation of this /// OpDef with statically-provided [TypeArg]s. - pub fn compute_signature( - &self, - args: &[TypeArg], - exts: &ExtensionRegistry, - ) -> Result { - self.signature_func.compute_signature(self, args, exts) + pub fn compute_signature(&self, args: &[TypeArg]) -> Result { + self.signature_func.compute_signature(self, args) } /// Fallibly returns a Hugr that may replace an instance of this OpDef @@ -427,14 +414,14 @@ impl OpDef { self.signature_func.static_params() } - pub(super) fn validate(&self, exts: &ExtensionRegistry) -> Result<(), SignatureError> { + pub(super) fn validate(&self) -> Result<(), SignatureError> { // TODO https://github.com/CQCL/hugr/issues/624 validate declared TypeParams // for both type scheme and custom binary if let SignatureFunc::CustomValidator(ts) = &self.signature_func { // The type scheme may contain row variables so be of variable length; // these will have to be substituted to fixed-length concrete types when // the OpDef is instantiated into an actual OpType. - ts.poly_func.validate(exts)?; + ts.poly_func.validate()?; } Ok(()) } @@ -551,8 +538,8 @@ pub(super) mod test { use crate::builder::{endo_sig, DFGBuilder, Dataflow, DataflowHugr}; use crate::extension::op_def::{CustomValidator, LowerFunc, OpDef, SignatureFunc}; use crate::extension::prelude::usize_t; + use crate::extension::SignatureError; use crate::extension::{ExtensionRegistry, ExtensionSet, PRELUDE}; - use crate::extension::{SignatureError, EMPTY_REG, PRELUDE_REGISTRY}; use crate::ops::OpName; use crate::std_extensions::collections::list; use crate::types::type_param::{TypeArgError, TypeParam}; @@ -681,7 +668,7 @@ pub(super) mod test { Type::new_extension(list_def.instantiate(vec![TypeArg::Type { ty: usize_t() }])?); let mut dfg = DFGBuilder::new(endo_sig(vec![list_usize]))?; let rev = dfg.add_dataflow_op( - e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: usize_t() }], ®) + e.instantiate_extension_op(&OP_NAME, vec![TypeArg::Type { ty: usize_t() }]) .unwrap(), dfg.input_wires(), )?; @@ -728,32 +715,32 @@ pub(super) mod test { // Base case, no type variables: let args = [TypeArg::BoundedNat { n: 3 }, usize_t().into()]; assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), + def.compute_signature(&args), Ok(Signature::new( vec![usize_t(); 3], vec![Type::new_tuple(vec![usize_t(); 3])] ) .with_extension_delta(EXT_ID)) ); - assert_eq!(def.validate_args(&args, &PRELUDE_REGISTRY, &[]), Ok(())); + assert_eq!(def.validate_args(&args, &[]), Ok(())); // Second arg may be a variable (substitutable) let tyvar = Type::new_var_use(0, TypeBound::Copyable); let tyvars: Vec = vec![tyvar.clone(); 3]; let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), + def.compute_signature(&args), Ok( Signature::new(tyvars.clone(), vec![Type::new_tuple(tyvars)]) .with_extension_delta(EXT_ID) ) ); - def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Copyable.into()]) + def.validate_args(&args, &[TypeBound::Copyable.into()]) .unwrap(); // quick sanity check that we are validating the args - note changed bound: assert_eq!( - def.validate_args(&args, &PRELUDE_REGISTRY, &[TypeBound::Any.into()]), + def.validate_args(&args, &[TypeBound::Any.into()]), Err(SignatureError::TypeVarDoesNotMatchDeclaration { actual: TypeBound::Any.into(), cached: TypeBound::Copyable.into() @@ -765,12 +752,12 @@ pub(super) mod test { let args = [TypeArg::new_var_use(0, kind.clone()), usize_t().into()]; // We can't prevent this from getting into our compute_signature implementation: assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), + def.compute_signature(&args), Err(SignatureError::InvalidTypeArgs) ); // But validation rules it out, even when the variable is declared: assert_eq!( - def.validate_args(&args, &PRELUDE_REGISTRY, &[kind]), + def.validate_args(&args, &[kind]), Err(SignatureError::FreeTypeVar { idx: 0, num_decls: 0 @@ -800,15 +787,15 @@ pub(super) mod test { let tv = Type::new_var_use(1, TypeBound::Copyable); let args = [TypeArg::Type { ty: tv.clone() }]; let decls = [TypeParam::Extensions, TypeBound::Copyable.into()]; - def.validate_args(&args, &EMPTY_REG, &decls).unwrap(); + def.validate_args(&args, &decls).unwrap(); assert_eq!( - def.compute_signature(&args, &EMPTY_REG), + def.compute_signature(&args), Ok(Signature::new_endo(tv).with_extension_delta(EXT_ID)) ); // But not with an external row variable let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into(); assert_eq!( - def.compute_signature(&[arg.clone()], &EMPTY_REG), + def.compute_signature(&[arg.clone()]), Err(SignatureError::TypeArgMismatch( TypeArgError::TypeMismatch { param: TypeBound::Any.into(), @@ -823,7 +810,7 @@ pub(super) mod test { #[test] fn instantiate_extension_delta() -> Result<(), Box> { - use crate::extension::prelude::{bool_t, PRELUDE_REGISTRY}; + use crate::extension::prelude::bool_t; let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { let params: Vec = vec![TypeParam::Extensions]; @@ -842,12 +829,8 @@ pub(super) mod test { let exp_fun_ty = Signature::new_endo(bool_t()).with_extension_delta(es.clone()); let args = [TypeArg::Extensions { es }]; - def.validate_args(&args, &PRELUDE_REGISTRY, ¶ms) - .unwrap(); - assert_eq!( - def.compute_signature(&args, &PRELUDE_REGISTRY), - Ok(exp_fun_ty) - ); + def.validate_args(&args, ¶ms).unwrap(); + assert_eq!(def.compute_signature(&args), Ok(exp_fun_ty)); Ok(()) })?; diff --git a/hugr-core/src/extension/op_def/serialize_signature_func.rs b/hugr-core/src/extension/op_def/serialize_signature_func.rs index 6c189cc84..c3c4fffcf 100644 --- a/hugr-core/src/extension/op_def/serialize_signature_func.rs +++ b/hugr-core/src/extension/op_def/serialize_signature_func.rs @@ -57,8 +57,8 @@ mod test { use super::*; use crate::{ extension::{ - prelude::usize_t, CustomSignatureFunc, CustomValidator, ExtensionRegistry, OpDef, - SignatureError, ValidateTypeArgs, + prelude::usize_t, CustomSignatureFunc, CustomValidator, OpDef, SignatureError, + ValidateTypeArgs, }, types::{FuncValueType, Signature, TypeArg}, }; @@ -96,7 +96,6 @@ mod test { &'a self, _arg_values: &[TypeArg], _def: &'o crate::extension::op_def::OpDef, - _extension_registry: &crate::extension::ExtensionRegistry, ) -> Result { Ok(Default::default()) } @@ -112,7 +111,6 @@ mod test { &self, _arg_values: &[TypeArg], _def: &'o OpDef, - _extension_registry: &ExtensionRegistry, ) -> Result<(), SignatureError> { Ok(()) } diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index b28814c81..0f8e369ad 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -650,8 +650,8 @@ impl MakeRegisteredOp for MakeTuple { PRELUDE_ID.to_owned() } - fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry { - &PRELUDE_REGISTRY + fn extension_ref(&self) -> Weak { + Arc::downgrade(&PRELUDE) } } @@ -712,8 +712,8 @@ impl MakeRegisteredOp for UnpackTuple { PRELUDE_ID.to_owned() } - fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry { - &PRELUDE_REGISTRY + fn extension_ref(&self) -> Weak { + Arc::downgrade(&PRELUDE) } } @@ -821,8 +821,8 @@ impl MakeRegisteredOp for Noop { PRELUDE_ID.to_owned() } - fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry { - &PRELUDE_REGISTRY + fn extension_ref(&self) -> Weak { + Arc::downgrade(&PRELUDE) } } @@ -947,8 +947,8 @@ impl MakeRegisteredOp for Lift { PRELUDE_ID.to_owned() } - fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry { - &PRELUDE_REGISTRY + fn extension_ref(&self) -> Weak { + Arc::downgrade(&PRELUDE) } } @@ -1088,11 +1088,7 @@ mod test { const TYPE_ARG_NONE: TypeArg = TypeArg::Sequence { elems: vec![] }; let op = PRELUDE - .instantiate_extension_op( - &PANIC_OP_ID, - [TYPE_ARG_NONE, TYPE_ARG_NONE], - &PRELUDE_REGISTRY, - ) + .instantiate_extension_op(&PANIC_OP_ID, [TYPE_ARG_NONE, TYPE_ARG_NONE]) .unwrap(); b.add_dataflow_op(op, [err]).unwrap(); @@ -1109,11 +1105,7 @@ mod test { elems: vec![type_arg_q.clone(), type_arg_q], }; let panic_op = PRELUDE - .instantiate_extension_op( - &PANIC_OP_ID, - [type_arg_2q.clone(), type_arg_2q.clone()], - &PRELUDE_REGISTRY, - ) + .instantiate_extension_op(&PANIC_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()]) .unwrap(); let mut b = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()])).unwrap(); @@ -1157,9 +1149,7 @@ mod test { let mut b: DFGBuilder = DFGBuilder::new(endo_sig(vec![])).unwrap(); let greeting: ConstString = ConstString::new("Hello, world!".into()); let greeting_out: Wire = b.add_load_value(greeting); - let print_op = PRELUDE - .instantiate_extension_op(&PRINT_OP_ID, [], &PRELUDE_REGISTRY) - .unwrap(); + let print_op = PRELUDE.instantiate_extension_op(&PRINT_OP_ID, []).unwrap(); b.add_dataflow_op(print_op, [greeting_out]).unwrap(); b.finish_hugr_with_outputs([]).unwrap(); } diff --git a/hugr-core/src/extension/prelude/generic.rs b/hugr-core/src/extension/prelude/generic.rs index b79bd40bf..1a6294ec9 100644 --- a/hugr-core/src/extension/prelude/generic.rs +++ b/hugr-core/src/extension/prelude/generic.rs @@ -23,8 +23,8 @@ use crate::types::PolyFuncTypeRV; use crate::types::type_param::TypeArg; use crate::Extension; +use super::PRELUDE; use super::{ConstUsize, PRELUDE_ID}; -use super::{PRELUDE, PRELUDE_REGISTRY}; use crate::types::type_param::TypeParam; /// Name of the operation for loading generic BoundedNat parameters. @@ -139,8 +139,8 @@ impl MakeRegisteredOp for LoadNat { PRELUDE_ID } - fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry { - &PRELUDE_REGISTRY + fn extension_ref(&self) -> Weak { + Arc::downgrade(&PRELUDE) } } diff --git a/hugr-core/src/extension/prelude/unwrap_builder.rs b/hugr-core/src/extension/prelude/unwrap_builder.rs index eba1aa0c7..b3f649a66 100644 --- a/hugr-core/src/extension/prelude/unwrap_builder.rs +++ b/hugr-core/src/extension/prelude/unwrap_builder.rs @@ -34,7 +34,7 @@ pub trait UnwrapBuilder: Dataflow { .collect_vec() .into(); let prelude = reg.get(&PRELUDE_ID).unwrap(); - let op = prelude.instantiate_extension_op(&PANIC_OP_ID, [input_arg, output_arg], reg)?; + let op = prelude.instantiate_extension_op(&PANIC_OP_ID, [input_arg, output_arg])?; let err = self.add_load_value(err); self.add_dataflow_op(op, iter::once(err).chain(input_wires)) } diff --git a/hugr-core/src/extension/resolution/ops.rs b/hugr-core/src/extension/resolution/ops.rs index 78e6b3fbc..6af838385 100644 --- a/hugr-core/src/extension/resolution/ops.rs +++ b/hugr-core/src/extension/resolution/ops.rs @@ -86,13 +86,12 @@ pub(crate) fn resolve_op_extensions<'e>( .into()); }; - let ext_op = - ExtensionOp::new_with_cached(def.clone(), opaque.args().to_vec(), opaque, extensions) - .map_err(|e| OpaqueOpError::SignatureError { - node, - name: opaque.name().clone(), - cause: e, - })?; + let ext_op = ExtensionOp::new_with_cached(def.clone(), opaque.args().to_vec(), opaque) + .map_err(|e| OpaqueOpError::SignatureError { + node, + name: opaque.name().clone(), + cause: e, + })?; if opaque.signature() != ext_op.signature() { return Err(OpaqueOpError::SignatureMismatch { diff --git a/hugr-core/src/extension/resolution/test.rs b/hugr-core/src/extension/resolution/test.rs index 65bcb2aad..c4b4bad79 100644 --- a/hugr-core/src/extension/resolution/test.rs +++ b/hugr-core/src/extension/resolution/test.rs @@ -8,9 +8,10 @@ use itertools::Itertools; use rstest::rstest; use crate::builder::{ - Container, Dataflow, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, + Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, + HugrBuilder, ModuleBuilder, }; -use crate::extension::prelude::{bool_t, usize_custom_t, usize_t, ConstUsize}; +use crate::extension::prelude::{bool_t, usize_custom_t, usize_t, ConstUsize, PRELUDE_ID}; use crate::extension::resolution::WeakExtensionRegistry; use crate::extension::resolution::{ resolve_op_extensions, resolve_op_types_extensions, ExtensionCollectionError, @@ -26,7 +27,7 @@ use crate::std_extensions::arithmetic::int_ops; use crate::std_extensions::arithmetic::int_types::{self, int_type}; use crate::std_extensions::collections::list::ListValue; use crate::types::{Signature, Type}; -use crate::{type_row, Extension, Hugr, HugrView}; +use crate::{std_extensions, type_row, Extension, Hugr, HugrView}; #[rstest] #[case::empty(Input { types: type_row![]}, ExtensionRegistry::default())] @@ -84,7 +85,7 @@ fn make_extension(name: &str, op_name: &str) -> (Arc, OpType) { .unwrap(); }); let op_def = ext.get_op(op_name).unwrap(); - let op = ExtensionOp::new(op_def.clone(), vec![], &ExtensionRegistry::default()).unwrap(); + let op = ExtensionOp::new(op_def.clone(), vec![]).unwrap(); (ext, op.into()) } @@ -153,6 +154,62 @@ fn check_extension_resolution(mut hugr: Hugr) { ); } +/// Build a small hugr using the float types extension and check that the extensions are resolved. +#[rstest] +fn resolve_hugr_extensions_simple() { + let mut build = DFGBuilder::new( + Signature::new(vec![], vec![float64_type()]).with_extension_delta( + [ + PRELUDE_ID.to_owned(), + std_extensions::arithmetic::float_types::EXTENSION_ID.to_owned(), + ] + .into_iter() + .collect::(), + ), + ) + .unwrap(); + + // A constant op using a non-prelude extension. + let f_const = build.add_load_const(Value::extension(ConstF64::new(f64::consts::PI))); + + let mut hugr = build + .finish_hugr_with_outputs([f_const]) + .unwrap_or_else(|e| panic!("{e}")); + + let build_extensions = hugr.extensions().clone(); + + // Check that the read-only methods collect the same extensions. + let mut collected_exts = ExtensionRegistry::default(); + for node in hugr.nodes() { + let op = hugr.get_optype(node); + collected_exts.extend(op.used_extensions().unwrap()); + } + assert_eq!( + collected_exts, build_extensions, + "{collected_exts} != {build_extensions}" + ); + + // Check that the mutable methods collect the same extensions. + hugr.resolve_extension_defs(&build_extensions).unwrap(); + assert_eq!( + hugr.extensions(), + &build_extensions, + "{} != {build_extensions}", + hugr.extensions() + ); + + // Serialization roundtrip to drop the weak pointers. + let ser = serde_json::to_string(&hugr).unwrap(); + let deser_hugr = Hugr::load_json(ser.as_bytes(), &build_extensions).unwrap(); + + assert_eq!( + deser_hugr.extensions(), + &build_extensions, + "{} != {build_extensions}", + hugr.extensions() + ); +} + /// Build a hugr with all possible op nodes and resolve the extensions. #[rstest] fn resolve_hugr_extensions() { diff --git a/hugr-core/src/extension/simple_op.rs b/hugr-core/src/extension/simple_op.rs index 0681d5818..f812de8a3 100644 --- a/hugr-core/src/extension/simple_op.rs +++ b/hugr-core/src/extension/simple_op.rs @@ -11,10 +11,7 @@ use crate::{ Extension, }; -use super::{ - op_def::SignatureFunc, ExtensionBuildError, ExtensionId, ExtensionRegistry, OpDef, - SignatureError, -}; +use super::{op_def::SignatureFunc, ExtensionBuildError, ExtensionId, OpDef, SignatureError}; use delegate::delegate; use thiserror::Error; @@ -176,14 +173,14 @@ pub trait MakeExtensionOp: NamedOp { fn to_registered( self, extension_id: ExtensionId, - registry: &ExtensionRegistry, - ) -> RegisteredOp<'_, Self> + extension: Weak, + ) -> RegisteredOp where Self: Sized, { RegisteredOp { extension_id, - registry, + extension, op: self, } } @@ -226,32 +223,28 @@ where /// Wrap an [MakeExtensionOp] with an extension registry to allow type computation. /// Generate from [MakeExtensionOp::to_registered] #[derive(Clone, Debug)] -pub struct RegisteredOp<'r, T> { +pub struct RegisteredOp { /// The name of the extension these ops belong to. - extension_id: ExtensionId, + pub extension_id: ExtensionId, /// A registry of all extensions, used for type computation. - registry: &'r ExtensionRegistry, + extension: Weak, /// The inner [MakeExtensionOp] op: T, } -impl RegisteredOp<'_, T> { +impl RegisteredOp { /// Extract the inner wrapped value pub fn to_inner(self) -> T { self.op } } -impl RegisteredOp<'_, T> { +impl RegisteredOp { /// Generate an [OpType]. pub fn to_extension_op(&self) -> Option { ExtensionOp::new( - self.registry - .get(&self.extension_id)? - .get_op(&self.name())? - .clone(), + self.extension.upgrade()?.get_op(&self.name())?.clone(), self.type_args(), - self.registry, ) .ok() } @@ -272,9 +265,8 @@ impl RegisteredOp<'_, T> { pub trait MakeRegisteredOp: MakeExtensionOp { /// The ID of the extension this op belongs to. fn extension_id(&self) -> ExtensionId; - /// A reference to an [ExtensionRegistry] which is sufficient to generate - /// the signature of this op. - fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry; + /// A reference to the [Extension] which defines this operation. + fn extension_ref(&self) -> Weak; /// Convert this operation in to an [ExtensionOp]. Returns None if the type /// cannot be computed. @@ -287,11 +279,11 @@ pub trait MakeRegisteredOp: MakeExtensionOp { } } -impl From for RegisteredOp<'_, T> { +impl From for RegisteredOp { fn from(ext_op: T) -> Self { let extension_id = ext_op.extension_id(); - let registry = ext_op.registry(); - ext_op.to_registered(extension_id, registry) + let extension = ext_op.extension_ref(); + ext_op.to_registered(extension_id, extension) } } @@ -358,15 +350,14 @@ mod test { .unwrap(); }) }; - static ref DUMMY_REG: ExtensionRegistry = ExtensionRegistry::new([EXT.clone()]); } impl MakeRegisteredOp for DummyEnum { fn extension_id(&self) -> ExtensionId { EXT_ID.to_owned() } - fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { - &DUMMY_REG + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXT) } } diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/rewrite/replace.rs index 3e9558533..2e4300b2a 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/rewrite/replace.rs @@ -448,7 +448,7 @@ mod test { DataflowSubContainer, HugrBuilder, SubContainer, }; use crate::extension::prelude::{bool_t, usize_t}; - use crate::extension::{ExtensionRegistry, PRELUDE, PRELUDE_REGISTRY}; + use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::rewrite::replace::WhichHugr; use crate::hugr::{HugrMut, Rewrite}; @@ -643,15 +643,9 @@ mod test { .unwrap(); }); let ext_name = ext.name().clone(); - let foo = ext - .instantiate_extension_op("foo", [], &PRELUDE_REGISTRY) - .unwrap(); - let bar = ext - .instantiate_extension_op("bar", [], &PRELUDE_REGISTRY) - .unwrap(); - let baz = ext - .instantiate_extension_op("baz", [], &PRELUDE_REGISTRY) - .unwrap(); + let foo = ext.instantiate_extension_op("foo", []).unwrap(); + let bar = ext.instantiate_extension_op("bar", []).unwrap(); + let baz = ext.instantiate_extension_op("baz", []).unwrap(); let mut registry = test_quantum_extension::REG.clone(); registry.register(ext).unwrap(); diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 452f15812..49a7b9321 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -9,7 +9,7 @@ use crate::extension::prelude::Noop; use crate::extension::prelude::{bool_t, qb_t, usize_t, PRELUDE_ID}; use crate::extension::simple_op::MakeRegisteredOp; use crate::extension::ExtensionRegistry; -use crate::extension::{test::SimpleOpDef, ExtensionSet, EMPTY_REG}; +use crate::extension::{test::SimpleOpDef, ExtensionSet}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::validate::ValidationError; use crate::ops::custom::{ExtensionOp, OpaqueOp, OpaqueOpError}; @@ -499,7 +499,7 @@ fn polyfunctype2() -> PolyFuncTypeRV { let res = PolyFuncTypeRV::new(params, FuncValueType::new(inputs, tv1)); // Just check we've got the arguments the right way round // (not that it really matters for the serialization schema we have) - res.validate(&EMPTY_REG).unwrap(); + res.validate().unwrap(); res } @@ -541,7 +541,7 @@ fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { #[case(ops::Const::new(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap()))] #[case(ops::Input::new(vec![Type::new_var_use(3,TypeBound::Copyable)]))] #[case(ops::Output::new(vec![Type::new_function(FuncValueType::new_endo(type_row![]))]))] -#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}, TypeArg::Extensions{ es: ExtensionSet::singleton(PRELUDE_ID)} ], &EMPTY_REG).unwrap())] +#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}, TypeArg::Extensions{ es: ExtensionSet::singleton(PRELUDE_ID)} ]).unwrap())] #[case(ops::CallIndirect { signature : Signature::new_endo(vec![bool_t()]) })] fn roundtrip_optype(#[case] optype: impl Into + std::fmt::Debug) { check_testing_roundtrip(NodeSer { diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index beefd2c96..8770474a2 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -300,11 +300,11 @@ impl<'a> ValidationContext<'a> { var_decls: &[TypeParam], ) -> Result<(), SignatureError> { match &port_kind { - EdgeKind::Value(ty) => ty.validate(self.hugr.extensions(), var_decls), + EdgeKind::Value(ty) => ty.validate(var_decls), // Static edges must *not* refer to type variables declared by enclosing FuncDefns // as these are only types at runtime. - EdgeKind::Const(ty) => ty.validate(self.hugr.extensions(), &[]), - EdgeKind::Function(pf) => pf.validate(self.hugr.extensions()), + EdgeKind::Const(ty) => ty.validate(&[]), + EdgeKind::Function(pf) => pf.validate(), _ => Ok(()), } } @@ -575,7 +575,7 @@ impl<'a> ValidationContext<'a> { // Check TypeArgs are valid, and if we can, fit the declared TypeParams ext_op .def() - .validate_args(ext_op.args(), self.hugr.extensions(), var_decls) + .validate_args(ext_op.args(), var_decls) .map_err(|cause| ValidationError::SignatureError { node, op: op_type.name(), @@ -592,22 +592,20 @@ impl<'a> ValidationContext<'a> { ))?; } OpType::Call(c) => { - c.validate(self.hugr.extensions()).map_err(|cause| { - ValidationError::SignatureError { + c.validate() + .map_err(|cause| ValidationError::SignatureError { node, op: op_type.name(), cause, - } - })?; + })?; } OpType::LoadFunction(c) => { - c.validate(self.hugr.extensions()).map_err(|cause| { - ValidationError::SignatureError { + c.validate() + .map_err(|cause| ValidationError::SignatureError { node, op: op_type.name(), cause, - } - })?; + })?; } _ => (), } diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index fea6b336e..00c87d906 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -12,9 +12,7 @@ use crate::builder::{ }; use crate::extension::prelude::Noop; use crate::extension::prelude::{bool_t, qb_t, usize_t, PRELUDE_ID}; -use crate::extension::{ - Extension, ExtensionRegistry, ExtensionSet, TypeDefBound, PRELUDE, PRELUDE_REGISTRY, -}; +use crate::extension::{Extension, ExtensionRegistry, ExtensionSet, TypeDefBound, PRELUDE}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::HugrMut; use crate::ops::dataflow::IOTrait; @@ -28,7 +26,6 @@ use crate::types::{ CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Type, TypeBound, TypeRV, TypeRow, }; -use crate::utils::test_quantum_extension; use crate::{ const_extension_ids, test_file, type_row, Direction, IncomingPort, Node, OutgoingPort, }; @@ -363,14 +360,6 @@ fn identity_hugr_with_type(t: Type) -> (Hugr, Node) { b.connect(input, 0, output, 0); (b, def) } -#[test] -fn unregistered_extension() { - let (mut h, _def) = identity_hugr_with_type(usize_t()); - assert!(h.validate().is_err(),); - h.resolve_extension_defs(&test_quantum_extension::REG) - .unwrap(); - h.validate().unwrap(); -} const_extension_ids! { const EXT_ID: ExtensionId = "MyExt"; @@ -637,16 +626,14 @@ fn instantiate_row_variables() -> Result<(), Box> { vec![usize_t(); 4], // outputs (*2^2, three calls) ))?; let [func, int] = dfb.input_wires_arr(); - let eval = e.instantiate_extension_op("eval", [uint_seq(1), uint_seq(2)], &PRELUDE_REGISTRY)?; + let eval = e.instantiate_extension_op("eval", [uint_seq(1), uint_seq(2)])?; let [a, b] = dfb.add_dataflow_op(eval, [func, int])?.outputs_arr(); let par = e.instantiate_extension_op( "parallel", [uint_seq(1), uint_seq(1), uint_seq(2), uint_seq(2)], - &PRELUDE_REGISTRY, )?; let [par_func] = dfb.add_dataflow_op(par, [func, func])?.outputs_arr(); - let eval2 = - e.instantiate_extension_op("eval", [uint_seq(2), uint_seq(4)], &PRELUDE_REGISTRY)?; + let eval2 = e.instantiate_extension_op("eval", [uint_seq(2), uint_seq(4)])?; let eval2 = dfb.add_dataflow_op(eval2, [par_func, a, b])?; dfb.finish_hugr_with_outputs(eval2.outputs())?; Ok(()) @@ -682,7 +669,6 @@ fn row_variables() -> Result<(), Box> { let par = e.instantiate_extension_op( "parallel", [tv.clone(), usize_t().into(), tv.clone(), usize_t().into()].map(seq1ty), - &PRELUDE_REGISTRY, )?; let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?; fb.finish_hugr_with_outputs(par_func.outputs())?; @@ -761,7 +747,6 @@ fn test_polymorphic_call() -> Result<(), Box> { TypeArg::Extensions { es }, usize_t().into(), ], - &PRELUDE_REGISTRY, )?; let [f1] = cc.add_dataflow_op(op.clone(), [func, i1])?.outputs_arr(); let [f2] = cc.add_dataflow_op(op, [func, i2])?.outputs_arr(); diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 26cd6fd28..532802903 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -445,7 +445,7 @@ impl<'a> Context<'a> { .collect::, _>>()?; self.static_edges.push((func_node, node_id)); - let optype = OpType::Call(Call::try_new(func_sig, type_args, self.extensions)?); + let optype = OpType::Call(Call::try_new(func_sig, type_args)?); let node = self.make_node(node_id, optype, parent)?; Ok(Some(node)) @@ -466,11 +466,7 @@ impl<'a> Context<'a> { self.static_edges.push((func_node, node_id)); - let optype = OpType::LoadFunction(LoadFunction::try_new( - func_sig, - type_args, - self.extensions, - )?); + let optype = OpType::LoadFunction(LoadFunction::try_new(func_sig, type_args)?); let node = self.make_node(node_id, optype, parent)?; Ok(Some(node)) diff --git a/hugr-core/src/ops/controlflow.rs b/hugr-core/src/ops/controlflow.rs index 41dae59a6..49728980f 100644 --- a/hugr-core/src/ops/controlflow.rs +++ b/hugr-core/src/ops/controlflow.rs @@ -375,7 +375,7 @@ mod test { use crate::{ extension::{ prelude::{qb_t, usize_t, PRELUDE_ID}, - ExtensionSet, PRELUDE_REGISTRY, + ExtensionSet, }, ops::{Conditional, DataflowOpTrait, DataflowParent}, types::{Signature, Substitution, Type, TypeArg, TypeBound, TypeRV}, @@ -393,15 +393,12 @@ mod test { sum_rows: vec![usize_t().into(), vec![qb_t(), tv0.clone()].into()], extension_delta: ExtensionSet::type_var(1), }; - let dfb2 = dfb.substitute(&Substitution::new( - &[ - qb_t().into(), - TypeArg::Extensions { - es: PRELUDE_ID.into(), - }, - ], - &PRELUDE_REGISTRY, - )); + let dfb2 = dfb.substitute(&Substitution::new(&[ + qb_t().into(), + TypeArg::Extensions { + es: PRELUDE_ID.into(), + }, + ])); let st = Type::new_sum(vec![vec![usize_t()], vec![qb_t(); 2]]); assert_eq!( dfb2.inner_signature(), @@ -419,15 +416,12 @@ mod test { outputs: vec![usize_t(), tv1].into(), extension_delta: ExtensionSet::new(), }; - let cond2 = cond.substitute(&Substitution::new( - &[ - TypeArg::Sequence { - elems: vec![usize_t().into(); 3], - }, - qb_t().into(), - ], - &PRELUDE_REGISTRY, - )); + let cond2 = cond.substitute(&Substitution::new(&[ + TypeArg::Sequence { + elems: vec![usize_t().into(); 3], + }, + qb_t().into(), + ])); let st = Type::new_sum(vec![usize_t(), qb_t()]); //both single-element variants assert_eq!( cond2.signature(), @@ -447,15 +441,12 @@ mod test { rest: vec![tv0.clone()].into(), extension_delta: ExtensionSet::type_var(1), }; - let tail2 = tail_loop.substitute(&Substitution::new( - &[ - usize_t().into(), - TypeArg::Extensions { - es: PRELUDE_ID.into(), - }, - ], - &PRELUDE_REGISTRY, - )); + let tail2 = tail_loop.substitute(&Substitution::new(&[ + usize_t().into(), + TypeArg::Extensions { + es: PRELUDE_ID.into(), + }, + ])); assert_eq!( tail2.signature(), Signature::new( diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index a0ad220e7..11ec390a6 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -12,7 +12,7 @@ use { ::proptest_derive::Arbitrary, }; -use crate::extension::{ConstFoldResult, ExtensionId, ExtensionRegistry, OpDef, SignatureError}; +use crate::extension::{ConstFoldResult, ExtensionId, OpDef, SignatureError}; use crate::types::{type_param::TypeArg, Signature}; use crate::{ops, IncomingPort, Node}; @@ -41,13 +41,9 @@ pub struct ExtensionOp { impl ExtensionOp { /// Create a new ExtensionOp given the type arguments and specified input extensions - pub fn new( - def: Arc, - args: impl Into>, - exts: &ExtensionRegistry, - ) -> Result { + pub fn new(def: Arc, args: impl Into>) -> Result { let args: Vec = args.into(); - let signature = def.compute_signature(&args, exts)?; + let signature = def.compute_signature(&args)?; Ok(Self { def, args, @@ -60,12 +56,11 @@ impl ExtensionOp { def: Arc, args: impl IntoIterator, opaque: &OpaqueOp, - exts: &ExtensionRegistry, ) -> Result { let args: Vec = args.into_iter().collect(); // TODO skip computation depending on config // see https://github.com/CQCL/hugr/issues/1363 - let signature = match def.compute_signature(&args, exts) { + let signature = match def.compute_signature(&args) { Ok(sig) => sig, Err(SignatureError::MissingComputeFunc) => { // TODO raise warning: https://github.com/CQCL/hugr/issues/1432 @@ -174,12 +169,6 @@ impl DataflowOpTrait for ExtensionOp { .map(|ta| ta.substitute(subst)) .collect::>(); let signature = self.signature.substitute(subst); - debug_assert_eq!( - self.def - .compute_signature(&args, subst.extension_registry()) - .as_ref(), - Ok(&signature) - ); Self { def: self.def.clone(), args, @@ -341,7 +330,9 @@ mod test { use ops::OpType; use crate::extension::resolution::resolve_op_extensions; - use crate::std_extensions::arithmetic::conversions::{self, CONVERT_OPS_REGISTRY}; + use crate::extension::ExtensionRegistry; + use crate::std_extensions::arithmetic::conversions::{self}; + use crate::std_extensions::STD_REG; use crate::{ extension::{ prelude::{bool_t, qb_t, usize_t}, @@ -380,7 +371,7 @@ mod test { #[test] fn resolve_opaque_op() { - let registry = &CONVERT_OPS_REGISTRY; + let registry = &STD_REG; let i0 = &INT_TYPES[0]; let opaque = OpaqueOp::new( conversions::EXTENSION_ID, diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index 7f9ae9c22..ca8c1b099 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -4,7 +4,7 @@ use std::borrow::Cow; use super::{impl_op_name, OpTag, OpTrait}; -use crate::extension::{ExtensionRegistry, ExtensionSet, SignatureError}; +use crate::extension::{ExtensionSet, SignatureError}; use crate::ops::StaticTag; use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeArg, TypeRow}; use crate::{type_row, IncomingPort}; @@ -220,9 +220,7 @@ impl DataflowOpTrait for Call { .collect::>(); let instantiation = self.instantiation.substitute(subst); debug_assert_eq!( - self.func_sig - .instantiate(&type_args, subst.extension_registry()) - .as_ref(), + self.func_sig.instantiate(&type_args).as_ref(), Ok(&instantiation) ); Self { @@ -240,10 +238,9 @@ impl Call { pub fn try_new( func_sig: PolyFuncType, type_args: impl Into>, - exts: &ExtensionRegistry, ) -> Result { let type_args: Vec<_> = type_args.into(); - let instantiation = func_sig.instantiate(&type_args, exts)?; + let instantiation = func_sig.instantiate(&type_args)?; Ok(Self { func_sig, type_args, @@ -268,7 +265,7 @@ impl Call { /// # use hugr::extension::prelude::qb_t; /// # use hugr::extension::PRELUDE_REGISTRY; /// let signature = Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]); - /// let call = Call::try_new(signature.into(), &[], &PRELUDE_REGISTRY).unwrap(); + /// let call = Call::try_new(signature.into(), &[]).unwrap(); /// let op = OpType::Call(call.clone()); /// assert_eq!(op.static_input_port(), Some(call.called_function_port())); /// ``` @@ -279,15 +276,8 @@ impl Call { self.instantiation.input_count().into() } - pub(crate) fn validate( - &self, - extension_registry: &ExtensionRegistry, - ) -> Result<(), SignatureError> { - let other = Self::try_new( - self.func_sig.clone(), - self.type_args.clone(), - extension_registry, - )?; + pub(crate) fn validate(&self) -> Result<(), SignatureError> { + let other = Self::try_new(self.func_sig.clone(), self.type_args.clone())?; if other.instantiation == self.instantiation { Ok(()) } else { @@ -428,9 +418,7 @@ impl DataflowOpTrait for LoadFunction { .collect::>(); let instantiation = self.instantiation.substitute(subst); debug_assert_eq!( - self.func_sig - .instantiate(&type_args, subst.extension_registry()) - .as_ref(), + self.func_sig.instantiate(&type_args).as_ref(), Ok(&instantiation) ); Self { @@ -448,10 +436,9 @@ impl LoadFunction { pub fn try_new( func_sig: PolyFuncType, type_args: impl Into>, - exts: &ExtensionRegistry, ) -> Result { let type_args: Vec<_> = type_args.into(); - let instantiation = func_sig.instantiate(&type_args, exts)?; + let instantiation = func_sig.instantiate(&type_args)?; Ok(Self { func_sig, type_args, @@ -475,15 +462,8 @@ impl LoadFunction { 0.into() } - pub(crate) fn validate( - &self, - extension_registry: &ExtensionRegistry, - ) -> Result<(), SignatureError> { - let other = Self::try_new( - self.func_sig.clone(), - self.type_args.clone(), - extension_registry, - )?; + pub(crate) fn validate(&self) -> Result<(), SignatureError> { + let other = Self::try_new(self.func_sig.clone(), self.type_args.clone())?; if other.instantiation == self.instantiation { Ok(()) } else { diff --git a/hugr-core/src/std_extensions/arithmetic/conversions.rs b/hugr-core/src/std_extensions/arithmetic/conversions.rs index bb3badc57..c77c81bfa 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions.rs @@ -8,9 +8,7 @@ use crate::extension::prelude::sum_with_error; use crate::extension::prelude::{bool_t, string_type, usize_t}; use crate::extension::simple_op::{HasConcrete, HasDef}; use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}; -use crate::extension::{ - ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc, PRELUDE, -}; +use crate::extension::{ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc}; use crate::ops::OpName; use crate::ops::{custom::ExtensionOp, NamedOp}; use crate::std_extensions::arithmetic::int_ops::int_polytype; @@ -168,14 +166,6 @@ lazy_static! { ConvertOpDef::load_all_ops(extension, extension_ref).unwrap(); }) }; - - /// Registry of extensions required to validate integer operations. - pub static ref CONVERT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::new([ - PRELUDE.clone(), - super::int_types::EXTENSION.clone(), - super::float_types::EXTENSION.clone(), - EXTENSION.clone(), - ]); } impl MakeRegisteredOp for ConvertOpType { @@ -183,8 +173,8 @@ impl MakeRegisteredOp for ConvertOpType { EXTENSION_ID.to_owned() } - fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { - &CONVERT_OPS_REGISTRY + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXTENSION) } } diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops.rs b/hugr-core/src/std_extensions/arithmetic/float_ops.rs index ce6f30e15..beb32346e 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops.rs @@ -9,7 +9,7 @@ use crate::{ extension::{ prelude::{bool_t, string_type}, simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError}, - ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureFunc, PRELUDE, + ExtensionId, ExtensionSet, OpDef, SignatureFunc, }, types::Signature, Extension, @@ -115,13 +115,6 @@ lazy_static! { FloatOps::load_all_ops(extension, extension_ref).unwrap(); }) }; - - /// Registry of extensions required to validate float operations. - pub static ref FLOAT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::new([ - PRELUDE.clone(), - super::float_types::EXTENSION.clone(), - EXTENSION.clone(), - ]); } impl MakeRegisteredOp for FloatOps { @@ -129,8 +122,8 @@ impl MakeRegisteredOp for FloatOps { EXTENSION_ID.to_owned() } - fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { - &FLOAT_OPS_REGISTRY + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXTENSION) } } diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index 2bb269076..eb5cb2ee9 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -7,9 +7,7 @@ use crate::extension::prelude::{bool_t, sum_with_error}; use crate::extension::simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; -use crate::extension::{ - CustomValidator, ExtensionRegistry, OpDef, SignatureFunc, ValidateJustArgs, PRELUDE, -}; +use crate::extension::{CustomValidator, OpDef, SignatureFunc, ValidateJustArgs}; use crate::ops::custom::ExtensionOp; use crate::ops::{NamedOp, OpName}; use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRowRV}; @@ -258,13 +256,6 @@ lazy_static! { IntOpDef::load_all_ops(extension, extension_ref).unwrap(); }) }; - - /// Registry of extensions required to validate integer operations. - pub static ref INT_OPS_REGISTRY: ExtensionRegistry = ExtensionRegistry::new([ - PRELUDE.clone(), - super::int_types::EXTENSION.clone(), - EXTENSION.clone(), - ]); } impl HasConcrete for IntOpDef { @@ -321,8 +312,8 @@ impl MakeRegisteredOp for ConcreteIntOp { EXTENSION_ID.to_owned() } - fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { - &INT_OPS_REGISTRY + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXTENSION) } } diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index 48e84bd05..b38e762b5 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -9,9 +9,7 @@ use std::sync::Arc; use lazy_static::lazy_static; use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; -use crate::extension::{ - ExtensionId, ExtensionRegistry, SignatureError, TypeDef, TypeDefBound, PRELUDE, -}; +use crate::extension::{ExtensionId, SignatureError, TypeDef, TypeDefBound}; use crate::ops::{ExtensionOp, OpName}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{Type, TypeBound, TypeName}; @@ -46,12 +44,6 @@ lazy_static! { array_scan::ArrayScanDef.add_to_extension(extension, extension_ref).unwrap(); }) }; - - /// Registry of extensions required to validate list operations. - pub static ref ARRAY_REGISTRY: ExtensionRegistry = ExtensionRegistry::new([ - PRELUDE.clone(), - EXTENSION.clone(), - ]); } fn array_type_def() -> &'static TypeDef { diff --git a/hugr-core/src/std_extensions/collections/array/array_op.rs b/hugr-core/src/std_extensions/collections/array/array_op.rs index 0e4edce50..96b23a3a0 100644 --- a/hugr-core/src/std_extensions/collections/array/array_op.rs +++ b/hugr-core/src/std_extensions/collections/array/array_op.rs @@ -4,7 +4,7 @@ use std::sync::{Arc, Weak}; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; -use crate::extension::prelude::{either_type, option_type, usize_custom_t}; +use crate::extension::prelude::{either_type, option_type, usize_t}; use crate::extension::simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; @@ -93,7 +93,7 @@ impl ArrayOpDef { fn signature_from_def( &self, array_def: &TypeDef, - extension_ref: &Weak, + _extension_ref: &Weak, ) -> SignatureFunc { use ArrayOpDef::*; if let new_array | pop_left | pop_right = self { @@ -107,11 +107,9 @@ impl ArrayOpDef { .expect("Array type instantiation failed"); let standard_params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; - // Construct the usize type using the passed extension reference. - // - // If we tried to use `usize_t()` directly it would try to access - // the `PRELUDE` lazy static recursively, causing a deadlock. - let usize_t: Type = usize_custom_t(extension_ref).into(); + // We can assume that the prelude has ben loaded at this point, + // since it doesn't depend on the array extension. + let usize_t: Type = usize_t(); match self { get => { @@ -263,8 +261,8 @@ impl MakeRegisteredOp for ArrayOp { super::EXTENSION_ID } - fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry { - &super::ARRAY_REGISTRY + fn extension_ref(&self) -> Weak { + Arc::downgrade(&super::EXTENSION) } } @@ -291,6 +289,7 @@ mod tests { use strum::IntoEnumIterator; use crate::extension::prelude::usize_t; + use crate::std_extensions::arithmetic::float_types::float64_type; use crate::std_extensions::collections::array::new_array_op; use crate::{ builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}, @@ -436,4 +435,24 @@ mod tests { ) ); } + + #[test] + /// Initialize an array operation where the element type is not from the prelude. + fn test_non_prelude_op() { + let size = 2; + let element_ty = float64_type(); + let op = ArrayOpDef::get.to_concrete(element_ty.clone(), size); + + let optype: OpType = op.into(); + + let sig = optype.dataflow_signature().unwrap(); + + assert_eq!( + sig.io(), + ( + &vec![array_type(size, element_ty.clone()), usize_t()].into(), + &vec![option_type(element_ty.clone()).into()].into() + ) + ); + } } diff --git a/hugr-core/src/std_extensions/collections/array/array_repeat.rs b/hugr-core/src/std_extensions/collections/array/array_repeat.rs index b908d6296..544866970 100644 --- a/hugr-core/src/std_extensions/collections/array/array_repeat.rs +++ b/hugr-core/src/std_extensions/collections/array/array_repeat.rs @@ -6,9 +6,7 @@ use std::sync::{Arc, Weak}; use crate::extension::simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; -use crate::extension::{ - ExtensionId, ExtensionRegistry, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDef, -}; +use crate::extension::{ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDef}; use crate::ops::{ExtensionOp, NamedOp, OpName}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncValueType, PolyFuncTypeRV, Signature, Type, TypeBound}; @@ -157,8 +155,8 @@ impl MakeRegisteredOp for ArrayRepeat { super::EXTENSION_ID } - fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { - &super::ARRAY_REGISTRY + fn extension_ref(&self) -> Weak { + Arc::downgrade(&super::EXTENSION) } } diff --git a/hugr-core/src/std_extensions/collections/array/array_scan.rs b/hugr-core/src/std_extensions/collections/array/array_scan.rs index 4bfd8889a..86a0fe94e 100644 --- a/hugr-core/src/std_extensions/collections/array/array_scan.rs +++ b/hugr-core/src/std_extensions/collections/array/array_scan.rs @@ -14,7 +14,7 @@ use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncTypeBase, PolyFuncTypeRV, RowVariable, Type, TypeBound, TypeRV}; use crate::Extension; -use super::{array_type_def, instantiate_array, ARRAY_REGISTRY, ARRAY_TYPENAME}; +use super::{array_type_def, instantiate_array, ARRAY_TYPENAME}; /// Name of the operation for the combined map/fold operation pub const ARRAY_SCAN_OP_ID: OpName = OpName::new_inline("scan"); @@ -203,8 +203,8 @@ impl MakeRegisteredOp for ArrayScan { super::EXTENSION_ID } - fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry { - &ARRAY_REGISTRY + fn extension_ref(&self) -> Weak { + Arc::downgrade(&super::EXTENSION) } } diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 426055251..0a37238c4 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -18,7 +18,7 @@ use crate::extension::resolution::{ WeakExtensionRegistry, }; use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; -use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc, PRELUDE}; +use crate::extension::{ExtensionBuildError, OpDef, SignatureFunc}; use crate::ops::constant::{maybe_hash_values, TryHash, ValueName}; use crate::ops::{OpName, Value}; use crate::types::{TypeName, TypeRowRV}; @@ -299,12 +299,6 @@ lazy_static! { ListOp::load_all_ops(extension, extension_ref).unwrap(); }) }; - - /// Registry of extensions required to validate list operations. - pub static ref LIST_REGISTRY: ExtensionRegistry = ExtensionRegistry::new([ - PRELUDE.clone(), - EXTENSION.clone(), - ]); } impl MakeRegisteredOp for ListOp { @@ -312,8 +306,8 @@ impl MakeRegisteredOp for ListOp { EXTENSION_ID.to_owned() } - fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { - &LIST_REGISTRY + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXTENSION) } } @@ -391,7 +385,6 @@ impl ListOpInst { ExtensionOp::new( registry.get(&EXTENSION_ID)?.get_op(&self.name())?.clone(), self.type_args(), - ®istry, ) .ok() } @@ -405,6 +398,7 @@ mod test { const_fail_tuple, const_none, const_ok_tuple, const_some_tuple, }; use crate::ops::OpTrait; + use crate::std_extensions::STD_REG; use crate::PortIndex; use crate::{ extension::{ @@ -421,7 +415,6 @@ mod test { fn test_extension() { assert_eq!(&ListOp::push.extension_id(), EXTENSION.name()); assert_eq!(&ListOp::push.extension(), EXTENSION.name()); - assert!(ListOp::pop.registry().contains(EXTENSION.name())); for (_, op_def) in EXTENSION.operations() { assert_eq!(op_def.extension_id(), &EXTENSION_ID); } @@ -538,7 +531,7 @@ mod test { let res = op .with_type(usize_t()) - .to_extension_op(&LIST_REGISTRY) + .to_extension_op(&STD_REG) .unwrap() .constant_fold(&consts) .unwrap(); diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index adcb2c970..6528009f6 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -12,7 +12,7 @@ use crate::{ extension::{ prelude::bool_t, simple_op::{try_from_name, MakeOpDef, MakeRegisteredOp, OpLoadError}, - ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, + ExtensionId, OpDef, SignatureFunc, }, ops, types::type_param::TypeArg, @@ -129,8 +129,6 @@ fn extension() -> Arc { lazy_static! { /// Reference to the logic Extension. pub static ref EXTENSION: Arc = extension(); - /// Registry required to validate logic extension. - pub static ref LOGIC_REG: ExtensionRegistry = ExtensionRegistry::new([EXTENSION.clone()]); } impl MakeRegisteredOp for LogicOp { @@ -138,8 +136,8 @@ impl MakeRegisteredOp for LogicOp { EXTENSION_ID.to_owned() } - fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { - &LOGIC_REG + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXTENSION) } } diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index d263f7a7b..16c98352a 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -14,7 +14,7 @@ use crate::{ simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }, - ExtensionId, ExtensionRegistry, OpDef, SignatureError, SignatureFunc, + ExtensionId, OpDef, SignatureError, SignatureFunc, }, ops::{custom::ExtensionOp, NamedOp}, type_row, @@ -109,8 +109,6 @@ fn extension() -> Arc { lazy_static! { /// Reference to the pointer Extension. pub static ref EXTENSION: Arc = extension(); - /// Registry required to validate pointer extension. - pub static ref PTR_REG: ExtensionRegistry = ExtensionRegistry::new([EXTENSION.clone()]); } /// Integer type of a given bit width (specified by the TypeArg). Depending on @@ -169,8 +167,8 @@ impl MakeRegisteredOp for PtrOp { EXTENSION_ID.to_owned() } - fn registry<'s, 'r: 's>(&'s self) -> &'r ExtensionRegistry { - &PTR_REG + fn extension_ref(&self) -> Weak { + Arc::downgrade(&EXTENSION) } } diff --git a/hugr-core/src/types.rs b/hugr-core/src/types.rs index 23db1f1d3..c285f3db6 100644 --- a/hugr-core/src/types.rs +++ b/hugr-core/src/types.rs @@ -32,7 +32,7 @@ use itertools::{repeat_n, Itertools}; use proptest_derive::Arbitrary; use serde::{Deserialize, Serialize}; -use crate::extension::{ExtensionRegistry, SignatureError}; +use crate::extension::SignatureError; use crate::ops::AliasDecl; use self::type_param::TypeParam; @@ -460,23 +460,19 @@ impl TypeBase { /// [RowVariable]: TypeEnum::RowVariable /// [validate]: crate::types::type_param::TypeArg::validate /// [TypeDef]: crate::extension::TypeDef - pub(crate) fn validate( - &self, - extension_registry: &ExtensionRegistry, - var_decls: &[TypeParam], - ) -> Result<(), SignatureError> { + pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { // There is no need to check the components against the bound, // that is guaranteed by construction (even for deserialization) match &self.0 { - TypeEnum::Sum(SumType::General { rows }) => rows - .iter() - .try_for_each(|row| row.validate(extension_registry, var_decls)), + TypeEnum::Sum(SumType::General { rows }) => { + rows.iter().try_for_each(|row| row.validate(var_decls)) + } TypeEnum::Sum(SumType::Unit { .. }) => Ok(()), // No leaves there TypeEnum::Alias(_) => Ok(()), - TypeEnum::Extension(custy) => custy.validate(extension_registry, var_decls), + TypeEnum::Extension(custy) => custy.validate(var_decls), // Function values may be passed around without knowing their arity // (i.e. with row vars) as long as they are not called: - TypeEnum::Function(ft) => ft.validate(extension_registry, var_decls), + TypeEnum::Function(ft) => ft.validate(var_decls), TypeEnum::Variable(idx, bound) => check_typevar_decl(var_decls, *idx, &(*bound).into()), TypeEnum::RowVar(rv) => rv.validate(var_decls), } @@ -588,7 +584,7 @@ impl From for TypeRV { /// Details a replacement of type variables with a finite list of known values. /// (Variables out of the range of the list will result in a panic) -pub struct Substitution<'a>(&'a [TypeArg], &'a ExtensionRegistry); +pub struct Substitution<'a>(&'a [TypeArg]); impl<'a> Substitution<'a> { /// Create a new Substitution given the replacement values (indexed @@ -597,8 +593,8 @@ impl<'a> Substitution<'a> { /// containing a type-variable. /// /// [TypeDef]: crate::extension::TypeDef - pub fn new(items: &'a [TypeArg], exts: &'a ExtensionRegistry) -> Self { - Self(items, exts) + pub fn new(items: &'a [TypeArg]) -> Self { + Self(items) } pub(crate) fn apply_var(&self, idx: usize, decl: &TypeParam) -> TypeArg { @@ -639,10 +635,6 @@ impl<'a> Substitution<'a> { _ => panic!("Not a type or list of types - call validate() ?"), } } - - pub(crate) fn extension_registry(&self) -> &ExtensionRegistry { - self.1 - } } pub(crate) fn check_typevar_decl( diff --git a/hugr-core/src/types/custom.rs b/hugr-core/src/types/custom.rs index 81bc814ac..201dc74c8 100644 --- a/hugr-core/src/types/custom.rs +++ b/hugr-core/src/types/custom.rs @@ -2,9 +2,9 @@ //! //! [`Type`]: super::Type use std::fmt::{self, Display}; -use std::sync::Weak; +use std::sync::{Arc, Weak}; -use crate::extension::{ExtensionId, ExtensionRegistry, SignatureError, TypeDef}; +use crate::extension::{ExtensionId, SignatureError, TypeDef}; use crate::Extension; use super::{ @@ -77,46 +77,41 @@ impl CustomType { self.bound } - pub(super) fn validate( - &self, - extension_registry: &ExtensionRegistry, - var_decls: &[TypeParam], - ) -> Result<(), SignatureError> { + pub(super) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { // Check the args are individually ok - self.args - .iter() - .try_for_each(|a| a.validate(extension_registry, var_decls))?; + self.args.iter().try_for_each(|a| a.validate(var_decls))?; // And check they fit into the TypeParams declared by the TypeDef - let def = self.get_type_def(extension_registry)?; + let ext = self.get_extension()?; + let def = self.get_type_def(&ext)?; def.check_custom(self) } - fn get_type_def<'a>( - &self, - extension_registry: &'a ExtensionRegistry, - ) -> Result<&'a TypeDef, SignatureError> { - let ex = extension_registry.get(&self.extension); - // Even if OpDef's (+binaries) are not available, the part of the Extension definition - // describing the TypeDefs can easily be passed around (serialized), so should be available. - let ex = ex.ok_or(SignatureError::ExtensionNotFound { - missing: self.extension.clone(), - available: extension_registry.ids().cloned().collect(), - })?; - ex.get_type(&self.id) + fn get_type_def<'a>(&self, ext: &'a Arc) -> Result<&'a TypeDef, SignatureError> { + ext.get_type(&self.id) .ok_or(SignatureError::ExtensionTypeNotFound { exn: self.extension.clone(), typ: self.id.clone(), }) } + fn get_extension(&self) -> Result, SignatureError> { + self.extension_ref + .upgrade() + .ok_or(SignatureError::MissingTypeExtension { + missing: self.extension.clone(), + typ: self.name().clone(), + }) + } + pub(super) fn substitute(&self, tr: &Substitution) -> Self { let args = self .args .iter() .map(|arg| arg.substitute(tr)) .collect::>(); + let ext = self.get_extension().unwrap_or_else(|e| panic!("{}", e)); let bound = self - .get_type_def(tr.extension_registry()) + .get_type_def(&ext) .expect("validate should rule this out") .bound(&args); debug_assert!(self.bound.contains(bound)); diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 92d2a9479..30677751a 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -2,7 +2,7 @@ use itertools::Itertools; -use crate::extension::{ExtensionRegistry, SignatureError}; +use crate::extension::SignatureError; #[cfg(test)] use { crate::proptest::RecursionDepth, @@ -117,22 +117,17 @@ impl PolyFuncTypeBase { /// # Errors /// If there is not exactly one [TypeArg] for each binder ([Self::params]), /// or an arg does not fit into its corresponding [TypeParam] - pub(crate) fn instantiate( - &self, - args: &[TypeArg], - ext_reg: &ExtensionRegistry, - ) -> Result, SignatureError> { + pub(crate) fn instantiate(&self, args: &[TypeArg]) -> Result, SignatureError> { // Check that args are applicable, and that we have a value for each binder, // i.e. each possible free variable within the body. check_type_args(args, &self.params)?; - Ok(self.body.substitute(&Substitution(args, ext_reg))) + Ok(self.body.substitute(&Substitution(args))) } /// Validates this instance, checking that the types in the body are /// wellformed with respect to the registry, and the type variables declared. - pub fn validate(&self, reg: &ExtensionRegistry) -> Result<(), SignatureError> { - // TODO https://github.com/CQCL/hugr/issues/624 validate TypeParams declared here, too - self.body.validate(reg, &self.params) + pub fn validate(&self) -> Result<(), SignatureError> { + self.body.validate(&self.params) } /// Helper function for the Display implementation @@ -161,10 +156,7 @@ pub(crate) mod test { use lazy_static::lazy_static; use crate::extension::prelude::{bool_t, usize_t}; - use crate::extension::{ - ExtensionId, ExtensionRegistry, SignatureError, TypeDefBound, EMPTY_REG, PRELUDE, - PRELUDE_REGISTRY, - }; + use crate::extension::{ExtensionId, ExtensionRegistry, SignatureError, TypeDefBound, PRELUDE}; use crate::std_extensions::collections::array::{self, array_type_parametric}; use crate::std_extensions::collections::list; use crate::types::signature::FuncTypeBase; @@ -185,10 +177,9 @@ pub(crate) mod test { fn new_validated( params: impl Into>, body: FuncTypeBase, - extension_registry: &ExtensionRegistry, ) -> Result { let res = Self::new(params, body); - res.validate(extension_registry)?; + res.validate()?; Ok(res) } } @@ -201,10 +192,9 @@ pub(crate) mod test { let list_len = PolyFuncTypeBase::new_validated( [TypeBound::Any.into()], Signature::new(vec![list_of_var], vec![usize_t()]), - ®ISTRY, )?; - let t = list_len.instantiate(&[TypeArg::Type { ty: usize_t() }], ®ISTRY)?; + let t = list_len.instantiate(&[TypeArg::Type { ty: usize_t() }])?; assert_eq!( t, Signature::new( @@ -228,28 +218,19 @@ pub(crate) mod test { // Valid schema... let good_array = array_type_parametric(size_var.clone(), ty_var.clone())?; - let good_ts = PolyFuncTypeBase::new_validated( - type_params.clone(), - Signature::new_endo(good_array), - &array::ARRAY_REGISTRY, - )?; + let good_ts = + PolyFuncTypeBase::new_validated(type_params.clone(), Signature::new_endo(good_array))?; // Sanity check (good args) - good_ts.instantiate( - &[ - TypeArg::BoundedNat { n: 5 }, - TypeArg::Type { ty: usize_t() }, - ], - &array::ARRAY_REGISTRY, - )?; - - let wrong_args = good_ts.instantiate( - &[ - TypeArg::Type { ty: usize_t() }, - TypeArg::BoundedNat { n: 5 }, - ], - &array::ARRAY_REGISTRY, - ); + good_ts.instantiate(&[ + TypeArg::BoundedNat { n: 5 }, + TypeArg::Type { ty: usize_t() }, + ])?; + + let wrong_args = good_ts.instantiate(&[ + TypeArg::Type { ty: usize_t() }, + TypeArg::BoundedNat { n: 5 }, + ]); assert_eq!( wrong_args, Err(SignatureError::TypeArgMismatch( @@ -277,11 +258,8 @@ pub(crate) mod test { TypeBound::Any, &Arc::downgrade(&array::EXTENSION), )); - let bad_ts = PolyFuncTypeBase::new_validated( - type_params.clone(), - Signature::new_endo(bad_array), - &array::ARRAY_REGISTRY, - ); + let bad_ts = + PolyFuncTypeBase::new_validated(type_params.clone(), Signature::new_endo(bad_array)); assert_eq!(bad_ts.err(), Some(arg_err)); Ok(()) @@ -303,8 +281,7 @@ pub(crate) mod test { params: vec![TypeBound::Any.into(), TypeParam::max_nat()], }, ] { - let invalid_ts = - PolyFuncTypeBase::new_validated([decl.clone()], body_type.clone(), ®ISTRY); + let invalid_ts = PolyFuncTypeBase::new_validated([decl.clone()], body_type.clone()); assert_eq!( invalid_ts.err(), Some(SignatureError::TypeVarDoesNotMatchDeclaration { @@ -314,7 +291,7 @@ pub(crate) mod test { ); } // Variable not declared at all - let invalid_ts = PolyFuncTypeBase::new_validated([], body_type, ®ISTRY); + let invalid_ts = PolyFuncTypeBase::new_validated([], body_type); assert_eq!( invalid_ts.err(), Some(SignatureError::FreeTypeVar { @@ -358,7 +335,6 @@ pub(crate) mod test { TypeBound::Any, &Arc::downgrade(&ext), ))), - ®, ) }; for decl in accepted { @@ -421,7 +397,6 @@ pub(crate) mod test { vec![usize_t()], vec![TypeRV::new_row_var_use(0, TypeBound::Copyable)], ), - &PRELUDE_REGISTRY, ) .unwrap_err(); assert_matches!(e, SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached } => { @@ -432,7 +407,6 @@ pub(crate) mod test { let e = PolyFuncTypeBase::new_validated( [decl.clone()], Signature::new_endo(vec![Type::new_var_use(0, TypeBound::Any)]), - &EMPTY_REG, ) .unwrap_err(); assert_matches!(e, SignatureError::TypeVarDoesNotMatchDeclaration { actual, cached } => { @@ -450,25 +424,21 @@ pub(crate) mod test { vec![usize_t().into(), rty.clone()], vec![TypeRV::new_tuple(rty)], ), - &PRELUDE_REGISTRY, ) .unwrap(); fn seq2() -> Vec { vec![usize_t().into(), bool_t().into()] } - pf.instantiate(&[TypeArg::Type { ty: usize_t() }], &PRELUDE_REGISTRY) + pf.instantiate(&[TypeArg::Type { ty: usize_t() }]) .unwrap_err(); - pf.instantiate( - &[TypeArg::Sequence { - elems: vec![usize_t().into(), TypeArg::Sequence { elems: seq2() }], - }], - &PRELUDE_REGISTRY, - ) + pf.instantiate(&[TypeArg::Sequence { + elems: vec![usize_t().into(), TypeArg::Sequence { elems: seq2() }], + }]) .unwrap_err(); let t2 = pf - .instantiate(&[TypeArg::Sequence { elems: seq2() }], &PRELUDE_REGISTRY) + .instantiate(&[TypeArg::Sequence { elems: seq2() }]) .unwrap(); assert_eq!( t2, @@ -492,18 +462,14 @@ pub(crate) mod test { }), }], Signature::new(vec![usize_t(), inner_fty.clone()], vec![inner_fty]), - &PRELUDE_REGISTRY, ) .unwrap(); let inner3 = Type::new_function(Signature::new_endo(vec![usize_t(), bool_t(), usize_t()])); let t3 = pf - .instantiate( - &[TypeArg::Sequence { - elems: vec![usize_t().into(), bool_t().into(), usize_t().into()], - }], - &PRELUDE_REGISTRY, - ) + .instantiate(&[TypeArg::Sequence { + elems: vec![usize_t().into(), bool_t().into(), usize_t().into()], + }]) .unwrap(); assert_eq!( t3, diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index cac530291..cca50c94c 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -113,13 +113,9 @@ impl FuncTypeBase { (&self.input, &self.output) } - pub(super) fn validate( - &self, - extension_registry: &ExtensionRegistry, - var_decls: &[TypeParam], - ) -> Result<(), SignatureError> { - self.input.validate(extension_registry, var_decls)?; - self.output.validate(extension_registry, var_decls)?; + pub(super) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { + self.input.validate(var_decls)?; + self.output.validate(var_decls)?; self.runtime_reqs.validate(var_decls) } diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index d1e62d6b1..b81a20ace 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -11,7 +11,6 @@ use thiserror::Error; use super::row_var::MaybeRV; use super::{check_typevar_decl, NoRV, RowVariable, Substitution, Type, TypeBase, TypeBound}; -use crate::extension::ExtensionRegistry; use crate::extension::ExtensionSet; use crate::extension::SignatureError; @@ -294,17 +293,11 @@ impl TypeArg { /// Much as [Type::validate], also checks that the type of any [TypeArg::Opaque] /// is valid and closed. - pub(crate) fn validate( - &self, - extension_registry: &ExtensionRegistry, - var_decls: &[TypeParam], - ) -> Result<(), SignatureError> { + pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { match self { - TypeArg::Type { ty } => ty.validate(extension_registry, var_decls), + TypeArg::Type { ty } => ty.validate(var_decls), TypeArg::BoundedNat { .. } | TypeArg::String { .. } => Ok(()), - TypeArg::Sequence { elems } => elems - .iter() - .try_for_each(|a| a.validate(extension_registry, var_decls)), + TypeArg::Sequence { elems } => elems.iter().try_for_each(|a| a.validate(var_decls)), TypeArg::Extensions { es: _ } => Ok(()), TypeArg::Variable { v: TypeArgVariable { idx, cached_decl }, @@ -478,7 +471,7 @@ mod test { use itertools::Itertools; use super::{check_type_arg, Substitution, TypeArg, TypeParam}; - use crate::extension::prelude::{bool_t, usize_t, PRELUDE_REGISTRY}; + use crate::extension::prelude::{bool_t, usize_t}; use crate::types::{type_param::TypeArgError, TypeBound, TypeRV}; #[test] @@ -576,7 +569,7 @@ mod test { }; check_type_arg(&outer_arg, &outer_param).unwrap(); - let outer_arg2 = outer_arg.substitute(&Substitution(&[row_arg], &PRELUDE_REGISTRY)); + let outer_arg2 = outer_arg.substitute(&Substitution(&[row_arg])); assert_eq!( outer_arg2, vec![bool_t().into(), TypeArg::UNIT, usize_t().into()].into() @@ -618,8 +611,7 @@ mod test { // Now substitute a list of two types for that row-variable let row_var_arg = vec![usize_t().into(), bool_t().into()].into(); check_type_arg(&row_var_arg, &row_var_decl).unwrap(); - let subst_arg = - good_arg.substitute(&Substitution(&[row_var_arg.clone()], &PRELUDE_REGISTRY)); + let subst_arg = good_arg.substitute(&Substitution(&[row_var_arg.clone()])); check_type_arg(&subst_arg, &outer_param).unwrap(); // invariance of substitution assert_eq!( subst_arg, diff --git a/hugr-core/src/types/type_row.rs b/hugr-core/src/types/type_row.rs index c807b872f..38a4b0520 100644 --- a/hugr-core/src/types/type_row.rs +++ b/hugr-core/src/types/type_row.rs @@ -8,10 +8,7 @@ use std::{ }; use super::{type_param::TypeParam, MaybeRV, NoRV, RowVariable, Substitution, Type, TypeBase}; -use crate::{ - extension::{ExtensionRegistry, SignatureError}, - utils::display_list, -}; +use crate::{extension::SignatureError, utils::display_list}; use delegate::delegate; use itertools::Itertools; @@ -94,12 +91,8 @@ impl TypeRowBase { } } - pub(super) fn validate( - &self, - exts: &ExtensionRegistry, - var_decls: &[TypeParam], - ) -> Result<(), SignatureError> { - self.iter().try_for_each(|t| t.validate(exts, var_decls)) + pub(super) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { + self.iter().try_for_each(|t| t.validate(var_decls)) } } diff --git a/hugr-core/src/utils.rs b/hugr-core/src/utils.rs index f0c0d7b97..d488f8471 100644 --- a/hugr-core/src/utils.rs +++ b/hugr-core/src/utils.rs @@ -205,9 +205,7 @@ pub(crate) mod test_quantum_extension { } fn get_gate(gate_name: &OpNameRef) -> ExtensionOp { - EXTENSION - .instantiate_extension_op(gate_name, [], ®) - .unwrap() + EXTENSION.instantiate_extension_op(gate_name, []).unwrap() } pub(crate) fn h_gate() -> ExtensionOp { get_gate("H") diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index 881f86932..ef4ad6f4a 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -256,8 +256,9 @@ mod test_fns { use hugr_core::ops::constant::CustomConst; use hugr_core::ops::{CallIndirect, Tag, Value}; - use hugr_core::std_extensions::arithmetic::int_ops::{self, INT_OPS_REGISTRY}; + use hugr_core::std_extensions::arithmetic::int_ops::{self}; use hugr_core::std_extensions::arithmetic::int_types::ConstInt; + use hugr_core::std_extensions::STD_REG; use hugr_core::types::{Signature, Type, TypeRow}; use hugr_core::{type_row, Hugr}; @@ -356,7 +357,7 @@ mod test_fns { let hugr = SimpleHugrConfig::new() .with_outs(v.get_type()) - .with_extensions(INT_OPS_REGISTRY.to_owned()) + .with_extensions(STD_REG.to_owned()) .finish(|mut builder: DFGW| { let konst = builder.add_load_value(v); builder.finish_with_outputs([konst]).unwrap() @@ -413,12 +414,12 @@ mod test_fns { let hugr = SimpleHugrConfig::new() .with_outs(v1.get_type()) - .with_extensions(INT_OPS_REGISTRY.to_owned()) - .finish_with_exts(|mut builder: DFGW, ext_reg| { + .with_extensions(STD_REG.to_owned()) + .finish(|mut builder: DFGW| { let k1 = builder.add_load_value(v1); let k2 = builder.add_load_value(v2); let ext_op = int_ops::EXTENSION - .instantiate_extension_op("iadd", [4.into()], ext_reg) + .instantiate_extension_op("iadd", [4.into()]) .unwrap(); let add = builder.add_dataflow_op(ext_op, [k1, k2]).unwrap(); builder.finish_with_outputs(add.outputs()).unwrap() diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index 36f58dbc0..e4b9deb75 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -675,6 +675,7 @@ mod test { use hugr_core::extension::ExtensionSet; use hugr_core::ops::Tag; use hugr_core::std_extensions::collections::array::{self, array_type, ArrayRepeat, ArrayScan}; + use hugr_core::std_extensions::STD_REG; use hugr_core::types::Type; use hugr_core::{ builder::{Dataflow, DataflowSubContainer, SubContainer}, @@ -706,7 +707,7 @@ mod test { #[rstest] fn emit_all_ops(mut llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() - .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) + .with_extensions(STD_REG.to_owned()) .finish(|mut builder| { array_op_builder::test::all_array_ops(builder.dfg_builder_endo([]).unwrap()) .finish_sub_container() @@ -723,7 +724,7 @@ mod test { #[rstest] fn emit_get(mut llvm_ctx: TestContext) { let hugr = SimpleHugrConfig::new() - .with_extensions(array::ARRAY_REGISTRY.to_owned()) + .with_extensions(STD_REG.to_owned()) .finish(|mut builder| { let us1 = builder.add_load_value(ConstUsize::new(1)); let us2 = builder.add_load_value(ConstUsize::new(2)); diff --git a/hugr-llvm/src/extension/collections/list.rs b/hugr-llvm/src/extension/collections/list.rs index fb200ca97..004209445 100644 --- a/hugr-llvm/src/extension/collections/list.rs +++ b/hugr-llvm/src/extension/collections/list.rs @@ -390,7 +390,7 @@ mod test { #[case::length(ListOp::length)] fn test_list_emission(mut llvm_ctx: TestContext, #[case] op: ListOp) { let ext_op = list::EXTENSION - .instantiate_extension_op(op.name().as_ref(), [qb_t().into()], &list::LIST_REGISTRY) + .instantiate_extension_op(op.name().as_ref(), [qb_t().into()]) .unwrap(); let es = ExtensionRegistry::new([list::EXTENSION.to_owned(), prelude::PRELUDE.to_owned()]); es.validate().unwrap(); diff --git a/hugr-llvm/src/extension/conversions.rs b/hugr-llvm/src/extension/conversions.rs index 45e6fd485..a3de97d68 100644 --- a/hugr-llvm/src/extension/conversions.rs +++ b/hugr-llvm/src/extension/conversions.rs @@ -252,11 +252,12 @@ mod test { use crate::test::{exec_ctx, llvm_ctx, TestContext}; use hugr_core::builder::SubContainer; use hugr_core::std_extensions::arithmetic::int_types::ConstInt; + use hugr_core::std_extensions::STD_REG; use hugr_core::{ builder::{Dataflow, DataflowSubContainer}, extension::prelude::{usize_t, ConstUsize, PRELUDE_REGISTRY}, std_extensions::arithmetic::{ - conversions::{ConvertOpDef, CONVERT_OPS_REGISTRY, EXTENSION}, + conversions::{ConvertOpDef, EXTENSION}, float_types::float64_type, int_types::INT_TYPES, }, @@ -274,15 +275,11 @@ mod test { SimpleHugrConfig::new() .with_ins(vec![in_type.clone()]) .with_outs(vec![out_type.clone()]) - .with_extensions(CONVERT_OPS_REGISTRY.clone()) + .with_extensions(STD_REG.clone()) .finish(|mut hugr_builder| { let [in1] = hugr_builder.input_wires_arr(); let ext_op = EXTENSION - .instantiate_extension_op( - name.as_ref(), - [(int_width as u64).into()], - &CONVERT_OPS_REGISTRY, - ) + .instantiate_extension_op(name.as_ref(), [(int_width as u64).into()]) .unwrap(); let outputs = hugr_builder .add_dataflow_op(ext_op, [in1]) @@ -350,12 +347,10 @@ mod test { let hugr = SimpleHugrConfig::new() .with_ins(vec![in_t]) .with_outs(vec![out_t]) - .with_extensions(CONVERT_OPS_REGISTRY.to_owned()) + .with_extensions(STD_REG.to_owned()) .finish(|mut hugr_builder| { let [in1] = hugr_builder.input_wires_arr(); - let ext_op = EXTENSION - .instantiate_extension_op(op_name, [], &CONVERT_OPS_REGISTRY) - .unwrap(); + let ext_op = EXTENSION.instantiate_extension_op(op_name, []).unwrap(); let [out1] = hugr_builder .add_dataflow_op(ext_op, [in1]) .unwrap() @@ -385,7 +380,7 @@ mod test { fn usize_roundtrip(mut exec_ctx: TestContext, #[case] val: u64) -> () { let hugr = SimpleHugrConfig::new() .with_outs(usize_t()) - .with_extensions(CONVERT_OPS_REGISTRY.clone()) + .with_extensions(STD_REG.clone()) .finish(|mut builder: DFGW| { let k = builder.add_load_value(ConstUsize::new(val)); let [int] = builder @@ -411,7 +406,7 @@ mod test { let int64 = INT_TYPES[6].clone(); SimpleHugrConfig::new() .with_outs(usize_t()) - .with_extensions(CONVERT_OPS_REGISTRY.clone()) + .with_extensions(STD_REG.clone()) .finish(|mut builder| { let k = builder.add_load_value(ConstUsize::new(val)); let [int] = builder @@ -573,12 +568,10 @@ mod test { let hugr = SimpleHugrConfig::new() .with_outs(vec![usize_t()]) - .with_extensions(CONVERT_OPS_REGISTRY.to_owned()) + .with_extensions(STD_REG.to_owned()) .finish(|mut builder| { let i = builder.add_load_value(ConstInt::new_u(0, i).unwrap()); - let ext_op = EXTENSION - .instantiate_extension_op("itobool", [], &CONVERT_OPS_REGISTRY) - .unwrap(); + let ext_op = EXTENSION.instantiate_extension_op("itobool", []).unwrap(); let [b] = builder.add_dataflow_op(ext_op, [i]).unwrap().outputs_arr(); let mut cond = builder .conditional_builder( @@ -609,16 +602,12 @@ mod test { fn itobool_roundtrip(mut exec_ctx: TestContext, #[values(0, 1)] i: u64) { let hugr = SimpleHugrConfig::new() .with_outs(vec![INT_TYPES[0].clone()]) - .with_extensions(CONVERT_OPS_REGISTRY.to_owned()) + .with_extensions(STD_REG.to_owned()) .finish(|mut builder| { let i = builder.add_load_value(ConstInt::new_u(0, i).unwrap()); - let i2b = EXTENSION - .instantiate_extension_op("itobool", [], &CONVERT_OPS_REGISTRY) - .unwrap(); + let i2b = EXTENSION.instantiate_extension_op("itobool", []).unwrap(); let [b] = builder.add_dataflow_op(i2b, [i]).unwrap().outputs_arr(); - let b2i = EXTENSION - .instantiate_extension_op("ifrombool", [], &CONVERT_OPS_REGISTRY) - .unwrap(); + let b2i = EXTENSION.instantiate_extension_op("ifrombool", []).unwrap(); let [i] = builder.add_dataflow_op(b2i, [b]).unwrap().outputs_arr(); builder.finish_with_outputs([i]).unwrap() }); diff --git a/hugr-llvm/src/extension/float.rs b/hugr-llvm/src/extension/float.rs index 7cb694cb0..de40921eb 100644 --- a/hugr-llvm/src/extension/float.rs +++ b/hugr-llvm/src/extension/float.rs @@ -132,15 +132,13 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { mod test { use hugr_core::extension::simple_op::MakeOpDef; use hugr_core::extension::SignatureFunc; - use hugr_core::std_extensions::arithmetic::float_ops::{self, FloatOps}; + use hugr_core::std_extensions::arithmetic::float_ops::FloatOps; + use hugr_core::std_extensions::STD_REG; use hugr_core::types::TypeRow; use hugr_core::Hugr; use hugr_core::{ builder::{Dataflow, DataflowSubContainer}, - std_extensions::arithmetic::{ - float_ops::FLOAT_OPS_REGISTRY, - float_types::{float64_type, ConstF64}, - }, + std_extensions::arithmetic::float_types::{float64_type, ConstF64}, }; use rstest::rstest; @@ -162,7 +160,7 @@ mod test { SimpleHugrConfig::new() .with_ins(inp) .with_outs(out) - .with_extensions(float_ops::FLOAT_OPS_REGISTRY.to_owned()) + .with_extensions(STD_REG.to_owned()) .finish(|mut builder| { let outputs = builder .add_dataflow_op(op, builder.input_wires()) @@ -177,7 +175,7 @@ mod test { llvm_ctx.add_extensions(add_float_extensions); let hugr = SimpleHugrConfig::new() .with_outs(float64_type()) - .with_extensions(FLOAT_OPS_REGISTRY.to_owned()) + .with_extensions(STD_REG.to_owned()) .finish(|mut builder| { let c = builder.add_load_value(ConstF64::new(3.12)); builder.finish_with_outputs([c]).unwrap() diff --git a/hugr-llvm/src/extension/int.rs b/hugr-llvm/src/extension/int.rs index e6d045ceb..a22204061 100644 --- a/hugr-llvm/src/extension/int.rs +++ b/hugr-llvm/src/extension/int.rs @@ -165,6 +165,7 @@ impl<'a, H: HugrView + 'a> CodegenExtsBuilder<'a, H> { #[cfg(test)] mod test { + use hugr_core::std_extensions::STD_REG; use hugr_core::{ builder::{Dataflow, DataflowSubContainer}, extension::prelude::bool_t, @@ -198,15 +199,11 @@ mod test { SimpleHugrConfig::new() .with_ins(vec![ty.clone(), ty.clone()]) .with_outs(output_types.into()) - .with_extensions(int_ops::INT_OPS_REGISTRY.clone()) + .with_extensions(STD_REG.clone()) .finish(|mut hugr_builder| { let [in1, in2] = hugr_builder.input_wires_arr(); let ext_op = int_ops::EXTENSION - .instantiate_extension_op( - name.as_ref(), - [(log_width as u64).into()], - &int_ops::INT_OPS_REGISTRY, - ) + .instantiate_extension_op(name.as_ref(), [(log_width as u64).into()]) .unwrap(); let outputs = hugr_builder .add_dataflow_op(ext_op, [in1, in2]) @@ -221,15 +218,11 @@ mod test { SimpleHugrConfig::new() .with_ins(vec![ty.clone()]) .with_outs(vec![ty.clone()]) - .with_extensions(int_ops::INT_OPS_REGISTRY.clone()) + .with_extensions(STD_REG.clone()) .finish(|mut hugr_builder| { let [in1] = hugr_builder.input_wires_arr(); let ext_op = int_ops::EXTENSION - .instantiate_extension_op( - name.as_ref(), - [(log_width as u64).into()], - &int_ops::INT_OPS_REGISTRY, - ) + .instantiate_extension_op(name.as_ref(), [(log_width as u64).into()]) .unwrap(); let outputs = hugr_builder .add_dataflow_op(ext_op, [in1]) diff --git a/hugr-llvm/src/extension/prelude.rs b/hugr-llvm/src/extension/prelude.rs index 4e012a566..82297fc3c 100644 --- a/hugr-llvm/src/extension/prelude.rs +++ b/hugr-llvm/src/extension/prelude.rs @@ -295,7 +295,7 @@ fn add_prelude_extensions<'a, H: HugrView + 'a>( #[cfg(test)] mod test { use hugr_core::builder::{Dataflow, DataflowSubContainer}; - use hugr_core::extension::{PRELUDE, PRELUDE_REGISTRY}; + use hugr_core::extension::PRELUDE; use hugr_core::types::{Type, TypeArg}; use hugr_core::{type_row, Hugr}; use prelude::{bool_t, qb_t, usize_t, PANIC_OP_ID, PRINT_OP_ID}; @@ -435,11 +435,7 @@ mod test { elems: vec![type_arg_q.clone(), type_arg_q], }; let panic_op = PRELUDE - .instantiate_extension_op( - &PANIC_OP_ID, - [type_arg_2q.clone(), type_arg_2q.clone()], - &PRELUDE_REGISTRY, - ) + .instantiate_extension_op(&PANIC_OP_ID, [type_arg_2q.clone(), type_arg_2q.clone()]) .unwrap(); let hugr = SimpleHugrConfig::new() @@ -462,9 +458,7 @@ mod test { #[rstest] fn prelude_print(prelude_llvm_ctx: TestContext) { let greeting: ConstString = ConstString::new("Hello, world!".into()); - let print_op = PRELUDE - .instantiate_extension_op(&PRINT_OP_ID, [], &PRELUDE_REGISTRY) - .unwrap(); + let print_op = PRELUDE.instantiate_extension_op(&PRINT_OP_ID, []).unwrap(); let hugr = SimpleHugrConfig::new() .with_extensions(prelude::PRELUDE_REGISTRY.to_owned()) diff --git a/hugr-llvm/src/utils/inline_constant_functions.rs b/hugr-llvm/src/utils/inline_constant_functions.rs index 99cb6527a..28e664b97 100644 --- a/hugr-llvm/src/utils/inline_constant_functions.rs +++ b/hugr-llvm/src/utils/inline_constant_functions.rs @@ -1,5 +1,4 @@ use hugr_core::{ - extension::ExtensionRegistry, hugr::hugrmut::HugrMut, ops::{FuncDefn, LoadFunction, Value}, types::PolyFuncType, @@ -12,18 +11,12 @@ fn const_fn_name(konst_n: Node) -> String { format!("const_fun_{}", konst_n.index()) } -pub fn inline_constant_functions( - hugr: &mut impl HugrMut, - registry: &ExtensionRegistry, -) -> Result<()> { - while inline_constant_functions_impl(hugr, registry)? {} +pub fn inline_constant_functions(hugr: &mut impl HugrMut) -> Result<()> { + while inline_constant_functions_impl(hugr)? {} Ok(()) } -fn inline_constant_functions_impl( - hugr: &mut impl HugrMut, - registry: &ExtensionRegistry, -) -> Result { +fn inline_constant_functions_impl(hugr: &mut impl HugrMut) -> Result { let mut const_funs = vec![]; for n in hugr.nodes() { @@ -76,10 +69,7 @@ fn inline_constant_functions_impl( hugr.insert_hugr(func_node, func_hugr); for lcn in load_constant_ns { - hugr.replace_op( - lcn, - LoadFunction::try_new(polysignature.clone(), [], registry)?, - )?; + hugr.replace_op(lcn, LoadFunction::try_new(polysignature.clone(), [])?)?; } any_changes = true; } @@ -95,7 +85,7 @@ mod test { Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder, }, - extension::{prelude::qb_t, PRELUDE_REGISTRY}, + extension::prelude::qb_t, ops::{CallIndirect, Const, Value}, types::Signature, Hugr, HugrView, Wire, @@ -140,7 +130,7 @@ mod test { builder.finish_hugr().unwrap() }; - inline_constant_functions(&mut hugr, &PRELUDE_REGISTRY).unwrap(); + inline_constant_functions(&mut hugr).unwrap(); for n in hugr.nodes() { if let Some(konst) = hugr.get_optype(n).as_const() { @@ -189,7 +179,7 @@ mod test { builder.finish_hugr().unwrap() }; - inline_constant_functions(&mut hugr, &PRELUDE_REGISTRY).unwrap(); + inline_constant_functions(&mut hugr).unwrap(); for n in hugr.nodes() { if let Some(konst) = hugr.get_optype(n).as_const() { diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 51bd8442f..9a2ca013b 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -1,6 +1,7 @@ use std::collections::hash_map::RandomState; use std::collections::HashSet; +use hugr_core::std_extensions::STD_REG; use itertools::Itertools; use lazy_static::lazy_static; use rstest::rstest; @@ -29,7 +30,6 @@ use hugr_core::types::{Signature, SumType, Type, TypeRow, TypeRowRV}; use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; use crate::dataflow::{partial_from_const, DFContext, PartialValue}; -use crate::test::TEST_REG; use super::{constant_fold_pass, ConstFoldContext, ConstantFoldPass, ValueHandle}; @@ -186,7 +186,7 @@ fn test_list_ops() -> Result<(), Box> { .add_dataflow_op( ListOp::pop .with_type(bool_t()) - .to_extension_op(&TEST_REG) + .to_extension_op(&STD_REG) .unwrap(), [list], )? @@ -199,7 +199,7 @@ fn test_list_ops() -> Result<(), Box> { .add_dataflow_op( ListOp::push .with_type(bool_t()) - .to_extension_op(&TEST_REG) + .to_extension_op(&STD_REG) .unwrap(), [list, elem], )? diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 86bdc92f9..a4507e156 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -14,28 +14,3 @@ pub mod validation; pub use force_order::{force_order, force_order_by_key}; pub use lower::{lower_ops, replace_many_ops}; pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; - -#[cfg(test)] -pub(crate) mod test { - - use lazy_static::lazy_static; - - use hugr_core::extension::{ExtensionRegistry, PRELUDE}; - use hugr_core::std_extensions::arithmetic; - use hugr_core::std_extensions::collections; - use hugr_core::std_extensions::logic; - - lazy_static! { - /// A registry containing various extensions for testing. - pub(crate) static ref TEST_REG: ExtensionRegistry = ExtensionRegistry::new([ - PRELUDE.to_owned(), - arithmetic::int_ops::EXTENSION.to_owned(), - arithmetic::int_types::EXTENSION.to_owned(), - arithmetic::float_types::EXTENSION.to_owned(), - arithmetic::float_ops::EXTENSION.to_owned(), - logic::EXTENSION.to_owned(), - arithmetic::conversions::EXTENSION.to_owned(), - collections::list::EXTENSION.to_owned(), - ]); - } -} diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index 9f24daefa..5fde84b3e 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -165,7 +165,6 @@ mod test { use hugr_core::builder::{endo_sig, inout_sig, CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; use hugr_core::extension::prelude::{qb_t, usize_t, ConstUsize}; - use hugr_core::extension::PRELUDE_REGISTRY; use hugr_core::hugr::views::sibling::SiblingMut; use hugr_core::ops::constant::Value; use hugr_core::ops::handle::CfgID; @@ -230,7 +229,7 @@ mod test { let loop_variants: TypeRow = vec![qb_t()].into(); let exit_types: TypeRow = vec![usize_t()].into(); let e = extension(); - let tst_op = e.instantiate_extension_op("Test", [], &PRELUDE_REGISTRY)?; + let tst_op = e.instantiate_extension_op("Test", [])?; let mut h = CFGBuilder::new(inout_sig(loop_variants.clone(), exit_types.clone()))?; let mut no_b1 = h.simple_entry_builder_exts(loop_variants.clone(), 1, PRELUDE_ID)?; let n = no_b1.add_dataflow_op(Noop::new(qb_t()), no_b1.input_wires())?; @@ -317,9 +316,7 @@ mod test { // CFG Normalization would move everything outside the CFG and elide the CFG altogether, // but this is an easy-to-construct test of merge-basic-blocks only (no CFG normalization). let e = extension(); - let tst_op: OpType = e - .instantiate_extension_op("Test", &[], &PRELUDE_REGISTRY)? - .into(); + let tst_op: OpType = e.instantiate_extension_op("Test", &[])?.into(); let [res_t] = tst_op .dataflow_signature() .unwrap() diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index b85a2b580..81cbed5ab 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -5,7 +5,6 @@ use std::{ }; use hugr_core::{ - extension::ExtensionRegistry, ops::{Call, FuncDefn, LoadFunction, OpTrait}, types::{Signature, Substitution, TypeArg}, Node, @@ -35,15 +34,13 @@ pub fn monomorphize(mut h: Hugr) -> Hugr { // We clone the extension registry because we will need a reference to // create our mutable substitutions. This is cannot cause a problem because // we will not be adding any new types or extension ops to the HUGR. - let reg = h.extensions().to_owned(); - #[cfg(debug_assertions)] validate(&h); let root = h.root(); // If the root is a polymorphic function, then there are no external calls, so nothing to do if !is_polymorphic_funcdefn(h.get_optype(root)) { - mono_scan(&mut h, root, None, &mut HashMap::new(), ®); + mono_scan(&mut h, root, None, &mut HashMap::new()); if !h.get_optype(root).is_module() { return remove_polyfuncs(h); } @@ -100,7 +97,6 @@ fn mono_scan( parent: Node, mut subst_into: Option<&mut Instantiating>, cache: &mut Instantiations, - reg: &ExtensionRegistry, ) { for old_ch in h.children(parent).collect_vec() { let ch_op = h.get_optype(old_ch); @@ -118,10 +114,10 @@ fn mono_scan( node_map: inst.node_map, ..**inst }; - mono_scan(h, old_ch, Some(&mut inst), cache, reg); + mono_scan(h, old_ch, Some(&mut inst), cache); new_ch } else { - mono_scan(h, old_ch, None, cache, reg); + mono_scan(h, old_ch, None, cache); old_ch }; @@ -133,7 +129,7 @@ fn mono_scan( ( &c.type_args, mono_sig.clone(), - OpType::from(Call::try_new(mono_sig.into(), [], reg).unwrap()), + OpType::from(Call::try_new(mono_sig.into(), []).unwrap()), ) } OpType::LoadFunction(lf) => { @@ -141,9 +137,7 @@ fn mono_scan( ( &lf.type_args, mono_sig.clone(), - LoadFunction::try_new(mono_sig.into(), [], reg) - .unwrap() - .into(), + LoadFunction::try_new(mono_sig.into(), []).unwrap().into(), ) } _ => continue, @@ -153,7 +147,7 @@ fn mono_scan( }; let fn_inp = ch_op.static_input_port().unwrap(); let tgt = h.static_source(old_ch).unwrap(); // Use old_ch as edges not copied yet - let new_tgt = instantiate(h, tgt, type_args.clone(), mono_sig.clone(), cache, reg); + let new_tgt = instantiate(h, tgt, type_args.clone(), mono_sig.clone(), cache); let fn_out = { let func = h.get_optype(new_tgt).as_func_defn().unwrap(); debug_assert_eq!(func.signature, mono_sig.into()); @@ -172,7 +166,6 @@ fn instantiate( type_args: Vec, mono_sig: Signature, cache: &mut Instantiations, - reg: &ExtensionRegistry, ) -> Node { let for_func = cache.entry(poly_func).or_insert_with(|| { // First time we've instantiated poly_func. Lift any nested FuncDefn's out to the same level. @@ -214,11 +207,11 @@ fn instantiate( // Now make the instantiation let mut node_map = HashMap::new(); let mut inst = Instantiating { - subst: &Substitution::new(&type_args, reg), + subst: &Substitution::new(&type_args), target_container: mono_tgt, node_map: &mut node_map, }; - mono_scan(h, poly_func, Some(&mut inst), cache, reg); + mono_scan(h, poly_func, Some(&mut inst), cache); // Copy edges...we have built a node_map for every node in the function. // Note we could avoid building the "large" map (smaller than the Hugr we've just created) // by doing this during recursion, but we'd need to be careful with nonlocal edges - @@ -494,12 +487,8 @@ mod test { let [inw] = pf2.input_wires_arr(); let [idx] = pf2.call(mono_func.handle(), &[], []).unwrap().outputs_arr(); let op_def = collections::array::EXTENSION.get_op("get").unwrap(); - let op = hugr_core::ops::ExtensionOp::new( - op_def.clone(), - vec![sv(0), tv(1).into()], - &STD_REG, - ) - .unwrap(); + let op = hugr_core::ops::ExtensionOp::new(op_def.clone(), vec![sv(0), tv(1).into()]) + .unwrap(); let [get] = pf2.add_dataflow_op(op, [inw, idx]).unwrap().outputs_arr(); let [got] = pf2 .build_unwrap_sum(&STD_REG, 1, SumType::new([vec![], vec![tv(1)]]), get) diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index eb2471e14..a567074db 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -1273,6 +1273,9 @@ def op_def(self) -> ext.OpDef: return std.PRELUDE.get_op("Noop") + def type_args(self) -> list[tys.TypeArg]: + return [tys.TypeTypeArg(self.type_)] + def cached_signature(self) -> tys.FunctionType | None: return tys.FunctionType.endo( [self.type_], diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index 56f87c38a..2607a4773 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -297,9 +297,6 @@ def test_invalid_recursive_function() -> None: f_recursive.set_outputs(f_recursive.input_node[0]) -@pytest.mark.skip( - "Temporarily disabled until https://github.com/CQCL/hugr/issues/1774 gets fixed" -) def test_higher_order() -> None: noop_fn = Dfg(tys.Qubit) noop_fn.set_outputs(noop_fn.add(ops.Noop()(noop_fn.input_node[0]))) diff --git a/hugr/benches/benchmarks/hugr.rs b/hugr/benches/benchmarks/hugr.rs index 6bdbe5938..1afbf8940 100644 --- a/hugr/benches/benchmarks/hugr.rs +++ b/hugr/benches/benchmarks/hugr.rs @@ -4,7 +4,7 @@ pub mod examples; use criterion::{black_box, criterion_group, AxisScale, BenchmarkId, Criterion, PlotConfiguration}; #[allow(unused)] -use hugr::std_extensions::arithmetic::float_ops::FLOAT_OPS_REGISTRY; +use hugr::std_extensions::STD_REG; use hugr::Hugr; pub use examples::{circuit, simple_cfg_hugr, simple_dfg_hugr}; @@ -38,7 +38,7 @@ impl Serializer for CapnpSer { fn deserialize(&self, bytes: &[u8]) -> Hugr { let bump = bumpalo::Bump::new(); let module = hugr_model::v0::binary::read_from_slice(bytes, &bump).unwrap(); - hugr_core::import::import_hugr(&module, &FLOAT_OPS_REGISTRY).unwrap() + hugr_core::import::import_hugr(&module, &STD_REG).unwrap() } } diff --git a/hugr/benches/benchmarks/hugr/examples.rs b/hugr/benches/benchmarks/hugr/examples.rs index 74616e642..926a97205 100644 --- a/hugr/benches/benchmarks/hugr/examples.rs +++ b/hugr/benches/benchmarks/hugr/examples.rs @@ -7,7 +7,6 @@ use hugr::builder::{ HugrBuilder, ModuleBuilder, }; use hugr::extension::prelude::{bool_t, qb_t, usize_t}; -use hugr::extension::PRELUDE_REGISTRY; use hugr::ops::OpName; use hugr::std_extensions::arithmetic::float_types::float64_type; use hugr::types::Signature; @@ -94,14 +93,10 @@ pub struct CircuitLayer { /// Construct a quantum circuit with two qubits and `layers` layers applying `H q0; CX q0, q1; CX q1, q0`. pub fn circuit(layers: usize) -> (Hugr, Vec) { - let h_gate = QUANTUM_EXT - .instantiate_extension_op("H", [], &PRELUDE_REGISTRY) - .unwrap(); - let cx_gate = QUANTUM_EXT - .instantiate_extension_op("CX", [], &PRELUDE_REGISTRY) - .unwrap(); + let h_gate = QUANTUM_EXT.instantiate_extension_op("H", []).unwrap(); + let cx_gate = QUANTUM_EXT.instantiate_extension_op("CX", []).unwrap(); // let rz = QUANTUM_EXT - // .instantiate_extension_op("Rz", [], &FLOAT_OPS_REGISTRY) + // .instantiate_extension_op("Rz", []) // .unwrap(); let signature = Signature::new_endo(vec![qb_t(), qb_t()]).with_extension_delta(QUANTUM_EXT.name().clone()); diff --git a/hugr/src/lib.rs b/hugr/src/lib.rs index e3313887e..708e8b47f 100644 --- a/hugr/src/lib.rs +++ b/hugr/src/lib.rs @@ -81,11 +81,10 @@ //! lazy_static! { //! /// Quantum extension definition. //! pub static ref EXTENSION: Arc = extension(); -//! pub static ref REG: ExtensionRegistry = ExtensionRegistry::new([EXTENSION.clone(), PRELUDE.clone()]); //! } //! fn get_gate(gate_name: impl Into) -> ExtensionOp { //! EXTENSION -//! .instantiate_extension_op(&gate_name.into(), [], ®) +//! .instantiate_extension_op(&gate_name.into(), []) //! .unwrap() //! .into() //! }