Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorboard #19

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
48 changes: 31 additions & 17 deletions examples/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

from builtins import range
from torch.utils.tensorboard import SummaryWriter

np.random.seed(1234)

Expand Down Expand Up @@ -43,6 +44,7 @@ def oracle(data, target):

# some hyper-parameters of the experiment
use_visdom = True
use_tensorboard = True
lr = 0.01
n_epochs = 10

Expand All @@ -52,32 +54,37 @@ def oracle(data, target):

# log the hyperparameters of the experiment
if use_visdom:
plotter = mlogger.VisdomPlotter({'env': 'my_experiment', 'server': 'http://localhost', 'port': 8097},
visdom_plotter = mlogger.VisdomPlotter({'env': 'my_experiment', 'server': 'http://localhost', 'port': 8097},
manual_update=True)
else:
plotter = None
visdom_plotter = None

if use_tensorboard:
summary_writer = SummaryWriter()
else:
summary_writer = None

xp = mlogger.Container()

xp.config = mlogger.Config(plotter=plotter)
xp.config = mlogger.Config(plotter=visdom_plotter, summary_writer=summary_writer)
lberrada marked this conversation as resolved.
Show resolved Hide resolved
xp.config.update(lr=lr, n_epochs=n_epochs)

xp.epoch = mlogger.metric.Simple()

xp.train = mlogger.Container()
xp.train.acc1 = mlogger.metric.Average(plotter=plotter, plot_title="Accuracy@1", plot_legend="training")
xp.train.acck = mlogger.metric.Average(plotter=plotter, plot_title="Accuracy@k", plot_legend="training")
xp.train.loss = mlogger.metric.Average(plotter=plotter, plot_title="Objective")
xp.train.timer = mlogger.metric.Timer(plotter=plotter, plot_title="Time", plot_legend="training")
xp.train.acc1 = mlogger.metric.Average(visdom_plotter=visdom_plotter, summary_writer=summary_writer, plot_title="Accuracy@1", plot_legend="training")
xp.train.acck = mlogger.metric.Average(visdom_plotter=visdom_plotter, summary_writer=summary_writer, plot_title="Accuracy@k", plot_legend="training")
xp.train.loss = mlogger.metric.Average(visdom_plotter=visdom_plotter, summary_writer=summary_writer, plot_title="Objective")
xp.train.timer = mlogger.metric.Timer(visdom_plotter=visdom_plotter, summary_writer=summary_writer, plot_title="Time", plot_legend="training")

xp.val = mlogger.Container()
xp.val.acc1 = mlogger.metric.Average(plotter=plotter, plot_title="Accuracy@1", plot_legend="validation")
xp.val.acck = mlogger.metric.Average(plotter=plotter, plot_title="Accuracy@k", plot_legend="validation")
xp.val.timer = mlogger.metric.Timer(plotter=plotter, plot_title="Time", plot_legend="validation")
xp.val.acc1 = mlogger.metric.Average(visdom_plotter=visdom_plotter, summary_writer=summary_writer, plot_title="Accuracy@1", plot_legend="validation")
xp.val.acck = mlogger.metric.Average(visdom_plotter=visdom_plotter, summary_writer=summary_writer, plot_title="Accuracy@k", plot_legend="validation")
xp.val.timer = mlogger.metric.Timer(visdom_plotter=visdom_plotter, summary_writer=summary_writer, plot_title="Time", plot_legend="validation")

xp.val_best = mlogger.Container()
xp.val_best.acc1 = mlogger.metric.Maximum(plotter=plotter, plot_title="Accuracy@1", plot_legend="validation-best")
xp.val_best.acck = mlogger.metric.Maximum(plotter=plotter, plot_title="Accuracy@k", plot_legend="validation-best")
xp.val_best.acc1 = mlogger.metric.Maximum(visdom_plotter=visdom_plotter, summary_writer=summary_writer, plot_title="Accuracy@1", plot_legend="validation-best")
xp.val_best.acck = mlogger.metric.Maximum(visdom_plotter=visdom_plotter, summary_writer=summary_writer, plot_title="Accuracy@k", plot_legend="validation-best")


#----------------------------------------------------------
Expand Down Expand Up @@ -129,7 +136,8 @@ def oracle(data, target):
print("Prec@1: \t {0:.2f}%".format(xp.val_best.acc1.value))
print("Prec@k: \t {0:.2f}%".format(xp.val_best.acck.value))

plotter.update_plots()
if use_visdom:
visdom_plotter.update_plots()

#----------------------------------------------------------
# Save & load experiment
Expand All @@ -141,12 +149,18 @@ def oracle(data, target):

xp.save_to('state.json')

new_plotter = mlogger.VisdomPlotter(visdom_opts={'env': 'my_experiment', 'server': 'http://localhost', 'port': 8097},
manual_update=True)


new_xp = mlogger.load_container('state.json')
new_xp.plot_on(new_plotter)
new_plotter.update_plots()

if use_visdom:
new_visdom_plotter = mlogger.VisdomPlotter(visdom_opts={'env': 'my_experiment', 'server': 'http://localhost', 'port': 8097},
manual_update=True)
new_xp.plot_on_visdom(new_visdom_plotter)
new_visdom_plotter.update_plots()
if use_tensorboard:
new_summary_writer = SummaryWriter()
new_xp.plot_on_tensorboard(new_summary_writer)

print('Current train loss value: {}'.format(new_xp.train.loss.value))
new_xp.train.loss.update(2)
Expand Down
49 changes: 39 additions & 10 deletions mlogger/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,29 @@
import time
import sys
import mlogger
import warnings


class Config(object):
def __init__(self, plotter=None, plot_title=None,
get_general_info=True, get_git_info=False, **kwargs):
def __init__(self, plotter=None, plot_title=None, get_general_info=True,
get_git_info=False, visdom_plotter=None, summary_writer=None, **kwargs):

object.__setattr__(self, '_state', {})

if plotter is not None:
self.plot_on(plotter, plot_title)
else:
object.__setattr__(self, '_plotter', plotter)
object.__setattr__(self, '_plot_title', plot_title)
warnings.warn("Argument `plotter` is deprecated. Please use `visdom_plotter` instead.", FutureWarning)
visdom_plotter = plotter
del plotter

object.__setattr__(self, '_visdom_plotter', visdom_plotter)
object.__setattr__(self, '_summary_writer', summary_writer)
object.__setattr__(self, '_plot_title', plot_title)

if visdom_plotter is not None:
self.plot_on_visdom(visdom_plotter, plot_title)

if summary_writer is not None:
self.plot_on_tensorboard(summary_writer)

if get_general_info:
self.update_general_info()
Expand Down Expand Up @@ -55,8 +65,10 @@ def state_dict(self):

def update(self, **kwargs):
self._state.update(kwargs)
if self._plotter is not None:
self._plotter._update_text(self._plot_title, kwargs)
if self._visdom_plotter is not None:
self._visdom_plotter._update_text(self._plot_title, kwargs)
if self._summary_writer is not None:
self._summary_writer.add_hparams(kwargs, {})
return self

def load_state_dict(self, state):
Expand All @@ -69,16 +81,33 @@ def __repr__(self):
return repr_

def plot_on(self, plotter, plot_title):
warnings.warn("Argument `plotter` is deprecated. Please use `visdom_plotter` instead.", FutureWarning)
return self.plot_on_visdom(plotter, plot_title)

def plot_on_visdom(self, visdom_plotter, plot_title):
# plot current state
if len(self._state):
plotter._update_text(plot_title, self._state)
visdom_plotter._update_text(plot_title, self._state)

# store for future logs
object.__setattr__(self, '_plotter', plotter)
object.__setattr__(self, '_visdom_plotter', visdom_plotter)
object.__setattr__(self, '_plot_title', plot_title)

return self

def plot_on_tensorboard(self, summary_writer, plot_title=None):
if plot_title:
warnings.warn("warning argument ignored", RuntimeWarning)

# plot current state
if len(self._state):
summary_writer.add_hparams(self._state, {})

# store for future logs
object.__setattr__(self, '_summary_writer', summary_writer)

return self

def __getattr__(self, key):
if key == '_state':
return object.__getattr__(self, _state)
Expand Down
23 changes: 20 additions & 3 deletions mlogger/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
from builtins import dict
from collections import defaultdict, OrderedDict
import warnings


class Container(object):
Expand Down Expand Up @@ -73,17 +74,33 @@ def named_metrics(self):
return named_metrics_list

def plot_on(self, plotter):
warnings.warn("Argument `plotter` is deprecated. Please use `visdom_plotter` instead.", FutureWarning)
self.plot_on_visdom(plotter)

def plot_on_visdom(self, visdom_plotter):
for child in self.children():
if isinstance(child, Container):
child.plot_on(plotter)
child.plot_on_visdom(visdom_plotter)
elif isinstance(child, mlogger.Config):
plot_title = child._plot_title
child.plot_on(plotter, plot_title)
child.plot_on_visdom(visdom_plotter, plot_title)
elif isinstance(child, mlogger.metric.Base):
plot_title = child._plot_title
plot_legend = child._plot_legend
if plot_title is not None:
child.plot_on_visdom(visdom_plotter, plot_title, plot_legend)

def plot_on_tensorboard(self, summary_writer):
for child in self.children():
if isinstance(child, Container):
child.plot_on_tensorboard(summary_writer)
elif isinstance(child, mlogger.Config):
child.plot_on_tensorboard(summary_writer)
elif isinstance(child, mlogger.metric.Base):
plot_title = child._plot_title
plot_legend = child._plot_legend
if plot_title is not None:
child.plot_on(plotter, plot_title, plot_legend)
child.plot_on_tensorboard(summary_writer, plot_title, plot_legend)

def __repr__(self):
_repr = "Container()"
Expand Down
72 changes: 56 additions & 16 deletions mlogger/metric/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import time
import warnings

from collections import defaultdict, OrderedDict
from future.utils import viewitems
Expand All @@ -13,9 +14,22 @@
]


def _deprecate_plotter_argument(plotter, visdom_plotter):
if plotter is not None:
if visdom_plotter is None or visdom_plotter == plotter:
lberrada marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn("Argument `plotter` is deprecated. Please use `visdom_plotter` instead.", FutureWarning)
lberrada marked this conversation as resolved.
Show resolved Hide resolved
else:
raise ValueError("Arguments 'plotter', and 'visdom_plotter', are different and both not None")

return visdom_plotter


class Simple(Base):
def __init__(self, time_indexing=None, plotter=None, plot_title=None, plot_legend=None):
super(Simple, self).__init__(time_indexing, plotter, plot_title, plot_legend)
def __init__(self, time_indexing=None, plotter=None, plot_title=None, plot_legend=None,
visdom_plotter=None, summary_writer=None):
visdom_plotter = _deprecate_plotter_argument(plotter, visdom_plotter)
super(Simple, self).__init__(time_indexing=time_indexing, plot_title=plot_title, plot_legend=plot_legend,
visdom_plotter=visdom_plotter, summary_writer=summary_writer)

def reset(self):
self._val = 0.
Expand Down Expand Up @@ -44,8 +58,13 @@ def __repr__(self):


class TNT(Base):
def __init__(self, tnt_meter, time_indexing=None, plotter=None, plot_title=None, plot_legend=None):
super(TNT, self).__init__(time_indexing, plotter, plot_title, plot_legend)
def __init__(self, tnt_meter, time_indexing=None, plotter=None, plot_title=None,
plot_legend=None, visdom_plotter=None, summary_writer=None):

visdom_plotter = _deprecate_plotter_argument(plotter, visdom_plotter)

super(TNT, self).__init__(time_indexing=time_indexing, plot_title=plot_title, plot_legend=plot_legend,
visdom_plotter=visdom_plotter, summary_writer=summary_writer)
self._tnt_meter = tnt_meter

def reset(self):
Expand All @@ -70,8 +89,12 @@ def __repr__(self):


class Timer(Base):
def __init__(self, plotter=None, plot_title=None, plot_legend=None):
super(Timer, self).__init__(False, plotter, plot_title, plot_legend)
def __init__(self, plotter=None, plot_title=None, plot_legend=None, visdom_plotter=None, summary_writer=None):

visdom_plotter = _deprecate_plotter_argument(plotter, visdom_plotter)

super(Timer, self).__init__(time_indexing=False, plot_title=plot_title, plot_legend=plot_legend,
visdom_plotter=visdom_plotter, summary_writer=summary_writer)

def reset(self):
self.start = time.time()
Expand Down Expand Up @@ -105,8 +128,12 @@ def __repr__(self):


class Maximum(Base):
def __init__(self, time_indexing=None, plotter=None, plot_title=None, plot_legend=None):
super(Maximum, self).__init__(time_indexing, plotter, plot_title, plot_legend)
def __init__(self, time_indexing=None, plotter=None, plot_title=None, plot_legend=None,
visdom_plotter=None, summary_writer=None):

visdom_plotter = _deprecate_plotter_argument(plotter, visdom_plotter)
super(Maximum, self).__init__(time_indexing=time_indexing, plot_title=plot_title, plot_legend=plot_legend,
visdom_plotter=visdom_plotter, summary_writer=summary_writer)

def reset(self):
self._val = -np.inf
Expand Down Expand Up @@ -143,8 +170,12 @@ def __repr__(self):


class Minimum(Base):
def __init__(self, time_indexing=None, plotter=None, plot_title=None, plot_legend=None):
super(Minimum, self).__init__(time_indexing, plotter, plot_title, plot_legend)
def __init__(self, time_indexing=None, plotter=None, plot_title=None, plot_legend=None,
visdom_plotter=None, summary_writer=None):

visdom_plotter = _deprecate_plotter_argument(plotter, visdom_plotter)
super(Minimum, self).__init__(time_indexing=time_indexing, plot_title=plot_title, plot_legend=plot_legend,
visdom_plotter=visdom_plotter, summary_writer=summary_writer)

def reset(self):
self._val = np.inf
Expand Down Expand Up @@ -184,8 +215,11 @@ class Accumulator_(Base):
"""
Credits to the authors of pytorch/tnt for this.
"""
def __init__(self, time_indexing, plotter=None, plot_title=None, plot_legend=None):
super(Accumulator_, self).__init__(time_indexing, plotter, plot_title, plot_legend)
def __init__(self, time_indexing, plotter=None, plot_title=None, plot_legend=None,
visdom_plotter=None, summary_writer=None):
visdom_plotter = _deprecate_plotter_argument(plotter, visdom_plotter)
super(Accumulator_, self).__init__(time_indexing=time_indexing, plot_title=plot_title, plot_legend=plot_legend,
visdom_plotter=visdom_plotter, summary_writer=summary_writer)

def reset(self):
self._avg = 0
Expand Down Expand Up @@ -213,8 +247,11 @@ def value(self):


class Average(Accumulator_):
def __init__(self, time_indexing=None, plotter=None, plot_title=None, plot_legend=None):
super(Average, self).__init__(time_indexing, plotter, plot_title, plot_legend)
def __init__(self, time_indexing=None, plotter=None, plot_title=None, plot_legend=None,
visdom_plotter=None, summary_writer=None):
visdom_plotter = _deprecate_plotter_argument(plotter, visdom_plotter)
super(Average, self).__init__(time_indexing=time_indexing, plot_title=plot_title, plot_legend=plot_legend,
visdom_plotter=visdom_plotter, summary_writer=summary_writer)

@property
def value(self):
Expand All @@ -230,8 +267,11 @@ def __repr__(self):


class Sum(Accumulator_):
def __init__(self, time_indexing=None, plotter=None, plot_title=None, plot_legend=None):
super(Sum, self).__init__(time_indexing, plotter, plot_title, plot_legend)
def __init__(self, time_indexing=None, plotter=None, plot_title=None,
plot_legend=None, visdom_plotter=None, summary_writer=None):
visdom_plotter = _deprecate_plotter_argument(plotter, visdom_plotter)
super(Sum, self).__init__(time_indexing=time_indexing, plot_title=plot_title, plot_legend=plot_legend,
visdom_plotter=visdom_plotter, summary_writer=summary_writer)

@property
def value(self):
Expand Down
Loading