From 6ccbed7f8f96fbf5430b1c18bb77bda69eb0ecf4 Mon Sep 17 00:00:00 2001 From: Mike Schmidt Date: Tue, 16 Jan 2024 09:43:31 -0700 Subject: [PATCH] feat(python): Added support for pickling lace.Engine --- pylace/src/lib.rs | 41 ++++++++++++++++++++++++++++++++++--- pylace/src/utils.rs | 3 ++- pylace/tests/test_pickle.py | 14 +++++++++++++ 3 files changed, 54 insertions(+), 4 deletions(-) create mode 100644 pylace/tests/test_pickle.py diff --git a/pylace/src/lib.rs b/pylace/src/lib.rs index fe65a227..4b685060 100644 --- a/pylace/src/lib.rs +++ b/pylace/src/lib.rs @@ -16,18 +16,20 @@ use lace::prelude::ColMetadataList; use lace::{EngineUpdateConfig, FType, HasStates, OracleT}; use polars::prelude::{DataFrame, NamedFrom, Series}; use pyo3::exceptions::{PyIndexError, PyRuntimeError, PyValueError}; -use pyo3::types::{PyDict, PyList, PyType}; +use pyo3::prelude::*; +use pyo3::types::{PyBytes, PyDict, PyList, PyType}; use pyo3::{create_exception, prelude::*}; use rand::SeedableRng; use rand_xoshiro::Xoshiro256Plus; +use serde::{Deserialize, Serialize}; use metadata::{Codebook, CodebookBuilder}; use crate::update_handler::PyUpdateHandler; use crate::utils::*; -#[derive(Clone)] -#[pyclass(subclass)] +#[derive(Clone, Serialize, Deserialize)] +#[pyclass(subclass, module = "lace.core")] struct CoreEngine { engine: lace::Engine, col_indexer: Indexer, @@ -1320,6 +1322,39 @@ impl CoreEngine { }) }) } + + pub fn __setstate__( + &mut self, + py: Python, + state: PyObject, + ) -> PyResult<()> { + let s = state.extract::<&PyBytes>(py)?; + *self = bincode::deserialize(s.as_bytes()).map_err(|e| { + PyValueError::new_err(format!("Cannot Deserialize CoreEngine: {e}")) + })?; + Ok(()) + } + + pub fn __getstate__(&self, py: Python) -> PyResult { + Ok(PyBytes::new( + py, + &bincode::serialize(&self).map_err(|e| { + PyValueError::new_err(format!( + "Cannot Serialize CoreEngine: {e}" + )) + })?, + ) + .to_object(py)) + } + + pub fn __getnewargs__(&self) -> PyResult<(PyDataFrame,)> { + Ok((PyDataFrame( + polars::df! { + "ID" => [0], + } + .expect("Should be a df"), + ),)) + } } #[pyfunction] diff --git a/pylace/src/utils.rs b/pylace/src/utils.rs index 55e3bc1f..033dcd94 100644 --- a/pylace/src/utils.rs +++ b/pylace/src/utils.rs @@ -15,6 +15,7 @@ use pyo3::prelude::*; use pyo3::types::{ PyAny, PyBool, PyDict, PyInt, PyList, PySlice, PyString, PyTuple, }; +use serde::{Deserialize, Serialize}; use crate::df::{PyDataFrame, PySeries}; @@ -381,7 +382,7 @@ pub(crate) fn str_to_mitype(mi_type: &str) -> PyResult { } } -#[derive(Clone, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] pub(crate) struct Indexer { pub to_ix: HashMap, pub to_name: HashMap, diff --git a/pylace/tests/test_pickle.py b/pylace/tests/test_pickle.py new file mode 100644 index 00000000..b23d941f --- /dev/null +++ b/pylace/tests/test_pickle.py @@ -0,0 +1,14 @@ +import pickle + +from lace import examples + + +def test_pickle_engine(): + engine = examples.Animals().engine + s = pickle.dumps(engine) + engine_b = pickle.loads(s) + + sim_a = engine.simulate(["swims", "flys"], n=10) + sim_b = engine_b.simulate(["swims", "flys"], n=10) + + assert sim_a.equals(sim_b)