Skip to content

Commit

Permalink
feat(python): Added support for pickling lace.Engine
Browse files Browse the repository at this point in the history
  • Loading branch information
schmidmt committed Feb 13, 2024
1 parent 8a8e39c commit 6ccbed7
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 4 deletions.
41 changes: 38 additions & 3 deletions pylace/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@ use lace::prelude::ColMetadataList;
use lace::{EngineUpdateConfig, FType, HasStates, OracleT};
use polars::prelude::{DataFrame, NamedFrom, Series};
use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError};
use pyo3::types::{PyDict, PyList, PyType};
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyList, PyType};
use pyo3::{create_exception, prelude::*};
use rand::SeedableRng;
use rand_xoshiro::Xoshiro256Plus;
use serde::{Deserialize, Serialize};

use metadata::{Codebook, CodebookBuilder};

use crate::update_handler::PyUpdateHandler;
use crate::utils::*;

#[derive(Clone)]
#[pyclass(subclass)]
#[derive(Clone, Serialize, Deserialize)]
#[pyclass(subclass, module = "lace.core")]
struct CoreEngine {
engine: lace::Engine,
col_indexer: Indexer,
Expand Down Expand Up @@ -1320,6 +1322,39 @@ impl CoreEngine {
})
})
}

pub fn __setstate__(
&mut self,
py: Python,
state: PyObject,
) -> PyResult<()> {
let s = state.extract::<&PyBytes>(py)?;
*self = bincode::deserialize(s.as_bytes()).map_err(|e| {
PyValueError::new_err(format!("Cannot Deserialize CoreEngine: {e}"))
})?;
Ok(())
}

pub fn __getstate__(&self, py: Python) -> PyResult<PyObject> {
Ok(PyBytes::new(
py,
&bincode::serialize(&self).map_err(|e| {
PyValueError::new_err(format!(
"Cannot Serialize CoreEngine: {e}"
))
})?,
)
.to_object(py))
}

pub fn __getnewargs__(&self) -> PyResult<(PyDataFrame,)> {
Ok((PyDataFrame(
polars::df! {
"ID" => [0],
}
.expect("Should be a df"),
),))
}
}

#[pyfunction]
Expand Down
3 changes: 2 additions & 1 deletion pylace/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use pyo3::prelude::*;
use pyo3::types::{
PyAny, PyBool, PyDict, PyInt, PyList, PySlice, PyString, PyTuple,
};
use serde::{Deserialize, Serialize};

use crate::df::{PyDataFrame, PySeries};

Expand Down Expand Up @@ -381,7 +382,7 @@ pub(crate) fn str_to_mitype(mi_type: &str) -> PyResult<lace::MiType> {
}
}

#[derive(Clone, PartialEq, Eq)]
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
pub(crate) struct Indexer {
pub to_ix: HashMap<String, usize>,
pub to_name: HashMap<usize, String>,
Expand Down
14 changes: 14 additions & 0 deletions pylace/tests/test_pickle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pickle

from lace import examples


def test_pickle_engine():
engine = examples.Animals().engine
s = pickle.dumps(engine)
engine_b = pickle.loads(s)

sim_a = engine.simulate(["swims", "flys"], n=10)
sim_b = engine_b.simulate(["swims", "flys"], n=10)

assert sim_a.equals(sim_b)

0 comments on commit 6ccbed7

Please sign in to comment.