Skip to content

Commit

Permalink
Fix error to load data at the correct position when resuming from a c…
Browse files Browse the repository at this point in the history
…heckpoint
  • Loading branch information
Thai Chau Truong committed Mar 28, 2024
1 parent 4d45118 commit 0342592
Show file tree
Hide file tree
Showing 10 changed files with 218 additions and 18 deletions.
48 changes: 48 additions & 0 deletions .github/workflows/deploy_docs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# This is a basic workflow that is manually triggered

name: Manual workflow

# Controls when the action will run. Workflow runs when manually triggered using the UI
# or API.
on:
workflow_dispatch:
# Inputs the workflow accepts.
inputs:
name:
# Friendly description to be shown in the UI instead of 'name'
description: 'Person to greet'
# Default value if no value is explicitly provided
default: 'World'
# Input has to be provided for the workflow to run
required: true
# The data type of the input
type: string

# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
deploy-docs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install --upgrade setuptools
pip install -e .
pip install -r docs/requirements.txt
- name: Build docs
run: |
set -e
# Check that docs are built without errors
cd docs/ && make html && cd ..
- name: Deploy docs
uses: JamesIves/[email protected]
with:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
BRANCH: gh-pages
FOLDER: docs/build/html
CLEAN: true
8 changes: 4 additions & 4 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.8
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -35,10 +35,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.8
- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: 3.8
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
13 changes: 11 additions & 2 deletions onmt/inputters/dynamic_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
batch_type,
batch_size,
batch_size_multiple,
resume_corpora_info={},
data_type="text",
bucket_size=2048,
bucket_size_init=-1,
Expand All @@ -144,6 +145,7 @@ def __init__(
self.transforms = transforms
self.vocabs = vocabs
self.corpora_info = corpora_info
self.resume_corpora_info = resume_corpora_info
self.task = task
self.init_iterators = False
self.batch_size = batch_size
Expand Down Expand Up @@ -171,7 +173,8 @@ def __init__(

@classmethod
def from_opt(
cls, corpora, transforms, vocabs, opt, task, copy, device, stride=1, offset=0
cls, corpora, transforms, vocabs, opt, task, copy, device,
resume_corpora_info={}, stride=1, offset=0
):
"""Initilize `DynamicDatasetIter` with options parsed from `opt`."""
corpora_info = {}
Expand Down Expand Up @@ -206,6 +209,7 @@ def from_opt(
opt.batch_type,
batch_size,
batch_size_multiple,
resume_corpora_info=resume_corpora_info,
data_type=opt.data_type,
bucket_size=bucket_size,
bucket_size_init=bucket_size_init,
Expand Down Expand Up @@ -388,6 +392,7 @@ def build_dynamic_dataset_iter(
vocabs,
copy=False,
task=CorpusTask.TRAIN,
resume_corpora_info={},
stride=1,
offset=0,
src=None,
Expand All @@ -412,7 +417,10 @@ def build_dynamic_dataset_iter(
advance to avoid the GPU waiting during the refilling of the bucket.
"""
transforms = make_transforms(opt, transforms_cls, vocabs)
corpora = get_corpora(opt, task, src=src, tgt=tgt, align=align)
corpora = get_corpora(
opt, task, src=src, tgt=tgt, align=align,
resume_corpora_info=resume_corpora_info
)
if corpora is None:
assert task != CorpusTask.TRAIN, "only valid corpus is ignorable."
return None
Expand Down Expand Up @@ -442,6 +450,7 @@ def build_dynamic_dataset_iter(
vocabs,
opt,
task,
resume_corpora_info=resume_corpora_info,
copy=copy,
stride=stride,
offset=offset,
Expand Down
28 changes: 25 additions & 3 deletions onmt/inputters/text_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ParallelCorpus(object):
"""A parallel corpus file pair that can be loaded to iterate."""

def __init__(
self, name, src, tgt, align=None, n_src_feats=0, src_feats_defaults=None
self, name, src, tgt, align=None, n_src_feats=0, src_feats_defaults=None, resumed_line=0
):
"""Initialize src & tgt side file path."""
self.id = name
Expand All @@ -108,6 +108,12 @@ def __init__(
self.align = align
self.n_src_feats = n_src_feats
self.src_feats_defaults = src_feats_defaults
self.resumed_line = resumed_line
self.can_read_file = False

def activate_reading_mode(self, line_index):
if (line_index >= self.resumed_line):
self.can_read_file = True

def load(self, offset=0, stride=1):
"""
Expand Down Expand Up @@ -145,13 +151,19 @@ def make_ex(sline, tline, align):
for i, (sline, tline, align) in enumerate(
itertools.zip_longest(fs, ft, fa)
):
self.activate_reading_mode(line_index=i)
if not self.can_read_file:
continue
if (i // stride) % stride == offset:
yield make_ex(sline, tline, align)
else:
with exfile_open(self.src, mode="rb") as fs, exfile_open(
self.tgt, mode="rb"
) as ft, exfile_open(self.align, mode="rb") as fa:
for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)):
self.activate_reading_mode(line_index=i)
if not self.can_read_file:
continue
if (i // stride) % stride == offset:
if tline is not None:
tline = tline.decode("utf-8")
Expand All @@ -169,19 +181,28 @@ def __str__(self):
)


def get_corpora(opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None):
def get_corpora(
opts,
task=CorpusTask.TRAIN,
src=None, tgt=None, align=None,
resume_corpora_info={}
):
corpora_dict = {}
if task == CorpusTask.TRAIN:
for corpus_id, corpus_dict in opts.data.items():
if corpus_id != CorpusName.VALID:
if corpus_dict.get("path_txt", None) is None:
resume_line = 0
if (corpus_id in resume_corpora_info):
resume_line = resume_corpora_info[corpus_id]["cid_line_number"]
corpora_dict[corpus_id] = ParallelCorpus(
corpus_id,
corpus_dict["path_src"],
corpus_dict["path_tgt"],
corpus_dict["path_align"],
n_src_feats=opts.n_src_feats,
src_feats_defaults=opts.src_feats_defaults,
resumed_line=resume_line
)
else:
corpora_dict[corpus_id] = BlockwiseCorpus(
Expand Down Expand Up @@ -282,7 +303,8 @@ def __iter__(self):


def build_corpora_iters(
corpora, transforms, corpora_info, skip_empty_level="warning", stride=1, offset=0
corpora, transforms, corpora_info,
skip_empty_level="warning", stride=1, offset=0,
):
"""Return `ParallelCorpusIterator` for all corpora defined in opts."""
corpora_iters = dict()
Expand Down
107 changes: 107 additions & 0 deletions onmt/models/model_saver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch
import re
import subprocess
from collections import deque
from onmt.utils.logging import logger
from onmt.inputters.inputter import vocabs_to_dict
Expand All @@ -12,6 +13,7 @@ def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
save_model_path = os.path.abspath(opt.save_model)
os.makedirs(os.path.dirname(save_model_path), exist_ok=True)

corpora_info_updater = CorpusInfoUpdater(opts=opt)
model_saver = ModelSaver(
opt.save_model,
model,
Expand All @@ -21,6 +23,7 @@ def build_model_saver(model_opt, opt, model, vocabs, optim, device_id):
opt.keep_checkpoint,
opt.save_format,
device_id,
corpora_info_updater
)
return model_saver

Expand Down Expand Up @@ -81,6 +84,97 @@ def fix_key(s):
return checkpoint


def load_corpora_info(opts, checkpoint):
message_resume_from_beginning = \
"The training will resume from the beginning of each corpus."
# Check if resume_from_corpora is True
if not opts.resume_from_corpora:
logger.info(
"No resume from corpora is specified. " + \
message_resume_from_beginning
)
return {}

# Check if the corpus list from the last training
# and in the new training are identical.
checkpoint_corpora = checkpoint.get("data", None)
if (checkpoint_corpora is None):
logger.info(
"Incoherent info: Some corpora in the last training " + \
"and in the new list do not match. " + \
message_resume_from_beginning
)
return {}

checkpoint_corpus_names = [name for name in checkpoint_corpora]
new_corpus_names = [name for name in opts.data]
if (set(checkpoint_corpus_names) != set(new_corpus_names)):
logger.info(
"Incoherent info: Some corpora in the last training " + \
"and in the new list do not match. " + \
message_resume_from_beginning
)
return {}

# For each corpus, check if the last line number to resume
# is smaller than or equal to the number of text lines.
message_incoherent_line_number = "Incoherent info: Some text line numbers " + \
"to resume do not exist or are greater than the total numbers of text lines. " + \
message_resume_from_beginning
corpora_info = {}
for c_name, corpus in checkpoint_corpora.items():
new_corpora_info = {}
if ("cid_line_number" not in corpus):
logger.info(message_incoherent_line_number)
return {}

new_corpora_info["cid_line_number"] = corpus["cid_line_number"]
number_of_text_lines = int(
subprocess.getoutput(
"wc -l " + \
opts.data[c_name]["path_src"] + \
" | awk '{print $1}'"
)
)
if (new_corpora_info["cid_line_number"] > number_of_text_lines-1):
logger.info(message_incoherent_line_number)
return {}

corpora_info[c_name] = new_corpora_info

logger.info(
"The training will resume from the saved text line in each corpus."
)
return corpora_info


class CorpusInfoUpdater(object):
def __init__(
self,
opts=None
):
self.opts = opts

def update_corpus_info_from_batches(self, batches):
# Update the last text line of each corpus
new_corpus_info = {}
for batch in batches:
for c_name, cid_line_number in zip(batch["cid"], batch["cid_line_number"]):
if (c_name not in new_corpus_info):
new_corpus_info[c_name] = cid_line_number
else:
new_corpus_info[c_name] = max(
new_corpus_info[c_name],
cid_line_number
)
for c_name, corpus in self.opts.data.items():
if (c_name in new_corpus_info):
corpus["cid_line_number"] = new_corpus_info[c_name]

def get_corpus_info_dict(self):
return {"data": self.opts.data}


class ModelSaverBase(object):
"""Base class for model saving operations
Expand All @@ -99,6 +193,7 @@ def __init__(
keep_checkpoint=-1,
save_format="pytorch",
device_id=0,
corpora_info_updater=None
):
self.base_path = base_path
self.model = model
Expand All @@ -109,6 +204,7 @@ def __init__(
self.keep_checkpoint = keep_checkpoint
self.save_format = save_format
self.device_id = device_id
self.corpora_info_updater = corpora_info_updater

if keep_checkpoint > 0:
self.checkpoint_queue = deque([], maxlen=keep_checkpoint)
Expand Down Expand Up @@ -171,6 +267,15 @@ def _save(self, step, model):

raise NotImplementedError()

def update_corpora_info(self, batches):
if (self.corpora_info_updater is not None):
self.corpora_info_updater.update_corpus_info_from_batches(batches)

def get_corpora_info_to_save(self):
if (self.corpora_info_updater is not None):
return self.corpora_info_updater.get_corpus_info_dict()
return {}

def _rm_checkpoint(self, name):
"""Remove a checkpoint
Expand Down Expand Up @@ -267,6 +372,7 @@ def _save(self, step, model):
"opt": self.model_opt,
"optim": self.optim.state_dict(),
}
checkpoint = {**checkpoint, **self.get_corpora_info_to_save()}
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
ckpt_path = "%s_step_%d.pt" % (self.base_path, step)
Expand Down Expand Up @@ -356,6 +462,7 @@ def _st_save(self, step, model):
"opt": self.model_opt,
"optim": self.optim.state_dict(),
}
checkpoint = {**checkpoint, **self.get_corpora_info_to_save()}

if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step))
Expand Down
4 changes: 3 additions & 1 deletion onmt/modules/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def forward(self, x):
y = torch.empty_like(x)
for i, expert in enumerate(self.experts):
if torch.any(flat_expert_indices == i):
y[flat_expert_indices == i] = expert(x[flat_expert_indices == i])
y[flat_expert_indices == i] = expert(
x[flat_expert_indices == i].unsqueeze(0)
)
y = (y.view(*expert_weights.shape, -1) * expert_weights.unsqueeze(-1)).sum(
dim=1
)
Expand Down
Loading

0 comments on commit 0342592

Please sign in to comment.