Skip to content

Commit

Permalink
add adapt_structure macro (#38)
Browse files Browse the repository at this point in the history
* add adapt_structure macro

* add tests

* clarify doc

* move struct outside of testset for old scope rules
  • Loading branch information
simonbyrne authored Jan 27, 2021
1 parent a2f9aa4 commit 1b19759
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/Adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ adapt_structure(to, x) = adapt_storage(to, x)
adapt_storage(to, x) = x

include("base.jl")
include("macro.jl")
include("wrappers.jl")

end # module
14 changes: 14 additions & 0 deletions src/macro.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
Adapt.@adapt_structure T
Define a method `adapt_structure(to, obj::T)` which calls `adapt_structure` on each field
of `obj` and constructs a new instance of `T` using the default constuctor `T(...)`.
"""
macro adapt_structure(T)
names = fieldnames(Core.eval(__module__, T))
quote
function Adapt.adapt_structure(to, obj::$(esc(T)))
$(esc(T))($([:(Adapt.adapt_structure(to, obj.$name)) for name in names]...))
end
end
end
18 changes: 18 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,21 @@ end
@test Adapt.parent(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == Array
@test Adapt.parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == Array
end


struct MyStruct{A,B}
a::A
b::B
end

@testset "@adapt_structure" begin

Adapt.@adapt_structure MyStruct

u = ones(3)
v = zeros(5)

@test_adapt CustomArray MyStruct(u,v) MyStruct(CustomArray(u), CustomArray(v))
@test_adapt CustomArray MyStruct(u,1.0) MyStruct(CustomArray(u), 1.0)

end

0 comments on commit 1b19759

Please sign in to comment.