Skip to content

Commit

Permalink
Add options to disable various multithreading features
Browse files Browse the repository at this point in the history
  • Loading branch information
lkdvos committed Oct 3, 2023
1 parent 205b0b0 commit 11a9fca
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 70 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ FoldsThreads = "9c68100b-dfe1-47cf-94c8-95104e173443"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OptimKit = "77e91f04-9b3b-57a6-a776-40b61faaebe0"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
TensorKit = "07d1fe3e-3e46-537d-9eac-e9e13d0d4cec"
TensorKitManifolds = "11fa318c-39cb-4a83-b1ed-cdc7ba1e3684"
Expand Down
17 changes: 4 additions & 13 deletions src/MPSKit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ using LinearAlgebra: diag, Diagonal
using LinearAlgebra: LinearAlgebra
using Base: @kwdef

using Preferences

#bells and whistles for mpses
export InfiniteMPS, FiniteMPS, WindowMPS, MPSMultiline
export PeriodicArray, Window
Expand Down Expand Up @@ -53,19 +55,8 @@ export transfer_left, transfer_right
@deprecate params(args...) environments(args...)
@deprecate InfiniteMPO(args...) DenseMPO(args...)

#default settings
module Defaults
const eltype = ComplexF64
const maxiter = 100
const tolgauge = 1e-14
const tol = 1e-12
const verbose = true
_finalize(iter, state, opp, envs) = (state, envs)

import KrylovKit: GMRES, Arnoldi
const linearsolver = GMRES(; tol, maxiter)
const eigsolver = Arnoldi(; tol, maxiter)
end

include("utility/defaults.jl")

include("utility/periodicarray.jl")
include("utility/utility.jl") #random utility functions
Expand Down
62 changes: 42 additions & 20 deletions src/algorithms/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,41 @@ end
"""

function ∂AC(x::MPSTensor, ham::SparseMPOSlice, leftenv, rightenv)::typeof(x)
local toret
local y
if Defaults.parallelize_derivatives
@floop WorkStealingEx() for (i, j) in keys(ham)
if isscal(ham, i, j)
@plansor t[-1 -2; -3] :=
leftenv[i][-1 5; 4] * x[4 6; 1] * τ[6 5; 7 -2] * rightenv[j][1 7; -3]
lmul!(ham.Os[i, j], t)
else
@plansor t[-1 -2; -3] :=
leftenv[i][-1 5; 4] *
x[4 2; 1] *
ham[i, j][5 -2; 2 3] *
rightenv[j][1 3; -3]
end

@floop WorkStealingEx() for (i, j) in keys(ham)
if isscal(ham, i, j)
@plansor t[-1 -2; -3] :=
leftenv[i][-1 5; 4] * x[4 6; 1] * τ[6 5; 7 -2] * rightenv[j][1 7; -3]
lmul!(ham.Os[i, j], t)
else
@plansor t[-1 -2; -3] :=
leftenv[i][-1 5; 4] *
x[4 2; 1] *
ham[i, j][5 -2; 2 3] *
rightenv[j][1 3; -3]
@reduce(y = inplace_add!(nothing, t))
end
else
y = zerovector(x)
for (i, j) in keys(ham)
if isscal(ham, i, j)
h = ham.Os[i, j]
@plansor y[-1 -2; -3] += h *
(leftenv[i][-1 5; 4] * x[4 6; 1] * τ[6 5; 7 -2] * rightenv[j][1 7; -3])
else
@plansor y[-1 -2; -3] +=
leftenv[i][-1 5; 4] *
x[4 2; 1] *
ham[i, j][5 -2; 2 3] *
rightenv[j][1 3; -3]
end
end

@reduce(toret = inplace_add!(nothing, t))
end

return toret
return y
end

function ∂AC(x::MPSTensor, opp::MPOTensor, leftenv, rightenv)::typeof(x)
Expand Down Expand Up @@ -184,13 +200,19 @@ end
Zero-site derivative (the C matrix to the right of pos)
"""
function ∂C(x::MPSBondTensor, leftenv::AbstractVector, rightenv::AbstractVector)::typeof(x)
@floop WorkStealingEx() for (le, re) in zip(leftenv, rightenv)
t = ∂C(x, le, re)

@reduce(s = inplace_add!(nothing, t))
if Defaults.parallelize_derivatives
@floop WorkStealingEx() for (le, re) in zip(leftenv, rightenv)
t = ∂C(x, le, re)
@reduce(y = inplace_add!(nothing, t))
end
else
y = zerovector(x)
for (le, re) in zip(leftenv, rightenv)
VectorInterface.add!(y, ∂C(x, le, re))
end
end

return s
return y
end

function ∂C(x::MPSBondTensor, leftenv::MPSTensor, rightenv::MPSTensor)
Expand Down
33 changes: 23 additions & 10 deletions src/algorithms/groundstate/vumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,36 @@ function find_groundstate(Ψ::InfiniteMPS, H, alg::VUMPS, envs=environments(Ψ,

while true
eigalg = Arnoldi(; tol=galerkin / (4 * sqrt(iter)))

@sync for (loc, (ac, c)) in enumerate(zip.AC, Ψ.CR))
Threads.@spawn begin
_, acvecs = eigsolve(∂∂AC($loc, $Ψ, $H, $envs), $ac, 1, :SR, eigalg)
$temp_ACs[loc] = acvecs[1]

if Defaults.parallelize_sites
@sync begin
for (loc, ac) in enumerate.AC)
Threads.@spawn begin
_, acvecs = eigsolve(∂∂AC($loc, $Ψ, $H, $envs), $ac, 1, :SR, eigalg)
$temp_ACs[loc] = acvecs[1]
end
end
for (loc, c) in enumerate.CR)
Threads.@spawn begin
_, crvecs = eigsolve(∂∂C($loc, $Ψ, $H, $envs), $c, 1, :SR, eigalg)
$temp_Cs[loc] = crvecs[1]
end
end
end

Threads.@spawn begin
_, crvecs = eigsolve(∂∂C($loc, $Ψ, $H, $envs), $c, 1, :SR, eigalg)
$temp_Cs[loc] = crvecs[1]
else
for (loc, ac) in enumerate.AC)
_, acvecs = eigsolve(∂∂AC(loc, Ψ, H, envs), ac, 1, :SR, eigalg)
temp_ACs[loc] = acvecs[1]
end
for (loc, c) in enumerate.CR)
_, crvecs = eigsolve(∂∂C(loc, Ψ, H, envs), c, 1, :SR, eigalg)
temp_Cs[loc] = crvecs[1]
end
end

for (i, (ac, c)) in enumerate(zip(temp_ACs, temp_Cs))
QAc, _ = TensorKit.leftorth!(ac; alg=QRpos())
Qc, _ = TensorKit.leftorth!(c; alg=QRpos())

temp_ACs[i] = QAc * adjoint(Qc)
end

Expand Down
90 changes: 63 additions & 27 deletions src/transfermatrix/transfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,49 +190,85 @@ end

function transfer_left(RetType, vec, ham::SparseMPOSlice, A, Ab)
toret = similar(vec, RetType, length(vec))

@threads for k in 1:length(vec)
els = keys(ham, :, k)

@floop WorkStealingEx() for j in els
if isscal(ham, j, k)
t = lmul!(ham.Os[j, k], transfer_left(vec[j], A, Ab))
else
t = transfer_left(vec[j], ham[j, k], A, Ab)

if Defaults.parallelize_transfers
@threads for k in 1:length(vec)
els = keys(ham, :, k)
@floop WorkStealingEx() for j in els
if isscal(ham, j, k)
t = lmul!(ham.Os[j, k], transfer_left(vec[j], A, Ab))
else
t = transfer_left(vec[j], ham[j, k], A, Ab)
end

@reduce(s = inplace_add!(nothing, t))
end

@reduce(s = inplace_add!(nothing, t))
if isnothing(s)
s = transfer_left(vec[1], ham[1, k], A, Ab)
end
toret[k] = s
end

if isnothing(s)
s = transfer_left(vec[1], ham[1, k], A, Ab)
else
for k in 1:length(vec)
els = keys(ham, :, k)
if isempty(els)
toret[k] = transfer_left(vec[1], ham[1, k], A, Ab)
else
zerovector!(toret[k])
for j in els
if isscal(ham, j, k)
add!(toret[k], transfer_left(vec[k], A, Ab), ham.Os[j, k])
else
add!(toret[k], transfer_left(vec[k], ham[j, k], A, Ab))
end
end
end
end
toret[k] = s
end

return toret
end
function transfer_right(RetType, vec, ham::SparseMPOSlice, A, Ab)
toret = similar(vec, RetType, length(vec))

@threads for j in 1:length(vec)
els = keys(ham, j, :)
if Defaults.parallelize_transfers
@threads for j in 1:length(vec)
els = keys(ham, j, :)

@floop WorkStealingEx() for k in els
if isscal(ham, j, k)
t = lmul!(ham.Os[j, k], transfer_right(vec[k], A, Ab))
else
t = transfer_right(vec[k], ham[j, k], A, Ab)
@floop WorkStealingEx() for k in els
if isscal(ham, j, k)
t = lmul!(ham.Os[j, k], transfer_right(vec[k], A, Ab))
else
t = transfer_right(vec[k], ham[j, k], A, Ab)
end

@reduce(s = inplace_add!(nothing, t))
end

@reduce(s = inplace_add!(nothing, t))
end
if isnothing(s)
s = transfer_right(vec[1], ham[j, 1], A, Ab)
end

if isnothing(s)
s = transfer_right(vec[1], ham[j, 1], A, Ab)
toret[j] = s
end
else
for j in 1:length(vec)
els = keys(ham, j, :)
if isempty(els)
toret[j] = transfer_left(vec[1], ham[j, 1], A, Ab)
else
zerovector!(toret[j])
for k in els
if isscal(ham, j, k)
add!(toret[j], transfer_right(vec[k], A, Ab), ham.Os[j, k])
else
add!(toret[j], transfer_right(vec[k], ham[j, k], A, Ab))
end
end
end

end

toret[j] = s
end

return toret
Expand Down
45 changes: 45 additions & 0 deletions src/utility/defaults.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
module Defaults
Some default values and settings for MPSKit.
"""
module Defaults

using Preferences
import KrylovKit: GMRES, Arnoldi

const eltype = ComplexF64
const maxiter = 100
const tolgauge = 1e-14
const tol = 1e-12
const verbose = true

_finalize(iter, state, opp, envs) = (state, envs)

const linearsolver = GMRES(; tol, maxiter)
const eigsolver = Arnoldi(; tol, maxiter)

# Preferences
# -----------

function set_parallelization(options::Pair{String, Bool}...)
for (key, val) in options
if !(key in ("sites", "derivatives", "transfers"))
throw(ArgumentError("Invalid option: \"$(key)\""))
end

@set_preferences!("parallelize_$key" => val)
end

sites = @load_preference("parallelize_sites", nothing)
derivatives = @load_preference("parallelize_derivatives", nothing)
transfers = @load_preference("parallelize_derivatives", nothing)
@info "Parallelization changed; restart your Julia session for this change to take effect!" sites derivatives transfers
return nothing
end

const parallelize_sites = @load_preference("parallelize_sites", Threads.nthreads() > 1)
const parallelize_derivatives = @load_preference("parallelize_derivatives", Threads.nthreads() > 1)
const parallelize_transfers = @load_preference("parallelize_transfers", Threads.nthreads() > 1)

end

0 comments on commit 11a9fca

Please sign in to comment.