forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathengine.h
393 lines (327 loc) · 14.6 KB
/
engine.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
#pragma once
// Engine implements backpropagation from output variables and their gradients
// to "root" variables (variables created by the user with requires_grad=True).
#include <ATen/Tensor.h>
#include <ATen/ThreadLocalState.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/autograd/anomaly_mode.h>
#include <torch/csrc/autograd/function.h>
#include <torch/csrc/autograd/functions/basic_ops.h>
#include <torch/csrc/autograd/input_buffer.h>
#include <torch/csrc/utils/future.h>
#include <deque>
#include <exception>
#include <functional>
#include <memory>
#include <queue>
#include <unordered_map>
#include <utility>
#include <vector>
#include <thread>
namespace torch { namespace autograd {
struct ReadyQueue;
}} // namespace torch::autograd
namespace torch { namespace autograd {
using FutureVariableList = torch::utils::Future<variable_list>;
static constexpr int NO_DEVICE = -2;
static constexpr int CPU_DEVICE = -1;
// Maximum reentrant backward depth before switching to a new thread
// This limit is based on the TSAN's deadlock detector, where it will
// fail if a program hold more than 65 locks in one thread at once.
// As we hold mutex in every of our custom C++ autograd Node, we would
// like to avoid TSAN complains on this when doing reentrant backwards
// For reference, see https://github.com/google/sanitizers/issues/950
static constexpr int MAX_DEPTH = 60;
void set_device(int device);
void validate_outputs(
const edge_list& edges,
variable_list& grads,
const std::function<std::string(const std::string&)>& format_error);
// GraphTask holds metadata needed for a single execution of backward()
struct GraphTask: std::enable_shared_from_this<GraphTask> {
std::atomic<uint64_t> outstanding_tasks_{0};
// Indicates if an error occurred while executing any task. When this is
// true, it signals all threads to stop executing.
std::atomic_bool has_error_{false};
std::atomic_bool future_completed_{false};
// It is safe to read grad_mode_ and keep_graph_ without synchronization
bool keep_graph_;
bool grad_mode_;
// To protect reads/writes to not_ready_, dependencies_, captured_vars_,
// has_error_, future_result_, cpu_ready_queue_, and leaf_streams.
std::mutex mutex_;
std::unordered_map<Node*, InputBuffer> not_ready_;
std::unordered_map<Node*, int> dependencies_;
struct ExecInfo {
struct Capture {
Capture(const Capture&) = delete;
Capture(Capture&&) = default;
Capture(int input_idx, int output_idx)
: input_idx_(input_idx), output_idx_(output_idx) {}
int input_idx_; // within Node inputs
int output_idx_; // within the output vector of a GraphTask
// This hook will be executed after a grad is captured. The captured
// grad will be replaced by the return value of the hook.
struct GradCaptureHook {
virtual ~GradCaptureHook() = default;
virtual at::Tensor operator()(const at::Tensor& grad) = 0;
};
// The hooks will be called one by one in the order as they were added.
// The input grad of a hook will be the output of its preceding hook. The
// first hook will take the captured grad as the input. The output of the
// last hook will replace the captured grad.
std::vector<std::unique_ptr<GradCaptureHook>> hooks_;
};
bool should_execute() const {
return needed_ || captures_;
}
bool needed_ = false;
std::unique_ptr<std::vector<Capture>> captures_;
};
// Exec info has a bit complicated semantics. If it's empty, it means the task
// is run in a "default" mode, which means that all next_edges we encounter
// should get executed. If it's not empty, only functions that have an entry
// and this entry has needed == True should be executed. exec_info_.empty()
// means it's .backward(), otherwise it's .grad(). exec_info_ is safe to read
// without synchronization
std::unordered_map<Node*, ExecInfo> exec_info_;
// Captures variables are grads captured that we return to the user. After
// execution of the GraphTask is completed, the captured_vars_ are moved
// out of the GraphTask and are no longer valid.
std::vector<Variable> captured_vars_;
at::ThreadLocalState thread_locals_ =
at::ThreadLocalState(/* keep_grad_mode */ false);
std::unordered_set<c10::Stream> leaf_streams;
void init_to_execute(Node& graph_root, const edge_list& outputs);
// The value of worker_device in the thread that created this task.
// See Note [Reentrant backwards]
// Safe to read owner_ and reentrant_depth_ without synchronizaton
int owner_;
// The number of parent graph tasks for this graph task
const int reentrant_depth_;
bool can_checkpoint() {
return exec_info_.empty();
}
// check if the GraphTask is completed or not
bool completed();
// mark the graph task as completed and trigger post processing
void mark_as_completed_and_run_post_processing();
// Set an appropriate exception on this graph_task which was encountered while
// running the provided function.
void set_exception(std::exception& e, const std::shared_ptr<Node>& fn);
// Set an appropriate exception on this graph_task which was encountered while
// running the provided function. But doesn't signal completion on
// 'future_result_' right away. The user needs to explicitly mark
// 'future_result_' completed with an appropriate exception.
void set_exception_without_signal(const std::shared_ptr<Node>& fn);
// Whether or not to stop execution for this GraphTask when an error is
// encountered. When set to true, this would cause Engine::execute() to throw
// an exception as soon as the autograd engine receives an exception.
bool exit_on_error_;
// CPU threads are dedicated to processing CPU work for the backward they invoked.
// So any given graph task maintains its own cpu_ready_queue_ where you should send
// work for it to be done. We memorize the cpu_ready_queue_ per GraphTask so that
// we know which ready queue we should push to if we are on device thread (i.e. GPU)
// and but next NodeTask should be run on CPU.
std::shared_ptr<ReadyQueue> cpu_ready_queue_;
// Future representing the completion of the graph task. Notified when all
// tasks are done.
std::shared_ptr<FutureVariableList> future_result_;
// Final callbacks installed during execution of this GraphTask
std::vector<std::function<void()>> final_callbacks_;
// To protect reads and writes to final_callbacks_. Intentionally no reusing
// mutex_ as the two are protecting different data structures.
std::mutex final_callbacks_lock_;
GraphTask(
bool keep_graph,
bool grad_mode,
int reentrant_depth,
std::shared_ptr<ReadyQueue> cpu_ready_queue,
bool exit_on_error = false)
: keep_graph_(keep_graph),
grad_mode_(grad_mode),
owner_(NO_DEVICE),
reentrant_depth_(reentrant_depth),
exit_on_error_(exit_on_error),
cpu_ready_queue_(std::move(cpu_ready_queue)),
future_result_(std::make_shared<FutureVariableList>()) {}
private:
// run GraphTask post processing
void exec_post_processing();
};
// The guard that sets and restores current_graph_task.
class GraphTaskGuard {
public:
explicit GraphTaskGuard(std::shared_ptr<GraphTask> graph_task);
~GraphTaskGuard();
void restore_current_graph_task();
private:
std::shared_ptr<GraphTask> last_graph_task_;
};
struct NodeTask {
std::weak_ptr<GraphTask> base_;
std::shared_ptr<Node> fn_;
// This buffer serves as an implicit "addition" node for all of the
// gradients flowing here. Once all the dependencies are finished, we
// use the contents of this buffer to run the function.
InputBuffer inputs_;
// When worker receives a task with isShutdownTask = true, it will immediately
// exit. The engine sends a shutdown task to every queue upon its destruction.
bool isShutdownTask_;
int getReentrantDepth() const;
NodeTask(
std::weak_ptr<GraphTask> base,
std::shared_ptr<Node> fn,
InputBuffer inputs,
bool isShutdownTask = false)
: base_(base),
fn_(std::move(fn)),
inputs_(std::move(inputs)),
isShutdownTask_(isShutdownTask) {}
};
struct ReadyQueue {
private:
// Returns true when t2 should be (weakly) BEFORE t1 in the queue.
// Shutdown tasks are first and then empty NodeTask are next.
struct CompareNodeTaskTime {
bool operator()(NodeTask const & t1, NodeTask const & t2) {
if (t2.isShutdownTask_) {
return true;
} else if (!t1.fn_ || t1.isShutdownTask_) {
return false;
} else if (!t2.fn_) {
return true;
} else if (t1.getReentrantDepth() == t2.getReentrantDepth()) {
return t1.fn_->sequence_nr() < t2.fn_->sequence_nr();
} else {
return t1.getReentrantDepth() < t2.getReentrantDepth();
}
}
};
// To notify threads waiting on the ReadyQueue of available tasks on the heap_
std::condition_variable not_empty_;
// To protect read and writes to heap_
mutable std::mutex mutex_;
std::priority_queue<NodeTask, std::vector<NodeTask>, CompareNodeTaskTime> heap_;
public:
// incrementOutstandingTasks indicates whether or not we should increment
// 'outstanding_tasks_' for the associated GraphTask. This should mostly
// always be true, see the doc for 'enqueue_blocked_task_on_cpu' for when we
// might set this to false.
void push(NodeTask item, bool incrementOutstandingTasks = true);
void pushShutdownTask();
NodeTask pop();
bool empty() const;
size_t size() const;
};
// A single instance of this struct should be created through the whole process lifetime.
// The worker thread creation logic and Engine's destructor rely on this.
struct TORCH_API Engine {
/// Returns a reference to a static `Engine` instance.
static Engine& get_default_engine();
static Engine& get_base_engine();
Engine(const Engine&) = delete;
Engine(Engine&&) = delete;
virtual ~Engine();
// Given a list of (Node, input number) pairs computes the value of the graph
// by following next_edge references.
virtual variable_list execute(
const edge_list& roots,
const variable_list& inputs,
bool keep_graph,
bool create_graph,
const edge_list& outputs = {});
// Given a pre-populated GraphTask and GraphRoot, computes the backward pass
// for the graph.
//
// NB: This API should only be used by internal autograd specific
// machinery and shouldn't be exposed to users in anyway.
virtual std::shared_ptr<FutureVariableList> execute_with_graph_task(
const std::shared_ptr<GraphTask>& graph_task,
std::shared_ptr<Node> graph_root);
virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() {
return nullptr;
}
// We pass cpu_ready_queue to evaluate_function, so that it knows
// the correct ready queue to push to after a NodeTask is ready
void evaluate_function(
std::shared_ptr<GraphTask>& graph_task,
Node* func,
InputBuffer& inputs,
const std::shared_ptr<ReadyQueue>& cpu_ready_queue);
void initialize_device_threads_pool();
virtual void thread_on_exception(
std::shared_ptr<GraphTask> graph_task,
const std::shared_ptr<Node>& fn,
std::exception& e);
void queue_callback(std::function<void()> callback);
bool is_checkpoint_valid();
size_t ready_queue_size(const std::shared_ptr<GraphTask>& graph_task, at::Device device);
// Should be called after fork to notify that worker threads are gone
void release_workers();
// Initializes a device thread for the autograd engine.
virtual void thread_init(
int device,
const std::shared_ptr<ReadyQueue>& ready_queue,
bool should_increment = true);
protected:
Engine();
void compute_dependencies(Node* root, GraphTask& task);
// initialize the thread local ready queue with the ready queue that is created
// elsewhere (i.e. thread_init, Engine::execute, etc), or create a new
// ready queue if ready_queue is not provided.
void init_local_ready_queue(std::shared_ptr<ReadyQueue> ready_queue = nullptr);
std::shared_ptr<ReadyQueue> ready_queue(
std::shared_ptr<ReadyQueue> cpu_ready_queue,
at::Device device);
std::shared_ptr<ReadyQueue> ready_queue_by_index(
std::shared_ptr<ReadyQueue> cpu_ready_queue,
int device_index);
// start device threads (CUDA, XLA, etc.) in Engine,
// note that it does NOT start CPU thread.
void start_device_threads();
void increment_non_reentrant_thread_count();
void decrement_non_reentrant_thread_count();
virtual void thread_main(const std::shared_ptr<GraphTask>& task);
void reentrant_thread_init();
void add_thread_pool_task(const std::weak_ptr<GraphTask>& graph_task);
// Ensures device_ready_queues_ are initialized only once
std::once_flag start_device_threads_flag_;
// Safe to read device_ready_queues_ without synchronization after initialization
std::vector<std::shared_ptr<ReadyQueue>> device_ready_queues_;
std::vector<std::function<void()>> final_callbacks_;
// To protect reads and writes to final_callbacks_
std::mutex post_callbacks_lock_;
// How many nested reentrant calls are allowed until a new thread is used
int max_recursion_depth_;
struct ThreadPoolShared {
// Data structures used by the threads for executing reentrant backwards
// tasks. See Note [Reentrant backwards]
// Number of available threads for processing new GraphTasks.
unsigned int num_workers_;
// The threads will wait on work_ to be notified of GraphTasks
std::condition_variable work_;
// To protect reads and writes to graphtask_queue_ and num_workers_
// and for synchronizing creating new threads when needed
std::mutex mutex_;
// Workers will process the GraphTasks added to this queue. A GraphTask is
// allocated inside Engine::execute and lives for the duration of execute
std::queue<std::weak_ptr<GraphTask>> graphtasks_queue_;
ThreadPoolShared() : num_workers_(0) {}
};
// Temporary workaround until shutting down threads is done
// We need shared ownership of all these objects because the threads are leaked
// when Engine shuts down, so there may be threads waiting on work_
// for the graphtasks_queue_ to be nonempty.
std::shared_ptr<ThreadPoolShared> thread_pool_shared_;
private:
// Number of non-reentrant threads
std::atomic<uint32_t> non_reentrant_device_thread_count_;
// Destructor will wait for non-reentrant threads to finish
std::condition_variable non_reentrant_device_thread_condvar_;
std::mutex non_reentrant_device_thread_mutex_;
};
// allow python_engine to override the default engine when it loads
using EngineStub = Engine& (*)();
TORCH_API void set_default_engine_stub(EngineStub stub);
}} // namespace torch::autograd