diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 289559539..8c864fa6c 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -7,7 +7,7 @@ use crate::{ type_param::{TypeArgVariable, TypeParam}, type_row::TypeRowBase, CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg, - TypeBase, TypeBound, TypeEnum, + TypeBase, TypeBound, TypeEnum, TypeRow, }, Direction, Hugr, HugrView, IncomingPort, Node, Port, }; @@ -728,8 +728,31 @@ impl<'a> Context<'a> { } // Get the signature of the control flow region. - // This is the same as the signature of the parent node. - let signature = Some(self.export_func_type(&self.hugr.signature(node).unwrap())); + let signature = { + let node_signature = self.hugr.signature(node).unwrap(); + + let mut wrap_ctrl = |types: &TypeRow| { + let types = self.export_type_row(types); + let types_ctrl = self.make_term(model::Term::Control { values: types }); + self.make_term(model::Term::List { + parts: self + .bump + .alloc_slice_copy(&[model::ListPart::Item(types_ctrl)]), + }) + }; + + let inputs = wrap_ctrl(node_signature.input()); + let outputs = wrap_ctrl(node_signature.output()); + let extensions = self.export_ext_set(&node_signature.runtime_reqs); + + let func_type = self.make_term(model::Term::FuncType { + inputs, + outputs, + extensions, + }); + + Some(func_type) + }; let scope = match closure { model::ScopeClosure::Closed => { diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 828047ea2..650d585a2 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -216,24 +216,26 @@ impl<'a> Context<'a> { } } - if inputs.is_empty() || outputs.is_empty() { - return Err(error_unsupported!( - "link {}#{} is missing either an input or an output port", - link_id.0, - link_id.1 - )); - } - - // We connect the first output to all the inputs, and the first input to all the outputs - // (except the first one, which we already connected to the first input). This should - // result in the hugr having a (hyper)edge that connects all the ports. - // There should be a better way to do this. - for (node, port) in inputs.iter() { - self.hugr.connect(outputs[0].0, outputs[0].1, *node, *port); - } - - for (node, port) in outputs.iter().skip(1) { - self.hugr.connect(*node, *port, inputs[0].0, inputs[0].1); + match (inputs.as_slice(), outputs.as_slice()) { + ([], []) => { + unreachable!(); + } + (_, [output]) => { + for (node, port) in inputs.iter() { + self.hugr.connect(output.0, output.1, *node, *port); + } + } + ([input], _) => { + for (node, port) in outputs.iter() { + self.hugr.connect(*node, *port, input.0, input.1); + } + } + _ => { + return Err(error_unsupported!( + "link {:?} would require hyperedge", + link_id + )); + } } inputs.clear(); @@ -996,7 +998,6 @@ impl<'a> Context<'a> { model::Term::ListType { .. } => Err(error_unsupported!("`(list ...)` as `TypeArg`")), model::Term::ExtSetType => Err(error_unsupported!("`ext-set` as `TypeArg`")), model::Term::Type => Err(error_unsupported!("`type` as `TypeArg`")), - model::Term::ApplyFull { .. } => Err(error_unsupported!("custom types as `TypeArg`")), model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeArg`")), model::Term::StaticType => Err(error_unsupported!("`static` as `TypeArg`")), model::Term::ControlType => Err(error_unsupported!("`ctrl` as `TypeArg`")), @@ -1010,8 +1011,12 @@ impl<'a> Context<'a> { model::Term::FuncType { .. } | model::Term::Adt { .. } - | model::Term::Control { .. } - | model::Term::NonLinearConstraint { .. } => { + | model::Term::ApplyFull { .. } => { + let ty = self.import_type(term_id)?; + Ok(TypeArg::Type { ty }) + } + + model::Term::Control { .. } | model::Term::NonLinearConstraint { .. } => { Err(model::ModelError::TypeError(term_id).into()) } } diff --git a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap index d3ed92bc7..cb6666ba9 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_cfg.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_cfg.snap @@ -14,16 +14,19 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg. (signature (-> [?0] [?0] (ext))) (cfg [%2] [%3] - (signature (-> [?0] [?0] (ext))) + (signature (-> [(ctrl [?0])] [(ctrl [?0])] (ext))) (block [%2] [%6] (signature (-> [(ctrl [?0])] [(ctrl [?0])] (ext))) (dfg [%4] [%5] (signature (-> [?0] [(adt [[?0]])] (ext))) (tag 0 [%4] [%5] (signature (-> [?0] [(adt [[?0]])] (ext)))))) - (block [%6] [%3] - (signature (-> [(ctrl [?0])] [(ctrl [?0])] (ext))) + (block [%6] [%3 %9] + (signature (-> [(ctrl [?0])] [(ctrl [?0]) (ctrl [?0])] (ext))) (dfg [%7] [%8] - (signature (-> [?0] [(adt [[?0]])] (ext))) - (tag 0 [%7] [%8] (signature (-> [?0] [(adt [[?0]])] (ext)))))))))) + (signature (-> [?0] [(adt [[?0] [?0]])] (ext))) + (tag + 0 + [%7] [%8] + (signature (-> [?0] [(adt [[?0] [?0]])] (ext)))))))))) diff --git a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap index d7cb2bf01..a6c6a34f4 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_constraints.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_constraints.snap @@ -16,3 +16,9 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cons (forall ?0 type) (where (nonlinear ?0)) [(@ prelude.Array ?0)] [(@ prelude.Array ?0) (@ prelude.Array ?0)] (ext)) + +(define-func util.copy + (forall ?0 type) + (where (nonlinear ?0)) + [?0] [?0 ?0] (ext) + (dfg [%0] [%0 %0] (signature (-> [?0] [?0 ?0] (ext))))) diff --git a/hugr-model/tests/fixtures/model-cfg.edn b/hugr-model/tests/fixtures/model-cfg.edn index 3987450f9..8c25ad91a 100644 --- a/hugr-model/tests/fixtures/model-cfg.edn +++ b/hugr-model/tests/fixtures/model-cfg.edn @@ -9,9 +9,9 @@ (signature (-> [?a] [?a] (ext))) (cfg [%2] [%4] (signature (-> [(ctrl [?a])] [(ctrl [?a])] (ext))) - (block [%2] [%4] - (signature (-> [(ctrl [?a])] [(ctrl [?a])] (ext))) + (block [%2] [%4 %2] + (signature (-> [(ctrl [?a])] [(ctrl [?a]) (ctrl [?a])] (ext))) (dfg [%5] [%6] - (signature (-> [?a] [(adt [[?a]])] (ext))) + (signature (-> [?a] [(adt [[?a] [?a]])] (ext))) (tag 0 [%5] [%6] - (signature (-> [?a] [(adt [[?a]])] (ext)))))))))) + (signature (-> [?a] [(adt [[?a] [?a]])] (ext)))))))))) diff --git a/hugr-model/tests/fixtures/model-constraints.edn b/hugr-model/tests/fixtures/model-constraints.edn index ddfbd659f..8987e4c69 100644 --- a/hugr-model/tests/fixtures/model-constraints.edn +++ b/hugr-model/tests/fixtures/model-constraints.edn @@ -11,3 +11,10 @@ (forall ?t type) (where (nonlinear ?t)) [(@ prelude.Array ?t)] [(@ prelude.Array ?t) (@ prelude.Array ?t)] (ext)) + +(define-func util.copy + (forall ?t type) + (where (nonlinear ?t)) + [?t] [?t ?t] (ext) + (dfg [%0] [%0 %0] + (signature (-> [?t] [?t ?t] (ext)))))