Skip to content

Commit

Permalink
Add __getnewargs__ to Variable and Derivative to support serial…
Browse files Browse the repository at this point in the history
…ization with `cloudpickle`

Locally-defined, i.e. inside a function, systems cannot be serialized with `pickle`, which uses references,
but can be serialized with `cloudpickle`, which serializes by value.
As Derivative defines its own `__new__` method, it requires to define `__getnewargs_ex__` for pickling.
  • Loading branch information
maurosilber committed Jan 19, 2024
1 parent 0d5d092 commit f8d3b41
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
## unreleased

- Update to `symbolite >= 0.6` to support serialization (pickle) of `System`.
- Add `__getnewargs__` to `Variable` and `Derivative` to support serialization
of "locally-defined" `System`s with `cloudpickle`.

## 0.4.0

Expand Down
1 change: 1 addition & 0 deletions requirements.test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
cloudpickle
jax[cpu]
matplotlib
numba
Expand Down
29 changes: 29 additions & 0 deletions src/poincare/tests/test_serialize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle

import cloudpickle
from pytest import mark

from .. import Derivative, System, Variable, initial
Expand Down Expand Up @@ -36,3 +37,31 @@ def test_roundtrip(model: System):
dump = pickle.dumps(model)
load = pickle.loads(dump)
assert load == model


@mark.parametrize("model", models)
def test_local_roundtrip(model: System):
class EmptyModel(System):
pass

class SingleVariable(System):
x: Variable = initial(default=0)

class SingleDerivative(System):
x: Variable = initial(default=0)
v: Derivative = x.derive(initial=0)

class SingleEquation(System):
x: Variable = initial(default=0)
eq = x.derive() << -x

inner_model = {
EmptyModel.__name__: EmptyModel,
SingleVariable.__name__: SingleVariable,
SingleDerivative.__name__: SingleDerivative,
SingleEquation.__name__: SingleEquation,
}[model.__name__]

dump = cloudpickle.dumps(inner_model)
load = cloudpickle.loads(dump)
assert load == inner_model
11 changes: 11 additions & 0 deletions src/poincare/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ def __init__(self, *, initial: Initial | None):
self.derivatives = {}
units.check_units(self, initial)

def __getnewargs__(self):
return (self.initial, self.derivatives, self.equation_order)

def eval(self, libsl=None):
if libsl is libabstract:
return self
Expand Down Expand Up @@ -246,6 +249,14 @@ def __set_name__(self, cls: Node, name: str):
super().__set_name__(cls, name)
self.variable.derivatives[self.order] = self

def __getnewargs_ex__(self):
args = (self.variable,)
kwargs = {
"initial": self.initial,
"order": self.order,
}
return args, kwargs

def eval(self, libsl=None):
if libsl is libabstract:
return self
Expand Down

0 comments on commit f8d3b41

Please sign in to comment.