diff --git a/docs/source/config_config.md b/docs/source/config_config.md index 05b0ba88..aa096296 100644 --- a/docs/source/config_config.md +++ b/docs/source/config_config.md @@ -14,7 +14,7 @@ To ease the creation of configs, the config-config tool reads in a human-writabl ## Command ```bash -python3 mammoth/tools/config_config.py config_all --in_config input.yaml --out_config output.yaml +mammoth_config_config config_all --in_config input.yaml --out_config output.yaml ``` ## Inputs diff --git a/docs/source/quickstart.md b/docs/source/quickstart.md index c945e5a8..ce88fbe8 100644 --- a/docs/source/quickstart.md +++ b/docs/source/quickstart.md @@ -7,7 +7,7 @@ MAMMOTH is specifically designed for distributed training of modular systems in In the example below, we will show you how to configure Mammoth. We will use two small experiments as examples -1. A simple set of toy tasks with synthetic data. Easy and fast, requiring no resources except for the Mammoth git repo. +1. A simple set of toy tasks with synthetic data. Easy and fast, requiring no resources except for the Mammoth package. 2. A machine translation model with language-specific encoders and decoders. ### Step 0: Install mammoth @@ -25,45 +25,37 @@ Easy and fast, requiring no resources except for the Mammoth git repo. This example uses a very small vocabulary, so we can use a "word level" model without sentencepiece. The opts `--n_nodes`, `--n_gpus_per_node`, `--node_rank`, and `--gpu_rank` are set to use a single GPU. -### Step 1: Set the locations - -Set the locations to the directory in which you want to work on this project, and the path to the mammoth git repo. - -```bash -export PROJECT_DIR="/path/to/work/dir" -export MAMMOTH_DIR="/path/to/mammoth" -``` - -### Step 2: Activate your virtual env +### Step 1: Activate your virtual env ```bash source ~/venvs/mammoth/bin/activate ``` -### Step 3: Copy the config template from the Mammoth repo +### Step 2: Copy the config template from the Mammoth repo ```bash -cd $PROJECT_DIR mkdir config -cp -i ${MAMMOTH_DIR}/examples/synthdata.template.yaml config/synthdata.template.yaml +pushd config +wget "https://raw.githubusercontent.com/Helsinki-NLP/mammoth/refs/heads/main/examples/synthdata.template.yaml" +popd ``` -### Step 4: Generate synthetic data +### Step 3: Generate synthetic data (this might take about 5 min) ```bash -python ${MAMMOTH_DIR}/tools/generate_synth_data.py \ +mammoth_generate_synth_data \ --config_path config/synthdata.template.yaml \ --shared_vocab data/synthdata/shared_vocab ``` -### Step 5: Generate the actual config from the config template +### Step 4: Generate the actual config from the config template (this should only take a few seconds) ```bash -python ${MAMMOTH_DIR}/tools/config_config.py \ +mammoth_config_config \ config_all \ --in_config config/synthdata.template.yaml \ --out_config config/synthdata.yaml \ @@ -71,7 +63,7 @@ python ${MAMMOTH_DIR}/tools/config_config.py \ --n_gpus_per_node 1 ``` -### Step 6: Train the model +### Step 5: Train the model (This might take about 1h. To speed things up, train for a shorter time, e.g. `--train_steps 5000 --warmup_steps 600`) @@ -79,7 +71,7 @@ python ${MAMMOTH_DIR}/tools/config_config.py \ mammoth_train --config config/synthdata.yaml --node_rank 0 --gpu_rank 0 ``` -### Step 7: Translate +### Step 6: Translate (this might take a few minutes) @@ -101,25 +93,15 @@ mammoth_translate \ ## Experiment 2: Machine translation with multi30k -### Step 1: Set the locations - -Set the locations to the directory in which you want to work on this project, and the path to the mammoth git repo. - -```bash -export PROJECT_DIR="/path/to/work/dir" -export MAMMOTH_DIR="/path/to/mammoth" -``` - -### Step 2: Activate your virtual env +### Step 1: Activate your virtual env ```bash source ~/venvs/mammoth/bin/activate ``` -### Step 3: Download data +### Step 2: Download data ```bash -cd $PROJECT_DIR mkdir data/multi30k pushd data/multi30k @@ -131,7 +113,7 @@ done popd ``` -### Step 4: Train sentencepiece models +### Step 3: Train sentencepiece models ```bash mkdir -p models/spm @@ -142,19 +124,21 @@ for language in cs en de fr; do done ``` -### Step 5: Copy the config template from the Mammoth repo +### Step 4: Copy the config template from the Mammoth repo ```bash mkdir config -cp -i ${MAMMOTH_DIR}/examples/multi30k.template.yaml config/multi30k.template.yaml +pushd config +wget "https://raw.githubusercontent.com/Helsinki-NLP/mammoth/refs/heads/main/examples/multi30k.template.yaml" +popd ``` -### Step 6: Generate the actual config from the config template +### Step 5: Generate the actual config from the config template (this should only take a few seconds) ```bash -python ${MAMMOTH_DIR}/tools/config_config.py \ +mammoth_config_config.py \ config_all \ --in_config config/multi30k.template.yaml \ --out_config config/multi30k.yaml \ @@ -162,7 +146,7 @@ python ${MAMMOTH_DIR}/tools/config_config.py \ --n_gpus_per_node 1 ``` -### Step 7: Train the model +### Step 6: Train the model (this might take a while) @@ -170,7 +154,7 @@ python ${MAMMOTH_DIR}/tools/config_config.py \ mammoth_train --config config/multi30k.yaml --node_rank 0 --gpu_rank 0 ``` -### Step 8: Translate +### Step 7: Translate (this might take a while) @@ -189,7 +173,7 @@ MODEL="models/${EXP_NAME}_step_${STEP}" # Translate all language pairs mkdir -p "translations/${EXP_NAME}/" -python ${MAMMOTH_DIR}/tools/iterate_tasks.py --config ${CONFIG} \ +mammoth_iterate_tasks --config ${CONFIG} \ --src "data/${EXP_NAME}/test_2016_flickr.{src_lang}.gz" \ --output "translations/${EXP_NAME}/test_2016_flickr.{task_id}.greedy.trans" \ | while read task_flags; do \ diff --git a/tools/config_config.py b/mammoth/bin/config_config.py old mode 100644 new mode 100755 similarity index 99% rename from tools/config_config.py rename to mammoth/bin/config_config.py index bfd4cccc..1e9d715c --- a/tools/config_config.py +++ b/mammoth/bin/config_config.py @@ -13,7 +13,7 @@ from itertools import compress from sklearn.cluster import AgglomerativeClustering -from gpu_assignment import optimize_gpu_assignment +from mammoth.utils.gpu_assignment import optimize_gpu_assignment logger = logging.getLogger('config_config') @@ -918,12 +918,12 @@ def extra_copy_gpu_assignment(opts): opts.in_config[0]['gpu_ranks'] = opts.copy_from[0]['gpu_ranks'] -if __name__ == '__main__': +def main(): init_logging() opts = get_opts() # if not opts.out_config: # opts.out_config = opts.in_config[1] - main = { + command = { func.__name__: func for func in ( complete_language_pairs, @@ -941,5 +941,9 @@ def extra_copy_gpu_assignment(opts): extra_copy_gpu_assignment, ) }[opts.command] - main(opts) + command(opts) save_yaml(opts) + + +if __name__ == '__main__': + main() diff --git a/tools/generate_synth_data.py b/mammoth/bin/generate_synth_data.py similarity index 100% rename from tools/generate_synth_data.py rename to mammoth/bin/generate_synth_data.py diff --git a/tools/iterate_tasks.py b/mammoth/bin/iterate_tasks.py old mode 100644 new mode 100755 similarity index 93% rename from tools/iterate_tasks.py rename to mammoth/bin/iterate_tasks.py index ea5a137d..c0b3005f --- a/tools/iterate_tasks.py +++ b/mammoth/bin/iterate_tasks.py @@ -22,7 +22,7 @@ '--output', type=str, default=None, - help='Template for source file paths. Use varibles src_lang, tgt_lang, and task_id.', + help='Template for translation output file paths. Use varibles src_lang, tgt_lang, and task_id.', ) @click.option( '--flag', diff --git a/tools/gpu_assignment.py b/mammoth/utils/gpu_assignment.py similarity index 100% rename from tools/gpu_assignment.py rename to mammoth/utils/gpu_assignment.py diff --git a/setup.py b/setup.py index 9cf03467..71167ab5 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,9 @@ # "onmt_server=mammoth.bin.server:main", "mammoth_train=mammoth.bin.train:main", "mammoth_translate=mammoth.bin.translate:main", + "mammoth_config_config=mammoth.bin.config_config:main", + "mammoth_iterate_tasks=mammoth.bin.iterate_tasks:main", + "mammoth_generate_synth_data=mammoth.bin.generate_synth_data:main", # "onmt_release_model=mammoth.bin.release_model:main", # "onmt_average_models=mammoth.bin.average_models:main", # "onmt_build_vocab=mammoth.bin.build_vocab:main", diff --git a/tools/README.md b/tools/README.md index 0869b807..4594e577 100644 --- a/tools/README.md +++ b/tools/README.md @@ -1,3 +1 @@ -This directly contains scripts and tools adopted from other open source projects such as Apache Joshua and Moses Decoder. - -TODO: credit the authors and resolve license issues (if any) +This directory contains some legacy scripts not deemed worthy to install as part of the mammoth package. diff --git a/tools/apply_bpe.py b/tools/apply_bpe.py deleted file mode 100755 index d03c0916..00000000 --- a/tools/apply_bpe.py +++ /dev/null @@ -1,350 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Author: Rico Sennrich -# flake8: noqa - -"""Use operations learned with learn_bpe.py to encode a new text. -The text will not be smaller, but use only a fixed vocabulary, with rare words -encoded as variable-length sequences of subword units. - -Reference: -Rico Sennrich, Barry Haddow and Alexandra Birch (2015). Neural Machine Translation of Rare Words with Subword Units. -Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (ACL 2016). Berlin, Germany. -""" -# This file is retrieved from https://github.com/rsennrich/subword-nmt - -from __future__ import unicode_literals, division - -import sys -import codecs -import io -import argparse -import json -import re -from collections import defaultdict - -# hack for python2/3 compatibility -from io import open - -argparse.open = open - - -class BPE(object): - def __init__(self, codes, separator='@@', vocab=None, glossaries=None): - - # check version information - firstline = codes.readline() - if firstline.startswith('#version:'): - self.version = tuple([int(x) for x in re.sub(r'(\.0+)*$', '', firstline.split()[-1]).split(".")]) - else: - self.version = (0, 1) - codes.seek(0) - - self.bpe_codes = [tuple(item.split()) for item in codes] - - # some hacking to deal with duplicates (only consider first instance) - self.bpe_codes = dict([(code, i) for (i, code) in reversed(list(enumerate(self.bpe_codes)))]) - - self.bpe_codes_reverse = dict([(pair[0] + pair[1], pair) for pair, i in self.bpe_codes.items()]) - - self.separator = separator - - self.vocab = vocab - - self.glossaries = glossaries if glossaries else [] - - self.cache = {} - - def segment(self, sentence): - """segment single sentence (whitespace-tokenized string) with BPE encoding""" - output = [] - for word in sentence.split(): - new_word = [ - out - for segment in self._isolate_glossaries(word) - for out in encode( - segment, - self.bpe_codes, - self.bpe_codes_reverse, - self.vocab, - self.separator, - self.version, - self.cache, - self.glossaries, - ) - ] - - for item in new_word[:-1]: - output.append(item + self.separator) - output.append(new_word[-1]) - - return ' '.join(output) - - def _isolate_glossaries(self, word): - word_segments = [word] - for gloss in self.glossaries: - word_segments = [ - out_segments for segment in word_segments for out_segments in isolate_glossary(segment, gloss) - ] - return word_segments - - -def create_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.RawDescriptionHelpFormatter, description="learn BPE-based word segmentation" - ) - - parser.add_argument( - '--input', - '-i', - type=argparse.FileType('r'), - default=sys.stdin, - metavar='PATH', - help="Input file (default: standard input).", - ) - parser.add_argument( - '--codes', - '-c', - type=argparse.FileType('r'), - metavar='PATH', - required=True, - help="File with BPE codes (created by learn_bpe.py).", - ) - parser.add_argument( - '--output', - '-o', - type=argparse.FileType('w'), - default=sys.stdout, - metavar='PATH', - help="Output file (default: standard output)", - ) - parser.add_argument( - '--separator', - '-s', - type=str, - default='@@', - metavar='STR', - help="Separator between non-final subword units (default: '%(default)s'))", - ) - parser.add_argument( - '--vocabulary', - type=argparse.FileType('r'), - default=None, - metavar="PATH", - help="Vocabulary file (built with get_vocab.py). If provided, this script reverts any merge operations that produce an OOV.", - ) - parser.add_argument( - '--vocabulary-threshold', - type=int, - default=None, - metavar="INT", - help="Vocabulary threshold. If vocabulary is provided, any word with frequency < threshold will be treated as OOV", - ) - parser.add_argument( - '--glossaries', - type=str, - nargs='+', - default=None, - metavar="STR", - help="Glossaries. The strings provided in glossaries will not be affected" - + "by the BPE (i.e. they will neither be broken into subwords, nor concatenated with other subwords", - ) - - return parser - - -def get_pairs(word): - """Return set of symbol pairs in a word. - - word is represented as tuple of symbols (symbols being variable-length strings) - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, glossaries=None): - """Encode word based on list of BPE merge operations, which are applied consecutively""" - - if orig in cache: - return cache[orig] - - if orig in glossaries: - cache[orig] = (orig,) - return (orig,) - - if version == (0, 1): - word = tuple(orig) + ('',) - elif version == (0, 2): # more consistent handling of word-final segments - word = tuple(orig[:-1]) + (orig[-1] + '',) - else: - raise NotImplementedError - - pairs = get_pairs(word) - - if not pairs: - return orig - - while True: - bigram = min(pairs, key=lambda pair: bpe_codes.get(pair, float('inf'))) - if bigram not in bpe_codes: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except BaseException: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word) - 1 and word[i + 1] == second: - new_word.append(first + second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - - # don't print end-of-word symbols - if word[-1] == '': - word = word[:-1] - elif word[-1].endswith(''): - word = word[:-1] + (word[-1].replace('', ''),) - - if vocab: - word = check_vocab_and_split(word, bpe_codes_reverse, vocab, separator) - - cache[orig] = word - return word - - -def recursive_split(segment, bpe_codes, vocab, separator, final=False): - """Recursively split segment into smaller units (by reversing BPE merges) - until all units are either in-vocabulary, or cannot be split futher.""" - - try: - if final: - left, right = bpe_codes[segment + ''] - right = right[:-4] - else: - left, right = bpe_codes[segment] - except BaseException: - # sys.stderr.write('cannot split {0} further.\n'.format(segment)) - yield segment - return - - if left + separator in vocab: - yield left - else: - for item in recursive_split(left, bpe_codes, vocab, separator, False): - yield item - - if (final and right in vocab) or (not final and right + separator in vocab): - yield right - else: - for item in recursive_split(right, bpe_codes, vocab, separator, final): - yield item - - -def check_vocab_and_split(orig, bpe_codes, vocab, separator): - """Check for each segment in word if it is in-vocabulary, - and segment OOV segments into smaller units by reversing the BPE merge operations""" - - out = [] - - for segment in orig[:-1]: - if segment + separator in vocab: - out.append(segment) - else: - # sys.stderr.write('OOV: {0}\n'.format(segment)) - for item in recursive_split(segment, bpe_codes, vocab, separator, False): - out.append(item) - - segment = orig[-1] - if segment in vocab: - out.append(segment) - else: - # sys.stderr.write('OOV: {0}\n'.format(segment)) - for item in recursive_split(segment, bpe_codes, vocab, separator, True): - out.append(item) - - return out - - -def read_vocabulary(vocab_file, threshold): - """read vocabulary file produced by get_vocab.py, and filter according to frequency threshold.""" - - vocabulary = set() - - for line in vocab_file: - word, freq = line.split() - freq = int(freq) - if threshold is None or freq >= threshold: - vocabulary.add(word) - - return vocabulary - - -def isolate_glossary(word, glossary): - """ - Isolate a glossary present inside a word. - - Returns a list of subwords. In which all 'glossary' glossaries are isolated - - For example, if 'USA' is the glossary and '1934USABUSA' the word, the return value is: - ['1934', 'USA', 'B', 'USA'] - """ - if word == glossary or glossary not in word: - return [word] - else: - splits = word.split(glossary) - segments = [segment.strip() for split in splits[:-1] for segment in [split, glossary] if segment != ''] - return segments + [splits[-1].strip()] if splits[-1] != '' else segments - - -if __name__ == '__main__': - - # python 2/3 compatibility - if sys.version_info < (3, 0): - sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) - sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) - sys.stdin = codecs.getreader('UTF-8')(sys.stdin) - else: - sys.stdin = io.TextIOWrapper(sys.stdin.buffer, encoding='utf-8') - sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') - sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', write_through=True, line_buffering=True) - - parser = create_parser() - args = parser.parse_args() - - # read/write files as UTF-8 - args.codes = codecs.open(args.codes.name, encoding='utf-8') - if args.input.name != '': - args.input = codecs.open(args.input.name, encoding='utf-8') - if args.output.name != '': - args.output = codecs.open(args.output.name, 'w', encoding='utf-8') - if args.vocabulary: - args.vocabulary = codecs.open(args.vocabulary.name, encoding='utf-8') - - if args.vocabulary: - vocabulary = read_vocabulary(args.vocabulary, args.vocabulary_threshold) - else: - vocabulary = None - - bpe = BPE(args.codes, args.separator, vocabulary, args.glossaries) - - for line in args.input: - args.output.write(bpe.segment(line).strip()) - args.output.write('\n') diff --git a/tools/average_models.py b/tools/average_models.py deleted file mode 100755 index ce714f92..00000000 --- a/tools/average_models.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env python -from mammoth.bin.average_models import main - - -if __name__ == "__main__": - main() diff --git a/tools/bpe_pipeline.sh b/tools/bpe_pipeline.sh deleted file mode 100755 index 4c9138b7..00000000 --- a/tools/bpe_pipeline.sh +++ /dev/null @@ -1,162 +0,0 @@ -#!/usr/bin/env bash -# Author : Thamme Gowda -# Created : Nov 06, 2017 - -ONMT="$( cd "$( dirname "${BASH_SOURCE[0]}" )/.." && pwd )" - -#======= EXPERIMENT SETUP ====== -# Activate python environment if needed -source ~/.bashrc -# source activate py3 - -# update these variables -NAME="run1" -OUT="onmt-runs/$NAME" - -DATA="$ONMT/onmt-runs/data" -TRAIN_SRC=$DATA/*train.src -TRAIN_TGT=$DATA/*train.tgt -VALID_SRC=$DATA/*dev.src -VALID_TGT=$DATA/*dev.tgt -TEST_SRC=$DATA/*test.src -TEST_TGT=$DATA/*test.tgt - -BPE="" # default -BPE="src" # src, tgt, src+tgt - -# applicable only when BPE="src" or "src+tgt" -BPE_SRC_OPS=10000 - -# applicable only when BPE="tgt" or "src+tgt" -BPE_TGT_OPS=10000 - -GPUARG="" # default -GPUARG="0" - - -#====== EXPERIMENT BEGIN ====== - -# Check if input exists -for f in $TRAIN_SRC $TRAIN_TGT $VALID_SRC $VALID_TGT $TEST_SRC $TEST_TGT; do - if [[ ! -f "$f" ]]; then - echo "Input File $f doesnt exist. Please fix the paths" - exit 1 - fi -done - -function lines_check { - l1=`wc -l $1` - l2=`wc -l $2` - if [[ $l1 != $l2 ]]; then - echo "ERROR: Record counts doesnt match between: $1 and $2" - exit 2 - fi -} -lines_check $TRAIN_SRC $TRAIN_TGT -lines_check $VALID_SRC $VALID_TGT -lines_check $TEST_SRC $TEST_TGT - - -echo "Output dir = $OUT" -[ -d $OUT ] || mkdir -p $OUT -[ -d $OUT/data ] || mkdir -p $OUT/data -[ -d $OUT/models ] || mkdir $OUT/models -[ -d $OUT/test ] || mkdir -p $OUT/test - - -echo "Step 1a: Preprocess inputs" -if [[ "$BPE" == *"src"* ]]; then - echo "BPE on source" - # Here we could use more monolingual data - $ONMT/tools/learn_bpe.py -s $BPE_SRC_OPS < $TRAIN_SRC > $OUT/data/bpe-codes.src - - $ONMT/tools/apply_bpe.py -c $OUT/data/bpe-codes.src < $TRAIN_SRC > $OUT/data/train.src - $ONMT/tools/apply_bpe.py -c $OUT/data/bpe-codes.src < $VALID_SRC > $OUT/data/valid.src - $ONMT/tools/apply_bpe.py -c $OUT/data/bpe-codes.src < $TEST_SRC > $OUT/data/test.src -else - ln -sf $TRAIN_SRC $OUT/data/train.src - ln -sf $VALID_SRC $OUT/data/valid.src - ln -sf $TEST_SRC $OUT/data/test.src -fi - - -if [[ "$BPE" == *"tgt"* ]]; then - echo "BPE on target" - # Here we could use more monolingual data - $ONMT/tools/learn_bpe.py -s $BPE_SRC_OPS < $TRAIN_TGT > $OUT/data/bpe-codes.tgt - - $ONMT/tools/apply_bpe.py -c $OUT/data/bpe-codes.tgt < $TRAIN_TGT > $OUT/data/train.tgt - $ONMT/tools/apply_bpe.py -c $OUT/data/bpe-codes.tgt < $VALID_TGT > $OUT/data/valid.tgt - #$ONMT/tools/apply_bpe.py -c $OUT/data/bpe-codes.tgt < $TEST_TGT > $OUT/data/test.tgt - # We dont touch the test References, No BPE on them! - ln -sf $TEST_TGT $OUT/data/test.tgt -else - ln -sf $TRAIN_TGT $OUT/data/train.tgt - ln -sf $VALID_TGT $OUT/data/valid.tgt - ln -sf $TEST_TGT $OUT/data/test.tgt -fi - - -#: < maxv) {maxv=score; max=$0}} END{ print max}'` -echo "Chosen Model = $model" -if [[ -z "$model" ]]; then - echo "Model not found. Looked in $OUT/models/" - exit 1 -fi - -GPU_OPTS="" -if [ ! -z $GPUARG ]; then - GPU_OPTS="-gpu $GPUARG" -fi - -echo "Step 3a: Translate Test" -python $ONMT/translate.py -model $model \ - -src $OUT/data/test.src \ - -output $OUT/test/test.out \ - -replace_unk -verbose $GPU_OPTS > $OUT/test/test.log - -echo "Step 3b: Translate Dev" -python $ONMT/translate.py -model $model \ - -src $OUT/data/valid.src \ - -output $OUT/test/valid.out \ - -replace_unk -verbose $GPU_OPTS > $OUT/test/valid.log - -if [[ "$BPE" == *"tgt"* ]]; then - echo "BPE decoding/detokenising target to match with references" - mv $OUT/test/test.out{,.bpe} - mv $OUT/test/valid.out{,.bpe} - cat $OUT/test/valid.out.bpe | sed -E 's/(@@ )|(@@ ?$)//g' > $OUT/test/valid.out - cat $OUT/test/test.out.bpe | sed -E 's/(@@ )|(@@ ?$)//g' > $OUT/test/test.out -fi - -echo "Step 4a: Evaluate Test" -$ONMT/tools/multi-bleu-detok.perl $OUT/data/test.tgt < $OUT/test/test.out > $OUT/test/test.tc.bleu -$ONMT/tools/multi-bleu-detok.perl -lc $OUT/data/test.tgt < $OUT/test/test.out > $OUT/test/test.lc.bleu - -echo "Step 4b: Evaluate Dev" -$ONMT/tools/multi-bleu-detok.perl $OUT/data/valid.tgt < $OUT/test/valid.out > $OUT/test/valid.tc.bleu -$ONMT/tools/multi-bleu-detok.perl -lc $OUT/data/valid.tgt < $OUT/test/valid.out > $OUT/test/valid.lc.bleu - -#===== EXPERIMENT END ====== diff --git a/tools/embeddings_to_torch.py b/tools/embeddings_to_torch.py deleted file mode 100755 index 00f2d981..00000000 --- a/tools/embeddings_to_torch.py +++ /dev/null @@ -1,132 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -from __future__ import division -import six -import argparse -import torch -from mammoth.utils.logging import init_logger, logger - - -# FIXME haven't touched that file yet... - - -def get_vocabs(dict_path): - fields = torch.load(dict_path) - - vocs = [] - for side in ['src', 'tgt']: - try: - vocab = fields[side].base_field.vocab - except AttributeError: - vocab = fields[side].vocab - vocs.append(vocab) - enc_vocab, dec_vocab = vocs - - logger.info("From: %s" % dict_path) - logger.info("\t* source vocab: %d words" % len(enc_vocab)) - logger.info("\t* target vocab: %d words" % len(dec_vocab)) - - return enc_vocab, dec_vocab - - -def read_embeddings(file_enc, skip_lines=0, filter_set=None): - embs = dict() - total_vectors_in_file = 0 - with open(file_enc, 'rb') as f: - for i, line in enumerate(f): - if i < skip_lines: - continue - if not line: - break - if len(line) == 0: - # is this reachable? - continue - - l_split = line.decode('utf8').strip().split(' ') - if len(l_split) == 2: - continue - total_vectors_in_file += 1 - if filter_set is not None and l_split[0] not in filter_set: - continue - embs[l_split[0]] = [float(em) for em in l_split[1:]] - return embs, total_vectors_in_file - - -def convert_to_torch_tensor(word_to_float_list_dict, vocab): - dim = len(six.next(six.itervalues(word_to_float_list_dict))) - tensor = torch.zeros((len(vocab), dim)) - for word, values in word_to_float_list_dict.items(): - tensor[vocab.stoi[word]] = torch.Tensor(values) - return tensor - - -def calc_vocab_load_stats(vocab, loaded_embed_dict): - matching_count = len(set(vocab.stoi.keys()) & set(loaded_embed_dict.keys())) - missing_count = len(vocab) - matching_count - percent_matching = matching_count / len(vocab) * 100 - return matching_count, missing_count, percent_matching - - -def main(): - parser = argparse.ArgumentParser(description='embeddings_to_torch.py') - parser.add_argument( - '-emb_file_both', required=False, help="loads Embeddings for both source and target " "from this file." - ) - parser.add_argument('-emb_file_enc', required=False, help="source Embeddings from this file") - parser.add_argument('-emb_file_dec', required=False, help="target Embeddings from this file") - parser.add_argument('-output_file', required=True, help="Output file for the prepared data") - parser.add_argument('-dict_file', required=True, help="Dictionary file") - parser.add_argument('-verbose', action="store_true", default=False) - parser.add_argument('-skip_lines', type=int, default=0, help="Skip first lines of the embedding file") - parser.add_argument('-type', choices=["GloVe", "word2vec"], default="GloVe") - opts = parser.parse_args() - - enc_vocab, dec_vocab = get_vocabs(opts.dict_file) - - # Read in embeddings - skip_lines = 1 if opts.type == "word2vec" else opts.skip_lines - if opts.emb_file_both is not None: - if opts.emb_file_enc is not None: - raise ValueError("If --emb_file_both is passed in, you should not" "set --emb_file_enc.") - if opts.emb_file_dec is not None: - raise ValueError("If --emb_file_both is passed in, you should not" "set --emb_file_dec.") - set_of_src_and_tgt_vocab = set(enc_vocab.stoi.keys()) | set(dec_vocab.stoi.keys()) - logger.info("Reading encoder and decoder embeddings from {}".format(opts.emb_file_both)) - src_vectors, total_vec_count = read_embeddings(opts.emb_file_both, skip_lines, set_of_src_and_tgt_vocab) - tgt_vectors = src_vectors - logger.info("\tFound {} total vectors in file".format(total_vec_count)) - else: - if opts.emb_file_enc is None: - raise ValueError( - "If --emb_file_enc not provided. Please specify " - "the file with encoder embeddings, or pass in " - "--emb_file_both" - ) - if opts.emb_file_dec is None: - raise ValueError( - "If --emb_file_dec not provided. Please specify " - "the file with encoder embeddings, or pass in " - "--emb_file_both" - ) - logger.info("Reading encoder embeddings from {}".format(opts.emb_file_enc)) - src_vectors, total_vec_count = read_embeddings(opts.emb_file_enc, skip_lines, filter_set=enc_vocab.stoi) - logger.info("\tFound {} total vectors in file.".format(total_vec_count)) - logger.info("Reading decoder embeddings from {}".format(opts.emb_file_dec)) - tgt_vectors, total_vec_count = read_embeddings(opts.emb_file_dec, skip_lines, filter_set=dec_vocab.stoi) - logger.info("\tFound {} total vectors in file".format(total_vec_count)) - logger.info("After filtering to vectors in vocab:") - logger.info("\t* enc: %d match, %d missing, (%.2f%%)" % calc_vocab_load_stats(enc_vocab, src_vectors)) - logger.info("\t* dec: %d match, %d missing, (%.2f%%)" % calc_vocab_load_stats(dec_vocab, tgt_vectors)) - - # Write to file - enc_output_file = opts.output_file + ".enc.pt" - dec_output_file = opts.output_file + ".dec.pt" - logger.info("\nSaving embedding as:\n\t* enc: %s\n\t* dec: %s" % (enc_output_file, dec_output_file)) - torch.save(convert_to_torch_tensor(src_vectors, enc_vocab), enc_output_file) - torch.save(convert_to_torch_tensor(tgt_vectors, dec_vocab), dec_output_file) - logger.info("\nDone.") - - -if __name__ == "__main__": - init_logger('embeddings_to_torch.log') - main() diff --git a/tools/extract_embeddings.py b/tools/extract_embeddings.py deleted file mode 100644 index 49ecbca9..00000000 --- a/tools/extract_embeddings.py +++ /dev/null @@ -1,75 +0,0 @@ -import argparse - -import torch - -import mammoth -import mammoth.model_builder - -from mammoth.utils.parse import ArgumentParser -import mammoth.opts - -from mammoth.utils.misc import use_gpu -from mammoth.utils.logging import init_logger, logger - -parser = argparse.ArgumentParser(description='translate.py') - -parser.add_argument('-model', required=True, help='Path to model .pt file') -parser.add_argument('-output_dir', default='.', help="""Path to output the embeddings""") -parser.add_argument('-gpu', type=int, default=-1, help="Device to run on") - - -def write_embeddings(filename, dict, embeddings): - with open(filename, 'wb') as file: - for i in range(min(len(embeddings), len(dict.itos))): - str = dict.itos[i].encode("utf-8") - for j in range(len(embeddings[0])): - str = str + (" %5f" % (embeddings[i][j])).encode("utf-8") - file.write(str + b"\n") - - -def main(): - dummy_parser = argparse.ArgumentParser(description='train.py') - mammoth.opts.model_opts(dummy_parser) - dummy_opt = dummy_parser.parse_known_args([])[0] - opts = parser.parse_args() - opts.cuda = opts.gpu > -1 - if opts.cuda: - torch.cuda.set_device(opts.gpu) - - # Add in default model arguments, possibly added since training. - checkpoint = torch.load(opts.model, map_location=lambda storage, loc: storage) - model_opts = checkpoint['opts'] - - fields = checkpoint['vocab'] - src_dict = fields['src'].base_field.vocab # assumes src is text - tgt_dict = fields['tgt'].base_field.vocab - - model_opts = checkpoint['opts'] - for arg in dummy_opt.__dict__: - if arg not in model_opts: - model_opts.__dict__[arg] = dummy_opt.__dict__[arg] - - # build_base_model expects updated and validated opts - ArgumentParser.update_model_opts(model_opts) - ArgumentParser.validate_model_opts(model_opts) - - model = mammoth.model_builder.build_base_model(model_opts, fields, use_gpu(opts), checkpoint) - encoder = model.encoder # no encoder for LM task - decoder = model.decoder - - encoder_embeddings = encoder.embeddings.word_lut.weight.data.tolist() - decoder_embeddings = decoder.embeddings.word_lut.weight.data.tolist() - - logger.info("Writing source embeddings") - write_embeddings(opts.output_dir + "/src_embeddings.txt", src_dict, encoder_embeddings) - - logger.info("Writing target embeddings") - write_embeddings(opts.output_dir + "/tgt_embeddings.txt", tgt_dict, decoder_embeddings) - - logger.info('... done.') - logger.info('Converting model...') - - -if __name__ == "__main__": - init_logger('extract_embeddings.log') - main() diff --git a/tools/extract_vocabulary.py b/tools/extract_vocabulary.py deleted file mode 100644 index 3c062e54..00000000 --- a/tools/extract_vocabulary.py +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -import argparse -import sys - - -def read_files_batch(file_list): - """Reads the provided files in batches""" - batch = [] # Keep batch for each file - fd_list = [] # File descriptor list - - exit = False # Flag used for quitting the program in case of error - try: - for filename in file_list: - fd_list.append(open(filename)) - - for lines in zip(*fd_list): - for i, line in enumerate(lines): - line = line.rstrip("\n").split(" ") - batch.append(line) - - yield batch - batch = [] # Reset batch - - except IOError: - print("Error reading file " + filename + ".") - exit = True # Flag to exit the program - - finally: - for fd in fd_list: - fd.close() - - if exit: # An error occurred, end execution - sys.exit(-1) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument( - '-file_type', - default='text', - choices=['text', 'field'], - required=True, - help="""Options for vocabulary extraction. - The default is 'text' where the user passes - a corpus or a list of corpora files for which - they want to create a vocabulary from. - If choosing the option 'field', we assume - the file passed is a torch file created during - the preprocessing stage of an already - preprocessed corpus. The vocabulary file created - will just be the vocabulary inside the field - corresponding to the argument 'side'.""", - ) - parser.add_argument("-file", type=str, nargs="+", required=True) - parser.add_argument("-out_file", type=str, required=True) - parser.add_argument( - "-side", - choices=['src', 'tgt'], - help="""Specifies 'src' or 'tgt' side for 'field' file_type.""", - ) - - opts = parser.parse_args() - - vocabulary = {} - if opts.file_type == 'text': - print("Reading input file...") - for batch in read_files_batch(opts.file): - for sentence in batch: - for w in sentence: - if w in vocabulary: - vocabulary[w] += 1 - else: - vocabulary[w] = 1 - - print("Writing vocabulary file...") - with open(opts.out_file, "w") as f: - for w, count in sorted(vocabulary.items(), key=lambda x: x[1], reverse=True): - f.write("{0}\n".format(w)) - else: - if opts.side not in ['src', 'tgt']: - raise ValueError("If using -file_type='field', specifies 'src' or 'tgt' argument for -side.") - import torch - - print("Reading input file...") - if not len(opts.file) == 1: - raise ValueError("If using -file_type='field', only pass one argument for -file.") - vocabs = torch.load(opts.file[0]) - voc = dict(vocabs)[opts.side] - - try: - word_list = voc[0][1].base_field.vocab.itos - except AttributeError: - word_list = voc[0][1].vocab.itos - - print("Writing vocabulary file...") - with open(opts.out_file, "wb") as f: - for w in word_list: - f.write(u"{0}\n".format(w).encode("utf-8")) - - -if __name__ == "__main__": - main() diff --git a/tools/learn_bpe.py b/tools/learn_bpe.py deleted file mode 100755 index 0dc1c97f..00000000 --- a/tools/learn_bpe.py +++ /dev/null @@ -1,278 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# Author: Rico Sennrich -# flake8: noqa - -"""Use byte pair encoding (BPE) to learn a variable-length encoding of the vocabulary in a text. -Unlike the original BPE, it does not compress the plain text, but can be used to reduce the vocabulary -of a text to a configurable number of symbols, with only a small increase in the number of tokens. - -Reference: -Rico Sennrich, Barry Haddow and Alexandra Birch (2016). Neural Machine Translation of Rare Words with Subword Units. -Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (ACL 2016). Berlin, Germany. -""" -# This file is retrieved from https://github.com/rsennrich/subword-nmt - -from __future__ import unicode_literals - -import sys -import codecs -import re -import copy -import argparse -from collections import defaultdict, Counter - -# hack for python2/3 compatibility -from io import open - -argparse.open = open - - -def create_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.RawDescriptionHelpFormatter, description="learn BPE-based word segmentation" - ) - - parser.add_argument( - '--input', - '-i', - type=argparse.FileType('r'), - default=sys.stdin, - metavar='PATH', - help="Input text (default: standard input).", - ) - - parser.add_argument( - '--output', - '-o', - type=argparse.FileType('w'), - default=sys.stdout, - metavar='PATH', - help="Output file for BPE codes (default: standard output)", - ) - parser.add_argument( - '--symbols', - '-s', - type=int, - default=10000, - help="Create this many new symbols (each representing a character n-gram) (default: %(default)s))", - ) - parser.add_argument( - '--min-frequency', - type=int, - default=2, - metavar='FREQ', - help='Stop if no symbol pair has frequency >= FREQ (default: %(default)s))', - ) - parser.add_argument( - '--dict-input', - action="store_true", - help="If set, input file is interpreted as a dictionary where each line contains a word-count pair", - ) - parser.add_argument('--verbose', '-v', action="store_true", help="verbose mode.") - - return parser - - -def get_vocabulary(fobj, is_dict=False): - """Read text and return dictionary that encodes vocabulary""" - vocab = Counter() - for line in fobj: - if is_dict: - word, count = line.strip().split() - vocab[word] = int(count) - else: - for word in line.split(): - vocab[word] += 1 - return vocab - - -def update_pair_statistics(pair, changed, stats, indices): - """Minimally update the indices and frequency of symbol pairs - - if we merge a pair of symbols, only pairs that overlap with occurrences - of this pair are affected, and need to be updated. - """ - stats[pair] = 0 - indices[pair] = defaultdict(int) - first, second = pair - new_pair = first + second - for j, word, old_word, freq in changed: - - # find all instances of pair, and update frequency/indices around it - i = 0 - while True: - # find first symbol - try: - i = old_word.index(first, i) - except ValueError: - break - # if first symbol is followed by second symbol, we've found an occurrence of pair (old_word[i:i+2]) - if i < len(old_word) - 1 and old_word[i + 1] == second: - # assuming a symbol sequence "A B C", if "B C" is merged, reduce the frequency of "A B" - if i: - prev = old_word[i - 1: i + 1] - stats[prev] -= freq - indices[prev][j] -= 1 - if i < len(old_word) - 2: - # assuming a symbol sequence "A B C B", if "B C" is merged, reduce the frequency of "C B". - # however, skip this if the sequence is A B C B C, because the frequency - # of "C B" will be reduced by the previous code block - if old_word[i + 2] != first or i >= len(old_word) - 3 or old_word[i + 3] != second: - nex = old_word[i + 1: i + 3] - stats[nex] -= freq - indices[nex][j] -= 1 - i += 2 - else: - i += 1 - - i = 0 - while True: - try: - # find new pair - i = word.index(new_pair, i) - except ValueError: - break - # assuming a symbol sequence "A BC D", if "B C" is merged, increase the frequency of "A BC" - if i: - prev = word[i - 1: i + 1] - stats[prev] += freq - indices[prev][j] += 1 - # assuming a symbol sequence "A BC B", if "B C" is merged, increase the frequency of "BC B" - # however, if the sequence is A BC BC, skip this step because the count of - # "BC BC" will be incremented by the previous code block - if i < len(word) - 1 and word[i + 1] != new_pair: - nex = word[i: i + 2] - stats[nex] += freq - indices[nex][j] += 1 - i += 1 - - -def get_pair_statistics(vocab): - """Count frequency of all symbol pairs, and create index""" - - # data structure of pair frequencies - stats = defaultdict(int) - - # index from pairs to words - indices = defaultdict(lambda: defaultdict(int)) - - for i, (word, freq) in enumerate(vocab): - prev_char = word[0] - for char in word[1:]: - stats[prev_char, char] += freq - indices[prev_char, char][i] += 1 - prev_char = char - - return stats, indices - - -def replace_pair(pair, vocab, indices): - """Replace all occurrences of a symbol pair ('A', 'B') with a new symbol 'AB'""" - first, second = pair - pair_str = ''.join(pair) - pair_str = pair_str.replace('\\', '\\\\') - changes = [] - pattern = re.compile(r'(?'); - # version numbering allows bckward compatibility - outfile.write('#version: 0.2\n') - - vocab = get_vocabulary(infile, is_dict) - vocab = dict([(tuple(x[:-1]) + (x[-1] + '',), y) for (x, y) in vocab.items()]) - sorted_vocab = sorted(vocab.items(), key=lambda x: x[1], reverse=True) - - stats, indices = get_pair_statistics(sorted_vocab) - big_stats = copy.deepcopy(stats) - # threshold is inspired by Zipfian assumption, but should only affect speed - threshold = max(stats.values()) / 10 - for i in range(num_symbols): - if stats: - most_frequent = max(stats, key=lambda x: (stats[x], x)) - - # we probably missed the best pair because of pruning; go back to full statistics - if not stats or (i and stats[most_frequent] < threshold): - prune_stats(stats, big_stats, threshold) - stats = copy.deepcopy(big_stats) - most_frequent = max(stats, key=lambda x: (stats[x], x)) - # threshold is inspired by Zipfian assumption, but should only affect speed - threshold = stats[most_frequent] * i / (i + 10000.0) - prune_stats(stats, big_stats, threshold) - - if stats[most_frequent] < min_frequency: - sys.stderr.write('no pair has frequency >= {0}. Stopping\n'.format(min_frequency)) - break - - if verbose: - sys.stderr.write( - 'pair {0}: {1} {2} -> {1}{2} (frequency {3})\n'.format( - i, most_frequent[0], most_frequent[1], stats[most_frequent] - ) - ) - outfile.write('{0} {1}\n'.format(*most_frequent)) - changes = replace_pair(most_frequent, sorted_vocab, indices) - update_pair_statistics(most_frequent, changes, stats, indices) - stats[most_frequent] = 0 - if not i % 100: - prune_stats(stats, big_stats, threshold) - - -if __name__ == '__main__': - - # python 2/3 compatibility - if sys.version_info < (3, 0): - sys.stderr = codecs.getwriter('UTF-8')(sys.stderr) - sys.stdout = codecs.getwriter('UTF-8')(sys.stdout) - sys.stdin = codecs.getreader('UTF-8')(sys.stdin) - else: - sys.stderr = codecs.getwriter('UTF-8')(sys.stderr.buffer) - sys.stdout = codecs.getwriter('UTF-8')(sys.stdout.buffer) - sys.stdin = codecs.getreader('UTF-8')(sys.stdin.buffer) - - parser = create_parser() - args = parser.parse_args() - - # read/write files as UTF-8 - if args.input.name != '': - args.input = codecs.open(args.input.name, encoding='utf-8') - if args.output.name != '': - args.output = codecs.open(args.output.name, 'w', encoding='utf-8') - - main(args.input, args.output, args.symbols, args.min_frequency, args.verbose, is_dict=args.dict_input) diff --git a/tools/multi-bleu-detok.perl b/tools/multi-bleu-detok.perl deleted file mode 100755 index 9d8edd73..00000000 --- a/tools/multi-bleu-detok.perl +++ /dev/null @@ -1,212 +0,0 @@ -#!/usr/bin/env perl -# -# This file is part of moses. Its use is licensed under the GNU Lesser General -# Public License version 2.1 or, at your option, any later version. - -# This file uses the internal tokenization of mteval-v13a.pl, -# giving the exact same (case-sensitive) results on untokenized text. -# Using this script with detokenized output and untokenized references is -# preferrable over multi-bleu.perl, since scores aren't affected by tokenization differences. -# -# like multi-bleu.perl , it supports plain text input and multiple references. - -# This file is retrieved from Moses Decoder :: https://github.com/moses-smt/mosesdecoder -# $Id$ -use warnings; -use strict; - -my $lowercase = 0; -if ($ARGV[0] eq "-lc") { - $lowercase = 1; - shift; -} - -my $stem = $ARGV[0]; -if (!defined $stem) { - print STDERR "usage: multi-bleu-detok.pl [-lc] reference < hypothesis\n"; - print STDERR "Reads the references from reference or reference0, reference1, ...\n"; - exit(1); -} - -$stem .= ".ref" if !-e $stem && !-e $stem."0" && -e $stem.".ref0"; - -my @REF; -my $ref=0; -while(-e "$stem$ref") { - &add_to_ref("$stem$ref",\@REF); - $ref++; -} -&add_to_ref($stem,\@REF) if -e $stem; -die("ERROR: could not find reference file $stem") unless scalar @REF; - -# add additional references explicitly specified on the command line -shift; -foreach my $stem (@ARGV) { - &add_to_ref($stem,\@REF) if -e $stem; -} - - - -sub add_to_ref { - my ($file,$REF) = @_; - my $s=0; - if ($file =~ /.gz$/) { - open(REF,"gzip -dc $file|") or die "Can't read $file"; - } else { - open(REF,$file) or die "Can't read $file"; - } - while() { - chop; - $_ = tokenization($_); - push @{$$REF[$s++]}, $_; - } - close(REF); -} - -my(@CORRECT,@TOTAL,$length_translation,$length_reference); -my $s=0; -while() { - chop; - $_ = lc if $lowercase; - $_ = tokenization($_); - my @WORD = split; - my %REF_NGRAM = (); - my $length_translation_this_sentence = scalar(@WORD); - my ($closest_diff,$closest_length) = (9999,9999); - foreach my $reference (@{$REF[$s]}) { -# print "$s $_ <=> $reference\n"; - $reference = lc($reference) if $lowercase; - my @WORD = split(' ',$reference); - my $length = scalar(@WORD); - my $diff = abs($length_translation_this_sentence-$length); - if ($diff < $closest_diff) { - $closest_diff = $diff; - $closest_length = $length; - # print STDERR "$s: closest diff ".abs($length_translation_this_sentence-$length)." = abs($length_translation_this_sentence-$length), setting len: $closest_length\n"; - } elsif ($diff == $closest_diff) { - $closest_length = $length if $length < $closest_length; - # from two references with the same closeness to me - # take the *shorter* into account, not the "first" one. - } - for(my $n=1;$n<=4;$n++) { - my %REF_NGRAM_N = (); - for(my $start=0;$start<=$#WORD-($n-1);$start++) { - my $ngram = "$n"; - for(my $w=0;$w<$n;$w++) { - $ngram .= " ".$WORD[$start+$w]; - } - $REF_NGRAM_N{$ngram}++; - } - foreach my $ngram (keys %REF_NGRAM_N) { - if (!defined($REF_NGRAM{$ngram}) || - $REF_NGRAM{$ngram} < $REF_NGRAM_N{$ngram}) { - $REF_NGRAM{$ngram} = $REF_NGRAM_N{$ngram}; -# print "$i: REF_NGRAM{$ngram} = $REF_NGRAM{$ngram}
\n"; - } - } - } - } - $length_translation += $length_translation_this_sentence; - $length_reference += $closest_length; - for(my $n=1;$n<=4;$n++) { - my %T_NGRAM = (); - for(my $start=0;$start<=$#WORD-($n-1);$start++) { - my $ngram = "$n"; - for(my $w=0;$w<$n;$w++) { - $ngram .= " ".$WORD[$start+$w]; - } - $T_NGRAM{$ngram}++; - } - foreach my $ngram (keys %T_NGRAM) { - $ngram =~ /^(\d+) /; - my $n = $1; - # my $corr = 0; -# print "$i e $ngram $T_NGRAM{$ngram}
\n"; - $TOTAL[$n] += $T_NGRAM{$ngram}; - if (defined($REF_NGRAM{$ngram})) { - if ($REF_NGRAM{$ngram} >= $T_NGRAM{$ngram}) { - $CORRECT[$n] += $T_NGRAM{$ngram}; - # $corr = $T_NGRAM{$ngram}; -# print "$i e correct1 $T_NGRAM{$ngram}
\n"; - } - else { - $CORRECT[$n] += $REF_NGRAM{$ngram}; - # $corr = $REF_NGRAM{$ngram}; -# print "$i e correct2 $REF_NGRAM{$ngram}
\n"; - } - } - # $REF_NGRAM{$ngram} = 0 if !defined $REF_NGRAM{$ngram}; - # print STDERR "$ngram: {$s, $REF_NGRAM{$ngram}, $T_NGRAM{$ngram}, $corr}\n" - } - } - $s++; -} -my $brevity_penalty = 1; -my $bleu = 0; - -my @bleu=(); - -for(my $n=1;$n<=4;$n++) { - if (defined ($TOTAL[$n])){ - $bleu[$n]=($TOTAL[$n])?$CORRECT[$n]/$TOTAL[$n]:0; - # print STDERR "CORRECT[$n]:$CORRECT[$n] TOTAL[$n]:$TOTAL[$n]\n"; - }else{ - $bleu[$n]=0; - } -} - -if ($length_reference==0){ - printf "BLEU = 0, 0/0/0/0 (BP=0, ratio=0, hyp_len=0, ref_len=0)\n"; - exit(1); -} - -if ($length_translation<$length_reference) { - $brevity_penalty = exp(1-$length_reference/$length_translation); -} -$bleu = $brevity_penalty * exp((my_log( $bleu[1] ) + - my_log( $bleu[2] ) + - my_log( $bleu[3] ) + - my_log( $bleu[4] ) ) / 4) ; -printf "BLEU = %.2f, %.1f/%.1f/%.1f/%.1f (BP=%.3f, ratio=%.3f, hyp_len=%d, ref_len=%d)\n", - 100*$bleu, - 100*$bleu[1], - 100*$bleu[2], - 100*$bleu[3], - 100*$bleu[4], - $brevity_penalty, - $length_translation / $length_reference, - $length_translation, - $length_reference; - -sub my_log { - return -9999999999 unless $_[0]; - return log($_[0]); -} - - - -sub tokenization -{ - my ($norm_text) = @_; - -# language-independent part: - $norm_text =~ s///g; # strip "skipped" tags - $norm_text =~ s/-\n//g; # strip end-of-line hyphenation and join lines - $norm_text =~ s/\n/ /g; # join lines - $norm_text =~ s/"/"/g; # convert SGML tag for quote to " - $norm_text =~ s/&/&/g; # convert SGML tag for ampersand to & - $norm_text =~ s/</ - $norm_text =~ s/>/>/g; # convert SGML tag for greater-than to < - -# language-dependent part (assuming Western languages): - $norm_text = " $norm_text "; - $norm_text =~ s/([\{-\~\[-\` -\&\(-\+\:-\@\/])/ $1 /g; # tokenize punctuation - $norm_text =~ s/([^0-9])([\.,])/$1 $2 /g; # tokenize period and comma unless preceded by a digit - $norm_text =~ s/([\.,])([^0-9])/ $1 $2/g; # tokenize period and comma unless followed by a digit - $norm_text =~ s/([0-9])(-)/$1 $2 /g; # tokenize dash when preceded by a digit - $norm_text =~ s/\s+/ /g; # one space only between words - $norm_text =~ s/^\s+//; # no leading space - $norm_text =~ s/\s+$//; # no trailing space - - return $norm_text; -} diff --git a/tools/release_model.py b/tools/release_model.py deleted file mode 100644 index dd437517..00000000 --- a/tools/release_model.py +++ /dev/null @@ -1,6 +0,0 @@ -#!/usr/bin/env python -from mammoth.bin.release_model import main - - -if __name__ == "__main__": - main() diff --git a/tools/spm_to_vocab.py b/tools/spm_to_vocab.py deleted file mode 100644 index f2371727..00000000 --- a/tools/spm_to_vocab.py +++ /dev/null @@ -1,23 +0,0 @@ -# converts a SentencePiece vocabulary to the format expected by dynamic data -# (essentially converts float expected counts to "fixed precision" int pseudo -# counts) -import sys -import math -from mammoth.constants import DefaultTokens - -OMIT = (DefaultTokens.UNK, DefaultTokens.BOS, DefaultTokens.EOS) - - -def convert(lines): - for line in lines: - w, c = line.rstrip('\n').split(None, 1) - if w in OMIT: - continue - c = math.exp(float(c)) * 1000000 - c = int(c) + 1 - yield w, c - - -if __name__ == '__main__': - for c, w in convert(sys.stdin): - print('{}\t{}'.format(c, w)) diff --git a/tools/test_rouge.py b/tools/test_rouge.py deleted file mode 100644 index 12ccc35d..00000000 --- a/tools/test_rouge.py +++ /dev/null @@ -1,72 +0,0 @@ -# -*- encoding: utf-8 -*- -import argparse -import os -import time -import pyrouge -import shutil -import sys -import codecs - -from mammoth.utils.logging import init_logger, logger - - -def eval_rouge(cand, ref): - """Calculate ROUGE scores of sequences passed as an iterator - e.g. a list of str, an open file, StringIO or even sys.stdin - """ - current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) - tmp_dir = ".rouge-tmp-{}".format(current_time) - try: - if not os.path.isdir(tmp_dir): - os.mkdir(tmp_dir) - os.mkdir(tmp_dir + "/candidate") - os.mkdir(tmp_dir + "/reference") - candidates = [line.strip() for line in cand] - references = [line.strip() for line in ref] - assert len(candidates) == len(references) - cnt = len(candidates) - for i in range(cnt): - if len(references[i]) < 1: - continue - with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", encoding="utf-8") as f: - f.write(candidates[i]) - with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", encoding="utf-8") as f: - f.write(references[i]) - r = pyrouge.Rouge155() - r.model_dir = tmp_dir + "/reference/" - r.system_dir = tmp_dir + "/candidate/" - r.model_filename_pattern = 'ref.#ID#.txt' - r.system_filename_pattern = r'cand.(\d+).txt' - rouge_results = r.convert_and_evaluate() - results_dict = r.output_to_dict(rouge_results) - return results_dict - finally: - pass - if os.path.isdir(tmp_dir): - shutil.rmtree(tmp_dir) - - -def rouge_results_to_str(results_dict): - return ">> ROUGE(1/2/3/L/SU4): {:.2f}/{:.2f}/{:.2f}/{:.2f}/{:.2f}".format( - results_dict["rouge_1_f_score"] * 100, - results_dict["rouge_2_f_score"] * 100, - results_dict["rouge_3_f_score"] * 100, - results_dict["rouge_l_f_score"] * 100, - results_dict["rouge_su*_f_score"] * 100, - ) - - -if __name__ == "__main__": - init_logger('test_rouge.log') - parser = argparse.ArgumentParser() - parser.add_argument('-c', type=str, default="candidate.txt", help='candidate file') - parser.add_argument('-r', type=str, default="reference.txt", help='reference file') - args = parser.parse_args() - if args.c.upper() == "STDIN": - candidates = sys.stdin - else: - candidates = codecs.open(args.c, encoding="utf-8") - references = codecs.open(args.r, encoding="utf-8") - - results_dict = eval_rouge(candidates, references) - logger.info(rouge_results_to_str(results_dict))