From 5dc030cf478c00bbcca4a269b20a7e78148240b6 Mon Sep 17 00:00:00 2001 From: Samuel Laferriere Date: Fri, 22 Nov 2024 16:37:50 +0400 Subject: [PATCH 1/3] style: just some small refactors/fixes/added comments from first reading of this lib --- src/helpers.rs | 23 +++++++++++++++++------ src/kzg.rs | 46 ++++++++++++++++++++++++++++++---------------- src/polynomial.rs | 5 ++++- 3 files changed, 51 insertions(+), 23 deletions(-) diff --git a/src/helpers.rs b/src/helpers.rs index 4688a10..a3d0e02 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -51,13 +51,13 @@ pub fn convert_by_padding_empty_byte(data: &[u8]) -> Vec { let mut end = (i + 1) * parse_size; if end > data_size { end = data_size; - valid_end = end - start + 1 + i * put_size; + valid_end = put_size * (i + 1); } // Set the first byte of each chunk to 0 valid_data[i * BYTES_PER_FIELD_ELEMENT] = 0x00; // Copy data from original to new vector, adjusting for the initial zero byte - valid_data[i * BYTES_PER_FIELD_ELEMENT + 1..i * BYTES_PER_FIELD_ELEMENT + 1 + end - start] + valid_data[i * BYTES_PER_FIELD_ELEMENT + 1..(i + 1) * BYTES_PER_FIELD_ELEMENT] .copy_from_slice(&data[start..end]); } @@ -311,7 +311,7 @@ pub fn process_chunks(receiver: Receiver<(Vec, usize, bool)>) -> Vec<(T, where T: ReadPointFromBytes, { - #[allow(clippy::unnecessary_filter_map)] + // TODO: should we use rayon to process this in parallel? receiver .iter() .map(|(chunk, position, is_native)| { @@ -321,7 +321,6 @@ where } else { T::read_point_from_bytes_be(&chunk).expect("Failed to read point from bytes") }; - (point, position) }) .collect() @@ -378,7 +377,17 @@ pub fn is_on_curve_g2(g2: &G2Projective) -> bool { left == right } -pub fn check_directory + std::fmt::Display>(path: P) -> Result { +/// Checks if the directory specified by `path`: +/// 1. can be written to by creating a temporary file, writing to it, and deleting it. +/// 2. has [REQUIRED_FREE_SPACE] to store the required amount of data. +/// +/// # Arguments +/// * `path` - The directory path to check +/// +/// # Returns +/// * `Ok(true)` if directory is writable and has enough space +/// * `Err(String)` with error description if checks fail +pub fn check_directory + std::fmt::Display>(path: P) -> Result<(), String> { let test_file_path = path.as_ref().join("cache_dir_write_test.tmp"); // Try to create and write to a temporary file @@ -401,6 +410,8 @@ pub fn check_directory + std::fmt::Display>(path: P) -> Result info, Err(_) => return Err(format!("unable to get disk information for path {}", path)), @@ -415,5 +426,5 @@ pub fn check_directory + std::fmt::Display>(path: P) -> Result info, - Err(err) => return Err(KzgError::GenericError(err)), - }; + if let Err(err) = check_directory(&cache_dir) { + return Err(KzgError::GenericError(err)); + } } let g1_points = Self::parallel_read_g1_points(path_to_g1_points.to_owned(), srs_points_to_load, false) .map_err(|e| KzgError::SerializationError(e.to_string()))?; - let g2_points_result: Result, KzgError> = + let g2_points: Vec = match (path_to_g2_points.is_empty(), g2_power_of2_path.is_empty()) { (false, _) => Self::parallel_read_g2_points( path_to_g2_points.to_owned(), srs_points_to_load, false, ) - .map_err(|e| KzgError::SerializationError(e.to_string())), - (_, false) => Self::read_g2_point_on_power_of_2(g2_power_of2_path), + .map_err(|e| KzgError::SerializationError(e.to_string()))?, + (_, false) => Self::read_g2_point_on_power_of_2(g2_power_of2_path)?, (true, true) => { return Err(KzgError::GenericError( "both g2 point files are empty, need the proper file specified".to_string(), @@ -80,8 +79,6 @@ impl Kzg { }, }; - let g2_points = g2_points_result?; - Ok(Self { g1: g1_points, g2: g2_points, @@ -282,6 +279,7 @@ impl Kzg { } /// read files in chunks with specified length + /// TODO: chunks seems misleading here, since we read one field element at a time. fn read_file_chunks( file_path: &str, sender: Sender<(Vec, usize, bool)>, @@ -295,6 +293,8 @@ impl Kzg { let mut buffer = vec![0u8; point_size]; let mut i = 0; + // We are making one syscall per field element, which is super inefficient. + // FIXME: read the entire file into memory and then split it into field elements. while let Ok(bytes_read) = reader.read(&mut buffer) { if bytes_read == 0 { break; @@ -358,11 +358,25 @@ impl Kzg { Ok(all_points.iter().map(|(point, _)| *point).collect()) } + /// read G1 points in parallel, by creating one reader thread, which reads bytes from the file, + /// and fans them out to worker threads (one per cpu) which parse the bytes into G1Affine points. + /// The worker threads then fan in the parsed points to the main thread, which sorts them by + /// their original position in the file to maintain order. + /// + /// # Arguments + /// * `file_path` - The path to the file containing the G1 points + /// * `srs_points_to_load` - The number of points to load from the file + /// * `is_native` - Whether the points are in native arkworks format or not + /// + /// # Returns + /// * `Ok(Vec)` - The G1 points read from the file + /// * `Err(KzgError)` - An error occurred while reading the file pub fn parallel_read_g1_points_native( file_path: String, srs_points_to_load: u32, is_native: bool, ) -> Result, KzgError> { + // Channel contains (bytes, position, is_native) tuples. The position is used to reorder the points after processing them. let (sender, receiver) = bounded::<(Vec, usize, bool)>(1000); // Spawning the reader thread @@ -413,7 +427,7 @@ impl Kzg { let (sender, receiver) = bounded::<(Vec, usize, bool)>(1000); // Spawning the reader thread - let reader_thread = std::thread::spawn( + let reader_handle = std::thread::spawn( move || -> Result<(), Box> { Self::read_file_chunks(&file_path, sender, 32, srs_points_to_load, is_native) .map_err(|e| -> Box { Box::new(e) }) @@ -422,7 +436,7 @@ impl Kzg { let num_workers = num_cpus::get(); - let workers: Vec<_> = (0..num_workers) + let worker_handles: Vec<_> = (0..num_workers) .map(|_| { let receiver = receiver.clone(); std::thread::spawn(move || helpers::process_chunks::(receiver)) @@ -430,7 +444,7 @@ impl Kzg { .collect(); // Wait for the reader thread to finish - match reader_thread.join() { + match reader_handle.join() { Ok(result) => match result { Ok(_) => {}, Err(e) => return Err(KzgError::GenericError(e.to_string())), @@ -440,8 +454,8 @@ impl Kzg { // Collect and sort results let mut all_points = Vec::new(); - for worker in workers { - let points = worker.join().expect("Worker thread panicked"); + for handle in worker_handles { + let points = handle.join().expect("Worker thread panicked"); all_points.extend(points); } @@ -600,7 +614,7 @@ impl Kzg { z_fr: Fr, eval_fr: &[Fr], value_fr: Fr, - roots_of_unities: &[Fr], + roots_of_unity: &[Fr], ) -> Fr { let mut quotient = Fr::zero(); let mut fi: Fr = Fr::zero(); @@ -608,7 +622,7 @@ impl Kzg { let mut denominator: Fr = Fr::zero(); let mut temp: Fr = Fr::zero(); - roots_of_unities + roots_of_unity .iter() .enumerate() .for_each(|(i, omega_i)| { diff --git a/src/polynomial.rs b/src/polynomial.rs index 911bffb..0595f3a 100644 --- a/src/polynomial.rs +++ b/src/polynomial.rs @@ -12,6 +12,9 @@ pub enum PolynomialFormat { #[derive(Clone, Debug, PartialEq)] pub struct Polynomial { elements: Vec, + // TODO: Remove this field, its just a duplicate of length_of_padded_blob_as_fr_vector. + // One can easily convert between them by *4 or /4. Also it should be calculated and not passed in, + // which is error prone (user might think length is in field elements). length_of_padded_blob: usize, length_of_padded_blob_as_fr_vector: usize, form: PolynomialFormat, @@ -20,7 +23,7 @@ pub struct Polynomial { impl Polynomial { /// Constructs a new `Polynomial` with a given vector of `Fr` elements. pub fn new( - elements: &Vec, + elements: &[Fr], length_of_padded_blob: usize, form: PolynomialFormat, ) -> Result { From 8a1133c92350dc636900a6893019cf2b0afe9e2d Mon Sep 17 00:00:00 2001 From: Samuel Laferriere Date: Fri, 22 Nov 2024 21:36:41 +0400 Subject: [PATCH 2/3] fix(tests): revert my stupid convert_by_padding_empty_byte changes --- src/helpers.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/helpers.rs b/src/helpers.rs index a3d0e02..85e99f5 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -51,13 +51,13 @@ pub fn convert_by_padding_empty_byte(data: &[u8]) -> Vec { let mut end = (i + 1) * parse_size; if end > data_size { end = data_size; - valid_end = put_size * (i + 1); + valid_end = end - start + 1 + i * put_size; } // Set the first byte of each chunk to 0 valid_data[i * BYTES_PER_FIELD_ELEMENT] = 0x00; // Copy data from original to new vector, adjusting for the initial zero byte - valid_data[i * BYTES_PER_FIELD_ELEMENT + 1..(i + 1) * BYTES_PER_FIELD_ELEMENT] + valid_data[i * BYTES_PER_FIELD_ELEMENT + 1..i * BYTES_PER_FIELD_ELEMENT + 1 + end - start] .copy_from_slice(&data[start..end]); } From bef408a74da3167cd9ab0c447eab2d547b8b21f9 Mon Sep 17 00:00:00 2001 From: Samuel Laferriere Date: Fri, 22 Nov 2024 23:59:07 +0400 Subject: [PATCH 3/3] docs: expand on FIXME comment --- src/kzg.rs | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/kzg.rs b/src/kzg.rs index b9b9942..8eb4376 100644 --- a/src/kzg.rs +++ b/src/kzg.rs @@ -294,7 +294,8 @@ impl Kzg { let mut i = 0; // We are making one syscall per field element, which is super inefficient. - // FIXME: read the entire file into memory and then split it into field elements. + // FIXME: Read the entire file (or large segments) into memory and then split it into field elements. + // Entire G1 file might be ~8GiB, so might not fit in RAM. while let Ok(bytes_read) = reader.read(&mut buffer) { if bytes_read == 0 { break; @@ -622,20 +623,17 @@ impl Kzg { let mut denominator: Fr = Fr::zero(); let mut temp: Fr = Fr::zero(); - roots_of_unity - .iter() - .enumerate() - .for_each(|(i, omega_i)| { - if *omega_i == z_fr { - return; - } - fi = eval_fr[i] - value_fr; - numerator = fi * omega_i; - denominator = z_fr - omega_i; - denominator *= z_fr; - temp = numerator.div(denominator); - quotient += temp; - }); + roots_of_unity.iter().enumerate().for_each(|(i, omega_i)| { + if *omega_i == z_fr { + return; + } + fi = eval_fr[i] - value_fr; + numerator = fi * omega_i; + denominator = z_fr - omega_i; + denominator *= z_fr; + temp = numerator.div(denominator); + quotient += temp; + }); quotient }