Skip to content

Commit

Permalink
Fix type unstable stack (#478)
Browse files Browse the repository at this point in the history
* Fix type unstable stack

* fix

* Update src/Overlay.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update Overlay.jl

* with test

* fix

* Fix stack recursive inference issue

* fix

* fix

* fix

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
wsmoses and github-actions[bot] authored Jan 6, 2025
1 parent 03b8363 commit 75c3140
Show file tree
Hide file tree
Showing 3 changed files with 210 additions and 31 deletions.
15 changes: 13 additions & 2 deletions src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,10 @@ for (cT, aT, bT) in (
C .= C2
end
else
LinearAlgebra.mul!(C, A, B, α, β)
# Inference barrier is required when calling function recursively within overload
# This is required since otherwise type inference will think this is a recursive edge
# rather than a call to the base method
Base.inferencebarrier(LinearAlgebra.mul!)(C, A, B, α, β)
end
return C
end
Expand All @@ -150,6 +153,14 @@ end
if use_overlayed_version(iter)
return TracedRArrayOverrides.overloaded_stack(dims, iter)
else
return Base._stack(dims, Base.IteratorSize(iter), iter)
iter2 = collect(iter)
if any(use_overlayed_version, iter2)
return TracedRArrayOverrides.overloaded_stack(dims, iter2)
else
# Inference barrier is required when calling function recursively within overload
# This is required since otherwise type inference will think this is a recursive edge
# rather than a call to the base method
return Base.inferencebarrier(Base._stack)(dims, iter2)
end
end
end
182 changes: 153 additions & 29 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,8 @@ function is_reactant_method(mi::Core.MethodInstance)
return mt === REACTANT_METHOD_TABLE
end

struct MustThrowError end

@generated function applyiterate_with_reactant(
iteratefn, applyfn, args::Vararg{Any,N}
) where {N}
Expand All @@ -183,7 +185,29 @@ end
end
end

function rewrite_inst(inst, ir, interp)
@generated function applyiterate_with_reactant(
mt::MustThrowError, iteratefn, applyfn, args::Vararg{Any,N}
) where {N}
@assert iteratefn == typeof(Base.iterate)
newargs = Vector{Expr}(undef, N)
for i in 1:N
@inbounds newargs[i] = :(args[$i]...)
end
quote
Base.@_inline_meta
call_with_reactant(mt, applyfn, $(newargs...))
end
end

function certain_error()
throw(
AssertionError(
"The inferred code was guaranteed to throw this error. And yet, it didn't. So here we are...",
),
)
end

function rewrite_inst(inst, ir, interp, RT, guaranteed_error)
if Meta.isexpr(inst, :call)
# Even if type unstable we do not want (or need) to replace intrinsic
# calls or builtins with our version.
Expand All @@ -194,12 +218,27 @@ function rewrite_inst(inst, ir, interp)
if ft == typeof(Core._apply_iterate)
ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir))
if should_rewrite_ft(ft)
rep = Expr(:call, applyiterate_with_reactant, inst.args[2:end]...)
return true, rep
if RT === Union{}
rep = Expr(
:call,
applyiterate_with_reactant,
MustThrowError(),
inst.args[2:end]...,
)
return true, rep, Union{}
else
rep = Expr(:call, applyiterate_with_reactant, inst.args[2:end]...)
return true, rep, Any
end
end
elseif should_rewrite_ft(ft)
rep = Expr(:call, call_with_reactant, inst.args...)
return true, rep
if RT === Union{}
rep = Expr(:call, call_with_reactant, MustThrowError(), inst.args...)
return true, rep, Union{}
else
rep = Expr(:call, call_with_reactant, inst.args...)
return true, rep, Any
end
end
end
if Meta.isexpr(inst, :invoke)
Expand All @@ -215,18 +254,35 @@ function rewrite_inst(inst, ir, interp)
min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))

# RT = Any

if !method.isva || !Base.isvarargtype(sig.parameters[end])
sig2 = Tuple{typeof(call_with_reactant),sig.parameters...}
if RT === Union{}
sig2 = Tuple{
typeof(call_with_reactant),MustThrowError,sig.parameters...
}
else
sig2 = Tuple{typeof(call_with_reactant),sig.parameters...}
end
else
vartup = inst.args[end]
ns = Type[]
eT = sig.parameters[end].T
for i in 1:(length(inst.args) - 1 - (length(sig.parameters) - 1))
push!(ns, eT)
end
sig2 = Tuple{
typeof(call_with_reactant),sig.parameters[1:(end - 1)]...,ns...
}
if RT === Union{}
sig2 = Tuple{
typeof(call_with_reactant),
MustThrowError,
sig.parameters[1:(end - 1)]...,
ns...,
}
else
sig2 = Tuple{
typeof(call_with_reactant),sig.parameters[1:(end - 1)]...,ns...
}
end
end

lookup_result = lookup_world(
Expand All @@ -244,11 +300,41 @@ function rewrite_inst(inst, ir, interp)
match.sparams,
)
n_method_args = method.nargs
rep = Expr(:invoke, mi, call_with_reactant, inst.args[2:end]...)
return true, rep
if RT === Union{}
rep = Expr(
:invoke, mi, call_with_reactant, MustThrowError(), inst.args[2:end]...
)
return true, rep, Union{}
else
rep = Expr(:invoke, mi, call_with_reactant, inst.args[2:end]...)
return true, rep, Any
end
end
end
return false, inst
if isa(inst, Core.ReturnNode) && (!isdefined(inst, :val) || guaranteed_error)
min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))

sig2 = Tuple{typeof(certain_error)}

lookup_result = lookup_world(
sig2, interp.world, Core.Compiler.method_table(interp), min_world, max_world
)

match = lookup_result::Core.MethodMatch
# look up the method and code instance
mi = ccall(
:jl_specializations_get_linfo,
Ref{Core.MethodInstance},
(Any, Any, Any),
match.method,
match.spec_types,
match.sparams,
)
rep = Expr(:invoke, mi, certain_error)
return true, rep, Union{}
end
return false, inst, RT
end

const oc_capture_vec = Vector{Any}()
Expand Down Expand Up @@ -334,19 +420,22 @@ const DEBUG_INTERP = Ref(false)
# to Any if our interpreter would change the return type of any result.
# Also rewrite invoke (type stable call) to be :call, since otherwise apparently
# screws up type inference after this (TODO this should be fixed).
function rewrite_insts!(ir, interp)
function rewrite_insts!(ir, interp, guaranteed_error)
any_changed = false
for (i, inst) in enumerate(ir.stmts)
# Explicitly skip any code which returns Union{} so that we throw the error
# instead of risking a segfault
RT = inst[:type]
@static if VERSION < v"1.11"
changed, next = rewrite_inst(inst[:inst], ir, interp)
changed, next, RT = rewrite_inst(inst[:inst], ir, interp, RT, guaranteed_error)
Core.Compiler.setindex!(ir.stmts[i], next, :inst)
else
changed, next = rewrite_inst(inst[:stmt], ir, interp)
changed, next, RT = rewrite_inst(inst[:stmt], ir, interp, RT, guaranteed_error)
Core.Compiler.setindex!(ir.stmts[i], next, :stmt)
end
if changed
any_changed = true
Core.Compiler.setindex!(ir.stmts[i], Any, :type)
Core.Compiler.setindex!(ir.stmts[i], RT, :type)
end
end
return ir, any_changed
Expand All @@ -372,21 +461,30 @@ function call_with_reactant_generator(
identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec()
)

fn = args[1]
sig = Tuple{args...}

guaranteed_error = false
if fn === MustThrowError
guaranteed_error = true
fn = args[2]
sig = Tuple{args[2:end]...}
end

# look up the method match
builtin_error =
:(throw(AssertionError("Unsupported call_with_reactant of builtin $(args[1])")))
:(throw(AssertionError("Unsupported call_with_reactant of builtin $fn")))

if args[1] <: Core.Builtin
if fn <: Core.Builtin
return stub(world, source, builtin_error)
end

method_error = :(throw(
MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world)
))

interp = ReactantInterpreter(; world)

sig = Tuple{args...}

min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))

Expand Down Expand Up @@ -427,8 +525,19 @@ function call_with_reactant_generator(
ir, rt = CC.typeinf_ircode(interp, mi, nothing)
end

if !is_reactant_method(mi::Core.MethodInstance)
ir, any_changed = rewrite_insts!(ir, interp)
if guaranteed_error
if rt !== Union{}
safe_print("Inconsistent guaranteed error IR", ir)
end
rt = Union{}
end

if DEBUG_INTERP[]
safe_print("ir", ir)
end

if !is_reactant_method(mi::Core.MethodInstance) || guaranteed_error
ir, any_changed = rewrite_insts!(ir, interp, guaranteed_error)
end

src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ())
Expand Down Expand Up @@ -474,6 +583,10 @@ function call_with_reactant_generator(
fn_args = Any[]
n_method_args = method.nargs
n_actual_args = length(redub_arguments)
if guaranteed_error
offset += 1
n_actual_args -= 1
end

tys = []

Expand All @@ -490,7 +603,7 @@ function call_with_reactant_generator(
push!(overdubbed_codelocs, code_info.codelocs[1])
offset += 1
push!(fn_args, Core.SSAValue(length(overdubbed_code)))
push!(tys, redub_arguments[i])
push!(tys, redub_arguments[i + (guaranteed_error ? 1 : 0)])

if DEBUG_INTERP[]
push!(
Expand Down Expand Up @@ -523,7 +636,12 @@ function call_with_reactant_generator(
push!(overdubbed_code, trailing_arguments)
push!(overdubbed_codelocs, code_info.codelocs[1])
push!(fn_args, Core.SSAValue(length(overdubbed_code)))
push!(tys, Tuple{redub_arguments[n_method_args:n_actual_args]...})
push!(
tys,
Tuple{
redub_arguments[(n_method_args:n_actual_args) .+ (guaranteed_error ? 1 : 0)]...,
},
)

if DEBUG_INTERP[]
push!(
Expand Down Expand Up @@ -554,17 +672,17 @@ function call_with_reactant_generator(
# Opaque closures also require taking the function argument. We can work around the latter
# if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure

dict, make_oc = if Base.issingletontype(args[1])
dict, make_oc = if Base.issingletontype(fn)
Base.Ref{Core.OpaqueClosure}(), make_oc_ref
else
Dict{args[1],Core.OpaqueClosure}(), make_oc_dict
end

push!(oc_capture_vec, dict)

oc = if false && Base.issingletontype(args[1])
oc = if false && Base.issingletontype(fn)
res = Core._call_in_world_total(
world, make_oc, dict, octup, rt, src, ocnargs, ocva, args[1].instance
world, make_oc, dict, octup, rt, src, ocnargs, ocva, fn.instance
)::Core.OpaqueClosure

else
Expand All @@ -576,10 +694,16 @@ function call_with_reactant_generator(
end

push!(overdubbed_code, Expr(:call, oc, fn_args[2:end]...))

push!(overdubbed_codelocs, code_info.codelocs[1])

push!(overdubbed_code, Core.ReturnNode(Core.SSAValue(length(overdubbed_code))))
ocres = Core.SSAValue(length(overdubbed_code))

if DEBUG_INTERP[]
push!(overdubbed_code, Expr(:call, safe_print, "ocres", ocres))
push!(overdubbed_codelocs, code_info.codelocs[1])
end

push!(overdubbed_code, Core.ReturnNode(ocres))
push!(overdubbed_codelocs, code_info.codelocs[1])

#=== set `code_info`/`reflection` fields accordingly ===#
Expand Down
44 changes: 44 additions & 0 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,50 @@ end
@test @jit(s4(x, y)) isa Any
end

@testset "unstable stack" begin
x = rand(4, 4)
y = rand(4, 4)
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

function s1(x)
xs = []
push!(xs, x)
push!(xs, x)
return stack(xs)
end
function s2(x)
xs = []
push!(xs, x)
push!(xs, x)
return stack(xs; dims=2)
end
function s3(x, y)
xs = []
push!(xs, x)
push!(xs, y)
return stack(xs; dims=2)
end
function s4(x, y)
xs = []
push!(xs, x)
push!(xs, y)
push!(xs, x)
return stack(xs; dims=2)
end

@test @jit(s1(x_ra)) s1(x)
@test @jit(s2(x_ra)) s2(x)
@test @jit(s3(x_ra, y_ra)) s3(x, y)
@test @jit(s4(x_ra, y_ra)) s4(x, y)

# Test that we don't hit illegal instruction; `x` is intentionally not a traced array
@test @jit(s1(x)) isa Any
@test @jit(s2(x)) isa Any
@test @jit(s3(x, y)) isa Any
@test @jit(s4(x, y)) isa Any
end

@testset "Boolean Indexing" begin
x_ra = Reactant.to_rarray(rand(Float32, 4, 16))
idxs_ra = Reactant.to_rarray(rand(Bool, 16))
Expand Down

2 comments on commit 75c3140

@wsmoses
Copy link
Member Author

@wsmoses wsmoses commented on 75c3140 Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/122439

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.16 -m "<description of version>" 75c31406631d3a31502fafee39bd36d734684865
git push origin v0.2.16

Please sign in to comment.