Skip to content

Commit

Permalink
mpi: inline small messages
Browse files Browse the repository at this point in the history
In this PR we add support to include the body of small (for a definition
of small) messages inside the message body. This increases the size of
_all_ messages being moved around, but hopefully spares the need to
malloc/free small messages.
  • Loading branch information
csegarragonz committed Mar 28, 2024
1 parent c3dbe3b commit 294c100
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 47 deletions.
20 changes: 18 additions & 2 deletions include/faabric/mpi/MpiMessage.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
#include <cstdint>
#include <vector>

// Constant copied from OpenMPI's SM implementation. It indicates the maximum
// number of Bytes that we may inline in a message (rather than malloc-ing)
// https://github.com/open-mpi/ompi/blob/main/opal/mca/btl/sm/btl_sm_component.c#L153
#define MPI_MAX_INLINE_SEND 256

namespace faabric::mpi {

enum MpiMessageType : int32_t
Expand Down Expand Up @@ -49,7 +54,11 @@ struct MpiMessage
// struct 8-aligned
int32_t requestId;
MpiMessageType messageType;
void* buffer;
union
{
void* buffer;
uint8_t inlineMsg[MPI_MAX_INLINE_SEND];
};
};
static_assert((sizeof(MpiMessage) % 8) == 0, "MPI message must be 8-aligned!");

Expand All @@ -60,7 +69,14 @@ inline size_t payloadSize(const MpiMessage& msg)

inline size_t msgSize(const MpiMessage& msg)
{
return sizeof(MpiMessage) + payloadSize(msg);
size_t payloadSz = payloadSize(msg);

// If we can inline the message, we do not need to add anything else
if (payloadSz < MPI_MAX_INLINE_SEND) {
return sizeof(MpiMessage);
}

return sizeof(MpiMessage) + payloadSz;
}

void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg);
Expand Down
13 changes: 7 additions & 6 deletions src/mpi/MpiMessage.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,25 @@ void parseMpiMsg(const std::vector<uint8_t>& bytes, MpiMessage* msg)
assert(msg != nullptr);
assert(bytes.size() >= sizeof(MpiMessage));
std::memcpy(msg, bytes.data(), sizeof(MpiMessage));
size_t thisPayloadSize = bytes.size() - sizeof(MpiMessage);
assert(thisPayloadSize == payloadSize(*msg));
size_t thisPayloadSize = payloadSize(*msg);

if (thisPayloadSize == 0) {
msg->buffer = nullptr;
return;
}

msg->buffer = faabric::util::malloc(thisPayloadSize);
std::memcpy(
msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize);
if (thisPayloadSize > MPI_MAX_INLINE_SEND) {
msg->buffer = faabric::util::malloc(thisPayloadSize);
std::memcpy(
msg->buffer, bytes.data() + sizeof(MpiMessage), thisPayloadSize);
}
}

void serializeMpiMsg(std::vector<uint8_t>& buffer, const MpiMessage& msg)
{
std::memcpy(buffer.data(), &msg, sizeof(MpiMessage));
size_t payloadSz = payloadSize(msg);
if (payloadSz > 0 && msg.buffer != nullptr) {
if (payloadSz > MPI_MAX_INLINE_SEND && msg.buffer != nullptr) {
std::memcpy(buffer.data() + sizeof(MpiMessage), msg.buffer, payloadSz);
}
}
Expand Down
70 changes: 52 additions & 18 deletions src/mpi/MpiWorld.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,8 @@ void MpiWorld::send(int sendRank,
MpiMessageType messageType)
{
// Sanity-check input parameters
// TODO: should we just make this assertions and wait for something else
// to seg-fault down the line?
checkRanksRange(sendRank, recvRank);
if (getHostForRank(sendRank) != thisHost) {
SPDLOG_ERROR("Trying to send message from a non-local rank: {}",
Expand All @@ -609,34 +611,45 @@ void MpiWorld::send(int sendRank,
.recvRank = recvRank,
.typeSize = dataType->size,
.count = count,
.messageType = messageType,
.buffer = nullptr };
.messageType = messageType };

// Mock the message sending in tests
// TODO: can we get rid of this atomic in the hot path?
if (faabric::util::isMockMode()) {
mpiMockedMessages[sendRank].push_back(msg);
return;
}

bool mustSendData = count > 0 && buffer != nullptr;
size_t dataSize = count * dataType->size;
bool mustSendData = dataSize > 0 && buffer != nullptr;

// Dispatch the message locally or globally
if (isLocal) {
// Take control over the buffer data if we are gonna move it to
// the in-memory queues for local messaging
if (mustSendData) {
void* bufferPtr = faabric::util::malloc(count * dataType->size);
std::memcpy(bufferPtr, buffer, count * dataType->size);
if (dataSize < MPI_MAX_INLINE_SEND) {
std::memcpy(msg.inlineMsg, buffer, count * dataType->size);
} else {
void* bufferPtr = faabric::util::malloc(count * dataType->size);
std::memcpy(bufferPtr, buffer, count * dataType->size);

msg.buffer = bufferPtr;
msg.buffer = bufferPtr;
}
} else {
msg.buffer = nullptr;
}

SPDLOG_TRACE(
"MPI - send {} -> {} ({})", sendRank, recvRank, messageType);
getLocalQueue(sendRank, recvRank)->enqueue(msg);
} else {
if (mustSendData) {
msg.buffer = (void*)buffer;
if (dataSize < MPI_MAX_INLINE_SEND) {
std::memcpy(msg.inlineMsg, buffer, count * dataType->size);
} else {
msg.buffer = (void*)buffer;

Check warning on line 651 in src/mpi/MpiWorld.cpp

View check run for this annotation

Codecov / codecov/patch

src/mpi/MpiWorld.cpp#L648-L651

Added lines #L648 - L651 were not covered by tests
}
}

SPDLOG_TRACE(
Expand Down Expand Up @@ -704,17 +717,25 @@ void MpiWorld::doRecv(const MpiMessage& m,
}
assert(m.messageType == messageType);
assert(m.count <= count);
size_t dataSize = m.count * dataType->size;

// We must copy the data into the application-provided buffer
if (m.count > 0 && m.buffer != nullptr) {
if (dataSize > 0) {
// Make sure we do not overflow the recepient buffer
auto bytesToCopy =
std::min<size_t>(m.count * dataType->size, count * dataType->size);
std::memcpy(buffer, m.buffer, bytesToCopy);

// This buffer has been malloc-ed either as part of a local `send`
// or as part of a remote `parseMpiMsg`
faabric::util::free((void*)m.buffer);
if (dataSize > MPI_MAX_INLINE_SEND) {
assert(m.buffer != nullptr);

std::memcpy(buffer, m.buffer, bytesToCopy);

// This buffer has been malloc-ed either as part of a local `send`
// or as part of a remote `parseMpiMsg`
faabric::util::free((void*)m.buffer);
} else {
std::memcpy(buffer, m.inlineMsg, bytesToCopy);
}
}

// Set status values if required
Expand Down Expand Up @@ -1886,21 +1907,34 @@ MpiMessage MpiWorld::recvBatchReturnLast(int sendRank,
// Copy the request id so that it is not overwritten
int tmpRequestId = itr->requestId;

// Copy into current slot in the list, but keep a copy to the
// app-provided buffer to read data into
// Copy the app-provided buffer to recv data into so that it is
// not overwritten too. Note that, irrespective of wether the
// message is inlined or not, we always use the buffer pointer to
// point to the app-provided recv-buffer
void* providedBuffer = itr->buffer;

// Copy into current slot in the list
*itr = getLocalQueue(sendRank, recvRank)->dequeue();
itr->requestId = tmpRequestId;

if (itr->buffer != nullptr) {
// If we have send a non-inlined message, copy the data into the
// provided buffer, free the one in the queue,
size_t dataSize = itr->count * itr->typeSize;
if (dataSize > MPI_MAX_INLINE_SEND) {
assert(itr->buffer != nullptr);
assert(providedBuffer != nullptr);
// If buffers are not null, we must have a non-zero size
assert((itr->count * itr->typeSize) > 0);
std::memcpy(
providedBuffer, itr->buffer, itr->count * itr->typeSize);

faabric::util::free(itr->buffer);

itr->buffer = providedBuffer;
} else if (dataSize > 0) {
std::memcpy(
providedBuffer, itr->inlineMsg, itr->count * itr->typeSize);
} else {
itr->buffer = providedBuffer;
}
itr->buffer = providedBuffer;
}
assert(itr->messageType != MpiMessageType::UNACKED_MPI_MESSAGE);

Expand Down
3 changes: 0 additions & 3 deletions tests/dist/mpi/examples/mpi_isendrecv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ int iSendRecv()
}
printf("Rank %i - async working properly\n", rank);

delete sendRequest;
delete recvRequest;

MPI_Finalize();

return 0;
Expand Down
3 changes: 0 additions & 3 deletions tests/dist/mpi/examples/mpi_send_sync_async.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ int sendSyncAsync()
MPI_Send(&r, 1, MPI_INT, r, 0, MPI_COMM_WORLD);
MPI_Wait(&sendRequest, MPI_STATUS_IGNORE);
}
delete sendRequest;
} else {
// Asynchronously receive twice from rank 0
int recvValue1 = -1;
Expand All @@ -47,8 +46,6 @@ int sendSyncAsync()
rank);
return 1;
}
delete recvRequest1;
delete recvRequest2;
}
printf("Rank %i - send sync and async working properly\n", rank);

Expand Down
53 changes: 41 additions & 12 deletions tests/test/mpi/test_mpi_message.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ bool areMpiMsgEqual(const MpiMessage& msgA, const MpiMessage& msgB)
return false;
}

// First, compare the message body (excluding the pointer, which we
// know is at the end)
if (std::memcmp(&msgA, &msgB, sizeof(MpiMessage) - sizeof(void*)) != 0) {
// First, compare the message body (excluding the union at the end)
size_t unionSize = sizeof(uint8_t) * MPI_MAX_INLINE_SEND;
if (std::memcmp(&msgA, &msgB, sizeof(MpiMessage) - unionSize) != 0) {
return false;
}

Expand All @@ -35,7 +35,11 @@ bool areMpiMsgEqual(const MpiMessage& msgA, const MpiMessage& msgB)
// Assert, as this should pass given the previous comparisons
assert(payloadSizeA == payloadSizeB);

return std::memcmp(msgA.buffer, msgB.buffer, payloadSizeA) == 0;
if (payloadSizeA > MPI_MAX_INLINE_SEND) {
return std::memcmp(msgA.buffer, msgB.buffer, payloadSizeA) == 0;
}

return std::memcmp(msgA.inlineMsg, msgB.inlineMsg, payloadSizeA) == 0;
}

TEST_CASE("Test getting a message size", "[mpi]")
Expand All @@ -59,11 +63,23 @@ TEST_CASE("Test getting a message size", "[mpi]")
expectedPayloadSize = 0;
}

SECTION("Non-empty message")
SECTION("Non-empty (small) message")
{
std::vector<int> nums = { 1, 2, 3, 4, 5, 6, 6 };
msg.count = nums.size();
msg.typeSize = sizeof(int);
std::memcpy(msg.inlineMsg, nums.data(), nums.size() * sizeof(int));

expectedPayloadSize = sizeof(int) * nums.size();
expectedMsgSize = sizeof(MpiMessage);
}

SECTION("Non-empty (large) message")
{
int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t);
std::vector<int32_t> nums(maxNumInts + 3, 3);
msg.count = nums.size();
msg.typeSize = sizeof(int);
msg.buffer = faabric::util::malloc(msg.count * msg.typeSize);
std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int));

Expand All @@ -74,7 +90,7 @@ TEST_CASE("Test getting a message size", "[mpi]")
REQUIRE(expectedMsgSize == msgSize(msg));
REQUIRE(expectedPayloadSize == payloadSize(msg));

if (msg.buffer != nullptr) {
if (expectedPayloadSize > MPI_MAX_INLINE_SEND && msg.buffer != nullptr) {
faabric::util::free(msg.buffer);
}
}
Expand All @@ -95,11 +111,22 @@ TEST_CASE("Test (de)serialising an MPI message", "[mpi]")
msg.buffer = nullptr;
}

SECTION("Non-empty message")
SECTION("Non-empty (small) message")
{
std::vector<int> nums = { 1, 2, 3, 4, 5, 6, 6 };
msg.count = nums.size();
msg.typeSize = sizeof(int);
std::memcpy(msg.inlineMsg, nums.data(), nums.size() * sizeof(int));
}

SECTION("Non-empty (large) message")
{
// Make sure we send more ints than the maximum inline
int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t);
std::vector<int32_t> nums(maxNumInts + 3, 3);
msg.count = nums.size();
msg.typeSize = sizeof(int);
REQUIRE(payloadSize(msg) > MPI_MAX_INLINE_SEND);
msg.buffer = faabric::util::malloc(msg.count * msg.typeSize);
std::memcpy(msg.buffer, nums.data(), nums.size() * sizeof(int));
}
Expand All @@ -113,11 +140,13 @@ TEST_CASE("Test (de)serialising an MPI message", "[mpi]")

REQUIRE(areMpiMsgEqual(msg, parsedMsg));

if (msg.buffer != nullptr) {
faabric::util::free(msg.buffer);
}
if (parsedMsg.buffer != nullptr) {
faabric::util::free(parsedMsg.buffer);
if (msg.count * msg.typeSize > MPI_MAX_INLINE_SEND) {
if (msg.buffer != nullptr) {
faabric::util::free(msg.buffer);
}
if (parsedMsg.buffer != nullptr) {
faabric::util::free(parsedMsg.buffer);
}
}
}
}
31 changes: 28 additions & 3 deletions tests/test/mpi/test_mpi_world.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,17 @@ TEST_CASE_METHOD(MpiTestFixture, "Test send and recv on same host", "[mpi]")
int rankA2 = 1;
std::vector<int> messageData;

SECTION("Non-empty message")
SECTION("Non-empty (small) message")
{
messageData = { 0, 1, 2 };
}

SECTION("Non-empty (large) message")
{
int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t);
messageData = std::vector<int>(maxNumInts + 3, 3);
}

SECTION("Empty message")
{
messageData = {};
Expand Down Expand Up @@ -273,8 +279,27 @@ TEST_CASE_METHOD(MpiTestFixture, "Test sendrecv", "[mpi]")
int rankA = 1;
int rankB = 2;
MPI_Status status{};
std::vector<int> messageDataAB = { 0, 1, 2 };
std::vector<int> messageDataBA = { 3, 2, 1, 0 };
std::vector<int> messageDataAB;
std::vector<int> messageDataBA;

SECTION("Empty messages")
{
messageDataAB = {};
messageDataBA = {};
}

SECTION("Small messages")
{
messageDataAB = { 0, 1, 2 };
messageDataBA = { 3, 2, 1, 0 };
}

SECTION("Large messages")
{
int32_t maxNumInts = MPI_MAX_INLINE_SEND / sizeof(int32_t);
messageDataAB = std::vector<int>(maxNumInts + 3, 3);
messageDataBA = std::vector<int>(maxNumInts + 4, 4);
}

// Results
std::vector<int> recvBufferA(messageDataBA.size(), 0);
Expand Down

0 comments on commit 294c100

Please sign in to comment.