Skip to content

Commit

Permalink
Merge pull request #47 from CovertLab/serialize_rng_state
Browse files Browse the repository at this point in the history
Serialize PRNG state
  • Loading branch information
U8NWXD authored Oct 27, 2021
2 parents c91d6fe + 8ab110d commit 197e6e0
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ given duration. `evolve` returns a dictionary with five keys:
* outcome - the final state of the system

```python
result = system.evolve(state, duration, rates)
result = system.evolve(duration, state, rates)
```

If you are interested in the history of states for plotting or otherwise, these can be
Expand Down
8 changes: 5 additions & 3 deletions arrow/arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def flat_indexes(assorted_lists):
conjunction with the flat array to recover the original list of lists.
Args:
assorted_lists (List[List]): A list of variable length lists.
assorted_lists (List[List]): A list of variable length lists.
Returns numpy arrays:
flat: The flattened data.
Expand Down Expand Up @@ -78,7 +78,7 @@ class StochasticSystem(object):
The stoichiometric matrix has a reaction for each row, with the values in that row
encoding how many of each substrate are either consumed or produced by the reaction
(and zero everywhere else).
(and zero everywhere else).
'''

def __init__(self, stoichiometry, random_seed=0):
Expand Down Expand Up @@ -134,6 +134,7 @@ def __getstate__(self):

return (
self.random_seed,
self.obsidian.get_random_state(),
self.stoichiometry,
self.reactants_lengths,
self.reactants_indexes,
Expand All @@ -151,7 +152,7 @@ def __setstate__(self, state):
Import from pickled state.
'''

self.random_seed, self.stoichiometry, self.reactants_lengths, self.reactants_indexes, self.reactants_flat, self.reactions_flat, self.dependencies_lengths, self.dependencies_indexes, self.dependencies_flat, self.substrates_lengths, self.substrates_indexes, self.substrates_flat = state
self.random_seed, random_state, self.stoichiometry, self.reactants_lengths, self.reactants_indexes, self.reactants_flat, self.reactions_flat, self.dependencies_lengths, self.dependencies_indexes, self.dependencies_flat, self.substrates_lengths, self.substrates_indexes, self.substrates_flat = state

self.obsidian = Arrowhead(
self.random_seed,
Expand All @@ -166,6 +167,7 @@ def __setstate__(self, state):
self.substrates_lengths,
self.substrates_indexes,
self.substrates_flat)
self.obsidian.set_random_state(*random_state)

def evolve(self, duration, state, rates):
status, steps, time, events, outcome = self.obsidian.evolve(duration, state, rates)
Expand Down
28 changes: 27 additions & 1 deletion arrow/arrowhead.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ from __future__ import absolute_import, division, print_function

cimport cython
from cpython.mem cimport PyMem_Malloc, PyMem_Free
from libc.stdint cimport int64_t
from libc.stdint cimport int64_t, uint32_t
from libc.string cimport memset, memcpy
from libc.stdlib cimport free

Expand Down Expand Up @@ -132,6 +132,32 @@ cdef class Arrowhead:
"""Returns the number of substrates this system operates on."""
return self.info.substrates_count

def get_random_state(self):
"""Returns the state of the pseudorandom number generator."""
cdef mersenne.MTState state
obsidian.get_random_state(&self.info, &state)

mt = copy_c_array(
&state.MT[0], mersenne.TWISTER_SIZE, sizeof(uint32_t),
np.NPY_UINT32)
mt_tempered = copy_c_array(
&state.MT_TEMPERED[0], mersenne.TWISTER_SIZE, sizeof(uint32_t),
np.NPY_UINT32)
index = state.index

return mt, mt_tempered, index

def set_random_state(
self, uint32_t[::1] mt, uint32_t[::1] mt_tempered,
size_t index):
cdef mersenne.MTState state
memcpy(&state.MT[0], &mt[0], sizeof(state.MT))
memcpy(
&state.MT_TEMPERED[0], &mt_tempered[0],
sizeof(state.MT_TEMPERED))
state.index = index
obsidian.set_random_state(&self.info, &state)


cdef np.ndarray copy_c_array(
void *source, np.npy_intp element_count, size_t element_size, int np_typenum):
Expand Down
8 changes: 7 additions & 1 deletion arrow/mersenne.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,14 @@ from libc.stdint cimport uint32_t

cdef extern from "mersenne.h":

size_t TWISTER_SIZE

ctypedef struct MTState:
pass
# mersenne.h defines TWISTER_SIZE to be 624.
uint32_t MT[624]
uint32_t MT_TEMPERED[624]
size_t index


void seed(MTState *state, uint32_t seed_value)

Expand Down
10 changes: 9 additions & 1 deletion arrow/obsidian.c
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ evolve_result evolve(Info *info, double duration, int64_t *state, double *rates)
choice = -1;
break;

// Otherwise we need to find the next reaction to perform.
// Otherwise we need to find the next reaction to perform.
} else {

// First, sample two random values, `point` from a linear distribution and
Expand Down Expand Up @@ -310,6 +310,14 @@ evolve_result evolve(Info *info, double duration, int64_t *state, double *rates)
return result;
}

void get_random_state(const Info *info, MTState *exported_random_state) {
memcpy(exported_random_state, info->random_state, sizeof(MTState));
}

void set_random_state(Info *info, const MTState *state) {
memcpy(info->random_state, state, sizeof(MTState));
}

// Print an array of doubles
int
print_array(double *array, int length) {
Expand Down
4 changes: 4 additions & 0 deletions arrow/obsidian.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ typedef struct Info {
// arrays that the caller must free().
evolve_result evolve(Info *info, double duration, int64_t *state, double *rates);

void get_random_state(const Info *info, MTState *exported_random_state);

void set_random_state(Info *info, const MTState *state);

// Supporting print utilities
int print_array(double *array, int length);

Expand Down
6 changes: 5 additions & 1 deletion arrow/obsidian.pxd
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# cython: language_level=3str

from libc.stdint cimport int64_t
from libc.stdint cimport int64_t, uint32_t

from mersenne cimport MTState

Expand Down Expand Up @@ -37,4 +37,8 @@ cdef extern from "obsidian.h":

evolve_result evolve(Info *info, double duration, int64_t *state, double *rates)

void get_random_state(Info *info, MTState *exported_random_state)

void set_random_state(Info *info, MTState *state)

int print_array(double *array, int length)
26 changes: 26 additions & 0 deletions arrow/test/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,32 @@ def test_flagella():

print('flagella result: {}'.format(result))

def test_get_set_random_state():
stoich = np.array([[1, 1, -1, 0], [-2, 0, 0, 1], [-1, -1, 1, 0]])
system = StochasticSystem(stoich)

state = np.array([1000, 1000, 0, 0])
rates = np.array([3.0, 1.0, 1.0])

system.evolve(1, state, rates)

rand_state = system.obsidian.get_random_state()

result_1 = system.evolve(1, state, rates)
result_2 = system.evolve(1, state, rates)

with np.testing.assert_raises(AssertionError):
for key in ('time', 'events', 'occurrences', 'outcome'):
np.testing.assert_array_equal(
result_1[key], result_2[key])

system.obsidian.set_random_state(*rand_state)
result_1_again = system.evolve(1, state, rates)

for key in ('time', 'events', 'occurrences', 'outcome'):
np.testing.assert_array_equal(
result_1[key], result_1_again[key])


def main(args):
systems = (
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

setup(
name='stochastic-arrow',
version='0.4.3',
version='0.4.4',
packages=['arrow'],
author='Ryan Spangler, John Mason, Jerry Morrison',
author_email='[email protected]',
Expand Down

0 comments on commit 197e6e0

Please sign in to comment.