Skip to content

Commit

Permalink
fix grid encode decode
Browse files Browse the repository at this point in the history
  • Loading branch information
lcnbr committed Jan 6, 2025
1 parent 28e3bac commit 283187f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/api/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11165,7 +11165,7 @@ impl PythonNumericalIntegrator {
/// Use `export_grid` to export the grid.
#[classmethod]
fn import_grid(_cls: &Bound<'_, PyType>, grid: &[u8]) -> PyResult<Self> {
let grid = bincode::deserialize(grid)
let (grid, _) = bincode::decode_from_slice(grid, bincode::config::standard())
.map_err(|e| pyo3::exceptions::PyIOError::new_err(e.to_string()))?;

Ok(PythonNumericalIntegrator { grid })
Expand All @@ -11174,7 +11174,7 @@ impl PythonNumericalIntegrator {
/// Export the grid, so that it can be sent to another thread or machine.
/// Use `import_grid` to load the grid.
fn export_grid<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyBytes>> {
bincode::serialize(&self.grid)
bincode::encode_to_vec(&self.grid, bincode::config::standard())
.map(|a| PyBytes::new(py, &a))
.map_err(|e| pyo3::exceptions::PyIOError::new_err(e.to_string()))
}
Expand Down
27 changes: 14 additions & 13 deletions src/numerical_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
//! // sample 10_000 times per iteration
//! for _ in 0..10_000 {
//! grid.sample(&mut rng, &mut sample);
//!
//!
//! if let Sample::Continuous(_cont_weight, xs) = &sample {
//! grid.add_training_sample(&sample, f(xs)).unwrap();
//! }
//! }
//!
//!
//! grid.update(1.5, 1.5);
//!
//!
//! println!(
//! "Integral at iteration {}: {}",
//! iteration,
Expand All @@ -39,6 +39,7 @@
//! }
//! ```
use bincode::{Decode, Encode};
use rand::{Rng, RngCore, SeedableRng};
use rand_xoshiro::Xoshiro256StarStar;
use serde::{Deserialize, Serialize};
Expand All @@ -60,7 +61,7 @@ use crate::domains::float::{ConstructibleFloat, Real, RealNumberLike};
///
/// The accumulator also stores which samples yielded the highest weight thus far.
/// This can be used to study the input that impacted the average and error the most.
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
#[derive(Debug, Default, Clone, Serialize, Deserialize, Encode, Decode)]
pub struct StatisticsAccumulator<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd>
{
sum: T,
Expand Down Expand Up @@ -383,7 +384,7 @@ impl<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> Statisti
/// and contains the weight and the list of sample points.
/// If the sample comes from a [DiscreteGrid], it is the variant [Discrete](Sample::Discrete) and contains
/// the weight, the bin and the subsample if the bin has a nested grid.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
pub enum Sample<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> {
Continuous(T, Vec<T>),
Discrete(T, usize, Option<Box<Sample<T>>>),
Expand Down Expand Up @@ -453,22 +454,22 @@ impl<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> Sample<T
/// // sample 10_000 times per iteration
/// for _ in 0..10_000 {
/// grid.sample(&mut rng, &mut sample);
///
///
/// if let Sample::Continuous(_cont_weight, xs) = &sample {
/// grid.add_training_sample(&sample, f(xs)).unwrap();
/// }
/// }
///
///
/// grid.update(1.5, 1.5);
///
///
/// println!(
/// "Integral at iteration {}: {}",
/// iteration,
/// grid.get_statistics().format_uncertainty()
/// );
/// }
/// ```
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
pub enum Grid<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> {
Continuous(ContinuousGrid<T>),
Discrete(DiscreteGrid<T>),
Expand Down Expand Up @@ -539,7 +540,7 @@ impl<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> Grid<T>
}
}
/// A bin of a discrete grid, which may contain a subgrid.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
pub struct Bin<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> {
pub pdf: T,
pub accumulator: StatisticsAccumulator<T>,
Expand Down Expand Up @@ -578,7 +579,7 @@ impl<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> Bin<T> {
/// of a sample from the grid landing in a bin is proportional to its
/// average value if training happens on the average, or to its
/// variance (recommended).
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
pub struct DiscreteGrid<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> {
pub bins: Vec<Bin<T>>,
pub accumulator: StatisticsAccumulator<T>,
Expand Down Expand Up @@ -797,7 +798,7 @@ impl<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> Discrete
/// of a sample from the grid landing in a bin is proportional to its
/// average value if training happens on the average, or to its
/// variance (recommended).
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
pub struct ContinuousGrid<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> {
pub continuous_dimensions: Vec<ContinuousDimension<T>>,
pub accumulator: StatisticsAccumulator<T>,
Expand Down Expand Up @@ -925,7 +926,7 @@ impl<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> Continuo
}

/// A dimension in a continuous grid that contains a partitioning.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
pub struct ContinuousDimension<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> {
pub partitioning: Vec<T>,
bin_accumulator: Vec<StatisticsAccumulator<T>>,
Expand Down

0 comments on commit 283187f

Please sign in to comment.