From fbb4e535a2988201f9c51ab7798a8058957c0304 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sergio=20S=C3=A1nchez=20Ram=C3=ADrez?= <15837247+mofeing@users.noreply.github.com> Date: Tue, 13 Feb 2024 19:30:20 +0000 Subject: [PATCH] Contraction path optimization with EinExprs (#120) --- Project.toml | 9 ++++ .../ITensorNetworksEinExprsExt.jl | 53 +++++++++++++++++++ src/ITensorNetworks.jl | 2 + test/Project.toml | 1 + test/test_contraction_sequence.jl | 14 ++++- 5 files changed, 78 insertions(+), 1 deletion(-) create mode 100644 ext/ITensorNetworksEinExprsExt/ITensorNetworksEinExprsExt.jl diff --git a/Project.toml b/Project.toml index 432678e2..874c0c1b 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" NamedGraphs = "678767b0-92e7-4007-89e4-4527a8725b19" Observers = "338f10d5-c7f1-4033-a7d1-f9dec39bcaa0" +PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Requires = "ae029012-a4dd-5104-9daa-d747884805df" SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d" @@ -31,6 +32,12 @@ Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" +[weakdeps] +EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" + +[extensions] +ITensorNetworksEinExprsExt = "EinExprs" + [compat] AbstractTrees = "0.4.4" Combinatorics = "1" @@ -40,6 +47,7 @@ DataStructures = "0.18" Dictionaries = "0.4" Distributions = "0.25.86" DocStringExtensions = "0.8, 0.9" +EinExprs = "0.6.4" Graphs = "1.8" GraphsFlows = "0.1.1" ITensors = "0.3.23" @@ -59,6 +67,7 @@ TupleTools = "1.4" julia = "1.7" [extras] +EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] diff --git a/ext/ITensorNetworksEinExprsExt/ITensorNetworksEinExprsExt.jl b/ext/ITensorNetworksEinExprsExt/ITensorNetworksEinExprsExt.jl new file mode 100644 index 00000000..73d7a847 --- /dev/null +++ b/ext/ITensorNetworksEinExprsExt/ITensorNetworksEinExprsExt.jl @@ -0,0 +1,53 @@ +module ITensorNetworksEinExprsExt + +using ITensors: Index, ITensor, @Algorithm_str, inds, noncommoninds +using ITensorNetworks: + ITensorNetworks, ITensorNetwork, vertextype, vertex_data, contraction_sequence +using EinExprs: EinExprs, EinExpr, einexpr, SizedEinExpr + +function to_einexpr(ts::Vector{ITensor}) + IndexType = Any + + tensor_exprs = EinExpr{IndexType}[] + inds_dims = Dict{IndexType,Int}() + + for tensor_v in ts + inds_v = collect(inds(tensor_v)) + push!(tensor_exprs, EinExpr{IndexType}(; head=inds_v)) + merge!(inds_dims, Dict(inds_v .=> size(tensor_v))) + end + + externalinds_tn = reduce(noncommoninds, ts) + return SizedEinExpr(sum(tensor_exprs; skip=externalinds_tn), inds_dims) +end + +function tensor_inds_to_vertex(ts::Vector{ITensor}) + IndexType = Any + VertexType = Int + + mapping = Dict{Set{IndexType},VertexType}() + + for (v, tensor_v) in enumerate(ts) + inds_v = collect(inds(tensor_v)) + mapping[Set(inds_v)] = v + end + + return mapping +end + +function ITensorNetworks.contraction_sequence( + ::Algorithm"einexpr", tn::Vector{ITensor}; optimizer=EinExprs.Exhaustive() +) + expr = to_einexpr(tn) + path = einexpr(optimizer, expr) + return to_contraction_sequence(path, tensor_inds_to_vertex(tn)) +end + +function to_contraction_sequence(expr, tensor_inds_to_vertex) + EinExprs.nargs(expr) == 0 && return tensor_inds_to_vertex[Set(expr.head)] + return map( + expr -> to_contraction_sequence(expr, tensor_inds_to_vertex), EinExprs.args(expr) + ) +end + +end diff --git a/src/ITensorNetworks.jl b/src/ITensorNetworks.jl index d4c993a2..dbd1e515 100644 --- a/src/ITensorNetworks.jl +++ b/src/ITensorNetworks.jl @@ -22,6 +22,7 @@ using LinearAlgebra using NamedGraphs using Observers using Observers.DataFrames: select! +using PackageExtensionCompat using Printf using Requires using SimpleTraits @@ -130,6 +131,7 @@ include(joinpath("treetensornetworks", "solvers", "tree_sweeping.jl")) include("exports.jl") function __init__() + @require_extensions @require OMEinsumContractionOrders = "6f22d1fd-8eed-4bb7-9776-e7d684900715" include( joinpath("requires", "omeinsumcontractionorders.jl") ) diff --git a/test/Project.toml b/test/Project.toml index f9f1d11f..05386c5c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,6 +4,7 @@ Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5" Glob = "c27321d9-0574-5035-807b-f59d2c89b15c" Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" GraphsFlows = "06909019-6f44-4949-96fc-b9d9aaa02889" diff --git a/test/test_contraction_sequence.jl b/test/test_contraction_sequence.jl index fd732f3f..5470dabe 100644 --- a/test/test_contraction_sequence.jl +++ b/test/test_contraction_sequence.jl @@ -3,6 +3,7 @@ using ITensorNetworks using OMEinsumContractionOrders using Random using Test +using EinExprs: Exhaustive, Greedy, HyPar Random.seed!(1234) @@ -23,7 +24,15 @@ ITensors.disable_warn_order() res_tree_sa = contract(tn; sequence=seq_tree_sa)[] seq_sa_bipartite = contraction_sequence(tn; alg="sa_bipartite") res_sa_bipartite = contract(tn; sequence=seq_sa_bipartite)[] - @test res_optimal ≈ res_greedy ≈ res_tree_sa ≈ res_sa_bipartite + seq_einexprs_exhaustive = contraction_sequence(tn; alg="einexpr", optimizer=Exhaustive()) + res_einexprs_exhaustive = contract(tn; sequence=seq_einexprs_exhaustive)[] + seq_einexprs_greedy = contraction_sequence(tn; alg="einexpr", optimizer=Greedy()) + res_einexprs_greedy = contract(tn; sequence=seq_einexprs_exhaustive)[] + @test res_greedy ≈ res_optimal + @test res_tree_sa ≈ res_optimal + @test res_sa_bipartite ≈ res_optimal + @test res_einexprs_exhaustive ≈ res_optimal + @test res_einexprs_greedy ≈ res_optimal if !Sys.iswindows() # KaHyPar doesn't work on Windows @@ -34,5 +43,8 @@ ITensors.disable_warn_order() seq_kahypar_bipartite = contraction_sequence(tn; alg="kahypar_bipartite", sc_target=200) res_kahypar_bipartite = contract(tn; sequence=seq_kahypar_bipartite)[] @test res_optimal ≈ res_kahypar_bipartite + seq_einexprs_kahypar = contraction_sequence(tn; alg="einexpr", optimizer=HyPar()) + res_einexprs_kahypar = contract(tn; sequence=seq_einexprs_kahypar)[] + @test res_einexprs_kahypar ≈ res_optimal end end