diff --git a/src/StructJuMP.jl b/src/StructJuMP.jl index 77cce99..8bdee6f 100644 --- a/src/StructJuMP.jl +++ b/src/StructJuMP.jl @@ -23,58 +23,13 @@ end # --------------- # StructureData # --------------- -if isdefined(:MPI) - type MPIWrapper - comm::MPI.Comm - init::Function - - function MPIWrapper() - instance = new(MPI.Comm(-1)) - finalizer(instance, freeMPIWrapper) - - instance.init = function(ucomm::MPI.Comm) - if isdefined(:MPI) && MPI.Initialized() && ucomm.val == -1 - instance.comm = MPI.COMM_WORLD - elseif isdefined(:MPI) && !MPI.Initialized() - MPI.Init() - instance.comm = MPI.COMM_WORLD - elseif isdefined(:MPI) && MPI.Initialized() && ucomm.val != -1 - instance.comm = ucomm - elseif isdefined(:MPI) && MPI.Finalized() - error("MPI is already finalized!") - else - #doing nothing - end - end - - return instance - end - - end - function freeMPIWrapper(instance::MPIWrapper) - if isdefined(:MPI) && MPI.Initialized() && !MPI.Finalized() - MPI.Finalize() - end - end - - const mpiWrapper = MPIWrapper(); - - type StructureData - probability::Dict{Int,Float64} - children::Dict{Int,JuMP.Model} - parent - num_scen::Int - othermap::Dict{JuMP.Variable,JuMP.Variable} - mpiWrapper::MPIWrapper - end -else - type StructureData - probability::Dict{Int,Float64} - children::Dict{Int,JuMP.Model} - parent - num_scen::Int - othermap::Dict{JuMP.Variable,JuMP.Variable} - end +type StructureData + probability::Dict{Int,Float64} + children::Dict{Int,JuMP.Model} + parent + num_scen::Int + othermap::Dict{JuMP.Variable,JuMP.Variable} + MPIWrapper # Empty unless StructJuMPwithMPI fills it end default_probability(m::JuMP.Model) = 1 / num_scenarios(m) default_probability(::Void) = 1.0 @@ -96,14 +51,21 @@ function structprinthook(io::IO, m::Model) end end +type DummyMPIWrapper + comm::Int + init::Function + + DummyMPIWrapper() = new(-1,identity) +end +const dummy_mpi_wrapper = DummyMPIWrapper() + # Constructor with the number of scenarios -function StructuredModel(;solver=JuMP.UnsetSolver(), parent=nothing, same_children_as=nothing, id=0, comm=isdefined(:MPI) ? MPI.Comm(-1) : -1, num_scenarios::Int=0, prob::Float64=default_probability(parent)) +function StructuredModel(;solver=JuMP.UnsetSolver(), parent=nothing, same_children_as=nothing, id=0, comm=nothing, num_scenarios::Int=0, prob::Float64=default_probability(parent), mpi_wrapper=dummy_mpi_wrapper) + _comm = (comm == nothing ? mpi_wrapper.comm : comm) m = JuMP.Model(solver=solver) if parent === nothing id = 0 - if isdefined(:MPI) - mpiWrapper.init(comm) - end + mpi_wrapper.init(_comm) if isdefined(:StructJuMPSolverInterface) JuMP.setsolvehook(m,StructJuMPSolverInterface.sj_solve) end @@ -115,20 +77,16 @@ function StructuredModel(;solver=JuMP.UnsetSolver(), parent=nothing, same_childr end if same_children_as !== nothing - if !isa(same_children_as, JuMP.Model) || !haskey(same_children_as.ext, :Stochastic) - error("The JuMP model given for the argument `same_children_as' is not valid. Please create it using the `StructuredModel' function.") - end - probability = same_children_as.ext[:Stochastic].probability - children = same_children_as.ext[:Stochastic].children - else - probability = Dict{Int, Float64}() - children = Dict{Int, JuMP.Model}() - end - if isdefined(:MPI) - m.ext[:Stochastic] = StructureData(probability, children, parent, num_scenarios, Dict{JuMP.Variable,JuMP.Variable}(), mpiWrapper) + if !isa(same_children_as, JuMP.Model) || !haskey(same_children_as.ext, :Stochastic) + error("The JuMP model given for the argument `same_children_as' is not valid. Please create it using the `StructuredModel' function.") + end + probability = same_children_as.ext[:Stochastic].probability + children = same_children_as.ext[:Stochastic].children else - m.ext[:Stochastic] = StructureData(probability, children, parent, num_scenarios, Dict{JuMP.Variable,JuMP.Variable}()) + probability = Dict{Int, Float64}() + children = Dict{Int, JuMP.Model}() end + m.ext[:Stochastic] = StructureData(probability, children, parent, num_scenarios, Dict{JuMP.Variable,JuMP.Variable}(), mpi_wrapper) # Printing children is important as well JuMP.setprinthook(m, structprinthook) @@ -146,37 +104,12 @@ getchildren(m::JuMP.Model) = getStructure(m).children::Dict{Int,JuMP.Model} getprobability(m::JuMP.Model) = getStructure(m).probability::Dict{Int, Float64} num_scenarios(m::JuMP.Model) = getStructure(m).num_scen::Int - -function getMyRank() - myrank = 0; - mysize = 1; - if isdefined(:MPI) && MPI.Initialized() && !MPI.Finalized() - comm = MPI.COMM_WORLD - mysize = MPI.Comm_size(comm) - myrank = MPI.Comm_rank(comm) - end - return myrank,mysize -end - -function getProcIdxSet(numScens::Integer) - mysize = 1; - myrank = 0; - if isdefined(:MPI) == true && MPI.Initialized() == true - comm = MPI.COMM_WORLD - mysize = MPI.Comm_size(comm) - myrank = MPI.Comm_rank(comm) - end - # Why don't we just take a round-and-robin? - proc_idx_set = Int[]; - for s = myrank:mysize:(numScens-1) - push!(proc_idx_set, s+1); - end - return proc_idx_set; -end +getProcIdxSet(dummy_mpi_wrapper::DummyMPIWrapper, num_scenarios) = 1:num_scenarios function getProcIdxSet(m::JuMP.Model) + haskey(m.ext[:Stochastic]) || error("Cannot use @second_stage without using the StructuredModel constructor") numScens = num_scenarios(m) - return getProcIdxSet(numScens) + return getProcIdxSet(getStructure(m).mpi_wrapper, numScens) end macro second_stage(m,ind,code)