From ce6909e435ab7208155be59cd7e0ba171aae412b Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 30 May 2024 12:01:52 -0400 Subject: [PATCH 1/4] Add support for Adapt --- Project.toml | 4 ++++ .../ITensorNetworksAdaptExt.jl | 14 +++++++++++ test/Project.toml | 1 + test/test_ext/Project.toml | 5 ++++ test/test_ext/test_itensornetworksadaptext.jl | 24 +++++++++++++++++++ 5 files changed, 48 insertions(+) create mode 100644 ext/ITensorNetworksAdaptExt/ITensorNetworksAdaptExt.jl create mode 100644 test/test_ext/Project.toml create mode 100644 test/test_ext/test_itensornetworksadaptext.jl diff --git a/Project.toml b/Project.toml index a2cd740d..f36640bc 100644 --- a/Project.toml +++ b/Project.toml @@ -36,12 +36,14 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" [weakdeps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889" OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715" Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0" [extensions] +ITensorNetworksAdaptExt = "Adapt" ITensorNetworksEinExprsExt = "EinExprs" ITensorNetworksGraphsFlowsExt = "GraphsFlows" ITensorNetworksOMEinsumContractionOrdersExt = "OMEinsumContractionOrders" @@ -49,6 +51,7 @@ ITensorNetworksObserversExt = "Observers" [compat] AbstractTrees = "0.4.4" +Adapt = "4" Combinatorics = "1" Compat = "3, 4" DataGraphs = "0.2.3" @@ -82,6 +85,7 @@ TupleTools = "1.4" julia = "1.10" [extras] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889" OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715" diff --git a/ext/ITensorNetworksAdaptExt/ITensorNetworksAdaptExt.jl b/ext/ITensorNetworksAdaptExt/ITensorNetworksAdaptExt.jl new file mode 100644 index 00000000..0051da93 --- /dev/null +++ b/ext/ITensorNetworksAdaptExt/ITensorNetworksAdaptExt.jl @@ -0,0 +1,14 @@ +module ITensorNetworksAdaptExt +using Adapt: Adapt, adapt +using ITensorNetworks: AbstractITensorNetwork, map_vertex_data_preserve_graph +function Adapt.adapt_structure(to, tn::AbstractITensorNetwork) + # TODO: Define and use: + # + # @preserve_graph map_vertex_data(adapt(to), tn) + # + # or just: + # + # @preserve_graph map(adapt(to), tn) + return map_vertex_data_preserve_graph(adapt(to), tn) +end +end diff --git a/test/Project.toml b/test/Project.toml index 3dd73b41..70eb14d3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" DataGraphs = "b5a273c3-7e6c-41f6-98bd-8d7f1525a36a" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" diff --git a/test/test_ext/Project.toml b/test/test_ext/Project.toml new file mode 100644 index 00000000..4d124728 --- /dev/null +++ b/test/test_ext/Project.toml @@ -0,0 +1,5 @@ +[deps] +Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +ITensorNetworks = "2919e153-833c-4bdc-8836-1ea460a35fc7" +ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5" +NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" diff --git a/test/test_ext/test_itensornetworksadaptext.jl b/test/test_ext/test_itensornetworksadaptext.jl new file mode 100644 index 00000000..ac4366d6 --- /dev/null +++ b/test/test_ext/test_itensornetworksadaptext.jl @@ -0,0 +1,24 @@ +@eval module $(gensym()) +using Adapt: Adapt, adapt +using NamedGraphs.NamedGraphGenerators: named_grid +using ITensorNetworks: random_tensornetwork, siteinds +using ITensors: ITensors +using Test: @test + +struct SinglePrecisionAdaptor end +single_precision(::Type{<:AbstractFloat}) = Float32 +single_precision(type::Type{<:Complex}) = complex(single_precision(real(type))) +Adapt.adapt_storage(::SinglePrecisionAdaptor, x) = single_precision(eltype(x)).(x) + +@testset "Test ITensorNetworksAdaptExt (eltype=$elt)" for elt in ( + Float32, Float64, Complex{Float32}, Complex{Float64} +) + g = named_grid((2, 2)) + s = siteinds("S=1/2", g) + tn = random_tensornetwork(elt, s) + @test ITensors.scalartype(tn) === elt + tn′ = adapt(SinglePrecisionAdaptor(), tn) + @show ITensors.scalartype(tn), ITensors.scalartype(tn′) + @test ITensors.scalartype(tn′) === single_precision(elt) +end +end From 33c1101032ea4d0f418424c91b3bbea2fd8b31ef Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 30 May 2024 12:05:37 -0400 Subject: [PATCH 2/4] Fix test --- test/test_ext/test_itensornetworksadaptext.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_ext/test_itensornetworksadaptext.jl b/test/test_ext/test_itensornetworksadaptext.jl index ac4366d6..e781f313 100644 --- a/test/test_ext/test_itensornetworksadaptext.jl +++ b/test/test_ext/test_itensornetworksadaptext.jl @@ -3,7 +3,7 @@ using Adapt: Adapt, adapt using NamedGraphs.NamedGraphGenerators: named_grid using ITensorNetworks: random_tensornetwork, siteinds using ITensors: ITensors -using Test: @test +using Test: @test, @testset struct SinglePrecisionAdaptor end single_precision(::Type{<:AbstractFloat}) = Float32 From b694199cfb99b9f9f9f5946719e569771aa964f5 Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 30 May 2024 12:07:37 -0400 Subject: [PATCH 3/4] Remove show --- test/test_ext/test_itensornetworksadaptext.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_ext/test_itensornetworksadaptext.jl b/test/test_ext/test_itensornetworksadaptext.jl index e781f313..0578511a 100644 --- a/test/test_ext/test_itensornetworksadaptext.jl +++ b/test/test_ext/test_itensornetworksadaptext.jl @@ -18,7 +18,6 @@ Adapt.adapt_storage(::SinglePrecisionAdaptor, x) = single_precision(eltype(x)).( tn = random_tensornetwork(elt, s) @test ITensors.scalartype(tn) === elt tn′ = adapt(SinglePrecisionAdaptor(), tn) - @show ITensors.scalartype(tn), ITensors.scalartype(tn′) @test ITensors.scalartype(tn′) === single_precision(elt) end end From 3926993b5c5e3e468c39a7a6612d123ab907236d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Thu, 30 May 2024 12:18:02 -0400 Subject: [PATCH 4/4] Bump to v0.11.13 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index f36640bc..877465cc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworks" uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7" authors = ["Matthew Fishman , Joseph Tindall and contributors"] -version = "0.11.12" +version = "0.11.13" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"