diff --git a/Project.toml b/Project.toml index a2cd740d..877465cc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworks" uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7" authors = ["Matthew Fishman , Joseph Tindall and contributors"] -version = "0.11.12" +version = "0.11.13" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -36,12 +36,14 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" [weakdeps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889" OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715" Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0" [extensions] +ITensorNetworksAdaptExt = "Adapt" ITensorNetworksEinExprsExt = "EinExprs" ITensorNetworksGraphsFlowsExt = "GraphsFlows" ITensorNetworksOMEinsumContractionOrdersExt = "OMEinsumContractionOrders" @@ -49,6 +51,7 @@ ITensorNetworksObserversExt = "Observers" [compat] AbstractTrees = "0.4.4" +Adapt = "4" Combinatorics = "1" Compat = "3, 4" DataGraphs = "0.2.3" @@ -82,6 +85,7 @@ TupleTools = "1.4" julia = "1.10" [extras] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889" OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715" diff --git a/ext/ITensorNetworksAdaptExt/ITensorNetworksAdaptExt.jl b/ext/ITensorNetworksAdaptExt/ITensorNetworksAdaptExt.jl new file mode 100644 index 00000000..0051da93 --- /dev/null +++ b/ext/ITensorNetworksAdaptExt/ITensorNetworksAdaptExt.jl @@ -0,0 +1,14 @@ +module ITensorNetworksAdaptExt +using Adapt: Adapt, adapt +using ITensorNetworks: AbstractITensorNetwork, map_vertex_data_preserve_graph +function Adapt.adapt_structure(to, tn::AbstractITensorNetwork) + # TODO: Define and use: + # + # @preserve_graph map_vertex_data(adapt(to), tn) + # + # or just: + # + # @preserve_graph map(adapt(to), tn) + return map_vertex_data_preserve_graph(adapt(to), tn) +end +end diff --git a/test/Project.toml b/test/Project.toml index 3dd73b41..70eb14d3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" diff --git a/test/test_ext/Project.toml b/test/test_ext/Project.toml new file mode 100644 index 00000000..4d124728 --- /dev/null +++ b/test/test_ext/Project.toml @@ -0,0 +1,5 @@ +[deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ITensorNetworks = "2919e153-833c-4bdc-8836-1ea460a35fc7" +ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" +NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" diff --git a/test/test_ext/test_itensornetworksadaptext.jl b/test/test_ext/test_itensornetworksadaptext.jl new file mode 100644 index 00000000..0578511a --- /dev/null +++ b/test/test_ext/test_itensornetworksadaptext.jl @@ -0,0 +1,23 @@ +@eval module $(gensym()) +using Adapt: Adapt, adapt +using NamedGraphs.NamedGraphGenerators: named_grid +using ITensorNetworks: random_tensornetwork, siteinds +using ITensors: ITensors +using Test: @test, @testset + +struct SinglePrecisionAdaptor end +single_precision(::Type{<:AbstractFloat}) = Float32 +single_precision(type::Type{<:Complex}) = complex(single_precision(real(type))) +Adapt.adapt_storage(::SinglePrecisionAdaptor, x) = single_precision(eltype(x)).(x) + +@testset "Test ITensorNetworksAdaptExt (eltype=$elt)" for elt in ( + Float32, Float64, Complex{Float32}, Complex{Float64} +) + g = named_grid((2, 2)) + s = siteinds("S=1/2", g) + tn = random_tensornetwork(elt, s) + @test ITensors.scalartype(tn) === elt + tn′ = adapt(SinglePrecisionAdaptor(), tn) + @test ITensors.scalartype(tn′) === single_precision(elt) +end +end