forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
CachedTensorUtils.cpp
49 lines (36 loc) · 1.37 KB
/
CachedTensorUtils.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include <ATen/ATen.h>
#include <ATen/CachedTensorUtils.h>
#include <c10/util/flat_hash_map.h>
namespace at {
namespace caching {
using weakref_type = c10::weak_intrusive_ptr<TensorImpl, UndefinedTensorImpl>;
bool cached_tensorimpls_enabled = false;
// Like `cached_casts` in autocast_mode, we hash on the TensorImpl*
// and keep the pointer alive with a weakref value.
ska::flat_hash_map<TensorImpl*, weakref_type> cached_tensorimpls;
std::mutex cached_tensorimpl_mutex;
bool is_cached_tensor(const at::Tensor& t) {
if (!cached_tensorimpls_enabled) {
return false;
}
const std::lock_guard<std::mutex> lock(cached_tensorimpl_mutex);
return cached_tensorimpls.count(t.unsafeGetTensorImpl());
}
void add_cached_tensor(const at::Tensor& t) {
TORCH_INTERNAL_ASSERT(cached_tensorimpls_enabled);
const std::lock_guard<std::mutex> lock(cached_tensorimpl_mutex);
cached_tensorimpls.emplace(t.unsafeGetTensorImpl(), weakref_type(t.getIntrusivePtr()));
}
void remove_cached_tensor(const at::Tensor& t) {
TORCH_INTERNAL_ASSERT(cached_tensorimpls_enabled);
const std::lock_guard<std::mutex> lock(cached_tensorimpl_mutex);
cached_tensorimpls.erase(t.unsafeGetTensorImpl());
}
void set_cached_tensors_enabled(bool enabled) {
cached_tensorimpls_enabled = enabled;
}
size_t adjusted_use_count(const at::Tensor& t) {
return t.use_count() - (is_cached_tensor(t) ? 1 : 0);
}
}
}