diff --git a/src/EscapeAnalysis.jl b/src/EscapeAnalysis.jl index b5ee82d..4dceeab 100644 --- a/src/EscapeAnalysis.jl +++ b/src/EscapeAnalysis.jl @@ -8,6 +8,7 @@ export has_return_escape, has_thrown_escape, has_all_escape, + is_load_forwardable, is_sroa_eligible, can_elide_finalizer @@ -158,24 +159,31 @@ has_thrown_escape(x::EscapeLattice) = x.ThrownEscape has_thrown_escape(x::EscapeLattice, pc::Int) = has_thrown_escape(x) && pc in x.EscapeSites has_all_escape(x::EscapeLattice) = AllEscape() ⊑ x +# utility lattice constructors ignore_aliasescapes(x::EscapeLattice) = EscapeLattice(x, BOT_ALIAS_ESCAPES) -has_aliasescapes(x::EscapeLattice) = !isa(x.AliasEscapes, Bool) - -# TODO is_sroa_eligible: consider throwness? """ - is_sroa_eligible(x::EscapeLattice) -> Bool + is_load_forwardable(x::EscapeLattice) -> Bool -Queries allocation eliminability by SROA. +Queries if `x` is elibigle for store-to-load forwarding optimization. """ -function is_sroa_eligible(x::EscapeLattice) +function is_load_forwardable(x::EscapeLattice) if x.AliasEscapes === false || # allows this query to work for immutables since we don't impose escape on them isa(x.AliasEscapes, FieldEscapes) - return !has_return_escape(x) # TODO technically we also need to check !has_thrown_escape(x) as well + # NOTE technically we also need to check `!has_thrown_escape(x)` here as well, + # but we can also do equivalent check during forwarding + return true end return false end +""" + is_sroa_eligible(x::EscapeLattice) -> Bool + +Queries allocation eliminability by SROA. +""" +is_sroa_eligible(x::EscapeLattice) = is_load_forwardable(x) && !has_return_escape(x) + """ can_elide_finalizer(x::EscapeLattice, pc::Int) -> Bool diff --git a/src/utils.jl b/src/utils.jl index 373155c..0aa912f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -202,7 +202,7 @@ function get_name_color(x::EscapeLattice, symbol::Bool = false) name, color = (nothing, "*"), :red end name = symbol ? last(name) : first(name) - if name !== nothing && EA.has_aliasescapes(x) + if name !== nothing && !isa(x.AliasEscapes, Bool) name = string(name, "′") end return name, color diff --git a/test/runtests.jl b/test/runtests.jl index d52d230..e6d88fd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -524,7 +524,7 @@ end i = only(findall(isT(SafeRef{String}), result.ir.stmts.type)) r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(2)], r) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end let result = code_escapes((String,)) do a t = (a,) @@ -534,7 +534,7 @@ end i = only(findall(t->t<:Tuple, result.ir.stmts.type)) r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(2)], r) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end let result = code_escapes((String, String)) do a, b obj = SafeRefs(a, b) @@ -546,7 +546,7 @@ end r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(2)], r) # a @test has_return_escape(result.state[Argument(3)], r) # b - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end # field escape should propagate to `setfield!` argument @@ -559,7 +559,7 @@ end i = only(findall(isT(SafeRef{String}), result.ir.stmts.type)) r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(2)], r) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end # propagate escape information imposed on return value of `setfield!` call let result = code_escapes((String,)) do a @@ -569,7 +569,7 @@ end i = only(findall(isT(SafeRef{String}), result.ir.stmts.type)) r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(2)], r) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end # nested allocations @@ -584,7 +584,7 @@ end if isnew(result.ir.stmts.inst[i]) && isT(SafeRef{String})(result.ir.stmts.type[i]) @test has_return_escape(result.state[SSAValue(i)], r) elseif isnew(result.ir.stmts.inst[i]) && isT(SafeRef{SafeRef{String}})(result.ir.stmts.type[i]) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end end end @@ -599,7 +599,7 @@ end if isnew(result.ir.stmts.inst[i]) && isT(Tuple{String})(result.ir.stmts.type[i]) @test has_return_escape(result.state[SSAValue(i)], r) elseif isnew(result.ir.stmts.inst[i]) && isT(Tuple{Tuple{String}})(result.ir.stmts.type[i]) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end end end @@ -613,7 +613,7 @@ end r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(2)], r) for i in findall(isnew, result.ir.stmts.inst) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end end let result = code_escapes() do @@ -651,7 +651,7 @@ end i = only(findall(isnew, result.ir.stmts.inst)) r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(2)], r) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end # ϕ-node allocations @@ -667,9 +667,9 @@ end @test has_return_escape(result.state[Argument(3)], r) # x @test has_return_escape(result.state[Argument(4)], r) # y i = only(findall(isϕ, result.ir.stmts.inst)) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) for i in findall(isnew, result.ir.stmts.inst) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end end let result = code_escapes((Bool,Any,Any)) do cond, x, y @@ -684,10 +684,10 @@ end @test has_return_escape(result.state[Argument(3)], r) # x @test has_return_escape(result.state[Argument(4)], r) # y for i in findall(isϕ, result.ir.stmts.inst) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end for i in findall(isnew, result.ir.stmts.inst) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end end # when ϕ-node merges values with different types @@ -725,7 +725,7 @@ end end i = findfirst(isnew, result.ir.stmts.inst) @test has_all_escape(result.state[Argument(2)]) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end # alias via getfield & setfield! let result = @eval EATModule() begin @@ -740,7 +740,7 @@ end end i = findfirst(isnew, result.ir.stmts.inst) @test has_all_escape(result.state[Argument(3)]) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end # alias via typeassert let result = code_escapes((Any,)) do a @@ -774,10 +774,10 @@ end r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(3)], r) # x for i in findall(isϕ, result.ir.stmts.inst) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end for i in findall(isnew, result.ir.stmts.inst) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end end let result = code_escapes((Bool,Bool,String)) do cond1, cond2, x @@ -792,10 +792,10 @@ end r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(4)], r) # x for i in findall(isϕ, result.ir.stmts.inst) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end for i in findall(isnew, result.ir.stmts.inst) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end end # alias via π-node @@ -852,7 +852,7 @@ end i = only(findall(isT(SafeRef{String}), result.ir.stmts.type)) r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(2)], r) # a - @test is_sroa_eligible(result.state[SSAValue(i)]) # obj + @test is_load_forwardable(result.state[SSAValue(i)]) # obj end let result = code_escapes((String, String, Symbol)) do a, b, fld obj = SafeRefs(a, b) @@ -862,7 +862,7 @@ end r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(2)], r) # a @test has_return_escape(result.state[Argument(3)], r) # b - @test is_sroa_eligible(result.state[SSAValue(i)]) # obj + @test is_load_forwardable(result.state[SSAValue(i)]) # obj end let result = code_escapes((String, String, Int)) do a, b, idx obj = SafeRefs(a, b) @@ -872,7 +872,7 @@ end r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(2)], r) # a @test has_return_escape(result.state[Argument(3)], r) # b - @test is_sroa_eligible(result.state[SSAValue(i)]) # obj + @test is_load_forwardable(result.state[SSAValue(i)]) # obj end let result = code_escapes((String, String, Symbol)) do a, b, fld obj = SafeRefs("a", "b") @@ -883,7 +883,7 @@ end r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(2)], r) # a @test !has_return_escape(result.state[Argument(3)], r) # b - @test is_sroa_eligible(result.state[SSAValue(i)]) # obj + @test is_load_forwardable(result.state[SSAValue(i)]) # obj end let result = code_escapes((String, Symbol)) do a, fld obj = SafeRefs("a", "b") @@ -893,7 +893,7 @@ end i = only(findall(isT(SafeRefs{String,String}), result.ir.stmts.type)) r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(2)], r) # a - @test is_sroa_eligible(result.state[SSAValue(i)]) # obj + @test is_load_forwardable(result.state[SSAValue(i)]) # obj end let result = code_escapes((String, String, Int)) do a, b, idx obj = SafeRefs("a", "b") @@ -904,7 +904,7 @@ end r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(2)], r) # a @test !has_return_escape(result.state[Argument(3)], r) # b - @test is_sroa_eligible(result.state[SSAValue(i)]) # obj + @test is_load_forwardable(result.state[SSAValue(i)]) # obj end # interprocedural @@ -922,7 +922,7 @@ end r = only(findall(isreturn, result.ir.stmts.inst)) @test has_return_escape(result.state[Argument(2)], r) # NOTE we can't scalar replace `obj`, but still we may want to stack allocate it - @test_broken is_sroa_eligible(result.state[SSAValue(i)]) + @test_broken is_load_forwardable(result.state[SSAValue(i)]) end # TODO interprocedural field analysis @@ -946,7 +946,7 @@ end r = only(findall(isreturn, result.ir.stmts.inst)) @test_broken !has_return_escape(result.state[Argument(2)], r) # a @test has_return_escape(result.state[Argument(3)], r) # b - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end let result = code_escapes((Any,Any)) do a, b r = SafeRef{Any}(:init) @@ -958,7 +958,7 @@ end r = only(findall(isreturn, result.ir.stmts.inst)) @test_broken !has_return_escape(result.state[Argument(2)], r) # a @test has_return_escape(result.state[Argument(3)], r) # b - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end let result = code_escapes((Any,Any,Bool)) do a, b, cond r = SafeRef{Any}(:init) @@ -971,7 +971,7 @@ end end end i = only(findall(isnew, result.ir.stmts.inst)) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) r = only(findall(result.ir.stmts.inst) do @nospecialize x isreturn(x) && isa(x.val, Core.SSAValue) end) @@ -998,7 +998,7 @@ end @test has_return_escape(result.state[Argument(3)], r) # baz @test has_return_escape(result.state[Argument(4)], r) # qux for new in findall(isnew, result.ir.stmts.inst) - @test is_sroa_eligible(result.state[SSAValue(new)]) + @test is_load_forwardable(result.state[SSAValue(new)]) end end let result = code_escapes((Bool,String,String,)) do cnd, baz, qux @@ -1050,17 +1050,17 @@ function compute!(a, b) end let result = @code_escapes compute(MPoint, 1+.5im, 2+.5im, 2+.25im, 4+.75im) for i in findall(isnew, result.ir.stmts.inst) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end end let result = @code_escapes compute(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) for i in findall(isnew, result.ir.stmts.inst) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end end let result = @code_escapes compute!(MPoint(1+.5im, 2+.5im), MPoint(2+.25im, 4+.75im)) for i in findall(isnew, result.ir.stmts.inst) - @test is_sroa_eligible(result.state[SSAValue(i)]) + @test is_load_forwardable(result.state[SSAValue(i)]) end end