Skip to content

Commit

Permalink
Move out vertex mapping from to_einexpr to tensor_inds_to_vertex
Browse files Browse the repository at this point in the history
  • Loading branch information
mofeing committed Feb 7, 2024
1 parent 6c48bbd commit cc19993
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions ext/ITensorNetworksEinExprsExt/src/ITensorNetworksEinExprsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,37 @@ using ITensorNetworks:
contraction_sequence
using EinExprs: EinExprs, EinExpr, einexpr, SizedEinExpr

function to_einexpr(tn::ITensorNetwork)
function to_einexpr(tn::AbstractITensorNetwork)
IndexType = Any
VertexType = vertextype(tn)

tensor_exprs = EinExpr{IndexType}[]
tensor_inds_to_vertex = Dict{Set{IndexType},VertexType}()
inds_dims = Dict{IndexType,Int}()

for v in vertices(tn)
tensor_v = tn[v]
inds_v = collect(inds(tensor_v))
push!(tensor_exprs, EinExpr{IndexType}(; head=inds_v))
tensor_inds_to_vertex[Set(inds_v)] = key
merge!(inds_dims, Dict(inds_v .=> size(tensor_v)))
end

externalinds_tn = collect(externalinds(tn))
expr = SizedEinExpr(sum(tensor_exprs; skip=externalinds_tn), inds_dims)
return SizedEinExpr(sum(tensor_exprs; skip=externalinds_tn), inds_dims)
end

function tensor_inds_to_vertex(tn::AbstractITensorNetwork)
mapping = Dict{Set{IndexType},VertexType}()

for v in vertices(tn)
tensor_v = tn[v]
inds_v = collect(inds(tensor_v))
mapping[Set(inds_v)] = v
end

return expr, tensor_inds_to_vertex
return mapping
end

function EinExprs.einexpr(tn::ITensorNetwork; optimizer::EinExprs.Optimizer)
expr, _ = to_einexpr(tn)
expr = to_einexpr(tn)
return einexpr(optimizer, expr)
end

Expand All @@ -46,8 +53,8 @@ end
function ITensorNetworks.contraction_sequence(
::Algorithm"einexpr", tn::ITensorNetwork{T}; optimizer=EinExprs.Exhaustive()
)
expr, tensor_inds_to_vertex = to_einexpr(tn)
return to_contraction_sequence(expr, tensor_inds_to_vertex)
expr = einexpr(tn; optimizer)
return to_contraction_sequence(expr, tensor_inds_to_vertex(tn))
end

function to_contraction_sequence(expr, tensor_inds_to_vertex)
Expand Down

0 comments on commit cc19993

Please sign in to comment.