Skip to content
This repository has been archived by the owner on Jan 20, 2025. It is now read-only.

Commit

Permalink
[ITensors] ITensor wrapping NamedDimsArray (#1268)
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Nov 25, 2023
0 parents commit ecf9824
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
48 changes: 48 additions & 0 deletions src/BroadcastMapConversion.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module BroadcastMapConversion
# Convert broadcast call to map call by capturing array arguments
# with `map_args` and creating a map function with `map_function`.
# Logic from https://github.com/Jutho/Strided.jl/blob/v2.0.4/src/broadcast.jl.

using Base.Broadcast: Broadcasted

const WrappedScalarArgs = Union{AbstractArray{<:Any,0},Ref{<:Any}}

function map_args(bc::Broadcasted, rest...)
return (map_args(bc.args...)..., map_args(rest...)...)
end
map_args(a::AbstractArray, rest...) = (a, map_args(rest...)...)
map_args(a, rest...) = map_args(rest...)
map_args() = ()

struct MapFunction{F,Args<:Tuple}
f::F
args::Args
end
struct Arg end

# construct MapFunction
function map_function(bc::Broadcasted)
args = map_function_tuple(bc.args)
return MapFunction(bc.f, args)
end
map_function_tuple(t::Tuple{}) = t
map_function_tuple(t::Tuple) = (map_function(t[1]), map_function_tuple(Base.tail(t))...)
map_function(a::WrappedScalarArgs) = a[]
map_function(a::AbstractArray) = Arg()
map_function(a) = a

# Evaluate MapFunction
(f::MapFunction)(args...) = apply(f, args)[1]
function apply(f::MapFunction, args)
args, newargs = apply_tuple(f.args, args)
return f.f(args...), newargs
end
apply(a::Arg, args::Tuple) = args[1], Base.tail(args)
apply(a, args) = a, args
apply_tuple(t::Tuple{}, args) = t, args
function apply_tuple(t::Tuple, args)
t1, newargs1 = apply(t[1], args)
ttail, newargs = apply_tuple(Base.tail(t), newargs1)
return (t1, ttail...), newargs
end
end
14 changes: 14 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@eval module $(gensym())
using Test: @test, @testset
using NDTensors.BroadcastMapConversion: map_function, map_args
@testset "BroadcastMapConversion" begin
using Base.Broadcast: Broadcasted
c = 2.2
a = randn(2, 3)
b = randn(2, 3)
bc = Broadcasted(*, (c, a))
@test copy(bc) c * a map(map_function(bc), map_args(bc)...)
bc = Broadcasted(+, (a, b))
@test copy(bc) a + b map(map_function(bc), map_args(bc)...)
end
end

0 comments on commit ecf9824

Please sign in to comment.