diff --git a/aten/src/ATen/core/boxing/KernelFunction.h b/aten/src/ATen/core/boxing/KernelFunction.h index 07e4cb2bb33259..1deb641b4e839a 100644 --- a/aten/src/ATen/core/boxing/KernelFunction.h +++ b/aten/src/ATen/core/boxing/KernelFunction.h @@ -219,6 +219,12 @@ class TORCH_API KernelFunction final { static KernelFunction makeAmbiguousAutogradOther(); static KernelFunction makeNamedNotSupported(); + template + static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack); + + template + static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack); + /** * Create a KernelFunction from an unboxed lambda. * @@ -240,12 +246,6 @@ class TORCH_API KernelFunction final { explicit KernelFunction(std::unique_ptr functor, InternalBoxedKernelFunction* boxed_kernel_func, void* unboxed_kernel_func); - template - static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack); - - template - static void make_boxed_function(OperatorKernel*, const OperatorHandle& opHandle, DispatchKeySet, Stack* stack); - OperatorKernel* getFunctor_() const; std::shared_ptr functor_; diff --git a/aten/src/ATen/native/CPUFallback.cpp b/aten/src/ATen/native/CPUFallback.cpp new file mode 100644 index 00000000000000..415222ab181ddb --- /dev/null +++ b/aten/src/ATen/native/CPUFallback.cpp @@ -0,0 +1,157 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace at { namespace native { + +// convenience helper for converting tensors to cpu + +std::vector to_cpu(const at::TensorList& tensors) { + // We can't just call at::to_cpu() on the entire list of Tensors + // Because it will break on undefined tensors. Separate out undefined tensors first. + std::vector cpu_tensors(tensors.size()); + std::vector valid_tensors; + std::vector to_translate(tensors.size()); + for (size_t i = 0; i < tensors.size(); ++i) { + const at::Tensor& tensor = tensors[i]; + // Explicitly handling undefined tensors here instead of letting `at::_to_cpu` handle it. + // Otherwise, we'd need to require all backends with their own implementation of _to_cpu + // to properly handle undefined tensors. + if (tensor.defined()) { + to_translate[i] = true; + valid_tensors.push_back(tensor); + } else { + cpu_tensors[i] = tensor; + } + } + auto cpu_valid_tensors = at::_to_cpu(valid_tensors); + for (size_t i = 0, defined_pos = 0; i < tensors.size(); ++i) { + if (to_translate[i]) { + cpu_tensors[i] = std::move(cpu_valid_tensors[defined_pos++]); + } + } + return cpu_tensors; +} + + +void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + auto& schema_args = op.schema().arguments(); + const auto num_arguments = schema_args.size(); + auto arguments = torch::jit::last(stack, num_arguments); + const auto arguments_begin = stack->size() - num_arguments; + + std::vector tensor_args; + std::vector tensor_args_indices; + + // Step 1: Convert all non-CPU tensor inputs into CPU tensors + // and put them on the stack at the correct indices. + for (int64_t idx = 0; idx < arguments.size(); ++idx) { + const auto& ivalue = arguments[idx]; + if (ivalue.isTensor()) { + tensor_args.push_back(ivalue.toTensor()); + tensor_args_indices.push_back(idx); + } else if (ivalue.isTensorList()) { + // Note: we copy each TensorList argument to CPU individually out of convenience, + // but XLA would benefit from materializing all tensor and TensorList args onto the CPU at the same time. + // We can improve this if we need better perf for XLA's CPU fallbacks. + auto cpu_ivalue = c10::IValue(c10::List(to_cpu(ivalue.toTensorList().vec()))); + (*stack)[arguments_begin + idx] = std::move(cpu_ivalue); + } + } + // XLA requires all of the tensor arguments to be gathered up and converted to CPU together. + auto cpu_tensors = to_cpu(tensor_args); + + for (auto i = 0; i < tensor_args_indices.size(); ++i) { + auto idx = tensor_args_indices[i]; + (*stack)[arguments_begin + idx] = c10::IValue(cpu_tensors[i]); + } + + // Step 2: Call the underlying CPU implementation of the operator + op.redispatchBoxed(c10::DispatchKeySet(c10::DispatchKey::CPU), stack); + + // Step 3: We need to take special care to handle mutable aliases properly: + // If any input tensors are mutable aliases, we need to + // directly copy the updated data on the CPU tensors back to the original inputs. + for (int64_t i = 0; i < tensor_args_indices.size(); ++i) { + auto tensor_idx = tensor_args_indices[i]; + const auto& alias_info = schema_args[tensor_idx].alias_info(); + if (alias_info.has_value() && alias_info.value().isWrite()) { + at::_copy_from_and_resize(cpu_tensors[i], tensor_args[i]); + } + } + + // Step 4: Convert any CPU output tensors back to the original input device. + // For mutable alias'd outputs, we also need to take special care + // to move the ORIGINAL input tensor back onto the stack, in place of + // the temporary CPU output tensor that we created. + // + // Note [CPU Fallback Does Not Handle View Operators] + // Also note that we are incapable of handling immutable alises properly. + // Why? + // Schemas with an immutable alias'd tensor outputs correspond to view operators. + // For example, the `view_as` schema from native_functions.yaml: + // `view_as(Tensor(a) self, Tensor other) -> Tensor(a)` + // We can't handle these ops properly, because view ops are supposed to return + // a NEW tensor that shares the SAME storage as the original tensor. + // However, the new tensor that we created cannot share the same storage, + // since it lives on CPU and the original tensor lives on a different device. + // Because of that, we warn if someone attempts to call the + // CPU fallback on a view operator (this is to maintain BC for view ops for XLA + // that fall back to CPU). + const auto& schema_returns = op.schema().returns(); + const auto& num_returns = schema_returns.size(); + auto returns = torch::jit::last(stack, num_returns); + const auto returns_begin = stack->size() - num_returns; + + for (int64_t idx = 0; idx < returns.size(); ++idx) { + if (returns[idx].isTensor()) { + const auto& return_tens = returns[idx].toTensor(); + if (return_tens.defined()) { + const auto& alias_info = schema_returns[idx].alias_info(); + if (alias_info.has_value() && alias_info.value().isWrite()) { + // Case (1): mutable alias case. Move the input ivalue directly onto the stack + // in place of the existing cpu output tensor. + bool found_alias = false; + // We could store some extra metadata on the function schema to avoid the loop here + // if we need to improve perf. + for (int64_t i = 0; i < tensor_args_indices.size(); ++i) { + auto input_tensor_idx = tensor_args_indices[i]; + const auto& input_tensor = cpu_tensors[i]; + const auto& input_alias_info = schema_args[input_tensor_idx].alias_info(); + if (input_tensor.defined() && alias_info == input_alias_info) { + // We've found the original input tensor that aliases with the current output. + // Wrap it in an IValue and put it directly on the stack. + (*stack)[returns_begin + idx] = c10::IValue(tensor_args[i]); + found_alias = true; + break; + } + } + TORCH_CHECK(found_alias, "The operator ", op.schema().operator_name(), " appears to have invalid alias information. ", + "Found a return tensor argument with a mismatched mutable alias: ", schema_returns[idx]); + } else { + if (alias_info.has_value() && !alias_info.value().isWrite()) { + // immutable alias (view) case: Warn here, since we're copying and not creating a view. + //If this operator is needed, the backend should provide a kernel for it. + // See Note [CPU Fallback Does Not Handle View Operators] + auto tgt_device = tensor_args[0].device(); + TORCH_WARN(false, "The operator ", op.schema().operator_name(), " appears to be a view operator, ", + "but it has no implementation for the backend \"", tgt_device, "\". View operators don't support ", + "falling back to run on the CPU, since the tensor's storage cannot be shared across devices."); + } + // Case (2): copy case. Copy the cpu output tensor to the original device. + auto tgt_device = tensor_args[0].device(); + (*stack)[returns_begin + idx] = c10::IValue(returns[idx].toTensor().to(tgt_device)); + } + } + } + } +} + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/CPUFallback.h b/aten/src/ATen/native/CPUFallback.h new file mode 100644 index 00000000000000..b8c54b01774868 --- /dev/null +++ b/aten/src/ATen/native/CPUFallback.h @@ -0,0 +1,50 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace at { namespace native { + +// This function implements a boxed fallback to CPU. +// External backends can add their own custom logging on top if it to customize their own CPU fallbacks. +TORCH_API void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); + +// This is a helper function that backends can use to directly call their boxed CPU fallback +// TODO: update and add a usage example after https://github.com/pytorch/pytorch/pull/58092 lands. +template +struct _call_fallback_fn final {}; + +template +struct _call_fallback_fn final { + static_assert(std::is_same::return_type>::value, + "Return type mismatch"); + static_assert(std::is_same, typename guts::infer_function_traits_t::parameter_types>::value, + "Parameter types mismatch"); + + static ReturnType call(ParameterTypes... args) { + auto op = c10::Dispatcher::singleton() + // TODO: figure out how to make compiler happy without dynamic casts + .findSchemaOrThrow((const char*) Op::name, (const char*) Op::overload_name) + //.findSchemaOrThrow("a", "b") + .typed(); + return c10::impl::BoxedKernelWrapper::call( + c10::KernelFunction::make_boxed_function, + nullptr, + op, + c10::DispatchKeySet(), // we know that the cpu_fallback doesn't use the dispatch keyset. + //std::forward(args...) + // TODO: get std::forward<> to work + args... + ); + } +}; + +template +using call_fallback_fn = _call_fallback_fn; + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/templates/aten_xla_type_default.cpp b/aten/src/ATen/templates/aten_xla_type_default.cpp deleted file mode 100644 index 040a752156eac9..00000000000000 --- a/aten/src/ATen/templates/aten_xla_type_default.cpp +++ /dev/null @@ -1,113 +0,0 @@ -// ${generated_comment} -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace ${cpp_namespace} { - -// convenience helpers for extracting out an optional c10::Device - -c10::optional get_device_arg(at::Tensor tensor) { - return tensor.device(); -} - -c10::optional get_device_arg(c10::optional tensor) { - return tensor ? c10::optional((*tensor).device()) : c10::nullopt; -} - -c10::optional get_device_arg(std::vector tensors) { - return tensors.size() > 0 ? c10::optional(tensors[0].device()) : c10::nullopt; -} - -c10::optional get_device_arg(at::TensorList tensors) { - return tensors.size() > 0 ? c10::optional(tensors[0].device()) : c10::nullopt; -} - -c10::optional get_device_arg(c10::optional device) { - return device; -} - -c10::optional get_device_arg(c10::Device device) { - return c10::optional(device); -} - -// convenience helpers for converting tensors to an optional device - -at::Tensor to_device_opt(const at::Tensor tensor, c10::optional device) { - return device ? tensor.to(*device) : tensor; -} - -std::vector to_device_opt(const std::vector& tensors, c10::optional device) { - std::vector output_tensors; - for (const auto& t : tensors) { - output_tensors.push_back(to_device_opt(t, device)); - } - return output_tensors; -} - -// convenience helper for converting tensors to cpu - -std::vector to_cpu(const at::TensorList& tensors) { - // We can't just call at::to_cpu() on the entire list of Tensors - // Because it will break on undefined tensors. Separate out undefined tensors first. - std::vector cpu_tensors(tensors.size()); - std::vector valid_tensors; - std::vector to_translate(tensors.size()); - for (size_t i = 0; i < tensors.size(); ++i) { - const at::Tensor& tensor = tensors[i]; - if (tensor.defined()) { - to_translate[i] = true; - valid_tensors.push_back(tensor); - } else { - cpu_tensors[i] = tensor; - } - } - auto cpu_valid_tensors = at::_to_cpu(valid_tensors); - for (size_t i = 0, defined_pos = 0; i < tensors.size(); ++i) { - if (to_translate[i]) { - cpu_tensors[i] = std::move(cpu_valid_tensors[defined_pos++]); - } - } - return cpu_tensors; -} - -std::vector> to_cpu(const std::vector>& tensors) { - std::vector> opt_tensors(tensors.size()); - std::vector materialized_tensors; - std::vector to_translate(tensors.size()); - for (size_t i = 0; i < tensors.size(); ++i) { - auto tensor = tensors[i]; - if (tensor.has_value()) { - to_translate[i] = true; - materialized_tensors.push_back(*tensor); - } - } - auto aten_materialized_tensors = to_cpu(materialized_tensors); - for (size_t i = 0, defined_pos = 0; i < tensors.size(); ++i) { - if (to_translate[i]) { - opt_tensors[i] = - std::move(aten_materialized_tensors[defined_pos++]); - } - } - return opt_tensors; -} - -${dispatch_aten_fallback_definitions} - - - -TORCH_LIBRARY_IMPL(aten, XLA, m) { -${dispatch_registrations} - -} - -} // namespace torch_xla diff --git a/aten/src/ATen/templates/aten_xla_type_default.h b/aten/src/ATen/templates/aten_xla_type_default.h deleted file mode 100644 index 6d1e84bdf491ef..00000000000000 --- a/aten/src/ATen/templates/aten_xla_type_default.h +++ /dev/null @@ -1,19 +0,0 @@ -// ${generated_comment} - -#include -#include - -using c10::Stream; - -namespace ${cpp_namespace} { - -class AtenXlaTypeDefault { - public: -${dispatch_aten_fallback_declarations} - -}; - -// TODO: maybe kill this, doesn't look like XLA actually calls it anywhere -void RegisterAtenTypeFunctions(); - -} // namespace torch_xla diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl index f95b9fcf04ae6b..d972be33ac5f7b 100644 --- a/tools/build_variables.bzl +++ b/tools/build_variables.bzl @@ -974,6 +974,7 @@ aten_native_source_non_codegen_list = [ "aten/src/ATen/native/ConvolutionMM3d.cpp", "aten/src/ATen/native/ConvolutionTBC.cpp", "aten/src/ATen/native/Copy.cpp", + "aten/src/ATen/native/CPUFallback.cpp", "aten/src/ATen/native/Cross.cpp", "aten/src/ATen/native/DilatedMaxPool2d.cpp", "aten/src/ATen/native/DilatedMaxPool3d.cpp", diff --git a/tools/codegen/dest/__init__.py b/tools/codegen/dest/__init__.py index 173fccba1924fc..ab4bada2775720 100644 --- a/tools/codegen/dest/__init__.py +++ b/tools/codegen/dest/__init__.py @@ -1,4 +1,2 @@ from .register_dispatch_key import RegisterDispatchKey as RegisterDispatchKey from .native_functions import compute_native_function_declaration as compute_native_function_declaration -from .gen_external_aten_fallbacks import (has_autogenerated_composite_kernel as has_autogenerated_composite_kernel, - GenExternalAtenFallback as GenExternalAtenFallback) diff --git a/tools/codegen/dest/gen_external_aten_fallbacks.py b/tools/codegen/dest/gen_external_aten_fallbacks.py deleted file mode 100644 index 62fdd800b39e29..00000000000000 --- a/tools/codegen/dest/gen_external_aten_fallbacks.py +++ /dev/null @@ -1,289 +0,0 @@ -from typing import List, Optional, Union, Dict -from typing_extensions import Literal -from dataclasses import dataclass -import re - -from tools.codegen.context import method_with_native_function -from tools.codegen.utils import Target, mapMaybe -from tools.codegen.model import (Argument, BackendIndex, SchemaKind, assert_never, - Return, NativeFunction, NativeFunctionsGroup, - ListType, OptionalType, BaseType, BaseTy, Variant, - gets_generated_out_inplace_wrapper) -from tools.codegen.api.types import DispatcherSignature, CppSignatureGroup -import tools.codegen.api.dispatcher as dispatcher -import tools.codegen.api.cpp as cpp - -# TODO: this contains a list of regex for ops that don't get a CPU fallback. -# We should just register fallthroughs when we make the CPU fallback a boxed kernel. -_FN_DENYLIST_REGEX = [ - # ATEN functions - r'[^(]*cudnn', - r'slow_conv_transpose2d_backward.grad_output', - r'slow_conv_transpose3d_backward.grad_output', - r'slow_conv3d_backward.grad_input', - r'thnn_conv2d_backward.grad_input', - r'thnn_conv_depthwise2d_backward.grad_input', - # XLA/TPU functions -] - -# TODO: remove this list. -# Instead, the codegen will figure out which ops to generate _out wrappers for -# entirely from the yaml. Maintaining the same behavior as current XLA codegen for now. -_FN_OUT = [ - 'abs', - 'add', - 'acos', - 'acosh', - 'asin', - 'asinh', - 'atan', - 'atan2', - 'atanh', - 'baddbmm', - 'bernoulli', - 'binary_cross_entropy', - 'binary_cross_entropy_backward', - 'clamp', - 'div', - 'gather', - 'ger', - 'hardsigmoid', - 'kthvalue', - 'index_select', - 'inverse', - 'log', - 'masked_select', - 'maximum', - 'minimum', - 'pow', - 'prod', - 'nonzero', - 'round', - 'normal', - 'std', - 'take', - 'topk', - 'var', -] - -# See Note [Auto generated composite kernels] -def has_autogenerated_composite_kernel(f: NativeFunction) -> bool: - return (f.structured or f.structured_delegate is not None) and \ - (f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace) - -def requires_backend_wrapper(f: NativeFunction, backend_index: BackendIndex) -> bool: - requires_lowering = not f.has_composite_kernel and not has_autogenerated_composite_kernel(f) - has_backend_kernel = backend_index.has_kernel(f) - in_denylist = any([re.match(frx, str(f.func.name)) for frx in _FN_DENYLIST_REGEX]) - return not in_denylist and (requires_lowering or has_backend_kernel) - -def tensor_creation_api( - ret_name: str, - ret: Return, - device_param_name: str, - *, - cpu_result_name: str, - tuple_idx: Optional[int] = None -) -> str: - if (ret.type == BaseType(BaseTy.Tensor) and not ret.is_write) or \ - (isinstance(ret.type, ListType) and ret.type.elem == BaseType(BaseTy.Tensor)): - # Only raw Tensor (non-reference) returns need to be copied back from CPU to the backend device. - # Tensor references can be returned directly, since they already live on the backend device. - # See Note [Tensor Copy Returns] - return f"to_device_opt({cpu_result_name}, get_device_arg({device_param_name}))" - else: - # for non tensor-types, we don't need to convert between devices. - return ret_name - - - - -# Generates aten_xla_type_default.h and aten_xla_type_default.cpp. -# -# - This function registers external backend kernels, and also generates fallbacks to CPU. -# This is useful because pretty much all external backends (e.g. XLA) -# do not have full aten coverage. -# For operators not implemented by the external backend, our codegen -# will register these fallbacks instead. -# - Why do we generate fallback for ALL (non-composite) aten ops, including ops that -# external backends have already implemented? -# Many external backend kernels only work with specific input shapes, -# and are written to call into a cpu fallback when given inputs -# that they cannot handle. -@dataclass(frozen=True) -class GenExternalAtenFallback: - target: Union[ - Literal[Target.NAMESPACED_DEFINITION], - Literal[Target.NAMESPACED_DECLARATION], - Literal[Target.REGISTRATION], - ] - backend_index: BackendIndex - - @method_with_native_function - def __call__(self, g: Union[NativeFunctionsGroup, NativeFunction]) -> List[str]: - - def gen_unstructured_external(f: NativeFunction) -> Optional[str]: - if not requires_backend_wrapper(f, self.backend_index): - return None - - def get_device_param(args: List[Argument]) -> str: - # TODO: the XLA codegen has specific precedence rules when determining which tensor argument - # to use as the device argument. - # We should update this to be consistent with how we choose device guards. - const_tensor_or_self = [ - a for a in args if (a.type == BaseType(BaseTy.Tensor) or a.type == OptionalType(BaseType(BaseTy.Tensor))) - and not a.is_write] - if any(const_tensor_or_self): - return const_tensor_or_self[0].name - tensor_like = [a for a in args if a.type.is_tensor_like()] - if any(tensor_like): - return tensor_like[0].name - device_like = [a for a in args if a.type == BaseType(BaseTy.Device) - or a.type == OptionalType(BaseType(BaseTy.Device))] - if any(device_like): - return device_like[0].name - raise AssertionError("Need a tensor-like or device argument in order to determine the output device") - - # XLA appears to have used the dispatcher convention to write their kernel signatures, - # probably because they based their signatures off of our RegistrationDeclarations.h - # See Note [External Backends Follow Dispatcher API] - dispatcher_sig = DispatcherSignature.from_schema(f.func) - name = dispatcher_sig.name() - args = dispatcher_sig.arguments() - - if self.target is Target.NAMESPACED_DECLARATION: - return f" static {dispatcher_sig.decl()};" - - elif self.target is Target.REGISTRATION: - # This codegen is only responsible for registering CPU fallback kernels - # We also skip registrations if there is a functional backend kernel, - # because we generate out/inplace wrappers in that case (handled in register_dispatch_key.py). - if self.backend_index.get_kernel(f) is not None or \ - (isinstance(g, NativeFunctionsGroup) and gets_generated_out_inplace_wrapper(f, g, self.backend_index)): - return '' - payload = f"static_cast<{dispatcher_sig.ptr_type()}>(&AtenXlaTypeDefault::{name})" - return f' m.impl("{f.func.name}", {payload});\n' - - if self.target is not Target.NAMESPACED_DEFINITION: - assert_never(self.target) - - # Everything below here is where we generate the CPU fallback. - dispatcher_order_args = dispatcher.jit_arguments(f.func) - - # Map each argument to it's intermediate variable name in the fallback - # We have to do it separately for TensorList/Optional/Tensor - tensorlist_args: Dict[Argument, str] = { - a: f'l_{a.name}' for a in dispatcher_order_args - if isinstance(a.type, ListType) and a.type.elem == BaseType(BaseTy.Tensor)} - - opt_tensors = [ - a for a in dispatcher_order_args - if isinstance(a.type, OptionalType) and a.type.elem == BaseType(BaseTy.Tensor)] - opt_tensor_args: Dict[Argument, str] = {a: f'external_tensors_opt[{i}]' for i, a in enumerate(opt_tensors)} - - tensors = [a for a in dispatcher_order_args if a.type == BaseType(BaseTy.Tensor)] - tensor_args: Dict[Argument, str] = {a: f'external_tensors[{i}]' for i, a in enumerate(tensors)} - annotated_tensor_indices: List[int] = [ - i for i, a in enumerate(tensors) if a.annotation is not None and a.annotation.is_write] - - print_args_str = ''.join([f' << " {a.name}=" << {a.name}.toString()' for a in tensor_args.keys()]) - - tensorlist_intermediates_str = '' - if len(tensorlist_args) > 0: - tensorlist_intermediates_str = '\n'.join([f' auto {updated_name} = to_cpu({arg.name});' - for arg, updated_name in tensorlist_args.items()]) - - opt_tensor_intermediates_str = '' - if len(opt_tensor_args) > 0: - arg_str = ", ".join([a.name for a in opt_tensor_args.keys()]) - opt_tensor_intermediates_str = \ - f'\n std::vector> external_tensors_opt_tensors = {{{arg_str}}};' - opt_tensor_intermediates_str += \ - '\n auto external_tensors_opt = to_cpu(external_tensors_opt_tensors);' - - intermediates = '' - if tensorlist_intermediates_str != '': - intermediates += tensorlist_intermediates_str + '\n' - intermediates += \ - f" std::vector external_tensors_tensors = {{{', '.join([a.name for a in tensor_args.keys()])}}};" - intermediates += "\n auto external_tensors = to_cpu(external_tensors_tensors);" - if opt_tensor_intermediates_str != '': - intermediates += opt_tensor_intermediates_str - - - is_method = Variant.function not in f.variants - func_name = f'AtenXlaTypeDefault::{name}' - - # Gather all of the updated variable names to call into the CPU operator. - # Just use the original binding names for inputs where we didn't create explicit intermediate variables. - updated_bindings: List[str] = [ - tensorlist_args.get(a, opt_tensor_args.get(a, tensor_args.get(a, a.name))) for a in dispatcher_order_args] - - at_call_name = CppSignatureGroup.from_native_function( - f, method=is_method).most_faithful_signature().name() - - # Notice that we don't need to perform a translate: we're technically going from the dispatcher API - # to the faithful C++ API, which are carefuly written to be exactly the same. - cpu_result_name = 'x_result' - if is_method: - at_call = f'{updated_bindings[0]}.{at_call_name}({", ".join(name for name in updated_bindings[1:])});' - else: - at_call = f'at::{at_call_name}({", ".join(name for name in updated_bindings)});' - avoid_warning = '' - if f.func.returns: - at_call = f'auto&& {cpu_result_name} = {at_call}' - avoid_warning = f'\n static_cast({cpu_result_name}); // Avoid warnings in case not used' - - collect_mutated_tensors = '' - update_tensors = '' - if len(annotated_tensor_indices) > 0: - indices_str = ", ".join([str(i) for i in annotated_tensor_indices]) - collect_mutated_tensors = f'\n std::vector external_tensors_update_indices = {{{indices_str}}};' - # TODO: uncomment the resize line below. Taken out temporarily for testing - update_tensors = ''' - for (int i : external_tensors_update_indices) { - at::_copy_from_and_resize(external_tensors[i], external_tensors_tensors[i]); - } -''' - - returns = '' - if f.func.returns: - ret_names = cpp.return_names(f, fallback_name=cpu_result_name) - if len(ret_names) == 1: - returns = tensor_creation_api( - ret_names[0], f.func.returns[0], - get_device_param(dispatcher_order_args), cpu_result_name=cpu_result_name) - else: - return_args = [ - tensor_creation_api( - ret_names[i], f.func.returns[i], - get_device_param(dispatcher_order_args), cpu_result_name=f'std::get<{i}>({cpu_result_name})' - ) for i in range(len(f.func.returns))] - returns = f'{dispatcher_sig.returns_type().cpp_type()}({", ".join(return_args)})' - return_str = '' - if returns != '': - return_str = f'\n return {returns};' - - return f"""\ -{dispatcher_sig.defn(name=func_name)} {{ - XLA_FN_TRACK(3); - XLA_COUNTER("aten::{name}", 1); - TF_VLOG(3) << "XLA {name} :"{print_args_str}; -{intermediates} - {at_call}{collect_mutated_tensors}{update_tensors}{avoid_warning}{return_str} -}} - -""" - m = self.backend_index.get_kernel(g) - if isinstance(g, NativeFunctionsGroup): - if m is not None and m.structured: - # We can probably only bother generating fallbacks for one of the variants, for structured - raise AssertionError("Not Implemented") - else: - return list(mapMaybe(gen_unstructured_external, g.functions())) - elif isinstance(g, NativeFunction): - f = g - x = gen_unstructured_external(f) - return [x] if x else [] - else: - assert_never(f) diff --git a/tools/codegen/gen.py b/tools/codegen/gen.py index 063307fec6270a..229426dbe04972 100644 --- a/tools/codegen/gen.py +++ b/tools/codegen/gen.py @@ -760,6 +760,10 @@ def compute_declaration_yaml(f: NativeFunction) -> object: ('has_math_kernel', f.has_composite_implicit_autograd_kernel), ]) +# See Note [Auto generated composite kernels] +def has_autogenerated_composite_kernel(f: NativeFunction) -> bool: + return (f.structured or f.structured_delegate is not None) and \ + (f.func.kind() == SchemaKind.functional or f.func.kind() == SchemaKind.inplace) @with_native_function_and_indices def compute_registration_declarations(f: NativeFunction, backend_indices: Dict[DispatchKey, BackendIndex]) -> str: @@ -771,7 +775,7 @@ def compute_registration_declarations(f: NativeFunction, backend_indices: Dict[D 'schema': f'aten::{f.func}', # TODO: What exactly is the semantics of the 'dispatch' field? 'dispatch': str({k for k, v in backend_indices.items() if v.has_kernel(f)} != {DispatchKey.CompositeImplicitAutograd}), - 'default': str(f.has_composite_kernel or dest.has_autogenerated_composite_kernel(f)) + 'default': str(f.has_composite_kernel or has_autogenerated_composite_kernel(f)) } return f"""{returns_type} {name}({args_str}); // {json.dumps(comment_data)} """ diff --git a/tools/codegen/gen_backend_stubs.py b/tools/codegen/gen_backend_stubs.py index a3b9ac254d1ce9..8651f00091b9a8 100644 --- a/tools/codegen/gen_backend_stubs.py +++ b/tools/codegen/gen_backend_stubs.py @@ -173,8 +173,7 @@ def make_file_manager(install_dir: str) -> FileManager: fm.write_with_template(f'Register{dispatch_key}.cpp', 'RegisterDispatchKey.cpp', lambda: { 'extra_cuda_headers': '', 'legacy_th_headers': '', - 'external_backend_headers': f'''#include "{output_dir}/{backend_key}NativeFunctions.h" -#include ''', + 'external_backend_headers': f'#include "{output_dir}/{backend_key}NativeFunctions.h"', 'namespaced_headers': '', 'DispatchKey': dispatch_key, 'dispatch_namespace': dispatch_key.lower(), @@ -207,29 +206,5 @@ def make_file_manager(install_dir: str) -> FileManager: )), }) - fm.write('aten_xla_type_default.h', lambda: { - 'generated_comment': generated_comment, - 'cpp_namespace': cpp_namespace, - 'dispatch_aten_fallback_declarations': list(concatMap( - dest.GenExternalAtenFallback(Target.NAMESPACED_DECLARATION, backend_indices[backend_dispatch_key]), - grouped_native_functions - )), - }) - - fm.write('aten_xla_type_default.cpp', lambda: { - 'generated_comment': generated_comment, - 'cpp_namespace': cpp_namespace, - # TODO: after cpu fallbacks are moved to a boxed kernel, - # merge registrations / definitions into RegisterDispatchKey - 'dispatch_aten_fallback_definitions': list(concatMap( - dest.GenExternalAtenFallback(Target.NAMESPACED_DEFINITION, backend_indices[backend_dispatch_key]), - grouped_native_functions - )), - 'dispatch_registrations': list(concatMap( - dest.GenExternalAtenFallback(Target.REGISTRATION, backend_indices[backend_dispatch_key]), - grouped_native_functions - )), - }) - if __name__ == '__main__': main()