Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add backup writing to vivarium simulation Context #455

Merged
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
"networkx",
"loguru",
"pyarrow",
"dill",
# Type stubs
"pandas-stubs",
]
Expand Down
10 changes: 9 additions & 1 deletion src/vivarium/framework/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pprint import pformat
from typing import Any, Dict, List, Optional, Set, Union

import dill
import numpy as np
import pandas as pd
import yaml
Expand Down Expand Up @@ -259,7 +260,7 @@ def step(self) -> None:
self._clock.step_forward(self.get_population().index)

def run(self) -> None:
while self._clock.time < self._clock.stop_time:
while not self.past_stop_time():
patricktnast marked this conversation as resolved.
Show resolved Hide resolved
self.step()

def finalize(self) -> None:
Expand Down Expand Up @@ -297,6 +298,10 @@ def _write_results(self) -> None:
except ConfigurationKeyError:
self._logger.info("No results directory set; results are not written to disk.")

def write_backup(self, backup_path: Path) -> None:
with open(backup_path, "wb") as f:
patricktnast marked this conversation as resolved.
Show resolved Hide resolved
dill.dump(self, f)

def get_performance_metrics(self) -> pd.DataFrame:
timing_dict = self._lifecycle.timings
total_time = np.sum([np.sum(v) for v in timing_dict.values()])
Expand All @@ -322,6 +327,9 @@ def add_components(self, component_list: List[Component]) -> None:
def get_population(self, untracked: bool = True) -> pd.DataFrame:
return self._population.get_population(untracked)

def past_stop_time(self) -> bool:
patricktnast marked this conversation as resolved.
Show resolved Hide resolved
return self._clock.time >= self._clock.stop_time

def __repr__(self):
return f"SimulationContext({self.name})"

Expand Down
12 changes: 12 additions & 0 deletions tests/framework/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Dict, List

import dill
import pandas as pd
import pytest

Expand Down Expand Up @@ -347,6 +348,17 @@ def test_SimulationContext_report_write(SimulationContext, base_config, componen
assert results.equals(written_results)


@pytest.mark.skip(reason="TODO: Figure out how to make Dill serialize in pytest")
patricktnast marked this conversation as resolved.
Show resolved Hide resolved
def test_SimulationContext_write_backup(SimulationContext, tmpdir):
sim = SimulationContext()
backup_path = tmpdir / "backup.pkl"
sim.write_backup(backup_path)
assert backup_path.exists()
with open(backup_path, "rb") as f:
sim_backup = dill.load(f)
assert isinstance(sim_backup, SimulationContext)

patricktnast marked this conversation as resolved.
Show resolved Hide resolved

def test_get_results_formatting(SimulationContext, base_config):
"""Test formatted results are as expected"""
components = [
Expand Down
Loading