Skip to content

Commit

Permalink
Introduce mapped/Mapped (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Dec 25, 2024
1 parent 33096dc commit 3e63462
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 19 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BroadcastMapConversion"
uuid = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.1.1"
version = "0.1.2"

[compat]
julia = "1.10"
16 changes: 13 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,21 @@ julia> Pkg.add("BroadcastMapConversion")
## Examples

````julia
using BroadcastMapConversion: BroadcastMapConversion
using Base.Broadcast: broadcasted
using BroadcastMapConversion: Mapped, mapped
using Test: @test

a = randn(2, 2)
bc = broadcasted(*, 2, a)
m = Mapped(bc)
m′ = mapped(x -> 2x, a)
@test copy(m) map(m.f, m.args...)
@test copy(m) copy(m′)
@test copy(m) copy(bc)
@test axes(m) == axes(bc)
@test copyto!(similar(m, Float64), m) copyto!(similar(bc, Float64), bc)
````

Examples go here.

---

*This page was generated using [Literate.jl](https://github.com/fredrikekre/Literate.jl).*
Expand Down
15 changes: 13 additions & 2 deletions examples/README.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,16 @@ julia> Pkg.add("BroadcastMapConversion")

# ## Examples

using BroadcastMapConversion: BroadcastMapConversion
# Examples go here.
using Base.Broadcast: broadcasted
using BroadcastMapConversion: Mapped, mapped
using Test: @test

a = randn(2, 2)
bc = broadcasted(*, 2, a)
m = Mapped(bc)
m′ = mapped(x -> 2x, a)
@test copy(m) map(m.f, m.args...)
@test copy(m) copy(m′)
@test copy(m) copy(bc)
@test axes(m) == axes(bc)
@test copyto!(similar(m, Float64), m) copyto!(similar(bc, Float64), bc)
51 changes: 48 additions & 3 deletions src/BroadcastMapConversion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,29 @@ module BroadcastMapConversion
# with `map_args` and creating a map function with `map_function`.
# Logic from https://github.com/Jutho/Strided.jl/blob/v2.0.4/src/broadcast.jl.

using Base.Broadcast: Broadcasted
using Base.Broadcast:
Broadcast, BroadcastStyle, Broadcasted, broadcasted, combine_eltypes, instantiate

const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}}

# Get the arguments of the map expression that
# is equivalent to the broadcast expression.
function map_args(bc::Broadcasted, rest...)
return (map_args(bc.args...)..., map_args(rest...)...)
end
map_args(a::AbstractArray, rest...) = (a, map_args(rest...)...)
map_args(a, rest...) = map_args(rest...)
map_args() = ()

struct MapFunction{F,Args<:Tuple}
struct MapFunction{F,Args<:Tuple} <: Function
f::F
args::Args
end
struct Arg end

# construct MapFunction
# Get the function of the map expression that
# is equivalent to the broadcast expression.
# Returns a `MapFunction`.
function map_function(bc::Broadcasted)
args = map_function_tuple(bc.args)
return MapFunction(bc.f, args)
Expand All @@ -45,4 +50,44 @@ function apply_tuple(t::Tuple, args)
ttail, newargs = apply_tuple(Base.tail(t), newargs1)
return (t1, ttail...), newargs
end

abstract type AbstractMapped <: Base.AbstractBroadcasted end

struct Mapped{Style<:Union{Nothing,BroadcastStyle},Axes,F,Args<:Tuple} <: AbstractMapped
style::Style
f::F
args::Args
axes::Axes
end

function Mapped(bc::Broadcasted)
return Mapped(bc.style, map_function(bc), map_args(bc), bc.axes)
end
function Broadcast.Broadcasted(m::Mapped)
return Broadcasted(m.style, m.f, m.args, m.axes)
end

# Convert `Broadcasted` to `Mapped` when `Broadcasted`
# is known to already be a map expression.
function map_broadcast_to_mapped(bc::Broadcasted)
return Mapped(bc.style, bc.f, bc.args, bc.axes)
end

mapped(f, args...) = Mapped(broadcasted(f, args...))

Base.similar(m::Mapped, elt::Type) = similar(Broadcasted(m), elt)
Base.similar(m::Mapped, elt::Type, ax::Tuple) = similar(Broadcasted(m), elt, ax)
Base.axes(m::Mapped) = axes(Broadcasted(m))
# Equivalent to:
# map(m.f, m.args...)
# copy(Broadcasted(m))
function Base.copy(m::Mapped)
elt = combine_eltypes(m.f, m.args)
# TODO: Handle case of non-concrete eltype.
@assert Base.isconcretetype(elt)
return copyto!(similar(m, elt), m)
end
Base.copyto!(dest::AbstractArray, m::Mapped) = map!(m.f, dest, m.args...)
Broadcast.instantiate(m::Mapped) = map_broadcast_to_mapped(instantiate(Broadcasted(m)))

end
32 changes: 22 additions & 10 deletions test/test_basics.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
using Base.Broadcast: Broadcasted
using BroadcastMapConversion: map_function, map_args
using Base.Broadcast: broadcasted
using BroadcastMapConversion: Mapped, mapped
using Test: @test, @testset

@testset "BroadcastMapConversion" begin
c = 2.2
a = randn(2, 3)
b = randn(2, 3)
bc = Broadcasted(*, (c, a))
@test copy(bc) c * a map(map_function(bc), map_args(bc)...)
bc = Broadcasted(+, (a, b))
@test copy(bc) a + b map(map_function(bc), map_args(bc)...)
@testset "BroadcastMapConversion (eltype=$elt)" for elt in (
Float32, Float64, Complex{Float32}, Complex{Float64}
)
c = elt(2.2)
a = randn(elt, 2, 3)
b = randn(elt, 2, 3)
for (bc, m′, ref) in (
(broadcasted(*, c, a), mapped(x -> c * x, a), c * a),
(broadcasted(+, a, broadcasted(*, c, b)), mapped((x, y) -> x + c * y, a, b), a + c * b),
)
m = Mapped(bc)
@test copy(m) ref
@test copy(m′) ref
@test map(m.f, m.args...) ref
@test map(m′.f, m′.args...) ref
@test axes(m) == axes(bc)
@test axes(m′) == axes(bc)
@test copyto!(similar(m, elt), m) ref
@test copyto!(similar(m′, elt), m) ref
end
end

0 comments on commit 3e63462

Please sign in to comment.