2
2
#include < faabric/mpi/MpiMessage.h>
3
3
#include < faabric/mpi/MpiWorld.h>
4
4
#include < faabric/planner/PlannerClient.h>
5
+ #include < faabric/transport/PointToPointMessage.h>
5
6
#include < faabric/transport/macros.h>
6
7
#include < faabric/util/ExecGraph.h>
7
8
#include < faabric/util/batch.h>
@@ -59,14 +60,16 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost,
59
60
serializeMpiMsg (serialisedBuffer, msg);
60
61
61
62
try {
62
- broker.sendMessage (
63
- thisRankMsg->groupid (),
64
- sendRank,
65
- recvRank,
66
- reinterpret_cast <const uint8_t *>(serialisedBuffer.data ()),
67
- serialisedBuffer.size (),
68
- dstHost,
69
- true );
63
+ // It is safe to send a pointer to a stack-allocated object
64
+ // because the broker will make an additional copy (and so will NNG!)
65
+ faabric::transport::PointToPointMessage msg (
66
+ { .groupId = thisRankMsg->groupid (),
67
+ .sendIdx = sendRank,
68
+ .recvIdx = recvRank,
69
+ .dataSize = serialisedBuffer.size (),
70
+ .dataPtr = (void *)serialisedBuffer.data () });
71
+
72
+ broker.sendMessage (msg, dstHost, true );
70
73
} catch (std::runtime_error& e) {
71
74
SPDLOG_ERROR (" {}:{}:{} Timed out with: MPI - send {} -> {}" ,
72
75
thisRankMsg->appid (),
@@ -80,10 +83,12 @@ void MpiWorld::sendRemoteMpiMessage(std::string dstHost,
80
83
81
84
MpiMessage MpiWorld::recvRemoteMpiMessage (int sendRank, int recvRank)
82
85
{
83
- std::vector<uint8_t > msg;
86
+ faabric::transport::PointToPointMessage msg (
87
+ { .groupId = thisRankMsg->groupid (),
88
+ .sendIdx = sendRank,
89
+ .recvIdx = recvRank });
84
90
try {
85
- msg =
86
- broker.recvMessage (thisRankMsg->groupid (), sendRank, recvRank, true );
91
+ broker.recvMessage (msg, true );
87
92
} catch (std::runtime_error& e) {
88
93
SPDLOG_ERROR (" {}:{}:{} Timed out with: MPI - recv (remote) {} -> {}" ,
89
94
thisRankMsg->appid (),
@@ -96,7 +101,8 @@ MpiMessage MpiWorld::recvRemoteMpiMessage(int sendRank, int recvRank)
96
101
97
102
// TODO(mpi-opt): make sure we minimze copies here
98
103
MpiMessage parsedMsg;
99
- parseMpiMsg (msg, &parsedMsg);
104
+ std::vector<uint8_t > msgBytes ((uint8_t *) msg.dataPtr , (uint8_t *) msg.dataPtr + msg.dataSize );
105
+ parseMpiMsg (msgBytes, &parsedMsg);
100
106
101
107
return parsedMsg;
102
108
}
0 commit comments