Skip to content

Commit

Permalink
Recreate train iterator as needed (#190)
Browse files Browse the repository at this point in the history
* et batch via for loop rather than next()

* off by 1

* use try-except instead

* fix too long line

* trailing whitespace

* recreate train iter on StopIteration

* get batch after recreating train iter

* remove debug line that was added by mistake
  • Loading branch information
cdfox-asapp authored and jeremyasapp committed Dec 30, 2019
1 parent 81be4e9 commit 8e159a5
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions flambe/learn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,12 @@ def __init__(self,
self._best_model: Dict[str, torch.Tensor] = dict()
self.register_attrs('_step', '_best_metric', '_best_model')

n_epochs = math.ceil(epoch_per_step * max_steps)
self.n_epochs = math.ceil(epoch_per_step * max_steps)

self._train_iterator = self.train_sampler.sample(dataset.train, n_epochs)
self._create_train_iterator()

def _create_train_iterator(self):
self._train_iterator = self.train_sampler.sample(self.dataset.train, self.n_epochs)

def _batch_to_device(self, batch: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]:
"""Move the current batch on the correct device.
Expand Down Expand Up @@ -191,7 +194,11 @@ def _train_step(self) -> None:
accumulated_loss = 0.0
for _ in range(self.batches_per_iter):
# Get next batch
batch = next(self._train_iterator)
try:
batch = next(self._train_iterator)
except StopIteration:
self._create_train_iterator()
batch = next(self._train_iterator)
batch = self._batch_to_device(batch)

# Compute loss
Expand All @@ -209,8 +216,10 @@ def _train_step(self) -> None:
clip_grad_value_(self.model.parameters(), self.max_grad_abs_val)

log(f'{tb_prefix}Training/Loss', accumulated_loss, global_step)
log(f'{tb_prefix}Training/Gradient_Norm', self.model.gradient_norm, global_step)
log(f'{tb_prefix}Training/Parameter_Norm', self.model.parameter_norm, global_step)
log(f'{tb_prefix}Training/Gradient_Norm', self.model.gradient_norm,
global_step)
log(f'{tb_prefix}Training/Parameter_Norm', self.model.parameter_norm,
global_step)

# Optimize
self.optimizer.step()
Expand Down

0 comments on commit 8e159a5

Please sign in to comment.