-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
jizong
committed
Jul 8, 2020
1 parent
8a5b6ab
commit 358d332
Showing
22 changed files
with
670 additions
and
131 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
### Example user template | ||
.data | ||
runs | ||
demo/MNIST/ | ||
# IntelliJ project files | ||
.idea | ||
*.iml | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .scheduler_callback import * | ||
from .tensorboard_callback import * | ||
from .toconsole_callback import * | ||
from ._callback import * | ||
from .storage_callback import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import weakref | ||
|
||
|
||
class EpochCallBacks: | ||
|
||
def __init__(self, train_callbacks=None, val_callbacks=None, test_callbacks=None) -> None: | ||
self._train_callbacks = train_callbacks | ||
self._val_callbacks = val_callbacks | ||
self._test_callbacks = test_callbacks | ||
if train_callbacks: | ||
for c in self._train_callbacks: | ||
assert isinstance(c, _EpochCallack), c | ||
if val_callbacks: | ||
for c in self._val_callbacks: | ||
assert isinstance(c, _EpochCallack), c | ||
if test_callbacks: | ||
for c in self._test_callbacks: | ||
assert isinstance(c, _EpochCallack), c | ||
|
||
|
||
class _EpochCallack: | ||
""" | ||
callback for epocher | ||
""" | ||
|
||
def set_epocher(self, epocher): | ||
self._epocher = weakref.proxy(epocher) | ||
|
||
def before_run(self): | ||
pass | ||
|
||
def after_run(self, *args, **kwargs): | ||
pass | ||
|
||
def before_step(self): | ||
pass | ||
|
||
def after_step(self, *args, **kwargs): | ||
pass | ||
|
||
|
||
class _TrainerCallback: | ||
""" | ||
callbacks for trainer | ||
""" | ||
|
||
def set_trainer(self, trainer): | ||
self._trainer = weakref.proxy(trainer) | ||
|
||
def before_train(self, *args, **kwargs): | ||
pass | ||
|
||
def after_train(self, *args, **kwargs): | ||
pass | ||
|
||
def before_epoch(self, *args, **kwargs): | ||
pass | ||
|
||
def after_epoch(self, *args, **kwargs): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from ._callback import _TrainerCallback | ||
|
||
|
||
class SchedulerCallback(_TrainerCallback): | ||
|
||
def after_epoch(self, *args, **kwargs): | ||
scheduler = self._trainer._model.scheduler | ||
if scheduler: | ||
scheduler.step() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from ._callback import _TrainerCallback | ||
from ..trainer._trainer import EpochResult | ||
|
||
|
||
class StorageCallback(_TrainerCallback): | ||
def after_epoch(self, epoch_result: EpochResult = None, *args, **kwargs): | ||
if epoch_result: | ||
storage = self._trainer._storage | ||
if epoch_result.train_result: | ||
storage.put_all({"tra": epoch_result.train_result}) | ||
if epoch_result.val_result: | ||
storage.put_all({"val": epoch_result.val_result}) | ||
if epoch_result.test_result: | ||
storage.put_all({"test": epoch_result.test_result}) | ||
storage = None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from contrastyou.callbacks._callback import _TrainerCallback | ||
from contrastyou.helper import flatten_dict | ||
from contrastyou.trainer._trainer import EpochResult | ||
from contrastyou.writer import SummaryWriter | ||
|
||
|
||
class SummaryCallback(_TrainerCallback): | ||
def __init__(self, log_dir=None) -> None: | ||
self._writer = SummaryWriter(log_dir) | ||
|
||
def after_epoch(self, epoch_result: EpochResult = None, *args, **kwargs): | ||
|
||
current_epoch = self._trainer._cur_epoch | ||
if epoch_result.train_result: | ||
self._writer.add_scalar_with_tag(tag="tra", tag_scalar_dict=flatten_dict(epoch_result.train_result), | ||
global_step=current_epoch) | ||
|
||
if epoch_result.val_result: | ||
self._writer.add_scalar_with_tag(tag="val", tag_scalar_dict=flatten_dict(epoch_result.val_result), | ||
global_step=current_epoch) | ||
|
||
if epoch_result.test_result: | ||
self._writer.add_scalar_with_tag(tag="test", tag_scalar_dict=flatten_dict(epoch_result.test_result), | ||
global_step=current_epoch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import sys | ||
|
||
from tqdm import tqdm | ||
|
||
from contrastyou.callbacks._callback import _EpochCallack, _TrainerCallback | ||
from contrastyou.helper import flatten_dict, nice_dict | ||
|
||
|
||
class TQDMCallback(_EpochCallack): | ||
|
||
def __init__(self, indicator_length=0, frequency_print=10) -> None: | ||
self._indicator_length = indicator_length | ||
self._frequency_print = frequency_print | ||
|
||
def before_run(self): | ||
self._indicator = tqdm(ncols=10, leave=True, dynamic_ncols=True) | ||
if self._indicator_length > 0: | ||
self._indicator = tqdm(total=self._indicator_length, ncols=10, leave=True, dynamic_ncols=True) | ||
self._n = 0 | ||
|
||
def after_step(self, report_dict=None, *args, **kwargs): | ||
self._indicator.update(1) | ||
self._n += 1 | ||
if self._n % self._frequency_print == 0: | ||
class_name = self._epocher.__class__.__name__ | ||
current_epoch = self._epocher._cur_epoch | ||
report_dict = flatten_dict(report_dict) | ||
self._indicator.set_description(f"{class_name} Epoch {current_epoch:03d}") | ||
self._indicator.set_postfix(report_dict) | ||
sys.stdout.flush() | ||
|
||
def after_run(self, *args, **kwargs): | ||
self._indicator.close() | ||
|
||
|
||
class PrintResultCallback(_EpochCallack, _TrainerCallback): | ||
def after_run(self, report_dict=None, *args, **kwargs): | ||
if report_dict: | ||
class_name = self._epocher.__class__.__name__ | ||
cur_epoch = self._epocher._cur_epoch | ||
sys.stdout.flush() | ||
print(f"{class_name} Epoch {cur_epoch}: {nice_dict(flatten_dict(report_dict))}") | ||
sys.stdout.flush() | ||
|
||
def after_train(self, *args, **kwargs): | ||
storage = self._trainer._storage | ||
sys.stdout.flush() | ||
print(storage.summary()) | ||
sys.stdout.flush() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from abc import abstractmethod, ABCMeta | ||
from contextlib import contextmanager | ||
from typing import Union, Dict, List | ||
|
||
import torch | ||
|
||
from contrastyou.meters2 import MeterInterface | ||
from contrastyou.modules.model import Model | ||
from ..callbacks._callback import _EpochCallack | ||
|
||
_REPORT_DICT = Dict[str, Union[float, int]] | ||
REPORT_DICT = Union[_REPORT_DICT, Dict[str, _REPORT_DICT]] | ||
|
||
|
||
class _EpochMixin: | ||
|
||
def __init__(self, *args, **kwargs) -> None: | ||
super(_EpochMixin, self).__init__(*args, **kwargs) | ||
self._callbacks: List[_EpochCallack] = [] | ||
|
||
def register_callbacks(self, callbacks: List[_EpochCallack]): | ||
if not isinstance(callbacks, list): | ||
callbacks = [callbacks, ] | ||
for i, c in enumerate(callbacks): | ||
if not isinstance(c, _EpochCallack): | ||
raise TypeError(f"callbacks [{i}] should be an instance of {_EpochCallack.__name__}, " | ||
f"given {c.__class__.__name__}.") | ||
c.set_epocher(self) | ||
self._callbacks.append(c) | ||
|
||
def run(self, *args, **kwargs) -> REPORT_DICT: | ||
with self._register_meters() as self.meters: | ||
self._before_run() | ||
result = self._run(*args, **kwargs) | ||
self._after_run(report_dict=result) | ||
return result | ||
|
||
def step(self, *args, **kwargs) -> REPORT_DICT: | ||
# return accumulated dict by the design | ||
self._before_step() | ||
result = self._step(*args, **kwargs) | ||
self._after_step(report_dict=result) | ||
return result | ||
|
||
def _before_run(self): | ||
for c in self._callbacks: | ||
c.before_run() | ||
|
||
def _after_run(self, *args, **kwargs): | ||
for c in self._callbacks: | ||
c.after_run(*args, **kwargs) | ||
|
||
def _before_step(self): | ||
for c in self._callbacks: | ||
c.before_step() | ||
|
||
def _after_step(self, *args, **kwargs): | ||
for c in self._callbacks: | ||
c.after_step(*args, **kwargs) | ||
|
||
|
||
class _Epoch(metaclass=ABCMeta): | ||
|
||
def __init__(self, model: Model, cur_epoch=0, device="cpu") -> None: | ||
super().__init__() | ||
self._model = model | ||
self._device = device | ||
self._cur_epoch = cur_epoch | ||
self.to(self._device) | ||
|
||
@classmethod | ||
def create_from_trainer(cls, trainer,*args,**kwargs): | ||
model = trainer._model | ||
device = trainer._device | ||
cur_epoch = trainer._cur_epoch | ||
return cls(model, cur_epoch, device) | ||
|
||
@contextmanager | ||
def _register_meters(self): | ||
meters: MeterInterface = MeterInterface() | ||
meters = self._configure_meters(meters) | ||
yield meters | ||
|
||
@abstractmethod | ||
def _configure_meters(self, meters: MeterInterface) -> MeterInterface: | ||
# todo: to be overrided to add or delete individual meters | ||
return meters | ||
|
||
@abstractmethod | ||
def _run(self, *args, **kwargs) -> REPORT_DICT: | ||
pass | ||
|
||
def run(self, *args, **kwargs) -> REPORT_DICT: | ||
with self._register_meters() as self.meters: | ||
return self._run(*args, **kwargs) | ||
|
||
@abstractmethod | ||
def _step(self, *args, **kwargs) -> REPORT_DICT: | ||
# return accumulated dict by the design | ||
pass | ||
|
||
def step(self, *args, **kwargs) -> REPORT_DICT: | ||
# return accumulated dict by the design | ||
return self._step(*args, **kwargs) | ||
|
||
@abstractmethod | ||
def _prepare_batch(self, *args, **kwargs): | ||
pass | ||
|
||
def to(self, device="cpu"): | ||
if isinstance(device, str): | ||
device = torch.device(device) | ||
assert isinstance(device, torch.device) | ||
self._model.to(device) | ||
self._device = device | ||
|
||
def assert_model_status(self): | ||
assert self._model.training, self._model.training | ||
|
||
|
||
class Epoch(_EpochMixin, _Epoch): | ||
"""Epocher with Mixin""" | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .utils import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# dictionary helper functions | ||
import collections | ||
from typing import Union, Dict | ||
|
||
from torch.utils.data.dataloader import DataLoader, _BaseDataLoaderIter | ||
|
||
|
||
def flatten_dict(d, parent_key="", sep="_"): | ||
items = [] | ||
for k, v in d.items(): | ||
new_key = parent_key + sep + k if parent_key else k | ||
if isinstance(v, collections.MutableMapping): | ||
items.extend(flatten_dict(v, new_key, sep=sep).items()) | ||
else: | ||
items.append((new_key, v)) | ||
return dict(items) | ||
|
||
|
||
def toDataLoaderIterator(loader_or_iter: Union[DataLoader, _BaseDataLoaderIter]): | ||
if not isinstance(loader_or_iter, (_BaseDataLoaderIter, DataLoader)): | ||
raise TypeError(f"{loader_or_iter} should an instance of DataLoader or _BaseDataLoaderIter, " | ||
f"given {loader_or_iter.__class__.__name__}.") | ||
return loader_or_iter if isinstance(loader_or_iter, _BaseDataLoaderIter) else iter(loader_or_iter) | ||
|
||
# make a flatten dictionary to be printablely nice. | ||
def nice_dict(input_dict: Dict[str, Union[int, float]]) -> str: | ||
""" | ||
this function is to return a nice string to dictionary displace propose. | ||
:param input_dict: dictionary | ||
:return: string | ||
""" | ||
assert isinstance( | ||
input_dict, dict | ||
), f"{input_dict} should be a dict, given {type(input_dict)}." | ||
is_flat_dict = True | ||
for k, v in input_dict.items(): | ||
if isinstance(v, dict): | ||
is_flat_dict = False | ||
break | ||
flat_dict = input_dict if is_flat_dict else flatten_dict(input_dict, sep="") | ||
string_list = [f"{k}:{v:.3f}" for k, v in flat_dict.items()] | ||
return ", ".join(string_list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.