Skip to content

Commit

Permalink
WIP: translation is broken
Browse files Browse the repository at this point in the history
  • Loading branch information
Waino committed Aug 12, 2024
1 parent b69750c commit e6f564a
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 70 deletions.
45 changes: 41 additions & 4 deletions mammoth/bin/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from mammoth.transforms import get_transforms_cls, make_transforms, TransformPipe

import mammoth.opts as opts
from mammoth.distributed import TaskSpecs
from mammoth.distributed import TaskSpecs, TaskQueueManager
from mammoth.distributed.contexts import WorldContext, DeviceContextEnum
from mammoth.distributed.tasks import get_adapter_ids
from mammoth.utils.parse import ArgumentParser
from mammoth.utils.misc import use_gpu


def translate(opts):
Expand All @@ -26,12 +28,24 @@ def translate(opts):
if 'adapters' in corpus_opts:
encoder_adapter_ids = get_adapter_ids(opts, corpus_opts, 'encoder')
decoder_adapter_ids = get_adapter_ids(opts, corpus_opts, 'decoder')
uses_adapters = True
else:
encoder_adapter_ids = None
decoder_adapter_ids = None
uses_adapters = False

node_rank = 0
local_rank = 0
if use_gpu(opts):
context_enum = DeviceContextEnum.SINGLE_GPU
gpus_per_node = 1
else:
context_enum = DeviceContextEnum.CPU
gpus_per_node = 0

task = TaskSpecs(
node_rank=None,
local_rank=None,
node_rank=node_rank,
local_rank=local_rank,
src_lang=src_lang,
tgt_lang=tgt_lang,
encoder_id=encoder_id,
Expand All @@ -46,7 +60,30 @@ def translate(opts):
decoder_adapter_ids=decoder_adapter_ids,
)

translator = build_translator(opts, task, logger=logger, report_score=True)
world_context = WorldContext(
context=context_enum,
n_nodes=1,
gpus_per_node=gpus_per_node,
)

task_queue_manager = TaskQueueManager(
tasks=[task],
accum_count=1,
world_context=world_context,
task_distribution_strategy_cls=None,
uses_adapters=uses_adapters,
).global_to_local(
node_rank=node_rank,
local_rank=local_rank,
opts=opts,
)
# FIXME: fix the attention bridge in translation
task_queue_manager.create_all_distributed_components(
use_attention_bridge=False, # (opts.ab_layers is not None and len(opts.ab_layers) != 0),
new_group_func=lambda: None,
)

translator = build_translator(opts, task_queue_manager, task, logger=logger, report_score=True)

# data_reader = InferenceDataReader(opts.src, opts.tgt, opts.src_feats)
src_shards = split_corpus(opts.src, opts.shard_size)
Expand Down
5 changes: 4 additions & 1 deletion mammoth/distributed/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,10 @@ def global_to_local(self, node_rank, local_rank, opts):
assert node_rank is not None
assert local_rank is not None
device_context = self.world_context.global_to_local(node_rank, local_rank)
task_distribution_strategy = self.task_distribution_strategy_cls(seed=opts.seed)
if self.task_distribution_strategy_cls is not None:
task_distribution_strategy = self.task_distribution_strategy_cls(seed=opts.seed)
else:
task_distribution_strategy = None
return LocalTaskQueueManager(
self.tasks,
accum_count=self.accum_count,
Expand Down
45 changes: 0 additions & 45 deletions mammoth/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,51 +47,6 @@ def uses_adapters(opts):
return 'adapters' in opts and opts.adapters


def load_test_multitask_model(opts, task_queue_manager, task=None, model_path=None):
if task is None:
raise ValueError('Must set task')
if model_path is None:
model_path = opts.models[0]

# Load only the frame
frame, ckpt_path = load_frame_checkpoint(ckpt_path=opts.train_from)

vocabs_dict = {
'src': frame["vocab"].get(('src', task.src_lang)),
'tgt': frame["vocab"].get(('tgt', task.tgt_lang)),
}

model_opts = ArgumentParser.ckpt_model_opts(frame['opts'])
# Avoid functionality on inference
# model_opts.update_vocab = False
model = build_model(
model_opts,
opts,
vocabs_dict,
task_queue_manager,
single_task=task.corpus_id,
)

# FIXME: load the model parameters

model_params = {name for name, p in model.named_parameters()}
model_params.update(name for name, p in model.named_buffers())
for key in set(combined_state_dict.keys()):
if key not in model_params:
print(f'Deleting unnecessary key: {key}')
del combined_state_dict[key]
for key in model_params:
if key not in combined_state_dict:
print(f'Key missing {key}')
model.load_state_dict(combined_state_dict)
device = torch.device("cuda" if use_gpu(opts) else "cpu")
model.to(device)

model.eval()

return vocabs_dict, model, model_opts


def get_attention_layers_kwargs(
side: Side,
layer_stack_index,
Expand Down
82 changes: 66 additions & 16 deletions mammoth/translate/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,21 @@

import torch

import mammoth.model_builder
import mammoth.modules.decoder_ensemble
# from mammoth.inputters.text_dataset import InferenceDataIterator
from mammoth.translate.beam_search import BeamSearch, BeamSearchLM
from mammoth.translate.greedy_search import GreedySearch, GreedySearchLM
from mammoth.utils.misc import tile, set_random_seed, report_matrix
from mammoth.utils.alignment import extract_alignment, build_align_pharaoh
from mammoth.constants import ModelTask, DefaultTokens
from mammoth.inputters.dataset import ParallelCorpus
from mammoth.inputters.dataloader import build_dataloader
from mammoth.inputters.dataset import ParallelCorpus
from mammoth.model_builder import build_model
from mammoth.translate.beam_search import BeamSearch, BeamSearchLM, GNMTGlobalScorer
from mammoth.translate.greedy_search import GreedySearch, GreedySearchLM
from mammoth.translate.translation import TranslationBuilder
from mammoth.utils.alignment import extract_alignment, build_align_pharaoh
from mammoth.utils.misc import tile, set_random_seed, report_matrix, use_gpu
from mammoth.utils.model_saver import load_frame_checkpoint, load_parameters_from_checkpoint
from mammoth.utils.parse import ArgumentParser


def build_translator(opts, task, report_score=True, logger=None, out_file=None):
def build_translator(opts, task_queue_manager, task, report_score=True, logger=None, out_file=None):
if out_file is None:
outdir = os.path.dirname(opts.output)
if outdir and not os.path.isdir(outdir):
Expand All @@ -29,15 +31,19 @@ def build_translator(opts, task, report_score=True, logger=None, out_file=None):
os.makedirs(os.path.dirname(opts.output), exist_ok=True)
out_file = codecs.open(opts.output, "w+", "utf-8")

load_test_model = (
mammoth.modules.decoder_ensemble.load_test_model if len(opts.models) > 3
else mammoth.model_builder.load_test_multitask_model
)
# TODO: reimplement ensemble decoding
load_model_for_translation_func = load_model_for_translation
if logger:
logger.info(str(task))
vocabs, model, model_opts = load_test_model(opts, task)
model_path = None
vocabs, model, model_opts = load_model_for_translation_func(
opts=opts,
task_queue_manager=task_queue_manager,
task=task,
model_path=model_path,
)

scorer = mammoth.translate.GNMTGlobalScorer.from_opts(opts)
scorer = GNMTGlobalScorer.from_opts(opts)

translator = Translator.from_opts(
model,
Expand All @@ -54,6 +60,49 @@ def build_translator(opts, task, report_score=True, logger=None, out_file=None):
return translator


def load_model_for_translation(opts, task_queue_manager, task=None, model_path=None):
if task is None:
raise ValueError('Must set task')
if model_path is None:
model_path = opts.models[0]

# Load only the frame
frame, frame_ckpt_path = load_frame_checkpoint(ckpt_path=model_path)

vocabs_dict = {
('src', task.src_lang): frame["vocab"].get(('src', task.src_lang)),
('tgt', task.tgt_lang): frame["vocab"].get(('tgt', task.tgt_lang)),
'src': frame["vocab"].get(('src', task.src_lang)),
'tgt': frame["vocab"].get(('tgt', task.tgt_lang)),
}
print(f'vocabs_dict {vocabs_dict}')
print(f'my compontents {task_queue_manager.get_my_distributed_components()}')

model_opts = ArgumentParser.ckpt_model_opts(frame['opts'])

model = build_model(
model_opts,
opts,
vocabs_dict,
task_queue_manager,
single_task=task.corpus_id,
)

load_parameters_from_checkpoint(
frame_ckpt_path,
model,
optim=None,
task_queue_manager=task_queue_manager,
reset_optim=True,
)

device = torch.device("cuda" if use_gpu(opts) else "cpu")
model.to(device)
model.eval()

return vocabs_dict, model, model_opts


def max_tok_len(new, count, sofar):
"""
In token batching scheme, the number of sequences is limited
Expand Down Expand Up @@ -153,7 +202,7 @@ def __init__(

self.model = model
self.vocabs = vocabs
tgt_vocab = dict(self.vocabs)["tgt"]
tgt_vocab = dict(self.vocabs)[("tgt", task.tgt_lang)]
self._tgt_vocab = tgt_vocab
self._tgt_eos_idx = self._tgt_vocab.stoi[DefaultTokens.EOS]
self._tgt_pad_idx = self._tgt_vocab.stoi[DefaultTokens.PAD]
Expand Down Expand Up @@ -480,7 +529,7 @@ def _translate(
# )
# data_iter = None

xlation_builder = mammoth.translate.TranslationBuilder(
xlation_builder = TranslationBuilder(
corpus,
self.vocabs,
self.n_best,
Expand Down Expand Up @@ -813,6 +862,7 @@ def _translate_batch_with_strategy(self, batch, src_vocabs, decode_strategy):
batch_size = batch.batch_size

# (0.5) Activate adapters
# FIXME: translation is broken, fix is WIP
metadata = self.task.get_serializable_metadata()
self.model.encoder.activate(metadata)
self.model.decoder.activate(metadata)
Expand Down
53 changes: 49 additions & 4 deletions mammoth/utils/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,13 @@ def load_parameters_from_checkpoint(
logger.info(f'Module {name} incompatible keys: {incompatible_keys}')
all_ok = False
else:
logger.warning(f'Could not find model checkpoint file {checkpoint_path}. Affected parameters are reinitialized.')
logger.warning(
f'Could not find model checkpoint file {checkpoint_path}. Affected parameters are reinitialized.'
)
all_ok = False

if not reset_optim:
optimizer_path = f'{ckpt_prefix}_{name}_optim.pt'
optimizer_path = f'{ckpt_prefix}_{name}_optim.pt'
if os.path.isfile(optimizer_path):
# The optimizer parameters are distributed the same way as the components
optim_state_dict = torch.load(optimizer_path)
Expand All @@ -101,17 +103,60 @@ def load_parameters_from_checkpoint(
logger.info(f'Optim {name} incompatible keys: {incompatible_keys}')
all_ok = False
else:
logger.warning(f'Could not find optim checkpoint file {optimizer_path}. Affected parameters are reinitialized.')
logger.warning(
f'Could not find optim checkpoint file {optimizer_path}. Affected parameters are reinitialized.'
)
all_ok = False
if all_ok:
if reset_optim:
logger.info(f'All modules restored from checkpoint {ckpt_prefix}')
logger.info('Optimizer was reset')
if optim is not None:
logger.info('Optimizer was reset')
else:
logger.info(f'All modules and optimizer restored from checkpoint {ckpt_prefix}')
# TODO: barf unless a flag --yes-i-messed-with-the-checkpoint is set


def load_model_for_translation(opts, task_queue_manager, task=None, model_path=None):
if task is None:
raise ValueError('Must set task')
if model_path is None:
model_path = opts.models[0]

# Load only the frame
frame, frame_ckpt_path = load_frame_checkpoint(ckpt_path=opts.train_from)

vocabs_dict = {
'src': frame["vocab"].get(('src', task.src_lang)),
'tgt': frame["vocab"].get(('tgt', task.tgt_lang)),
}

model_opts = ArgumentParser.ckpt_model_opts(frame['opts'])

model = build_model(
model_opts,
opts,
vocabs_dict,
task_queue_manager,
single_task=task.corpus_id,
)

load_parameters_from_checkpoint(
frame_ckpt_path,
model,
optim=None,
task_queue_manager=task_queue_manager,
reset_optim=True,
)

device = torch.device("cuda" if use_gpu(opts) else "cpu")
model.to(device)
model.eval()

return vocabs_dict, model, model_opts



class ModelSaverBase(object):
"""Base class for model saving operations
Expand Down

0 comments on commit e6f564a

Please sign in to comment.