From 49b83cd5f2e244ffa2a7a3902013f68ec11e33f3 Mon Sep 17 00:00:00 2001 From: nmatthews-asapp Date: Wed, 29 Jan 2020 10:04:51 -0500 Subject: [PATCH 1/2] add experiment flag to pickle only on checkpoints --- flambe/experiment/experiment.py | 7 +++++-- flambe/experiment/tune_adapter.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/flambe/experiment/experiment.py b/flambe/experiment/experiment.py index d120337b..e6a85351 100644 --- a/flambe/experiment/experiment.py +++ b/flambe/experiment/experiment.py @@ -106,7 +106,8 @@ def __init__(self, max_failures: int = 1, stop_on_failure: bool = True, merge_plot: bool = True, - user_provider: Callable[[], str] = None) -> None: + user_provider: Callable[[], str] = None, + pickle_checkpoints: bool = False) -> None: super().__init__(env=env, user_provider=user_provider) self.name = name @@ -152,6 +153,7 @@ def __init__(self, raise TypeError("Pipeline argument is not of type Dict[str, Schema]. " f"Got {type(pipeline).__name__} instead") self.pipeline = pipeline + self.pickle_checkpoints = pickle_checkpoints def process_resources( self, @@ -363,7 +365,8 @@ def trial_name_creator(trial): 'global_vars': resources, 'verbose': verbose, 'custom_modules': list(self.extensions.keys()), - 'debug': debug} + 'debug': debug, + 'pickle_checkpoints': self.pickle_checkpoints} # Filter out the tensorboard logger as we handle # general and tensorboard-specific logging ourselves tune_loggers = list(filter(lambda l: l != tf2_compat_logger and # noqa: E741 diff --git a/flambe/experiment/tune_adapter.py b/flambe/experiment/tune_adapter.py index 39ec7d88..a97f8ba3 100644 --- a/flambe/experiment/tune_adapter.py +++ b/flambe/experiment/tune_adapter.py @@ -36,6 +36,7 @@ def _setup(self, config: Dict): self.verbose = config['verbose'] self.hyper_params = config['hyper_params'] self.debug = config['debug'] + self.pickle_checkpoints = config['pickle_checkpoints'] with TrialLogging(log_dir=self.logdir, verbose=self.verbose, @@ -152,7 +153,7 @@ def _train(self) -> Dict: def _save(self, checkpoint_dir: str) -> str: """Subclasses should override this to implement save().""" path = os.path.join(checkpoint_dir, "checkpoint.flambe") - self.block.save(path, overwrite=True) + self.block.save(path, pickle_only=self.pickle_checkpoints, overwrite=True) return path def _restore(self, checkpoint: str) -> None: From 75e988086c794e58991e6e70e75960a55916d42b Mon Sep 17 00:00:00 2001 From: nmatthews-asapp Date: Thu, 30 Jan 2020 14:36:40 -0500 Subject: [PATCH 2/2] change pickle only flag to use torch in tune adapter --- flambe/experiment/tune_adapter.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/flambe/experiment/tune_adapter.py b/flambe/experiment/tune_adapter.py index a97f8ba3..268f9287 100644 --- a/flambe/experiment/tune_adapter.py +++ b/flambe/experiment/tune_adapter.py @@ -8,6 +8,8 @@ import shutil import ray +import torch +import dill from flambe.compile import load_state_from_file, Schema, Component from flambe.compile.extensions import setup_default_modules, import_modules @@ -153,13 +155,19 @@ def _train(self) -> Dict: def _save(self, checkpoint_dir: str) -> str: """Subclasses should override this to implement save().""" path = os.path.join(checkpoint_dir, "checkpoint.flambe") - self.block.save(path, pickle_only=self.pickle_checkpoints, overwrite=True) + if self.pickle_checkpoints: + torch.save(self.block, path, pickle_module=dill) + else: + self.block.save(path, overwrite=True) return path def _restore(self, checkpoint: str) -> None: """Subclasses should override this to implement restore().""" - state = load_state_from_file(checkpoint) - self.block.load_state(state) + if self.pickle_checkpoints: + self.block = torch.load(checkpoint, pickle_protocol=dill) + else: + state = load_state_from_file(checkpoint) + self.block.load_state(state) def _stop(self): """Subclasses should override this for any cleanup on stop."""