diff --git a/.gitignore b/.gitignore index 60143d7..e7b44ed 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target /data/* +config diff --git a/src/common.rs b/src/common.rs index 6ac756b..bfc73be 100644 --- a/src/common.rs +++ b/src/common.rs @@ -39,6 +39,10 @@ pub(crate) struct Args { /// Option whether histogram should be used as an alternative evaluation method. #[arg(long)] pub(crate) hist: bool, + + /// Option whether histogram should be used as an alternative evaluation method. + #[arg(long, short)] + pub(crate) config: bool, } pub(crate) fn bit_value_in_block(bit: usize, block: &[u8]) -> bool { diff --git a/src/main.rs b/src/main.rs index ab052c5..e9d5487 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,9 @@ use crate::distinguishers::{ use clap::Parser; use common::{prepare_data, Data}; +use itertools::Itertools; +use std::fs::File; +use std::io::{BufRead, BufReader}; use std::time::Instant; fn print_results(p_value: f64, z_score: f64) { @@ -24,22 +27,22 @@ fn results( testing_data: &Data, patterns_combined: usize, hist: bool, -) { +) -> (f64, f64) { final_patterns.sort_by(|a, b| { f64::abs(b.z_score.unwrap()) .partial_cmp(&f64::abs(a.z_score.unwrap())) .unwrap() }); - let mut best_mp = best_multi_pattern(training_data, &final_patterns, patterns_combined); - - println!("trained in {:.2?}", start.elapsed()); - - println!("training z-score: {}", best_mp.z_score.unwrap()); - println!("best multi-pattern: {best_mp:?}"); if hist { - hist_result(final_patterns, training_data, testing_data); + hist_result(final_patterns, training_data, testing_data, start) } else { + let mut best_mp = best_multi_pattern(training_data, &final_patterns, patterns_combined); + + println!("trained in {:.2?}", start.elapsed()); + + println!("training z-score: {}", best_mp.z_score.unwrap()); + println!("best multi-pattern: {best_mp:?}"); let z_score = evaluate_distinguisher(&mut best_mp, testing_data); let p_value = p_value( best_mp.get_count(), @@ -47,21 +50,28 @@ fn results( best_mp.probability, ); print_results(p_value, z_score); + (p_value, z_score) } } -fn hist_result(final_patterns: Vec, training_data: &Data, testing_data: &Data) { +fn hist_result( + final_patterns: Vec, + training_data: &Data, + testing_data: &Data, + start: Instant, +) -> (f64, f64) { let bits = final_patterns[0].bits.clone(); println!("number of bits: {}", bits.len()); if bits.len() > 20 { println!("Too many bits in pattern, can't produce hist result."); - return; + return (1.0, 0.0); } let hist = Histogram::get_hist(&bits, training_data); + println!("trained in {:.2?}", start.elapsed()); println!("training z-score: {}", hist.z_score); let count = hist.evaluate(testing_data); @@ -71,9 +81,10 @@ fn hist_result(final_patterns: Vec, training_data: &Data, testing_data: let p_val = p_value(count, testing_data.num_of_blocks, prob); print_results(p_val, z); + (p_val, z) } -fn run_bottomup(args: Args) { +fn run_bottomup(args: Args) -> (f64, f64) { let s = Instant::now(); let (training_data, validation_data_option, testing_data_option) = prepare_data( &args.data_source, @@ -92,12 +103,41 @@ fn run_bottomup(args: Args) { &testing_data_option.unwrap(), args.patterns_combined, args.hist, - ); + ) +} + +fn parse_args(s: Vec<&str>) -> Args { + Args { + data_source: s[0].to_string(), + block_size: s[1].trim().parse().unwrap(), + k: s[2].trim().parse().unwrap(), + min_difference: s[3].trim().parse().unwrap(), + patterns_combined: s[4].trim().parse().unwrap(), + base_pattern_size: s[5].trim().parse().unwrap(), + validation_and_testing_split: s[6].trim().parse().unwrap(), + hist: s[7].trim().parse().unwrap(), + config: false, + } } fn main() { let args = Args::parse(); - println!("\n{args:?}\n"); - - run_bottomup(args); + if args.config { + let file = File::open(args.data_source).unwrap(); + + let reader = BufReader::new(file); + let mut results = Vec::new(); + for line in reader.lines() { + let l = line.unwrap(); + println!("config: {l}"); + let splitted = l.split(',').collect_vec(); + let args = parse_args(splitted); + results.push(run_bottomup(args)); + println!(); + } + println!("{results:?}"); + } else { + println!("\n{args:?}\n"); + run_bottomup(args); + } }