Skip to content

Commit

Permalink
Allow more customization in ITensorNetwork constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Apr 5, 2024
1 parent b2db922 commit 4756d28
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 27 deletions.
42 changes: 21 additions & 21 deletions src/itensornetwork.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using DataGraphs: DataGraphs, DataGraph
using Dictionaries: dictionary
using Dictionaries: Indices, dictionary
using ITensors: ITensors, ITensor
using NamedGraphs: NamedGraphs, NamedEdge, NamedGraph, vertextype

Expand Down Expand Up @@ -153,20 +153,26 @@ function ITensorNetwork(inds_network::IndsNetwork; kwargs...)
end

# TODO: Handle `eltype` and `undef` through `generic_state`.
function generic_state(f, inds...)
return f(inds...)
# `inds` are stored in a `NamedTuple`
function generic_state(f, inds::NamedTuple)
return generic_state(f, reduce(vcat, inds.linkinds; init=inds.siteinds))
end
function generic_state(a::AbstractArray, inds...)
return itensor(a, inds...)

function generic_state(f, inds::Vector)
return f(inds)
end
function generic_state(a::AbstractArray, inds::Vector)
return itensor(a, inds)
end
function generic_state(s::AbstractString, inds...)
tensor = state(s, inds[1])
# TODO: Remove this and handle with `insert_missing_linkinds`.
return contract(tensor, onehot.(vcat(inds[2:end]...) .=> 1)...)
function generic_state(s::AbstractString, inds::NamedTuple)
# TODO: Handle the case of multiple site indices.
site_tensors = [state(s, only(inds.siteinds))]
link_tensors = [[onehot(i => 1) for i in inds.linkinds[e]] for e in keys(inds.linkinds)]
return contract(reduce(vcat, link_tensors; init=site_tensors))
end

# TODO: This is repeated from `ModelHamiltonians`, put into a
# single location (such as a `MakeCallable` submodule).
# TODO: This is similar to `ModelHamiltonians.to_callable`,
# try merging the two.
to_callable(value::Type) = value
to_callable(value::Function) = value
to_callable(value::AbstractDict) = Base.Fix1(getindex, value)
Expand Down Expand Up @@ -194,7 +200,7 @@ function ITensorNetwork(
itensor_constructor::Function, inds_network::IndsNetwork; link_space=1, kwargs...
)
if isnothing(link_space)
# Make the the link space is set
# Make sure the link space is set
link_space = 1
end
# Graphs.jl uses `zero` to create a graph of the same type
Expand All @@ -212,15 +218,9 @@ function ITensorNetwork(
end
for v in vertices(tn)
siteinds = get(inds_network, v, indtype(inds_network)[])
linkinds = [
get(inds_network, edgetype(inds_network)(v, nv), indtype(inds_network)[]) for
nv in neighbors(inds_network, v)
]
# TODO: Come up with a better interface besides flattening the indices.
# Maybe a namedtuple or dictionary, which indicates which indices
# are site and links, which edges the links are associated with, etc.
tensor_v = generic_state(itensor_constructor(v), siteinds..., vcat(linkinds...)...)
# TODO: Call `insert_missing_linkinds` instead.
edges = [edgetype(inds_network)(v, nv) for nv in neighbors(inds_network, v)]
linkinds = map(e -> get(inds_network, e, indtype(inds_network)[]), Indices(edges))
tensor_v = generic_state(itensor_constructor(v), (; siteinds, linkinds))
setindex_preserve_graph!(tn, tensor_v, v)
end
return tn
Expand Down
6 changes: 3 additions & 3 deletions src/specialitensornetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Note that passing a link_space will mean the indices of the resulting network do
"""
function delta_network(eltype::Type, s::IndsNetwork; link_space=nothing)
return ITensorNetwork(s; link_space) do v
return (inds...) -> delta(eltype, inds...)
return inds -> delta(eltype, inds)
end
end

Expand All @@ -30,7 +30,7 @@ Build an ITensor network on a graph specified by the inds network s. Bond_dim is
"""
function random_tensornetwork(eltype::Type, s::IndsNetwork; link_space=nothing)
return ITensorNetwork(s; link_space) do v
return (inds...) -> itensor(randn(eltype, dim(inds)...), inds...)
return inds -> itensor(randn(eltype, dim.(inds)...), inds)
end
end

Expand All @@ -57,7 +57,7 @@ function random_tensornetwork(
distribution::Distribution, s::IndsNetwork; link_space=nothing
)
return ITensorNetwork(s; link_space) do v
return (inds...) -> itensor(rand(distribution, dim(inds)...), inds...)
return inds -> itensor(rand(distribution, dim.(inds)...), inds)
end
end

Expand Down
6 changes: 3 additions & 3 deletions test/test_itensornetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using ITensors:
order,
sim,
uniqueinds
using ITensors.NDTensors: dims
using ITensors.NDTensors: dim
using ITensorNetworks:
ITensorNetworks,
,
Expand Down Expand Up @@ -160,11 +160,11 @@ using Test: @test, @test_broken, @testset
)

ψ = ITensorNetwork(g; link_space) do v
return (inds...) -> itensor(randn(elt, dims(inds)...), inds...)
return inds -> itensor(randn(elt, dim.(inds)...), inds)
end
@test eltype(ψ[first(vertices(ψ))]) == elt
ψ = ITensorNetwork(g; link_space) do v
return (inds...) -> itensor(randn(dims(inds)...), inds...)
return inds -> itensor(randn(dim.(inds)...), inds)
end
@test eltype(ψ[first(vertices(ψ))]) == Float64
ψ = random_tensornetwork(elt, g; link_space)
Expand Down

0 comments on commit 4756d28

Please sign in to comment.