Skip to content

Commit

Permalink
inference: run find_throw_blocks once before inference (#42149)
Browse files Browse the repository at this point in the history
With this commit, we mark `src.ssaflags` before inference, and optimizer
will just observe them.

Fixes #42148
  • Loading branch information
aviatesk authored Sep 8, 2021
1 parent 15b9851 commit ed3691f
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 46 deletions.
2 changes: 1 addition & 1 deletion base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ end
function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
fargs::Union{Nothing,Vector{Any}}, argtypes::Vector{Any}, @nospecialize(atype),
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
if sv.params.unoptimize_throw_blocks && sv.currpc in sv.throw_blocks
if sv.params.unoptimize_throw_blocks && is_stmt_throw_block(get_curr_ssaflag(sv))
add_remark!(interp, sv, "Skipped call in throw block")
return CallMeta(Any, false)
end
Expand Down
12 changes: 7 additions & 5 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ mutable struct InferenceState
handler_at::Vector{LineNum}
# ssavalue sparsity and restart info
ssavalue_uses::Vector{BitSet}
throw_blocks::BitSet

cycle_backedges::Vector{Tuple{InferenceState, LineNum}} # call-graph backedges connecting from callee to caller
callers_in_cycle::Vector{InferenceState}
Expand All @@ -57,6 +56,8 @@ mutable struct InferenceState
(; def) = linfo = result.linfo
code = src.code::Vector{Any}

params = InferenceParams(interp)

sp = sptypes_from_meth_instance(linfo::MethodInstance)

nssavalues = src.ssavaluetypes::Int
Expand All @@ -81,26 +82,28 @@ mutable struct InferenceState
s_types[1] = s_argtypes

ssavalue_uses = find_ssavalue_uses(code, nssavalues)
throw_blocks = find_throw_blocks(code)

# exception handlers
ip = BitSet()
handler_at = compute_trycatch(src.code, ip)
push!(ip, 1)

# `throw` block deoptimization
params.unoptimize_throw_blocks && mark_throw_blocks!(src, handler_at)

mod = isa(def, Method) ? def.module : def
valid_worlds = WorldRange(src.min_world,
src.max_world == typemax(UInt) ? get_world_counter() : src.max_world)

@assert cache === :no || cache === :local || cache === :global
frame = new(
InferenceParams(interp), result, linfo,
params, result, linfo,
sp, slottypes, mod, 0,
IdSet{InferenceState}(), IdSet{InferenceState}(),
src, get_world_counter(interp), valid_worlds,
nargs, s_types, s_edges, stmt_info,
Union{}, ip, 1, n, handler_at,
ssavalue_uses, throw_blocks,
ssavalue_uses,
Vector{Tuple{InferenceState,LineNum}}(), # cycle_backedges
Vector{InferenceState}(), # callers_in_cycle
#=parent=#nothing,
Expand Down Expand Up @@ -197,7 +200,6 @@ function compute_trycatch(code::Vector{Any}, ip::BitSet)
return handler_at
end


"""
Iterate through all callers of the given InferenceState in the abstract
interpretation stack (including the given InferenceState itself), vising
Expand Down
25 changes: 12 additions & 13 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,15 @@ const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError

# This statement is marked as @inbounds by user.
# Ff replaced by inlining, any contained boundschecks may be removed.
const IR_FLAG_INBOUNDS = 0x01 << 0
const IR_FLAG_INBOUNDS = 0x01 << 0
# This statement is marked as @inline by user
const IR_FLAG_INLINE = 0x01 << 1
const IR_FLAG_INLINE = 0x01 << 1
# This statement is marked as @noinline by user
const IR_FLAG_NOINLINE = 0x01 << 2
const IR_FLAG_NOINLINE = 0x01 << 2
const IR_FLAG_THROW_BLOCK = 0x01 << 3
# This statement may be removed if its result is unused. In particular it must
# thus be both pure and effect free.
const IR_FLAG_EFFECT_FREE = 0x01 << 4
const IR_FLAG_EFFECT_FREE = 0x01 << 4

# known to be always effect-free (in particular nothrow)
const _PURE_BUILTINS = Any[tuple, svec, ===, typeof, nfields]
Expand Down Expand Up @@ -194,8 +195,9 @@ function isinlineable(m::Method, me::OptimizationState, params::OptimizationPara
return inlineable
end

is_stmt_inline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_INLINE != 0
is_stmt_noinline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_NOINLINE != 0
is_stmt_inline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_INLINE 0
is_stmt_noinline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_NOINLINE 0
is_stmt_throw_block(stmt_flag::UInt8) = stmt_flag & IR_FLAG_THROW_BLOCK 0

# These affect control flow within the function (so may not be removed
# if there is no usage within the function), but don't affect the purity
Expand Down Expand Up @@ -533,13 +535,12 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
end

function statement_or_branch_cost(@nospecialize(stmt), line::Int, src::Union{CodeInfo, IRCode}, sptypes::Vector{Any},
slottypes::Vector{Any}, union_penalties::Bool, params::OptimizationParams,
throw_blocks::Union{Nothing,BitSet})
slottypes::Vector{Any}, union_penalties::Bool, params::OptimizationParams)
thiscost = 0
dst(tgt) = isa(src, IRCode) ? first(src.cfg.blocks[tgt].stmts) : tgt
if stmt isa Expr
thiscost = statement_cost(stmt, line, src, sptypes, slottypes, union_penalties, params,
throw_blocks !== nothing && line in throw_blocks)::Int
is_stmt_throw_block(isa(src, IRCode) ? src.stmts.flag[line] : src.ssaflags[line]))::Int
elseif stmt isa GotoNode
# loops are generally always expensive
# but assume that forward jumps are already counted for from
Expand All @@ -554,24 +555,22 @@ end
function inline_worthy(ir::IRCode,
params::OptimizationParams, union_penalties::Bool=false, cost_threshold::Integer=params.inline_cost_threshold)
bodycost::Int = 0
throw_blocks = params.unoptimize_throw_blocks ? find_throw_blocks(ir.stmts.inst, RefValue(ir)) : nothing
for line = 1:length(ir.stmts)
stmt = ir.stmts[line][:inst]
thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, ir.argtypes, union_penalties, params, throw_blocks)
thiscost = statement_or_branch_cost(stmt, line, ir, ir.sptypes, ir.argtypes, union_penalties, params)
bodycost = plus_saturate(bodycost, thiscost)
bodycost > cost_threshold && return false
end
return true
end

function statement_costs!(cost::Vector{Int}, body::Vector{Any}, src::Union{CodeInfo, IRCode}, sptypes::Vector{Any}, unionpenalties::Bool, params::OptimizationParams)
throw_blocks = params.unoptimize_throw_blocks ? find_throw_blocks(body) : nothing
maxcost = 0
for line = 1:length(body)
stmt = body[line]
thiscost = statement_or_branch_cost(stmt, line, src, sptypes,
src isa CodeInfo ? src.slottypes : src.argtypes,
unionpenalties, params, throw_blocks)
unionpenalties, params)
cost[line] = thiscost
if thiscost > maxcost
maxcost = thiscost
Expand Down
4 changes: 0 additions & 4 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ struct OptimizationParams
MAX_TUPLE_SPLAT::Int
MAX_UNION_SPLITTING::Int

unoptimize_throw_blocks::Bool

function OptimizationParams(;
inlining::Bool = inlining_enabled(),
inline_cost_threshold::Int = 100,
Expand All @@ -64,7 +62,6 @@ struct OptimizationParams
max_methods::Int = 3,
tuple_splat::Int = 32,
union_splitting::Int = 4,
unoptimize_throw_blocks::Bool = true,
)
return new(
inlining,
Expand All @@ -75,7 +72,6 @@ struct OptimizationParams
max_methods,
tuple_splat,
union_splitting,
unoptimize_throw_blocks,
)
end
end
Expand Down
38 changes: 15 additions & 23 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,25 +319,27 @@ function is_throw_call(e::Expr)
return false
end

function find_throw_blocks(code::Vector{Any}, ir = RefValue{IRCode}())
function mark_throw_blocks!(src::CodeInfo, handler_at::Vector{Int})
for stmt in find_throw_blocks(src.code, handler_at)
src.ssaflags[stmt] |= IR_FLAG_THROW_BLOCK
end
return nothing
end

function find_throw_blocks(code::Vector{Any}, handler_at::Vector{Int})
stmts = BitSet()
n = length(code)
try_depth = 0
for i in n:-1:1
s = code[i]
if isa(s, Expr)
if s.head === :enter
try_depth -= 1
elseif s.head === :leave
try_depth += (s.args[1]::Int)
elseif s.head === :gotoifnot
tgt = s.args[2]::Int
if i+1 in stmts && tgt in stmts
if s.head === :gotoifnot
if i+1 in stmts && s.args[2]::Int in stmts
push!(stmts, i)
end
elseif s.head === :return
# see `ReturnNode` handling
elseif is_throw_call(s)
if try_depth == 0
if handler_at[i] == 0
push!(stmts, i)
end
elseif i+1 in stmts
Expand All @@ -348,22 +350,12 @@ function find_throw_blocks(code::Vector{Any}, ir = RefValue{IRCode}())
# (where !isdefined(s, :val)) as `throw` points, but that can cause
# worse codegen around the call site (issue #37558)
elseif isa(s, GotoNode)
tgt = s.label
if isassigned(ir)
tgt = first(ir[].cfg.blocks[tgt].stmts)
end
if tgt in stmts
if s.label in stmts
push!(stmts, i)
end
elseif isa(s, GotoIfNot)
if i+1 in stmts
tgt = s.dest::Int
if isassigned(ir)
tgt = first(ir[].cfg.blocks[tgt].stmts)
end
if tgt in stmts
push!(stmts, i)
end
if i+1 in stmts && s.dest in stmts
push!(stmts, i)
end
elseif i+1 in stmts
push!(stmts, i)
Expand Down
1 change: 1 addition & 0 deletions test/compiler/contextual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ module MiniCassette
# Insert one SSAValue for every argument statement
prepend!(code, [Expr(:call, getfield, SlotNumber(4), i) for i = 1:nargs])
prepend!(ci.codelocs, [0 for i = 1:nargs])
prepend!(ci.ssaflags, [0x00 for i = 1:nargs])
ci.ssavaluetypes += nargs
function map_slot_number(slot)
if slot == 1
Expand Down

0 comments on commit ed3691f

Please sign in to comment.