Skip to content

Commit

Permalink
use is_load_forwardable instead of is_sroa_eligible
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Jan 11, 2022
1 parent 6180254 commit dd0ec5b
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 41 deletions.
22 changes: 15 additions & 7 deletions src/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export
has_return_escape,
has_thrown_escape,
has_all_escape,
is_load_forwardable,
is_sroa_eligible,
can_elide_finalizer

Expand Down Expand Up @@ -158,24 +159,31 @@ has_thrown_escape(x::EscapeLattice) = x.ThrownEscape
has_thrown_escape(x::EscapeLattice, pc::Int) = has_thrown_escape(x) && pc in x.EscapeSites
has_all_escape(x::EscapeLattice) = AllEscape() x

# utility lattice constructors
ignore_aliasescapes(x::EscapeLattice) = EscapeLattice(x, BOT_ALIAS_ESCAPES)
has_aliasescapes(x::EscapeLattice) = !isa(x.AliasEscapes, Bool)

# TODO is_sroa_eligible: consider throwness?

"""
is_sroa_eligible(x::EscapeLattice) -> Bool
is_load_forwardable(x::EscapeLattice) -> Bool
Queries allocation eliminability by SROA.
Queries if `x` is elibigle for store-to-load forwarding optimization.
"""
function is_sroa_eligible(x::EscapeLattice)
function is_load_forwardable(x::EscapeLattice)
if x.AliasEscapes === false || # allows this query to work for immutables since we don't impose escape on them
isa(x.AliasEscapes, FieldEscapes)
return !has_return_escape(x) # TODO technically we also need to check !has_thrown_escape(x) as well
# NOTE technically we also need to check `!has_thrown_escape(x)` here as well,
# but we can also do equivalent check during forwarding
return true
end
return false
end

"""
is_sroa_eligible(x::EscapeLattice) -> Bool
Queries allocation eliminability by SROA.
"""
is_sroa_eligible(x::EscapeLattice) = is_load_forwardable(x) && !has_return_escape(x)

"""
can_elide_finalizer(x::EscapeLattice, pc::Int) -> Bool
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ function get_name_color(x::EscapeLattice, symbol::Bool = false)
name, color = (nothing, "*"), :red
end
name = symbol ? last(name) : first(name)
if name !== nothing && EA.has_aliasescapes(x)
if name !== nothing && !isa(x.AliasEscapes, Bool)
name = string(name, "")
end
return name, color
Expand Down
66 changes: 33 additions & 33 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ end
i = only(findall(isT(SafeRef{String}), result.ir.stmts.type))
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(2)], r)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
let result = code_escapes((String,)) do a
t = (a,)
Expand All @@ -534,7 +534,7 @@ end
i = only(findall(t->t<:Tuple, result.ir.stmts.type))
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(2)], r)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
let result = code_escapes((String, String)) do a, b
obj = SafeRefs(a, b)
Expand All @@ -546,7 +546,7 @@ end
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(2)], r) # a
@test has_return_escape(result.state[Argument(3)], r) # b
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end

# field escape should propagate to `setfield!` argument
Expand All @@ -559,7 +559,7 @@ end
i = only(findall(isT(SafeRef{String}), result.ir.stmts.type))
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(2)], r)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
# propagate escape information imposed on return value of `setfield!` call
let result = code_escapes((String,)) do a
Expand All @@ -569,7 +569,7 @@ end
i = only(findall(isT(SafeRef{String}), result.ir.stmts.type))
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(2)], r)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end

# nested allocations
Expand All @@ -584,7 +584,7 @@ end
if isnew(result.ir.stmts.inst[i]) && isT(SafeRef{String})(result.ir.stmts.type[i])
@test has_return_escape(result.state[SSAValue(i)], r)
elseif isnew(result.ir.stmts.inst[i]) && isT(SafeRef{SafeRef{String}})(result.ir.stmts.type[i])
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
end
end
Expand All @@ -599,7 +599,7 @@ end
if isnew(result.ir.stmts.inst[i]) && isT(Tuple{String})(result.ir.stmts.type[i])
@test has_return_escape(result.state[SSAValue(i)], r)
elseif isnew(result.ir.stmts.inst[i]) && isT(Tuple{Tuple{String}})(result.ir.stmts.type[i])
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
end
end
Expand All @@ -613,7 +613,7 @@ end
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(2)], r)
for i in findall(isnew, result.ir.stmts.inst)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
end
let result = code_escapes() do
Expand Down Expand Up @@ -651,7 +651,7 @@ end
i = only(findall(isnew, result.ir.stmts.inst))
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(2)], r)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end

# ϕ-node allocations
Expand All @@ -667,9 +667,9 @@ end
@test has_return_escape(result.state[Argument(3)], r) # x
@test has_return_escape(result.state[Argument(4)], r) # y
i = only(findall(isϕ, result.ir.stmts.inst))
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
for i in findall(isnew, result.ir.stmts.inst)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
end
let result = code_escapes((Bool,Any,Any)) do cond, x, y
Expand All @@ -684,10 +684,10 @@ end
@test has_return_escape(result.state[Argument(3)], r) # x
@test has_return_escape(result.state[Argument(4)], r) # y
for i in findall(isϕ, result.ir.stmts.inst)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
for i in findall(isnew, result.ir.stmts.inst)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
end
# when ϕ-node merges values with different types
Expand Down Expand Up @@ -725,7 +725,7 @@ end
end
i = findfirst(isnew, result.ir.stmts.inst)
@test has_all_escape(result.state[Argument(2)])
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
# alias via getfield & setfield!
let result = @eval EATModule() begin
Expand All @@ -740,7 +740,7 @@ end
end
i = findfirst(isnew, result.ir.stmts.inst)
@test has_all_escape(result.state[Argument(3)])
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
# alias via typeassert
let result = code_escapes((Any,)) do a
Expand Down Expand Up @@ -774,10 +774,10 @@ end
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(3)], r) # x
for i in findall(isϕ, result.ir.stmts.inst)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
for i in findall(isnew, result.ir.stmts.inst)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
end
let result = code_escapes((Bool,Bool,String)) do cond1, cond2, x
Expand All @@ -792,10 +792,10 @@ end
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(4)], r) # x
for i in findall(isϕ, result.ir.stmts.inst)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
for i in findall(isnew, result.ir.stmts.inst)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
end
# alias via π-node
Expand Down Expand Up @@ -852,7 +852,7 @@ end
i = only(findall(isT(SafeRef{String}), result.ir.stmts.type))
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(2)], r) # a
@test is_sroa_eligible(result.state[SSAValue(i)]) # obj
@test is_load_forwardable(result.state[SSAValue(i)]) # obj
end
let result = code_escapes((String, String, Symbol)) do a, b, fld
obj = SafeRefs(a, b)
Expand All @@ -862,7 +862,7 @@ end
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(2)], r) # a
@test has_return_escape(result.state[Argument(3)], r) # b
@test is_sroa_eligible(result.state[SSAValue(i)]) # obj
@test is_load_forwardable(result.state[SSAValue(i)]) # obj
end
let result = code_escapes((String, String, Int)) do a, b, idx
obj = SafeRefs(a, b)
Expand All @@ -872,7 +872,7 @@ end
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(2)], r) # a
@test has_return_escape(result.state[Argument(3)], r) # b
@test is_sroa_eligible(result.state[SSAValue(i)]) # obj
@test is_load_forwardable(result.state[SSAValue(i)]) # obj
end
let result = code_escapes((String, String, Symbol)) do a, b, fld
obj = SafeRefs("a", "b")
Expand All @@ -883,7 +883,7 @@ end
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(2)], r) # a
@test !has_return_escape(result.state[Argument(3)], r) # b
@test is_sroa_eligible(result.state[SSAValue(i)]) # obj
@test is_load_forwardable(result.state[SSAValue(i)]) # obj
end
let result = code_escapes((String, Symbol)) do a, fld
obj = SafeRefs("a", "b")
Expand All @@ -893,7 +893,7 @@ end
i = only(findall(isT(SafeRefs{String,String}), result.ir.stmts.type))
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(2)], r) # a
@test is_sroa_eligible(result.state[SSAValue(i)]) # obj
@test is_load_forwardable(result.state[SSAValue(i)]) # obj
end
let result = code_escapes((String, String, Int)) do a, b, idx
obj = SafeRefs("a", "b")
Expand All @@ -904,7 +904,7 @@ end
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(2)], r) # a
@test !has_return_escape(result.state[Argument(3)], r) # b
@test is_sroa_eligible(result.state[SSAValue(i)]) # obj
@test is_load_forwardable(result.state[SSAValue(i)]) # obj
end

# interprocedural
Expand All @@ -922,7 +922,7 @@ end
r = only(findall(isreturn, result.ir.stmts.inst))
@test has_return_escape(result.state[Argument(2)], r)
# NOTE we can't scalar replace `obj`, but still we may want to stack allocate it
@test_broken is_sroa_eligible(result.state[SSAValue(i)])
@test_broken is_load_forwardable(result.state[SSAValue(i)])
end

# TODO interprocedural field analysis
Expand All @@ -946,7 +946,7 @@ end
r = only(findall(isreturn, result.ir.stmts.inst))
@test_broken !has_return_escape(result.state[Argument(2)], r) # a
@test has_return_escape(result.state[Argument(3)], r) # b
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
let result = code_escapes((Any,Any)) do a, b
r = SafeRef{Any}(:init)
Expand All @@ -958,7 +958,7 @@ end
r = only(findall(isreturn, result.ir.stmts.inst))
@test_broken !has_return_escape(result.state[Argument(2)], r) # a
@test has_return_escape(result.state[Argument(3)], r) # b
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
let result = code_escapes((Any,Any,Bool)) do a, b, cond
r = SafeRef{Any}(:init)
Expand All @@ -971,7 +971,7 @@ end
end
end
i = only(findall(isnew, result.ir.stmts.inst))
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
r = only(findall(result.ir.stmts.inst) do @nospecialize x
isreturn(x) && isa(x.val, Core.SSAValue)
end)
Expand All @@ -998,7 +998,7 @@ end
@test has_return_escape(result.state[Argument(3)], r) # baz
@test has_return_escape(result.state[Argument(4)], r) # qux
for new in findall(isnew, result.ir.stmts.inst)
@test is_sroa_eligible(result.state[SSAValue(new)])
@test is_load_forwardable(result.state[SSAValue(new)])
end
end
let result = code_escapes((Bool,String,String,)) do cnd, baz, qux
Expand Down Expand Up @@ -1050,17 +1050,17 @@ function compute!(a, b)
end
let result = @code_escapes compute(MPoint, 1+.5im, 2+.5im, 2+.25im, 4+.75im)
for i in findall(isnew, result.ir.stmts.inst)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
end
let result = @code_escapes compute(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im))
for i in findall(isnew, result.ir.stmts.inst)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
end
let result = @code_escapes compute!(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im))
for i in findall(isnew, result.ir.stmts.inst)
@test is_sroa_eligible(result.state[SSAValue(i)])
@test is_load_forwardable(result.state[SSAValue(i)])
end
end

Expand Down

0 comments on commit dd0ec5b

Please sign in to comment.