-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
50 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,15 +14,16 @@ | |
__author__ = 'Sami Virpioja, Peter Smit' | ||
__author_email__ = "[email protected]" | ||
|
||
import array | ||
import datetime | ||
import gzip | ||
import itertools | ||
import logging | ||
import math | ||
import random | ||
import time | ||
import re | ||
import sys | ||
import gzip | ||
import array | ||
import itertools | ||
import datetime | ||
import time | ||
|
||
try: | ||
# In Python2 import cPickle for better performance | ||
|
@@ -35,6 +36,9 @@ | |
except ImportError: | ||
pass | ||
|
||
_logger = logging.getLogger(__name__) | ||
|
||
|
||
class Error(Exception): | ||
"""Base class for exceptions in this module.""" | ||
pass | ||
|
@@ -54,14 +58,6 @@ def __init__(self, filename, line): | |
def __str__(self): | ||
return "illegal format in file '%s'" % self.file | ||
|
||
_verboselevel = 1 | ||
|
||
def _vprint(s, l = 1, maxl = 999): | ||
"""Internal function for printing to standard error stream""" | ||
global _verboselevel | ||
if _verboselevel >= l and _verboselevel <= maxl: | ||
sys.stderr.write(s) | ||
|
||
_log2pi = math.log(2*math.pi) | ||
|
||
def logfactorial(n): | ||
|
@@ -603,8 +599,8 @@ def epoch_update(self, epoch_num): | |
# data sets | ||
self.supervisedcorpusweight = self.corpuscostweight * \ | ||
float(self.boundaries) / self.annotations.get_types() | ||
_vprint("Corpus weight of annotated data set to %s\n" % | ||
self.supervisedcorpusweight, 2) | ||
_logger.debug("Corpus weight of annotated data set to %s\n" % | ||
self.supervisedcorpusweight) | ||
|
||
def get_viterbi_segments(self, compound, allow_new_items = True): | ||
"""Find optimal segmentation using the Viterbi algorithm.""" | ||
|
@@ -940,11 +936,11 @@ def batch_train(model, corpus, freqthreshold = 1, finishthreshold = 0.005): | |
newcost = model.get_cost() | ||
wordstoprocess = len(filter(lambda x: x >= freqthreshold, | ||
corpus.get_counts())) | ||
_vprint("Found %s compounds in training data\n" % wordstoprocess, 1) | ||
_logger.info("Found %s compounds in training data\n" % wordstoprocess) | ||
dotfreq = int(math.ceil(wordstoprocess / 70.0)) | ||
epochs = 0 | ||
_vprint("Starting batch training\n", 1) | ||
_vprint("Epochs: %s\tCost: %s\n" % (epochs, newcost), 1) | ||
_logger.info("Starting batch training\n") | ||
_logger.info("Epochs: %s\tCost: %s\n" % (epochs, newcost)) | ||
while True: | ||
# One epoch | ||
indices = range(corpus.get_type_count()) | ||
|
@@ -955,20 +951,20 @@ def batch_train(model, corpus, freqthreshold = 1, finishthreshold = 0.005): | |
continue | ||
w = corpus.get_compound_atoms(j) | ||
segments = model.optimize(w) | ||
_vprint("#%s: %s\n" % (i, segments), 2) | ||
_logger.debug("#%s: %s\n" % (i, segments)) | ||
i += 1 | ||
if i % dotfreq == 0: | ||
_vprint(".", 1, 1) | ||
_logger.info(".") | ||
epochs += 1 | ||
_vprint("\n", 1, 1) | ||
_vprint("Cost before epoch update: %s\n" % model.get_cost(), 2) | ||
|
||
_logger.debug("Cost before epoch update: %s\n" % model.get_cost()) | ||
model.epoch_update(epochs) | ||
oldcost = newcost | ||
newcost = model.get_cost() | ||
_vprint("Epochs: %s\tCost: %s\n" % (epochs, newcost), 1) | ||
_logger.info("Epochs: %s\tCost: %s\n" % (epochs, newcost)) | ||
if epochs > 1 and newcost >= oldcost - finishthreshold * wordstoprocess: | ||
break | ||
_vprint("Done.\n", 1) | ||
_logger.info("Done.\n") | ||
return epochs, newcost | ||
|
||
def online_train(model, corpusiter, epochinterval = 10000, dampfunc = None): | ||
|
@@ -985,7 +981,7 @@ def online_train(model, corpusiter, epochinterval = 10000, dampfunc = None): | |
model.epoch_update(0) | ||
if dampfunc is not None: | ||
counts = {} | ||
_vprint("Starting online training\n", 1) | ||
_logger.info("Starting online training\n") | ||
i = 0 | ||
epochs = 0 | ||
dotfreq = int(math.ceil(epochinterval / 70.0)) | ||
|
@@ -1004,20 +1000,19 @@ def online_train(model, corpusiter, epochinterval = 10000, dampfunc = None): | |
else: | ||
model.add(w, 1) | ||
segments = model.optimize(w) | ||
_vprint("#%s: %s\n" % (i, segments), 2) | ||
_logger.debug("#%s: %s\n" % (i, segments)) | ||
i += 1 | ||
if i % dotfreq == 0: | ||
_vprint(".", 1, 1) | ||
_logger.info(".") | ||
if i % epochinterval == 0: | ||
_vprint("\n", 1, 1) | ||
epochs += 1 | ||
model.epoch_update(epochs) | ||
newcost = model.get_cost() | ||
_vprint("Tokens processed: %s\tCost: %s\n" % (i, newcost), 1) | ||
_logger.info("Tokens processed: %s\tCost: %s\n" % (i, newcost)) | ||
epochs += 1 | ||
model.epoch_update(epochs) | ||
newcost = model.get_cost() | ||
_vprint("\nTokens processed: %s\tCost: %s\n" % (i, newcost), 1) | ||
_logger.info("\nTokens processed: %s\tCost: %s\n" % (i, newcost)) | ||
return epochs, newcost | ||
|
||
def corpus_segmentation_dict(model, corpus): | ||
|
@@ -1177,8 +1172,12 @@ def main(argv): | |
metavar='<file>') | ||
args = parser.parse_args(argv) | ||
|
||
global _verboselevel | ||
_verboselevel = args.verbose | ||
if args.verbose >= 2: | ||
logging.basicConfig(level=logging.DEBUG) | ||
elif args.verbose >=1: | ||
logging.basicConfig(level=logging.INFO) | ||
else: | ||
logging.basicConfig(level=logging.WARNING) | ||
|
||
if args.loadfile is None and args.loadsegfile is None and \ | ||
len(args.trainfiles) == 0: | ||
|
@@ -1196,22 +1195,22 @@ def main(argv): | |
|
||
# Load exisiting model or create a new one | ||
if args.loadfile is not None: | ||
_vprint("Loading model from '%s'..." % args.loadfile, 1) | ||
_logger.info("Loading model from '%s'..." % args.loadfile) | ||
with open(args.loadfile, 'rb') as fobj: | ||
model = pickle.load(fobj) | ||
_vprint(" Done.\n", 1) | ||
_logger.info(" Done.\n") | ||
if annotations is not None: | ||
# Add annotated data to model | ||
model.set_annotations(annotations, args.scorpusweight) | ||
elif args.loadsegfile is not None: | ||
_vprint("Loading model from '%s'..." % args.loadsegfile, 1) | ||
_logger.info("Loading model from '%s'..." % args.loadsegfile) | ||
model = BaselineModel(forcesplit_list = args.forcesplit, | ||
corpusweight = args.corpusweight, | ||
annotations = annotations, | ||
supervisedcorpusweight = args.scorpusweight, | ||
use_skips = args.skips) | ||
model.load_segmentations(args.loadsegfile) | ||
_vprint(" Done.\n", 1) | ||
_logger.info(" Done.") | ||
else: | ||
model = BaselineModel(forcesplit_list = args.forcesplit, | ||
corpusweight = args.corpusweight, | ||
|
@@ -1235,15 +1234,14 @@ def main(argv): | |
data = Corpus(args.separator) | ||
for f in args.trainfiles: | ||
if f == '-': | ||
_vprint("Loading training data from standard "+ | ||
"input\n", 1) | ||
_logger.info("Loading training data from standard input") | ||
else: | ||
_vprint("Loading training data file '%s'..." % f, 1) | ||
_logger.info("Loading training data file '%s'..." % f) | ||
if args.list: | ||
data.load_from_list(f) | ||
else: | ||
data.load(f, args.cseparator) | ||
_vprint(" Done.\n", 1) | ||
_logger.info(" Done.\n") | ||
model.batch_init(data, args.freqthreshold, dampfunc) | ||
if args.splitprob is not None: | ||
model.random_split_init(data, args.splitprob) | ||
|
@@ -1260,21 +1258,21 @@ def main(argv): | |
else: | ||
parser.error("unknown training mode '%s'" % args.trainmode) | ||
te = time.time() | ||
_vprint("Epochs: %s\nFinal cost: %s\nTime: %.3fs\n" % | ||
(e, c, te-ts), 1) | ||
_logger.info("Epochs: %s\nFinal cost: %s\nTime: %.3fs\n" % | ||
(e, c, te-ts)) | ||
|
||
# Save model | ||
if args.savefile is not None: | ||
_vprint("Saving model to '%s'..." % args.savefile, 1) | ||
_logger.info("Saving model to '%s'..." % args.savefile) | ||
with open(args.savefile, 'wb') as fobj: | ||
pickle.dump(model, fobj, pickle.HIGHEST_PROTOCOL) | ||
_vprint(" Done.\n", 1) | ||
_logger.info(" Done.\n") | ||
|
||
if args.savesegfile is not None: | ||
_vprint("Saving model segmentations to '%s'..." % | ||
args.savesegfile, 1) | ||
_logger.info("Saving model segmentations to '%s'..." % | ||
args.savesegfile) | ||
model.save_segmentations(args.savesegfile) | ||
_vprint(" Done.\n", 1) | ||
_logger.info(" Done.\n") | ||
|
||
# Output lexicon | ||
if args.lexfile is not None: | ||
|
@@ -1285,17 +1283,16 @@ def main(argv): | |
else: | ||
fobj = open(args.lexfile, 'w') | ||
if args.lexfile != '-': | ||
_vprint("Saving model lexicon to '%s'..." % | ||
args.lexfile, 1) | ||
_logger.info("Saving model lexicon to '%s'..." % args.lexfile) | ||
for item in sorted(model.get_lexicon().get_items()): | ||
fobj.write("%s %s\n" % (model.get_item_count(item), item)) | ||
if args.lexfile != '-': | ||
fobj.close() | ||
_vprint(" Done.\n", 1) | ||
_logger.info(" Done.\n") | ||
|
||
# Segment test data | ||
if len(args.testfiles) > 0: | ||
_vprint("Segmenting test data...", 1) | ||
_logger.info("Segmenting test data...") | ||
if args.outfile == '-': | ||
fobj = sys.stdout | ||
elif args.outfile[-3:] == '.gz': | ||
|
@@ -1310,10 +1307,10 @@ def main(argv): | |
fobj.write("%s\n" % ' '.join(items)) | ||
i += 1 | ||
if i % 10000 == 0: | ||
_vprint(".", 1, 1) | ||
_logger.info(".") | ||
if args.outfile != '-': | ||
fobj.close() | ||
_vprint(" Done.\n", 1) | ||
_logger.info(" Done.\n") | ||
|
||
if __name__ == "__main__": | ||
main(sys.argv[1:]) |