Skip to content

Commit e534d9f

Browse files
committed
derive encode and decode on grid
1 parent e36be28 commit e534d9f

File tree

7 files changed

+25
-20
lines changed

7 files changed

+25
-20
lines changed

.devenv/load-exports

-1
This file was deleted.

.devenv/profile

-1
This file was deleted.

.devenv/run

-1
This file was deleted.

.gitignore

+2-2
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@ target/
66
.direnv/
77
patch_symbolica.py
88
symbolica_path.txt
9-
.devenv/*
9+
.devenv/
1010
devenv.local.nix
11-
1211
# direnv
1312
.direnv
13+
.devenv
1414

1515
# pre-commit
1616
.pre-commit-config.yaml

flake.nix

+2
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757
pkgs.gnum4
5858
pkgs.gmp
5959
pkgs.mpfr
60+
pkgs.python3
61+
pkgs.maturin
6062
pkgs.gnumake
6163
pkgs.diffutils
6264
pkgs.glibc

src/api/python.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ use std::{
77
sync::Arc,
88
};
99

10+
use bincode::{
11+
de::read::Reader, enc::write::Writer, impl_borrow_decode_with_context, BorrowDecode, Decode,
12+
Encode,
13+
};
14+
1015
use ahash::HashMap;
1116
use brotli::CompressorWriter;
1217
use pyo3::{
@@ -11148,7 +11153,7 @@ impl PythonNumericalIntegrator {
1114811153
/// Use `export_grid` to export the grid.
1114911154
#[classmethod]
1115011155
fn import_grid(_cls: &Bound<'_, PyType>, grid: &[u8]) -> PyResult<Self> {
11151-
let grid = bincode::deserialize(grid)
11156+
let (grid, _) = bincode::decode_from_slice(grid, bincode::config::standard())
1115211157
.map_err(|e| pyo3::exceptions::PyIOError::new_err(e.to_string()))?;
1115311158

1115411159
Ok(PythonNumericalIntegrator { grid })
@@ -11157,7 +11162,7 @@ impl PythonNumericalIntegrator {
1115711162
/// Export the grid, so that it can be sent to another thread or machine.
1115811163
/// Use `import_grid` to load the grid.
1115911164
fn export_grid<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyBytes>> {
11160-
bincode::serialize(&self.grid)
11165+
bincode::encode_to_vec(&self.grid, bincode::config::standard())
1116111166
.map(|a| PyBytes::new(py, &a))
1116211167
.map_err(|e| pyo3::exceptions::PyIOError::new_err(e.to_string()))
1116311168
}

src/numerical_integration.rs

+14-13
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,14 @@
2323
//! // sample 10_000 times per iteration
2424
//! for _ in 0..10_000 {
2525
//! grid.sample(&mut rng, &mut sample);
26-
//!
26+
//!
2727
//! if let Sample::Continuous(_cont_weight, xs) = &sample {
2828
//! grid.add_training_sample(&sample, f(xs)).unwrap();
2929
//! }
3030
//! }
31-
//!
31+
//!
3232
//! grid.update(1.5, 1.5);
33-
//!
33+
//!
3434
//! println!(
3535
//! "Integral at iteration {}: {}",
3636
//! iteration,
@@ -39,6 +39,7 @@
3939
//! }
4040
//! ```
4141
42+
use bincode::{Decode, Encode};
4243
use rand::{Rng, RngCore, SeedableRng};
4344
use rand_xoshiro::Xoshiro256StarStar;
4445
use serde::{Deserialize, Serialize};
@@ -60,7 +61,7 @@ use crate::domains::float::{ConstructibleFloat, Real, RealNumberLike};
6061
///
6162
/// The accumulator also stores which samples yielded the highest weight thus far.
6263
/// This can be used to study the input that impacted the average and error the most.
63-
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
64+
#[derive(Debug, Default, Clone, Serialize, Deserialize, Encode, Decode)]
6465
pub struct StatisticsAccumulator<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd>
6566
{
6667
sum: T,
@@ -383,7 +384,7 @@ impl<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> Statisti
383384
/// and contains the weight and the list of sample points.
384385
/// If the sample comes from a [DiscreteGrid], it is the variant [Discrete](Sample::Discrete) and contains
385386
/// the weight, the bin and the subsample if the bin has a nested grid.
386-
#[derive(Debug, Clone, Serialize, Deserialize)]
387+
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
387388
pub enum Sample<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> {
388389
Continuous(T, Vec<T>),
389390
Discrete(T, usize, Option<Box<Sample<T>>>),
@@ -453,22 +454,22 @@ impl<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> Sample<T
453454
/// // sample 10_000 times per iteration
454455
/// for _ in 0..10_000 {
455456
/// grid.sample(&mut rng, &mut sample);
456-
///
457+
///
457458
/// if let Sample::Continuous(_cont_weight, xs) = &sample {
458459
/// grid.add_training_sample(&sample, f(xs)).unwrap();
459460
/// }
460461
/// }
461-
///
462+
///
462463
/// grid.update(1.5, 1.5);
463-
///
464+
///
464465
/// println!(
465466
/// "Integral at iteration {}: {}",
466467
/// iteration,
467468
/// grid.get_statistics().format_uncertainty()
468469
/// );
469470
/// }
470471
/// ```
471-
#[derive(Debug, Clone, Serialize, Deserialize)]
472+
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
472473
pub enum Grid<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> {
473474
Continuous(ContinuousGrid<T>),
474475
Discrete(DiscreteGrid<T>),
@@ -539,7 +540,7 @@ impl<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> Grid<T>
539540
}
540541
}
541542
/// A bin of a discrete grid, which may contain a subgrid.
542-
#[derive(Debug, Clone, Serialize, Deserialize)]
543+
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
543544
pub struct Bin<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> {
544545
pub pdf: T,
545546
pub accumulator: StatisticsAccumulator<T>,
@@ -578,7 +579,7 @@ impl<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> Bin<T> {
578579
/// of a sample from the grid landing in a bin is proportional to its
579580
/// average value if training happens on the average, or to its
580581
/// variance (recommended).
581-
#[derive(Debug, Clone, Serialize, Deserialize)]
582+
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
582583
pub struct DiscreteGrid<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> {
583584
pub bins: Vec<Bin<T>>,
584585
pub accumulator: StatisticsAccumulator<T>,
@@ -797,7 +798,7 @@ impl<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> Discrete
797798
/// of a sample from the grid landing in a bin is proportional to its
798799
/// average value if training happens on the average, or to its
799800
/// variance (recommended).
800-
#[derive(Debug, Clone, Serialize, Deserialize)]
801+
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
801802
pub struct ContinuousGrid<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> {
802803
pub continuous_dimensions: Vec<ContinuousDimension<T>>,
803804
pub accumulator: StatisticsAccumulator<T>,
@@ -925,7 +926,7 @@ impl<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> Continuo
925926
}
926927

927928
/// A dimension in a continuous grid that contains a partitioning.
928-
#[derive(Debug, Clone, Serialize, Deserialize)]
929+
#[derive(Debug, Clone, Serialize, Deserialize, Encode, Decode)]
929930
pub struct ContinuousDimension<T: Real + ConstructibleFloat + Copy + RealNumberLike + PartialOrd> {
930931
pub partitioning: Vec<T>,
931932
bin_accumulator: Vec<StatisticsAccumulator<T>>,

0 commit comments

Comments
 (0)