Skip to content

Commit

Permalink
Add DispatchKey impl overload; remove use of torch::dispatch (pytorch…
Browse files Browse the repository at this point in the history
…#35706)

Summary:
Pull Request resolved: pytorch#35706

It is extremely common to define implementations of operators at a
specific dispatch key, so we add an overload to impl specifically for
this case.  I then delete most uses of torch::dispatch

dispatch_autograd call sites can't make use of this overload.  So
instead the new preferred way to specify something as autograd is to
pass kAutograd as the dispatch key (short form, analogous to kCPU/kCUDA
which we support today).

I flip flopped about whether or not kAutograd should have the type
DispatchKey or some other type (to help better encapsulate the
DispatchKey enum); this is more direct and I can't think of any
BC problems from this usage.

Some other reorganization I did:
- I renamed all of the worker functions in op_registration to have
  a leading underscore and made them private, just to make it more
  clear what the public versus private API were (the private API
  shouldn't be used by users because it doesn't come with && overloads)
- In a few places where I was touching lines already, I replaced
  full DispatchKey typed out enums with shorter kFoo names, similar
  to kAutograd but I didn't publish these globally.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D20775783

Pulled By: ezyang

fbshipit-source-id: e45b289e5d1f86c180b24cf14c63cf4459ab5337
  • Loading branch information
ezyang authored and facebook-github-bot committed Apr 2, 2020
1 parent c3abcf8 commit 2db6119
Show file tree
Hide file tree
Showing 14 changed files with 152 additions and 112 deletions.
25 changes: 11 additions & 14 deletions aten/src/ATen/autocast_mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,7 @@ I think Option 2 is the right answer for all ops, not just convolutions. Option
*****************************************************************************************************************/

auto register_fallthrough = c10::import()
.fallback(c10::dispatch(c10::DispatchKey::AutocastTensorId,
c10::CppFunction::makeFallthrough()));
.fallback(c10::DispatchKey::AutocastTensorId, c10::CppFunction::makeFallthrough());

/********************************************************************************************************************
Explicit registration for out-of-place ops
Expand Down Expand Up @@ -361,17 +360,17 @@ Therefore, for the moment, this is all copy pasted in from VariableTypeEverythin
// Common cases where registration signature matches redispatch signature
// (that's why SIGNATURE is repeated in the WrapFunction instantiation)
#define KERNEL(FUNC, REGISTER_NAME, SIGNATURE, POLICY) \
.impl(REGISTER_NAME, c10::dispatch(DispatchKey::AutocastTensorId, \
&WrapFunction<CastPolicy::POLICY, SIGNATURE, SIGNATURE, &FUNC>::type::call))
.impl(REGISTER_NAME, DispatchKey::AutocastTensorId, \
&WrapFunction<CastPolicy::POLICY, SIGNATURE, SIGNATURE, &FUNC>::type::call)

#define KERNEL_UNBOXED_ONLY(FUNC, REGISTER_NAME, SIGNATURE, POLICY) \
.impl(REGISTER_NAME, c10::dispatch(DispatchKey::AutocastTensorId, \
c10::CppFunction::makeUnboxedOnly(&WrapFunction<CastPolicy::POLICY, SIGNATURE, SIGNATURE, &FUNC>::type::call)))
.impl(REGISTER_NAME, DispatchKey::AutocastTensorId, \
c10::CppFunction::makeUnboxedOnly(&WrapFunction<CastPolicy::POLICY, SIGNATURE, SIGNATURE, &FUNC>::type::call))

// Less-common but still useful case: redispatching to a function with a new signature (e.g. appending a dtype)
#define KERNEL_UNBOXED_ONLY_DIFFERENT_REDISPATCH_SIGNATURE(REDISPATCH_FUNC, REGISTER_NAME, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, POLICY) \
.impl(REGISTER_NAME, c10::dispatch(DispatchKey::AutocastTensorId, \
c10::CppFunction::makeUnboxedOnly(&WrapFunction<CastPolicy::POLICY, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, &REDISPATCH_FUNC>::type::call)))
.impl(REGISTER_NAME, DispatchKey::AutocastTensorId, \
c10::CppFunction::makeUnboxedOnly(&WrapFunction<CastPolicy::POLICY, REGISTER_SIGNATURE, REDISPATCH_SIGNATURE, &REDISPATCH_FUNC>::type::call))

/*****************************************
Explicit registration for out-of-place ops
Expand Down Expand Up @@ -426,9 +425,8 @@ auto register_out_of_place = c10::import()
KERNEL(ADD_NS(gelu), "aten::gelu", Tensor (const Tensor &), fp32)
KERNEL_UNBOXED_ONLY(ADD_NS(layer_norm), "aten::layer_norm", Tensor (const Tensor &, IntArrayRef, const Tensor &, const Tensor &, double, bool), fp32)
// The macro doesn't like this one so I had to write it out manually.
.impl("aten::native_layer_norm",
c10::dispatch(DispatchKey::AutocastTensorId,
CppFunction::makeUnboxedOnly(&WrapFunction<CastPolicy::fp32, std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const Tensor &, const Tensor &, int64_t, int64_t, double), std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const Tensor &, const Tensor &, int64_t, int64_t, double), &ADD_NS(native_layer_norm)>::type::call)))
.impl("aten::native_layer_norm", DispatchKey::AutocastTensorId,
CppFunction::makeUnboxedOnly(&WrapFunction<CastPolicy::fp32, std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const Tensor &, const Tensor &, int64_t, int64_t, double), std::tuple<Tensor,Tensor,Tensor> (const Tensor &, const Tensor &, const Tensor &, int64_t, int64_t, double), &ADD_NS(native_layer_norm)>::type::call))
KERNEL_UNBOXED_ONLY(ADD_NS(group_norm), "aten::group_norm", Tensor (const Tensor &, int64_t, const Tensor &, const Tensor &, double, bool), fp32)
KERNEL_UNBOXED_ONLY(ADD_NS(frobenius_norm), "aten::frobenius_norm", Tensor (const Tensor &), fp32)
KERNEL_UNBOXED_ONLY(ADD_NS(frobenius_norm), "aten::frobenius_norm.dim", Tensor (const Tensor &, IntArrayRef, bool), fp32)
Expand Down Expand Up @@ -496,9 +494,8 @@ auto register_out_of_place = c10::import()
;

auto register_banned = torch::import()
.impl("aten::binary_cross_entropy",
torch::dispatch(DispatchKey::AutocastTensorId,
CppFunction::makeUnboxedOnly(&at::autocast::binary_cross_entropy_banned)));
.impl("aten::binary_cross_entropy", DispatchKey::AutocastTensorId,
CppFunction::makeUnboxedOnly(&at::autocast::binary_cross_entropy_banned));
}
#endif

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/BackendSelectFallbackKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
namespace {

static auto registry = c10::import()
.fallback(c10::dispatch(c10::DispatchKey::BackendSelect, c10::CppFunction::makeFallthrough()))
.fallback(c10::DispatchKey::BackendSelect, c10::CppFunction::makeFallthrough())
;

}
8 changes: 4 additions & 4 deletions aten/src/ATen/core/op_registration/op_registration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ Module& Module::operator=(Module&&) = default;
// TODO: Error if an operator is def'ed multiple times. Right now we just
// merge everything

Module& Module::def(FunctionSchema&& schema) & {
Module& Module::_def(FunctionSchema&& schema) & {
if (ns_.has_value()) schema.setNamespaceIfNotSet(ns_->c_str());
registrars_.emplace_back(Dispatcher::singleton().registerDef(std::move(schema)));
return *this;
}

Module& Module::def(c10::either<OperatorName, FunctionSchema>&& name_or_schema, CppFunction&& f) & {
Module& Module::_def(c10::either<OperatorName, FunctionSchema>&& name_or_schema, CppFunction&& f) & {
FunctionSchema schema = [&] {
if (name_or_schema.is_right()) {
return std::move(name_or_schema).right();
Expand All @@ -156,7 +156,7 @@ Module& Module::def(c10::either<OperatorName, FunctionSchema>&& name_or_schema,
return *this;
}

Module& Module::impl(const char* name_str, CppFunction&& f) & {
Module& Module::_impl(const char* name_str, CppFunction&& f) & {
auto name = torch::jit::parseName(name_str);
if (ns_.has_value()) name.setNamespaceIfNotSet(ns_->c_str());
registrars_.emplace_back(
Expand All @@ -171,7 +171,7 @@ Module& Module::impl(const char* name_str, CppFunction&& f) & {
return *this;
}

Module& Module::fallback(CppFunction&& f) & {
Module& Module::_fallback(CppFunction&& f) & {
TORCH_CHECK(!ns_, "Cannot define fallbacks from namespaces, use c10::import().fallback() instead");
TORCH_CHECK(f.dispatch_key_, "Fallback for catch all function not supported");
registrars_.emplace_back(Dispatcher::singleton().registerFallback(*f.dispatch_key_, std::move(f.func_)));
Expand Down
67 changes: 47 additions & 20 deletions aten/src/ATen/core/op_registration/op_registration.h
Original file line number Diff line number Diff line change
Expand Up @@ -609,8 +609,8 @@ class CAFFE2_API RegisterOperators final {
// // provide multiple; one per backend). We'll take care of calling
// // the correct implementation depending on if we get a CPU
// // tensor or a CUDA tensor
// .impl("aten::mul", torch::dispatch(torch::kCPU, &mul_cpu_impl))
// .impl("aten::mul", torch::dispatch(torch::kCUDA, &mul_cuda_impl))
// .impl("aten::mul", torch::kCPU, &mul_cpu_impl)
// .impl("aten::mul", torch::kCUDA, &mul_cuda_impl)
//
// Also, you can omit the top level namespace and specify it explicitly in
// the sub-definitions, e.g., torch::import().impl("aten::mul", ...)
Expand Down Expand Up @@ -718,12 +718,17 @@ template <typename Func>
inline CppFunction dispatch(DeviceType type, Func&& raw_f) {
auto deviceTypeToDispatchKey = [](DeviceType t){
switch (t) {
// This list is synchronized with the k-constants in c10/core/DeviceType.h
case DeviceType::CPU:
return c10::DispatchKey::CPUTensorId;
case DeviceType::CUDA:
return c10::DispatchKey::CUDATensorId;
case DeviceType::XLA:
return c10::DispatchKey::XLATensorId;
case DeviceType::HIP:
return c10::DispatchKey::HIPTensorId;
case DeviceType::MSNPU:
return c10::DispatchKey::MSNPUTensorId;
default:
TORCH_CHECK(false,
"Device type ", t, " cannot be overloaded at dispatch time, "
Expand All @@ -733,12 +738,6 @@ inline CppFunction dispatch(DeviceType type, Func&& raw_f) {
return dispatch(deviceTypeToDispatchKey(type), std::forward<Func>(raw_f));
}

// Convenience for overriding autograd functionality
template <typename Func>
inline CppFunction dispatch_autograd(Func&& raw_f) {
return dispatch(c10::DispatchKey::VariableTensorId, std::forward<Func>(raw_f));
}

inline FunctionSchema schema(const char* str, AliasAnalysisKind k) {
FunctionSchema s = torch::jit::parseSchema(str);
s.setAliasAnalysis(k);
Expand Down Expand Up @@ -786,6 +785,14 @@ class CAFFE2_API Module final {
friend Module _import_DOES_NOT_WORK_WITH_MOBILE_CUSTOM_BUILD(std::string ns);
friend Module import();

private:
// Non-user visible actual implementations of functions. These aren't
// public because we only implement & qualifier and not && qualifier
Module& _def(FunctionSchema&& schema) &;
Module& _def(c10::either<OperatorName, FunctionSchema>&&, CppFunction&& f) &;
Module& _impl(const char* name, CppFunction&& f) &;
Module& _fallback(CppFunction&& f) &;

public:
Module(const Module&) = delete;
Module& operator=(const Module&) = delete;
Expand Down Expand Up @@ -824,11 +831,10 @@ class CAFFE2_API Module final {
// Declare an operator with a schema, but don't provide any implementations
// for it. You're expected to then provide implementations using the
// impl() method.
Module& def(FunctionSchema&& schema) &;
template <typename Schema>
Module& def(Schema&& raw_schema) & {
FunctionSchema s = schema(std::forward<Schema>(raw_schema));
return def(std::move(s));
return _def(std::move(s));
}
template <typename Schema>
Module&& def(Schema&& raw_schema) && {
Expand All @@ -840,12 +846,11 @@ class CAFFE2_API Module final {
// an implementation for it. def(n, f) is almost equivalent to def(n).impl(f),
// except that if n is not a schema, then the schema is inferred from the
// static type of f.
Module& def(c10::either<OperatorName, FunctionSchema>&&, CppFunction&& f) &;
template <typename NameOrSchema, typename Func>
Module& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f) & {
CppFunction f(std::forward<Func>(raw_f));
auto name_or_schema = detail::constructSchemaOrName(std::forward<NameOrSchema>(raw_name_or_schema));
return def(std::move(name_or_schema), std::move(f));
return _def(std::move(name_or_schema), std::move(f));
}
template <typename NameOrSchema, typename Func>
Module&& def(NameOrSchema&& raw_name_or_schema, Func&& raw_f) && {
Expand All @@ -857,27 +862,50 @@ class CAFFE2_API Module final {
// implementations for a single operator at different dispatch keys
// (see torch::dispatch). Implementations must have a corresponding
// declaration (from def), otherwise they are invalid.
Module& impl(const char* name, CppFunction&& f) &;
template <typename Func>
Module& impl(const char* name, Func&& raw_f) & {
CppFunction f(std::forward<Func>(raw_f));
return impl(name, std::move(f));
return _impl(name, std::move(f));
}
template <typename Func>
Module&& impl(const char* name, Func&& raw_f) && {
impl(name, std::forward<Func>(raw_f));
return std::move(*this);
}
// Convenience overload for directly specifying the dispatch key. Dispatch
// can validly be either DeviceType or DispatchKey; check torch::dispatch for
// the canonical list of accepted overloads.
template <typename Dispatch, typename Func>
Module& impl(const char* name, Dispatch&& key, Func&& raw_f) & {
return impl(name, dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));
}
template <typename Dispatch, typename Func>
Module&& impl(const char* name, Dispatch&& key, Func&& raw_f) && {
impl(name, std::forward<Dispatch>(key), std::forward<Func>(raw_f));
return std::move(*this);
}

// Register a fallback implementation for all operators which will be used
// if there is not a specific implementation for an operator available.
// At the moment, you must specify a dispatch key (see torch::dispatch) for
// your fallback.
Module& fallback(CppFunction&& f) &;
// Providing a DispatchKey is MANDATORY for fallback at the moment.
//
// Dispatch can validly be either DeviceType or DispatchKey; check
// torch::dispatch for the canonical list of accepted overloads.
template <typename Dispatch, typename Func>
Module& fallback(Dispatch&& key, Func&& raw_f) & {
return fallback(c10::dispatch(std::forward<Dispatch>(key), std::forward<Func>(raw_f)));
}
template <typename Dispatch, typename Func>
Module&& fallback(Dispatch&& key, Func&& raw_f) && {
fallback(std::forward<Dispatch>(key), std::forward<Func>(raw_f));
return std::move(*this);
}
// NB: these overloads are here for completeness, but you'll probably want to
// use the direct Dispatch overload
template <typename Func>
Module& fallback(Func&& raw_f) & {
CppFunction f(std::forward<Func>(raw_f));
return fallback(std::move(f));
CppFunction f((std::forward<Func>(raw_f)));
return _fallback(std::move(f));
}
template <typename Func>
Module&& fallback(Func&& raw_f) && {
Expand Down Expand Up @@ -913,7 +941,6 @@ namespace torch {

// New-style API
using c10::dispatch;
using c10::dispatch_autograd;
using c10::schema;
using c10::import;
}
26 changes: 14 additions & 12 deletions aten/src/ATen/core/op_registration/op_registration_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1297,11 +1297,11 @@ TEST(NewOperatorRegistrationTest, testBasics) {
.def("_test::dummy2(Tensor self) -> Tensor")
.def("_test::dummy3(Tensor self, Tensor other) -> Tensor", [](const Tensor& self, const Tensor& other) { return self; })
.def("_test::dummy4", [](const Tensor& self, const Tensor& other) { return other; })
.impl("_test::dummy", c10::dispatch(c10::DeviceType::CPU, [](const Tensor& self) { return self; }))
.impl("_test::dummy", c10::dispatch(c10::DeviceType::XLA, [](const Tensor& self) { return self; }))
.impl("_test::dummy", c10::DeviceType::CPU, [](const Tensor& self) { return self; })
.impl("_test::dummy", c10::DeviceType::XLA, [](const Tensor& self) { return self; })
// Internal API
.impl("_test::dummy2", c10::dispatch(c10::DispatchKey::CPUTensorId, [](const Tensor& self) { return self; }))
.impl("_test::dummy2", c10::dispatch(c10::DispatchKey::XLATensorId, [](const Tensor& self) { return self; }));
.impl("_test::dummy2", c10::DispatchKey::CPUTensorId, [](const Tensor& self) { return self; })
.impl("_test::dummy2", c10::DispatchKey::XLATensorId, [](const Tensor& self) { return self; });

ASSERT_TRUE(Dispatcher::singleton().findSchema({"_test::dummy", ""}).has_value());
// Should have a schema even if there are no impls
Expand Down Expand Up @@ -1382,7 +1382,7 @@ TEST(NewOperatorRegistrationTest, dispatch) {
auto registrar = c10::import()
.def("test::fn_cpu", torch::dispatch(c10::DispatchKey::CPUTensorId, [&](const Tensor& x) { cpu_called = true; return x; }))
.def("test::fn_cuda", torch::dispatch(c10::kCUDA, [&](const Tensor& x) { cuda_called = true; return x; }))
.def("test::fn_autograd", torch::dispatch_autograd([&](const Tensor& x) { autograd_called = true; return x; }));
.def("test::fn_autograd", torch::dispatch(c10::kAutograd, [&](const Tensor& x) { autograd_called = true; return x; }));

{
auto op = Dispatcher::singleton().findSchema({"test::fn_cpu", ""});
Expand Down Expand Up @@ -1415,9 +1415,11 @@ TEST(NewOperatorRegistrationTest, dispatchMultiple) {
bool autograd_called = false;
auto registrar = c10::import()
.def("test::fn(Tensor self) -> Tensor")
.impl("test::fn", torch::dispatch(c10::DispatchKey::CPUTensorId, [&](const Tensor& x) { cpu_called = true; return x; }))
.impl("test::fn", torch::dispatch(c10::kCUDA, [&](const Tensor& x) { cuda_called = true; return x; }))
.impl("test::fn", torch::dispatch_autograd([&](const Tensor& x) { autograd_called = true; return x; }));
// NB: Direct use of DispatchKey is discouraged; use the DeviceType
// k-synonyms instead
.impl("test::fn", c10::DispatchKey::CPUTensorId, [&](const Tensor& x) { cpu_called = true; return x; })
.impl("test::fn", c10::kCUDA, [&](const Tensor& x) { cuda_called = true; return x; })
.impl("test::fn", c10::kAutograd, [&](const Tensor& x) { autograd_called = true; return x; });

auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
ASSERT_TRUE(op.has_value());
Expand All @@ -1440,7 +1442,7 @@ TEST(NewOperatorRegistrationTest, dispatchMultiple) {

TEST(NewOperatorRegistrationTest, fallback) {
auto registrar = c10::import()
.fallback(torch::dispatch(c10::kCPU, c10::CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>()));
.fallback(c10::kCPU, c10::CppFunction::makeFromBoxedFunction<&backend_fallback_kernel>());

auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()");
auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""});
Expand All @@ -1454,12 +1456,12 @@ TEST(NewOperatorRegistrationTest, BackendSelectRedispatchesToCPU) {
bool backend_generic_called = false;
auto registrar = c10::import()
.def("test::fn(Tensor self) -> Tensor")
.impl("test::fn", torch::dispatch(c10::kCPU, [&](const Tensor& x) { cpu_called = true; return x; }))
.impl("test::fn", torch::dispatch(c10::DispatchKey::BackendSelect, [&](const Tensor& x) {
.impl("test::fn", c10::kCPU, [&](const Tensor& x) { cpu_called = true; return x; })
.impl("test::fn", c10::DispatchKey::BackendSelect, [&](const Tensor& x) {
backend_generic_called = true;
auto op = c10::Dispatcher::singleton().findSchema({"test::fn", ""});
return c10::Dispatcher::singleton().callUnboxedRedispatch<Tensor, const Tensor&>(*op, c10::DispatchKey::BackendSelect, x);
}))
})
;
auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
ASSERT_TRUE(op.has_value());
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/function_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,16 @@ def TypedDict(name, attrs, total=True): # type: ignore
CppFunction::makeUnboxedOnly(TypeDefault::${type_wrapper_name}))
""")
BACKEND_UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\
.impl("${operator_name_with_overload}", torch::dispatch(
.impl("${operator_name_with_overload}",
DispatchKey::${Backend}TensorId,
CppFunction::makeUnboxedOnly(${Type}::${type_wrapper_name})))
CppFunction::makeUnboxedOnly(${Type}::${type_wrapper_name}))
""")
DEFAULT_FUNCTION_REGISTRATION = CodeTemplate("""\
.impl("${operator_name_with_overload}", &TypeDefault::${type_wrapper_name})
""")
BACKEND_FUNCTION_REGISTRATION = CodeTemplate("""\
.impl("${operator_name_with_overload}",
torch::dispatch(DispatchKey::${Backend}TensorId, &${Type}::${type_wrapper_name}))
DispatchKey::${Backend}TensorId, &${Type}::${type_wrapper_name})
""")

# add non-virtual declaration to TensorBody.h
Expand Down
9 changes: 4 additions & 5 deletions aten/src/ATen/test/backend_fallback_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,11 @@ TEST(BackendFallbackTest, TestBackendFallbackWithWrapper) {
TEST(BackendFallbackTest, TestFallthroughBackendFallback) {
// By default fallthrough
auto registry = c10::import()
.fallback(
c10::dispatch(DispatchKey::TESTING_ONLY_GenericModeTensorId,
c10::CppFunction::makeFallthrough()))
.fallback(DispatchKey::TESTING_ONLY_GenericModeTensorId,
c10::CppFunction::makeFallthrough())
.impl("aten::mul.Tensor",
c10::dispatch(DispatchKey::TESTING_ONLY_GenericModeTensorId,
c10::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>()));
DispatchKey::TESTING_ONLY_GenericModeTensorId,
c10::CppFunction::makeFromBoxedFunction<&generic_mode_fallback>());

c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericModeTensorId);

Expand Down
Loading

0 comments on commit 2db6119

Please sign in to comment.