Skip to content

Commit

Permalink
make CollectTrace for profiling by iteration async (#966)
Browse files Browse the repository at this point in the history
Summary:
This fix issue #953. Makes `libkineto::api().client()->stop()` and `stopTraceInternal` run in `profilerThread_`  so that, the training  process will not be blocked.

Pull Request resolved: #966

Reviewed By: sanrise

Differential Revision: D64214259

Pulled By: sraikund16

fbshipit-source-id: a27398e266df8502579d49b2c7d65c1863788008
  • Loading branch information
staugust authored and facebook-github-bot committed Oct 15, 2024
1 parent ed052ea commit 7a2a167
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 19 deletions.
71 changes: 54 additions & 17 deletions libkineto/src/CuptiActivityProfiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ std::ostream& operator<<(
return oss;
}

CuptiActivityProfiler::~CuptiActivityProfiler() {
if (collectTraceThread_ && collectTraceThread_->joinable()) {
collectTraceThread_->join();
}
}

void CuptiActivityProfiler::transferCpuTrace(
std::unique_ptr<libkineto::CpuTraceBuffer> cpuTrace) {
std::lock_guard<std::recursive_mutex> guard(mutex_);
Expand Down Expand Up @@ -1120,6 +1126,33 @@ void CuptiActivityProfiler::configure(
currentRunloopState_ = RunloopState::Warmup;
}

void CuptiActivityProfiler::collectTrace(
bool collection_done,
const std::chrono::time_point<std::chrono::system_clock>& now) {
if (libkineto::api().client()) {
libkineto::api().client()->stop();
}

#if defined(HAS_CUPTI) || defined(HAS_ROCTRACER)
if (cupti_.stopCollection) {
ecs_.cupti_stopped_early = cupti_.stopCollection;
LOG(ERROR)
<< "State: CollectTrace stopped by CUPTI. (Buffer size configured is "
<< config_->activitiesMaxGpuBufferSize() / 1024 / 1024 << "MB)";
}
#endif // HAS_CUPTI || HAS_ROCTRACER
std::lock_guard<std::recursive_mutex> guard(mutex_);
stopTraceInternal(now);
VLOG_IF(0, collection_done) << "Reached profile end time";
UST_LOGGER_MARK_COMPLETED(kCollectionStage);
}

void CuptiActivityProfiler::ensureCollectTraceDone() {
if (collectTraceThread_ && collectTraceThread_->joinable()) {
collectTraceThread_->join();
collectTraceThread_.reset(nullptr);
}
}
void CuptiActivityProfiler::toggleCollectionDynamic(const bool enable) {
#ifdef HAS_CUPTI
if (enable) {
Expand Down Expand Up @@ -1266,26 +1299,26 @@ const time_point<system_clock> CuptiActivityProfiler::performRunLoopStep(
) {
// Update runloop state first to prevent further updates to shared state
LOG(INFO) << "Tracing complete.";
VLOG_IF(1, currentIter > 0)
VLOG_IF(1, currentIter >= 0)
<< "This state change was invoked by application's step() call";

if (libkineto::api().client()) {
libkineto::api().client()->stop();
}

#if defined(HAS_CUPTI) || defined(HAS_ROCTRACER)
if (cupti_.stopCollection) {
ecs_.cupti_stopped_early = cupti_.stopCollection;
LOG(ERROR)
<< "State: CollectTrace stopped by CUPTI. (Buffer size configured is "
<< config_->activitiesMaxGpuBufferSize() / 1024 / 1024 << "MB)";
// currentIter >= 0 means this is an iteration-based collection,
// triggered by pytorch main thread, it should be executed in another
// thread in case pytorch main thread is blocked
if (currentIter >= 0) {
// if collectTraceThread_ is already running, there's no need to
// execute collectTrace twice.
if (!collectTraceThread_) {
std::lock_guard<std::recursive_mutex> guard(mutex_);
collectTraceThread_ = std::make_unique<std::thread>(
&CuptiActivityProfiler::collectTrace,
this,
collection_done,
now);
}
break;
}
#endif // HAS_CUPTI || HAS_ROCTRACER

std::lock_guard<std::recursive_mutex> guard(mutex_);
stopTraceInternal(now);
VLOG_IF(0, collection_done) << "Reached profile end time";
UST_LOGGER_MARK_COMPLETED(kCollectionStage);
collectTrace(collection_done, now);
} else if (derivedConfig_->isProfilingByIteration()) {
// nothing to do here
} else if (
Expand All @@ -1305,6 +1338,10 @@ const time_point<system_clock> CuptiActivityProfiler::performRunLoopStep(
if (currentIter >= 0) {
return new_wakeup_time;
}

// Before processing, we should wait for collectTrace thread to be done.
ensureCollectTraceDone();

// FIXME: Probably want to allow interruption here
// for quickly handling trace request via synchronous API
std::lock_guard<std::recursive_mutex> guard(mutex_);
Expand Down
14 changes: 13 additions & 1 deletion libkineto/src/CuptiActivityProfiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class CuptiActivityProfiler {
CuptiActivityProfiler(RoctracerActivityApi& rai, bool cpuOnly);
CuptiActivityProfiler(const CuptiActivityProfiler&) = delete;
CuptiActivityProfiler& operator=(const CuptiActivityProfiler&) = delete;

~CuptiActivityProfiler();
bool isActive() const {
return currentRunloopState_ != RunloopState::WaitForRequest;
}
Expand Down Expand Up @@ -170,6 +170,13 @@ class CuptiActivityProfiler {
stopTraceInternal(now);
}

// Collect CPU and GPU traces
void collectTrace(
bool collectionDone,
const std::chrono::time_point<std::chrono::system_clock>& now);

// Ensure collectTrace is done
void ensureCollectTraceDone();
// Process CPU and GPU traces
void processTrace(ActivityLogger& logger) {
std::lock_guard<std::recursive_mutex> guard(mutex_);
Expand Down Expand Up @@ -483,6 +490,11 @@ class CuptiActivityProfiler {
// Mutex to protect non-atomic access to below state
std::recursive_mutex mutex_;

// Add a thread to collect both cpu and gpu traces in case torch main thread
// is blocked when profiling by iterations is enabled. Issue #953 shows
// details.
std::unique_ptr<std::thread> collectTraceThread_{nullptr};

// Runloop phase
std::atomic<RunloopState> currentRunloopState_{RunloopState::WaitForRequest};

Expand Down
2 changes: 1 addition & 1 deletion libkineto/test/CuptiActivityProfilerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,8 @@ TEST(CuptiActivityProfiler, AsyncTraceUsingIter) {
EXPECT_TRUE(profiler.isActive());

auto nextnext = next + milliseconds(1000);

profiler.performRunLoopStep(nextnext, nextnext);
profiler.ensureCollectTraceDone();
profiler.performRunLoopStep(nextnext, nextnext);

// Assert that tracing has completed
Expand Down

0 comments on commit 7a2a167

Please sign in to comment.