Skip to content

Commit

Permalink
Readme with instructions to generate and evaluate with a 12B model (f…
Browse files Browse the repository at this point in the history
…acebookresearch#1351)

Summary: Pull Request resolved: fairinternal/fairseq-py#1351

Reviewed By: edunov

Differential Revision: D24386349

Pulled By: huihuifan

fbshipit-source-id: ade362d7cb64e24e6b2689ba87c53636073d2246
  • Loading branch information
shruti-bh authored and facebook-github-bot committed Oct 19, 2020
1 parent a48f235 commit 65e11a3
Show file tree
Hide file tree
Showing 5 changed files with 401 additions and 11 deletions.
213 changes: 202 additions & 11 deletions examples/m2m_100/README.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,209 @@
# MMMT Tokenizer
# Beyond English-Centric Multilingual Machine Translation

We apply different tokenization strategies for different languages following the existing literature. Here we provide tok.sh a tokenizer that can be used to reproduce our results.
## Introduction
In this work, we create a true Many-to-Many multilingual translation model that can translate directly between any pair of 100 languages. Our focus on non-English-Centric models brings gains of more than 10 BLEU when directly translating between non-English directions while performing competitively with the best single systems of WMT.

To reproduce the results, follow these steps:
If you are new to using fairseq, read the following walkthrough. Otherwise, skip to the sections below.

0. **Generation Data**

To download the generation data, follow the below commands. Note that all datasets need to be detokenized *before* applying SPM in the data preprocessing step. If you use these evaluation datasets, please cite their associated papers.
```bash
# WMT - use sacrebleu, example here:
sacrebleu -t wmt14 -l fr-en --echo src > wmt.test.fr-en.fr
sacrebleu -t wmt14 -l fr-en --echo ref > wmt.test.fr-en.en

# WAT
wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/wat2019.my-en.zip
unzip wat2019.my-en.zip

# FLORES
# download from: https://github.com/facebookresearch/flores

# TED - need to detokenize with Moses!
# from: https://github.com/neulab/word-embeddings-for-nmt
wget http://phontron.com/data/ted_talks.tar.gz

# Autshumato
# request to download: https://repo.sadilar.org/handle/20.500.12185/397

# Tatoeba Challenge
# available here: https://github.com/Helsinki-NLP/Tatoeba-Challenge
```
tgt_lang=...
reference_translation=...
cat generation_output | grep -P "^H" |sort -V |cut -f 3- |sh tok.sh $tgt_lang > hyp
cat $reference_translation |sh tok.sh $tgt_lang > ref
sacrebleu -tok 'none' ref < hyp

1. **Training Data**

To produce the training data, we use a combination of [CCMatrix](https://arxiv.org/abs/1911.04944) and [CCAligned](https://arxiv.org/abs/1911.06154). Check out the instructions [here](https://github.com/facebookresearch/LASER/tree/master/tasks/CCMatrix) to download the raw data.

2. **Preprocess Data**

After downloading raw data, you will need to postprocess the data, then apply SPM, then binarize. Note that it is very important you run the postprocessing script, because this removes any instance of the evaluation data in the mined training data.

```bash
# preprocess data

# remove sentences with more than 50% punctuation
python /path/to/fairseq/examples/m2m_100/process_data/remove_too_much_punc.py

# deduplicate training data
paste /path/to/datadir/train.$src /path/to/datadir/train.$tgt | awk '!x[$0]++' > /path/to/datadir/train.dedup
echo "keeping $(wc -l /path/to/datadir/train.dedup) bitext out of $(wc -l /path/to/datadir/train.$src)"
cut -f1 /path/to/datadir/train.dedup > /path/to/datadir/train.$src
cut -f2 /path/to/datadir/train.dedup > /path/to/datadir/train.$tgt

# remove all instances of evaluation data from the training data
python /path/to/fairseq/examples/m2m_100/process_data/dedup_data.py

# frequency cleaning
wget https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz
tar -xvzf histograms.tar.gz
python /path/to/fairseq/examples/m2m_100/process_data/clean_histogram.py --src $src --tgt $tgt --src-file /path/to/source/file --tgt-file /path/to/output/file --src-output-file source_output.$src --tgt-output-file target_output.$tgt --histograms /path/to/histograms

# apply SPM
wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
python /path/to/fairseq/scripts/spm_encode.py \
--model spm.128k.model \
--output_format=piece \
--inputs=/path/to/input/file/here \
--outputs=/path/to/output/file/here

# length ratio cleaning
perl mosesdecoder/scripts/training/clean-corpus-n.perl --ratio 3 /path/to/training/data/train.spm.$src-$tgt $src $tgt /path/to/output/directory/train.spm.$src-$tgt 1 250

# binarize data
wget https://dl.fbaipublicfiles.com/m2m_100/data_dict.128k.txt
fairseq-preprocess \
--source-lang $src --target-lang $tgt \
--testpref spm.$src.$tgt \
--thresholdsrc 0 --thresholdtgt 0 \
--destdir data_bin \
--srcdict data_dict.128k.txt --tgtdict data_dict.128k.txt
```

# Installation
3. **Training Scripts**

To reproduce the training of our models, we train with fairseq-py's multilingual translation [task](https://github.com/pytorch/fairseq/tree/master/examples/multilingual). If you are interested in model parallel training, also check out [fairscale](https://github.com/facebookresearch/fairscale).

4. **Generation**

To generate from our models, follow the the commands in the generation section below.


If you use any of the resources listed here, please cite:
```bibtex
@article{fan2020beyond,
title={Beyond English-Centric Multilingual Machine Translation},
author={Fan, Angela and Bhosale, Shruti and Schwenk, Holger and Ma, Zhiyi and El-Kishky, Ahmed and Goyal, Siddharth and Baines, Mandeep and Celebi, Onur and Wenzek, Guillaume and Chaudhary, Vishrav and Goyal, Naman and Birch, Tom and Liptchinsky, Vitaliy and Edunov, Sergey and Grave, Edouard and Auli, Michael and Joulin, Armand},
journal={arXiv preprint},
year={2020}
}
@article{schwenk2019ccmatrix,
title={Ccmatrix: Mining billions of high-quality parallel sentences on the web},
author={Schwenk, Holger and Wenzek, Guillaume and Edunov, Sergey and Grave, Edouard and Joulin, Armand},
journal={arXiv preprint arXiv:1911.04944},
year={2019}
}
@article{el2019massive,
title={A Massive Collection of Cross-Lingual Web-Document Pairs},
author={El-Kishky, Ahmed and Chaudhary, Vishrav and Guzman, Francisco and Koehn, Philipp},
journal={arXiv preprint arXiv:1911.06154},
year={2019}
}
```

Tools needed for all the languages except Arabic can be installed by running install_dependencies.sh
If you want to evaluate Arabic models, please follow the instructions provided here: http://alt.qcri.org/tools/arabic-normalizer/ to install

## Trained Models

Looking for other trained models? Check back soon.

Model | Description | Download
---|---|---
`12b_last_checkpoint` | 12B parameter model trained on many-to-many training data for 100 languages | [12b_last_checkpoint](https://dl.fbaipublicfiles.com/m2m_100/12b_last_checkpoint.pt)


## SentencePiece Model

```bash
wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
```

## Generation with M2M-100

### Encode using our SentencePiece Model

Note: Install SentencePiece from [here](https://github.com/google/sentencepiece)

```bash
fairseq=/path/to/fairseq
cd $fairseq
sacrebleu --echo src -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.de
sacrebleu --echo ref -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.fr
wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
for lang in de fr ; do
python scripts/spm_encode.py \
--model spm.128k.model \
--output_format=piece \
--inputs=raw_input.de-fr.${lang} \
--outputs=spm.de-fr.${lang}
done
```

### Binarization

```bash
wget https://dl.fbaipublicfiles.com/m2m_100/data_dict.128k.txt
fairseq-preprocess \
--source-lang de --target-lang fr \
--testpref spm.de-fr \
--thresholdsrc 0 --thresholdtgt 0 \
--destdir data_bin \
--srcdict data_dict.128k.txt --tgtdict data_dict.128k.txt
```

### Generation on a V100 GPU

```bash
wget https://dl.fbaipublicfiles.com/m2m_100/model_dict.128k.txt
wget https://dl.fbaipublicfiles.com/m2m_100/language_pairs.txt
wget https://dl.fbaipublicfiles.com/m2m_100/12b_last_checkpoint.pt
fairseq-generate \
data_bin \
--batch-size 1 \
--path 12b_last_checkpoint.pt \
--fixed-dictionary model_dict.128k.txt \
-s de -t fr \
--remove-bpe 'sentencepiece' \
--beam 5 \
--task translation_multi_simple_epoch \
--lang-pairs language_pairs.txt \
--decoder-langtok --encoder-langtok src \
--gen-subset test \
--fp16 \
--dataset-impl mmap \
--distributed-world-size 1 --distributed-no-spawn \
--pipeline-model-parallel \
--pipeline-chunks 1 \
--pipeline-encoder-balance '[26]' \
--pipeline-encoder-devices '[0]' \
--pipeline-decoder-balance '[1,24,1]' \
--pipeline-decoder-devices '[0,1,0]' > gen_out
```
## Evaluation with M2M-100

### Tokenization

Note: Refer to tokenizers/README.md for more details on tokenization.

```bash
cd ${fairseq}/examples/m2m_100
cat ${fairseq}/gen_out | grep -P "^H" | sort -V | cut -f 3- | sh tok.sh fr > hyp
cat ${fairseq}/raw_input.de-fr.fr | sh tok.sh fr > ref
```

### BLEU

```bash
sacrebleu -tok 'none' ref < hyp
```
52 changes: 52 additions & 0 deletions examples/m2m_100/process_data/clean_histogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--src', type=str, help='Source language')
parser.add_argument('--tgt', type=str, help='Target language')
parser.add_argument('--src-file', type=str, help='Input source file')
parser.add_argument('--tgt-file', type=str, help='Input target file')
parser.add_argument('--src-output-file', type=str, help='Output source file')
parser.add_argument('--tgt-output-file', type=str, help='Output target file')
parser.add_argument('--threshold', type=float, default=0.5, help='Threshold')
parser.add_argument('--threshold-character', type=str, default=']', help='Threshold character')
parser.add_argument('--histograms', type=str, help='Path to histograms')

args = parser.parse_args()


def read_hist(f):
ch = []
for line in f:
c = line[0]
if c == args.threshold_character:
break
ch.append(c)
return ch


with(open("{}/{}".format(args.histograms, args.src), 'r', encoding='utf8')) as f:
ch1 = read_hist(f)

with(open("{}/{}".format(args.histograms, args.tgt), 'r', encoding='utf8')) as f:
ch2 = read_hist(f)

print("Accepted characters for {}: {}".format(args.src, ch1))
print("Accepted characters for {}: {}".format(args.tgt, ch2))

with open(args.src_file, 'r', encoding='utf8') as fs1, open(args.tgt_file, 'r', encoding='utf8') as fs2, open(args.src_output_file, 'w', encoding='utf8') as fos1, open(args.tgt_output_file, 'w', encoding='utf8') as fos2:
ls1 = fs1.readline()
ls2 = fs2.readline()

while ls1 or ls2:
cnt1 = len([c for c in ls1.strip() if c in ch1])
cnt2 = len([c for c in ls2.strip() if c in ch2])

if cnt1 / len(ls1) > args.threshold and cnt2 / len(ls2) > args.threshold:
fos1.write(ls1)
fos2.write(ls2)
else:
print("{} {} {} \n{} {} {}".format(args.src, cnt1 / len(ls1), ls1.strip(), args.tgt, cnt2 / len(ls2), ls2.strip()))

ls1 = fs1.readline()
ls2 = fs2.readline()

91 changes: 91 additions & 0 deletions examples/m2m_100/process_data/dedup_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import argparse
from collections import namedtuple
import os

DATADIR = "/path/to/train_data"
DEDUP_FROM_DIR = "/path/to/eval/data"
OUTPUT_DIR = "/path/to/output/data"


def main(args):
languages = set()
for language_directory in os.listdir(DATADIR):
if "_" in language_directory:
src, tgt = language_directory.split("_")
languages.add(LanguagePair(src=src, tgt=tgt))

data = existing_data()
train_languages = sorted(languages)
for language_pair in train_languages[args.start_index:args.start_index + args.size]:
print(language_pair)
dedup(language_pair, data)


LanguagePair = namedtuple("LanguagePair", ["src", "tgt"])


def existing_data():
data = set()
for file in os.listdir(DEDUP_FROM_DIR):
with open(os.path.join(DEDUP_FROM_DIR, file)) as f:
data |= set(f.readlines())
return data

def dedup(language_pair, data, verbose=True, output=True):
train_filenames = LanguagePair(
src=f"{DATADIR}/{language_pair.src}_{language_pair.tgt}/train.{language_pair.src}",
tgt=f"{DATADIR}/{language_pair.src}_{language_pair.tgt}/train.{language_pair.tgt}",
)

output_filenames = LanguagePair(
src=f"{OUTPUT_DIR}/train.dedup.{language_pair.src}-{language_pair.tgt}.{language_pair.src}",
tgt=f"{OUTPUT_DIR}/train.dedup.{language_pair.src}-{language_pair.tgt}.{language_pair.tgt}"
)

# If output exists, skip this pair. It has already been done.
if (os.path.exists(output_filenames.src) and
os.path.exists(output_filenames.tgt)):
if verbose:
print(f"{language_pair.src}-{language_pair.tgt} already done.")
return

if verbose:
print(f"{language_pair.src}-{language_pair.tgt} ready, will check dups.")

# If there is no output, no need to actually do the loop.
if not output:
return

if os.path.exists(train_filenames.src) and os.path.exists(train_filenames.tgt):
with open(train_filenames.src) as f:
train_source = f.readlines()

with open(train_filenames.tgt) as f:
train_target = f.readlines()

# do dedup
new_train_source = []
new_train_target = []
for i, train_line in enumerate(train_source):
if train_line not in data and train_target[i] not in data:
new_train_source.append(train_line)
new_train_target.append(train_target[i])

assert len(train_source) == len(train_target)
assert len(new_train_source) == len(new_train_target)
assert len(new_train_source) <= len(train_source)

with open(output_filenames.src, "w") as o:
for line in new_train_source:
o.write(line)

with open(output_filenames.tgt, "w") as o:
for line in new_train_target:
o.write(line)


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("-s", "--start-index", required=True, type=int)
parser.add_argument("-n", "--size", required=True, type=int)
main(parser.parse_args())
Loading

0 comments on commit 65e11a3

Please sign in to comment.