diff --git a/base/Base_compiler.jl b/base/Base_compiler.jl index 8cc7096ee26a7..4ec6bae171d8f 100644 --- a/base/Base_compiler.jl +++ b/base/Base_compiler.jl @@ -234,13 +234,13 @@ include("abstractarray.jl") include("baseext.jl") include("c.jl") -include("ntuple.jl") include("abstractset.jl") include("bitarray.jl") include("bitset.jl") include("abstractdict.jl") include("iddict.jl") include("idset.jl") +include("ntuple.jl") include("iterators.jl") using .Iterators: zip, enumerate, only using .Iterators: Flatten, Filter, product # for generators diff --git a/base/boot.jl b/base/boot.jl index 9b386f90d4abe..e50d74659d399 100644 --- a/base/boot.jl +++ b/base/boot.jl @@ -777,27 +777,6 @@ struct GeneratedFunctionStub spnames::SimpleVector end -# invoke and wrap the results of @generated expression -function (g::GeneratedFunctionStub)(world::UInt, source::LineNumberNode, @nospecialize args...) - # args is (spvals..., argtypes...) - body = g.gen(args...) - file = source.file - file isa Symbol || (file = :none) - lam = Expr(:lambda, Expr(:argnames, g.argnames...).args, - Expr(:var"scope-block", - Expr(:block, - source, - Expr(:meta, :push_loc, file, :var"@generated body"), - Expr(:return, body), - Expr(:meta, :pop_loc)))) - spnames = g.spnames - if spnames === svec() - return lam - else - return Expr(Symbol("with-static-parameters"), lam, spnames...) - end -end - # If the generator is a subtype of this trait, inference caches the generated unoptimized # code, sacrificing memory space to improve the performance of subsequent inferences. # This tradeoff is not appropriate in general cases (e.g., for `GeneratedFunctionStub`s diff --git a/base/expr.jl b/base/expr.jl index 84078829f77ed..d71723ee26f1f 100644 --- a/base/expr.jl +++ b/base/expr.jl @@ -1654,3 +1654,46 @@ end function quoted(@nospecialize(x)) return is_self_quoting(x) ? x : QuoteNode(x) end + +# Implementation of generated functions +function generated_body_to_codeinfo(ex::Expr, defmod::Module, isva::Bool) + ci = ccall(:jl_expand, Any, (Any, Any), ex, defmod) + if !isa(ci, CodeInfo) + if isa(ci, Expr) && ci.head === :error + error("syntax: $(ci.args[1])") + end + error("The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator.") + end + ci.isva = isva + code = ci.code + bindings = IdSet{Core.Binding}() + for i = 1:length(code) + stmt = code[i] + if isa(stmt, GlobalRef) + push!(bindings, convert(Core.Binding, stmt)) + end + end + if !isempty(bindings) + ci.edges = Core.svec(bindings...) + end + return ci +end + +# invoke and wrap the results of @generated expression +function (g::Core.GeneratedFunctionStub)(world::UInt, source::Method, @nospecialize args...) + # args is (spvals..., argtypes...) + body = g.gen(args...) + file = source.file + file isa Symbol || (file = :none) + lam = Expr(:lambda, Expr(:argnames, g.argnames...).args, + Expr(:var"scope-block", + Expr(:block, + LineNumberNode(Int(source.line), source.file), + Expr(:meta, :push_loc, file, :var"@generated body"), + Expr(:return, body), + Expr(:meta, :pop_loc)))) + spnames = g.spnames + return generated_body_to_codeinfo(spnames === Core.svec() ? lam : Expr(Symbol("with-static-parameters"), lam, spnames...), + typename(typeof(g.gen)).module, + source.isva) +end diff --git a/base/invalidation.jl b/base/invalidation.jl index d9a0bd95c5159..36b867ede2868 100644 --- a/base/invalidation.jl +++ b/base/invalidation.jl @@ -93,26 +93,28 @@ function scan_edge_list(ci::Core.CodeInstance, binding::Core.Binding) end function invalidate_method_for_globalref!(gr::GlobalRef, method::Method, invalidated_bpart::Core.BindingPartition, new_max_world::UInt) + invalidate_all = false + binding = convert(Core.Binding, gr) if isdefined(method, :source) src = _uncompressed_ir(method) - binding = convert(Core.Binding, gr) old_stmts = src.code invalidate_all = should_invalidate_code_for_globalref(gr, src) - for mi in specializations(method) - isdefined(mi, :cache) || continue - ci = mi.cache - while true - if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, binding)) - ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world) - end - isdefined(ci, :next) || break - ci = ci.next + end + for mi in specializations(method) + isdefined(mi, :cache) || continue + ci = mi.cache + while true + if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, binding)) + ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world) end + isdefined(ci, :next) || break + ci = ci.next end end end function invalidate_code_for_globalref!(gr::GlobalRef, invalidated_bpart::Core.BindingPartition, new_max_world::UInt) + b = convert(Core.Binding, gr) try valid_in_valuepos = false foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable diff --git a/src/jl_exported_funcs.inc b/src/jl_exported_funcs.inc index 9e221420aa9f4..4d1ab94644e39 100644 --- a/src/jl_exported_funcs.inc +++ b/src/jl_exported_funcs.inc @@ -127,7 +127,6 @@ XX(jl_exit_on_sigint) \ XX(jl_exit_threaded_region) \ XX(jl_expand) \ - XX(jl_expand_and_resolve) \ XX(jl_expand_stmt) \ XX(jl_expand_stmt_with_loc) \ XX(jl_expand_with_loc) \ diff --git a/src/julia_internal.h b/src/julia_internal.h index 8d09861ba9cd5..9817c8cc8263b 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -722,6 +722,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void); JL_DLLEXPORT void jl_resolve_definition_effects_in_ir(jl_array_t *stmts, jl_module_t *m, jl_svec_t *sparam_vals, jl_value_t *binding_edge, int binding_effects); JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge); +JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge); int get_next_edge(jl_array_t *list, int i, jl_value_t** invokesig, jl_code_instance_t **caller) JL_NOTSAFEPOINT; int set_next_edge(jl_array_t *list, int i, jl_value_t *invokesig, jl_code_instance_t *caller); diff --git a/src/method.c b/src/method.c index a79c46d9dab42..8a14eb00182b1 100644 --- a/src/method.c +++ b/src/method.c @@ -604,8 +604,7 @@ static jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator, size_t totargs = 2 + n_sparams + def->nargs; JL_GC_PUSHARGS(gargs, totargs); gargs[0] = jl_box_ulong(world); - gargs[1] = jl_box_long(def->line); - gargs[1] = jl_new_struct(jl_linenumbernode_type, gargs[1], def->file); + gargs[1] = (jl_value_t*)def; memcpy(&gargs[2], jl_svec_data(sparam_vals), n_sparams * sizeof(void*)); memcpy(&gargs[2 + n_sparams], args, (def->nargs - def->isva) * sizeof(void*)); if (def->isva) @@ -615,23 +614,6 @@ static jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator, return code; } -// Lower `ex` into Julia IR, and (if it expands into a CodeInfo) resolve global-variable -// references in light of the provided type parameters. -// Like `jl_expand`, if there is an error expanding the provided expression, the return value -// will be an error expression (an `Expr` with `error_sym` as its head), which should be eval'd -// in the caller's context. -JL_DLLEXPORT jl_code_info_t *jl_expand_and_resolve(jl_value_t *ex, jl_module_t *module, - jl_svec_t *sparam_vals) { - jl_code_info_t *func = (jl_code_info_t*)jl_expand((jl_value_t*)ex, module); - JL_GC_PUSH1(&func); - if (jl_is_code_info(func)) { - jl_array_t *stmts = (jl_array_t*)func->code; - jl_resolve_definition_effects_in_ir(stmts, module, sparam_vals, NULL, 1); - } - JL_GC_POP(); - return func; -} - JL_DLLEXPORT jl_code_instance_t *jl_cached_uninferred(jl_code_instance_t *codeinst, size_t world) { for (; codeinst; codeinst = jl_atomic_load_relaxed(&codeinst->next)) { @@ -703,25 +685,12 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t ex = jl_call_staged(def, generator, world, mi->sparam_vals, jl_svec_data(ttdt->parameters), jl_nparams(ttdt)); // do some post-processing - if (jl_is_code_info(ex)) { - func = (jl_code_info_t*)ex; - jl_array_t *stmts = (jl_array_t*)func->code; - jl_resolve_definition_effects_in_ir(stmts, def->module, mi->sparam_vals, NULL, 1); - } - else { - // Lower the user's expression and resolve references to the type parameters - func = jl_expand_and_resolve(ex, def->module, mi->sparam_vals); - if (!jl_is_code_info(func)) { - if (jl_is_expr(func) && ((jl_expr_t*)func)->head == jl_error_sym) { - ct->ptls->in_pure_callback = 0; - jl_toplevel_eval(def->module, (jl_value_t*)func); - } - jl_error("The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator."); - } - // TODO: This should ideally be in the lambda expression, - // but currently our isva determination is non-syntactic - func->isva = def->isva; + if (!jl_is_code_info(ex)) { + jl_error("As of Julia 1.12, generated functions must return `CodeInfo`. See `Base.generated_body_to_codeinfo`."); } + func = (jl_code_info_t*)ex; + jl_array_t *stmts = (jl_array_t*)func->code; + jl_resolve_definition_effects_in_ir(stmts, def->module, mi->sparam_vals, NULL, 1); ex = NULL; // If this generated function has an opaque closure, cache it for @@ -778,6 +747,9 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t if (jl_is_method_instance(kind)) { jl_method_instance_add_backedge((jl_method_instance_t*)kind, jl_nothing, ci); } + else if (jl_is_binding(kind)) { + jl_add_binding_backedge((jl_binding_t*)kind, (jl_value_t*)ci); + } else if (jl_is_mtable(kind)) { assert(i < l); ex = data[i++]; diff --git a/src/module.c b/src/module.c index 6ff43da5f9fc4..b2a4018519fca 100644 --- a/src/module.c +++ b/src/module.c @@ -1099,6 +1099,20 @@ void jl_invalidate_binding_refs(jl_globalref_t *ref, jl_binding_partition_t *inv JL_GC_POP(); } +JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge) +{ + if (!b->backedges) { + b->backedges = jl_alloc_vec_any(0); + jl_gc_wb(b, b->backedges); + } else if (jl_array_len(b->backedges) > 0 && + jl_array_ptr_ref(b->backedges, jl_array_len(b->backedges)-1) == edge) { + // Optimization: Deduplicate repeated insertion of the same edge (e.g. during + // definition of a method that contains many references to the same global) + return; + } + jl_array_ptr_1d_push(b->backedges, edge); +} + // Called for all GlobalRefs found in lowered code. Adds backedges for cross-module // GlobalRefs. JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge) @@ -1114,16 +1128,7 @@ JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t jl_binding_t *b = gr->binding; if (!b) b = jl_get_module_binding(gr->mod, gr->name, 1); - if (!b->backedges) { - b->backedges = jl_alloc_vec_any(0); - jl_gc_wb(b, b->backedges); - } else if (jl_array_len(b->backedges) > 0 && - jl_array_ptr_ref(b->backedges, jl_array_len(b->backedges)-1) == edge) { - // Optimization: Deduplicate repeated insertion of the same edge (e.g. during - // definition of a method that contains many references to the same global) - return; - } - jl_array_ptr_1d_push(b->backedges, edge); + jl_add_binding_backedge(b, edge); } JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr) diff --git a/test/rebinding.jl b/test/rebinding.jl index ed8e96ba30c62..aee866facaf02 100644 --- a/test/rebinding.jl +++ b/test/rebinding.jl @@ -55,6 +55,13 @@ module Rebinding @test f_return_delete_me_indirect() == 3 Base.delete_binding(@__MODULE__, :delete_me) @test_throws UndefVarError f_return_delete_me_indirect() + + # + via generated function + const delete_me = 4 + @generated f_generated_return_delete_me() = return :(delete_me) + @test f_generated_return_delete_me() == 4 + Base.delete_binding(@__MODULE__, :delete_me) + @test_throws UndefVarError f_generated_return_delete_me() end module RebindingPrecompile diff --git a/test/staged.jl b/test/staged.jl index 6cb99950a7bb2..f3dbdcd73d811 100644 --- a/test/staged.jl +++ b/test/staged.jl @@ -381,7 +381,7 @@ let @test length(ir.cfg.blocks) == 1 end -function generate_lambda_ex(world::UInt, source::LineNumberNode, +function generate_lambda_ex(world::UInt, source::Method, argnames, spnames, @nospecialize body) stub = Core.GeneratedFunctionStub(identity, Core.svec(argnames...), Core.svec(spnames...)) return stub(world, source, body) @@ -389,7 +389,7 @@ end # Test that `Core.CachedGenerator` works as expected struct Generator54916 <: Core.CachedGenerator end -function (::Generator54916)(world::UInt, source::LineNumberNode, args...) +function (::Generator54916)(world::UInt, source::Method, args...) return generate_lambda_ex(world, source, (:doit54916, :func, :arg), (), :(func(arg))) end @@ -432,7 +432,7 @@ function overdubbee54341(a, b) a + b end const overdubee_codeinfo54341 = code_lowered(overdubbee54341, Tuple{Any, Any})[1] -function overdub_generator54341(world::UInt, source::LineNumberNode, selftype, fargtypes) +function overdub_generator54341(world::UInt, source::Method, selftype, fargtypes) if length(fargtypes) != 2 return generate_lambda_ex(world, source, (:overdub54341, :args), (), :(error("Wrong number of arguments")))