Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use opaque closure to fix ShardedForm #954

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
71 changes: 32 additions & 39 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ ShardedForm() = ShardedForm(80, 4)

const MultithreadedForm = ShardedForm{true}

MultithreadedForm() = MultithreadedForm(nothing, 2*nthreads())
MultithreadedForm() = MultithreadedForm(80, 2*nthreads())

function throw_missing_specialization(n)
throw(ArgumentError("Missing specialization for $n arguments. Check `iip_config`."))
Expand Down Expand Up @@ -138,6 +138,10 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp}, args...;
linenumbers = true,
cse = false, kwargs...)

if parallel == nothing && _nnz(rhss) >= 1000
parallel = ShardedForm() # by default switch for arrays longer than 1000 exprs
end

dargs = map((x) -> destructure_arg(x[2], !checkbounds,
Symbol("ˍ₋arg$(x[1])")), enumerate([args...]))

Expand Down Expand Up @@ -179,7 +183,7 @@ function _build_and_inject_function(mod::Module, ex)
# XXX: Workaround to specify the module as both the cache module AND context module.
# Currently, the @RuntimeGeneratedFunction macro only sets the context module.
module_tag = getproperty(mod, RuntimeGeneratedFunctions._tagname)
RuntimeGeneratedFunctions.RuntimeGeneratedFunction(module_tag, module_tag, ex; opaque_closures=false)
RuntimeGeneratedFunctions.RuntimeGeneratedFunction(module_tag, module_tag, ex; opaque_closures=true)
end

toexpr(n::Num, st) = toexpr(value(n), st)
Expand Down Expand Up @@ -273,7 +277,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;

oop, iip = iip_config
oop_body = if oop
postprocess_fbody(make_array(parallel, dargs, rhss, similarto, cse))
postprocess_fbody(make_array(parallel, rhss, similarto, cse))
else
term(throw_missing_specialization, length(dargs))
end
Expand All @@ -286,7 +290,6 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...;
out = Sym{Any}(:ˍ₋out)
ip_body = if iip
postprocess_fbody(set_array(parallel,
dargs,
out,
outputidxs,
rhss,
Expand Down Expand Up @@ -314,16 +317,16 @@ _nnz(x::AbstractArray) = length(x)
_nnz(x::AbstractSparseArray) = nnz(x)
_nnz(x::Union{Base.ReshapedArray, LinearAlgebra.Transpose}) = _nnz(parent(x))

function make_array(s, dargs, arr, similarto, cse)
function make_array(s, arr, similarto, cse)
s !== nothing && Base.@warn("Parallel form of $(typeof(s)) not implemented")
_make_array(arr, similarto, cse)
end

function make_array(s::SerialForm, dargs, arr, similarto, cse)
function make_array(s::SerialForm, arr, similarto, cse)
_make_array(arr, similarto, cse)
end

function make_array(s::ShardedForm, closed_args, arr, similarto, cse)
function make_array(s::ShardedForm, arr, similarto, cse)
if arr isa AbstractSparseArray

return LiteralExpr(quote
Expand All @@ -332,32 +335,25 @@ function make_array(s::ShardedForm, closed_args, arr, similarto, cse)
copy($(arr.colptr)),
copy($(arr.rowval)),
$(make_array(s,
closed_args,
arr.nzval,
Vector,cse)))
end)
end
per_task = ceil(Int, length(arr) / s.ncalls)
slices = collect(Iterators.partition(arr, per_task))
arrays = map(slices) do slice
Func(closed_args, [], _make_array(slice, similarto, cse)), closed_args
funcs = map(slices) do slice
Func([], [], _make_array(slice, similarto, cse))
end
SpawnFetch{typeof(s)}(first.(arrays), last.(arrays), vcat)
SpawnFetch{typeof(s)}(funcs, vcat)
end

struct Funcall{F, T}
f::F
args::T
end

(f::Funcall)() = f.f(f.args...)

function toexpr(p::SpawnFetch{MultithreadedForm}, st)
args = isnothing(p.args) ?
Iterators.repeated((), length(p.exprs)) : p.args
spawns = map(p.exprs, args) do thunk, a
ex = :($Funcall($(drop_expr(@RuntimeGeneratedFunction(@__MODULE__, toexpr(thunk, st), false))),
($(toexpr.(a, (st,))...),)))
spawns = map(p.exprs) do thunk
if isnothing(p.args)
ex = toexpr(thunk, st)
else
ex = :(()->$(toexpr(thunk, st))($(map(x->toexpr(x, st), p.args))...))
end
quote
let
task = Base.Task($ex)
Expand All @@ -373,10 +369,12 @@ function toexpr(p::SpawnFetch{MultithreadedForm}, st)
end

function toexpr(p::SpawnFetch{ShardedForm{false}}, st)
args = isnothing(p.args) ?
Iterators.repeated((), length(p.exprs)) : p.args
spawns = map(p.exprs, args) do thunk, a
:($(drop_expr(@RuntimeGeneratedFunction(@__MODULE__, toexpr(thunk, st), false)))($(toexpr.(a, (st,))...),))
spawns = map(p.exprs) do thunk
if isnothing(p.args)
ex = :($(toexpr(thunk, st))())
else
ex = :($(toexpr(thunk, st))($(map(x->toexpr(x, st), p.args))...))
end
end
quote
$(toexpr(p.combine, st))($(spawns...))
Expand Down Expand Up @@ -438,18 +436,17 @@ _make_array(x, similarto, cse) = x

## In-place version

function set_array(p, closed_vars, args...)
function set_array(p, args...)
p !== nothing && Base.@warn("Parallel form of $(typeof(p)) not implemented")
_set_array(args...)
end

function set_array(s::SerialForm, closed_vars, args...)
function set_array(s::SerialForm, args...)
_set_array(args...)
end

function recursive_split(leaf_f, s, out, args, outputidxs, xs)
cutoff = isnothing(s.cutoff) ? ceil(Int, length(xs) / (2*s.ncalls)) : s.cutoff
if length(xs) <= cutoff
if length(xs) <= s.cutoff
return leaf_f(outputidxs, xs)
else
per_part = ceil(Int, length(xs) / s.ncalls)
Expand All @@ -458,16 +455,13 @@ function recursive_split(leaf_f, s, out, args, outputidxs, xs)
recursive_split(leaf_f, s, out, args, first.(slice), last.(slice))
end
return Func(args, [],
SpawnFetch{typeof(s)}(fs, [args for f in fs],
(@inline noop(x...) = nothing)),
[])
SpawnFetch{typeof(s)}(fs, (@inline noop(x...) = nothing)), [])
end
end

function set_array(s::ShardedForm, closed_args, out, outputidxs, rhss, checkbounds, skipzeros, cse)
function set_array(s::ShardedForm, out, outputidxs, rhss, checkbounds, skipzeros, cse)
if rhss isa AbstractSparseArray
return set_array(s,
closed_args,
LiteralExpr(:($out.nzval)),
nothing,
rhss.nzval,
Expand All @@ -481,9 +475,8 @@ function set_array(s::ShardedForm, closed_args, out, outputidxs, rhss, checkboun
if outputidxs === nothing
outputidxs = collect(eachindex(rhss))
end
all_args = [outvar, closed_args...]
ex = recursive_split(s, outvar, all_args, outputidxs, rhss) do idxs, xs
Func(all_args, [],
ex = recursive_split(s, outvar, [], outputidxs, rhss) do idxs, xs
Func([], [],
_set_array(outvar, idxs, xs, checkbounds, skipzeros, cse),
[])
end.body
Expand Down