diff --git a/src/Adapt.jl b/src/Adapt.jl index 13916eb..82ebf66 100644 --- a/src/Adapt.jl +++ b/src/Adapt.jl @@ -43,6 +43,7 @@ adapt_structure(to, x) = adapt_storage(to, x) adapt_storage(to, x) = x include("base.jl") +include("macro.jl") include("wrappers.jl") end # module diff --git a/src/macro.jl b/src/macro.jl new file mode 100644 index 0000000..0a9ce0e --- /dev/null +++ b/src/macro.jl @@ -0,0 +1,14 @@ +""" + Adapt.@adapt_structure T + +Define a method `adapt_structure(to, obj::T)` which calls `adapt_structure` on each field +of `obj` and constructs a new instance of `T` using the default constuctor `T(...)`. +""" +macro adapt_structure(T) + names = fieldnames(Core.eval(__module__, T)) + quote + function Adapt.adapt_structure(to, obj::$(esc(T))) + $(esc(T))($([:(Adapt.adapt_structure(to, obj.$name)) for name in names]...)) + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 73ab0ef..112ee48 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -168,3 +168,21 @@ end @test Adapt.parent(LinearAlgebra.Transpose{Float64,Array{Float64,1}}) == Array @test Adapt.parent(Adapt.WrappedSubArray{Float64,3,Array{Float64,3}}) == Array end + + +struct MyStruct{A,B} + a::A + b::B +end + +@testset "@adapt_structure" begin + + Adapt.@adapt_structure MyStruct + + u = ones(3) + v = zeros(5) + + @test_adapt CustomArray MyStruct(u,v) MyStruct(CustomArray(u), CustomArray(v)) + @test_adapt CustomArray MyStruct(u,1.0) MyStruct(CustomArray(u), 1.0) + +end