-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
29 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,11 @@ | |
Smallest pieces are "atoms" (e.g. characters or words) | ||
Lexicon contains "items" (e.g. morphs or phrases) | ||
""" | ||
|
||
__all__ = ['InputFormatError','batch_train','online_train', | ||
'corpus_segmentation_dict', 'Lexicon','BaselineModel', 'Corpus', | ||
'Annotations'] | ||
|
||
__version__ = '2.0.0pre1' | ||
__author__ = 'Sami Virpioja, Peter Smit' | ||
__author_email__ = "[email protected]" | ||
|
@@ -191,7 +196,7 @@ def __init__(self, forcesplit_list = [], corpusweight = 1.0, | |
self.counter = {} # Counter for random skipping | ||
self.corpuscostweight = corpusweight | ||
self.forcesplit_list = forcesplit_list | ||
if annotations != None: | ||
if annotations is not None: | ||
self.set_annotations(annotations, supervisedcorpusweight) | ||
else: | ||
self.supervised = False | ||
|
@@ -208,7 +213,7 @@ def set_annotations(self, annotations, supervisedcorpusweight): | |
self.supervisedtokens = 0 | ||
self.supervisedlogtokensum = 0.0 | ||
self.supervisedcorpuscost = 0.0 | ||
if supervisedcorpusweight == None: | ||
if supervisedcorpusweight is None: | ||
self.supervisedcorpusweight = 1.0 | ||
self.sweightbalance = True | ||
else: | ||
|
@@ -225,7 +230,7 @@ def load_segmentations(self, segfile): | |
<count> <item1> + <item2> + ... + <itemN> | ||
""" | ||
if segfile[-3:] == '.gz': | ||
if segfile.endswith('.gz'): | ||
fobj = gzip.open(segfile, 'r') | ||
else: | ||
fobj = open(segfile, 'r') | ||
|
@@ -253,7 +258,7 @@ def save_segmentations(self, segfile): | |
<count> <item1> + <item2> + ... + <itemN> | ||
""" | ||
if segfile[-3:] == '.gz': | ||
if segfile.endswith('.gz'): | ||
fobj = gzip.open(segfile, 'w') | ||
else: | ||
fobj = open(segfile, 'w') | ||
|
@@ -458,7 +463,7 @@ def best_analysis(self, choices): | |
math.log(self.analyses[m][1]) | ||
else: | ||
cost -= self.penaltylogprob # penaltylogprob is negative | ||
if bestcost == None or cost < bestcost: | ||
if bestcost is None or cost < bestcost: | ||
bestcost = cost | ||
bestanalysis = analysis | ||
return bestanalysis, bestcost | ||
|
@@ -614,7 +619,7 @@ def get_viterbi_segments(self, compound, allow_new_items = True): | |
bestpath = None | ||
bestcost = None | ||
for pt in range(0, t): | ||
if grid[pt][0] == None: | ||
if grid[pt][0] is None: | ||
continue | ||
cost = grid[pt][0] | ||
item = compound[pt:t] | ||
|
@@ -631,14 +636,14 @@ def get_viterbi_segments(self, compound, allow_new_items = True): | |
cost += badlikelihood | ||
else: | ||
continue | ||
if bestcost == None or cost < bestcost: | ||
if bestcost is None or cost < bestcost: | ||
bestcost = cost | ||
bestpath = pt | ||
grid.append((bestcost, bestpath)) | ||
items = [] | ||
path = grid[-1][1] | ||
lt = clen + 1 | ||
while path != None: | ||
while path is not None: | ||
t = path | ||
items.append(compound[t:lt]) | ||
path = grid[t][1] | ||
|
@@ -682,7 +687,7 @@ def get_compound_str(self, i): | |
|
||
def get_compound_atoms(self, i): | ||
"""Return the atom representation of the compound at index i.""" | ||
if self.atom_sep == None: | ||
if self.atom_sep is None: | ||
return self.compounds[i] # string | ||
else: | ||
return tuple(re.split(self.atom_sep, self.compounds[i])) # tuple | ||
|
@@ -715,7 +720,7 @@ def get_max_compound_len(self): | |
|
||
def get_compound_len(self, c): | ||
"""Return the number of atoms in the compound.""" | ||
if self.atom_sep == None: | ||
if self.atom_sep is None: | ||
return len(c) | ||
else: | ||
return len(re.split(self.atom_sep, c)) | ||
|
@@ -740,7 +745,7 @@ def load(self, datafile, compound_sep = ' *', comment_re = "^#"): | |
for line in fobj: | ||
if re.search(comment_re, line): | ||
continue | ||
if compound_sep == None or compound_sep == '': | ||
if compound_sep is None or compound_sep == '': | ||
# Line is one compound | ||
compounds = [line.rstrip()] | ||
else: | ||
|
@@ -789,7 +794,7 @@ def generator(self, datafiles, compound_sep = ' *', comment_re = "^#"): | |
for line in fobj: | ||
if re.search(comment_re, line): | ||
continue | ||
if compound_sep == None or compound_sep == '': | ||
if compound_sep is None or compound_sep == '': | ||
# Line is one compound | ||
compounds = [line.rstrip()] | ||
else: | ||
|
@@ -978,14 +983,14 @@ def online_train(model, corpusiter, epochinterval = 10000, dampfunc = None): | |
""" | ||
model.epoch_update(0) | ||
if dampfunc != None: | ||
if dampfunc is not None: | ||
counts = {} | ||
_vprint("Starting online training\n", 1) | ||
i = 0 | ||
epochs = 0 | ||
dotfreq = int(math.ceil(epochinterval / 70.0)) | ||
for w in corpusiter: | ||
if dampfunc != None: | ||
if dampfunc is not None: | ||
if not counts.has_key(w): | ||
c = 0 | ||
counts[w] = 1 | ||
|
@@ -1175,30 +1180,30 @@ def main(argv): | |
global _verboselevel | ||
_verboselevel = args.verbose | ||
|
||
if args.loadfile == None and args.loadsegfile == None and \ | ||
if args.loadfile is None and args.loadsegfile is None and \ | ||
len(args.trainfiles) == 0: | ||
parser.error("either model file or training data should be defined") | ||
|
||
if args.randseed != None: | ||
if args.randseed is not None: | ||
random.seed(args.randseed) | ||
|
||
# Load annotated data if specified | ||
if args.annofile != None: | ||
if args.annofile is not None: | ||
annotations = Annotations() | ||
annotations.load(args.annofile) | ||
else: | ||
annotations = None | ||
|
||
# Load exisiting model or create a new one | ||
if args.loadfile != None: | ||
if args.loadfile is not None: | ||
_vprint("Loading model from '%s'..." % args.loadfile, 1) | ||
with open(args.loadfile, 'rb') as fobj: | ||
model = pickle.load(fobj) | ||
_vprint(" Done.\n", 1) | ||
if annotations != None: | ||
if annotations is not None: | ||
# Add annotated data to model | ||
model.set_annotations(annotations, args.scorpusweight) | ||
elif args.loadsegfile != None: | ||
elif args.loadsegfile is not None: | ||
_vprint("Loading model from '%s'..." % args.loadsegfile, 1) | ||
model = BaselineModel(forcesplit_list = args.forcesplit, | ||
corpusweight = args.corpusweight, | ||
|
@@ -1240,7 +1245,7 @@ def main(argv): | |
data.load(f, args.cseparator) | ||
_vprint(" Done.\n", 1) | ||
model.batch_init(data, args.freqthreshold, dampfunc) | ||
if args.splitprob != None: | ||
if args.splitprob is not None: | ||
model.random_split_init(data, args.splitprob) | ||
e, c = batch_train(model, data, freqthreshold = args.freqthreshold) | ||
elif args.trainmode == 'online': | ||
|
@@ -1259,20 +1264,20 @@ def main(argv): | |
(e, c, te-ts), 1) | ||
|
||
# Save model | ||
if args.savefile != None: | ||
if args.savefile is not None: | ||
_vprint("Saving model to '%s'..." % args.savefile, 1) | ||
with open(args.savefile, 'wb') as fobj: | ||
pickle.dump(model, fobj, pickle.HIGHEST_PROTOCOL) | ||
_vprint(" Done.\n", 1) | ||
|
||
if args.savesegfile != None: | ||
if args.savesegfile is not None: | ||
_vprint("Saving model segmentations to '%s'..." % | ||
args.savesegfile, 1) | ||
model.save_segmentations(args.savesegfile) | ||
_vprint(" Done.\n", 1) | ||
|
||
# Output lexicon | ||
if args.lexfile != None: | ||
if args.lexfile is not None: | ||
if args.lexfile == '-': | ||
fobj = sys.stdout | ||
elif args.lexfile[-3:] == '.gz': | ||
|