Skip to content

Commit

Permalink
Refactor generated function implementation to provide bindings/method
Browse files Browse the repository at this point in the history
This PR refactors the generated function implementation in multiple ways:

1. Rather than allocating a new LineNumber node to pass to the generator,
   we just pass the original method from which this LineNumberNode was
   constructed. This has been a bit of a longer-standing annoyance of mine,
   since the generator needs to know properties of the original method to
   properly interpret the return value from the generator, but this
   information was only available on the C side.

2. Move the handling of `Expr` returns fully into Julia. Right not things
   were a bit split with the julia code post-processing an `Expr` return,
   but then handing it back to C for lowering. By moving it fully into
   Julia, we can keep the C-side interface simpler by always getting a
   `CodeInfo`.

With these refactorings done, amend the post-processing code to provide
binding edges for `Expr` returns. Ordinarily, bindings in lowered
code do not need edges, because we will scan the lowered code of
the method to find them. However, generated functions are different,
because we do not in general have the lowered code available.
To still give them binding edges, we simply scan through the
post-lowered code and all of the bindings we find into the edges array.

I will note that both of these will require minor adjustments to
`@generated` functions that use the CodeInfo interface (N.B.: this
interface is not considered stable and we've broken it in almost
every release so far). In particular, the following adjustments
need to be made:

1. Adjusting the `source` argument to the new `Method` ABI
2. If necessary, adding any edges that correspond to GlobalRefs used -
   the code will treat the returned CodeInfo mostly opaquely and
   (unlike in the `Expr` case) will not automatically compute these edges.
  • Loading branch information
Keno committed Feb 2, 2025
1 parent 12698af commit 30c6ff9
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 83 deletions.
2 changes: 1 addition & 1 deletion base/Base_compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,13 +234,13 @@ include("abstractarray.jl")
include("baseext.jl")

include("c.jl")
include("ntuple.jl")
include("abstractset.jl")
include("bitarray.jl")
include("bitset.jl")
include("abstractdict.jl")
include("iddict.jl")
include("idset.jl")
include("ntuple.jl")
include("iterators.jl")
using .Iterators: zip, enumerate, only
using .Iterators: Flatten, Filter, product # for generators
Expand Down
21 changes: 0 additions & 21 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -777,27 +777,6 @@ struct GeneratedFunctionStub
spnames::SimpleVector
end

# invoke and wrap the results of @generated expression
function (g::GeneratedFunctionStub)(world::UInt, source::LineNumberNode, @nospecialize args...)
# args is (spvals..., argtypes...)
body = g.gen(args...)
file = source.file
file isa Symbol || (file = :none)
lam = Expr(:lambda, Expr(:argnames, g.argnames...).args,
Expr(:var"scope-block",
Expr(:block,
source,
Expr(:meta, :push_loc, file, :var"@generated body"),
Expr(:return, body),
Expr(:meta, :pop_loc))))
spnames = g.spnames
if spnames === svec()
return lam
else
return Expr(Symbol("with-static-parameters"), lam, spnames...)
end
end

# If the generator is a subtype of this trait, inference caches the generated unoptimized
# code, sacrificing memory space to improve the performance of subsequent inferences.
# This tradeoff is not appropriate in general cases (e.g., for `GeneratedFunctionStub`s
Expand Down
43 changes: 43 additions & 0 deletions base/expr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1654,3 +1654,46 @@ end
function quoted(@nospecialize(x))
return is_self_quoting(x) ? x : QuoteNode(x)
end

# Implementation of generated functions
function generated_body_to_codeinfo(ex::Expr, defmod::Module, isva::Bool)
ci = ccall(:jl_expand, Any, (Any, Any), ex, defmod)
if !isa(ci, CodeInfo)
if isa(ci, Expr) && ci.head === :error
error("syntax: $(ci.args[1])")
end
error("The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator.")
end
ci.isva = isva
code = ci.code
bindings = IdSet{Core.Binding}()
for i = 1:length(code)
stmt = code[i]
if isa(stmt, GlobalRef)
push!(bindings, convert(Core.Binding, stmt))
end
end
if !isempty(bindings)
ci.edges = Core.svec(bindings...)
end
return ci
end

# invoke and wrap the results of @generated expression
function (g::Core.GeneratedFunctionStub)(world::UInt, source::Method, @nospecialize args...)
# args is (spvals..., argtypes...)
body = g.gen(args...)
file = source.file
file isa Symbol || (file = :none)
lam = Expr(:lambda, Expr(:argnames, g.argnames...).args,
Expr(:var"scope-block",
Expr(:block,
LineNumberNode(Int(source.line), source.file),
Expr(:meta, :push_loc, file, :var"@generated body"),
Expr(:return, body),
Expr(:meta, :pop_loc))))
spnames = g.spnames
return generated_body_to_codeinfo(spnames === Core.svec() ? lam : Expr(Symbol("with-static-parameters"), lam, spnames...),
typename(typeof(g.gen)).module,
source.isva)
end
22 changes: 12 additions & 10 deletions base/invalidation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,26 +93,28 @@ function scan_edge_list(ci::Core.CodeInstance, binding::Core.Binding)
end

function invalidate_method_for_globalref!(gr::GlobalRef, method::Method, invalidated_bpart::Core.BindingPartition, new_max_world::UInt)
invalidate_all = false
binding = convert(Core.Binding, gr)
if isdefined(method, :source)
src = _uncompressed_ir(method)
binding = convert(Core.Binding, gr)
old_stmts = src.code
invalidate_all = should_invalidate_code_for_globalref(gr, src)
for mi in specializations(method)
isdefined(mi, :cache) || continue
ci = mi.cache
while true
if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, binding))
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
end
isdefined(ci, :next) || break
ci = ci.next
end
for mi in specializations(method)
isdefined(mi, :cache) || continue
ci = mi.cache
while true
if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, binding))
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
end
isdefined(ci, :next) || break
ci = ci.next
end
end
end

function invalidate_code_for_globalref!(gr::GlobalRef, invalidated_bpart::Core.BindingPartition, new_max_world::UInt)
b = convert(Core.Binding, gr)
try
valid_in_valuepos = false
foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable
Expand Down
1 change: 0 additions & 1 deletion src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@
XX(jl_exit_on_sigint) \
XX(jl_exit_threaded_region) \
XX(jl_expand) \
XX(jl_expand_and_resolve) \
XX(jl_expand_stmt) \
XX(jl_expand_stmt_with_loc) \
XX(jl_expand_with_loc) \
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void);
JL_DLLEXPORT void jl_resolve_definition_effects_in_ir(jl_array_t *stmts, jl_module_t *m, jl_svec_t *sparam_vals, jl_value_t *binding_edge,
int binding_effects);
JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge);
JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge);

int get_next_edge(jl_array_t *list, int i, jl_value_t** invokesig, jl_code_instance_t **caller) JL_NOTSAFEPOINT;
int set_next_edge(jl_array_t *list, int i, jl_value_t *invokesig, jl_code_instance_t *caller);
Expand Down
46 changes: 9 additions & 37 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,7 @@ static jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator,
size_t totargs = 2 + n_sparams + def->nargs;
JL_GC_PUSHARGS(gargs, totargs);
gargs[0] = jl_box_ulong(world);
gargs[1] = jl_box_long(def->line);
gargs[1] = jl_new_struct(jl_linenumbernode_type, gargs[1], def->file);
gargs[1] = (jl_value_t*)def;
memcpy(&gargs[2], jl_svec_data(sparam_vals), n_sparams * sizeof(void*));
memcpy(&gargs[2 + n_sparams], args, (def->nargs - def->isva) * sizeof(void*));
if (def->isva)
Expand All @@ -615,23 +614,6 @@ static jl_value_t *jl_call_staged(jl_method_t *def, jl_value_t *generator,
return code;
}

// Lower `ex` into Julia IR, and (if it expands into a CodeInfo) resolve global-variable
// references in light of the provided type parameters.
// Like `jl_expand`, if there is an error expanding the provided expression, the return value
// will be an error expression (an `Expr` with `error_sym` as its head), which should be eval'd
// in the caller's context.
JL_DLLEXPORT jl_code_info_t *jl_expand_and_resolve(jl_value_t *ex, jl_module_t *module,
jl_svec_t *sparam_vals) {
jl_code_info_t *func = (jl_code_info_t*)jl_expand((jl_value_t*)ex, module);
JL_GC_PUSH1(&func);
if (jl_is_code_info(func)) {
jl_array_t *stmts = (jl_array_t*)func->code;
jl_resolve_definition_effects_in_ir(stmts, module, sparam_vals, NULL, 1);
}
JL_GC_POP();
return func;
}

JL_DLLEXPORT jl_code_instance_t *jl_cached_uninferred(jl_code_instance_t *codeinst, size_t world)
{
for (; codeinst; codeinst = jl_atomic_load_relaxed(&codeinst->next)) {
Expand Down Expand Up @@ -703,25 +685,12 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t
ex = jl_call_staged(def, generator, world, mi->sparam_vals, jl_svec_data(ttdt->parameters), jl_nparams(ttdt));

// do some post-processing
if (jl_is_code_info(ex)) {
func = (jl_code_info_t*)ex;
jl_array_t *stmts = (jl_array_t*)func->code;
jl_resolve_definition_effects_in_ir(stmts, def->module, mi->sparam_vals, NULL, 1);
}
else {
// Lower the user's expression and resolve references to the type parameters
func = jl_expand_and_resolve(ex, def->module, mi->sparam_vals);
if (!jl_is_code_info(func)) {
if (jl_is_expr(func) && ((jl_expr_t*)func)->head == jl_error_sym) {
ct->ptls->in_pure_callback = 0;
jl_toplevel_eval(def->module, (jl_value_t*)func);
}
jl_error("The function body AST defined by this @generated function is not pure. This likely means it contains a closure, a comprehension or a generator.");
}
// TODO: This should ideally be in the lambda expression,
// but currently our isva determination is non-syntactic
func->isva = def->isva;
if (!jl_is_code_info(ex)) {
jl_error("As of Julia 1.12, generated functions must return `CodeInfo`. See `Base.generated_body_to_codeinfo`.");
}
func = (jl_code_info_t*)ex;
jl_array_t *stmts = (jl_array_t*)func->code;
jl_resolve_definition_effects_in_ir(stmts, def->module, mi->sparam_vals, NULL, 1);
ex = NULL;

// If this generated function has an opaque closure, cache it for
Expand Down Expand Up @@ -778,6 +747,9 @@ JL_DLLEXPORT jl_code_info_t *jl_code_for_staged(jl_method_instance_t *mi, size_t
if (jl_is_method_instance(kind)) {
jl_method_instance_add_backedge((jl_method_instance_t*)kind, jl_nothing, ci);
}
else if (jl_is_binding(kind)) {
jl_add_binding_backedge((jl_binding_t*)kind, (jl_value_t*)ci);
}
else if (jl_is_mtable(kind)) {
assert(i < l);
ex = data[i++];
Expand Down
25 changes: 15 additions & 10 deletions src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,20 @@ void jl_invalidate_binding_refs(jl_globalref_t *ref, jl_binding_partition_t *inv
JL_GC_POP();
}

JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge)
{
if (!b->backedges) {
b->backedges = jl_alloc_vec_any(0);
jl_gc_wb(b, b->backedges);
} else if (jl_array_len(b->backedges) > 0 &&
jl_array_ptr_ref(b->backedges, jl_array_len(b->backedges)-1) == edge) {
// Optimization: Deduplicate repeated insertion of the same edge (e.g. during
// definition of a method that contains many references to the same global)
return;
}
jl_array_ptr_1d_push(b->backedges, edge);
}

// Called for all GlobalRefs found in lowered code. Adds backedges for cross-module
// GlobalRefs.
JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge)
Expand All @@ -1114,16 +1128,7 @@ JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t
jl_binding_t *b = gr->binding;
if (!b)
b = jl_get_module_binding(gr->mod, gr->name, 1);
if (!b->backedges) {
b->backedges = jl_alloc_vec_any(0);
jl_gc_wb(b, b->backedges);
} else if (jl_array_len(b->backedges) > 0 &&
jl_array_ptr_ref(b->backedges, jl_array_len(b->backedges)-1) == edge) {
// Optimization: Deduplicate repeated insertion of the same edge (e.g. during
// definition of a method that contains many references to the same global)
return;
}
jl_array_ptr_1d_push(b->backedges, edge);
jl_add_binding_backedge(b, edge);
}

JL_DLLEXPORT void jl_disable_binding(jl_globalref_t *gr)
Expand Down
7 changes: 7 additions & 0 deletions test/rebinding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ module Rebinding
@test f_return_delete_me_indirect() == 3
Base.delete_binding(@__MODULE__, :delete_me)
@test_throws UndefVarError f_return_delete_me_indirect()

# + via generated function
const delete_me = 4
@generated f_generated_return_delete_me() = return :(delete_me)
@test f_generated_return_delete_me() == 4
Base.delete_binding(@__MODULE__, :delete_me)
@test_throws UndefVarError f_generated_return_delete_me()
end

module RebindingPrecompile
Expand Down
6 changes: 3 additions & 3 deletions test/staged.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,15 +381,15 @@ let
@test length(ir.cfg.blocks) == 1
end

function generate_lambda_ex(world::UInt, source::LineNumberNode,
function generate_lambda_ex(world::UInt, source::Method,
argnames, spnames, @nospecialize body)
stub = Core.GeneratedFunctionStub(identity, Core.svec(argnames...), Core.svec(spnames...))
return stub(world, source, body)
end

# Test that `Core.CachedGenerator` works as expected
struct Generator54916 <: Core.CachedGenerator end
function (::Generator54916)(world::UInt, source::LineNumberNode, args...)
function (::Generator54916)(world::UInt, source::Method, args...)
return generate_lambda_ex(world, source,
(:doit54916, :func, :arg), (), :(func(arg)))
end
Expand Down Expand Up @@ -432,7 +432,7 @@ function overdubbee54341(a, b)
a + b
end
const overdubee_codeinfo54341 = code_lowered(overdubbee54341, Tuple{Any, Any})[1]
function overdub_generator54341(world::UInt, source::LineNumberNode, selftype, fargtypes)
function overdub_generator54341(world::UInt, source::Method, selftype, fargtypes)
if length(fargtypes) != 2
return generate_lambda_ex(world, source,
(:overdub54341, :args), (), :(error("Wrong number of arguments")))
Expand Down

0 comments on commit 30c6ff9

Please sign in to comment.