-
Notifications
You must be signed in to change notification settings - Fork 1
/
midsection.jl
94 lines (77 loc) · 2.99 KB
/
midsection.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
load("utils/req.jl")
req("dag/transforms.jl")
# -- scattering fusion --------------------------------------------------------
type Scatterer; end
typealias ScatterContext Context{Scatterer}
scattered(node::Node) = evaluate(ScatterContext(), node)
function evaluate(c::ScatterContext, node::SymNode)
# todo: eliminate duplicate indvars?
ellipsis = EllipsisNode(SymNode(:indvars, :local))
node.val.name == :(... ) ? ellipsis : RefNode(node, ellipsis)
end
function evaluate(c::ScatterContext, node::Union(CallNode,RefNode))
# todo: use first line once ref result type bug is fixed
# Node(node, { node.args[1]::SymNode, evaluate(c, node.args[2:end])... })
Node(node, { node.args[1]::SymNode,
evaluate(c, Node[node.args[2:end]...])... })
end
# == Front midsection =========================================================
# -- scatter propagation ------------------------------------------------------
type ScatterPropagator; end
typealias ScatterPropContext Context{ScatterPropagator}
scatter_propagated(sink::Node) = evaluate(ScatterPropContext(), sink)
function evaluate(c::ScatterPropContext, node::CallNode)
if get_op(node).val == SymbolEx(:scatter, :call)
args = get_callargs(node)
@expect length(args)==1
return (@cached scattered(c, args[1]))
else
return default_evaluate(c, node)
end
end
function scattered(c::ScatterPropContext, ns::Vector)#Nodes)
{ (@cached scattered(c, node))|node in ns }
end
function scattered(c::ScatterPropContext, node::CallNode)
op = get_op(node).val
if op == SymbolEx(:scatter, :call)
# let scatter*scatter = scatter
args = get_callargs[node]
@expect length(args)==1
return scattered(c, args[1])
else
op = scattered_op(op)
args = scattered(c, get_callargs(node))
return CallNode(op, args...)
end
end
# todo: scattered for other node types: RefNode, ...more?
function scattered(c::ScatterPropContext, node::SymNode)
@expect node.val.kind == :input "expected kind == :input, got $(node.val.kind)"
RefNode(node, SymNode(:..., :symbol))
end
function scattered_op(ex::SymbolEx)
@expect ex.kind == :call
op = ex.name
if contains([:*, :/, :\ ], op)
error("scattered: cannot scatter op = ", op)
end
SymNode(op, :call)
end
# == Back midsection ==========================================================
function expand_ellipsis_indexing(sink::Node, indvars::Vector{Symbol})
indvars = { SymNode(indvar, :local) for indvar in indvars }
rewrite_dag(sink, (node, args)->expandell_rewrite(node, args, indvars))
end
function expandell_rewrite(oldnode::Node, args::Vector, indvars::Vector)
Node(oldnode, args)
end
function expandell_rewrite(oldnode::RefNode, args::Vector, indvars::Vector)
inds = get_inds(oldnode)
if (length(inds)==1) && (inds[1].val==SymbolEx(:..., :symbol))
newinds = {args[1], indvars...}
return Node(oldnode, newinds)
else
return Node(oldnode, args)
end
end