From 7c4edc6f2e291dc1912dbb8086198031e01b0c83 Mon Sep 17 00:00:00 2001 From: Lukas Heidemann Date: Wed, 8 Jan 2025 14:53:58 +0000 Subject: [PATCH] Bugfixes in model import and export. --- hugr-core/src/export.rs | 29 ++++++++++-- hugr-core/src/import.rs | 47 ++++++++++--------- .../tests/snapshots/model__roundtrip_cfg.snap | 2 +- 3 files changed, 53 insertions(+), 25 deletions(-) diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 289559539..8c864fa6c 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -7,7 +7,7 @@ use crate::{ type_param::{TypeArgVariable, TypeParam}, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg, - TypeBase, TypeBound, TypeEnum, + TypeBase, TypeBound, TypeEnum, TypeRow, }, Direction, Hugr, HugrView, IncomingPort, Node, Port, }; @@ -728,8 +728,31 @@ impl<'a> Context<'a> { } // Get the signature of the control flow region. - // This is the same as the signature of the parent node. - let signature = Some(self.export_func_type(&self.hugr.signature(node).unwrap())); + let signature = { + let node_signature = self.hugr.signature(node).unwrap(); + + let mut wrap_ctrl = |types: &TypeRow| { + let types = self.export_type_row(types); + let types_ctrl = self.make_term(model::Term::Control { values: types }); + self.make_term(model::Term::List { + parts: self + .bump + .alloc_slice_copy(&[model::ListPart::Item(types_ctrl)]), + }) + }; + + let inputs = wrap_ctrl(node_signature.input()); + let outputs = wrap_ctrl(node_signature.output()); + let extensions = self.export_ext_set(&node_signature.runtime_reqs); + + let func_type = self.make_term(model::Term::FuncType { + inputs, + outputs, + extensions, + }); + + Some(func_type) + }; let scope = match closure { model::ScopeClosure::Closed => { diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 6a08b4a78..d0c0601de 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -216,24 +216,26 @@ impl<'a> Context<'a> { } } - if inputs.is_empty() || outputs.is_empty() { - return Err(error_unsupported!( - "link {}#{} is missing either an input or an output port", - link_id.0, - link_id.1 - )); - } - - // We connect the first output to all the inputs, and the first input to all the outputs - // (except the first one, which we already connected to the first input). This should - // result in the hugr having a (hyper)edge that connects all the ports. - // There should be a better way to do this. - for (node, port) in inputs.iter() { - self.hugr.connect(outputs[0].0, outputs[0].1, *node, *port); - } - - for (node, port) in outputs.iter().skip(1) { - self.hugr.connect(*node, *port, inputs[0].0, inputs[0].1); + match (inputs.as_slice(), outputs.as_slice()) { + ([], []) => { + unreachable!(); + } + (_, [output]) => { + for (node, port) in inputs.iter() { + self.hugr.connect(output.0, output.1, *node, *port); + } + } + ([input], _) => { + for (node, port) in outputs.iter() { + self.hugr.connect(*node, *port, input.0, input.1); + } + } + _ => { + return Err(error_unsupported!( + "link {:?} would require hyperedge", + link_id + )); + } } inputs.clear(); @@ -995,7 +997,6 @@ impl<'a> Context<'a> { model::Term::ListType { .. } => Err(error_unsupported!("`(list ...)` as `TypeArg`")), model::Term::ExtSetType => Err(error_unsupported!("`ext-set` as `TypeArg`")), model::Term::Type => Err(error_unsupported!("`type` as `TypeArg`")), - model::Term::ApplyFull { .. } => Err(error_unsupported!("custom types as `TypeArg`")), model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeArg`")), model::Term::StaticType => Err(error_unsupported!("`static` as `TypeArg`")), model::Term::ControlType => Err(error_unsupported!("`ctrl` as `TypeArg`")), @@ -1007,8 +1008,12 @@ impl<'a> Context<'a> { model::Term::FuncType { .. } | model::Term::Adt { .. } - | model::Term::Control { .. } - | model::Term::NonLinearConstraint { .. } => { + | model::Term::ApplyFull { .. } => { + let ty = self.import_type(term_id)?; + Ok(TypeArg::Type { ty }) + } + + model::Term::Control { .. } | model::Term::NonLinearConstraint { .. } => { Err(model::ModelError::TypeError(term_id).into()) } } diff --git a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap index d3ed92bc7..02f827a80 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap @@ -14,7 +14,7 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg. (signature (-> [?0] [?0] (ext))) (cfg [%2] [%3] - (signature (-> [?0] [?0] (ext))) + (signature (-> [(ctrl [?0])] [(ctrl [?0])] (ext))) (block [%2] [%6] (signature (-> [(ctrl [?0])] [(ctrl [?0])] (ext))) (dfg