Skip to content

Commit

Permalink
feat!: Resolve extension references inside the extension themselves (#…
Browse files Browse the repository at this point in the history
…1783)

When deserializing an `ExtensionRegistry`, traverses the extension
definitions to update any `Weak<Extension>` so they point to valid
extensions (including to extensions in the registry being
deserialized!).

This needed a bit of `unsafe` trickery to mimic
[`Arc::new_cyclic`](https://doc.rust-lang.org/std/sync/struct.Arc.html#method.new_cyclic)
but for vectors of extensions.

As with `Package`, I left the `serde::Deserialize` implementation
available but now extension registries have a `load_json` method that
should be used when possible, as it runs the required pointer updates.

BREAKING CHANGE: Marked `ExtensionBuildError` and
`ExtensionRegistryError` as non-exhaustive.
  • Loading branch information
aborgna-q authored Dec 16, 2024
1 parent 2d08fc1 commit 1091755
Show file tree
Hide file tree
Showing 8 changed files with 418 additions and 18 deletions.
128 changes: 125 additions & 3 deletions hugr-core/src/extension.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
//! system (outside the `types` module), which also parses nested [`OpDef`]s.
use itertools::Itertools;
use resolution::{ExtensionResolutionError, WeakExtensionRegistry};
pub use semver::Version;
use serde::{Deserialize, Deserializer, Serialize};
use std::cell::UnsafeCell;
use std::collections::btree_map;
use std::collections::{BTreeMap, BTreeSet};
use std::fmt::Debug;
use std::mem;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Weak};
use std::{io, mem};

use derive_more::Display;
use thiserror::Error;
Expand Down Expand Up @@ -83,6 +85,24 @@ impl ExtensionRegistry {
res
}

/// Load an ExtensionRegistry serialized as json.
///
/// After deserialization, updates all the internal `Weak<Extension>`
/// references to point to the newly created [`Arc`]s in the registry,
/// or extensions in the `additional_extensions` parameter.
pub fn load_json(
reader: impl io::Read,
other_extensions: &ExtensionRegistry,
) -> Result<Self, ExtensionRegistryLoadError> {
let extensions: Vec<Extension> = serde_json::from_reader(reader)?;
// After deserialization, we need to update all the internal
// `Weak<Extension>` references.
Ok(ExtensionRegistry::new_with_extension_resolution(
extensions,
&other_extensions.into(),
)?)
}

/// Gets the Extension with the given name
pub fn get(&self, name: &str) -> Option<&Arc<Extension>> {
self.exts.get(name)
Expand Down Expand Up @@ -213,6 +233,86 @@ impl ExtensionRegistry {

self.exts.remove(name)
}

/// Constructs a new ExtensionRegistry from a list of [`Extension`]s while
/// giving you a [`WeakExtensionRegistry`] to the allocation. This allows
/// you to add [`Weak`] self-references to the [`Extension`]s while
/// constructing them, before wrapping them in [`Arc`]s.
///
/// This is similar to [`Arc::new_cyclic`], but for ExtensionRegistries.
///
/// Calling [`Weak::upgrade`] on a weak reference in the
/// [`WeakExtensionRegistry`] inside your closure will return an extension
/// with no internal (op / type / value) definitions.
//
// It may be possible to implement this safely using `Arc::new_cyclic`
// directly, but the callback type does not allow for returning extra
// data so it seems unlikely.
pub fn new_cyclic<F, E>(
extensions: impl IntoIterator<Item = Extension>,
init: F,
) -> Result<Self, E>
where
F: FnOnce(Vec<Extension>, &WeakExtensionRegistry) -> Result<Vec<Extension>, E>,
{
let extensions = extensions.into_iter().collect_vec();

// Unsafe internally-mutable wrapper around an extension. Important:
// `repr(transparent)` ensures the layout is identical to `Extension`,
// so it can be safely transmuted.
#[repr(transparent)]
struct ExtensionCell {
ext: UnsafeCell<Extension>,
}

// Create the arcs with internal mutability, and collect weak references
// over immutable references.
//
// This is safe as long as the cell mutation happens when we can guarantee
// that the weak references are not used.
let (arcs, weaks): (Vec<Arc<ExtensionCell>>, Vec<Weak<Extension>>) = extensions
.iter()
.map(|ext| {
// Create a new arc with an empty extension sharing the name and version of the original,
// but with no internal definitions.
//
// `UnsafeCell` is not sync, but we are not writing to it while the weak references are
// being used.
#[allow(clippy::arc_with_non_send_sync)]
let arc = Arc::new(ExtensionCell {
ext: UnsafeCell::new(Extension::new(ext.name().clone(), ext.version().clone())),
});

// SAFETY: `ExtensionCell` is `repr(transparent)`, so it has the same layout as `Extension`.
let weak_arc: Weak<Extension> = unsafe { mem::transmute(Arc::downgrade(&arc)) };
(arc, weak_arc)
})
.unzip();

let mut weak_registry = WeakExtensionRegistry::default();
for (ext, weak) in extensions.iter().zip(weaks) {
weak_registry.register(ext.name().clone(), weak);
}

// Actual initialization here
// Upgrading the weak references at any point here will access the empty extensions in the arcs.
let extensions = init(extensions, &weak_registry)?;

// We're done.
let arcs: Vec<Arc<Extension>> = arcs
.into_iter()
.zip(extensions)
.map(|(arc, ext)| {
// Replace the dummy extensions with the updated ones.
// SAFETY: The cell is only mutated when the weak references are not used.
unsafe { *arc.ext.get() = ext };
// Pretend the UnsafeCells never existed.
// SAFETY: `ExtensionCell` is `repr(transparent)`, so it has the same layout as `Extension`.
unsafe { mem::transmute::<Arc<ExtensionCell>, Arc<Extension>>(arc) }
})
.collect();
Ok(ExtensionRegistry::new(arcs))
}
}

impl IntoIterator for ExtensionRegistry {
Expand Down Expand Up @@ -251,8 +351,10 @@ impl Extend<Arc<Extension>> for ExtensionRegistry {
}
}

// Encode/decode ExtensionRegistry as a list of extensions.
// We can get the map key from the extension itself.
/// Encode/decode ExtensionRegistry as a list of extensions.
///
/// Any `Weak<Extension>` references inside the registry will be left unresolved.
/// Prefer using [`ExtensionRegistry::load_json`] when deserializing.
impl<'de> Deserialize<'de> for ExtensionRegistry {
fn deserialize<D>(deserializer: D) -> Result<ExtensionRegistry, D::Error>
where
Expand Down Expand Up @@ -421,6 +523,11 @@ impl ExtensionValue {
&self.typed_value
}

/// Returns a mutable reference to the typed value of this [`ExtensionValue`].
pub(super) fn typed_value_mut(&mut self) -> &mut ops::Value {
&mut self.typed_value
}

/// Returns a reference to the name of this [`ExtensionValue`].
pub fn name(&self) -> &str {
self.name.as_str()
Expand Down Expand Up @@ -658,6 +765,7 @@ impl PartialEq for Extension {

/// An error that can occur in defining an extension registry.
#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[non_exhaustive]
pub enum ExtensionRegistryError {
/// Extension already defined.
#[error("The registry already contains an extension with id {0} and version {1}. New extension has version {2}.")]
Expand All @@ -667,8 +775,21 @@ pub enum ExtensionRegistryError {
InvalidSignature(ExtensionId, #[source] SignatureError),
}

/// An error that can occur while loading an extension registry.
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum ExtensionRegistryLoadError {
/// Deserialization error.
#[error(transparent)]
SerdeError(#[from] serde_json::Error),
/// Error when resolving internal extension references.
#[error(transparent)]
ExtensionResolutionError(#[from] ExtensionResolutionError),
}

/// An error that can occur in building a new extension.
#[derive(Debug, Clone, Error, PartialEq, Eq)]
#[non_exhaustive]
pub enum ExtensionBuildError {
/// Existing [`OpDef`]
#[error("Extension already has an op called {0}.")]
Expand Down Expand Up @@ -909,6 +1030,7 @@ pub mod test {
assert!(reg.remove_extension(&ext_1_id).unwrap().version() == &Version::new(1, 1, 0));
assert_eq!(reg.len(), 1);
}

mod proptest {

use ::proptest::{collection::hash_set, prelude::*};
Expand Down
15 changes: 15 additions & 0 deletions hugr-core/src/extension/op_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ impl CustomValidator {
validate: Box::new(validate),
}
}

/// Return a mutable reference to the PolyFuncType.
pub(super) fn poly_func_mut(&mut self) -> &mut PolyFuncTypeRV {
&mut self.poly_func
}
}

/// The ways in which an OpDef may compute the Signature of each operation node.
Expand Down Expand Up @@ -407,6 +412,11 @@ impl OpDef {
self.extension_ref.clone()
}

/// Returns a mutable reference to the weak extension pointer in the operation definition.
pub(super) fn extension_mut(&mut self) -> &mut Weak<Extension> {
&mut self.extension_ref
}

/// Returns a reference to the description of this [`OpDef`].
pub fn description(&self) -> &str {
self.description.as_ref()
Expand Down Expand Up @@ -469,6 +479,11 @@ impl OpDef {
pub fn signature_func(&self) -> &SignatureFunc {
&self.signature_func
}

/// Returns a mutable reference to the signature function of this [`OpDef`].
pub(super) fn signature_func_mut(&mut self) -> &mut SignatureFunc {
&mut self.signature_func
}
}

impl Extension {
Expand Down
27 changes: 26 additions & 1 deletion hugr-core/src/extension/resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
//! (will) automatically resolve extensions as the operations are created,
//! we will no longer require this post-facto resolution step.
mod extension;
mod ops;
mod types;
mod types_mut;
Expand Down Expand Up @@ -102,12 +103,12 @@ pub enum ExtensionResolutionError {
/// A list of available extensions.
available_extensions: Vec<ExtensionId>,
},
/// A type references an extension that is not in the given registry.
#[display(
"Type {ty}{} requires extension {missing_extension}, but it could not be found in the extension list used during resolution. The available extensions are: {}",
node.map(|n| format!(" in {}", n)).unwrap_or_default(),
available_extensions.join(", ")
)]
/// A type references an extension that is not in the given registry.
MissingTypeExtension {
/// The node that requires the extension.
node: Option<Node>,
Expand All @@ -118,6 +119,30 @@ pub enum ExtensionResolutionError {
/// A list of available extensions.
available_extensions: Vec<ExtensionId>,
},
/// A type definition's `extension_id` does not match the extension it is in.
#[display(
"Type definition {def} in extension {extension} declares it was defined in {wrong_extension} instead."
)]
WrongTypeDefExtension {
/// The extension that defines the type.
extension: ExtensionId,
/// The type definition name.
def: TypeName,
/// The extension declared in the type definition's `extension_id`.
wrong_extension: ExtensionId,
},
/// An operation definition's `extension_id` does not match the extension it is in.
#[display(
"Operation definition {def} in extension {extension} declares it was defined in {wrong_extension} instead."
)]
WrongOpDefExtension {
/// The extension that defines the op.
extension: ExtensionId,
/// The op definition name.
def: OpName,
/// The extension declared in the op definition's `extension_id`.
wrong_extension: ExtensionId,
},
/// The type of an `OpaqueValue` has types which do not reference their defining extensions.
#[display("The type of the opaque value '{value}' requires extensions {missing_extensions}, but does not reference their definition.")]
InvalidConstTypes {
Expand Down
Loading

0 comments on commit 1091755

Please sign in to comment.