Skip to content

Commit

Permalink
ckp: implement blocked time stepping
Browse files Browse the repository at this point in the history
  • Loading branch information
speglich committed Apr 28, 2022
1 parent f695c7c commit 28af91b
Showing 1 changed file with 53 additions and 24 deletions.
77 changes: 53 additions & 24 deletions pyrevolve/pyrevolve.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABCMeta, abstractproperty, abstractmethod
import numpy as np
import math
from . import crevolve as cr
from .compression import init_compression as init
from .schedulers import CRevolve, HRevolve, Action, Architecture
Expand Down Expand Up @@ -69,11 +70,12 @@ def __init__(
fwd_operator,
rev_operator,
n_checkpoints,
n_timesteps,
op_timesteps,
storage_list=None,
scheduler=None,
timings=None,
profiler=None,
block_size=1,
):
"""
Initialises checkpointer for a given forward- and reverse operator, a
Expand All @@ -82,10 +84,10 @@ def __init__(
methods and a scheduler object must be provided as well. Otherwise
NumpyStorage and CRevolve are used as default
"""
if n_timesteps is None:
if op_timesteps is None:
raise Exception(
"Online checkpointing not yet supported. Specify \
number of time steps!"
number of Operator time steps!"
)

if profiler is None:
Expand All @@ -100,12 +102,15 @@ def __init__(

self.checkpoint = checkpoint
self.n_checkpoints = n_checkpoints
self.n_timesteps = n_timesteps
self.block_size = block_size
self.op_timesteps = op_timesteps
self.timings = timings
self.fwd_operator = fwd_operator
self.rev_operator = rev_operator
self.scheduler = scheduler

self.cp_timesteps = int(math.ceil(self.op_timesteps / self.block_size))

def addStorage(self, new_storage):
self.storage_list.append(new_storage)

Expand Down Expand Up @@ -171,6 +176,20 @@ def addByteStorage(self, compression_params):
def makespan(self):
return 0

@property
def op_old_capo(self):
return self.scheduler.old_capo * self.block_size

@property
def op_capo(self):
_op_capo = self.scheduler.capo * self.block_size
return _op_capo if _op_capo < self.op_timesteps else self.op_timesteps

@property
def next_op_capo(self):
_op_capo = (self.scheduler.capo + 1) * self.block_size
return _op_capo if _op_capo < self.op_timesteps else self.op_timesteps

def apply_forward(self):
"""Executes only the forward computation while storing checkpoints,
then returns."""
Expand All @@ -181,7 +200,7 @@ def apply_forward(self):
# advance forward computation
with self.profiler.get_timer("forward", "advance"):
self.fwd_operator.apply(
t_start=self.scheduler.old_capo, t_end=self.scheduler.capo
t_start=self.op_old_capo, t_end=self.op_capo
)
elif action.type == Action.TAKESHOT:
# take a snapshot: copy from workspace into storage
Expand All @@ -195,7 +214,7 @@ def apply_forward(self):
# final step in the forward computation
with self.profiler.get_timer("forward", "lastfw"):
self.fwd_operator.apply(
t_start=self.scheduler.old_capo, t_end=self.n_timesteps
t_start=self.op_old_capo, t_end=self.op_timesteps
)
break
elif action.type == Action.REVERSE:
Expand All @@ -220,10 +239,10 @@ def apply_reverse(self):
# advance adjoint computation by a single step
with self.profiler.get_timer("reverse", "reverse"):
self.fwd_operator.apply(
t_start=self.scheduler.capo, t_end=self.scheduler.capo + 1
t_start=self.op_capo, t_end=self.next_op_capo
)
self.rev_operator.apply(
t_start=self.scheduler.capo, t_end=self.scheduler.capo + 1
t_start=self.op_capo, t_end=self.next_op_capo
)
elif action.type == Action.REVSTART:
"""Sets the rev_operator to 'nt' only if its not already there.
Expand All @@ -232,7 +251,7 @@ def apply_reverse(self):
"""
with self.profiler.get_timer("reverse", "reverse"):
self.rev_operator.apply(
t_start=self.scheduler.capo, t_end=self.scheduler.capo + 1
t_start=self.op_capo, t_end=self.next_op_capo
)
elif action.type == Action.TAKESHOT:
# take a snapshot: copy from workspace into storage
Expand All @@ -242,7 +261,7 @@ def apply_reverse(self):
# advance forward computation
with self.profiler.get_timer("reverse", "advance"):
self.fwd_operator.apply(
t_start=self.scheduler.old_capo, t_end=self.scheduler.capo
t_start=self.op_old_capo, t_end=self.op_capo
)
elif action.type == Action.RESTORE:
# restore a snapshot: copy from storage into workspace
Expand Down Expand Up @@ -294,13 +313,14 @@ def __init__(
fwd_operator,
rev_operator,
n_checkpoints,
n_timesteps,
op_timesteps,
timings=None,
profiler=None,
compression_params=None,
diskstorage=False,
filedir="./",
singlefile=True,
block_size=1,
):
"""
Initializes a single-level Revolver
Expand All @@ -309,7 +329,7 @@ def __init__(
fwd_operator: forward operator
rev_operator: backward operator
n_checkpoints: number of checkpoints
n_timesteps: number of timesteps
op_timesteps: number of timesteps
timings: timings
profiler: Profiler
compression_params: compression scheme
Expand All @@ -322,20 +342,21 @@ def __init__(
fwd_operator,
rev_operator,
n_checkpoints,
n_timesteps,
op_timesteps,
timings=timings,
profiler=profiler,
block_size=block_size,
)

self.filedir = filedir
self.singlefile = singlefile

if n_checkpoints is None:
self.n_checkpoints = cr.adjust(n_timesteps)
self.n_checkpoints = cr.adjust(self.cp_timesteps)
else:
self.n_checkpoints = n_checkpoints

self.scheduler = CRevolve(self.n_checkpoints, self.n_timesteps)
self.scheduler = CRevolve(self.n_checkpoints, self.cp_timesteps)

# remove storage list to avoid memory overflow
self.resetStorageList()
Expand Down Expand Up @@ -369,13 +390,14 @@ def __init__(
checkpoint,
fwd_operator,
rev_operator,
n_timesteps,
op_timesteps,
storage_list,
timings=None,
profiler=None,
uf=1,
ub=1,
up=1,
block_size=1,
):
"""
Initializes a multi-level Revolver using HRevolve
Expand All @@ -390,7 +412,7 @@ def __init__(
fwd_operator: forward operator
rev_operator: backward operator
n_checkpoints: number of checkpoints
n_timesteps: number of timesteps
op_timesteps: number of timesteps
timings: timings
profiler: profiler
storage_list: list of storage objects
Expand All @@ -402,12 +424,14 @@ def __init__(
checkpoint,
fwd_operator,
rev_operator,
n_timesteps,
n_timesteps,
op_timesteps,
op_timesteps,
storage_list=storage_list,
timings=timings,
profiler=profiler,
block_size=block_size,
)

self.uf = uf # forward cost (default=1)
self.ub = ub # backward cost (default=1)
self.up = up # turn cost (default=1)
Expand All @@ -431,7 +455,8 @@ def reload_scheduler(self, uf=1, ub=1, up=1):
self.up = up
self.arch = Architecture(self.storage_list)
self.scheduler = HRevolve(
self.n_checkpoints, self.n_timesteps, self.arch, self.uf, self.ub, self.up
self.n_checkpoints, self.cp_timesteps, self.arch, self.uf, self.ub,
self.up
)
else:
raise ValueError(
Expand Down Expand Up @@ -471,21 +496,23 @@ def __init__(
fwd_operator,
rev_operator,
n_checkpoints,
n_timesteps,
op_timesteps,
timings=None,
profiler=None,
compression_params=None,
block_size=1,
):
super().__init__(
checkpoint,
fwd_operator,
rev_operator,
n_checkpoints,
n_timesteps,
op_timesteps,
timings=timings,
profiler=profiler,
compression_params=compression_params,
diskstorage=False,
block_size=block_size,
)


Expand All @@ -510,23 +537,25 @@ def __init__(
fwd_operator,
rev_operator,
n_checkpoints,
n_timesteps,
op_timesteps,
timings=None,
profiler=None,
filedir="./",
singlefile=True,
block_size=1,
):
super().__init__(
checkpoint,
fwd_operator,
rev_operator,
n_checkpoints,
n_timesteps,
op_timesteps,
timings=timings,
profiler=profiler,
diskstorage=True,
filedir=filedir,
singlefile=singlefile,
block_size=block_size,
)


Expand Down

0 comments on commit 28af91b

Please sign in to comment.