Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

style: just some small refactors/fixes/added comments #16

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ pub fn process_chunks<T>(receiver: Receiver<(Vec<u8>, usize, bool)>) -> Vec<(T,
where
T: ReadPointFromBytes,
{
#[allow(clippy::unnecessary_filter_map)]
// TODO: should we use rayon to process this in parallel?
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The process chunks is called from a parallel processor which splits up the entire file to chunks and calls the function and later arranges the results in order, if thats what you were thinking.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no I meant to use par_iter here instead of iter. If you only have a few cpus lets say, then you'll only have a few workers, and each will be processing a lot of points, sequentially. Not sure if all the cpus are saturated already... maybe.

receiver
.iter()
.map(|(chunk, position, is_native)| {
Expand All @@ -321,7 +321,6 @@ where
} else {
T::read_point_from_bytes_be(&chunk).expect("Failed to read point from bytes")
};

(point, position)
})
.collect()
Expand Down Expand Up @@ -378,7 +377,17 @@ pub fn is_on_curve_g2(g2: &G2Projective) -> bool {
left == right
}

pub fn check_directory<P: AsRef<Path> + std::fmt::Display>(path: P) -> Result<bool, String> {
/// Checks if the directory specified by `path`:
/// 1. can be written to by creating a temporary file, writing to it, and deleting it.
/// 2. has [REQUIRED_FREE_SPACE] to store the required amount of data.
///
/// # Arguments
/// * `path` - The directory path to check
///
/// # Returns
/// * `Ok(true)` if directory is writable and has enough space
/// * `Err(String)` with error description if checks fail
pub fn check_directory<P: AsRef<Path> + std::fmt::Display>(path: P) -> Result<(), String> {
let test_file_path = path.as_ref().join("cache_dir_write_test.tmp");

// Try to create and write to a temporary file
Expand All @@ -401,6 +410,8 @@ pub fn check_directory<P: AsRef<Path> + std::fmt::Display>(path: P) -> Result<bo
}

// Get disk information
// FIXME: I don't think this works... the directory might actually be on a separate
// disk than the default one disk_info looks at.
let disk = match disk_info() {
Ok(info) => info,
Err(_) => return Err(format!("unable to get disk information for path {}", path)),
Expand All @@ -415,5 +426,5 @@ pub fn check_directory<P: AsRef<Path> + std::fmt::Display>(path: P) -> Result<bo
path, REQUIRED_FREE_SPACE
));
}
Ok(true)
Ok(())
}
70 changes: 41 additions & 29 deletions src/kzg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,34 +54,31 @@ impl Kzg {
}

if !cache_dir.is_empty() {
match check_directory(&cache_dir) {
Ok(info) => info,
Err(err) => return Err(KzgError::GenericError(err)),
};
if let Err(err) = check_directory(&cache_dir) {
return Err(KzgError::GenericError(err));
}
}

let g1_points =
Self::parallel_read_g1_points(path_to_g1_points.to_owned(), srs_points_to_load, false)
.map_err(|e| KzgError::SerializationError(e.to_string()))?;

let g2_points_result: Result<Vec<G2Affine>, KzgError> =
let g2_points: Vec<G2Affine> =
match (path_to_g2_points.is_empty(), g2_power_of2_path.is_empty()) {
(false, _) => Self::parallel_read_g2_points(
path_to_g2_points.to_owned(),
srs_points_to_load,
false,
)
.map_err(|e| KzgError::SerializationError(e.to_string())),
(_, false) => Self::read_g2_point_on_power_of_2(g2_power_of2_path),
.map_err(|e| KzgError::SerializationError(e.to_string()))?,
(_, false) => Self::read_g2_point_on_power_of_2(g2_power_of2_path)?,
(true, true) => {
return Err(KzgError::GenericError(
"both g2 point files are empty, need the proper file specified".to_string(),
))
},
};

let g2_points = g2_points_result?;

Ok(Self {
g1: g1_points,
g2: g2_points,
Expand Down Expand Up @@ -282,6 +279,7 @@ impl Kzg {
}

/// read files in chunks with specified length
/// TODO: chunks seems misleading here, since we read one field element at a time.
fn read_file_chunks(
file_path: &str,
sender: Sender<(Vec<u8>, usize, bool)>,
Expand All @@ -295,6 +293,9 @@ impl Kzg {
let mut buffer = vec![0u8; point_size];

let mut i = 0;
// We are making one syscall per field element, which is super inefficient.
// FIXME: Read the entire file (or large segments) into memory and then split it into field elements.
// Entire G1 file might be ~8GiB, so might not fit in RAM.
while let Ok(bytes_read) = reader.read(&mut buffer) {
if bytes_read == 0 {
break;
Expand Down Expand Up @@ -358,11 +359,25 @@ impl Kzg {
Ok(all_points.iter().map(|(point, _)| *point).collect())
}

/// read G1 points in parallel, by creating one reader thread, which reads bytes from the file,
/// and fans them out to worker threads (one per cpu) which parse the bytes into G1Affine points.
/// The worker threads then fan in the parsed points to the main thread, which sorts them by
/// their original position in the file to maintain order.
///
/// # Arguments
/// * `file_path` - The path to the file containing the G1 points
/// * `srs_points_to_load` - The number of points to load from the file
/// * `is_native` - Whether the points are in native arkworks format or not
///
/// # Returns
/// * `Ok(Vec<G1Affine>)` - The G1 points read from the file
/// * `Err(KzgError)` - An error occurred while reading the file
pub fn parallel_read_g1_points_native(
file_path: String,
srs_points_to_load: u32,
is_native: bool,
) -> Result<Vec<G1Affine>, KzgError> {
// Channel contains (bytes, position, is_native) tuples. The position is used to reorder the points after processing them.
let (sender, receiver) = bounded::<(Vec<u8>, usize, bool)>(1000);

// Spawning the reader thread
Expand Down Expand Up @@ -413,7 +428,7 @@ impl Kzg {
let (sender, receiver) = bounded::<(Vec<u8>, usize, bool)>(1000);

// Spawning the reader thread
let reader_thread = std::thread::spawn(
let reader_handle = std::thread::spawn(
move || -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Self::read_file_chunks(&file_path, sender, 32, srs_points_to_load, is_native)
.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })
Expand All @@ -422,15 +437,15 @@ impl Kzg {

let num_workers = num_cpus::get();

let workers: Vec<_> = (0..num_workers)
let worker_handles: Vec<_> = (0..num_workers)
.map(|_| {
let receiver = receiver.clone();
std::thread::spawn(move || helpers::process_chunks::<G1Affine>(receiver))
})
.collect();

// Wait for the reader thread to finish
match reader_thread.join() {
match reader_handle.join() {
Ok(result) => match result {
Ok(_) => {},
Err(e) => return Err(KzgError::GenericError(e.to_string())),
Expand All @@ -440,8 +455,8 @@ impl Kzg {

// Collect and sort results
let mut all_points = Vec::new();
for worker in workers {
let points = worker.join().expect("Worker thread panicked");
for handle in worker_handles {
let points = handle.join().expect("Worker thread panicked");
all_points.extend(points);
}

Expand Down Expand Up @@ -600,28 +615,25 @@ impl Kzg {
z_fr: Fr,
eval_fr: &[Fr],
value_fr: Fr,
roots_of_unities: &[Fr],
roots_of_unity: &[Fr],
) -> Fr {
let mut quotient = Fr::zero();
let mut fi: Fr = Fr::zero();
let mut numerator: Fr = Fr::zero();
let mut denominator: Fr = Fr::zero();
let mut temp: Fr = Fr::zero();

roots_of_unities
.iter()
.enumerate()
.for_each(|(i, omega_i)| {
if *omega_i == z_fr {
return;
}
fi = eval_fr[i] - value_fr;
numerator = fi * omega_i;
denominator = z_fr - omega_i;
denominator *= z_fr;
temp = numerator.div(denominator);
quotient += temp;
});
roots_of_unity.iter().enumerate().for_each(|(i, omega_i)| {
if *omega_i == z_fr {
return;
}
fi = eval_fr[i] - value_fr;
numerator = fi * omega_i;
denominator = z_fr - omega_i;
denominator *= z_fr;
temp = numerator.div(denominator);
quotient += temp;
});

quotient
}
Expand Down
5 changes: 4 additions & 1 deletion src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ pub enum PolynomialFormat {
#[derive(Clone, Debug, PartialEq)]
pub struct Polynomial {
elements: Vec<Fr>,
// TODO: Remove this field, its just a duplicate of length_of_padded_blob_as_fr_vector.
// One can easily convert between them by *4 or /4. Also it should be calculated and not passed in,
// which is error prone (user might think length is in field elements).
length_of_padded_blob: usize,
length_of_padded_blob_as_fr_vector: usize,
form: PolynomialFormat,
Expand All @@ -20,7 +23,7 @@ pub struct Polynomial {
impl Polynomial {
/// Constructs a new `Polynomial` with a given vector of `Fr` elements.
pub fn new(
elements: &Vec<Fr>,
elements: &[Fr],
length_of_padded_blob: usize,
form: PolynomialFormat,
) -> Result<Self, PolynomialError> {
Expand Down