-
Notifications
You must be signed in to change notification settings - Fork 1
/
patterns.py
349 lines (298 loc) · 12.5 KB
/
patterns.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
from dataclasses import dataclass
from typing import Iterable, Callable
from fnmatch import fnmatch
from pyparsing import nestedExpr
from trees import *
def match_str(patt: str, word: str) ->bool:
"matching strings with wildcards * anywhere"
return fnmatch(word, patt)
def intpred(n, x):
"condition compared with a number: =8, <8, >8, !8"
number = int(n[1:])
match n[0]:
case '=': return x == number
case '!': return x != number
case '<': return x < number
case '>': return x > number
def combos(altlists):
match altlists:
case [[], _]:
return []
case [xs, *yss]:
return [[x] + ys
for x in xs
for ys in combos([[y for y in ys if y != x] for ys in yss])]
case _:
return [[]]
def different_matches(mf, ps, xs):
"each pattern in ps finds a different matching object in xs with match function mf"
matchlist = [[x for x in xs if mf(p, x)] for p in ps]
return combos(matchlist)
@dataclass
class Pattern(Tree):
def __str__(self):
if sts := self.subtrees:
return '(' + ' '.join([self.root] + [t.__str__() for t in sts]) + ')'
else:
return self.root
def match_wordline(patt: Pattern, word: WordLine) ->bool:
"matching individual wordlines"
match patt:
case Pattern(field, ['IN', *forms]) if field in WORDLINE_FIELDS:
wfield = word.as_dict()[field]
return any(match_str(form, wfield) for form in forms)
case Pattern(field, [form]) if field in WORDLINE_FIELDS:
return match_str(form, word.as_dict()[field])
case Pattern('HEAD_DISTANCE', [n]):
return intpred(n, int(word.HEAD) - int(word.ID)) if word.ID.isdigit() else False
case Pattern('AND', patts):
return all(match_wordline(p, word) for p in patts)
case Pattern('OR', patts):
return any(match_wordline(p, word) for p in patts)
case Pattern('NOT', [patt]):
return not (match_wordline(patt, word))
case _:
return False
def match_deptree(patt: Pattern, tree: DepTree) -> bool:
"matching entire trees - either their root wordline or the whole tree"
if match_wordline(patt, tree.root):
return True
else:
match patt:
case Pattern('LENGTH', [n]):
return intpred(n, len(tree))
case Pattern('DEPTH', [n]):
return intpred(n, tree.depth())
case Pattern('METADATA', [strpatt]):
return match_str(strpatt, '\n'.join(tree.comments))
case Pattern ('IS_NONPROJECTIVE', []):
return nonprojective(tree)
case Pattern('TREE', [pt, *patts]):
return (len(patts) == len(sts := tree.subtrees)
and match_deptree(pt, tree)
and all(match_deptree(*pt) for pt in zip(patts, sts)))
case Pattern('TREE_', [pt, *patts]):
return (match_deptree(pt, tree)
and different_matches(match_deptree, patts, tree.subtrees))
# this could use the same subtree twice:
# and all(any(match_deptree(p, t) for t in tree.subtrees) for p in patts))
case Pattern('SEQUENCE', patts):
return (len(patts) == len(sts := tree.wordlines())
and all(match_wordline(*pt) for pt in zip(patts, sts)))
case Pattern('SUBSEQUENCE', patts):
for i in range(len(tree.wordlines())-len(patts)):
sts = tree.wordlines()[i:i+len(patts)]
if all(match_wordline(*pt) for pt in zip(patts, sts)):
return True
return False
case Pattern('SEQUENCE_', patts):
return (all(any(match_wordline(p, t) for t in tree.wordlines()) for p in patts))
case Pattern('HAS_SUBTREE', patts):
return any(all(match_deptree(p, st) for p in patts) for st in tree.subtrees)
case Pattern('HAS_NO_SUBTREE', patts):
return not any(all(match_deptree(p, st) for p in patts) for st in tree.subtrees)
case Pattern('CONTAINS_SUBTREE', patts): ## to revisit
return (all(match_deptree(p, tree) for p in patts) or
any(match_deptree(patt, st) for st in tree.subtrees))
case Pattern('AND', patts): # must be defined again for tree patterns
return all(match_deptree(p, tree) for p in patts)
case Pattern('OR', patts):
return any(match_deptree(p, tree) for p in patts)
case Pattern('NOT', [patt]):
return not (match_deptree(patt, tree))
case _:
return False
def matches_of_deptree(patt: Pattern, tree: DepTree) -> list[DepTree]:
"return singleton list if the tree matches, otherwise empty"
if match_deptree(patt, tree):
return [tree]
else:
return []
def matches_in_deptree(patt: Pattern, tree: DepTree) -> list[DepTree]:
"finding all subtrees that match a pattern"
ts = []
if match_deptree(patt, tree):
ts.append(tree)
for subtree in tree.subtrees:
ts.extend(matches_in_deptree(patt, subtree))
return ts
def match_found_in_deptree(patt: Pattern, tree: DepTree) -> list[DepTree]:
"return a tree that has at least one matching subtree"
def found_in(tr):
if match_deptree(patt, tr):
tr.add_misc('MATCH')
for subtree in tr.subtrees:
found_in(subtree)
found_in(tree)
if any(w.MISC.endswith('+MATCH') for w in tree.wordlines()):
### +MATCH should not appear in MISC for another reason
return [tree]
else:
return []
def len_segment_pattern(patt: Pattern) -> int:
match patt:
case Pattern('SEGMENT', patts):
return sum(len_segment_pattern(pt) for pt in patts)
case Pattern('REPEAT', [n, patt]):
return int(n) * len_segment_pattern(patt)
case _:
return 1
def match_segment(patt: Pattern, trees: list[DepTree]) -> bool:
"matching a contiguous segment of trees"
if len_segment_pattern(patt) == len(trees):
match patt:
case Pattern('SEGMENT', patts):
for pt in patts:
if match_segment(pt, trees[:len_segment_pattern(pt)]):
trees = trees[len_segment_pattern(pt):]
continue
else:
return False
return True
case Pattern('REPEAT', [n, pt]):
for _ in range(int(n)):
if match_segment(pt, trees[:len_segment_pattern(pt)]):
trees = trees[len_segment_pattern(pt):]
continue
else:
return False
return True
case _:
return match_deptree(patt, trees[0]) # must be a singleton
def matches_in_tree_stream(patt: Pattern,
trees: Iterable[DepTree]) -> Iterable[list[DepTree]]:
match patt:
case Pattern('REPEAT', [n, pt]) if n[0] == '>':
segment = []
while trees:
try:
tr = next(trees)
except StopIteration:
break
while match_deptree(pt, tr):
segment.append(tr)
if trees:
try:
tr = next(trees)
except StopIteration:
break
else:
break
if len(segment) > int(n[1:]):
yield segment
segment = []
case _:
lenp = len_segment_pattern(patt)
try:
segment = [next(trees) for _ in range(lenp)]
except StopIteration:
return
while trees:
if match_segment(patt, segment):
yield segment
try:
segment = [next(trees) for _ in range(lenp)] # segments may not overlap
except StopIteration:
return
else:
try:
segment.pop(0)
segment.append(next(trees))
except StopIteration:
return
def change_wordline(patt: Pattern, word: WordLine) ->WordLine:
"changing the value of some field in accordance with a pattern"
match patt:
case Pattern('IF', [condpatt, changepatt]):
if match_wordline(condpatt, word):
return change_wordline(changepatt, word)
else:
return word
case Pattern(field, [oldval, newval]) if field in WORDLINE_FIELDS:
wdict = word.as_dict()
if match_str(oldval, wdict[field]):
wdict[field] = newval
return WordLine(**wdict)
else:
return word
case Pattern('AND', patts): # cumulative changes in the order of patts
for patt in patts:
word = change_wordline(patt, word)
return word
case _:
return word
def change_deptree(patt: Pattern, tree: DepTree) -> DepTree:
"change a tree in accordance with a pattern"
match patt:
case Pattern('IF', [condpatt, changepatt]):
if match_deptree(condpatt, tree):
return change_deptree(changepatt, tree)
else:
return tree
case Pattern('PRUNE', [depth]):
depth = int(depth)
return prune_subtrees_below(tree, depth)
case Pattern('FILTER_SUBTREES', [condpatt]): ## to revisit
tree.subtrees = [t for t in tree.subtrees if match_deptree(condpatt, t)]
return tree
case Pattern('AND', patts):
for patt in patts:
tree = change_deptree(patt, tree)
return tree
case _:
tree.root = change_wordline(patt, tree.root)
return tree
def changes_in_deptree(patt: Pattern, tree: DepTree) -> DepTree:
"performing change in a tree and recursively in all changed subtrees"
tree = change_deptree(patt, tree)
tree.subtrees = [change_deptree(patt, t) for t in tree.subtrees]
return tree
def find_paths_in_tree(patts: list[Pattern], tree: DepTree) -> list[DepTree]:
"find paths in a tree"
if patts[1:]:
return [DepTree(tree.root, [stp], [])
for st in tree.subtrees
for stp in find_paths_in_tree(patts[1:], st)
if match_wordline(patts[0], tree.root)
]
elif patts:
return [DepTree(tree.root, [], [])
for _ in [0]
if match_wordline(patts[0], tree.root)]
else:
return []
def find_paths_in_subtrees(patts: list[Pattern], tree: DepTree) -> list[DepTree]:
"find parts in tree and all subtrees"
paths = find_paths_in_tree(patts, tree)
for st in tree.subtrees:
for p in find_paths_in_subtrees(patts, st):
paths.append(p)
return paths
def find_partial_local_trees(patts: list[Pattern], tree: DepTree) -> list[DepTree]:
"find partial trees in a tree"
if patts and match_wordline(patts[0], tree.root):
xss = different_matches(match_deptree, patts[1:], tree.subtrees)
return [DepTree(tree.root, [DepTree(x.root, [], []) for x in xs], []) for xs in xss]
else:
return []
def find_partial_local_subtrees(patts: list[Pattern], tree: DepTree) -> list[DepTree]:
subtrs = find_partial_local_trees(patts, tree)
for st in tree.subtrees:
for p in find_partial_local_subtrees(patts, st):
subtrs.append(p)
return subtrs
class ParseError(Exception):
pass
def parse_pattern(s: str) ->Pattern:
"to get a pattern from a string"
if not s.startswith('('): # add outer parentheses if missing
s = '(' + s + ')'
parse = nestedExpr().parseString(s)
def to_pattern(lisp):
match lisp:
case [fun, *args]:
args = [to_pattern(arg) for arg in args]
return Pattern(fun, args)
case tok:
return tok
return to_pattern(parse[0])