From 358d332c493758e8014a4e3ed6709eddff48ccea Mon Sep 17 00:00:00 2001 From: jizong Date: Wed, 8 Jul 2020 00:51:30 -0400 Subject: [PATCH] adding a lot of callbacks --- .gitignore | 1 + contrastyou/callbacks/__init__.py | 5 + contrastyou/callbacks/_callback.py | 60 +++++++ contrastyou/callbacks/scheduler_callback.py | 9 + contrastyou/callbacks/storage_callback.py | 15 ++ contrastyou/callbacks/tensorboard_callback.py | 24 +++ contrastyou/callbacks/toconsole_callback.py | 49 +++++ contrastyou/epoch/__init__.py | 0 contrastyou/epoch/_epoch.py | 123 +++++++++++++ contrastyou/helper/__init__.py | 1 + contrastyou/helper/utils.py | 42 +++++ contrastyou/meters2/meter_interface.py | 15 +- contrastyou/modules/model.py | 9 +- contrastyou/storage/_historical_container.py | 3 + contrastyou/storage/storage.py | 20 +-- contrastyou/trainer/_buffer.py | 2 +- contrastyou/trainer/_epoch.py | 54 ------ contrastyou/trainer/_trainer.py | 126 ++++++++----- contrastyou/writer/tensorboard.py | 7 +- demo/demo.py | 167 ++++++++++++++++++ demo/demo.yaml | 13 ++ test/contextmanager.py | 56 ++++++ 22 files changed, 670 insertions(+), 131 deletions(-) create mode 100644 contrastyou/callbacks/__init__.py create mode 100644 contrastyou/callbacks/_callback.py create mode 100644 contrastyou/callbacks/scheduler_callback.py create mode 100644 contrastyou/callbacks/storage_callback.py create mode 100644 contrastyou/callbacks/tensorboard_callback.py create mode 100644 contrastyou/callbacks/toconsole_callback.py create mode 100644 contrastyou/epoch/__init__.py create mode 100644 contrastyou/epoch/_epoch.py create mode 100644 contrastyou/helper/utils.py delete mode 100644 contrastyou/trainer/_epoch.py create mode 100644 demo/demo.py create mode 100644 demo/demo.yaml create mode 100644 test/contextmanager.py diff --git a/.gitignore b/.gitignore index f70be632..6a2ed010 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ ### Example user template .data runs +demo/MNIST/ # IntelliJ project files .idea *.iml diff --git a/contrastyou/callbacks/__init__.py b/contrastyou/callbacks/__init__.py new file mode 100644 index 00000000..c5c1c957 --- /dev/null +++ b/contrastyou/callbacks/__init__.py @@ -0,0 +1,5 @@ +from .scheduler_callback import * +from .tensorboard_callback import * +from .toconsole_callback import * +from ._callback import * +from .storage_callback import * \ No newline at end of file diff --git a/contrastyou/callbacks/_callback.py b/contrastyou/callbacks/_callback.py new file mode 100644 index 00000000..303d6431 --- /dev/null +++ b/contrastyou/callbacks/_callback.py @@ -0,0 +1,60 @@ +import weakref + + +class EpochCallBacks: + + def __init__(self, train_callbacks=None, val_callbacks=None, test_callbacks=None) -> None: + self._train_callbacks = train_callbacks + self._val_callbacks = val_callbacks + self._test_callbacks = test_callbacks + if train_callbacks: + for c in self._train_callbacks: + assert isinstance(c, _EpochCallack), c + if val_callbacks: + for c in self._val_callbacks: + assert isinstance(c, _EpochCallack), c + if test_callbacks: + for c in self._test_callbacks: + assert isinstance(c, _EpochCallack), c + + +class _EpochCallack: + """ + callback for epocher + """ + + def set_epocher(self, epocher): + self._epocher = weakref.proxy(epocher) + + def before_run(self): + pass + + def after_run(self, *args, **kwargs): + pass + + def before_step(self): + pass + + def after_step(self, *args, **kwargs): + pass + + +class _TrainerCallback: + """ + callbacks for trainer + """ + + def set_trainer(self, trainer): + self._trainer = weakref.proxy(trainer) + + def before_train(self, *args, **kwargs): + pass + + def after_train(self, *args, **kwargs): + pass + + def before_epoch(self, *args, **kwargs): + pass + + def after_epoch(self, *args, **kwargs): + pass diff --git a/contrastyou/callbacks/scheduler_callback.py b/contrastyou/callbacks/scheduler_callback.py new file mode 100644 index 00000000..bbc0334c --- /dev/null +++ b/contrastyou/callbacks/scheduler_callback.py @@ -0,0 +1,9 @@ +from ._callback import _TrainerCallback + + +class SchedulerCallback(_TrainerCallback): + + def after_epoch(self, *args, **kwargs): + scheduler = self._trainer._model.scheduler + if scheduler: + scheduler.step() diff --git a/contrastyou/callbacks/storage_callback.py b/contrastyou/callbacks/storage_callback.py new file mode 100644 index 00000000..5ca8e0b5 --- /dev/null +++ b/contrastyou/callbacks/storage_callback.py @@ -0,0 +1,15 @@ +from ._callback import _TrainerCallback +from ..trainer._trainer import EpochResult + + +class StorageCallback(_TrainerCallback): + def after_epoch(self, epoch_result: EpochResult = None, *args, **kwargs): + if epoch_result: + storage = self._trainer._storage + if epoch_result.train_result: + storage.put_all({"tra": epoch_result.train_result}) + if epoch_result.val_result: + storage.put_all({"val": epoch_result.val_result}) + if epoch_result.test_result: + storage.put_all({"test": epoch_result.test_result}) + storage = None diff --git a/contrastyou/callbacks/tensorboard_callback.py b/contrastyou/callbacks/tensorboard_callback.py new file mode 100644 index 00000000..56e7ffc1 --- /dev/null +++ b/contrastyou/callbacks/tensorboard_callback.py @@ -0,0 +1,24 @@ +from contrastyou.callbacks._callback import _TrainerCallback +from contrastyou.helper import flatten_dict +from contrastyou.trainer._trainer import EpochResult +from contrastyou.writer import SummaryWriter + + +class SummaryCallback(_TrainerCallback): + def __init__(self, log_dir=None) -> None: + self._writer = SummaryWriter(log_dir) + + def after_epoch(self, epoch_result: EpochResult = None, *args, **kwargs): + + current_epoch = self._trainer._cur_epoch + if epoch_result.train_result: + self._writer.add_scalar_with_tag(tag="tra", tag_scalar_dict=flatten_dict(epoch_result.train_result), + global_step=current_epoch) + + if epoch_result.val_result: + self._writer.add_scalar_with_tag(tag="val", tag_scalar_dict=flatten_dict(epoch_result.val_result), + global_step=current_epoch) + + if epoch_result.test_result: + self._writer.add_scalar_with_tag(tag="test", tag_scalar_dict=flatten_dict(epoch_result.test_result), + global_step=current_epoch) diff --git a/contrastyou/callbacks/toconsole_callback.py b/contrastyou/callbacks/toconsole_callback.py new file mode 100644 index 00000000..9053cb74 --- /dev/null +++ b/contrastyou/callbacks/toconsole_callback.py @@ -0,0 +1,49 @@ +import sys + +from tqdm import tqdm + +from contrastyou.callbacks._callback import _EpochCallack, _TrainerCallback +from contrastyou.helper import flatten_dict, nice_dict + + +class TQDMCallback(_EpochCallack): + + def __init__(self, indicator_length=0, frequency_print=10) -> None: + self._indicator_length = indicator_length + self._frequency_print = frequency_print + + def before_run(self): + self._indicator = tqdm(ncols=10, leave=True, dynamic_ncols=True) + if self._indicator_length > 0: + self._indicator = tqdm(total=self._indicator_length, ncols=10, leave=True, dynamic_ncols=True) + self._n = 0 + + def after_step(self, report_dict=None, *args, **kwargs): + self._indicator.update(1) + self._n += 1 + if self._n % self._frequency_print == 0: + class_name = self._epocher.__class__.__name__ + current_epoch = self._epocher._cur_epoch + report_dict = flatten_dict(report_dict) + self._indicator.set_description(f"{class_name} Epoch {current_epoch:03d}") + self._indicator.set_postfix(report_dict) + sys.stdout.flush() + + def after_run(self, *args, **kwargs): + self._indicator.close() + + +class PrintResultCallback(_EpochCallack, _TrainerCallback): + def after_run(self, report_dict=None, *args, **kwargs): + if report_dict: + class_name = self._epocher.__class__.__name__ + cur_epoch = self._epocher._cur_epoch + sys.stdout.flush() + print(f"{class_name} Epoch {cur_epoch}: {nice_dict(flatten_dict(report_dict))}") + sys.stdout.flush() + + def after_train(self, *args, **kwargs): + storage = self._trainer._storage + sys.stdout.flush() + print(storage.summary()) + sys.stdout.flush() diff --git a/contrastyou/epoch/__init__.py b/contrastyou/epoch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/contrastyou/epoch/_epoch.py b/contrastyou/epoch/_epoch.py new file mode 100644 index 00000000..9c3eb0a4 --- /dev/null +++ b/contrastyou/epoch/_epoch.py @@ -0,0 +1,123 @@ +from abc import abstractmethod, ABCMeta +from contextlib import contextmanager +from typing import Union, Dict, List + +import torch + +from contrastyou.meters2 import MeterInterface +from contrastyou.modules.model import Model +from ..callbacks._callback import _EpochCallack + +_REPORT_DICT = Dict[str, Union[float, int]] +REPORT_DICT = Union[_REPORT_DICT, Dict[str, _REPORT_DICT]] + + +class _EpochMixin: + + def __init__(self, *args, **kwargs) -> None: + super(_EpochMixin, self).__init__(*args, **kwargs) + self._callbacks: List[_EpochCallack] = [] + + def register_callbacks(self, callbacks: List[_EpochCallack]): + if not isinstance(callbacks, list): + callbacks = [callbacks, ] + for i, c in enumerate(callbacks): + if not isinstance(c, _EpochCallack): + raise TypeError(f"callbacks [{i}] should be an instance of {_EpochCallack.__name__}, " + f"given {c.__class__.__name__}.") + c.set_epocher(self) + self._callbacks.append(c) + + def run(self, *args, **kwargs) -> REPORT_DICT: + with self._register_meters() as self.meters: + self._before_run() + result = self._run(*args, **kwargs) + self._after_run(report_dict=result) + return result + + def step(self, *args, **kwargs) -> REPORT_DICT: + # return accumulated dict by the design + self._before_step() + result = self._step(*args, **kwargs) + self._after_step(report_dict=result) + return result + + def _before_run(self): + for c in self._callbacks: + c.before_run() + + def _after_run(self, *args, **kwargs): + for c in self._callbacks: + c.after_run(*args, **kwargs) + + def _before_step(self): + for c in self._callbacks: + c.before_step() + + def _after_step(self, *args, **kwargs): + for c in self._callbacks: + c.after_step(*args, **kwargs) + + +class _Epoch(metaclass=ABCMeta): + + def __init__(self, model: Model, cur_epoch=0, device="cpu") -> None: + super().__init__() + self._model = model + self._device = device + self._cur_epoch = cur_epoch + self.to(self._device) + + @classmethod + def create_from_trainer(cls, trainer,*args,**kwargs): + model = trainer._model + device = trainer._device + cur_epoch = trainer._cur_epoch + return cls(model, cur_epoch, device) + + @contextmanager + def _register_meters(self): + meters: MeterInterface = MeterInterface() + meters = self._configure_meters(meters) + yield meters + + @abstractmethod + def _configure_meters(self, meters: MeterInterface) -> MeterInterface: + # todo: to be overrided to add or delete individual meters + return meters + + @abstractmethod + def _run(self, *args, **kwargs) -> REPORT_DICT: + pass + + def run(self, *args, **kwargs) -> REPORT_DICT: + with self._register_meters() as self.meters: + return self._run(*args, **kwargs) + + @abstractmethod + def _step(self, *args, **kwargs) -> REPORT_DICT: + # return accumulated dict by the design + pass + + def step(self, *args, **kwargs) -> REPORT_DICT: + # return accumulated dict by the design + return self._step(*args, **kwargs) + + @abstractmethod + def _prepare_batch(self, *args, **kwargs): + pass + + def to(self, device="cpu"): + if isinstance(device, str): + device = torch.device(device) + assert isinstance(device, torch.device) + self._model.to(device) + self._device = device + + def assert_model_status(self): + assert self._model.training, self._model.training + + +class Epoch(_EpochMixin, _Epoch): + """Epocher with Mixin""" + pass diff --git a/contrastyou/helper/__init__.py b/contrastyou/helper/__init__.py index e69de29b..90f60fdd 100644 --- a/contrastyou/helper/__init__.py +++ b/contrastyou/helper/__init__.py @@ -0,0 +1 @@ +from .utils import * \ No newline at end of file diff --git a/contrastyou/helper/utils.py b/contrastyou/helper/utils.py new file mode 100644 index 00000000..fd09582a --- /dev/null +++ b/contrastyou/helper/utils.py @@ -0,0 +1,42 @@ +# dictionary helper functions +import collections +from typing import Union, Dict + +from torch.utils.data.dataloader import DataLoader, _BaseDataLoaderIter + + +def flatten_dict(d, parent_key="", sep="_"): + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, collections.MutableMapping): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def toDataLoaderIterator(loader_or_iter: Union[DataLoader, _BaseDataLoaderIter]): + if not isinstance(loader_or_iter, (_BaseDataLoaderIter, DataLoader)): + raise TypeError(f"{loader_or_iter} should an instance of DataLoader or _BaseDataLoaderIter, " + f"given {loader_or_iter.__class__.__name__}.") + return loader_or_iter if isinstance(loader_or_iter, _BaseDataLoaderIter) else iter(loader_or_iter) + +# make a flatten dictionary to be printablely nice. +def nice_dict(input_dict: Dict[str, Union[int, float]]) -> str: + """ + this function is to return a nice string to dictionary displace propose. + :param input_dict: dictionary + :return: string + """ + assert isinstance( + input_dict, dict + ), f"{input_dict} should be a dict, given {type(input_dict)}." + is_flat_dict = True + for k, v in input_dict.items(): + if isinstance(v, dict): + is_flat_dict = False + break + flat_dict = input_dict if is_flat_dict else flatten_dict(input_dict, sep="") + string_list = [f"{k}:{v:.3f}" for k, v in flat_dict.items()] + return ", ".join(string_list) \ No newline at end of file diff --git a/contrastyou/meters2/meter_interface.py b/contrastyou/meters2/meter_interface.py index ef793263..ba67d0b1 100644 --- a/contrastyou/meters2/meter_interface.py +++ b/contrastyou/meters2/meter_interface.py @@ -25,18 +25,14 @@ def tracking_status( assert group_name in self.group return { k: v.detailed_summary() if detailed_summary else v.summary() - for k, v in self.individual_meters.items() + for k, v in self.meters.items() if k in self._group_dicts[group_name] } return { k: v.detailed_summary() if detailed_summary else v.summary() - for k, v in self.individual_meters.items() + for k, v in self.meters.items() } - def add(self, meter_name, *args, **kwargs): - assert meter_name in self.meter_names - self._ind_meter_dicts[meter_name].add(*args, **kwargs) - def reset(self) -> None: """ reset individual meters @@ -63,8 +59,7 @@ def __getitem__(self, meter_name: str) -> _Metric: try: return self._ind_meter_dicts[meter_name] except KeyError as e: - print(f"meter_interface.meter_names:{self.meter_names}") - raise e + raise KeyError(f"meter_interface.meter_names:{self.meter_names} with error {e}") def register_meter(self, name: str, meter: _Metric, group_name=None) -> None: assert isinstance(name, str), name @@ -107,7 +102,3 @@ def meters(self) -> Optional[Dict[str, _Metric]]: @property def group(self) -> List[str]: return sorted(self._group_dicts.keys()) - - @property - def individual_meters(self): - return self._ind_meter_dicts diff --git a/contrastyou/modules/model.py b/contrastyou/modules/model.py index f95b33a3..0f0b9c7d 100644 --- a/contrastyou/modules/model.py +++ b/contrastyou/modules/model.py @@ -150,10 +150,10 @@ def schedulerStep(self, *args, **kwargs): self._scheduler.step(*args, **kwargs) def set_mode(self, mode): - assert mode in (ModelState.TRAIN, ModelState.EVAL) or mode in ("train", "eval") + assert (mode in (ModelState.TRAIN, ModelState.TEST)) or (mode in ("train", "val")) if mode in (ModelState.TRAIN, "train"): self.train() - elif mode in (ModelState.EVAL, "eval"): + elif mode in (ModelState.TEST, "val"): self.eval() def train(self): @@ -175,8 +175,9 @@ def apply(self, *args, **kwargs) -> None: def get_lr(self): if self._scheduler is not None: - return self._scheduler.get_lr() - return None + return self._scheduler.get_last_lr() + warnings.warn("No scheduler is found while calling for `get_lr()`") + return [0] @property def optimizer(self): diff --git a/contrastyou/storage/_historical_container.py b/contrastyou/storage/_historical_container.py index 3e2fef67..bdd0fade 100644 --- a/contrastyou/storage/_historical_container.py +++ b/contrastyou/storage/_historical_container.py @@ -5,6 +5,8 @@ import pandas as pd +from contrastyou.helper import flatten_dict + _Record_Type = Dict[str, float] _Save_Type = OrderedDict_Type[int, _Record_Type] @@ -26,6 +28,7 @@ def __init__(self) -> None: def add(self, input_dict: _Record_Type, epoch=None) -> None: # only str-num dict can be added. + input_dict = flatten_dict(input_dict) for v in input_dict.values(): assert isinstance(v, numbers.Number), v if epoch: diff --git a/contrastyou/storage/storage.py b/contrastyou/storage/storage.py index eeb37515..29964b76 100644 --- a/contrastyou/storage/storage.py +++ b/contrastyou/storage/storage.py @@ -60,17 +60,15 @@ def get(self, name, epoch=None): return self._storage[name][epoch] def summary(self) -> pd.DataFrame: - """ - summary on the list of sub summarys, merging them together. - :return: - """ - result_dict = {} - for k, v in self._storage.items(): - result_dict[k]=v.record_dict - # flatten the dict - from deepclustering.utils import flatten_dict - flatten_result = flatten_dict(result_dict) - return pd.DataFrame(flatten_result) + list_of_summary = [ + rename_df_columns(v.summary(), k) for k, v in self._storage.items() + ] + # merge the list + summary = functools.reduce( + lambda x, y: pd.merge(x, y, left_index=True, right_index=True), + list_of_summary, + ) + return pd.DataFrame(summary) @property def meter_names(self, sorted=False) -> List[str]: diff --git a/contrastyou/trainer/_buffer.py b/contrastyou/trainer/_buffer.py index 21ee759c..31516a1d 100644 --- a/contrastyou/trainer/_buffer.py +++ b/contrastyou/trainer/_buffer.py @@ -17,7 +17,7 @@ class _BufferMixin: def __init__(self) -> None: self._buffers = OrderedDict() - def register_buffer(self, name: str, value: Union[str, N]): + def _register_buffer(self, name: str, value: Union[str, N]): r"""Adds a persistent buffer to the module. """ if '_buffers' not in self.__dict__: diff --git a/contrastyou/trainer/_epoch.py b/contrastyou/trainer/_epoch.py deleted file mode 100644 index d0e197f1..00000000 --- a/contrastyou/trainer/_epoch.py +++ /dev/null @@ -1,54 +0,0 @@ -from abc import abstractmethod -from typing import Union, Dict - -from torch.utils.data import DataLoader -from torch.utils.data.dataloader import _BaseDataLoaderIter -from contrastyou.modules.model import Model -from contrastyou.meters2 import MeterInterface -from contrastyou import ModelState -from contextlib import contextmanager -class _Epoch: - - @abstractmethod - def register_meters(self): - self._meters: MeterInterface = MeterInterface() - - def write2tensorboard(self): - pass - - def _data_preprocessing(self): - pass - - def run_epoch(self): - pass - - -class TrainEpoch(_Epoch): - - def __init__(self, model:Model, train_loader: Union[DataLoader, _BaseDataLoaderIter], num_batches: int = 512) -> None: - super().__init__() - self._model = model - self._loader = train_loader - self._num_batches = num_batches - self._indicator = range(self._num_batches) - - @contextmanager - def register_meters(self): - super(TrainEpoch, self).register_meters() - self._meters.register_meter() - yield self._meters - self._meters.reset() - - - def run_epoch(self, mode=ModelState.TRAIN ) -> Dict[str, float]: - self._model.set_mode(mode) - - with self.register_meters() as meters: - pass - - - - - -class ValEpoch(_Epoch): - pass diff --git a/contrastyou/trainer/_trainer.py b/contrastyou/trainer/_trainer.py index e6eb62f4..4c87228e 100644 --- a/contrastyou/trainer/_trainer.py +++ b/contrastyou/trainer/_trainer.py @@ -1,25 +1,35 @@ from abc import abstractmethod from copy import deepcopy +from dataclasses import dataclass from pathlib import Path -from typing import Union, Dict, Any, TypeVar +from pprint import pprint +from typing import Union, Dict, Any, TypeVar, List import numpy as np import torch -from deepclustering.utils import path2Path -from deepclustering.writer import SummaryWriter from torch import Tensor from torch.utils.data import DataLoader from torch.utils.data.dataloader import _BaseDataLoaderIter from ._buffer import _BufferMixin from .. import PROJECT_PATH -# from ..meters import MeterInterface, AverageValueMeter +from ..callbacks._callback import _TrainerCallback, EpochCallBacks +from ..helper.utils import toDataLoaderIterator from ..modules.model import Model from ..storage import Storage +from ..writer.tensorboard import path2Path N = TypeVar('N', int, float, Tensor, np.ndarray) +@dataclass() +class EpochResult: + train_result: Dict[str, Dict[str, Union[int, float]]] = None + val_result: Dict[str, Dict[str, Union[int, float]]] = None + test_result: Dict[str, Dict[str, Union[int, float]]] = None + + + class _Trainer(_BufferMixin): """ Abstract class for a general trainer, which has _train_loop, _eval_loop,load_state, state_dict, and save_checkpoint @@ -36,23 +46,21 @@ def __init__( val_loader: DataLoader, max_epoch: int = 100, save_dir: str = "base", - checkpoint: str = None, device="cpu", config: dict = None, ) -> None: super(_Trainer, self).__init__() self._model = model - self._train_loader = train_loader + self._train_loader = toDataLoaderIterator(train_loader) self._val_loader = val_loader - self.register_buffer("_max_epoch", int(max_epoch)) - self.register_buffer("_best_score", -1.0) - self.register_buffer("_start_epoch", 0) # whether 0 or loaded from the checkpoint. - self.register_buffer("_epoch", None) + self._register_buffer("_max_epoch", int(max_epoch)) + self._register_buffer("_best_score", -1.0) + self._register_buffer("_start_epoch", 0) # whether 0 or loaded from the checkpoint. + self._register_buffer("_cur_epoch", 0) self._save_dir: Path = Path(self.RUN_PATH) / str(save_dir) self._save_dir.mkdir(exist_ok=True, parents=True) - self._checkpoint = checkpoint self._device = torch.device(device) if config: @@ -64,48 +72,18 @@ def __init__( def to(self, device): self._model.to(device=device) - def _start_training(self): - for epoch in range(self._start_epoch, self._max_epoch): - if self._model.get_lr() is not None: - self._meter_interface["lr"].add(self._model.get_lr()[0]) - self.train_loop(train_loader=self._train_loader, epoch=epoch) - with torch.no_grad(): - current_score = self.eval_loop(self._val_loader, epoch) - self._model.schedulerStep() - # save meters and checkpoints - self._meter_interface.step() - self.save_checkpoint(self.state_dict(), epoch, current_score) - self._meter_interface.summary().to_csv(self._save_dir / "wholeMeter.csv") - - def start_training(self): - with SummaryWriter(log_dir=self._save_dir) as self.writer: - return self._start_training() - @abstractmethod - def _train_loop( - self, - train_loader: Union[DataLoader, _BaseDataLoaderIter] = None, - epoch: int = 0, - mode=None, - *args, - **kwargs, - ): + def _start_training(self): pass - def train_loop(self, *args, **kwargs): - return self._train_loop(*args, **kwargs) + def start_training(self): + return self._start_training() - @abstractmethod - def _eval_loop( - self, - val_loader: Union[DataLoader, _BaseDataLoaderIter] = None, - epoch: int = 0, - mode=None, - ) -> float: + def _start_single_epoch(self) -> EpochResult: pass - def eval_loop(self, *args, **kwargs): - return self._eval_loop(*args, **kwargs) + def start_single_epoch(self): + return self._start_single_epoch() def inference(self, identifier="best.pth", *args, **kwargs): """ @@ -228,3 +206,57 @@ def clean_up(self, wait_time=3): shutil.rmtree(save_dir, ignore_errors=True) shutil.move(str(self._save_dir), str(save_dir)) shutil.rmtree(str(self._save_dir), ignore_errors=True) + + +class _TrainerMixin: + + def __init__(self, *args, **kwargs, ) -> None: + super().__init__(*args, **kwargs) + self._callbacks: List[_TrainerCallback] = [] + self._epoch_callbacks = EpochCallBacks(None, None, None) + + def register_callbacks(self, callbacks: List[_TrainerCallback]): + if not isinstance(callbacks, list): + callbacks = [callbacks, ] + for i, c in enumerate(callbacks): + if not isinstance(c, _TrainerCallback): + raise TypeError(f"callbacks [{i}] should be an instance of {_TrainerCallback.__name__}, " + f"given {c.__class__.__name__}.") + c.set_trainer(self) + self._callbacks.append(c) + + def register_epoch_callbacks(self, epoch_callback: EpochCallBacks = None): + if epoch_callback: + self._epoch_callbacks = epoch_callback + + def _before_train_start(self, *args, **kwargs): + for c in self._callbacks: + c.before_train(*args, **kwargs) + + def _after_train_end(self, *args, **kwargs): + for c in self._callbacks: + c.after_train(*args, **kwargs) + + def _before_epoch_start(self, *args, **kwargs): + for c in self._callbacks: + c.before_epoch(*args, **kwargs) + + def _after_epoch_end(self, *args, **kwargs): + for c in self._callbacks: + c.after_epoch(*args, **kwargs) + + def start_running(self): + self._before_train_start() + train_result = self._start_running() + self._after_train_end(train_result=train_result) + return train_result + + def start_single_epoch(self): + self._before_epoch_start() + epoch_result = self._start_single_epoch() + self._after_epoch_end(epoch_result=epoch_result) + return epoch_result + + +class Trainer(_TrainerMixin, _Trainer): + pass diff --git a/contrastyou/writer/tensorboard.py b/contrastyou/writer/tensorboard.py index 0dab83be..44df124f 100644 --- a/contrastyou/writer/tensorboard.py +++ b/contrastyou/writer/tensorboard.py @@ -3,6 +3,9 @@ from tensorboardX import SummaryWriter as _SummaryWriter +from contrastyou.callbacks._callback import _EpochCallack +from ..helper import flatten_dict + def path2Path(path) -> Path: assert isinstance(path, (str, Path)), path @@ -15,6 +18,7 @@ def __init__(self, log_dir=None, comment="", **kwargs): log_dir = path2Path(log_dir) assert log_dir.exists() and log_dir.is_dir(), log_dir super().__init__(str(log_dir / "tensorboard"), comment, **kwargs) + atexit.register(self.close) def add_scalar_with_tag( self, tag, tag_scalar_dict, global_step=None, walltime=None @@ -35,5 +39,4 @@ def add_scalar_with_tag( def __enter__(self): return self - def __exit__(self, exc_type, exc_val, exc_tb): - atexit.register(self.close) + diff --git a/demo/demo.py b/demo/demo.py new file mode 100644 index 00000000..9258a0ae --- /dev/null +++ b/demo/demo.py @@ -0,0 +1,167 @@ +# demo for the framework +from typing import Union, Callable, Optional + +import torch +from deepclustering.loss import KL_div +from haxiolearn.utils import class2one_hot +from torch import Tensor +from torch.utils.data.dataloader import DataLoader, _BaseDataLoaderIter +from torchvision.transforms import ToTensor + +from contrastyou import ModelState +from contrastyou.epoch._epoch import Epoch, REPORT_DICT +from contrastyou.meters2 import AverageValueMeter, ConfusionMatrix, MeterInterface +from contrastyou.modules.model import Model +from contrastyou.trainer._trainer import Trainer, EpochResult + + +# define epocher +class TrainEpoch(Epoch): + + def __init__(self, model: Model, loader: Union[DataLoader, _BaseDataLoaderIter], num_batches: Optional[int], + criterion: Callable[[Tensor, Tensor], Tensor], device="cpu", cur_epoch: int = 0) -> None: + super().__init__(model=model, device=device, cur_epoch=cur_epoch) + assert isinstance(loader, (DataLoader, _BaseDataLoaderIter)), type(loader) + self._loader = loader + self._num_batches = num_batches + self._criterion = criterion + + @classmethod + def create_from_trainer(cls, trainer, loader=None): + model = trainer._model + device = trainer._device + cur_epoch = trainer._cur_epoch + criterion = trainer._criterion + num_batches = trainer._num_batches + return cls(model=model, cur_epoch=cur_epoch, device=device, num_batches=num_batches, + criterion=criterion, loader=loader) + + def _configure_meters(self, meters: MeterInterface) -> MeterInterface: + meters.register_meter("lr", AverageValueMeter()) + meters.register_meter("loss", AverageValueMeter()) + meters.register_meter("confusion_mx", ConfusionMatrix(10)) + return meters + + def _run(self, mode=ModelState.TRAIN): + self._model.set_mode(mode) + self.assert_model_status() + if "lr" in self.meters.meters: + self.meters["lr"].add(self._model.get_lr()[0]) + epoch_iterator = enumerate(self._loader) + if isinstance(self._num_batches, int) and self._num_batches > 1: + epoch_iterator = zip(range(self._num_batches), self._loader) + # main loop + report_dict = None + for self._cur_batch, data in epoch_iterator: + report_dict = self.step(data) + return report_dict + + def _step(self, data) -> REPORT_DICT: + image, target = self._prepare_batch(data, self._device) + onehot_target = class2one_hot(target, 10) + predict_simplex = self._model(image.repeat(1, 3, 1, 1), force_simplex=True) + loss = self._criterion(predict_simplex, onehot_target) + self._model.zero_grad() + loss.backward() + self._model.step() + self.meters["loss"].add(loss.item()) + self.meters["confusion_mx"].add(predict_simplex.max(1)[1], target) + report_dict = self.meters.tracking_status() + return report_dict + + def _prepare_batch(self, data, device="cpu"): + return data[0].to(device), data[1].to(device) + + +class ValEpoch(TrainEpoch): + + def __init__(self, model: Model, loader: DataLoader, criterion: Callable[[Tensor, Tensor], Tensor], device="cpu", + cur_epoch=0 + ) -> None: + assert isinstance(loader, DataLoader), f"{self.__class__.__name__} requires DataLoader type loader, " \ + f"given {loader.__class__.__name__}." + super().__init__(model, loader, None, criterion, device, cur_epoch=cur_epoch) + + @classmethod + def create_from_trainer(cls, trainer, loader=None): + model = trainer._model + device = trainer._device + cur_epoch = trainer._cur_epoch + criterion = trainer._criterion + return cls(model=model, cur_epoch=cur_epoch, device=device, + criterion=criterion, loader=loader) + + def _configure_meters(self, meters: MeterInterface) -> MeterInterface: + meters = super(ValEpoch, self)._configure_meters(meters) + meters.delete_meter("lr") + return meters + + def _step(self, data) -> REPORT_DICT: + image, target = self._prepare_batch(data, self._device) + onehot_target = class2one_hot(target, 10) + predict_simplex = self._model(image.repeat(1, 3, 1, 1), force_simplex=True) + loss = self._criterion(predict_simplex, onehot_target, disable_assert=True) + self.meters["loss"].add(loss.item()) + self.meters["confusion_mx"].add(predict_simplex.max(1)[1], target) + report_dict = self.meters.tracking_status() + return report_dict + + def assert_model_status(self): + assert not self._model.training, self._model.training + + +class Trainer(Trainer): + + def __init__(self, model: Model, train_loader: Union[DataLoader, _BaseDataLoaderIter], val_loader: DataLoader, + max_epoch: int = 100, save_dir: str = "base", device="cpu", num_batches=10, + criterion: Callable[[Tensor, Tensor], Tensor] = None, + config: dict = None) -> None: + super().__init__(model, train_loader, val_loader, max_epoch, save_dir, device, config) + self._criterion = criterion + self._num_batches = num_batches + + def _start_training(self): + for self._cur_epoch in range(self._start_epoch, self._max_epoch): + epoch_result = self.start_single_epoch() + return self._storage.summary() + + def _start_single_epoch(self): + tra_epoch = TrainEpoch.create_from_trainer(self, self._train_loader) + tra_epoch.register_callbacks(self._epoch_callbacks._train_callbacks) + tra_result = tra_epoch.run(mode=ModelState.TRAIN) + with torch.no_grad(): + val_epoch = ValEpoch.create_from_trainer(self, self._val_loader) + val_epoch.register_callbacks(self._epoch_callbacks._val_callbacks) + val_result = val_epoch.run(mode=ModelState.TEST) + return EpochResult(train_result=tra_result, val_result=val_result) + + +if __name__ == '__main__': + from torchvision.models import resnet18 + from torchvision.datasets import MNIST + from contrastyou.callbacks import TQDMCallback, PrintResultCallback, SummaryCallback, SchedulerCallback, \ + EpochCallBacks, StorageCallback + + arch = resnet18(num_classes=10) + optim = torch.optim.Adam(arch.parameters()) + scheduler = torch.optim.lr_scheduler.StepLR(optim, 10, 0.1) + model = Model(arch, optim, scheduler) + dataset = MNIST(root="./", transform=ToTensor(), download=True) + val_dataset = MNIST(root="./", transform=ToTensor(), download=True, train=False) + dataloader = DataLoader(dataset, batch_size=100, num_workers=4, pin_memory=True, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=1000, num_workers=4) + trainer = Trainer(model, dataloader, val_loader, max_epoch=30, num_batches=len(dataloader), save_dir="123", device="cuda", + criterion=KL_div()) + # trainer callback + scheduler_cb = SchedulerCallback() + writer = SummaryCallback(str(trainer._save_dir)) + storage_cb = StorageCallback() + printable_callback = PrintResultCallback() + trainer.register_callbacks([scheduler_cb, writer, storage_cb, printable_callback]) + # epoch callback + printable_callback = PrintResultCallback() + tqdm_indicator = TQDMCallback(frequency_print=10) + trainer.register_epoch_callbacks( + EpochCallBacks([printable_callback, tqdm_indicator], + [printable_callback, tqdm_indicator])) + trainer.start_training() diff --git a/demo/demo.yaml b/demo/demo.yaml new file mode 100644 index 00000000..8266b39c --- /dev/null +++ b/demo/demo.yaml @@ -0,0 +1,13 @@ +Arch: + name: enet + num_classes: 4 + +Optim: + name: Adam + lr: 0.0001 + weight_decay: 0.00005 + +Scheduler: + name: StepLR + step_size: 30 + gamma: 0.1 \ No newline at end of file diff --git a/test/contextmanager.py b/test/contextmanager.py new file mode 100644 index 00000000..2aa020cb --- /dev/null +++ b/test/contextmanager.py @@ -0,0 +1,56 @@ +from abc import abstractmethod, ABCMeta +from typing import Union, Dict + +from torch.utils.data import DataLoader +from torch.utils.data.dataloader import _BaseDataLoaderIter +from contrastyou.modules.model import Model +from contrastyou.meters2 import MeterInterface, AverageValueMeter +from contrastyou import ModelState +from contextlib import contextmanager + + +class _Epoch(metaclass=ABCMeta): + + @contextmanager + def register_meters(self): + meters: MeterInterface = MeterInterface() + meters = self._configure_meters(meters) + yield meters + + @abstractmethod + def _configure_meters(self, meters: MeterInterface): + # to be overrided to add or delete individual meters + return meters + + +class TrainEpoch(_Epoch): + + def _configure_meters(self, meters: MeterInterface): + meters.register_meter("meter1", AverageValueMeter()) + return meters + + def _run_epoch(self, *args, **kwargs): + with self.register_meters() as meters: + print(meters.meter_names) + meters["meter1"].add(1) + meters["meter1"].add(2) + print(meters.tracking_status()) + + def run_epoch(self, *args, **kwargs): + return self._run_epoch(*args, **kwargs) + + +class ValEpoch(TrainEpoch): + + def _configure_meters(self, meters: MeterInterface): + meters = super(ValEpoch, self)._configure_meters(meters) + meters.register_meter("meter2", AverageValueMeter()) + return meters + + +if __name__ == '__main__': + tepoch = TrainEpoch() + tepoch.run_epoch() + + vepoch = ValEpoch() + vepoch.run_epoch()