Skip to content

Commit

Permalink
polish
Browse files Browse the repository at this point in the history
  • Loading branch information
ShuangWu121 committed Feb 7, 2025
1 parent 85e6df5 commit 41df6e0
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 64 deletions.
21 changes: 11 additions & 10 deletions backend/src/stwo/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub const STAGE0_TRACE_IDX: usize = 1;
pub const STAGE1_TRACE_IDX: usize = 2;

pub type PowdrComponent = FrameworkComponent<PowdrEval>;
pub type PublicEntry = (String, String, PolyID, usize, M31);

pub fn gen_stwo_circle_column<B, F>(
domain: CircleDomain,
Expand Down Expand Up @@ -60,15 +61,14 @@ pub struct PowdrEval {
preprocess_col_offset: usize,
// for each stage, for each public input of that stage, the name of the public,
// the name of the witness column that this public is related to, the poly_id, the row index and its value
pub(crate) publics_by_stage: Vec<Vec<(String, String, PolyID, usize)>>,
pub(crate) publics_by_stage: Vec<Vec<PublicEntry>>,
stage0_witness_columns: BTreeMap<PolyID, usize>,
stage1_witness_columns: BTreeMap<PolyID, usize>,
constant_shifted: BTreeMap<PolyID, usize>,
constant_columns: BTreeMap<PolyID, usize>,
// stwo supports maximum 2 stages, challenges are only created after stage 0
pub challenges: BTreeMap<u64, M31>,
poly_stage_map: BTreeMap<PolyID, usize>,
public_values: BTreeMap<String, M31>,
}

impl PowdrEval {
Expand Down Expand Up @@ -116,7 +116,13 @@ impl PowdrEval {
let publics_by_stage = analyzed.get_publics().into_iter().fold(
vec![vec![]; analyzed.stage_count()],
|mut acc, (name, column_name, id, row, stage)| {
acc[stage as usize].push((name, column_name, id, row));
acc[stage as usize].push((
name.clone(),
column_name,
id,
row,
*public_values.get(&name).unwrap(),
));
acc
},
);
Expand All @@ -138,7 +144,6 @@ impl PowdrEval {
constant_columns,
challenges,
poly_stage_map,
public_values,
}
}
}
Expand Down Expand Up @@ -257,7 +262,7 @@ impl FrameworkEval for PowdrEval {

// build selector columns and constraints for publics, for now I am using constant columns as selectors
self.publics_by_stage.iter().flatten().enumerate().for_each(
|(index, (name, _, poly_id, _))| {
|(index, (_, _, poly_id, _, value))| {
let selector = eval.get_preprocessed_column(PreprocessedColumn::Plonk(
index
+ constant_eval.len()
Expand All @@ -273,11 +278,7 @@ impl FrameworkEval for PowdrEval {
};

// constraining s(i) * (pub[i] - x(i)) = 0
eval.add_constraint(
selector
* (E::F::from(into_stwo_field(self.public_values.get(name).unwrap()))
- witness_col),
);
eval.add_constraint(selector * (E::F::from(into_stwo_field(value)) - witness_col));
},
);

Expand Down
94 changes: 40 additions & 54 deletions backend/src/stwo/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub struct StwoProver<B: BackendForChannel<MC> + Send, MC: MerkleChannel, C: Cha

/// Proving key
proving_key: StarkProvingKey<B>,
/// Verifying key placeholder
/// TODO: Add verification key.
_verifying_key: Option<()>,
_channel_marker: PhantomData<C>,
_merkle_channel_marker: PhantomData<MC>,
Expand Down Expand Up @@ -205,6 +205,7 @@ where
.get_publics()
.into_iter()
.map(|(_, _, _, row_id, _)| {
// Create a column with a single 1 at the row_id-th (in circle domain bitreverse order) position
let mut col = Col::<B, BaseField>::zeros(1 << log_size);
col.set(
bit_reverse_index(
Expand Down Expand Up @@ -274,14 +275,6 @@ where
)
})
.collect();
// TODO:in polonky3, this is built in ConstraintSystem, which is machine-specific, check if it's the same here
let publics_by_stage = self.analyzed.get_publics().into_iter().fold(
vec![vec![]; self.analyzed.stage_count()],
|mut acc, (name, column_name, id, row, stage)| {
acc[stage as usize].push((name, column_name, id, row));
acc
},
);

// Generate witness for stage 0, build constant columns in circle domain at the same time
let mut machine_log_sizes: BTreeMap<String, u32> = BTreeMap::new();
Expand Down Expand Up @@ -327,17 +320,25 @@ where

// get publics of stage0
// TODO:when publics are supplied, the order of the witness might cause problem
let mut public_values: BTreeMap<String, M31> = witness_by_machine
// TODO:in polonky3, this is built in ConstraintSystem, which is machine-specific, check if it's the same here
let publics_by_stage = self.analyzed.get_publics().into_iter().fold(
vec![vec![]; self.analyzed.stage_count()],
|mut acc, (name, column_name, id, row, stage)| {
acc[stage as usize].push((name, column_name, id, row));
acc
},
);

let mut public_values: BTreeMap<String, M31> = publics_by_stage[0]
.iter()
.flat_map(|(_, witness_cols)| {
publics_by_stage[0]
.flat_map(|(name, ref_witness_col_name, _, row)| {
let namespace = ref_witness_col_name.split("::").next().unwrap();
witness_by_machine
.get(namespace)
.unwrap()
.iter()
.filter_map(|(name, ref_witness_col_name, _, row)| {
witness_cols.iter().find_map(|(witness_col_name, v)| {
(ref_witness_col_name == witness_col_name)
.then(|| (name.clone(), v[*row]))
})
})
.filter(move |(witness_col_name, _)| ref_witness_col_name == witness_col_name)
.map(|(_, col)| (name.clone(), col[*row]))
})
.collect();

Expand Down Expand Up @@ -389,7 +390,7 @@ where
if self.analyzed.stage_count() > 1 {
// Build witness columns for stage 1 using the callback function, with the generated challenges

let stage0_witness_names_stage1_witness = witness_by_machine
let stage0_witness_names_stage1_witness_cols = witness_by_machine
.iter()
.map(|(machine_name, machine_witness)| {
(
Expand All @@ -407,43 +408,28 @@ where
})
.collect_vec();

// TODO: previous publics are built with the order in publics_by_stage (find map from witness machine, that matches publics by stage))
// here is with the order in witness_by_machine, (find map from publics by stage that matach the witness. )
// if the orders are different, the publics will be wrong, check
let public_values_stage1: BTreeMap<String, M31> = stage0_witness_names_stage1_witness
.iter()
.flat_map(|(stage0_columns, callback_result)| {
callback_result.iter().filter_map(|(witness_name, vec)| {
if stage0_columns.contains(witness_name) {
None
} else {
publics_by_stage[1].iter().find_map(
|(name, ref_witness_col_name, _, row)| {
(witness_name == ref_witness_col_name)
.then(|| (name.clone(), vec[*row]))
},
)
}
// Get publics of stage 1
let public_values_stage1: BTreeMap<String, M31> =
stage0_witness_names_stage1_witness_cols
.iter()
.flat_map(|(stage0_columns, callback_result)| {
callback_result.iter().filter_map(|(witness_name, vec)| {
if stage0_columns.contains(witness_name) {
None
} else {
publics_by_stage[1].iter().find_map(
|(name, ref_witness_col_name, _, row)| {
(witness_name == ref_witness_col_name)
.then(|| (name.clone(), vec[*row]))
},
)
}
})
})
})
.collect();
.collect();

let stage1_witness_cols_circle_domain_eval = witness_by_machine
.into_iter()
.map(|(machine_name, machine_witness)| {
(
machine_witness
.iter()
.map(|(k, _)| k.clone())
.collect::<BTreeSet<_>>(),
witgen_callback.next_stage_witness(
&self.split[&machine_name],
&machine_witness,
stage0_challenges.clone(),
1,
),
)
})
let stage1_witness_cols_circle_domain_eval = stage0_witness_names_stage1_witness_cols
.iter()
.flat_map(move |(stage0_columns, callback_result)| {
callback_result
.iter()
Expand Down

0 comments on commit 41df6e0

Please sign in to comment.