Skip to content

Commit

Permalink
further defer jl_insert_backedges after loading (JuliaLang#56447)
Browse files Browse the repository at this point in the history
Finish fully breaking the dependency between method insertions and
inferring whether the cache is valid. The cache should be inferable in
parallel and in aggregate after all loading is finished. This prepares
us for moving this code into Julia (Core.Compiler) next.
  • Loading branch information
vtjnash authored Nov 7, 2024
1 parent 671cd5e commit 4278ded
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 80 deletions.
5 changes: 3 additions & 2 deletions src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -4034,12 +4034,13 @@ static jl_value_t *jl_restore_package_image_from_stream(void* pkgimage_handle, i
// allocate a world for the new methods, and insert them there, invalidating content as needed
size_t world = jl_atomic_load_relaxed(&jl_world_counter) + 1;
jl_activate_methods(extext_methods, internal_methods, world);
// TODO: inject new_ext_cis into caches here, so the system can see them immediately as potential candidates (before validation)
// allow users to start running in this updated world
jl_atomic_store_release(&jl_world_counter, world);
// but one of those immediate users is going to be our cache updates
jl_insert_backedges((jl_array_t*)edges, (jl_array_t*)new_ext_cis, world); // restore external backedges (needs to be last)
// now permit more methods to be added again
JL_UNLOCK(&world_counter_lock);
// but one of those immediate users is going to be our cache insertions
jl_insert_backedges((jl_array_t*)edges, (jl_array_t*)new_ext_cis); // restore existing caches (needs to be last)
// reinit ccallables
jl_reinit_ccallable(&ccallable_list, base, pkgimage_handle);
arraylist_free(&ccallable_list);
Expand Down
178 changes: 100 additions & 78 deletions src/staticdata_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -751,56 +751,58 @@ static void jl_copy_roots(jl_array_t *method_roots_list, uint64_t key)
}
}

static size_t verify_invokesig(jl_value_t *invokesig, jl_method_t *expected, size_t minworld)
static void verify_invokesig(jl_value_t *invokesig, jl_method_t *expected, size_t world, size_t *minworld, size_t *maxworld)
{
assert(jl_is_type(invokesig));
assert(jl_is_method(expected));
size_t min_valid = 0;
size_t max_valid = ~(size_t)0;
if (jl_egal(invokesig, expected->sig)) {
// the invoke match is `expected` for `expected->sig`, unless `expected` is invalid
if (jl_atomic_load_relaxed(&expected->deleted_world) < max_valid)
max_valid = 0;
*minworld = jl_atomic_load_relaxed(&expected->primary_world);
*maxworld = jl_atomic_load_relaxed(&expected->deleted_world);
assert(*minworld <= world);
if (*maxworld < world)
*maxworld = 0;
}
else {
*minworld = 1;
*maxworld = ~(size_t)0;
jl_methtable_t *mt = jl_method_get_table(expected);
if ((jl_value_t*)mt == jl_nothing) {
max_valid = 0;
*maxworld = 0;
}
else {
jl_value_t *matches = jl_gf_invoke_lookup_worlds(invokesig, (jl_value_t*)mt, minworld, &min_valid, &max_valid);
jl_value_t *matches = jl_gf_invoke_lookup_worlds(invokesig, (jl_value_t*)mt, world, minworld, maxworld);
if (matches == jl_nothing) {
max_valid = 0;
*maxworld = 0;
}
else {
if (((jl_method_match_t*)matches)->method != expected) {
max_valid = 0;
*maxworld = 0;
}
}
}
}
return max_valid;
}

static size_t verify_call(jl_value_t *sig, jl_svec_t *expecteds, size_t i, size_t n, size_t minworld, jl_value_t **matches JL_REQUIRE_ROOTED_SLOT)
static void verify_call(jl_value_t *sig, jl_svec_t *expecteds, size_t i, size_t n, size_t world, size_t *minworld, size_t *maxworld, jl_value_t **matches JL_REQUIRE_ROOTED_SLOT)
{
// verify that these edges intersect with the same methods as before
size_t min_valid = 0;
size_t max_valid = ~(size_t)0;
*minworld = 1;
*maxworld = ~(size_t)0;
int ambig = 0;
// TODO: possibly need to included ambiguities too (for the optimizer correctness)?
jl_value_t *result = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing,
_jl_debug_method_invalidation ? INT32_MAX : n,
0, minworld, &min_valid, &max_valid, &ambig);
0, world, minworld, maxworld, &ambig);
*matches = result;
if (result == jl_nothing) {
max_valid = 0;
*maxworld = 0;
}
else {
// setdiff!(result, expected)
size_t j, k, ins = 0;
if (jl_array_nrows(result) != n) {
max_valid = 0;
*maxworld = 0;
}
for (k = 0; k < jl_array_nrows(result); k++) {
jl_method_t *match = ((jl_method_match_t*)jl_array_ptr_ref(result, k))->method;
Expand All @@ -822,29 +824,33 @@ static size_t verify_call(jl_value_t *sig, jl_svec_t *expecteds, size_t i, size_
// intersection has a new method or a method was
// deleted--this is now probably no good, just invalidate
// everything about it now
max_valid = 0;
*maxworld = 0;
if (!_jl_debug_method_invalidation)
break;
jl_array_ptr_set(result, ins++, match);
}
}
if (max_valid != ~(size_t)0 && _jl_debug_method_invalidation)
if (*maxworld != ~(size_t)0 && _jl_debug_method_invalidation)
jl_array_del_end((jl_array_t*)result, jl_array_nrows(result) - ins);
}
return max_valid;
}

// Test all edges relevant to a method:
//// Visit the entire call graph, starting from edges[idx] to determine if that method is valid
//// Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable
//// and slightly modified with an early termination option once the computation reaches its minimum
static int jl_verify_method(jl_code_instance_t *codeinst, size_t minworld, size_t *maxworld, arraylist_t *stack, htable_t *visiting)
static int jl_verify_method(jl_code_instance_t *codeinst, size_t *minworld, size_t *maxworld, arraylist_t *stack, htable_t *visiting)
{
size_t world = jl_atomic_load_relaxed(&codeinst->min_world);
size_t max_valid2 = jl_atomic_load_relaxed(&codeinst->max_world);
if (max_valid2 != WORLD_AGE_REVALIDATION_SENTINEL) {
*minworld = world;
*maxworld = max_valid2;
return 0;
}
*minworld = 1;
size_t current_world = jl_atomic_load_relaxed(&jl_world_counter);
*maxworld = current_world;
assert(jl_is_method_instance(codeinst->def) && jl_is_method(codeinst->def->def.method));
void **bp = ptrhash_bp(visiting, codeinst);
if (*bp != HT_NOTFOUND)
Expand All @@ -862,21 +868,22 @@ static int jl_verify_method(jl_code_instance_t *codeinst, size_t minworld, size_
// verify current edges
for (size_t j = 0; j < jl_svec_len(callees); ) {
jl_value_t *edge = jl_svecref(callees, j);
size_t min_valid2;
size_t max_valid2;
assert(!jl_is_method(edge)); // `Method`-edge isn't allowed for the optimized one-edge format
if (jl_is_code_instance(edge))
edge = (jl_value_t*)((jl_code_instance_t*)edge)->def;
if (jl_is_method_instance(edge)) {
jl_method_instance_t *mi = (jl_method_instance_t*)edge;
sig = jl_type_intersection(mi->def.method->sig, (jl_value_t*)mi->specTypes); // TODO: ??
max_valid2 = verify_call(sig, callees, j, 1, minworld, &matches);
verify_call(sig, callees, j, 1, world, &min_valid2, &max_valid2, &matches);
sig = NULL;
j += 1;
}
else if (jl_is_long(edge)) {
jl_value_t *sig = jl_svecref(callees, j + 1);
size_t nedges = jl_unbox_long(edge);
max_valid2 = verify_call(sig, callees, j + 2, nedges, minworld, &matches);
verify_call(sig, callees, j + 2, nedges, world, &min_valid2, &max_valid2, &matches);
j += 2 + nedges;
edge = sig;
}
Expand All @@ -896,9 +903,11 @@ static int jl_verify_method(jl_code_instance_t *codeinst, size_t minworld, size_
assert(jl_is_method(callee));
meth = (jl_method_t*)callee;
}
max_valid2 = verify_invokesig(edge, meth, minworld);
verify_invokesig(edge, meth, world, &min_valid2, &max_valid2);
j += 2;
}
if (*minworld < min_valid2)
*minworld = min_valid2;
if (*maxworld > max_valid2)
*maxworld = max_valid2;
if (max_valid2 != ~(size_t)0 && _jl_debug_method_invalidation) {
Expand All @@ -917,14 +926,19 @@ static int jl_verify_method(jl_code_instance_t *codeinst, size_t minworld, size_
// verify recursive edges (if valid, or debugging)
size_t cycle = depth;
jl_code_instance_t *cause = codeinst;
if (*maxworld == ~(size_t)0 || _jl_debug_method_invalidation) {
if (*maxworld != 0 || _jl_debug_method_invalidation) {
for (size_t j = 0; j < jl_svec_len(callees); j++) {
jl_value_t *edge = jl_svecref(callees, j);
if (!jl_is_code_instance(edge))
continue;
jl_code_instance_t *callee = (jl_code_instance_t*)edge;
size_t max_valid2 = ~(size_t)0;
size_t child_cycle = jl_verify_method(callee, minworld, &max_valid2, stack, visiting);
size_t min_valid2;
size_t max_valid2;
size_t child_cycle = jl_verify_method(callee, &min_valid2, &max_valid2, stack, visiting);
if (*minworld < min_valid2)
*minworld = min_valid2;
if (*minworld > max_valid2)
max_valid2 = 0;
if (*maxworld > max_valid2) {
cause = callee;
*maxworld = max_valid2;
Expand All @@ -947,12 +961,18 @@ static int jl_verify_method(jl_code_instance_t *codeinst, size_t minworld, size_
// cycle as also having a failed edge.
while (stack->len >= depth) {
jl_code_instance_t *child = (jl_code_instance_t*)arraylist_pop(stack);
if (*maxworld != jl_atomic_load_relaxed(&child->max_world))
jl_atomic_store_relaxed(&child->max_world, *maxworld);
if (jl_atomic_load_relaxed(&jl_n_threads) == 1) {
// a different thread might simultaneously come to a different, but equally valid, alternative result
assert(jl_atomic_load_relaxed(&child->max_world) == WORLD_AGE_REVALIDATION_SENTINEL);
assert(*minworld <= jl_atomic_load_relaxed(&child->min_world));
}
if (*maxworld != 0)
jl_atomic_store_relaxed(&child->min_world, *minworld);
jl_atomic_store_relaxed(&child->max_world, *maxworld);
void **bp = ptrhash_bp(visiting, codeinst);
assert(*bp == (char*)HT_NOTFOUND + stack->len + 1);
*bp = HT_NOTFOUND;
if (_jl_debug_method_invalidation && *maxworld != ~(size_t)0) {
if (_jl_debug_method_invalidation && *maxworld < current_world) {
jl_array_ptr_1d_push(_jl_debug_method_invalidation, (jl_value_t*)child);
loctag = jl_cstr_to_string("verify_methods");
JL_GC_PUSH1(&loctag);
Expand All @@ -966,26 +986,30 @@ static int jl_verify_method(jl_code_instance_t *codeinst, size_t minworld, size_
return 0;
}

static size_t jl_verify_method_graph(jl_code_instance_t *codeinst, size_t minworld, arraylist_t *stack, htable_t *visiting)
static void jl_verify_method_graph(jl_code_instance_t *codeinst, arraylist_t *stack, htable_t *visiting)
{
size_t minworld;
size_t maxworld;
assert(stack->len == 0);
for (size_t i = 0, hsz = visiting->size; i < hsz; i++)
assert(visiting->table[i] == HT_NOTFOUND);
size_t maxworld = ~(size_t)0;
int child_cycle = jl_verify_method(codeinst, minworld, &maxworld, stack, visiting);
int child_cycle = jl_verify_method(codeinst, &minworld, &maxworld, stack, visiting);
assert(child_cycle == 0); (void)child_cycle;
assert(stack->len == 0);
for (size_t i = 0, hsz = visiting->size / 2; i < hsz; i++) {
assert(visiting->table[2 * i + 1] == HT_NOTFOUND);
visiting->table[2 * i] = HT_NOTFOUND;
}
return maxworld;
if (jl_atomic_load_relaxed(&jl_n_threads) == 1) { // a different thread might simultaneously come to a different, but equally valid, alternative result
assert(maxworld == 0 || jl_atomic_load_relaxed(&codeinst->min_world) == minworld);
assert(jl_atomic_load_relaxed(&codeinst->max_world) == maxworld);
}
}

// Restore backedges to external targets
// `edges` = [caller1, ...], the list of worklist-owned code instances internally
// `ext_ci_list` = [caller1, ...], the list of worklist-owned code instances externally
static void jl_insert_backedges(jl_array_t *edges, jl_array_t *ext_ci_list, size_t minworld)
static void jl_insert_backedges(jl_array_t *edges, jl_array_t *ext_ci_list)
{
// determine which CodeInstance objects are still valid in our image
// to enable any applicable new codes
Expand All @@ -1001,61 +1025,59 @@ static void jl_insert_backedges(jl_array_t *edges, jl_array_t *ext_ci_list, size
jl_code_instance_t *codeinst = (jl_code_instance_t*)jl_array_ptr_ref(edges, i);
jl_svec_t *callees = jl_atomic_load_relaxed(&codeinst->edges);
jl_method_instance_t *caller = codeinst->def;
if (jl_atomic_load_relaxed(&codeinst->min_world) != minworld) {
if (external && jl_atomic_load_relaxed(&codeinst->max_world) != WORLD_AGE_REVALIDATION_SENTINEL) {
assert(jl_atomic_load_relaxed(&codeinst->min_world) == 1);
assert(jl_atomic_load_relaxed(&codeinst->max_world) == ~(size_t)0);
}
else {
continue;
}
}
size_t maxvalid = jl_verify_method_graph(codeinst, minworld, &stack, &visiting);
assert(jl_atomic_load_relaxed(&codeinst->max_world) == maxvalid);
if (maxvalid == ~(size_t)0) {
// if this callee is still valid, add all the backedges
for (size_t j = 0; j < jl_svec_len(callees); ) {
jl_value_t *edge = jl_svecref(callees, j);
if (jl_is_long(edge)) {
j += 2; // skip over signature and count but not methods
continue;
}
else if (jl_is_method(edge)) {
j += 1;
continue;
}
if (jl_is_code_instance(edge))
edge = (jl_value_t*)((jl_code_instance_t*)edge)->def;
if (jl_is_method_instance(edge)) {
jl_method_instance_add_backedge((jl_method_instance_t*)edge, NULL, codeinst);
j += 1;
}
else if (jl_is_mtable(edge)) {
jl_methtable_t *mt = (jl_methtable_t*)edge;
jl_value_t *sig = jl_svecref(callees, j + 1);
jl_method_table_add_backedge(mt, sig, codeinst);
j += 2;
}
else {
jl_value_t *callee = jl_svecref(callees, j + 1);
if (jl_is_code_instance(callee))
callee = (jl_value_t*)((jl_code_instance_t*)callee)->def;
else if (jl_is_method(callee)) {
j += 2;
jl_verify_method_graph(codeinst, &stack, &visiting);
size_t minvalid = jl_atomic_load_relaxed(&codeinst->min_world);
size_t maxvalid = jl_atomic_load_relaxed(&codeinst->max_world);
if (maxvalid >= minvalid) {
if (jl_atomic_load_relaxed(&jl_world_counter) == maxvalid) {
// if this callee is still valid, add all the backedges
for (size_t j = 0; j < jl_svec_len(callees); ) {
jl_value_t *edge = jl_svecref(callees, j);
if (jl_is_long(edge)) {
j += 2; // skip over signature and count but not methods
continue;
}
jl_method_instance_add_backedge((jl_method_instance_t*)callee, edge, codeinst);
j += 2;
else if (jl_is_method(edge)) {
j += 1;
continue;
}
if (jl_is_code_instance(edge))
edge = (jl_value_t*)((jl_code_instance_t*)edge)->def;
if (jl_is_method_instance(edge)) {
jl_method_instance_add_backedge((jl_method_instance_t*)edge, NULL, codeinst);
j += 1;
}
else if (jl_is_mtable(edge)) {
jl_methtable_t *mt = (jl_methtable_t*)edge;
jl_value_t *sig = jl_svecref(callees, j + 1);
jl_method_table_add_backedge(mt, sig, codeinst);
j += 2;
}
else {
jl_value_t *callee = jl_svecref(callees, j + 1);
if (jl_is_code_instance(callee))
callee = (jl_value_t*)((jl_code_instance_t*)callee)->def;
else if (jl_is_method(callee)) {
j += 2;
continue;
}
jl_method_instance_add_backedge((jl_method_instance_t*)callee, edge, codeinst);
j += 2;
}
}
}
if (jl_atomic_load_relaxed(&jl_world_counter) == maxvalid) {
maxvalid = ~(size_t)0;
jl_atomic_store_relaxed(&codeinst->max_world, maxvalid);
}
if (external) {
jl_value_t *owner = codeinst->owner;
JL_GC_PROMISE_ROOTED(owner);

// See #53586, #53109
assert(jl_atomic_load_relaxed(&codeinst->inferred));

if (jl_rettype_inferred(owner, caller, minworld, maxvalid) != jl_nothing) {
if (jl_rettype_inferred(owner, caller, minvalid, maxvalid) != jl_nothing) {
// We already got a code instance for this world age range from somewhere else - we don't need
// this one.
}
Expand Down

0 comments on commit 4278ded

Please sign in to comment.