Skip to content

Commit

Permalink
wip: runtime analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed May 12, 2023
1 parent 30078f0 commit c94714a
Show file tree
Hide file tree
Showing 4 changed files with 281 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/JET.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ export
# optanalyzer
@report_opt, report_opt, @test_opt, test_opt,
# configurations
LastFrameModule, AnyFrameModule
LastFrameModule, AnyFrameModule,
@analysispass

let README = normpath(dirname(@__DIR__), "README.md")
s = read(README, String)
Expand Down Expand Up @@ -79,12 +80,12 @@ using .CC:
argextype, argtype_by_index, argtype_tail, argtypes_to_type, compute_basic_blocks,
get_compileable_sig, hasintersect, has_free_typevars, ignorelimited, inlining_enabled,
instanceof_tfunc, is_throw_call, isType, isconstType, issingletontype,
may_invoke_generator, singleton_type, slot_id, specialize_method, switchtupleunion,
tmerge, widenconst,
may_invoke_generator, retrieve_code_info, singleton_type, slot_id, specialize_method,
switchtupleunion, tmerge, widenconst,

using Base:
@invoke, @invokelatest, IdSet, default_tt, destructure_callex, parse_input_line,
@invoke, @invokelatest, IdSet, default_tt, parse_input_line,
rewrap_unionall, uniontypes, unwrap_unionall

using Base.Meta:
Expand Down Expand Up @@ -152,13 +153,20 @@ end

@static isdefined(CC, :StmtInfo) && import .CC: StmtInfo

@static if VERSION v"1.10.0-DEV.96"
# TODO investigate why `pass_generator` is called with `world === typemax(UInt)`
# and hit the error in the `Base._which` version
@static if false # VERSION ≥ v"1.10.0-DEV.96"
using Base: _which
else
# HACK This definition is same as the one defined in
# https://github.com/JuliaLang/julia/blob/38d24e574caab20529a61a6f7444c9e473724ccc/base/reflection.jl#L1565
# modulo that this version allows us to use it within a `@generated` context
# (see the commented out `world == typemax(UInt) && error(...)` line below).
function _which(@nospecialize(tt::Type);
method_table::Union{Nothing,MethodTable,Core.Compiler.MethodTableView}=nothing,
world::UInt=get_world_counter(),
raise::Bool=false)
# world == typemax(UInt) && error("code reflection cannot be used from generated functions")
if method_table === nothing
table = Core.Compiler.InternalMethodTable(world)
elseif isa(method_table, MethodTable)
Expand Down Expand Up @@ -1397,4 +1405,6 @@ using PrecompileTools
end
end

include("runtime.jl")

end # module JET
233 changes: 233 additions & 0 deletions src/runtime.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
using Core.IR
using Base: to_tuple_type

abstract type AnalysisPass end
function get_constructor end
function get_jetconfigs end

# JuliaLang/julia#48611: world age is exposed to generated functions, and should be used
const has_generated_worlds = let
v = VERSION v"1.10.0-DEV.873"
v && @assert fieldcount(Core.GeneratedFunctionStub) == 3
v
end

function analyze_and_generate(world::UInt, source::LineNumberNode, passtype, fargtypes)
tt = to_tuple_type(fargtypes)
match = _which(tt; raise=false, world)
match === nothing && return nothing

mi = specialize_method(match)
Analyzer = get_constructor(passtype)
jetconfigs = get_jetconfigs(passtype)
analyzer = Analyzer(world; jetconfigs...)
analyzer, result = analyze_method_instance!(analyzer, mi)
analyzername = nameof(typeof(analyzer))
sig = LazyPrinter(io::IO->Base.show_tuple_as_call(io, Symbol(""), tt))
src = lazy"$analyzername: $sig"
res = JETCallResult(result, analyzer, src; jetconfigs...)

isempty(get_reports(res)) || return generate_report_error_ex(world, source, mi, res)

src = @static if has_generated_worlds
copy(retrieve_code_info(mi, world)::CodeInfo)
else
copy(retrieve_code_info(mi)::CodeInfo)
end
analysispass_transform!(src, mi, length(fargtypes))
return src
end

struct JETRuntimeError <: Exception
mi::MethodInstance
res::JETCallResult
end
function Base.showerror(io::IO, err::JETRuntimeError)
n = length(get_reports(err.res))
print(io, "JETRuntimeError raised by `$(err.res.source)`:")
println(io)
show(io, err.res)
end

function generate_report_error_ex(world::UInt, source::LineNumberNode,
mi::MethodInstance, res::JETCallResult)
args = Core.svec(:pass, :fargs)
sparams = Core.svec()
ex = :(throw($JETRuntimeError($mi, $res)))
return generate_lambda_ex(world, source, args, sparams, ex)
end

function generate_lambda_ex(world::UInt, source::LineNumberNode,
args::SimpleVector, sparams::SimpleVector, body::Expr)
stub = Core.GeneratedFunctionStub(identity, args, sparams)
return stub(world, source, body)
end

# TODO share this code with CassetteOverlay
function analysispass_transform!(src::CodeInfo, mi::MethodInstance, nargs::Int)
method = mi.def::Method
mnargs = Int(method.nargs)

src.slotnames = Symbol[Symbol("#self#"), :fargs, src.slotnames[mnargs+1:end]...]
src.slotflags = UInt8[ 0x00, 0x00, src.slotflags[mnargs+1:end]...]

code = src.code
fargsslot = SlotNumber(2)
precode = Any[]
local ssaid = 0
for i = 1:mnargs
if method.isva && i == mnargs
args = map(i:nargs) do j
push!(precode, Expr(:call, getfield, fargsslot, j))
ssaid += 1
return SSAValue(ssaid)
end
push!(precode, Expr(:call, tuple, args...))
else
push!(precode, Expr(:call, getfield, fargsslot, i))
end
ssaid += 1
end
prepend!(code, precode)
prepend!(src.codelocs, [0 for i = 1:ssaid])
prepend!(src.ssaflags, [0x00 for i = 1:ssaid])
src.ssavaluetypes += ssaid

function map_slot_number(slot::Int)
@assert slot 1
if 1 slot mnargs
if method.isva && slot == mnargs
return SSAValue(ssaid)
else
return SSAValue(slot)
end
else
return SlotNumber(slot - mnargs + 2)
end
end
map_ssa_value(id::Int) = id + ssaid
for i = (ssaid+1:length(code))
code[i] = transform_stmt(code[i], map_slot_number, map_ssa_value, mi.def.sig, mi.sparam_vals)
end

src.edges = MethodInstance[mi]
src.method_for_inference_limit_heuristics = method

return src
end

function transform_stmt(@nospecialize(x), map_slot_number, map_ssa_value, @nospecialize(spsig), sparams::SimpleVector)
transform(@nospecialize x′) = transform_stmt(x′, map_slot_number, map_ssa_value, spsig, sparams)

if isa(x, Expr)
head = x.head
if head === :call
return Expr(:call, SlotNumber(1), map(transform, x.args)...)
elseif head === :foreigncall
# first argument of :foreigncall is a magic tuple and should be preserved
arg2 = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), x.args[2], spsig, sparams)
arg3 = Core.svec(Any[
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, sparams)
for argt in x.args[3]::SimpleVector ]...)
return Expr(:foreigncall, x.args[1], arg2, arg3, map(transform, x.args[4:end])...)
elseif head === :enter
return Expr(:enter, map_ssa_value(x.args[1]::Int))
elseif head === :static_parameter
return sparams[x.args[1]::Int]
end
return Expr(x.head, map(transform, x.args)...)
elseif isa(x, GotoNode)
return GotoNode(map_ssa_value(x.label))
elseif isa(x, GotoIfNot)
return GotoIfNot(transform(x.cond), map_ssa_value(x.dest))
elseif isa(x, ReturnNode)
return ReturnNode(transform(x.val))
elseif isa(x, SlotNumber)
return map_slot_number(x.id)
elseif isa(x, NewvarNode)
return NewvarNode(map_slot_number(x.slot.id))
elseif isa(x, SSAValue)
return SSAValue(map_ssa_value(x.id))
else
return x
end
end

function pass_generator(world::UInt, source::LineNumberNode, pass, fargs)
src = analyze_and_generate(world, source, pass, fargs)
if src === nothing
# code generation failed – make it raise a proper MethodError
stub = Core.GeneratedFunctionStub(identity, Core.svec(:pass, :fargs), Core.svec())
return stub(world, source, :(return first(fargs)(Base.tail(fargs)...)))
end
return src
end

"""
@analysispass Analyzer [jetconfigs...]
TODO docs.
"""
macro analysispass(args...)
isempty(args) && throw(ArgumentError("`@analysispass` expected more than one argument."))
analyzertype = args[1]
params = Expr(:parameters)
append!(params.args, args[2:end])
jetconfigs = Expr(:tuple, params)

PassName = esc(gensym(string(analyzertype)))

blk = quote
let analyzertypetype = Core.Typeof($(esc(analyzertype)))
if !(analyzertypetype <: Type{<:$(@__MODULE__).AbstractAnalyzer})
throw(ArgumentError(
"`@analysispass` expected a subtype of `JET.AbstractAnalyzer`, but got object of `$analyzertypetype`."))
end
end

struct $PassName <: $AnalysisPass end

$(@__MODULE__).get_constructor(::Type{$PassName}) = $(esc(analyzertype))
$(@__MODULE__).get_jetconfigs(::Type{$PassName}) = $(esc(jetconfigs))

@inline function (::$PassName)(f::Union{Core.Builtin,Core.IntrinsicFunction}, args...)
@nospecialize f args
return f(args...)
end
@inline function (self::$PassName)(::typeof(Core.Compiler.return_type), tt::DataType)
# return Core.Compiler.return_type(self, tt)
return Core.Compiler.return_type(tt)
end
@inline function (self::$PassName)(::typeof(Core.Compiler.return_type), f, tt::DataType)
newtt = Base.signature_type(f, tt)
# return Core.Compiler.return_type(self, newtt)
return Core.Compiler.return_type(newtt)
end
@inline function (self::$PassName)(::typeof(Core._apply_iterate), iterate, f, args...)
@nospecialize args
return Core.Compiler._apply_iterate(iterate, self, (f,), args...)
end

@static if $has_generated_worlds
function (pass::$PassName)(fargs...)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, pass_generator))
end
else
@generated function (pass::$PassName)($(esc(:fargs))...)
world = Base.get_world_counter()
source = LineNumberNode(@__LINE__, @__FILE__)
src = $analyze_and_generate(world, pass, fargs)
if src === nothing
# a code generation failed – make it raise a proper MethodError
return :(first(fargs)(Base.tail(fargs)...))
end
return src
end
end

return $PassName()
end

return Expr(:toplevel, blk.args...)
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ using Test, JET
@testset "OptAnalyzer" include("analyzers/test_optanalyzer.jl")
end

@testset "runtime" include("runtime.jl")

@testset "performance" include("performance.jl")

@testset "sanity check" include("sanity_check.jl")
Expand Down
31 changes: 31 additions & 0 deletions test/test_runtime.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
module test_runtime

using JET, Test

call_xs(f, xs) = f(xs[])

@test_throws "Type{$Int}" @analysispass Int

pass1 = @analysispass JET.OptAnalyzer
@test pass1() do
call_xs(sin, Ref(42))
end == sin(42)
@test_throws JET.JETRuntimeError pass1() do
call_xs(sin, Ref{Any}(42))
end

function_filter(@nospecialize f) = f !== sin
pass2 = @analysispass JET.OptAnalyzer function_filter
@test pass2() do
call_xs(sin, Ref(42))
end == sin(42)
@test pass2() do
call_xs(sin, Ref{Any}(42))
end

pass3 = @analysispass JET.JETAnalyzer
@test pass3() do
collect(1:10)
end == collect(1:10)

end # module test_runtime

0 comments on commit c94714a

Please sign in to comment.