Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

This PR makes Gradient Descent parallelized using Threads.@spawn #179

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
74 changes: 59 additions & 15 deletions src/algorithms/grassmann.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
module GrassmannMPS

using ..MPSKit
using ..Defaults
using TensorKit
import TensorKitManifolds.Grassmann

Expand Down Expand Up @@ -69,22 +70,38 @@

function ManifoldPoint(state::Union{InfiniteMPS,FiniteMPS}, envs)
al_d = similar(state.AL)
for i in 1:length(state)
al_d[i] = MPSKit.∂∂AC(i, state, envs.opp, envs) * state.AC[i]
@static if Defaults.parallelize_sites

Check warning on line 73 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L73

Added line #L73 was not covered by tests
@sync for i in 1:length(state)
Threads.@spawn begin
al_d[i] = MPSKit.∂∂AC(i, state, envs.opp, envs) * state.AC[i]
end
end

Check warning on line 78 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L78

Added line #L78 was not covered by tests
else
for i in 1:length(state)
al_d[i] = MPSKit.∂∂AC(i, state, envs.opp, envs) * state.AC[i]
end
end

g = Grassmann.project.(al_d, state.AL)

Gertian marked this conversation as resolved.
Show resolved Hide resolved
Rhoreg = Vector{eltype(state.CR)}(undef, length(state))
δmin = sqrt(eps(real(scalartype(state))))
for i in 1:length(state)
Rhoreg[i] = regularize(state.CR[i], max(norm(g[i]) / 10, δmin))
@static if Defaults.parallelize_sites

Check warning on line 88 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L88

Added line #L88 was not covered by tests
Gertian marked this conversation as resolved.
Show resolved Hide resolved
@sync for i in 1:length(state)
Threads.@spawn begin
Rhoreg[i] = regularize(state.CR[i], max(norm(g[i]) / 10, δmin))
end
end

Check warning on line 93 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L93

Added line #L93 was not covered by tests
else
for i in 1:length(state)
Rhoreg[i] = regularize(state.CR[i], max(norm(g[i]) / 10, δmin))
end
end

return ManifoldPoint(state, envs, g, Rhoreg)
end

function ManifoldPoint(state::MPSMultiline, envs)
#TODO : support parralelize_sites
# FIXME: add support for unitcells
@assert length(state.AL) == 1 "GradientGrassmann only supports MPSMultiline without unitcells for now"

Expand Down Expand Up @@ -115,9 +132,16 @@
function fg(x::ManifoldPoint{T}) where {T<:Union{InfiniteMPS,FiniteMPS}}
# the gradient I want to return is the preconditioned gradient!
g_prec = Vector{PrecGrad{eltype(x.g),eltype(x.Rhoreg)}}(undef, length(x.g))

for i in 1:length(x.state)
g_prec[i] = PrecGrad(rmul!(copy(x.g[i]), x.state.CR[i]'), x.Rhoreg[i])
@static if Defaults.parallelize_sites

Check warning on line 135 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L135

Added line #L135 was not covered by tests
@sync for i in 1:length(x.state)
Threads.@spawn begin
g_prec[i] = PrecGrad(rmul!(copy(x.g[i]), x.state.CR[i]'), x.Rhoreg[i])
end
end

Check warning on line 140 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L140

Added line #L140 was not covered by tests
else
for i in 1:length(x.state)
g_prec[i] = PrecGrad(rmul!(copy(x.g[i]), x.state.CR[i]'), x.Rhoreg[i])
end
end

# TODO: the operator really should not be part of the environments, and this should
Expand All @@ -128,6 +152,7 @@
return real(f), g_prec
end
function fg(x::ManifoldPoint{<:MPSMultiline})
#TODO : support parralelize_sites
@assert length(x.state) == 1 "GradientGrassmann only supports MPSMultiline without unitcells for now"
# the gradient I want to return is the preconditioned gradient!
g_prec = map(enumerate(x.g)) do (i, cg)
Expand All @@ -147,6 +172,7 @@
Retract a left-canonical MPSMultiline along Grassmann tangent `g` by distance `alpha`.
"""
function retract(x::ManifoldPoint{<:MPSMultiline}, tg, alpha)
#TODO : support parralelize_sites
Gertian marked this conversation as resolved.
Show resolved Hide resolved
g = reshape(tg, size(x.state))

nal = similar(x.state.AL)
Expand All @@ -170,11 +196,19 @@
envs = x.envs
nal = similar(state.AL)
h = similar(g) # The tangent at the end-point
for i in 1:length(g)
nal[i], th = Grassmann.retract(state.AL[i], g[i].Pg, alpha)
h[i] = PrecGrad(th)
@static if Defaults.parallelize_sites

Check warning on line 199 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L199

Added line #L199 was not covered by tests
@sync for i in 1:length(g)
Threads.@spawn begin
nal[i], th = Grassmann.retract(state.AL[i], g[i].Pg, alpha)
h[i] = PrecGrad(th)
end
end

Check warning on line 205 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L205

Added line #L205 was not covered by tests
else
for i in 1:length(g)
nal[i], th = Grassmann.retract(state.AL[i], g[i].Pg, alpha)
h[i] = PrecGrad(th)
end
end

nstate = InfiniteMPS(nal, state.CR[end])

newpoint = ManifoldPoint(nstate, envs)
Expand All @@ -186,6 +220,7 @@
Retract a left-canonical finite MPS along Grassmann tangent `g` by distance `alpha`.
"""
function retract(x::ManifoldPoint{<:FiniteMPS}, g, alpha)
#TODO : support parralelize_sites.
state = x.state
envs = x.envs

Expand All @@ -208,9 +243,18 @@
`alpha`. `xp` is the end-point of the retraction.
"""
function transport!(h, x, g, alpha, xp)
for i in 1:length(h)
h[i] = PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
xp.state.AL[i]))
@static if Defaults.parallelize_sites

Check warning on line 246 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L246

Added line #L246 was not covered by tests
@sync for i in 1:length(h)
Threads.@spawn begin
h[i] = PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
xp.state.AL[i]))
end
end

Check warning on line 252 in src/algorithms/grassmann.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms/grassmann.jl#L252

Added line #L252 was not covered by tests
else
for i in 1:length(h)
h[i] = PrecGrad(Grassmann.transport!(h[i].Pg, x.state.AL[i], g[i].Pg, alpha,
xp.state.AL[i]))
end
end
return h
end
Expand Down
Loading