From 39e1f76efbe88c795039841fc963670ac7ca3a6f Mon Sep 17 00:00:00 2001 From: Mickus Timothee Date: Tue, 19 Sep 2023 21:33:41 +0300 Subject: [PATCH] ensure transforms are ordered as declared --- onmt/inputters/dataset.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/onmt/inputters/dataset.py b/onmt/inputters/dataset.py index 5b32fc80..c0f1b2e6 100644 --- a/onmt/inputters/dataset.py +++ b/onmt/inputters/dataset.py @@ -165,17 +165,26 @@ def get_corpus(opts, task, src_vocab: Vocab, tgt_vocab: Vocab, is_train: bool = """build an iterable Dataset object""" # get transform classes to infer special tokens # FIXME ensure TQM properly initializes transform with global if necessary + vocabs = {'src': src_vocab, 'tgt': tgt_vocab} corpus_opts = opts.data[task.corpus_id] - transforms_cls = get_transforms_cls(corpus_opts.get('transforms', opts.transforms)) + transforms_to_apply = corpus_opts.get('transforms', None) + transforms_to_apply = transforms_to_apply or opts.get('transforms', None) + transforms_to_apply = transforms_to_apply or [] + transforms_cls = make_transforms( + opts, + get_transforms_cls(transforms_to_apply), + vocabs, + task=task, + ) + transforms_to_apply = [transforms_cls[trf_name] for trf_name in transforms_to_apply] - vocabs = {'src': src_vocab, 'tgt': tgt_vocab} # build Dataset proper dataset = ParallelCorpus( corpus_opts["path_src"] if is_train else corpus_opts["path_valid_src"], corpus_opts["path_tgt"] if is_train else corpus_opts["path_valid_tgt"], src_vocab, tgt_vocab, - TransformPipe(opts, make_transforms(opts, transforms_cls, vocabs, task=task).values()), + TransformPipe(opts, transforms_to_apply), stride=corpus_opts.get('stride', None), offset=corpus_opts.get('offset', None), is_train=is_train,