forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlibrary.h
430 lines (380 loc) · 16.7 KB
/
library.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
#pragma once
#include <c10/core/DispatchKey.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/op_registration/infer_schema.h>
#if defined(EXPOSE_C2_OPS) || !defined(CAFFE2_IS_XPLAT_BUILD)
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#endif
// Just for inferFunctionSchemaFromFunctor
#include <ATen/core/op_registration/op_registration.h>
namespace torch {
template <class CurClass>
class class_;
// A quick tour of a few usage examples:
//
// // Define a library whose operators live in the namespace 'aten'.
// // You must define all of the operators for this library in
// // this namespace.
// TORCH_LIBRARY(aten, m) {
// // Define a schema for an operator, but provide no implementation
// m.def("mul(Tensor self, Tensor other) -> Tensor");
//
// // Define a operator with exactly one implementation for all backends.
// m.def("add(Tensor self, Tensor other) -> Tensor", &add_impl);
//
// // Provide an implementation for a defined operator (you can
// // provide multiple; one per backend). We'll take care of calling
// // the correct implementation depending on if we get a CPU
// // tensor or a CUDA tensor
// m.impl("mul", torch::kCPU, &mul_cpu_impl);
// m.impl("mul", torch::kCUDA, &mul_cuda_impl);
// }
//
// // Define implementations for operators for a non-standard backend,
// // e.g., XLA (valid values are entries of DispatchKey). These
// // operator names are not namespaced; you can define implementations
// // for any namespace.
// TORCH_LIBRARY_IMPL(aten, XLA, m) {
// m.impl("mul", &mul_xla_impl);
// }
// Represents a C++ function that implements an operator. Most users won't
// interact directly with this class, except via error messages: the
// constructors this function define the set of permissible "function"-like
// things you can bind via the interface.
//
// This class erases the type of the passed in function, but durably records
// the type via an inferred schema for the function.
//
// TODO: This is morally the same thing as KernelRegistrationConfig, but it's
// opaque to the user.
class CAFFE2_API CppFunction final {
public:
// This overload accepts function pointers, e.g., CppFunction(&add_impl)
template <typename Func>
explicit CppFunction(Func* f, std::enable_if_t<c10::guts::is_function_type<Func>::value, std::nullptr_t> = nullptr)
: func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f))
, cpp_signature_(c10::impl::CppSignature::make<Func>())
// TODO: Don't go through WrapRuntimeKernelFunctor
, schema_(c10::detail::inferFunctionSchemaFromFunctor<c10::impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Func>>>())
, debug_()
{}
// This overload accepts compile time function pointers, e.g., CppFunction(TORCH_FN(add_impl))
template <typename FuncPtr>
explicit CppFunction(FuncPtr f, std::enable_if_t<c10::is_compile_time_function_pointer<FuncPtr>::value, std::nullptr_t> = nullptr)
: func_(c10::KernelFunction::makeFromUnboxedRuntimeFunction(f.func_ptr()))
, cpp_signature_(c10::impl::CppSignature::make<typename FuncPtr::FuncType>())
// TODO: Don't go through WrapRuntimeKernelFunctor
, schema_(c10::detail::inferFunctionSchemaFromFunctor<c10::impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<typename FuncPtr::FuncType>>>())
, debug_()
{}
// This overload accepts lambdas, e.g., CppFunction([](const Tensor& self) { ... })
template <typename Lambda>
explicit CppFunction(Lambda&& f, std::enable_if_t<c10::guts::is_functor<std::decay_t<Lambda>>::value, std::nullptr_t> = nullptr)
: func_(c10::KernelFunction::makeFromUnboxedLambda(std::forward<Lambda>(f)))
, cpp_signature_(c10::impl::CppSignature::make<Lambda>())
// TODO: Don't go through WrapRuntimeKernelFunctor
, schema_(c10::detail::inferFunctionSchemaFromFunctor<c10::impl::WrapFunctionIntoRuntimeFunctor<std::decay_t<Lambda>>>())
, debug_()
{}
// This static factory lets you create CppFunctions that (1) don't have boxing
// wrappers (because we don't support it yet) and (2) don't have schema
// inference (because some ops don't support it).
//
// TODO: Eliminate the necessity for this function entirely.
template <typename Func>
static CppFunction makeUnboxedOnly(Func* f) {
return CppFunction(
c10::KernelFunction::makeFromUnboxedOnlyRuntimeFunction(f),
/* cpp_signature */ c10::impl::CppSignature::make<Func>(),
/* schema */ nullptr
);
}
// TODO: more user friendly API
static CppFunction makeFallthrough() {
return CppFunction(
c10::KernelFunction::makeFallthrough(),
/* cpp_signature */ c10::nullopt, // not known for fallthroughs
/* schema */ nullptr
);
}
static CppFunction makeNamedNotSupported() {
return CppFunction(
c10::KernelFunction::makeNamedNotSupported(),
/* cpp_signature */ c10::nullopt, // not known for fallthroughs
/* schema */ nullptr
);
}
// TODO: more user friendly API
template<c10::KernelFunction::BoxedKernelFunction* func>
static CppFunction makeFromBoxedFunction() {
return CppFunction(
c10::KernelFunction::makeFromBoxedFunction<func>(),
/* cpp_signature */ c10::nullopt, // not known for boxed functions
/* schema */ nullptr
);
}
CppFunction&& debug(std::string d) && {
debug_ = std::move(d);
return std::move(*this);
}
private:
c10::optional<c10::DispatchKey> dispatch_key_;
c10::KernelFunction func_;
c10::optional<c10::impl::CppSignature> cpp_signature_;
std::unique_ptr<c10::FunctionSchema> schema_;
std::string debug_;
// The "setter" for dispatch_key_
template <typename Func>
friend CppFunction dispatch(c10::DispatchKey, Func&&);
// The only class which actually pulls out values from CppFunction (does so
// destructively, felt too lazy to write accessors that I don't even
// want users to use)
friend class Library;
CppFunction(c10::KernelFunction func, c10::optional<c10::impl::CppSignature> cpp_signature, std::unique_ptr<c10::FunctionSchema> schema);
};
// Create a CppFunction which is associated with a specific dispatch key.
// CppFunctions that are tagged with a DispatchKey don't get invoked /unless/
// the dispatcher determines that the DispatchKey is the best choice for
// a function
template <typename Func>
inline CppFunction dispatch(c10::DispatchKey k, Func&& raw_f) {
CppFunction f(std::forward<Func>(raw_f));
if (k == c10::DispatchKey::CatchAll) {
f.dispatch_key_ = c10::nullopt;
} else {
f.dispatch_key_ = k;
}
return f;
}
// Convenience overload of dispatch which accepts DeviceType
template <typename Func>
inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
auto deviceTypeToDispatchKey = [](c10::DeviceType t){
switch (t) {
// This list is synchronized with the k-constants in c10/core/DeviceType.h
case c10::DeviceType::CPU:
return c10::DispatchKey::CPU;
case c10::DeviceType::CUDA:
return c10::DispatchKey::CUDA;
case c10::DeviceType::XLA:
return c10::DispatchKey::XLA;
case c10::DeviceType::HIP:
return c10::DispatchKey::HIP;
case c10::DeviceType::MSNPU:
return c10::DispatchKey::MSNPU;
default:
TORCH_CHECK(false,
"Device type ", t, " cannot be overloaded at dispatch time, "
"please file a bug report explaining what you were trying to do.");
}
};
return dispatch(deviceTypeToDispatchKey(type), std::forward<Func>(raw_f));
}
inline c10::FunctionSchema schema(const char* str, c10::AliasAnalysisKind k) {
c10::FunctionSchema s = torch::jit::parseSchema(str);
s.setAliasAnalysis(k);
return s;
}
inline c10::FunctionSchema schema(const char* s) {
return schema(s, c10::AliasAnalysisKind::FROM_SCHEMA);
}
inline c10::FunctionSchema&& schema(c10::FunctionSchema&& s) { return std::move(s); }
namespace detail {
inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(c10::FunctionSchema&& s) {
return c10::make_right<c10::OperatorName, c10::FunctionSchema>(std::move(s));
}
inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(c10::OperatorName&& n) {
return c10::make_left<c10::OperatorName, c10::FunctionSchema>(std::move(n));
}
inline c10::either<c10::OperatorName, c10::FunctionSchema> constructSchemaOrName(const char* str) {
auto s = torch::jit::parseSchemaOrName(str);
if (s.is_right()) {
s.right().setAliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA);
}
return s;
}
class TorchLibraryInit;
}
// This is the "handle" by which functions defined in TORCH_LIBRARY
// and TORCH_LIBRARY_IMPL can define operators and override implementations
// at certain backends.
//
// Conventionally, you get access to it using those two macros:
//
// TORCH_LIBRARY(torchvision, m) {
// // m is a torch::Library
// m.def("roi_align", ...);
// ...
// }
//
// TORCH_LIBRARY_IMPL(aten, XLA, m) {
// // m is a torch::Library
// m.impl("add", ...);
// ...
// }
//
// In some cases, you need to define something that applies to all namespaces,
// not just one namespace (usually a fallback). In that case, use the reserved
// namespace _, e.g.,
//
// TORCH_LIBRARY_IMPL(_, XLA, m) {
// m.fallback(xla_fallback);
// }
//
class CAFFE2_API Library final {
public:
// Which type of macro produced this Library
enum Kind {
DEF, // from TORCH_LIBRARY (no qualifier)
IMPL,
FRAGMENT,
};
// Use TORCH_LIBRARY/TORCH_LIBRARY_IMPL instead of these constructors directly
Library(Kind kind, std::string ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line);
Library(const Library&) = delete;
Library& operator=(const Library&) = delete;
Library(Library&&) = default;
Library& operator=(Library&&) = default;
// Some notes about the API design here. We had the following constraints:
//
// - We need to support multiple "types" of arguments for schema and
// functions (e.g., unnamed lambda types, regular functions, const char*,
// fully instantiated schemas)
// - We don't want to write exponentially many overloads
// - We don't want to rely on implicit conversion to a common type,
// because the C++ compiler will only be willing to do a single
// implicit conversion (reducing the set of valid types which you
// can invoke with); also error messages are worse when an implicit
// conversion is not selected (as the compiler will not explain
// why it didn't select an implicit conversion; this is different
// from overloads where it will explain each candidate overload and
// why it didn't apply)
//
// To solve all of these constraints at the same time, we use a trick taken
// from the pybind11 library: template over the argument in the user visible
// API, and inside of the templated function explicitly call an overloaded
// function to resolve the argument to a real type. You get the good error
// messages from overloads, but at the same time you only need to write the
// overload for any given argument type once.
// Declare an operator with a schema, but don't provide any implementations
// for it. You're expected to then provide implementations using the
// impl() method.
template <typename Schema>
Library& def(Schema&& raw_schema) & {
c10::FunctionSchema s = schema(std::forward<Schema>(raw_schema));
return _def(std::move(s));
}
// Convenience method to define an operator for a schema and then register
// an implementation for it. def(n, f) is almost equivalent to def(n).impl(f),
// except that if n is not a schema, then the schema is inferred from the
// static type of f.
template <typename NameOrSchema, typename Func>
Library& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f) & {
CppFunction f(std::forward<Func>(raw_f));
auto name_or_schema = detail::constructSchemaOrName(std::forward<NameOrSchema>(raw_name_or_schema));
return _def(std::move(name_or_schema), std::move(f));
}
// Register an implementation for an operator. You may register multiple
// implementations for a single operator at different dispatch keys
// (see torch::dispatch). Implementations must have a corresponding
// declaration (from def), otherwise they are invalid.
template <typename Func>
Library& impl(const char* name, Func&& raw_f) & {
CppFunction f(std::forward<Func>(raw_f));
return _impl(name, std::move(f));
}
// Convenience overload for directly specifying the dispatch key. Dispatch
// can validly be either DeviceType or DispatchKey; check torch::dispatch for
// the canonical list of accepted overloads.
template <typename Dispatch, typename Func>
Library& impl(const char* name, Dispatch&& key, Func&& raw_f) & {
return impl(name, dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));
}
// Convenience overload for unboxed only kernels. These are quite common
// but will be eventually eliminated; this function makes it easy to grep for
// them.
//
// TODO: Remove this overload once the makeUnboxedOnly incidence rate
// goes way down
template <typename Func>
Library& impl_UNBOXED(const char* name, Func* raw_f) & {
return impl(name, CppFunction::makeUnboxedOnly(raw_f));
}
// Register a fallback implementation for all operators which will be used
// if there is not a specific implementation for an operator available.
// Providing a DispatchKey is MANDATORY for fallback at the moment; e.g.,
// only call this from TORCH_LIBRARY_IMPL
template <typename Func>
Library& fallback(Func&& raw_f) & {
CppFunction f((std::forward<Func>(raw_f)));
return _fallback(std::move(f));
}
template <class CurClass>
inline class_<CurClass> class_(const std::string& className);
private:
Kind kind_;
c10::optional<std::string> ns_;
c10::optional<c10::DispatchKey> dispatch_key_;
const char* file_;
uint32_t line_;
std::vector<c10::RegistrationHandleRAII> registrars_;
friend class detail::TorchLibraryInit;
// Non-user visible actual implementations of functions. These aren't
// public because we only implement & qualifier and not && qualifier
Library& _def(c10::FunctionSchema&& schema, c10::OperatorName* out_name = nullptr) &;
Library& _def(c10::either<c10::OperatorName, c10::FunctionSchema>&&, CppFunction&& f) &;
Library& _impl(const char* name, CppFunction&& f) &;
Library& _fallback(CppFunction&& f) &;
};
namespace detail {
class TorchLibraryInit final {
private:
using InitFn = void(Library&);
Library lib_;
public:
TorchLibraryInit(Library::Kind kind, InitFn* fn, const char* ns, c10::optional<c10::DispatchKey> k, const char* file, uint32_t line)
: lib_(kind, ns, k, file, line) {
fn(lib_);
}
};
} // namespace detail
} // namespace torch
// NB: The EXACT NAMING of the initializer functions (e.g.,
// TORCH_LIBRARY_init_aten) matters for the code analyzer;
// see the regexes at tools/code_analyzer/run_analyzer.sh
#define TORCH_LIBRARY(ns, m) \
static void TORCH_LIBRARY_init_ ## ns (torch::Library&); \
static torch::detail::TorchLibraryInit TORCH_LIBRARY_static_init_ ## ns ( \
torch::Library::DEF, \
&TORCH_LIBRARY_init_ ## ns, \
#ns, c10::nullopt, __FILE__, __LINE__ \
); \
void TORCH_LIBRARY_init_ ## ns (torch::Library& m)
// This macro is a version of TORCH_LIBRARY that doesn't enforce that there
// is only one library (it is a "fragment"). This should ONLY be used
// with PerOpRegistration (as its name suggests).
#define TORCH_LIBRARY_FRAGMENT_THIS_API_IS_FOR_PER_OP_REGISTRATION_ONLY(ns, m) \
static void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (torch::Library&); \
static torch::detail::TorchLibraryInit TORCH_LIBRARY_FRAGMENT_static_init_ ## ns ## _ ## k ( \
torch::Library::FRAGMENT, \
&TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k, \
#ns, c10::nullopt, __FILE__, __LINE__ \
); \
void TORCH_LIBRARY_FRAGMENT_init_ ## ns ## _ ## k (torch::Library& m)
#define TORCH_LIBRARY_IMPL(ns, k, m) \
static void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (torch::Library&); \
static torch::detail::TorchLibraryInit TORCH_LIBRARY_IMPL_static_init_ ## ns ## _ ## k ( \
torch::Library::IMPL, \
& TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k, \
#ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__ \
); \
void TORCH_LIBRARY_IMPL_init_ ## ns ## _ ## k (torch::Library& m)
// These are variants of the macros above which are to be used for testing (they
// don't setup the static initializer, so you can control the visibility of
// the allocated library yourself).
//
// DO NOT use these in production code, they are NOT understood by the
// code analyzer and will be incorrectly analyzed in those situations.
#define MAKE_TORCH_LIBRARY(ns) torch::Library(torch::Library::DEF, #ns, c10::nullopt, __FILE__, __LINE__)
#define MAKE_TORCH_LIBRARY_IMPL(ns, k) torch::Library(torch::Library::IMPL, #ns, c10::make_optional(c10::DispatchKey::k), __FILE__, __LINE__)
#include <torch/custom_class.h>