From 4756d2869105cd51c66e164816292ce1fdd5b71d Mon Sep 17 00:00:00 2001 From: mtfishman Date: Fri, 5 Apr 2024 16:31:24 -0400 Subject: [PATCH] Allow more customization in ITensorNetwork constructor --- src/itensornetwork.jl | 42 +++++++++++++++++------------------ src/specialitensornetworks.jl | 6 ++--- test/test_itensornetwork.jl | 6 ++--- 3 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/itensornetwork.jl b/src/itensornetwork.jl index 8a6a26cd..bdc7374d 100644 --- a/src/itensornetwork.jl +++ b/src/itensornetwork.jl @@ -1,5 +1,5 @@ using DataGraphs: DataGraphs, DataGraph -using Dictionaries: dictionary +using Dictionaries: Indices, dictionary using ITensors: ITensors, ITensor using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype @@ -153,20 +153,26 @@ function ITensorNetwork(inds_network::IndsNetwork; kwargs...) end # TODO: Handle `eltype` and `undef` through `generic_state`. -function generic_state(f, inds...) - return f(inds...) +# `inds` are stored in a `NamedTuple` +function generic_state(f, inds::NamedTuple) + return generic_state(f, reduce(vcat, inds.linkinds; init=inds.siteinds)) end -function generic_state(a::AbstractArray, inds...) - return itensor(a, inds...) + +function generic_state(f, inds::Vector) + return f(inds) +end +function generic_state(a::AbstractArray, inds::Vector) + return itensor(a, inds) end -function generic_state(s::AbstractString, inds...) - tensor = state(s, inds[1]) - # TODO: Remove this and handle with `insert_missing_linkinds`. - return contract(tensor, onehot.(vcat(inds[2:end]...) .=> 1)...) +function generic_state(s::AbstractString, inds::NamedTuple) + # TODO: Handle the case of multiple site indices. + site_tensors = [state(s, only(inds.siteinds))] + link_tensors = [[onehot(i => 1) for i in inds.linkinds[e]] for e in keys(inds.linkinds)] + return contract(reduce(vcat, link_tensors; init=site_tensors)) end -# TODO: This is repeated from `ModelHamiltonians`, put into a -# single location (such as a `MakeCallable` submodule). +# TODO: This is similar to `ModelHamiltonians.to_callable`, +# try merging the two. to_callable(value::Type) = value to_callable(value::Function) = value to_callable(value::AbstractDict) = Base.Fix1(getindex, value) @@ -194,7 +200,7 @@ function ITensorNetwork( itensor_constructor::Function, inds_network::IndsNetwork; link_space=1, kwargs... ) if isnothing(link_space) - # Make the the link space is set + # Make sure the link space is set link_space = 1 end # Graphs.jl uses `zero` to create a graph of the same type @@ -212,15 +218,9 @@ function ITensorNetwork( end for v in vertices(tn) siteinds = get(inds_network, v, indtype(inds_network)[]) - linkinds = [ - get(inds_network, edgetype(inds_network)(v, nv), indtype(inds_network)[]) for - nv in neighbors(inds_network, v) - ] - # TODO: Come up with a better interface besides flattening the indices. - # Maybe a namedtuple or dictionary, which indicates which indices - # are site and links, which edges the links are associated with, etc. - tensor_v = generic_state(itensor_constructor(v), siteinds..., vcat(linkinds...)...) - # TODO: Call `insert_missing_linkinds` instead. + edges = [edgetype(inds_network)(v, nv) for nv in neighbors(inds_network, v)] + linkinds = map(e -> get(inds_network, e, indtype(inds_network)[]), Indices(edges)) + tensor_v = generic_state(itensor_constructor(v), (; siteinds, linkinds)) setindex_preserve_graph!(tn, tensor_v, v) end return tn diff --git a/src/specialitensornetworks.jl b/src/specialitensornetworks.jl index 6416f12d..015cab39 100644 --- a/src/specialitensornetworks.jl +++ b/src/specialitensornetworks.jl @@ -9,7 +9,7 @@ Note that passing a link_space will mean the indices of the resulting network do """ function delta_network(eltype::Type, s::IndsNetwork; link_space=nothing) return ITensorNetwork(s; link_space) do v - return (inds...) -> delta(eltype, inds...) + return inds -> delta(eltype, inds) end end @@ -30,7 +30,7 @@ Build an ITensor network on a graph specified by the inds network s. Bond_dim is """ function random_tensornetwork(eltype::Type, s::IndsNetwork; link_space=nothing) return ITensorNetwork(s; link_space) do v - return (inds...) -> itensor(randn(eltype, dim(inds)...), inds...) + return inds -> itensor(randn(eltype, dim.(inds)...), inds) end end @@ -57,7 +57,7 @@ function random_tensornetwork( distribution::Distribution, s::IndsNetwork; link_space=nothing ) return ITensorNetwork(s; link_space) do v - return (inds...) -> itensor(rand(distribution, dim(inds)...), inds...) + return inds -> itensor(rand(distribution, dim.(inds)...), inds) end end diff --git a/test/test_itensornetwork.jl b/test/test_itensornetwork.jl index 7351fb97..c44f80da 100644 --- a/test/test_itensornetwork.jl +++ b/test/test_itensornetwork.jl @@ -29,7 +29,7 @@ using ITensors: order, sim, uniqueinds -using ITensors.NDTensors: dims +using ITensors.NDTensors: dim using ITensorNetworks: ITensorNetworks, ⊗, @@ -160,11 +160,11 @@ using Test: @test, @test_broken, @testset ) ψ = ITensorNetwork(g; link_space) do v - return (inds...) -> itensor(randn(elt, dims(inds)...), inds...) + return inds -> itensor(randn(elt, dim.(inds)...), inds) end @test eltype(ψ[first(vertices(ψ))]) == elt ψ = ITensorNetwork(g; link_space) do v - return (inds...) -> itensor(randn(dims(inds)...), inds...) + return inds -> itensor(randn(dim.(inds)...), inds) end @test eltype(ψ[first(vertices(ψ))]) == Float64 ψ = random_tensornetwork(elt, g; link_space)