diff --git a/base/task.jl b/base/task.jl index 0d37e1fe9eae2..f46d4dd75addd 100644 --- a/base/task.jl +++ b/base/task.jl @@ -591,10 +591,11 @@ function enq_work(t::Task) else tid = 0 if ccall(:jl_enqueue_task, Cint, (Any,), t) != 0 - # if multiq is full, give to a random thread (TODO fix) - tid = mod(time_ns() % Int, Threads.nthreads()) + 1 - ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1) - push!(Workqueues[tid], t) + tid = ccall(:jl_get_random_thread_for_spawned_task, Cint, ()) + ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid) + + # Note that tid is obtained from c, and is therefore 0-indexed. + push!(Workqueues[tid+1], t) end end ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16) diff --git a/base/threadingconstructs.jl b/base/threadingconstructs.jl index 27096e1ba8be6..7f68e235e3721 100644 --- a/base/threadingconstructs.jl +++ b/base/threadingconstructs.jl @@ -181,6 +181,14 @@ macro spawn(expr) end end +function disable_spawning_on_this_thread() + ccall(:jl_accept_spawned_tasks, Cint, (Cint, Cint), threadid() - 1, 0) +end + +function enable_spawning_on_this_thread() + ccall(:jl_accept_spawned_tasks, Cint, (Cint, Cint), threadid() - 1, 0) +end + # This is a stub that can be overloaded for downstream structures like `Channel` function foreach end diff --git a/src/partr.c b/src/partr.c index a0dbf0f415a73..b41d279b4619d 100644 --- a/src/partr.c +++ b/src/partr.c @@ -77,11 +77,20 @@ static int32_t heap_p; /* unbias state for the RNG */ static uint64_t cong_unbias; +/* If set to 0, thread will stop pulling new work from the multiq when its workqueue is empty. + * This effectively makes it a high-priority protected thread. Work can be put on its workqueue + * directly. + */ +static uint8_t *thread_accepts_spawned_tasks = NULL; +jl_mutex_t thread_scheduling_lock; static inline void multiq_init(void) { heap_p = heap_c * jl_n_threads; heaps = (taskheap_t *)calloc(heap_p, sizeof(taskheap_t)); + thread_accepts_spawned_tasks = (uint8_t*)realloc(thread_accepts_spawned_tasks, jl_n_threads * sizeof(*thread_accepts_spawned_tasks)); + memset(thread_accepts_spawned_tasks, 1, jl_n_threads * sizeof(*thread_accepts_spawned_tasks)); + for (int32_t i = 0; i < heap_p; ++i) { jl_mutex_init(&heaps[i].lock); heaps[i].tasks = (jl_task_t **)calloc(tasks_per_heap, sizeof(jl_task_t*)); @@ -390,13 +399,13 @@ static jl_task_t *get_next_task(jl_value_t *trypoptask, jl_value_t *q) jl_gc_safepoint(); jl_value_t *args[2] = { trypoptask, q }; jl_task_t *task = (jl_task_t*)jl_apply(args, 2); + int self_tid = jl_get_ptls_states()->tid; if (jl_typeis(task, jl_task_type)) { - int self = jl_get_ptls_states()->tid; - jl_set_task_tid(task, self); + jl_set_task_tid(task, self_tid); return task; } jl_gc_safepoint(); - return multiq_deletemin(); + return thread_accepts_spawned_tasks[self_tid] ? multiq_deletemin() : NULL; } static int may_sleep(jl_ptls_t ptls) @@ -409,6 +418,50 @@ static int may_sleep(jl_ptls_t ptls) extern volatile unsigned _threadedregion; +// Get random thread id for a thread that accepts spawned tasks. +JL_DLLEXPORT int jl_get_random_thread_for_spawned_task() +{ + jl_ptls_t ptls = jl_get_ptls_states(); + uint64_t random_tid; + // Multiple cycles may be necessary if there are many threads + // that do not accept work but this should be fairly rare. + do { + random_tid = cong(jl_n_threads, cong_unbias, &ptls->rngseed); + } while(!thread_accepts_spawned_tasks[random_tid]); + return (int)random_tid; +} + + +JL_DLLEXPORT int jl_accept_spawned_tasks(int tid, int accept) +{ + if (tid < 0 || tid >= jl_n_threads) + return 1; + + int failed = 0; + JL_LOCK(&thread_scheduling_lock); + if (accept) { + thread_accepts_spawned_tasks[tid] = 1; + } else if (thread_accepts_spawned_tasks[tid]) { + // Ensure that there is at least one more thread that still + // accepts spawned tasks. + int other_available = 0; + for (int i = 0; i < jl_n_threads; i++) { + if (i != tid && thread_accepts_spawned_tasks[i]) { + other_available++; + break; + } + } + if (other_available) { + thread_accepts_spawned_tasks[tid] = 0; + } else { + jl_printf(JL_STDERR, "WARNING: can't disable task processing on all threads. Ignoring."); + failed = 1; + } + } + JL_UNLOCK(&thread_scheduling_lock); + return failed; +} + JL_DLLEXPORT jl_task_t *jl_task_get_next(jl_value_t *trypoptask, jl_value_t *q) { jl_ptls_t ptls = jl_get_ptls_states();