diff --git a/src/Overlay.jl b/src/Overlay.jl index fc035f94f..eac99c25f 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -132,7 +132,10 @@ for (cT, aT, bT) in ( C .= C2 end else - LinearAlgebra.mul!(C, A, B, α, β) + # Inference barrier is required when calling function recursively within overload + # This is required since otherwise type inference will think this is a recursive edge + # rather than a call to the base method + Base.inferencebarrier(LinearAlgebra.mul!)(C, A, B, α, β) end return C end @@ -150,6 +153,14 @@ end if use_overlayed_version(iter) return TracedRArrayOverrides.overloaded_stack(dims, iter) else - return Base._stack(dims, Base.IteratorSize(iter), iter) + iter2 = collect(iter) + if any(use_overlayed_version, iter2) + return TracedRArrayOverrides.overloaded_stack(dims, iter2) + else + # Inference barrier is required when calling function recursively within overload + # This is required since otherwise type inference will think this is a recursive edge + # rather than a call to the base method + return Base.inferencebarrier(Base._stack)(dims, iter2) + end end end diff --git a/src/utils.jl b/src/utils.jl index d674fcc77..e2f518cdf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -169,6 +169,8 @@ function is_reactant_method(mi::Core.MethodInstance) return mt === REACTANT_METHOD_TABLE end +struct MustThrowError end + @generated function applyiterate_with_reactant( iteratefn, applyfn, args::Vararg{Any,N} ) where {N} @@ -183,7 +185,29 @@ end end end -function rewrite_inst(inst, ir, interp) +@generated function applyiterate_with_reactant( + mt::MustThrowError, iteratefn, applyfn, args::Vararg{Any,N} +) where {N} + @assert iteratefn == typeof(Base.iterate) + newargs = Vector{Expr}(undef, N) + for i in 1:N + @inbounds newargs[i] = :(args[$i]...) + end + quote + Base.@_inline_meta + call_with_reactant(mt, applyfn, $(newargs...)) + end +end + +function certain_error() + throw( + AssertionError( + "The inferred code was guaranteed to throw this error. And yet, it didn't. So here we are...", + ), + ) +end + +function rewrite_inst(inst, ir, interp, RT, guaranteed_error) if Meta.isexpr(inst, :call) # Even if type unstable we do not want (or need) to replace intrinsic # calls or builtins with our version. @@ -194,12 +218,27 @@ function rewrite_inst(inst, ir, interp) if ft == typeof(Core._apply_iterate) ft = Core.Compiler.widenconst(maybe_argextype(inst.args[3], ir)) if should_rewrite_ft(ft) - rep = Expr(:call, applyiterate_with_reactant, inst.args[2:end]...) - return true, rep + if RT === Union{} + rep = Expr( + :call, + applyiterate_with_reactant, + MustThrowError(), + inst.args[2:end]..., + ) + return true, rep, Union{} + else + rep = Expr(:call, applyiterate_with_reactant, inst.args[2:end]...) + return true, rep, Any + end end elseif should_rewrite_ft(ft) - rep = Expr(:call, call_with_reactant, inst.args...) - return true, rep + if RT === Union{} + rep = Expr(:call, call_with_reactant, MustThrowError(), inst.args...) + return true, rep, Union{} + else + rep = Expr(:call, call_with_reactant, inst.args...) + return true, rep, Any + end end end if Meta.isexpr(inst, :invoke) @@ -215,8 +254,16 @@ function rewrite_inst(inst, ir, interp) min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) + # RT = Any + if !method.isva || !Base.isvarargtype(sig.parameters[end]) - sig2 = Tuple{typeof(call_with_reactant),sig.parameters...} + if RT === Union{} + sig2 = Tuple{ + typeof(call_with_reactant),MustThrowError,sig.parameters... + } + else + sig2 = Tuple{typeof(call_with_reactant),sig.parameters...} + end else vartup = inst.args[end] ns = Type[] @@ -224,9 +271,18 @@ function rewrite_inst(inst, ir, interp) for i in 1:(length(inst.args) - 1 - (length(sig.parameters) - 1)) push!(ns, eT) end - sig2 = Tuple{ - typeof(call_with_reactant),sig.parameters[1:(end - 1)]...,ns... - } + if RT === Union{} + sig2 = Tuple{ + typeof(call_with_reactant), + MustThrowError, + sig.parameters[1:(end - 1)]..., + ns..., + } + else + sig2 = Tuple{ + typeof(call_with_reactant),sig.parameters[1:(end - 1)]...,ns... + } + end end lookup_result = lookup_world( @@ -244,11 +300,41 @@ function rewrite_inst(inst, ir, interp) match.sparams, ) n_method_args = method.nargs - rep = Expr(:invoke, mi, call_with_reactant, inst.args[2:end]...) - return true, rep + if RT === Union{} + rep = Expr( + :invoke, mi, call_with_reactant, MustThrowError(), inst.args[2:end]... + ) + return true, rep, Union{} + else + rep = Expr(:invoke, mi, call_with_reactant, inst.args[2:end]...) + return true, rep, Any + end end end - return false, inst + if isa(inst, Core.ReturnNode) && (!isdefined(inst, :val) || guaranteed_error) + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + + sig2 = Tuple{typeof(certain_error)} + + lookup_result = lookup_world( + sig2, interp.world, Core.Compiler.method_table(interp), min_world, max_world + ) + + match = lookup_result::Core.MethodMatch + # look up the method and code instance + mi = ccall( + :jl_specializations_get_linfo, + Ref{Core.MethodInstance}, + (Any, Any, Any), + match.method, + match.spec_types, + match.sparams, + ) + rep = Expr(:invoke, mi, certain_error) + return true, rep, Union{} + end + return false, inst, RT end const oc_capture_vec = Vector{Any}() @@ -334,19 +420,22 @@ const DEBUG_INTERP = Ref(false) # to Any if our interpreter would change the return type of any result. # Also rewrite invoke (type stable call) to be :call, since otherwise apparently # screws up type inference after this (TODO this should be fixed). -function rewrite_insts!(ir, interp) +function rewrite_insts!(ir, interp, guaranteed_error) any_changed = false for (i, inst) in enumerate(ir.stmts) + # Explicitly skip any code which returns Union{} so that we throw the error + # instead of risking a segfault + RT = inst[:type] @static if VERSION < v"1.11" - changed, next = rewrite_inst(inst[:inst], ir, interp) + changed, next, RT = rewrite_inst(inst[:inst], ir, interp, RT, guaranteed_error) Core.Compiler.setindex!(ir.stmts[i], next, :inst) else - changed, next = rewrite_inst(inst[:stmt], ir, interp) + changed, next, RT = rewrite_inst(inst[:stmt], ir, interp, RT, guaranteed_error) Core.Compiler.setindex!(ir.stmts[i], next, :stmt) end if changed any_changed = true - Core.Compiler.setindex!(ir.stmts[i], Any, :type) + Core.Compiler.setindex!(ir.stmts[i], RT, :type) end end return ir, any_changed @@ -372,21 +461,30 @@ function call_with_reactant_generator( identity, Core.svec(:call_with_reactant, REDUB_ARGUMENTS_NAME), Core.svec() ) + fn = args[1] + sig = Tuple{args...} + + guaranteed_error = false + if fn === MustThrowError + guaranteed_error = true + fn = args[2] + sig = Tuple{args[2:end]...} + end + # look up the method match builtin_error = - :(throw(AssertionError("Unsupported call_with_reactant of builtin $(args[1])"))) + :(throw(AssertionError("Unsupported call_with_reactant of builtin $fn"))) - if args[1] <: Core.Builtin + if fn <: Core.Builtin return stub(world, source, builtin_error) end + method_error = :(throw( MethodError($REDUB_ARGUMENTS_NAME[1], $REDUB_ARGUMENTS_NAME[2:end], $world) )) interp = ReactantInterpreter(; world) - sig = Tuple{args...} - min_world = Ref{UInt}(typemin(UInt)) max_world = Ref{UInt}(typemax(UInt)) @@ -427,8 +525,19 @@ function call_with_reactant_generator( ir, rt = CC.typeinf_ircode(interp, mi, nothing) end - if !is_reactant_method(mi::Core.MethodInstance) - ir, any_changed = rewrite_insts!(ir, interp) + if guaranteed_error + if rt !== Union{} + safe_print("Inconsistent guaranteed error IR", ir) + end + rt = Union{} + end + + if DEBUG_INTERP[] + safe_print("ir", ir) + end + + if !is_reactant_method(mi::Core.MethodInstance) || guaranteed_error + ir, any_changed = rewrite_insts!(ir, interp, guaranteed_error) end src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ()) @@ -474,6 +583,10 @@ function call_with_reactant_generator( fn_args = Any[] n_method_args = method.nargs n_actual_args = length(redub_arguments) + if guaranteed_error + offset += 1 + n_actual_args -= 1 + end tys = [] @@ -490,7 +603,7 @@ function call_with_reactant_generator( push!(overdubbed_codelocs, code_info.codelocs[1]) offset += 1 push!(fn_args, Core.SSAValue(length(overdubbed_code))) - push!(tys, redub_arguments[i]) + push!(tys, redub_arguments[i + (guaranteed_error ? 1 : 0)]) if DEBUG_INTERP[] push!( @@ -523,7 +636,12 @@ function call_with_reactant_generator( push!(overdubbed_code, trailing_arguments) push!(overdubbed_codelocs, code_info.codelocs[1]) push!(fn_args, Core.SSAValue(length(overdubbed_code))) - push!(tys, Tuple{redub_arguments[n_method_args:n_actual_args]...}) + push!( + tys, + Tuple{ + redub_arguments[(n_method_args:n_actual_args) .+ (guaranteed_error ? 1 : 0)]..., + }, + ) if DEBUG_INTERP[] push!( @@ -554,7 +672,7 @@ function call_with_reactant_generator( # Opaque closures also require taking the function argument. We can work around the latter # if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure - dict, make_oc = if Base.issingletontype(args[1]) + dict, make_oc = if Base.issingletontype(fn) Base.Ref{Core.OpaqueClosure}(), make_oc_ref else Dict{args[1],Core.OpaqueClosure}(), make_oc_dict @@ -562,9 +680,9 @@ function call_with_reactant_generator( push!(oc_capture_vec, dict) - oc = if false && Base.issingletontype(args[1]) + oc = if false && Base.issingletontype(fn) res = Core._call_in_world_total( - world, make_oc, dict, octup, rt, src, ocnargs, ocva, args[1].instance + world, make_oc, dict, octup, rt, src, ocnargs, ocva, fn.instance )::Core.OpaqueClosure else @@ -576,10 +694,16 @@ function call_with_reactant_generator( end push!(overdubbed_code, Expr(:call, oc, fn_args[2:end]...)) - push!(overdubbed_codelocs, code_info.codelocs[1]) - push!(overdubbed_code, Core.ReturnNode(Core.SSAValue(length(overdubbed_code)))) + ocres = Core.SSAValue(length(overdubbed_code)) + + if DEBUG_INTERP[] + push!(overdubbed_code, Expr(:call, safe_print, "ocres", ocres)) + push!(overdubbed_codelocs, code_info.codelocs[1]) + end + + push!(overdubbed_code, Core.ReturnNode(ocres)) push!(overdubbed_codelocs, code_info.codelocs[1]) #=== set `code_info`/`reflection` fields accordingly ===# diff --git a/test/basic.jl b/test/basic.jl index d05d3d968..231a54754 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -880,6 +880,50 @@ end @test @jit(s4(x, y)) isa Any end +@testset "unstable stack" begin + x = rand(4, 4) + y = rand(4, 4) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + function s1(x) + xs = [] + push!(xs, x) + push!(xs, x) + return stack(xs) + end + function s2(x) + xs = [] + push!(xs, x) + push!(xs, x) + return stack(xs; dims=2) + end + function s3(x, y) + xs = [] + push!(xs, x) + push!(xs, y) + return stack(xs; dims=2) + end + function s4(x, y) + xs = [] + push!(xs, x) + push!(xs, y) + push!(xs, x) + return stack(xs; dims=2) + end + + @test @jit(s1(x_ra)) ≈ s1(x) + @test @jit(s2(x_ra)) ≈ s2(x) + @test @jit(s3(x_ra, y_ra)) ≈ s3(x, y) + @test @jit(s4(x_ra, y_ra)) ≈ s4(x, y) + + # Test that we don't hit illegal instruction; `x` is intentionally not a traced array + @test @jit(s1(x)) isa Any + @test @jit(s2(x)) isa Any + @test @jit(s3(x, y)) isa Any + @test @jit(s4(x, y)) isa Any +end + @testset "Boolean Indexing" begin x_ra = Reactant.to_rarray(rand(Float32, 4, 16)) idxs_ra = Reactant.to_rarray(rand(Bool, 16))