From f83db21b9af04a2cefd733b97f13efa16829f062 Mon Sep 17 00:00:00 2001 From: Douglas Wilson <141026920+doug-q@users.noreply.github.com> Date: Fri, 10 Jan 2025 17:36:29 +0000 Subject: [PATCH] feat(tket2-hseries): Lazify more flavours of measure ops (#742) Closes #740 --------- Co-authored-by: Seyon Sivarajah --- Cargo.lock | 1 + tket2-hseries/Cargo.toml | 1 + tket2-hseries/src/extension/qsystem.rs | 2 + tket2-hseries/src/lazify_measure.rs | 390 +++++++++++++++---------- tket2-hseries/src/lib.rs | 45 ++- 5 files changed, 288 insertions(+), 151 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 18f769ee..4d6095bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2056,6 +2056,7 @@ version = "0.7.1" dependencies = [ "clap", "cool_asserts", + "delegate 0.13.1", "derive_more 1.0.0", "hugr", "hugr-cli", diff --git a/tket2-hseries/Cargo.toml b/tket2-hseries/Cargo.toml index ac90c121..0b626058 100644 --- a/tket2-hseries/Cargo.toml +++ b/tket2-hseries/Cargo.toml @@ -39,6 +39,7 @@ derive_more = { workspace = true, features = [ "from", "into", ] } +delegate.workspace = true [dev-dependencies] cool_asserts.workspace = true diff --git a/tket2-hseries/src/extension/qsystem.rs b/tket2-hseries/src/extension/qsystem.rs index 5044bd49..f65c05c5 100644 --- a/tket2-hseries/src/extension/qsystem.rs +++ b/tket2-hseries/src/extension/qsystem.rs @@ -23,6 +23,7 @@ use hugr::{ Extension, Wire, }; +use derive_more::Display; use lazy_static::lazy_static; use strum_macros::{EnumIter, EnumString, IntoStaticStr}; @@ -76,6 +77,7 @@ lazy_static! { EnumIter, IntoStaticStr, EnumString, + Display, )] #[allow(missing_docs)] #[non_exhaustive] diff --git a/tket2-hseries/src/lazify_measure.rs b/tket2-hseries/src/lazify_measure.rs index 67a47b0a..e1e0c2f5 100644 --- a/tket2-hseries/src/lazify_measure.rs +++ b/tket2-hseries/src/lazify_measure.rs @@ -3,8 +3,9 @@ //! //! [Tket2Op::Measure]: tket2::Tk2Op::Measure //! [QSystemOp::Measure]: crate::extension::qsystem::QSystemOp::Measure -use std::collections::{HashMap, HashSet}; +use std::{collections::HashMap, iter}; +use delegate::delegate; use derive_more::{Display, Error, From}; use hugr::{ algorithms::{ @@ -12,27 +13,29 @@ use hugr::{ non_local::NonLocalEdgesError, validation::{ValidatePassError, ValidationLevel}, }, - builder::{DFGBuilder, Dataflow, DataflowHugr}, + builder::{DFGBuilder, Dataflow, DataflowHugr as _}, extension::prelude::{bool_t, qb_t}, hugr::{hugrmut::HugrMut, views::SiblingSubgraph, Rewrite, SimpleReplacementError}, + ops::{handle::NodeHandle as _, OpTrait as _}, types::Signature, - Hugr, HugrView, IncomingPort, Node, OutgoingPort, SimpleReplacement, + HugrView, Node, SimpleReplacement, Wire, }; +use itertools::Itertools as _; use tket2::Tk2Op; -use lazy_static::lazy_static; +use crate::extension::{futures::FutureOpBuilder as _, qsystem::QSystemOp}; -use crate::extension::{futures::FutureOpBuilder, qsystem::QSystemOpBuilder}; - -/// A `Hugr -> Hugr` pass that replaces [Tk2Op::Measure] nodes with -/// [QSystemOp::Measure] nodes. To construct a `LazifyMeasurePass` use -/// [Default::default]. +/// A HUGR -> HUGR pass that replaces measurement ops with lazy `tket2.qsystem` +/// measurement ops. /// -/// The `Hugr` must not contain any non-local edges. If validation is enabled, -/// this precondition will be verified. +/// [Tk2Op::Measure], [QSystemOp::Measure], and [QSystemOp::MeasureReset] nodes +/// are replaced by [QSystemOp::LazyMeasure] and [QSystemOp::LazyMeasureReset] +/// nodes. /// -/// [Tket2Op::Measure]: tket2::Tk2Op::Measure -/// [QSystemOp::Measure]: crate::extension::qsystem::QSystemOp::Measure +/// To construct a `LazifyMeasurePass` use [Default::default]. +/// +/// The HUGR must not contain any non-local edges. If validation is enabled, +/// this precondition will be verified. #[derive(Default)] pub struct LazifyMeasurePass(ValidationLevel); @@ -40,10 +43,18 @@ pub struct LazifyMeasurePass(ValidationLevel); #[non_exhaustive] /// An error reported from [LazifyMeasurePass]. pub enum LazifyMeasurePassError { - /// The [Hugr] was invalid either before or after a pass ran. + /// The HUGR was invalid either before or after a pass ran. ValidationError(ValidatePassError), - /// The [Hugr] was found to contain non-local edges. + /// The HUGR was found to contain non-local edges. NonLocalEdgesError(NonLocalEdgesError), + /// A [LazifyMeasureRewrite] was constructed targetting an invalid op. + #[display("A LazifyMeasureRewrite was constructed for node {node} with an invalid signature.\nExpected: {expected_signature}\nActual: {}", actual_signature.as_ref().map_or("None".to_string(), |x| format!("{x}")))] + #[allow(missing_docs)] + InvalidOp { + node: Node, + expected_signature: Signature, + actual_signature: Option, + }, /// A [SimpleReplacement] failed during the running of the pass. SimpleReplacementError(SimpleReplacementError), } @@ -52,17 +63,11 @@ impl LazifyMeasurePass { /// Run `LazifyMeasurePass` on the given [HugrMut]. `registry` is used for /// validation, if enabled. pub fn run(&self, hugr: &mut impl HugrMut) -> Result<(), LazifyMeasurePassError> { - self.0.run_validated_pass(hugr, |hugr, validation_level| { - if validation_level != &ValidationLevel::None { + self.0.run_validated_pass(hugr, |hugr, level| { + if *level != ValidationLevel::None { ensure_no_nonlocal_edges(hugr)?; } - let mut state = State::new(hugr.nodes().filter_map( - |n| match hugr.get_optype(n).cast() { - Some(Tk2Op::Measure) => Some(WorkItem::ReplaceMeasure(n)), - _ => None, - }, - )); - while state.work_one(hugr)? {} + replace_measure_ops(hugr)?; Ok(()) }) } @@ -74,126 +79,204 @@ impl LazifyMeasurePass { } } -enum WorkItem { - ReplaceMeasure(Node), -} +/// Implementation of [LazifyMeasurePass]. +/// +/// No validation is done here. +pub fn replace_measure_ops(hugr: &mut impl HugrMut) -> Result, LazifyMeasurePassError> { + let nodes_and_rewrites = hugr + .nodes() + .filter_map(|n| { + let optype = hugr.get_optype(n); + if let Some(Tk2Op::MeasureFree) = optype.cast() { + Some(LazifyMeasureRewrite::try_new_measure(n, &hugr)) + } else if let Some(QSystemOp::Measure) = optype.cast() { + Some(LazifyMeasureRewrite::try_new_measure(n, &hugr)) + } else if let Some(QSystemOp::MeasureReset) = optype.cast() { + Some(LazifyMeasureRewrite::try_new_measure_reset(n, &hugr)) + } else { + None + } + .map(|x| x.map(|y| (n, y))) + }) + .collect::, _>>()?; -struct State { - worklist: Vec, + nodes_and_rewrites + .into_iter() + .map(|(n, rw)| { + hugr.apply_rewrite(rw)?; + Ok(n) + }) + .try_collect() } -impl State { - fn new(items: impl IntoIterator) -> Self { - let worklist = items.into_iter().collect(); - Self { worklist } - } +/// A rewrite used in [LazifyMeasurePass] to replace strict measure ops with +/// either [QSystemOp::LazyMeasure] or [QSystemOp::LazyMeasureReset]. +pub struct LazifyMeasureRewrite(SimpleReplacement); - fn work_one(&mut self, hugr: &mut impl HugrMut) -> Result { - let Some(item) = self.worklist.pop() else { - return Ok(false); +impl LazifyMeasureRewrite { + /// Construct a new `LazifyMeasureRewrite` replacing `node` with a + /// [QSystemOp::LazyMeasure]. + /// + /// Fails if node does not have signature `[QB] -> [BOOL]` + pub fn try_new_measure( + node: Node, + hugr: impl HugrView, + ) -> Result { + Self::check_signature(node, QSystemOp::LazyMeasure, hugr.get_optype(node))?; + + let subgraph = SiblingSubgraph::from_node(node, &hugr); + let uses = hugr.linked_inputs(node, 0).collect_vec(); + let (lazy_measure_node, replacement) = { + let bool_uses = uses.len(); + let mut builder = + DFGBuilder::new(Signature::new(qb_t(), vec![bool_t(); bool_uses])).unwrap(); + let [qb] = builder.input_wires_arr(); + let (lazy_measure_node, future_wire) = { + let handle = builder + .add_dataflow_op(QSystemOp::LazyMeasure, [qb]) + .unwrap(); + (handle.node(), handle.out_wire(0)) + }; + let out_wires = Self::build_futures_gadget(&mut builder, future_wire, bool_uses); + ( + lazy_measure_node, + builder.finish_hugr_with_outputs(out_wires).unwrap(), + ) }; - self.worklist.extend(item.work(hugr)?); - Ok(true) + let nu_inp = HashMap::from_iter([((lazy_measure_node, 0.into()), (node, 0.into()))]); + let nu_out = iter::zip(uses, (0..).map_into()).collect(); + + Ok(Self(SimpleReplacement::new( + subgraph, + replacement, + nu_inp, + nu_out, + ))) } -} -lazy_static! { - static ref MEASURE_READ_HUGR: Hugr = { - let mut builder = DFGBuilder::new(Signature::new(qb_t(), vec![qb_t(), bool_t()])).unwrap(); - let [qb] = builder.input_wires_arr(); - let [qb, lazy_r] = builder.add_lazy_measure_reset(qb).unwrap(); - let [r] = builder.add_read(lazy_r, bool_t()).unwrap(); - builder.finish_hugr_with_outputs([qb, r]).unwrap() - }; -} + /// Construct a new `LazifyMeasureRewrite` replacing `node` with a + /// [QSystemOp::LazyMeasureReset]. + /// + /// Fails if node does not have signature `[QB] -> [QB,BOOL]` + pub fn try_new_measure_reset( + node: Node, + hugr: impl HugrView, + ) -> Result { + Self::check_signature(node, QSystemOp::LazyMeasureReset, hugr.get_optype(node))?; -fn measure_replacement(num_dups: usize) -> Hugr { - let mut out_types = vec![qb_t()]; - out_types.extend((0..num_dups).map(|_| bool_t())); - let num_out_types = out_types.len(); - let mut builder = DFGBuilder::new(Signature::new(qb_t(), out_types)).unwrap(); - let [qb] = builder.input_wires_arr(); - let [qb, mut future_r] = builder.add_lazy_measure_reset(qb).unwrap(); - let mut future_rs = vec![]; - if num_dups > 0 { - for _ in 0..num_dups - 1 { - let [r1, r2] = builder.add_dup(future_r, bool_t()).unwrap(); - future_rs.push(r1); - future_r = r2; - } - future_rs.push(future_r) - } else { - builder.add_free(future_r, bool_t()).unwrap(); - } - let mut rs = vec![qb]; - rs.extend( - future_rs - .into_iter() - .map(|r| builder.add_read(r, bool_t()).unwrap()[0]), - ); - assert_eq!(num_out_types, rs.len()); - assert_eq!(num_out_types, num_dups + 1); - builder.finish_hugr_with_outputs(rs).unwrap() -} + let subgraph = SiblingSubgraph::from_node(node, &hugr); + let uses = hugr.linked_inputs(node, 1).collect_vec(); + let (lazy_measure_reset_node, replacement) = { + let bool_uses = uses.len(); + let mut builder = { + let outputs = iter::once(qb_t()) + .chain(itertools::repeat_n(bool_t(), bool_uses)) + .collect_vec(); + DFGBuilder::new(Signature::new(qb_t(), outputs)).unwrap() + }; + let [qb] = builder.input_wires_arr(); + let (lazy_measure_reset_node, [qb_wire, future_wire]) = { + let handle = builder + .add_dataflow_op(QSystemOp::LazyMeasureReset, [qb]) + .unwrap(); + (handle.node(), handle.outputs_arr()) + }; + let out_wires = Self::build_futures_gadget(&mut builder, future_wire, bool_uses); + ( + lazy_measure_reset_node, + builder + .finish_hugr_with_outputs(iter::once(qb_wire).chain(out_wires)) + .unwrap(), + ) + }; + let nu_inp = HashMap::from_iter([((lazy_measure_reset_node, 0.into()), (node, 0.into()))]); + let qb_use = hugr.single_linked_input(node, 0).unwrap(); // qubit is linear so this can't fail + let nu_out = iter::zip(iter::once(qb_use).chain(uses), (0..).map_into()).collect(); -fn simple_replace_measure( - hugr: &impl HugrView, - node: Node, -) -> (HashSet<(Node, IncomingPort)>, SimpleReplacement) { - assert!( - hugr.get_optype(node).cast() == Some(Tk2Op::Measure), - "{:?}", - hugr.get_optype(node) - ); - let g = SiblingSubgraph::try_from_nodes([node], hugr).unwrap(); - let num_uses_of_bool = hugr.linked_inputs(node, OutgoingPort::from(1)).count(); - let replacement_hugr = measure_replacement(num_uses_of_bool); - let [i, o] = replacement_hugr.get_io(replacement_hugr.root()).unwrap(); + Ok(Self(SimpleReplacement::new( + subgraph, + replacement, + nu_inp, + nu_out, + ))) + } - // A map from (target ports of edges from the Input node of `replacement`) to (target ports of - // edges from nodes not in `removal` to nodes in `removal`). - let nu_inp = replacement_hugr - .all_linked_inputs(i) - .map(|(n, p)| ((n, p), (node, p))) - .collect(); + fn build_futures_gadget(builder: &mut impl Dataflow, wire: Wire, num_uses: usize) -> Vec { + let future_wires = if num_uses == 0 { + builder.add_free(wire, bool_t()).unwrap(); + vec![] + } else { + let mut future_wires = vec![wire]; + for _ in 1..num_uses { + let prev_wire = future_wires.last_mut().unwrap(); + let [wire1, wire2] = builder.add_dup(*prev_wire, bool_t()).unwrap(); + *prev_wire = wire1; + future_wires.push(wire2); + } + future_wires + }; + debug_assert_eq!(future_wires.len(), num_uses); - // qubit is linear, there must be exactly one - let (target_node, target_port) = hugr - .single_linked_input(node, OutgoingPort::from(0)) - .unwrap(); - // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to - // (input ports of the Output node of `replacement`). - let mut nu_out: HashMap<_, _> = [((target_node, target_port), IncomingPort::from(0))] - .into_iter() - .collect(); - nu_out.extend( - hugr.linked_inputs(node, OutgoingPort::from(1)) - .enumerate() - .map(|(i, target)| (target, IncomingPort::from(i + 1))), - ); - assert_eq!(nu_out.len(), 1 + num_uses_of_bool); - assert_eq!(nu_out.len(), replacement_hugr.in_value_types(o).count()); + future_wires + .into_iter() + .map(|w| builder.add_read(w, bool_t()).unwrap()[0]) + .collect_vec() + } - let nu_out_set = nu_out.keys().copied().collect(); - ( - nu_out_set, - SimpleReplacement::new(g, replacement_hugr, nu_inp, nu_out), - ) + // We check that the signature of `op_to_replace` is correct, given the + // `qsystem_op` we intend to replace it with. + // + // Note that calling this private function with a non-sensical `qsystem_op` + // (i.e. not LazyMeasure or LazyMeasureReset) will panic. + fn check_signature( + node: Node, + qsystem_op: QSystemOp, + op_to_replace: &hugr::ops::OpType, + ) -> Result<(), LazifyMeasurePassError> { + let actual_signature = op_to_replace.dataflow_signature().map(|x| x.into_owned()); + match qsystem_op { + QSystemOp::LazyMeasure => { + let expected_signature = Signature::new(qb_t(), bool_t()); + if !actual_signature + .as_ref() + .is_some_and(|x| x.io() == expected_signature.io()) + { + Err(LazifyMeasurePassError::InvalidOp { + node, + expected_signature, + actual_signature, + })? + } + } + QSystemOp::LazyMeasureReset => { + let expected_signature = Signature::new(qb_t(), vec![qb_t(), bool_t()]); + if !actual_signature + .as_ref() + .is_some_and(|x| x.io() == expected_signature.io()) + { + Err(LazifyMeasurePassError::InvalidOp { + node, + expected_signature, + actual_signature, + })? + } + } + op => panic!("bug: {op} is unsupported"), + } + Ok(()) + } } -impl WorkItem { - fn work( - self, - hugr: &mut impl HugrMut, - ) -> Result, LazifyMeasurePassError> { - match self { - Self::ReplaceMeasure(node) => { - // for now we read immediately, but when we don't the first - // results are the linked inputs we must return - let (_, replace) = simple_replace_measure(hugr, node); - replace.apply(hugr)?; - Ok(std::iter::empty()) - } +impl Rewrite for LazifyMeasureRewrite { + type ApplyResult = ::ApplyResult; + type Error = ::Error; + const UNCHANGED_ON_FAILURE: bool = ::UNCHANGED_ON_FAILURE; + + delegate! { + to self.0 { + fn apply(self, hugr: &mut impl HugrMut) -> Result; + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error>; + fn invalidation_set(&self) -> impl Iterator; } } } @@ -201,7 +284,16 @@ impl WorkItem { #[cfg(test)] mod test { - use crate::extension::{futures::FutureOpDef, qsystem::QSystemOp}; + use hugr::{ + builder::{DFGBuilder, Dataflow as _, DataflowHugr as _}, + extension::prelude::qb_t, + types::Signature, + }; + + use crate::extension::{ + futures::FutureOpDef, + qsystem::{QSystemOp, QSystemOpBuilder as _}, + }; use super::*; @@ -209,23 +301,28 @@ mod test { fn simple() { let mut hugr = { let mut builder = - DFGBuilder::new(Signature::new(qb_t(), vec![qb_t(), bool_t()])).unwrap(); - let [qb] = builder.input_wires_arr(); - let outs = builder - .add_dataflow_op(Tk2Op::Measure, [qb]) + DFGBuilder::new(Signature::new(vec![qb_t(), qb_t()], bool_t())).unwrap(); + let [qb1, qb2] = builder.input_wires_arr(); + let [r1] = builder + .add_dataflow_op(Tk2Op::MeasureFree, [qb1]) .unwrap() - .outputs(); - builder.finish_hugr_with_outputs(outs).unwrap() + .outputs_arr(); + let [qb2, _r2] = builder.add_measure_reset(qb2).unwrap(); + let _r3 = builder.add_measure(qb2).unwrap(); + builder.finish_hugr_with_outputs([r1]).unwrap() }; - assert!(hugr.validate_no_extensions().is_ok()); LazifyMeasurePass::default().run(&mut hugr).unwrap(); - assert!(hugr.validate_no_extensions().is_ok()); + hugr.validate().unwrap(); + let mut num_read = 0; + let mut num_lazy_measure = 0; let mut num_lazy_measure_reset = 0; for n in hugr.nodes() { let ot = hugr.get_optype(n); if let Some(FutureOpDef::Read) = ot.cast() { num_read += 1; + } else if let Some(QSystemOp::LazyMeasure) = ot.cast() { + num_lazy_measure += 1; } else if let Some(QSystemOp::LazyMeasureReset) = ot.cast() { num_lazy_measure_reset += 1; } else { @@ -234,23 +331,22 @@ mod test { } assert_eq!(1, num_read); + assert_eq!(2, num_lazy_measure); assert_eq!(1, num_lazy_measure_reset); } #[test] fn multiple_uses() { let mut builder = - DFGBuilder::new(Signature::new(qb_t(), vec![qb_t(), bool_t(), bool_t()])).unwrap(); + DFGBuilder::new(Signature::new(qb_t(), vec![bool_t(), bool_t()])).unwrap(); let [qb] = builder.input_wires_arr(); - let [qb, bool] = builder - .add_dataflow_op(Tk2Op::Measure, [qb]) + let [bool] = builder + .add_dataflow_op(Tk2Op::MeasureFree, [qb]) .unwrap() .outputs_arr(); - let mut hugr = builder.finish_hugr_with_outputs([qb, bool, bool]).unwrap(); - - assert!(hugr.validate_no_extensions().is_ok()); + let mut hugr = builder.finish_hugr_with_outputs([bool, bool]).unwrap(); LazifyMeasurePass::default().run(&mut hugr).unwrap(); - assert!(hugr.validate_no_extensions().is_ok()); + hugr.validate().unwrap(); } #[test] @@ -262,8 +358,6 @@ mod test { .unwrap() .outputs_arr(); let mut hugr = builder.finish_hugr_with_outputs([qb]).unwrap(); - - assert!(hugr.validate_no_extensions().is_ok()); LazifyMeasurePass::default().run(&mut hugr).unwrap(); assert!(hugr.validate_no_extensions().is_ok()); } diff --git a/tket2-hseries/src/lib.rs b/tket2-hseries/src/lib.rs index 44e1a853..ad0eca27 100644 --- a/tket2-hseries/src/lib.rs +++ b/tket2-hseries/src/lib.rs @@ -37,6 +37,8 @@ pub struct QSystemPass { validation_level: ValidationLevel, constant_fold: bool, monomorphize: bool, + force_order: bool, + lazify: bool, } impl Default for QSystemPass { @@ -45,6 +47,8 @@ impl Default for QSystemPass { validation_level: ValidationLevel::default(), constant_fold: false, monomorphize: true, + force_order: true, + lazify: true, } } } @@ -86,7 +90,16 @@ impl QSystemPass { self.constant_fold().run(hugr)?; } self.lower_tk2().run(hugr)?; - self.lazify_measure().run(hugr)?; + if self.lazify { + self.lazify_measure().run(hugr)?; + } + if self.force_order { + self.force_order(hugr)?; + } + Ok(()) + } + + fn force_order(&self, hugr: &mut Hugr) -> Result<(), QSystemPassError> { self.validation_level.run_validated_pass(hugr, |hugr, _| { force_order(hugr, hugr.root(), |hugr, node| { let optype = hugr.get_optype(node); @@ -101,8 +114,7 @@ impl QSystemPass { } })?; Ok::<_, QSystemPassError>(()) - })?; - Ok(()) + }) } fn lower_tk2(&self) -> LowerTket2ToQSystemPass { @@ -129,6 +141,8 @@ impl QSystemPass { /// Returns a new `QSystemPass` with constant folding enabled according to /// `constant_fold`. + /// + /// Off by default. pub fn with_constant_fold(mut self, constant_fold: bool) -> Self { self.constant_fold = constant_fold; self @@ -136,10 +150,35 @@ impl QSystemPass { /// Returns a new `QSystemPass` with monomorphization enabled according to /// `monomorphize`. + /// + /// On by default. pub fn with_monormophize(mut self, monomorphize: bool) -> Self { self.monomorphize = monomorphize; self } + + /// Returns a new `QSystemPass` with forcing the HUGR to have + /// totally-ordered ops enabled according to `force_order`. + /// + /// On by default. + /// + /// When enabled, we push quantum ops as early as possible, and we push + /// `tket2.futures.read` ops as late as possible. + pub fn with_force_order(mut self, force_order: bool) -> Self { + self.force_order = force_order; + self + } + + /// Returns a new `QSystemPass` with lazification enabled according to `lazify`. + /// + /// On by default. + /// + /// When enabled we replace strict measurement ops with lazy equivalents + /// from `tket2.qsystem`. + pub fn with_lazify(mut self, lazify: bool) -> Self { + self.lazify = lazify; + self + } } #[cfg(test)]