From 64a246bde76401901163777bfa8d4e765766e4d8 Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Mon, 18 Mar 2024 16:33:22 +0100 Subject: [PATCH] Fix error to load data at the correct position when resuming from a checkpoint --- onmt/inputters/dynamic_iterator.py | 13 +++- onmt/inputters/text_corpus.py | 28 +++++++- onmt/models/model_saver.py | 107 +++++++++++++++++++++++++++++ onmt/opts.py | 7 ++ onmt/train_single.py | 9 ++- onmt/trainer.py | 2 + 6 files changed, 158 insertions(+), 8 deletions(-) diff --git a/onmt/inputters/dynamic_iterator.py b/onmt/inputters/dynamic_iterator.py index 8289d1a3ce..44e866a94f 100644 --- a/onmt/inputters/dynamic_iterator.py +++ b/onmt/inputters/dynamic_iterator.py @@ -129,6 +129,7 @@ def __init__( batch_type, batch_size, batch_size_multiple, + resume_corpora_info={}, data_type="text", bucket_size=2048, bucket_size_init=-1, @@ -144,6 +145,7 @@ def __init__( self.transforms = transforms self.vocabs = vocabs self.corpora_info = corpora_info + self.resume_corpora_info = resume_corpora_info self.task = task self.init_iterators = False self.batch_size = batch_size @@ -171,7 +173,8 @@ def __init__( @classmethod def from_opt( - cls, corpora, transforms, vocabs, opt, task, copy, device, stride=1, offset=0 + cls, corpora, transforms, vocabs, opt, task, copy, device, + resume_corpora_info={}, stride=1, offset=0 ): """Initilize `DynamicDatasetIter` with options parsed from `opt`.""" corpora_info = {} @@ -206,6 +209,7 @@ def from_opt( opt.batch_type, batch_size, batch_size_multiple, + resume_corpora_info=resume_corpora_info, data_type=opt.data_type, bucket_size=bucket_size, bucket_size_init=bucket_size_init, @@ -388,6 +392,7 @@ def build_dynamic_dataset_iter( vocabs, copy=False, task=CorpusTask.TRAIN, + resume_corpora_info={}, stride=1, offset=0, src=None, @@ -412,7 +417,10 @@ def build_dynamic_dataset_iter( advance to avoid the GPU waiting during the refilling of the bucket. """ transforms = make_transforms(opt, transforms_cls, vocabs) - corpora = get_corpora(opt, task, src=src, tgt=tgt, align=align) + corpora = get_corpora( + opt, task, src=src, tgt=tgt, align=align, + resume_corpora_info=resume_corpora_info + ) if corpora is None: assert task != CorpusTask.TRAIN, "only valid corpus is ignorable." return None @@ -442,6 +450,7 @@ def build_dynamic_dataset_iter( vocabs, opt, task, + resume_corpora_info=resume_corpora_info, copy=copy, stride=stride, offset=offset, diff --git a/onmt/inputters/text_corpus.py b/onmt/inputters/text_corpus.py index ca32cbbf0e..db4960ccb6 100644 --- a/onmt/inputters/text_corpus.py +++ b/onmt/inputters/text_corpus.py @@ -99,7 +99,7 @@ class ParallelCorpus(object): """A parallel corpus file pair that can be loaded to iterate.""" def __init__( - self, name, src, tgt, align=None, n_src_feats=0, src_feats_defaults=None + self, name, src, tgt, align=None, n_src_feats=0, src_feats_defaults=None, resumed_line=0 ): """Initialize src & tgt side file path.""" self.id = name @@ -108,6 +108,12 @@ def __init__( self.align = align self.n_src_feats = n_src_feats self.src_feats_defaults = src_feats_defaults + self.resumed_line = resumed_line + self.can_read_file = False + + def activate_reading_mode(self, line_index): + if (line_index >= self.resumed_line): + self.can_read_file = True def load(self, offset=0, stride=1): """ @@ -145,6 +151,9 @@ def make_ex(sline, tline, align): for i, (sline, tline, align) in enumerate( itertools.zip_longest(fs, ft, fa) ): + self.activate_reading_mode(line_index=i) + if not self.can_read_file: + continue if (i // stride) % stride == offset: yield make_ex(sline, tline, align) else: @@ -152,6 +161,9 @@ def make_ex(sline, tline, align): self.tgt, mode="rb" ) as ft, exfile_open(self.align, mode="rb") as fa: for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)): + self.activate_reading_mode(line_index=i) + if not self.can_read_file: + continue if (i // stride) % stride == offset: if tline is not None: tline = tline.decode("utf-8") @@ -169,12 +181,20 @@ def __str__(self): ) -def get_corpora(opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None): +def get_corpora( + opts, + task=CorpusTask.TRAIN, + src=None, tgt=None, align=None, + resume_corpora_info={} +): corpora_dict = {} if task == CorpusTask.TRAIN: for corpus_id, corpus_dict in opts.data.items(): if corpus_id != CorpusName.VALID: if corpus_dict.get("path_txt", None) is None: + resume_line = 0 + if (corpus_id in resume_corpora_info): + resume_line = resume_corpora_info[corpus_id]["cid_line_number"] corpora_dict[corpus_id] = ParallelCorpus( corpus_id, corpus_dict["path_src"], @@ -182,6 +202,7 @@ def get_corpora(opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None): corpus_dict["path_align"], n_src_feats=opts.n_src_feats, src_feats_defaults=opts.src_feats_defaults, + resumed_line=resume_line ) else: corpora_dict[corpus_id] = BlockwiseCorpus( @@ -282,7 +303,8 @@ def __iter__(self): def build_corpora_iters( - corpora, transforms, corpora_info, skip_empty_level="warning", stride=1, offset=0 + corpora, transforms, corpora_info, + skip_empty_level="warning", stride=1, offset=0, ): """Return `ParallelCorpusIterator` for all corpora defined in opts.""" corpora_iters = dict() diff --git a/onmt/models/model_saver.py b/onmt/models/model_saver.py index 986ca7ae99..2cc1c78c2e 100644 --- a/onmt/models/model_saver.py +++ b/onmt/models/model_saver.py @@ -1,6 +1,7 @@ import os import torch import re +import subprocess from collections import deque from onmt.utils.logging import logger from onmt.inputters.inputter import vocabs_to_dict @@ -12,6 +13,7 @@ def build_model_saver(model_opt, opt, model, vocabs, optim, device_id): save_model_path = os.path.abspath(opt.save_model) os.makedirs(os.path.dirname(save_model_path), exist_ok=True) + corpora_info_updater = CorpusInfoUpdater(opts=opt) model_saver = ModelSaver( opt.save_model, model, @@ -21,6 +23,7 @@ def build_model_saver(model_opt, opt, model, vocabs, optim, device_id): opt.keep_checkpoint, opt.save_format, device_id, + corpora_info_updater ) return model_saver @@ -81,6 +84,97 @@ def fix_key(s): return checkpoint +def load_corpora_info(opts, checkpoint): + message_resume_from_beginning = \ + "The training will resume from the beginning of each corpus." + # Check if resume_from_corpora is True + if not opts.resume_from_corpora: + logger.info( + "No resume from corpora is specified. " + \ + message_resume_from_beginning + ) + return {} + + # Check if the corpus list from the last training + # and in the new training are identical. + checkpoint_corpora = checkpoint.get("data", None) + if (checkpoint_corpora is None): + logger.info( + "Incoherent info: Some corpora in the last training " + \ + "and in the new list do not match. " + \ + message_resume_from_beginning + ) + return {} + + checkpoint_corpus_names = [name for name in checkpoint_corpora] + new_corpus_names = [name for name in opts.data] + if (set(checkpoint_corpus_names) != set(new_corpus_names)): + logger.info( + "Incoherent info: Some corpora in the last training " + \ + "and in the new list do not match. " + \ + message_resume_from_beginning + ) + return {} + + # For each corpus, check if the last line number to resume + # is smaller than or equal to the number of text lines. + message_incoherent_line_number = "Incoherent info: Some text line numbers " + \ + "to resume do not exist or are greater than the total numbers of text lines. " + \ + message_resume_from_beginning + corpora_info = {} + for c_name, corpus in checkpoint_corpora.items(): + new_corpora_info = {} + if ("cid_line_number" not in corpus): + logger.info(message_incoherent_line_number) + return {} + + new_corpora_info["cid_line_number"] = corpus["cid_line_number"] + number_of_text_lines = int( + subprocess.getoutput( + "wc -l " + \ + opts.data[c_name]["path_src"] + \ + " | awk '{print $1}'" + ) + ) + if (new_corpora_info["cid_line_number"] > number_of_text_lines-1): + logger.info(message_incoherent_line_number) + return {} + + corpora_info[c_name] = new_corpora_info + + logger.info( + "The training will resume from the saved text line in each corpus." + ) + return corpora_info + + +class CorpusInfoUpdater(object): + def __init__( + self, + opts=None + ): + self.opts = opts + + def update_corpus_info_from_batches(self, batches): + # Update the last text line of each corpus + new_corpus_info = {} + for batch in batches: + for c_name, cid_line_number in zip(batch["cid"], batch["cid_line_number"]): + if (c_name not in new_corpus_info): + new_corpus_info[c_name] = cid_line_number + else: + new_corpus_info[c_name] = max( + new_corpus_info[c_name], + cid_line_number + ) + for c_name, corpus in self.opts.data.items(): + if (c_name in new_corpus_info): + corpus["cid_line_number"] = new_corpus_info[c_name] + + def get_corpus_info_dict(self): + return {"data": self.opts.data} + + class ModelSaverBase(object): """Base class for model saving operations @@ -99,6 +193,7 @@ def __init__( keep_checkpoint=-1, save_format="pytorch", device_id=0, + corpora_info_updater=None ): self.base_path = base_path self.model = model @@ -109,6 +204,7 @@ def __init__( self.keep_checkpoint = keep_checkpoint self.save_format = save_format self.device_id = device_id + self.corpora_info_updater = corpora_info_updater if keep_checkpoint > 0: self.checkpoint_queue = deque([], maxlen=keep_checkpoint) @@ -171,6 +267,15 @@ def _save(self, step, model): raise NotImplementedError() + def update_corpora_info(self, batches): + if (self.corpora_info_updater is not None): + self.corpora_info_updater.update_corpus_info_from_batches(batches) + + def get_corpora_info_to_save(self): + if (self.corpora_info_updater is not None): + return self.corpora_info_updater.get_corpus_info_dict() + return {} + def _rm_checkpoint(self, name): """Remove a checkpoint @@ -267,6 +372,7 @@ def _save(self, step, model): "opt": self.model_opt, "optim": self.optim.state_dict(), } + checkpoint = {**checkpoint, **self.get_corpora_info_to_save()} if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step)) ckpt_path = "%s_step_%d.pt" % (self.base_path, step) @@ -356,6 +462,7 @@ def _st_save(self, step, model): "opt": self.model_opt, "optim": self.optim.state_dict(), } + checkpoint = {**checkpoint, **self.get_corpora_info_to_save()} if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step)) diff --git a/onmt/opts.py b/onmt/opts.py index 21abd96a3d..6576246aa9 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -1263,6 +1263,13 @@ def _add_train_general_opts(parser): help="If training from a checkpoint then this is the " "path to the pretrained model's state_dict.", ) + group.add( + "--resume_from_corpora", + "-resume_from_corpora", + action="store_true", + help="If training from a checkpoint and this is set to True " + " then the data generator will resume from the last line of each corpora.", + ) group.add( "--reset_optim", "-reset_optim", diff --git a/onmt/train_single.py b/onmt/train_single.py index 76ab3bef66..0eb4ca149b 100644 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -17,7 +17,7 @@ from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter from onmt.inputters.text_corpus import save_transformed_sample from onmt.model_builder import build_model -from onmt.models.model_saver import load_checkpoint +from onmt.models.model_saver import load_checkpoint, load_corpora_info from onmt.utils.optimizers import Optimizer from onmt.utils.misc import set_random_seed from onmt.trainer import build_trainer @@ -80,6 +80,7 @@ def _init_train(opt): if opt.train_from: # Load checkpoint if we resume from a previous training. checkpoint = load_checkpoint(ckpt_path=opt.train_from) + resume_corpora_info = load_corpora_info(opt, checkpoint) vocabs = dict_to_vocabs(checkpoint["vocab"]) if ( hasattr(checkpoint["opt"], "_all_transform") @@ -105,8 +106,9 @@ def _init_train(opt): else: checkpoint = None vocabs = prepare_transforms_vocabs(opt, transforms_cls) + resume_corpora_info = {} - return checkpoint, vocabs, transforms_cls + return checkpoint, resume_corpora_info, vocabs, transforms_cls def configure_process(opt, device_id): @@ -159,7 +161,7 @@ def main(opt, device_id): configure_process(opt, device_id) init_logger(opt.log_file) - checkpoint, vocabs, transforms_cls = _init_train(opt) + checkpoint, resume_corpora_info, vocabs, transforms_cls = _init_train(opt) model_opt = _get_model_opts(opt, checkpoint=checkpoint) # Build model. @@ -211,6 +213,7 @@ def main(opt, device_id): transforms_cls, vocabs, task=CorpusTask.TRAIN, + resume_corpora_info=resume_corpora_info, copy=opt.copy_attn, stride=stride, offset=offset, diff --git a/onmt/trainer.py b/onmt/trainer.py index 6916ec3ba9..7e1e4ab771 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -320,6 +320,8 @@ def train( batches, normalization, total_stats, report_stats ) + self.model_saver.update_corpora_info(batches) + if self.average_decay > 0 and i % self.average_every == 0: self._update_average(step)