-
Notifications
You must be signed in to change notification settings - Fork 80
/
Copy pathoptimise_toy.py
82 lines (67 loc) · 2.75 KB
/
optimise_toy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from typing import cast
from xdsl.dialects.builtin import (
DenseIntOrFPElementsAttr,
)
from xdsl.ir import OpResult
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.utils.hints import isa
from ..dialects.toy import ConstantOp, ReshapeOp, TensorTypeF64, TransposeOp
class SimplifyRedundantTranspose(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: TransposeOp, rewriter: PatternRewriter):
"""
Fold transpose(transpose(x)) -> x
"""
# Look at the input of the current transpose.
transpose_input = op.arg
if not isinstance(transpose_input, OpResult):
# Input was not produced by an operation, could be a function argument
return
transpose_input_op = transpose_input.op
if not isinstance(transpose_input_op, TransposeOp):
# Input defined by another transpose? If not, no match.
return
rewriter.replace_op(op, [], [transpose_input_op.arg])
class ReshapeReshapeOpPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ReshapeOp, rewriter: PatternRewriter):
"""
Reshape(Reshape(x)) = Reshape(x)
"""
# Look at the input of the current reshape.
reshape_input = op.arg
if not isinstance(reshape_input, OpResult):
# Input was not produced by an operation, could be a function argument
return
reshape_input_op = reshape_input.op
if not isinstance(reshape_input_op, ReshapeOp):
# Input defined by another transpose? If not, no match.
return
t = cast(TensorTypeF64, op.res.type)
new_op = ReshapeOp.from_input_and_type(reshape_input_op.arg, t)
rewriter.replace_matched_op(new_op)
class FoldConstantReshapeOpPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: ReshapeOp, rewriter: PatternRewriter):
"""
Reshaping a constant can be done at compile time
"""
# Look at the input of the current reshape.
reshape_input = op.arg
if not isinstance(reshape_input, OpResult):
# Input was not produced by an operation, could be a function argument
return
reshape_input_op = reshape_input.op
if not isinstance(reshape_input_op, ConstantOp):
# Input defined by another transpose? If not, no match.
return
assert isa(op.res.type, TensorTypeF64)
new_value = DenseIntOrFPElementsAttr.create_dense_float(
type=op.res.type, data=reshape_input_op.value.get_values()
)
new_op = ConstantOp(new_value)
rewriter.replace_matched_op(new_op)