Skip to content

Commit

Permalink
add a boxed CPU fallback kernel (pytorch#58065)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#58065

This PR replaces the existing code-generated CPU fallback kernels that XLA uses with a single boxed CPU fallback.

Current state: there are a couple different design ideas that I want to point out, but the logic for the actually kernel is mostly done and passing tests.

### Design

To preface, I'm not 100% tied to the current design and I'm putting the PR up now for opinions and totally open to alternatives, some of which I listed below. Actually after writing this description, I'm leaning toward the following changes:
* Confirm whether or not we can remove all C++ logging info directly in the yaml.

**Current Design**

All of the CPU fallback codegen is deleted. In its place, XLA (and other external backends, later) can choose to opt into a CPU fallback by adding the following code in a C++ file. I have an corresponding [xla-side PR with the xla changes](https://github.com/pytorch/xla/pull/2945/files#diff-1a005c10039f0cb11130a3b740f5de716d2f10acaea121017016025861886798R1).

There's no actual requirement to split up the code into a .h and .cpp file, but that's necessary in the XLA case because they sometimes need to call the fallback directly from their handcrafted kernels.

```
// xla_cpu_fallback.h
#include <ATen/native/CPUFallback.h>
...
void xla_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);
...
```
```
// xla_cpu_fallback.cpp
#include "torch_xla/csrc/aten_cpu_fallback.h"
...
void xla_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
  // Do custom logging here
  ...
  // Call the actual boxed CPU fallback.
  at::native::cpu_fallback(op, stack);
}

TORCH_LIBRARY_IMPL(_, XLA, m) {
  m.fallback(torch::CppFunction::makeFromBoxedFunction<&xla_cpu_fallback>());
}
```

Now that the fallback is exposed in the backend, they can call it directly. Doing so requires converting from an unboxed to a boxed context, which we provide a utility function before. E.g.:
```
#include <ATen/native/CPUFallback.h>

at::Tensor addmm(const at::Tensor& self,const at::Tensor& mat1,const at::Tensor& mat2,const at::Scalar& beta,const at::Scalar& alpha) {
  ....
  if (...call_fallback...) {
    return at::native::call_fallback_fn<&xla_cpu_fallback, decltype(at::addmm)>::call("aten::addmm", self, mat1, mat2, beta, alpha);
  }
  ...
}
```

That `decltype(at::addmm)` logic isn't actually used everywhere in the xla-side PR yet, since you hit issues with overloads. I could use it everywhere once pytorch#58092 lands.

**Alternatives: The API for calling the CPU fallback directly is ugly, can we make it nicer?**
We could change the api to use `at::redispatch`, which would make it look something like this:
```
at::Tensor addmm(const at::Tensor& self,const at::Tensor& mat1,const at::Tensor& mat2,const at::Scalar& beta,const at::Scalar& alpha) {
  ....
  if (...call_fallback...) {
    return at::redispatch::addmm(c10::DispatchKeySet(c10::DispatchKey::CPUFallback), self, mat1, mat2, beta, alpha);
  }
  ...
}
```
Which definitely feels cleaner, but also requires adding a new DispatchKey just for this use case. Conditionally calling the CPU fallback doesn't sound like a hugely important use case, so I don't know if giving up one of our 64 dispatch key slots is worth the API improvement. Totally open to other opinions though!

Another more mild improvement that would avoid having to pass operator string names (including overloads) around would be to codegen (yet another) namespaced API. Something like this:
```
at::Tensor addmm(const at::Tensor& self,const at::Tensor& mat1,const at::Tensor& mat2,const at::Scalar& beta,const at::Scalar& alpha) {
  ....
  if (...call_fallback...) {
    return at::fallback::addmm<&xla_cpu_fallback>(self, mat1, mat2, beta, alpha);
  }
  ...
}
```

Writing that out actually I actually like it more (I think it'll let us get rid of `decltype(...)`). Maybe that is nice enough to warrant a new codegen API - I haven't tried adding that yet, but if people like it I'm happy to try it out.

**More alternatives**
The current design also involves the backend manually writing and registering the boxed fallback themselves, but an alternative would be for us to do it in codegen too: they would just need to pass in all of the C++ logging that they want done in the fallback, directly through the yaml. The main downsides:
* Backend code that wants to call the fallback needs to abide by whatever convention our codegen uses to name the generated boxed fallback.
* Passing custom C++ logging through yaml is just more fragile: right now xla uses an `iostream` to log each tensor arg in the operator, so we'd have to either force other backends into the same convention or figure something else out later.

To be fair, we actually already do that: XLA has custom per-tensor-arg logging for all of the generated `out` wrappers in the codegen, which we do by passing their C++ logging info through the yaml. This seems unnecessary though, since `out` wrappers just call into a functional kernel, which is hand written with its own custom logging. So my take is: try to remove custom C++ logging from the yaml, and if it turns out to be really necessary, then we may as well take advantage of that to codegen the fallback.

### Performance impact

While ops that fall back to CPU aren't exactly hot path, we probably don't want to use a boxed fallback if it turns out to be an absolute perf killer.

I ran my benchmarks using callgrind, benchmarking both `at::add` and `at::add_out` run on XLA. My callgrind benchmark for `at::add` can be found here (the add_out benchmark looks basically the same): https://www.internalfb.com/phabricator/paste/view/P415418587. I created the benchmark by hacking the existing xla C++ test build scripts and throwing in a reference to callgrind.

I also attached the full callgrind output for each benchmark; the full output is actually pretty noise and hard to parse, but I focused on everything underneath the `at::add()` call in the output, which was much more stable. My guess is that it's due to some heavyweight async startup processing that xla does.

`at::add`:
before: 88,505,130 instructions. Full output: https://www.internalfb.com/phabricator/paste/view/P415421001
after: 102,185,654 instructions. Full output: https://www.internalfb.com/phabricator/paste/view/P415421273
delta: ~15.5% increase

`at::add_out`:
before: 63,897,395 instructions. Full output: https://www.internalfb.com/intern/everpaste/?handle=GBrrKwtAPlix9wUEAOZtrFXpdO5UbsIXAAAz
after: 73,170,346 instructions. Full output: https://www.internalfb.com/phabricator/paste/view/P415423227
delta: ~14.5% increase

High level takeaway: A framework overhead increase of 10-20% doesn't seem too horrible for the CPU fallback use case.

For structured, functional ops that requires a CPU fallback, we're actually in an unfortunate situation: we're doing even more work than necessary. Our codegen automatically creates a `CompositeExplicitAutograd` kernel which calls into the `out` operator. So the extra work that we end up doing is:
* An extra dispatcher hop: (at::add -> CompositeExplicitAutograd -> CPUFallback -> at::native::add) instead of (at::add -> CPUFallback -> at::native::add)
* An unnecessary tensor allocation (the CompositeExplicitAutograd kernel uses at::empty() to create an output tensor, which is immediately overwritten by the CPU fallback)
* An unnecessary meta() call (the CompositeExplicitAutograd kernel calls it to create the output tensor, but we call it again in the CPU kernel).
* unboxing->boxing->unboxing logic (this is the only strictly required piece)

There are definitely ways to avoid the unnecessary work explained above: one would be to give the boxed fallback higher priority than composite keys (there's [an issue for it here](pytorch#55104)), and codegen fallthroughs for all composite ops. It'll require more infra to set up, so I see it as more of a perf knob that we can apply if we need it later.

Unfortunately I couldn't dig much deeper into the differences aside from the aggregate change in instructions, since it looks like callgrind fudged some of the instruction attribution (`at::to_cpu` takes up a ton of instructions, but I don't see any attribution for the `at::native::add` kernel anywhere).

Test Plan: Imported from OSS

Reviewed By: jbschlosser

Differential Revision: D28833085

Pulled By: bdhirsh

fbshipit-source-id: 537ebd5d7fb5858f1158764ff47132d503c3b92b
  • Loading branch information
bdhirsh authored and facebook-github-bot committed Jun 25, 2021
1 parent ad69e2f commit 9134b0e
Show file tree
Hide file tree
Showing 10 changed files with 220 additions and 456 deletions.
12 changes: 6 additions & 6 deletions aten/src/ATen/core/boxing/KernelFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ class TORCH_API KernelFunction final {
static KernelFunction makeAmbiguousAutogradOther();
static KernelFunction makeNamedNotSupported();

template<BoxedKernelFunction* func>
static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack);

template<BoxedKernelFunction_withDispatchKeys* func>
static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack);

/**
* Create a KernelFunction from an unboxed lambda.
*
Expand All @@ -240,12 +246,6 @@ class TORCH_API KernelFunction final {

explicit KernelFunction(std::unique_ptr<OperatorKernel> functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func);

template<BoxedKernelFunction* func>
static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack);

template<BoxedKernelFunction_withDispatchKeys* func>
static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack);

OperatorKernel* getFunctor_() const;

std::shared_ptr<OperatorKernel> functor_;
Expand Down
157 changes: 157 additions & 0 deletions aten/src/ATen/native/CPUFallback.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#include <ATen/native/CPUFallback.h>

#include <ATen/core/ivalue.h>
#include <ATen/core/stack.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
#include <ATen/Functions.h>
#include <torch/library.h>

namespace at { namespace native {

// convenience helper for converting tensors to cpu

std::vector<at::Tensor> to_cpu(const at::TensorList& tensors) {
// We can't just call at::to_cpu() on the entire list of Tensors
// Because it will break on undefined tensors. Separate out undefined tensors first.
std::vector<at::Tensor> cpu_tensors(tensors.size());
std::vector<at::Tensor> valid_tensors;
std::vector<bool> to_translate(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
const at::Tensor& tensor = tensors[i];
// Explicitly handling undefined tensors here instead of letting `at::_to_cpu` handle it.
// Otherwise, we'd need to require all backends with their own implementation of _to_cpu
// to properly handle undefined tensors.
if (tensor.defined()) {
to_translate[i] = true;
valid_tensors.push_back(tensor);
} else {
cpu_tensors[i] = tensor;
}
}
auto cpu_valid_tensors = at::_to_cpu(valid_tensors);
for (size_t i = 0, defined_pos = 0; i < tensors.size(); ++i) {
if (to_translate[i]) {
cpu_tensors[i] = std::move(cpu_valid_tensors[defined_pos++]);
}
}
return cpu_tensors;
}


void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
auto& schema_args = op.schema().arguments();
const auto num_arguments = schema_args.size();
auto arguments = torch::jit::last(stack, num_arguments);
const auto arguments_begin = stack->size() - num_arguments;

std::vector<at::Tensor> tensor_args;
std::vector<int> tensor_args_indices;

// Step 1: Convert all non-CPU tensor inputs into CPU tensors
// and put them on the stack at the correct indices.
for (int64_t idx = 0; idx < arguments.size(); ++idx) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
tensor_args.push_back(ivalue.toTensor());
tensor_args_indices.push_back(idx);
} else if (ivalue.isTensorList()) {
// Note: we copy each TensorList argument to CPU individually out of convenience,
// but XLA would benefit from materializing all tensor and TensorList args onto the CPU at the same time.
// We can improve this if we need better perf for XLA's CPU fallbacks.
auto cpu_ivalue = c10::IValue(c10::List<at::Tensor>(to_cpu(ivalue.toTensorList().vec())));
(*stack)[arguments_begin + idx] = std::move(cpu_ivalue);
}
}
// XLA requires all of the tensor arguments to be gathered up and converted to CPU together.
auto cpu_tensors = to_cpu(tensor_args);

for (auto i = 0; i < tensor_args_indices.size(); ++i) {
auto idx = tensor_args_indices[i];
(*stack)[arguments_begin + idx] = c10::IValue(cpu_tensors[i]);
}

// Step 2: Call the underlying CPU implementation of the operator
op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::CPU), stack);

// Step 3: We need to take special care to handle mutable aliases properly:
// If any input tensors are mutable aliases, we need to
// directly copy the updated data on the CPU tensors back to the original inputs.
for (int64_t i = 0; i < tensor_args_indices.size(); ++i) {
auto tensor_idx = tensor_args_indices[i];
const auto& alias_info = schema_args[tensor_idx].alias_info();
if (alias_info.has_value() && alias_info.value().isWrite()) {
at::_copy_from_and_resize(cpu_tensors[i], tensor_args[i]);
}
}

// Step 4: Convert any CPU output tensors back to the original input device.
// For mutable alias'd outputs, we also need to take special care
// to move the ORIGINAL input tensor back onto the stack, in place of
// the temporary CPU output tensor that we created.
//
// Note [CPU Fallback Does Not Handle View Operators]
// Also note that we are incapable of handling immutable alises properly.
// Why?
// Schemas with an immutable alias'd tensor outputs correspond to view operators.
// For example, the `view_as` schema from native_functions.yaml:
// `view_as(Tensor(a) self, Tensor other) -> Tensor(a)`
// We can't handle these ops properly, because view ops are supposed to return
// a NEW tensor that shares the SAME storage as the original tensor.
// However, the new tensor that we created cannot share the same storage,
// since it lives on CPU and the original tensor lives on a different device.
// Because of that, we warn if someone attempts to call the
// CPU fallback on a view operator (this is to maintain BC for view ops for XLA
// that fall back to CPU).
const auto& schema_returns = op.schema().returns();
const auto& num_returns = schema_returns.size();
auto returns = torch::jit::last(stack, num_returns);
const auto returns_begin = stack->size() - num_returns;

for (int64_t idx = 0; idx < returns.size(); ++idx) {
if (returns[idx].isTensor()) {
const auto& return_tens = returns[idx].toTensor();
if (return_tens.defined()) {
const auto& alias_info = schema_returns[idx].alias_info();
if (alias_info.has_value() && alias_info.value().isWrite()) {
// Case (1): mutable alias case. Move the input ivalue directly onto the stack
// in place of the existing cpu output tensor.
bool found_alias = false;
// We could store some extra metadata on the function schema to avoid the loop here
// if we need to improve perf.
for (int64_t i = 0; i < tensor_args_indices.size(); ++i) {
auto input_tensor_idx = tensor_args_indices[i];
const auto& input_tensor = cpu_tensors[i];
const auto& input_alias_info = schema_args[input_tensor_idx].alias_info();
if (input_tensor.defined() && alias_info == input_alias_info) {
// We've found the original input tensor that aliases with the current output.
// Wrap it in an IValue and put it directly on the stack.
(*stack)[returns_begin + idx] = c10::IValue(tensor_args[i]);
found_alias = true;
break;
}
}
TORCH_CHECK(found_alias, "The operator ", op.schema().operator_name(), " appears to have invalid alias information. ",
"Found a return tensor argument with a mismatched mutable alias: ", schema_returns[idx]);
} else {
if (alias_info.has_value() && !alias_info.value().isWrite()) {
// immutable alias (view) case: Warn here, since we're copying and not creating a view.
//If this operator is needed, the backend should provide a kernel for it.
// See Note [CPU Fallback Does Not Handle View Operators]
auto tgt_device = tensor_args[0].device();
TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ",
"but it has no implementation for the backend \"", tgt_device, "\". View operators don't support ",
"falling back to run on the CPU, since the tensor's storage cannot be shared across devices.");
}
// Case (2): copy case. Copy the cpu output tensor to the original device.
auto tgt_device = tensor_args[0].device();
(*stack)[returns_begin + idx] = c10::IValue(returns[idx].toTensor().to(tgt_device));
}
}
}
}
}

} // namespace native
} // namespace at
50 changes: 50 additions & 0 deletions aten/src/ATen/native/CPUFallback.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#pragma once

#include <ATen/core/ivalue.h>
#include <ATen/core/stack.h>
#include <ATen/core/boxing/KernelFunction.h>
#include <ATen/core/dispatch/Dispatcher.h>
#include <c10/util/Metaprogramming.h>
#include <torch/library.h>

namespace at { namespace native {

// This function implements a boxed fallback to CPU.
// External backends can add their own custom logging on top if it to customize their own CPU fallbacks.
TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack);

// This is a helper function that backends can use to directly call their boxed CPU fallback
// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands.
template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, class ReturnType, class... ParameterTypes>
struct _call_fallback_fn final {};

template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op, class ReturnType, class... ParameterTypes>
struct _call_fallback_fn<fallback_fn, Op, ReturnType(ParameterTypes...)> final {
static_assert(std::is_same<ReturnType, typename guts::infer_function_traits_t<typename Op::schema>::return_type>::value,
"Return type mismatch");
static_assert(std::is_same<guts::typelist::typelist<ParameterTypes...>, typename guts::infer_function_traits_t<typename Op::schema>::parameter_types>::value,
"Parameter types mismatch");

static ReturnType call(ParameterTypes... args) {
auto op = c10::Dispatcher::singleton()
// TODO: figure out how to make compiler happy without dynamic casts
.findSchemaOrThrow((const char*) Op::name, (const char*) Op::overload_name)
//.findSchemaOrThrow("a", "b")
.typed<ReturnType (ParameterTypes...)>();
return c10::impl::BoxedKernelWrapper<ReturnType (ParameterTypes...)>::call(
c10::KernelFunction::make_boxed_function<fallback_fn>,
nullptr,
op,
c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset.
//std::forward<ParameterTypes...>(args...)
// TODO: get std::forward<> to work
args...
);
}
};

template<c10::KernelFunction::BoxedKernelFunction* fallback_fn, class Op>
using call_fallback_fn = _call_fallback_fn<fallback_fn, Op, typename Op::schema>;

} // namespace native
} // namespace at
113 changes: 0 additions & 113 deletions aten/src/ATen/templates/aten_xla_type_default.cpp

This file was deleted.

19 changes: 0 additions & 19 deletions aten/src/ATen/templates/aten_xla_type_default.h

This file was deleted.

1 change: 1 addition & 0 deletions tools/build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,7 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/ConvolutionMM3d.cpp",
"aten/src/ATen/native/ConvolutionTBC.cpp",
"aten/src/ATen/native/Copy.cpp",
"aten/src/ATen/native/CPUFallback.cpp",
"aten/src/ATen/native/Cross.cpp",
"aten/src/ATen/native/DilatedMaxPool2d.cpp",
"aten/src/ATen/native/DilatedMaxPool3d.cpp",
Expand Down
2 changes: 0 additions & 2 deletions tools/codegen/dest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
from .register_dispatch_key import RegisterDispatchKey as RegisterDispatchKey
from .native_functions import compute_native_function_declaration as compute_native_function_declaration
from .gen_external_aten_fallbacks import (has_autogenerated_composite_kernel as has_autogenerated_composite_kernel,
GenExternalAtenFallback as GenExternalAtenFallback)
Loading

0 comments on commit 9134b0e

Please sign in to comment.