Skip to content

Commit

Permalink
Fix _setindex error, fix rotl90(::CTMRGEnv), add gauge-fixing test sc…
Browse files Browse the repository at this point in the history
…ript, replace global phase fix by normalizing with sign
  • Loading branch information
pbrehmer committed Feb 27, 2024
1 parent daa1dd7 commit cc8d217
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 33 deletions.
6 changes: 3 additions & 3 deletions examples/heisenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using PEPSKit

# Square lattice Heisenberg Hamiltonian
# Sublattice-rotate to get (1, 1, 1) → (-1, 1, -1), transformed to GS with single-site unit cell
function square_lattice_heisenberg(; Jx=-1.0, Jy=1.0, Jz=-1.0)
function square_lattice_heisenberg(; Jx=-1, Jy=1, Jz=-1)
Sx, Sy, Sz, _ = spinmatrices(1//2)
Vphys =^2
σx = TensorMap(Sx, Vphys, Vphys)
Expand All @@ -28,10 +28,10 @@ function init_peps(d, D, Lx, Ly, finit=randn, dtype=ComplexF64)
end

# Parameters
H = square_lattice_heisenberg()
H = square_lattice_heisenberg(; Jx=-1, Jy=1, Jz=-1)
χbond = 2
χenv = 20
ctmalg = CTMRG(; trscheme=truncdim(χenv), tol=1e-12, miniter=4, maxiter=100, verbosity=2)
ctmalg = CTMRG(; trscheme=truncdim(χenv), tol=1e-10, miniter=4, maxiter=100, verbosity=2)
optalg = PEPSOptimize{LinSolve}(;
optimizer=LBFGS(4; maxiter=100, gradtol=1e-4, verbosity=2),
fpgrad_tol=1e-6,
Expand Down
43 changes: 43 additions & 0 deletions examples/test_gauge_fixing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
using LinearAlgebra
using TensorKit, MPSKitModels, OptimKit
using PEPSKit

# Square lattice Heisenberg Hamiltonian
function square_lattice_heisenberg(; Jx=-1.0, Jy=1.0, Jz=-1.0)
Sx, Sy, Sz, _ = spinmatrices(1//2)
Vphys =^2
σx = TensorMap(Sx, Vphys, Vphys)
σy = TensorMap(Sy, Vphys, Vphys)
σz = TensorMap(Sz, Vphys, Vphys)

@tensor H[-1 -3; -2 -4] :=
Jx * σx[-1, -2] * σx[-3, -4] +
Jy * σy[-1, -2] * σy[-3, -4] +
Jz * σz[-1, -2] * σz[-3, -4]

return H
end

# Initialize InfinitePEPS with random & complex entries by default
function init_peps(d, D, Lx, Ly, finit=randn, dtype=ComplexF64)
Pspaces = fill(ℂ^d, Lx, Ly)
Nspaces = fill(ℂ^D, Lx, Ly)
Espaces = fill(ℂ^D, Lx, Ly)
return InfinitePEPS(finit, dtype, Pspaces, Nspaces, Espaces)
end

# Initialize PEPS and environment
H = square_lattice_heisenberg()
χbond = 2
χenv = 20
ctmalg = CTMRG(; trscheme=truncdim(χenv), tol=1e-10, miniter=4, maxiter=100, verbosity=2)
ψ = init_peps(2, χbond, 1, 1)
env, = leading_boundary(ψ, ctmalg, CTMRGEnv(ψ; Venv=^χenv))

println("\nBefore gauge-fixing:")
env′, = PEPSKit.ctmrg_iter(ψ, env, ctmalg)
PEPSKit.check_elementwise_conv(env, env′)

println("\nAfter gauge-fixing:")
envfix = PEPSKit.gauge_fix(env, env′)
PEPSKit.check_elementwise_conv(env, envfix)
4 changes: 2 additions & 2 deletions examples/test_gradients.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ end
H = square_lattice_heisenberg()
χbond = 2
χenv = 20
ctmalg = CTMRG(; trscheme=truncdim(χenv), tol=1e-12, miniter=4, maxiter=100, verbosity=1)
ctmalg = CTMRG(; trscheme=truncdim(χenv), tol=1e-12, miniter=4, maxiter=100, verbosity=2)
ψ = init_peps(2, χbond, 1, 1)
env, = leading_boundary(ψinit, ctmalg, CTMRGEnv(ψinit; Venv=^χenv))
env, = leading_boundary(ψ, ctmalg, CTMRGEnv(ψ; Venv=^χenv))

# Compute CTM gradient in four different ways (set reuse_env=false to not mutate environment)
println("\nFP gradient using naive AD:")
Expand Down
67 changes: 41 additions & 26 deletions src/algorithms/ctmrg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ end
# Compute CTMRG environment for a given state
function MPSKit.leading_boundary(state, alg::CTMRG, envinit=CTMRGEnv(state))
normold = 1.0
CSold = tsvd(envinit.corners[NORTHWEST]; alg=TensorKit.SVD())[2]
TSold = tsvd(envinit.edges[NORTH]; alg=TensorKit.SVD())[2]
CSold = map(x -> tsvd(x; alg=TensorKit.SVD())[2], envinit.corners)
TSold = map(x -> tsvd(x; alg=TensorKit.SVD())[2], envinit.edges)
ϵold = 1.0
env = deepcopy(envinit)
Pleft, Pright = projector_type(eltype(env.edges), (4, size(state)...))
Expand All @@ -24,10 +24,10 @@ function MPSKit.leading_boundary(state, alg::CTMRG, envinit=CTMRGEnv(state))
Δϵ = abs((ϵold - ϵ) / ϵold)
normnew = norm(state, env)
Δnorm = abs(normold - normnew)
CSnew = tsvd(env.corners[NORTHWEST]; alg=TensorKit.SVD())[2]
ΔCS = norm(CSnew - CSold)
TSnew = tsvd(env.edges[NORTH]; alg=TensorKit.SVD())[2]
ΔTS = norm(TSnew - TSold)
CSnew = map(c -> tsvd(c; alg=TensorKit.SVD())[2], env.corners)
ΔCS = maximum(norm.(CSnew - CSold))
TSnew = map(t -> tsvd(t; alg=TensorKit.SVD())[2], env.edges)
ΔTS = maximum(norm.(TSnew - TSold))
(max(Δnorm, ΔCS, ΔTS) < alg.tol && i > alg.miniter) && break # Converge if maximal Δ falls below tolerance

# Print verbose info
Expand Down Expand Up @@ -75,22 +75,30 @@ function gauge_fix(
end

# Correct relative phases
# TODO: assign corners/edges to correct columns
# TODO: respect difference between left/right projector rows
cornersfix = map(Iterators.product(axes(envfinal.corners)...)) do (dir, r, c)
@tensor Cfix[-1; -2] :=
conj(signs[_prev(dir, 4), r, c][1 -1]) *
envfinal.corners[dir, r, c][1; 2] *
signs[dir, r, c][2, -2]
end
edgesfix = map(zip(signs, envfinal.edges)) do (σ, edge)
@tensor Tfix[-1 -2 -3; -4] := conj(σ[1 -1]) * edge[1 -2 -3; 2] * σ[2, -4]
# edgesfix = similar(envfinal.edges)
# for (dir, r, c) in Iterators.product(axes(envfinal.edges)...)
# @tensor Tfix[-1 -2 -3; -4] :=
# conj(signs[dir, r, c][1 -1]) *
# envfinal.edges[dir, r, c][1 -2 -3; 2] *
# signs[dir, _prev(r, end), c][2, -4]
# @diffset edgesfix[dir, r, _next(c, size(edgesfix, 2))] = Tfix
# end
edgesfix = map(Iterators.product(axes(envfinal.edges)...)) do (dir, r, c)
@tensor Tfix[-1 -2 -3; -4] :=
conj(signs[dir, r, c][1 -1]) *
envfinal.edges[dir, r, c][1 -2 -3; 2] *
signs[dir, _prev(r, end), c][2, -4]
end

# Fix global phase
cornersgfix = map(zip(envprev.corners, cornersfix)) do (Cprev, Cfix)
φ = tr(Cprev) / tr(Cfix) # Extract phase via trace to make it differentiable
return φ * Cfix
end
envfix = CTMRGEnv(cornersgfix, edgesfix)
edgesfix = circshift(edgesfix, (0, 1))
envfix = CTMRGEnv(cornersfix, edgesfix)

# Gauge projectors for correct backpropagation
if !isnothing(Pleft) && !isnothing(Pright)
Expand Down Expand Up @@ -146,26 +154,26 @@ end
# One CTMRG iteration x′ = f(A, x)
function ctmrg_iter(state, env::CTMRGEnv{C,T}, alg::CTMRG) where {C,T}
ϵ = 0.0
Pleft, Pright = projector_type(T, (4, size(state)...))
Pleft, Pright = Zygote.Buffer.(projector_type(T, (4, size(state)...)))

for i in 1:4
env, Pl, Pr, ϵ₀ = left_move(state, env, alg)
state = rotate_north(state, EAST)
env = rotate_north(env, EAST)
ϵ = max(ϵ, ϵ₀)
@diffset Pleft[i, :, :] .= Pl
@diffset Pright[i, :, :] .= Pr
Pleft[i, :, :] = Pl
Pright[i, :, :] = Pr
end

return env, Pleft, Pright, ϵ
return env, copy(Pleft), copy(Pright), ϵ
end

# Grow environment, compute projectors and renormalize
function left_move(state, env::CTMRGEnv{C,T}, alg::CTMRG) where {C,T}
corners::typeof(env.corners) = copy(env.corners)
edges::typeof(env.edges) = copy(env.edges)
ϵ = 0.0
Pleft, Pright = projector_type(T, size(state))
Pleft, Pright = Zygote.Buffer.(projector_type(T, size(state))) # Use Zygote.Buffer instead of @diffset to avoid ZeroTangent errors in _setindex

for col in 1:size(state, 2)
cnext = _next(col, size(state, 2))
Expand Down Expand Up @@ -210,8 +218,8 @@ function left_move(state, env::CTMRGEnv{C,T}, alg::CTMRG) where {C,T}

# Compute projectors
Pl, Pr = build_projectors(U, S, V, Q_sw, Q_nw)
@diffset Pleft[row, col] = Pl
@diffset Pright[row, col] = Pr
Pleft[row, col] = Pl
Pright[row, col] = Pr
end

# Use projectors to grow the corners & edges
Expand All @@ -233,12 +241,19 @@ function left_move(state, env::CTMRGEnv{C,T}, alg::CTMRG) where {C,T}
@diffset edges[WEST, row, cnext] = T_w
end

@diffset corners[SOUTHWEST, :, cnext] ./= norm.(corners[SOUTHWEST, :, cnext])
@diffset corners[NORTHWEST, :, cnext] ./= norm.(corners[NORTHWEST, :, cnext])
@diffset edges[WEST, :, cnext] ./= norm.(edges[WEST, :, cnext])
# Normalize with signed norm to avoid fixing global phase in gauge-fixing
@diffset corners[SOUTHWEST, :, cnext] .= map(corners[SOUTHWEST, :, cnext]) do C
C / (sign(tr(C)) * norm(C))
end
@diffset corners[NORTHWEST, :, cnext] .= map(corners[NORTHWEST, :, cnext]) do C
C / (sign(tr(C)) * norm(C))
end
@diffset edges[WEST, :, cnext] .= map(edges[WEST, :, cnext]) do E
E / (sign(@tensor E[1 2 2; 1]) * norm(E))
end
end

return CTMRGEnv(corners, edges), Pleft, Pright, ϵ
return CTMRGEnv(corners, edges), copy(Pleft), copy(Pright), ϵ
end

# Compute enlarged NW corner
Expand Down
3 changes: 3 additions & 0 deletions src/algorithms/peps_opt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ function ctmrg_gradient(x, H, ctmalg::CTMRG, optalg::PEPSOptimize)
cfun = optalg.reuse_env ? costfun! : costfun
E = cfun(peps, env, H, ctmalg, optalg)
∂E∂A = gradient(cfun, peps, env, H, ctmalg, optalg)[1]
if !(typeof(∂E∂A) <: InfinitePEPS) # NaiveAD returns NamedTuple as gradient instead of InfinitePEPS
∂E∂A = InfinitePEPS(∂E∂A.A)
end
@assert !isnan(norm(∂E∂A))
return E, ∂E∂A
end
Expand Down
5 changes: 3 additions & 2 deletions src/environments/ctmrgenv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ end

# Rotate corners & edges counter-clockwise
function Base.rotl90(env::CTMRGEnv{C,T}) where {C,T}
corners′ = similar(env.corners)
edges′ = similar(env.edges)
# Initialize rotated corners & edges with rotated sizes
corners′ = Array{C,3}(undef, 4, size(env.corners, 3), size(env.corners, 2))
edges′ = Array{T,3}(undef, 4, size(env.edges, 3), size(env.edges, 2))

for dir in 1:4
@diffset corners′[_prev(dir, 4), :, :] .= rotl90(env.corners[dir, :, :])
Expand Down
1 change: 1 addition & 0 deletions src/utility/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ function ChainRulesCore.rrule(::typeof(_setindex), a::AbstractArray, tv, args...
else
v
end
# TODO: Fix this for ZeroTangents
v = typeof(v) != typeof(a) ? convert(typeof(a), v) : v
#v = convert(typeof(a),v);
backwards_tv = v[args...]
Expand Down

0 comments on commit cc8d217

Please sign in to comment.