Skip to content

Commit

Permalink
Separate ucxx::DelayedSubmissionCollection implementations (rapidsa…
Browse files Browse the repository at this point in the history
…i#89)

Separate `ucxx::DelayedSubmissionCollection` implementations for requests and generic callbacks, thus reducing some code duplication and making collection logic more self-contained.

This may help tackling issues that are only reproducible in CentOS 7.

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)

Approvers:
  - Lawrence Mitchell (https://github.com/wence-)

URL: rapidsai#89
  • Loading branch information
pentschev authored Oct 12, 2023
1 parent 25daeea commit 8e8b764
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 80 deletions.
143 changes: 136 additions & 7 deletions cpp/include/ucxx/delayed_submission.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <functional>
#include <memory>
#include <mutex>
#include <string_view>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -56,15 +57,139 @@ class DelayedSubmission {
const ucs_memory_type_t memoryType = UCS_MEMORY_TYPE_UNKNOWN);
};

template <typename T>
class BaseDelayedSubmissionCollection {
protected:
std::string_view _name{
"undefined"}; ///< The human-readable name of the collection, used for logging
bool _enabled{true}; ///< Whether the resource required to process the collection is enabled.
std::vector<T> _collection{}; ///< The collection.
std::mutex _mutex{}; ///< Mutex to provide access to `_collection`.

/**
* @brief Log message during `schedule()`.
*
* Log a specialized message while `schedule()` is being executed.
*
* @param[in] item the callback that was passed as argument to `schedule()`.
*/
virtual void scheduleLog(T item) = 0;

/**
* @brief Process a single item during `process()`.
*
* Method called by `process()` to process a single item of the collection.
*
* @param[in] item the callback that was passed as argument to `schedule()` when
* the first registered.
*/
virtual void processItem(T item) = 0;

public:
/**
* @brief Constructor for a thread-safe delayed submission collection.
*
* Construct a thread-safe delayed submission collection. A delayed submission collection
* provides two operations: schedule and process. The `schedule()` method will push an
* operation into the collection, whereas the `process()` will invoke all callbacks that
* were previously pushed into the collection and clear the collection.
*
* @param[in] name human-readable name of the collection, used for logging.
*/
explicit BaseDelayedSubmissionCollection(const std::string_view name, const bool enabled)
: _name{name}, _enabled{enabled}
{
}

BaseDelayedSubmissionCollection() = delete;
BaseDelayedSubmissionCollection(const BaseDelayedSubmissionCollection&) = delete;
BaseDelayedSubmissionCollection& operator=(BaseDelayedSubmissionCollection const&) = delete;
BaseDelayedSubmissionCollection(BaseDelayedSubmissionCollection&& o) = delete;
BaseDelayedSubmissionCollection& operator=(BaseDelayedSubmissionCollection&& o) = delete;

/**
* @brief Register a callable or complex-type for delayed submission.
*
* Register a simple callback, or complex-type with a callback (requires specialization),
* for delayed submission that will be executed when the request is in fact submitted when
* `process()` is called.
*
* Raise an exception if `false` was specified as the `enabled` argument to the constructor.
*
* @throws std::runtime_error if `_enabled` is `false`.
*
* @param[in] item the callback that will be executed by `process()` when the
* operation is submitted.
* @param[in] resourceEnabled whether the resource is enabled.
*/
virtual void schedule(T item)
{
if (!_enabled) throw std::runtime_error("Resource is disabled.");

{
std::lock_guard<std::mutex> lock(_mutex);
_collection.push_back(item);
}
scheduleLog(item);
}

/**
* @brief Process all pending callbacks.
*
* Process all pending generic. Generic callbacks are deemed completed when their
* execution completes.
*/
void process()
{
decltype(_collection) itemsToProcess;
{
std::lock_guard<std::mutex> lock(_mutex);
// Move _collection to a local copy in order to to hold the lock for as
// short as possible
itemsToProcess = std::move(_collection);
}

if (itemsToProcess.size() > 0) {
ucxx_trace_req("Submitting %lu %s callbacks", itemsToProcess.size(), _name);
for (auto& item : itemsToProcess)
processItem(item);
}
}
};

class RequestDelayedSubmissionCollection
: public BaseDelayedSubmissionCollection<
std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType>> {
protected:
void scheduleLog(
std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType> item) override;

void processItem(
std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType> item) override;

public:
explicit RequestDelayedSubmissionCollection(const std::string_view name, const bool enabled);
};

class GenericDelayedSubmissionCollection
: public BaseDelayedSubmissionCollection<DelayedSubmissionCallbackType> {
protected:
void scheduleLog(DelayedSubmissionCallbackType item) override;

void processItem(DelayedSubmissionCallbackType callback) override;

public:
explicit GenericDelayedSubmissionCollection(const std::string_view name);
};

class DelayedSubmissionCollection {
private:
std::vector<DelayedSubmissionCallbackType>
_genericPre{}; ///< The collection of all known generic pre-progress operations.
std::vector<DelayedSubmissionCallbackType>
_genericPost{}; ///< The collection of all known generic post-progress operations.
std::vector<std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType>>
_requests{}; ///< The collection of all known delayed request submission operations.
std::mutex _mutex{}; ///< Mutex to provide access to the collection.
GenericDelayedSubmissionCollection _genericPre{
"generic pre"}; ///< The collection of all known generic pre-progress operations.
GenericDelayedSubmissionCollection _genericPost{
"generic post"}; ///< The collection of all known generic post-progress operations.
RequestDelayedSubmissionCollection _requests{
"request", false}; ///< The collection of all known delayed request submission operations.
bool _enableDelayedRequestSubmission{false};

public:
Expand Down Expand Up @@ -96,6 +221,10 @@ class DelayedSubmissionCollection {
* operation, only that it has been submitted. The completion of each delayed request
* submission is handled externally by the implementation of the object being processed,
* for example by checking the result of `ucxx::Request::isCompleted()`.
*
* Generic callbacks may be used to to pass information between threads on the subject
* that requests have been in fact processed, therefore, requests are processed first,
* then generic callbacks are.
*/
void processPre();

Expand Down
123 changes: 50 additions & 73 deletions cpp/src/delayed_submission.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,50 @@ DelayedSubmission::DelayedSubmission(const bool send,
{
}

RequestDelayedSubmissionCollection::RequestDelayedSubmissionCollection(const std::string_view name,
const bool enabled)
: BaseDelayedSubmissionCollection<
std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType>>{name, enabled}
{
}

void RequestDelayedSubmissionCollection::scheduleLog(
std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType> item)
{
ucxx_trace_req("Registered %s: %p", _name, item.first.get());
}

void RequestDelayedSubmissionCollection::processItem(
std::pair<std::shared_ptr<Request>, DelayedSubmissionCallbackType> item)
{
auto& req = item.first;
auto& callback = item.second;

ucxx_trace_req("Submitting %s callbacks: %p", _name, req.get());

if (callback) callback();
}

GenericDelayedSubmissionCollection::GenericDelayedSubmissionCollection(const std::string_view name)
: BaseDelayedSubmissionCollection<DelayedSubmissionCallbackType>{name, true}
{
}

void GenericDelayedSubmissionCollection::scheduleLog(DelayedSubmissionCallbackType item)
{
ucxx_trace_req("Registered %s", _name);
}

void GenericDelayedSubmissionCollection::processItem(DelayedSubmissionCallbackType callback)
{
ucxx_trace_req("Submitting %s callback", _name);

if (callback) callback();
}

DelayedSubmissionCollection::DelayedSubmissionCollection(bool enableDelayedRequestSubmission)
: _enableDelayedRequestSubmission(enableDelayedRequestSubmission)
: _enableDelayedRequestSubmission(enableDelayedRequestSubmission),
_requests(RequestDelayedSubmissionCollection{"request", enableDelayedRequestSubmission})
{
}

Expand All @@ -34,92 +76,27 @@ bool DelayedSubmissionCollection::isDelayedRequestSubmissionEnabled() const

void DelayedSubmissionCollection::processPre()
{
decltype(_requests) requestsToProcess;
{
std::lock_guard<std::mutex> lock(_mutex);
// Move _requests to a local copy in order to to hold the lock for as
// short as possible
requestsToProcess = std::move(_requests);
}
if (requestsToProcess.size() > 0) {
ucxx_trace_req("Submitting %lu requests", requestsToProcess.size());
for (auto& pair : requestsToProcess) {
auto& req = pair.first;
auto& callback = pair.second;

ucxx_trace_req("Submitting request: %p", req.get());

if (callback) callback();
}
}
decltype(_genericPre) callbacks;
{
std::lock_guard<std::mutex> lock(_mutex);
// Move _genericPre to a local copy in order to to hold the lock for as
// short as possible
callbacks = std::move(_genericPre);
}

if (callbacks.size() > 0) {
ucxx_trace_req("Submitting %lu generic", callbacks.size());

for (auto& callback : callbacks) {
ucxx_trace_req("Submitting generic");

if (callback) callback();
}
}
}
_requests.process();

void DelayedSubmissionCollection::processPost()
{
decltype(_genericPost) callbacks;
{
std::lock_guard<std::mutex> lock(_mutex);
// Move _genericPost to a local copy in order to to hold the lock for as
// short as possible
callbacks = std::move(_genericPost);
}

if (callbacks.size() > 0) {
ucxx_trace_req("Submitting %lu generic", callbacks.size());

for (auto& callback : callbacks) {
ucxx_trace_req("Submitting generic");

if (callback) callback();
}
}
_genericPre.process();
}

void DelayedSubmissionCollection::processPost() { _genericPost.process(); }

void DelayedSubmissionCollection::registerRequest(std::shared_ptr<Request> request,
DelayedSubmissionCallbackType callback)
{
if (!isDelayedRequestSubmissionEnabled()) throw std::runtime_error("Context not initialized");

{
std::lock_guard<std::mutex> lock(_mutex);
_requests.push_back({request, callback});
}
ucxx_trace_req("Registered submit request: %p", request.get());
_requests.schedule({request, callback});
}

void DelayedSubmissionCollection::registerGenericPre(DelayedSubmissionCallbackType callback)
{
{
std::lock_guard<std::mutex> lock(_mutex);
_genericPre.push_back({callback});
}
ucxx_trace_req("Registered generic");
_genericPre.schedule(callback);
}

void DelayedSubmissionCollection::registerGenericPost(DelayedSubmissionCallbackType callback)
{
{
std::lock_guard<std::mutex> lock(_mutex);
_genericPost.push_back({callback});
}
ucxx_trace_req("Registered generic");
_genericPost.schedule(callback);
}

} // namespace ucxx

0 comments on commit 8e8b764

Please sign in to comment.