Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

This PR makes Gradient Descent parallelized using Threads.@spawn #179

Closed
wants to merge 12 commits into from
78 changes: 62 additions & 16 deletions src/algorithms/grassmann.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
module GrassmannMPS

using ..MPSKit
using ..Defaults
using TensorKit
import TensorKitManifolds.Grassmann

Expand Down Expand Up @@ -69,22 +70,40 @@

function ManifoldPoint(state::Union{InfiniteMPS,FiniteMPS}, envs)
al_d = similar(state.AL)
for i in 1:length(state)
al_d[i] = MPSKit.∂∂AC(i, state, envs.opp, envs) * state.AC[i]
@static if Defaults.parallelize_sites

Check warning on line 73 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L73

Added line #L73 was not covered by tests
g = fill(nothing, length(state)) #not typstable but this won't be a performance issue for now :)
@sync for i in 1:length(state)
Threads.@spawn begin
al_d[i] = MPSKit.∂∂AC(i, state, envs.opp, envs) * state.AC[i]
g[i] = Grassmann.project(al_d[i], state.AL[i])
end
end
else
for i in 1:length(state)
al_d[i] = MPSKit.∂∂AC(i, state, envs.opp, envs) * state.AC[i]
end
g = Grassmann.project.(al_d, state.AL)
end

g = Grassmann.project.(al_d, state.AL)

Rhoreg = Vector{eltype(state.CR)}(undef, length(state))
δmin = sqrt(eps(real(scalartype(state))))
for i in 1:length(state)
Rhoreg[i] = regularize(state.CR[i], max(norm(g[i]) / 10, δmin))
@static if Defaults.parallelize_sites

Check warning on line 90 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L90

Added line #L90 was not covered by tests
@sync for i in 1:length(state)
Threads.@spawn begin
Rhoreg[i] = regularize(state.CR[i], max(norm(g[i]) / 10, δmin))
end
end
else
for i in 1:length(state)
Rhoreg[i] = regularize(state.CR[i], max(norm(g[i]) / 10, δmin))
end
end

return ManifoldPoint(state, envs, g, Rhoreg)
end

function ManifoldPoint(state::MPSMultiline, envs)
#TODO : support parralelize_sites
# FIXME: add support for unitcells
@assert length(state.AL) == 1 "GradientGrassmann only supports MPSMultiline without unitcells for now"

Expand Down Expand Up @@ -115,9 +134,16 @@
function fg(x::ManifoldPoint{T}) where {T<:Union{InfiniteMPS,FiniteMPS}}
# the gradient I want to return is the preconditioned gradient!
g_prec = Vector{PrecGrad{eltype(x.g),eltype(x.Rhoreg)}}(undef, length(x.g))

for i in 1:length(x.state)
g_prec[i] = PrecGrad(rmul!(copy(x.g[i]), x.state.CR[i]'), x.Rhoreg[i])
@static if Defaults.parallelize_sites

Check warning on line 137 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L137

Added line #L137 was not covered by tests
@sync for i in 1:length(x.state)
Threads.@spawn begin
g_prec[i] = PrecGrad(rmul!(copy(x.g[i]), x.state.CR[i]'), x.Rhoreg[i])
end
end
else
for i in 1:length(x.state)
g_prec[i] = PrecGrad(rmul!(copy(x.g[i]), x.state.CR[i]'), x.Rhoreg[i])
end
end

# TODO: the operator really should not be part of the environments, and this should
Expand All @@ -128,6 +154,7 @@
return real(f), g_prec
end
function fg(x::ManifoldPoint{<:MPSMultiline})
#TODO : support parralelize_sites
@assert length(x.state) == 1 "GradientGrassmann only supports MPSMultiline without unitcells for now"
# the gradient I want to return is the preconditioned gradient!
g_prec = map(enumerate(x.g)) do (i, cg)
Expand All @@ -147,6 +174,7 @@
Retract a left-canonical MPSMultiline along Grassmann tangent `g` by distance `alpha`.
"""
function retract(x::ManifoldPoint{<:MPSMultiline}, tg, alpha)
#TODO : support parralelize_sites
g = reshape(tg, size(x.state))

nal = similar(x.state.AL)
Expand All @@ -170,11 +198,19 @@
envs = x.envs
nal = similar(state.AL)
h = similar(g) # The tangent at the end-point
for i in 1:length(g)
nal[i], th = Grassmann.retract(state.AL[i], g[i].Pg, alpha)
h[i] = PrecGrad(th)
@static if Defaults.parallelize_sites

Check warning on line 201 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L201

Added line #L201 was not covered by tests
@sync for i in 1:length(g)
Threads.@spawn begin
nal[i], th = Grassmann.retract(state.AL[i], g[i].Pg, alpha)
h[i] = PrecGrad(th)
end
end
else
for i in 1:length(g)
nal[i], th = Grassmann.retract(state.AL[i], g[i].Pg, alpha)
h[i] = PrecGrad(th)
end
end

nstate = InfiniteMPS(nal, state.CR[end])

newpoint = ManifoldPoint(nstate, envs)
Expand All @@ -186,6 +222,7 @@
Retract a left-canonical finite MPS along Grassmann tangent `g` by distance `alpha`.
"""
function retract(x::ManifoldPoint{<:FiniteMPS}, g, alpha)
#TODO : support parralelize_sites.
state = x.state
envs = x.envs

Expand All @@ -208,9 +245,18 @@
`alpha`. `xp` is the end-point of the retraction.
"""
function transport!(h, x, g, alpha, xp)
for i in 1:length(h)
h[i] = PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
xp.state.AL[i]))
@static if Defaults.parallelize_sites

Check warning on line 248 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L248

Added line #L248 was not covered by tests
@sync for i in 1:length(h)
Threads.@spawn begin
h[i] = PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
xp.state.AL[i]))
end
end
else
for i in 1:length(h)
h[i] = PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
xp.state.AL[i]))
end
end
return h
end
Expand Down
Loading