Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor generated function implementation to provide bindings/method #57230

Merged
merged 1 commit into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion base/Base_compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,13 @@ include("abstractarray.jl")
include("baseext.jl")

include("c.jl")
include("ntuple.jl")
include("abstractset.jl")
include("bitarray.jl")
include("bitset.jl")
include("abstractdict.jl")
include("iddict.jl")
include("idset.jl")
include("ntuple.jl")
include("iterators.jl")
using .Iterators: zip, enumerate, only
using .Iterators: Flatten, Filter, product # for generators
Expand Down
21 changes: 0 additions & 21 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -777,27 +777,6 @@ struct GeneratedFunctionStub
spnames::SimpleVector
end

# invoke and wrap the results of @generated expression
function (g::GeneratedFunctionStub)(world::UInt, source::LineNumberNode, @nospecialize args...)
# args is (spvals..., argtypes...)
body = g.gen(args...)
file = source.file
file isa Symbol || (file = :none)
lam = Expr(:lambda, Expr(:argnames, g.argnames...).args,
Expr(:var"scope-block",
Expr(:block,
source,
Expr(:meta, :push_loc, file, :var"@generated body"),
Expr(:return, body),
Expr(:meta, :pop_loc))))
spnames = g.spnames
if spnames === svec()
return lam
else
return Expr(Symbol("with-static-parameters"), lam, spnames...)
end
end

# If the generator is a subtype of this trait, inference caches the generated unoptimized
# code, sacrificing memory space to improve the performance of subsequent inferences.
# This tradeoff is not appropriate in general cases (e.g., for `GeneratedFunctionStub`s
Expand Down
43 changes: 43 additions & 0 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1654,3 +1654,46 @@ end
function quoted(@nospecialize(x))
return is_self_quoting(x) ? x : QuoteNode(x)
end

# Implementation of generated functions
function generated_body_to_codeinfo(ex::Expr, defmod::Module, isva::Bool)
ci = ccall(:jl_expand, Any, (Any, Any), ex, defmod)
if !isa(ci, CodeInfo)
if isa(ci, Expr) && ci.head === :error
error("syntax: $(ci.args[1])")
end
error("The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator.")
end
ci.isva = isva
code = ci.code
bindings = IdSet{Core.Binding}()
for i = 1:length(code)
stmt = code[i]
if isa(stmt, GlobalRef)
push!(bindings, convert(Core.Binding, stmt))
end
end
if !isempty(bindings)
ci.edges = Core.svec(bindings...)
end
return ci
end

# invoke and wrap the results of @generated expression
function (g::Core.GeneratedFunctionStub)(world::UInt, source::Method, @nospecialize args...)
# args is (spvals..., argtypes...)
body = g.gen(args...)
file = source.file
file isa Symbol || (file = :none)
lam = Expr(:lambda, Expr(:argnames, g.argnames...).args,
Expr(:var"scope-block",
Expr(:block,
LineNumberNode(Int(source.line), source.file),
Expr(:meta, :push_loc, file, :var"@generated body"),
Expr(:return, body),
Expr(:meta, :pop_loc))))
spnames = g.spnames
return generated_body_to_codeinfo(spnames === Core.svec() ? lam : Expr(Symbol("with-static-parameters"), lam, spnames...),
typename(typeof(g.gen)).module,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why aren't we using source.module here anymore? I suppose they probably usually resolve to the same thing, but might be good to document the intent of this change anyways?

source.isva)
end
22 changes: 12 additions & 10 deletions base/invalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,28 @@ function scan_edge_list(ci::Core.CodeInstance, binding::Core.Binding)
end

function invalidate_method_for_globalref!(gr::GlobalRef, method::Method, invalidated_bpart::Core.BindingPartition, new_max_world::UInt)
invalidate_all = false
binding = convert(Core.Binding, gr)
if isdefined(method, :source)
src = _uncompressed_ir(method)
binding = convert(Core.Binding, gr)
old_stmts = src.code
invalidate_all = should_invalidate_code_for_globalref(gr, src)
for mi in specializations(method)
isdefined(mi, :cache) || continue
ci = mi.cache
while true
if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, binding))
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
end
isdefined(ci, :next) || break
ci = ci.next
end
for mi in specializations(method)
isdefined(mi, :cache) || continue
ci = mi.cache
while true
if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, binding))
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
end
isdefined(ci, :next) || break
ci = ci.next
end
end
end

function invalidate_code_for_globalref!(gr::GlobalRef, invalidated_bpart::Core.BindingPartition, new_max_world::UInt)
b = convert(Core.Binding, gr)
try
valid_in_valuepos = false
foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable
Expand Down
1 change: 0 additions & 1 deletion src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@
XX(jl_exit_on_sigint) \
XX(jl_exit_threaded_region) \
XX(jl_expand) \
XX(jl_expand_and_resolve) \
XX(jl_expand_stmt) \
XX(jl_expand_stmt_with_loc) \
XX(jl_expand_with_loc) \
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void);
JL_DLLEXPORT void jl_resolve_definition_effects_in_ir(jl_array_t *stmts, jl_module_t *m, jl_svec_t *sparam_vals, jl_value_t *binding_edge,
int binding_effects);
JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge);
JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge);

int get_next_edge(jl_array_t *list, int i, jl_value_t** invokesig, jl_code_instance_t **caller) JL_NOTSAFEPOINT;
int set_next_edge(jl_array_t *list, int i, jl_value_t *invokesig, jl_code_instance_t *caller);
Expand Down
46 changes: 9 additions & 37 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,7 @@ static jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator,
size_t totargs = 2 + n_sparams + def->nargs;
JL_GC_PUSHARGS(gargs, totargs);
gargs[0] = jl_box_ulong(world);
gargs[1] = jl_box_long(def->line);
gargs[1] = jl_new_struct(jl_linenumbernode_type, gargs[1], def->file);
gargs[1] = (jl_value_t*)def;
memcpy(&gargs[2], jl_svec_data(sparam_vals), n_sparams * sizeof(void*));
memcpy(&gargs[2 + n_sparams], args, (def->nargs - def->isva) * sizeof(void*));
if (def->isva)
Expand All @@ -615,23 +614,6 @@ static jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator,
return code;
}

// Lower `ex` into Julia IR, and (if it expands into a CodeInfo) resolve global-variable
// references in light of the provided type parameters.
// Like `jl_expand`, if there is an error expanding the provided expression, the return value
// will be an error expression (an `Expr` with `error_sym` as its head), which should be eval'd
// in the caller's context.
JL_DLLEXPORT jl_code_info_t *jl_expand_and_resolve(jl_value_t *ex, jl_module_t *module,
jl_svec_t *sparam_vals) {
jl_code_info_t *func = (jl_code_info_t*)jl_expand((jl_value_t*)ex, module);
JL_GC_PUSH1(&func);
if (jl_is_code_info(func)) {
jl_array_t *stmts = (jl_array_t*)func->code;
jl_resolve_definition_effects_in_ir(stmts, module, sparam_vals, NULL, 1);
}
JL_GC_POP();
return func;
}

JL_DLLEXPORT jl_code_instance_t *jl_cached_uninferred(jl_code_instance_t *codeinst, size_t world)
{
for (; codeinst; codeinst = jl_atomic_load_relaxed(&codeinst->next)) {
Expand Down Expand Up @@ -703,25 +685,12 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t
ex = jl_call_staged(def, generator, world, mi->sparam_vals, jl_svec_data(ttdt->parameters), jl_nparams(ttdt));

// do some post-processing
if (jl_is_code_info(ex)) {
func = (jl_code_info_t*)ex;
jl_array_t *stmts = (jl_array_t*)func->code;
jl_resolve_definition_effects_in_ir(stmts, def->module, mi->sparam_vals, NULL, 1);
}
else {
// Lower the user's expression and resolve references to the type parameters
func = jl_expand_and_resolve(ex, def->module, mi->sparam_vals);
if (!jl_is_code_info(func)) {
if (jl_is_expr(func) && ((jl_expr_t*)func)->head == jl_error_sym) {
ct->ptls->in_pure_callback = 0;
jl_toplevel_eval(def->module, (jl_value_t*)func);
}
jl_error("The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator.");
}
// TODO: This should ideally be in the lambda expression,
// but currently our isva determination is non-syntactic
func->isva = def->isva;
if (!jl_is_code_info(ex)) {
jl_error("As of Julia 1.12, generated functions must return `CodeInfo`. See `Base.generated_body_to_codeinfo`.");
}
func = (jl_code_info_t*)ex;
jl_array_t *stmts = (jl_array_t*)func->code;
jl_resolve_definition_effects_in_ir(stmts, def->module, mi->sparam_vals, NULL, 1);
ex = NULL;

// If this generated function has an opaque closure, cache it for
Expand Down Expand Up @@ -778,6 +747,9 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t
if (jl_is_method_instance(kind)) {
jl_method_instance_add_backedge((jl_method_instance_t*)kind, jl_nothing, ci);
}
else if (jl_is_binding(kind)) {
jl_add_binding_backedge((jl_binding_t*)kind, (jl_value_t*)ci);
}
else if (jl_is_mtable(kind)) {
assert(i < l);
ex = data[i++];
Expand Down
25 changes: 15 additions & 10 deletions src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,20 @@ void jl_invalidate_binding_refs(jl_globalref_t *ref, jl_binding_partition_t *inv
JL_GC_POP();
}

JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge)
{
if (!b->backedges) {
b->backedges = jl_alloc_vec_any(0);
Copy link
Member

@vtjnash vtjnash Feb 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there is no lock on this, so any uses of ./julia (which now defaults to using threads) and which run any code now may trigger this UB memory corruption?

jl_gc_wb(b, b->backedges);
} else if (jl_array_len(b->backedges) > 0 &&
jl_array_ptr_ref(b->backedges, jl_array_len(b->backedges)-1) == edge) {
// Optimization: Deduplicate repeated insertion of the same edge (e.g. during
// definition of a method that contains many references to the same global)
return;
}
jl_array_ptr_1d_push(b->backedges, edge);
}

// Called for all GlobalRefs found in lowered code. Adds backedges for cross-module
// GlobalRefs.
JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge)
Expand All @@ -1114,16 +1128,7 @@ JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t
jl_binding_t *b = gr->binding;
if (!b)
b = jl_get_module_binding(gr->mod, gr->name, 1);
if (!b->backedges) {
b->backedges = jl_alloc_vec_any(0);
jl_gc_wb(b, b->backedges);
} else if (jl_array_len(b->backedges) > 0 &&
jl_array_ptr_ref(b->backedges, jl_array_len(b->backedges)-1) == edge) {
// Optimization: Deduplicate repeated insertion of the same edge (e.g. during
// definition of a method that contains many references to the same global)
return;
}
jl_array_ptr_1d_push(b->backedges, edge);
jl_add_binding_backedge(b, edge);
}

JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr)
Expand Down
7 changes: 7 additions & 0 deletions test/rebinding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ module Rebinding
@test f_return_delete_me_indirect() == 3
Base.delete_binding(@__MODULE__, :delete_me)
@test_throws UndefVarError f_return_delete_me_indirect()

# + via generated function
const delete_me = 4
@generated f_generated_return_delete_me() = return :(delete_me)
@test f_generated_return_delete_me() == 4
Base.delete_binding(@__MODULE__, :delete_me)
@test_throws UndefVarError f_generated_return_delete_me()
end

module RebindingPrecompile
Expand Down
6 changes: 3 additions & 3 deletions test/staged.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,15 @@ let
@test length(ir.cfg.blocks) == 1
end

function generate_lambda_ex(world::UInt, source::LineNumberNode,
function generate_lambda_ex(world::UInt, source::Method,
argnames, spnames, @nospecialize body)
stub = Core.GeneratedFunctionStub(identity, Core.svec(argnames...), Core.svec(spnames...))
return stub(world, source, body)
end

# Test that `Core.CachedGenerator` works as expected
struct Generator54916 <: Core.CachedGenerator end
function (::Generator54916)(world::UInt, source::LineNumberNode, args...)
function (::Generator54916)(world::UInt, source::Method, args...)
return generate_lambda_ex(world, source,
(:doit54916, :func, :arg), (), :(func(arg)))
end
Expand Down Expand Up @@ -432,7 +432,7 @@ function overdubbee54341(a, b)
a + b
end
const overdubee_codeinfo54341 = code_lowered(overdubbee54341, Tuple{Any, Any})[1]
function overdub_generator54341(world::UInt, source::LineNumberNode, selftype, fargtypes)
function overdub_generator54341(world::UInt, source::Method, selftype, fargtypes)
if length(fargtypes) != 2
return generate_lambda_ex(world, source,
(:overdub54341, :args), (), :(error("Wrong number of arguments")))
Expand Down