Skip to content

Commit

Permalink
Revert
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeyT1994 committed Jun 12, 2024
1 parent 0e5e5d8 commit 4bc0183
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 85 deletions.
27 changes: 1 addition & 26 deletions src/solvers/alternating_update/alternating_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using NamedGraphs.GraphsExtensions: GraphsExtensions

function alternating_update(
operator,
init_state;
init_state::AbstractTTN;
nsweeps, # define default for each solver implementation
nsites, # define default for each level of solver implementation
updater, # this specifies the update performed locally
Expand All @@ -21,7 +21,6 @@ function alternating_update(
inserter=default_inserter(),
transform_operator_kwargs=(;),
transform_operator=default_transform_operator(),
sweep_plan_func=default_sweep_plan,
kwargs...,
)
inserter_kwargs = (; inserter_kwargs..., kwargs...)
Expand Down Expand Up @@ -104,30 +103,6 @@ function alternating_update(operator::AbstractTTN, init_state::AbstractTTN; kwar
return alternating_update(projected_operator, init_state; kwargs...)
end

#TODO: Generalise your BP alternating update to operator::Sum{AbstractITensorNetwork}. Shouldn't this
# account for the environment correctly and put the BP error precisely on the state only, which is better
# conditioned?!
function alternating_update(
operators::Vector{ITensorNetwork},
init_state::AbstractITensorNetwork,
sweep_plans;
kwargs...,
)
cache_update_kwargs = is_tree(init_state) ? (;) : (; maxiter=10, tol=1e-5)
ψOψs = QuadraticFormNetwork[
QuadraticFormNetwork(operator, init_state) for operator in operators
]
ψIψ = QuadraticFormNetwork(init_state)
ψOψ_bpcs = BeliefPropagationCache[BeliefPropagationCache(ψOψ) for ψOψ in ψOψs]
ψIψ_bpc = BeliefPropagationCache(ψIψ)
ψOψ_bpcs = BeliefPropagationCache[
update(ψOψ_bpc; cache_update_kwargs...) for ψOψ_bpc in ψOψ_bpcs
]
ψIψ_bpc = update(ψIψ_bpc; cache_update_kwargs...)
projected_operators = (ψOψ_bpcs, ψIψ_bpc)
return alternating_update(projected_operators, init_state, sweep_plans; kwargs...)
end

function alternating_update(
operator::AbstractTTN, init_state::AbstractTTN, sweep_plans; kwargs...
)
Expand Down
61 changes: 2 additions & 59 deletions src/solvers/alternating_update/region_update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ end

function region_update(
projected_operator,
state::AbstractTTN;
state;
outputlevel,
which_sweep,
sweep_plan,
Expand Down Expand Up @@ -117,61 +117,4 @@ function region_update(
update_observer!(region_observer!; all_kwargs...)
!(isnothing(region_printer)) && region_printer(; all_kwargs...)
return state, projected_operator
end

function region_update(
projected_operators,
state;
outputlevel,
which_sweep,
sweep_plan,
which_region_update,
region_printer,
(region_observer!),
)
(region, region_kwargs) = sweep_plan[which_region_update]
(;
extracter,
extracter_kwargs,
updater,
updater_kwargs,
inserter,
inserter_kwargs,
transform_operator,
transform_operator_kwargs,
internal_kwargs,
) = region_kwargs
ψOψ_bpcs, ψIψ_bpc = first(projected_operators), last(projected_operators)

#Fix extracter, update and inserter to work with sum of ψOψ_bpcs
local_state, ∂ψOψ_bpc_∂rs, sqrt_mts, inv_sqrt_mts = extracter(
state, ψOψ_bpcs, ψIψ_bpc, region; extracter_kwargs...
)

local_state, _ = updater(
local_state, ∂ψOψ_bpc_∂rs, sqrt_mts, inv_sqrt_mts; updater_kwargs...
)

state, ψOψ_bpcs, ψIψ_bpc, spec, info = inserter(
state, ψOψ_bpcs, ψIψ_bpc, local_state, region; inserter_kwargs...
)

all_kwargs = (;
which_region_update,
sweep_plan,
total_sweep_steps=length(sweep_plan),
end_of_sweep=(which_region_update == length(sweep_plan)),
state,
region,
which_sweep,
spec,
outputlevel,
info...,
region_kwargs...,
internal_kwargs...,
)
update_observer!(region_observer!; all_kwargs...)
!(isnothing(region_printer)) && region_printer(; all_kwargs...)

return state, (ψOψ_bpcs, ψIψ_bpc)
end
end

0 comments on commit 4bc0183

Please sign in to comment.