Skip to content

Commit

Permalink
Bugfixes in model import and export.
Browse files Browse the repository at this point in the history
  • Loading branch information
zrho committed Jan 10, 2025
1 parent 774a13e commit 7c4edc6
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 25 deletions.
29 changes: 26 additions & 3 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
type_param::{TypeArgVariable, TypeParam},
type_row::TypeRowBase,
CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg,
TypeBase, TypeBound, TypeEnum,
TypeBase, TypeBound, TypeEnum, TypeRow,
},
Direction, Hugr, HugrView, IncomingPort, Node, Port,
};
Expand Down Expand Up @@ -728,8 +728,31 @@ impl<'a> Context<'a> {
}

// Get the signature of the control flow region.
// This is the same as the signature of the parent node.
let signature = Some(self.export_func_type(&self.hugr.signature(node).unwrap()));
let signature = {
let node_signature = self.hugr.signature(node).unwrap();

let mut wrap_ctrl = |types: &TypeRow| {
let types = self.export_type_row(types);
let types_ctrl = self.make_term(model::Term::Control { values: types });
self.make_term(model::Term::List {
parts: self
.bump
.alloc_slice_copy(&[model::ListPart::Item(types_ctrl)]),
})
};

let inputs = wrap_ctrl(node_signature.input());
let outputs = wrap_ctrl(node_signature.output());
let extensions = self.export_ext_set(&node_signature.runtime_reqs);

let func_type = self.make_term(model::Term::FuncType {
inputs,
outputs,
extensions,
});

Some(func_type)
};

let scope = match closure {
model::ScopeClosure::Closed => {
Expand Down
47 changes: 26 additions & 21 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,24 +216,26 @@ impl<'a> Context<'a> {
}
}

if inputs.is_empty() || outputs.is_empty() {
return Err(error_unsupported!(
"link {}#{} is missing either an input or an output port",
link_id.0,
link_id.1
));
}

// We connect the first output to all the inputs, and the first input to all the outputs
// (except the first one, which we already connected to the first input). This should
// result in the hugr having a (hyper)edge that connects all the ports.
// There should be a better way to do this.
for (node, port) in inputs.iter() {
self.hugr.connect(outputs[0].0, outputs[0].1, *node, *port);
}

for (node, port) in outputs.iter().skip(1) {
self.hugr.connect(*node, *port, inputs[0].0, inputs[0].1);
match (inputs.as_slice(), outputs.as_slice()) {
([], []) => {
unreachable!();
}
(_, [output]) => {
for (node, port) in inputs.iter() {
self.hugr.connect(output.0, output.1, *node, *port);
}
}
([input], _) => {
for (node, port) in outputs.iter() {
self.hugr.connect(*node, *port, input.0, input.1);
}
}
_ => {
return Err(error_unsupported!(
"link {:?} would require hyperedge",
link_id
));
}
}

inputs.clear();
Expand Down Expand Up @@ -995,7 +997,6 @@ impl<'a> Context<'a> {
model::Term::ListType { .. } => Err(error_unsupported!("`(list ...)` as `TypeArg`")),
model::Term::ExtSetType => Err(error_unsupported!("`ext-set` as `TypeArg`")),
model::Term::Type => Err(error_unsupported!("`type` as `TypeArg`")),
model::Term::ApplyFull { .. } => Err(error_unsupported!("custom types as `TypeArg`")),
model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeArg`")),
model::Term::StaticType => Err(error_unsupported!("`static` as `TypeArg`")),
model::Term::ControlType => Err(error_unsupported!("`ctrl` as `TypeArg`")),
Expand All @@ -1007,8 +1008,12 @@ impl<'a> Context<'a> {

model::Term::FuncType { .. }
| model::Term::Adt { .. }
| model::Term::Control { .. }
| model::Term::NonLinearConstraint { .. } => {
| model::Term::ApplyFull { .. } => {
let ty = self.import_type(term_id)?;
Ok(TypeArg::Type { ty })
}

model::Term::Control { .. } | model::Term::NonLinearConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
}
Expand Down
2 changes: 1 addition & 1 deletion hugr-core/tests/snapshots/model__roundtrip_cfg.snap
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-cfg.
(signature (-> [?0] [?0] (ext)))
(cfg
[%2] [%3]
(signature (-> [?0] [?0] (ext)))
(signature (-> [(ctrl [?0])] [(ctrl [?0])] (ext)))
(block [%2] [%6]
(signature (-> [(ctrl [?0])] [(ctrl [?0])] (ext)))
(dfg
Expand Down

0 comments on commit 7c4edc6

Please sign in to comment.