diff --git a/src/BroadcastMapConversion.jl b/src/BroadcastMapConversion.jl new file mode 100644 index 0000000..6edf0ae --- /dev/null +++ b/src/BroadcastMapConversion.jl @@ -0,0 +1,48 @@ +module BroadcastMapConversion +# Convert broadcast call to map call by capturing array arguments +# with `map_args` and creating a map function with `map_function`. +# Logic from https://github.com/Jutho/Strided.jl/blob/v2.0.4/src/broadcast.jl. + +using Base.Broadcast: Broadcasted + +const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}} + +function map_args(bc::Broadcasted, rest...) + return (map_args(bc.args...)..., map_args(rest...)...) +end +map_args(a::AbstractArray, rest...) = (a, map_args(rest...)...) +map_args(a, rest...) = map_args(rest...) +map_args() = () + +struct MapFunction{F,Args<:Tuple} + f::F + args::Args +end +struct Arg end + +# construct MapFunction +function map_function(bc::Broadcasted) + args = map_function_tuple(bc.args) + return MapFunction(bc.f, args) +end +map_function_tuple(t::Tuple{}) = t +map_function_tuple(t::Tuple) = (map_function(t[1]), map_function_tuple(Base.tail(t))...) +map_function(a::WrappedScalarArgs) = a[] +map_function(a::AbstractArray) = Arg() +map_function(a) = a + +# Evaluate MapFunction +(f::MapFunction)(args...) = apply(f, args)[1] +function apply(f::MapFunction, args) + args, newargs = apply_tuple(f.args, args) + return f.f(args...), newargs +end +apply(a::Arg, args::Tuple) = args[1], Base.tail(args) +apply(a, args) = a, args +apply_tuple(t::Tuple{}, args) = t, args +function apply_tuple(t::Tuple, args) + t1, newargs1 = apply(t[1], args) + ttail, newargs = apply_tuple(Base.tail(t), newargs1) + return (t1, ttail...), newargs +end +end diff --git a/test/runtests.jl b/test/runtests.jl new file mode 100644 index 0000000..92e5d6a --- /dev/null +++ b/test/runtests.jl @@ -0,0 +1,14 @@ +@eval module $(gensym()) +using Test: @test, @testset +using NDTensors.BroadcastMapConversion: map_function, map_args +@testset "BroadcastMapConversion" begin + using Base.Broadcast: Broadcasted + c = 2.2 + a = randn(2, 3) + b = randn(2, 3) + bc = Broadcasted(*, (c, a)) + @test copy(bc) ≈ c * a ≈ map(map_function(bc), map_args(bc)...) + bc = Broadcasted(+, (a, b)) + @test copy(bc) ≈ a + b ≈ map(map_function(bc), map_args(bc)...) +end +end