Skip to content

Commit f12703c

Browse files
suopytorchmergebot
authored andcommitted
Revert D34604068: [PyTorch] [Model Tracer] Use c10::Synchronized<T> for kernel dtype tracer
Test Plan: revert-hammer Differential Revision: D34604068 (pytorch@6fd6fe0) Original commit changeset: 1ec50ada8112 Original Phabricator Diff: D34604068 (pytorch@6fd6fe0) fbshipit-source-id: 8b80bfd947c96108306e4472505c1af62c4fe8cb (cherry picked from commit 62d6d29)
1 parent 689cd22 commit f12703c

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.cpp

+9-6
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@ KernelDTypeTracer::KernelDTypeTracer() {
1515
std::string kernel_tag = name.substr(0, dollar_pos);
1616
std::string dtype = name.substr(dollar_pos + 1);
1717

18-
getCalledKernelTags().withLock([&](kernel_tags_type& kernel_tags) {
19-
kernel_tags[kernel_tag].insert(dtype);
20-
});
18+
std::lock_guard<std::mutex> guard(getMutex());
19+
getCalledKernelTags()[kernel_tag].insert(dtype);
2120
return nullptr;
2221
};
2322

@@ -26,12 +25,16 @@ KernelDTypeTracer::KernelDTypeTracer() {
2625
.scopes({at::RecordScope::KERNEL_FUNCTION_DTYPE}));
2726
}
2827

29-
c10::Synchronized<KernelDTypeTracer::kernel_tags_type>& KernelDTypeTracer::
30-
getCalledKernelTags() {
31-
static c10::Synchronized<kernel_tags_type> called_kernel_tags;
28+
KernelDTypeTracer::kernel_tags_type& KernelDTypeTracer::getCalledKernelTags() {
29+
static kernel_tags_type called_kernel_tags;
3230
return called_kernel_tags;
3331
}
3432

33+
std::mutex& KernelDTypeTracer::getMutex() {
34+
static std::mutex m;
35+
return m;
36+
}
37+
3538
} // namespace mobile
3639
} // namespace jit
3740
} // namespace torch

torch/csrc/jit/mobile/model_tracer/KernelDTypeTracer.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#pragma once
22

33
#include <ATen/record_function.h>
4-
#include <c10/util/Synchronized.h>
54
#include <map>
5+
#include <mutex>
66
#include <set>
77
#include <string>
88

@@ -30,7 +30,9 @@ struct KernelDTypeTracer final {
3030
typedef std::map<std::string, std::set<std::string>> kernel_tags_type;
3131

3232
KernelDTypeTracer();
33-
static c10::Synchronized<kernel_tags_type>& getCalledKernelTags();
33+
static kernel_tags_type& getCalledKernelTags();
34+
/* Protect concurrent writes into the map. */
35+
static std::mutex& getMutex();
3436

3537
~KernelDTypeTracer() {
3638
at::removeCallback(handle_);

torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -295,10 +295,12 @@ TracerResult trace_run(const std::string& input_module_path) {
295295

296296
recordCustomClassesFromOpSchemas(root_ops, traced_operators, loaded_classes);
297297

298-
kdtype_tracer.getCalledKernelTags().withLock(
299-
[&](KernelDTypeTracer::kernel_tags_type& kernel_tags) {
300-
called_kernel_tags.insert(kernel_tags.begin(), kernel_tags.end());
301-
});
298+
{
299+
std::lock_guard<std::mutex> guard(KernelDTypeTracer::getMutex());
300+
called_kernel_tags.insert(
301+
kdtype_tracer.getCalledKernelTags().begin(),
302+
kdtype_tracer.getCalledKernelTags().end());
303+
}
302304

303305
traced_operators.insert(
304306
always_included_traced_ops.begin(), always_included_traced_ops.end());

0 commit comments

Comments
 (0)