Skip to content

Commit fdbeea4

Browse files
authored
Add distributed MPI support for solving classes (#248)
1 parent dc7fd8c commit fdbeea4

12 files changed

+628
-58
lines changed

.clang-format

+1-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ ReferenceAlignment: Pointer
159159
ReflowComments: true
160160
RemoveBracesLLVM: false
161161
RemoveSemicolon: false
162-
RequiresClausePosition: OwnLine
162+
RequiresClausePosition: SingleLine
163163
RequiresExpressionIndentation: OuterScope
164164
SeparateDefinitionBlocks: Leave
165165
ShortNamespaceLines: 1

CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ if (USE_MPI)
141141
elseif (MPI_C_INCLUDE_DIRS_LOWER MATCHES "mpich")
142142
set(MKL_MPI mpich)
143143
endif ()
144+
add_compile_definitions(SUANPAN_DISTRIBUTED)
144145
endif ()
145146

146147
if (USE_MKL)

Domain/DomainState.cpp

+311-32
Large diffs are not rendered by default.

Domain/MetaMat/MetaMat.hpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@
3434
#include "ILU.hpp"
3535
#include "Jacobi.hpp"
3636

37-
template<typename T, typename U> concept ArmaContainer = std::is_floating_point_v<U> && (std::is_convertible_v<T, Mat<U>> || std::is_convertible_v<T, SpMat<U>>) ;
38-
3937
template<sp_d T> class MetaMat;
4038

4139
template<sp_d T> class op_add {
@@ -199,9 +197,9 @@ template<sp_d T> class MetaMat {
199197

200198
virtual void operator*=(T) = 0;
201199

202-
template<ArmaContainer<T> C> int solve(Mat<T>& X, C&& B) { return IterativeSolver::NONE == this->setting.iterative_solver ? this->direct_solve(X, std::forward<C>(B)) : this->iterative_solve(X, std::forward<C>(B)); }
200+
template<typename C> requires is_arma_mat<T, C> int solve(Mat<T>& X, C&& B) { return IterativeSolver::NONE == this->setting.iterative_solver ? this->direct_solve(X, std::forward<C>(B)) : this->iterative_solve(X, std::forward<C>(B)); }
203201

204-
template<ArmaContainer<T> C> Mat<T> solve(C&& B) {
202+
template<typename C> requires is_arma_mat<T, C> Mat<T> solve(C&& B) {
205203
Mat<T> X;
206204

207205
if(SUANPAN_SUCCESS != this->solve(X, std::forward<C>(B))) throw std::runtime_error("fail to solve the system");

Element/Element.cpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -257,25 +257,32 @@ Element::Element(const unsigned T, const unsigned NN, const unsigned ND, uvec&&
257257
Element::Element(const unsigned T, const unsigned NN, const unsigned ND, uvec&& NT, uvec&& MT, const bool F, const MaterialType MTP, std::vector<DOF>&& DI)
258258
: DataElement{std::move(NT), std::move(MT), uvec{}, F}
259259
, ElementBase(T)
260+
, Distributed(T)
260261
, num_node(NN)
261262
, num_dof(ND)
262263
, material_type(MTP)
263264
, section_type(SectionType::D0)
264-
, dof_identifier(std::move(DI)) { suanpan_assert([&] { if(!dof_identifier.empty() && num_dof != dof_identifier.size()) throw invalid_argument("size of dof identifier must meet number of dofs"); }); }
265+
, dof_identifier(std::move(DI)) {
266+
suanpan_assert([&] { if(!dof_identifier.empty() && num_dof != dof_identifier.size()) throw invalid_argument("size of dof identifier must meet number of dofs"); });
267+
}
265268

266269
Element::Element(const unsigned T, const unsigned NN, const unsigned ND, uvec&& NT, uvec&& ST, const bool F, const SectionType STP, std::vector<DOF>&& DI)
267270
: DataElement{std::move(NT), uvec{}, std::move(ST), F}
268271
, ElementBase(T)
272+
, Distributed(T)
269273
, num_node(NN)
270274
, num_dof(ND)
271275
, material_type(MaterialType::D0)
272276
, section_type(STP)
273-
, dof_identifier(std::move(DI)) { suanpan_assert([&] { if(!dof_identifier.empty() && num_dof != dof_identifier.size()) throw invalid_argument("size of dof identifier must meet number of dofs"); }); }
277+
, dof_identifier(std::move(DI)) {
278+
suanpan_assert([&] { if(!dof_identifier.empty() && num_dof != dof_identifier.size()) throw invalid_argument("size of dof identifier must meet number of dofs"); });
279+
}
274280

275281
// for contact elements that use node groups
276282
Element::Element(const unsigned T, const unsigned ND, uvec&& GT)
277283
: DataElement{std::move(GT), {}, {}, false}
278284
, ElementBase(T)
285+
, Distributed(T)
279286
, num_node(static_cast<unsigned>(-1))
280287
, num_dof(ND)
281288
, use_group(true)
@@ -286,6 +293,7 @@ Element::Element(const unsigned T, const unsigned ND, uvec&& GT)
286293
Element::Element(const unsigned T, const unsigned ND, const unsigned ET, const unsigned NT)
287294
: DataElement{{NT}, {}, {}, false}
288295
, ElementBase(T)
296+
, Distributed(ET)
289297
, num_node(static_cast<unsigned>(-1))
290298
, num_dof(ND)
291299
, use_other(ET)

Element/Element.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#define ELEMENT_H
3030

3131
#include <Element/ElementBase.h>
32+
#include <MPI/Distributed.h>
3233

3334
enum class MaterialType : unsigned;
3435
enum class SectionType : unsigned;
@@ -114,7 +115,7 @@ struct DataElement {
114115
const double characteristic_length = 1.;
115116
};
116117

117-
class Element : protected DataElement, public ElementBase {
118+
class Element : protected DataElement, public ElementBase, public Distributed {
118119
const unsigned num_node; // number of nodes
119120
const unsigned num_dof; // number of DoFs
120121
const unsigned num_size = num_dof * num_node; // number of size

MPI/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,6 @@ add_dependencies(tester.pardiso solver.pardiso)
1717
add_dependencies(suanPan solver.pardiso)
1818

1919
install(TARGETS solver.pardiso DESTINATION bin)
20+
21+
add_executable(distributed_obj distributed_obj.cpp)
22+
target_link_libraries(distributed_obj MKL::MKL_SCALAPACK MPI::MPI_C MPI::MPI_CXX MPI::MPI_Fortran)

MPI/Distributed.h

+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*******************************************************************************
2+
* Copyright (C) 2017-2025 Theodore Chang
3+
*
4+
* This program is free software: you can redistribute it and/or modify
5+
* it under the terms of the GNU General Public License as published by
6+
* the Free Software Foundation, either version 3 of the License, or
7+
* (at your option) any later version.
8+
*
9+
* This program is distributed in the hope that it will be useful,
10+
* but WITHOUT ANY WARRANTY; without even the implied warranty of
11+
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12+
* GNU General Public License for more details.
13+
*
14+
* You should have received a copy of the GNU General Public License
15+
* along with this program. If not, see <http://www.gnu.org/licenses/>.
16+
******************************************************************************/
17+
18+
#ifndef DISTRIBUTED_H
19+
#define DISTRIBUTED_H
20+
21+
#include <suanPan.h>
22+
23+
class Distributed {
24+
static constexpr int root_rank{0};
25+
26+
static auto assign_process(const int obj_tag) { return obj_tag % comm_size; }
27+
28+
public:
29+
const int tag, process_rank;
30+
31+
const bool is_local;
32+
33+
explicit Distributed(const int obj_tag)
34+
: tag(obj_tag)
35+
, process_rank(assign_process(tag))
36+
, is_local(comm_rank == process_rank) {}
37+
38+
#ifdef SUANPAN_DISTRIBUTED
39+
/**
40+
* @brief Performs a gather operation on a distributed matrix object.
41+
*
42+
* This function initiates a non-blocking gather operation. If the calling process is the root process,
43+
* it receives data from the specified process. If the calling process is not the root process, it sends
44+
* its data to the root process.
45+
*
46+
* @tparam DT The data type of the matrix elements.
47+
* @param object The matrix object to be gathered.
48+
* @return An optional non-blocking request handle for the gather operation.
49+
*/
50+
template<mpl_data_t DT> std::optional<mpl::irequest> gather(const Mat<DT>& object) {
51+
if(root_rank == comm_rank) {
52+
if(!is_local) return comm_world.irecv(const_cast<DT*>(object.memptr()), mpl::contiguous_layout<DT>{object.n_elem}, process_rank, mpl::tag_t{tag});
53+
}
54+
else if(is_local) return comm_world.isend(object.memptr(), mpl::contiguous_layout<DT>{object.n_elem}, root_rank, mpl::tag_t{tag});
55+
56+
return {};
57+
}
58+
#else
59+
template<mpl_data_t DT> static auto gather(const Mat<DT>&) {}
60+
#endif
61+
};
62+
63+
#endif // DISTRIBUTED_H

MPI/distributed_obj.cpp

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include "Distributed.h"
2+
3+
#include <mpl/mpl.hpp>
4+
5+
class Object : public Distributed {
6+
mat data{-999, -999};
7+
8+
public:
9+
explicit Object(int tag)
10+
: Distributed(tag) {}
11+
12+
auto generate() {
13+
if(is_local) data.fill(fill::randn);
14+
}
15+
16+
auto gather_data() { return gather(data); }
17+
18+
auto print() { printf("Object %d on process %d data: %+.6f %+.6f.\n", tag, mpl::environment::comm_world().rank(), data(0), data(1)); }
19+
};
20+
21+
int main() {
22+
arma_rng::set_seed_random();
23+
24+
std::vector<Object> objects;
25+
26+
for(auto i = 0; i < 100; ++i) objects.emplace_back(i);
27+
28+
for(auto& i : objects) i.generate();
29+
30+
mpl::irequest_pool requests;
31+
for(auto& i : objects)
32+
if(auto req = i.gather_data(); req.has_value()) requests.push(std::move(req).value());
33+
requests.waitall();
34+
35+
if(0 == mpl::environment::comm_world().rank())
36+
for(auto& i : objects) i.print();
37+
38+
return 0;
39+
}

Solver/Integrator/Integrator.cpp

+124-9
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@
1616
******************************************************************************/
1717

1818
#include "Integrator.h"
19+
1920
#include <Domain/DomainBase.h>
2021
#include <Domain/Factory.hpp>
2122

2223
Integrator::Integrator(const unsigned T)
2324
: Tag(T) {}
2425

25-
void Integrator::set_domain(const weak_ptr<DomainBase>& D) { if(database.lock() != D.lock()) database = D; }
26+
void Integrator::set_domain(const weak_ptr<DomainBase>& D) {
27+
if(database.lock() != D.lock()) database = D;
28+
}
2629

2730
shared_ptr<DomainBase> Integrator::get_domain() const { return database.lock(); }
2831

@@ -260,13 +263,69 @@ mat Integrator::solve(sp_mat&& B) {
260263
return X;
261264
}
262265

263-
int Integrator::solve(mat& X, const mat& B) { return database.lock()->get_factory()->get_stiffness()->solve(X, B); }
266+
int Integrator::solve(mat& X, const mat& B) {
267+
int info{0};
268+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
269+
if(0 == comm_rank) info = database.lock()->get_factory()->get_stiffness()->solve(X, B);
270+
271+
if(SUANPAN_SUCCESS == bcast_from_root(info)) {
272+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
273+
// ReSharper disable once CppDFAUnreachableCode
274+
if(0 != comm_rank) X.set_size(B.n_rows, B.n_cols);
275+
bcast_from_root(X);
276+
}
277+
278+
return info;
279+
}
280+
281+
int Integrator::solve(mat& X, const sp_mat& B) {
282+
int info{0};
283+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
284+
if(0 == comm_rank) info = database.lock()->get_factory()->get_stiffness()->solve(X, B);
264285

265-
int Integrator::solve(mat& X, const sp_mat& B) { return database.lock()->get_factory()->get_stiffness()->solve(X, B); }
286+
if(SUANPAN_SUCCESS == bcast_from_root(info)) {
287+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
288+
// ReSharper disable once CppDFAUnreachableCode
289+
if(0 != comm_rank) X.set_size(B.n_rows, B.n_cols);
290+
bcast_from_root(X);
291+
}
292+
293+
return info;
294+
}
266295

267-
int Integrator::solve(mat& X, mat&& B) { return database.lock()->get_factory()->get_stiffness()->solve(X, std::move(B)); }
296+
int Integrator::solve(mat& X, mat&& B) {
297+
const auto n_rows = B.n_rows, n_cols = B.n_cols;
298+
299+
int info{0};
300+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
301+
if(0 == comm_rank) info = database.lock()->get_factory()->get_stiffness()->solve(X, std::move(B));
302+
303+
if(SUANPAN_SUCCESS == bcast_from_root(info)) {
304+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
305+
// ReSharper disable once CppDFAUnreachableCode
306+
if(0 != comm_rank) X.set_size(n_rows, n_cols);
307+
bcast_from_root(X);
308+
}
309+
310+
return info;
311+
}
312+
313+
int Integrator::solve(mat& X, sp_mat&& B) {
314+
const auto n_rows = B.n_rows, n_cols = B.n_cols;
315+
316+
int info{0};
317+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
318+
if(0 == comm_rank) info = database.lock()->get_factory()->get_stiffness()->solve(X, std::move(B));
319+
320+
if(SUANPAN_SUCCESS == bcast_from_root(info)) {
321+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
322+
// ReSharper disable once CppDFAUnreachableCode
323+
if(0 != comm_rank) X.set_size(n_rows, n_cols);
324+
bcast_from_root(X);
325+
}
268326

269-
int Integrator::solve(mat& X, sp_mat&& B) { return database.lock()->get_factory()->get_stiffness()->solve(X, std::move(B)); }
327+
return info;
328+
}
270329

271330
/**
272331
* Avoid machine error accumulation.
@@ -363,13 +422,69 @@ void ExplicitIntegrator::update_from_ninja() {
363422
W->update_trial_acceleration_by(W->get_ninja());
364423
}
365424

366-
int ExplicitIntegrator::solve(mat& X, const mat& B) { return get_domain()->get_factory()->get_mass()->solve(X, B); }
425+
int ExplicitIntegrator::solve(mat& X, const mat& B) {
426+
int info{0};
427+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
428+
if(0 == comm_rank) info = get_domain()->get_factory()->get_mass()->solve(X, B);
429+
430+
if(SUANPAN_SUCCESS == bcast_from_root(info)) {
431+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
432+
// ReSharper disable once CppDFAUnreachableCode
433+
if(0 != comm_rank) X.set_size(B.n_rows, B.n_cols);
434+
bcast_from_root(X);
435+
}
436+
437+
return info;
438+
}
439+
440+
int ExplicitIntegrator::solve(mat& X, const sp_mat& B) {
441+
int info{0};
442+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
443+
if(0 == comm_rank) info = get_domain()->get_factory()->get_mass()->solve(X, B);
444+
445+
if(SUANPAN_SUCCESS == bcast_from_root(info)) {
446+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
447+
// ReSharper disable once CppDFAUnreachableCode
448+
if(0 != comm_rank) X.set_size(B.n_rows, B.n_cols);
449+
bcast_from_root(X);
450+
}
451+
452+
return info;
453+
}
367454

368-
int ExplicitIntegrator::solve(mat& X, const sp_mat& B) { return get_domain()->get_factory()->get_mass()->solve(X, B); }
455+
int ExplicitIntegrator::solve(mat& X, mat&& B) {
456+
const auto n_rows = B.n_rows, n_cols = B.n_cols;
369457

370-
int ExplicitIntegrator::solve(mat& X, mat&& B) { return get_domain()->get_factory()->get_mass()->solve(X, std::move(B)); }
458+
int info{0};
459+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
460+
if(0 == comm_rank) info = get_domain()->get_factory()->get_mass()->solve(X, std::move(B));
371461

372-
int ExplicitIntegrator::solve(mat& X, sp_mat&& B) { return get_domain()->get_factory()->get_mass()->solve(X, std::move(B)); }
462+
if(SUANPAN_SUCCESS == bcast_from_root(info)) {
463+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
464+
// ReSharper disable once CppDFAUnreachableCode
465+
if(0 != comm_rank) X.set_size(n_rows, n_cols);
466+
bcast_from_root(X);
467+
}
468+
469+
return info;
470+
}
471+
472+
int ExplicitIntegrator::solve(mat& X, sp_mat&& B) {
473+
const auto n_rows = B.n_rows, n_cols = B.n_cols;
474+
475+
int info{0};
476+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
477+
if(0 == comm_rank) info = get_domain()->get_factory()->get_mass()->solve(X, std::move(B));
478+
479+
if(SUANPAN_SUCCESS == bcast_from_root(info)) {
480+
// ReSharper disable once CppIfCanBeReplacedByConstexprIf
481+
// ReSharper disable once CppDFAUnreachableCode
482+
if(0 != comm_rank) X.set_size(n_rows, n_cols);
483+
bcast_from_root(X);
484+
}
485+
486+
return info;
487+
}
373488

374489
vec ExplicitIntegrator::from_incre_velocity(const vec&, const uvec&) { throw invalid_argument("support velocity cannot be used with explicit integrator"); }
375490

0 commit comments

Comments
 (0)