Skip to content

Commit

Permalink
adding a lot of callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
jizong committed Jul 8, 2020
1 parent 8a5b6ab commit 358d332
Show file tree
Hide file tree
Showing 22 changed files with 670 additions and 131 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
### Example user template
.data
runs
demo/MNIST/
# IntelliJ project files
.idea
*.iml
Expand Down
5 changes: 5 additions & 0 deletions contrastyou/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .scheduler_callback import *
from .tensorboard_callback import *
from .toconsole_callback import *
from ._callback import *
from .storage_callback import *
60 changes: 60 additions & 0 deletions contrastyou/callbacks/_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import weakref


class EpochCallBacks:

def __init__(self, train_callbacks=None, val_callbacks=None, test_callbacks=None) -> None:
self._train_callbacks = train_callbacks
self._val_callbacks = val_callbacks
self._test_callbacks = test_callbacks
if train_callbacks:
for c in self._train_callbacks:
assert isinstance(c, _EpochCallack), c
if val_callbacks:
for c in self._val_callbacks:
assert isinstance(c, _EpochCallack), c
if test_callbacks:
for c in self._test_callbacks:
assert isinstance(c, _EpochCallack), c


class _EpochCallack:
"""
callback for epocher
"""

def set_epocher(self, epocher):
self._epocher = weakref.proxy(epocher)

def before_run(self):
pass

def after_run(self, *args, **kwargs):
pass

def before_step(self):
pass

def after_step(self, *args, **kwargs):
pass


class _TrainerCallback:
"""
callbacks for trainer
"""

def set_trainer(self, trainer):
self._trainer = weakref.proxy(trainer)

def before_train(self, *args, **kwargs):
pass

def after_train(self, *args, **kwargs):
pass

def before_epoch(self, *args, **kwargs):
pass

def after_epoch(self, *args, **kwargs):
pass
9 changes: 9 additions & 0 deletions contrastyou/callbacks/scheduler_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ._callback import _TrainerCallback


class SchedulerCallback(_TrainerCallback):

def after_epoch(self, *args, **kwargs):
scheduler = self._trainer._model.scheduler
if scheduler:
scheduler.step()
15 changes: 15 additions & 0 deletions contrastyou/callbacks/storage_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from ._callback import _TrainerCallback
from ..trainer._trainer import EpochResult


class StorageCallback(_TrainerCallback):
def after_epoch(self, epoch_result: EpochResult = None, *args, **kwargs):
if epoch_result:
storage = self._trainer._storage
if epoch_result.train_result:
storage.put_all({"tra": epoch_result.train_result})
if epoch_result.val_result:
storage.put_all({"val": epoch_result.val_result})
if epoch_result.test_result:
storage.put_all({"test": epoch_result.test_result})
storage = None
24 changes: 24 additions & 0 deletions contrastyou/callbacks/tensorboard_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from contrastyou.callbacks._callback import _TrainerCallback
from contrastyou.helper import flatten_dict
from contrastyou.trainer._trainer import EpochResult
from contrastyou.writer import SummaryWriter


class SummaryCallback(_TrainerCallback):
def __init__(self, log_dir=None) -> None:
self._writer = SummaryWriter(log_dir)

def after_epoch(self, epoch_result: EpochResult = None, *args, **kwargs):

current_epoch = self._trainer._cur_epoch
if epoch_result.train_result:
self._writer.add_scalar_with_tag(tag="tra", tag_scalar_dict=flatten_dict(epoch_result.train_result),
global_step=current_epoch)

if epoch_result.val_result:
self._writer.add_scalar_with_tag(tag="val", tag_scalar_dict=flatten_dict(epoch_result.val_result),
global_step=current_epoch)

if epoch_result.test_result:
self._writer.add_scalar_with_tag(tag="test", tag_scalar_dict=flatten_dict(epoch_result.test_result),
global_step=current_epoch)
49 changes: 49 additions & 0 deletions contrastyou/callbacks/toconsole_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import sys

from tqdm import tqdm

from contrastyou.callbacks._callback import _EpochCallack, _TrainerCallback
from contrastyou.helper import flatten_dict, nice_dict


class TQDMCallback(_EpochCallack):

def __init__(self, indicator_length=0, frequency_print=10) -> None:
self._indicator_length = indicator_length
self._frequency_print = frequency_print

def before_run(self):
self._indicator = tqdm(ncols=10, leave=True, dynamic_ncols=True)
if self._indicator_length > 0:
self._indicator = tqdm(total=self._indicator_length, ncols=10, leave=True, dynamic_ncols=True)
self._n = 0

def after_step(self, report_dict=None, *args, **kwargs):
self._indicator.update(1)
self._n += 1
if self._n % self._frequency_print == 0:
class_name = self._epocher.__class__.__name__
current_epoch = self._epocher._cur_epoch
report_dict = flatten_dict(report_dict)
self._indicator.set_description(f"{class_name} Epoch {current_epoch:03d}")
self._indicator.set_postfix(report_dict)
sys.stdout.flush()

def after_run(self, *args, **kwargs):
self._indicator.close()


class PrintResultCallback(_EpochCallack, _TrainerCallback):
def after_run(self, report_dict=None, *args, **kwargs):
if report_dict:
class_name = self._epocher.__class__.__name__
cur_epoch = self._epocher._cur_epoch
sys.stdout.flush()
print(f"{class_name} Epoch {cur_epoch}: {nice_dict(flatten_dict(report_dict))}")
sys.stdout.flush()

def after_train(self, *args, **kwargs):
storage = self._trainer._storage
sys.stdout.flush()
print(storage.summary())
sys.stdout.flush()
Empty file added contrastyou/epoch/__init__.py
Empty file.
123 changes: 123 additions & 0 deletions contrastyou/epoch/_epoch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from abc import abstractmethod, ABCMeta
from contextlib import contextmanager
from typing import Union, Dict, List

import torch

from contrastyou.meters2 import MeterInterface
from contrastyou.modules.model import Model
from ..callbacks._callback import _EpochCallack

_REPORT_DICT = Dict[str, Union[float, int]]
REPORT_DICT = Union[_REPORT_DICT, Dict[str, _REPORT_DICT]]


class _EpochMixin:

def __init__(self, *args, **kwargs) -> None:
super(_EpochMixin, self).__init__(*args, **kwargs)
self._callbacks: List[_EpochCallack] = []

def register_callbacks(self, callbacks: List[_EpochCallack]):
if not isinstance(callbacks, list):
callbacks = [callbacks, ]
for i, c in enumerate(callbacks):
if not isinstance(c, _EpochCallack):
raise TypeError(f"callbacks [{i}] should be an instance of {_EpochCallack.__name__}, "
f"given {c.__class__.__name__}.")
c.set_epocher(self)
self._callbacks.append(c)

def run(self, *args, **kwargs) -> REPORT_DICT:
with self._register_meters() as self.meters:
self._before_run()
result = self._run(*args, **kwargs)
self._after_run(report_dict=result)
return result

def step(self, *args, **kwargs) -> REPORT_DICT:
# return accumulated dict by the design
self._before_step()
result = self._step(*args, **kwargs)
self._after_step(report_dict=result)
return result

def _before_run(self):
for c in self._callbacks:
c.before_run()

def _after_run(self, *args, **kwargs):
for c in self._callbacks:
c.after_run(*args, **kwargs)

def _before_step(self):
for c in self._callbacks:
c.before_step()

def _after_step(self, *args, **kwargs):
for c in self._callbacks:
c.after_step(*args, **kwargs)


class _Epoch(metaclass=ABCMeta):

def __init__(self, model: Model, cur_epoch=0, device="cpu") -> None:
super().__init__()
self._model = model
self._device = device
self._cur_epoch = cur_epoch
self.to(self._device)

@classmethod
def create_from_trainer(cls, trainer,*args,**kwargs):
model = trainer._model
device = trainer._device
cur_epoch = trainer._cur_epoch
return cls(model, cur_epoch, device)

@contextmanager
def _register_meters(self):
meters: MeterInterface = MeterInterface()
meters = self._configure_meters(meters)
yield meters

@abstractmethod
def _configure_meters(self, meters: MeterInterface) -> MeterInterface:
# todo: to be overrided to add or delete individual meters
return meters

@abstractmethod
def _run(self, *args, **kwargs) -> REPORT_DICT:
pass

def run(self, *args, **kwargs) -> REPORT_DICT:
with self._register_meters() as self.meters:
return self._run(*args, **kwargs)

@abstractmethod
def _step(self, *args, **kwargs) -> REPORT_DICT:
# return accumulated dict by the design
pass

def step(self, *args, **kwargs) -> REPORT_DICT:
# return accumulated dict by the design
return self._step(*args, **kwargs)

@abstractmethod
def _prepare_batch(self, *args, **kwargs):
pass

def to(self, device="cpu"):
if isinstance(device, str):
device = torch.device(device)
assert isinstance(device, torch.device)
self._model.to(device)
self._device = device

def assert_model_status(self):
assert self._model.training, self._model.training


class Epoch(_EpochMixin, _Epoch):
"""Epocher with Mixin"""
pass
1 change: 1 addition & 0 deletions contrastyou/helper/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .utils import *
42 changes: 42 additions & 0 deletions contrastyou/helper/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# dictionary helper functions
import collections
from typing import Union, Dict

from torch.utils.data.dataloader import DataLoader, _BaseDataLoaderIter


def flatten_dict(d, parent_key="", sep="_"):
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)


def toDataLoaderIterator(loader_or_iter: Union[DataLoader, _BaseDataLoaderIter]):
if not isinstance(loader_or_iter, (_BaseDataLoaderIter, DataLoader)):
raise TypeError(f"{loader_or_iter} should an instance of DataLoader or _BaseDataLoaderIter, "
f"given {loader_or_iter.__class__.__name__}.")
return loader_or_iter if isinstance(loader_or_iter, _BaseDataLoaderIter) else iter(loader_or_iter)

# make a flatten dictionary to be printablely nice.
def nice_dict(input_dict: Dict[str, Union[int, float]]) -> str:
"""
this function is to return a nice string to dictionary displace propose.
:param input_dict: dictionary
:return: string
"""
assert isinstance(
input_dict, dict
), f"{input_dict} should be a dict, given {type(input_dict)}."
is_flat_dict = True
for k, v in input_dict.items():
if isinstance(v, dict):
is_flat_dict = False
break
flat_dict = input_dict if is_flat_dict else flatten_dict(input_dict, sep="")
string_list = [f"{k}:{v:.3f}" for k, v in flat_dict.items()]
return ", ".join(string_list)
15 changes: 3 additions & 12 deletions contrastyou/meters2/meter_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,14 @@ def tracking_status(
assert group_name in self.group
return {
k: v.detailed_summary() if detailed_summary else v.summary()
for k, v in self.individual_meters.items()
for k, v in self.meters.items()
if k in self._group_dicts[group_name]
}
return {
k: v.detailed_summary() if detailed_summary else v.summary()
for k, v in self.individual_meters.items()
for k, v in self.meters.items()
}

def add(self, meter_name, *args, **kwargs):
assert meter_name in self.meter_names
self._ind_meter_dicts[meter_name].add(*args, **kwargs)

def reset(self) -> None:
"""
reset individual meters
Expand All @@ -63,8 +59,7 @@ def __getitem__(self, meter_name: str) -> _Metric:
try:
return self._ind_meter_dicts[meter_name]
except KeyError as e:
print(f"meter_interface.meter_names:{self.meter_names}")
raise e
raise KeyError(f"meter_interface.meter_names:{self.meter_names} with error {e}")

def register_meter(self, name: str, meter: _Metric, group_name=None) -> None:
assert isinstance(name, str), name
Expand Down Expand Up @@ -107,7 +102,3 @@ def meters(self) -> Optional[Dict[str, _Metric]]:
@property
def group(self) -> List[str]:
return sorted(self._group_dicts.keys())

@property
def individual_meters(self):
return self._ind_meter_dicts
Loading

0 comments on commit 358d332

Please sign in to comment.