diff --git a/Project.toml b/Project.toml index 2765a618..e20bc55b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ITensorNetworks" uuid = "2919e153-833c-4bdc-8836-1ea460a35fc7" authors = ["Matthew Fishman , Joseph Tindall and contributors"] -version = "0.11.10" +version = "0.11.11" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -19,6 +19,7 @@ IsApprox = "28f27b66-4bd8-47e7-9110-e2746eb8bed7" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" @@ -63,6 +64,7 @@ ITensors = "0.6.8" IsApprox = "0.1" IterTools = "1.4.0" KrylovKit = "0.6, 0.7" +MacroTools = "0.5" NDTensors = "0.3" NamedGraphs = "0.6.0" OMEinsumContractionOrders = "0.8.3" diff --git a/src/abstractitensornetwork.jl b/src/abstractitensornetwork.jl index b7d75327..6f6ee164 100644 --- a/src/abstractitensornetwork.jl +++ b/src/abstractitensornetwork.jl @@ -39,6 +39,7 @@ using ITensors: using ITensorMPS: ITensorMPS, add, linkdim, linkinds, siteinds using .ITensorsExtensions: ITensorsExtensions, indtype, promote_indtype using LinearAlgebra: LinearAlgebra, factorize +using MacroTools: @capture using NamedGraphs: NamedGraphs, NamedGraph, not_implemented using NamedGraphs.GraphsExtensions: ⊔, directed_graph, incident_edges, rename_vertices, vertextype @@ -138,6 +139,30 @@ function setindex_preserve_graph!(tn::AbstractITensorNetwork, value, vertex) return tn end +# TODO: Move to `BaseExtensions` module. +function is_setindex!_expr(expr::Expr) + return is_assignment_expr(expr) && is_getindex_expr(first(expr.args)) +end +is_setindex!_expr(x) = false +is_getindex_expr(expr::Expr) = (expr.head === :ref) +is_getindex_expr(x) = false +is_assignment_expr(expr::Expr) = (expr.head === :(=)) +is_assignment_expr(expr) = false + +# TODO: Define this in terms of a function mapping +# preserve_graph_function(::typeof(setindex!)) = setindex!_preserve_graph +# preserve_graph_function(::typeof(map_vertex_data)) = map_vertex_data_preserve_graph +# Also allow annotating codeblocks like `@views`. +macro preserve_graph(expr) + if !is_setindex!_expr(expr) + error( + "preserve_graph must be used with setindex! syntax (as @preserve_graph a[i,j,...] = value)", + ) + end + @capture(expr, array_[indices__] = value_) + return :(setindex_preserve_graph!($(esc(array)), $(esc(value)), $(esc.(indices)...))) +end + function ITensors.hascommoninds(tn::AbstractITensorNetwork, edge::Pair) return hascommoninds(tn, edgetype(tn)(edge)) end @@ -148,7 +173,7 @@ end function Base.setindex!(tn::AbstractITensorNetwork, value, v) # v = to_vertex(tn, index...) - setindex_preserve_graph!(tn, value, v) + @preserve_graph tn[v] = value for edge in incident_edges(tn, v) rem_edge!(tn, edge) end @@ -297,12 +322,12 @@ function ITensors.replaceinds( @assert underlying_graph(is) == underlying_graph(is′) for v in vertices(is) isassigned(is, v) || continue - setindex_preserve_graph!(tn, replaceinds(tn[v], is[v] => is′[v]), v) + @preserve_graph tn[v] = replaceinds(tn[v], is[v] => is′[v]) end for e in edges(is) isassigned(is, e) || continue for v in (src(e), dst(e)) - setindex_preserve_graph!(tn, replaceinds(tn[v], is[e] => is′[e]), v) + @preserve_graph tn[v] = replaceinds(tn[v], is[e] => is′[e]) end end return tn @@ -361,13 +386,31 @@ end LinearAlgebra.adjoint(tn::Union{IndsNetwork,AbstractITensorNetwork}) = prime(tn) -#dag(tn::AbstractITensorNetwork) = map_vertex_data(dag, tn) -function ITensors.dag(tn::AbstractITensorNetwork) - tndag = copy(tn) - for v in vertices(tndag) - setindex_preserve_graph!(tndag, dag(tndag[v]), v) +function map_vertex_data(f, tn::AbstractITensorNetwork) + tn = copy(tn) + for v in vertices(tn) + tn[v] = f(tn[v]) end - return tndag + return tn +end + +# TODO: Define `@preserve_graph map_vertex_data(f, tn)` +function map_vertex_data_preserve_graph(f, tn::AbstractITensorNetwork) + tn = copy(tn) + for v in vertices(tn) + @preserve_graph tn[v] = f(tn[v]) + end + return tn +end + +function Base.conj(tn::AbstractITensorNetwork) + # TODO: Use `@preserve_graph map_vertex_data(f, tn)` + return map_vertex_data_preserve_graph(conj, tn) +end + +function ITensors.dag(tn::AbstractITensorNetwork) + # TODO: Use `@preserve_graph map_vertex_data(f, tn)` + return map_vertex_data_preserve_graph(dag, tn) end # TODO: should this make sure that internal indices @@ -442,9 +485,7 @@ function NDTensors.contract( for n_dst in neighbors_dst add_edge!(tn, merged_vertex => n_dst) end - - setindex_preserve_graph!(tn, new_itensor, merged_vertex) - + @preserve_graph tn[merged_vertex] = new_itensor return tn end @@ -533,13 +574,8 @@ function LinearAlgebra.factorize( add_edge!(tn, X_vertex => nX) end add_edge!(tn, Y_vertex => dst(edge)) - - # tn[X_vertex] = X - setindex_preserve_graph!(tn, X, X_vertex) - - # tn[Y_vertex] = Y - setindex_preserve_graph!(tn, Y, Y_vertex) - + @preserve_graph tn[X_vertex] = X + @preserve_graph tn[Y_vertex] = Y return tn end diff --git a/test/test_itensornetwork.jl b/test/test_itensornetwork.jl index 7e97c6a3..ba3caa01 100644 --- a/test/test_itensornetwork.jl +++ b/test/test_itensornetwork.jl @@ -175,6 +175,17 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) return inds -> itensor(randn(rng, elt, dim.(inds)...), inds) end @test eltype(ψ[first(vertices(ψ))]) == elt + + ψc = conj(ψ) + for v in vertices(ψ) + @test ψc[v] == conj(ψ[v]) + end + + ψd = dag(ψ) + for v in vertices(ψ) + @test ψd[v] == dag(ψ[v]) + end + rng = StableRNG(1234) ψ = ITensorNetwork(g; kwargs...) do v return inds -> itensor(randn(rng, dim.(inds)...), inds)