Skip to content

Commit

Permalink
feat(python): Added support for pickling lace.Engine (#184)
Browse files Browse the repository at this point in the history
* feat(python): Added support for pickling lace.Engine
* fix(pylace): Removed unnecessary expect.
  • Loading branch information
schmidmt authored Feb 13, 2024
1 parent 8a8e39c commit f820017
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 9 deletions.
10 changes: 6 additions & 4 deletions pylace/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pylace/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ serde_json = "1.0.91"
serde_yaml = "0.9.17"
polars = "0.36"
polars-arrow = "0.36.2"
serde = { version = "1.0.196", features = ["derive"] }
bincode = "1.3.3"

[package.metadata.maturin]
name = "lace.core"
43 changes: 39 additions & 4 deletions pylace/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@ use lace::metadata::SerializedType;
use lace::prelude::ColMetadataList;
use lace::{EngineUpdateConfig, FType, HasStates, OracleT};
use polars::prelude::{DataFrame, NamedFrom, Series};
use pyo3::create_exception;
use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError};
use pyo3::types::{PyDict, PyList, PyType};
use pyo3::{create_exception, prelude::*};
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyList, PyType};
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use serde::{Deserialize, Serialize};

use metadata::{Codebook, CodebookBuilder};

use crate::update_handler::PyUpdateHandler;
use crate::utils::*;

#[derive(Clone)]
#[pyclass(subclass)]
#[derive(Clone, Serialize, Deserialize)]
#[pyclass(subclass, module = "lace.core")]
struct CoreEngine {
engine: lace::Engine,
col_indexer: Indexer,
Expand Down Expand Up @@ -1320,6 +1322,39 @@ impl CoreEngine {
})
})
}

pub fn __setstate__(
&mut self,
py: Python,
state: PyObject,
) -> PyResult<()> {
let s = state.extract::<&PyBytes>(py)?;
*self = bincode::deserialize(s.as_bytes()).map_err(|e| {
PyValueError::new_err(format!("Cannot Deserialize CoreEngine: {e}"))
})?;
Ok(())
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
Ok(PyBytes::new(
py,
&bincode::serialize(&self).map_err(|e| {
PyValueError::new_err(format!(
"Cannot Serialize CoreEngine: {e}"
))
})?,
)
.to_object(py))
}

pub fn __getnewargs__(&self) -> PyResult<(PyDataFrame,)> {
Ok((PyDataFrame(
polars::df! {
"ID" => [0],
}
.map_err(|e| PyValueError::new_err(format!("Polars error: {e}")))?,
),))
}
}

#[pyfunction]
Expand Down
3 changes: 2 additions & 1 deletion pylace/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use pyo3::prelude::*;
use pyo3::types::{
PyAny, PyBool, PyDict, PyInt, PyList, PySlice, PyString, PyTuple,
};
use serde::{Deserialize, Serialize};

use crate::df::{PyDataFrame, PySeries};

Expand Down Expand Up @@ -381,7 +382,7 @@ pub(crate) fn str_to_mitype(mi_type: &str) -> PyResult<lace::MiType> {
}
}

#[derive(Clone, PartialEq, Eq)]
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub(crate) struct Indexer {
pub to_ix: HashMap<String, usize>,
pub to_name: HashMap<usize, String>,
Expand Down
14 changes: 14 additions & 0 deletions pylace/tests/test_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pickle

from lace import examples


def test_pickle_engine():
engine = examples.Animals().engine
s = pickle.dumps(engine)
engine_b = pickle.loads(s)

sim_a = engine.simulate(["swims", "flys"], n=10)
sim_b = engine_b.simulate(["swims", "flys"], n=10)

assert sim_a.equals(sim_b)

0 comments on commit f820017

Please sign in to comment.