diff --git a/src/treetensornetworks/solvers/contract.jl b/src/treetensornetworks/solvers/contract.jl index 8ce97aec..3799eda7 100644 --- a/src/treetensornetworks/solvers/contract.jl +++ b/src/treetensornetworks/solvers/contract.jl @@ -74,6 +74,11 @@ end Overload of `ITensors.apply`. """ function apply(tn1::AbstractTTN, tn2::AbstractTTN; init, kwargs...) + if !isone(plev_diff(flatten_external_indsnetwork(tn1, tn2), external_indsnetwork(init))) + error( + "Initial guess `init` needs to primelevel one less than the contraction tn1 and tn2." + ) + end init = init' tn12 = contract(tn1, tn2; init, kwargs...) return replaceprime(tn12, 1 => 0) @@ -82,8 +87,25 @@ end function sum_apply( tns::Vector{<:Tuple{<:AbstractTTN,<:AbstractTTN}}; alg="fit", init, kwargs... ) + if !isone( + plev_diff( + flatten_external_indsnetwork(first(first(tns)), last(first(tns))), + external_indsnetwork(init), + ), + ) + error( + "Initial guess `init` needs to primelevel one less than the contraction tn1 and tn2." + ) + end + init = init' alg != "fit" && error("sum_apply not implemented for other algorithms than fit.") tn12 = contract(Algorithm(alg), tns; init, kwargs...) return replaceprime(tn12, 1 => 0) end + +function plev_diff(a::IndsNetwork, b::IndsNetwork) + pla = plev(only(a[first(vertices(a))])) + plb = plev(only(b[first(vertices(b))])) + return pla - plb +end