From a6dcfd5fa1579e3b73119d36355f30691c3c1bb2 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Thu, 3 Aug 2023 17:46:27 -0400 Subject: [PATCH 1/7] set opaque closures to true --- src/build_function.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index 457078c95..f57b166ee 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -179,7 +179,7 @@ function _build_and_inject_function(mod::Module, ex) # XXX: Workaround to specify the module as both the cache module AND context module. # Currently, the @RuntimeGeneratedFunction macro only sets the context module. module_tag = getproperty(mod, RuntimeGeneratedFunctions._tagname) - RuntimeGeneratedFunctions.RuntimeGeneratedFunction(module_tag, module_tag, ex; opaque_closures=false) + RuntimeGeneratedFunctions.RuntimeGeneratedFunction(module_tag, module_tag, ex; opaque_closures=true) end toexpr(n::Num, st) = toexpr(value(n), st) @@ -356,7 +356,7 @@ function toexpr(p::SpawnFetch{MultithreadedForm}, st) args = isnothing(p.args) ? Iterators.repeated((), length(p.exprs)) : p.args spawns = map(p.exprs, args) do thunk, a - ex = :($Funcall($(drop_expr(@RuntimeGeneratedFunction(@__MODULE__, toexpr(thunk, st), false))), + ex = :($Funcall($(drop_expr(@RuntimeGeneratedFunction(@__MODULE__, toexpr(thunk, st), true))), ($(toexpr.(a, (st,))...),))) quote let @@ -376,7 +376,7 @@ function toexpr(p::SpawnFetch{ShardedForm{false}}, st) args = isnothing(p.args) ? Iterators.repeated((), length(p.exprs)) : p.args spawns = map(p.exprs, args) do thunk, a - :($(drop_expr(@RuntimeGeneratedFunction(@__MODULE__, toexpr(thunk, st), false)))($(toexpr.(a, (st,))...),)) + :($(drop_expr(@RuntimeGeneratedFunction(@__MODULE__, toexpr(thunk, st), true)))($(toexpr.(a, (st,))...),)) end quote $(toexpr(p.combine, st))($(spawns...)) From 7261b86838d446ffdbad332c4d942b43ee11436b Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 4 Aug 2023 18:02:27 -0400 Subject: [PATCH 2/7] actually just leave the expression in there, don't pass any extra args --- src/build_function.jl | 48 +++++++++++++++---------------------------- 1 file changed, 17 insertions(+), 31 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index f57b166ee..dd80ef5a6 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -273,7 +273,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; oop, iip = iip_config oop_body = if oop - postprocess_fbody(make_array(parallel, dargs, rhss, similarto, cse)) + postprocess_fbody(make_array(parallel, rhss, similarto, cse)) else term(throw_missing_specialization, length(dargs)) end @@ -286,7 +286,6 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; out = Sym{Any}(:ˍ₋out) ip_body = if iip postprocess_fbody(set_array(parallel, - dargs, out, outputidxs, rhss, @@ -314,16 +313,16 @@ _nnz(x::AbstractArray) = length(x) _nnz(x::AbstractSparseArray) = nnz(x) _nnz(x::Union{Base.ReshapedArray, LinearAlgebra.Transpose}) = _nnz(parent(x)) -function make_array(s, dargs, arr, similarto, cse) +function make_array(s, arr, similarto, cse) s !== nothing && Base.@warn("Parallel form of $(typeof(s)) not implemented") _make_array(arr, similarto, cse) end -function make_array(s::SerialForm, dargs, arr, similarto, cse) +function make_array(s::SerialForm, arr, similarto, cse) _make_array(arr, similarto, cse) end -function make_array(s::ShardedForm, closed_args, arr, similarto, cse) +function make_array(s::ShardedForm, arr, similarto, cse) if arr isa AbstractSparseArray return LiteralExpr(quote @@ -332,32 +331,21 @@ function make_array(s::ShardedForm, closed_args, arr, similarto, cse) copy($(arr.colptr)), copy($(arr.rowval)), $(make_array(s, - closed_args, arr.nzval, Vector,cse))) end) end per_task = ceil(Int, length(arr) / s.ncalls) slices = collect(Iterators.partition(arr, per_task)) - arrays = map(slices) do slice - Func(closed_args, [], _make_array(slice, similarto, cse)), closed_args + funcs = map(slices) do slice + Func([], [], _make_array(slice, similarto, cse)) end - SpawnFetch{typeof(s)}(first.(arrays), last.(arrays), vcat) + SpawnFetch{typeof(s)}(funcs, nothing, vcat) end -struct Funcall{F, T} - f::F - args::T -end - -(f::Funcall)() = f.f(f.args...) - function toexpr(p::SpawnFetch{MultithreadedForm}, st) - args = isnothing(p.args) ? - Iterators.repeated((), length(p.exprs)) : p.args - spawns = map(p.exprs, args) do thunk, a - ex = :($Funcall($(drop_expr(@RuntimeGeneratedFunction(@__MODULE__, toexpr(thunk, st), true))), - ($(toexpr.(a, (st,))...),))) + spawns = map(p.exprs) do thunk + ex = :(()->$(toexpr(thunk, st))) quote let task = Base.Task($ex) @@ -373,10 +361,9 @@ function toexpr(p::SpawnFetch{MultithreadedForm}, st) end function toexpr(p::SpawnFetch{ShardedForm{false}}, st) - args = isnothing(p.args) ? - Iterators.repeated((), length(p.exprs)) : p.args - spawns = map(p.exprs, args) do thunk, a - :($(drop_expr(@RuntimeGeneratedFunction(@__MODULE__, toexpr(thunk, st), true)))($(toexpr.(a, (st,))...),)) + spawns = map(p.exprs) do thunk + @show thunk + ex = :(($(toexpr(thunk, st)))()) end quote $(toexpr(p.combine, st))($(spawns...)) @@ -438,12 +425,12 @@ _make_array(x, similarto, cse) = x ## In-place version -function set_array(p, closed_vars, args...) +function set_array(p, args...) p !== nothing && Base.@warn("Parallel form of $(typeof(p)) not implemented") _set_array(args...) end -function set_array(s::SerialForm, closed_vars, args...) +function set_array(s::SerialForm, args...) _set_array(args...) end @@ -464,10 +451,9 @@ function recursive_split(leaf_f, s, out, args, outputidxs, xs) end end -function set_array(s::ShardedForm, closed_args, out, outputidxs, rhss, checkbounds, skipzeros, cse) +function set_array(s::ShardedForm, out, outputidxs, rhss, checkbounds, skipzeros, cse) if rhss isa AbstractSparseArray return set_array(s, - closed_args, LiteralExpr(:($out.nzval)), nothing, rhss.nzval, @@ -481,9 +467,9 @@ function set_array(s::ShardedForm, closed_args, out, outputidxs, rhss, checkboun if outputidxs === nothing outputidxs = collect(eachindex(rhss)) end - all_args = [outvar, closed_args...] + all_args = [outvar] ex = recursive_split(s, outvar, all_args, outputidxs, rhss) do idxs, xs - Func(all_args, [], + Func([], [], _set_array(outvar, idxs, xs, checkbounds, skipzeros, cse), []) end.body From b8e46ffe84bec38542a8da54f1fdc4c7a27fa04a Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 18 Aug 2023 15:08:08 -0400 Subject: [PATCH 3/7] use .args for completeness fix multithreading bug --- src/build_function.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index dd80ef5a6..ea28e93a7 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -345,7 +345,7 @@ end function toexpr(p::SpawnFetch{MultithreadedForm}, st) spawns = map(p.exprs) do thunk - ex = :(()->$(toexpr(thunk, st))) + ex = :(()->$(toexpr(thunk, st))($(map(x->toexpr(x, st), p.args))...)) quote let task = Base.Task($ex) @@ -362,8 +362,7 @@ end function toexpr(p::SpawnFetch{ShardedForm{false}}, st) spawns = map(p.exprs) do thunk - @show thunk - ex = :(($(toexpr(thunk, st)))()) + ex = :($(toexpr(thunk, st))($(map(x->toexpr(x, st), p.args))...)) end quote $(toexpr(p.combine, st))($(spawns...)) From da29f2e993627046afcd8c78037b625ae42b3025 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 18 Aug 2023 15:28:20 -0400 Subject: [PATCH 4/7] isnothing on args --- src/build_function.jl | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index ea28e93a7..2b5788ada 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -345,7 +345,11 @@ end function toexpr(p::SpawnFetch{MultithreadedForm}, st) spawns = map(p.exprs) do thunk - ex = :(()->$(toexpr(thunk, st))($(map(x->toexpr(x, st), p.args))...)) + if isnothing(p.args) + ex = toexpr(thunk, st) + else + ex = :(()->$(toexpr(thunk, st))($(map(x->toexpr(x, st), p.args))...)) + end quote let task = Base.Task($ex) @@ -362,7 +366,11 @@ end function toexpr(p::SpawnFetch{ShardedForm{false}}, st) spawns = map(p.exprs) do thunk - ex = :($(toexpr(thunk, st))($(map(x->toexpr(x, st), p.args))...)) + if isnothing(p.args) + ex = :($(toexpr(thunk, st))()) + else + ex = :($(toexpr(thunk, st))($(map(x->toexpr(x, st), p.args))...)) + end end quote $(toexpr(p.combine, st))($(spawns...)) From ca7da64e6d4a8bb0de75eb022b83bb5a320e0fb3 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 18 Aug 2023 15:34:35 -0400 Subject: [PATCH 5/7] drop args --- src/build_function.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/build_function.jl b/src/build_function.jl index 2b5788ada..d47153115 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -452,7 +452,7 @@ function recursive_split(leaf_f, s, out, args, outputidxs, xs) recursive_split(leaf_f, s, out, args, first.(slice), last.(slice)) end return Func(args, [], - SpawnFetch{typeof(s)}(fs, [args for f in fs], + SpawnFetch{typeof(s)}(fs, (@inline noop(x...) = nothing)), []) end From 2cded7f524ee6ac5832bcf29fb4bd0de2a2c2dc4 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 18 Aug 2023 15:43:23 -0400 Subject: [PATCH 6/7] revert #954 -- switches to ShardedForm if more than 1000 exprs --- src/build_function.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/build_function.jl b/src/build_function.jl index d47153115..7f117f51b 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -138,6 +138,10 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...; linenumbers = true, cse = false, kwargs...) + if parallel == nothing && _nnz(rhss) >= 1000 + parallel = ShardedForm() # by default switch for arrays longer than 1000 exprs + end + dargs = map((x) -> destructure_arg(x[2], !checkbounds, Symbol("ˍ₋arg$(x[1])")), enumerate([args...])) From a4570c8c26650faf599779cf2b67e6e50c231a54 Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 18 Aug 2023 16:33:02 -0400 Subject: [PATCH 7/7] fix bug when cutoff is nothing; i.e. when MultithreadedForm is constructed without args --- src/build_function.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/build_function.jl b/src/build_function.jl index d97b1a36d..3a144e6c4 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -27,7 +27,7 @@ ShardedForm() = ShardedForm(80, 4) const MultithreadedForm = ShardedForm{true} -MultithreadedForm() = MultithreadedForm(nothing, 2*nthreads()) +MultithreadedForm() = MultithreadedForm(80, 2*nthreads()) function throw_missing_specialization(n) throw(ArgumentError("Missing specialization for $n arguments. Check `iip_config`.")) @@ -446,8 +446,7 @@ function set_array(s::SerialForm, args...) end function recursive_split(leaf_f, s, out, args, outputidxs, xs) - cutoff = isnothing(s.cutoff) ? ceil(Int, length(xs) / (2*s.ncalls)) : s.cutoff - if length(xs) <= cutoff + if length(xs) <= s.cutoff return leaf_f(outputidxs, xs) else per_part = ceil(Int, length(xs) / s.ncalls)