Skip to content

Commit

Permalink
Put the corpus info updater into a function in ModelSaver class
Browse files Browse the repository at this point in the history
  • Loading branch information
Thai Chau Truong committed Mar 31, 2024
1 parent 64a246b commit 2f74466
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 73 deletions.
23 changes: 11 additions & 12 deletions onmt/inputters/text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ParallelCorpus(object):
"""A parallel corpus file pair that can be loaded to iterate."""

def __init__(
self, name, src, tgt, align=None, n_src_feats=0, src_feats_defaults=None, resumed_line=0
self, name, src, tgt, align=None, n_src_feats=0, src_feats_defaults=None, line_number_to_resume=0
):
"""Initialize src & tgt side file path."""
self.id = name
Expand All @@ -108,11 +108,11 @@ def __init__(
self.align = align
self.n_src_feats = n_src_feats
self.src_feats_defaults = src_feats_defaults
self.resumed_line = resumed_line
self.line_number_to_resume = line_number_to_resume
self.can_read_file = False

def activate_reading_mode(self, line_index):
if (line_index >= self.resumed_line):
def activate_reading_mode(self, line_number):
if (line_number >= self.line_number_to_resume):
self.can_read_file = True

def load(self, offset=0, stride=1):
Expand All @@ -122,7 +122,7 @@ def load(self, offset=0, stride=1):
`stride` example, starting from `offset`.
"""

def make_ex(sline, tline, align):
def make_ex(sline, tline, align, line_number):
sline, sfeats = parse_features(
sline,
n_feats=self.n_src_feats,
Expand All @@ -137,6 +137,7 @@ def make_ex(sline, tline, align):
"tgt": tline,
"src_original": sline,
"tgt_original": tline,
"cid_line_number": line_number
}
if align is not None:
example["align"] = align
Expand All @@ -155,21 +156,21 @@ def make_ex(sline, tline, align):
if not self.can_read_file:
continue
if (i // stride) % stride == offset:
yield make_ex(sline, tline, align)
yield make_ex(sline, tline, align, i)
else:
with exfile_open(self.src, mode="rb") as fs, exfile_open(
self.tgt, mode="rb"
) as ft, exfile_open(self.align, mode="rb") as fa:
for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)):
self.activate_reading_mode(line_index=i)
self.activate_reading_mode(line_number=i)
if not self.can_read_file:
continue
if (i // stride) % stride == offset:
if tline is not None:
tline = tline.decode("utf-8")
if align is not None:
align = align.decode("utf-8")
yield make_ex(sline.decode("utf-8"), tline, align)
yield make_ex(sline.decode("utf-8"), tline, align, i)

def __str__(self):
cls_name = type(self).__name__
Expand All @@ -194,15 +195,15 @@ def get_corpora(
if corpus_dict.get("path_txt", None) is None:
resume_line = 0
if (corpus_id in resume_corpora_info):
resume_line = resume_corpora_info[corpus_id]["cid_line_number"]
resume_line = resume_corpora_info[corpus_id]
corpora_dict[corpus_id] = ParallelCorpus(
corpus_id,
corpus_dict["path_src"],
corpus_dict["path_tgt"],
corpus_dict["path_align"],
n_src_feats=opts.n_src_feats,
src_feats_defaults=opts.src_feats_defaults,
resumed_line=resume_line
line_number_to_resume=resume_line
)
else:
corpora_dict[corpus_id] = BlockwiseCorpus(
Expand Down Expand Up @@ -265,8 +266,6 @@ def _process(self, stream):
example["src_feats"] = [
feat.strip().split(" ") for feat in example["src_feats"]
]
line_number = i * self.stride + self.offset
example["cid_line_number"] = line_number
example["cid"] = self.cid
if "align" in example:
example["align"] = example["align"].strip().split(" ")
Expand Down
100 changes: 42 additions & 58 deletions onmt/models/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
import re
import subprocess
from collections import deque
import onmt.utils
from onmt.utils.logging import logger
from onmt.inputters.inputter import vocabs_to_dict
from onmt.modules.lora import lora_state_dict


def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
def build_model_saver(model_opt, opt, model, vocabs, optim, resume_corpora_info, device_id):
# _check_save_model_path
save_model_path = os.path.abspath(opt.save_model)
os.makedirs(os.path.dirname(save_model_path), exist_ok=True)

corpora_info_updater = CorpusInfoUpdater(opts=opt)
model_saver = ModelSaver(
opt.save_model,
model,
Expand All @@ -22,8 +22,8 @@ def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
optim,
opt.keep_checkpoint,
opt.save_format,
resume_corpora_info,
device_id,
corpora_info_updater
)
return model_saver

Expand Down Expand Up @@ -97,7 +97,7 @@ def load_corpora_info(opts, checkpoint):

# Check if the corpus list from the last training
# and in the new training are identical.
checkpoint_corpora = checkpoint.get("data", None)
checkpoint_corpora = checkpoint.get("corpus_info", None)
if (checkpoint_corpora is None):
logger.info(
"Incoherent info: Some corpora in the last training " + \
Expand All @@ -118,61 +118,30 @@ def load_corpora_info(opts, checkpoint):

# For each corpus, check if the last line number to resume
# is smaller than or equal to the number of text lines.
message_incoherent_line_number = "Incoherent info: Some text line numbers " + \
"to resume do not exist or are greater than the total numbers of text lines. " + \
message_incoherent_line_number = "Incoherent info: text line numbers " + \
"to resume in some corpora exceed their total numbers of lines. " + \
message_resume_from_beginning
corpora_info = {}
for c_name, corpus in checkpoint_corpora.items():
new_corpora_info = {}
if ("cid_line_number" not in corpus):
logger.info(message_incoherent_line_number)
return {}

new_corpora_info["cid_line_number"] = corpus["cid_line_number"]
for c_name in checkpoint_corpora:
number_of_text_lines = int(
subprocess.getoutput(
"wc -l " + \
opts.data[c_name]["path_src"] + \
" | awk '{print $1}'"
)
)
if (new_corpora_info["cid_line_number"] > number_of_text_lines-1):
if (checkpoint_corpora[c_name] > number_of_text_lines-1):
logger.info(message_incoherent_line_number)
return {}

corpora_info[c_name] = new_corpora_info
# To set the text lines to resume, we increase all text lines by 1
# (and return to the beginning if the end is reached)
checkpoint_corpora[c_name] = \
(checkpoint_corpora[c_name] + 1) % number_of_text_lines

logger.info(
"The training will resume from the saved text line in each corpus."
)
return corpora_info


class CorpusInfoUpdater(object):
def __init__(
self,
opts=None
):
self.opts = opts

def update_corpus_info_from_batches(self, batches):
# Update the last text line of each corpus
new_corpus_info = {}
for batch in batches:
for c_name, cid_line_number in zip(batch["cid"], batch["cid_line_number"]):
if (c_name not in new_corpus_info):
new_corpus_info[c_name] = cid_line_number
else:
new_corpus_info[c_name] = max(
new_corpus_info[c_name],
cid_line_number
)
for c_name, corpus in self.opts.data.items():
if (c_name in new_corpus_info):
corpus["cid_line_number"] = new_corpus_info[c_name]

def get_corpus_info_dict(self):
return {"data": self.opts.data}
return checkpoint_corpora


class ModelSaverBase(object):
Expand All @@ -192,8 +161,8 @@ def __init__(
optim,
keep_checkpoint=-1,
save_format="pytorch",
resume_corpora_info={},
device_id=0,
corpora_info_updater=None
):
self.base_path = base_path
self.model = model
Expand All @@ -203,15 +172,39 @@ def __init__(
self.last_saved_step = None
self.keep_checkpoint = keep_checkpoint
self.save_format = save_format
self.corpus_info = resume_corpora_info
self.device_id = device_id
self.corpora_info_updater = corpora_info_updater

if keep_checkpoint > 0:
self.checkpoint_queue = deque([], maxlen=keep_checkpoint)
if save_format == "safetensors":
self.model_queue = deque([], maxlen=keep_checkpoint)

def save(self, step, moving_average=None):
def update_corpus_info_from_batches(self, batches, distributed=False):
# Update the last text line of each corpus
if batches is not None:
# Gather corpus line numbers to save to checkpoints
batch_cids = sum([batch["cid"] for batch in batches], [])
batch_cid_line_numbers = sum(
[batch["cid_line_number"] for batch in batches], []
)
if distributed:
batch_cids = sum(
onmt.utils.distributed.all_gather_list(batch_cids),
[]
)
batch_cid_line_numbers = sum(
onmt.utils.distributed.all_gather_list(batch_cid_line_numbers),
[]
)
# Save the last processed line number of each corpus
new_corpus_info = {
c_name: cid_line_number \
for c_name, cid_line_number in zip(batch_cids, batch_cid_line_numbers)
}
self.corpus_info = {**self.corpus_info, **new_corpus_info}

def save(self, step, moving_average=None, batches=None, distributed=False):
"""Main entry point for model saver
It wraps the `_save` method with checks and apply `keep_checkpoint`
Expand Down Expand Up @@ -267,15 +260,6 @@ def _save(self, step, model):

raise NotImplementedError()

def update_corpora_info(self, batches):
if (self.corpora_info_updater is not None):
self.corpora_info_updater.update_corpus_info_from_batches(batches)

def get_corpora_info_to_save(self):
if (self.corpora_info_updater is not None):
return self.corpora_info_updater.get_corpus_info_dict()
return {}

def _rm_checkpoint(self, name):
"""Remove a checkpoint
Expand Down Expand Up @@ -371,8 +355,8 @@ def _save(self, step, model):
"vocab": vocabs_to_dict(self.vocabs),
"opt": self.model_opt,
"optim": self.optim.state_dict(),
"corpus_info": self.corpus_info,
}
checkpoint = {**checkpoint, **self.get_corpora_info_to_save()}
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
ckpt_path = "%s_step_%d.pt" % (self.base_path, step)
Expand Down Expand Up @@ -461,8 +445,8 @@ def _st_save(self, step, model):
"vocab": vocabs_to_dict(self.vocabs),
"opt": self.model_opt,
"optim": self.optim.state_dict(),
"corpus_info": self.corpus_info,
}
checkpoint = {**checkpoint, **self.get_corpora_info_to_save()}

if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
Expand Down
2 changes: 1 addition & 1 deletion onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def main(opt, device_id):
del checkpoint

# Build model saver
model_saver = build_model_saver(model_opt, opt, model, vocabs, optim, device_id)
model_saver = build_model_saver(model_opt, opt, model, vocabs, optim, resume_corpora_info, device_id)

trainer = build_trainer(
opt, device_id, model, vocabs, optim, model_saver=model_saver
Expand Down
7 changes: 5 additions & 2 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,6 @@ def train(
batches, normalization, total_stats, report_stats
)

self.model_saver.update_corpora_info(batches)

if self.average_decay > 0 and i % self.average_every == 0:
self._update_average(step)

Expand Down Expand Up @@ -351,6 +349,11 @@ def train(
logger.info("earlystopper has_stopped!")
break

self.model_saver.update_corpus_info_from_batches(
batches,
distributed=(self.n_gpu > 1 and self.parallel_mode == "data_parallel")
)

if self.model_saver is not None and (
save_checkpoint_steps != 0 and step % save_checkpoint_steps == 0
):
Expand Down

0 comments on commit 2f74466

Please sign in to comment.