Skip to content

Commit

Permalink
Fix more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman committed Apr 17, 2024
1 parent 8b4e8a5 commit 2cc2721
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 29 deletions.
8 changes: 4 additions & 4 deletions src/approx_itensornetwork/partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@ using ITensors: ITensor, noncommoninds
using NamedGraphs: NamedGraph, subgraph

function _partition(g::AbstractGraph, subgraph_vertices)
partitioned_graph = DataGraph(
NamedGraph(eachindex(subgraph_vertices)),
map(vs -> subgraph(g, vs), Dictionary(subgraph_vertices)),
)
partitioned_graph = DataGraph(NamedGraph(eachindex(subgraph_vertices)))
for v in vertices(partitioned_graph)
partitioned_graph[v] = subgraph(g, subgraph_vertices[v])
end
for e in edges(g)
s1 = findfirst_on_vertices(subgraph -> src(e) vertices(subgraph), partitioned_graph)
s2 = findfirst_on_vertices(subgraph -> dst(e) vertices(subgraph), partitioned_graph)
Expand Down
26 changes: 10 additions & 16 deletions src/caches/beliefpropagationcache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,8 @@ function message(bp_cache::BeliefPropagationCache, edge::PartitionEdge)
mts = messages(bp_cache)
return get(mts, edge, default_message(bp_cache, edge))
end
function messages(
bp_cache::BeliefPropagationCache, edges::Vector{<:PartitionEdge}; kwargs...
)
return [message(bp_cache, edge; kwargs...) for edge in edges]
function messages(bp_cache::BeliefPropagationCache, edges; kwargs...)
return map(edge -> message(bp_cache, edge; kwargs...), edges)
end

function Base.copy(bp_cache::BeliefPropagationCache)
Expand Down Expand Up @@ -256,21 +254,18 @@ end
"""
Update the tensornetwork inside the cache
"""
function update_factors(
bp_cache::BeliefPropagationCache, vertices::Vector, factors::Vector{ITensor}
)
function update_factors(bp_cache::BeliefPropagationCache, factors)
bp_cache = copy(bp_cache)
tn = tensornetwork(bp_cache)

for (vertex, factor) in zip(vertices, factors)
for vertex in eachindex(factors)
# TODO: Add a check that this preserves the graph structure.
setindex_preserve_graph!(tn, factor, vertex)
setindex_preserve_graph!(tn, factors[vertex], vertex)
end
return bp_cache
end

function update_factor(bp_cache, vertex, factor)
return update_factors(bp_cache, [vertex], ITensor[factor])
return update_factors(bp_cache, Dictionary([vertex], [factor]))
end

function region_scalar(bp_cache::BeliefPropagationCache, pv::PartitionVertex)
Expand All @@ -285,16 +280,15 @@ end

function vertex_scalars(
bp_cache::BeliefPropagationCache,
pvs::Vector=partitionvertices(partitioned_tensornetwork(bp_cache)),
pvs=partitionvertices(partitioned_tensornetwork(bp_cache)),
)
return [region_scalar(bp_cache, pv) for pv in pvs]
return map(pv -> region_scalar(bp_cache, pv), pvs)
end

function edge_scalars(
bp_cache::BeliefPropagationCache,
pes::Vector=partitionedges(partitioned_tensornetwork(bp_cache)),
bp_cache::BeliefPropagationCache, pes=partitionedges(partitioned_tensornetwork(bp_cache))
)
return [region_scalar(bp_cache, pe) for pe in pes]
return map(pe -> region_scalar(bp_cache, pe), pes)
end

function scalar_factors_quotient(bp_cache::BeliefPropagationCache)
Expand Down
4 changes: 2 additions & 2 deletions src/edge_sequences.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Graphs: IsDirected, connected_components, edges, edgetype
using ITensors.NDTensors: Algorithm, @Algorithm_str
using NamedGraphs: NamedGraphs
using NamedGraphs.GraphsExtensions: GraphsExtensions, undirected_graph
using NamedGraphs.GraphsExtensions: GraphsExtensions, forest_cover, undirected_graph
using NamedGraphs.PartitionedGraphs: PartitionEdge, PartitionedGraph, partitioned_graph
using SimpleTraits: SimpleTraits, @traitfn, Not
using SimpleTraits
Expand Down Expand Up @@ -30,7 +30,7 @@ end
g::::(!IsDirected);
root_vertex=GraphsExtensions.default_root_vertex,
)
forests = NamedGraphs.forest_cover(g)
forests = forest_cover(g)
edges = edgetype(g)[]
for forest in forests
trees = [forest[vs] for vs in connected_components(forest)]
Expand Down
4 changes: 3 additions & 1 deletion src/gauging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@ using ITensors.NDTensors: dense, scalartype
using NamedGraphs.PartitionedGraphs: partitionedge

function default_bond_tensors::ITensorNetwork)
return DataGraph{vertextype(ψ),Nothing,ITensor}(underlying_graph(ψ))
return DataGraph(
underlying_graph(ψ); edge_data_eltype=Nothing, vertex_data_eltype=ITensor
)
end

struct VidalITensorNetwork{V,BTS} <: AbstractITensorNetwork{V}
Expand Down
4 changes: 2 additions & 2 deletions test/test_binary_tree_partition.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using ITensorNetworks:
binary_tree_structure,
path_graph_structure,
random_tensornetwork
using NamedGraphs: NamedEdge
using NamedGraphs: NamedEdge, NamedGraph
using NamedGraphs.NamedGraphGenerators: named_grid
using NamedGraphs.GraphsExtensions: post_order_dfs_vertices
using OMEinsumContractionOrders: OMEinsumContractionOrders
Expand Down Expand Up @@ -130,7 +130,7 @@ end
underlying_tree = underlying_graph(input_partition)
# Change type of each partition[v] since they will be updated
# with potential data type chage.
p = DataGraph()
p = DataGraph(NamedGraph())
for v in vertices(input_partition)
add_vertex!(p, v)
p[v] = ITensorNetwork{Any}(input_partition[v])
Expand Down
4 changes: 2 additions & 2 deletions test/test_opsum_to_ttn.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@eval module $(gensym())
using DataGraphs: vertex_data
using Dictionaries: Dictionary
using Dictionaries: Dictionary, getindices
using Graphs: add_vertex!, rem_vertex!, add_edge!, rem_edge!, vertices
using ITensors:
ITensors,
Expand Down Expand Up @@ -236,7 +236,7 @@ end

# linearized version
linear_order = [4, 1, 2, 5, 3, 6]
vmap = Dictionary(vertices(is)[linear_order], 1:length(linear_order))
vmap = Dictionary(getindices(vertices(is), linear_order), eachindex(linear_order))
sites = only.(filter(d -> !isempty(d), collect(vertex_data(is_missing_site))))[linear_order]

J1 = -1
Expand Down
2 changes: 1 addition & 1 deletion test/test_sitetype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Graphs: nv, vertices
using ITensorNetworks: IndsNetwork, siteinds
using ITensors: SiteType, hastags, space
using ITensors.NDTensors: dim
using NamedGraphs: named_grid
using NamedGraphs.NamedGraphGenerators: named_grid
using Test: @test, @testset

@testset "Site ind system" begin
Expand Down
2 changes: 1 addition & 1 deletion test/test_ttno.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ using Test: @test, @testset
# operator site inds
is_isp = union_all_inds(is, prime(is; links=[]))
# specify random linear vertex ordering of graph vertices
vertex_order = shuffle(vertices(c))
vertex_order = shuffle(collect(vertices(c)))

@testset "Construct TTN operator from ITensor or Array" begin
cutoff = 1e-10
Expand Down

0 comments on commit 2cc2721

Please sign in to comment.