File tree 3 files changed +19
-12
lines changed
torch/csrc/jit/mobile/model_tracer
3 files changed +19
-12
lines changed Original file line number Diff line number Diff line change @@ -15,9 +15,8 @@ KernelDTypeTracer::KernelDTypeTracer() {
15
15
std::string kernel_tag = name.substr (0 , dollar_pos);
16
16
std::string dtype = name.substr (dollar_pos + 1 );
17
17
18
- getCalledKernelTags ().withLock ([&](kernel_tags_type& kernel_tags) {
19
- kernel_tags[kernel_tag].insert (dtype);
20
- });
18
+ std::lock_guard<std::mutex> guard (getMutex ());
19
+ getCalledKernelTags ()[kernel_tag].insert (dtype);
21
20
return nullptr ;
22
21
};
23
22
@@ -26,12 +25,16 @@ KernelDTypeTracer::KernelDTypeTracer() {
26
25
.scopes ({at::RecordScope::KERNEL_FUNCTION_DTYPE}));
27
26
}
28
27
29
- c10::Synchronized<KernelDTypeTracer::kernel_tags_type>& KernelDTypeTracer::
30
- getCalledKernelTags () {
31
- static c10::Synchronized<kernel_tags_type> called_kernel_tags;
28
+ KernelDTypeTracer::kernel_tags_type& KernelDTypeTracer::getCalledKernelTags () {
29
+ static kernel_tags_type called_kernel_tags;
32
30
return called_kernel_tags;
33
31
}
34
32
33
+ std::mutex& KernelDTypeTracer::getMutex () {
34
+ static std::mutex m;
35
+ return m;
36
+ }
37
+
35
38
} // namespace mobile
36
39
} // namespace jit
37
40
} // namespace torch
Original file line number Diff line number Diff line change 1
1
#pragma once
2
2
3
3
#include < ATen/record_function.h>
4
- #include < c10/util/Synchronized.h>
5
4
#include < map>
5
+ #include < mutex>
6
6
#include < set>
7
7
#include < string>
8
8
@@ -30,7 +30,9 @@ struct KernelDTypeTracer final {
30
30
typedef std::map<std::string, std::set<std::string>> kernel_tags_type;
31
31
32
32
KernelDTypeTracer ();
33
- static c10::Synchronized<kernel_tags_type>& getCalledKernelTags ();
33
+ static kernel_tags_type& getCalledKernelTags ();
34
+ /* Protect concurrent writes into the map. */
35
+ static std::mutex& getMutex ();
34
36
35
37
~KernelDTypeTracer () {
36
38
at::removeCallback (handle_);
Original file line number Diff line number Diff line change @@ -295,10 +295,12 @@ TracerResult trace_run(const std::string& input_module_path) {
295
295
296
296
recordCustomClassesFromOpSchemas (root_ops, traced_operators, loaded_classes);
297
297
298
- kdtype_tracer.getCalledKernelTags ().withLock (
299
- [&](KernelDTypeTracer::kernel_tags_type& kernel_tags) {
300
- called_kernel_tags.insert (kernel_tags.begin (), kernel_tags.end ());
301
- });
298
+ {
299
+ std::lock_guard<std::mutex> guard (KernelDTypeTracer::getMutex ());
300
+ called_kernel_tags.insert (
301
+ kdtype_tracer.getCalledKernelTags ().begin (),
302
+ kdtype_tracer.getCalledKernelTags ().end ());
303
+ }
302
304
303
305
traced_operators.insert (
304
306
always_included_traced_ops.begin (), always_included_traced_ops.end ());
You can’t perform that action at this time.
0 commit comments