Skip to content

Commit

Permalink
Contraction path optimization with EinExprs (#120)
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing authored Feb 13, 2024
1 parent da7636e commit fbb4e53
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 1 deletion.
9 changes: 9 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19"
Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
Expand All @@ -31,6 +32,12 @@ Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"

[weakdeps]
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"

[extensions]
ITensorNetworksEinExprsExt = "EinExprs"

[compat]
AbstractTrees = "0.4.4"
Combinatorics = "1"
Expand All @@ -40,6 +47,7 @@ DataStructures = "0.18"
Dictionaries = "0.4"
Distributions = "0.25.86"
DocStringExtensions = "0.8, 0.9"
EinExprs = "0.6.4"
Graphs = "1.8"
GraphsFlows = "0.1.1"
ITensors = "0.3.23"
Expand All @@ -59,6 +67,7 @@ TupleTools = "1.4"
julia = "1.7"

[extras]
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
Expand Down
53 changes: 53 additions & 0 deletions ext/ITensorNetworksEinExprsExt/ITensorNetworksEinExprsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
module ITensorNetworksEinExprsExt

using ITensors: Index, ITensor, @Algorithm_str, inds, noncommoninds
using ITensorNetworks:
ITensorNetworks, ITensorNetwork, vertextype, vertex_data, contraction_sequence
using EinExprs: EinExprs, EinExpr, einexpr, SizedEinExpr

function to_einexpr(ts::Vector{ITensor})
IndexType = Any

tensor_exprs = EinExpr{IndexType}[]
inds_dims = Dict{IndexType,Int}()

for tensor_v in ts
inds_v = collect(inds(tensor_v))
push!(tensor_exprs, EinExpr{IndexType}(; head=inds_v))
merge!(inds_dims, Dict(inds_v .=> size(tensor_v)))
end

externalinds_tn = reduce(noncommoninds, ts)
return SizedEinExpr(sum(tensor_exprs; skip=externalinds_tn), inds_dims)
end

function tensor_inds_to_vertex(ts::Vector{ITensor})
IndexType = Any
VertexType = Int

mapping = Dict{Set{IndexType},VertexType}()

for (v, tensor_v) in enumerate(ts)
inds_v = collect(inds(tensor_v))
mapping[Set(inds_v)] = v
end

return mapping
end

function ITensorNetworks.contraction_sequence(
::Algorithm"einexpr", tn::Vector{ITensor}; optimizer=EinExprs.Exhaustive()
)
expr = to_einexpr(tn)
path = einexpr(optimizer, expr)
return to_contraction_sequence(path, tensor_inds_to_vertex(tn))
end

function to_contraction_sequence(expr, tensor_inds_to_vertex)
EinExprs.nargs(expr) == 0 && return tensor_inds_to_vertex[Set(expr.head)]
return map(
expr -> to_contraction_sequence(expr, tensor_inds_to_vertex), EinExprs.args(expr)
)
end

end
2 changes: 2 additions & 0 deletions src/ITensorNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using LinearAlgebra
using NamedGraphs
using Observers
using Observers.DataFrames: select!
using PackageExtensionCompat
using Printf
using Requires
using SimpleTraits
Expand Down Expand Up @@ -130,6 +131,7 @@ include(joinpath("treetensornetworks", "solvers", "tree_sweeping.jl"))
include("exports.jl")

function __init__()
@require_extensions
@require OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715" include(
joinpath("requires", "omeinsumcontractionorders.jl")
)
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889"
Expand Down
14 changes: 13 additions & 1 deletion test/test_contraction_sequence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using ITensorNetworks
using OMEinsumContractionOrders
using Random
using Test
using EinExprs: Exhaustive, Greedy, HyPar

Random.seed!(1234)

Expand All @@ -23,7 +24,15 @@ ITensors.disable_warn_order()
res_tree_sa = contract(tn; sequence=seq_tree_sa)[]
seq_sa_bipartite = contraction_sequence(tn; alg="sa_bipartite")
res_sa_bipartite = contract(tn; sequence=seq_sa_bipartite)[]
@test res_optimal res_greedy res_tree_sa res_sa_bipartite
seq_einexprs_exhaustive = contraction_sequence(tn; alg="einexpr", optimizer=Exhaustive())
res_einexprs_exhaustive = contract(tn; sequence=seq_einexprs_exhaustive)[]
seq_einexprs_greedy = contraction_sequence(tn; alg="einexpr", optimizer=Greedy())
res_einexprs_greedy = contract(tn; sequence=seq_einexprs_exhaustive)[]
@test res_greedy res_optimal
@test res_tree_sa res_optimal
@test res_sa_bipartite res_optimal
@test res_einexprs_exhaustive res_optimal
@test res_einexprs_greedy res_optimal

if !Sys.iswindows()
# KaHyPar doesn't work on Windows
Expand All @@ -34,5 +43,8 @@ ITensors.disable_warn_order()
seq_kahypar_bipartite = contraction_sequence(tn; alg="kahypar_bipartite", sc_target=200)
res_kahypar_bipartite = contract(tn; sequence=seq_kahypar_bipartite)[]
@test res_optimal res_kahypar_bipartite
seq_einexprs_kahypar = contraction_sequence(tn; alg="einexpr", optimizer=HyPar())
res_einexprs_kahypar = contract(tn; sequence=seq_einexprs_kahypar)[]
@test res_einexprs_kahypar res_optimal
end
end

0 comments on commit fbb4e53

Please sign in to comment.