Skip to content

Commit

Permalink
Add support for forcing the GC to run when using Distributed
Browse files Browse the repository at this point in the history
  • Loading branch information
mtfishman authored Aug 2, 2023
2 parents e0a4a6b + 070c530 commit be76312
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 17 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ version = "0.1.0"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Folds = "41a02a25-b8f0-4f67-bc48-60067656b558"
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"

[compat]
Accessors = "0.1.26"
Compat = "4.8"
Folds = "0.2.8"
ITensors = "0.3.27"
MPI = "0.20"
Expand Down
2 changes: 2 additions & 0 deletions src/ITensorParallel.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ITensorParallel

using Accessors
using Compat
using Distributed
using Folds
using MPI
Expand All @@ -27,6 +28,7 @@ import ITensors:
include(joinpath("partition", "partition.jl"))
include(joinpath("partition", "partition_sum_split.jl"))
include(joinpath("partition", "partition_chain_split.jl"))
include("force_gc.jl")
include("foldssum.jl")
include("distributedsum.jl")
include("mpi_extensions.jl")
Expand Down
55 changes: 42 additions & 13 deletions src/distributedsum.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,61 @@
distribute(term) = @spawnat(:any, term)
distribute(term::Future) = term
distribute(term::MPO) = distribute(ProjMPO(term))

# Functionality for distributed terms of a sum
# onto seperate processes
function distribute(terms::Vector)
return map(term -> @spawnat(:any, term), terms)
distribute(terms::Vector) = distribute.(terms)

function DistributedSum(terms::Vector; executor_kwargs...)
return FoldsSum(distribute(terms), SequentialEx(; executor_kwargs...))
end

distribute(terms::Vector{Future}) = terms
(term::Future)(v::ITensor) = product(term, v)

function distribute(terms::Vector{MPO})
return distribute(ProjMPO.(terms))
function product(term::Future, v::ITensor)
return product(() -> force_gc(), term, v)
end

function DistributedSum(terms::Vector; executor_kwargs...)
return FoldsSum(distribute(terms), SequentialEx(; executor_kwargs...))
function product(callback, term::Future, v::ITensor)
return @fetchfrom term.where begin
res = fetch(term)(v)
callback()
return res
end
end

function position!(term::Future, v::MPS, pos::Int)
return @spawnat term.where position!(fetch(term), v, pos)
return position!(() -> force_gc(), term, v, pos)
end

function product(term::Future, v::ITensor)
return @fetchfrom term.where fetch(term)(v)
function position!(callback, term::Future, v::MPS, pos::Int)
return @spawnat term.where begin
res = position!(fetch(term), v, pos)
callback()
return res
end
end
(term::Future)(v::ITensor) = product(term, v)

function noiseterm(term::Future, v::ITensor, dir::String)
return @fetchfrom term.where noiseterm(fetch(term), v, dir)
return noiseterm(() -> force_gc(), term, v, dir)
end

function noiseterm(callback, term::Future, v::ITensor, dir::String)
return @fetchfrom term.where begin
res = noiseterm(fetch(term), v, dir)
callback()
return res
end
end

function disk(term::Future; disk_kwargs...)
return @spawnat term.where disk(fetch(term); disk_kwargs...)
return disk(() -> force_gc(), term; disk_kwargs...)
end

function disk(callback, term::Future; disk_kwargs...)
return @spawnat term.where begin
res = disk(fetch(term); disk_kwargs...)
callback()
return res
end
end
38 changes: 35 additions & 3 deletions src/foldssum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,48 @@ end

## Necessary operations
function product(sum::FoldsSum, v::ITensor)
return Folds.sum(term -> term(v), terms(sum), executor(sum))
return product(Returns(nothing), sum, v)
end

function product(sum::FoldsSum{<:Any,<:DistributedEx}, v::ITensor)
return product(() -> force_gc(), sum, v)
end

function product(callback, sum::FoldsSum, v::ITensor)
return Folds.sum(terms(sum), executor(sum)) do term
res = term(v)
callback()
return res
end
end

function position!(sum::FoldsSum, v::MPS, pos::Int)
new_terms = Folds.map(term -> position!(term, v, pos), terms(sum), executor(sum))
return position!(Returns(nothing), sum, v, pos)
end

function position!(sum::FoldsSum{<:Any,<:DistributedEx}, v::MPS, pos::Int)
return position!(() -> force_gc(), sum, v, pos)
end

function position!(callback, sum::FoldsSum, v::MPS, pos::Int)
new_terms = Folds.map(terms(sum), executor(sum)) do term
res = position!(term, v, pos)
callback()
return res
end
return set_terms(sum, new_terms)
end

function noiseterm(sum::FoldsSum, v::ITensor, dir::String)
return Folds.sum(term -> noiseterm(term, v, dir), terms(sum), executor(sum))
return noiseterm(() -> force_gc(), sum, v, dir)
end

function noiseterm(callback, sum::FoldsSum, v::ITensor, dir::String)
return Folds.sum(terms(sum), executor(sum)) do term
res = noiseterm(term, v, dir)
callback()
return res
end
end

const ThreadedSum{T} = FoldsSum{T,ThreadedEx}
Expand Down
16 changes: 16 additions & 0 deletions src/force_gc.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Default to 6 GB threshold to trigger GC
default_gc_gb_threshold() = 6.0
const gc_gb_threshold = Ref(default_gc_gb_threshold())
get_gc_gb_threshold() = gc_gb_threshold[]
function set_gc_gb_threshold!(gb_threshold)
gc_gb_threshold[] = gb_threshold
return nothing
end

# https://discourse.julialang.org/t/from-multithreading-to-distributed/101984/6
function force_gc(gb_threshold::Real=get_gc_gb_threshold())
if Sys.free_memory() < gb_threshold * 2^30
GC.gc()
end
return nothing
end
2 changes: 1 addition & 1 deletion src/mpisumterm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ set_term(sumterm::MPISumTerm, term) = (@set sumterm.term = term)
set_comm(sumterm::MPISumTerm, comm) = (@set sumterm.comm = comm)

MPISumTerm(mpo::MPO, comm::MPI.Comm) = MPISumTerm(ProjMPO(mpo), comm)
MPISumTerm(mpos::Vector{MPO},comm::MPI.Comm) = MPISumTerm(ProjMPOSum(mpos),comm)
MPISumTerm(mpos::Vector{MPO}, comm::MPI.Comm) = MPISumTerm(ProjMPOSum(mpos), comm)

nsite(sumterm::MPISumTerm) = nsite(term(sumterm))

Expand Down

0 comments on commit be76312

Please sign in to comment.