diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs index 94c9ce3e58..7bca1d69de 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/regen.rs @@ -1,6 +1,6 @@ use crate::dag::operator::operator::Operator; use crate::dag::operator::OperatorIndex; -use crate::dag::unparametrized::Dag; +use crate::dag::unparametrized::{Dag, DagBuilder}; fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator { let mut op = op.clone(); @@ -20,14 +20,14 @@ fn reindex_op_inputs(op: &Operator, old_index_to_new: &[usize]) -> Operator { pub(crate) fn regen( dag: &Dag, - f: &mut dyn FnMut(usize, &Operator, &mut Dag) -> Option, + f: &mut dyn FnMut(usize, &Operator, &mut DagBuilder<'_>) -> Option, ) -> (Dag, Vec>) { let mut regen_dag = Dag::new(); let mut old_index_to_new = vec![]; for (i, op) in dag.operators.iter().enumerate() { let op = reindex_op_inputs(op, &old_index_to_new); let size = regen_dag.operators.len(); - if let Some(op_i) = f(i, &op, &mut regen_dag) { + if let Some(op_i) = f(i, &op, &mut regen_dag.builder(dag.circuit_tags[i].clone())) { old_index_to_new.push(op_i.0); } else { assert!(size == regen_dag.operators.len()); @@ -37,6 +37,8 @@ pub(crate) fn regen( regen_dag.out_shapes.push(dag.out_shapes[i].clone()); regen_dag.output_tags.push(dag.output_tags[i]); regen_dag.circuit_tags.push(dag.circuit_tags[i].clone()); + op.get_inputs_iter() + .for_each(|n| regen_dag.output_tags[n.0].use_as_input()); } } (regen_dag, instructions_multi_map(&old_index_to_new)) diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/round.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/round.rs index c243a33fa1..92b40d85f2 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/round.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/rewrite/round.rs @@ -1,9 +1,9 @@ use crate::dag::operator::{Operator, OperatorIndex}; -use crate::dag::unparametrized::Dag; +use crate::dag::unparametrized::{Dag, DagBuilder}; use super::regen::regen; -fn regen_round(_: usize, op: &Operator, dag: &mut Dag) -> Option { +fn regen_round(_: usize, op: &Operator, dag: &mut DagBuilder<'_>) -> Option { match *op { Operator::Round { input, diff --git a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs index 6c2f3c7e69..54f626ef63 100644 --- a/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs +++ b/compilers/concrete-optimizer/concrete-optimizer/src/dag/unparametrized.rs @@ -7,6 +7,24 @@ use std::{collections::HashSet, fmt}; /// The name of the default. Used when adding operations directly on the dag instead of via a builder. const DEFAULT_CIRCUIT: &str = "_"; +#[derive(Debug, Clone, PartialEq, Copy)] +pub(crate) enum OutputKind { + // The operator was actively marked as being an output. Using it as input should not turn it into an input. + TaggedAsOutput, + // The operator was initialized as being an output. Using it as input should turn it into an input. + InitializedAsOutput, + // The operator was used as input to another operator. + UsedAsInput, +} + +impl OutputKind { + pub(crate) fn use_as_input(&mut self) { + if let OutputKind::InitializedAsOutput = self { + *self = OutputKind::UsedAsInput; + } + } +} + /// A type referencing every informations related to an operator of the dag. #[derive(Debug, Clone)] #[allow(unused)] @@ -16,7 +34,7 @@ pub(crate) struct DagOperator<'dag> { pub(crate) operator: &'dag Operator, pub(crate) shape: &'dag Shape, pub(crate) precision: &'dag Precision, - pub(crate) output_tag: &'dag bool, + pub(crate) output_tag: &'dag OutputKind, pub(crate) circuit_tag: &'dag String, } @@ -28,7 +46,10 @@ impl<'dag> DagOperator<'dag> { /// Returns if the operator is an output. pub(crate) fn is_output(&self) -> bool { - *self.output_tag + match self.output_tag { + OutputKind::TaggedAsOutput | OutputKind::InitializedAsOutput => true, + OutputKind::UsedAsInput => false, + } } /// Returns an iterator over the operators used as input to this operator. @@ -96,9 +117,9 @@ impl<'dag> DagBuilder<'dag> { self.dag.out_shapes.push(self.infer_out_shape(&operator)); operator .get_inputs_iter() - .for_each(|id| self.dag.output_tags[id.0] = false); + .for_each(|id| self.dag.output_tags[id.0].use_as_input()); self.dag.operators.push(operator); - self.dag.output_tags.push(true); + self.dag.output_tags.push(OutputKind::InitializedAsOutput); self.dag.circuit_tags.push(self.circuit.clone()); OperatorIndex(i) } @@ -296,7 +317,7 @@ impl<'dag> DagBuilder<'dag> { pub fn tag_operator_as_output(&mut self, operator: OperatorIndex) { assert!(operator.0 < self.dag.len()); debug_assert!(self.dag.circuit_tags[operator.0] == self.circuit); - self.dag.output_tags[operator.0] = true; + self.dag.output_tags[operator.0] = OutputKind::TaggedAsOutput; } pub fn get_circuit(&self) -> DagCircuit<'_> { @@ -373,7 +394,7 @@ pub struct Dag { // Collect all operators output precision pub(crate) out_precisions: Vec, // Collect whether operators are tagged as outputs - pub(crate) output_tags: Vec, + pub(crate) output_tags: Vec, // Collect the circuit the operators are associated with pub(crate) circuit_tags: Vec, } @@ -633,7 +654,7 @@ impl Dag { /// tagged using this method. pub fn tag_operator_as_output(&mut self, operator: OperatorIndex) { assert!(operator.0 < self.len()); - self.output_tags[operator.0] = true; + self.output_tags[operator.0] = OutputKind::TaggedAsOutput; } /// Returns the number of circuits in the dag. @@ -647,6 +668,18 @@ mod tests { use super::*; use crate::dag::operator::Shape; + #[test] + fn output_marking() { + let mut graph = Dag::new(); + let mut builder = graph.builder("main1"); + let a = builder.add_input(1, Shape::number()); + let b = builder.add_input(1, Shape::number()); + builder.tag_operator_as_output(b); + let _ = builder.add_dot([a, b], [1, 1]); + assert!(graph.get_operator(b).is_output()); + assert!(!graph.get_operator(a).is_output()); + } + #[test] #[allow(clippy::many_single_char_names)] fn graph_builder() {