diff --git a/include/ygm/comm.hpp b/include/ygm/comm.hpp index 85a34394..8601c5b8 100644 --- a/include/ygm/comm.hpp +++ b/include/ygm/comm.hpp @@ -191,6 +191,10 @@ class comm { void flush_send_buffer(int dest); + void handle_completed_send(mpi_isend_request &req_buffer); + + void check_completed_sends(); + void check_if_production_halt_required(); void flush_all_local_and_process_incoming(); diff --git a/include/ygm/detail/comm.ipp b/include/ygm/detail/comm.ipp index b1fab497..1025c3f5 100644 --- a/include/ygm/detail/comm.ipp +++ b/include/ygm/detail/comm.ipp @@ -525,6 +525,7 @@ inline std::pair comm::barrier_reduce_counts() { inline void comm::flush_send_buffer(int dest) { static size_t counter = 0; if (m_vec_send_buffers[dest].size() > 0) { + check_completed_sends(); mpi_isend_request request; if (m_free_send_buffers.empty()) { request.buffer = std::make_shared(); @@ -552,6 +553,36 @@ inline void comm::flush_send_buffer(int dest) { } } +/** + * @brief Handle a completed send by putting the buffer on the free list or + * allowing it to be freed + */ +inline void comm::handle_completed_send(mpi_isend_request &req_buffer) { + m_pending_isend_bytes -= req_buffer.buffer->size(); + if (m_free_send_buffers.size() < config.send_buffer_free_list_len) { + req_buffer.buffer->clear(); + m_free_send_buffers.push_back(req_buffer.buffer); + } +} + +/** + * @brief Test completed sends + */ +inline void comm::check_completed_sends() { + if (!m_send_queue.empty()) { + int flag(1); + while (flag && not m_send_queue.empty()) { + YGM_ASSERT_MPI( + MPI_Test(&(m_send_queue.front().request), &flag, MPI_STATUS_IGNORE)); + stats.isend_test(); + if (flag) { + handle_completed_send(m_send_queue.front()); + m_send_queue.pop_front(); + } + } + } +} + inline void comm::check_if_production_halt_required() { while (m_enable_interrupts && !m_in_process_receive_queue && m_pending_isend_bytes > config.buffer_size) { @@ -947,9 +978,7 @@ inline bool comm::process_receive_queue() { } for (int i = 0; i < outcount; ++i) { if (twin_indices[i] == 0) { // completed a iSend - m_pending_isend_bytes -= m_send_queue.front().buffer->size(); - m_send_queue.front().buffer->clear(); - m_free_send_buffers.push_back(m_send_queue.front().buffer); + handle_completed_send(m_send_queue.front()); m_send_queue.pop_front(); } else { // completed an iRecv -- COPIED FROM BELOW received_to_return = true; @@ -962,18 +991,7 @@ inline bool comm::process_receive_queue() { } } } else { - if (!m_send_queue.empty()) { - int flag(0); - YGM_ASSERT_MPI( - MPI_Test(&(m_send_queue.front().request), &flag, MPI_STATUS_IGNORE)); - stats.isend_test(); - if (flag) { - m_pending_isend_bytes -= m_send_queue.front().buffer->size(); - m_send_queue.front().buffer->clear(); - m_free_send_buffers.push_back(m_send_queue.front().buffer); - m_send_queue.pop_front(); - } - } + check_completed_sends(); } received_to_return |= local_process_incoming(); diff --git a/include/ygm/detail/comm_environment.hpp b/include/ygm/detail/comm_environment.hpp index 3b833a7d..3c2c0aca 100644 --- a/include/ygm/detail/comm_environment.hpp +++ b/include/ygm/detail/comm_environment.hpp @@ -54,6 +54,9 @@ class comm_environment { if (const char* cc = std::getenv("YGM_COMM_ISSEND_FREQ")) { freq_issend = convert(cc); } + if (const char* cc = std::getenv("YGM_COMM_SEND_BUFFER_FREE_LIST_LEN")) { + send_buffer_free_list_len = convert(cc); + } if (const char* cc = std::getenv("YGM_COMM_ROUTING")) { if (std::string(cc) == "NONE") { routing = routing_type::NONE; @@ -96,8 +99,9 @@ class comm_environment { size_t irecv_size = 1024 * 1024 * 1024; size_t num_irecvs = 8; - size_t num_isends_wait = 4; - size_t freq_issend = 8; + size_t num_isends_wait = 4; + size_t freq_issend = 8; + size_t send_buffer_free_list_len = 32; routing_type routing = routing_type::NONE;