From e6f564a3946ceb8da1c7fdd11fab078e8b056c22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Stig-Arne=20Gr=C3=B6nroos?= Date: Mon, 12 Aug 2024 17:13:17 +0300 Subject: [PATCH] WIP: translation is broken --- mammoth/bin/translate.py | 45 ++++++++++++++++-- mammoth/distributed/tasks.py | 5 +- mammoth/model_builder.py | 45 ------------------ mammoth/translate/translator.py | 82 ++++++++++++++++++++++++++------- mammoth/utils/model_saver.py | 53 +++++++++++++++++++-- 5 files changed, 160 insertions(+), 70 deletions(-) diff --git a/mammoth/bin/translate.py b/mammoth/bin/translate.py index 4deb9e2e..830863a9 100644 --- a/mammoth/bin/translate.py +++ b/mammoth/bin/translate.py @@ -7,9 +7,11 @@ from mammoth.transforms import get_transforms_cls, make_transforms, TransformPipe import mammoth.opts as opts -from mammoth.distributed import TaskSpecs +from mammoth.distributed import TaskSpecs, TaskQueueManager +from mammoth.distributed.contexts import WorldContext, DeviceContextEnum from mammoth.distributed.tasks import get_adapter_ids from mammoth.utils.parse import ArgumentParser +from mammoth.utils.misc import use_gpu def translate(opts): @@ -26,12 +28,24 @@ def translate(opts): if 'adapters' in corpus_opts: encoder_adapter_ids = get_adapter_ids(opts, corpus_opts, 'encoder') decoder_adapter_ids = get_adapter_ids(opts, corpus_opts, 'decoder') + uses_adapters = True else: encoder_adapter_ids = None decoder_adapter_ids = None + uses_adapters = False + + node_rank = 0 + local_rank = 0 + if use_gpu(opts): + context_enum = DeviceContextEnum.SINGLE_GPU + gpus_per_node = 1 + else: + context_enum = DeviceContextEnum.CPU + gpus_per_node = 0 + task = TaskSpecs( - node_rank=None, - local_rank=None, + node_rank=node_rank, + local_rank=local_rank, src_lang=src_lang, tgt_lang=tgt_lang, encoder_id=encoder_id, @@ -46,7 +60,30 @@ def translate(opts): decoder_adapter_ids=decoder_adapter_ids, ) - translator = build_translator(opts, task, logger=logger, report_score=True) + world_context = WorldContext( + context=context_enum, + n_nodes=1, + gpus_per_node=gpus_per_node, + ) + + task_queue_manager = TaskQueueManager( + tasks=[task], + accum_count=1, + world_context=world_context, + task_distribution_strategy_cls=None, + uses_adapters=uses_adapters, + ).global_to_local( + node_rank=node_rank, + local_rank=local_rank, + opts=opts, + ) + # FIXME: fix the attention bridge in translation + task_queue_manager.create_all_distributed_components( + use_attention_bridge=False, # (opts.ab_layers is not None and len(opts.ab_layers) != 0), + new_group_func=lambda: None, + ) + + translator = build_translator(opts, task_queue_manager, task, logger=logger, report_score=True) # data_reader = InferenceDataReader(opts.src, opts.tgt, opts.src_feats) src_shards = split_corpus(opts.src, opts.shard_size) diff --git a/mammoth/distributed/tasks.py b/mammoth/distributed/tasks.py index 9cf67897..44358bb6 100644 --- a/mammoth/distributed/tasks.py +++ b/mammoth/distributed/tasks.py @@ -288,7 +288,10 @@ def global_to_local(self, node_rank, local_rank, opts): assert node_rank is not None assert local_rank is not None device_context = self.world_context.global_to_local(node_rank, local_rank) - task_distribution_strategy = self.task_distribution_strategy_cls(seed=opts.seed) + if self.task_distribution_strategy_cls is not None: + task_distribution_strategy = self.task_distribution_strategy_cls(seed=opts.seed) + else: + task_distribution_strategy = None return LocalTaskQueueManager( self.tasks, accum_count=self.accum_count, diff --git a/mammoth/model_builder.py b/mammoth/model_builder.py index 018d30b1..4f70aec5 100644 --- a/mammoth/model_builder.py +++ b/mammoth/model_builder.py @@ -47,51 +47,6 @@ def uses_adapters(opts): return 'adapters' in opts and opts.adapters -def load_test_multitask_model(opts, task_queue_manager, task=None, model_path=None): - if task is None: - raise ValueError('Must set task') - if model_path is None: - model_path = opts.models[0] - - # Load only the frame - frame, ckpt_path = load_frame_checkpoint(ckpt_path=opts.train_from) - - vocabs_dict = { - 'src': frame["vocab"].get(('src', task.src_lang)), - 'tgt': frame["vocab"].get(('tgt', task.tgt_lang)), - } - - model_opts = ArgumentParser.ckpt_model_opts(frame['opts']) - # Avoid functionality on inference - # model_opts.update_vocab = False - model = build_model( - model_opts, - opts, - vocabs_dict, - task_queue_manager, - single_task=task.corpus_id, - ) - - # FIXME: load the model parameters - - model_params = {name for name, p in model.named_parameters()} - model_params.update(name for name, p in model.named_buffers()) - for key in set(combined_state_dict.keys()): - if key not in model_params: - print(f'Deleting unnecessary key: {key}') - del combined_state_dict[key] - for key in model_params: - if key not in combined_state_dict: - print(f'Key missing {key}') - model.load_state_dict(combined_state_dict) - device = torch.device("cuda" if use_gpu(opts) else "cpu") - model.to(device) - - model.eval() - - return vocabs_dict, model, model_opts - - def get_attention_layers_kwargs( side: Side, layer_stack_index, diff --git a/mammoth/translate/translator.py b/mammoth/translate/translator.py index 6689c36e..7179dee5 100644 --- a/mammoth/translate/translator.py +++ b/mammoth/translate/translator.py @@ -8,19 +8,21 @@ import torch -import mammoth.model_builder -import mammoth.modules.decoder_ensemble # from mammoth.inputters.text_dataset import InferenceDataIterator -from mammoth.translate.beam_search import BeamSearch, BeamSearchLM -from mammoth.translate.greedy_search import GreedySearch, GreedySearchLM -from mammoth.utils.misc import tile, set_random_seed, report_matrix -from mammoth.utils.alignment import extract_alignment, build_align_pharaoh from mammoth.constants import ModelTask, DefaultTokens -from mammoth.inputters.dataset import ParallelCorpus from mammoth.inputters.dataloader import build_dataloader +from mammoth.inputters.dataset import ParallelCorpus +from mammoth.model_builder import build_model +from mammoth.translate.beam_search import BeamSearch, BeamSearchLM, GNMTGlobalScorer +from mammoth.translate.greedy_search import GreedySearch, GreedySearchLM +from mammoth.translate.translation import TranslationBuilder +from mammoth.utils.alignment import extract_alignment, build_align_pharaoh +from mammoth.utils.misc import tile, set_random_seed, report_matrix, use_gpu +from mammoth.utils.model_saver import load_frame_checkpoint, load_parameters_from_checkpoint +from mammoth.utils.parse import ArgumentParser -def build_translator(opts, task, report_score=True, logger=None, out_file=None): +def build_translator(opts, task_queue_manager, task, report_score=True, logger=None, out_file=None): if out_file is None: outdir = os.path.dirname(opts.output) if outdir and not os.path.isdir(outdir): @@ -29,15 +31,19 @@ def build_translator(opts, task, report_score=True, logger=None, out_file=None): os.makedirs(os.path.dirname(opts.output), exist_ok=True) out_file = codecs.open(opts.output, "w+", "utf-8") - load_test_model = ( - mammoth.modules.decoder_ensemble.load_test_model if len(opts.models) > 3 - else mammoth.model_builder.load_test_multitask_model - ) + # TODO: reimplement ensemble decoding + load_model_for_translation_func = load_model_for_translation if logger: logger.info(str(task)) - vocabs, model, model_opts = load_test_model(opts, task) + model_path = None + vocabs, model, model_opts = load_model_for_translation_func( + opts=opts, + task_queue_manager=task_queue_manager, + task=task, + model_path=model_path, + ) - scorer = mammoth.translate.GNMTGlobalScorer.from_opts(opts) + scorer = GNMTGlobalScorer.from_opts(opts) translator = Translator.from_opts( model, @@ -54,6 +60,49 @@ def build_translator(opts, task, report_score=True, logger=None, out_file=None): return translator +def load_model_for_translation(opts, task_queue_manager, task=None, model_path=None): + if task is None: + raise ValueError('Must set task') + if model_path is None: + model_path = opts.models[0] + + # Load only the frame + frame, frame_ckpt_path = load_frame_checkpoint(ckpt_path=model_path) + + vocabs_dict = { + ('src', task.src_lang): frame["vocab"].get(('src', task.src_lang)), + ('tgt', task.tgt_lang): frame["vocab"].get(('tgt', task.tgt_lang)), + 'src': frame["vocab"].get(('src', task.src_lang)), + 'tgt': frame["vocab"].get(('tgt', task.tgt_lang)), + } + print(f'vocabs_dict {vocabs_dict}') + print(f'my compontents {task_queue_manager.get_my_distributed_components()}') + + model_opts = ArgumentParser.ckpt_model_opts(frame['opts']) + + model = build_model( + model_opts, + opts, + vocabs_dict, + task_queue_manager, + single_task=task.corpus_id, + ) + + load_parameters_from_checkpoint( + frame_ckpt_path, + model, + optim=None, + task_queue_manager=task_queue_manager, + reset_optim=True, + ) + + device = torch.device("cuda" if use_gpu(opts) else "cpu") + model.to(device) + model.eval() + + return vocabs_dict, model, model_opts + + def max_tok_len(new, count, sofar): """ In token batching scheme, the number of sequences is limited @@ -153,7 +202,7 @@ def __init__( self.model = model self.vocabs = vocabs - tgt_vocab = dict(self.vocabs)["tgt"] + tgt_vocab = dict(self.vocabs)[("tgt", task.tgt_lang)] self._tgt_vocab = tgt_vocab self._tgt_eos_idx = self._tgt_vocab.stoi[DefaultTokens.EOS] self._tgt_pad_idx = self._tgt_vocab.stoi[DefaultTokens.PAD] @@ -480,7 +529,7 @@ def _translate( # ) # data_iter = None - xlation_builder = mammoth.translate.TranslationBuilder( + xlation_builder = TranslationBuilder( corpus, self.vocabs, self.n_best, @@ -813,6 +862,7 @@ def _translate_batch_with_strategy(self, batch, src_vocabs, decode_strategy): batch_size = batch.batch_size # (0.5) Activate adapters + # FIXME: translation is broken, fix is WIP metadata = self.task.get_serializable_metadata() self.model.encoder.activate(metadata) self.model.decoder.activate(metadata) diff --git a/mammoth/utils/model_saver.py b/mammoth/utils/model_saver.py index ab80a737..000ec2d3 100644 --- a/mammoth/utils/model_saver.py +++ b/mammoth/utils/model_saver.py @@ -88,11 +88,13 @@ def load_parameters_from_checkpoint( logger.info(f'Module {name} incompatible keys: {incompatible_keys}') all_ok = False else: - logger.warning(f'Could not find model checkpoint file {checkpoint_path}. Affected parameters are reinitialized.') + logger.warning( + f'Could not find model checkpoint file {checkpoint_path}. Affected parameters are reinitialized.' + ) all_ok = False if not reset_optim: - optimizer_path = f'{ckpt_prefix}_{name}_optim.pt' + optimizer_path = f'{ckpt_prefix}_{name}_optim.pt' if os.path.isfile(optimizer_path): # The optimizer parameters are distributed the same way as the components optim_state_dict = torch.load(optimizer_path) @@ -101,17 +103,60 @@ def load_parameters_from_checkpoint( logger.info(f'Optim {name} incompatible keys: {incompatible_keys}') all_ok = False else: - logger.warning(f'Could not find optim checkpoint file {optimizer_path}. Affected parameters are reinitialized.') + logger.warning( + f'Could not find optim checkpoint file {optimizer_path}. Affected parameters are reinitialized.' + ) all_ok = False if all_ok: if reset_optim: logger.info(f'All modules restored from checkpoint {ckpt_prefix}') - logger.info('Optimizer was reset') + if optim is not None: + logger.info('Optimizer was reset') else: logger.info(f'All modules and optimizer restored from checkpoint {ckpt_prefix}') # TODO: barf unless a flag --yes-i-messed-with-the-checkpoint is set +def load_model_for_translation(opts, task_queue_manager, task=None, model_path=None): + if task is None: + raise ValueError('Must set task') + if model_path is None: + model_path = opts.models[0] + + # Load only the frame + frame, frame_ckpt_path = load_frame_checkpoint(ckpt_path=opts.train_from) + + vocabs_dict = { + 'src': frame["vocab"].get(('src', task.src_lang)), + 'tgt': frame["vocab"].get(('tgt', task.tgt_lang)), + } + + model_opts = ArgumentParser.ckpt_model_opts(frame['opts']) + + model = build_model( + model_opts, + opts, + vocabs_dict, + task_queue_manager, + single_task=task.corpus_id, + ) + + load_parameters_from_checkpoint( + frame_ckpt_path, + model, + optim=None, + task_queue_manager=task_queue_manager, + reset_optim=True, + ) + + device = torch.device("cuda" if use_gpu(opts) else "cpu") + model.to(device) + model.eval() + + return vocabs_dict, model, model_opts + + + class ModelSaverBase(object): """Base class for model saving operations