Skip to content

Commit

Permalink
fix(optimizer): faulty output regeneration
Browse files Browse the repository at this point in the history
  • Loading branch information
aPere3 committed May 2, 2024
1 parent 3417a34 commit 097bc39
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::dag::operator::operator::Operator;
use crate::dag::operator::OperatorIndex;
use crate::dag::unparametrized::Dag;
use crate::dag::unparametrized::{Dag, DagBuilder};

fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator {
let mut op = op.clone();
Expand All @@ -20,14 +20,14 @@ fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator {

pub(crate) fn regen(
dag: &Dag,
f: &mut dyn FnMut(usize, &Operator, &mut Dag) -> Option<OperatorIndex>,
f: &mut dyn FnMut(usize, &Operator, &mut DagBuilder<'_>) -> Option<OperatorIndex>,
) -> (Dag, Vec<Vec<OperatorIndex>>) {
let mut regen_dag = Dag::new();
let mut old_index_to_new = vec![];
for (i, op) in dag.operators.iter().enumerate() {
let op = reindex_op_inputs(op, &old_index_to_new);
let size = regen_dag.operators.len();
if let Some(op_i) = f(i, &op, &mut regen_dag) {
if let Some(op_i) = f(i, &op, &mut regen_dag.builder(dag.circuit_tags[i].clone())) {
old_index_to_new.push(op_i.0);
} else {
assert!(size == regen_dag.operators.len());
Expand All @@ -37,6 +37,8 @@ pub(crate) fn regen(
regen_dag.out_shapes.push(dag.out_shapes[i].clone());
regen_dag.output_tags.push(dag.output_tags[i]);
regen_dag.circuit_tags.push(dag.circuit_tags[i].clone());
op.get_inputs_iter()
.for_each(|n| regen_dag.output_tags[n.0].use_as_input());
}
}
(regen_dag, instructions_multi_map(&old_index_to_new))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::dag::operator::{Operator, OperatorIndex};
use crate::dag::unparametrized::Dag;
use crate::dag::unparametrized::{Dag, DagBuilder};

use super::regen::regen;

fn regen_round(_: usize, op: &Operator, dag: &mut Dag) -> Option<OperatorIndex> {
fn regen_round(_: usize, op: &Operator, dag: &mut DagBuilder<'_>) -> Option<OperatorIndex> {
match *op {
Operator::Round {
input,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,24 @@ use std::{collections::HashSet, fmt};
/// The name of the default. Used when adding operations directly on the dag instead of via a builder.
const DEFAULT_CIRCUIT: &str = "_";

#[derive(Debug, Clone, PartialEq, Copy)]
pub(crate) enum OutputKind {
// The operator was actively marked as being an output. Using it as input should not turn it into an input.
TaggedAsOutput,
// The operator was initialized as being an output. Using it as input should turn it into an input.
InitializedAsOutput,
// The operator was used as input to another operator.
UsedAsInput,
}

impl OutputKind {
pub(crate) fn use_as_input(&mut self) {
if let OutputKind::InitializedAsOutput = self {
*self = OutputKind::UsedAsInput;
}
}
}

/// A type referencing every informations related to an operator of the dag.
#[derive(Debug, Clone)]
#[allow(unused)]
Expand All @@ -16,7 +34,7 @@ pub(crate) struct DagOperator<'dag> {
pub(crate) operator: &'dag Operator,
pub(crate) shape: &'dag Shape,
pub(crate) precision: &'dag Precision,
pub(crate) output_tag: &'dag bool,
pub(crate) output_tag: &'dag OutputKind,
pub(crate) circuit_tag: &'dag String,
}

Expand All @@ -28,7 +46,10 @@ impl<'dag> DagOperator<'dag> {

/// Returns if the operator is an output.
pub(crate) fn is_output(&self) -> bool {
*self.output_tag
match self.output_tag {
OutputKind::TaggedAsOutput | OutputKind::InitializedAsOutput => true,
OutputKind::UsedAsInput => false,
}
}

/// Returns an iterator over the operators used as input to this operator.
Expand Down Expand Up @@ -96,9 +117,9 @@ impl<'dag> DagBuilder<'dag> {
self.dag.out_shapes.push(self.infer_out_shape(&operator));
operator
.get_inputs_iter()
.for_each(|id| self.dag.output_tags[id.0] = false);
.for_each(|id| self.dag.output_tags[id.0].use_as_input());
self.dag.operators.push(operator);
self.dag.output_tags.push(true);
self.dag.output_tags.push(OutputKind::InitializedAsOutput);
self.dag.circuit_tags.push(self.circuit.clone());
OperatorIndex(i)
}
Expand Down Expand Up @@ -296,7 +317,7 @@ impl<'dag> DagBuilder<'dag> {
pub fn tag_operator_as_output(&mut self, operator: OperatorIndex) {
assert!(operator.0 < self.dag.len());
debug_assert!(self.dag.circuit_tags[operator.0] == self.circuit);
self.dag.output_tags[operator.0] = true;
self.dag.output_tags[operator.0] = OutputKind::TaggedAsOutput;
}

pub fn get_circuit(&self) -> DagCircuit<'_> {
Expand Down Expand Up @@ -373,7 +394,7 @@ pub struct Dag {
// Collect all operators output precision
pub(crate) out_precisions: Vec<Precision>,
// Collect whether operators are tagged as outputs
pub(crate) output_tags: Vec<bool>,
pub(crate) output_tags: Vec<OutputKind>,
// Collect the circuit the operators are associated with
pub(crate) circuit_tags: Vec<String>,
}
Expand Down Expand Up @@ -633,7 +654,7 @@ impl Dag {
/// tagged using this method.
pub fn tag_operator_as_output(&mut self, operator: OperatorIndex) {
assert!(operator.0 < self.len());
self.output_tags[operator.0] = true;
self.output_tags[operator.0] = OutputKind::TaggedAsOutput;
}

/// Returns the number of circuits in the dag.
Expand All @@ -647,6 +668,18 @@ mod tests {
use super::*;
use crate::dag::operator::Shape;

#[test]
fn output_marking() {
let mut graph = Dag::new();
let mut builder = graph.builder("main1");
let a = builder.add_input(1, Shape::number());
let b = builder.add_input(1, Shape::number());
builder.tag_operator_as_output(b);
let _ = builder.add_dot([a, b], [1, 1]);
assert!(graph.get_operator(b).is_output());
assert!(!graph.get_operator(a).is_output());
}

#[test]
#[allow(clippy::many_single_char_names)]
fn graph_builder() {
Expand Down

0 comments on commit 097bc39

Please sign in to comment.