-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathtransition_system.py
149 lines (130 loc) · 5.55 KB
/
transition_system.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
#!/usr/bin/env python
#coding=utf-8
'''
Definition of TransitionSystem class. It initializes the state and provide a method to advance
the state as a result of the application of an action. Actions can be determined either by an
oracle (training and oracle mode parsing) or by a supervised trained classifier (actual parsing).
It also provides a method to return the sequence of {state, action} pairs used to construct the
classifier dataset as well as the relations recovered by the transition system.
@author: Marco Damonte ([email protected])
@since: 03-10-16
'''
from oracle import Oracle
from state import State
from action import Action
from variables import Variables
from history import History
from node import Node
from rules import Rules
from relations import Relations
import PyTorch
import PyTorchHelpers
import numpy as np
import copy
Classify = PyTorchHelpers.load_lua_class('nnets/classify.lua', 'Classify')
class TransitionSystem:
def __init__(self, embs, data, stage, model_dir = None):
if model_dir is not None:
self._classify = Classify(model_dir)
self._labels = [item.strip() for item in open(model_dir + "/relations.txt").read().splitlines()]
else:
self._labels = None
if stage == "ORACLETEST":
assert(len(data) == 4)
hooks = False
tokens, dependencies, relations, alignments = data
lemmas = None
relations2 = []
self.gold = relations
for r in relations:
if r[1].startswith(":snt"):
r2 = (Node(True),":top",r[2])
else:
r2 = (r[0],r[1],r[2])
if (r2[0].token is not None or r2[1] == ":top") and r2[2].token is not None:
relations2.append(r2)
oracle = Oracle(relations2)
self.variables = Variables()
elif stage == "TRAIN" or stage == "COLLECT":
assert(len(data) == 4)
hooks = False
tokens, dependencies, relations, alignments = data
lemmas = None
relations2 = []
for r in relations:
if r[1].startswith(":snt"):
r2 = (Node(True),":top",r[2])
else:
r2 = (r[0],r[1],r[2])
if (r2[0].token is not None or r2[1] == ":top") and r2[2].token is not None:
relations2.append(r2)
oracle = Oracle(relations2)
self.variables = None
else: #PARSING
assert(len(data) == 2)
hooks = True
tokens, dependencies = data
relations2 = None
alignments = None
oracle = None
self.variables = Variables()
self.state = State(embs, relations2, tokens, dependencies, alignments, oracle, hooks, self.variables, stage, Rules(self._labels))
self.history = History()
while self.state.isTerminal() == False:
#print self.state
tok = copy.deepcopy(self.state.buffer.peek())
if oracle is not None:
action = oracle.valid_actions(self.state)
else:
action = self.classifier()
#print action
#raw_input()
if action is not None:
f_rel = []
f_lab = []
f_reentr = []
if stage == "TRAIN":
f_rel = self.state.rel_features()
if action.name== "larc" or action.name == "rarc":
f_lab = self.state.lab_features()
if action.name == "reduce":
f_reentr = self.state.reentr_features()
self.state.apply(action)
self.history.add((f_rel, f_lab, f_reentr), action, tok)
else:
break
assert (self.state.stack.isEmpty() == True and self.state.buffer.isEmpty() == True)
def classifier(self):
digits, words, pos, deps = self.state.rel_features()
constr = self.state.legal_actions()
acttype = int(self._classify.action(digits, words, pos, deps, constr))
assert(acttype > 0 and acttype < 5)
if acttype == 1:
sg = self.state.nextSubgraph()
return Action("shift", sg)
if acttype == 2:
reentr_features = self.state.reentr_features()
siblings = [item[0] for p in self.state.stack.relations.parents[self.state.stack.top()] for item in self.state.stack.relations.children[p[0]] if item[0] != self.state.stack.top()]
for s, feats in zip(siblings,reentr_features):
words, pos, deps = feats
pred = int(self._classify.reentrancy(words, pos, deps))
if pred == 1:
arg0_idx = 9
if self.state.legal_rel_labels("reent", (self.state.stack.top(), s))[arg0_idx] == 1:
return Action("reduce", (s, ":ARG0", None))
break
return Action("reduce", None)
if acttype == 3:
rel = "larc"
elif acttype == 4:
rel = "rarc"
constr = self.state.legal_rel_labels(rel, 1)
digits, words, pos, deps = self.state.lab_features()
pred = int(self._classify.label(digits, words, pos, deps, constr))
return Action(rel,self._labels[pred - 1])
def statesactions(self):
return self.history.statesactions()
def relations(self):
return self.state.stack.relations.triples()
def alignments(self):
return self.history.alignments