diff --git a/doc/flags.md b/doc/flags.md index b4c8c059..6354a137 100644 --- a/doc/flags.md +++ b/doc/flags.md @@ -8,7 +8,15 @@ The preprocessing script `scripts/preprocess.py` accepts the following command-l - `--val_frac`: What fraction of the data to use as a validation set; default is `0.1`. - `--test_frac`: What fraction of the data to use as a test set; default is `0.1`. - `--quiet`: If you pass this flag then no output will be printed to the console. +- `--syllabic`: Predict syllables instead of letters. You must specify a dictionary (e.g., `en_US`) for syllable separation. +- `--install_syllabic_dict`: Install a new dictionary for syllable separation (e.g., `en_US`, `fr_FR`, `pt_BR`, etc.) +Syllabic prediction transforms the input file: all letters are converted to lower-case, spaces are conflated, and all characters other than letters, numerals, punctuation, and newlines are ignored. The input is assumed as Unicode, and Unicode General Category is employed to decide the type of character. + +The PyHyphen library must be installed to allow the `--syllabic` and `--install_syllabic_dict` flags. You can do it with: +```bash +pip install PyHyphen +``` # Training The training script `train.lua` accepts the following command-line flags: diff --git a/scripts/preprocess.py b/scripts/preprocess.py index 90b834b6..4c0ead50 100644 --- a/scripts/preprocess.py +++ b/scripts/preprocess.py @@ -4,6 +4,7 @@ import numpy as np import h5py import codecs +import sys parser = argparse.ArgumentParser() @@ -13,6 +14,8 @@ parser.add_argument('--val_frac', type=float, default=0.1) parser.add_argument('--test_frac', type=float, default=0.1) parser.add_argument('--quiet', action='store_true') +parser.add_argument('--syllabic', default='none') +parser.add_argument('--install_syllabic_dict', default='none') parser.add_argument('--encoding', default='utf-8') args = parser.parse_args() @@ -20,15 +23,79 @@ if __name__ == '__main__': if args.encoding == 'bytes': args.encoding = None + if args.install_syllabic_dict != 'none' : + # Note that this step is unnecessary with pyhyphen>=3.0.0 as language + # dictionaries are now installed on-the-fly. + from hyphen import dictools + dictools.install(args.install_syllabic_dict) + sys.exit(0) + # First go the file once to see how big it is and to build the vocab - token_to_idx = {} - total_size = 0 - with codecs.open(args.input_txt, 'r', args.encoding) as f: - for line in f: - total_size += len(line) - for char in line: - if char not in token_to_idx: - token_to_idx[char] = len(token_to_idx) + 1 + if args.syllabic == 'none' : + syllabic = False + token_to_idx = {} + total_size = 0 + with codecs.open(args.input_txt, 'r', args.encoding) as f: + for line in f: + total_size += len(line) + for char in line: + if char not in token_to_idx: + token_to_idx[char] = len(token_to_idx) + 1 + else : + syllabic = True + + import unicodedata + from hyphen import dictools + if not dictools.is_installed(args.syllabic) : + # Note that in more recent versions of pyhyphen, it is not necessary + # to crash here, as the language dictionary will be automatically + # downloaded by Hyphenator. + print 'Syllabic dictionary', args.syllabic, 'not installed' + print 'Installed dictionaries:', ' '.join(dictools.list_installed()) + sys.exit(0) + from hyphen import Hyphenator + separator = Hyphenator(args.syllabic) + + def scanSyllables(stream, encoding, processing) : + word = '' + space = False + with codecs.open(stream, 'r', encoding) as f: + for line in f: + for char in line: + cat = unicodedata.category(char) + if cat[0]=='L' : + word = word + char + space = False + continue + if len(word)>0 : + syls = separator.syllables(word.lower()) + if len(syls) == 0 : + syls = [ word.lower() ] + word = '' + else : + syls = [ ] + if cat[0]=='Z' : + if not space : syls.append( u' ' ) + space = True + elif cat[0]=='N' or cat[0]=='P' : + syls.append( char ) + space = False + elif char == u'\n' : + syls.append( char ) + space = False + for syl in syls : + processing(syl) + + def createVocab(syl) : + global token_to_idx + global total_size + total_size += 1 + if syl not in token_to_idx: + token_to_idx[syl] = len(token_to_idx) + 1 + + token_to_idx = { u'\n' : 1 } + total_size = 0 + scanSyllables(args.input_txt, args.encoding, createVocab) # Now we can figure out the split sizes val_size = int(args.val_frac * total_size) @@ -58,14 +125,39 @@ # Go through the file again and write data to numpy arrays split_idx, cur_idx = 0, 0 - with codecs.open(args.input_txt, 'r', args.encoding) as f: - for line in f: - for char in line: - splits[split_idx][cur_idx] = token_to_idx[char] - cur_idx += 1 - if cur_idx == splits[split_idx].size: - split_idx += 1 - cur_idx = 0 + if not syllabic : + with codecs.open(args.input_txt, 'r', args.encoding) as f: + for line in f: + for char in line: + splits[split_idx][cur_idx] = token_to_idx[char] + cur_idx += 1 + if cur_idx == splits[split_idx].size: + split_idx += 1 + cur_idx = 0 + else : + + def convertInput(syl) : + global check_size + global splits + global split_idx + global cur_idx + global token_to_idx + check_size += 1 + # print check_size, syl + splits[split_idx][cur_idx] = token_to_idx[syl] + cur_idx += 1 + if cur_idx == splits[split_idx].size: + split_idx += 1 + cur_idx = 0 + + check_size = 0 + scanSyllables(args.input_txt, args.encoding, convertInput) + + if total_size != check_size : + print 'WARNING : File sizes mismatched between vocabulary building (', total_size, ') and token conversion (', check_size, ')' + if cur_idx!=0 : + print 'ERROR : File size mismatched between splits. cur_idx =', cur_idx + sys.exit(1) # Write data to HDF5 file with h5py.File(args.output_h5, 'w') as f: