diff --git a/src/build_function.jl b/src/build_function.jl index 457078c95..3a144e6c4 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -27,7 +27,7 @@ ShardedForm() = ShardedForm(80, 4) const MultithreadedForm = ShardedForm{true} -MultithreadedForm() = MultithreadedForm(nothing, 2*nthreads()) +MultithreadedForm() = MultithreadedForm(80, 2*nthreads()) function throw_missing_specialization(n) throw(ArgumentError("Missing specialization for $n arguments. Check `iip_config`.")) @@ -138,6 +138,10 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...; linenumbers = true, cse = false, kwargs...) + if parallel == nothing && _nnz(rhss) >= 1000 + parallel = ShardedForm() # by default switch for arrays longer than 1000 exprs + end + dargs = map((x) -> destructure_arg(x[2], !checkbounds, Symbol("ˍ₋arg$(x[1])")), enumerate([args...])) @@ -179,7 +183,7 @@ function _build_and_inject_function(mod::Module, ex) # XXX: Workaround to specify the module as both the cache module AND context module. # Currently, the @RuntimeGeneratedFunction macro only sets the context module. module_tag = getproperty(mod, RuntimeGeneratedFunctions._tagname) - RuntimeGeneratedFunctions.RuntimeGeneratedFunction(module_tag, module_tag, ex; opaque_closures=false) + RuntimeGeneratedFunctions.RuntimeGeneratedFunction(module_tag, module_tag, ex; opaque_closures=true) end toexpr(n::Num, st) = toexpr(value(n), st) @@ -273,7 +277,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; oop, iip = iip_config oop_body = if oop - postprocess_fbody(make_array(parallel, dargs, rhss, similarto, cse)) + postprocess_fbody(make_array(parallel, rhss, similarto, cse)) else term(throw_missing_specialization, length(dargs)) end @@ -286,7 +290,6 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; out = Sym{Any}(:ˍ₋out) ip_body = if iip postprocess_fbody(set_array(parallel, - dargs, out, outputidxs, rhss, @@ -314,16 +317,16 @@ _nnz(x::AbstractArray) = length(x) _nnz(x::AbstractSparseArray) = nnz(x) _nnz(x::Union{Base.ReshapedArray, LinearAlgebra.Transpose}) = _nnz(parent(x)) -function make_array(s, dargs, arr, similarto, cse) +function make_array(s, arr, similarto, cse) s !== nothing && Base.@warn("Parallel form of $(typeof(s)) not implemented") _make_array(arr, similarto, cse) end -function make_array(s::SerialForm, dargs, arr, similarto, cse) +function make_array(s::SerialForm, arr, similarto, cse) _make_array(arr, similarto, cse) end -function make_array(s::ShardedForm, closed_args, arr, similarto, cse) +function make_array(s::ShardedForm, arr, similarto, cse) if arr isa AbstractSparseArray return LiteralExpr(quote @@ -332,32 +335,25 @@ function make_array(s::ShardedForm, closed_args, arr, similarto, cse) copy($(arr.colptr)), copy($(arr.rowval)), $(make_array(s, - closed_args, arr.nzval, Vector,cse))) end) end per_task = ceil(Int, length(arr) / s.ncalls) slices = collect(Iterators.partition(arr, per_task)) - arrays = map(slices) do slice - Func(closed_args, [], _make_array(slice, similarto, cse)), closed_args + funcs = map(slices) do slice + Func([], [], _make_array(slice, similarto, cse)) end - SpawnFetch{typeof(s)}(first.(arrays), last.(arrays), vcat) + SpawnFetch{typeof(s)}(funcs, vcat) end -struct Funcall{F, T} - f::F - args::T -end - -(f::Funcall)() = f.f(f.args...) - function toexpr(p::SpawnFetch{MultithreadedForm}, st) - args = isnothing(p.args) ? - Iterators.repeated((), length(p.exprs)) : p.args - spawns = map(p.exprs, args) do thunk, a - ex = :($Funcall($(drop_expr(@RuntimeGeneratedFunction(@__MODULE__, toexpr(thunk, st), false))), - ($(toexpr.(a, (st,))...),))) + spawns = map(p.exprs) do thunk + if isnothing(p.args) + ex = toexpr(thunk, st) + else + ex = :(()->$(toexpr(thunk, st))($(map(x->toexpr(x, st), p.args))...)) + end quote let task = Base.Task($ex) @@ -373,10 +369,12 @@ function toexpr(p::SpawnFetch{MultithreadedForm}, st) end function toexpr(p::SpawnFetch{ShardedForm{false}}, st) - args = isnothing(p.args) ? - Iterators.repeated((), length(p.exprs)) : p.args - spawns = map(p.exprs, args) do thunk, a - :($(drop_expr(@RuntimeGeneratedFunction(@__MODULE__, toexpr(thunk, st), false)))($(toexpr.(a, (st,))...),)) + spawns = map(p.exprs) do thunk + if isnothing(p.args) + ex = :($(toexpr(thunk, st))()) + else + ex = :($(toexpr(thunk, st))($(map(x->toexpr(x, st), p.args))...)) + end end quote $(toexpr(p.combine, st))($(spawns...)) @@ -438,18 +436,17 @@ _make_array(x, similarto, cse) = x ## In-place version -function set_array(p, closed_vars, args...) +function set_array(p, args...) p !== nothing && Base.@warn("Parallel form of $(typeof(p)) not implemented") _set_array(args...) end -function set_array(s::SerialForm, closed_vars, args...) +function set_array(s::SerialForm, args...) _set_array(args...) end function recursive_split(leaf_f, s, out, args, outputidxs, xs) - cutoff = isnothing(s.cutoff) ? ceil(Int, length(xs) / (2*s.ncalls)) : s.cutoff - if length(xs) <= cutoff + if length(xs) <= s.cutoff return leaf_f(outputidxs, xs) else per_part = ceil(Int, length(xs) / s.ncalls) @@ -458,16 +455,13 @@ function recursive_split(leaf_f, s, out, args, outputidxs, xs) recursive_split(leaf_f, s, out, args, first.(slice), last.(slice)) end return Func(args, [], - SpawnFetch{typeof(s)}(fs, [args for f in fs], - (@inline noop(x...) = nothing)), - []) + SpawnFetch{typeof(s)}(fs, (@inline noop(x...) = nothing)), []) end end -function set_array(s::ShardedForm, closed_args, out, outputidxs, rhss, checkbounds, skipzeros, cse) +function set_array(s::ShardedForm, out, outputidxs, rhss, checkbounds, skipzeros, cse) if rhss isa AbstractSparseArray return set_array(s, - closed_args, LiteralExpr(:($out.nzval)), nothing, rhss.nzval, @@ -481,9 +475,8 @@ function set_array(s::ShardedForm, closed_args, out, outputidxs, rhss, checkboun if outputidxs === nothing outputidxs = collect(eachindex(rhss)) end - all_args = [outvar, closed_args...] - ex = recursive_split(s, outvar, all_args, outputidxs, rhss) do idxs, xs - Func(all_args, [], + ex = recursive_split(s, outvar, [], outputidxs, rhss) do idxs, xs + Func([], [], _set_array(outvar, idxs, xs, checkbounds, skipzeros, cse), []) end.body