diff --git a/examples/server.json b/examples/server.json new file mode 100644 index 00000000..d7bedaf6 --- /dev/null +++ b/examples/server.json @@ -0,0 +1,14 @@ +{ + "models_root": "./models", + "models": [ + { + "id": 0, + "opts": { + "config": "out.modernize.cpu.yaml", + "model": "models/smoketest.modernize_step_10", + "task_id": "train_de-fr_split0" + }, + "models": ["smoketest.modernize_step_10"] + } + ] +} diff --git a/mammoth/bin/server.py b/mammoth/bin/server.py index 7dae4712..4f168355 100755 --- a/mammoth/bin/server.py +++ b/mammoth/bin/server.py @@ -78,10 +78,16 @@ def unload_model(model_id): @app.route('/translate', methods=['POST']) def translate(): - inputs = request.get_json(force=True) + out = {} + inputs = request.get_json(force=True, silent=True) + if inputs is None: + error = f'Unable to parse {request.data}' + logger.warning(error) + out['error'] = str(error) + out['status'] = STATUS_ERROR + return jsonify(out) if debug: logger.info(inputs) - out = {} try: trans, scores, n_best, _, aligns = translation_server.run(inputs) assert len(trans) == len(inputs) * n_best diff --git a/mammoth/inputters/dataset.py b/mammoth/inputters/dataset.py index 22687905..e7d60638 100644 --- a/mammoth/inputters/dataset.py +++ b/mammoth/inputters/dataset.py @@ -3,6 +3,7 @@ import itertools from functools import partial import gzip +from io import IOBase import torch from torch.nn.utils.rnn import pad_sequence @@ -54,12 +55,16 @@ def _make_example_dict(packed): 'line_idx': next(line_idx_generator) } - if src_path.endswith('.gz'): + if isinstance(src_path, IOBase): + src_fh = src_path + elif src_path.endswith('.gz'): src_fh = gzip.open(src_path, 'rt') else: src_fh = open(src_path, 'rt') if tgt_path is None: tgt_fh = itertools.repeat(None) + elif isinstance(tgt_path, IOBase): + tgt_fh = src_path elif tgt_path.endswith('.gz'): tgt_fh = gzip.open(tgt_path, 'rt') else: diff --git a/mammoth/tests/test_data_prepare.py b/mammoth/tests/test_data_prepare.py index 7b2db323..ec289ec1 100644 --- a/mammoth/tests/test_data_prepare.py +++ b/mammoth/tests/test_data_prepare.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# # # FIXME: Broken in FoTraNMT +# # # FIXME: Broken in MAMMOTH # # import copy # import unittest diff --git a/mammoth/tests/test_models.py b/mammoth/tests/test_models.py index b9400791..550838e7 100644 --- a/mammoth/tests/test_models.py +++ b/mammoth/tests/test_models.py @@ -141,7 +141,7 @@ def test_method(self): ''' opts.brnn = False -# FIXME: Most tests disabled: FoTraNMT only supports Transformer +# FIXME: Most tests disabled: MAMMOTH only supports Transformer test_embeddings = [ # [], [('decoder_type', 'transformer')] @@ -150,7 +150,7 @@ def test_method(self): for p in test_embeddings: _add_test(p, 'embeddings_forward') -# FIXME: All tests disabled: FoTraNMT only supports Transformer, and the test for Transformer is broken +# FIXME: All tests disabled: MAMMOTH only supports Transformer, and the test for Transformer is broken tests_encoder = [ # [], # [('encoder_type', 'mean')], @@ -161,7 +161,7 @@ def test_method(self): for p in tests_encoder: _add_test(p, 'encoder_forward') -# FIXME: Most tests disabled: FoTraNMT only supports Transformer +# FIXME: Most tests disabled: MAMMOTH only supports Transformer tests_nmtmodel = [ # [('rnn_type', 'GRU')], # [('layers', 10)], @@ -193,6 +193,6 @@ def test_method(self): ] -# ## FIXME: Broken in FoTraNMT +# ## FIXME: Broken in MAMMOTH # for p in tests_nmtmodel: # _add_test(p, 'nmtmodel_forward') diff --git a/mammoth/tests/test_text_dataset.py b/mammoth/tests/test_text_dataset.py index 9d6f3b56..74eedd5c 100644 --- a/mammoth/tests/test_text_dataset.py +++ b/mammoth/tests/test_text_dataset.py @@ -1,4 +1,4 @@ -# FIXME, or rather TODO in FoTran +# FIXME, or rather TODO in MAMMOTH # import unittest # # import itertools diff --git a/mammoth/tests/test_translation_server.py b/mammoth/tests/test_translation_server.py index 61ae316f..ad0d4217 100644 --- a/mammoth/tests/test_translation_server.py +++ b/mammoth/tests/test_translation_server.py @@ -13,7 +13,7 @@ class TestServerModel(unittest.TestCase): - @unittest.skip('Broken in FoTraNMT') # FIXME + @unittest.skip('Broken in MAMMOTH') # FIXME def test_deferred_loading_model_and_unload(self): model_id = 0 opts = {"models": ["test_model.pt"]} @@ -26,7 +26,7 @@ def test_deferred_loading_model_and_unload(self): sm.unload() self.assertFalse(sm.loaded) - @unittest.skip('Broken in FoTraNMT') # FIXME + @unittest.skip('Broken in MAMMOTH') # FIXME def test_load_model_on_init_and_unload(self): model_id = 0 opts = {"models": ["test_model.pt"]} @@ -37,7 +37,7 @@ def test_load_model_on_init_and_unload(self): sm.unload() self.assertFalse(sm.loaded) - @unittest.skip('Broken in FoTraNMT') # FIXME + @unittest.skip('Broken in MAMMOTH') # FIXME def test_tokenizing_with_no_tokenizer_fails(self): model_id = 0 opts = {"models": ["test_model.pt"]} @@ -46,7 +46,7 @@ def test_tokenizing_with_no_tokenizer_fails(self): with self.assertRaises(ValueError): sm.tokenize("hello world") - @unittest.skip('Broken in FoTraNMT') # FIXME + @unittest.skip('Broken in MAMMOTH') # FIXME def test_detokenizing_with_no_tokenizer_fails(self): model_id = 0 opts = {"models": ["test_model.pt"]} @@ -57,7 +57,7 @@ def test_detokenizing_with_no_tokenizer_fails(self): if torch.cuda.is_available(): - @unittest.skip('Broken in FoTraNMT') # FIXME + @unittest.skip('Broken in MAMMOTH') # FIXME def test_moving_to_gpu_and_back(self): torch.cuda.set_device(torch.device("cuda", 0)) model_id = 0 @@ -74,7 +74,7 @@ def test_moving_to_gpu_and_back(self): for p in sm.translator.model.parameters(): self.assertEqual(p.device.type, "cpu") - @unittest.skip('Broken in FoTraNMT') # FIXME + @unittest.skip('Broken in MAMMOTH') # FIXME def test_initialize_on_gpu_and_move_back(self): torch.cuda.set_device(torch.device("cuda", 0)) model_id = 0 @@ -111,7 +111,7 @@ def test_initialize_on_nonzero_gpu_and_back(self): for p in sm.translator.model.parameters(): self.assertEqual(p.device.type, "cpu") - @unittest.skip('Broken in FoTraNMT') # FIXME + @unittest.skip('Broken in MAMMOTH') # FIXME def test_run(self): model_id = 0 opts = {"models": ["test_model.pt"]} diff --git a/mammoth/translate/translation_server.py b/mammoth/translate/translation_server.py index 5c008a67..d96bc784 100644 --- a/mammoth/translate/translation_server.py +++ b/mammoth/translate/translation_server.py @@ -12,16 +12,19 @@ import torch import mammoth.opts +from io import StringIO from itertools import islice, zip_longest from copy import deepcopy from mammoth.constants import DefaultTokens +from mammoth.distributed import TaskSpecs +from mammoth.distributed.tasks import get_adapter_ids +from mammoth.translate.translator import build_translator +from mammoth.transforms import get_transforms_cls, make_transforms, TransformPipe +from mammoth.utils.alignment import to_word_align from mammoth.utils.logging import init_logger from mammoth.utils.misc import set_random_seed -from mammoth.utils.misc import check_model_config -from mammoth.utils.alignment import to_word_align from mammoth.utils.parse import ArgumentParser -from mammoth.translate.translator import build_translator def critical(func): @@ -175,7 +178,7 @@ def start(self, config_file): parameter for model #%d""" % i ) - check_model_config(conf, self.models_root) + # check_model_config(conf, self.models_root) kwargs = { 'timeout': conf.get('timeout', None), 'load': conf.get('load', None), @@ -383,7 +386,11 @@ def parse_opt(self, opts): prec_argv = sys.argv sys.argv = sys.argv[:1] parser = ArgumentParser() - mammoth.opts.translate_opts(parser) + + parser.translation = True + + mammoth.opts.dynamic_prepare_opts(parser) + mammoth.opts.translate_opts(parser, dynamic=True) models = opts['models'] if not isinstance(models, (list, tuple)): @@ -397,11 +404,19 @@ def parse_opt(self, opts): sys.argv += [str(model) for model in v] elif isinstance(v, bool): sys.argv += ['-%s' % k] + elif k == 'transforms': + if type(v) is str: + sys.argv += ['-transforms', v] + else: + assert type(v) is list + sys.argv += ['-transforms', *v] else: sys.argv += ['-%s' % k, str(v)] opts = parser.parse_args() + ArgumentParser.validate_prepare_opts(opts) ArgumentParser.validate_translate_opts(opts) + ArgumentParser.validate_translate_opts_dynamic(opts) opts.cuda = opts.gpu > -1 sys.argv = prec_argv @@ -431,9 +446,27 @@ def load(self, preload=False): preload=preload, ) else: + task = self.make_task(self.opts) self.translator = build_translator( - self.opts, report_score=False, out_file=codecs.open(os.devnull, "w", "utf-8") + self.opts, + task, + report_score=False, + out_file=codecs.open(os.devnull, "w", "utf-8"), ) + + # Build transforms + transforms_cls = get_transforms_cls(self.opts._all_transform) + transforms = make_transforms( + self.opts, + transforms_cls, + self.translator.vocabs, + task=task + ) + data_transform = [ + transforms[name] for name in self.opts.transforms if name in transforms + ] + + self.transform = TransformPipe.build_from(data_transform) except RuntimeError as e: raise ServerModelError("Runtime Error: %s" % str(e)) @@ -442,6 +475,36 @@ def load(self, preload=False): self.reset_unload_timer() self.loading_lock.set() + def make_task(self, opts): + corpus_id = opts.task_id + print(f'opts.tasks {type(opts.tasks)}') + corpus_opts = opts.tasks[corpus_id] + src_lang, tgt_lang = corpus_opts['src_tgt'].split('-', 1) + encoder_id = corpus_opts.get('enc_sharing_group', [src_lang]) + decoder_id = corpus_opts.get('dec_sharing_group', [tgt_lang]) + if 'adapters' in corpus_opts: + encoder_adapter_ids = get_adapter_ids(opts, corpus_opts, 'encoder') + decoder_adapter_ids = get_adapter_ids(opts, corpus_opts, 'decoder') + else: + encoder_adapter_ids = None + decoder_adapter_ids = None + task = TaskSpecs( + node_rank=None, + local_rank=None, + src_lang=src_lang, + tgt_lang=tgt_lang, + encoder_id=encoder_id, + decoder_id=decoder_id, + corpus_id=corpus_id, + weight=1.0, + corpus_opts=corpus_opts, + src_vocab=None, + tgt_vocab=None, + encoder_adapter_ids=encoder_adapter_ids, + decoder_adapter_ids=decoder_adapter_ids, + ) + return task + @critical def run(self, inputs): """Translate `inputs` using this model @@ -514,8 +577,9 @@ def run(self, inputs): if len(texts_to_translate) > 0: try: - scores, predictions = self.translator.translate( - texts_to_translate, + scores, predictions = self.translator.translate_dynamic( + src=StringIO("\n".join(texts_to_translate)), + transform=self.transform, tgt=texts_ref, batch_size=len(texts_to_translate) if self.opts.batch_size == 0 else self.opts.batch_size, ) diff --git a/mammoth/translate/translator.py b/mammoth/translate/translator.py index 81434075..463eb47f 100644 --- a/mammoth/translate/translator.py +++ b/mammoth/translate/translator.py @@ -449,11 +449,11 @@ def _translate( of `n_best` predictions """ - self.logger.info("src vocab: {}".format(self.vocabs['src'])) - self.logger.info("transforms: {}".format(transforms)) + self._log("src vocab: {}".format(self.vocabs['src'])) + self._log("transforms: {}".format(transforms)) corpus = ParallelCorpus( - self.src_file_path, - self.tgt_file_path, # may be None + src, + tgt, self.vocabs['src'], self.vocabs['tgt'], transforms=transforms, # I suppose you might want *some* transforms diff --git a/mammoth/utils/misc.py b/mammoth/utils/misc.py index e78d64b8..243bb31d 100644 --- a/mammoth/utils/misc.py +++ b/mammoth/utils/misc.py @@ -5,6 +5,7 @@ import inspect import numpy as np from itertools import islice, repeat +from io import StringIO import os @@ -30,16 +31,16 @@ def split_corpus(path, shard_size, default=None): def _split_corpus(path, shard_size): - """Yield a `list` containing `shard_size` line of `path`.""" - with open(path, "rb") as f: + """Yield io's with `shard_size` lines each.""" + with open(path, "rt") as f: if shard_size <= 0: - yield f.readlines() + yield f else: while True: - shard = list(islice(f, shard_size)) + shard = "\n".join(islice(f, shard_size)) if not shard: break - yield shard + yield StringIO(shard) def aeq(*args): diff --git a/tools/demo/README.md b/tools/demo/README.md new file mode 100644 index 00000000..2b7e64fc --- /dev/null +++ b/tools/demo/README.md @@ -0,0 +1,9 @@ +Run the backend (translation server) + + python mammoth/bin/server.py --config server.json + +Run the frontend + + streamlit run mammoth_demo.py + +Start the backend first, otherwise the frontend will show an error until refreshed. diff --git a/tools/demo/configs/ab-neg/config.yaml b/tools/demo/configs/ab-neg/config.yaml new file mode 100644 index 00000000..14c56c1a --- /dev/null +++ b/tools/demo/configs/ab-neg/config.yaml @@ -0,0 +1,620 @@ +accum_count: 4 +batch_size: 4096 +batch_type: tokens +dec_layers: +- 6 +decay_method: none +decoder_type: transformer +denoising_objective: bart +dropout: 0.1 +enc_layers: 6 +encoder_type: transformer +gpu_ranks: +- 0 +- 1 +- 2 +- 3 +- 4 +- 5 +heads: 8 +keep_checkpoint: -1 +label_smoothing: 0.1 +learning_rate: 3.0 +mask_ratio: 0.2 +max_generator_batches: 2 +max_grad_norm: 0.0 +model_dim: 512 +model_type: text +n_nodes: 1 +normalization: tokens +optim: adafactor +overwrite: false +param_init: 0.0 +param_init_glorot: true +position_encoding: true +replace_length: 1 +report_every: 50 +save_all_gpus: false +save_checkpoint_steps: 5000 +save_model: /scratch/project_462000447/members/attiehjo/1.ab-negative-refactored/experiments_tags/unpc/all-centric/encoder-shared/no_bridge/123/models/model +seed: 123 +skip_empty_level: silent +src_seq_length: 200 +src_subword_type: sentencepiece +src_vocab: + ar: tools/demo/configs/ab-neg/vocab/opusTC.ar.vocab.onmt + en: tools/demo/configs/ab-neg/vocab/opusTC.en.vocab.onmt + es: tools/demo/configs/ab-neg/vocab/opusTC.es.vocab.onmt + fr: tools/demo/configs/ab-neg/vocab/opusTC.fr.vocab.onmt + ru: tools/demo/configs/ab-neg/vocab/opusTC.ru.vocab.onmt + zh: tools/demo/configs/ab-neg/vocab/opusTC.zh.vocab.onmt +src_vocab_size: 100000 +tasks: + train_ar-ar: + dec_sharing_group: + - ar + enc_sharing_group: + - all + node_gpu: 0:0 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/monolingual/ar/train.ar.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/monolingual/ar/train.ar.sp + src_prefix: + src_tgt: ar-ar + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + - denoising + train_ar-en: + dec_sharing_group: + - en + enc_sharing_group: + - all + node_gpu: 0:1 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-en/train.ar-en.ar.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-en/train.ar-en.en.sp + src_prefix: + src_tgt: ar-en + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_ar-es: + dec_sharing_group: + - es + enc_sharing_group: + - all + node_gpu: 0:2 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-es/train.ar-es.ar.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-es/train.ar-es.es.sp + src_prefix: + src_tgt: ar-es + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_ar-fr: + dec_sharing_group: + - fr + enc_sharing_group: + - all + node_gpu: 0:3 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-fr/train.ar-fr.ar.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-fr/train.ar-fr.fr.sp + src_prefix: + src_tgt: ar-fr + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_ar-ru: + dec_sharing_group: + - ru + enc_sharing_group: + - all + node_gpu: 0:4 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-ru/train.ar-ru.ar.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-ru/train.ar-ru.ru.sp + src_prefix: + src_tgt: ar-ru + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_ar-zh: + dec_sharing_group: + - zh + enc_sharing_group: + - all + node_gpu: 0:5 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-zh/train.ar-zh.ar.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-zh/train.ar-zh.zh.sp + src_prefix: + src_tgt: ar-zh + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_en-ar: + dec_sharing_group: + - ar + enc_sharing_group: + - all + node_gpu: 0:0 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-en/train.ar-en.en.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-en/train.ar-en.ar.sp + src_prefix: + src_tgt: en-ar + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_en-en: + dec_sharing_group: + - en + enc_sharing_group: + - all + node_gpu: 0:1 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/monolingual/en/train.en.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/monolingual/en/train.en.sp + src_prefix: + src_tgt: en-en + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + - denoising + train_en-es: + dec_sharing_group: + - es + enc_sharing_group: + - all + node_gpu: 0:2 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-es/train.en-es.en.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-es/train.en-es.es.sp + src_prefix: + src_tgt: en-es + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_en-fr: + dec_sharing_group: + - fr + enc_sharing_group: + - all + node_gpu: 0:3 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-fr/train.en-fr.en.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-fr/train.en-fr.fr.sp + src_prefix: + src_tgt: en-fr + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_en-ru: + dec_sharing_group: + - ru + enc_sharing_group: + - all + node_gpu: 0:4 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-ru/train.en-ru.en.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-ru/train.en-ru.ru.sp + src_prefix: + src_tgt: en-ru + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_en-zh: + dec_sharing_group: + - zh + enc_sharing_group: + - all + node_gpu: 0:5 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-zh/train.en-zh.en.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-zh/train.en-zh.zh.sp + src_prefix: + src_tgt: en-zh + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_es-ar: + dec_sharing_group: + - ar + enc_sharing_group: + - all + node_gpu: 0:0 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-es/train.ar-es.es.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-es/train.ar-es.ar.sp + src_prefix: + src_tgt: es-ar + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_es-en: + dec_sharing_group: + - en + enc_sharing_group: + - all + node_gpu: 0:1 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-es/train.en-es.es.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-es/train.en-es.en.sp + src_prefix: + src_tgt: es-en + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_es-es: + dec_sharing_group: + - es + enc_sharing_group: + - all + node_gpu: 0:2 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/monolingual/es/train.es.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/monolingual/es/train.es.sp + src_prefix: + src_tgt: es-es + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + - denoising + train_es-fr: + dec_sharing_group: + - fr + enc_sharing_group: + - all + node_gpu: 0:3 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/es-fr/train.es-fr.es.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/es-fr/train.es-fr.fr.sp + src_prefix: + src_tgt: es-fr + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_es-ru: + dec_sharing_group: + - ru + enc_sharing_group: + - all + node_gpu: 0:4 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/es-ru/train.es-ru.es.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/es-ru/train.es-ru.ru.sp + src_prefix: + src_tgt: es-ru + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_es-zh: + dec_sharing_group: + - zh + enc_sharing_group: + - all + node_gpu: 0:5 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/es-zh/train.es-zh.es.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/es-zh/train.es-zh.zh.sp + src_prefix: + src_tgt: es-zh + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_fr-ar: + dec_sharing_group: + - ar + enc_sharing_group: + - all + node_gpu: 0:0 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-fr/train.ar-fr.fr.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-fr/train.ar-fr.ar.sp + src_prefix: + src_tgt: fr-ar + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_fr-en: + dec_sharing_group: + - en + enc_sharing_group: + - all + node_gpu: 0:1 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-fr/train.en-fr.fr.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-fr/train.en-fr.en.sp + src_prefix: + src_tgt: fr-en + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_fr-es: + dec_sharing_group: + - es + enc_sharing_group: + - all + node_gpu: 0:2 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/es-fr/train.es-fr.fr.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/es-fr/train.es-fr.es.sp + src_prefix: + src_tgt: fr-es + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_fr-fr: + dec_sharing_group: + - fr + enc_sharing_group: + - all + node_gpu: 0:3 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/monolingual/fr/train.fr.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/monolingual/fr/train.fr.sp + src_prefix: + src_tgt: fr-fr + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + - denoising + train_fr-ru: + dec_sharing_group: + - ru + enc_sharing_group: + - all + node_gpu: 0:4 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/fr-ru/train.fr-ru.fr.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/fr-ru/train.fr-ru.ru.sp + src_prefix: + src_tgt: fr-ru + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_fr-zh: + dec_sharing_group: + - zh + enc_sharing_group: + - all + node_gpu: 0:5 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/fr-zh/train.fr-zh.fr.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/fr-zh/train.fr-zh.zh.sp + src_prefix: + src_tgt: fr-zh + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_ru-ar: + dec_sharing_group: + - ar + enc_sharing_group: + - all + node_gpu: 0:0 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-ru/train.ar-ru.ru.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-ru/train.ar-ru.ar.sp + src_prefix: + src_tgt: ru-ar + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_ru-en: + dec_sharing_group: + - en + enc_sharing_group: + - all + node_gpu: 0:1 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-ru/train.en-ru.ru.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-ru/train.en-ru.en.sp + src_prefix: + src_tgt: ru-en + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_ru-es: + dec_sharing_group: + - es + enc_sharing_group: + - all + node_gpu: 0:2 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/es-ru/train.es-ru.ru.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/es-ru/train.es-ru.es.sp + src_prefix: + src_tgt: ru-es + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_ru-fr: + dec_sharing_group: + - fr + enc_sharing_group: + - all + node_gpu: 0:3 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/fr-ru/train.fr-ru.ru.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/fr-ru/train.fr-ru.fr.sp + src_prefix: + src_tgt: ru-fr + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_ru-ru: + dec_sharing_group: + - ru + enc_sharing_group: + - all + node_gpu: 0:4 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/monolingual/ru/train.ru.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/monolingual/ru/train.ru.sp + src_prefix: + src_tgt: ru-ru + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + - denoising + train_ru-zh: + dec_sharing_group: + - zh + enc_sharing_group: + - all + node_gpu: 0:5 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ru-zh/train.ru-zh.ru.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ru-zh/train.ru-zh.zh.sp + src_prefix: + src_tgt: ru-zh + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_zh-ar: + dec_sharing_group: + - ar + enc_sharing_group: + - all + node_gpu: 0:0 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-zh/train.ar-zh.zh.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ar-zh/train.ar-zh.ar.sp + src_prefix: + src_tgt: zh-ar + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_zh-en: + dec_sharing_group: + - en + enc_sharing_group: + - all + node_gpu: 0:1 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-zh/train.en-zh.zh.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/en-zh/train.en-zh.en.sp + src_prefix: + src_tgt: zh-en + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_zh-es: + dec_sharing_group: + - es + enc_sharing_group: + - all + node_gpu: 0:2 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/es-zh/train.es-zh.zh.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/es-zh/train.es-zh.es.sp + src_prefix: + src_tgt: zh-es + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_zh-fr: + dec_sharing_group: + - fr + enc_sharing_group: + - all + node_gpu: 0:3 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/fr-zh/train.fr-zh.zh.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/fr-zh/train.fr-zh.fr.sp + src_prefix: + src_tgt: zh-fr + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_zh-ru: + dec_sharing_group: + - ru + enc_sharing_group: + - all + node_gpu: 0:4 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ru-zh/train.ru-zh.zh.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/ru-zh/train.ru-zh.ru.sp + src_prefix: + src_tgt: zh-ru + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + train_zh-zh: + dec_sharing_group: + - zh + enc_sharing_group: + - all + node_gpu: 0:5 + path_src: /scratch/project_462000447/members/attiehjo/data/unpc/generated/monolingual/zh/train.zh.sp + path_tgt: /scratch/project_462000447/members/attiehjo/data/unpc/generated/monolingual/zh/train.zh.sp + src_prefix: + src_tgt: zh-zh + tgt_prefix: + transforms: + - sentencepiece + - filtertoolong + - prefix + - denoising +tensorboard: true +tensorboard_log_dir: /scratch/project_462000447/members/attiehjo/1.ab-negative-refactored/experiments_tags/unpc/all-centric/encoder-shared/no_bridge/123/tb/ +tgt_seq_length: 200 +tgt_subword_type: sentencepiece +tgt_vocab: + ar: tools/demo/configs/ab-neg/vocab/opusTC.ar.vocab.onmt + en: tools/demo/configs/ab-neg/vocab/opusTC.en.vocab.onmt + es: tools/demo/configs/ab-neg/vocab/opusTC.es.vocab.onmt + fr: tools/demo/configs/ab-neg/vocab/opusTC.fr.vocab.onmt + ru: tools/demo/configs/ab-neg/vocab/opusTC.ru.vocab.onmt + zh: tools/demo/configs/ab-neg/vocab/opusTC.zh.vocab.onmt +tgt_vocab_size: 100000 +train_steps: 300000 +transformer_ff: 2048 +valid_batch_size: 4096 +valid_steps: 10000 +warmup_steps: 10000 +weight_decay: 0.05 +world_size: 6 + +src_subword_model: tools/demo/configs/ab-neg/vocab/{src_lang}.spm +tgt_subword_model: tools/demo/configs/ab-neg/vocab/{tgt_lang}.spm diff --git a/tools/demo/configs/ab-neg/server.json b/tools/demo/configs/ab-neg/server.json new file mode 100644 index 00000000..9ce1035d --- /dev/null +++ b/tools/demo/configs/ab-neg/server.json @@ -0,0 +1,185 @@ +{ + "models_root": "tools/demo/configs/ab-neg/models/", + "models": [ + { + "id": 0, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_en-ar" + }, + "models": ["model_step_15000"] + }, + { + "id": 1, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_en-es" + }, + "models": ["model_step_15000"] + }, + { + "id": 2, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_en-fr" + }, + "models": ["model_step_15000"] + }, + { + "id": 3, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_en-ru" + }, + "models": ["model_step_15000"] + }, + { + "id": 4, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_en-zh" + }, + "models": ["model_step_15000"] + }, + { + "id": 5, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_es-ar" + }, + "models": ["model_step_15000"] + }, + { + "id": 6, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_es-en" + }, + "models": ["model_step_15000"] + }, + { + "id": 7, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_es-fr" + }, + "models": ["model_step_15000"] + }, + { + "id": 8, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_es-ru" + }, + "models": ["model_step_15000"] + }, + { + "id": 9, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_es-zh" + }, + "models": ["model_step_15000"] + }, + { + "id": 10, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_fr-ar" + }, + "models": ["model_step_15000"] + }, + { + "id": 11, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_fr-en" + }, + "models": ["model_step_15000"] + }, + { + "id": 12, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_fr-es" + }, + "models": ["model_step_15000"] + }, + { + "id": 13, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_fr-ru" + }, + "models": ["model_step_15000"] + }, + { + "id": 14, + "opts": { + "config": "tools/demo/configs/ab-neg/config.yaml", + "model": "tools/demo/configs/ab-neg/models/model_step_15000", + "transforms": ["sentencepiece", "prefix"], + "src_prefix": "", + "tgt_prefix": "", + "task_id": "train_fr-zh" + }, + "models": ["model_step_15000"] + } + ] +} diff --git a/tools/demo/configs/hydra/hydra-L-train-config.yml b/tools/demo/configs/hydra/hydra-L-train-config.yml new file mode 100755 index 00000000..9353945e --- /dev/null +++ b/tools/demo/configs/hydra/hydra-L-train-config.yml @@ -0,0 +1,184 @@ +src_vocab: + 'all': tools/demo/configs/hydra/vocab/mammoth-hydra.64k.spm.vocab +tgt_vocab: + 'en': tools/demo/configs/hydra/vocab/mammoth-hydra.64k.spm.vocab + 'fr': tools/demo/configs/hydra/vocab/mammoth-hydra.64k.spm.vocab + 'ru': tools/demo/configs/hydra/vocab/mammoth-hydra.64k.spm.vocab + +overwrite: False +tasks: + # GPU 0:0 + defmod_en: + src_tgt: all-en + enc_sharing_group: [all] + dec_sharing_group: [en1, dm, en2] + node_gpu: 0:0 + path_src: /scratch/project_2005099/data/mammoth-hydra/codwoe/en.src.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/codwoe/en.tgt.sp + transforms: [filtertoolong] + pargen_en: + src_tgt: all-en + enc_sharing_group: [all] + dec_sharing_group: [en1, pg, en2] + node_gpu: 0:0 + path_src: /scratch/project_2005099/data/mammoth-hydra/tapaco/en.src.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/tapaco/en.tgt.sp + transforms: [filtertoolong] + texsim_en: + src_tgt: all-en + enc_sharing_group: [all] + dec_sharing_group: [en1, ts, en2] + node_gpu: 0:0 + path_src: /scratch/project_2005099/data/mammoth-hydra/wikilarge/en.src.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/wikilarge/en.tgt.sp + transforms: [filtertoolong] + translate_fr-en: + src_tgt: all-en + enc_sharing_group: [all] + dec_sharing_group: [en1, mt, en2] + node_gpu: 0:0 + path_src: /scratch/project_2005099/data/mammoth-hydra/unpc/UNv1.0.6way.fr.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/unpc/UNv1.0.6way.en.sp + transforms: [filtertoolong] + translate_ru-en: + src_tgt: all-en + enc_sharing_group: [all] + dec_sharing_group: [en1, mt, en2] + node_gpu: 0:0 + path_src: /scratch/project_2005099/data/mammoth-hydra/unpc/UNv1.0.6way.ru.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/unpc/UNv1.0.6way.en.sp + transforms: [filtertoolong] + + # GPU 0:1 + defmod_fr: + src_tgt: all-fr + enc_sharing_group: [all] + dec_sharing_group: [fr1, dm, fr2] + node_gpu: 0:1 + path_src: /scratch/project_2005099/data/mammoth-hydra/codwoe/fr.src.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/codwoe/fr.tgt.sp + transforms: [filtertoolong] + pargen_fr: + src_tgt: all-fr + enc_sharing_group: [all] + dec_sharing_group: [fr1, pg, fr2] + node_gpu: 0:1 + path_src: /scratch/project_2005099/data/mammoth-hydra/tapaco/fr.src.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/tapaco/fr.tgt.sp + transforms: [filtertoolong] + texsim_fr: + src_tgt: all-fr + enc_sharing_group: [all] + dec_sharing_group: [fr1, ts, fr2] + node_gpu: 0:1 + path_src: /scratch/project_2005099/data/mammoth-hydra/wikilarge/fr.src.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/wikilarge/fr.tgt.sp + transforms: [filtertoolong] + translate_en-fr: + src_tgt: all-fr + enc_sharing_group: [all] + dec_sharing_group: [fr1, mt, fr2] + node_gpu: 0:1 + path_src: /scratch/project_2005099/data/mammoth-hydra/unpc/UNv1.0.6way.en.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/unpc/UNv1.0.6way.fr.sp + transforms: [filtertoolong] + translate_ru-fr: + src_tgt: all-fr + enc_sharing_group: [all] + dec_sharing_group: [fr1, mt, fr2] + node_gpu: 0:1 + path_src: /scratch/project_2005099/data/mammoth-hydra/unpc/UNv1.0.6way.ru.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/unpc/UNv1.0.6way.fr.sp + transforms: [filtertoolong] + + # GPU 0:2 + defmod_ru: + src_tgt: all-ru + enc_sharing_group: [all] + dec_sharing_group: [ru1, dm, ru2] + node_gpu: 0:2 + path_src: /scratch/project_2005099/data/mammoth-hydra/codwoe/ru.src.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/codwoe/ru.tgt.sp + transforms: [filtertoolong] + pargen_ru: + src_tgt: all-ru + enc_sharing_group: [all] + dec_sharing_group: [ru1, pg, ru2] + node_gpu: 0:2 + path_src: /scratch/project_2005099/data/mammoth-hydra/tapaco/ru.src.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/tapaco/ru.tgt.sp + transforms: [filtertoolong] + texsim_ru: + src_tgt: all-ru + enc_sharing_group: [all] + dec_sharing_group: [ru1, ts, ru2] + node_gpu: 0:2 + path_src: /scratch/project_2005099/data/mammoth-hydra/ruadapt/ru.src.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/ruadapt/ru.tgt.sp + transforms: [filtertoolong] + translate_fr-ru: + src_tgt: all-ru + enc_sharing_group: [all] + dec_sharing_group: [ru1, mt, ru2] + node_gpu: 0:2 + path_src: /scratch/project_2005099/data/mammoth-hydra/unpc/UNv1.0.6way.fr.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/unpc/UNv1.0.6way.ru.sp + transforms: [filtertoolong] + translate_en-ru: + src_tgt: all-ru + enc_sharing_group: [all] + dec_sharing_group: [ru1, mt, ru2] + node_gpu: 0:1 + path_src: /scratch/project_2005099/data/mammoth-hydra/unpc/UNv1.0.6way.en.sp + path_tgt: /scratch/project_2005099/data/mammoth-hydra/unpc/UNv1.0.6way.ru.sp + transforms: [filtertoolong] + + +### Transform related opts: +#### Filter +src_seq_length: 200 +tgt_seq_length: 200 +#### Bart +src_subword_type: sentencepiece +tgt_subword_type: sentencepiece +mask_ratio: 0.2 +replace_length: 1 + +batch_size: 4096 +batch_type: tokens +normalization: tokens +valid_batch_size: 4096 +max_generator_batches: 2 +src_vocab_size: 100000 +tgt_vocab_size: 100000 +encoder_type: transformer +decoder_type: transformer +model_dim: 512 +transformer_ff: 2048 +heads: 8 +enc_layers: [12] +dec_layers: [2, 2, 2] +dropout: 0.1 +label_smoothing: 0.1 +param_init: 0.0 +param_init_glorot: true +position_encoding: true +valid_steps: 10000 +warmup_steps: 10000 +report_every: 100 +save_checkpoint_steps: 10000 +keep_checkpoint: -1 +accum_count: 1 +optim: adafactor +decay_method: none +learning_rate: 3.0 +max_grad_norm: 0.0 +seed: 3435 +save_all_gpus: false + +world_size: 3 +gpu_ranks: [0, 1, 2] +node_rank: 0 + +src_subword_model: tools/demo/configs/hydra/vocab/mammoth-hydra.64k.spm.model +tgt_subword_model: tools/demo/configs/hydra/vocab/mammoth-hydra.64k.spm.model diff --git a/tools/demo/configs/hydra/server.json b/tools/demo/configs/hydra/server.json new file mode 100644 index 00000000..501a1cbd --- /dev/null +++ b/tools/demo/configs/hydra/server.json @@ -0,0 +1,165 @@ +{ + "models_root": "tools/demo/configs/hydra/models/", + "models": [ + { + "id": 0, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "defmod_en" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 1, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "defmod_fr" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 2, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "defmod_ru" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 3, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "pargen_en" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 4, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "pargen_fr" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 5, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "pargen_ru" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 6, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "texsim_en" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 7, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "texsim_fr" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 8, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "texsim_ru" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 9, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "texsim_fr" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 10, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "translate_fr-en" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 11, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "translate_ru-en" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 12, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "translate_en-fr" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 13, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "translate_ru-fr" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 14, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "translate_fr-ru" + }, + "models": ["hydra-L-ckpt_step_460000"] + }, + { + "id": 15, + "opts": { + "config": "tools/demo/configs/hydra/hydra-L-train-config.yml", + "model": "models/hydra-L-ckpt_step_460000", + "transforms": "sentencepiece", + "task_id": "translate_en-ru" + }, + "models": ["hydra-L-ckpt_step_460000"] + } + ] +} diff --git a/tools/demo/mammoth_demo.py b/tools/demo/mammoth_demo.py new file mode 100644 index 00000000..50cc4fdf --- /dev/null +++ b/tools/demo/mammoth_demo.py @@ -0,0 +1,221 @@ +from dataclasses import dataclass +import requests +import streamlit as st # type: ignore + +st.set_page_config(layout="wide") + +MAMMOTH = '🦣' +FAT_UNDER = '▁' + +ARCHITECTURE_HTML_HYDRA = """ +

Decoder

+
+
+
en
+
fr
+
ru
+
+
+
defmod
+
pargen
+
texsim
+
translate
+
+
+
en
+
fr
+
ru
+
+
+

Encoder

+
+
+
fully shared
+
+
+ + +""" + +ARCHITECTURE_HTML_ABNEG = """ +

Decoder

+
+
+
ar
+
en
+
es
+
fr
+
ru
+
zh
+
+
+

Encoder

+
+
+
fully shared
+
+
+ + +""" + + +def render(template, model_task): + task, lang = model_task.split('_') + if task == 'translate' or task == 'train': + _, lang = lang.split('-') + template = template.replace('__TASK__', task) + template = template.replace('__LANG__', lang) + return template + + +@dataclass +class ModelSpecs: + id: int + task: str + loaded: bool + + @staticmethod + def format_model(model): + return model.task + + +class Translator: + def __call__(self): + st.title(f'{MAMMOTH} MAMMOTH translation demo') + col1, col2 = st.columns([0.6, 0.4], gap="large") + with col1: + model = st.selectbox( + 'Model', + st.session_state.models, + format_func=ModelSpecs.format_model, + ) + source = st.text_area( + 'Source text', + height=None, + ) + submitted = st.button('▶️ Translate') + if source or submitted: + target_text = self.submit(source, model.id) + else: + target_text = '' + st.text_area( + 'Target text', + value=target_text, + height=None, + ) + with col2: + architecture_html = ARCHITECTURE_HTML_HYDRA if 'train' not in model.task else ARCHITECTURE_HTML_ABNEG + st.markdown( + render(architecture_html, model.task), + unsafe_allow_html=True, + ) + + def submit(self, query, model): + try: + response = requests.request( + 'POST', + 'http://127.0.0.1:5000/translator/translate', + json=[{ + 'src': query, + 'id': model, + }], + ) + data = response.json() + except Exception as e: + print(response.content) + raise e + tokenized = data[0][0]['tgt'] + return self.detokenize(tokenized) + + def detokenize(self, tokenized): + result = tokenized.replace(' ', '').replace(FAT_UNDER, ' ') + return result + + def get_models(self): + data = requests.request( + 'GET', + 'http://127.0.0.1:5000/translator/models', + ).json() + models = [ + ModelSpecs(model_specs['model_id'], model_specs['opts']['task_id'], model_specs['loaded']) + for model_specs in data + ] + st.session_state.models = models + return models + + +translator = Translator() +if 'models' not in st.session_state: + st.session_state.models = translator.get_models() +translator() diff --git a/tools/demo/requirements.txt b/tools/demo/requirements.txt new file mode 100644 index 00000000..a40921a7 --- /dev/null +++ b/tools/demo/requirements.txt @@ -0,0 +1,49 @@ +GitPython==3.1.41 +Jinja2==3.1.3 +MarkupSafe==2.1.4 +altair==5.2.0 +attrs==23.2.0 +backports.zoneinfo==0.2.1;python_version<"3.9" +blinker==1.7.0 +cachetools==5.3.2 +certifi==2023.11.17 +charset-normalizer==3.3.2 +click==8.1.7 +gitdb==4.0.11 +idna==3.6 +importlib-metadata==7.0.1 +importlib-resources==6.1.1 +jsonschema-specifications==2023.12.1 +jsonschema==4.21.1 +markdown-it-py==3.0.0 +mdurl==0.1.2 +numpy==1.24.4 +packaging==23.2 +pandas==2.0.3 +pillow==10.2.0 +pkgutil-resolve-name==1.3.10 +protobuf==4.25.2 +pyarrow==15.0.0 +pydeck==0.8.1b0 +pygments==2.17.2 +python-dateutil==2.8.2 +pytz==2023.4 +referencing==0.33.0 +requests==2.31.0 +rich==13.7.0 +rpds-py==0.17.1 +six==1.16.0 +smmap==5.0.1 +streamlit==1.30.0 +tenacity==8.2.3 +toml==0.10.2 +toolz==0.12.1 +tornado==6.4 +typing-extensions==4.9.0 +tzdata==2023.4 +tzlocal==5.2 +urllib3==2.1.0 +validators==0.22.0 +watchdog==3.0.0 +werkzeug==2.3.8 +zipp==3.17.0