From 30c6ff9c606a7428249968e65369e57e0f95c7ad Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Sat, 1 Feb 2025 00:50:00 +0000 Subject: [PATCH] Refactor generated function implementation to provide bindings/method This PR refactors the generated function implementation in multiple ways: 1. Rather than allocating a new LineNumber node to pass to the generator, we just pass the original method from which this LineNumberNode was constructed. This has been a bit of a longer-standing annoyance of mine, since the generator needs to know properties of the original method to properly interpret the return value from the generator, but this information was only available on the C side. 2. Move the handling of `Expr` returns fully into Julia. Right not things were a bit split with the julia code post-processing an `Expr` return, but then handing it back to C for lowering. By moving it fully into Julia, we can keep the C-side interface simpler by always getting a `CodeInfo`. With these refactorings done, amend the post-processing code to provide binding edges for `Expr` returns. Ordinarily, bindings in lowered code do not need edges, because we will scan the lowered code of the method to find them. However, generated functions are different, because we do not in general have the lowered code available. To still give them binding edges, we simply scan through the post-lowered code and all of the bindings we find into the edges array. I will note that both of these will require minor adjustments to `@generated` functions that use the CodeInfo interface (N.B.: this interface is not considered stable and we've broken it in almost every release so far). In particular, the following adjustments need to be made: 1. Adjusting the `source` argument to the new `Method` ABI 2. If necessary, adding any edges that correspond to GlobalRefs used - the code will treat the returned CodeInfo mostly opaquely and (unlike in the `Expr` case) will not automatically compute these edges. --- base/Base_compiler.jl | 2 +- base/boot.jl | 21 ------------------ base/expr.jl | 43 ++++++++++++++++++++++++++++++++++++ base/invalidation.jl | 22 ++++++++++--------- src/jl_exported_funcs.inc | 1 - src/julia_internal.h | 1 + src/method.c | 46 ++++++++------------------------------- src/module.c | 25 ++++++++++++--------- test/rebinding.jl | 7 ++++++ test/staged.jl | 6 ++--- 10 files changed, 91 insertions(+), 83 deletions(-) diff --git a/base/Base_compiler.jl b/base/Base_compiler.jl index 8cc7096ee26a7..4ec6bae171d8f 100644 --- a/base/Base_compiler.jl +++ b/base/Base_compiler.jl @@ -234,13 +234,13 @@ include("abstractarray.jl") include("baseext.jl") include("c.jl") -include("ntuple.jl") include("abstractset.jl") include("bitarray.jl") include("bitset.jl") include("abstractdict.jl") include("iddict.jl") include("idset.jl") +include("ntuple.jl") include("iterators.jl") using .Iterators: zip, enumerate, only using .Iterators: Flatten, Filter, product # for generators diff --git a/base/boot.jl b/base/boot.jl index 9b386f90d4abe..e50d74659d399 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -777,27 +777,6 @@ struct GeneratedFunctionStub spnames::SimpleVector end -# invoke and wrap the results of @generated expression -function (g::GeneratedFunctionStub)(world::UInt, source::LineNumberNode, @nospecialize args...) - # args is (spvals..., argtypes...) - body = g.gen(args...) - file = source.file - file isa Symbol || (file = :none) - lam = Expr(:lambda, Expr(:argnames, g.argnames...).args, - Expr(:var"scope-block", - Expr(:block, - source, - Expr(:meta, :push_loc, file, :var"@generated body"), - Expr(:return, body), - Expr(:meta, :pop_loc)))) - spnames = g.spnames - if spnames === svec() - return lam - else - return Expr(Symbol("with-static-parameters"), lam, spnames...) - end -end - # If the generator is a subtype of this trait, inference caches the generated unoptimized # code, sacrificing memory space to improve the performance of subsequent inferences. # This tradeoff is not appropriate in general cases (e.g., for `GeneratedFunctionStub`s diff --git a/base/expr.jl b/base/expr.jl index 84078829f77ed..d71723ee26f1f 100644 --- a/base/expr.jl +++ b/base/expr.jl @@ -1654,3 +1654,46 @@ end function quoted(@nospecialize(x)) return is_self_quoting(x) ? x : QuoteNode(x) end + +# Implementation of generated functions +function generated_body_to_codeinfo(ex::Expr, defmod::Module, isva::Bool) + ci = ccall(:jl_expand, Any, (Any, Any), ex, defmod) + if !isa(ci, CodeInfo) + if isa(ci, Expr) && ci.head === :error + error("syntax: $(ci.args[1])") + end + error("The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator.") + end + ci.isva = isva + code = ci.code + bindings = IdSet{Core.Binding}() + for i = 1:length(code) + stmt = code[i] + if isa(stmt, GlobalRef) + push!(bindings, convert(Core.Binding, stmt)) + end + end + if !isempty(bindings) + ci.edges = Core.svec(bindings...) + end + return ci +end + +# invoke and wrap the results of @generated expression +function (g::Core.GeneratedFunctionStub)(world::UInt, source::Method, @nospecialize args...) + # args is (spvals..., argtypes...) + body = g.gen(args...) + file = source.file + file isa Symbol || (file = :none) + lam = Expr(:lambda, Expr(:argnames, g.argnames...).args, + Expr(:var"scope-block", + Expr(:block, + LineNumberNode(Int(source.line), source.file), + Expr(:meta, :push_loc, file, :var"@generated body"), + Expr(:return, body), + Expr(:meta, :pop_loc)))) + spnames = g.spnames + return generated_body_to_codeinfo(spnames === Core.svec() ? lam : Expr(Symbol("with-static-parameters"), lam, spnames...), + typename(typeof(g.gen)).module, + source.isva) +end diff --git a/base/invalidation.jl b/base/invalidation.jl index d9a0bd95c5159..36b867ede2868 100644 --- a/base/invalidation.jl +++ b/base/invalidation.jl @@ -93,26 +93,28 @@ function scan_edge_list(ci::Core.CodeInstance, binding::Core.Binding) end function invalidate_method_for_globalref!(gr::GlobalRef, method::Method, invalidated_bpart::Core.BindingPartition, new_max_world::UInt) + invalidate_all = false + binding = convert(Core.Binding, gr) if isdefined(method, :source) src = _uncompressed_ir(method) - binding = convert(Core.Binding, gr) old_stmts = src.code invalidate_all = should_invalidate_code_for_globalref(gr, src) - for mi in specializations(method) - isdefined(mi, :cache) || continue - ci = mi.cache - while true - if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, binding)) - ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world) - end - isdefined(ci, :next) || break - ci = ci.next + end + for mi in specializations(method) + isdefined(mi, :cache) || continue + ci = mi.cache + while true + if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, binding)) + ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world) end + isdefined(ci, :next) || break + ci = ci.next end end end function invalidate_code_for_globalref!(gr::GlobalRef, invalidated_bpart::Core.BindingPartition, new_max_world::UInt) + b = convert(Core.Binding, gr) try valid_in_valuepos = false foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable diff --git a/src/jl_exported_funcs.inc b/src/jl_exported_funcs.inc index 9e221420aa9f4..4d1ab94644e39 100644 --- a/src/jl_exported_funcs.inc +++ b/src/jl_exported_funcs.inc @@ -127,7 +127,6 @@ XX(jl_exit_on_sigint) \ XX(jl_exit_threaded_region) \ XX(jl_expand) \ - XX(jl_expand_and_resolve) \ XX(jl_expand_stmt) \ XX(jl_expand_stmt_with_loc) \ XX(jl_expand_with_loc) \ diff --git a/src/julia_internal.h b/src/julia_internal.h index 8d09861ba9cd5..9817c8cc8263b 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -722,6 +722,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void); JL_DLLEXPORT void jl_resolve_definition_effects_in_ir(jl_array_t *stmts, jl_module_t *m, jl_svec_t *sparam_vals, jl_value_t *binding_edge, int binding_effects); JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge); +JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge); int get_next_edge(jl_array_t *list, int i, jl_value_t** invokesig, jl_code_instance_t **caller) JL_NOTSAFEPOINT; int set_next_edge(jl_array_t *list, int i, jl_value_t *invokesig, jl_code_instance_t *caller); diff --git a/src/method.c b/src/method.c index a79c46d9dab42..8a14eb00182b1 100644 --- a/src/method.c +++ b/src/method.c @@ -604,8 +604,7 @@ static jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator, size_t totargs = 2 + n_sparams + def->nargs; JL_GC_PUSHARGS(gargs, totargs); gargs[0] = jl_box_ulong(world); - gargs[1] = jl_box_long(def->line); - gargs[1] = jl_new_struct(jl_linenumbernode_type, gargs[1], def->file); + gargs[1] = (jl_value_t*)def; memcpy(&gargs[2], jl_svec_data(sparam_vals), n_sparams * sizeof(void*)); memcpy(&gargs[2 + n_sparams], args, (def->nargs - def->isva) * sizeof(void*)); if (def->isva) @@ -615,23 +614,6 @@ static jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator, return code; } -// Lower `ex` into Julia IR, and (if it expands into a CodeInfo) resolve global-variable -// references in light of the provided type parameters. -// Like `jl_expand`, if there is an error expanding the provided expression, the return value -// will be an error expression (an `Expr` with `error_sym` as its head), which should be eval'd -// in the caller's context. -JL_DLLEXPORT jl_code_info_t *jl_expand_and_resolve(jl_value_t *ex, jl_module_t *module, - jl_svec_t *sparam_vals) { - jl_code_info_t *func = (jl_code_info_t*)jl_expand((jl_value_t*)ex, module); - JL_GC_PUSH1(&func); - if (jl_is_code_info(func)) { - jl_array_t *stmts = (jl_array_t*)func->code; - jl_resolve_definition_effects_in_ir(stmts, module, sparam_vals, NULL, 1); - } - JL_GC_POP(); - return func; -} - JL_DLLEXPORT jl_code_instance_t *jl_cached_uninferred(jl_code_instance_t *codeinst, size_t world) { for (; codeinst; codeinst = jl_atomic_load_relaxed(&codeinst->next)) { @@ -703,25 +685,12 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t ex = jl_call_staged(def, generator, world, mi->sparam_vals, jl_svec_data(ttdt->parameters), jl_nparams(ttdt)); // do some post-processing - if (jl_is_code_info(ex)) { - func = (jl_code_info_t*)ex; - jl_array_t *stmts = (jl_array_t*)func->code; - jl_resolve_definition_effects_in_ir(stmts, def->module, mi->sparam_vals, NULL, 1); - } - else { - // Lower the user's expression and resolve references to the type parameters - func = jl_expand_and_resolve(ex, def->module, mi->sparam_vals); - if (!jl_is_code_info(func)) { - if (jl_is_expr(func) && ((jl_expr_t*)func)->head == jl_error_sym) { - ct->ptls->in_pure_callback = 0; - jl_toplevel_eval(def->module, (jl_value_t*)func); - } - jl_error("The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator."); - } - // TODO: This should ideally be in the lambda expression, - // but currently our isva determination is non-syntactic - func->isva = def->isva; + if (!jl_is_code_info(ex)) { + jl_error("As of Julia 1.12, generated functions must return `CodeInfo`. See `Base.generated_body_to_codeinfo`."); } + func = (jl_code_info_t*)ex; + jl_array_t *stmts = (jl_array_t*)func->code; + jl_resolve_definition_effects_in_ir(stmts, def->module, mi->sparam_vals, NULL, 1); ex = NULL; // If this generated function has an opaque closure, cache it for @@ -778,6 +747,9 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t if (jl_is_method_instance(kind)) { jl_method_instance_add_backedge((jl_method_instance_t*)kind, jl_nothing, ci); } + else if (jl_is_binding(kind)) { + jl_add_binding_backedge((jl_binding_t*)kind, (jl_value_t*)ci); + } else if (jl_is_mtable(kind)) { assert(i < l); ex = data[i++]; diff --git a/src/module.c b/src/module.c index 6ff43da5f9fc4..b2a4018519fca 100644 --- a/src/module.c +++ b/src/module.c @@ -1099,6 +1099,20 @@ void jl_invalidate_binding_refs(jl_globalref_t *ref, jl_binding_partition_t *inv JL_GC_POP(); } +JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge) +{ + if (!b->backedges) { + b->backedges = jl_alloc_vec_any(0); + jl_gc_wb(b, b->backedges); + } else if (jl_array_len(b->backedges) > 0 && + jl_array_ptr_ref(b->backedges, jl_array_len(b->backedges)-1) == edge) { + // Optimization: Deduplicate repeated insertion of the same edge (e.g. during + // definition of a method that contains many references to the same global) + return; + } + jl_array_ptr_1d_push(b->backedges, edge); +} + // Called for all GlobalRefs found in lowered code. Adds backedges for cross-module // GlobalRefs. JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge) @@ -1114,16 +1128,7 @@ JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t jl_binding_t *b = gr->binding; if (!b) b = jl_get_module_binding(gr->mod, gr->name, 1); - if (!b->backedges) { - b->backedges = jl_alloc_vec_any(0); - jl_gc_wb(b, b->backedges); - } else if (jl_array_len(b->backedges) > 0 && - jl_array_ptr_ref(b->backedges, jl_array_len(b->backedges)-1) == edge) { - // Optimization: Deduplicate repeated insertion of the same edge (e.g. during - // definition of a method that contains many references to the same global) - return; - } - jl_array_ptr_1d_push(b->backedges, edge); + jl_add_binding_backedge(b, edge); } JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr) diff --git a/test/rebinding.jl b/test/rebinding.jl index ed8e96ba30c62..aee866facaf02 100644 --- a/test/rebinding.jl +++ b/test/rebinding.jl @@ -55,6 +55,13 @@ module Rebinding @test f_return_delete_me_indirect() == 3 Base.delete_binding(@__MODULE__, :delete_me) @test_throws UndefVarError f_return_delete_me_indirect() + + # + via generated function + const delete_me = 4 + @generated f_generated_return_delete_me() = return :(delete_me) + @test f_generated_return_delete_me() == 4 + Base.delete_binding(@__MODULE__, :delete_me) + @test_throws UndefVarError f_generated_return_delete_me() end module RebindingPrecompile diff --git a/test/staged.jl b/test/staged.jl index 6cb99950a7bb2..f3dbdcd73d811 100644 --- a/test/staged.jl +++ b/test/staged.jl @@ -381,7 +381,7 @@ let @test length(ir.cfg.blocks) == 1 end -function generate_lambda_ex(world::UInt, source::LineNumberNode, +function generate_lambda_ex(world::UInt, source::Method, argnames, spnames, @nospecialize body) stub = Core.GeneratedFunctionStub(identity, Core.svec(argnames...), Core.svec(spnames...)) return stub(world, source, body) @@ -389,7 +389,7 @@ end # Test that `Core.CachedGenerator` works as expected struct Generator54916 <: Core.CachedGenerator end -function (::Generator54916)(world::UInt, source::LineNumberNode, args...) +function (::Generator54916)(world::UInt, source::Method, args...) return generate_lambda_ex(world, source, (:doit54916, :func, :arg), (), :(func(arg))) end @@ -432,7 +432,7 @@ function overdubbee54341(a, b) a + b end const overdubee_codeinfo54341 = code_lowered(overdubbee54341, Tuple{Any, Any})[1] -function overdub_generator54341(world::UInt, source::LineNumberNode, selftype, fargtypes) +function overdub_generator54341(world::UInt, source::Method, selftype, fargtypes) if length(fargtypes) != 2 return generate_lambda_ex(world, source, (:overdub54341, :args), (), :(error("Wrong number of arguments")))