From 8e159a54233d52641efc727bcaf3e9ae84657e4e Mon Sep 17 00:00:00 2001 From: Christopher Fox <31865717+cdfox-asapp@users.noreply.github.com> Date: Mon, 30 Dec 2019 16:46:57 -0500 Subject: [PATCH] Recreate train iterator as needed (#190) * et batch via for loop rather than next() * off by 1 * use try-except instead * fix too long line * trailing whitespace * recreate train iter on StopIteration * get batch after recreating train iter * remove debug line that was added by mistake --- flambe/learn/train.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/flambe/learn/train.py b/flambe/learn/train.py index c767cf2f..4c18a892 100644 --- a/flambe/learn/train.py +++ b/flambe/learn/train.py @@ -146,9 +146,12 @@ def __init__(self, self._best_model: Dict[str, torch.Tensor] = dict() self.register_attrs('_step', '_best_metric', '_best_model') - n_epochs = math.ceil(epoch_per_step * max_steps) + self.n_epochs = math.ceil(epoch_per_step * max_steps) - self._train_iterator = self.train_sampler.sample(dataset.train, n_epochs) + self._create_train_iterator() + + def _create_train_iterator(self): + self._train_iterator = self.train_sampler.sample(self.dataset.train, self.n_epochs) def _batch_to_device(self, batch: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: """Move the current batch on the correct device. @@ -191,7 +194,11 @@ def _train_step(self) -> None: accumulated_loss = 0.0 for _ in range(self.batches_per_iter): # Get next batch - batch = next(self._train_iterator) + try: + batch = next(self._train_iterator) + except StopIteration: + self._create_train_iterator() + batch = next(self._train_iterator) batch = self._batch_to_device(batch) # Compute loss @@ -209,8 +216,10 @@ def _train_step(self) -> None: clip_grad_value_(self.model.parameters(), self.max_grad_abs_val) log(f'{tb_prefix}Training/Loss', accumulated_loss, global_step) - log(f'{tb_prefix}Training/Gradient_Norm', self.model.gradient_norm, global_step) - log(f'{tb_prefix}Training/Parameter_Norm', self.model.parameter_norm, global_step) + log(f'{tb_prefix}Training/Gradient_Norm', self.model.gradient_norm, + global_step) + log(f'{tb_prefix}Training/Parameter_Norm', self.model.parameter_norm, + global_step) # Optimize self.optimizer.step()