diff --git a/.ci/docker/ci_commit_pins/executorch.txt b/.ci/docker/ci_commit_pins/executorch.txt index 639bfcdd0420e..9aaea8851d475 100644 --- a/.ci/docker/ci_commit_pins/executorch.txt +++ b/.ci/docker/ci_commit_pins/executorch.txt @@ -1 +1 @@ -cd1c833b079adb324871dcbbe75b43d42ffc0ade +export-D64151426 diff --git a/.ci/docker/ci_commit_pins/triton-cpu.txt b/.ci/docker/ci_commit_pins/triton-cpu.txt index 36e9eea06792a..09e347149d1d9 100644 --- a/.ci/docker/ci_commit_pins/triton-cpu.txt +++ b/.ci/docker/ci_commit_pins/triton-cpu.txt @@ -1 +1 @@ -6a333f1b05671f6fada4ba7bbfae4a02a9d96f4f +c7711371cace304afe265c1ffa906415ab82fc66 diff --git a/.ci/docker/common/install_clang.sh b/.ci/docker/common/install_clang.sh index 59d1520bed5ff..f7ef2fb374e4f 100755 --- a/.ci/docker/common/install_clang.sh +++ b/.ci/docker/common/install_clang.sh @@ -20,9 +20,10 @@ if [ -n "$CLANG_VERSION" ]; then fi sudo apt-get update - apt-get install -y --no-install-recommends clang-"$CLANG_VERSION" llvm-"$CLANG_VERSION" - if [[ $CLANG_VERSION == 18 ]]; then - apt-get install -y --no-install-recommends libomp-18-dev + if [[ $CLANG_VERSION -ge 18 ]]; then + apt-get install -y libomp-${CLANG_VERSION}-dev libclang-rt-${CLANG_VERSION}-dev clang-"$CLANG_VERSION" llvm-"$CLANG_VERSION" + else + apt-get install -y --no-install-recommends clang-"$CLANG_VERSION" llvm-"$CLANG_VERSION" fi # Install dev version of LLVM. diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index f3c198044ffed..9357967fbf044 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -65,23 +65,10 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then # Install PyTorch conda deps, as per https://github.com/pytorch/pytorch README if [[ $(uname -m) == "aarch64" ]]; then - CONDA_COMMON_DEPS="astunparse pyyaml setuptools openblas==0.3.25=*openmp* ninja==1.11.1 scons==4.5.2" - - if [ "$ANACONDA_PYTHON_VERSION" = "3.8" ]; then - NUMPY_VERSION=1.24.4 - else - NUMPY_VERSION=1.26.2 - fi + conda_install "openblas==0.3.25=*openmp*" else - CONDA_COMMON_DEPS="astunparse pyyaml mkl=2021.4.0 mkl-include=2021.4.0 setuptools" - - if [ "$ANACONDA_PYTHON_VERSION" = "3.11" ] || [ "$ANACONDA_PYTHON_VERSION" = "3.12" ] || [ "$ANACONDA_PYTHON_VERSION" = "3.13" ]; then - NUMPY_VERSION=1.26.0 - else - NUMPY_VERSION=1.21.2 - fi + conda_install "mkl=2021.4.0 mkl-include=2021.4.0" fi - conda_install ${CONDA_COMMON_DEPS} # Install llvm-8 as it is required to compile llvmlite-0.30.0 from source # and libpython-static for torch deploy @@ -103,8 +90,6 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then # Install some other packages, including those needed for Python test reporting pip_install -r /opt/conda/requirements-ci.txt - pip_install numpy=="$NUMPY_VERSION" - pip_install -U scikit-learn if [ -n "$DOCS" ]; then apt-get update diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index 7df1a5794f92c..ac6ebd828e77c 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -137,6 +137,39 @@ function install_124 { ldconfig } +function install_126 { + echo "Installing CUDA 12.6.2 and cuDNN ${CUDNN_VERSION} and NCCL ${NCCL_VERSION} and cuSparseLt-0.6.2" + rm -rf /usr/local/cuda-12.6 /usr/local/cuda + # install CUDA 12.6.2 in the same container + wget -q https://developer.download.nvidia.com/compute/cuda/12.6.2/local_installers/cuda_12.6.2_560.35.03_linux.run + chmod +x cuda_12.6.2_560.35.03_linux.run + ./cuda_12.6.2_560.35.03_linux.run --toolkit --silent + rm -f cuda_12.6.2_560.35.03_linux.run + rm -f /usr/local/cuda && ln -s /usr/local/cuda-12.6 /usr/local/cuda + + # cuDNN license: https://developer.nvidia.com/cudnn/license_agreement + mkdir tmp_cudnn && cd tmp_cudnn + wget -q https://developer.download.nvidia.com/compute/cudnn/redist/cudnn/linux-x86_64/cudnn-linux-x86_64-${CUDNN_VERSION}_cuda12-archive.tar.xz -O cudnn-linux-x86_64-${CUDNN_VERSION}_cuda12-archive.tar.xz + tar xf cudnn-linux-x86_64-${CUDNN_VERSION}_cuda12-archive.tar.xz + cp -a cudnn-linux-x86_64-${CUDNN_VERSION}_cuda12-archive/include/* /usr/local/cuda/include/ + cp -a cudnn-linux-x86_64-${CUDNN_VERSION}_cuda12-archive/lib/* /usr/local/cuda/lib64/ + cd .. + rm -rf tmp_cudnn + + # NCCL license: https://docs.nvidia.com/deeplearning/nccl/#licenses + # Follow build: https://github.com/NVIDIA/nccl/tree/master?tab=readme-ov-file#build + git clone -b $NCCL_VERSION --depth 1 https://github.com/NVIDIA/nccl.git + cd nccl && make -j src.build + cp -a build/include/* /usr/local/cuda/include/ + cp -a build/lib/* /usr/local/cuda/lib64/ + cd .. + rm -rf nccl + + install_cusparselt_062 + + ldconfig +} + function prune_118 { echo "Pruning CUDA 11.8 and cuDNN" ##################################################################################### @@ -227,12 +260,46 @@ function prune_124 { $NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublasLt_static.a -o $CUDA_LIB_DIR/libcublasLt_static.a ##################################################################################### - # CUDA 12.1 prune visual tools + # CUDA 12.4 prune visual tools ##################################################################################### export CUDA_BASE="/usr/local/cuda-12.4/" rm -rf $CUDA_BASE/libnvvp $CUDA_BASE/nsightee_plugins $CUDA_BASE/nsight-compute-2024.1.0 $CUDA_BASE/nsight-systems-2023.4.4/ } +function prune_126 { + echo "Pruning CUDA 12.6" + ##################################################################################### + # CUDA 12.6 prune static libs + ##################################################################################### + export NVPRUNE="/usr/local/cuda-12.6/bin/nvprune" + export CUDA_LIB_DIR="/usr/local/cuda-12.6/lib64" + + export GENCODE="-gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90" + export GENCODE_CUDNN="-gencode arch=compute_50,code=sm_50 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_80,code=sm_80 -gencode arch=compute_86,code=sm_86 -gencode arch=compute_90,code=sm_90" + + if [[ -n "$OVERRIDE_GENCODE" ]]; then + export GENCODE=$OVERRIDE_GENCODE + fi + if [[ -n "$OVERRIDE_GENCODE_CUDNN" ]]; then + export GENCODE_CUDNN=$OVERRIDE_GENCODE_CUDNN + fi + + # all CUDA libs except CuDNN and CuBLAS + ls $CUDA_LIB_DIR/ | grep "\.a" | grep -v "culibos" | grep -v "cudart" | grep -v "cudnn" | grep -v "cublas" | grep -v "metis" \ + | xargs -I {} bash -c \ + "echo {} && $NVPRUNE $GENCODE $CUDA_LIB_DIR/{} -o $CUDA_LIB_DIR/{}" + + # prune CuDNN and CuBLAS + $NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublas_static.a -o $CUDA_LIB_DIR/libcublas_static.a + $NVPRUNE $GENCODE_CUDNN $CUDA_LIB_DIR/libcublasLt_static.a -o $CUDA_LIB_DIR/libcublasLt_static.a + + ##################################################################################### + # CUDA 12.6 prune visual tools + ##################################################################################### + export CUDA_BASE="/usr/local/cuda-12.6/" + rm -rf $CUDA_BASE/libnvvp $CUDA_BASE/nsightee_plugins $CUDA_BASE/nsight-compute-2024.3.2 $CUDA_BASE/nsight-systems-2024.5.1/ +} + # idiomatic parameter and option handling in sh while test $# -gt 0 do @@ -243,6 +310,8 @@ do ;; 12.4) install_124; prune_124 ;; + 12.6) install_126; prune_126 + ;; *) echo "bad argument $1"; exit 1 ;; esac diff --git a/.ci/docker/common/install_onnx.sh b/.ci/docker/common/install_onnx.sh index b99ecac283d64..24071e4bb0729 100755 --- a/.ci/docker/common/install_onnx.sh +++ b/.ci/docker/common/install_onnx.sh @@ -32,7 +32,7 @@ pip_install coloredlogs packaging pip_install onnxruntime==1.18.1 pip_install onnx==1.16.2 -pip_install onnxscript==0.1.0.dev20240831 --no-deps +pip_install onnxscript==0.1.0.dev20241009 --no-deps # required by onnxscript pip_install ml_dtypes diff --git a/.ci/docker/common/install_xpu.sh b/.ci/docker/common/install_xpu.sh index 0be00d3341522..e4a44b0c962b6 100644 --- a/.ci/docker/common/install_xpu.sh +++ b/.ci/docker/common/install_xpu.sh @@ -41,13 +41,16 @@ function install_ubuntu() { libegl-mesa0 libegl1-mesa libegl1-mesa-dev libgbm1 libgl1-mesa-dev libgl1-mesa-dri \ libglapi-mesa libgles2-mesa-dev libglx-mesa0 libigdgmm12 libxatracker2 mesa-va-drivers \ mesa-vdpau-drivers mesa-vulkan-drivers va-driver-all vainfo hwinfo clinfo + if [[ "${XPU_DRIVER_TYPE,,}" == "rolling" ]]; then + apt-get install -y intel-ocloc + fi # Development Packages apt-get install -y libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev # Install Intel Support Packages if [ -n "$XPU_VERSION" ]; then - apt-get install -y intel-for-pytorch-gpu-dev-${XPU_VERSION} intel-pti-dev + apt-get install -y intel-for-pytorch-gpu-dev-${XPU_VERSION} intel-pti-dev-0.9 else - apt-get install -y intel-for-pytorch-gpu-dev intel-pti-dev + apt-get install -y intel-for-pytorch-gpu-dev-0.5 intel-pti-dev-0.9 fi # Cleanup @@ -97,7 +100,7 @@ EOF intel-igc-opencl-devel level-zero-devel intel-gsc-devel libmetee-devel \ level-zero-devel # Install Intel Support Packages - yum install -y intel-for-pytorch-gpu-dev intel-pti-dev + yum install -y intel-for-pytorch-gpu-dev-0.5 intel-pti-dev-0.9 # Cleanup dnf clean all @@ -131,7 +134,7 @@ function install_sles() { zypper install -y libigdfcl-devel intel-igc-cm libigfxcmrt-devel level-zero-devel # Install Intel Support Packages - zypper install -y intel-for-pytorch-gpu-dev intel-pti-dev + zypper install -y intel-for-pytorch-gpu-dev-0.5 intel-pti-dev-0.9 } diff --git a/.ci/docker/conda/Dockerfile b/.ci/docker/conda/Dockerfile index 93fef77f07ff9..5bfc3a37d506c 100644 --- a/.ci/docker/conda/Dockerfile +++ b/.ci/docker/conda/Dockerfile @@ -70,6 +70,10 @@ FROM cuda as cuda12.4 RUN bash ./install_cuda.sh 12.4 ENV DESIRED_CUDA=12.4 +FROM cuda as cuda12.6 +RUN bash ./install_cuda.sh 12.6 +ENV DESIRED_CUDA=12.6 + # Install MNIST test data FROM base as mnist ADD ./common/install_mnist.sh install_mnist.sh @@ -79,6 +83,7 @@ FROM base as all_cuda COPY --from=cuda11.8 /usr/local/cuda-11.8 /usr/local/cuda-11.8 COPY --from=cuda12.1 /usr/local/cuda-12.1 /usr/local/cuda-12.1 COPY --from=cuda12.4 /usr/local/cuda-12.4 /usr/local/cuda-12.4 +COPY --from=cuda12.6 /usr/local/cuda-12.6 /usr/local/cuda-12.6 # Final step FROM ${BASE_TARGET} as final diff --git a/.ci/docker/libtorch/Dockerfile b/.ci/docker/libtorch/Dockerfile index 2c73f55aff319..187e47724aa87 100644 --- a/.ci/docker/libtorch/Dockerfile +++ b/.ci/docker/libtorch/Dockerfile @@ -66,6 +66,11 @@ RUN bash ./install_cuda.sh 12.4 RUN bash ./install_magma.sh 12.4 RUN ln -sf /usr/local/cuda-12.4 /usr/local/cuda +FROM cuda as cuda12.6 +RUN bash ./install_cuda.sh 12.6 +RUN bash ./install_magma.sh 12.6 +RUN ln -sf /usr/local/cuda-12.6 /usr/local/cuda + FROM cpu as rocm ARG PYTORCH_ROCM_ARCH ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} diff --git a/.ci/docker/manywheel/build_scripts/ssl-check.py b/.ci/docker/manywheel/build_scripts/ssl-check.py index b1df3e1346f38..0fd7eb363144a 100644 --- a/.ci/docker/manywheel/build_scripts/ssl-check.py +++ b/.ci/docker/manywheel/build_scripts/ssl-check.py @@ -1,10 +1,12 @@ # cf. https://github.com/pypa/manylinux/issues/53 +import sys +from urllib.request import urlopen + + GOOD_SSL = "https://google.com" BAD_SSL = "https://self-signed.badssl.com" -import sys - print("Testing SSL certificate checking for Python:", sys.version) @@ -12,14 +14,8 @@ print("This version never checks SSL certs; skipping tests") sys.exit(0) -if sys.version_info[0] >= 3: - from urllib.request import urlopen - - EXC = OSError -else: - from urllib import urlopen - EXC = IOError +EXC = OSError print(f"Connecting to {GOOD_SSL} should work") urlopen(GOOD_SSL) diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt index a1b644910f01a..f530c42d09f6d 100644 --- a/.ci/docker/requirements-ci.txt +++ b/.ci/docker/requirements-ci.txt @@ -5,7 +5,7 @@ #Pinned versions: 1.6 #test that import: -boto3==1.19.12 +boto3==1.35.42 #Description: AWS SDK for python #Pinned versions: 1.19.12, 1.16.34 #test that import: @@ -118,7 +118,7 @@ numba==0.55.2 ; python_version == "3.10" #numpy #Description: Provides N-dimensional arrays and linear algebra -#Pinned versions: 1.20 +#Pinned versions: 1.26.2 #test that import: test_view_ops.py, test_unary_ufuncs.py, test_type_promotion.py, #test_type_info.py, test_torch.py, test_tensorexpr_pybind.py, test_tensorexpr.py, #test_tensorboard.py, test_tensor_creation_ops.py, test_static_runtime.py, @@ -128,6 +128,10 @@ numba==0.55.2 ; python_version == "3.10" #test_nn.py, test_namedtensor.py, test_linalg.py, test_jit_cuda_fuser.py, #test_jit.py, test_indexing.py, test_datapipe.py, test_dataloader.py, #test_binary_ufuncs.py +numpy==1.21.2; python_version == "3.9" +numpy==1.22.4; python_version == "3.10" +numpy==1.26.2; python_version == "3.11" or python_version == "3.12" +numpy==2.1.2; python_version >= "3.13" #onnxruntime #Description: scoring engine for Open Neural Network Exchange (ONNX) models @@ -253,7 +257,7 @@ tb-nightly==2.13.0a20230426 #test that import: # needed by torchgen utils -typing-extensions +typing-extensions>=4.10.0 #Description: type hints for python #Pinned versions: #test that import: @@ -322,13 +326,12 @@ lxml==5.0.0 PyGithub==2.3.0 -sympy==1.12.1 ; python_version == "3.8" sympy==1.13.1 ; python_version >= "3.9" #Description: Required by coremltools, also pinned in .github/requirements/pip-requirements-macOS.txt #Pinned versions: #test that import: -onnx==1.16.1 +onnx==1.17.0 #Description: Required by mypy and test_public_bindings.py when checking torch.onnx._internal #Pinned versions: #test that import: @@ -342,3 +345,26 @@ parameterized==0.8.1 #Description: Parameterizes unittests, both the tests themselves and the entire testing class #Pinned versions: #test that import: + +#Description: required for testing torch/distributed/_tools/sac_estimator.py +#Pinned versions: 1.24.0 +#test that import: test_sac_estimator.py + +pwlf==2.2.1 ; python_version >= "3.8" +#Description: required for testing torch/distributed/_tools/sac_estimator.py +#Pinned versions: 2.2.1 +#test that import: test_sac_estimator.py + + +# To build PyTorch itself +astunparse +PyYAML +setuptools + +ninja==1.11.1 ; platform_machine == "aarch64" +scons==4.5.2 ; platform_machine == "aarch64" + +pulp==2.9.0 ; python_version >= "3.8" +#Description: required for testing ilp formulaiton under torch/distributed/_tools +#Pinned versions: 2.9.0 +#test that import: test_sac_ilp.py diff --git a/.ci/libtorch/build.sh b/.ci/libtorch/build.sh new file mode 100644 index 0000000000000..e822feb2674d9 --- /dev/null +++ b/.ci/libtorch/build.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env bash + +# This is mostly just a shim to manywheel/build.sh +# TODO: Make this a dedicated script to build just libtorch + +set -ex + +SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +USE_CUSPARSELT=0 BUILD_PYTHONLESS=1 DESIRED_PYTHON="3.9" ${SCRIPTPATH}/../manywheel/build.sh diff --git a/.ci/manywheel/LICENSE b/.ci/manywheel/LICENSE new file mode 100644 index 0000000000000..7d8f7841a6197 --- /dev/null +++ b/.ci/manywheel/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2016 manylinux + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/.ci/manywheel/build.sh b/.ci/manywheel/build.sh new file mode 100755 index 0000000000000..e79083ee0cdc9 --- /dev/null +++ b/.ci/manywheel/build.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +set -ex + +SCRIPTPATH="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +case "${GPU_ARCH_TYPE:-BLANK}" in + BLANK) + # Legacy behavior for CircleCI + bash "${SCRIPTPATH}/build_cuda.sh" + ;; + cuda) + bash "${SCRIPTPATH}/build_cuda.sh" + ;; + rocm) + bash "${SCRIPTPATH}/build_rocm.sh" + ;; + cpu | cpu-cxx11-abi | cpu-s390x | xpu) + bash "${SCRIPTPATH}/build_cpu.sh" + ;; + *) + echo "Un-recognized GPU_ARCH_TYPE '${GPU_ARCH_TYPE}', exiting..." + exit 1 + ;; +esac diff --git a/.ci/manywheel/build_common.sh b/.ci/manywheel/build_common.sh new file mode 100644 index 0000000000000..c2823b01865b9 --- /dev/null +++ b/.ci/manywheel/build_common.sh @@ -0,0 +1,505 @@ +#!/usr/bin/env bash +# meant to be called only from the neighboring build.sh and build_cpu.sh scripts + +set -ex +SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + + +# Require only one python installation +if [[ -z "$DESIRED_PYTHON" ]]; then + echo "Need to set DESIRED_PYTHON env variable" + exit 1 +fi +if [[ -n "$BUILD_PYTHONLESS" && -z "$LIBTORCH_VARIANT" ]]; then + echo "BUILD_PYTHONLESS is set, so need LIBTORCH_VARIANT to also be set" + echo "LIBTORCH_VARIANT should be one of shared-with-deps shared-without-deps static-with-deps static-without-deps" + exit 1 +fi + +# Function to retry functions that sometimes timeout or have flaky failures +retry () { + $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) +} + +# TODO move this into the Docker images +OS_NAME=$(awk -F= '/^NAME/{print $2}' /etc/os-release) +if [[ "$OS_NAME" == *"CentOS Linux"* ]]; then + retry yum install -q -y zip openssl +elif [[ "$OS_NAME" == *"AlmaLinux"* ]]; then + retry yum install -q -y zip openssl +elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then + retry dnf install -q -y zip openssl +elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then + # TODO: Remove this once nvidia package repos are back online + # Comment out nvidia repositories to prevent them from getting apt-get updated, see https://github.com/pytorch/pytorch/issues/74968 + # shellcheck disable=SC2046 + sed -i 's/.*nvidia.*/# &/' $(find /etc/apt/ -type f -name "*.list") + + retry apt-get update + retry apt-get -y install zip openssl +fi + +# We use the package name to test the package by passing this to 'pip install' +# This is the env variable that setup.py uses to name the package. Note that +# pip 'normalizes' the name first by changing all - to _ +if [[ -z "$TORCH_PACKAGE_NAME" ]]; then + TORCH_PACKAGE_NAME='torch' +fi + +if [[ -z "$TORCH_NO_PYTHON_PACKAGE_NAME" ]]; then + TORCH_NO_PYTHON_PACKAGE_NAME='torch_no_python' +fi + +TORCH_PACKAGE_NAME="$(echo $TORCH_PACKAGE_NAME | tr '-' '_')" +TORCH_NO_PYTHON_PACKAGE_NAME="$(echo $TORCH_NO_PYTHON_PACKAGE_NAME | tr '-' '_')" +echo "Expecting the built wheels to all be called '$TORCH_PACKAGE_NAME' or '$TORCH_NO_PYTHON_PACKAGE_NAME'" + +# Version: setup.py uses $PYTORCH_BUILD_VERSION.post$PYTORCH_BUILD_NUMBER if +# PYTORCH_BUILD_NUMBER > 1 +build_version="$PYTORCH_BUILD_VERSION" +build_number="$PYTORCH_BUILD_NUMBER" +if [[ -n "$OVERRIDE_PACKAGE_VERSION" ]]; then + # This will be the *exact* version, since build_number<1 + build_version="$OVERRIDE_PACKAGE_VERSION" + build_number=0 +fi +if [[ -z "$build_version" ]]; then + build_version=1.0.0 +fi +if [[ -z "$build_number" ]]; then + build_number=1 +fi +export PYTORCH_BUILD_VERSION=$build_version +export PYTORCH_BUILD_NUMBER=$build_number + +export CMAKE_LIBRARY_PATH="/opt/intel/lib:/lib:$CMAKE_LIBRARY_PATH" +export CMAKE_INCLUDE_PATH="/opt/intel/include:$CMAKE_INCLUDE_PATH" + +if [[ -e /opt/openssl ]]; then + export OPENSSL_ROOT_DIR=/opt/openssl + export CMAKE_INCLUDE_PATH="/opt/openssl/include":$CMAKE_INCLUDE_PATH +fi + +# If given a python version like 3.6m or 2.7mu, convert this to the format we +# expect. The binary CI jobs pass in python versions like this; they also only +# ever pass one python version, so we assume that DESIRED_PYTHON is not a list +# in this case +if [[ -n "$DESIRED_PYTHON" && $DESIRED_PYTHON =~ ([0-9].[0-9]+)t ]]; then + python_digits="$(echo $DESIRED_PYTHON | tr -cd [:digit:])" + py_majmin="${DESIRED_PYTHON}" + DESIRED_PYTHON="cp${python_digits}-cp${python_digits}t" +elif [[ -n "$DESIRED_PYTHON" && "$DESIRED_PYTHON" != cp* ]]; then + python_nodot="$(echo $DESIRED_PYTHON | tr -d m.u)" + DESIRED_PYTHON="cp${python_nodot}-cp${python_nodot}" + if [[ ${python_nodot} -ge 310 ]]; then + py_majmin="${DESIRED_PYTHON:2:1}.${DESIRED_PYTHON:3:2}" + else + py_majmin="${DESIRED_PYTHON:2:1}.${DESIRED_PYTHON:3:1}" + fi +fi + +pydir="/opt/python/$DESIRED_PYTHON" +export PATH="$pydir/bin:$PATH" +echo "Will build for Python version: ${DESIRED_PYTHON} with ${python_installation}" + +mkdir -p /tmp/$WHEELHOUSE_DIR + +export PATCHELF_BIN=/usr/local/bin/patchelf +patchelf_version=$($PATCHELF_BIN --version) +echo "patchelf version: " $patchelf_version +if [[ "$patchelf_version" == "patchelf 0.9" ]]; then + echo "Your patchelf version is too old. Please use version >= 0.10." + exit 1 +fi + +######################################################## +# Compile wheels as well as libtorch +####################################################### +if [[ -z "$PYTORCH_ROOT" ]]; then + echo "Need to set PYTORCH_ROOT env variable" + exit 1 +fi +pushd "$PYTORCH_ROOT" +python setup.py clean +retry pip install -qr requirements.txt +case ${DESIRED_PYTHON} in + cp31*) + retry pip install -q --pre numpy==2.1.0 + ;; + # Should catch 3.9+ + *) + retry pip install -q --pre numpy==2.0.2 + ;; +esac + +if [[ "$DESIRED_DEVTOOLSET" == *"cxx11-abi"* ]]; then + export _GLIBCXX_USE_CXX11_ABI=1 +else + export _GLIBCXX_USE_CXX11_ABI=0 +fi + +if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then + echo "Calling build_amd.py at $(date)" + python tools/amd_build/build_amd.py +fi + +# This value comes from binary_linux_build.sh (and should only be set to true +# for master / release branches) +BUILD_DEBUG_INFO=${BUILD_DEBUG_INFO:=0} + +if [[ $BUILD_DEBUG_INFO == "1" ]]; then + echo "Building wheel and debug info" +else + echo "BUILD_DEBUG_INFO was not set, skipping debug info" +fi + +if [[ "$DISABLE_RCCL" = 1 ]]; then + echo "Disabling NCCL/RCCL in pyTorch" + USE_RCCL=0 + USE_NCCL=0 + USE_KINETO=0 +else + USE_RCCL=1 + USE_NCCL=1 + USE_KINETO=1 +fi + +echo "Calling setup.py bdist at $(date)" + +if [[ "$USE_SPLIT_BUILD" == "true" ]]; then + echo "Calling setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" + time EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ + BUILD_LIBTORCH_WHL=1 BUILD_PYTHON_ONLY=0 \ + BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \ + USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \ + python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR + echo "Finished setup.py bdist_wheel for split build (BUILD_LIBTORCH_WHL)" + echo "Calling setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" + time EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ + BUILD_LIBTORCH_WHL=0 BUILD_PYTHON_ONLY=1 \ + BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \ + USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \ + python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR --cmake + echo "Finished setup.py bdist_wheel for split build (BUILD_PYTHON_ONLY)" +else + time CMAKE_ARGS=${CMAKE_ARGS[@]} \ + EXTRA_CAFFE2_CMAKE_FLAGS=${EXTRA_CAFFE2_CMAKE_FLAGS[@]} \ + BUILD_LIBTORCH_CPU_WITH_DEBUG=$BUILD_DEBUG_INFO \ + USE_NCCL=${USE_NCCL} USE_RCCL=${USE_RCCL} USE_KINETO=${USE_KINETO} \ + python setup.py bdist_wheel -d /tmp/$WHEELHOUSE_DIR +fi +echo "Finished setup.py bdist at $(date)" + +# Build libtorch packages +if [[ -n "$BUILD_PYTHONLESS" ]]; then + # Now build pythonless libtorch + # Note - just use whichever python we happen to be on + python setup.py clean + + if [[ $LIBTORCH_VARIANT = *"static"* ]]; then + STATIC_CMAKE_FLAG="-DTORCH_STATIC=1" + fi + + mkdir -p build + pushd build + echo "Calling tools/build_libtorch.py at $(date)" + time CMAKE_ARGS=${CMAKE_ARGS[@]} \ + EXTRA_CAFFE2_CMAKE_FLAGS="${EXTRA_CAFFE2_CMAKE_FLAGS[@]} $STATIC_CMAKE_FLAG" \ + python ../tools/build_libtorch.py + echo "Finished tools/build_libtorch.py at $(date)" + popd + + mkdir -p libtorch/{lib,bin,include,share} + cp -r build/build/lib libtorch/ + + # for now, the headers for the libtorch package will just be copied in + # from one of the wheels (this is from when this script built multiple + # wheels at once) + ANY_WHEEL=$(ls /tmp/$WHEELHOUSE_DIR/torch*.whl | head -n1) + unzip -d any_wheel $ANY_WHEEL + if [[ -d any_wheel/torch/include ]]; then + cp -r any_wheel/torch/include libtorch/ + else + cp -r any_wheel/torch/lib/include libtorch/ + fi + cp -r any_wheel/torch/share/cmake libtorch/share/ + rm -rf any_wheel + + echo $PYTORCH_BUILD_VERSION > libtorch/build-version + echo "$(pushd $PYTORCH_ROOT && git rev-parse HEAD)" > libtorch/build-hash + + mkdir -p /tmp/$LIBTORCH_HOUSE_DIR + + if [[ "$DESIRED_DEVTOOLSET" == *"cxx11-abi"* ]]; then + LIBTORCH_ABI="cxx11-abi-" + else + LIBTORCH_ABI= + fi + + zip -rq /tmp/$LIBTORCH_HOUSE_DIR/libtorch-$LIBTORCH_ABI$LIBTORCH_VARIANT-$PYTORCH_BUILD_VERSION.zip libtorch + cp /tmp/$LIBTORCH_HOUSE_DIR/libtorch-$LIBTORCH_ABI$LIBTORCH_VARIANT-$PYTORCH_BUILD_VERSION.zip \ + /tmp/$LIBTORCH_HOUSE_DIR/libtorch-$LIBTORCH_ABI$LIBTORCH_VARIANT-latest.zip +fi + +popd + +####################################################################### +# ADD DEPENDENCIES INTO THE WHEEL +# +# auditwheel repair doesn't work correctly and is buggy +# so manually do the work of copying dependency libs and patchelfing +# and fixing RECORDS entries correctly +###################################################################### + +fname_with_sha256() { + HASH=$(sha256sum $1 | cut -c1-8) + DIRNAME=$(dirname $1) + BASENAME=$(basename $1) + # Do not rename nvrtc-builtins.so as they are dynamically loaded + # by libnvrtc.so + # Similarly don't mangle libcudnn and libcublas library names + if [[ $BASENAME == "libnvrtc-builtins.s"* || $BASENAME == "libcudnn"* || $BASENAME == "libcublas"* ]]; then + echo $1 + else + INITNAME=$(echo $BASENAME | cut -f1 -d".") + ENDNAME=$(echo $BASENAME | cut -f 2- -d".") + echo "$DIRNAME/$INITNAME-$HASH.$ENDNAME" + fi +} + +fname_without_so_number() { + LINKNAME=$(echo $1 | sed -e 's/\.so.*/.so/g') + echo "$LINKNAME" +} + +make_wheel_record() { + FPATH=$1 + if echo $FPATH | grep RECORD >/dev/null 2>&1; then + # if the RECORD file, then + echo "$FPATH,," + else + HASH=$(openssl dgst -sha256 -binary $FPATH | openssl base64 | sed -e 's/+/-/g' | sed -e 's/\//_/g' | sed -e 's/=//g') + FSIZE=$(ls -nl $FPATH | awk '{print $5}') + echo "$FPATH,sha256=$HASH,$FSIZE" + fi +} + +replace_needed_sofiles() { + find $1 -name '*.so*' | while read sofile; do + origname=$2 + patchedname=$3 + if [[ "$origname" != "$patchedname" ]] || [[ "$DESIRED_CUDA" == *"rocm"* ]]; then + set +e + origname=$($PATCHELF_BIN --print-needed $sofile | grep "$origname.*") + ERRCODE=$? + set -e + if [ "$ERRCODE" -eq "0" ]; then + echo "patching $sofile entry $origname to $patchedname" + $PATCHELF_BIN --replace-needed $origname $patchedname $sofile + fi + fi + done +} + +echo 'Built this wheel:' +ls /tmp/$WHEELHOUSE_DIR +mkdir -p "/$WHEELHOUSE_DIR" +mv /tmp/$WHEELHOUSE_DIR/torch*linux*.whl /$WHEELHOUSE_DIR/ + +if [[ "$USE_SPLIT_BUILD" == "true" ]]; then + mv /tmp/$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/ || true +fi + +if [[ -n "$BUILD_PYTHONLESS" ]]; then + mkdir -p /$LIBTORCH_HOUSE_DIR + mv /tmp/$LIBTORCH_HOUSE_DIR/*.zip /$LIBTORCH_HOUSE_DIR + rm -rf /tmp/$LIBTORCH_HOUSE_DIR +fi +rm -rf /tmp/$WHEELHOUSE_DIR +rm -rf /tmp_dir +mkdir /tmp_dir +pushd /tmp_dir + +for pkg in /$WHEELHOUSE_DIR/torch_no_python*.whl /$WHEELHOUSE_DIR/torch*linux*.whl /$LIBTORCH_HOUSE_DIR/libtorch*.zip; do + + # if the glob didn't match anything + if [[ ! -e $pkg ]]; then + continue + fi + + rm -rf tmp + mkdir -p tmp + cd tmp + cp $pkg . + + unzip -q $(basename $pkg) + rm -f $(basename $pkg) + + if [[ -d torch ]]; then + PREFIX=torch + else + PREFIX=libtorch + fi + + if [[ $pkg != *"without-deps"* ]]; then + # copy over needed dependent .so files over and tag them with their hash + patched=() + for filepath in "${DEPS_LIST[@]}"; do + filename=$(basename $filepath) + destpath=$PREFIX/lib/$filename + if [[ "$filepath" != "$destpath" ]]; then + cp $filepath $destpath + fi + + # ROCm workaround for roctracer dlopens + if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then + patchedpath=$(fname_without_so_number $destpath) + # Keep the so number for XPU dependencies + elif [[ "$DESIRED_CUDA" == *"xpu"* ]]; then + patchedpath=$destpath + else + patchedpath=$(fname_with_sha256 $destpath) + fi + patchedname=$(basename $patchedpath) + if [[ "$destpath" != "$patchedpath" ]]; then + mv $destpath $patchedpath + fi + patched+=("$patchedname") + echo "Copied $filepath to $patchedpath" + done + + echo "patching to fix the so names to the hashed names" + for ((i=0;i<${#DEPS_LIST[@]};++i)); do + replace_needed_sofiles $PREFIX ${DEPS_SONAME[i]} ${patched[i]} + # do the same for caffe2, if it exists + if [[ -d caffe2 ]]; then + replace_needed_sofiles caffe2 ${DEPS_SONAME[i]} ${patched[i]} + fi + done + + # copy over needed auxiliary files + for ((i=0;i<${#DEPS_AUX_SRCLIST[@]};++i)); do + srcpath=${DEPS_AUX_SRCLIST[i]} + dstpath=$PREFIX/${DEPS_AUX_DSTLIST[i]} + mkdir -p $(dirname $dstpath) + cp $srcpath $dstpath + done + fi + + # set RPATH of _C.so and similar to $ORIGIN, $ORIGIN/lib + find $PREFIX -maxdepth 1 -type f -name "*.so*" | while read sofile; do + echo "Setting rpath of $sofile to ${C_SO_RPATH:-'$ORIGIN:$ORIGIN/lib'}" + $PATCHELF_BIN --set-rpath ${C_SO_RPATH:-'$ORIGIN:$ORIGIN/lib'} ${FORCE_RPATH:-} $sofile + $PATCHELF_BIN --print-rpath $sofile + done + + # set RPATH of lib/ files to $ORIGIN + find $PREFIX/lib -maxdepth 1 -type f -name "*.so*" | while read sofile; do + echo "Setting rpath of $sofile to ${LIB_SO_RPATH:-'$ORIGIN'}" + $PATCHELF_BIN --set-rpath ${LIB_SO_RPATH:-'$ORIGIN'} ${FORCE_RPATH:-} $sofile + $PATCHELF_BIN --print-rpath $sofile + done + + # regenerate the RECORD file with new hashes + record_file=$(echo $(basename $pkg) | sed -e 's/-cp.*$/.dist-info\/RECORD/g') + if [[ -e $record_file ]]; then + echo "Generating new record file $record_file" + : > "$record_file" + # generate records for folders in wheel + find * -type f | while read fname; do + make_wheel_record "$fname" >>"$record_file" + done + fi + + if [[ $BUILD_DEBUG_INFO == "1" ]]; then + pushd "$PREFIX/lib" + + # Duplicate library into debug lib + cp libtorch_cpu.so libtorch_cpu.so.dbg + + # Keep debug symbols on debug lib + strip --only-keep-debug libtorch_cpu.so.dbg + + # Remove debug info from release lib + strip --strip-debug libtorch_cpu.so + + objcopy libtorch_cpu.so --add-gnu-debuglink=libtorch_cpu.so.dbg + + # Zip up debug info + mkdir -p /tmp/debug + mv libtorch_cpu.so.dbg /tmp/debug/libtorch_cpu.so.dbg + CRC32=$(objcopy --dump-section .gnu_debuglink=>(tail -c4 | od -t x4 -An | xargs echo) libtorch_cpu.so) + + pushd /tmp + PKG_NAME=$(basename "$pkg" | sed 's/\.whl$//g') + zip /tmp/debug-whl-libtorch-"$PKG_NAME"-"$CRC32".zip /tmp/debug/libtorch_cpu.so.dbg + cp /tmp/debug-whl-libtorch-"$PKG_NAME"-"$CRC32".zip "$PYTORCH_FINAL_PACKAGE_DIR" + popd + + popd + fi + + # zip up the wheel back + zip -rq $(basename $pkg) $PREIX* + + # replace original wheel + rm -f $pkg + mv $(basename $pkg) $pkg + cd .. + rm -rf tmp +done + +# Copy wheels to host machine for persistence before testing +if [[ -n "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then + mkdir -p "$PYTORCH_FINAL_PACKAGE_DIR" || true + if [[ -n "$BUILD_PYTHONLESS" ]]; then + cp /$LIBTORCH_HOUSE_DIR/libtorch*.zip "$PYTORCH_FINAL_PACKAGE_DIR" + else + cp /$WHEELHOUSE_DIR/torch*.whl "$PYTORCH_FINAL_PACKAGE_DIR" + fi +fi + +# remove stuff before testing +rm -rf /opt/rh +if ls /usr/local/cuda* >/dev/null 2>&1; then + rm -rf /usr/local/cuda* +fi + + +# Test that all the wheels work +if [[ -z "$BUILD_PYTHONLESS" ]]; then + export OMP_NUM_THREADS=4 # on NUMA machines this takes too long + pushd $PYTORCH_ROOT/test + + # Install the wheel for this Python version + if [[ "$USE_SPLIT_BUILD" == "true" ]]; then + pip uninstall -y "$TORCH_NO_PYTHON_PACKAGE_NAME" || true + fi + + pip uninstall -y "$TORCH_PACKAGE_NAME" + + if [[ "$USE_SPLIT_BUILD" == "true" ]]; then + pip install "$TORCH_NO_PYTHON_PACKAGE_NAME" --no-index -f /$WHEELHOUSE_DIR --no-dependencies -v + fi + + pip install "$TORCH_PACKAGE_NAME" --no-index -f /$WHEELHOUSE_DIR --no-dependencies -v + + # Print info on the libraries installed in this wheel + # Rather than adjust find command to skip non-library files with an embedded *.so* in their name, + # since this is only for reporting purposes, we add the || true to the ldd command. + installed_libraries=($(find "$pydir/lib/python${py_majmin}/site-packages/torch/" -name '*.so*')) + echo "The wheel installed all of the libraries: ${installed_libraries[@]}" + for installed_lib in "${installed_libraries[@]}"; do + ldd "$installed_lib" || true + done + + # Run the tests + echo "$(date) :: Running tests" + pushd "$PYTORCH_ROOT" + + #TODO: run_tests.sh and check_binary.sh should be moved to pytorch/pytorch project + LD_LIBRARY_PATH=/usr/local/nvidia/lib64 \ + "/builder/run_tests.sh" manywheel "${py_majmin}" "$DESIRED_CUDA" + popd + echo "$(date) :: Finished tests" +fi diff --git a/.ci/manywheel/build_cpu.sh b/.ci/manywheel/build_cpu.sh new file mode 100755 index 0000000000000..5b8277e44f9e6 --- /dev/null +++ b/.ci/manywheel/build_cpu.sh @@ -0,0 +1,99 @@ +#!/usr/bin/env bash + +set -ex + +GPU_ARCH_TYPE=${GPU_ARCH_TYPE:-cpu} + +export TH_BINARY_BUILD=1 +export USE_CUDA=0 + +# Keep an array of cmake variables to add to +if [[ -z "$CMAKE_ARGS" ]]; then + # These are passed to tools/build_pytorch_libs.sh::build() + CMAKE_ARGS=() +fi +if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then + # These are passed to tools/build_pytorch_libs.sh::build_caffe2() + EXTRA_CAFFE2_CMAKE_FLAGS=() +fi + +DIR_SUFFIX=cpu +if [[ "$GPU_ARCH_TYPE" == "xpu" ]]; then + DIR_SUFFIX=xpu + # Refer https://www.intel.com/content/www/us/en/developer/articles/tool/pytorch-prerequisites-for-intel-gpu/2-5.html + source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh + source /opt/intel/oneapi/pti/latest/env/vars.sh + export USE_STATIC_MKL=1 +fi + +WHEELHOUSE_DIR="wheelhouse$DIR_SUFFIX" +LIBTORCH_HOUSE_DIR="libtorch_house$DIR_SUFFIX" +if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then + if [[ -z "$BUILD_PYTHONLESS" ]]; then + PYTORCH_FINAL_PACKAGE_DIR="/remote/wheelhouse$DIR_SUFFIX" + else + PYTORCH_FINAL_PACKAGE_DIR="/remote/libtorch_house$DIR_SUFFIX" + fi +fi +mkdir -p "$PYTORCH_FINAL_PACKAGE_DIR" || true + +OS_NAME=$(awk -F= '/^NAME/{print $2}' /etc/os-release) +if [[ "$OS_NAME" == *"CentOS Linux"* ]]; then + LIBGOMP_PATH="/usr/lib64/libgomp.so.1" +elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then + LIBGOMP_PATH="/usr/lib64/libgomp.so.1" +elif [[ "$OS_NAME" == *"AlmaLinux"* ]]; then + LIBGOMP_PATH="/usr/lib64/libgomp.so.1" +elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then + if [[ "$(uname -m)" == "s390x" ]]; then + LIBGOMP_PATH="/usr/lib/s390x-linux-gnu/libgomp.so.1" + else + LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1" + fi +fi + +DEPS_LIST=( + "$LIBGOMP_PATH" +) + +DEPS_SONAME=( + "libgomp.so.1" +) + +if [[ "$GPU_ARCH_TYPE" == "xpu" ]]; then + echo "Bundling with xpu support package libs." + DEPS_LIST+=( + "/opt/intel/oneapi/compiler/latest/lib/libsycl-preview.so.7" + "/opt/intel/oneapi/compiler/latest/lib/libOpenCL.so.1" + "/opt/intel/oneapi/compiler/latest/lib/libxptifw.so" + "/opt/intel/oneapi/compiler/latest/lib/libsvml.so" + "/opt/intel/oneapi/compiler/latest/lib/libirng.so" + "/opt/intel/oneapi/compiler/latest/lib/libimf.so" + "/opt/intel/oneapi/compiler/latest/lib/libintlc.so.5" + "/opt/intel/oneapi/compiler/latest/lib/libpi_level_zero.so" + "/opt/intel/oneapi/pti/latest/lib/libpti_view.so.0.9" + "/opt/intel/oneapi/pti/latest/lib/libpti.so.0.9" + ) + DEPS_SONAME+=( + "libsycl-preview.so.7" + "libOpenCL.so.1" + "libxptifw.so" + "libsvml.so" + "libirng.so" + "libimf.so" + "libintlc.so.5" + "libpi_level_zero.so" + "libpti_view.so.0.9" + "libpti.so.0.9" + ) +fi + +rm -rf /usr/local/cuda* + +SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" +if [[ -z "$BUILD_PYTHONLESS" ]]; then + BUILD_SCRIPT=build_common.sh +else + BUILD_SCRIPT=build_libtorch.sh +fi +source ${SOURCE_DIR}/${BUILD_SCRIPT} diff --git a/.ci/manywheel/build_cuda.sh b/.ci/manywheel/build_cuda.sh new file mode 100644 index 0000000000000..4eda14a393da7 --- /dev/null +++ b/.ci/manywheel/build_cuda.sh @@ -0,0 +1,290 @@ +#!/usr/bin/env bash + +set -ex + +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P ))" + +export TORCH_NVCC_FLAGS="-Xfatbin -compress-all" +export NCCL_ROOT_DIR=/usr/local/cuda +export TH_BINARY_BUILD=1 +export USE_STATIC_CUDNN=1 +export USE_STATIC_NCCL=1 +export ATEN_STATIC_CUDA=1 +export USE_CUDA_STATIC_LINK=1 +export INSTALL_TEST=0 # dont install test binaries into site-packages +export USE_CUPTI_SO=0 +export USE_CUSPARSELT=${USE_CUSPARSELT:-1} # Enable if not disabled by libtorch build + +# Keep an array of cmake variables to add to +if [[ -z "$CMAKE_ARGS" ]]; then + # These are passed to tools/build_pytorch_libs.sh::build() + CMAKE_ARGS=() +fi +if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then + # These are passed to tools/build_pytorch_libs.sh::build_caffe2() + EXTRA_CAFFE2_CMAKE_FLAGS=() +fi + +# Determine CUDA version and architectures to build for +# +# NOTE: We should first check `DESIRED_CUDA` when determining `CUDA_VERSION`, +# because in some cases a single Docker image can have multiple CUDA versions +# on it, and `nvcc --version` might not show the CUDA version we want. +if [[ -n "$DESIRED_CUDA" ]]; then + # If the DESIRED_CUDA already matches the format that we expect + if [[ ${DESIRED_CUDA} =~ ^[0-9]+\.[0-9]+$ ]]; then + CUDA_VERSION=${DESIRED_CUDA} + else + # cu90, cu92, cu100, cu101 + if [[ ${#DESIRED_CUDA} -eq 4 ]]; then + CUDA_VERSION="${DESIRED_CUDA:2:1}.${DESIRED_CUDA:3:1}" + elif [[ ${#DESIRED_CUDA} -eq 5 ]]; then + CUDA_VERSION="${DESIRED_CUDA:2:2}.${DESIRED_CUDA:4:1}" + fi + fi + echo "Using CUDA $CUDA_VERSION as determined by DESIRED_CUDA" + + # There really has to be a better way to do this - eli + # Possibly limiting builds to specific cuda versions be delimiting images would be a choice + if [[ "$OS_NAME" == *"Ubuntu"* ]]; then + echo "Switching to CUDA version ${DESIRED_CUDA}" + /builder/conda/switch_cuda_version.sh "${DESIRED_CUDA}" + fi +else + CUDA_VERSION=$(nvcc --version|grep release|cut -f5 -d" "|cut -f1 -d",") + echo "CUDA $CUDA_VERSION Detected" +fi + +cuda_version_nodot=$(echo $CUDA_VERSION | tr -d '.') + +TORCH_CUDA_ARCH_LIST="5.0;6.0;7.0;7.5;8.0;8.6" +case ${CUDA_VERSION} in + 12.4) + if [[ "$GPU_ARCH_TYPE" = "cuda-aarch64" ]]; then + TORCH_CUDA_ARCH_LIST="9.0" + else + TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0+PTX" + fi + EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON") + ;; + 12.1) + TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};9.0" + EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON") + ;; + 11.8) + TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};3.7;9.0" + EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON") + ;; + 11.[67]) + TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST};3.7" + EXTRA_CAFFE2_CMAKE_FLAGS+=("-DATEN_NO_TEST=ON") + ;; + *) + echo "unknown cuda version $CUDA_VERSION" + exit 1 + ;; +esac + +export TORCH_CUDA_ARCH_LIST=${TORCH_CUDA_ARCH_LIST} +echo "${TORCH_CUDA_ARCH_LIST}" + +# Package directories +WHEELHOUSE_DIR="wheelhouse$cuda_version_nodot" +LIBTORCH_HOUSE_DIR="libtorch_house$cuda_version_nodot" +if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then + if [[ -z "$BUILD_PYTHONLESS" ]]; then + PYTORCH_FINAL_PACKAGE_DIR="/remote/wheelhouse$cuda_version_nodot" + else + PYTORCH_FINAL_PACKAGE_DIR="/remote/libtorch_house$cuda_version_nodot" + fi +fi +mkdir -p "$PYTORCH_FINAL_PACKAGE_DIR" || true + +OS_NAME=$(awk -F= '/^NAME/{print $2}' /etc/os-release) +if [[ "$OS_NAME" == *"CentOS Linux"* ]]; then + LIBGOMP_PATH="/usr/lib64/libgomp.so.1" +elif [[ "$OS_NAME" == *"AlmaLinux"* ]]; then + LIBGOMP_PATH="/usr/lib64/libgomp.so.1" +elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then + LIBGOMP_PATH="/usr/lib64/libgomp.so.1" +elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then + LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1" +fi + +DEPS_LIST=( + "$LIBGOMP_PATH" +) +DEPS_SONAME=( + "libgomp.so.1" +) + +if [[ $USE_CUSPARSELT == "1" ]]; then + DEPS_SONAME+=( + "libcusparseLt.so.0" + ) + DEPS_LIST+=( + "/usr/local/cuda/lib64/libcusparseLt.so.0" + ) +fi + +if [[ $CUDA_VERSION == "12.1" || $CUDA_VERSION == "12.4" ]]; then + export USE_STATIC_CUDNN=0 + # Try parallelizing nvcc as well + export TORCH_NVCC_FLAGS="-Xfatbin -compress-all --threads 2" + + if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then + echo "Bundling with cudnn and cublas." + DEPS_LIST+=( + "/usr/local/cuda/lib64/libcudnn_adv.so.9" + "/usr/local/cuda/lib64/libcudnn_cnn.so.9" + "/usr/local/cuda/lib64/libcudnn_graph.so.9" + "/usr/local/cuda/lib64/libcudnn_ops.so.9" + "/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9" + "/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9" + "/usr/local/cuda/lib64/libcudnn_heuristic.so.9" + "/usr/local/cuda/lib64/libcudnn.so.9" + "/usr/local/cuda/lib64/libcublas.so.12" + "/usr/local/cuda/lib64/libcublasLt.so.12" + "/usr/local/cuda/lib64/libcudart.so.12" + "/usr/local/cuda/lib64/libnvToolsExt.so.1" + "/usr/local/cuda/lib64/libnvrtc.so.12" + "/usr/local/cuda/lib64/libnvrtc-builtins.so" + ) + DEPS_SONAME+=( + "libcudnn_adv.so.9" + "libcudnn_cnn.so.9" + "libcudnn_graph.so.9" + "libcudnn_ops.so.9" + "libcudnn_engines_runtime_compiled.so.9" + "libcudnn_engines_precompiled.so.9" + "libcudnn_heuristic.so.9" + "libcudnn.so.9" + "libcublas.so.12" + "libcublasLt.so.12" + "libcudart.so.12" + "libnvToolsExt.so.1" + "libnvrtc.so.12" + "libnvrtc-builtins.so" + ) + else + echo "Using nvidia libs from pypi." + CUDA_RPATHS=( + '$ORIGIN/../../nvidia/cublas/lib' + '$ORIGIN/../../nvidia/cuda_cupti/lib' + '$ORIGIN/../../nvidia/cuda_nvrtc/lib' + '$ORIGIN/../../nvidia/cuda_runtime/lib' + '$ORIGIN/../../nvidia/cudnn/lib' + '$ORIGIN/../../nvidia/cufft/lib' + '$ORIGIN/../../nvidia/curand/lib' + '$ORIGIN/../../nvidia/cusolver/lib' + '$ORIGIN/../../nvidia/cusparse/lib' + '$ORIGIN/../../nvidia/nccl/lib' + '$ORIGIN/../../nvidia/nvtx/lib' + ) + CUDA_RPATHS=$(IFS=: ; echo "${CUDA_RPATHS[*]}") + export C_SO_RPATH=$CUDA_RPATHS':$ORIGIN:$ORIGIN/lib' + export LIB_SO_RPATH=$CUDA_RPATHS':$ORIGIN' + export FORCE_RPATH="--force-rpath" + export USE_STATIC_NCCL=0 + export USE_SYSTEM_NCCL=1 + export ATEN_STATIC_CUDA=0 + export USE_CUDA_STATIC_LINK=0 + export USE_CUPTI_SO=1 + export NCCL_INCLUDE_DIR="/usr/local/cuda/include/" + export NCCL_LIB_DIR="/usr/local/cuda/lib64/" + fi +elif [[ $CUDA_VERSION == "11.8" ]]; then + export USE_STATIC_CUDNN=0 + # Try parallelizing nvcc as well + export TORCH_NVCC_FLAGS="-Xfatbin -compress-all --threads 2" + # Bundle ptxas into the wheel, see https://github.com/pytorch/pytorch/pull/119750 + export BUILD_BUNDLE_PTXAS=1 + + if [[ -z "$PYTORCH_EXTRA_INSTALL_REQUIREMENTS" ]]; then + echo "Bundling with cudnn and cublas." + DEPS_LIST+=( + "/usr/local/cuda/lib64/libcudnn_adv.so.9" + "/usr/local/cuda/lib64/libcudnn_cnn.so.9" + "/usr/local/cuda/lib64/libcudnn_graph.so.9" + "/usr/local/cuda/lib64/libcudnn_ops.so.9" + "/usr/local/cuda/lib64/libcudnn_engines_runtime_compiled.so.9" + "/usr/local/cuda/lib64/libcudnn_engines_precompiled.so.9" + "/usr/local/cuda/lib64/libcudnn_heuristic.so.9" + "/usr/local/cuda/lib64/libcudnn.so.9" + "/usr/local/cuda/lib64/libcublas.so.11" + "/usr/local/cuda/lib64/libcublasLt.so.11" + "/usr/local/cuda/lib64/libcudart.so.11.0" + "/usr/local/cuda/lib64/libnvToolsExt.so.1" + "/usr/local/cuda/lib64/libnvrtc.so.11.2" # this is not a mistake, it links to more specific cuda version + "/usr/local/cuda/lib64/libnvrtc-builtins.so.11.8" + ) + DEPS_SONAME+=( + "libcudnn_adv.so.9" + "libcudnn_cnn.so.9" + "libcudnn_graph.so.9" + "libcudnn_ops.so.9" + "libcudnn_engines_runtime_compiled.so.9" + "libcudnn_engines_precompiled.so.9" + "libcudnn_heuristic.so.9" + "libcudnn.so.9" + "libcublas.so.11" + "libcublasLt.so.11" + "libcudart.so.11.0" + "libnvToolsExt.so.1" + "libnvrtc.so.11.2" + "libnvrtc-builtins.so.11.8" + ) + else + echo "Using nvidia libs from pypi." + CUDA_RPATHS=( + '$ORIGIN/../../nvidia/cublas/lib' + '$ORIGIN/../../nvidia/cuda_cupti/lib' + '$ORIGIN/../../nvidia/cuda_nvrtc/lib' + '$ORIGIN/../../nvidia/cuda_runtime/lib' + '$ORIGIN/../../nvidia/cudnn/lib' + '$ORIGIN/../../nvidia/cufft/lib' + '$ORIGIN/../../nvidia/curand/lib' + '$ORIGIN/../../nvidia/cusolver/lib' + '$ORIGIN/../../nvidia/cusparse/lib' + '$ORIGIN/../../nvidia/nccl/lib' + '$ORIGIN/../../nvidia/nvtx/lib' + ) + CUDA_RPATHS=$(IFS=: ; echo "${CUDA_RPATHS[*]}") + export C_SO_RPATH=$CUDA_RPATHS':$ORIGIN:$ORIGIN/lib' + export LIB_SO_RPATH=$CUDA_RPATHS':$ORIGIN' + export FORCE_RPATH="--force-rpath" + export USE_STATIC_NCCL=0 + export USE_SYSTEM_NCCL=1 + export ATEN_STATIC_CUDA=0 + export USE_CUDA_STATIC_LINK=0 + export USE_CUPTI_SO=1 + export NCCL_INCLUDE_DIR="/usr/local/cuda/include/" + export NCCL_LIB_DIR="/usr/local/cuda/lib64/" + fi +else + echo "Unknown cuda version $CUDA_VERSION" + exit 1 +fi + +# builder/test.sh requires DESIRED_CUDA to know what tests to exclude +export DESIRED_CUDA="$cuda_version_nodot" + +# Switch `/usr/local/cuda` to the desired CUDA version +rm -rf /usr/local/cuda || true +ln -s "/usr/local/cuda-${CUDA_VERSION}" /usr/local/cuda + +# Switch `/usr/local/magma` to the desired CUDA version +rm -rf /usr/local/magma || true +ln -s /usr/local/cuda-${CUDA_VERSION}/magma /usr/local/magma + +export CUDA_VERSION=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev) # 10.0.130 +export CUDA_VERSION_SHORT=$(ls /usr/local/cuda/lib64/libcudart.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev | cut -f1,2 -d".") # 10.0 +export CUDNN_VERSION=$(ls /usr/local/cuda/lib64/libcudnn.so.*|sort|tac | head -1 | rev | cut -d"." -f -3 | rev) + +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +if [[ -z "$BUILD_PYTHONLESS" ]]; then + BUILD_SCRIPT=build_common.sh +else + BUILD_SCRIPT=build_libtorch.sh +fi +source $SCRIPTPATH/${BUILD_SCRIPT} diff --git a/.ci/manywheel/build_libtorch.sh b/.ci/manywheel/build_libtorch.sh new file mode 100644 index 0000000000000..fd330f6435c8c --- /dev/null +++ b/.ci/manywheel/build_libtorch.sh @@ -0,0 +1,353 @@ +#!/usr/bin/env bash +# meant to be called only from the neighboring build.sh and build_cpu.sh scripts + +set -e pipefail +SOURCE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null && pwd )" + +# Require only one python installation +if [[ -z "$DESIRED_PYTHON" ]]; then + echo "Need to set DESIRED_PYTHON env variable" + exit 1 +fi +if [[ -n "$BUILD_PYTHONLESS" && -z "$LIBTORCH_VARIANT" ]]; then + echo "BUILD_PYTHONLESS is set, so need LIBTORCH_VARIANT to also be set" + echo "LIBTORCH_VARIANT should be one of shared-with-deps shared-without-deps static-with-deps static-without-deps" + exit 1 +fi + +# Function to retry functions that sometimes timeout or have flaky failures +retry () { + $* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*) +} + +# TODO move this into the Docker images +OS_NAME=`awk -F= '/^NAME/{print $2}' /etc/os-release` +if [[ "$OS_NAME" == *"CentOS Linux"* ]]; then + retry yum install -q -y zip openssl +elif [[ "$OS_NAME" == *"AlmaLinux"* ]]; then + retry yum install -q -y zip openssl +elif [[ "$OS_NAME" == *"Red Hat Enterprise Linux"* ]]; then + retry dnf install -q -y zip openssl +elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then + # TODO: Remove this once nvidia package repos are back online + # Comment out nvidia repositories to prevent them from getting apt-get updated, see https://github.com/pytorch/pytorch/issues/74968 + # shellcheck disable=SC2046 + sed -i 's/.*nvidia.*/# &/' $(find /etc/apt/ -type f -name "*.list") + retry apt-get update + retry apt-get -y install zip openssl +fi + +# Version: setup.py uses $PYTORCH_BUILD_VERSION.post$PYTORCH_BUILD_NUMBER if +# PYTORCH_BUILD_NUMBER > 1 +build_version="$PYTORCH_BUILD_VERSION" +build_number="$PYTORCH_BUILD_NUMBER" +if [[ -n "$OVERRIDE_PACKAGE_VERSION" ]]; then + # This will be the *exact* version, since build_number<1 + build_version="$OVERRIDE_PACKAGE_VERSION" + build_number=0 +fi +if [[ -z "$build_version" ]]; then + build_version=1.0.0 +fi +if [[ -z "$build_number" ]]; then + build_number=1 +fi +export PYTORCH_BUILD_VERSION=$build_version +export PYTORCH_BUILD_NUMBER=$build_number + +export CMAKE_LIBRARY_PATH="/opt/intel/lib:/lib:$CMAKE_LIBRARY_PATH" +export CMAKE_INCLUDE_PATH="/opt/intel/include:$CMAKE_INCLUDE_PATH" + +# set OPENSSL_ROOT_DIR=/opt/openssl if it exists +if [[ -e /opt/openssl ]]; then + export OPENSSL_ROOT_DIR=/opt/openssl + export CMAKE_INCLUDE_PATH="/opt/openssl/include":$CMAKE_INCLUDE_PATH +fi + +# If given a python version like 3.6m or 2.7mu, convert this to the format we +# expect. The binary CI jobs pass in python versions like this; they also only +# ever pass one python version, so we assume that DESIRED_PYTHON is not a list +# in this case +if [[ -n "$DESIRED_PYTHON" && "$DESIRED_PYTHON" != cp* ]]; then + python_nodot="$(echo $DESIRED_PYTHON | tr -d m.u)" + DESIRED_PYTHON="cp${python_nodot}-cp${python_nodot}" +fi +pydir="/opt/python/$DESIRED_PYTHON" +export PATH="$pydir/bin:$PATH" + +export PATCHELF_BIN=/usr/local/bin/patchelf +patchelf_version=`$PATCHELF_BIN --version` +echo "patchelf version: " $patchelf_version +if [[ "$patchelf_version" == "patchelf 0.9" ]]; then + echo "Your patchelf version is too old. Please use version >= 0.10." + exit 1 +fi + +######################################################## +# Compile wheels as well as libtorch +####################################################### +if [[ -z "$PYTORCH_ROOT" ]]; then + echo "Need to set PYTORCH_ROOT env variable" + exit 1 +fi +pushd "$PYTORCH_ROOT" +python setup.py clean +retry pip install -qr requirements.txt +retry pip install -q numpy==2.0.1 + +if [[ "$DESIRED_DEVTOOLSET" == *"cxx11-abi"* ]]; then + export _GLIBCXX_USE_CXX11_ABI=1 +else + export _GLIBCXX_USE_CXX11_ABI=0 +fi + +if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then + echo "Calling build_amd.py at $(date)" + python tools/amd_build/build_amd.py + # TODO remove this work-around once pytorch sources are updated + export ROCclr_DIR=/opt/rocm/rocclr/lib/cmake/rocclr +fi + +echo "Calling setup.py install at $(date)" + +if [[ $LIBTORCH_VARIANT = *"static"* ]]; then + STATIC_CMAKE_FLAG="-DTORCH_STATIC=1" +fi + +( + set -x + + mkdir -p build + + time CMAKE_ARGS=${CMAKE_ARGS[@]} \ + EXTRA_CAFFE2_CMAKE_FLAGS="${EXTRA_CAFFE2_CMAKE_FLAGS[@]} $STATIC_CMAKE_FLAG" \ + # TODO: Remove this flag once https://github.com/pytorch/pytorch/issues/55952 is closed + CFLAGS='-Wno-deprecated-declarations' \ + BUILD_LIBTORCH_CPU_WITH_DEBUG=1 \ + python setup.py install + + mkdir -p libtorch/{lib,bin,include,share} + + # Make debug folder separate so it doesn't get zipped up with the rest of + # libtorch + mkdir debug + + # Copy over all lib files + cp -rv build/lib/* libtorch/lib/ + cp -rv build/lib*/torch/lib/* libtorch/lib/ + + # Copy over all include files + cp -rv build/include/* libtorch/include/ + cp -rv build/lib*/torch/include/* libtorch/include/ + + # Copy over all of the cmake files + cp -rv build/lib*/torch/share/* libtorch/share/ + + # Split libtorch into debug / release version + cp libtorch/lib/libtorch_cpu.so libtorch/lib/libtorch_cpu.so.dbg + + # Keep debug symbols on debug lib + strip --only-keep-debug libtorch/lib/libtorch_cpu.so.dbg + + # Remove debug info from release lib + strip --strip-debug libtorch/lib/libtorch_cpu.so + + # Add a debug link to the release lib to the debug lib (debuggers will then + # search for symbols in a file called libtorch_cpu.so.dbg in some + # predetermined locations) and embed a CRC32 of the debug library into the .so + cd libtorch/lib + + objcopy libtorch_cpu.so --add-gnu-debuglink=libtorch_cpu.so.dbg + cd ../.. + + # Move the debug symbols to its own directory so it doesn't get processed / + # zipped with all the other libraries + mv libtorch/lib/libtorch_cpu.so.dbg debug/libtorch_cpu.so.dbg + + echo "${PYTORCH_BUILD_VERSION}" > libtorch/build-version + echo "$(pushd $PYTORCH_ROOT && git rev-parse HEAD)" > libtorch/build-hash + +) + +if [[ "$DESIRED_DEVTOOLSET" == *"cxx11-abi"* ]]; then + LIBTORCH_ABI="cxx11-abi-" +else + LIBTORCH_ABI= +fi + +( + set -x + + mkdir -p /tmp/$LIBTORCH_HOUSE_DIR + + # objcopy installs a CRC32 into libtorch_cpu above so, so add that to the name here + CRC32=$(objcopy --dump-section .gnu_debuglink=>(tail -c4 | od -t x4 -An | xargs echo) libtorch/lib/libtorch_cpu.so) + + # Zip debug symbols + zip /tmp/$LIBTORCH_HOUSE_DIR/debug-libtorch-$LIBTORCH_ABI$LIBTORCH_VARIANT-$PYTORCH_BUILD_VERSION-$CRC32.zip debug/libtorch_cpu.so.dbg + + # Zip and copy libtorch + zip -rq /tmp/$LIBTORCH_HOUSE_DIR/libtorch-$LIBTORCH_ABI$LIBTORCH_VARIANT-$PYTORCH_BUILD_VERSION.zip libtorch + cp /tmp/$LIBTORCH_HOUSE_DIR/libtorch-$LIBTORCH_ABI$LIBTORCH_VARIANT-$PYTORCH_BUILD_VERSION.zip \ + /tmp/$LIBTORCH_HOUSE_DIR/libtorch-$LIBTORCH_ABI$LIBTORCH_VARIANT-latest.zip +) + + +popd + +####################################################################### +# ADD DEPENDENCIES INTO THE WHEEL +# +# auditwheel repair doesn't work correctly and is buggy +# so manually do the work of copying dependency libs and patchelfing +# and fixing RECORDS entries correctly +###################################################################### + +fname_with_sha256() { + HASH=$(sha256sum $1 | cut -c1-8) + DIRNAME=$(dirname $1) + BASENAME=$(basename $1) + if [[ $BASENAME == "libnvrtc-builtins.so" || $BASENAME == "libcudnn"* ]]; then + echo $1 + else + INITNAME=$(echo $BASENAME | cut -f1 -d".") + ENDNAME=$(echo $BASENAME | cut -f 2- -d".") + echo "$DIRNAME/$INITNAME-$HASH.$ENDNAME" + fi +} + +fname_without_so_number() { + LINKNAME=$(echo $1 | sed -e 's/\.so.*/.so/g') + echo "$LINKNAME" +} + +make_wheel_record() { + FPATH=$1 + if echo $FPATH | grep RECORD >/dev/null 2>&1; then + # if the RECORD file, then + echo "$FPATH,," + else + HASH=$(openssl dgst -sha256 -binary $FPATH | openssl base64 | sed -e 's/+/-/g' | sed -e 's/\//_/g' | sed -e 's/=//g') + FSIZE=$(ls -nl $FPATH | awk '{print $5}') + echo "$FPATH,sha256=$HASH,$FSIZE" + fi +} + +echo 'Built this package:' +( + set -x + mkdir -p /$LIBTORCH_HOUSE_DIR + mv /tmp/$LIBTORCH_HOUSE_DIR/*.zip /$LIBTORCH_HOUSE_DIR + rm -rf /tmp/$LIBTORCH_HOUSE_DIR +) +TMP_DIR=$(mktemp -d) +trap "rm -rf ${TMP_DIR}" EXIT +pushd "${TMP_DIR}" + +for pkg in /$LIBTORCH_HOUSE_DIR/libtorch*.zip; do + + # if the glob didn't match anything + if [[ ! -e $pkg ]]; then + continue + fi + + rm -rf tmp + mkdir -p tmp + cd tmp + cp $pkg . + + unzip -q $(basename $pkg) + rm -f $(basename $pkg) + + PREFIX=libtorch + + if [[ $pkg != *"without-deps"* ]]; then + # copy over needed dependent .so files over and tag them with their hash + patched=() + for filepath in "${DEPS_LIST[@]}"; do + filename=$(basename $filepath) + destpath=$PREFIX/lib/$filename + if [[ "$filepath" != "$destpath" ]]; then + cp $filepath $destpath + fi + + if [[ "$DESIRED_CUDA" == *"rocm"* ]]; then + patchedpath=$(fname_without_so_number $destpath) + else + patchedpath=$(fname_with_sha256 $destpath) + fi + patchedname=$(basename $patchedpath) + if [[ "$destpath" != "$patchedpath" ]]; then + mv $destpath $patchedpath + fi + patched+=("$patchedname") + echo "Copied $filepath to $patchedpath" + done + + echo "patching to fix the so names to the hashed names" + for ((i=0;i<${#DEPS_LIST[@]};++i)); do + find $PREFIX -name '*.so*' | while read sofile; do + origname=${DEPS_SONAME[i]} + patchedname=${patched[i]} + if [[ "$origname" != "$patchedname" ]] || [[ "$DESIRED_CUDA" == *"rocm"* ]]; then + set +e + origname=$($PATCHELF_BIN --print-needed $sofile | grep "$origname.*") + ERRCODE=$? + set -e + if [ "$ERRCODE" -eq "0" ]; then + echo "patching $sofile entry $origname to $patchedname" + $PATCHELF_BIN --replace-needed $origname $patchedname $sofile + fi + fi + done + done + + # copy over needed auxiliary files + for ((i=0;i<${#DEPS_AUX_SRCLIST[@]};++i)); do + srcpath=${DEPS_AUX_SRCLIST[i]} + dstpath=$PREFIX/${DEPS_AUX_DSTLIST[i]} + mkdir -p $(dirname $dstpath) + cp $srcpath $dstpath + done + fi + + # set RPATH of _C.so and similar to $ORIGIN, $ORIGIN/lib + find $PREFIX -maxdepth 1 -type f -name "*.so*" | while read sofile; do + echo "Setting rpath of $sofile to " '$ORIGIN:$ORIGIN/lib' + $PATCHELF_BIN --set-rpath '$ORIGIN:$ORIGIN/lib' $sofile + $PATCHELF_BIN --print-rpath $sofile + done + + # set RPATH of lib/ files to $ORIGIN + find $PREFIX/lib -maxdepth 1 -type f -name "*.so*" | while read sofile; do + echo "Setting rpath of $sofile to " '$ORIGIN' + $PATCHELF_BIN --set-rpath '$ORIGIN' $sofile + $PATCHELF_BIN --print-rpath $sofile + done + + # regenerate the RECORD file with new hashes + record_file=`echo $(basename $pkg) | sed -e 's/-cp.*$/.dist-info\/RECORD/g'` + if [[ -e $record_file ]]; then + echo "Generating new record file $record_file" + rm -f $record_file + # generate records for folders in wheel + find * -type f | while read fname; do + echo $(make_wheel_record $fname) >>$record_file + done + fi + + # zip up the wheel back + zip -rq $(basename $pkg) $PREFIX* + + # replace original wheel + rm -f $pkg + mv $(basename $pkg) $pkg + cd .. + rm -rf tmp +done + +# Copy wheels to host machine for persistence before testing +if [[ -n "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then + cp /$LIBTORCH_HOUSE_DIR/libtorch*.zip "$PYTORCH_FINAL_PACKAGE_DIR" + cp /$LIBTORCH_HOUSE_DIR/debug-libtorch*.zip "$PYTORCH_FINAL_PACKAGE_DIR" +fi diff --git a/.ci/manywheel/build_rocm.sh b/.ci/manywheel/build_rocm.sh new file mode 100755 index 0000000000000..1e14c9d81d246 --- /dev/null +++ b/.ci/manywheel/build_rocm.sh @@ -0,0 +1,263 @@ +#!/usr/bin/env bash + +set -ex + +export ROCM_HOME=/opt/rocm +export MAGMA_HOME=$ROCM_HOME/magma +# TODO: libtorch_cpu.so is broken when building with Debug info +export BUILD_DEBUG_INFO=0 + +# TODO Are these all used/needed? +export TH_BINARY_BUILD=1 +export USE_STATIC_CUDNN=1 +export USE_STATIC_NCCL=1 +export ATEN_STATIC_CUDA=1 +export USE_CUDA_STATIC_LINK=1 +export INSTALL_TEST=0 # dont install test binaries into site-packages +# Set RPATH instead of RUNPATH when using patchelf to avoid LD_LIBRARY_PATH override +export FORCE_RPATH="--force-rpath" + +# Keep an array of cmake variables to add to +if [[ -z "$CMAKE_ARGS" ]]; then + # These are passed to tools/build_pytorch_libs.sh::build() + CMAKE_ARGS=() +fi +if [[ -z "$EXTRA_CAFFE2_CMAKE_FLAGS" ]]; then + # These are passed to tools/build_pytorch_libs.sh::build_caffe2() + EXTRA_CAFFE2_CMAKE_FLAGS=() +fi + +# Determine ROCm version and architectures to build for +# +# NOTE: We should first check `DESIRED_CUDA` when determining `ROCM_VERSION` +if [[ -n "$DESIRED_CUDA" ]]; then + if ! echo "${DESIRED_CUDA}"| grep "^rocm" >/dev/null 2>/dev/null; then + export DESIRED_CUDA="rocm${DESIRED_CUDA}" + fi + # rocm3.7, rocm3.5.1 + ROCM_VERSION="$DESIRED_CUDA" + echo "Using $ROCM_VERSION as determined by DESIRED_CUDA" +else + echo "Must set DESIRED_CUDA" + exit 1 +fi + +# Package directories +WHEELHOUSE_DIR="wheelhouse$ROCM_VERSION" +LIBTORCH_HOUSE_DIR="libtorch_house$ROCM_VERSION" +if [[ -z "$PYTORCH_FINAL_PACKAGE_DIR" ]]; then + if [[ -z "$BUILD_PYTHONLESS" ]]; then + PYTORCH_FINAL_PACKAGE_DIR="/remote/wheelhouse$ROCM_VERSION" + else + PYTORCH_FINAL_PACKAGE_DIR="/remote/libtorch_house$ROCM_VERSION" + fi +fi +mkdir -p "$PYTORCH_FINAL_PACKAGE_DIR" || true + +# To make version comparison easier, create an integer representation. +ROCM_VERSION_CLEAN=$(echo ${ROCM_VERSION} | sed s/rocm//) +save_IFS="$IFS" +IFS=. ROCM_VERSION_ARRAY=(${ROCM_VERSION_CLEAN}) +IFS="$save_IFS" +if [[ ${#ROCM_VERSION_ARRAY[@]} == 2 ]]; then + ROCM_VERSION_MAJOR=${ROCM_VERSION_ARRAY[0]} + ROCM_VERSION_MINOR=${ROCM_VERSION_ARRAY[1]} + ROCM_VERSION_PATCH=0 +elif [[ ${#ROCM_VERSION_ARRAY[@]} == 3 ]]; then + ROCM_VERSION_MAJOR=${ROCM_VERSION_ARRAY[0]} + ROCM_VERSION_MINOR=${ROCM_VERSION_ARRAY[1]} + ROCM_VERSION_PATCH=${ROCM_VERSION_ARRAY[2]} +else + echo "Unhandled ROCM_VERSION ${ROCM_VERSION}" + exit 1 +fi +ROCM_INT=$(($ROCM_VERSION_MAJOR * 10000 + $ROCM_VERSION_MINOR * 100 + $ROCM_VERSION_PATCH)) + +# Required ROCm libraries +ROCM_SO_FILES=( + "libMIOpen.so" + "libamdhip64.so" + "libhipblas.so" + "libhipfft.so" + "libhiprand.so" + "libhipsolver.so" + "libhipsparse.so" + "libhsa-runtime64.so" + "libamd_comgr.so" + "libmagma.so" + "librccl.so" + "librocblas.so" + "librocfft.so" + "librocm_smi64.so" + "librocrand.so" + "librocsolver.so" + "librocsparse.so" + "libroctracer64.so" + "libroctx64.so" + "libhipblaslt.so" + "libhiprtc.so" +) + +if [[ $ROCM_INT -ge 60100 ]]; then + ROCM_SO_FILES+=("librocprofiler-register.so") +fi + +if [[ $ROCM_INT -ge 60200 ]]; then + ROCM_SO_FILES+=("librocm-core.so") +fi + +OS_NAME=`awk -F= '/^NAME/{print $2}' /etc/os-release` +if [[ "$OS_NAME" == *"CentOS Linux"* ]]; then + LIBGOMP_PATH="/usr/lib64/libgomp.so.1" + LIBNUMA_PATH="/usr/lib64/libnuma.so.1" + LIBELF_PATH="/usr/lib64/libelf.so.1" + LIBTINFO_PATH="/usr/lib64/libtinfo.so.5" + LIBDRM_PATH="/opt/amdgpu/lib64/libdrm.so.2" + LIBDRM_AMDGPU_PATH="/opt/amdgpu/lib64/libdrm_amdgpu.so.1" + if [[ $ROCM_INT -ge 60100 ]]; then + # Below libs are direct dependencies of libhipsolver + LIBSUITESPARSE_CONFIG_PATH="/lib64/libsuitesparseconfig.so.4" + LIBCHOLMOD_PATH="/lib64/libcholmod.so.2" + # Below libs are direct dependencies of libcholmod + LIBAMD_PATH="/lib64/libamd.so.2" + LIBCAMD_PATH="/lib64/libcamd.so.2" + LIBCCOLAMD_PATH="/lib64/libccolamd.so.2" + LIBCOLAMD_PATH="/lib64/libcolamd.so.2" + LIBSATLAS_PATH="/lib64/atlas/libsatlas.so.3" + # Below libs are direct dependencies of libsatlas + LIBGFORTRAN_PATH="/lib64/libgfortran.so.3" + LIBQUADMATH_PATH="/lib64/libquadmath.so.0" + fi + MAYBE_LIB64=lib64 +elif [[ "$OS_NAME" == *"Ubuntu"* ]]; then + LIBGOMP_PATH="/usr/lib/x86_64-linux-gnu/libgomp.so.1" + LIBNUMA_PATH="/usr/lib/x86_64-linux-gnu/libnuma.so.1" + LIBELF_PATH="/usr/lib/x86_64-linux-gnu/libelf.so.1" + if [[ $ROCM_INT -ge 50300 ]]; then + LIBTINFO_PATH="/lib/x86_64-linux-gnu/libtinfo.so.6" + else + LIBTINFO_PATH="/lib/x86_64-linux-gnu/libtinfo.so.5" + fi + LIBDRM_PATH="/usr/lib/x86_64-linux-gnu/libdrm.so.2" + LIBDRM_AMDGPU_PATH="/usr/lib/x86_64-linux-gnu/libdrm_amdgpu.so.1" + if [[ $ROCM_INT -ge 60100 ]]; then + # Below libs are direct dependencies of libhipsolver + LIBCHOLMOD_PATH="/lib/x86_64-linux-gnu/libcholmod.so.3" + # Below libs are direct dependencies of libcholmod + LIBSUITESPARSE_CONFIG_PATH="/lib/x86_64-linux-gnu/libsuitesparseconfig.so.5" + LIBAMD_PATH="/lib/x86_64-linux-gnu/libamd.so.2" + LIBCAMD_PATH="/lib/x86_64-linux-gnu/libcamd.so.2" + LIBCCOLAMD_PATH="/lib/x86_64-linux-gnu/libccolamd.so.2" + LIBCOLAMD_PATH="/lib/x86_64-linux-gnu/libcolamd.so.2" + LIBMETIS_PATH="/lib/x86_64-linux-gnu/libmetis.so.5" + LIBLAPACK_PATH="/lib/x86_64-linux-gnu/liblapack.so.3" + LIBBLAS_PATH="/lib/x86_64-linux-gnu/libblas.so.3" + # Below libs are direct dependencies of libblas + LIBGFORTRAN_PATH="/lib/x86_64-linux-gnu/libgfortran.so.5" + LIBQUADMATH_PATH="/lib/x86_64-linux-gnu/libquadmath.so.0" + fi + MAYBE_LIB64=lib +fi +OS_SO_PATHS=($LIBGOMP_PATH $LIBNUMA_PATH\ + $LIBELF_PATH $LIBTINFO_PATH\ + $LIBDRM_PATH $LIBDRM_AMDGPU_PATH\ + $LIBSUITESPARSE_CONFIG_PATH\ + $LIBCHOLMOD_PATH $LIBAMD_PATH\ + $LIBCAMD_PATH $LIBCCOLAMD_PATH\ + $LIBCOLAMD_PATH $LIBSATLAS_PATH\ + $LIBGFORTRAN_PATH $LIBQUADMATH_PATH\ + $LIBMETIS_PATH $LIBLAPACK_PATH\ + $LIBBLAS_PATH) +OS_SO_FILES=() +for lib in "${OS_SO_PATHS[@]}" +do + file_name="${lib##*/}" # Substring removal of path to get filename + OS_SO_FILES[${#OS_SO_FILES[@]}]=$file_name # Append lib to array +done + +# PyTorch-version specific +# AOTriton dependency only for PyTorch >= 2.4 +if (( $(echo "${PYTORCH_VERSION} 2.4" | awk '{print ($1 >= $2)}') )); then + ROCM_SO_FILES+=("libaotriton_v2.so") +fi + +# rocBLAS library files +ROCBLAS_LIB_SRC=$ROCM_HOME/lib/rocblas/library +ROCBLAS_LIB_DST=lib/rocblas/library +ARCH=$(echo $PYTORCH_ROCM_ARCH | sed 's/;/|/g') # Replace ; seperated arch list to bar for grep +ARCH_SPECIFIC_FILES=$(ls $ROCBLAS_LIB_SRC | grep -E $ARCH) +OTHER_FILES=$(ls $ROCBLAS_LIB_SRC | grep -v gfx) +ROCBLAS_LIB_FILES=($ARCH_SPECIFIC_FILES $OTHER_FILES) + +# hipblaslt library files +HIPBLASLT_LIB_SRC=$ROCM_HOME/lib/hipblaslt/library +HIPBLASLT_LIB_DST=lib/hipblaslt/library +ARCH_SPECIFIC_FILES=$(ls $HIPBLASLT_LIB_SRC | grep -E $ARCH) +OTHER_FILES=$(ls $HIPBLASLT_LIB_SRC | grep -v gfx) +HIPBLASLT_LIB_FILES=($ARCH_SPECIFIC_FILES $OTHER_FILES) + +# ROCm library files +ROCM_SO_PATHS=() +for lib in "${ROCM_SO_FILES[@]}" +do + file_path=($(find $ROCM_HOME/lib/ -name "$lib")) # First search in lib + if [[ -z $file_path ]]; then + if [ -d "$ROCM_HOME/lib64/" ]; then + file_path=($(find $ROCM_HOME/lib64/ -name "$lib")) # Then search in lib64 + fi + fi + if [[ -z $file_path ]]; then + file_path=($(find $ROCM_HOME/ -name "$lib")) # Then search in ROCM_HOME + fi + if [[ -z $file_path ]]; then + echo "Error: Library file $lib is not found." >&2 + exit 1 + fi + ROCM_SO_PATHS[${#ROCM_SO_PATHS[@]}]="$file_path" # Append lib to array +done + +DEPS_LIST=( + ${ROCM_SO_PATHS[*]} + ${OS_SO_PATHS[*]} +) + +DEPS_SONAME=( + ${ROCM_SO_FILES[*]} + ${OS_SO_FILES[*]} +) + +DEPS_AUX_SRCLIST=( + "${ROCBLAS_LIB_FILES[@]/#/$ROCBLAS_LIB_SRC/}" + "${HIPBLASLT_LIB_FILES[@]/#/$HIPBLASLT_LIB_SRC/}" + "/opt/amdgpu/share/libdrm/amdgpu.ids" +) + +DEPS_AUX_DSTLIST=( + "${ROCBLAS_LIB_FILES[@]/#/$ROCBLAS_LIB_DST/}" + "${HIPBLASLT_LIB_FILES[@]/#/$HIPBLASLT_LIB_DST/}" + "share/libdrm/amdgpu.ids" +) + +# MIOpen library files +MIOPEN_SHARE_SRC=$ROCM_HOME/share/miopen/db +MIOPEN_SHARE_DST=share/miopen/db +MIOPEN_SHARE_FILES=($(ls $MIOPEN_SHARE_SRC | grep -E $ARCH)) +DEPS_AUX_SRCLIST+=(${MIOPEN_SHARE_FILES[@]/#/$MIOPEN_SHARE_SRC/}) +DEPS_AUX_DSTLIST+=(${MIOPEN_SHARE_FILES[@]/#/$MIOPEN_SHARE_DST/}) + +# RCCL library files +RCCL_SHARE_SRC=$ROCM_HOME/share/rccl/msccl-algorithms +RCCL_SHARE_DST=share/rccl/msccl-algorithms +RCCL_SHARE_FILES=($(ls $RCCL_SHARE_SRC)) +DEPS_AUX_SRCLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_SRC/}) +DEPS_AUX_DSTLIST+=(${RCCL_SHARE_FILES[@]/#/$RCCL_SHARE_DST/}) + +echo "PYTORCH_ROCM_ARCH: ${PYTORCH_ROCM_ARCH}" + +SCRIPTPATH="$( cd "$(dirname "$0")" ; pwd -P )" +if [[ -z "$BUILD_PYTHONLESS" ]]; then + BUILD_SCRIPT=build_common.sh +else + BUILD_SCRIPT=build_libtorch.sh +fi +source $SCRIPTPATH/${BUILD_SCRIPT} diff --git a/.ci/manywheel/test_wheel.sh b/.ci/manywheel/test_wheel.sh new file mode 100755 index 0000000000000..1ee7cd167d903 --- /dev/null +++ b/.ci/manywheel/test_wheel.sh @@ -0,0 +1,26 @@ +#!/usr/bin/env bash +set -e + +yum install -y wget git + +rm -rf /usr/local/cuda* + +# Install Anaconda +if ! ls /py +then + echo "Miniconda needs to be installed" + wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh + bash ~/miniconda.sh -b -p /py +else + echo "Miniconda is already installed" +fi + +export PATH="/py/bin:$PATH" + +# Anaconda token +if ls /remote/token +then + source /remote/token +fi + +conda install -y conda-build anaconda-client diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index fccb0362d8da0..a612442462ffe 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -178,7 +178,7 @@ fi # sccache will fail for CUDA builds if all cores are used for compiling # gcc 7 with sccache seems to have intermittent OOM issue if all cores are used if [ -z "$MAX_JOBS" ]; then - if { [[ "$BUILD_ENVIRONMENT" == *cuda* ]] || [[ "$BUILD_ENVIRONMENT" == *gcc7* ]]; } && which sccache > /dev/null; then + if { [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; } && which sccache > /dev/null; then export MAX_JOBS=$(($(nproc) - 1)) fi fi @@ -203,10 +203,12 @@ if [[ "${BUILD_ENVIRONMENT}" == *clang* ]]; then fi if [[ "$BUILD_ENVIRONMENT" == *-clang*-asan* ]]; then - export LDSHARED="clang --shared" - export USE_CUDA=0 + if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then + export USE_CUDA=1 + fi export USE_ASAN=1 - export UBSAN_FLAGS="-fno-sanitize-recover=all;-fno-sanitize=float-divide-by-zero;-fno-sanitize=float-cast-overflow" + export REL_WITH_DEB_INFO=1 + export UBSAN_FLAGS="-fno-sanitize-recover=all" unset USE_LLVM fi @@ -218,10 +220,6 @@ if [[ "${BUILD_ENVIRONMENT}" == *-pch* ]]; then export USE_PRECOMPILED_HEADERS=1 fi -if [[ "${BUILD_ENVIRONMENT}" == *linux-focal-py3.7-gcc7-build* ]]; then - export USE_GLOO_WITH_OPENSSL=ON -fi - if [[ "${BUILD_ENVIRONMENT}" != *android* && "${BUILD_ENVIRONMENT}" != *cuda* ]]; then export BUILD_STATIC_RUNTIME_BENCHMARK=ON fi @@ -278,7 +276,6 @@ else # set only when building other architectures # or building non-XLA tests. if [[ "$BUILD_ENVIRONMENT" != *rocm* && - "$BUILD_ENVIRONMENT" != *s390x* && "$BUILD_ENVIRONMENT" != *xla* ]]; then if [[ "$BUILD_ENVIRONMENT" != *py3.8* ]]; then # Install numpy-2.0.2 for builds which are backward compatible with 1.X diff --git a/.ci/pytorch/create_test_cert.py b/.ci/pytorch/create_test_cert.py index 071df513fbd2e..f2be0c13227d1 100644 --- a/.ci/pytorch/create_test_cert.py +++ b/.ci/pytorch/create_test_cert.py @@ -45,8 +45,7 @@ def create_cert(path, C, ST, L, O, key): .not_valid_before(datetime.now(timezone.utc)) .not_valid_after( # Our certificate will be valid for 10 days - datetime.now(timezone.utc) - + timedelta(days=10) + datetime.now(timezone.utc) + timedelta(days=10) ) .add_extension( x509.BasicConstraints(ca=True, path_length=None), @@ -91,8 +90,7 @@ def sign_certificate_request(path, csr_cert, ca_cert, private_ca_key): .not_valid_before(datetime.now(timezone.utc)) .not_valid_after( # Our certificate will be valid for 10 days - datetime.now(timezone.utc) - + timedelta(days=10) + datetime.now(timezone.utc) + timedelta(days=10) # Sign our certificate with our private key ) .sign(private_ca_key, hashes.SHA256()) diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index f95a4ee8749f5..61a6dbef015c8 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -49,16 +49,16 @@ NUM_TEST_SHARDS="${NUM_TEST_SHARDS:=1}" export VALGRIND=ON # export TORCH_INDUCTOR_INSTALL_GXX=ON if [[ "$BUILD_ENVIRONMENT" == *clang9* ]]; then - # clang9 appears to miscompile code involving c10::optional, + # clang9 appears to miscompile code involving std::optional, # such that valgrind complains along these lines: # # Conditional jump or move depends on uninitialised value(s) # at 0x40303A: ~optional_base (Optional.h:281) # by 0x40303A: call (Dispatcher.h:448) - # by 0x40303A: call(at::Tensor const&, c10::ArrayRef, c10::ArrayRef, c10::optional) (basic.cpp:10) + # by 0x40303A: call(at::Tensor const&, c10::ArrayRef, c10::ArrayRef, std::optional) (basic.cpp:10) # by 0x403700: main (basic.cpp:16) # Uninitialised value was created by a stack allocation - # at 0x402AAA: call(at::Tensor const&, c10::ArrayRef, c10::ArrayRef, c10::optional) (basic.cpp:6) + # at 0x402AAA: call(at::Tensor const&, c10::ArrayRef, c10::ArrayRef, std::optional) (basic.cpp:6) # # The problem does not appear with gcc or newer versions of clang (we tested # clang14). So we suppress valgrind testing for clang9 specifically. @@ -72,7 +72,7 @@ if [[ "$BUILD_ENVIRONMENT" == *clang9* ]]; then # # using namespace at; # - # Tensor call(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, c10::optional storage_offset) { + # Tensor call(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, std::optional storage_offset) { # auto op = c10::Dispatcher::singleton() # .findSchemaOrThrow(at::_ops::as_strided::name, at::_ops::as_strided::overload_name) # .typed(); @@ -81,7 +81,7 @@ if [[ "$BUILD_ENVIRONMENT" == *clang9* ]]; then # # int main(int argv) { # Tensor b = empty({3, 4}); - # auto z = call(b, b.sym_sizes(), b.sym_strides(), c10::nullopt); + # auto z = call(b, b.sym_sizes(), b.sym_strides(), std::nullopt); # } export VALGRIND=OFF fi @@ -196,6 +196,9 @@ install_tlparse # ASAN test is not working if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then export ASAN_OPTIONS=detect_leaks=0:symbolize=1:detect_stack_use_after_return=true:strict_init_order=true:detect_odr_violation=1:detect_container_overflow=0:check_initialization_order=true:debug=true + if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then + export ASAN_OPTIONS="${ASAN_OPTIONS}:protect_shadow_gap=0" + fi export UBSAN_OPTIONS=print_stacktrace=1:suppressions=$PWD/ubsan.supp export PYTORCH_TEST_WITH_ASAN=1 export PYTORCH_TEST_WITH_UBSAN=1 @@ -233,8 +236,8 @@ if [[ "$BUILD_ENVIRONMENT" == *asan* ]]; then # it depends on a ton of dynamic libraries that most programs aren't gonna # have, and it applies to child processes. - # TODO: get rid of the hardcoded path - export LD_PRELOAD=/usr/lib/llvm-15/lib/clang/15.0.7/lib/linux/libclang_rt.asan-x86_64.so + LD_PRELOAD=$(clang --print-file-name=libclang_rt.asan-x86_64.so) + export LD_PRELOAD # Disable valgrind for asan export VALGRIND=OFF @@ -281,7 +284,7 @@ test_python_shard() { # modify LD_LIBRARY_PATH to ensure it has the conda env. # This set of tests has been shown to be buggy without it for the split-build - time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests $INCLUDE_CLAUSE --shard "$1" "$NUM_TEST_SHARDS" --verbose $PYTHON_TEST_EXTRA_OPTION + time python test/run_test.py --exclude-jit-executor --exclude-distributed-tests $INCLUDE_CLAUSE --shard "$1" "$NUM_TEST_SHARDS" --verbose $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running assert_git_not_dirty } @@ -307,7 +310,8 @@ test_dynamo_shard() { --exclude-distributed-tests \ --exclude-torch-export-tests \ --shard "$1" "$NUM_TEST_SHARDS" \ - --verbose + --verbose \ + --upload-artifacts-while-running assert_git_not_dirty } @@ -320,6 +324,7 @@ test_inductor_distributed() { python test/run_test.py -i distributed/test_c10d_functional_native.py --verbose python test/run_test.py -i distributed/_tensor/test_dtensor_compile.py --verbose python test/run_test.py -i distributed/tensor/parallel/test_micro_pipeline_tp.py --verbose + python test/run_test.py -i distributed/_composable/test_replicate_with_compiler.py --verbose python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_comm.py --verbose python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_multi_group --verbose python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_with_activation_checkpointing --verbose @@ -331,11 +336,12 @@ test_inductor_distributed() { python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_compute_dtype --verbose python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_mixed_precision.py -k test_reduce_dtype --verbose python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_clip_grad_norm_.py -k test_clip_grad_norm_2d --verbose + python test/run_test.py -i distributed/_composable/fsdp/test_fully_shard_compile.py --verbose python test/run_test.py -i distributed/fsdp/test_fsdp_tp_integration.py -k test_fsdp_tp_integration --verbose # this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported # with if required # gpus aren't available - python test/run_test.py --include distributed/test_dynamo_distributed distributed/test_inductor_collectives --verbose + python test/run_test.py --include distributed/test_dynamo_distributed distributed/test_inductor_collectives distributed/test_compute_comm_reordering --verbose assert_git_not_dirty } @@ -369,21 +375,27 @@ test_inductor_aoti() { CPP_TESTS_DIR="${BUILD_BIN_DIR}" LD_LIBRARY_PATH="${TORCH_LIB_DIR}" python test/run_test.py --cpp --verbose -i cpp/test_aoti_abi_check cpp/test_aoti_inference } -test_inductor_cpp_wrapper_abi_compatible() { - export TORCHINDUCTOR_ABI_COMPATIBLE=1 +test_inductor_cpp_wrapper() { + export TORCHINDUCTOR_CPP_WRAPPER=1 TEST_REPORTS_DIR=$(pwd)/test/test-reports mkdir -p "$TEST_REPORTS_DIR" - echo "Testing Inductor cpp wrapper mode with TORCHINDUCTOR_ABI_COMPATIBLE=1" - PYTORCH_TESTING_DEVICE_ONLY_FOR="" python test/run_test.py --include inductor/test_cpu_cpp_wrapper - python test/run_test.py --include inductor/test_cuda_cpp_wrapper inductor/test_cpu_repro inductor/test_extension_backend - - TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/timm_models.py --device cuda --accuracy --amp \ + python benchmarks/dynamo/timm_models.py --device cuda --accuracy --amp \ --training --inductor --disable-cudagraphs --only vit_base_patch16_224 \ --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_training.csv" python benchmarks/dynamo/check_accuracy.py \ --actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_training.csv" \ --expected "benchmarks/dynamo/ci_expected_accuracy/inductor_timm_training.csv" + + python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ + --bfloat16 --inference --inductor --only hf_T5 --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" + python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ + --bfloat16 --inference --inductor --only llama --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" + python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ + --bfloat16 --inference --inductor --only moco --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" + python benchmarks/dynamo/check_accuracy.py \ + --actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" \ + --expected "benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv" } # "Global" flags for inductor benchmarking controlled by TEST_CONFIG @@ -403,7 +415,7 @@ pr_time_benchmarks() { PYTHONPATH=$(pwd)/benchmarks/dynamo/pr_time_benchmarks source benchmarks/dynamo/pr_time_benchmarks/benchmark_runner.sh "$TEST_REPORTS_DIR/pr_time_benchmarks_results.csv" "benchmarks/dynamo/pr_time_benchmarks/benchmarks" echo "benchmark results on current PR: " cat "$TEST_REPORTS_DIR/pr_time_benchmarks_results.csv" - PYTHONPATH=$(pwd)/benchmarks/dynamo/pr_time_benchmarks python benchmarks/dynamo/pr_time_benchmarks/check_results.py "benchmarks/dynamo/pr_time_benchmarks/expected_results.csv" "$TEST_REPORTS_DIR/pr_time_benchmarks_results.csv" + PYTHONPATH=$(pwd)/benchmarks/dynamo/pr_time_benchmarks python benchmarks/dynamo/pr_time_benchmarks/check_results.py "benchmarks/dynamo/pr_time_benchmarks/expected_results.csv" "$TEST_REPORTS_DIR/pr_time_benchmarks_results.csv" "$TEST_REPORTS_DIR/new_expected_results.csv" } if [[ "${TEST_CONFIG}" == *pr_time_benchmarks* ]]; then @@ -511,7 +523,7 @@ test_perf_for_dashboard() { "${target_flag[@]}" --"$mode" --"$dtype" --export --disable-cudagraphs "$@" \ --output "$TEST_REPORTS_DIR/${backend}_export_${suite}_${dtype}_${mode}_${device}_${target}.csv" fi - TORCHINDUCTOR_ABI_COMPATIBLE=1 $TASKSET python "benchmarks/dynamo/$suite.py" \ + $TASKSET python "benchmarks/dynamo/$suite.py" \ "${target_flag[@]}" --"$mode" --"$dtype" --export-aot-inductor --disable-cudagraphs "$@" \ --output "$TEST_REPORTS_DIR/${backend}_aot_inductor_${suite}_${dtype}_${mode}_${device}_${target}.csv" fi @@ -566,13 +578,6 @@ test_single_dynamo_benchmark() { test_perf_for_dashboard "$suite" \ "${DYNAMO_BENCHMARK_FLAGS[@]}" "$@" "${partition_flags[@]}" else - if [[ "${TEST_CONFIG}" == *aot_inductor* && "${TEST_CONFIG}" != *cpu_aot_inductor* ]]; then - # Test AOTInductor with the ABI-compatible mode on CI - # This can be removed once the ABI-compatible mode becomes default. - # For CPU device, we perfer non ABI-compatible mode on CI when testing AOTInductor. - export TORCHINDUCTOR_ABI_COMPATIBLE=1 - fi - if [[ "${TEST_CONFIG}" == *_avx2* ]]; then TEST_CONFIG=${TEST_CONFIG//_avx2/} fi @@ -648,17 +653,6 @@ test_inductor_torchbench_smoketest_perf() { TEST_REPORTS_DIR=$(pwd)/test/test-reports mkdir -p "$TEST_REPORTS_DIR" - # Test some models in the cpp wrapper mode - TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ - --bfloat16 --inference --inductor --only hf_T5 --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" - TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ - --bfloat16 --inference --inductor --only llama --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" - TORCHINDUCTOR_ABI_COMPATIBLE=1 TORCHINDUCTOR_CPP_WRAPPER=1 python benchmarks/dynamo/torchbench.py --device cuda --accuracy \ - --bfloat16 --inference --inductor --only moco --output "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" - python benchmarks/dynamo/check_accuracy.py \ - --actual "$TEST_REPORTS_DIR/inductor_cpp_wrapper_inference.csv" \ - --expected "benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv" - python benchmarks/dynamo/torchbench.py --device cuda --performance --backend inductor --float16 --training \ --batch-size-file "$(realpath benchmarks/dynamo/torchbench_models_list.txt)" --only hf_Bert \ --output "$TEST_REPORTS_DIR/inductor_training_smoketest.csv" @@ -748,19 +742,9 @@ test_inductor_torchbench_cpu_smoketest_perf(){ fi cat "$output_name" # The threshold value needs to be actively maintained to make this check useful. - python benchmarks/dynamo/check_perf_csv.py -f "$output_name" -t "$speedup_target" + # Allow 1% variance for CPU perf to accommodate perf fluctuation + python benchmarks/dynamo/check_perf_csv.py -f "$output_name" -t "$speedup_target" -s 0.99 done - - # Add a few ABI-compatible accuracy tests for CPU. These can be removed once we turn on ABI-compatible as default. - TORCHINDUCTOR_ABI_COMPATIBLE=1 python benchmarks/dynamo/timm_models.py --device cpu --accuracy \ - --bfloat16 --inference --export-aot-inductor --disable-cudagraphs --only adv_inception_v3 \ - --output "$TEST_REPORTS_DIR/aot_inductor_smoke_test.csv" - TORCHINDUCTOR_ABI_COMPATIBLE=1 python benchmarks/dynamo/timm_models.py --device cpu --accuracy \ - --bfloat16 --inference --export-aot-inductor --disable-cudagraphs --only beit_base_patch16_224 \ - --output "$TEST_REPORTS_DIR/aot_inductor_smoke_test.csv" - python benchmarks/dynamo/check_accuracy.py \ - --actual "$TEST_REPORTS_DIR/aot_inductor_smoke_test.csv" \ - --expected "benchmarks/dynamo/ci_expected_accuracy/aot_inductor_timm_inference.csv" } test_torchbench_gcp_smoketest(){ @@ -1371,7 +1355,7 @@ test_executorch() { echo "Run ExecuTorch regression tests for some models" # TODO(huydhn): Add more coverage here using ExecuTorch's gather models script # shellcheck disable=SC1091 - source .ci/scripts/test.sh mv3 cmake xnnpack-quantization-delegation '' + source .ci/scripts/test_model.sh mv3 cmake xnnpack-quantization-delegation '' popd @@ -1453,7 +1437,6 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then else install_torchaudio cuda fi - install_torchtext install_torchvision TORCH_CUDA_ARCH_LIST="8.0;8.6" pip_install git+https://github.com/pytorch/ao.git id=$((SHARD_NUMBER-1)) @@ -1479,9 +1462,11 @@ elif [[ "${TEST_CONFIG}" == *torchbench* ]]; then fi PYTHONPATH=$(pwd)/torchbench test_dynamo_benchmark torchbench "$id" fi -elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper_abi_compatible* ]]; then +elif [[ "${TEST_CONFIG}" == *inductor_cpp_wrapper* ]]; then + install_torchaudio cuda install_torchvision - test_inductor_cpp_wrapper_abi_compatible + checkout_install_torchbench hf_T5 llama moco + PYTHONPATH=$(pwd)/torchbench test_inductor_cpp_wrapper elif [[ "${TEST_CONFIG}" == *inductor* ]]; then install_torchvision test_inductor_shard "${SHARD_NUMBER}" diff --git a/.ci/pytorch/win-test.sh b/.ci/pytorch/win-test.sh index 09b624183c7ae..dd9acdfaaa96b 100755 --- a/.ci/pytorch/win-test.sh +++ b/.ci/pytorch/win-test.sh @@ -46,6 +46,9 @@ python -m pip install tlparse==0.3.25 # Install parameterized python -m pip install parameterized==0.8.1 +# Install pulp for testing ilps under torch\distributed\_tools +python -m pip install pulp==2.9.0 + run_tests() { # Run nvidia-smi if available for path in '/c/Program Files/NVIDIA Corporation/NVSMI/nvidia-smi.exe' /c/Windows/System32/nvidia-smi.exe; do diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index 106d0917ca68c..046dc7ef9b1e7 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -114,6 +114,12 @@ if [[ "$PACKAGE_TYPE" =~ .*wheel.* && -n "$PYTORCH_BUILD_VERSION" && "$PYTORCH_B fi fi +USE_GLOO_WITH_OPENSSL="ON" +if [[ "$GPU_ARCH_TYPE" =~ .*aarch64.* ]]; then + USE_GLOO_WITH_OPENSSL="OFF" + USE_GOLD_LINKER="OFF" +fi + cat >"$envfile" < /dev/null 2>&1; then - zip -r "logs-${FILE_SUFFIX}.zip" test -i '*.log' + if find "test/test-reports" -name "*.log" 2>/dev/null | grep -q .; then + zip -r "logs-${FILE_SUFFIX}.zip" test/test-reports -i '*.log' fi - name: Zip debugging artifacts for upload @@ -77,7 +77,7 @@ runs: FILE_SUFFIX: ${{ inputs.file-suffix }} run: | # -ir => recursive include all files in pattern - 7z a "test-jsons-$Env:FILE_SUFFIX.zip" -ir'!test\*.json' + 7z a "test-jsons-$Env:FILE_SUFFIX.zip" -ir'!test\test-reports\*.json' - name: Zip test reports for upload if: runner.os == 'Windows' && !inputs.use-gha @@ -86,7 +86,7 @@ runs: FILE_SUFFIX: ${{ inputs.file-suffix }} run: | # -ir => recursive include all files in pattern - 7z a "test-reports-$Env:FILE_SUFFIX.zip" -ir'!test\*.xml' -ir'!test\*.csv' + 7z a "test-reports-$Env:FILE_SUFFIX.zip" -ir'!test\test-reports\*.xml' -ir'!test\test-reports\*.csv' - name: Zip usage log for upload if: runner.os == 'Windows' && !inputs.use-gha @@ -96,7 +96,7 @@ runs: FILE_SUFFIX: ${{ inputs.file-suffix }} run: | # -ir => recursive include all files in pattern - 7z a "logs-$Env:FILE_SUFFIX.zip" 'usage_log.txt' -ir'!test\*.log' + 7z a "logs-$Env:FILE_SUFFIX.zip" 'usage_log.txt' -ir'!test\test-reports\*.log' # S3 upload - name: Store Test Downloaded JSONs on S3 diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 39ea87cdc6b9d..3789810cfb5ab 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -3f0569939c4369bec943fc27d1c9d8dfbc828c26 +79047bf6bdec9e32c4cffd0f9835b347781fefbf diff --git a/.github/ci_commit_pins/torchbench.txt b/.github/ci_commit_pins/torchbench.txt index dcf750b7fae06..21b3c3481f398 100644 --- a/.github/ci_commit_pins/torchbench.txt +++ b/.github/ci_commit_pins/torchbench.txt @@ -1 +1 @@ -23512dbebd44a11eb84afbf53c3c071dd105297e +e522b45cd4535b9dfe067aa68d7315755df38f48 diff --git a/.github/labeler.yml b/.github/labeler.yml index c6b6cc8118b42..12511ee8651bc 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -98,3 +98,9 @@ "module: distributed_checkpoint": - torch/distributed/checkpoint/** - test/distributed/checkpoint/** + +"module: compiled autograd": +- torch/csrc/dynamo/python_compiled_autograd.cpp +- torch/csrc/dynamo/compiled_autograd.h +- torch/_dynamo/compiled_autograd.py +- torch/inductor/test_compiled_autograd.py diff --git a/.github/lf-canary-scale-config.yml b/.github/lf-canary-scale-config.yml deleted file mode 100644 index 26ac07d190821..0000000000000 --- a/.github/lf-canary-scale-config.yml +++ /dev/null @@ -1,251 +0,0 @@ - -# This file is generated by .github/scripts/validate_scale_config.py in test-infra -# It defines runner types that will be provisioned by by LF Self-hosted runners - -# scale-config.yml: -# Powers what instance types are available for GHA auto-scaled -# runners. Runners listed here will be available as self hosted -# runners, configuration is directly pulled from the main branch. -# -# -# NOTES: -# - Linux runners are by default non-ephemeral to reduce the amount of CreateInstaces calls -# to avoid RequestLimitExceeded issues -# - When updating this file, run the following command to validate the YAML and to generate -# corresponding versions of scale-config for the pytorch/pytorch repo and merge the -# pytorch/pytorch changes before merging these changes. -# `python .github/scripts/validate_scale_config.py --test-infra-repo-root [path_to_test-infra_root] --pytorch-repo-root [path_to_pytorch_root]`` -# -# TODO: Add some documentation on how the auto-scaling works -# -# NOTE: Default values, -# -# runner_types: -# runner_label: -# instance_type: m4.large -# os: linux -# max_available: 20 -# disk_size: 50 -# is_ephemeral: true - -runner_types: - lf.c.linux.12xlarge: - disk_size: 200 - instance_type: c5.12xlarge - is_ephemeral: false - max_available: 1000 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.10xlarge.avx2: - disk_size: 200 - instance_type: m4.10xlarge - is_ephemeral: false - max_available: 450 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.24xl.spr-metal: - disk_size: 200 - instance_type: c7i.metal-24xl - is_ephemeral: false - max_available: 150 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.16xlarge.spr: - disk_size: 200 - instance_type: c7i.16xlarge - is_ephemeral: false - max_available: 150 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.9xlarge.ephemeral: - disk_size: 200 - instance_type: c5.9xlarge - is_ephemeral: true - max_available: 50 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - variants: - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs - lf.c.linux.12xlarge.ephemeral: - disk_size: 200 - instance_type: c5.12xlarge - is_ephemeral: true - max_available: 300 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.16xlarge.nvidia.gpu: - disk_size: 150 - instance_type: g3.16xlarge - is_ephemeral: false - max_available: 150 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.24xlarge: - disk_size: 150 - instance_type: c5.24xlarge - is_ephemeral: false - max_available: 500 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.24xlarge.ephemeral: - disk_size: 150 - instance_type: c5.24xlarge - is_ephemeral: true - max_available: 200 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.2xlarge: - disk_size: 150 - instance_type: c5.2xlarge - is_ephemeral: false - max_available: 3120 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.4xlarge: - disk_size: 150 - instance_type: c5.4xlarge - is_ephemeral: false - max_available: 1000 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.4xlarge.nvidia.gpu: - disk_size: 150 - instance_type: g3.4xlarge - is_ephemeral: false - max_available: 1000 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.8xlarge.nvidia.gpu: - disk_size: 150 - instance_type: g3.8xlarge - is_ephemeral: false - max_available: 400 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.g4dn.12xlarge.nvidia.gpu: - disk_size: 150 - instance_type: g4dn.12xlarge - is_ephemeral: false - max_available: 250 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.g4dn.metal.nvidia.gpu: - disk_size: 150 - instance_type: g4dn.metal - is_ephemeral: false - max_available: 300 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.g5.48xlarge.nvidia.gpu: - disk_size: 150 - instance_type: g5.48xlarge - is_ephemeral: false - max_available: 200 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.g5.12xlarge.nvidia.gpu: - disk_size: 150 - instance_type: g5.12xlarge - is_ephemeral: false - max_available: 150 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.g5.4xlarge.nvidia.gpu: - disk_size: 150 - instance_type: g5.4xlarge - is_ephemeral: false - max_available: 2400 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.g6.4xlarge.experimental.nvidia.gpu: - disk_size: 150 - instance_type: g6.4xlarge - is_ephemeral: false - max_available: 50 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.large: - max_available: 1200 - disk_size: 15 - instance_type: c5.large - is_ephemeral: false - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.c.linux.arm64.2xlarge: - disk_size: 256 - instance_type: t4g.2xlarge - is_ephemeral: false - max_available: 200 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-arm64 - lf.c.linux.arm64.m7g.4xlarge: - disk_size: 256 - instance_type: m7g.4xlarge - is_ephemeral: false - max_available: 200 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-arm64 - lf.c.linux.arm64.2xlarge.ephemeral: - disk_size: 256 - instance_type: t4g.2xlarge - is_ephemeral: true - max_available: 200 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-arm64 - lf.c.linux.arm64.m7g.4xlarge.ephemeral: - disk_size: 256 - instance_type: m7g.4xlarge - is_ephemeral: true - max_available: 200 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-arm64 - lf.c.linux.arm64.m7g.metal: - disk_size: 256 - instance_type: m7g.metal - is_ephemeral: false - max_available: 100 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-arm64 - lf.c.windows.g4dn.xlarge: - disk_size: 256 - instance_type: g4dn.xlarge - is_ephemeral: true - max_available: 100 - os: windows - lf.c.windows.g4dn.xlarge.nonephemeral: - disk_size: 256 - instance_type: g4dn.xlarge - is_ephemeral: false - max_available: 100 - os: windows - lf.c.windows.4xlarge: - disk_size: 256 - instance_type: c5d.4xlarge - is_ephemeral: true - max_available: 420 - os: windows - lf.c.windows.4xlarge.nonephemeral: - disk_size: 256 - instance_type: c5d.4xlarge - is_ephemeral: false - max_available: 420 - os: windows - lf.c.windows.8xlarge.nvidia.gpu: - disk_size: 256 - instance_type: p3.2xlarge - is_ephemeral: true - max_available: 300 - os: windows - lf.c.windows.8xlarge.nvidia.gpu.nonephemeral: - disk_size: 256 - instance_type: p3.2xlarge - is_ephemeral: false - max_available: 150 - os: windows - lf.c.windows.g5.4xlarge.nvidia.gpu: - disk_size: 256 - instance_type: g5.4xlarge - is_ephemeral: false - max_available: 250 - os: windows diff --git a/.github/lf-scale-config.yml b/.github/lf-scale-config.yml deleted file mode 100644 index cd2ee5fee2aec..0000000000000 --- a/.github/lf-scale-config.yml +++ /dev/null @@ -1,251 +0,0 @@ - -# This file is generated by .github/scripts/validate_scale_config.py in test-infra -# It defines runner types that will be provisioned by by LF Self-hosted runners - -# scale-config.yml: -# Powers what instance types are available for GHA auto-scaled -# runners. Runners listed here will be available as self hosted -# runners, configuration is directly pulled from the main branch. -# -# -# NOTES: -# - Linux runners are by default non-ephemeral to reduce the amount of CreateInstaces calls -# to avoid RequestLimitExceeded issues -# - When updating this file, run the following command to validate the YAML and to generate -# corresponding versions of scale-config for the pytorch/pytorch repo and merge the -# pytorch/pytorch changes before merging these changes. -# `python .github/scripts/validate_scale_config.py --test-infra-repo-root [path_to_test-infra_root] --pytorch-repo-root [path_to_pytorch_root]`` -# -# TODO: Add some documentation on how the auto-scaling works -# -# NOTE: Default values, -# -# runner_types: -# runner_label: -# instance_type: m4.large -# os: linux -# max_available: 20 -# disk_size: 50 -# is_ephemeral: true - -runner_types: - lf.linux.12xlarge: - disk_size: 200 - instance_type: c5.12xlarge - is_ephemeral: false - max_available: 1000 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.10xlarge.avx2: - disk_size: 200 - instance_type: m4.10xlarge - is_ephemeral: false - max_available: 450 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.24xl.spr-metal: - disk_size: 200 - instance_type: c7i.metal-24xl - is_ephemeral: false - max_available: 150 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.16xlarge.spr: - disk_size: 200 - instance_type: c7i.16xlarge - is_ephemeral: false - max_available: 150 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.9xlarge.ephemeral: - disk_size: 200 - instance_type: c5.9xlarge - is_ephemeral: true - max_available: 50 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - variants: - am2: - ami: amzn2-ami-hvm-2.0.20240306.2-x86_64-ebs - lf.linux.12xlarge.ephemeral: - disk_size: 200 - instance_type: c5.12xlarge - is_ephemeral: true - max_available: 300 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.16xlarge.nvidia.gpu: - disk_size: 150 - instance_type: g3.16xlarge - is_ephemeral: false - max_available: 150 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.24xlarge: - disk_size: 150 - instance_type: c5.24xlarge - is_ephemeral: false - max_available: 500 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.24xlarge.ephemeral: - disk_size: 150 - instance_type: c5.24xlarge - is_ephemeral: true - max_available: 200 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.2xlarge: - disk_size: 150 - instance_type: c5.2xlarge - is_ephemeral: false - max_available: 3120 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.4xlarge: - disk_size: 150 - instance_type: c5.4xlarge - is_ephemeral: false - max_available: 1000 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.4xlarge.nvidia.gpu: - disk_size: 150 - instance_type: g3.4xlarge - is_ephemeral: false - max_available: 1000 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.8xlarge.nvidia.gpu: - disk_size: 150 - instance_type: g3.8xlarge - is_ephemeral: false - max_available: 400 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.g4dn.12xlarge.nvidia.gpu: - disk_size: 150 - instance_type: g4dn.12xlarge - is_ephemeral: false - max_available: 250 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.g4dn.metal.nvidia.gpu: - disk_size: 150 - instance_type: g4dn.metal - is_ephemeral: false - max_available: 300 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.g5.48xlarge.nvidia.gpu: - disk_size: 150 - instance_type: g5.48xlarge - is_ephemeral: false - max_available: 200 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.g5.12xlarge.nvidia.gpu: - disk_size: 150 - instance_type: g5.12xlarge - is_ephemeral: false - max_available: 150 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.g5.4xlarge.nvidia.gpu: - disk_size: 150 - instance_type: g5.4xlarge - is_ephemeral: false - max_available: 2400 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.g6.4xlarge.experimental.nvidia.gpu: - disk_size: 150 - instance_type: g6.4xlarge - is_ephemeral: false - max_available: 50 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.large: - max_available: 1200 - disk_size: 15 - instance_type: c5.large - is_ephemeral: false - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-x86_64 - lf.linux.arm64.2xlarge: - disk_size: 256 - instance_type: t4g.2xlarge - is_ephemeral: false - max_available: 200 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-arm64 - lf.linux.arm64.m7g.4xlarge: - disk_size: 256 - instance_type: m7g.4xlarge - is_ephemeral: false - max_available: 200 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-arm64 - lf.linux.arm64.2xlarge.ephemeral: - disk_size: 256 - instance_type: t4g.2xlarge - is_ephemeral: true - max_available: 200 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-arm64 - lf.linux.arm64.m7g.4xlarge.ephemeral: - disk_size: 256 - instance_type: m7g.4xlarge - is_ephemeral: true - max_available: 200 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-arm64 - lf.linux.arm64.m7g.metal: - disk_size: 256 - instance_type: m7g.metal - is_ephemeral: false - max_available: 100 - os: linux - ami: al2023-ami-2023.5.202*-kernel-6.1-arm64 - lf.windows.g4dn.xlarge: - disk_size: 256 - instance_type: g4dn.xlarge - is_ephemeral: true - max_available: 100 - os: windows - lf.windows.g4dn.xlarge.nonephemeral: - disk_size: 256 - instance_type: g4dn.xlarge - is_ephemeral: false - max_available: 100 - os: windows - lf.windows.4xlarge: - disk_size: 256 - instance_type: c5d.4xlarge - is_ephemeral: true - max_available: 420 - os: windows - lf.windows.4xlarge.nonephemeral: - disk_size: 256 - instance_type: c5d.4xlarge - is_ephemeral: false - max_available: 420 - os: windows - lf.windows.8xlarge.nvidia.gpu: - disk_size: 256 - instance_type: p3.2xlarge - is_ephemeral: true - max_available: 300 - os: windows - lf.windows.8xlarge.nvidia.gpu.nonephemeral: - disk_size: 256 - instance_type: p3.2xlarge - is_ephemeral: false - max_available: 150 - os: windows - lf.windows.g5.4xlarge.nvidia.gpu: - disk_size: 256 - instance_type: g5.4xlarge - is_ephemeral: false - max_available: 250 - os: windows diff --git a/.github/pytorch-probot.yml b/.github/pytorch-probot.yml index 0905e41becd99..c92f35305dbce 100644 --- a/.github/pytorch-probot.yml +++ b/.github/pytorch-probot.yml @@ -22,6 +22,7 @@ ciflow_push_tags: - ciflow/unstable - ciflow/xpu - ciflow/torchbench +- ciflow/autoformat retryable_workflows: - pull - trunk diff --git a/.github/requirements-gha-cache.txt b/.github/requirements-gha-cache.txt index 5d1e45160564a..e24a81cbfbc01 100644 --- a/.github/requirements-gha-cache.txt +++ b/.github/requirements-gha-cache.txt @@ -4,7 +4,7 @@ # docs/cpp/requirements.txt # functorch/docs/requirements.txt # .ci/docker/requirements-ci.txt -boto3==1.19.12 +boto3==1.35.42 jinja2==3.1.4 lintrunner==0.10.7 ninja==1.10.0.post1 diff --git a/.github/requirements/pip-requirements-macOS.txt b/.github/requirements/pip-requirements-macOS.txt index 03107d5416499..f33bb515bd632 100644 --- a/.github/requirements/pip-requirements-macOS.txt +++ b/.github/requirements/pip-requirements-macOS.txt @@ -1,4 +1,4 @@ -boto3==1.19.12 +boto3==1.35.42 hypothesis==6.56.4 expecttest==0.2.1 fbscribelogger==0.1.6 diff --git a/.github/scripts/generate_binary_build_matrix.py b/.github/scripts/generate_binary_build_matrix.py index 42fd537b0cf74..423cf0248cec7 100644 --- a/.github/scripts/generate_binary_build_matrix.py +++ b/.github/scripts/generate_binary_build_matrix.py @@ -77,6 +77,7 @@ "nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | " + "nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | " "nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64'" @@ -368,13 +369,14 @@ def generate_wheels_matrix( # TODO: Enable python 3.13 on rocm, aarch64, windows if ( - gpu_arch_type == "rocm" or (os != "linux" and os != "linux-s390x") - ) and (python_version == "3.13" or python_version == "3.13t"): + gpu_arch_type == "rocm" + or os not in ["linux", "linux-s390x", "macos-arm64"] + ) and python_version in ["3.13", "3.13t"]: continue - # TODO: Enable python 3.13t on xpu and cpu-s390x + # TODO: Enable python 3.13t on xpu and cpu-s390x or MacOS if ( - gpu_arch_type == "xpu" or gpu_arch_type == "cpu-s390x" + gpu_arch_type in ["xpu", "cpu-s390x"] or os == "macos-arm64" ) and python_version == "3.13t": continue @@ -409,7 +411,7 @@ def generate_wheels_matrix( "container_image": WHEEL_CONTAINER_IMAGES[arch_version], "package_type": package_type, "pytorch_extra_install_requirements": ( - PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version] # fmt: skip + PYTORCH_EXTRA_INSTALL_REQUIREMENTS[arch_version] if os != "linux-aarch64" else "" ), @@ -457,7 +459,7 @@ def generate_wheels_matrix( ".", "_" ), "pytorch_extra_install_requirements": ( - PYTORCH_EXTRA_INSTALL_REQUIREMENTS["12.1"] # fmt: skip + PYTORCH_EXTRA_INSTALL_REQUIREMENTS["12.4"] if os != "linux" and gpu_arch_type != "xpu" else "" ), diff --git a/.github/scripts/generate_ci_workflows.py b/.github/scripts/generate_ci_workflows.py index f9c857a3ed9cb..e99f95944245c 100755 --- a/.github/scripts/generate_ci_workflows.py +++ b/.github/scripts/generate_ci_workflows.py @@ -114,20 +114,21 @@ class OperatingSystem: isolated_workflow=True, ), ), - BinaryBuildWorkflow( - os=OperatingSystem.LINUX, - package_type="manywheel", - build_configs=generate_binary_build_matrix.generate_wheels_matrix( - OperatingSystem.LINUX, - use_split_build=True, - arches=["11.8", "12.1", "12.4", "cpu"], - ), - ciflow_config=CIFlowConfig( - labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL}, - isolated_workflow=True, - ), - use_split_build=True, - ), + # See https://github.com/pytorch/pytorch/issues/138750 + # BinaryBuildWorkflow( + # os=OperatingSystem.LINUX, + # package_type="manywheel", + # build_configs=generate_binary_build_matrix.generate_wheels_matrix( + # OperatingSystem.LINUX, + # use_split_build=True, + # arches=["11.8", "12.1", "12.4", "cpu"], + # ), + # ciflow_config=CIFlowConfig( + # labels={LABEL_CIFLOW_BINARIES, LABEL_CIFLOW_BINARIES_WHEEL}, + # isolated_workflow=True, + # ), + # use_split_build=True, + # ), BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="conda", @@ -180,21 +181,22 @@ class OperatingSystem: ), branches="main", ), - BinaryBuildWorkflow( - os=OperatingSystem.LINUX, - package_type="manywheel", - build_configs=generate_binary_build_matrix.generate_wheels_matrix( - OperatingSystem.LINUX, - arches=["11.8", "12.1", "12.4"], - python_versions=["3.9"], - use_split_build=True, - ), - ciflow_config=CIFlowConfig( - labels={LABEL_CIFLOW_PERIODIC}, - ), - branches="main", - use_split_build=True, - ), + # See https://github.com/pytorch/pytorch/issues/138750 + # BinaryBuildWorkflow( + # os=OperatingSystem.LINUX, + # package_type="manywheel", + # build_configs=generate_binary_build_matrix.generate_wheels_matrix( + # OperatingSystem.LINUX, + # arches=["11.8", "12.1", "12.4"], + # python_versions=["3.9"], + # use_split_build=True, + # ), + # ciflow_config=CIFlowConfig( + # labels={LABEL_CIFLOW_PERIODIC}, + # ), + # branches="main", + # use_split_build=True, + # ), BinaryBuildWorkflow( os=OperatingSystem.LINUX, package_type="libtorch", diff --git a/.github/scripts/lintrunner.sh b/.github/scripts/lintrunner.sh index ffc98182b5b56..a988c7ac807d1 100755 --- a/.github/scripts/lintrunner.sh +++ b/.github/scripts/lintrunner.sh @@ -41,7 +41,8 @@ RC=0 if ! lintrunner --force-color --tee-json=lint.json ${ADDITIONAL_LINTRUNNER_ARGS} 2> /dev/null; then echo "" echo -e "\e[1m\e[36mYou can reproduce these results locally by using \`lintrunner -m origin/main\`. (If you don't get the same results, run \'lintrunner init\' to update your local linter)\e[0m" - echo -e "\e[1m\e[36mSee https://github.com/pytorch/pytorch/wiki/lintrunner for setup instructions.\e[0m" + echo -e "\e[1m\e[36mSee https://github.com/pytorch/pytorch/wiki/lintrunner for setup instructions. To apply suggested patches automatically, use the -a flag. Before pushing another commit,\e[0m" + echo -e "\e[1m\e[36mplease verify locally and ensure everything passes.\e[0m" RC=1 fi diff --git a/.github/scripts/runner_determinator.py b/.github/scripts/runner_determinator.py index df7bb4d867332..bd1c9a37980a5 100644 --- a/.github/scripts/runner_determinator.py +++ b/.github/scripts/runner_determinator.py @@ -39,7 +39,8 @@ experiments: lf: rollout_percent: 25 - + all_branches: false + default: true --- # Opt-ins: @@ -57,7 +58,7 @@ import random from argparse import ArgumentParser from logging import LogRecord -from typing import Any, Dict, Iterable, List, NamedTuple, Tuple +from typing import Any, Dict, FrozenSet, Iterable, List, NamedTuple, Tuple import yaml from github import Auth, Github @@ -86,6 +87,9 @@ class Experiment(NamedTuple): all_branches: bool = ( False # If True, the experiment is also enabled on the exception branches ) + default: bool = ( + True # If True, the experiment is enabled by default for all queries + ) # Add more fields as needed @@ -140,6 +144,12 @@ def set_github_output(key: str, value: str) -> None: f.write(f"{key}={value}\n") +def _str_comma_separated_to_set(value: str) -> FrozenSet[str]: + return frozenset( + filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(","))) + ) + + def parse_args() -> Any: parser = ArgumentParser("Get dynamic rollout settings") parser.add_argument("--github-token", type=str, required=True, help="GitHub token") @@ -174,6 +184,13 @@ def parse_args() -> Any: required=True, help="Current GitHub ref type, branch or tag", ) + parser.add_argument( + "--eligible-experiments", + type=_str_comma_separated_to_set, + required=False, + default="", + help="comma separated list of experiments to check, if omitted all experiments marked with default=True are checked", + ) return parser.parse_args() @@ -348,6 +365,7 @@ def get_runner_prefix( rollout_state: str, workflow_requestors: Iterable[str], branch: str, + eligible_experiments: FrozenSet[str] = frozenset(), is_canary: bool = False, ) -> str: settings = parse_settings(rollout_state) @@ -356,14 +374,25 @@ def get_runner_prefix( fleet_prefix = "" prefixes = [] for experiment_name, experiment_settings in settings.experiments.items(): - enabled = False - if not experiment_settings.all_branches and is_exception_branch(branch): log.info( f"Branch {branch} is an exception branch. Not enabling experiment {experiment_name}." ) continue + if eligible_experiments: + if experiment_name not in eligible_experiments: + exp_list = ", ".join(eligible_experiments) + log.info( + f"Skipping experiment '{experiment_name}', as it is not in the eligible_experiments list: {exp_list}" + ) + continue + elif not experiment_settings.default: + log.info( + f"Skipping experiment '{experiment_name}', as it is not a default experiment" + ) + continue + # Is any workflow_requestor opted in to this experiment? opted_in_users = [ requestor @@ -371,11 +400,13 @@ def get_runner_prefix( if is_user_opted_in(requestor, user_optins, experiment_name) ] + enabled = False if opted_in_users: log.info( f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}." ) enabled = True + elif experiment_settings.rollout_perc: # If no user is opted in, then we randomly enable the experiment based on the rollout percentage if random.uniform(0, 100) <= experiment_settings.rollout_perc: @@ -444,6 +475,7 @@ def main() -> None: rollout_state, (args.github_issue_owner, username), args.github_branch, + args.eligible_experiments, is_canary, ) diff --git a/.github/scripts/test_runner_determinator.py b/.github/scripts/test_runner_determinator.py index 086d2732c2141..b3e3ec55c3486 100644 --- a/.github/scripts/test_runner_determinator.py +++ b/.github/scripts/test_runner_determinator.py @@ -16,6 +16,7 @@ def test_parse_settings(self) -> None: rollout_perc: 25 otherExp: rollout_perc: 0 + default: false --- Users: @@ -32,7 +33,7 @@ def test_parse_settings(self) -> None: "lf settings not parsed correctly", ) self.assertTupleEqual( - rd.Experiment(rollout_perc=0), + rd.Experiment(rollout_perc=0, default=False), settings.experiments["otherExp"], "otherExp settings not parsed correctly", ) @@ -46,7 +47,7 @@ def test_parse_settings_in_code_block(self) -> None: rollout_perc: 25 otherExp: rollout_perc: 0 - + default: false ``` --- @@ -65,7 +66,7 @@ def test_parse_settings_in_code_block(self) -> None: "lf settings not parsed correctly", ) self.assertTupleEqual( - rd.Experiment(rollout_perc=0), + rd.Experiment(rollout_perc=0, default=False), settings.experiments["otherExp"], "otherExp settings not parsed correctly", ) @@ -177,6 +178,64 @@ def test_opted_in_user_two_experiments(self) -> None: prefix = rd.get_runner_prefix(settings_text, ["User2"], USER_BRANCH) self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for User2") + def test_opted_in_user_two_experiments_default(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + default: false + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + prefix = rd.get_runner_prefix(settings_text, ["User2"], USER_BRANCH) + self.assertEqual("lf.", prefix, "Runner prefix not correct for User2") + + def test_opted_in_user_two_experiments_default_exp(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + default: false + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + prefix = rd.get_runner_prefix( + settings_text, ["User2"], USER_BRANCH, frozenset(["lf", "otherExp"]) + ) + self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for User2") + + def test_opted_in_user_two_experiments_default_exp_2(self) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 0 + otherExp: + rollout_perc: 0 + default: false + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + prefix = rd.get_runner_prefix( + settings_text, ["User2"], USER_BRANCH, frozenset(["otherExp"]) + ) + self.assertEqual("otherExp.", prefix, "Runner prefix not correct for User2") + @patch("random.uniform", return_value=50) def test_opted_out_user(self, mock_uniform: Mock) -> None: settings_text = """ @@ -215,6 +274,77 @@ def test_opted_out_user_was_pulled_in_by_rollout(self, mock_uniform: Mock) -> No prefix = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) self.assertEqual("lf.otherExp.", prefix, "Runner prefix not correct for user") + @patch("random.uniform", return_value=10) + def test_opted_out_user_was_pulled_in_by_rollout_excl_nondefault( + self, mock_uniform: Mock + ) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 25 + default: false + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + # User3 is opted out, but is pulled into default experiments by the 10% rollout + prefix = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual("lf.", prefix, "Runner prefix not correct for user") + + @patch("random.uniform", return_value=10) + def test_opted_out_user_was_pulled_in_by_rollout_filter_exp( + self, mock_uniform: Mock + ) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 25 + otherExp: + rollout_perc: 25 + default: false + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + # User3 is opted out, but is pulled into default experiments by the 10% rollout + prefix = rd.get_runner_prefix( + settings_text, ["User3"], USER_BRANCH, frozenset(["otherExp"]) + ) + self.assertEqual("otherExp.", prefix, "Runner prefix not correct for user") + + @patch("random.uniform", return_value=25) + def test_opted_out_user_was_pulled_out_by_rollout_filter_exp( + self, mock_uniform: Mock + ) -> None: + settings_text = """ + experiments: + lf: + rollout_perc: 10 + otherExp: + rollout_perc: 50 + default: false + --- + + Users: + @User1,lf + @User2,lf,otherExp + + """ + + # User3 is opted out, but is pulled into default experiments by the 10% rollout + prefix = rd.get_runner_prefix(settings_text, ["User3"], USER_BRANCH) + self.assertEqual("", prefix, "Runner prefix not correct for user") + def test_lf_prefix_always_comes_first(self) -> None: settings_text = """ experiments: diff --git a/.github/scripts/trymerge.py b/.github/scripts/trymerge.py index 340a2810e8577..adefb2fb34a9e 100755 --- a/.github/scripts/trymerge.py +++ b/.github/scripts/trymerge.py @@ -1506,7 +1506,7 @@ def checks_to_str(checks: List[Tuple[str, Optional[str]]]) -> str: def checks_to_markdown_bullets( - checks: List[Tuple[str, Optional[str], Optional[int]]] + checks: List[Tuple[str, Optional[str], Optional[int]]], ) -> List[str]: return [ f"- [{c[0]}]({c[1]})" if c[1] is not None else f"- {c[0]}" for c in checks[:5] diff --git a/.github/templates/common.yml.j2 b/.github/templates/common.yml.j2 index 8db7da9456a6b..5330b3a4c612d 100644 --- a/.github/templates/common.yml.j2 +++ b/.github/templates/common.yml.j2 @@ -25,7 +25,7 @@ concurrency: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -40,6 +40,16 @@ concurrency: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell diff --git a/.github/templates/linux_binary_build_workflow.yml.j2 b/.github/templates/linux_binary_build_workflow.yml.j2 index 3d1c1ace0617f..3c68977a40017 100644 --- a/.github/templates/linux_binary_build_workflow.yml.j2 +++ b/.github/templates/linux_binary_build_workflow.yml.j2 @@ -54,7 +54,7 @@ env: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/templates/windows_binary_build_workflow.yml.j2 b/.github/templates/windows_binary_build_workflow.yml.j2 index 9ba9af06a2ef4..f8006581ee79d 100644 --- a/.github/templates/windows_binary_build_workflow.yml.j2 +++ b/.github/templates/windows_binary_build_workflow.yml.j2 @@ -55,7 +55,7 @@ env: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/_binary-build-linux.yml b/.github/workflows/_binary-build-linux.yml index 509312c30bdfe..120439c7c114d 100644 --- a/.github/workflows/_binary-build-linux.yml +++ b/.github/workflows/_binary-build-linux.yml @@ -271,7 +271,9 @@ jobs: ) docker exec -t -w "${PYTORCH_ROOT}" "${container_name}" bash -c "bash .circleci/scripts/binary_populate_env.sh" if [[ ${BUILD_ENVIRONMENT} == *"aarch64"* ]]; then - docker exec -t "${container_name}" bash -c "bash /builder/aarch64_linux/aarch64_ci_build.sh" + docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /builder/aarch64_linux/aarch64_ci_build.sh" + elif [[ ${{ inputs.PACKAGE_TYPE }} == "manywheel" || ${{ inputs.PACKAGE_TYPE }} == "libtorch" ]]; then + docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /pytorch/.ci/${{ inputs.PACKAGE_TYPE }}/build.sh" else docker exec -t "${container_name}" bash -c "source ${BINARY_ENV_FILE} && bash /builder/${{ inputs.PACKAGE_TYPE }}/build.sh" fi diff --git a/.github/workflows/_linux-test.yml b/.github/workflows/_linux-test.yml index cb577a5b09485..eed89ae6ffadc 100644 --- a/.github/workflows/_linux-test.yml +++ b/.github/workflows/_linux-test.yml @@ -230,7 +230,7 @@ jobs: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} SCRIBE_GRAPHQL_ACCESS_TOKEN: ${{ secrets.SCRIBE_GRAPHQL_ACCESS_TOKEN }} IS_A100_RUNNER: ${{ contains(matrix.runner, 'a100') && '1' || '0' }} - + ARTIFACTS_FILE_SUFFIX: ${{ github.job }}-${{ matrix.config }}-${{ matrix.shard }}-${{ matrix.num_shards }}-${{ matrix.runner }}_${{ steps.get-job-id.outputs.job-id }} run: | set -x @@ -289,6 +289,7 @@ jobs: -e SCRIBE_GRAPHQL_ACCESS_TOKEN \ -e DASHBOARD_TAG \ -e IS_A100_RUNNER \ + -e ARTIFACTS_FILE_SUFFIX \ --env-file="/tmp/github_env_${GITHUB_RUN_ID}" \ --security-opt seccomp=unconfined \ --cap-add=SYS_PTRACE \ diff --git a/.github/workflows/_runner-determinator.yml b/.github/workflows/_runner-determinator.yml index b1a07af53bb4c..0959e31c844e8 100644 --- a/.github/workflows/_runner-determinator.yml +++ b/.github/workflows/_runner-determinator.yml @@ -3,6 +3,11 @@ name: Check whether the workflow owner can use ARC runners on: workflow_call: inputs: + check_experiments: + required: false + type: string + description: | + List of experiments for this workfow. If not defined, all default experiments are included. triggering_actor: required: true type: string @@ -35,6 +40,8 @@ on: jobs: runner-determinator: + # Don't run on forked repos + if: github.repository_owner == 'pytorch' runs-on: ubuntu-latest outputs: label-type: ${{ steps.set-condition.outputs.label-type }} @@ -43,6 +50,7 @@ jobs: ISSUE_NUMBER: ${{ inputs.issue_number }} TRIGGERING_ACTOR: ${{ inputs.triggering_actor }} ISSUE_OWNER: ${{ inputs.issue_owner }} + CHECK_EXPERIMENTS: ${{ inputs.check_experiments }} steps: # - name: Checkout PyTorch # uses: pytorch/pytorch/.github/actions/checkout-pytorch@main @@ -98,7 +106,8 @@ jobs: experiments: lf: rollout_percent: 25 - + all_branches: false + default: true --- # Opt-ins: @@ -116,7 +125,7 @@ jobs: import random from argparse import ArgumentParser from logging import LogRecord - from typing import Any, Dict, Iterable, List, NamedTuple, Tuple + from typing import Any, Dict, FrozenSet, Iterable, List, NamedTuple, Tuple import yaml from github import Auth, Github @@ -145,6 +154,9 @@ jobs: all_branches: bool = ( False # If True, the experiment is also enabled on the exception branches ) + default: bool = ( + True # If True, the experiment is enabled by default for all queries + ) # Add more fields as needed @@ -199,6 +211,12 @@ jobs: f.write(f"{key}={value}\n") + def _str_comma_separated_to_set(value: str) -> FrozenSet[str]: + return frozenset( + filter(lambda itm: itm != "", map(str.strip, value.strip(" \n\t").split(","))) + ) + + def parse_args() -> Any: parser = ArgumentParser("Get dynamic rollout settings") parser.add_argument("--github-token", type=str, required=True, help="GitHub token") @@ -233,6 +251,13 @@ jobs: required=True, help="Current GitHub ref type, branch or tag", ) + parser.add_argument( + "--eligible-experiments", + type=_str_comma_separated_to_set, + required=False, + default="", + help="comma separated list of experiments to check, if omitted all experiments marked with default=True are checked", + ) return parser.parse_args() @@ -407,6 +432,7 @@ jobs: rollout_state: str, workflow_requestors: Iterable[str], branch: str, + eligible_experiments: FrozenSet[str] = frozenset(), is_canary: bool = False, ) -> str: settings = parse_settings(rollout_state) @@ -415,14 +441,25 @@ jobs: fleet_prefix = "" prefixes = [] for experiment_name, experiment_settings in settings.experiments.items(): - enabled = False - if not experiment_settings.all_branches and is_exception_branch(branch): log.info( f"Branch {branch} is an exception branch. Not enabling experiment {experiment_name}." ) continue + if eligible_experiments: + if experiment_name not in eligible_experiments: + exp_list = ", ".join(eligible_experiments) + log.info( + f"Skipping experiment '{experiment_name}', as it is not in the eligible_experiments list: {exp_list}" + ) + continue + elif not experiment_settings.default: + log.info( + f"Skipping experiment '{experiment_name}', as it is not a default experiment" + ) + continue + # Is any workflow_requestor opted in to this experiment? opted_in_users = [ requestor @@ -430,11 +467,13 @@ jobs: if is_user_opted_in(requestor, user_optins, experiment_name) ] + enabled = False if opted_in_users: log.info( f"{', '.join(opted_in_users)} have opted into experiment {experiment_name}." ) enabled = True + elif experiment_settings.rollout_perc: # If no user is opted in, then we randomly enable the experiment based on the rollout percentage if random.uniform(0, 100) <= experiment_settings.rollout_perc: @@ -503,6 +542,7 @@ jobs: rollout_state, (args.github_issue_owner, username), args.github_branch, + args.eligible_experiments, is_canary, ) @@ -538,4 +578,5 @@ jobs: --github-actor "$TRIGGERING_ACTOR" \ --github-issue-owner "$ISSUE_OWNER" \ --github-ref-type "$curr_ref_type" \ - --github-repo "$GITHUB_REPOSITORY" + --github-repo "$GITHUB_REPOSITORY" \ + --eligible-experiments "$CHECK_EXPERIMENTS" \ diff --git a/.github/workflows/_win-build.yml b/.github/workflows/_win-build.yml index f0a1bb003de76..37c1e11212ce0 100644 --- a/.github/workflows/_win-build.yml +++ b/.github/workflows/_win-build.yml @@ -68,9 +68,10 @@ jobs: shell: bash steps: # Duplicated in win-test because this MUST go before a checkout - - name: Enable git symlinks on Windows and disable fsmonitor daemon + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon shell: bash run: | + git config --global core.longpaths true git config --global core.symlinks true # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock diff --git a/.github/workflows/_win-test.yml b/.github/workflows/_win-test.yml index 4991d897b61e8..41ceb05e046d7 100644 --- a/.github/workflows/_win-test.yml +++ b/.github/workflows/_win-test.yml @@ -46,9 +46,10 @@ jobs: shell: bash steps: # Duplicated in win-build because this MUST go before a checkout - - name: Enable git symlinks on Windows and disable fsmonitor daemon + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon shell: bash run: | + git config --global core.longpaths true git config --global core.symlinks true # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock diff --git a/.github/workflows/build-conda-images.yml b/.github/workflows/build-conda-images.yml index 4d2f146a7577d..9fd670873a4e3 100644 --- a/.github/workflows/build-conda-images.yml +++ b/.github/workflows/build-conda-images.yml @@ -35,7 +35,7 @@ jobs: runs-on: linux.9xlarge.ephemeral strategy: matrix: - cuda_version: ["11.8", "12.1", "12.4", "cpu"] + cuda_version: ["11.8", "12.1", "12.4", "12.6", "cpu"] env: CUDA_VERSION: ${{ matrix.cuda_version }} steps: @@ -62,5 +62,11 @@ jobs: fi - name: Build Docker Image if: env.WITH_PUSH == 'true' - run: | - .ci/docker/conda/build.sh conda-builder${{ matrix.cuda_version == 'cpu' && ':' || ':cuda' }}${{matrix.cuda_version}} + uses: nick-fields/retry@v3.0.0 + with: + shell: bash + timeout_minutes: 90 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + .ci/docker/conda/build.sh conda-builder${{ matrix.cuda_version == 'cpu' && ':' || ':cuda' }}${{matrix.cuda_version}} diff --git a/.github/workflows/build-libtorch-images.yml b/.github/workflows/build-libtorch-images.yml index 5146e7593a5fd..abacbda450559 100644 --- a/.github/workflows/build-libtorch-images.yml +++ b/.github/workflows/build-libtorch-images.yml @@ -31,7 +31,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -44,7 +44,7 @@ jobs: runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" strategy: matrix: - cuda_version: ["12.4", "12.1", "11.8"] + cuda_version: ["12.6", "12.4", "12.1", "11.8"] env: GPU_ARCH_TYPE: cuda GPU_ARCH_VERSION: ${{ matrix.cuda_version }} @@ -72,8 +72,14 @@ jobs: fi - name: Build Docker Image if: env.WITH_PUSH == 'true' - run: | - .ci/docker/libtorch/build.sh libtorch-cxx11-builder:cuda${{matrix.cuda_version}} + uses: nick-fields/retry@v3.0.0 + with: + shell: bash + timeout_minutes: 90 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + .ci/docker/libtorch/build.sh libtorch-cxx11-builder:cuda${{matrix.cuda_version}} build-docker-rocm: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} needs: get-label-type @@ -108,8 +114,14 @@ jobs: fi - name: Build Docker Image if: env.WITH_PUSH == 'true' - run: | - .ci/docker/libtorch/build.sh libtorch-cxx11-builder:rocm${{matrix.rocm_version}} + uses: nick-fields/retry@v3.0.0 + with: + shell: bash + timeout_minutes: 90 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + .ci/docker/libtorch/build.sh libtorch-cxx11-builder:rocm${{matrix.rocm_version}} build-docker-cpu: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} needs: get-label-type @@ -138,5 +150,11 @@ jobs: fi - name: Build Docker Image if: env.WITH_PUSH == 'true' - run: | - .ci/docker/libtorch/build.sh libtorch-cxx11-builder:cpu + uses: nick-fields/retry@v3.0.0 + with: + shell: bash + timeout_minutes: 90 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + .ci/docker/libtorch/build.sh libtorch-cxx11-builder:cpu diff --git a/.github/workflows/build-manywheel-images.yml b/.github/workflows/build-manywheel-images.yml index 7ecf278c58575..4c77c669994ea 100644 --- a/.github/workflows/build-manywheel-images.yml +++ b/.github/workflows/build-manywheel-images.yml @@ -35,7 +35,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -48,7 +48,7 @@ jobs: runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.9xlarge.ephemeral" strategy: matrix: - cuda_version: ["12.4", "12.1", "11.8"] + cuda_version: ["12.6", "12.4", "12.1", "11.8"] env: GPU_ARCH_TYPE: cuda GPU_ARCH_VERSION: ${{ matrix.cuda_version }} @@ -78,8 +78,14 @@ jobs: fi - name: Build Docker Image if: env.WITH_PUSH == 'true' - run: | - .ci/docker/manywheel/build.sh manylinux-builder:cuda${{matrix.cuda_version}} + uses: nick-fields/retry@v3.0.0 + with: + shell: bash + timeout_minutes: 90 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + .ci/docker/manywheel/build.sh manylinux-builder:cuda${{matrix.cuda_version}} # NOTE: manylinux_2_28 are still experimental, see https://github.com/pytorch/pytorch/issues/123649 build-docker-cuda-manylinux_2_28: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} @@ -117,8 +123,14 @@ jobs: fi - name: Build Docker Image if: env.WITH_PUSH == 'true' - run: | - .ci/docker/manywheel/build.sh manylinux2_28-builder:cuda${{matrix.cuda_version}} + uses: nick-fields/retry@v3.0.0 + with: + shell: bash + timeout_minutes: 90 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + .ci/docker/manywheel/build.sh manylinux2_28-builder:cuda${{matrix.cuda_version}} build-docker-cuda-aarch64: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} needs: get-label-type @@ -151,8 +163,14 @@ jobs: fi - name: Build Docker Image if: env.WITH_PUSH == 'true' - run: | - .ci/docker/manywheel/build.sh manylinuxaarch64-builder:cuda${{matrix.cuda_version}} + uses: nick-fields/retry@v3.0.0 + with: + shell: bash + timeout_minutes: 90 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + .ci/docker/manywheel/build.sh manylinuxaarch64-builder:cuda${{matrix.cuda_version}} build-docker-rocm: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} needs: get-label-type @@ -187,8 +205,14 @@ jobs: fi - name: Build Docker Image if: env.WITH_PUSH == 'true' - run: | - .ci/docker/manywheel/build.sh manylinux-builder:rocm${{matrix.rocm_version}} + uses: nick-fields/retry@v3.0.0 + with: + shell: bash + timeout_minutes: 90 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + .ci/docker/manywheel/build.sh manylinux-builder:rocm${{matrix.rocm_version}} build-docker-cpu: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} needs: get-label-type @@ -217,8 +241,14 @@ jobs: fi - name: Build Docker Image if: env.WITH_PUSH == 'true' - run: | - .ci/docker/manywheel/build.sh manylinux-builder:cpu + uses: nick-fields/retry@v3.0.0 + with: + shell: bash + timeout_minutes: 90 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + .ci/docker/manywheel/build.sh manylinux-builder:cpu build-docker-cpu-manylinux_2_28: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} needs: get-label-type @@ -249,8 +279,14 @@ jobs: fi - name: Build Docker Image if: env.WITH_PUSH == 'true' - run: | - .ci/docker/manywheel/build.sh manylinux2_28-builder:cpu + uses: nick-fields/retry@v3.0.0 + with: + shell: bash + timeout_minutes: 90 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + .ci/docker/manywheel/build.sh manylinux2_28-builder:cpu build-docker-cpu-aarch64: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} needs: get-label-type @@ -281,8 +317,14 @@ jobs: fi - name: Build Docker Image if: env.WITH_PUSH == 'true' - run: | - .ci/docker/manywheel/build.sh manylinuxaarch64-builder:cpu-aarch64 + uses: nick-fields/retry@v3.0.0 + with: + shell: bash + timeout_minutes: 90 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + .ci/docker/manywheel/build.sh manylinuxaarch64-builder:cpu-aarch64 build-docker-cpu-aarch64-2_28: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} needs: get-label-type @@ -316,8 +358,14 @@ jobs: env: DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }} DOCKER_ID: ${{ secrets.DOCKER_ID }} - run: | - .ci/docker/manywheel/build.sh manylinux2_28_aarch64-builder:cpu-aarch64 + uses: nick-fields/retry@v3.0.0 + with: + shell: bash + timeout_minutes: 90 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + .ci/docker/manywheel/build.sh manylinux2_28_aarch64-builder:cpu-aarch64 build-docker-cpu-cxx11-abi: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} needs: get-label-type @@ -348,8 +396,14 @@ jobs: fi - name: Build Docker Image if: env.WITH_PUSH == 'true' - run: | - .ci/docker/manywheel/build.sh manylinuxcxx11-abi-builder:cpu-cxx11-abi + uses: nick-fields/retry@v3.0.0 + with: + shell: bash + timeout_minutes: 90 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + .ci/docker/manywheel/build.sh manylinuxcxx11-abi-builder:cpu-cxx11-abi build-docker-xpu: environment: ${{ (github.ref == 'refs/heads/main' || startsWith(github.event.ref, 'refs/tags/v')) && 'docker-build' || '' }} needs: get-label-type @@ -380,5 +434,11 @@ jobs: fi - name: Build Docker Image if: env.WITH_PUSH == 'true' - run: | - .ci/docker/manywheel/build.sh manylinux2_28-builder:xpu + uses: nick-fields/retry@v3.0.0 + with: + shell: bash + timeout_minutes: 90 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + .ci/docker/manywheel/build.sh manylinux2_28-builder:xpu diff --git a/.github/workflows/build-triton-wheel.yml b/.github/workflows/build-triton-wheel.yml index 332223a611e49..1482ec0d7ac5b 100644 --- a/.github/workflows/build-triton-wheel.yml +++ b/.github/workflows/build-triton-wheel.yml @@ -29,7 +29,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/create_release.yml b/.github/workflows/create_release.yml index 2c83b8cb57196..b09c91a560ace 100644 --- a/.github/workflows/create_release.yml +++ b/.github/workflows/create_release.yml @@ -18,7 +18,7 @@ on: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 7aca5ffbd626d..06c320021b31b 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -32,7 +32,7 @@ permissions: read-all jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -123,7 +123,7 @@ jobs: IMAGE_NAME: ${{ matrix.docker-image-name }} with: shell: bash - timeout_minutes: 15 + timeout_minutes: 30 max_attempts: 5 retry_wait_seconds: 90 command: | diff --git a/.github/workflows/docker-release.yml b/.github/workflows/docker-release.yml index 41c5b40860303..4b663c5b14c84 100644 --- a/.github/workflows/docker-release.yml +++ b/.github/workflows/docker-release.yml @@ -36,7 +36,7 @@ permissions: read-all jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml index a29bb0288924f..47a0a5d1bb652 100644 --- a/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-aarch64-binary-manywheel-nightly.yml @@ -39,7 +39,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -65,7 +65,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_9-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cpu-aarch64-test: # Testing @@ -185,7 +185,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_10-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cpu-aarch64-test: # Testing @@ -305,7 +305,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_11-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cpu-aarch64-test: # Testing @@ -425,7 +425,7 @@ jobs: ALPINE_IMAGE: "arm64v8/alpine" build_name: manywheel-py3_12-cpu-aarch64 build_environment: linux-aarch64-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cpu-aarch64-test: # Testing diff --git a/.github/workflows/generated-linux-binary-conda-nightly.yml b/.github/workflows/generated-linux-binary-conda-nightly.yml index e4451fb1f9b74..c1fc0de5cb378 100644 --- a/.github/workflows/generated-linux-binary-conda-nightly.yml +++ b/.github/workflows/generated-linux-binary-conda-nightly.yml @@ -39,7 +39,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml index ad1098bf7d170..8a2179090672c 100644 --- a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml +++ b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-main.yml @@ -34,7 +34,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml index 408106d0096ab..ae94b978eed11 100644 --- a/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-cxx11-abi-nightly.yml @@ -39,7 +39,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml index 06c26961e9894..cece36980b13d 100644 --- a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml +++ b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-main.yml @@ -34,7 +34,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml index ee9f94c8ac6c2..30dd464830e7c 100644 --- a/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml +++ b/.github/workflows/generated-linux-binary-libtorch-pre-cxx11-nightly.yml @@ -39,7 +39,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/generated-linux-binary-manywheel-main.yml b/.github/workflows/generated-linux-binary-manywheel-main.yml index d87b832bf03cb..b7a97835141d7 100644 --- a/.github/workflows/generated-linux-binary-manywheel-main.yml +++ b/.github/workflows/generated-linux-binary-manywheel-main.yml @@ -34,7 +34,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -153,7 +153,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cuda12_4-test: # Testing diff --git a/.github/workflows/generated-linux-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-nightly.yml index d211fad70b3a9..aa002aed3fea2 100644 --- a/.github/workflows/generated-linux-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-binary-manywheel-nightly.yml @@ -39,7 +39,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -343,7 +343,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_9-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cuda12_4-test: # Testing @@ -1029,7 +1029,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_10-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cuda12_4-test: # Testing @@ -1785,7 +1785,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_11-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cuda12_4-test: # Testing @@ -2471,7 +2471,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_12-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cuda12_4-test: # Testing @@ -3157,7 +3157,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cuda12_4-test: # Testing @@ -3623,7 +3623,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build_name: manywheel-py3_13t-cuda12_4 build_environment: linux-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13t-cuda12_4-test: # Testing diff --git a/.github/workflows/generated-linux-binary-manywheel-split-main.yml b/.github/workflows/generated-linux-binary-manywheel-split-main.yml deleted file mode 100644 index 9c2456e4632c7..0000000000000 --- a/.github/workflows/generated-linux-binary-manywheel-split-main.yml +++ /dev/null @@ -1,182 +0,0 @@ -# @generated DO NOT EDIT MANUALLY - -# Template is at: .github/templates/linux_binary_build_workflow.yml.j2 -# Generation script: .github/scripts/generate_ci_workflows.py -name: linux-binary-manywheel-split - - -on: - push: - branches: - - main - tags: - - 'ciflow/periodic/*' - workflow_dispatch: - -env: - # Needed for conda builds - ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" - ANACONDA_USER: pytorch - AWS_DEFAULT_REGION: us-east-1 - BINARY_ENV_FILE: /tmp/env - BUILD_ENVIRONMENT: linux-binary-manywheel-split - BUILDER_ROOT: /builder - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - PR_NUMBER: ${{ github.event.pull_request.number }} - PYTORCH_FINAL_PACKAGE_DIR: /artifacts - PYTORCH_ROOT: /pytorch - SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - SKIP_ALL_TESTS: 0 -concurrency: - group: linux-binary-manywheel-split-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true - -jobs: - get-label-type: - name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml - with: - triggering_actor: ${{ github.triggering_actor }} - issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} - curr_branch: ${{ github.head_ref || github.ref_name }} - curr_ref_type: ${{ github.ref_type }} - manywheel-py3_9-cuda11_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda11_8 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda11_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda11_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda11_8 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_9-cuda12_1-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_1 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_1-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_1-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_1 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - - manywheel-py3_9-cuda12_4-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_4 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_4-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_4-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_4 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/generated-linux-binary-manywheel-split-nightly.yml b/.github/workflows/generated-linux-binary-manywheel-split-nightly.yml deleted file mode 100644 index 5d33a0b59d674..0000000000000 --- a/.github/workflows/generated-linux-binary-manywheel-split-nightly.yml +++ /dev/null @@ -1,1796 +0,0 @@ -# @generated DO NOT EDIT MANUALLY - -# Template is at: .github/templates/linux_binary_build_workflow.yml.j2 -# Generation script: .github/scripts/generate_ci_workflows.py -name: linux-binary-manywheel-split - - -on: - push: - # NOTE: Meta Employees can trigger new nightlies using: https://fburl.com/trigger_pytorch_nightly_build - branches: - - nightly - tags: - # NOTE: Binary build pipelines should only get triggered on release candidate builds - # Release candidate tags look like: v1.11.0-rc1 - - v[0-9]+.[0-9]+.[0-9]+-rc[0-9]+ - - 'ciflow/binaries/*' - - 'ciflow/binaries_wheel/*' - workflow_dispatch: - -env: - # Needed for conda builds - ALPINE_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/tool/alpine" - ANACONDA_USER: pytorch - AWS_DEFAULT_REGION: us-east-1 - BINARY_ENV_FILE: /tmp/env - BUILD_ENVIRONMENT: linux-binary-manywheel-split - BUILDER_ROOT: /builder - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - PR_NUMBER: ${{ github.event.pull_request.number }} - PYTORCH_FINAL_PACKAGE_DIR: /artifacts - PYTORCH_ROOT: /pytorch - SHA1: ${{ github.event.pull_request.head.sha || github.sha }} - SKIP_ALL_TESTS: 0 -concurrency: - group: linux-binary-manywheel-split-${{ github.event.pull_request.number || github.ref_name }}-${{ github.ref_type == 'branch' && github.sha }}-${{ github.event_name == 'workflow_dispatch' }} - cancel-in-progress: true - -jobs: - get-label-type: - name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml - with: - triggering_actor: ${{ github.triggering_actor }} - issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} - curr_branch: ${{ github.head_ref || github.ref_name }} - curr_ref_type: ${{ github.ref_type }} - manywheel-py3_9-cuda11_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda11_8 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda11_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda11_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda11_8 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda11_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_9-cuda11_8-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda11_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_9-cuda12_1-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_1 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_1-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_1-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_1 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_1-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_9-cuda12_1-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_1 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_9-cuda12_4-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cuda12_4 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_4-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cuda12_4-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_4 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cuda12_4-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_9-cuda12_4-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cuda12_4 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_9-cpu-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.9" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_9-cpu - build_environment: linux-binary-manywheel-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cpu-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_9-cpu-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cpu - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_9-cpu-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_9-cpu-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.9" - build_name: manywheel-py3_9-cpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_10-cuda11_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda11_8 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda11_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda11_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda11_8 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda11_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cuda11_8-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda11_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_10-cuda12_1-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda12_1 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda12_1-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_1-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cuda12_1-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_1 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_10-cuda12_4-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cuda12_4 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_4-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cuda12_4-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_4 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cuda12_4-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cuda12_4-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cuda12_4 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_10-cpu-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.10" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_10-cpu - build_environment: linux-binary-manywheel-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cpu-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_10-cpu-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cpu - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_10-cpu-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_10-cpu-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.10" - build_name: manywheel-py3_10-cpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_11-cuda11_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda11_8 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda11_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_11-cuda11_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda11_8 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda11_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_11-cuda11_8-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda11_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_11-cuda12_1-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda12_1 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_1-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_11-cuda12_1-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_1 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_1-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_11-cuda12_1-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_1 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_11-cuda12_1-full-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda12_1-full - build_environment: linux-binary-manywheel-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_1-full-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_11-cuda12_1-full-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_1-full - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_1-full-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_11-cuda12_1-full-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_1-full - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_11-cuda12_4-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cuda12_4 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_4-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_11-cuda12_4-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_4 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cuda12_4-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_11-cuda12_4-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cuda12_4 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_11-cpu-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.11" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_11-cpu - build_environment: linux-binary-manywheel-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cpu-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_11-cpu-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cpu - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_11-cpu-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_11-cpu-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.11" - build_name: manywheel-py3_11-cpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-cuda11_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cuda11_8 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda11_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_12-cuda11_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda11_8 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda11_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-cuda11_8-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda11_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-cuda12_1-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cuda12_1 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_1-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_12-cuda12_1-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_1 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_1-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-cuda12_1-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_1 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-cuda12_4-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cuda12_4 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_4-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_12-cuda12_4-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_4 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cuda12_4-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-cuda12_4-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cuda12_4 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_12-cpu-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.12" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_12-cpu - build_environment: linux-binary-manywheel-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cpu-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_12-cpu-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cpu - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_12-cpu-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_12-cpu-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.12" - build_name: manywheel-py3_12-cpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_13-cuda11_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13-cuda11_8 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda11_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13-cuda11_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda11_8 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda11_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13-cuda11_8-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda11_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_13-cuda12_1-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13-cuda12_1 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_1-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13-cuda12_1-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_1 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_1-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13-cuda12_1-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_1 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_13-cuda12_4-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13-cuda12_4 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_4-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13-cuda12_4-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_4 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cuda12_4-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13-cuda12_4-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cuda12_4 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_13-cpu-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.13" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13-cpu - build_environment: linux-binary-manywheel-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cpu-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13-cpu-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cpu - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13-cpu-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13-cpu-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.13" - build_name: manywheel-py3_13-cpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_13t-cuda11_8-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.13t" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13t-cuda11_8 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu11==11.8.89; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu11==11.8.87; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu11==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu11==11.11.3.6; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu11==10.9.0.58; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu11==10.3.0.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu11==11.4.1.48; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu11==11.7.5.86; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu11==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu11==11.8.86; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda11_8-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13t-cuda11_8-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda11_8 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda11_8-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13t-cuda11_8-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu118 - GPU_ARCH_VERSION: 11.8 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda11.8-main - use_split_build: True - DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda11_8 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_13t-cuda12_1-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.13t" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13t-cuda12_1 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda12_1-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13t-cuda12_1-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda12_1 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda12_1-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13t-cuda12_1-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu121 - GPU_ARCH_VERSION: 12.1 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.1-main - use_split_build: True - DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda12_1 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_13t-cuda12_4-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.13t" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13t-cuda12_4 - build_environment: linux-binary-manywheel-split - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda12_4-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13t-cuda12_4-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda12_4 - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge.nvidia.gpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cuda12_4-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13t-cuda12_4-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cu124 - GPU_ARCH_VERSION: 12.4 - GPU_ARCH_TYPE: cuda - DOCKER_IMAGE: pytorch/manylinux-builder:cuda12.4-main - use_split_build: True - DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cuda12_4 - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml - - manywheel-py3_13t-cpu-build: - if: ${{ github.repository_owner == 'pytorch' }} - uses: ./.github/workflows/_binary-build-linux.yml - needs: get-label-type - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.13t" - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - build_name: manywheel-py3_13t-cpu - build_environment: linux-binary-manywheel-split - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cpu-test: # Testing - if: ${{ github.repository_owner == 'pytorch' }} - needs: - - manywheel-py3_13t-cpu-build - - get-label-type - uses: ./.github/workflows/_binary-test-linux.yml - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cpu - build_environment: linux-binary-manywheel-split - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - runs_on: linux.4xlarge - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - manywheel-py3_13t-cpu-upload: # Uploading - if: ${{ github.repository_owner == 'pytorch' }} - permissions: - id-token: write - contents: read - needs: manywheel-py3_13t-cpu-test - with: - PYTORCH_ROOT: /pytorch - BUILDER_ROOT: /builder - PACKAGE_TYPE: manywheel - # TODO: This is a legacy variable that we eventually want to get rid of in - # favor of GPU_ARCH_VERSION - DESIRED_CUDA: cpu - GPU_ARCH_TYPE: cpu - DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main - use_split_build: True - DESIRED_PYTHON: "3.13t" - build_name: manywheel-py3_13t-cpu - secrets: - github-token: ${{ secrets.GITHUB_TOKEN }} - conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} - conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} - uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml index 22468055434e8..114f83569bd01 100644 --- a/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml +++ b/.github/workflows/generated-linux-s390x-binary-manywheel-nightly.yml @@ -39,7 +39,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -64,7 +64,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_9-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_9-cpu-s390x-test: # Testing @@ -133,7 +133,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_10-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_10-cpu-s390x-test: # Testing @@ -202,7 +202,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_11-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_11-cpu-s390x-test: # Testing @@ -271,7 +271,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_12-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_12-cpu-s390x-test: # Testing @@ -340,7 +340,7 @@ jobs: ALPINE_IMAGE: "docker.io/s390x/alpine" build_name: manywheel-py3_13-cpu-s390x build_environment: linux-s390x-binary-manywheel - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' secrets: github-token: ${{ secrets.GITHUB_TOKEN }} manywheel-py3_13-cpu-s390x-test: # Testing diff --git a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml index 0a3716c7019b2..687e716eae471 100644 --- a/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml +++ b/.github/workflows/generated-macos-arm64-binary-wheel-nightly.yml @@ -46,7 +46,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -162,7 +162,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -278,7 +278,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -394,7 +394,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: # NOTE: These environment variables are put here so that they can be applied on every job equally # They are also here because setting them at a workflow level doesn't give us access to the @@ -496,3 +496,119 @@ jobs: conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} uses: ./.github/workflows/_binary-upload.yml + wheel-py3_13-cpu-build: + if: ${{ github.repository_owner == 'pytorch' }} + runs-on: macos-14-xlarge + timeout-minutes: 240 + env: + PYTORCH_ROOT: ${{ github.workspace }}/pytorch + BUILDER_ROOT: ${{ github.workspace }}/builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + SKIP_ALL_TESTS: 1 + DESIRED_PYTHON: "3.13" + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' + steps: + # NOTE: These environment variables are put here so that they can be applied on every job equally + # They are also here because setting them at a workflow level doesn't give us access to the + # runner.temp variable, which we need. + - name: Populate binary env + shell: bash + run: | + # shellcheck disable=SC2129 + echo "BINARY_ENV_FILE=${RUNNER_TEMP}/env" >> "${GITHUB_ENV}" + # shellcheck disable=SC2129 + echo "PYTORCH_FINAL_PACKAGE_DIR=${RUNNER_TEMP}/artifacts" >> "${GITHUB_ENV}" + # shellcheck disable=SC2129 + echo "MAC_PACKAGE_WORK_DIR=${RUNNER_TEMP}" >> "${GITHUB_ENV}" + - name: Install conda and dependencies + run: | + # Install conda, setup-miniconda messes with the path that messes with the ruby stuff we do later on + curl --retry 3 --retry-all-errors -o "${RUNNER_TEMP}/conda.sh" "https://repo.anaconda.com/miniconda/Miniconda3-py310_23.5.2-0-MacOSX-$(uname -m).sh" + chmod +x "${RUNNER_TEMP}/conda.sh" + /bin/bash "${RUNNER_TEMP}/conda.sh" -b -p "${RUNNER_TEMP}/anaconda" + echo "${RUNNER_TEMP}/anaconda/bin" >> "${GITHUB_PATH}" + if [ -d "/Applications/Xcode_14.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_14.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + elif [ -d "/Applications/Xcode_13.3.1.app" ]; then + echo "DEVELOPER_DIR=/Applications/Xcode_13.3.1.app/Contents/Developer" >> "${GITHUB_ENV}" + fi + - name: Checkout PyTorch + uses: malfet/checkout@silent-checkout + with: + ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }} + submodules: recursive + path: pytorch + quiet-checkout: true + - name: Clean PyTorch checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: pytorch + - name: Checkout pytorch/builder + uses: malfet/checkout@silent-checkout + with: + ref: main + submodules: recursive + repository: pytorch/builder + path: builder + quiet-checkout: true + - name: Clean pytorch/builder checkout + run: | + # Remove any artifacts from the previous checkouts + git clean -fxd + working-directory: builder + - name: Install sccache (only for non-forked PRs, and pushes to trunk) + uses: nick-fields/retry@v3.0.0 + if: ${{ github.event_name == 'push' || github.event.pull_request.head.repo.full_name == github.repository }} + with: + timeout_minutes: 5 + max_attempts: 3 + retry_wait_seconds: 90 + command: | + sudo curl --retry 3 --retry-all-errors https://s3.amazonaws.com/ossci-macos/sccache_v2.15 --output /usr/local/bin/sccache + sudo chmod +x /usr/local/bin/sccache + echo "SCCACHE_BUCKET=ossci-compiler-cache-circleci-v2" >> "${GITHUB_ENV}" + - name: Populate binary env + run: | + # shellcheck disable=SC1091 + source "${RUNNER_TEMP}/anaconda/bin/activate" + "${PYTORCH_ROOT}/.circleci/scripts/binary_populate_env.sh" + - name: Build PyTorch binary + run: | + # shellcheck disable=SC1091 + source "${RUNNER_TEMP}/anaconda/bin/activate" + "${PYTORCH_ROOT}/.circleci/scripts/binary_macos_build.sh" + - uses: actions/upload-artifact@v4.4.0 + if: always() + with: + name: wheel-py3_13-cpu + retention-days: 14 + if-no-files-found: error + path: "${{ env.PYTORCH_FINAL_PACKAGE_DIR }}" + wheel-py3_13-cpu-upload: # Uploading + if: ${{ github.repository_owner == 'pytorch' }} + permissions: + id-token: write + contents: read + needs: wheel-py3_13-cpu-build + with: + PYTORCH_ROOT: /pytorch + BUILDER_ROOT: /builder + PACKAGE_TYPE: wheel + # TODO: This is a legacy variable that we eventually want to get rid of in + # favor of GPU_ARCH_VERSION + DESIRED_CUDA: cpu + GPU_ARCH_TYPE: cpu + DOCKER_IMAGE: pytorch/manylinux-builder:cpu-main + DESIRED_PYTHON: "3.13" + build_name: wheel-py3_13-cpu + use_s3: False + secrets: + github-token: ${{ secrets.GITHUB_TOKEN }} + conda-pytorchbot-token: ${{ secrets.CONDA_PYTORCHBOT_TOKEN }} + conda-pytorchbot-token-test: ${{ secrets.CONDA_PYTORCHBOT_TOKEN_TEST }} + uses: ./.github/workflows/_binary-upload.yml diff --git a/.github/workflows/generated-windows-binary-conda-nightly.yml b/.github/workflows/generated-windows-binary-conda-nightly.yml index bcadb5d0fc450..30273a358ae03 100644 --- a/.github/workflows/generated-windows-binary-conda-nightly.yml +++ b/.github/workflows/generated-windows-binary-conda-nightly.yml @@ -34,7 +34,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -64,7 +64,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -75,6 +75,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -178,7 +188,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -189,6 +199,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -310,7 +330,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -321,6 +341,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -425,7 +455,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -436,6 +466,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -558,7 +598,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -569,6 +609,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -673,7 +723,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -684,6 +734,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -806,7 +866,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -817,6 +877,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -921,7 +991,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -932,6 +1002,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1053,7 +1133,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1064,6 +1144,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1167,7 +1257,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1178,6 +1268,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1299,7 +1399,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1310,6 +1410,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1414,7 +1524,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1425,6 +1535,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1547,7 +1667,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1558,6 +1678,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1662,7 +1792,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1673,6 +1803,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1795,7 +1935,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1806,6 +1946,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1910,7 +2060,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1921,6 +2071,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2042,7 +2202,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2053,6 +2213,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2156,7 +2326,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2167,6 +2337,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2288,7 +2468,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2299,6 +2479,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2403,7 +2593,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2414,6 +2604,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2536,7 +2736,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2547,6 +2747,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2651,7 +2861,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2662,6 +2872,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2784,7 +3004,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2795,6 +3015,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2899,7 +3129,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2910,6 +3140,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3031,7 +3271,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3042,6 +3282,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3145,7 +3395,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3156,6 +3406,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3277,7 +3537,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3288,6 +3548,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3392,7 +3662,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3403,6 +3673,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3525,7 +3805,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3536,6 +3816,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3640,7 +3930,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3651,6 +3941,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3773,7 +4073,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3784,6 +4084,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3888,7 +4198,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3899,6 +4209,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml index 85e2564d612f4..228cc881795b1 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-main.yml @@ -27,7 +27,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -61,7 +61,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -72,6 +72,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -179,7 +189,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -190,6 +200,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell diff --git a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml index 215dbe681896e..d4a9ea942b65e 100644 --- a/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-debug-nightly.yml @@ -34,7 +34,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -68,7 +68,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -79,6 +79,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -186,7 +196,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -197,6 +207,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -326,7 +346,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -337,6 +357,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -445,7 +475,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -456,6 +486,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -586,7 +626,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -597,6 +637,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -705,7 +755,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -716,6 +766,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -846,7 +906,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -857,6 +917,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -965,7 +1035,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -976,6 +1046,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell diff --git a/.github/workflows/generated-windows-binary-libtorch-release-main.yml b/.github/workflows/generated-windows-binary-libtorch-release-main.yml index 7fd315028bb70..311bbc02a458b 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-main.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-main.yml @@ -27,7 +27,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -61,7 +61,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -72,6 +72,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -179,7 +189,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -190,6 +200,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell diff --git a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml index c3ce65daff709..b6dec6bbd9c96 100644 --- a/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml +++ b/.github/workflows/generated-windows-binary-libtorch-release-nightly.yml @@ -34,7 +34,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -68,7 +68,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -79,6 +79,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -186,7 +196,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -197,6 +207,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -326,7 +346,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -337,6 +357,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -445,7 +475,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -456,6 +486,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -586,7 +626,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -597,6 +637,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -705,7 +755,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -716,6 +766,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -846,7 +906,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -857,6 +917,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -965,7 +1035,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -976,6 +1046,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell diff --git a/.github/workflows/generated-windows-binary-wheel-nightly.yml b/.github/workflows/generated-windows-binary-wheel-nightly.yml index 316329b46870f..b9c259701db6a 100644 --- a/.github/workflows/generated-windows-binary-wheel-nightly.yml +++ b/.github/workflows/generated-windows-binary-wheel-nightly.yml @@ -34,7 +34,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -55,7 +55,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -65,7 +65,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -76,6 +76,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -179,7 +189,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -190,6 +200,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -302,7 +322,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -312,7 +332,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -323,6 +343,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -427,7 +457,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -438,6 +468,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -551,7 +591,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -561,7 +601,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -572,6 +612,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -676,7 +726,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -687,6 +737,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -800,7 +860,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.9" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -810,7 +870,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -821,6 +881,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -925,7 +995,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -936,6 +1006,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1057,7 +1137,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1068,6 +1148,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1171,7 +1261,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1182,6 +1272,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1293,7 +1393,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1303,7 +1403,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1314,6 +1414,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1417,7 +1527,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1428,6 +1538,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1540,7 +1660,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1550,7 +1670,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1561,6 +1681,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1665,7 +1795,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1676,6 +1806,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1789,7 +1929,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -1799,7 +1939,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1810,6 +1950,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -1914,7 +2064,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -1925,6 +2075,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2038,7 +2198,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.10" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2048,7 +2208,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2059,6 +2219,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2163,7 +2333,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2174,6 +2344,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2295,7 +2475,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2306,6 +2486,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2409,7 +2599,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2420,6 +2610,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2531,7 +2731,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2541,7 +2741,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2552,6 +2752,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2655,7 +2865,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2666,6 +2876,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2778,7 +2998,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -2788,7 +3008,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2799,6 +3019,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -2903,7 +3133,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -2914,6 +3144,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3027,7 +3267,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3037,7 +3277,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3048,6 +3288,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3152,7 +3402,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3163,6 +3413,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3276,7 +3536,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.11" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3286,7 +3546,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3297,6 +3557,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3401,7 +3671,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3412,6 +3682,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3533,7 +3813,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3544,6 +3824,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3647,7 +3937,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3658,6 +3948,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3769,7 +4069,7 @@ jobs: GPU_ARCH_TYPE: cpu SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -3779,7 +4079,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3790,6 +4090,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -3893,7 +4203,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -3904,6 +4214,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -4016,7 +4336,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4026,7 +4346,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -4037,6 +4357,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -4141,7 +4471,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -4152,6 +4482,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -4265,7 +4605,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4275,7 +4615,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -4286,6 +4626,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -4390,7 +4740,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -4401,6 +4751,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -4514,7 +4874,7 @@ jobs: GPU_ARCH_TYPE: cuda SKIP_ALL_TESTS: 1 DESIRED_PYTHON: "3.12" - PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.1.3.1; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.0.2.54; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.2.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.4.5.107; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.1.0.106; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.1.105; platform_system == 'Linux' and platform_machine == 'x86_64' + PYTORCH_EXTRA_INSTALL_REQUIREMENTS: nvidia-cuda-nvrtc-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-runtime-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cuda-cupti-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cudnn-cu12==9.1.0.70; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cublas-cu12==12.4.5.8; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cufft-cu12==11.2.1.3; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-curand-cu12==10.3.5.147; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusolver-cu12==11.6.1.9; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparse-cu12==12.3.1.170; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-cusparselt-cu12==0.6.2; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nccl-cu12==2.21.5; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvtx-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' | nvidia-nvjitlink-cu12==12.4.127; platform_system == 'Linux' and platform_machine == 'x86_64' steps: - name: Display EC2 information shell: bash @@ -4524,7 +4884,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -4535,6 +4895,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -4639,7 +5009,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -4650,6 +5020,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -4771,7 +5151,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -4782,6 +5162,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell @@ -4885,7 +5275,7 @@ jobs: # Pulled from instance metadata endpoint for EC2 # see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/instancedata-data-retrieval.html category=$1 - curl -fsSL "http://169.254.169.254/latest/meta-data/${category}" + curl -H "X-aws-ec2-metadata-token: $(curl -s -X PUT "http://169.254.169.254/latest/api/token" -H "X-aws-ec2-metadata-token-ttl-seconds: 30")" -fsSL "http://169.254.169.254/latest/meta-data/${category}" } echo "ami-id: $(get_ec2_metadata ami-id)" echo "instance-id: $(get_ec2_metadata instance-id)" @@ -4896,6 +5286,16 @@ jobs: continue-on-error: true with: github-secret: ${{ secrets.GITHUB_TOKEN }} + - name: Enable git long paths and symlinks on Windows and disable fsmonitor daemon + shell: bash + run: | + git config --global core.longpaths true + git config --global core.symlinks true + + # https://git-scm.com/docs/git-fsmonitor--daemon. The daemon could lock + # the directory on Windows and prevent GHA from checking out as reported + # in https://github.com/actions/checkout/issues/1018 + git config --global core.fsmonitor false # Needed for binary builds, see: https://github.com/pytorch/pytorch/issues/73339#issuecomment-1058981560 - name: Enable long paths on Windows shell: powershell diff --git a/.github/workflows/inductor-cu124.yml b/.github/workflows/inductor-cu124.yml index 950afbf0b591e..bddbc9c730af4 100644 --- a/.github/workflows/inductor-cu124.yml +++ b/.github/workflows/inductor-cu124.yml @@ -20,7 +20,8 @@ permissions: read-all jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -56,7 +57,7 @@ jobs: { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_cpp_wrapper_abi_compatible", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_cpp_wrapper", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} diff --git a/.github/workflows/inductor-micro-benchmark-x86.yml b/.github/workflows/inductor-micro-benchmark-x86.yml index d31dbc5951ea1..cbd9a5dace798 100644 --- a/.github/workflows/inductor-micro-benchmark-x86.yml +++ b/.github/workflows/inductor-micro-benchmark-x86.yml @@ -17,6 +17,7 @@ permissions: read-all jobs: linux-jammy-cpu-py3_9-gcc11-inductor-build: + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' name: linux-jammy-cpu-py3.9-gcc11-inductor uses: ./.github/workflows/_linux-build.yml with: diff --git a/.github/workflows/inductor-micro-benchmark.yml b/.github/workflows/inductor-micro-benchmark.yml index fad0538d10755..e8270abd469aa 100644 --- a/.github/workflows/inductor-micro-benchmark.yml +++ b/.github/workflows/inductor-micro-benchmark.yml @@ -18,7 +18,8 @@ permissions: read-all jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/inductor-perf-compare.yml b/.github/workflows/inductor-perf-compare.yml index a38bcadf7e5f7..0f459f42107fc 100644 --- a/.github/workflows/inductor-perf-compare.yml +++ b/.github/workflows/inductor-perf-compare.yml @@ -13,30 +13,43 @@ concurrency: permissions: read-all jobs: - get-label-type: - name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + get-default-label-prefix: + name: get-default-label-prefix + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} curr_branch: ${{ github.head_ref || github.ref_name }} curr_ref_type: ${{ github.ref_type }} + get-test-label-type: + name: get-test-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + check_experiments: "awsa100" + linux-focal-cuda12_1-py3_10-gcc9-inductor-build: name: cuda12.1-py3.10-gcc9-sm80 uses: ./.github/workflows/_linux-build.yml - needs: get-label-type + needs: + - get-default-label-prefix + - get-test-label-type with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + runner_prefix: "${{ needs.get-default-label-prefix.outputs.label-type }}" build-environment: linux-focal-cuda12.1-py3.10-gcc9-sm80 docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9-inductor-benchmarks cuda-arch-list: '8.0' test-matrix: | { include: [ - { config: "inductor_huggingface_perf_compare", shard: 1, num_shards: 1, runner: "linux.gcp.a100" }, - { config: "inductor_timm_perf_compare", shard: 1, num_shards: 2, runner: "linux.gcp.a100" }, - { config: "inductor_timm_perf_compare", shard: 2, num_shards: 2, runner: "linux.gcp.a100" }, - { config: "inductor_torchbench_perf_compare", shard: 1, num_shards: 1, runner: "linux.gcp.a100" }, + { config: "inductor_huggingface_perf_compare", shard: 1, num_shards: 1, runner: "${{ needs.get-test-label-type.outputs.label-type }}linux.gcp.a100" }, + { config: "inductor_timm_perf_compare", shard: 1, num_shards: 2, runner: "${{ needs.get-test-label-type.outputs.label-type }}linux.gcp.a100" }, + { config: "inductor_timm_perf_compare", shard: 2, num_shards: 2, runner: "${{ needs.get-test-label-type.outputs.label-type }}linux.gcp.a100" }, + { config: "inductor_torchbench_perf_compare", shard: 1, num_shards: 1, runner: "${{ needs.get-test-label-type.outputs.label-type }}linux.gcp.a100" }, ]} secrets: HUGGING_FACE_HUB_TOKEN: ${{ secrets.HUGGING_FACE_HUB_TOKEN }} diff --git a/.github/workflows/inductor-perf-test-nightly-a10g.yml b/.github/workflows/inductor-perf-test-nightly-a10g.yml index e42d7d4148c22..cd208bfde262d 100644 --- a/.github/workflows/inductor-perf-test-nightly-a10g.yml +++ b/.github/workflows/inductor-perf-test-nightly-a10g.yml @@ -70,7 +70,8 @@ permissions: read-all jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/inductor-perf-test-nightly-aarch64.yml b/.github/workflows/inductor-perf-test-nightly-aarch64.yml index 11479acfcd992..e51950ca74ad9 100644 --- a/.github/workflows/inductor-perf-test-nightly-aarch64.yml +++ b/.github/workflows/inductor-perf-test-nightly-aarch64.yml @@ -50,7 +50,8 @@ permissions: read-all jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/inductor-perf-test-nightly-x86.yml b/.github/workflows/inductor-perf-test-nightly-x86.yml index 83e8b26dd628e..0d9d79332f945 100644 --- a/.github/workflows/inductor-perf-test-nightly-x86.yml +++ b/.github/workflows/inductor-perf-test-nightly-x86.yml @@ -50,7 +50,8 @@ permissions: read-all jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/inductor-perf-test-nightly.yml b/.github/workflows/inductor-perf-test-nightly.yml index 5c7651d516d8b..84a935e196a76 100644 --- a/.github/workflows/inductor-perf-test-nightly.yml +++ b/.github/workflows/inductor-perf-test-nightly.yml @@ -68,7 +68,8 @@ permissions: read-all jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/inductor-periodic.yml b/.github/workflows/inductor-periodic.yml index 6bcb1be5ef094..2abc1e40a3699 100644 --- a/.github/workflows/inductor-periodic.yml +++ b/.github/workflows/inductor-periodic.yml @@ -20,7 +20,8 @@ permissions: read-all jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/inductor-rocm.yml b/.github/workflows/inductor-rocm.yml index 9bcc40ddbb6b2..faf386881734b 100644 --- a/.github/workflows/inductor-rocm.yml +++ b/.github/workflows/inductor-rocm.yml @@ -24,7 +24,8 @@ permissions: read-all jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 905e600e1bea0..92e09623dbb51 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -20,7 +20,8 @@ permissions: read-all jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -38,25 +39,25 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.12xlarge.nvidia.gpu" }, - { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_cpp_wrapper_abi_compatible", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" }, + { config: "inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "dynamic_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "aot_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_cpp_wrapper", shard: 1, num_shards: 1, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -81,8 +82,8 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -106,7 +107,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor-halide", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "inductor-halide", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, ]} secrets: inherit @@ -130,7 +131,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor-triton-cpu", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, + { config: "inductor-triton-cpu", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, ]} linux-jammy-cpu-py3_12-inductor-triton-cpu-test: @@ -155,8 +156,8 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 1, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "inductor_timm", shard: 2, num_shards: 2, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -181,47 +182,47 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor_avx512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "inductor_avx512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_freezing_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_freezing_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_freezing_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_inductor_amp_freezing_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, - { config: "cpu_inductor_amp_freezing_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, - { config: "cpu_inductor_amp_freezing_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, - { config: "cpu_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, - { config: "cpu_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.16xlarge.spr" }, - { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_aot_inductor_freezing_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_aot_inductor_freezing_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_aot_inductor_freezing_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_aot_inductor_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_aot_inductor_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_aot_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "cpu_aot_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "dynamic_cpu_aot_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.12xlarge" }, - { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" }, - { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, - { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, - { config: "cpu_inductor_freezing_avx2_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, - { config: "cpu_inductor_freezing_avx2_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, - { config: "cpu_inductor_freezing_avx2_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, - { config: "cpu_inductor_freezing_avx2_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, - { config: "cpu_inductor_freezing_avx2_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "inductor_avx512", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "inductor_avx512", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, + { config: "cpu_inductor_timm", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_timm", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_freezing_huggingface", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, + { config: "cpu_inductor_freezing_timm", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_freezing_timm", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_freezing_torchbench", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_freezing_torchbench", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_inductor_amp_freezing_huggingface", shard: 1, num_shards: 1, runner: "linux.16xlarge.spr" }, + { config: "cpu_inductor_amp_freezing_timm", shard: 1, num_shards: 2, runner: "linux.16xlarge.spr" }, + { config: "cpu_inductor_amp_freezing_timm", shard: 2, num_shards: 2, runner: "linux.16xlarge.spr" }, + { config: "cpu_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "linux.16xlarge.spr" }, + { config: "cpu_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "linux.16xlarge.spr" }, + { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, + { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_huggingface", shard: 1, num_shards: 1, runner: "linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_timm", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_timm", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_torchbench", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_aot_inductor_freezing_torchbench", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_aot_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "cpu_aot_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_freezing_torchbench", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_freezing_torchbench", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_amp_freezing_torchbench", shard: 1, num_shards: 2, runner: "linux.12xlarge" }, + { config: "dynamic_cpu_aot_inductor_amp_freezing_torchbench", shard: 2, num_shards: 2, runner: "linux.12xlarge" }, + { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "linux.24xl.spr-metal" }, + { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "linux.10xlarge.avx2" }, + { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_huggingface", shard: 1, num_shards: 1, runner: "linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_torchbench", shard: 1, num_shards: 2, runner: "linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_torchbench", shard: 2, num_shards: 2, runner: "linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_timm", shard: 1, num_shards: 2, runner: "linux.10xlarge.avx2" }, + { config: "cpu_inductor_freezing_avx2_timm", shard: 2, num_shards: 2, runner: "linux.10xlarge.avx2" }, ]} secrets: inherit diff --git a/.github/workflows/lint-autoformat.yml b/.github/workflows/lint-autoformat.yml index ad9ffd521751e..a20e5737857f2 100644 --- a/.github/workflows/lint-autoformat.yml +++ b/.github/workflows/lint-autoformat.yml @@ -2,8 +2,9 @@ name: Apply lint suggestions on: - pull_request: - types: [opened, synchronize, reopened] + push: + tags: + - ciflow/autoformat/* jobs: lintrunner-autoformat: @@ -11,7 +12,7 @@ jobs: contents: read pull-requests: write runs-on: lf.linux.2xlarge - if: ${{ github.repository_owner == 'pytorch' }} + if: ${{ github.repository_owner == 'pytorch' && github.event.pull_request.user.login != 'ezyang' && github.event.pull_request.user.login != 'malfet' && !startsWith(github.head_ref, 'export-') }} steps: - name: Checkout pytorch uses: pytorch/pytorch/.github/actions/checkout-pytorch@main diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 660a054b10691..a49de16d0b6f6 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -17,7 +17,7 @@ permissions: read-all jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/linux-aarch64.yml b/.github/workflows/linux-aarch64.yml index da01f1b1d733b..e7506d68cb440 100644 --- a/.github/workflows/linux-aarch64.yml +++ b/.github/workflows/linux-aarch64.yml @@ -14,7 +14,7 @@ jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/llm_td_retrieval.yml b/.github/workflows/llm_td_retrieval.yml index 64fbd1d4ccfdd..3be1c98ec6d01 100644 --- a/.github/workflows/llm_td_retrieval.yml +++ b/.github/workflows/llm_td_retrieval.yml @@ -8,10 +8,11 @@ permissions: contents: read jobs: - get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + # Don't run on forked repos + if: github.repository_owner == 'pytorch' + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -19,6 +20,8 @@ jobs: curr_ref_type: ${{ github.ref_type }} llm-retrieval: + # Don't run on forked repos + if: github.repository_owner == 'pytorch' runs-on: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" continue-on-error: true needs: get-label-type diff --git a/.github/workflows/nightly-rockset-uploads.yml b/.github/workflows/nightly-rockset-uploads.yml index 4bcf6548a6b82..b80c9d1c91787 100644 --- a/.github/workflows/nightly-rockset-uploads.yml +++ b/.github/workflows/nightly-rockset-uploads.yml @@ -32,7 +32,7 @@ jobs: cache: pip - run: | - pip3 install requests==2.32.2 rockset==1.0.3 boto3==1.19.12 + pip3 install requests==2.32.2 rockset==1.0.3 boto3==1.35.42 - name: Upload external contribution stats uses: nick-fields/retry@v3.0.0 diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 5057e9da2d1dd..c806b525c6425 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -19,7 +19,8 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/periodic.yml b/.github/workflows/periodic.yml index 455aa7ce84ecc..3711b4dd68cf2 100644 --- a/.github/workflows/periodic.yml +++ b/.github/workflows/periodic.yml @@ -40,7 +40,8 @@ jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -57,10 +58,10 @@ jobs: docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3_10-gcc9-test: @@ -89,10 +90,10 @@ jobs: { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, - { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} @@ -333,6 +334,7 @@ jobs: name: linux-focal-cuda12.1-py3.10-gcc9-experimental-split-build uses: ./.github/workflows/_linux-build.yml needs: get-label-type + if: false # See https://github.com/pytorch/pytorch/issues/138750 with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" use_split_build: true @@ -340,10 +342,10 @@ jobs: docker-image-name: pytorch-linux-focal-cuda12.1-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, ]} @@ -363,6 +365,7 @@ jobs: name: linux-focal-cuda11.8-py3.9-gcc9-experimental-split-build uses: ./.github/workflows/_linux-build.yml needs: get-label-type + if: false # See https://github.com/pytorch/pytorch/issues/138750 with: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" use_split_build: true @@ -385,3 +388,33 @@ jobs: build-environment: linux-focal-cuda11.8-py3.9-gcc9-experimental-split-build docker-image: ${{ needs.linux-focal-cuda11_8-py3_9-gcc9-experimental-split-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-cuda11_8-py3_9-gcc9-experimental-split-build.outputs.test-matrix }} + + linux-focal-cuda11_8-py3_10-gcc9-experimental-split-build: + name: linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + if: false # See https://github.com/pytorch/pytorch/issues/138750 + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + use_split_build: true + build-environment: linux-focal-cuda11.8-py3.10-gcc9 + docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 + cuda-arch-list: '7.5' + test-matrix: | + { include: [ + { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + ]} + + linux-focal-cuda11_8-py3_10-gcc9-experimental-split-build-test: + name: linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build-test + uses: ./.github/workflows/_linux-test.yml + needs: + - linux-focal-cuda11_8-py3_10-gcc9-experimental-split-build + - target-determination + with: + timeout-minutes: 360 + build-environment: linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build + docker-image: ${{ needs.linux-focal-cuda11_8-py3_10-gcc9-experimental-split-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-focal-cuda11_8-py3_10-gcc9-experimental-split-build.outputs.test-matrix }} diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index da1730f52b379..206c350aa42fb 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -37,7 +37,8 @@ jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -53,10 +54,11 @@ jobs: docker-image-name: pytorch-linux-jammy-py3.9-gcc11 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "docs_test", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "backwards_compat", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, @@ -185,10 +187,11 @@ jobs: docker-image-name: pytorch-linux-focal-py3.9-clang10 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, @@ -217,10 +220,11 @@ jobs: docker-image-name: pytorch-linux-focal-py3.11-clang10 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "crossref", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "crossref", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, @@ -251,10 +255,11 @@ jobs: docker-image-name: pytorch-linux-focal-py3.12-clang10 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - { config: "default", shard: 2, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - { config: "default", shard: 3, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - { config: "default", shard: 4, num_shards: 4, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 3, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 4, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "default", shard: 5, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "dynamo", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "dynamo", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, { config: "dynamo", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, @@ -280,11 +285,12 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" build-environment: linux-focal-cuda11.8-py3.10-gcc9 docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 + cuda-arch-list: '7.5' test-matrix: | { include: [ - { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, - { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, - { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, + { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, + { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g4dn.12xlarge.nvidia.gpu" }, ]} secrets: inherit @@ -578,6 +584,7 @@ jobs: secrets: inherit linux-focal-py3_12-clang10-experimental-split-build: + if: false # See https://github.com/pytorch/pytorch/issues/138750 name: linux-focal-py3.12-clang10-experimental-split-build uses: ./.github/workflows/_linux-build.yml needs: get-label-type diff --git a/.github/workflows/rocm.yml b/.github/workflows/rocm.yml index 051a5eb1a9b71..6aa01ff179eff 100644 --- a/.github/workflows/rocm.yml +++ b/.github/workflows/rocm.yml @@ -26,6 +26,7 @@ jobs: contents: read linux-focal-rocm6_2-py3_10-build: + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' name: linux-focal-rocm6.2-py3.10 uses: ./.github/workflows/_linux-build.yml with: diff --git a/.github/workflows/slow.yml b/.github/workflows/slow.yml index 66d346c372c6a..2aab56e971f81 100644 --- a/.github/workflows/slow.yml +++ b/.github/workflows/slow.yml @@ -38,7 +38,8 @@ jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -56,14 +57,14 @@ jobs: cuda-arch-list: 8.6 test-matrix: | { include: [ - { config: "default", shard: 1, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 2, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 3, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 4, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 5, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 6, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 7, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "default", shard: 8, num_shards: 8, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 1, num_shards: 8, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 2, num_shards: 8, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 3, num_shards: 8, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 4, num_shards: 8, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 5, num_shards: 8, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 6, num_shards: 8, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 7, num_shards: 8, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "default", shard: 8, num_shards: 8, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3-gcc9-slow-gradcheck-test: @@ -89,9 +90,9 @@ jobs: cuda-arch-list: 8.6 test-matrix: | { include: [ - { config: "slow", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "slow", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, - { config: "slow", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.g5.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 1, num_shards: 3, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 2, num_shards: 3, runner: "linux.g5.4xlarge.nvidia.gpu" }, + { config: "slow", shard: 3, num_shards: 3, runner: "linux.g5.4xlarge.nvidia.gpu" }, ]} linux-focal-cuda12_1-py3_10-gcc9-sm86-test: @@ -115,8 +116,8 @@ jobs: docker-image-name: pytorch-linux-focal-py3.9-clang10 test-matrix: | { include: [ - { config: "slow", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "slow", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "slow", shard: 1, num_shards: 2, runner: "linux.2xlarge" }, + { config: "slow", shard: 2, num_shards: 2, runner: "linux.2xlarge" }, ]} linux-focal-py3_9-clang10-test: @@ -168,9 +169,9 @@ jobs: docker-image-name: pytorch-linux-jammy-py3-clang15-asan test-matrix: | { include: [ - { config: "slow", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - { config: "slow", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, - { config: "slow", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "slow", shard: 1, num_shards: 3, runner: "linux.4xlarge" }, + { config: "slow", shard: 2, num_shards: 3, runner: "linux.4xlarge" }, + { config: "slow", shard: 3, num_shards: 3, runner: "linux.4xlarge" }, ]} sync-tag: asan-build secrets: inherit diff --git a/.github/workflows/target-determination-indexer.yml b/.github/workflows/target-determination-indexer.yml index 373f464eae139..1025ba9b2c234 100644 --- a/.github/workflows/target-determination-indexer.yml +++ b/.github/workflows/target-determination-indexer.yml @@ -12,7 +12,7 @@ permissions: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/target_determination.yml b/.github/workflows/target_determination.yml index f7b2f383f314e..523f816fa49f5 100644 --- a/.github/workflows/target_determination.yml +++ b/.github/workflows/target_determination.yml @@ -7,7 +7,9 @@ jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + # Don't run on forked repos + if: github.repository_owner == 'pytorch' + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -70,7 +72,7 @@ jobs: PR_NUMBER: ${{ github.event.pull_request.number }} run: | unzip -o .additional_ci_files/llm_results/mappings.zip -d .additional_ci_files/llm_results || true - python3 -m pip install boto3==1.19.12 + python3 -m pip install boto3==1.35.42 python3 tools/testing/do_target_determination_for_s3.py - name: Upload TD results to s3 diff --git a/.github/workflows/torchbench.yml b/.github/workflows/torchbench.yml index 4e2098e589238..6a599809060e7 100644 --- a/.github/workflows/torchbench.yml +++ b/.github/workflows/torchbench.yml @@ -13,7 +13,7 @@ concurrency: jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.github/workflows/trunk.yml b/.github/workflows/trunk.yml index 112c35c4d7acd..655fb72b20e69 100644 --- a/.github/workflows/trunk.yml +++ b/.github/workflows/trunk.yml @@ -36,7 +36,8 @@ jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main + if: github.event_name != 'schedule' || github.repository == 'pytorch/pytorch' with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} @@ -256,6 +257,7 @@ jobs: tests-to-include: "test_nn test_torch test_cuda test_ops test_unary_ufuncs test_binary_ufuncs test_autograd inductor/test_torchinductor distributed/test_c10d_common distributed/test_c10d_nccl" linux-focal-cuda12_4-py3_10-gcc9-experimental-split-build: + if: false # See https://github.com/pytorch/pytorch/issues/138750 name: linux-focal-cuda12.4-py3.10-gcc9-experimental-split-build uses: ./.github/workflows/_linux-build.yml needs: get-label-type @@ -266,10 +268,10 @@ jobs: docker-image-name: pytorch-linux-focal-cuda12.4-cudnn9-py3-gcc9 test-matrix: | { include: [ - { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, - { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge" }, + { config: "nogpu_AVX512", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "nogpu_AVX512", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "nogpu_NO_AVX2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, + { config: "nogpu_NO_AVX2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge" }, { config: "jit_legacy", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, { config: "default", shard: 1, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, { config: "default", shard: 2, num_shards: 5, runner: "${{ needs.get-label-type.outputs.label-type }}linux.4xlarge.nvidia.gpu" }, @@ -288,31 +290,3 @@ jobs: build-environment: linux-focal-cuda12.4-py3.10-gcc9-experimental-split-build docker-image: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-experimental-split-build.outputs.docker-image }} test-matrix: ${{ needs.linux-focal-cuda12_4-py3_10-gcc9-experimental-split-build.outputs.test-matrix }} - - linux-focal-cuda11_8-py3_10-gcc9-experimental-split-build: - name: linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build - uses: ./.github/workflows/_linux-build.yml - needs: get-label-type - with: - runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" - use_split_build: true - build-environment: linux-focal-cuda11.8-py3.10-gcc9 - docker-image-name: pytorch-linux-focal-cuda11.8-cudnn9-py3-gcc9 - test-matrix: | - { include: [ - { config: "distributed", shard: 1, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, - { config: "distributed", shard: 2, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, - { config: "distributed", shard: 3, num_shards: 3, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.nvidia.gpu" }, - ]} - - linux-focal-cuda11_8-py3_10-gcc9-experimental-split-build-test: - name: linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build-test - uses: ./.github/workflows/_linux-test.yml - needs: - - linux-focal-cuda11_8-py3_10-gcc9-experimental-split-build - - target-determination - with: - timeout-minutes: 360 - build-environment: linux-focal-cuda11.8-py3.10-gcc9-experimental-split-build - docker-image: ${{ needs.linux-focal-cuda11_8-py3_10-gcc9-experimental-split-build.outputs.docker-image }} - test-matrix: ${{ needs.linux-focal-cuda11_8-py3_10-gcc9-experimental-split-build.outputs.test-matrix }} diff --git a/.github/workflows/update-viablestrict.yml b/.github/workflows/update-viablestrict.yml index f6b3e39817171..bf179e50766a2 100644 --- a/.github/workflows/update-viablestrict.yml +++ b/.github/workflows/update-viablestrict.yml @@ -25,7 +25,9 @@ jobs: stable-branch: viable/strict requires: '[\"pull\", \"trunk\", \"lint\", \"linux-binary\"]' secret-bot-token: ${{ secrets.MERGEBOT_TOKEN }} - rockset-api-key: ${{ secrets.ROCKSET_API_KEY }} + clickhouse-url: ${{ secrets.CLICKHOUSE_URL }} + clickhouse-username: ${{ secrets.CLICKHOUSE_VIABLESTRICT_USERNAME }} + clickhouse-password: ${{ secrets.CLICKHOUSE_VIABLESTRICT_PASSWORD }} - name: Authenticate to AWS with OIDC uses: aws-actions/configure-aws-credentials@v4 diff --git a/.github/workflows/update_pytorch_labels.yml b/.github/workflows/update_pytorch_labels.yml index db09474fb2120..7e01727895578 100644 --- a/.github/workflows/update_pytorch_labels.yml +++ b/.github/workflows/update_pytorch_labels.yml @@ -29,5 +29,5 @@ jobs: aws-region: us-east-1 - name: Update PyTorch labels list in S3 run: | - python3 -m pip install boto3==1.19.12 + python3 -m pip install boto3==1.35.42 .github/scripts/export_pytorch_labels.py pytorch pytorch diff --git a/.github/workflows/upload-test-stats.yml b/.github/workflows/upload-test-stats.yml index f9e5593bf66ff..8d5072e054f6d 100644 --- a/.github/workflows/upload-test-stats.yml +++ b/.github/workflows/upload-test-stats.yml @@ -53,7 +53,7 @@ jobs: cache: pip - run: | - pip3 install requests==2.32.2 rockset==1.0.3 boto3==1.19.12 + pip3 install requests==2.32.2 rockset==1.0.3 boto3==1.35.42 - name: Upload test artifacts id: upload-s3 diff --git a/.github/workflows/upload-torch-dynamo-perf-stats.yml b/.github/workflows/upload-torch-dynamo-perf-stats.yml index b4b55a7b473ea..27a39ec342482 100644 --- a/.github/workflows/upload-torch-dynamo-perf-stats.yml +++ b/.github/workflows/upload-torch-dynamo-perf-stats.yml @@ -49,7 +49,7 @@ jobs: cache: pip - run: | - pip3 install requests==2.32.2 rockset==1.0.3 boto3==1.19.12 + pip3 install requests==2.32.2 rockset==1.0.3 boto3==1.35.42 - name: Upload torch dynamo performance stats to S3 id: upload-s3 diff --git a/.github/workflows/upload_test_stats_intermediate.yml b/.github/workflows/upload_test_stats_intermediate.yml index d560f619db43d..0c02e3c372338 100644 --- a/.github/workflows/upload_test_stats_intermediate.yml +++ b/.github/workflows/upload_test_stats_intermediate.yml @@ -28,7 +28,7 @@ jobs: cache: pip - run: | - pip3 install requests==2.32.2 rockset==1.0.3 boto3==1.19.12 + pip3 install requests==2.32.2 rockset==1.0.3 boto3==1.35.42 - name: Upload test stats env: diff --git a/.github/workflows/xpu.yml b/.github/workflows/xpu.yml index 17fd3e4dfc6b7..10b4008abbd23 100644 --- a/.github/workflows/xpu.yml +++ b/.github/workflows/xpu.yml @@ -14,7 +14,7 @@ jobs: get-label-type: name: get-label-type - uses: ./.github/workflows/_runner-determinator.yml + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@main with: triggering_actor: ${{ github.triggering_actor }} issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} diff --git a/.gitmodules b/.gitmodules index 26b47a3a85c3c..36d5becb57c3b 100644 --- a/.gitmodules +++ b/.gitmodules @@ -127,3 +127,7 @@ [submodule "third_party/NVTX"] path = third_party/NVTX url = https://github.com/NVIDIA/NVTX.git +[submodule "third_party/composable_kernel"] + path = third_party/composable_kernel + url = https://github.com/ROCm/composable_kernel.git + branch = develop diff --git a/.lintrunner.toml b/.lintrunner.toml index 930c4779b0e4d..d82ee315e73a5 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -136,11 +136,9 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'numpy==1.24.3 ; python_version == "3.8"', 'numpy==1.26.0 ; python_version >= "3.9"', 'expecttest==0.2.1', 'mypy==1.11.2', - 'sympy==1.12.1 ; python_version == "3.8"', 'sympy==1.13.0 ; python_version >= "3.9"', 'types-requests==2.27.25', 'types-PyYAML==6.0.7', @@ -243,7 +241,9 @@ exclude_patterns = [ 'c10/util/*inl.h', 'c10/test/**/*.h', 'third_party/**/*', - 'torch/csrc/api/**', + 'torch/csrc/api/include/torch/nn/modules/common.h', + 'torch/csrc/api/include/torch/linalg.h', + 'torch/csrc/api/include/torch/nn/pimpl-inl.h', 'torch/csrc/autograd/generated/**', 'torch/csrc/distributed/**/*', 'torch/csrc/dynamo/eval_frame.h', @@ -376,17 +376,6 @@ command = [ ] is_formatter = true -[[linter]] -code = 'CONSTEXPR' -include_patterns=['aten/src/ATen/native/cuda/*.cu'] -command = [ - 'python3', - 'tools/linter/adapters/constexpr_linter.py', - '--', - '@{{PATHSFILE}}', -] -is_formatter = true - [[linter]] code = 'SPACES' include_patterns = ['**'] @@ -1232,87 +1221,6 @@ exclude_patterns = [ 'torch/fft/__init__.py', 'torch/func/__init__.py', 'torch/futures/__init__.py', - 'torch/fx/__init__.py', - 'torch/fx/_compatibility.py', - 'torch/fx/_symbolic_trace.py', - 'torch/fx/annotate.py', - 'torch/fx/config.py', - 'torch/fx/experimental/__init__.py', - 'torch/fx/experimental/accelerator_partitioner.py', - 'torch/fx/experimental/const_fold.py', - 'torch/fx/experimental/debug.py', - 'torch/fx/experimental/graph_gradual_typechecker.py', - 'torch/fx/experimental/merge_matmul.py', - 'torch/fx/experimental/meta_tracer.py', - 'torch/fx/experimental/migrate_gradual_types/__init__.py', - 'torch/fx/experimental/migrate_gradual_types/constraint.py', - 'torch/fx/experimental/migrate_gradual_types/constraint_generator.py', - 'torch/fx/experimental/migrate_gradual_types/constraint_transformation.py', - 'torch/fx/experimental/migrate_gradual_types/operation.py', - 'torch/fx/experimental/migrate_gradual_types/transform_to_z3.py', - 'torch/fx/experimental/migrate_gradual_types/util.py', - 'torch/fx/experimental/migrate_gradual_types/z3_types.py', - 'torch/fx/experimental/normalize.py', - 'torch/fx/experimental/optimization.py', - 'torch/fx/experimental/partitioner_utils.py', - 'torch/fx/experimental/refinement_types.py', - 'torch/fx/experimental/rewriter.py', - 'torch/fx/experimental/schema_type_annotation.py', - 'torch/fx/experimental/unification/__init__.py', - 'torch/fx/experimental/unification/core.py', - 'torch/fx/experimental/unification/dispatch.py', - 'torch/fx/experimental/unification/match.py', - 'torch/fx/experimental/unification/more.py', - 'torch/fx/experimental/unification/multipledispatch/__init__.py', - 'torch/fx/experimental/unification/multipledispatch/conflict.py', - 'torch/fx/experimental/unification/multipledispatch/core.py', - 'torch/fx/experimental/unification/multipledispatch/dispatcher.py', - 'torch/fx/experimental/unification/multipledispatch/utils.py', - 'torch/fx/experimental/unification/multipledispatch/variadic.py', - 'torch/fx/experimental/unification/unification_tools.py', - 'torch/fx/experimental/unification/utils.py', - 'torch/fx/experimental/unification/variable.py', - 'torch/fx/experimental/unify_refinements.py', - 'torch/fx/graph.py', - 'torch/fx/graph_module.py', - 'torch/fx/interpreter.py', - 'torch/fx/node.py', - 'torch/fx/operator_schemas.py', - 'torch/fx/passes/__init__.py', - 'torch/fx/passes/annotate_getitem_nodes.py', - 'torch/fx/passes/backends/__init__.py', - 'torch/fx/passes/backends/cudagraphs.py', - 'torch/fx/passes/dialect/__init__.py', - 'torch/fx/passes/dialect/common/__init__.py', - 'torch/fx/passes/dialect/common/cse_pass.py', - 'torch/fx/passes/fake_tensor_prop.py', - 'torch/fx/passes/graph_drawer.py', - 'torch/fx/passes/graph_manipulation.py', - 'torch/fx/passes/infra/__init__.py', - 'torch/fx/passes/infra/partitioner.py', - 'torch/fx/passes/infra/pass_base.py', - 'torch/fx/passes/infra/pass_manager.py', - 'torch/fx/passes/net_min_base.py', - 'torch/fx/passes/operator_support.py', - 'torch/fx/passes/param_fetch.py', - 'torch/fx/passes/pass_manager.py', - 'torch/fx/passes/reinplace.py', - 'torch/fx/passes/shape_prop.py', - 'torch/fx/passes/split_module.py', - 'torch/fx/passes/split_utils.py', - 'torch/fx/passes/splitter_base.py', - 'torch/fx/passes/tests/__init__.py', - 'torch/fx/passes/tests/test_pass_manager.py', - 'torch/fx/passes/tools_common.py', - 'torch/fx/passes/utils/__init__.py', - 'torch/fx/passes/utils/common.py', - 'torch/fx/passes/utils/fuser_utils.py', - 'torch/fx/passes/utils/matcher_utils.py', - 'torch/fx/passes/utils/source_matcher_utils.py', - 'torch/fx/proxy.py', - 'torch/fx/subgraph_rewriter.py', - 'torch/fx/tensor_type.py', - 'torch/fx/traceback.py', 'torch/linalg/__init__.py', 'torch/monitor/__init__.py', 'torch/nested/__init__.py', @@ -1483,7 +1391,7 @@ init_command = [ 'black==23.12.1', 'usort==1.0.8.post1', 'isort==5.13.2', - 'ruff==0.6.3', # sync with RUFF + 'ruff==0.7.0', # sync with RUFF ] is_formatter = true @@ -1568,7 +1476,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'ruff==0.6.3', # sync with PYFMT + 'ruff==0.7.0', # sync with PYFMT ] is_formatter = true diff --git a/CMakeLists.txt b/CMakeLists.txt index a104c705b12b0..30377996b3e39 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -341,19 +341,6 @@ cmake_dependent_option( cmake_dependent_option(USE_SYSTEM_UCC "Use system-wide UCC" OFF "USE_UCC" OFF) cmake_dependent_option(USE_C10D_UCC "USE C10D UCC" ON "USE_DISTRIBUTED;USE_UCC" OFF) -cmake_dependent_option( - USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON - "USE_DISTRIBUTED" OFF) -cmake_dependent_option( - USE_GLOO_WITH_OPENSSL - "Use Gloo with OpenSSL. Only available if USE_GLOO is on." OFF - "USE_GLOO AND LINUX AND NOT INTERN_BUILD_MOBILE" OFF) -cmake_dependent_option(USE_C10D_GLOO "USE C10D GLOO" ON - "USE_DISTRIBUTED;USE_GLOO" OFF) -cmake_dependent_option(USE_C10D_NCCL "USE C10D NCCL" ON - "USE_DISTRIBUTED;USE_NCCL" OFF) -cmake_dependent_option(USE_C10D_MPI "USE C10D MPI" ON "USE_DISTRIBUTED;USE_MPI" - OFF) cmake_dependent_option( USE_GLOO "Use Gloo. Only available if USE_DISTRIBUTED is on." ON "USE_DISTRIBUTED" OFF) @@ -394,8 +381,10 @@ cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler" option(USE_MIMALLOC "Use mimalloc" OFF) # Enable third party mimalloc library to improve memory allocation performance # on Windows. +option(USE_MIMALLOC_ON_MKL "Use mimalloc on MKL" OFF) if(WIN32) set(USE_MIMALLOC ON) + set(USE_MIMALLOC_ON_MKL ON) endif() if(USE_CCACHE) @@ -469,6 +458,7 @@ option(USE_SYSTEM_FXDIV "Use system-provided fxdiv." OFF) option(USE_SYSTEM_BENCHMARK "Use system-provided google benchmark." OFF) option(USE_SYSTEM_ONNX "Use system-provided onnx." OFF) option(USE_SYSTEM_XNNPACK "Use system-provided xnnpack." OFF) +OPTION(USE_SYSTEM_NVTX "Use system-provided nvtx." OFF) option(USE_GOLD_LINKER "Use ld.gold to link" OFF) if(USE_SYSTEM_LIBS) set(USE_SYSTEM_CPUINFO ON) @@ -487,6 +477,7 @@ if(USE_SYSTEM_LIBS) if(USE_NCCL) set(USE_SYSTEM_NCCL ON) endif() + set(USE_SYSTEM_NVTX ON) endif() # /Z7 override option When generating debug symbols, CMake default to use the @@ -1096,6 +1087,10 @@ if(NOT MSVC) append_cxx_flag_if_supported("-fno-math-errno" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-fno-trapping-math" CMAKE_CXX_FLAGS) append_cxx_flag_if_supported("-Werror=format" CMAKE_CXX_FLAGS) + if(CMAKE_COMPILER_IS_GNUCXX AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13) + append_cxx_flag_if_supported("-Wno-error=dangling-reference" CMAKE_CXX_FLAGS) + append_cxx_flag_if_supported("-Wno-error=redundant-move" CMAKE_CXX_FLAGS) + endif() else() # skip unwanted includes from windows.h add_compile_definitions(WIN32_LEAN_AND_MEAN) @@ -1244,6 +1239,10 @@ if(USE_MIMALLOC) include_directories(third_party/mimalloc/include) endif() +if(USE_MIMALLOC AND USE_MIMALLOC_ON_MKL) + add_definitions(-DUSE_MIMALLOC_ON_MKL) +endif() + # ---[ Main build add_subdirectory(c10) add_subdirectory(caffe2) diff --git a/CODEOWNERS b/CODEOWNERS index 4a908d79f99ac..0cee05d4b8351 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -102,6 +102,10 @@ test/functorch/test_vmap.py @zou3519 @chillee @kshitij12345 torch/_higher_order_ops/*.py @zou3519 torch/_dynamo/variables/higher_order_ops.py @zou3519 +# AOTAutograd +torch/_functorch/_aot_autograd/*.py @bdhirsh +torch/_functorch/aot_autograd.py @bdhirsh + # torch MPS test/test_mps.py @kulinseth @malfet aten/src/ATen/mps/ @kulinseth @malfet @@ -112,10 +116,10 @@ aten/src/ATen/detail/MTIAHooksInterface.h @egienvalue torch/csrc/mtia/ @egienvalue # Profiler -torch/csrc/autograd/profiler* @aaronenyeshi @sraikund16 -torch/autograd/profiler* @aaronenyeshi @sraikund16 -torch/csrc/profiler/ @aaronenyeshi @sraikund16 -torch/profiler/ @aaronenyeshi @sraikund16 +torch/csrc/autograd/profiler* @sraikund16 +torch/autograd/profiler* @sraikund16 +torch/csrc/profiler/ @sraikund16 +torch/profiler/ @sraikund16 # AOTDispatch tests test/functorch/test_aotdispatch.py @ezyang @Chillee diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 99e47ef502998..c2eab67762074 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -286,6 +286,11 @@ The following packages should be installed with either `conda` or `pip`: - `expecttest` and `hypothesis` - required to run tests - `mypy` - recommended for linting - `pytest` - recommended to run tests more selectively +Running +``` +pip install -r requirements +``` +will install these dependencies for you. All PyTorch test suites are located in the `test` folder and start with `test_`. Run the entire test @@ -878,7 +883,7 @@ Process 87741 stopped * thread #1, queue = 'com.apple.main-thread', stop reason = breakpoint 1.1 frame #0: 0x00000001024e2628 libtorch_python.dylib`at::indexing::impl::applySelect(self=0x00000001004ee8a8, dim=0, index=(data_ = 3), real_dim=0, (null)=0x000000016fdfe535, self_sizes= Has Value=true ) at TensorIndexing.h:239:7 236 const at::Device& /*self_device*/, - 237 const c10::optional& self_sizes) { + 237 const std::optional& self_sizes) { 238 // See NOTE [nested tensor size for indexing] -> 239 if (self_sizes.has_value()) { 240 auto maybe_index = index.maybe_as_int(); @@ -1081,10 +1086,6 @@ Here are a few well known pitfalls and workarounds: catch all of these problems: stay vigilant to the possibility that your crash is due to a real memory problem. -* (NVCC) `c10::optional` does not work when used from device code. Don't use - it from kernels. Upstream issue: https://github.com/akrzemi1/Optional/issues/58 - and our local issue #10329. - * `constexpr` generally works less well on MSVC. * The idiom `static_assert(f() == f())` to test if `f` is constexpr diff --git a/README.md b/README.md index c7dd72ccc77c6..ead539403d9ae 100644 --- a/README.md +++ b/README.md @@ -208,6 +208,8 @@ If you want to compile with ROCm support, install - [AMD ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html) 4.0 and above installation - ROCm is currently supported only for Linux systems. +By default the build system expects ROCm to be installed in `/opt/rocm`. If ROCm is installed in a different directory, the `ROCM_PATH` environment variable must be set to the ROCm installation directory. The build system automatically detects the AMD GPU architecture. Optionally, the AMD GPU architecture can be explicitly set with the `PYTORCH_ROCM_ARCH` environment variable [AMD GPU architecture](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/reference/system-requirements.html#supported-gpus) + If you want to disable ROCm support, export the environment variable `USE_ROCM=0`. Other potentially useful environment variables may be found in `setup.py`. @@ -289,7 +291,7 @@ python tools/amd_build/build_amd.py Install PyTorch ```bash -export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} +export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" python setup.py develop ``` @@ -371,14 +373,14 @@ with such a step. On Linux ```bash -export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} +export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" python setup.py build --cmake-only ccmake build # or cmake-gui build ``` On macOS ```bash -export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"} +export CMAKE_PREFIX_PATH="${CONDA_PREFIX:-'$(dirname $(which conda))/../'}:${CMAKE_PREFIX_PATH}" MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++ python setup.py build --cmake-only ccmake build # or cmake-gui build ``` diff --git a/aten/src/ATen/BlasBackend.h b/aten/src/ATen/BlasBackend.h index 7f8c321ad9fa2..521addefc5ee1 100644 --- a/aten/src/ATen/BlasBackend.h +++ b/aten/src/ATen/BlasBackend.h @@ -7,7 +7,7 @@ namespace at { -enum class BlasBackend : int8_t { Cublas, Cublaslt }; +enum class BlasBackend : int8_t { Cublas, Cublaslt, Ck }; inline std::string BlasBackendToString(at::BlasBackend backend) { switch (backend) { @@ -15,6 +15,8 @@ inline std::string BlasBackendToString(at::BlasBackend backend) { return "at::BlasBackend::Cublas"; case BlasBackend::Cublaslt: return "at::BlasBackend::Cublaslt"; + case BlasBackend::Ck: + return "at::BlasBackend::Ck"; default: TORCH_CHECK(false, "Unknown blas backend"); } diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 16e4641ddf205..a0a845eed6562 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -54,7 +54,7 @@ if(NOT BUILD_LITE_INTERPRETER) endif() EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS}) -file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/sve/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h") +file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec128/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/sve/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h") file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp" "functorch/*.cpp") file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh" "cuda/tunable/*.cuh" "cuda/tunable/*.h") file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp" "cuda/tunable/*.cpp") @@ -266,6 +266,9 @@ endif() if(USE_CUDA) list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/cuda) + # Next two lines are needed because TunableOp uses third-party/fmt + list(APPEND ATen_CUDA_INCLUDE $) + list(APPEND ATen_CUDA_DEPENDENCY_LIBS fmt::fmt-header-only) list(APPEND ATen_CUDA_CU_SRCS ${cuda_cu} ${native_cuda_cu} @@ -309,6 +312,9 @@ if(USE_ROCM) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include) list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include) + # Next two lines are needed because TunableOp uses third-party/fmt + list(APPEND ATen_HIP_INCLUDE $) + list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only) list(APPEND ATen_HIP_SRCS ${ATen_HIP_SRCS} ${hip_hip} @@ -422,7 +428,7 @@ if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(s390x|ppc64le)$") list(APPEND ATen_CPU_DEPENDENCY_LIBS cpuinfo) endif() -if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE) +if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE AND NOT (MSVC AND CMAKE_SYSTEM_PROCESSOR STREQUAL "ARM64")) if(NOT MSVC) # Bump up optimization level for sleef to -O1, since at -O0 the compiler # excessively spills intermediate vector registers to the stack diff --git a/aten/src/ATen/CPUApplyUtils.h b/aten/src/ATen/CPUApplyUtils.h index 5c524ef97c475..780510579a7ef 100644 --- a/aten/src/ATen/CPUApplyUtils.h +++ b/aten/src/ATen/CPUApplyUtils.h @@ -64,11 +64,15 @@ struct strided_tensor_iter_fixed { int64_t strides_[N] = {0}; strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete; - void operator=(strided_tensor_iter_fixed const& x) = delete; - strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) = default; + strided_tensor_iter_fixed& operator=(strided_tensor_iter_fixed const& x) = + delete; + strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) noexcept = default; + strided_tensor_iter_fixed& operator=(strided_tensor_iter_fixed&& x) noexcept = + default; + ~strided_tensor_iter_fixed() noexcept = default; strided_tensor_iter_fixed( Tensor& tensor, - C10_UNUSED bool sort_strides = false) + [[maybe_unused]] bool sort_strides = false) : data_(tensor.data_ptr()) { std::memset(counter_, 0, sizeof(int64_t) * N); if (tensor.dim() > 0) { @@ -93,8 +97,10 @@ struct strided_tensor_iter { std::vector strides_; strided_tensor_iter(strided_tensor_iter const&) = delete; - void operator=(strided_tensor_iter const& x) = delete; - strided_tensor_iter(strided_tensor_iter&&) = default; + strided_tensor_iter& operator=(strided_tensor_iter const& x) = delete; + strided_tensor_iter(strided_tensor_iter&&) noexcept = default; + strided_tensor_iter& operator=(strided_tensor_iter&&) noexcept = default; + ~strided_tensor_iter() noexcept = default; strided_tensor_iter(Tensor& tensor) : data_(tensor.data_ptr()), dim_(tensor.ndimension()), @@ -136,7 +142,7 @@ inline bool _apply_preamble(ArrayRef tensors) { checkDeviceType("CPU_tensor_apply", tensors, kCPU); checkLayout("CPU_tensor_apply", tensors, kStrided); if (!_all_equal_numel(tensors)) - AT_ERROR(_all_equal_numel_error(tensors)); + TORCH_CHECK(false, _all_equal_numel_error(tensors)); // An empty tensor has no elements for (auto& t : tensors) if (t.numel() == 0) diff --git a/aten/src/ATen/CPUFixedAllocator.h b/aten/src/ATen/CPUFixedAllocator.h index cf621f34cc637..e4429867f254b 100644 --- a/aten/src/ATen/CPUFixedAllocator.h +++ b/aten/src/ATen/CPUFixedAllocator.h @@ -12,11 +12,11 @@ namespace at { static cpu_fixed_malloc(void*, ptrdiff_t) { - AT_ERROR("attempting to resize a tensor view of an external blob"); + TORCH_CHECK(false, "attempting to resize a tensor view of an external blob"); } static cpu_fixed_realloc(void*, void*, ptrdiff_t) { - AT_ERROR("attempting to resize a tensor view of an external blob"); + TORCH_CHECK(false, "attempting to resize a tensor view of an external blob"); } static cpu_fixed_free(void* state, void* allocation) { diff --git a/aten/src/ATen/CPUGeneratorImpl.cpp b/aten/src/ATen/CPUGeneratorImpl.cpp index 0fcf14bab464d..313069ce3336f 100644 --- a/aten/src/ATen/CPUGeneratorImpl.cpp +++ b/aten/src/ATen/CPUGeneratorImpl.cpp @@ -189,7 +189,7 @@ void CPUGeneratorImpl::set_state(const c10::TensorImpl& new_state) { double_normal_sample = std::optional(legacy_pod->normal_y); } } else { - AT_ERROR("Expected either a CPUGeneratorImplStateLegacy of size ", size_legacy, + TORCH_CHECK(false, "Expected either a CPUGeneratorImplStateLegacy of size ", size_legacy, " or a CPUGeneratorImplState of size ", size_current, " but found the input RNG state size to be ", new_state_size); } diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index ff8ceed7e8de8..c151cffdf045f 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -290,7 +290,12 @@ at::BlasBackend Context::blasPreferredBackend() { #ifdef USE_ROCM if (blas_preferred_backend == at::BlasBackend::Cublaslt) { static const bool hipblaslt_unsupported = []() { - static const std::vector archs = {"gfx90a", "gfx940", "gfx941", "gfx942"}; + static const std::vector archs = { + "gfx90a", "gfx940", "gfx941", "gfx942", +#if ROCM_VERSION >= 60300 + "gfx1100", "gfx1101" +#endif + }; for (auto index: c10::irange(getNumGPUs())) { if (!detail::getCUDAHooks().isGPUArch(index, archs)) { TORCH_WARN_ONCE( @@ -316,6 +321,8 @@ void Context::setBlasPreferredBackend(at::BlasBackend b) { #else TORCH_CHECK((b != at::BlasBackend::Cublaslt) || hasCuBLASLt(), "Cannot set preferred backend to cuBLASLt if PyTorch has not been compiled with cuBLASLt."); + TORCH_CHECK((b != at::BlasBackend::Ck) || hasROCM(), + "Cannot set preferred backend to Ck if PyTorch has not been compiled for ROCm."); if (b != at::BlasBackend::Cublas) { TORCH_WARN_ONCE( "torch.backends.cuda.preferred_blas_library is an experimental feature. " diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h index dcdb129521c05..f59f83b08aae2 100644 --- a/aten/src/ATen/Context.h +++ b/aten/src/ATen/Context.h @@ -39,8 +39,8 @@ class TORCH_API Context { const Generator& defaultGenerator(Device device) { c10::DeviceType device_type = device.type(); - initCUDAIfNeeded(device_type); - initHIPIfNeeded(device_type); + lazyInitDevice(device_type); + if (device_type == at::kCPU) { return at::detail::getDefaultCPUGenerator(); } else if (device_type == at::kCUDA) { @@ -58,6 +58,7 @@ class TORCH_API Context { AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled."); } } + const AcceleratorHooksInterface& getAcceleratorHooksInterface( std::optional opt_device_type = std::nullopt) { c10::DeviceType device_type = opt_device_type.has_value() @@ -76,26 +77,23 @@ class TORCH_API Context { } else if (device_type == at::kHIP) { return at::detail::getHIPHooks(); } else { - AT_ERROR( - c10::DeviceTypeName(device_type), " device type not an accelerator."); + TORCH_CHECK( + false, + c10::DeviceTypeName(device_type), + " device type not an accelerator."); } } + Device getDeviceFromPtr(void* data, c10::DeviceType device_type) { - initCUDAIfNeeded(device_type); - initHIPIfNeeded(device_type); - initXPUIfNeeded(device_type); + lazyInitDevice(device_type); + if (device_type == at::kCPU) { return c10::DeviceType::CPU; - } else if (device_type == at::kCUDA) { - return at::detail::getCUDAHooks().getDeviceFromPtr(data); - } else if (device_type == at::kXPU) { - return at::detail::getXPUHooks().getDeviceFromPtr(data); - } else if (device_type == at::kPrivateUse1) { - return at::detail::getPrivateUse1Hooks().getDeviceFromPtr(data); } else { - AT_ERROR(c10::DeviceTypeName(device_type), " device type not enabled."); + return getAcceleratorHooksInterface(device_type).getDeviceFromPtr(data); } } + bool isPinnedPtr( const void* data, std::optional device_type = std::nullopt) { @@ -108,10 +106,20 @@ class TORCH_API Context { } return getAcceleratorHooksInterface(opt_device_type).isPinnedPtr(data); } + Allocator* getPinnedMemoryAllocator( std::optional device_type = std::nullopt) { return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator(); } + + void lazyInitDevice(c10::DeviceType device_type) { + if (device_type != at::kCPU) { + c10::call_once(init_[static_cast(device_type)], [&] { + getAcceleratorHooksInterface(device_type).init(); + }); + } + } + static bool hasOpenMP(); static bool hasMKL(); static bool hasLAPACK(); @@ -143,6 +151,9 @@ class TORCH_API Context { static bool hasCuBLASLt() { return detail::getCUDAHooks().hasCuBLASLt(); } + static bool hasROCM() { + return detail::getCUDAHooks().hasROCM(); + } static bool hasHIP() { return detail::getHIPHooks().hasHIP(); } @@ -164,27 +175,6 @@ class TORCH_API Context { static bool hasMAIA() { return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA); } - // defined in header so that getNonVariableType has ability to inline - // call_once check. getNonVariableType is called fairly frequently - void lazyInitCUDA() { - c10::call_once(thc_init, [&] { detail::getCUDAHooks().initCUDA(); }); - } - void lazyInitHIP() { - c10::call_once(thh_init, [&] { detail::getHIPHooks().initHIP(); }); - } - void lazyInitXPU() { - c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); }); - } - void lazyInitMTIA() { - c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); }); - } - void lazyInitPrivateUse1() { - c10::call_once(thp_init, [&] { - if (isPrivateUse1HooksRegistered()) { - at::detail::getPrivateUse1Hooks().initPrivateUse1(); - } - }); - } static const at::cuda::NVRTC& getNVRTC() { return detail::getCUDAHooks().nvrtc(); } @@ -359,28 +349,36 @@ class TORCH_API Context { bool allowFP16ReductionCPU() const; void setAllowFP16ReductionCPU(bool); - private: - void initCUDAIfNeeded(c10::DeviceType p) { - if (p == c10::DeviceType::CUDA) { - lazyInitCUDA(); - } + // Preserved for BC + void lazyInitCUDA() { + TORCH_WARN_DEPRECATION( + "lazyInitCUDA is deprecated. Please use lazyInitDevice(at::kCUDA) instead.") + lazyInitDevice(at::kCUDA); } - void initHIPIfNeeded(c10::DeviceType p) { - if (p == c10::DeviceType::HIP) { - lazyInitHIP(); - } + void lazyInitHIP() { + TORCH_WARN_DEPRECATION( + "lazyInitHIP is deprecated. Please use lazyInitDevice(at::kHIP) instead.") + lazyInitDevice(at::kHIP); } - void initXPUIfNeeded(c10::DeviceType p) { - if (p == c10::DeviceType::XPU) { - lazyInitXPU(); - } + void lazyInitXPU() { + TORCH_WARN_DEPRECATION( + "lazyInitXPU is deprecated. Please use lazyInitDevice(at::kXPU) instead.") + lazyInitDevice(at::kXPU); + } + void lazyInitMTIA() { + TORCH_WARN_DEPRECATION( + "lazyInitMTIA is deprecated. Please use lazyInitDevice(at::kMTIA) instead.") + lazyInitDevice(at::kMTIA); } + void lazyInitPrivateUse1() { + TORCH_WARN_DEPRECATION( + "lazyInitPrivateUse1 is deprecated. Please use lazyInitDevice(at::kPrivateUse1) instead.") + lazyInitDevice(at::kPrivateUse1); + } + + private: static bool checkCuBLASConfigDeterministic(); - c10::once_flag thc_init; - c10::once_flag thh_init; - c10::once_flag thx_init; - c10::once_flag th_mtia_init; - c10::once_flag thp_init; + std::array init_; bool enabled_cudnn = true; bool deterministic_cudnn = false; bool deterministic_mkldnn = false; diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp index 07b13ee10a9d5..0c844003eb153 100644 --- a/aten/src/ATen/DLConvertor.cpp +++ b/aten/src/ATen/DLConvertor.cpp @@ -22,6 +22,13 @@ DLDataType getDLDataType(const Tensor& t) { case ScalarType::UInt64: dtype.code = DLDataTypeCode::kDLUInt; break; + case ScalarType::Int1: + case ScalarType::Int2: + case ScalarType::Int3: + case ScalarType::Int4: + case ScalarType::Int5: + case ScalarType::Int6: + case ScalarType::Int7: case ScalarType::Char: dtype.code = DLDataTypeCode::kDLInt; break; @@ -49,11 +56,7 @@ DLDataType getDLDataType(const Tensor& t) { dtype.code = DLDataTypeCode::kDLBool; break; case ScalarType::ComplexHalf: - dtype.code = DLDataTypeCode::kDLComplex; - break; case ScalarType::ComplexFloat: - dtype.code = DLDataTypeCode::kDLComplex; - break; case ScalarType::ComplexDouble: dtype.code = DLDataTypeCode::kDLComplex; break; @@ -90,7 +93,7 @@ DLDataType getDLDataType(const Tensor& t) { static DLDevice getDLDevice(const Tensor& tensor, c10::DeviceIndex device_id) { DLDevice ctx; - ctx.device_id = static_cast(device_id); + ctx.device_id = static_cast(static_cast(device_id)); switch (tensor.device().type()) { case DeviceType::CPU: ctx.device_type = DLDeviceType::kDLCPU; @@ -253,10 +256,12 @@ ScalarType toScalarType(const DLDataType& dtype) { } // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) +namespace { struct ATenDLMTensor { Tensor handle; - DLManagedTensor tensor; + DLManagedTensor tensor{}; }; +} // namespace static void deleter(DLManagedTensor* arg) { delete static_cast(arg->manager_ctx); diff --git a/aten/src/ATen/Dispatch.h b/aten/src/ATen/Dispatch.h index db2eccf7954be..a13b85e319d15 100644 --- a/aten/src/ATen/Dispatch.h +++ b/aten/src/ATen/Dispatch.h @@ -38,11 +38,9 @@ inline constexpr bool should_include_kernel_dtype( * binary. */ #if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE -namespace at { -namespace detail { +namespace at::detail { TORCH_API void record_kernel_function_dtype(std::string name); -} -} // namespace at +} // namespace at::detail #define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \ at::detail::record_kernel_function_dtype( \ @@ -55,7 +53,8 @@ TORCH_API void record_kernel_function_dtype(std::string name); do { \ if constexpr (!at::should_include_kernel_dtype( \ at_dispatch_name, enum_type)) { \ - AT_ERROR( \ + TORCH_CHECK( \ + false, \ "dtype '", \ toString(enum_type), \ "' not selected for kernel tag ", \ @@ -63,38 +62,38 @@ TORCH_API void record_kernel_function_dtype(std::string name); } \ } while (0) -#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \ - case enum_type: { \ - AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ - using HINT C10_UNUSED = c10::impl::ScalarTypeToCPPTypeT; \ - return __VA_ARGS__(); \ +#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \ + case enum_type: { \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ + using HINT [[maybe_unused]] = c10::impl::ScalarTypeToCPPTypeT; \ + return __VA_ARGS__(); \ } #define AT_DISPATCH_CASE(enum_type, ...) \ AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__) -#define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \ - case enum_type: { \ - AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ - using scalar_t = scalar_type; \ - using underlying_t C10_UNUSED = typename scalar_t::underlying; \ - const auto& SCALAR_TYPE C10_UNUSED = enum_type; \ - const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \ - return __VA_ARGS__(); \ +#define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \ + case enum_type: { \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ + using scalar_t = scalar_type; \ + using underlying_t [[maybe_unused]] = typename scalar_t::underlying; \ + [[maybe_unused]] const auto& SCALAR_TYPE = enum_type; \ + [[maybe_unused]] const auto& UNDERLYING_TYPE = toUnderlying(enum_type); \ + return __VA_ARGS__(); \ } -#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ - enum_type, scalar_type, bitwidth, qmin, qmax, ...) \ - case enum_type: { \ - AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ - using scalar_t = scalar_type; \ - using underlying_t C10_UNUSED = typename scalar_t::underlying; \ - const auto& SCALAR_TYPE C10_UNUSED = enum_type; \ - const auto& UNDERLYING_TYPE C10_UNUSED = toUnderlying(enum_type); \ - C10_UNUSED int bit_width = bitwidth; \ - C10_UNUSED int64_t quant_min = qmin; \ - C10_UNUSED int64_t quant_max = qmax; \ - return __VA_ARGS__(); \ +#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + enum_type, scalar_type, bitwidth, qmin, qmax, ...) \ + case enum_type: { \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ + using scalar_t = scalar_type; \ + using underlying_t [[maybe_unused]] = typename scalar_t::underlying; \ + [[maybe_unused]] const auto& SCALAR_TYPE = enum_type; \ + [[maybe_unused]] const auto& UNDERLYING_TYPE = toUnderlying(enum_type); \ + [[maybe_unused]] int bit_width = bitwidth; \ + [[maybe_unused]] int64_t quant_min = qmin; \ + [[maybe_unused]] int64_t quant_max = qmax; \ + return __VA_ARGS__(); \ } namespace detail { @@ -220,7 +219,8 @@ inline void deprecated_AT_DISPATCH_ALL_TYPES_AND_HALF_AND_COMPLEX() {} switch (_st) { \ __VA_ARGS__ \ default: \ - AT_ERROR( \ + TORCH_CHECK( \ + false, \ '"', \ at_dispatch_name, \ "\" not implemented for '", \ diff --git a/aten/src/ATen/Dispatch_v2.h b/aten/src/ATen/Dispatch_v2.h index e0764834c02fd..31dd12f8de9b8 100644 --- a/aten/src/ATen/Dispatch_v2.h +++ b/aten/src/ATen/Dispatch_v2.h @@ -112,12 +112,12 @@ // Ensure we never have too many scalar types for the expansion here to // support. To bump this, you must regenerate the macros below. -static_assert(static_cast(c10::ScalarType::NumOptions) < 45); +static_assert(static_cast(c10::ScalarType::NumOptions) < 60); // Python code to regenerate generate code below: #if 0 -num_args = 45 +num_args = 60 nums = ', '.join(str(i) for i in reversed(range(num_args+1))) args = ', '.join(f'_{i}' for i in range(1, num_args+1)) @@ -135,8 +135,8 @@ for i in range(1, num_args+1): // Begin generated code // clang-format off -#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)) -#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, N, ...) N +#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)) +#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N #define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N) #define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) #define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) @@ -182,5 +182,21 @@ for i in range(1, num_args+1): #define AT_AP43(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) #define AT_AP44(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) #define AT_AP45(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) +#define AT_AP46(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) +#define AT_AP47(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) +#define AT_AP48(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) +#define AT_AP49(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) +#define AT_AP50(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) +#define AT_AP51(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) +#define AT_AP52(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) +#define AT_AP53(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) +#define AT_AP54(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) +#define AT_AP55(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) +#define AT_AP56(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) +#define AT_AP57(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) +#define AT_AP58(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N) +#define AT_AP59(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N) AT_DISPATCH_CASE(_59, N) +#define AT_AP60(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N) AT_DISPATCH_CASE(_59, N) AT_DISPATCH_CASE(_60, N) + // End generated code // clang-format on diff --git a/aten/src/ATen/ExpandUtils.h b/aten/src/ATen/ExpandUtils.h index 50e8478951d13..e9abc85b59c30 100644 --- a/aten/src/ATen/ExpandUtils.h +++ b/aten/src/ATen/ExpandUtils.h @@ -78,7 +78,7 @@ inline void check_defined( const char* api_name) { for (auto& t : tensors) { if (!t.get().defined()) { - AT_ERROR(api_name, "(...) called with an undefined Tensor"); + TORCH_CHECK(false, api_name, "(...) called with an undefined Tensor"); } } } diff --git a/aten/src/ATen/FunctionalInverses.cpp b/aten/src/ATen/FunctionalInverses.cpp index b5581b71e7678..60c86bad733a8 100644 --- a/aten/src/ATen/FunctionalInverses.cpp +++ b/aten/src/ATen/FunctionalInverses.cpp @@ -231,6 +231,7 @@ Tensor FunctionalInverses::slice_Tensor_inverse(const Tensor& base, const Tensor } } +// NOLINTNEXTLINE(performance-unnecessary-value-param) Tensor FunctionalInverses::split_Tensor_inverse(const Tensor& base, const Tensor& mutated_view, InverseReturnMode inverse_return_mode, int64_t mutated_view_idx, c10::SymInt split_size, int64_t dim) { // It would be nice if this logic could be re-used from autograd's split_backward(), but I don't think it can. // For functionalization, we have only have one of the tensors from the TensorList outputed by split(), and we want to layer i @@ -452,6 +453,7 @@ Tensor FunctionalInverses::chunk_inverse(const at::Tensor & base, const at::Tens return split_with_sizes_inverse(base, mutated_view, inverse_return_mode, mutated_view_idx, split_sizes, dim); } +// NOLINTNEXTLINE(performance-unnecessary-value-param) Tensor FunctionalInverses::narrow_inverse(const at::Tensor & base, const at::Tensor & mutated_view, InverseReturnMode inverse_return_mode, int dim, c10::SymInt start, c10::SymInt length) { if (inverse_return_mode == InverseReturnMode::AlwaysView) { // NB: assumes mutated_view is a narrowed view of base. diff --git a/aten/src/ATen/FunctionalTensorWrapper.cpp b/aten/src/ATen/FunctionalTensorWrapper.cpp index 6f66e8065731a..c16c29ed58aed 100644 --- a/aten/src/ATen/FunctionalTensorWrapper.cpp +++ b/aten/src/ATen/FunctionalTensorWrapper.cpp @@ -638,7 +638,7 @@ void replace_(const ITensorListRef functional_tensor, ITensorListRef other) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_tensor.size() == other.size()); auto functional_tensor_it = functional_tensor.begin(); auto other_it = other.begin(); - for (C10_UNUSED const auto i : c10::irange(functional_tensor.size())) { + for ([[maybe_unused]] const auto i : c10::irange(functional_tensor.size())) { replace_(*functional_tensor_it++, *other_it++); } } @@ -655,7 +655,7 @@ void propagate_xla_data(const ITensorListRef functional_tensor, ITensorListRef o TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_tensor.size() == other.size()); auto functional_tensor_it = functional_tensor.begin(); auto other_it = other.begin(); - for (C10_UNUSED const auto i : c10::irange(functional_tensor.size())) { + for ([[maybe_unused]] const auto i : c10::irange(functional_tensor.size())) { propagate_xla_data(*functional_tensor_it++, *other_it++); } } @@ -670,7 +670,7 @@ void propagate_xla_data_direct(const ITensorListRef tensor, ITensorListRef other) { auto tensor_it = tensor.begin(); auto other_it = other.begin(); - for (C10_UNUSED const auto i : c10::irange(tensor.size())) { + for ([[maybe_unused]] const auto i : c10::irange(tensor.size())) { propagate_xla_data_direct(*tensor_it++, *other_it++); } } diff --git a/aten/src/ATen/InferSize.h b/aten/src/ATen/InferSize.h index 53d1b395453ea..3bcccfad971cc 100644 --- a/aten/src/ATen/InferSize.h +++ b/aten/src/ATen/InferSize.h @@ -33,7 +33,7 @@ inline void infer_size_impl( } else if (shape[dim] >= 0) { newsize *= shape[dim]; } else { - AT_ERROR("invalid shape dimension ", shape[dim]); + TORCH_CHECK(false, "invalid shape dimension ", shape[dim]); } } diff --git a/aten/src/ATen/MapAllocator.h b/aten/src/ATen/MapAllocator.h index db1258beee525..51fb4674edd04 100644 --- a/aten/src/ATen/MapAllocator.h +++ b/aten/src/ATen/MapAllocator.h @@ -112,6 +112,10 @@ class TORCH_API RefcountedMapAllocator : private RefcountedMapAllocatorArgCheck, size_t size); static RefcountedMapAllocator* fromDataPtr(const at::DataPtr&); + RefcountedMapAllocator(const RefcountedMapAllocator&) = delete; + RefcountedMapAllocator(RefcountedMapAllocator&&) = delete; + RefcountedMapAllocator& operator=(const RefcountedMapAllocator&) = delete; + RefcountedMapAllocator& operator=(RefcountedMapAllocator&&) = delete; static at::DataPtr makeDataPtr( const char* filename, int flags, diff --git a/aten/src/ATen/MemoryOverlap.cpp b/aten/src/ATen/MemoryOverlap.cpp index 2e6792d5ca698..0ed36ebfc8dda 100644 --- a/aten/src/ATen/MemoryOverlap.cpp +++ b/aten/src/ATen/MemoryOverlap.cpp @@ -61,7 +61,7 @@ MemOverlapStatus get_overlap_status(const TensorImpl* a, const TensorImpl* b) { // same pointer across multiple storages there are many // similar situations (e.g., storage().data() == storage().data()+1) // which we will miss. - auto a_storage = a->unsafe_storage(); + const auto& a_storage = a->unsafe_storage(); if (a_storage && a_storage.is_alias_of(b->unsafe_storage())) { const auto a_begin = static_cast(a->data()); const auto a_end = a_begin + a->numel() * a->itemsize(); diff --git a/aten/src/ATen/OpaqueTensorImpl.h b/aten/src/ATen/OpaqueTensorImpl.h index f71ae5358f299..f9f69aa3c42bd 100644 --- a/aten/src/ATen/OpaqueTensorImpl.h +++ b/aten/src/ATen/OpaqueTensorImpl.h @@ -45,15 +45,15 @@ struct TORCH_API OpaqueTensorImpl : public TensorImpl { } void set_size(int64_t dim, int64_t new_size) override { - AT_ERROR("opaque tensors do not have set_size"); + TORCH_CHECK(false, "opaque tensors do not have set_size"); } void set_stride(int64_t dim, int64_t new_stride) override { - AT_ERROR("opaque tensors do not have set_stride"); + TORCH_CHECK(false, "opaque tensors do not have set_stride"); } void set_storage_offset(int64_t storage_offset) override { - AT_ERROR("opaque tensors do not have set_storage_offset"); + TORCH_CHECK(false, "opaque tensors do not have set_storage_offset"); } #ifdef DEBUG diff --git a/aten/src/ATen/ParallelCommon.cpp b/aten/src/ATen/ParallelCommon.cpp index 82d5e994fb798..49b83d9157db7 100644 --- a/aten/src/ATen/ParallelCommon.cpp +++ b/aten/src/ATen/ParallelCommon.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -23,17 +24,17 @@ namespace at { namespace { -const char* get_env_var( +std::string get_env_var( const char* var_name, const char* def_value = nullptr) { - const char* value = std::getenv(var_name); - return value ? value : def_value; + auto env = c10::utils::get_env(var_name); + return env.has_value() ? env.value() : def_value; } #ifndef C10_MOBILE size_t get_env_num_threads(const char* var_name, size_t def_value = 0) { try { - if (auto* value = std::getenv(var_name)) { - int nthreads = std::stoi(value); + if (auto value = c10::utils::get_env(var_name)) { + int nthreads = std::stoi(value.value()); TORCH_CHECK(nthreads > 0); return nthreads; } diff --git a/aten/src/ATen/SavedTensorHooks.cpp b/aten/src/ATen/SavedTensorHooks.cpp index b5733305ad069..871d9df0c924c 100644 --- a/aten/src/ATen/SavedTensorHooks.cpp +++ b/aten/src/ATen/SavedTensorHooks.cpp @@ -74,7 +74,7 @@ std::pair SavedTensorDefaultHooks::pop_hooks() { std::optional> SavedTensorDefaultHooks::get_hooks() { // For tls.is_tracing, see NOTE: [Deferring tensor pack/unpack hooks until runtime] if (!is_initialized || tls.stack.empty() || tls.is_tracing) { - return c10::nullopt; + return std::nullopt; } return tls.stack.top(); } diff --git a/aten/src/ATen/SparseCsrTensorUtils.h b/aten/src/ATen/SparseCsrTensorUtils.h index f4095c9bfa044..2ec973013c494 100644 --- a/aten/src/ATen/SparseCsrTensorUtils.h +++ b/aten/src/ATen/SparseCsrTensorUtils.h @@ -23,7 +23,8 @@ case kSparseBsc: \ return __VA_ARGS__(); \ default: \ - AT_ERROR( \ + TORCH_CHECK( \ + false, \ NAME, \ " expected sparse compressed tensor layout but got ", \ the_layout); \ @@ -42,7 +43,8 @@ case kSparseBsc: \ return (COLUMN_DIM_ACTION)(); \ default: \ - AT_ERROR( \ + TORCH_CHECK( \ + false, \ NAME, \ " expected sparse compressed tensor layout but got ", \ the_layout); \ @@ -61,7 +63,8 @@ case kSparseBsc: \ return (BLOCK_ACTION)(); \ default: \ - AT_ERROR( \ + TORCH_CHECK( \ + false, \ NAME, \ " expected sparse compressed tensor layout but got ", \ the_layout); \ @@ -77,7 +80,8 @@ case kSparseBsr: \ return (ROW_DIM_ACTION)(); \ default: \ - AT_ERROR( \ + TORCH_CHECK( \ + false, \ NAME, \ " expected sparse row compressed tensor layout but got ", \ the_layout); \ @@ -93,7 +97,8 @@ case kSparseBsc: \ return (COL_DIM_ACTION)(); \ default: \ - AT_ERROR( \ + TORCH_CHECK( \ + false, \ NAME, \ " expected sparse column compressed tensor layout but got ", \ the_layout); \ @@ -108,7 +113,8 @@ case kSparseCsc: \ return (ACTION)(); \ default: \ - AT_ERROR( \ + TORCH_CHECK( \ + false, \ NAME, \ " expected sparse compressed (non-block) tensor layout but got ", \ the_layout); \ @@ -123,7 +129,8 @@ case kSparseBsc: \ return (ACTION)(); \ default: \ - AT_ERROR( \ + TORCH_CHECK( \ + false, \ NAME, \ " expected sparse compressed block tensor layout but got ", \ the_layout); \ @@ -144,8 +151,8 @@ class CheckSparseTensorInvariants { bool old_state; public: - CheckSparseTensorInvariants(bool state) { - old_state = at::globalContext().checkSparseTensorInvariants(); + CheckSparseTensorInvariants(bool state) + : old_state(at::globalContext().checkSparseTensorInvariants()) { at::globalContext().setCheckSparseTensorInvariants(state); } diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index bda88a3ee54a6..2a3b9481255f5 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -57,13 +57,13 @@ void SparseTensorImpl::release_resources() { } void SparseTensorImpl::set_size(int64_t dim, int64_t new_size) { - AT_ERROR("sparse tensors do not have set_size"); + TORCH_CHECK(false, "sparse tensors do not have set_size"); } void SparseTensorImpl::set_stride(int64_t dim, int64_t new_stride) { - AT_ERROR("sparse tensors do not have set_stride"); + TORCH_CHECK(false, "sparse tensors do not have set_stride"); } void SparseTensorImpl::set_storage_offset(int64_t storage_offset) { - AT_ERROR("sparse tensors do not have set_storage_offset"); + TORCH_CHECK(false, "sparse tensors do not have set_storage_offset"); } #ifdef DEBUG bool SparseTensorImpl::has_storage() const { diff --git a/aten/src/ATen/TensorUtils.cpp b/aten/src/ATen/TensorUtils.cpp index 36371ad682460..7b2a1cbe62fe3 100644 --- a/aten/src/ATen/TensorUtils.cpp +++ b/aten/src/ATen/TensorUtils.cpp @@ -155,7 +155,7 @@ void checkSameGPU(CheckedFrom c, const TensorArg& t1, const TensorArg& t2) { } oss << "but expected " << ((!t1->is_cpu() && !t2->is_cpu()) ? "them" : "it") << " to be on GPU (while checking arguments for " << c << ")"; - AT_ERROR(oss.str()); + TORCH_CHECK(false, oss.str()); } TORCH_CHECK( t1->get_device() == t2->get_device(), @@ -200,7 +200,7 @@ void checkScalarTypes(CheckedFrom c, const TensorArg& t, } oss << "; but got " << t->toString() << " instead (while checking arguments for " << c << ")"; - AT_ERROR(oss.str()); + TORCH_CHECK(false, oss.str()); } } diff --git a/aten/src/ATen/ThreadLocalState.h b/aten/src/ATen/ThreadLocalState.h index 721ea9957513b..2469cb1c3c47e 100644 --- a/aten/src/ATen/ThreadLocalState.h +++ b/aten/src/ATen/ThreadLocalState.h @@ -82,7 +82,7 @@ class TORCH_API ThreadLocalState { !defined(BUILD_LITE_INTERPRETER) // TLS for autocast dtypes std::array - autocast_dtypes_; + autocast_dtypes_{}; #endif friend class ThreadLocalStateGuard; diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index 74845113a0774..95a35bd5563a0 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -36,7 +36,8 @@ inline std::vector checked_dense_tensor_list_unwrap( for (const auto i : c10::irange(tensors.size())) { const auto& expr = tensors[i]; if (expr.layout() != Layout::Strided) { - AT_ERROR( + TORCH_CHECK( + false, "Expected dense tensor but got ", expr.layout(), " for sequence element ", @@ -48,7 +49,8 @@ inline std::vector checked_dense_tensor_list_unwrap( "'"); } if (expr.device().type() != device_type) { - AT_ERROR( + TORCH_CHECK( + false, "Expected object of device type ", device_type, " but got device type ", @@ -62,7 +64,8 @@ inline std::vector checked_dense_tensor_list_unwrap( "'"); } if (expr.scalar_type() != scalar_type) { - AT_ERROR( + TORCH_CHECK( + false, "Expected object of scalar type ", scalar_type, " but got scalar type ", @@ -96,7 +99,8 @@ std::array check_intlist( return res; } if (list.size() != N) { - AT_ERROR( + TORCH_CHECK( + false, "Expected a list of ", N, " ints but got ", diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 8ae66a30dcaf0..1129892dd25f5 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -149,7 +149,7 @@ Banned functions *******************************/ static Tensor binary_cross_entropy_banned(const Tensor &, const Tensor &, const std::optional&, int64_t) { - AT_ERROR("torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n" + TORCH_CHECK(false, "torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.\n" "Many models use a sigmoid layer right before the binary cross entropy layer.\n" "In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits\n" "or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are\n" @@ -212,13 +212,13 @@ TORCH_LIBRARY_IMPL(_, AutocastMPS, m) { TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) { // lower_precision_fp - KERNEL_MPS2(_convolution, deprecated, lower_precision_fp) + KERNEL_MPS(_convolution, deprecated, lower_precision_fp) KERNEL_MPS(_convolution, lower_precision_fp) KERNEL_MPS(conv1d, lower_precision_fp) KERNEL_MPS(conv2d, lower_precision_fp) KERNEL_MPS(conv_tbc, lower_precision_fp) KERNEL_MPS(conv_transpose1d, lower_precision_fp) - KERNEL_MPS2(conv_transpose2d, input, lower_precision_fp) + KERNEL_MPS(conv_transpose2d, input, lower_precision_fp) KERNEL_MPS(convolution, lower_precision_fp) KERNEL_MPS(_mps_convolution, lower_precision_fp) KERNEL_MPS(prelu, lower_precision_fp) @@ -252,16 +252,16 @@ TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) { KERNEL_MPS(rsqrt, fp32) KERNEL_MPS(sinh, fp32) KERNEL_MPS(tan, fp32) - KERNEL_MPS2(pow, Tensor_Scalar, fp32) - KERNEL_MPS2(pow, Tensor_Tensor, fp32) - KERNEL_MPS2(pow, Scalar, fp32) + KERNEL_MPS(pow, Tensor_Scalar, fp32) + KERNEL_MPS(pow, Tensor_Tensor, fp32) + KERNEL_MPS(pow, Scalar, fp32) KERNEL_MPS(softplus, fp32) KERNEL_MPS(layer_norm, fp32) KERNEL_MPS(native_layer_norm, fp32) KERNEL_MPS(group_norm, fp32) - KERNEL_MPS2(frobenius_norm, dim, fp32) + KERNEL_MPS(frobenius_norm, dim, fp32) KERNEL_MPS(nuclear_norm, fp32) - KERNEL_MPS2(nuclear_norm, dim, fp32) + KERNEL_MPS(nuclear_norm, dim, fp32) KERNEL_MPS(batch_norm, fp32) KERNEL_MPS(cosine_similarity, fp32) KERNEL_MPS(poisson_nll_loss, fp32) @@ -288,22 +288,22 @@ TORCH_LIBRARY_IMPL(aten, AutocastMPS, m) { // fp32_set_opt_dtype KERNEL_MPS(prod, fp32) - KERNEL_MPS2(prod, dim_int, fp32) - KERNEL_MPS2(prod, dim_Dimname, fp32) - KERNEL_MPS2(softmax, int, fp32) - KERNEL_MPS2(softmax, Dimname, fp32) - KERNEL_MPS2(log_softmax, int, fp32) - KERNEL_MPS2(log_softmax, Dimname, fp32) + KERNEL_MPS(prod, dim_int, fp32) + KERNEL_MPS(prod, dim_Dimname, fp32) + KERNEL_MPS(softmax, int, fp32) + KERNEL_MPS(softmax, Dimname, fp32) + KERNEL_MPS(log_softmax, int, fp32) + KERNEL_MPS(log_softmax, Dimname, fp32) KERNEL_MPS(cumprod, fp32) - KERNEL_MPS2(cumprod, dimname, fp32) + KERNEL_MPS(cumprod, dimname, fp32) KERNEL_MPS(cumsum, fp32) - KERNEL_MPS2(cumsum, dimname, fp32) + KERNEL_MPS(cumsum, dimname, fp32) KERNEL_MPS(linalg_vector_norm, fp32) KERNEL_MPS(linalg_matrix_norm, fp32) - KERNEL_MPS2(linalg_matrix_norm, str_ord, fp32) + KERNEL_MPS(linalg_matrix_norm, str_ord, fp32) KERNEL_MPS(sum, fp32) - KERNEL_MPS2(sum, dim_IntList, fp32) - KERNEL_MPS2(sum, dim_DimnameList, fp32) + KERNEL_MPS(sum, dim_IntList, fp32) + KERNEL_MPS(sum, dim_DimnameList, fp32) // // promote KERNEL_MPS(addcdiv, promote) diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index 95f1dd2ca0c00..fbd9121d38516 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -211,7 +211,7 @@ inline at::ScalarType prioritize( const Tensor& nextArg, c10::DeviceType device_type = c10::DeviceType::CUDA) { if (current == at::kDouble) { - AT_ERROR("promote type is double in at::autocast::prioritize"); + TORCH_CHECK(false, "promote type is double in at::autocast::prioritize"); return current; } at::ScalarType lower_precision_fp = @@ -225,7 +225,8 @@ inline at::ScalarType prioritize( } else if (current == lower_precision_fp && next == lower_precision_fp) { return lower_precision_fp; } else { - AT_ERROR("Unexpected floating ScalarType in at::autocast::prioritize"); + TORCH_CHECK( + false, "Unexpected floating ScalarType in at::autocast::prioritize"); return current; } } else { @@ -749,26 +750,9 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions. REDISPATCH_SIGNATURE, \ POLICY) -// KERNEL_MPS registration for AutocastMPS -#define KERNEL_MPS(OP, POLICY) \ - m.impl( \ - TORCH_SELECTIVE_NAME("aten::" #OP), \ - &WrapFunction< \ - CastPolicy::POLICY, \ - DeviceType::MPS, \ - decltype(ATEN_FN(OP)), \ - decltype(ATEN_FN(OP)), \ - &ATEN_FN(OP)>::type::call); - -#define KERNEL_MPS2(OP, OVERLOAD, POLICY) \ - m.impl( \ - TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \ - &WrapFunction< \ - CastPolicy::POLICY, \ - DeviceType::MPS, \ - decltype(ATEN_FN2(OP, OVERLOAD)), \ - decltype(ATEN_FN2(OP, OVERLOAD)), \ - &ATEN_FN2(OP, OVERLOAD)>::type::call); +// KERNEL_MPS +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMPS +#define KERNEL_MPS(...) KERNEL(c10::DeviceType::MPS, __VA_ARGS__) // Op lists for different policies. // To make sure other backends can reuse the policy op list. diff --git a/aten/src/ATen/code_template.h b/aten/src/ATen/code_template.h index b9dfc618e54aa..ee7488b4e348c 100644 --- a/aten/src/ATen/code_template.h +++ b/aten/src/ATen/code_template.h @@ -205,7 +205,7 @@ struct CodeTemplate { // or trailing newlines. It's the responsibility of the calling function // to indent correctly in the context. void emitIndent(std::ostream& out, size_t indent) const { - for (C10_UNUSED const auto i : c10::irange(indent)) { + for ([[maybe_unused]] const auto i : c10::irange(indent)) { out << " "; } } diff --git a/aten/src/ATen/core/Dict.h b/aten/src/ATen/core/Dict.h index a9befba8276ce..b1f4ebe62e732 100644 --- a/aten/src/ATen/core/Dict.h +++ b/aten/src/ATen/core/Dict.h @@ -314,7 +314,7 @@ class Dict final { * * @return The number of elements removed. This is either '1' if an element with the key existed, or '0' if it didn't. */ - C10_NODISCARD size_t erase(const Key& key) const; + [[nodiscard]] size_t erase(const Key& key) const; /** * Returns the mapped value of the element with key equivalent to key. diff --git a/aten/src/ATen/core/Dict_inl.h b/aten/src/ATen/core/Dict_inl.h index 0419b3bd49e91..c48d7ec38ae5a 100644 --- a/aten/src/ATen/core/Dict_inl.h +++ b/aten/src/ATen/core/Dict_inl.h @@ -142,8 +142,8 @@ void Dict::erase(iterator iter) const { impl_->dict.erase(iter.entryRef_.iterator_); } -template -C10_NODISCARD size_t Dict::erase(const Key& key) const { +template +[[nodiscard]] size_t Dict::erase(const Key& key) const { return impl_->dict.erase(key); } diff --git a/aten/src/ATen/core/DistributionsHelper.h b/aten/src/ATen/core/DistributionsHelper.h index 18588ee00a36b..39004008d0070 100644 --- a/aten/src/ATen/core/DistributionsHelper.h +++ b/aten/src/ATen/core/DistributionsHelper.h @@ -95,11 +95,9 @@ struct uniform_int_distribution { template struct uniform_real_distribution { - C10_HOST_DEVICE inline uniform_real_distribution(T from, T to) { + C10_HOST_DEVICE inline uniform_real_distribution(T from, T to) : from_(from), to_(to) { TORCH_CHECK_IF_NOT_ON_CUDA(from <= to); TORCH_CHECK_IF_NOT_ON_CUDA(to - from <= std::numeric_limits::max()); - from_ = from; - to_ = to; } template @@ -186,10 +184,8 @@ DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(float); template struct normal_distribution { - C10_HOST_DEVICE inline normal_distribution(T mean_in, T stdv_in) { + C10_HOST_DEVICE inline normal_distribution(T mean_in, T stdv_in) : mean(mean_in), stdv(stdv_in) { TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in >= 0, "stdv_in must be positive: ", stdv_in); - mean = mean_in; - stdv = stdv_in; } template @@ -236,9 +232,8 @@ template <> struct DiscreteDistributionType { using type = double; }; template struct bernoulli_distribution { - C10_HOST_DEVICE inline bernoulli_distribution(T p_in) { + C10_HOST_DEVICE inline bernoulli_distribution(T p_in) : p(p_in) { TORCH_CHECK_IF_NOT_ON_CUDA(p_in >= 0 && p_in <= 1); - p = p_in; } template @@ -257,9 +252,8 @@ struct bernoulli_distribution { template struct geometric_distribution { - C10_HOST_DEVICE inline geometric_distribution(T p_in) { + C10_HOST_DEVICE inline geometric_distribution(T p_in) : p(p_in) { TORCH_CHECK_IF_NOT_ON_CUDA(p_in > 0 && p_in < 1); - p = p_in; } template @@ -317,10 +311,8 @@ struct cauchy_distribution { template struct lognormal_distribution { - C10_HOST_DEVICE inline lognormal_distribution(T mean_in, T stdv_in) { + C10_HOST_DEVICE inline lognormal_distribution(T mean_in, T stdv_in) : mean(mean_in), stdv(stdv_in) { TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in > 0); - mean = mean_in; - stdv = stdv_in; } template diff --git a/aten/src/ATen/core/Formatting.cpp b/aten/src/ATen/core/Formatting.cpp index 824640705238a..22f1490fc4908 100644 --- a/aten/src/ATen/core/Formatting.cpp +++ b/aten/src/ATen/core/Formatting.cpp @@ -37,7 +37,7 @@ std::ostream& operator<<(std::ostream & out, const Scalar& s) { std::string toString(const Scalar& s) { std::stringstream out; out << s; - return out.str(); + return std::move(out).str(); } } namespace at { @@ -153,7 +153,7 @@ static std::tuple __printFormat(std::ostream& stream, const Tensor& static void __printIndent(std::ostream &stream, int64_t indent) { - for (C10_UNUSED const auto i : c10::irange(indent)) { + for ([[maybe_unused]] const auto i : c10::irange(indent)) { stream << " "; } } diff --git a/aten/src/ATen/core/boxing/KernelFunction_impl.h b/aten/src/ATen/core/boxing/KernelFunction_impl.h index 8ba3049157d21..d505b30575834 100644 --- a/aten/src/ATen/core/boxing/KernelFunction_impl.h +++ b/aten/src/ATen/core/boxing/KernelFunction_impl.h @@ -8,6 +8,17 @@ namespace c10 { +namespace detail { +template +std::enable_if_t< + !std::is_array_v && !std::is_array_v && + std::is_base_of_v, + std::unique_ptr> +make_unique_base(Args&&... args) { + return std::unique_ptr(new Child(std::forward(args)...)); +} +} + inline KernelFunction::KernelFunction() : boxed_kernel_func_() , unboxed_kernel_func_(nullptr) @@ -174,12 +185,16 @@ template inline KernelFunction KernelFunction::makeFromUnboxedFunction(FuncPtr func_ptr) { static_assert(is_compile_time_function_pointer::value, "Tried to call KernelFunction::makeFromUnboxedFunction with an invalid parameter. It must be a function pointer created with TORCH_FN."); static_assert(!std::is_same::value, "Tried to call KernelFunction::makeFromUnboxedFunction with a boxed function pointer. Please use KernelFunction::makeFromBoxedFunction instead."); +#if defined(__GNUC__) && defined(__SANITIZE_ADDRESS__) && !defined(__CUDACC__) + TORCH_INTERNAL_ASSERT(FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr"); +#else static_assert(FuncPtr::func_ptr() != nullptr, "Kernel function cannot be nullptr"); +#endif #if !defined(C10_MOBILE) (void)func_ptr; // Suppress unused variable warning return makeFromUnboxedFunctor::type>( - guts::make_unique_base::type>() + detail::make_unique_base::type>() ); #else // On mobile, we rather want to optimize for binary size than for performance, @@ -196,7 +211,7 @@ inline KernelFunction KernelFunction::makeFromUnboxedRuntimeFunction(FuncType* f TORCH_INTERNAL_ASSERT(func != nullptr, "Kernel function cannot be nullptr"); return makeFromUnboxedFunctor>>( - guts::make_unique_base>>(func) + detail::make_unique_base>>(func) ); } @@ -206,7 +221,7 @@ inline std::enable_if_t>::value, #if !defined(C10_MOBILE) return makeFromUnboxedFunctor>>( - guts::make_unique_base>>(std::forward(lambda)) + detail::make_unique_base>>(std::forward(lambda)) ); #else // On mobile, we rather want to optimize for binary size than for performance, @@ -222,7 +237,7 @@ inline std::enable_if_t>::value, static_assert(guts::is_functor>::value, "Tried to call KernelFunction::makeFromUnboxedLambda with a non-lambda type."); return makeFromUnboxedFunctor>>( - guts::make_unique_base>>(std::forward(lambda)) + detail::make_unique_base>>(std::forward(lambda)) ); } diff --git a/aten/src/ATen/core/class_type.h b/aten/src/ATen/core/class_type.h index 67d0bae4c83c7..c4223443274f5 100644 --- a/aten/src/ATen/core/class_type.h +++ b/aten/src/ATen/core/class_type.h @@ -390,7 +390,8 @@ struct TORCH_API ClassType : public NamedType { std::string doc_string = "", std::vector unresolved_class_attributes = {}); - std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override { + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { const auto& n = name().value(); return n.qualifiedName(); } diff --git a/aten/src/ATen/core/dynamic_type.cpp b/aten/src/ATen/core/dynamic_type.cpp index ab01730da33e2..091d07bdaaaf5 100644 --- a/aten/src/ATen/core/dynamic_type.cpp +++ b/aten/src/ATen/core/dynamic_type.cpp @@ -376,8 +376,8 @@ DynamicTypePtr ivalue::TupleTypeFactory::fallback( return nullptr; } -TORCH_API TupleTypePtr -ivalue::TupleTypeFactory::fallback(C10_UNUSED const Type& type) { +TORCH_API TupleTypePtr ivalue::TupleTypeFactory::fallback( + [[maybe_unused]] const Type& type) { #ifdef C10_MOBILE return nullptr; #else @@ -398,5 +398,4 @@ ivalue::TupleTypeFactory::fallback(C10_UNUSED const Type& type) { #endif } - } // namespace c10 diff --git a/aten/src/ATen/core/enum_type.h b/aten/src/ATen/core/enum_type.h index 136fe59e22fb5..4d61be51e0476 100644 --- a/aten/src/ATen/core/enum_type.h +++ b/aten/src/ATen/core/enum_type.h @@ -28,7 +28,7 @@ struct TORCH_API EnumType : public NamedType { std::move(enum_names_values), std::move(cu))); default: - AT_ERROR( + TORCH_CHECK(false, "Cannot create Enum with value type '", value->str(), "', only int, float and string are supported"); @@ -88,7 +88,7 @@ struct TORCH_API EnumType : public NamedType { cu_(std::move(cu)) {} std::string annotation_str_impl( - C10_UNUSED const TypePrinter& printer = nullptr) const override { + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { const auto& n = name().value(); return n.qualifiedName(); } diff --git a/aten/src/ATen/core/function.h b/aten/src/ATen/core/function.h index 01e395bcf6106..ec14d4a03efea 100644 --- a/aten/src/ATen/core/function.h +++ b/aten/src/ATen/core/function.h @@ -56,7 +56,7 @@ struct TORCH_API Function { virtual c10::intrusive_ptr runAsync( Stack& /*stack*/, // NOLINTNEXTLINE(performance-unnecessary-value-param) - C10_UNUSED TaskLauncher taskLauncher = at::launch) { + [[maybe_unused]] TaskLauncher taskLauncher = at::launch) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(false); return {}; } diff --git a/aten/src/ATen/core/function_schema.h b/aten/src/ATen/core/function_schema.h index 8dab896b1411d..081e38e49b867 100644 --- a/aten/src/ATen/core/function_schema.h +++ b/aten/src/ATen/core/function_schema.h @@ -108,7 +108,7 @@ struct TORCH_API Argument { return is_out_; } - C10_NODISCARD const AliasInfo* alias_info() const { + [[nodiscard]] const AliasInfo* alias_info() const { return alias_info_.get(); } diff --git a/aten/src/ATen/core/function_schema_inl.h b/aten/src/ATen/core/function_schema_inl.h index a2fff1c130cb5..7e07785eb05a4 100644 --- a/aten/src/ATen/core/function_schema_inl.h +++ b/aten/src/ATen/core/function_schema_inl.h @@ -55,7 +55,7 @@ inline void FunctionSchema::checkAndNormalizeInputs( inputs.push_back(*argument.default_value()); continue; } - AT_ERROR( + TORCH_CHECK(false, name(), "() is missing value for argument '", argument.name(), diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 7dd98769024b3..b3f5ab69782b6 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -756,7 +756,7 @@ IValueComparator getLessThanComparator(const IValue& v) { torch::jit::Function* lt_func = checkObjectSortSchema(v.type()->expect(), why_not); if (!lt_func) { - AT_ERROR(why_not.str()); + TORCH_CHECK(false, why_not.str()); } return [lt_func](const IValue& a, const IValue& b) { @@ -772,7 +772,7 @@ IValueComparator getLessThanComparator(const IValue& v) { }; } - AT_ERROR("IValues of type: ", v.tagKind(), " are not comparable"); + TORCH_CHECK(false, "IValues of type: ", v.tagKind(), " are not comparable"); } IValueComparator getGreaterThanComparator(const IValue& v) { @@ -967,7 +967,7 @@ IValue IValue::deepcopy( copy = *this; } break; default: { - AT_ERROR("Can't deepcopy IValue with tag: ", tagKind()); + TORCH_CHECK(false, "Can't deepcopy IValue with tag: ", tagKind()); } } // NB: this doesn't work if an object contains itself, and it may @@ -1050,7 +1050,7 @@ c10::intrusive_ptr ivalue::Object::deepcopy( } err << ". Please define serialization methods via def_pickle() for " "this class."; - AT_ERROR(err.str()); + TORCH_CHECK(false, err.str()); } object->setSlot(i, slots_[i].deepcopy(memo, device)); } diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 98cb2baae1f4d..42a03ea946027 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -522,7 +522,7 @@ struct TORCH_API IValue final { } c10::intrusive_ptr toTuple() &&; c10::intrusive_ptr toTuple() const&; - C10_NODISCARD ivalue::Tuple& toTupleRef() const; + [[nodiscard]] ivalue::Tuple& toTupleRef() const; // Double IValue(double d) : tag(Tag::Double) { @@ -1163,7 +1163,7 @@ struct TORCH_API IValue final { // this value different (e.g. using NaN boxing), and this would make it more // costly to determine the tag for all types vs just determining if something // is a particular type. Instead we want clients to use the `isX` methods when - // possible. If for perf. reasons you really, absolutely, must have a jump + // possible. If for performance reasons you really, absolutely, must have a jump // table, then we can revisit this. enum class Tag : uint32_t { #define DEFINE_TAG(x) x, diff --git a/aten/src/ATen/core/ivalue_inl.h b/aten/src/ATen/core/ivalue_inl.h index 2d30d3ba5cafe..87460a7664115 100644 --- a/aten/src/ATen/core/ivalue_inl.h +++ b/aten/src/ATen/core/ivalue_inl.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -500,7 +501,7 @@ struct TORCH_API TupleElements { return *this; } - C10_NODISCARD c10::ArrayRef asArrayRef() const { + [[nodiscard]] c10::ArrayRef asArrayRef() const { if (inlineSize_) { return c10::ArrayRef(elementsInline_, inlineSize_); } else { @@ -527,15 +528,15 @@ struct TORCH_API TupleElements { } } - C10_NODISCARD bool empty() const { + [[nodiscard]] bool empty() const { return inlineSize_ ? false : elementsVector_.empty(); } - C10_NODISCARD size_t size() const { + [[nodiscard]] size_t size() const { return inlineSize_ ? inlineSize_ : elementsVector_.size(); } - C10_NODISCARD IValue& operator[](size_t idx) { + [[nodiscard]] IValue& operator[](size_t idx) { if (inlineSize_) { return elementsInline_[idx]; } else { @@ -543,7 +544,7 @@ struct TORCH_API TupleElements { } } - C10_NODISCARD const IValue& operator[](size_t idx) const { + [[nodiscard]] const IValue& operator[](size_t idx) const { if (inlineSize_) { return elementsInline_[idx]; } else { @@ -551,7 +552,7 @@ struct TORCH_API TupleElements { } } - C10_NODISCARD IValue& at(size_t idx) { + [[nodiscard]] IValue& at(size_t idx) { if (inlineSize_) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3); TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_); @@ -561,7 +562,7 @@ struct TORCH_API TupleElements { } } - C10_NODISCARD const IValue& at(size_t idx) const { + [[nodiscard]] const IValue& at(size_t idx) const { if (inlineSize_) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY(inlineSize_ <= 3); TORCH_CHECK(idx < inlineSize_, "TupleElements: invalid index Index = ", idx, "; Length = ", inlineSize_); @@ -572,7 +573,7 @@ struct TORCH_API TupleElements { } } - C10_NODISCARD iterator begin() { + [[nodiscard]] iterator begin() { if (inlineSize_) { return elementsInline_; } else { @@ -580,7 +581,7 @@ struct TORCH_API TupleElements { } } - C10_NODISCARD iterator end() { + [[nodiscard]] iterator end() { if (inlineSize_) { return elementsInline_ + inlineSize_; } else { @@ -588,7 +589,7 @@ struct TORCH_API TupleElements { } } - C10_NODISCARD const_iterator begin() const { + [[nodiscard]] const_iterator begin() const { if (inlineSize_) { return elementsInline_; } else { @@ -596,7 +597,7 @@ struct TORCH_API TupleElements { } } - C10_NODISCARD const_iterator end() const { + [[nodiscard]] const_iterator end() const { if (inlineSize_) { return elementsInline_ + inlineSize_; } else { @@ -604,27 +605,27 @@ struct TORCH_API TupleElements { } } - C10_NODISCARD const_iterator cbegin() const { + [[nodiscard]] const_iterator cbegin() const { return begin(); } - C10_NODISCARD const_iterator cend() const { + [[nodiscard]] const_iterator cend() const { return end(); } - C10_NODISCARD std::vector vec() const & { + [[nodiscard]] std::vector vec() const& { return asArrayRef().vec(); } - C10_NODISCARD IValue& back() { + [[nodiscard]] IValue& back() { return *(end() - 1); } - C10_NODISCARD const IValue& back() const { + [[nodiscard]] const IValue& back() const { return *(end() - 1); } - C10_NODISCARD std::vector vec() && { + [[nodiscard]] std::vector vec() && { std::vector result; result.reserve(size()); for (auto&& iv : *this) { @@ -863,6 +864,19 @@ struct C10_EXPORT ivalue::Future final : c10::intrusive_ptr_target { Future& operator=(const Future&) = delete; Future& operator=(Future&&) = delete; + // Destructor + // Explicitly destroy events under device guard, otherwise it can lead to + // extra context being created on device 0. Reason: python garbage collector + // calls this destructor, but python GC does not have a device context, so a + // "default" one (usually on device 0) could be created when we go down the + // line of event destroy. + ~Future() override { + while (!events_.empty()) { + c10::OptionalDeviceGuard deviceGuard(events_.back().device()); + events_.pop_back(); + } + } + struct TORCH_API FutureError final : public std::exception { explicit FutureError(std::string&& error_msg_) : error_msg(std::move(error_msg_)) {} diff --git a/aten/src/ATen/core/jit_type.h b/aten/src/ATen/core/jit_type.h index 01839231db36d..86cfb17ed4059 100644 --- a/aten/src/ATen/core/jit_type.h +++ b/aten/src/ATen/core/jit_type.h @@ -938,7 +938,7 @@ struct TORCH_API DictType : public SharedType { case TypeKind::DeviceObjType: return DictTypePtr(new DictType(std::move(key), std::move(value))); default: - AT_ERROR( + TORCH_CHECK(false, "Cannot create dict for key type '", key->str(), "', only int, float, complex, Tensor, device and string keys are supported"); @@ -1278,7 +1278,8 @@ struct TORCH_API NumberType : public Type { protected: NumberType(TypeKind kind = TypeKind::NumberType) : Type(kind) {} - std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override { + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { return "number"; // technically not a valid python type, but // we need to use it when parsing back in annotations // for implicit conversions @@ -1305,7 +1306,8 @@ struct TORCH_API FloatType : public NumberType { private: FloatType() : NumberType(TypeKind::FloatType) {} - std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override { + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { return "float"; } }; @@ -1330,7 +1332,8 @@ struct TORCH_API ComplexType : public NumberType { private: ComplexType() : NumberType(TypeKind::ComplexType) {} - std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override { + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { return "complex"; } }; @@ -1419,7 +1422,8 @@ struct TORCH_API IntType : public NumberType { private: IntType() : NumberType(TypeKind::IntType) {} - std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override { + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { return "int"; } }; @@ -1453,7 +1457,8 @@ struct TORCH_API StringType : public Type { // we only use "str" (not "string") in both FunctionSchema and script return annotation_str(); } - std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override { + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { return "str"; } static const TypeKind Kind = TypeKind::StringType; @@ -1473,7 +1478,8 @@ struct TORCH_API StorageType : public Type { std::string str() const override { return annotation_str(); } - std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override { + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { return "Storage"; } static const TypeKind Kind = TypeKind::StorageType; @@ -1508,7 +1514,8 @@ struct TORCH_API FunctionType : public NamedType { private: FunctionType(torch::jit::Function* function); - std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override { + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { const auto& n = name().value(); return n.qualifiedName(); } @@ -2199,7 +2206,8 @@ struct TORCH_API InterfaceType : public NamedType { const InterfaceType& rhs, std::ostream* why_not); - std::string annotation_str_impl(C10_UNUSED const TypePrinter& printer = nullptr) const override { + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { return name()->qualifiedName(); } diff --git a/aten/src/ATen/core/jit_type_base.h b/aten/src/ATen/core/jit_type_base.h index 462323510ca97..34038cfa0070d 100644 --- a/aten/src/ATen/core/jit_type_base.h +++ b/aten/src/ATen/core/jit_type_base.h @@ -585,7 +585,7 @@ struct TORCH_API Type { virtual TypePtr createWithContained( // NOLINTNEXTLINE(performance-unnecessary-value-param) std::vector /*contained_types*/) const { - AT_ERROR( + TORCH_CHECK(false, "type with contained types did not overload createWithContained: ", str()); } diff --git a/aten/src/ATen/core/op_registration/README.md b/aten/src/ATen/core/op_registration/README.md index 61b41b48c4a67..45b9bfa7b4199 100644 --- a/aten/src/ATen/core/op_registration/README.md +++ b/aten/src/ATen/core/op_registration/README.md @@ -140,7 +140,7 @@ Or with annotations: ``` namespace { - Tensor my_kernel_cpu(const Tensor& a, int64_t b, at::optional c) {...} + Tensor my_kernel_cpu(const Tensor& a, int64_t b, std::optional c) {...} } static auto registry = torch::RegisterOperators() @@ -176,7 +176,7 @@ The kernel function can take any of the following types as inputs or outputs: * `bool` * `c10::string_view` * `at::Scalar` (this is a type that can hold either an integer or a floating point value) -* `at::optional` with T being any type from the list above +* `std::optional` with T being any type from the list above The kernel function can take and return list inputs by using `torch::List`. `T` must be one of the supported types from above excluding `at::Scalar`. diff --git a/aten/src/ATen/core/stack.h b/aten/src/ATen/core/stack.h index 6372a3ccb556f..4fd4c2659790c 100644 --- a/aten/src/ATen/core/stack.h +++ b/aten/src/ATen/core/stack.h @@ -103,6 +103,9 @@ inline void drop(Stack* stack, size_t n) { drop(*stack, n); } inline IValue pop(Stack& stack) { + if (stack.empty()) { + throw std::runtime_error("pop() called on empty stack"); + } auto r = std::move(stack.back()); stack.pop_back(); return r; diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index fd0c3ae5170a1..88a6cd8ff6f5c 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -629,7 +629,7 @@ MatchTypeReturn matchTypeVariables( } } - AT_ERROR("Unhandled free variable container: ", formal->repr_str()); + TORCH_CHECK(false, "Unhandled free variable container: ", formal->repr_str()); } // change return types like List[List[t]] into List[List[int]] diff --git a/aten/src/ATen/cpu/Utils.cpp b/aten/src/ATen/cpu/Utils.cpp index 4455d4c117731..60ad78143f733 100644 --- a/aten/src/ATen/cpu/Utils.cpp +++ b/aten/src/ATen/cpu/Utils.cpp @@ -84,6 +84,14 @@ bool init_amx() { #endif } +bool is_arm_sve_supported() { +#if !defined(__s390x__) && !defined(__powerpc__) + return cpuinfo_initialize() && cpuinfo_has_arm_sve(); +#else + return false; +#endif +} + static uint32_t get_cache_size(int level) { #if !defined(__s390x__) && !defined(__powerpc__) if (!cpuinfo_initialize()) { diff --git a/aten/src/ATen/cpu/Utils.h b/aten/src/ATen/cpu/Utils.h index ad918dde7e059..27f9be3b3ffd0 100644 --- a/aten/src/ATen/cpu/Utils.h +++ b/aten/src/ATen/cpu/Utils.h @@ -21,6 +21,9 @@ TORCH_API bool is_amx_tile_supported(); // Enable the system to use AMX instructions. TORCH_API bool init_amx(); +// Detect if CPU supports Arm(R) architecture SVE ISA +TORCH_API bool is_arm_sve_supported(); + // Get the L1 cache size per core in Byte TORCH_API uint32_t L1d_cache_size(); diff --git a/aten/src/ATen/cpu/vec/functional_base.h b/aten/src/ATen/cpu/vec/functional_base.h index e54440ed6eedd..4d1d05ea8d326 100644 --- a/aten/src/ATen/cpu/vec/functional_base.h +++ b/aten/src/ATen/cpu/vec/functional_base.h @@ -85,28 +85,47 @@ struct VecReduceAllSIMD { using Vec = Vectorized; Vec v = acc_vec; - // 128-bit shuffle: [a1, a2, a3, a4, a5, a6, a7, a8] -> [a5, a6, a7, a8, a1, a2, a3, a4] - Vec v1 = {v.get_high(), v.get_low()}; - // [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] ('+' stands for the reduction function. Note that the last 4 elements are not required) - v = vec_fun(v, v1); - // 64-bit shuffle: [a1+a5, a2+a6, a3+a7, a4+a8, -, -, -, -] -> [a3+a7, a4+a8, a1+a5, a2+a6, -, -, -, -] - float32x4_t v1_1 = vextq_f32(v.get_low(), v.get_low(), 2); - v1 = {v1_1, v1_1}; + float32x4_t v1_1 = vextq_f32(v, v, 2); + Vec v1 = v1_1; // [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] v = vec_fun(v, v1); // 32-bit shuffle: [a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, -, -, -, -] -> [a2+a4+a6+a8, a1+a3+a5+a7, a2+a4+a6+a8, a1+a3+a5+a7, -, -, -, -] - v1_1 = vrev64q_f32(v.get_low()); - v1 = {v1_1, v1_1}; + v1_1 = vrev64q_f32(v); + v1 = v1_1; // [a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, a1+a2+a3+a4+a5+a6+a7+a8, -, -, -, -] v = vec_fun(v, v1); - return v.get_low()[0]; + return v[0]; + } +}; +#endif // defined(__aarch64__) + +#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && defined(CPU_CAPABILITY_SVE256) +template +struct VecReduceAllSIMD { + static inline float apply(const Op& vec_fun, const Vectorized& acc_vec) { + using Vec = Vectorized; + Vec v = acc_vec; + // 128-bit shuffle + svuint32_t ind = svdupq_n_u32(4, 5, 6, 7); + Vec v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + // 64-bit shuffle + ind = svdupq_n_u32(2, 3, 0, 1); + v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + // 32-bit shuffle + ind = svdupq_n_u32(1, 0, 2, 3); + v1 = svtbl_f32(v, ind); + v = vec_fun(v, v1); + return svlasta(svpfalse(), v); } }; #endif // defined(__aarch64__) + template inline scalar_t vec_reduce_all(const Op& vec_fun, const Vectorized& acc_vec) { return VecReduceAllSIMD::apply(vec_fun, acc_vec); diff --git a/aten/src/ATen/cpu/vec/vec.h b/aten/src/ATen/cpu/vec/vec.h index 234431068a40b..e4b0c4b95d845 100644 --- a/aten/src/ATen/cpu/vec/vec.h +++ b/aten/src/ATen/cpu/vec/vec.h @@ -3,6 +3,7 @@ #if defined(CPU_CAPABILITY_AVX512) #include #else +#include #include #endif diff --git a/aten/src/ATen/cpu/vec/vec128/vec128.h b/aten/src/ATen/cpu/vec/vec128/vec128.h new file mode 100644 index 0000000000000..0d0108a1f6e1f --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec128/vec128.h @@ -0,0 +1,9 @@ +#pragma once +// ARM NEON uses 128-bit vector registers. + +#include + +#if !defined(CPU_CAPABILITY_SVE) +#include +#include +#endif diff --git a/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h new file mode 100644 index 0000000000000..7476159221178 --- /dev/null +++ b/aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h @@ -0,0 +1,590 @@ +#pragma once + +// DO NOT DEFINE STATIC DATA IN THIS HEADER! +// See Note [Do not compile initializers with AVX] + +#include +#include +#include + +#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#include +#endif + +// Sleef offers vectorized versions of some transcedentals +// such as sin, cos, tan etc.. +// However for now opting for STL, since we are not building +// with Sleef for mobile yet. + +namespace at::vec { +// See Note [CPU_CAPABILITY namespace] +inline namespace CPU_CAPABILITY { + +// Right now contains only aarch64 implementation. +// Due to follow two reasons aarch32 is not currently supported. +// 1. Due to difference in ISA been aarch32 and aarch64, intrinsics +// that work for aarch64 dont work for aarch32. +// 2. Android NDK r21 has problems with compiling aarch32. +// Clang seg faults. +// https://github.com/android/ndk/issues/1248 +// https://bugs.llvm.org/show_bug.cgi?id=45824 +// Most likely we will do aarch32 support with inline asm. +#if defined(__aarch64__) + +#ifdef __BIG_ENDIAN__ +#error "Big endian is not supported." +#endif + +#if defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) +#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code +#else +#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code +#endif + +template +struct BlendRegs { + static float32x4_t impl( + const float32x4_t& a, const float32x4_t& b, float32x4_t& res); +}; + +template +struct BlendRegs{ + static float32x4_t impl( + const float32x4_t& a, const float32x4_t& b, float32x4_t& res) { + return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index); + } +}; + +template +struct BlendRegs{ + static float32x4_t impl( + const float32x4_t& a, const float32x4_t& b, float32x4_t& res) { + return vsetq_lane_f32(vgetq_lane_f32(a, index), res, index); + } +}; + +template <> class Vectorized { +private: + float32x4_t values; +public: + using value_type = float; + using size_type = int; + static constexpr size_type size() { + return 4; + } + Vectorized() {} + Vectorized(float32x4_t v) : values(v) {} + Vectorized(float val) : values{vdupq_n_f32(val)} {} + Vectorized(float val0, float val1, float val2, float val3) : + values{val0, val1, val2, val3} {} + Vectorized(float (&arr)[4]) : Vectorized(arr[0], arr[1], arr[2], arr[3]) {} + operator float32x4_t() const { + return values; + } + template + static Vectorized blend(const Vectorized& a, const Vectorized& b) { + Vectorized vec; + vec.values = + BlendRegs<0, (mask & 0x01)!=0>::impl( + a.values, b.values, vec.values); + vec.values = + BlendRegs<1, (mask & 0x02)!=0>::impl( + a.values, b.values, vec.values); + vec.values = + BlendRegs<2, (mask & 0x04)!=0>::impl( + a.values, b.values, vec.values); + vec.values = + BlendRegs<3, (mask & 0x08)!=0>::impl( + a.values, b.values, vec.values); + return vec; + } + static Vectorized blendv(const Vectorized& a, const Vectorized& b, + const Vectorized& mask) { + // TODO + // NB: This requires that each value, i.e., each uint value, + // of the mask either all be zeros or all be 1s. + // We perhaps need some kind of an assert? + // But that will affect performance. + Vectorized vec(mask.values); + vec.values = vbslq_f32( + vreinterpretq_u32_f32(vec.values), + b.values, + a.values); + return vec; + } + template + static Vectorized arange(float base = 0.f, step_t step = static_cast(1)) { + const Vectorized base_vec(base); + const Vectorized step_vec(step); + const Vectorized step_sizes(0, 1, 2, 3); + return fmadd(step_sizes, step_vec, base_vec); + } + static Vectorized set(const Vectorized& a, const Vectorized& b, + int64_t count = size()) { + switch (count) { + case 0: + return a; + case 1: + { + Vectorized vec; + static uint32x4_t mask_low = {0xFFFFFFFF, 0x0, 0x0, 0x0}; + vec.values = vreinterpretq_f32_u32(mask_low); + vec.values = vbslq_f32( + vreinterpretq_u32_f32(vec.values), + b.values, + a.values); + return vec; + } + case 2: + { + Vectorized vec; + static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0}; + vec.values = vreinterpretq_f32_u32(mask_low); + vec.values = vbslq_f32( + vreinterpretq_u32_f32(vec.values), + b.values, + a.values); + return vec; + } + case 3: + { + Vectorized vec; + static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0}; + vec.values = vreinterpretq_f32_u32(mask_low); + vec.values = vbslq_f32( + vreinterpretq_u32_f32(vec.values), + b.values, + a.values); + return vec; + } + } + return b; + } + static Vectorized loadu(const void* ptr, int64_t count = size()) { + if (count == size()) { + return vld1q_f32(reinterpret_cast(ptr)); + } else { + __at_align__ float tmp_values[size()]; + for (const auto i : c10::irange(size())) { + tmp_values[i] = 0.0; + } + std::memcpy( + tmp_values, + reinterpret_cast(ptr), + count * sizeof(float)); + return vld1q_f32(reinterpret_cast(tmp_values)); + } + } + void store(void* ptr, int64_t count = size()) const { + if (count == size()) { + vst1q_f32(reinterpret_cast(ptr), values); + } else { + float tmp_values[size()]; + vst1q_f32(reinterpret_cast(tmp_values), values); + std::memcpy(ptr, tmp_values, count * sizeof(float)); + } + } + // Very slow implementation of indexing. + // Only required because vec256_qint refers to this. + // Once we specialize that implementation for ARM + // this should be removed. TODO (kimishpatel) + float operator[](int idx) const { + __at_align__ float tmp[size()]; + store(tmp); + return tmp[idx]; + } + float operator[](int idx) { + __at_align__ float tmp[size()]; + store(tmp); + return tmp[idx]; + } + // For boolean version where we want to if any 1/all zero + // etc. can be done faster in a different way. + int zero_mask() const { + __at_align__ float tmp[size()]; + store(tmp); + int mask = 0; + for (int i = 0; i < size(); ++ i) { + if (tmp[i] == 0.f) { + mask |= (1 << i); + } + } + return mask; + } + Vectorized isnan() const { + __at_align__ float tmp[size()]; + __at_align__ float res[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if (_isnan(tmp[i])) { + std::memset(static_cast(&res[i]), 0xFF, sizeof(float)); + } else { + std::memset(static_cast(&res[i]), 0, sizeof(float)); + } + } + return loadu(res); + }; + bool has_inf_nan() const { + __at_align__ float tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + if(_isnan(tmp[i]) || _isinf(tmp[i])) { + return true; + } + } + return false; + } + Vectorized map(float (*const f)(float)) const { + __at_align__ float tmp[size()]; + store(tmp); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i]); + } + return loadu(tmp); + } + Vectorized map2( + const Vectorized& second, + float (*const f)(float, float)) const { + __at_align__ float tmp[size()]; + __at_align__ float tmp_second[size()]; + store(tmp); + second.store(tmp_second); + for (const auto i : c10::irange(size())) { + tmp[i] = f(tmp[i], tmp_second[i]); + } + return loadu(tmp); + } + Vectorized abs() const { + return Vectorized(vabsq_f32(values)); + } + Vectorized angle() const { + auto zero = Vectorized(0); + auto pi = Vectorized(c10::pi); + auto tmp = blendv(zero, pi, *this < zero); + return blendv(tmp, *this, isnan()); + } + Vectorized real() const { + return *this; + } + Vectorized imag() const { + return Vectorized(0.f); + } + Vectorized conj() const { + return *this; + } +#define DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(name, sleef_name) \ + Vectorized name() const { \ + return USE_SLEEF( \ + Vectorized(sleef_name(values)), \ + map(std::name) \ + ); \ + } + +#define DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(name) \ + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(name, Sleef_##name##f4_u10) + + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(acos) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(acosh) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(asin) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(atan) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(atanh) + +#define DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(name, sleef_name) \ + Vectorized name(const Vectorized &arg) const { \ + return USE_SLEEF( \ + Vectorized(sleef_name(values, arg.values)), \ + map2(arg, std::name) \ + ); \ + } + +#define DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(name) \ + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(name, Sleef_##name##f4_u10) + + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(atan2) + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(copysign, Sleef_copysignf4) + Vectorized erf() const; + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(erfc, Sleef_erfcf4_u15) + Vectorized erfinv() const { + return map(calc_erfinv); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(exp2) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(expm1) + Vectorized exp_u20() const { + return exp(); + } + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(fmod, Sleef_fmodf4); + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(hypot, Sleef_hypotf4_u05); + Vectorized i0() const { + return map(calc_i0); + } + Vectorized i0e() const { + return map(calc_i0e); + } + Vectorized digamma() const { + return map(calc_digamma); + } + Vectorized igamma(const Vectorized &x) const { + return map2(x, calc_igamma); + } + Vectorized igammac(const Vectorized &x) const { + return map2(x, calc_igammac); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log10) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log1p) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(log2) + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC_WITH_SLEEF_NAME(nextafter, Sleef_nextafterf4) + Vectorized frac() const; + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(sin) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(sinh) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(cos) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(cosh) + Vectorized ceil() const { + return map(at::native::ceil_impl); + } + Vectorized floor() const { + return map(at::native::floor_impl); + } + Vectorized neg() const { + return Vectorized( + vnegq_f32(values)); + } + Vectorized round() const { + // We do not use std::round because we would like to round midway numbers to the nearest even integer. + return map(at::native::round_impl); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(tan) + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(tanh) + Vectorized trunc() const { + return Vectorized(vrndq_f32(values)); + } + DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC(lgamma) + Vectorized sqrt() const { + return Vectorized(vsqrtq_f32(values)); + } + Vectorized reciprocal() const { + return Vectorized(vdivq_f32(vdupq_n_f32(1.0f), values)); + } + Vectorized rsqrt() const { + return this->sqrt().reciprocal(); + } + DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC(pow) + Vectorized operator==(const Vectorized& other) const { + return Vectorized(vreinterpretq_f32_u32(vceqq_f32(values, other.values))); + } + + Vectorized operator!=(const Vectorized& other) const { + float32x4_t r0 = vreinterpretq_f32_u32( + vmvnq_u32(vceqq_f32(values, other.values))); + return Vectorized(r0); + } + + Vectorized operator<(const Vectorized& other) const { + return Vectorized(vreinterpretq_f32_u32(vcltq_f32(values, other.values))); + } + + Vectorized operator<=(const Vectorized& other) const { + return Vectorized(vreinterpretq_f32_u32(vcleq_f32(values, other.values))); + } + + Vectorized operator>(const Vectorized& other) const { + return Vectorized(vreinterpretq_f32_u32(vcgtq_f32(values, other.values))); + } + + Vectorized operator>=(const Vectorized& other) const { + return Vectorized(vreinterpretq_f32_u32(vcgeq_f32(values, other.values))); + } + + Vectorized eq(const Vectorized& other) const; + Vectorized ne(const Vectorized& other) const; + Vectorized gt(const Vectorized& other) const; + Vectorized ge(const Vectorized& other) const; + Vectorized lt(const Vectorized& other) const; + Vectorized le(const Vectorized& other) const; +}; + +template <> +Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { + return Vectorized(vaddq_f32(a, b)); +} + +template <> +Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { + return Vectorized(vsubq_f32(a, b)); +} + +template <> +Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { + return Vectorized(vmulq_f32(a, b)); +} + +template <> +Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { + return Vectorized(vdivq_f32(a, b)); +} + +// frac. Implement this here so we can use subtraction +inline Vectorized Vectorized::frac() const { + return *this - this->trunc(); +} + +//Added sleef Implementation for Maximum +Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { + if(!a.has_inf_nan() && !b.has_inf_nan()){ + return USE_SLEEF( + Vectorized(Sleef_fmaxf4(a, b)), + Vectorized(vmaxq_f32(a,b))); + } + else{ + return Vectorized(vmaxq_f32(a, b)); + } +} + +// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if +// either input is a NaN. +template <> +Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { + return Vectorized(vminq_f32(a, b)); +} + +template <> +Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) { + return minimum(max, maximum(min, a)); +} + +template <> +Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) { + return minimum(max, a); +} + +template <> +Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) { + return maximum(min, a); +} + +template <> +Vectorized inline operator&(const Vectorized& a, const Vectorized& b) { + return Vectorized(vreinterpretq_f32_u32(vandq_u32( + vreinterpretq_u32_f32(a), + vreinterpretq_u32_f32(b)))); +} + +template <> +Vectorized inline operator|(const Vectorized& a, const Vectorized& b) { + return Vectorized(vreinterpretq_f32_u32(vorrq_u32( + vreinterpretq_u32_f32(a), + vreinterpretq_u32_f32(b)))); +} + +template <> +Vectorized inline operator^(const Vectorized& a, const Vectorized& b) { + return Vectorized(vreinterpretq_f32_u32(veorq_u32( + vreinterpretq_u32_f32(a), + vreinterpretq_u32_f32(b)))); +} + +inline Vectorized Vectorized::eq(const Vectorized& other) const { + return (*this == other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ne(const Vectorized& other) const { + return (*this != other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::gt(const Vectorized& other) const { + return (*this > other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::ge(const Vectorized& other) const { + return (*this >= other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::lt(const Vectorized& other) const { + return (*this < other) & Vectorized(1.0f); +} + +inline Vectorized Vectorized::le(const Vectorized& other) const { + return (*this <= other) & Vectorized(1.0f); +} + +template <> +inline void convert(const float* src, int32_t* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { + vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i))); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +inline void convert(const int32_t* src, float* dst, int64_t n) { + int64_t i; +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { + vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i))); + } +#ifndef __msvc_cl__ +#pragma unroll +#endif + for (; i < n; i++) { + dst[i] = static_cast(src[i]); + } +} + +template <> +Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { + return Vectorized(vfmaq_f32(c, a, b)); +} + +template <> +Vectorized inline fmsub(const Vectorized& a, const Vectorized& b, const Vectorized& c) { + return Vectorized(vfmsq_f32(c, a, b)); +} + +inline Vectorized Vectorized::erf() const{ + // constants + const Vectorized neg_zero_vec(-0.f); + const Vectorized one_vec(1.0f); + const Vectorized p(0.3275911f); + const Vectorized p1(0.254829592f); + const Vectorized p2(-0.284496736f); + const Vectorized p3(1.421413741f); + const Vectorized p4(-1.453152027f); + const Vectorized p5(1.061405429f); + // sign(x) + auto sign_mask = neg_zero_vec & *this; + auto abs_vec = this->abs(); + // t = 1 / (p * abs(x) + 1) + auto tmp0 = fmadd(p, abs_vec, one_vec); + auto t = one_vec / tmp0; + // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1 + auto tmp1 = fmadd(p5, t, p4); + auto tmp2 = fmadd(tmp1, t, p3); + auto tmp3 = fmadd(tmp2, t, p2); + auto r = fmadd(tmp3, t, p1); + // - exp(- x * x) + auto pow_2 = (*this) * (*this); + auto neg_pow_2 = pow_2 ^ neg_zero_vec; + auto tmp4 = neg_pow_2.map(std::exp); // This can be swapped for a faster implementation of exp. + auto tmp5 = tmp4 ^ neg_zero_vec; + // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) + auto tmp6 = t * tmp5; + auto tmp7 = fmadd(tmp6, r, one_vec); + return tmp7 ^ sign_mask; +} +#undef DEFINE_SLEEF_COMPATIBLE_BINARY_ELEMENTWISE_FUNC +#undef DEFINE_SLEEF_COMPATIBLE_UNARY_ELEMENTWISE_FUNC +#endif /* defined(aarch64) */ + +}} // namespace at::vec::CPU_CAPABILITY diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_half_neon.h b/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h similarity index 61% rename from aten/src/ATen/cpu/vec/vec256/vec256_half_neon.h rename to aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h index 0b51972a029b4..c3f45d930fa9a 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_half_neon.h +++ b/aten/src/ATen/cpu/vec/vec128/vec128_half_neon.h @@ -4,7 +4,7 @@ // See Note [Do not compile initializers with AVX] #include -#include +#include #include #include #include @@ -61,15 +61,15 @@ struct BlendHalfRegs { template <> class Vectorized { private: - float16x8x2_t values; + float16x8_t values; public: // value_type should be c10::Half to fit interface with vec_base.h using value_type = c10::Half; using size_type = int; static constexpr size_type size() { - static_assert(sizeof(float16x8x2_t) == 16 * sizeof(value_type)); - return 16; + static_assert(sizeof(float16x8_t) == 8 * sizeof(value_type)); + return 8; } private: @@ -89,69 +89,43 @@ class Vectorized { Vectorized map_with_vec_float_method( Vectorized (Vectorized::*m)() const) const { - // Convert low float16x8_t to 2 float32x4_t variables, apply m, and convert - // back - float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values.val[0])); - float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values.val[0])); - Vectorized mv0 = (Vectorized(v00, v01).*m)(); - float16x4_t r00 = vcvt_f16_f32(mv0.get_low()); - float16x4_t r01 = vcvt_f16_f32(mv0.get_high()); - - // Convert high float16x8_t to 2 float32x4_t variables, apply m, and convert - // back - float32x4_t v10 = vcvt_f32_f16(vget_low_f16(values.val[1])); - float32x4_t v11 = vcvt_f32_f16(vget_high_f16(values.val[1])); - Vectorized mv1 = (Vectorized(v10, v11).*m)(); - float16x4_t r10 = vcvt_f16_f32(mv1.get_low()); - float16x4_t r11 = vcvt_f16_f32(mv1.get_high()); - - // Pack result into Vectorized - return Vectorized( - vcombine_f16(r00, r01), vcombine_f16(r10, r11)); + float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values)); + float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values)); + Vectorized mv0 = (Vectorized(v00).*m)(); + Vectorized mv1 = (Vectorized(v01).*m)(); + float16x4_t r00 = vcvt_f16_f32(mv0); + float16x4_t r01 = vcvt_f16_f32(mv1); + return Vectorized(vcombine_f16(r00, r01)); } Vectorized map2_with_vec_float_method( const Vectorized& second, Vectorized (Vectorized::*m)(const Vectorized&) const) const { - // Convert low float16x8_t to 2 float32x4_t variables, apply m, and convert - // back - float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values.val[0])); - float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values.val[0])); - float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.get_low())); - float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.get_low())); - Vectorized mv0 = (Vectorized(v00, v01).*m)( - Vectorized(second_v00, second_v01)); - float16x4_t r00 = vcvt_f16_f32(mv0.get_low()); - float16x4_t r01 = vcvt_f16_f32(mv0.get_high()); - - // Convert high float16x8_t to 2 float32x4_t variables, apply m, and convert - // back - float32x4_t v10 = vcvt_f32_f16(vget_low_f16(values.val[1])); - float32x4_t v11 = vcvt_f32_f16(vget_high_f16(values.val[1])); - float32x4_t second_v10 = vcvt_f32_f16(vget_low_f16(second.get_high())); - float32x4_t second_v11 = vcvt_f32_f16(vget_high_f16(second.get_high())); - Vectorized mv1 = (Vectorized(v10, v11).*m)( - Vectorized(second_v10, second_v11)); - float16x4_t r10 = vcvt_f16_f32(mv1.get_low()); - float16x4_t r11 = vcvt_f16_f32(mv1.get_high()); + float32x4_t v00 = vcvt_f32_f16(vget_low_f16(values)); + float32x4_t v01 = vcvt_f32_f16(vget_high_f16(values)); + float32x4_t second_v00 = vcvt_f32_f16(vget_low_f16(second.values)); + float32x4_t second_v01 = vcvt_f32_f16(vget_high_f16(second.values)); + Vectorized mv0 = (Vectorized(v00).*m)(Vectorized(second_v00)); + Vectorized mv1 = (Vectorized(v01).*m)(Vectorized(second_v01)); + float16x4_t r00 = vcvt_f16_f32(mv0); + float16x4_t r01 = vcvt_f16_f32(mv1); // Pack result into Vectorized - return Vectorized( - vcombine_f16(r00, r01), vcombine_f16(r10, r11)); + return Vectorized(vcombine_f16(r00, r01)); } public: // constructor Vectorized() {} - Vectorized(float16x8x2_t v) : values(v) {} + Vectorized(float16x8_t v) : values(v) {} // A ctor that accepts c10::Half is needed to fit interface with vec_base.h // A second constructor that takes float16_t is also included Vectorized(c10::Half val) - : values{vdupq_n_f16((float16_t)val), vdupq_n_f16((float16_t)val)} { + : values{vdupq_n_f16((float16_t)val)} { } - Vectorized(float16_t val) : values{vdupq_n_f16(val), vdupq_n_f16(val)} {} + Vectorized(float16_t val) : values{vdupq_n_f16(val)} {} Vectorized( float16_t val0, float16_t val1, @@ -160,15 +134,7 @@ class Vectorized { float16_t val4, float16_t val5, float16_t val6, - float16_t val7, - float16_t val8, - float16_t val9, - float16_t val10, - float16_t val11, - float16_t val12, - float16_t val13, - float16_t val14, - float16_t val15) + float16_t val7) : values{ val0, val1, @@ -177,17 +143,8 @@ class Vectorized { val4, val5, val6, - val7, - val8, - val9, - val10, - val11, - val12, - val13, - val14, - val15} {} - Vectorized(float16x8_t val0, float16x8_t val1) : values{val0, val1} {} - operator float16x8x2_t() const { + val7} {} + operator float16x8_t() const { return values; } template @@ -196,42 +153,23 @@ class Vectorized { const Vectorized& b) { Vectorized vec; // 0. - vec.values.val[0] = BlendHalfRegs<0, (mask & 0x01) != 0>::impl( - a.values.val[0], b.values.val[0], vec.values.val[0]); - vec.values.val[0] = BlendHalfRegs<1, (mask & 0x02) != 0>::impl( - a.values.val[0], b.values.val[0], vec.values.val[0]); - vec.values.val[0] = BlendHalfRegs<2, (mask & 0x04) != 0>::impl( - a.values.val[0], b.values.val[0], vec.values.val[0]); - vec.values.val[0] = BlendHalfRegs<3, (mask & 0x08) != 0>::impl( - a.values.val[0], b.values.val[0], vec.values.val[0]); - - vec.values.val[0] = BlendHalfRegs<4, (mask & 0x10) != 0>::impl( - a.values.val[0], b.values.val[0], vec.values.val[0]); - vec.values.val[0] = BlendHalfRegs<5, (mask & 0x20) != 0>::impl( - a.values.val[0], b.values.val[0], vec.values.val[0]); - vec.values.val[0] = BlendHalfRegs<6, (mask & 0x40) != 0>::impl( - a.values.val[0], b.values.val[0], vec.values.val[0]); - vec.values.val[0] = BlendHalfRegs<7, (mask & 0x80) != 0>::impl( - a.values.val[0], b.values.val[0], vec.values.val[0]); - - // 1. - vec.values.val[1] = BlendHalfRegs<0, (mask & 0x10) != 0>::impl( - a.values.val[1], b.values.val[1], vec.values.val[1]); - vec.values.val[1] = BlendHalfRegs<1, (mask & 0x20) != 0>::impl( - a.values.val[1], b.values.val[1], vec.values.val[1]); - vec.values.val[1] = BlendHalfRegs<2, (mask & 0x40) != 0>::impl( - a.values.val[1], b.values.val[1], vec.values.val[1]); - vec.values.val[1] = BlendHalfRegs<3, (mask & 0x80) != 0>::impl( - a.values.val[1], b.values.val[1], vec.values.val[1]); - - vec.values.val[1] = BlendHalfRegs<4, (mask & 0x10) != 0>::impl( - a.values.val[1], b.values.val[1], vec.values.val[1]); - vec.values.val[1] = BlendHalfRegs<5, (mask & 0x20) != 0>::impl( - a.values.val[1], b.values.val[1], vec.values.val[1]); - vec.values.val[1] = BlendHalfRegs<6, (mask & 0x40) != 0>::impl( - a.values.val[1], b.values.val[1], vec.values.val[1]); - vec.values.val[1] = BlendHalfRegs<7, (mask & 0x80) != 0>::impl( - a.values.val[1], b.values.val[1], vec.values.val[1]); + vec.values = BlendHalfRegs<0, (mask & 0x01) != 0>::impl( + a.values, b.values, vec.values); + vec.values = BlendHalfRegs<1, (mask & 0x02) != 0>::impl( + a.values, b.values, vec.values); + vec.values = BlendHalfRegs<2, (mask & 0x04) != 0>::impl( + a.values, b.values, vec.values); + vec.values = BlendHalfRegs<3, (mask & 0x08) != 0>::impl( + a.values, b.values, vec.values); + + vec.values = BlendHalfRegs<4, (mask & 0x10) != 0>::impl( + a.values, b.values, vec.values); + vec.values = BlendHalfRegs<5, (mask & 0x20) != 0>::impl( + a.values, b.values, vec.values); + vec.values = BlendHalfRegs<6, (mask & 0x40) != 0>::impl( + a.values, b.values, vec.values); + vec.values = BlendHalfRegs<7, (mask & 0x80) != 0>::impl( + a.values, b.values, vec.values); return vec; } @@ -249,14 +187,10 @@ class Vectorized { // We perhaps need some kind of an assert? // But that will affect performance. Vectorized vec(mask.values); - vec.values.val[0] = vbslq_f16( - vreinterpretq_u16_f16(vec.values.val[0]), - b.values.val[0], - a.values.val[0]); - vec.values.val[1] = vbslq_f16( - vreinterpretq_u16_f16(vec.values.val[1]), - b.values.val[1], - a.values.val[1]); + vec.values = vbslq_f16( + vreinterpretq_u16_f16(vec.values), + b.values, + a.values); return vec; } template @@ -266,40 +200,32 @@ class Vectorized { const Vectorized base_vec(base); const Vectorized step_vec(step); const Vectorized step_sizes( - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15); + 0, 1, 2, 3, 4, 5, 6, 7); return fmadd(step_sizes, step_vec, base_vec); } static Vectorized set( const Vectorized& a, const Vectorized& b, int64_t count = size()) { - uint16_t pre_mask[16] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + uint16_t pre_mask[size()] = {0}; for (int i = 0; i < count; i++) { pre_mask[i] = 0xFFFF; } - uint16x8x2_t mask = vld1q_u16_x2(pre_mask); + uint16x8_t mask = vld1q_u16(pre_mask); // Using blendv is awkward because 0xFFFF is one of many NaN's in FP16 // so we directly use vbslq_f16 instead Vectorized vec( vbslq_f16( - // Low bits - mask.val[0], - b.values.val[0], - a.values.val[0]), - // High bits - vbslq_f16(mask.val[1], b.values.val[1], a.values.val[1])); + mask, + b.values, + a.values)); return vec; } static Vectorized loadu(const void* ptr, int64_t count = size()) { if (count == size()) { - return vld1q_f16_x2(reinterpret_cast(ptr)); - } else if (count == (size() >> 1)) { - Vectorized res; - res.values.val[0] = vld1q_f16(reinterpret_cast(ptr)); - std::memset(&res.values.val[1], 0, sizeof(res.values.val[1])); - return res; + return vld1q_f16(reinterpret_cast(ptr)); } __at_align__ float16_t tmp_values[size()]; for (const auto i : c10::irange(size())) { @@ -309,32 +235,18 @@ class Vectorized { tmp_values, reinterpret_cast(ptr), count * sizeof(float16_t)); - return vld1q_f16_x2(reinterpret_cast(tmp_values)); + return vld1q_f16(reinterpret_cast(tmp_values)); } void store(void* ptr, int64_t count = size()) const { if (count == size()) { - vst1q_f16_x2(reinterpret_cast(ptr), values); + vst1q_f16(reinterpret_cast(ptr), values); return; - } else if (count == (size() >> 1)) { - vst1q_f16(reinterpret_cast(ptr), values.val[0]); } else { float16_t tmp_values[size()]; - vst1q_f16_x2(reinterpret_cast(tmp_values), values); + vst1q_f16(reinterpret_cast(tmp_values), values); std::memcpy(ptr, tmp_values, count * sizeof(float16_t)); } } - inline const float16x8_t& get_low() const { - return values.val[0]; - } - inline float16x8_t& get_low() { - return values.val[0]; - } - inline const float16x8_t& get_high() const { - return values.val[1]; - } - inline float16x8_t& get_high() { - return values.val[1]; - } // Very slow implementation of indexing. // Only required because vec256_qint refers to this. // Once we specialize that implementation for ARM @@ -394,8 +306,7 @@ class Vectorized { return loadu(tmp); } Vectorized abs() const { - return Vectorized( - vabsq_f16(values.val[0]), vabsq_f16(values.val[1])); + return Vectorized(vabsq_f16(values)); } Vectorized angle() const { auto zero = Vectorized(0); @@ -518,8 +429,7 @@ class Vectorized { return map(at::native::floor_impl); } Vectorized neg() const { - return Vectorized( - vnegq_f16(values.val[0]), vnegq_f16(values.val[1])); + return Vectorized(vnegq_f16(values)); } inline Vectorized round() const { // This function is questionable with a conversion, so we use map @@ -532,22 +442,17 @@ class Vectorized { return map_with_vec_float_method(&Vectorized::tanh); } Vectorized trunc() const { - float16x8_t r0 = vrndq_f16(values.val[0]); - float16x8_t r1 = vrndq_f16(values.val[1]); - return Vectorized(r0, r1); + return Vectorized(vrndq_f16(values)); } Vectorized lgamma() const { return map_with_vec_float_method(&Vectorized::lgamma); } Vectorized sqrt() const { - return Vectorized( - vsqrtq_f16(values.val[0]), vsqrtq_f16(values.val[1])); + return Vectorized(vsqrtq_f16(values)); } Vectorized reciprocal() const { auto ones = vdupq_n_f16(1.0f); - auto r0 = vdivq_f16(ones, values.val[0]); - auto r1 = vdivq_f16(ones, values.val[1]); - return Vectorized(r0, r1); + return Vectorized(vdivq_f16(ones, values)); } Vectorized rsqrt() const { return this->sqrt().reciprocal(); @@ -556,51 +461,28 @@ class Vectorized { return map2_with_vec_float_method(exp, &Vectorized::pow); } Vectorized operator==(const Vectorized& other) const { - float16x8_t r0 = - vreinterpretq_f16_u16(vceqq_f16(values.val[0], other.values.val[0])); - float16x8_t r1 = - vreinterpretq_f16_u16(vceqq_f16(values.val[1], other.values.val[1])); - return Vectorized(r0, r1); + return Vectorized(vreinterpretq_f16_u16(vceqq_f16(values, other.values))); } Vectorized operator!=(const Vectorized& other) const { - float16x8_t r0 = vreinterpretq_f16_u16( - vmvnq_u16(vceqq_f16(values.val[0], other.values.val[0]))); - float16x8_t r1 = vreinterpretq_f16_u16( - vmvnq_u16(vceqq_f16(values.val[1], other.values.val[1]))); - return Vectorized(r0, r1); + return Vectorized(vreinterpretq_f16_u16( + vmvnq_u16(vceqq_f16(values, other.values)))); } Vectorized operator<(const Vectorized& other) const { - float16x8_t r0 = - vreinterpretq_f16_u16(vcltq_f16(values.val[0], other.values.val[0])); - float16x8_t r1 = - vreinterpretq_f16_u16(vcltq_f16(values.val[1], other.values.val[1])); - return Vectorized(r0, r1); + return Vectorized(vreinterpretq_f16_u16(vcltq_f16(values, other.values))); } Vectorized operator<=(const Vectorized& other) const { - float16x8_t r0 = - vreinterpretq_f16_u16(vcleq_f16(values.val[0], other.values.val[0])); - float16x8_t r1 = - vreinterpretq_f16_u16(vcleq_f16(values.val[1], other.values.val[1])); - return Vectorized(r0, r1); + return Vectorized(vreinterpretq_f16_u16(vcleq_f16(values, other.values))); } Vectorized operator>(const Vectorized& other) const { - float16x8_t r0 = - vreinterpretq_f16_u16(vcgtq_f16(values.val[0], other.values.val[0])); - float16x8_t r1 = - vreinterpretq_f16_u16(vcgtq_f16(values.val[1], other.values.val[1])); - return Vectorized(r0, r1); + return Vectorized(vreinterpretq_f16_u16(vcgtq_f16(values, other.values))); } Vectorized operator>=(const Vectorized& other) const { - float16x8_t r0 = - vreinterpretq_f16_u16(vcgeq_f16(values.val[0], other.values.val[0])); - float16x8_t r1 = - vreinterpretq_f16_u16(vcgeq_f16(values.val[1], other.values.val[1])); - return Vectorized(r0, r1); + return Vectorized(vreinterpretq_f16_u16(vcgeq_f16(values, other.values))); } Vectorized eq(const Vectorized& other) const; @@ -615,36 +497,28 @@ template <> Vectorized inline operator+( const Vectorized& a, const Vectorized& b) { - float16x8_t r0 = vaddq_f16(a.get_low(), b.get_low()); - float16x8_t r1 = vaddq_f16(a.get_high(), b.get_high()); - return Vectorized(r0, r1); + return Vectorized(vaddq_f16(a, b)); } template <> Vectorized inline operator-( const Vectorized& a, const Vectorized& b) { - float16x8_t r0 = vsubq_f16(a.get_low(), b.get_low()); - float16x8_t r1 = vsubq_f16(a.get_high(), b.get_high()); - return Vectorized(r0, r1); + return Vectorized(vsubq_f16(a, b)); } template <> Vectorized inline operator*( const Vectorized& a, const Vectorized& b) { - float16x8_t r0 = vmulq_f16(a.get_low(), b.get_low()); - float16x8_t r1 = vmulq_f16(a.get_high(), b.get_high()); - return Vectorized(r0, r1); + return Vectorized(vmulq_f16(a, b)); } template <> Vectorized inline operator/( const Vectorized& a, const Vectorized& b) { - float16x8_t r0 = vdivq_f16(a.get_low(), b.get_low()); - float16x8_t r1 = vdivq_f16(a.get_high(), b.get_high()); - return Vectorized(r0, r1); + return Vectorized(vdivq_f16(a, b)); } // frac. Implement this here so we can use subtraction @@ -658,9 +532,7 @@ template <> Vectorized inline maximum( const Vectorized& a, const Vectorized& b) { - float16x8_t r0 = vmaxq_f16(a.get_low(), b.get_low()); - float16x8_t r1 = vmaxq_f16(a.get_high(), b.get_high()); - return Vectorized(r0, r1); + return Vectorized(vmaxq_f16(a, b)); } // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if @@ -669,9 +541,7 @@ template <> Vectorized inline minimum( const Vectorized& a, const Vectorized& b) { - float16x8_t r0 = vminq_f16(a.get_low(), b.get_low()); - float16x8_t r1 = vminq_f16(a.get_high(), b.get_high()); - return Vectorized(r0, r1); + return Vectorized(vminq_f16(a, b)); } template <> @@ -700,36 +570,24 @@ template <> Vectorized inline operator&( const Vectorized& a, const Vectorized& b) { - float16x8_t r0 = vreinterpretq_f16_u16(vandq_u16( - vreinterpretq_u16_f16(a.get_low()), vreinterpretq_u16_f16(b.get_low()))); - float16x8_t r1 = vreinterpretq_f16_u16(vandq_u16( - vreinterpretq_u16_f16(a.get_high()), - vreinterpretq_u16_f16(b.get_high()))); - return Vectorized(r0, r1); + return Vectorized(vreinterpretq_f16_u16(vandq_u16( + vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)))); } template <> Vectorized inline operator|( const Vectorized& a, const Vectorized& b) { - float16x8_t r0 = vreinterpretq_f16_u16(vorrq_u16( - vreinterpretq_u16_f16(a.get_low()), vreinterpretq_u16_f16(b.get_low()))); - float16x8_t r1 = vreinterpretq_f16_u16(vorrq_u16( - vreinterpretq_u16_f16(a.get_high()), - vreinterpretq_u16_f16(b.get_high()))); - return Vectorized(r0, r1); + return Vectorized(vreinterpretq_f16_u16(vorrq_u16( + vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)))); } template <> Vectorized inline operator^( const Vectorized& a, const Vectorized& b) { - float16x8_t r0 = vreinterpretq_f16_u16(veorq_u16( - vreinterpretq_u16_f16(a.get_low()), vreinterpretq_u16_f16(b.get_low()))); - float16x8_t r1 = vreinterpretq_f16_u16(veorq_u16( - vreinterpretq_u16_f16(a.get_high()), - vreinterpretq_u16_f16(b.get_high()))); - return Vectorized(r0, r1); + return Vectorized(vreinterpretq_f16_u16(veorq_u16( + vreinterpretq_u16_f16(a), vreinterpretq_u16_f16(b)))); } inline Vectorized Vectorized::eq( @@ -771,7 +629,6 @@ inline void convert(const float16_t* src, int16_t* dst, int64_t n) { for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { vst1q_s16(dst + i, vcvtq_s16_f16(vld1q_f16(src + i))); - vst1q_s16(dst + i + 8, vcvtq_s16_f16(vld1q_f16(src + i + 8))); } #ifndef __msvc_cl__ #pragma unroll @@ -790,7 +647,6 @@ inline void convert(const int16_t* src, float16_t* dst, int64_t n) { for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { vst1q_f16(dst + i, vcvtq_f16_s16(vld1q_s16(src + i))); - vst1q_f16(dst + i + 8, vcvtq_f16_s16(vld1q_s16(src + i + 8))); } #ifndef __msvc_cl__ #pragma unroll @@ -805,9 +661,7 @@ Vectorized inline fmadd( const Vectorized& a, const Vectorized& b, const Vectorized& c) { - float16x8_t r0 = vfmaq_f16(c.get_low(), a.get_low(), b.get_low()); - float16x8_t r1 = vfmaq_f16(c.get_high(), a.get_high(), b.get_high()); - return Vectorized(r0, r1); + return Vectorized(vfmaq_f16(c, a, b)); } template <> @@ -815,9 +669,7 @@ Vectorized inline fmsub( const Vectorized& a, const Vectorized& b, const Vectorized& c) { - float16x8_t r0 = vfmsq_f16(c.get_low(), a.get_low(), b.get_low()); - float16x8_t r1 = vfmsq_f16(c.get_high(), a.get_high(), b.get_high()); - return Vectorized(r0, r1); + return Vectorized(vfmsq_f16(c, a, b)); } #endif /* defined(aarch64) && defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(C10_MOBILE) */ diff --git a/aten/src/ATen/cpu/vec/vec256/vec256.h b/aten/src/ATen/cpu/vec/vec256/vec256.h index 68367b81bd8a0..f88e852303912 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256.h @@ -9,9 +9,6 @@ #if !(defined(__VSX__) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_ZVECTOR)) #if defined(CPU_CAPABILITY_SVE256) #include -#else -#include -#include #endif #include #include diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h index 12c11abb748de..ec84c7bfa5356 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h @@ -1101,35 +1101,27 @@ CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16); inline std::tuple, Vectorized> convert_half_float(const Vectorized& a) { static_assert(Vectorized::size() == 2 * Vectorized::size()); #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - float16x8x2_t arr = a; - float16x8_t x = arr.val[0]; - float16x8_t y = arr.val[1]; + float16x8_t x = a; #else auto arr = reinterpret_cast(a.operator const Half*()); float16x8_t x = vld1q_f16(arr); - float16x8_t y = vld1q_f16(arr + Vectorized::size()); #endif float32x4_t x1 = vcvt_f32_f16(vget_low_f16(x)); float32x4_t x2 = vcvt_f32_f16(vget_high_f16(x)); - float32x4_t y1 = vcvt_f32_f16(vget_low_f16(y)); - float32x4_t y2 = vcvt_f32_f16(vget_high_f16(y)); - return { Vectorized(x1, x2), Vectorized(y1, y2) }; + return { Vectorized(x1), Vectorized(x2) }; } inline Vectorized convert_float_half(const Vectorized& a, const Vectorized& b) { static_assert(Vectorized::size() == 2 * Vectorized::size()); - float32x4x2_t x = a; - float32x4x2_t y = b; - float16x4_t x1 = vcvt_f16_f32(x.val[0]); - float16x4_t x2 = vcvt_f16_f32(x.val[1]); - float16x4_t y1 = vcvt_f16_f32(y.val[0]); - float16x4_t y2 = vcvt_f16_f32(y.val[1]); + float32x4_t x = a; + float32x4_t y = b; + float16x4_t x1 = vcvt_f16_f32(x); + float16x4_t x2 = vcvt_f16_f32(y); #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - return Vectorized(vcombine_f16(x1, x2), vcombine_f16(y1, y2)); + return Vectorized(vcombine_f16(x1, x2)); #else Vectorized rc; auto arr = reinterpret_cast(rc.operator Half*()); vst1q_f16(arr, vcombine_f16(x1, x2)); - vst1q_f16(arr + Vectorized::size(), vcombine_f16(y1, y2)); return rc; #endif } diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_convert.h b/aten/src/ATen/cpu/vec/vec256/vec256_convert.h index 9950d606c21c1..7ae2e8168c74d 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_convert.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_convert.h @@ -271,9 +271,9 @@ struct VecConvert< 1, int64_t, 2, - typename std::enable_if< + std::enable_if_t< std::is_same_v || - std::is_same_v>::type> { + std::is_same_v>> { static inline VectorizedN apply( const VectorizedN& src) { return VecConvert::apply( @@ -284,7 +284,7 @@ struct VecConvert< #endif /* defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER) */ -#if (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)) || defined(CPU_CAPABILITY_NEON) +#if (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)) template struct VecConvert< float, @@ -298,19 +298,59 @@ struct VecConvert< } }; #endif +#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) +template +struct VecConvert< + float, + 1, + src_t, + 1, + typename std::enable_if_t, + void>> { + static inline VectorizedN apply(const VectorizedN& src) { + return convert_int8_half_register_to_float(src[0]); + } +}; +template +struct VecConvert< + float, + 2, + src_t, + 1, + typename std::enable_if_t, + void>> { + static inline VectorizedN apply(const VectorizedN& src) { + const auto [v0, v1] = convert_int8_to_float(src[0]); + return VectorizedN(v0, v1); + } +}; +#endif -#if defined(CPU_CAPABILITY_NEON) +#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) template <> -struct VecConvert { - static inline VectorizedN apply( +struct VecConvert { + static inline VectorizedN apply( const VectorizedN& src) { - VectorizedN result; + VectorizedN result; uint16x8_t u16_8 = vld1q_u16(reinterpret_cast(&src[0])); auto u16_low1 = vget_low_u16(u16_8); auto u16_high1 = vget_high_u16(u16_8); float32x4_t f32x4_0 = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(u16_low1), 16)); float32x4_t f32x4_1 = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(u16_high1), 16)); - result[0] = {f32x4_0, f32x4_1}; + result[0] = f32x4_0; + result[1] = f32x4_1; + return result; + } +}; +// Half register to full register. +template <> +struct VecConvert { + static inline VectorizedN apply( + const VectorizedN& src) { + VectorizedN result; + uint16x4_t u16_8 = vld1_u16(reinterpret_cast(&src[0])); + float32x4_t f32x4_0 = vreinterpretq_f32_u32(vshlq_n_u32(vmovl_u16(u16_8), 16)); + result[0] = f32x4_0; return result; } }; diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_float.h index dab1790b26ab0..687dc71ef8691 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_float.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_float.h @@ -35,6 +35,8 @@ template <> class Vectorized { float val5, float val6, float val7, float val8) { values = _mm256_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8); } + Vectorized(const float (&arr)[8]) + : Vectorized(arr[0], arr[1], arr[2], arr[3], arr[4], arr[5], arr[6], arr[7]) {} operator __m256() const { return values; } @@ -216,27 +218,27 @@ template <> class Vectorized { } Vectorized exp_u20() const { // A faster version of exp with ULP=20 - static __m256 vec_factorial_1 = + const __m256 vec_factorial_1 = _mm256_set1_ps(0.999999701f); // 1/factorial(1) - static __m256 vec_factorial_2 = + const __m256 vec_factorial_2 = _mm256_set1_ps(0.499991506f); // 1/factorial(2) - static __m256 vec_factorial_3 = + const __m256 vec_factorial_3 = _mm256_set1_ps(0.166676521f); // 1/factorial(3) - static __m256 vec_factorial_4 = + const __m256 vec_factorial_4 = _mm256_set1_ps(0.0418978221f); // 1/factorial(4) - static __m256 vec_factorial_5 = + const __m256 vec_factorial_5 = _mm256_set1_ps(0.00828929059f); // 1/factorial(5) - static __m256 vec_exp_log2ef = + const __m256 vec_exp_log2ef = _mm256_castsi256_ps(_mm256_set1_epi32(0x3fb8aa3b)); // log2(e) - static __m256 vec_half = _mm256_set1_ps(0.5f); - static __m256 vec_one = _mm256_set1_ps(1.f); - static __m256 vec_zero = _mm256_set1_ps(0.f); - static __m256 vec_two = _mm256_set1_ps(2.f); - static __m256 vec_ln2f = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f317218)); // ln(2) - static __m256 vec_ln_flt_min = _mm256_castsi256_ps(_mm256_set1_epi32(0xc2aeac50)); - static __m256 vec_ln_flt_max = _mm256_castsi256_ps(_mm256_set1_epi32(0x42b17218)); - static __m256i vec_127 = _mm256_set1_epi32(0x0000007f); - static int n_mantissa_bits = 23; + const __m256 vec_half = _mm256_set1_ps(0.5f); + const __m256 vec_one = _mm256_set1_ps(1.f); + const __m256 vec_zero = _mm256_set1_ps(0.f); + const __m256 vec_two = _mm256_set1_ps(2.f); + const __m256 vec_ln2f = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f317218)); // ln(2) + const __m256 vec_ln_flt_min = _mm256_castsi256_ps(_mm256_set1_epi32(0xc2aeac50)); + const __m256 vec_ln_flt_max = _mm256_castsi256_ps(_mm256_set1_epi32(0x42b17218)); + const __m256i vec_127 = _mm256_set1_epi32(0x0000007f); + const int n_mantissa_bits = 23; // exp(x) = // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h b/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h deleted file mode 100644 index fdf9d66898646..0000000000000 --- a/aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h +++ /dev/null @@ -1,909 +0,0 @@ -#pragma once - -// DO NOT DEFINE STATIC DATA IN THIS HEADER! -// See Note [Do not compile initializers with AVX] - -#include -#include -#include - -#if defined(__aarch64__) && defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) -#include -#endif - -// Sleef offers vectorized versions of some transcedentals -// such as sin, cos, tan etc.. -// However for now opting for STL, since we are not building -// with Sleef for mobile yet. - -namespace at::vec { -// See Note [CPU_CAPABILITY namespace] -inline namespace CPU_CAPABILITY { - -// Right now contains only aarch64 implementation. -// Due to follow two reasons aarch32 is not currently supported. -// 1. Due to difference in ISA been aarch32 and aarch64, intrinsics -// that work for aarch64 dont work for aarch32. -// 2. Android NDK r21 has problems with compiling aarch32. -// Clang seg faults. -// https://github.com/android/ndk/issues/1248 -// https://bugs.llvm.org/show_bug.cgi?id=45824 -// Most likely we will do aarch32 support with inline asm. -#if defined(__aarch64__) - -#ifdef __BIG_ENDIAN__ -#error "Big endian is not supported." -#endif - -#if defined(AT_BUILD_ARM_VEC256_WITH_SLEEF) -#define USE_SLEEF(sleef_code, non_sleef_code) sleef_code -#else -#define USE_SLEEF(sleef_code, non_sleef_code) non_sleef_code -#endif - -template -struct BlendRegs { - static float32x4_t impl( - const float32x4_t& a, const float32x4_t& b, float32x4_t& res); -}; - -template -struct BlendRegs{ - static float32x4_t impl( - const float32x4_t& a, const float32x4_t& b, float32x4_t& res) { - return vsetq_lane_f32(vgetq_lane_f32(b, index), res, index); - } -}; - -template -struct BlendRegs{ - static float32x4_t impl( - const float32x4_t& a, const float32x4_t& b, float32x4_t& res) { - return vsetq_lane_f32(vgetq_lane_f32(a, index), res, index); - } -}; - -template <> class Vectorized { -private: - float32x4x2_t values; -public: - using value_type = float; - using size_type = int; - static constexpr size_type size() { - return 8; - } - Vectorized() {} - Vectorized(float32x4x2_t v) : values(v) {} - Vectorized(float val) : values{vdupq_n_f32(val), vdupq_n_f32(val) } {} - Vectorized(float val0, float val1, float val2, float val3, - float val4, float val5, float val6, float val7) : - values{val0, val1, val2, val3, val4, val5, val6, val7} {} - Vectorized(float32x4_t val0, float32x4_t val1) : values{val0, val1} {} - operator float32x4x2_t() const { - return values; - } - template - static Vectorized blend(const Vectorized& a, const Vectorized& b) { - Vectorized vec; - // 0. - vec.values.val[0] = - BlendRegs<0, (mask & 0x01)!=0>::impl( - a.values.val[0], b.values.val[0], vec.values.val[0]); - vec.values.val[0] = - BlendRegs<1, (mask & 0x02)!=0>::impl( - a.values.val[0], b.values.val[0], vec.values.val[0]); - vec.values.val[0] = - BlendRegs<2, (mask & 0x04)!=0>::impl( - a.values.val[0], b.values.val[0], vec.values.val[0]); - vec.values.val[0] = - BlendRegs<3, (mask & 0x08)!=0>::impl( - a.values.val[0], b.values.val[0], vec.values.val[0]); - // 1. - vec.values.val[1] = - BlendRegs<0, (mask & 0x10)!=0>::impl( - a.values.val[1], b.values.val[1], vec.values.val[1]); - vec.values.val[1] = - BlendRegs<1, (mask & 0x20)!=0>::impl( - a.values.val[1], b.values.val[1], vec.values.val[1]); - vec.values.val[1] = - BlendRegs<2, (mask & 0x40)!=0>::impl( - a.values.val[1], b.values.val[1], vec.values.val[1]); - vec.values.val[1] = - BlendRegs<3, (mask & 0x80)!=0>::impl( - a.values.val[1], b.values.val[1], vec.values.val[1]); - return vec; - } - static Vectorized blendv(const Vectorized& a, const Vectorized& b, - const Vectorized& mask) { - // TODO - // NB: This requires that each value, i.e., each uint value, - // of the mask either all be zeros or all be 1s. - // We perhaps need some kind of an assert? - // But that will affect performance. - Vectorized vec(mask.values); - vec.values.val[0] = vbslq_f32( - vreinterpretq_u32_f32(vec.values.val[0]), - b.values.val[0], - a.values.val[0]); - vec.values.val[1] = vbslq_f32( - vreinterpretq_u32_f32(vec.values.val[1]), - b.values.val[1], - a.values.val[1]); - return vec; - } - template - static Vectorized arange(float base = 0.f, step_t step = static_cast(1)) { - const Vectorized base_vec(base); - const Vectorized step_vec(step); - const Vectorized step_sizes(0, 1, 2, 3, 4, 5, 6, 7); - return fmadd(step_sizes, step_vec, base_vec); - } - static Vectorized set(const Vectorized& a, const Vectorized& b, - int64_t count = size()) { - switch (count) { - case 0: - return a; - case 1: - { - Vectorized vec; - static uint32x4_t mask_low = {0xFFFFFFFF, 0x0, 0x0, 0x0}; - vec.values.val[0] = vreinterpretq_f32_u32(mask_low); - vec.values.val[1] = a.values.val[1]; - vec.values.val[0] = vbslq_f32( - vreinterpretq_u32_f32(vec.values.val[0]), - b.values.val[0], - a.values.val[0]); - return vec; - } - case 2: - { - Vectorized vec; - static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0}; - vec.values.val[0] = vreinterpretq_f32_u32(mask_low); - vec.values.val[1] = a.values.val[1]; - vec.values.val[0] = vbslq_f32( - vreinterpretq_u32_f32(vec.values.val[0]), - b.values.val[0], - a.values.val[0]); - return vec; - } - case 3: - { - Vectorized vec; - static uint32x4_t mask_low = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0}; - vec.values.val[0] = vreinterpretq_f32_u32(mask_low); - vec.values.val[1] = a.values.val[1]; - vec.values.val[0] = vbslq_f32( - vreinterpretq_u32_f32(vec.values.val[0]), - b.values.val[0], - a.values.val[0]); - return vec; - } - case 4: - return Vectorized(b.values.val[0], a.values.val[1]); - case 5: - { - Vectorized vec; - static uint32x4_t mask_high = {0xFFFFFFFF, 0x0, 0x0, 0x0}; - vec.values.val[0] = b.values.val[0]; - vec.values.val[1] = vreinterpretq_f32_u32(mask_high); - vec.values.val[1] = vbslq_f32( - vreinterpretq_u32_f32(vec.values.val[1]), - b.values.val[1], - a.values.val[1]); - return vec; - } - case 6: - { - Vectorized vec; - static uint32x4_t mask_high = {0xFFFFFFFF, 0xFFFFFFFF, 0x0, 0x0}; - vec.values.val[0] = b.values.val[0]; - vec.values.val[1] = vreinterpretq_f32_u32(mask_high); - vec.values.val[1] = vbslq_f32( - vreinterpretq_u32_f32(vec.values.val[1]), - b.values.val[1], - a.values.val[1]); - return vec; - } - case 7: - { - Vectorized vec; - static uint32x4_t mask_high = {0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0x0}; - vec.values.val[0] = b.values.val[0]; - vec.values.val[1] = vreinterpretq_f32_u32(mask_high); - vec.values.val[1] = vbslq_f32( - vreinterpretq_u32_f32(vec.values.val[1]), - b.values.val[1], - a.values.val[1]); - return vec; - } - } - return b; - } - static Vectorized loadu(const void* ptr, int64_t count = size()) { - if (count == size()) { - return vld1q_f32_x2(reinterpret_cast(ptr)); - } - else if (count == (size() >> 1)) { - Vectorized res; - res.values.val[0] = vld1q_f32(reinterpret_cast(ptr)); - res.values.val[1] = vdupq_n_f32(0.f); - return res; - } - else { - __at_align__ float tmp_values[size()]; - for (const auto i : c10::irange(size())) { - tmp_values[i] = 0.0; - } - std::memcpy( - tmp_values, - reinterpret_cast(ptr), - count * sizeof(float)); - return vld1q_f32_x2(reinterpret_cast(tmp_values)); - } - } - void store(void* ptr, int64_t count = size()) const { - if (count == size()) { - vst1q_f32_x2(reinterpret_cast(ptr), values); - } - else if (count == (size() >> 1)) { - vst1q_f32(reinterpret_cast(ptr), values.val[0]); - } - else { - float tmp_values[size()]; - vst1q_f32_x2(reinterpret_cast(tmp_values), values); - std::memcpy(ptr, tmp_values, count * sizeof(float)); - } - } - inline const float32x4_t& get_low() const { - return values.val[0]; - } - inline float32x4_t& get_low() { - return values.val[0]; - } - inline const float32x4_t& get_high() const { - return values.val[1]; - } - inline float32x4_t& get_high() { - return values.val[1]; - } - // Very slow implementation of indexing. - // Only required because vec256_qint refers to this. - // Once we specialize that implementation for ARM - // this should be removed. TODO (kimishpatel) - float operator[](int idx) const { - __at_align__ float tmp[size()]; - store(tmp); - return tmp[idx]; - } - float operator[](int idx) { - __at_align__ float tmp[size()]; - store(tmp); - return tmp[idx]; - } - // For boolean version where we want to if any 1/all zero - // etc. can be done faster in a different way. - int zero_mask() const { - __at_align__ float tmp[size()]; - store(tmp); - int mask = 0; - for (int i = 0; i < size(); ++ i) { - if (tmp[i] == 0.f) { - mask |= (1 << i); - } - } - return mask; - } - Vectorized isnan() const { - __at_align__ float tmp[size()]; - __at_align__ float res[size()]; - store(tmp); - for (const auto i : c10::irange(size())) { - if (_isnan(tmp[i])) { - std::memset(static_cast(&res[i]), 0xFF, sizeof(float)); - } else { - std::memset(static_cast(&res[i]), 0, sizeof(float)); - } - } - return loadu(res); - }; - bool has_inf_nan() const { - __at_align__ float tmp[size()]; - store(tmp); - for (const auto i : c10::irange(size())) { - if(_isnan(tmp[i]) || _isinf(tmp[i])) { - return true; - } - } - return false; - } - Vectorized map(float (*const f)(float)) const { - __at_align__ float tmp[size()]; - store(tmp); - for (const auto i : c10::irange(size())) { - tmp[i] = f(tmp[i]); - } - return loadu(tmp); - } - Vectorized abs() const { - return Vectorized(vabsq_f32(values.val[0]), vabsq_f32(values.val[1])); - } - Vectorized angle() const { - auto zero = Vectorized(0); - auto pi = Vectorized(c10::pi); - auto tmp = blendv(zero, pi, *this < zero); - return blendv(tmp, *this, isnan()); - } - Vectorized real() const { - return *this; - } - Vectorized imag() const { - return Vectorized(0.f); - } - Vectorized conj() const { - return *this; - } - Vectorized acos() const { - return USE_SLEEF( - Vectorized(Sleef_acosf4_u10(values.val[0]), Sleef_acosf4_u10(values.val[1])), - map(std::acos) - ); - } - Vectorized acosh() const { - return USE_SLEEF( - Vectorized(Sleef_acoshf4_u10(values.val[0]), Sleef_acoshf4_u10(values.val[1])), - map(std::acosh) - ); - } - Vectorized asin() const { - return USE_SLEEF( - Vectorized(Sleef_asinf4_u10(values.val[0]), Sleef_asinf4_u10(values.val[1])), - map(std::asin) - ); - } - Vectorized atan() const { - return USE_SLEEF( - Vectorized(Sleef_atanf4_u10(values.val[0]), Sleef_atanf4_u10(values.val[1])), - map(std::atan) - ); - } - Vectorized atanh() const { - return USE_SLEEF( - Vectorized(Sleef_atanhf4_u10(values.val[0]), Sleef_atanhf4_u10(values.val[1])), - map(std::atanh) - ); - } - Vectorized atan2(const Vectorized &exp) const { - USE_SLEEF( - { - return Vectorized(Sleef_atan2f4_u10(values.val[0], exp.values.val[0]), - Sleef_atan2f4_u10(values.val[1], exp.values.val[1])); - }, - { - __at_align__ float tmp[size()]; - __at_align__ float tmp_exp[size()]; - store(tmp); - exp.store(tmp_exp); - for (const auto i : c10::irange(size())) { - tmp[i] = std::atan2(tmp[i], tmp_exp[i]); - } - return loadu(tmp); - } - ) - } - Vectorized copysign(const Vectorized &sign) const { - USE_SLEEF( - { - return Vectorized(Sleef_copysignf4(values.val[0], sign.values.val[0]), - Sleef_copysignf4(values.val[1], sign.values.val[1])); - }, - { - __at_align__ float tmp[size()]; - __at_align__ float tmp_sign[size()]; - store(tmp); - sign.store(tmp_sign); - for (size_type i = 0; i < size(); i++) { - tmp[i] = std::copysign(tmp[i], tmp_sign[i]); - } - return loadu(tmp); - } - ) - } - Vectorized erf() const; - Vectorized erfc() const { - return USE_SLEEF( - Vectorized(Sleef_erfcf4_u15(values.val[0]), Sleef_erfcf4_u15(values.val[1])), - map(std::erfc) - ); - } - Vectorized erfinv() const { - return map(calc_erfinv); - } - Vectorized exp() const { - return USE_SLEEF( - Vectorized(Sleef_expf4_u10(values.val[0]), Sleef_expf4_u10(values.val[1])), - map(std::exp) - ); - } - Vectorized exp2() const { - return USE_SLEEF( - Vectorized(Sleef_exp2f4_u10(values.val[0]), Sleef_exp2f4_u10(values.val[1])), - map(std::exp2) - ); - } - Vectorized expm1() const { - return USE_SLEEF( - Vectorized(Sleef_expm1f4_u10(values.val[0]), Sleef_expm1f4_u10(values.val[1])), - map(std::expm1) - ); - } - Vectorized exp_u20() const { - return exp(); - } - Vectorized fmod(const Vectorized& q) const { - USE_SLEEF( - { - return Vectorized(Sleef_fmodf4(values.val[0], q.values.val[0]), - Sleef_fmodf4(values.val[1], q.values.val[1])); - }, - { - __at_align__ float tmp[size()]; - __at_align__ float tmp_q[size()]; - store(tmp); - q.store(tmp_q); - for (const auto i : c10::irange(size())) { - tmp[i] = std::fmod(tmp[i], tmp_q[i]); - } - return loadu(tmp); - } - ) - } - Vectorized hypot(const Vectorized &b) const { - USE_SLEEF( - { - return Vectorized(Sleef_hypotf4_u05(values.val[0], b.values.val[0]), - Sleef_hypotf4_u05(values.val[1], b.values.val[1])); - }, - { - __at_align__ float tmp[size()]; - __at_align__ float tmp_b[size()]; - store(tmp); - b.store(tmp_b); - for (const auto i : c10::irange(size())) { - tmp[i] = std::hypot(tmp[i], tmp_b[i]); - } - return loadu(tmp); - } - ) - } - Vectorized i0() const { - return map(calc_i0); - } - Vectorized i0e() const { - return map(calc_i0e); - } - Vectorized digamma() const { - return map(calc_digamma); - } - Vectorized igamma(const Vectorized &x) const { - __at_align__ float tmp[size()]; - __at_align__ float tmp_x[size()]; - store(tmp); - x.store(tmp_x); - for (const auto i : c10::irange(size())) { - tmp[i] = calc_igamma(tmp[i], tmp_x[i]); - } - return loadu(tmp); - } - Vectorized igammac(const Vectorized &x) const { - __at_align__ float tmp[size()]; - __at_align__ float tmp_x[size()]; - store(tmp); - x.store(tmp_x); - for (const auto i : c10::irange(size())) { - tmp[i] = calc_igammac(tmp[i], tmp_x[i]); - } - return loadu(tmp); - } - Vectorized log() const { - return USE_SLEEF( - Vectorized(Sleef_logf4_u10(values.val[0]), Sleef_logf4_u10(values.val[1])), - map(std::log) - ); - } - Vectorized log10() const { - return USE_SLEEF( - Vectorized(Sleef_log10f4_u10(values.val[0]), Sleef_log10f4_u10(values.val[1])), - map(std::log10) - ); - } - Vectorized log1p() const { - return USE_SLEEF( - Vectorized(Sleef_log1pf4_u10(values.val[0]), Sleef_log1pf4_u10(values.val[1])), - map(std::log1p) - ); - } - Vectorized log2() const { - return USE_SLEEF( - Vectorized(Sleef_log2f4_u10(values.val[0]), Sleef_log2f4_u10(values.val[1])), - map(std::log2) - ); - } - Vectorized nextafter(const Vectorized &b) const { - USE_SLEEF( - { - return Vectorized(Sleef_nextafterf4(values.val[0], b.values.val[0]), - Sleef_nextafterf4(values.val[1], b.values.val[1])); - }, - { - __at_align__ float tmp[size()]; - __at_align__ float tmp_b[size()]; - store(tmp); - b.store(tmp_b); - for (const auto i : c10::irange(size())) { - tmp[i] = std::nextafter(tmp[i], tmp_b[i]); - } - return loadu(tmp); - } - ) - } - Vectorized frac() const; - Vectorized sin() const { - return USE_SLEEF( - Vectorized(Sleef_sinf4_u10(values.val[0]), Sleef_sinf4_u10(values.val[1])), - map(std::sin) - ); - } - Vectorized sinh() const { - return USE_SLEEF( - Vectorized(Sleef_sinhf4_u10(values.val[0]), Sleef_sinhf4_u10(values.val[1])), - map(std::sinh) - ); - } - Vectorized cos() const { - return USE_SLEEF( - Vectorized(Sleef_cosf4_u10(values.val[0]), Sleef_cosf4_u10(values.val[1])), - map(std::cos) - ); - } - Vectorized cosh() const { - return USE_SLEEF( - Vectorized(Sleef_coshf4_u10(values.val[0]), Sleef_coshf4_u10(values.val[1])), - map(std::cosh) - ); - } - Vectorized ceil() const { - return map(at::native::ceil_impl); - } - Vectorized floor() const { - return map(at::native::floor_impl); - } - Vectorized neg() const { - return Vectorized( - vnegq_f32(values.val[0]), - vnegq_f32(values.val[1])); - } - Vectorized round() const { - // We do not use std::round because we would like to round midway numbers to the nearest even integer. - return map(at::native::round_impl); - } - Vectorized tan() const { - return USE_SLEEF( - Vectorized(Sleef_tanf4_u10(values.val[0]), Sleef_tanf4_u10(values.val[1])), - map(std::tan) - ); - } - Vectorized tanh() const { - return USE_SLEEF( - Vectorized(Sleef_tanhf4_u10(values.val[0]), Sleef_tanhf4_u10(values.val[1])), - map(std::tanh) - ); - } - Vectorized trunc() const { - float32x4_t r0 = vrndq_f32(values.val[0]); - float32x4_t r1 = vrndq_f32(values.val[1]); - return Vectorized(r0, r1); - } - Vectorized lgamma() const { - return USE_SLEEF( - Vectorized(Sleef_lgammaf4_u10(values.val[0]), Sleef_lgammaf4_u10(values.val[1])), - map(std::lgamma) - ); - } - Vectorized sqrt() const { - return Vectorized( - vsqrtq_f32(values.val[0]), - vsqrtq_f32(values.val[1])); - } - Vectorized reciprocal() const { - auto r0 = vdivq_f32(vdupq_n_f32(1.0f), values.val[0]); - auto r1 = vdivq_f32(vdupq_n_f32(1.0f), values.val[1]); - return Vectorized(r0, r1); - } - Vectorized rsqrt() const { - return this->sqrt().reciprocal(); - } - Vectorized pow(const Vectorized &exp) const { - USE_SLEEF( - { - return Vectorized(Sleef_powf4_u10(values.val[0], exp.values.val[0]), - Sleef_powf4_u10(values.val[1], exp.values.val[1])); - }, - { - __at_align__ float tmp[size()]; - __at_align__ float tmp_exp[size()]; - store(tmp); - exp.store(tmp_exp); - for (const auto i : c10::irange(size())) { - tmp[i] = std::pow(tmp[i], tmp_exp[i]); - } - return loadu(tmp); - } - ) - } - Vectorized operator==(const Vectorized& other) const { - float32x4_t r0 = - vreinterpretq_f32_u32(vceqq_f32(values.val[0], other.values.val[0])); - float32x4_t r1 = - vreinterpretq_f32_u32(vceqq_f32(values.val[1], other.values.val[1])); - return Vectorized(r0, r1); - } - - Vectorized operator!=(const Vectorized& other) const { - float32x4_t r0 = vreinterpretq_f32_u32( - vmvnq_u32(vceqq_f32(values.val[0], other.values.val[0]))); - float32x4_t r1 = vreinterpretq_f32_u32( - vmvnq_u32(vceqq_f32(values.val[1], other.values.val[1]))); - return Vectorized(r0, r1); - } - - Vectorized operator<(const Vectorized& other) const { - float32x4_t r0 = - vreinterpretq_f32_u32(vcltq_f32(values.val[0], other.values.val[0])); - float32x4_t r1 = - vreinterpretq_f32_u32(vcltq_f32(values.val[1], other.values.val[1])); - return Vectorized(r0, r1); - } - - Vectorized operator<=(const Vectorized& other) const { - float32x4_t r0 = - vreinterpretq_f32_u32(vcleq_f32(values.val[0], other.values.val[0])); - float32x4_t r1 = - vreinterpretq_f32_u32(vcleq_f32(values.val[1], other.values.val[1])); - return Vectorized(r0, r1); - } - - Vectorized operator>(const Vectorized& other) const { - float32x4_t r0 = - vreinterpretq_f32_u32(vcgtq_f32(values.val[0], other.values.val[0])); - float32x4_t r1 = - vreinterpretq_f32_u32(vcgtq_f32(values.val[1], other.values.val[1])); - return Vectorized(r0, r1); - } - - Vectorized operator>=(const Vectorized& other) const { - float32x4_t r0 = - vreinterpretq_f32_u32(vcgeq_f32(values.val[0], other.values.val[0])); - float32x4_t r1 = - vreinterpretq_f32_u32(vcgeq_f32(values.val[1], other.values.val[1])); - return Vectorized(r0, r1); - } - - Vectorized eq(const Vectorized& other) const; - Vectorized ne(const Vectorized& other) const; - Vectorized gt(const Vectorized& other) const; - Vectorized ge(const Vectorized& other) const; - Vectorized lt(const Vectorized& other) const; - Vectorized le(const Vectorized& other) const; -}; - -template <> -Vectorized inline operator+(const Vectorized& a, const Vectorized& b) { - float32x4_t r0 = vaddq_f32(a.get_low(), b.get_low()); - float32x4_t r1 = vaddq_f32(a.get_high(), b.get_high()); - return Vectorized(r0, r1); -} - -template <> -Vectorized inline operator-(const Vectorized& a, const Vectorized& b) { - float32x4_t r0 = vsubq_f32(a.get_low(), b.get_low()); - float32x4_t r1 = vsubq_f32(a.get_high(), b.get_high()); - return Vectorized(r0, r1); -} - -template <> -Vectorized inline operator*(const Vectorized& a, const Vectorized& b) { - float32x4_t r0 = vmulq_f32(a.get_low(), b.get_low()); - float32x4_t r1 = vmulq_f32(a.get_high(), b.get_high()); - return Vectorized(r0, r1); -} - -template <> -Vectorized inline operator/(const Vectorized& a, const Vectorized& b) { - float32x4_t r0 = vdivq_f32(a.get_low(), b.get_low()); - float32x4_t r1 = vdivq_f32(a.get_high(), b.get_high()); - return Vectorized(r0, r1); -} - -// frac. Implement this here so we can use subtraction -inline Vectorized Vectorized::frac() const { - return *this - this->trunc(); -} - -//Added sleef Implementation for Maximum -Vectorized inline maximum(const Vectorized& a, const Vectorized& b) { - if(!a.has_inf_nan() && !b.has_inf_nan()){ - return USE_SLEEF( - Vectorized(Sleef_fmaxf4(a.get_low(), b.get_low()),Sleef_fmaxf4(a.get_high(), b.get_high())), - Vectorized(vmaxq_f32(a.get_low(), b.get_low()),vmaxq_f32(a.get_high(), b.get_high()))); - } - else{ - return Vectorized(vmaxq_f32(a.get_low(), b.get_low()),vmaxq_f32(a.get_high(), b.get_high())); - } - } - -// Implements the IEEE 754 201X `minimum` operation, which propagates NaN if -// either input is a NaN. -template <> -Vectorized inline minimum(const Vectorized& a, const Vectorized& b) { - float32x4_t r0 = vminq_f32(a.get_low(), b.get_low()); - float32x4_t r1 = vminq_f32(a.get_high(), b.get_high()); - return Vectorized(r0, r1); -} - -template <> -Vectorized inline clamp(const Vectorized& a, const Vectorized& min, const Vectorized& max) { - return minimum(max, maximum(min, a)); -} - -template <> -Vectorized inline clamp_max(const Vectorized& a, const Vectorized& max) { - return minimum(max, a); -} - -template <> -Vectorized inline clamp_min(const Vectorized& a, const Vectorized& min) { - return maximum(min, a); -} - -template <> -Vectorized inline operator&(const Vectorized& a, const Vectorized& b) { - float32x4_t r0 = vreinterpretq_f32_u32(vandq_u32( - vreinterpretq_u32_f32(a.get_low()), - vreinterpretq_u32_f32(b.get_low()))); - float32x4_t r1 = vreinterpretq_f32_u32(vandq_u32( - vreinterpretq_u32_f32(a.get_high()), - vreinterpretq_u32_f32(b.get_high()))); - return Vectorized(r0, r1); -} - -template <> -Vectorized inline operator|(const Vectorized& a, const Vectorized& b) { - float32x4_t r0 = vreinterpretq_f32_u32(vorrq_u32( - vreinterpretq_u32_f32(a.get_low()), - vreinterpretq_u32_f32(b.get_low()))); - float32x4_t r1 = vreinterpretq_f32_u32(vorrq_u32( - vreinterpretq_u32_f32(a.get_high()), - vreinterpretq_u32_f32(b.get_high()))); - return Vectorized(r0, r1); -} - -template <> -Vectorized inline operator^(const Vectorized& a, const Vectorized& b) { - float32x4_t r0 = vreinterpretq_f32_u32(veorq_u32( - vreinterpretq_u32_f32(a.get_low()), - vreinterpretq_u32_f32(b.get_low()))); - float32x4_t r1 = vreinterpretq_f32_u32(veorq_u32( - vreinterpretq_u32_f32(a.get_high()), - vreinterpretq_u32_f32(b.get_high()))); - return Vectorized(r0, r1); -} - -inline Vectorized Vectorized::eq(const Vectorized& other) const { - return (*this == other) & Vectorized(1.0f); -} - -inline Vectorized Vectorized::ne(const Vectorized& other) const { - return (*this != other) & Vectorized(1.0f); -} - -inline Vectorized Vectorized::gt(const Vectorized& other) const { - return (*this > other) & Vectorized(1.0f); -} - -inline Vectorized Vectorized::ge(const Vectorized& other) const { - return (*this >= other) & Vectorized(1.0f); -} - -inline Vectorized Vectorized::lt(const Vectorized& other) const { - return (*this < other) & Vectorized(1.0f); -} - -inline Vectorized Vectorized::le(const Vectorized& other) const { - return (*this <= other) & Vectorized(1.0f); -} - -template <> -inline void convert(const float* src, int32_t* dst, int64_t n) { - int64_t i; -#ifndef __msvc_cl__ -#pragma unroll -#endif - for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { - vst1q_s32(dst + i, vcvtq_s32_f32(vld1q_f32(src + i))); - vst1q_s32(dst + i + 4, vcvtq_s32_f32(vld1q_f32(src + i + 4))); - } -#ifndef __msvc_cl__ -#pragma unroll -#endif - for (; i < n; i++) { - dst[i] = static_cast(src[i]); - } -} - -template <> -inline void convert(const int32_t* src, float* dst, int64_t n) { - int64_t i; -#ifndef __msvc_cl__ -#pragma unroll -#endif - for (i = 0; i <= (n - Vectorized::size()); i += Vectorized::size()) { - vst1q_f32(dst + i, vcvtq_f32_s32(vld1q_s32(src + i))); - vst1q_f32(dst + i + 4, vcvtq_f32_s32(vld1q_s32(src + i + 4))); - } -#ifndef __msvc_cl__ -#pragma unroll -#endif - for (; i < n; i++) { - dst[i] = static_cast(src[i]); - } -} - -template <> -Vectorized inline fmadd(const Vectorized& a, const Vectorized& b, const Vectorized& c) { - float32x4_t r0 = vfmaq_f32(c.get_low(), a.get_low(), b.get_low()); - float32x4_t r1 = vfmaq_f32(c.get_high(), a.get_high(), b.get_high()); - return Vectorized(r0, r1); -} - -template <> -Vectorized inline fmsub(const Vectorized& a, const Vectorized& b, const Vectorized& c) { - float32x4_t r0 = vfmsq_f32(c.get_low(), a.get_low(), b.get_low()); - float32x4_t r1 = vfmsq_f32(c.get_high(), a.get_high(), b.get_high()); - return Vectorized(r0, r1); -} - -inline Vectorized Vectorized::erf() const{ - // constants - const Vectorized neg_zero_vec(-0.f); - const Vectorized one_vec(1.0f); - const Vectorized p(0.3275911f); - const Vectorized p1(0.254829592f); - const Vectorized p2(-0.284496736f); - const Vectorized p3(1.421413741f); - const Vectorized p4(-1.453152027f); - const Vectorized p5(1.061405429f); - // sign(x) - auto sign_mask = neg_zero_vec & *this; - auto abs_vec = this->abs(); - // t = 1 / (p * abs(x) + 1) - auto tmp0 = fmadd(p, abs_vec, one_vec); - auto t = one_vec / tmp0; - // r = p5 * t ^ 4 + p4 * t ^ 3 + p3 * t ^ 2 + p2 * t + p1 - auto tmp1 = fmadd(p5, t, p4); - auto tmp2 = fmadd(tmp1, t, p3); - auto tmp3 = fmadd(tmp2, t, p2); - auto r = fmadd(tmp3, t, p1); - // - exp(- x * x) - auto pow_2 = (*this) * (*this); - auto neg_pow_2 = pow_2 ^ neg_zero_vec; - auto tmp4 = neg_pow_2.map(std::exp); // This can be swapped for a faster implementation of exp. - auto tmp5 = tmp4 ^ neg_zero_vec; - // erf(x) = sign(x) * (1 - r * t * exp(- x * x)) - auto tmp6 = t * tmp5; - auto tmp7 = fmadd(tmp6, r, one_vec); - return tmp7 ^ sign_mask; -} -#endif /* defined(aarch64) */ - -}} // namespace at::vec::CPU_CAPABILITY diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h index a5dcc6dbd9a02..9b900cd0f63ee 100644 --- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h +++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h @@ -258,19 +258,21 @@ __FORCE_INLINE void QuantizeAvx2( template<> struct Vectorized : public Vectorizedqi { using size_type = int; + static constexpr size_type kSize = Vectorized::size(); static constexpr size_type size() { - return 8; + return kSize; } + static constexpr int kFloatNumVecs = kSize / Vectorized::size(); static constexpr int float_num_vecs() { - return 1; + return kFloatNumVecs; } static constexpr int int_num_vecs() { return 1; } - using float_vec_return_type = std::array, 1>; + using float_vec_return_type = std::array, kFloatNumVecs>; using int_vec_return_type = std::array, 1>; using value_type = c10::qint32::underlying; @@ -334,7 +336,7 @@ struct Vectorized : public Vectorizedqi { Vectorized retval; auto rhs_data = (__m256)rhs[0]; at::native::quantize_vec( - scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, 8); + scale, zero_point, (float*)&rhs_data, (c10::qint32*)&retval.vals, size()); return retval; } @@ -447,20 +449,23 @@ __m256i RequantizeAvx2( template<> struct Vectorized : public Vectorizedqi { + static constexpr int kSize = VECTOR_WIDTH; static constexpr int size() { - return 32; + return kSize; } + static constexpr int kFloatNumVecs = kSize / Vectorized::size(); static constexpr int float_num_vecs() { - return 4; + return kFloatNumVecs; } + static constexpr int kIntNumVecs = kSize / Vectorized::size(); static constexpr int int_num_vecs() { - return 4; + return kIntNumVecs; } - using float_vec_return_type = std::array, 4>; - using int_vec_return_type = std::array, 4>; + using float_vec_return_type = std::array, kFloatNumVecs>; + using int_vec_return_type = std::array, kIntNumVecs>; using value_type = typename c10::qint8::underlying; public: @@ -647,20 +652,23 @@ Vectorized inline maximum(const Vectorized& a, const Vec template<> struct Vectorized : public Vectorizedqi { + static constexpr int kSize = VECTOR_WIDTH; static constexpr int size() { - return 32; + return kSize; } + static constexpr int kFloatNumVecs = kSize / Vectorized::size(); static constexpr int float_num_vecs() { - return 4; + return kFloatNumVecs; } + static constexpr int kIntNumVecs = kSize / Vectorized::size(); static constexpr int int_num_vecs() { - return 4; + return kIntNumVecs; } - using float_vec_return_type = std::array, 4>; - using int_vec_return_type = std::array, 4>; + using float_vec_return_type = std::array, kFloatNumVecs>; + using int_vec_return_type = std::array, kIntNumVecs>; using value_type = typename c10::quint8::underlying; public: @@ -864,11 +872,11 @@ struct VectorizedQuantizedConverter { } static constexpr int float_num_vecs() { - return size() / 8; + return size_ / Vectorized::size(); } static constexpr int int_num_vecs() { - return size() / 8; + return size_ / Vectorized::size(); } using float_vec_return_type = float_vec_return_type_; @@ -897,19 +905,12 @@ struct VectorizedQuantizedConverter { Vectorized /*scale_zp_premul*/) const { float_vec_return_type rv; for (const auto i : c10::irange(float_num_vecs())) { - float tmp_vals[8]; - for (const auto j : c10::irange(8)) { + float tmp_vals[Vectorized::size()]; + for (const auto j : c10::irange(Vectorized::size())) { tmp_vals[j] = at::native::dequantize_val( - scale[j], zero_point[j], T(vals[8 * i + j])); + scale[j], zero_point[j], T(vals[Vectorized::size() * i + j])); } - rv[i] = Vectorized(tmp_vals[0], - tmp_vals[1], - tmp_vals[2], - tmp_vals[3], - tmp_vals[4], - tmp_vals[5], - tmp_vals[6], - tmp_vals[7]); + rv[i] = Vectorized(tmp_vals); } return rv; } @@ -930,25 +931,8 @@ struct Vectorized : public VectorizedQuantizedConverter< c10::qint32, std::array, 1>, std::array, 1>, - 8> { - Vectorized() - : VectorizedQuantizedConverter< - c10::qint32, - std::array, 1>, - std::array, 1>, - 8>() {} - Vectorized(c10::qint32 val) - : VectorizedQuantizedConverter< - c10::qint32, - std::array, 1>, - std::array, 1>, - 8>(val) {} - Vectorized(const void* ptr) - : VectorizedQuantizedConverter< - c10::qint32, - std::array, 1>, - std::array, 1>, - 8>(ptr) {} + Vectorized::size()> { + using VectorizedQuantizedConverter::VectorizedQuantizedConverter; static Vectorized loadu(const void* ptr) { return Vectorized(ptr); @@ -973,10 +957,10 @@ struct Vectorized : public VectorizedQuantizedConverter< int32_t zero_point, float /*inverse_scale*/) { std::array qvals; - std::array float_vals; + std::array::size()> float_vals; for (const auto i : c10::irange(float_num_vecs())) { - rhs[i].store(&float_vals[i * 8], 8); + rhs[i].store(&float_vals[i * Vectorized::size()]); } at::native::quantize_vec( @@ -984,7 +968,7 @@ struct Vectorized : public VectorizedQuantizedConverter< zero_point, float_vals.data(), (c10::qint32*)qvals.data(), - 8 * float_num_vecs()); + float_vals.size()); return Vectorized::loadu(qvals.data()); } @@ -1075,25 +1059,8 @@ struct Vectorized : public VectorizedQuantizedConverter< c10::qint8, std::array, 4>, std::array, 4>, - 32> { - Vectorized() - : VectorizedQuantizedConverter< - c10::qint8, - std::array, 4>, - std::array, 4>, - 32>() {} - Vectorized(c10::qint8 val) - : VectorizedQuantizedConverter< - c10::qint8, - std::array, 4>, - std::array, 4>, - 32>(val) {} - Vectorized(const void* ptr) - : VectorizedQuantizedConverter< - c10::qint8, - std::array, 4>, - std::array, 4>, - 32>(ptr) {} + 4 * Vectorized::size()> { + using VectorizedQuantizedConverter::VectorizedQuantizedConverter; static Vectorized loadu(const void* ptr) { return Vectorized(ptr); @@ -1118,10 +1085,10 @@ struct Vectorized : public VectorizedQuantizedConverter< int32_t zero_point, float /*inverse_scale*/) { std::array qvals; - std::array float_vals; + std::array::size()> float_vals; for (const auto i : c10::irange(float_num_vecs())) { - rhs[i].store(&float_vals[i * 8], 8); + rhs[i].store(&float_vals[i * Vectorized::size()]); } at::native::quantize_vec( @@ -1129,7 +1096,7 @@ struct Vectorized : public VectorizedQuantizedConverter< zero_point, float_vals.data(), (c10::qint8*)qvals.data(), - 8 * float_num_vecs()); + float_vals.size()); return Vectorized::loadu(qvals.data()); } @@ -1208,25 +1175,8 @@ struct Vectorized : public VectorizedQuantizedConverter< c10::quint8, std::array, 4>, std::array, 4>, - 32> { - Vectorized() - : VectorizedQuantizedConverter< - c10::quint8, - std::array, 4>, - std::array, 4>, - 32>() {} - Vectorized(c10::quint8 val) - : VectorizedQuantizedConverter< - c10::quint8, - std::array, 4>, - std::array, 4>, - 32>(val) {} - Vectorized(const void* ptr) - : VectorizedQuantizedConverter< - c10::quint8, - std::array, 4>, - std::array, 4>, - 32>(ptr) {} + 4 * Vectorized::size()> { + using VectorizedQuantizedConverter::VectorizedQuantizedConverter; static Vectorized loadu(const void* ptr) { return Vectorized(ptr); @@ -1251,10 +1201,10 @@ struct Vectorized : public VectorizedQuantizedConverter< int32_t zero_point, float /*inverse_scale*/) { std::array qvals; - std::array float_vals; + std::array::size()> float_vals; for (const auto i : c10::irange(float_num_vecs())) { - rhs[i].store(&float_vals[i * 8], 8); + rhs[i].store(&float_vals[i * Vectorized::size()]); } at::native::quantize_vec( @@ -1262,7 +1212,7 @@ struct Vectorized : public VectorizedQuantizedConverter< zero_point, float_vals.data(), (c10::quint8*)qvals.data(), - 8 * float_num_vecs()); + float_vals.size()); return Vectorized::loadu(qvals.data()); } @@ -1339,30 +1289,45 @@ Vectorized inline maximum(const Vectorized& a, const V #endif // if defined(CPU_CAPABILITY_AVX2) -#if defined(CPU_CAPABILITY_NEON) -template -typename std::enable_if_t, at::vec::Vectorized> -inline convert_int8_to_float(at::vec::Vectorized src) { - // Note: this function only convert inputs number of elements equal to at::vec::Vectorized.size() +#if (defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE256)) +std::pair, Vectorized> +inline convert_int8_to_float(at::vec::Vectorized src) { auto s8x8 = vld1_s8(src.operator const int8_t*()); auto s16x8 = vmovl_s8(s8x8); auto s32x4_hi = vmovl_s16(vget_high_s16(s16x8)); auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8)); - return Vectorized(vcvtq_f32_s32(s32x4_lo), vcvtq_f32_s32(s32x4_hi)); + return std::make_pair(Vectorized(vcvtq_f32_s32(s32x4_lo)), Vectorized(vcvtq_f32_s32(s32x4_hi))); } -template -typename std::enable_if_t, at::vec::Vectorized> -inline convert_int8_to_float(at::vec::Vectorized src) { - // Note: this function only convert inputs number of elements equal to at::vec::Vectorized.size() +std::pair, Vectorized> +inline convert_int8_to_float(at::vec::Vectorized src) { auto u8x8 = vld1_u8(src.operator const uint8_t*()); auto u16x8 = vmovl_u8(u8x8); auto u32x4_hi = vmovl_u16(vget_high_u16(u16x8)); auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); - return Vectorized(vcvtq_f32_u32(u32x4_lo), vcvtq_f32_u32(u32x4_hi)); + return std::make_pair(Vectorized(vcvtq_f32_u32(u32x4_lo)), Vectorized(vcvtq_f32_u32(u32x4_hi))); +} + +Vectorized +inline convert_int8_half_register_to_float(at::vec::Vectorized src) { + auto s8x8 = vld1_s8(src.operator const int8_t*()); + auto s16x8 = vmovl_s8(s8x8); + + auto s32x4_lo = vmovl_s16(vget_low_s16(s16x8)); + + return Vectorized(vcvtq_f32_s32(s32x4_lo)); +} + +Vectorized +inline convert_int8_half_register_to_float(at::vec::Vectorized src) { + auto u8x8 = vld1_u8(src.operator const uint8_t*()); + auto u16x8 = vmovl_u8(u8x8); + auto u32x4_lo = vmovl_u16(vget_low_u16(u16x8)); + + return Vectorized(vcvtq_f32_u32(u32x4_lo)); } #endif diff --git a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h index 4ca57363ee4b4..931da5678437b 100644 --- a/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h +++ b/aten/src/ATen/cpu/vec/vec256/zarch/vec256_zarch.h @@ -1190,8 +1190,8 @@ struct Vectorized()>> { typename U = T, std::enable_if_t::value, int> = 0> Vectorized swapped() const { - vtype v0 = vec_permi(_vec0, _vec0, 2); - vtype v1 = vec_permi(_vec1, _vec1, 2); + vtype v0 = {_vec0[1], _vec0[0]}; + vtype v1 = {_vec1[1], _vec1[0]}; return {v0, v1}; } @@ -1685,6 +1685,7 @@ std::pair, Vectorized> unpack(const Vectorized& x) { return {Vectorized{vec0, vec1}, Vectorized{vec2, vec3}}; } +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function") template <> std::pair, Vectorized> unpack( const Vectorized& x) { @@ -1702,6 +1703,7 @@ std::pair, Vectorized> unpack( cast_zvector(Vectorized{vec0, vec1}), cast_zvector(Vectorized{vec2, vec3})}; } +C10_DIAGNOSTIC_POP() template ::type> Vectorized pack(const Vectorized& first, const Vectorized& second) { @@ -1710,6 +1712,7 @@ Vectorized pack(const Vectorized& first, const Vectorized& second) { return Vectorized{vec0, vec1}; } +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function") template <> Vectorized pack( const Vectorized& first, @@ -1718,6 +1721,7 @@ Vectorized pack( auto vec1 = vec_packsu(second.vec0(), second.vec1()); return Vectorized{vec0, vec1}; } +C10_DIAGNOSTIC_POP() } /* unnamed namespace */ @@ -1735,7 +1739,7 @@ struct Vectorized()>> { return VECTOR_WIDTH / sizeof(value_type); } - static constexpr size_t float_num_vecs() { + static constexpr int float_num_vecs() { return size() / Vectorized::size(); } static constexpr int int_num_vecs() { @@ -2419,8 +2423,8 @@ struct Vectorized()>> { static typename Vectorized::vinner_type real_neg(const typename Vectorized::vinner_type &a) { auto a_neg = a.neg(); - auto v0 = vec_permi(a_neg.vec0(), a.vec0(), 1); - auto v1 = vec_permi(a_neg.vec1(), a.vec1(), 1); + vtype v0 = {a_neg.vec0()[0], a.vec0()[1]}; + vtype v1 = {a_neg.vec1()[0], a.vec1()[1]}; return { v0, v1 }; } @@ -2732,10 +2736,10 @@ std::pair, Vectorized> inline inner_interleave2( // a = {a0, a1, a2, a3} // b = {b0, b1, b2, b3} using vtype = typename Vectorized::vtype; - vtype ab00 = vec_permi(a.vec0(), b.vec0(), 0); - vtype ab11 = vec_permi(a.vec0(), b.vec0(), 3); - vtype ab2_00 = vec_permi(a.vec1(), b.vec1(), 0); - vtype ab2_11 = vec_permi(a.vec1(), b.vec1(), 3); + vtype ab00 = {a.vec0()[0], b.vec0()[0]}; + vtype ab11 = {a.vec0()[1], b.vec0()[1]}; + vtype ab2_00 = {a.vec1()[0], b.vec1()[0]}; + vtype ab2_11 = {a.vec1()[1], b.vec1()[1]}; // return {a0, b0, a1, b1} // {a2, b2, a3, b3} return std::make_pair( @@ -2750,11 +2754,11 @@ std::pair, Vectorized> inline inner_deinterleave2( // a = {a0, b0, a1, b1} // b = {a2, b2, a3, b3} using vtype = typename Vectorized::vtype; - vtype aa01 = vec_permi(a.vec0(), a.vec1(), 0); - vtype aa23 = vec_permi(b.vec0(), b.vec1(), 0); + vtype aa01 = {a.vec0()[0], a.vec1()[0]}; + vtype aa23 = {b.vec0()[0], b.vec1()[0]}; - vtype bb_01 = vec_permi(a.vec0(), a.vec1(), 3); - vtype bb_23 = vec_permi(b.vec0(), b.vec1(), 3); + vtype bb_01 = {a.vec0()[1], a.vec1()[1]}; + vtype bb_23 = {b.vec0()[1], b.vec1()[1]}; // swap lanes: // return {a0, a1, a2, a3} @@ -2868,7 +2872,7 @@ std::pair, Vectorized> inline deinterleave2< } template -typename std::enable_if::value, at::vec::Vectorized>::type +std::enable_if_t, at::vec::Vectorized> inline convert_int8_to_float(const Vectorized &src) { // Note: this function only convert inputs number of elements equal to at::vec::Vectorized.size() // Only handle first 64 bits @@ -2878,7 +2882,7 @@ inline convert_int8_to_float(const Vectorized &src) { } template -typename std::enable_if::value, at::vec::Vectorized>::type +std::enable_if_t, at::vec::Vectorized> inline convert_float_to_int8(const Vectorized &src) { constexpr auto min_val = std::numeric_limits::min(); constexpr auto max_val = std::numeric_limits::max(); diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_convert.h b/aten/src/ATen/cpu/vec/vec512/vec512_convert.h index cfb4ddb13732c..af4801cccf488 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_convert.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_convert.h @@ -281,9 +281,9 @@ struct VecConvert< 1, int64_t, 2, - typename std::enable_if< + std::enable_if_t< std::is_same_v || - std::is_same_v>::type> { + std::is_same_v>> { static inline VectorizedN apply( const VectorizedN& src) { return VecConvert::apply( diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_float.h index 4e21eae91cb24..843e2dfcb8795 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_float.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_float.h @@ -40,6 +40,9 @@ template <> class Vectorized { values = _mm512_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8, val9, val10, val11, val12, val13, val14, val15, val16); } + Vectorized(const float (&arr)[16]) + : Vectorized(arr[0], arr[1], arr[2], arr[3], arr[4], arr[5], arr[6], arr[7], + arr[8], arr[9], arr[10], arr[11], arr[12], arr[13], arr[14], arr[15]) {} operator __m512() const { return values; } @@ -236,27 +239,27 @@ template <> class Vectorized { } Vectorized exp_u20() const { // A faster version of exp with ULP=20 - static __m512 vec_factorial_1 = + const __m512 vec_factorial_1 = _mm512_set1_ps(0.999999701f); // 1/factorial(1) - static __m512 vec_factorial_2 = + const __m512 vec_factorial_2 = _mm512_set1_ps(0.499991506f); // 1/factorial(2) - static __m512 vec_factorial_3 = + const __m512 vec_factorial_3 = _mm512_set1_ps(0.166676521f); // 1/factorial(3) - static __m512 vec_factorial_4 = + const __m512 vec_factorial_4 = _mm512_set1_ps(0.0418978221f); // 1/factorial(4) - static __m512 vec_factorial_5 = + const __m512 vec_factorial_5 = _mm512_set1_ps(0.00828929059f); // 1/factorial(5) - static __m512 vec_exp_log2ef = + const __m512 vec_exp_log2ef = _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e) - static __m512 vec_half = _mm512_set1_ps(0.5f); - static __m512 vec_one = _mm512_set1_ps(1.f); - static __m512 vec_zero = _mm512_set1_ps(0.f); - static __m512 vec_two = _mm512_set1_ps(2.f); - static __m512 vec_ln2f = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2) - static __m512 vec_ln_flt_min = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); - static __m512 vec_ln_flt_max = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); - static __m512i vec_127 = _mm512_set1_epi32(0x0000007f); - static int n_mantissa_bits = 23; + const __m512 vec_half = _mm512_set1_ps(0.5f); + const __m512 vec_one = _mm512_set1_ps(1.f); + const __m512 vec_zero = _mm512_set1_ps(0.f); + const __m512 vec_two = _mm512_set1_ps(2.f); + const __m512 vec_ln2f = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2) + const __m512 vec_ln_flt_min = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50)); + const __m512 vec_ln_flt_max = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218)); + const __m512i vec_127 = _mm512_set1_epi32(0x0000007f); + const int n_mantissa_bits = 23; // exp(x) = // = exp(n * ln(2) + r) // divide x by ln(2) and get quot and rem diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_mask.h b/aten/src/ATen/cpu/vec/vec512/vec512_mask.h index cdb433af25254..d32e1da1cf72c 100644 --- a/aten/src/ATen/cpu/vec/vec512/vec512_mask.h +++ b/aten/src/ATen/cpu/vec/vec512/vec512_mask.h @@ -84,9 +84,9 @@ struct VecMaskLoad< dst_n, mask_t, dst_n, - typename std::enable_if< + std::enable_if_t< std::is_same_v || - std::is_same_v>::type> { + std::is_same_v>> { static inline VectorizedN apply( const data_t* ptr, const VecMask& vec_mask) { @@ -151,9 +151,9 @@ struct VecMaskLoad< 1, mask_t, 1, - typename std::enable_if< + std::enable_if_t< std::is_same_v || - std::is_same_v>::type> { + std::is_same_v>> { static inline VectorizedN apply( const data_t* ptr, const VecMask& vec_mask) { @@ -173,9 +173,9 @@ struct VecMaskLoad< 2, mask_t, 1, - typename std::enable_if< + std::enable_if_t< std::is_same_v || - std::is_same_v>::type> { + std::is_same_v>> { static inline VectorizedN apply( const data_t* ptr, const VecMask& vec_mask) { diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h index ba7865cb522f2..2b29caf5edd61 100644 --- a/aten/src/ATen/cpu/vec/vec_base.h +++ b/aten/src/ATen/cpu/vec/vec_base.h @@ -62,6 +62,16 @@ Windows llvm will not have this defination. #endif #define VECTOR_WIDTH 64 #define int_vector __m512i +#elif defined(__aarch64__) && !defined(CPU_CAPABILITY_SVE) // CPU_CAPABILITY_AVX512 +// SVE code expects 256-vectors; leave that set for SVE? +#if defined(__GNUC__) +#define __at_align__ __attribute__((aligned(16))) +#elif defined(_WIN32) +#define __at_align__ __declspec(align(16)) +#else +#define __at_align__ +#endif +#define VECTOR_WIDTH 16 #else // CPU_CAPABILITY_AVX512 #if defined(__GNUC__) #define __at_align__ __attribute__((aligned(32))) @@ -138,40 +148,10 @@ struct Vectorized { public: using value_type = T; using size_type = int; - // Note [constexpr static function to avoid odr-usage compiler bug] - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - // Why, you might ask, is size defined to be a static constexpr function, - // rather than a more ordinary 'static constexpr int size;' variable? - // The problem lies within ODR rules for static constexpr members versus - // static constexpr functions. First, recall that this class (along with all - // of its derivations) live in an anonymous namespace: they are intended to be - // *completely* inlined at their use-sites, because we need to compile it - // multiple times for different instruction sets. - // - // Because of this constraint, we CANNOT provide a single definition for - // any static members in this class; since we want to compile the class - // multiple times, there wouldn't actually be any good place to put the - // definition. Now here is the problem: if we ODR-use a static constexpr - // member, we are *obligated* to provide a definition. Without the - // definition, you get a compile error like: - // - // relocation R_X86_64_PC32 against undefined symbol - // `_ZN2at6vec25612_GLOBAL__N_16VectorizedIdE4sizeE' can not be used when making - // a shared object; recompile with -fPIC - // - // If this were C++17, we could replace a static constexpr variable with - // an inline variable which doesn't require one definition. But we are not - // C++17. So the next best thing is to replace the member with a static - // constexpr (and therefore inline) function, which does not require ODR - // either. - // - // Also, technically according to the C++ standard, we don't have to define - // a constexpr variable if we never odr-use it. But it seems that some - // versions GCC/Clang have buggy determinations on whether or not an - // identifier is odr-used or not, and in any case it's hard to tell if - // a variable is odr-used or not. So best to just cut the problem at the root. + + static constexpr size_type kSize = VECTOR_WIDTH / sizeof(T); static constexpr size_type size() { - return VECTOR_WIDTH / sizeof(T); + return kSize; } Vectorized() : values{static_cast(0)} {} Vectorized(T val) { @@ -183,6 +163,9 @@ struct Vectorized { typename = std::enable_if_t<(sizeof...(Args) == size())>> Vectorized(Args... vals) : values{vals...}{ } + Vectorized(const T(&arr)[kSize]) { + std::memcpy(values, arr, sizeof(values)); + } // This also implies const T& operator[](int idx) const inline operator const T*() const { return values; @@ -209,8 +192,13 @@ struct Vectorized { } return vector; } - static Vectorized blendv(const Vectorized& a, const Vectorized& b, - const Vectorized& mask) { +// Workaround for https: //gcc.gnu.org/bugzilla/show_bug.cgi?id=117001 +#if __GNUC__ <= 12 && defined(__ARM_FEATURE_SVE) + static Vectorized __attribute__ ((optimize("-fno-tree-loop-vectorize"))) blendv(const Vectorized& a, +#else + static Vectorized blendv(const Vectorized& a, +#endif + const Vectorized& b, const Vectorized& mask) { Vectorized vector; int_same_size_t buffer[size()]; mask.store(buffer); @@ -290,6 +278,19 @@ struct Vectorized { } return false; } +// TODO: Remove this once the issue with MSVC is fixed +// See https://developercommunity.visualstudio.com/t/MSVC-loop-unrolling-problem-194033813-/10720692 +#if defined(_WIN32) && defined(__aarch64__) + Vectorized map(T (*const f)(T)) const { + Vectorized ret; + for (int64_t i = 0; i < size(); i++) { + ret[i] = f(values[i]); + if (++i < size()) + ret[i] = f(values[i]); + } + return ret; + } +#else Vectorized map(T (*const f)(T)) const { Vectorized ret; for (int64_t i = 0; i != size(); i++) { @@ -297,6 +298,7 @@ struct Vectorized { } return ret; } +#endif Vectorized map(T (*const f)(const T &)) const { Vectorized ret; for (int64_t i = 0; i != size(); i++) { @@ -1116,7 +1118,7 @@ inline void convert(const src_T *src, dst_T *dst, int64_t n) { #ifndef _MSC_VER # pragma unroll #endif - for (C10_UNUSED const auto i : c10::irange(n)) { + for ([[maybe_unused]] const auto i : c10::irange(n)) { *dst = c10::convert(c10::load(src)); src++; dst++; diff --git a/aten/src/ATen/cpu/vec/vec_mask.h b/aten/src/ATen/cpu/vec/vec_mask.h index a39ffa3090b8e..c547e5911ecbd 100644 --- a/aten/src/ATen/cpu/vec/vec_mask.h +++ b/aten/src/ATen/cpu/vec/vec_mask.h @@ -279,6 +279,7 @@ VEC_MASK_DEFINE_UNARY_OP_GLOBAL(operator~) VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator&) VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator|) VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator^) +VEC_MASK_DEFINE_BINARY_OP_GLOBAL(operator*) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator>, a & ~b) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator<, ~a& b) VEC_MASK_DEFINE_BINARY_OP_WITH_EXPR_GLOBAL(operator==, ~(a ^ b)) diff --git a/aten/src/ATen/cpu/vec/vec_n.h b/aten/src/ATen/cpu/vec/vec_n.h index 8c4e622682a28..ec17ab0e45e51 100644 --- a/aten/src/ATen/cpu/vec/vec_n.h +++ b/aten/src/ATen/cpu/vec/vec_n.h @@ -77,6 +77,21 @@ class VectorizedN { return result; } + template + inline VectorizedN ternary_op( + const VectorizedN& other, + const VectorizedN& other2, + Op op) const { + VectorizedN result; +#ifndef _MSC_VER +#pragma unroll +#endif + for (int i = 0; i < N; ++i) { + result.values[i] = op(values[i], other.values[i], other2.values[i]); + } + return result; + } + VectorizedN() = default; explicit VectorizedN(T val) { @@ -89,7 +104,8 @@ class VectorizedN { VectorizedN(const Vectorized& val) : values({val}) {} template = 0> - VectorizedN(const Vectorized& val_0, const Vectorized& val_1) : values({val_0, val_1}) {} + VectorizedN(const Vectorized& val_0, const Vectorized& val_1) + : values({val_0, val_1}) {} template = 0> inline operator Vectorized() const { @@ -110,7 +126,8 @@ class VectorizedN { const VectorizedN& b) { VectorizedN result; for (int i = 0; i < N; ++i) { - result.values[i] = Vectorized::template blend(a.values[i], b.values[i]); + result.values[i] = + Vectorized::template blend(a.values[i], b.values[i]); } return result; } @@ -306,6 +323,20 @@ class VectorizedN { }); \ } +#define VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(op) \ + template \ + inline VectorizedN op( \ + const VectorizedN& a, \ + const VectorizedN& b, \ + const VectorizedN& c) { \ + return a.ternary_op( \ + b, \ + c, \ + [](const Vectorized& a, \ + const Vectorized& b, \ + const Vectorized& c) { return op(a, b, c); }); \ + } + #define VECTORIZEDN_DEFINE_BINARY_OP_INPLACE_GLOBAL(op) \ template \ inline VectorizedN& op( \ @@ -326,9 +357,9 @@ VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator<<) VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator>>) VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(maximum) VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(minimum) -VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(fmadd) -VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(fmsub) -VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp) +VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(fmadd) +VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(fmsub) +VECTORIZEDN_DEFINE_TERNARY_OP_GLOBAL(clamp) VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_max) VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(clamp_min) VECTORIZEDN_DEFINE_BINARY_OP_GLOBAL(operator&) @@ -357,5 +388,17 @@ inline T vec_reduce_all(const OpVec& vec_fun, VectorizedN acc_vec) { return vec_reduce_all(vec_fun, vec_result); } +template +std::ostream& operator<<(std::ostream& stream, const VectorizedN& vec_n) { + stream << "vec_n["; + for (int i = 0; i < N; ++i) { + if (i != 0) { + stream << ", "; + } + stream << vec_n[i]; + } + stream << ']'; + return stream; +} } // namespace CPU_CAPABILITY } // namespace at::vec diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 9b3fd5dc6e4dd..bb0ff0917859f 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #ifdef USE_ROCM @@ -18,6 +19,7 @@ // until hipblas has an API to accept flags, we must use rocblas here #include #include +#include #define PYTORCH_ROCBLAS_VERSION_DECIMAL (ROCBLAS_VERSION_MAJOR * 100 + ROCBLAS_VERSION_MINOR) #define USE_GEMM_FLAGS_FP16_ALT_IMPL (PYTORCH_ROCBLAS_VERSION_DECIMAL >= 242) // needed to work around calling rocblas API instead of hipblas API @@ -32,7 +34,7 @@ static rocblas_operation hipOperationToRocOperation(hipblasOperation_t op) case HIPBLAS_OP_C: return rocblas_operation_conjugate_transpose; } - AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM"); + TORCH_CHECK(false, "HIPBLAS_STATUS_INVALID_ENUM"); } static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error) { @@ -55,7 +57,7 @@ static hipblasStatus_t rocBLASStatusToHIPStatus(rocblas_status error) case rocblas_status_internal_error: return HIPBLAS_STATUS_INTERNAL_ERROR; } - AT_ERROR("HIPBLAS_STATUS_INVALID_ENUM"); + TORCH_CHECK(false, "HIPBLAS_STATUS_INVALID_ENUM"); } // hipblas does not have hipblasSetMathMode #define hipblasSetMathMode(handle, flags) HIPBLAS_STATUS_SUCCESS @@ -114,7 +116,7 @@ static cublasOperation_t _cublasOpFromChar(char op) { case 'C': return CUBLAS_OP_C; } - AT_ERROR( + TORCH_CHECK(false, "_cublasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); } @@ -180,17 +182,17 @@ uint32_t _getAlignment(uintptr_t address) { #endif static size_t _parseChosenWorkspaceSize() { - const char * val = getenv("CUBLASLT_WORKSPACE_SIZE"); + auto val = c10::utils::get_env("CUBLASLT_WORKSPACE_SIZE"); #ifdef USE_ROCM - if (!val) { + if (!val.has_value()) { // accept either env var - val = getenv("HIPBLASLT_WORKSPACE_SIZE"); + val = c10::utils::get_env("HIPBLASLT_WORKSPACE_SIZE"); } #endif size_t workspace_size = 1024; /* default size in KiB according to #73328 */ - if (val) { + if (val.has_value()) { try { - workspace_size = std::stoi(val); + workspace_size = std::stoi(val.value()); } catch(std::invalid_argument const& e) { TORCH_WARN("invalid CUBLASLT_WORKSPACE_SIZE,", " using default workspace size of ", workspace_size, " KiB."); @@ -792,6 +794,7 @@ inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) { static_assert(false && sizeof(Dtype), "at::cuda::blas::gemm_internal_cublas: not implemented"); } + template <> void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(double)) { // See Note [Writing Nondeterministic Operations] @@ -1000,6 +1003,11 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)) gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(double)); #endif } +#ifdef USE_ROCM + else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { + at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(double)); + } +#endif else { gemm_internal_cublas(CUDABLAS_GEMM_ARGS(double)); } @@ -1011,6 +1019,11 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); } +#ifdef USE_ROCM + else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { + at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(float)); + } +#endif else { gemm_internal_cublas(CUDABLAS_GEMM_ARGS(float)); } @@ -1054,6 +1067,11 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::Half)); } +#ifdef USE_ROCM + else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { + at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(at::Half)); + } +#endif else { gemm_internal_cublas(CUDABLAS_GEMM_ARGS(at::Half)); } @@ -1065,6 +1083,11 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::BFloat16)); } +#ifdef USE_ROCM + else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) { + at::native::gemm_internal_ck(CUDABLAS_GEMM_ARGS(at::BFloat16)); + } +#endif else { gemm_internal_cublas(CUDABLAS_GEMM_ARGS(at::BFloat16)); } diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h index e6f0c5a9a373b..989dd34633e73 100644 --- a/aten/src/ATen/cuda/CUDABlas.h +++ b/aten/src/ATen/cuda/CUDABlas.h @@ -34,7 +34,7 @@ class PointerModeGuard { private: cublasHandle_t handle; - cublasPointerMode_t previous_mode; + cublasPointerMode_t previous_mode{}; }; /* LEVEL 3 BLAS FUNCTIONS */ diff --git a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp index 861c9f634e261..6505fcfdd077d 100644 --- a/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp +++ b/aten/src/ATen/cuda/CUDAGeneratorImpl.cpp @@ -31,7 +31,7 @@ static std::vector default_gens_cuda; * Warning: this function must only be called once! */ static void initCUDAGenVector() { - num_gpus = c10::cuda::device_count(); + num_gpus = static_cast(c10::cuda::device_count()); cuda_gens_init_flag.resize(num_gpus); default_gens_cuda.resize(num_gpus); } diff --git a/aten/src/ATen/cuda/CUDAGeneratorImpl.h b/aten/src/ATen/cuda/CUDAGeneratorImpl.h index 0fe664e35f54c..b0b77cb822a85 100644 --- a/aten/src/ATen/cuda/CUDAGeneratorImpl.h +++ b/aten/src/ATen/cuda/CUDAGeneratorImpl.h @@ -5,7 +5,6 @@ #include #include #include -#include #include #include namespace at { @@ -168,7 +167,7 @@ struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl { CUDAGeneratorImpl* clone_impl() const override; c10::intrusive_ptr state_; - std::atomic_flag no_reset_rnn_state_; + std::atomic_flag no_reset_rnn_state_{}; }; namespace cuda::detail { diff --git a/aten/src/ATen/cuda/CUDAGraph.cpp b/aten/src/ATen/cuda/CUDAGraph.cpp index aa7ed65ff2093..34067a3197e59 100644 --- a/aten/src/ATen/cuda/CUDAGraph.cpp +++ b/aten/src/ATen/cuda/CUDAGraph.cpp @@ -7,9 +7,7 @@ #include #include -#include #include -#include namespace at::cuda { @@ -19,8 +17,7 @@ constexpr int kSynchronizeBusyWaitMillis = 10; MempoolId_t graph_pool_handle() { // Sets just the second value, to distinguish it from MempoolId_ts created from // cudaStreamGetCaptureInfo id_s in capture_begin. - auto new_pool = c10::cuda::MemPool(); - return new_pool.id(); + return c10::cuda::MemPool::graph_pool_handle(); } /** @@ -115,8 +112,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt } else { // User did not ask us to share a mempool. Create graph pool handle using is_user_created=false. // Sets just the first value, to distinguish it from MempoolId_ts created by graph_pool_handle(). - auto mempool = c10::cuda::MemPool({}, false); - mempool_id_ = mempool.id(); + mempool_id_ = c10::cuda::MemPool::graph_pool_handle(false); TORCH_INTERNAL_ASSERT(mempool_id_.first > 0); } @@ -124,8 +120,8 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt // autograd thread's free() call triggering an invalid cudaEventRecord in the caching allocator // due to the capture status being updated _after_ a capture had already started. c10::cuda::CUDACachingAllocator::beginAllocateToPool(capture_dev_, mempool_id_, [this](cudaStream_t stream) { - cudaStreamCaptureStatus status; - CaptureId_t stream_capture_id; + cudaStreamCaptureStatus status{}; + CaptureId_t stream_capture_id = 0; AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &stream_capture_id)); return status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive && stream_capture_id == capture_id_; }); @@ -144,7 +140,7 @@ void CUDAGraph::capture_begin(MempoolId_t pool/*=0*/, cudaStreamCaptureMode capt // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__STREAM.html#group__CUDART__STREAM_1g9d0535d93a214cbf126835257b16ba85 AT_CUDA_CHECK(cudaStreamBeginCapture(capture_stream_, capture_mode)); - cudaStreamCaptureStatus status; + cudaStreamCaptureStatus status{}; AT_CUDA_CHECK(cudaStreamGetCaptureInfo(stream, &status, &capture_id_)); TORCH_INTERNAL_ASSERT(status == cudaStreamCaptureStatus::cudaStreamCaptureStatusActive); diff --git a/aten/src/ATen/cuda/EmptyTensor.cpp b/aten/src/ATen/cuda/EmptyTensor.cpp index ad4f854a05ccc..108b7be47de17 100644 --- a/aten/src/ATen/cuda/EmptyTensor.cpp +++ b/aten/src/ATen/cuda/EmptyTensor.cpp @@ -10,7 +10,7 @@ TensorBase empty_cuda( ScalarType dtype, std::optional device_opt, std::optional memory_format_opt) { - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); const auto device = device_or_default(device_opt); TORCH_INTERNAL_ASSERT(device.is_cuda()); const DeviceGuard device_guard(device); @@ -50,7 +50,7 @@ TensorBase empty_strided_cuda( IntArrayRef stride, ScalarType dtype, std::optional device_opt) { - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); const auto device = device_or_default(device_opt); TORCH_INTERNAL_ASSERT(device.is_cuda()); const DeviceGuard device_guard(device); diff --git a/aten/src/ATen/cuda/Exceptions.h b/aten/src/ATen/cuda/Exceptions.h index 47d64e2bf3126..7387224f7ab81 100644 --- a/aten/src/ATen/cuda/Exceptions.h +++ b/aten/src/ATen/cuda/Exceptions.h @@ -157,18 +157,19 @@ constexpr const char* _cusolver_backend_suggestion = \ // See NOTE [ USE OF NVRTC AND DRIVER API ]. #if !defined(USE_ROCM) -#define AT_CUDA_DRIVER_CHECK(EXPR) \ - do { \ - CUresult __err = EXPR; \ - if (__err != CUDA_SUCCESS) { \ - const char* err_str; \ - CUresult get_error_str_err C10_UNUSED = at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \ - if (get_error_str_err != CUDA_SUCCESS) { \ - AT_ERROR("CUDA driver error: unknown error"); \ - } else { \ - AT_ERROR("CUDA driver error: ", err_str); \ - } \ - } \ +#define AT_CUDA_DRIVER_CHECK(EXPR) \ + do { \ + CUresult __err = EXPR; \ + if (__err != CUDA_SUCCESS) { \ + const char* err_str; \ + [[maybe_unused]] CUresult get_error_str_err = \ + at::globalContext().getNVRTC().cuGetErrorString(__err, &err_str); \ + if (get_error_str_err != CUDA_SUCCESS) { \ + TORCH_CHECK(false, "CUDA driver error: unknown error"); \ + } else { \ + TORCH_CHECK(false, "CUDA driver error: ", err_str); \ + } \ + } \ } while (0) #else @@ -177,7 +178,7 @@ constexpr const char* _cusolver_backend_suggestion = \ do { \ CUresult __err = EXPR; \ if (__err != CUDA_SUCCESS) { \ - AT_ERROR("CUDA driver error: ", static_cast(__err)); \ + TORCH_CHECK(false, "CUDA driver error: ", static_cast(__err)); \ } \ } while (0) @@ -197,9 +198,9 @@ constexpr const char* _cusolver_backend_suggestion = \ nvrtcResult __err = EXPR; \ if (__err != NVRTC_SUCCESS) { \ if (static_cast(__err) != 7) { \ - AT_ERROR("CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \ + TORCH_CHECK(false, "CUDA NVRTC error: ", at::globalContext().getNVRTC().nvrtcGetErrorString(__err)); \ } else { \ - AT_ERROR("CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \ + TORCH_CHECK(false, "CUDA NVRTC error: NVRTC_ERROR_BUILTIN_OPERATION_FAILURE"); \ } \ } \ } while (0) diff --git a/aten/src/ATen/cuda/PeerToPeerAccess.cpp b/aten/src/ATen/cuda/PeerToPeerAccess.cpp index e9ce2d9d3a604..e56d2f3ee229d 100644 --- a/aten/src/ATen/cuda/PeerToPeerAccess.cpp +++ b/aten/src/ATen/cuda/PeerToPeerAccess.cpp @@ -34,7 +34,7 @@ void init_p2p_access_cache(int64_t num_devices) { } // namespace detail bool get_p2p_access(int dev, int dev_to_access) { - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); TORCH_CHECK(dev >= 0 || dev < num_devices_, dev, " is not a device"); diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 0bdd865d88d25..d5b4c3ae62b41 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -39,7 +39,6 @@ #include #include -#include #include namespace c10::cuda::_internal { @@ -61,7 +60,7 @@ namespace { bool _hasPrimaryContext(DeviceIndex device_index) { TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(), "hasPrimaryContext expects a valid device index, but got device_index=", device_index); - unsigned int ctx_flags; + unsigned int ctx_flags = 0; // In standalone tests of cuDevicePrimaryCtxGetState, I've seen the "active" argument end up with weird // (garbage-looking nonzero) values when the context is not active, unless I initialize it to zero. int ctx_is_active = 0; @@ -84,7 +83,7 @@ struct _Initializer { // NB: deleter is dynamic, because we need it to live in a separate // compilation unit (alt is to have another method in hooks, but // let's not if we don't need to!) -void CUDAHooks::initCUDA() const { +void CUDAHooks::init() const { C10_LOG_API_USAGE_ONCE("aten.init.cuda"); // Force the update to enable unit testing. This code get executed before unit tests // have a chance to enable vitals. @@ -124,7 +123,7 @@ bool CUDAHooks::isPinnedPtr(const void* data) const { if (primary_ctx_device_index.has_value()) { device_guard.reset_device(at::Device(at::DeviceType::CUDA, *primary_ctx_device_index)); } - cudaPointerAttributes attr; + cudaPointerAttributes attr{}; // We do not believe that CUDA needs mutable access to the data // here. cudaError_t err = cudaPointerGetAttributes(&attr, data); @@ -300,7 +299,7 @@ long CUDAHooks::versionCuDNN() const { #if AT_CUDNN_ENABLED() return CUDNN_VERSION; #else - AT_ERROR("Cannot query CuDNN version if ATen_cuda is not built with CuDNN"); + TORCH_CHECK(false, "Cannot query CuDNN version if ATen_cuda is not built with CuDNN"); #endif } @@ -325,10 +324,10 @@ bool CUDAHooks::hasCUDART() const { std::string CUDAHooks::showConfig() const { std::ostringstream oss; - int runtimeVersion; + int runtimeVersion = 0; cudaRuntimeGetVersion(&runtimeVersion); - auto printCudaStyleVersion = [&](int v) { + auto printCudaStyleVersion = [&](size_t v) { #ifdef USE_ROCM // HIP_VERSION value format was changed after ROCm v4.2 to include the patch number if(v < 500) { @@ -369,7 +368,7 @@ std::string CUDAHooks::showConfig() const { #if AT_CUDNN_ENABLED() - auto printCudnnStyleVersion = [&](int v) { + auto printCudnnStyleVersion = [&](size_t v) { oss << (v / 1000) << "." << (v / 100 % 10); if (v % 100 != 0) { oss << "." << (v % 100); @@ -408,7 +407,7 @@ double CUDAHooks::batchnormMinEpsilonCuDNN() const { #if AT_CUDNN_ENABLED() return CUDNN_BN_MIN_EPSILON; #else - AT_ERROR( + TORCH_CHECK(false, "Cannot query CUDNN_BN_MIN_EPSILON if ATen_cuda is not built with CuDNN"); #endif } diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.h b/aten/src/ATen/cuda/detail/CUDAHooks.h index 11401701e44c0..2dbc336778c35 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.h +++ b/aten/src/ATen/cuda/detail/CUDAHooks.h @@ -3,7 +3,6 @@ #include #include -#include // TODO: No need to have this whole header, we can just put it all in // the cpp file @@ -19,7 +18,7 @@ TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)()); // The real implementation of CUDAHooksInterface struct CUDAHooks : public at::CUDAHooksInterface { CUDAHooks(at::CUDAHooksArgs) {} - void initCUDA() const override; + void init() const override; Device getDeviceFromPtr(void* data) const override; bool isPinnedPtr(const void* data) const override; const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override; diff --git a/aten/src/ATen/cuda/detail/IndexUtils.cu b/aten/src/ATen/cuda/detail/IndexUtils.cu index fda742f5cdfc2..9207b577f9443 100644 --- a/aten/src/ATen/cuda/detail/IndexUtils.cu +++ b/aten/src/ATen/cuda/detail/IndexUtils.cu @@ -37,7 +37,7 @@ within the next one. bool maybeOverlappingIndices(const TensorBase& t) { /* Extract size/stride arrays; only consider size >1 dims. */ std::vector info(t.dim()); - int dims = t.dim(); + auto dims = t.dim(); int nonSize1Dims = 0; for (int i = 0; i < dims; ++i) { int64_t size = t.size(i); diff --git a/aten/src/ATen/cuda/jiterator.cu b/aten/src/ATen/cuda/jiterator.cu index db751e33c43d2..6474395953351 100644 --- a/aten/src/ATen/cuda/jiterator.cu +++ b/aten/src/ATen/cuda/jiterator.cu @@ -8,7 +8,6 @@ #include #include -#include namespace at { namespace native { diff --git a/aten/src/ATen/cuda/tunable/GemmCommon.h b/aten/src/ATen/cuda/tunable/GemmCommon.h index c3a171e8d9251..da6dcc51e6661 100644 --- a/aten/src/ATen/cuda/tunable/GemmCommon.h +++ b/aten/src/ATen/cuda/tunable/GemmCommon.h @@ -22,6 +22,7 @@ #include #include #endif +#include namespace at::cuda::tunable { @@ -30,15 +31,15 @@ enum class BlasOp { T = 1 }; -inline std::string BlasOpToString(BlasOp op) { +inline char BlasOpToString(BlasOp op) { switch (op) { case BlasOp::N: - return "N"; + return 'N'; case BlasOp::T: - return "T"; + return 'T'; } TORCH_CHECK(false, "unrecognized BlasOp"); - return "N"; + return 'N'; } namespace detail { @@ -81,7 +82,7 @@ struct GemmParams : OpParams { } std::string Signature() const override { - return c10::str(transa, transb, "_", m, "_", n, "_", k); + return fmt::sprintf("%c%c_%ld_%ld_%ld", transa, transb, m, n, k); } size_t GetSizeA() const { @@ -158,7 +159,7 @@ struct GemmParams : OpParams { template struct GemmAndBiasParams : OpParams { std::string Signature() const override { - return c10::str(transa, transb, "_", m, "_", n, "_", k); + return fmt::sprintf("%c%c_%ld_%ld_%ld", transa, transb, m, n, k); } size_t GetSize(bool duplicate_inputs) const { @@ -228,7 +229,7 @@ struct GemmStridedBatchedParams : OpParams { } std::string Signature() const override { - return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch); + return fmt::sprintf("%c%c_%ld_%ld_%ld_B_%ld", transa, transb, m, n, k, batch); } size_t GetSizeA() const { @@ -313,7 +314,7 @@ struct ScaledGemmParams : OpParams { } std::string Signature() const override { - return c10::str(transa, transb, "_", m, "_", n, "_", k); + return fmt::sprintf("%c%c_%ld_%ld_%ld", transa, transb, m, n, k); } size_t GetSizeA() const { diff --git a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h index 483b4fb7a91a0..aba653f76db8e 100644 --- a/aten/src/ATen/cuda/tunable/GemmHipblaslt.h +++ b/aten/src/ATen/cuda/tunable/GemmHipblaslt.h @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -309,7 +310,7 @@ static hipblasOperation_t _hipblasOpFromChar(char op) { case 'C': return HIPBLAS_OP_C; } - AT_ERROR( + TORCH_CHECK(false, "_hipblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); } @@ -322,7 +323,7 @@ static char _charFromhipblasOp(hipblasOperation_t op) { case HIPBLAS_OP_C: return 'C'; } - AT_ERROR( + TORCH_CHECK(false, "_charFromhipblasOp input should be HIPBLAS_OP_N/T/C but got `", op, "`"); } @@ -578,8 +579,7 @@ auto GetHipBlasLtTypeStringAndOps() { auto algo = heuristic_result[i].algo; int algo_index = hipblaslt_ext::getIndexFromAlgo(algo); auto callable = std::make_unique>(algo); - std::string type_string = c10::str( - "Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index); + std::string type_string = fmt::sprintf("Gemm_Hipblaslt_%c%c_%d", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), algo_index); ret.emplace_back(type_string, std::move(callable)); } diff --git a/aten/src/ATen/cuda/tunable/GemmRocblas.h b/aten/src/ATen/cuda/tunable/GemmRocblas.h index f096ff00fd9b4..026836fc73ccd 100644 --- a/aten/src/ATen/cuda/tunable/GemmRocblas.h +++ b/aten/src/ATen/cuda/tunable/GemmRocblas.h @@ -7,6 +7,7 @@ #include #include #include +#include #define ROCBLAS_BETA_FEATURES_API #include @@ -129,7 +130,7 @@ static rocblas_operation _rocblasOpFromChar(char op) { case 'C': return rocblas_operation_conjugate_transpose; } - AT_ERROR( + TORCH_CHECK(false, "_rocblasOpFromChar input should be 't', 'n' or 'c' but got `", op, "`"); } @@ -197,7 +198,7 @@ auto GetRocBlasGemmTypeStringAndOps() { std::vector>>>> ret; for (size_t i = 0; i < solutions.size(); ++i) { auto callable = std::make_unique>(solutions[i]); - ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable))); + ret.emplace_back(std::make_pair(fmt::sprintf("Gemm_Rocblas_%d", solutions[i]), std::move(callable))); } return ret; } diff --git a/aten/src/ATen/cuda/tunable/README.md b/aten/src/ATen/cuda/tunable/README.md index e17ff71f3004e..a2a0d0b8d77f0 100644 --- a/aten/src/ATen/cuda/tunable/README.md +++ b/aten/src/ATen/cuda/tunable/README.md @@ -77,6 +77,31 @@ default, now called through TunableOp. Any call to at::cuda::blas::gemm() or ::b when enabled. Calling gemm() for a given set of input arguments (transa, transb, m, n, k) will attempt to use the fastest available implementation across both rocblas and hipblaslt. +## Offline Tuning + +### Motivation +Basically it is used for workload with high-memory utilization where one might run out of memory with regular tuning. + +### Workflow +There are basically two steps: +1) Set the environment variables to collect the untuned GEMM and this will generate `tunableop_untuned?.csv` ("?" is placeholder for the GPU ID), like: +``` +PYTORCH_TUNABLEOP_ENABLED=1 +PYTORCH_TUNABLEOP_TUNING=0 +PYTORCH_TUNABLEOP_RECORD_UNTUNED=1 +... +``` +2) Run a Python script that reads the `tunableop_untuned?.csv` and generates the `tunableop_results?.csv`, like: +``` +import torch.cuda.tunable as tunable +import os + +os.putenv('PYTORCH_TUNABLEOP_ENABLED', '1') +os.putenv('PYTORCH_TUNABLEOP_TUNING', '1') +os.putenv('PYTORCH_TUNABLEOP_RECORD_UNTUNED', '0') +tunable.tune_gemm_in_file("tunableop_results?.csv") +``` + ## Tuning Context The behavior of TunableOp is currently manipulated through environment variables, the C++ interface of at::cuda::tunable::getTuningContext(), or the `torch.cuda.tunable` python interfaces. The environment variables take @@ -90,6 +115,8 @@ programmatically since the settings become fixed. Use the C++ or Python APIs ins | -------------------- | ----------- | | PYTORCH_TUNABLEOP_ENABLED | Default is 0. Set to 1 to enable. | | PYTORCH_TUNABLEOP_TUNING | Default is 1. Set to 0 to disable. | +| PYTORCH_TUNABLEOP_RECORD_UNTUNED | Default is 0. Set to 1 to enable. | +| PYTORCH_TUNABLEOP_UNTUNED_FILENAME | Default is 'tunableop_untuned.csv'. | | PYTORCH_TUNABLEOP_VERBOSE | Default is 0. Set to 1 to enable basic logging. 2 for basic tuning status. 3 for full trace. | | PYTORCH_TUNABLEOP_VERBOSE_FILENAME | Default is "err" for stderr. Set to "out" for stdout or a filename for capturing verbose logging. | | PYTORCH_TUNABLEOP_FILENAME | Default is 'tunableop_results.csv'. | @@ -112,6 +139,8 @@ All python APIs exist in the `torch.cuda.tunable` module. | is_enabled() -> bool | | | tuning_enable(val: bool = True) -> None | Default is True. | | tuning_is_enabled() -> bool | | +| record_untuned_enable(val: bool = True) -> None | Default is True. | +| record_untuned_is_enabled() -> bool | | | set_max_tuning_duration(duration: int) -> None | | | get_max_tuning_duration() -> int | | | set_max_tuning_iterations(iterations: int) -> None | | @@ -123,6 +152,7 @@ All python APIs exist in the `torch.cuda.tunable` module. | write_file_on_exit(val: bool) -> None | Default is True. | | write_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). | | read_file(filename: Optional[str] = None) -> None | If filename not given, it will call get_filename(). | +| tune_gemm_in_file(filename: str) -> None | read an untuned file and tune GEMMs in it. | ### C++ Interface Example: diff --git a/aten/src/ATen/cuda/tunable/StreamTimer.h b/aten/src/ATen/cuda/tunable/StreamTimer.h index c70cb1a908d9d..36b8d72a4953b 100644 --- a/aten/src/ATen/cuda/tunable/StreamTimer.h +++ b/aten/src/ATen/cuda/tunable/StreamTimer.h @@ -18,7 +18,7 @@ namespace at::cuda::tunable { class StreamTimer : public ITimer { public: StreamTimer(); - virtual ~StreamTimer() override; + ~StreamTimer() override; void Start() override; diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp index 1b7c898758558..318d08189f4e0 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.cpp +++ b/aten/src/ATen/cuda/tunable/Tunable.cpp @@ -19,16 +19,10 @@ #include #endif -#include #include -#include -#include -#include #include #include #include -#include -#include #include #include #include @@ -83,7 +77,7 @@ ResultEntry TuningResultsManager::Lookup(const std::string& op_signature, const return it->second; } -inline void TuningResultsManager::AddImpl(const std::string& op_signature, +void TuningResultsManager::AddImpl(const std::string& op_signature, const std::string& params_signature, ResultEntry best, KernelMap& kernel_map) { @@ -98,7 +92,7 @@ inline void TuningResultsManager::AddImpl(const std::string& op_signature, } TUNABLE_LOG2(op_signature, "(", params_signature, ") -> ", best); - kernel_map.emplace(params_signature, best); + kernel_map.emplace(params_signature, std::move(best)); } void TuningResultsManager::Add(const std::string& op_signature, const std::string& params_signature, ResultEntry best) { @@ -109,7 +103,33 @@ void TuningResultsManager::Add(const std::string& op_signature, const std::strin it = results_.insert({op_signature, {}}).first; } - AddImpl(op_signature, params_signature, best, it->second); + AddImpl(op_signature, params_signature, std::move(best), it->second); +} + +void TuningResultsManager::RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature, const std::string& params_signature) { + std::scoped_lock l{lock_}; + if (!untuned_file.good()) { + TORCH_WARN_ONCE("failed to open file for writing; untuned gemm will not be saved"); + return; + } else { + bool isNew = false; + auto it = untuned_results_.find(op_signature); + if (it == untuned_results_.end()) { + it = untuned_results_.insert({op_signature, {}}).first; + isNew = true; + } + + auto it_kernel_map = it->second.find(params_signature); + if (it_kernel_map == it->second.end()) { + it->second.insert(params_signature); + isNew = true; + } + + if (isNew) { + untuned_file << op_signature << "," << params_signature << std::endl; + TUNABLE_LOG3("Untuned,", op_signature, ",", params_signature); + } + } } void TuningResultsManager::Delete(const std::string& op_signature, const std::string& params_signature) { @@ -129,7 +149,7 @@ void TuningResultsManager::Delete(const std::string& op_signature, const std::st it->second.erase(it2); } -inline void TuningResultsManager::DisjointMergeImpl( +void TuningResultsManager::DisjointMergeImpl( const std::string& op_signature, const KernelMap& kernel_map, /*out*/ std::unordered_map& results) { @@ -179,7 +199,7 @@ size_t TuningResultsManager::GetSize() { TuningResultsValidator::TuningResultsValidator() { RegisterValidator( "PT_VERSION", - [this]() { return GetPyTorchVersion(); }, + []() { return GetPyTorchVersion(); }, [this](auto&& k) { return ValidatePyTorchVersion(std::forward(k)); }); #ifdef USE_ROCM // rocm @@ -342,7 +362,7 @@ void TuningResultsValidator::RegisterValidator(const std::string& key, const Get } } -std::string TuningResultsValidator::GetPyTorchVersion() const { +std::string TuningResultsValidator::GetPyTorchVersion() { return TORCH_VERSION; } @@ -359,6 +379,7 @@ TuningStatus TuningResultsValidator::ValidatePyTorchVersion(const std::string& v TuningContext::TuningContext() : enable_{false}, tuning_enable_{true}, + record_untuned_enable_{false}, manager_initialized_{false}, write_file_on_exit_{true}, numerics_check_enable_{false}, @@ -369,6 +390,7 @@ TuningContext::TuningContext() : icache_flush_{true}, rotating_buffer_size_{-1}, filename_{}, + untuned_file_{}, results_count_from_input_file_{0} { } @@ -394,6 +416,10 @@ TuningContext::~TuningContext() { } } } + + if (untuned_file_.good()) { + untuned_file_.close(); + } } void TuningContext::EnableTunableOp(bool value) { @@ -424,6 +450,15 @@ void TuningContext::EnableTuning(bool value) { } } +void TuningContext::EnableRecordUntuned(bool value) { + record_untuned_enable_ = value; + if (value) { + TUNABLE_LOG1("Enable Record Untuned for TunableOp"); + } else { + TUNABLE_LOG1("Disable Record Untuned for TunableOp"); + } +} + bool TuningContext::IsTuningEnabled() const { static const char *env = std::getenv("PYTORCH_TUNABLEOP_TUNING"); if (env != nullptr && strcmp(env, "0") == 0) { @@ -432,6 +467,33 @@ bool TuningContext::IsTuningEnabled() const { return tuning_enable_; } +bool TuningContext::IsRecordUntunedEnabled() const { + static const char *env = std::getenv("PYTORCH_TUNABLEOP_RECORD_UNTUNED"); + if (env != nullptr && strcmp(env, "1") == 0) { + return true; + } + return record_untuned_enable_; +} + +std::ofstream& TuningContext::GetUntunedFile(){ + if (!untuned_file_.is_open()) { + const char *env = std::getenv("PYTORCH_TUNABLEOP_UNTUNED_FILENAME"); + std::string filename = (env == nullptr) ? "tunableop_untuned.csv" : env; + + std::string device = c10::str(int(c10::cuda::current_device())); + std::size_t found = filename.rfind('.'); + if (found != std::string::npos) { + filename.insert(found, device); + } else { + // all else fails, just append + filename.append(device); + } + + untuned_file_ = std::ofstream(filename, std::ios::out | std::ios::trunc); + } + return untuned_file_; +} + void TuningContext::WriteFileOnExit(bool value) { write_file_on_exit_ = value; } @@ -545,7 +607,7 @@ TuningResultsManager& TuningContext::GetTuningResultsManager() { SetFilename(filename, true); } auto filename = GetFilename(); - if (!filename.empty()) { + if (!filename.empty() && !IsRecordUntunedEnabled()) { ReadFile(filename); // attempt immediately to open file for writing to catch errors early std::ofstream file(filename, std::ios::out | std::ios::app); diff --git a/aten/src/ATen/cuda/tunable/Tunable.h b/aten/src/ATen/cuda/tunable/Tunable.h index 243031cf3da2d..02cc0bc4fdab3 100644 --- a/aten/src/ATen/cuda/tunable/Tunable.h +++ b/aten/src/ATen/cuda/tunable/Tunable.h @@ -10,6 +10,7 @@ #pragma once #include +#include #include #include @@ -17,10 +18,9 @@ #include #include #include -#include #include +#include #include -#include namespace at::cuda::tunable { @@ -33,11 +33,11 @@ struct MaybeDelete { using OstreamPtr = std::unique_ptr; -static OstreamPtr get_stream(std::string filename) { - if (filename.compare("out") == 0) { +inline OstreamPtr get_stream(const std::string& filename) { + if (filename == "out") { return OstreamPtr { &std::cout, MaybeDelete {false} }; } - else if (filename.compare("err") == 0) { + else if (filename == "err") { return OstreamPtr { &std::cerr, MaybeDelete {false} }; } else { @@ -47,16 +47,17 @@ static OstreamPtr get_stream(std::string filename) { } -static void TunableLog(int level, const std::string& msg) { +template +static void TunableLog(int level, Types... args) { static const char *env_file = getenv("PYTORCH_TUNABLEOP_VERBOSE_FILENAME"); static const char *env_verbose = getenv("PYTORCH_TUNABLEOP_VERBOSE"); static int level_user = env_verbose ? atoi(env_verbose) : 0; static auto streamptr = detail::get_stream(env_file ? env_file : "err"); if (level_user >= level) { - (*streamptr) << msg < KernelMap; typedef std::unordered_map ResultsMap; +typedef std::unordered_map> UntunedMap; struct TORCH_CUDA_CPP_API TuningResults { // Validates if these results are compatible with the libraries @@ -105,7 +107,7 @@ class TORCH_CUDA_CPP_API TuningResultsManager { ResultEntry Lookup(const std::string& op_signature, const std::string& params_signature); - inline void AddImpl(const std::string& op_signature, + void AddImpl(const std::string& op_signature, const std::string& params_signature, ResultEntry best, KernelMap& kernel_map); @@ -116,7 +118,7 @@ class TORCH_CUDA_CPP_API TuningResultsManager { void Delete(const std::string& op_signature, const std::string& params_signature); - inline void DisjointMergeImpl( + void DisjointMergeImpl( const std::string& op_signature, const KernelMap& kernel_map, /*out*/ ResultsMap& results); @@ -129,9 +131,12 @@ class TORCH_CUDA_CPP_API TuningResultsManager { size_t GetSize(); + void RecordUntuned( std::ofstream& untuned_file, const std::string& op_signature, const std::string& params_signature); private: std::mutex lock_; ResultsMap results_; + UntunedMap untuned_results_; + }; class TORCH_CUDA_CPP_API TuningResultsValidator { @@ -148,7 +153,7 @@ class TORCH_CUDA_CPP_API TuningResultsValidator { void RegisterValidator(const std::string& key, const GetFunc& gf, const ValidateFunc& vf); protected: - std::string GetPyTorchVersion() const; + static std::string GetPyTorchVersion() ; TuningStatus ValidatePyTorchVersion(const std::string& value) const; public: @@ -173,6 +178,10 @@ class TORCH_CUDA_CPP_API TuningContext { void EnableTuning(bool value); bool IsTuningEnabled() const; + void EnableRecordUntuned(bool value); + bool IsRecordUntunedEnabled() const; + std::ofstream& GetUntunedFile(); + void EnableNumericsCheck(bool value); bool IsNumericsCheckEnabled() const; @@ -213,6 +222,7 @@ class TORCH_CUDA_CPP_API TuningContext { private: bool enable_; bool tuning_enable_; + bool record_untuned_enable_; bool manager_initialized_; bool write_file_on_exit_; bool numerics_check_enable_; @@ -226,6 +236,7 @@ class TORCH_CUDA_CPP_API TuningContext { mutable c10::once_flag manager_init_once_; TuningResultsValidator validator_; std::string filename_; + std::ofstream untuned_file_; size_t results_count_from_input_file_; }; diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index 00b02e91b4f35..55e072a7f6d7c 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -22,6 +22,7 @@ #include #include #include +#include namespace at::cuda::tunable { @@ -135,57 +136,57 @@ inline bool IsZero(c10::complex v) { } template -inline std::string TypeName(T v) { +inline const char* TypeName(T v) { return "unknown"; } template <> -inline std::string TypeName(float v) { +inline const char* TypeName(float v) { return "float"; } template <> -inline std::string TypeName(double v) { +inline const char* TypeName(double v) { return "double"; } template <> -inline std::string TypeName(BFloat16 v) { +inline const char* TypeName(BFloat16 v) { return "BFloat16"; } template <> -inline std::string TypeName(Half v) { +inline const char* TypeName(Half v) { return "Half"; } template <> -inline std::string TypeName(Float8_e4m3fn v) { +inline const char* TypeName(Float8_e4m3fn v) { return "Float8_e4m3fn"; } template <> -inline std::string TypeName(Float8_e5m2 v) { +inline const char* TypeName(Float8_e5m2 v) { return "Float8_e5m2"; } template <> -inline std::string TypeName(Float8_e4m3fnuz v) { +inline const char* TypeName(Float8_e4m3fnuz v) { return "Float8_e4m3fnuz"; } template <> -inline std::string TypeName(Float8_e5m2fnuz v) { +inline const char* TypeName(Float8_e5m2fnuz v) { return "Float8_e5m2fnuz"; } template <> -inline std::string TypeName(c10::complex v) { +inline const char* TypeName(c10::complex v) { return "c10::complex"; } template <> -inline std::string TypeName(c10::complex v) { +inline const char* TypeName(c10::complex v) { return "c10::complex"; } @@ -196,15 +197,15 @@ class GemmTunableOp : public TunableOp, StreamTimer> { this->RegisterOp(std::string("Default"), std::make_unique>()); #ifdef USE_ROCM - static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); - if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) { + static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); + if (!env_rocblas.has_value() || env_rocblas.value()) { for (auto&& [name, op] : GetRocBlasGemmTypeStringAndOps()) { this->RegisterOp(std::move(name), std::move(op)); } } - static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); - if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { + static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (!env_hipblaslt.has_value() || env_hipblaslt.value()) { // disallow tuning of hipblaslt with c10::complex if constexpr ( !std::is_same_v> && @@ -218,7 +219,7 @@ class GemmTunableOp : public TunableOp, StreamTimer> { } std::string Signature() override { - return c10::str("GemmTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + return fmt::sprintf("GemmTunableOp_%s_%c%c", TypeName(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout)); } }; @@ -229,8 +230,8 @@ class GemmAndBiasTunableOp : public TunableOp, StreamTimer> this->RegisterOp(std::string("Default"), std::make_unique>()); #ifdef USE_ROCM - static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); - if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { + static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (!env_hipblaslt.has_value() || env_hipblaslt.value()) { // disallow tuning of hipblaslt with c10::complex if constexpr ( !std::is_same_v> && @@ -244,7 +245,7 @@ class GemmAndBiasTunableOp : public TunableOp, StreamTimer> } std::string Signature() override { - return c10::str("GemmAndBiasTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + return fmt::sprintf("GemmAndBiasTunableOp_%s_%c%c", TypeName(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout)); } }; @@ -255,15 +256,15 @@ class GemmStridedBatchedTunableOp : public TunableOp this->RegisterOp(std::string("Default"), std::make_unique>()); #ifdef USE_ROCM - static const char *env_rocblas = std::getenv("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); - if (env_rocblas == nullptr || strcmp(env_rocblas, "1") == 0) { + static const auto env_rocblas = c10::utils::check_env("PYTORCH_TUNABLEOP_ROCBLAS_ENABLED"); + if (!env_rocblas.has_value() || env_rocblas.value()) { for (auto&& [name, op] : GetRocBlasGemmStridedBatchedTypeStringAndOps()) { this->RegisterOp(std::move(name), std::move(op)); } } - static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); - if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { + static const auto env_hipblaslt = c10::utils::check_env("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); + if (!env_hipblaslt.has_value() || env_hipblaslt.value()) { // disallow tuning of hipblaslt with c10::complex if constexpr ( !std::is_same_v> && @@ -277,7 +278,7 @@ class GemmStridedBatchedTunableOp : public TunableOp } std::string Signature() override { - return c10::str("GemmStridedBatchedTunableOp_", TypeName(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + return fmt::sprintf("GemmStridedBatchedTunableOp_%s_%c%c", TypeName(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout)); } }; @@ -295,11 +296,11 @@ class ScaledGemmTunableOp : public TunableOp, StreamTimer> } std::string Signature() override { - return c10::str("ScaledGemmTunableOp", - "_", TypeName(AT{}), - "_", TypeName(BT{}), - "_", TypeName(CT{}), - "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); + return fmt::sprintf("ScaledGemmTunableOp_%s_%s_%s_%c%c", + TypeName(AT{}), + TypeName(BT{}), + TypeName(CT{}), + BlasOpToString(ALayout), BlasOpToString(BLayout)); } }; diff --git a/aten/src/ATen/cuda/tunable/TunableOp.h b/aten/src/ATen/cuda/tunable/TunableOp.h index 6d1065ad01a6a..ed60c59482247 100644 --- a/aten/src/ATen/cuda/tunable/TunableOp.h +++ b/aten/src/ATen/cuda/tunable/TunableOp.h @@ -18,7 +18,6 @@ #endif #include -#include #include #include @@ -54,9 +53,15 @@ class TunableOp { auto params_sig = params->Signature(); result = mgr.Lookup(op_sig, params_sig); // If there is not previous tuning result been found, we do the tuning iff tuning is enabled - if (result == ResultEntry::Null() && ctx->IsTuningEnabled()) { - result = FindFastest(params); - mgr.Add(op_sig, params_sig, result); + if (result == ResultEntry::Null()) { + if (ctx->IsTuningEnabled()) { + result = FindFastest(params); + mgr.Add(op_sig, params_sig, result); + } + else if (ctx->IsRecordUntunedEnabled()) { + // or record the gemm into file + mgr.RecordUntuned(ctx->GetUntunedFile(), op_sig, params_sig); + } } } else { @@ -140,7 +145,7 @@ class TunableOp { bool use_buffer_rotation = (rotating_size > 0); size_t param_size = params->GetSize(use_buffer_rotation); size_t param_count = (rotating_size / param_size) + 1; - constexpr size_t MB = 1024*1024; + constexpr size_t MB = 1024ull*1024; if (use_buffer_rotation) { TUNABLE_LOG2("Rotating buffer ", rotating_size/MB, " MiB. ", "Needed Size: ", param_size/MB, " MiB. ", @@ -260,6 +265,7 @@ class TunableOp { std::string CreateSignature() { #ifndef _WIN32 const auto* name = typeid(*this).name(); + // NOLINTNEXTLINE(*array*) char buf[256]; size_t buf_len = 256; abi::__cxa_demangle(name, buf, &buf_len, nullptr); diff --git a/aten/src/ATen/cudnn/AutocastRNN.cpp b/aten/src/ATen/cudnn/AutocastRNN.cpp index c920e9ce1cf86..71cd199b33790 100644 --- a/aten/src/ATen/cudnn/AutocastRNN.cpp +++ b/aten/src/ATen/cudnn/AutocastRNN.cpp @@ -113,7 +113,7 @@ _cudnn_rnn_cast_reflatten(const Tensor & input, batch_sizes, dropout_state); #else // AT_CUDNN_ENABLED() - AT_ERROR("autocast::_cudnn_rnn_cast_reflatten: ATen not compiled with cuDNN support"); + TORCH_CHECK(false, "autocast::_cudnn_rnn_cast_reflatten: ATen not compiled with cuDNN support"); return {Tensor{}, Tensor{}, Tensor{}, Tensor{}, Tensor{}}; // never reached, placates the compiler #endif // AT_CUDNN_ENABLED() } diff --git a/aten/src/ATen/detail/AcceleratorHooksInterface.h b/aten/src/ATen/detail/AcceleratorHooksInterface.h index 61409db3ac680..4eab4d24f71b3 100644 --- a/aten/src/ATen/detail/AcceleratorHooksInterface.h +++ b/aten/src/ATen/detail/AcceleratorHooksInterface.h @@ -19,6 +19,10 @@ struct TORCH_API AcceleratorHooksInterface { // Whether the device at device_index is fully initialized or not. virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0; + virtual void init() const { + TORCH_CHECK(false, "Backend doesn`t support init()"); + } + virtual DeviceIndex deviceCount() const { return 0; } @@ -50,6 +54,10 @@ struct TORCH_API AcceleratorHooksInterface { TORCH_CHECK(false, "Backend doesn't support getPinnedMemoryAllocator()"); return nullptr; } + + virtual Device getDeviceFromPtr(void* data) const { + TORCH_CHECK(false, "Backend doesn't support getDeviceFromPtr()"); + } }; } // namespace at diff --git a/aten/src/ATen/detail/CUDAHooksInterface.h b/aten/src/ATen/detail/CUDAHooksInterface.h index f9a3fa098508f..144643e52973b 100644 --- a/aten/src/ATen/detail/CUDAHooksInterface.h +++ b/aten/src/ATen/detail/CUDAHooksInterface.h @@ -65,15 +65,19 @@ struct TORCH_API CUDAHooksInterface : AcceleratorHooksInterface { ~CUDAHooksInterface() override = default; // Initialize THCState and, transitively, the CUDA state - virtual void initCUDA() const { + void init() const override { TORCH_CHECK(false, "Cannot initialize CUDA without ATen_cuda library. ", CUDA_HELP); } - virtual const Generator& getDefaultCUDAGenerator(C10_UNUSED DeviceIndex device_index = -1) const { - TORCH_CHECK(false, "Cannot get default CUDA generator without ATen_cuda library. ", CUDA_HELP); + virtual const Generator& getDefaultCUDAGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const { + TORCH_CHECK( + false, + "Cannot get default CUDA generator without ATen_cuda library. ", + CUDA_HELP); } - virtual Device getDeviceFromPtr(void* /*data*/) const { + Device getDeviceFromPtr(void* /*data*/) const override { TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP); } diff --git a/aten/src/ATen/detail/HIPHooksInterface.h b/aten/src/ATen/detail/HIPHooksInterface.h index b3194668d9512..f852db8d600e6 100644 --- a/aten/src/ATen/detail/HIPHooksInterface.h +++ b/aten/src/ATen/detail/HIPHooksInterface.h @@ -26,9 +26,8 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface { // squelch -Werror=non-virtual-dtor ~HIPHooksInterface() override = default; - // Initialize the HIP library state - virtual void initHIP() const { - AT_ERROR("Cannot initialize HIP without ATen_hip library."); + void init() const override { + TORCH_CHECK(false, "Cannot initialize HIP without ATen_hip library."); } virtual std::unique_ptr initHIPGenerator(Context*) const { @@ -48,7 +47,7 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface { } Allocator* getPinnedMemoryAllocator() const override { - AT_ERROR("Pinned memory requires HIP."); + TORCH_CHECK(false, "Pinned memory requires HIP."); } virtual void registerHIPTypes(Context*) const { @@ -60,7 +59,7 @@ struct TORCH_API HIPHooksInterface : AcceleratorHooksInterface { } bool hasPrimaryContext(DeviceIndex device_index) const override { - AT_ERROR("Cannot check primary context without ATen_hip library."); + TORCH_CHECK(false, "Cannot check primary context without ATen_hip library."); } }; diff --git a/aten/src/ATen/detail/IPUHooksInterface.h b/aten/src/ATen/detail/IPUHooksInterface.h index 8f24df4fdd2de..20dbb703d571f 100644 --- a/aten/src/ATen/detail/IPUHooksInterface.h +++ b/aten/src/ATen/detail/IPUHooksInterface.h @@ -1,14 +1,25 @@ #pragma once #include +#include + #include #include #include namespace at { -struct TORCH_API IPUHooksInterface { - virtual ~IPUHooksInterface() = default; +struct TORCH_API IPUHooksInterface: AcceleratorHooksInterface { + ~IPUHooksInterface() override = default; + + void init() const override { + TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library."); + } + + bool hasPrimaryContext(DeviceIndex device_index) const override { + TORCH_CHECK(false, "Cannot initialize IPU without ATen_ipu library."); + return false; + } virtual const Generator& getDefaultIPUGenerator( DeviceIndex device_index [[maybe_unused]] = -1) const { diff --git a/aten/src/ATen/detail/MAIAHooksInterface.h b/aten/src/ATen/detail/MAIAHooksInterface.h index ad4ef146eccd9..554cc93043fd3 100644 --- a/aten/src/ATen/detail/MAIAHooksInterface.h +++ b/aten/src/ATen/detail/MAIAHooksInterface.h @@ -3,13 +3,24 @@ #include #include +#include + // NB: Class must live in `at` due to limitations of Registry.h. namespace at { -struct TORCH_API MAIAHooksInterface { +struct TORCH_API MAIAHooksInterface : AcceleratorHooksInterface { // This should never actually be implemented, but it is used to // squelch -Werror=non-virtual-dtor - virtual ~MAIAHooksInterface() = default; + ~MAIAHooksInterface() override = default; + + void init() const override { + TORCH_CHECK(false, "Cannot initialize MAIA without ATen_maia library."); + } + + bool hasPrimaryContext(DeviceIndex device_index) const override { + TORCH_CHECK(false, "Cannot initialize MAIA without ATen_maia library."); + return false; + } virtual std::string showConfig() const { TORCH_CHECK(false, "Cannot query detailed MAIA version information."); diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h index 180ff68588edd..e3f8d3132bb8c 100644 --- a/aten/src/ATen/detail/MPSHooksInterface.h +++ b/aten/src/ATen/detail/MPSHooksInterface.h @@ -22,7 +22,7 @@ struct TORCH_API MPSHooksInterface : AcceleratorHooksInterface { ~MPSHooksInterface() override = default; // Initialize the MPS library state - virtual void initMPS() const { + void init() const override { FAIL_MPSHOOKS_FUNC(__func__); } virtual bool hasMPS() const { diff --git a/aten/src/ATen/detail/MTIAHooksInterface.h b/aten/src/ATen/detail/MTIAHooksInterface.h index 1480436fb4f1d..3320bf90108f8 100644 --- a/aten/src/ATen/detail/MTIAHooksInterface.h +++ b/aten/src/ATen/detail/MTIAHooksInterface.h @@ -31,7 +31,7 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface { ~MTIAHooksInterface() override = default; - virtual void initMTIA() const { + void init() const override { // Avoid logging here, since MTIA needs init devices first then it will know // how many devices are available. Make it as no-op if mtia extension is not // dynamically loaded. diff --git a/aten/src/ATen/detail/PrivateUse1HooksInterface.h b/aten/src/ATen/detail/PrivateUse1HooksInterface.h index e321f484deeac..3820c960dfe57 100644 --- a/aten/src/ATen/detail/PrivateUse1HooksInterface.h +++ b/aten/src/ATen/detail/PrivateUse1HooksInterface.h @@ -18,7 +18,7 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDefaultGenerator`."); } - virtual at::Device getDeviceFromPtr(void* data) const { + at::Device getDeviceFromPtr(void* data) const override { TORCH_CHECK_NOT_IMPLEMENTED( false, "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `getDeviceFromPtr`."); @@ -40,7 +40,7 @@ struct TORCH_API PrivateUse1HooksInterface : AcceleratorHooksInterface { "You should register `PrivateUse1HooksInterface` for PrivateUse1 before call `hasPrimaryContext`."); } - virtual void initPrivateUse1() const {} + void init() const override {} virtual void resizePrivateUse1Bytes( const c10::Storage& storage, size_t newsize) const { diff --git a/aten/src/ATen/detail/XPUHooksInterface.h b/aten/src/ATen/detail/XPUHooksInterface.h index f4cd9a34b5752..8cb5497e62c03 100644 --- a/aten/src/ATen/detail/XPUHooksInterface.h +++ b/aten/src/ATen/detail/XPUHooksInterface.h @@ -14,10 +14,8 @@ namespace at { struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{ ~XPUHooksInterface() override = default; - virtual void initXPU() const { - TORCH_CHECK( - false, - "Cannot initialize XPU without ATen_xpu library."); + void init() const override { + TORCH_CHECK(false, "Cannot initialize XPU without ATen_xpu library."); } virtual bool hasXPU() const { @@ -34,12 +32,15 @@ struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{ TORCH_CHECK(false, "Cannot get XPU global device index without ATen_xpu library."); } - virtual Generator getXPUGenerator(C10_UNUSED DeviceIndex device_index = -1) const { + virtual Generator getXPUGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const { TORCH_CHECK(false, "Cannot get XPU generator without ATen_xpu library."); } - virtual const Generator& getDefaultXPUGenerator(C10_UNUSED DeviceIndex device_index = -1) const { - TORCH_CHECK(false, "Cannot get default XPU generator without ATen_xpu library."); + virtual const Generator& getDefaultXPUGenerator( + [[maybe_unused]] DeviceIndex device_index = -1) const { + TORCH_CHECK( + false, "Cannot get default XPU generator without ATen_xpu library."); } virtual DeviceIndex getNumGPUs() const { @@ -50,7 +51,7 @@ struct TORCH_API XPUHooksInterface : AcceleratorHooksInterface{ TORCH_CHECK(false, "Cannot get current device on XPU without ATen_xpu library."); } - virtual Device getDeviceFromPtr(void* /*data*/) const { + Device getDeviceFromPtr(void* /*data*/) const override { TORCH_CHECK(false, "Cannot get device of pointer on XPU without ATen_xpu library."); } diff --git a/aten/src/ATen/dlpack.h b/aten/src/ATen/dlpack.h index c77205f962158..6f8e03dd57042 100644 --- a/aten/src/ATen/dlpack.h +++ b/aten/src/ATen/dlpack.h @@ -32,7 +32,9 @@ #define DLPACK_DLL #endif +// NOLINTNEXTLINE(modernize-deprecated-headers) #include +// NOLINTNEXTLINE(modernize-deprecated-headers) #include #ifdef __cplusplus diff --git a/aten/src/ATen/functorch/BatchRulesConvolution.cpp b/aten/src/ATen/functorch/BatchRulesConvolution.cpp index 3cf00f33def55..89de1fc18f5b6 100644 --- a/aten/src/ATen/functorch/BatchRulesConvolution.cpp +++ b/aten/src/ATen/functorch/BatchRulesConvolution.cpp @@ -362,6 +362,7 @@ static std::tuple convolution_backward_plumbing( const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_, const c10::OptionalArrayRef bias_sizes_opt, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, + // NOLINTNEXTLINE(performance-unnecessary-value-param) c10::SymIntArrayRef output_padding, c10::SymInt groups, std::array output_mask) { const auto maybe_layer = maybeCurrentDynamicLayer(); vmap_check_escaped(maybe_layer, "convolution_backward_plumbing"); diff --git a/aten/src/ATen/functorch/BatchRulesIndexing.cpp b/aten/src/ATen/functorch/BatchRulesIndexing.cpp index eb571b2980781..5620d8593ca90 100644 --- a/aten/src/ATen/functorch/BatchRulesIndexing.cpp +++ b/aten/src/ATen/functorch/BatchRulesIndexing.cpp @@ -8,7 +8,7 @@ #include #include -namespace at { namespace functorch { +namespace at::functorch { #define OP_DECOMPOSE(op) m.impl(#op, static_cast(native::op)); #define OP_DECOMPOSE2(op, overload) m.impl(#op"."#overload, static_cast(native::op)); @@ -20,4 +20,4 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { OP_DECOMPOSE(_unsafe_masked_index_put_accumulate); } -}} +} diff --git a/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp b/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp index fed7fecc217b9..99a589f370224 100644 --- a/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp +++ b/aten/src/ATen/functorch/BatchRulesLinearAlgebra.cpp @@ -494,7 +494,7 @@ _scaled_dot_product_flash_attention_batch_rule( double dropout_p, bool is_causal, bool return_debug_mask, - c10::optional scale + std::optional scale ) { if (dropout_p > 0) { auto maybe_layer = maybeCurrentDynamicLayer(); @@ -543,7 +543,7 @@ fourOutputs _scaled_dot_product_efficient_attention_batch_rule( bool compute_log_sumexp, double dropout_p, bool is_causal, - c10::optional scale + std::optional scale ) { if (dropout_p > 0) { auto maybe_layer = maybeCurrentDynamicLayer(); @@ -585,7 +585,7 @@ _scaled_dot_product_cudnn_attention_batch_rule( double dropout_p, bool is_causal, bool return_debug_mask, - c10::optional scale + std::optional scale ) { if (dropout_p > 0) { auto maybe_layer = maybeCurrentDynamicLayer(); diff --git a/aten/src/ATen/functorch/BatchRulesModules.cpp b/aten/src/ATen/functorch/BatchRulesModules.cpp index 99a5a434d54c7..2572e07debfa2 100644 --- a/aten/src/ATen/functorch/BatchRulesModules.cpp +++ b/aten/src/ATen/functorch/BatchRulesModules.cpp @@ -224,9 +224,9 @@ static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes // but shape inference is not possible. if (self.sym_numel() == 0) { if (num_classes <= 0) { - AT_ERROR("Can not infer total number of classes from empty tensor."); + TORCH_CHECK(false, "Can not infer total number of classes from empty tensor."); } else { - shape.push_back(num_classes); + shape.emplace_back(num_classes); return at::empty_symint(shape, self.options()); } } @@ -246,7 +246,7 @@ static Tensor one_hot_decomposition_hack(const Tensor &self, int64_t num_classes // TORCH_CHECK(num_classes > self.max().item().toLong(), "Class values must be smaller than num_classes."); // } - shape.push_back(num_classes); + shape.emplace_back(num_classes); Tensor ret = at::zeros_symint(shape, self.options()); return ret.scatter(-1, self.unsqueeze(-1), 1); } diff --git a/aten/src/ATen/functorch/BatchRulesRandomness.cpp b/aten/src/ATen/functorch/BatchRulesRandomness.cpp index 2cd175fdcbabd..b6609ebc39b31 100644 --- a/aten/src/ATen/functorch/BatchRulesRandomness.cpp +++ b/aten/src/ATen/functorch/BatchRulesRandomness.cpp @@ -213,7 +213,7 @@ static std::tuple native_dropout_batching_rule(const Tensor& tens return std::make_tuple(output, mask); } -static Tensor multinomial_batching_rule(const Tensor& self, const int64_t num_samples, const bool replacement, const std::optional generator) { +static Tensor multinomial_batching_rule(const Tensor& self, const int64_t num_samples, const bool replacement, std::optional generator) { c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchVmapMode); auto maybe_layer = maybeCurrentDynamicLayer(); const auto cur_level = maybe_layer->layerId(); @@ -237,7 +237,7 @@ static Tensor multinomial_batching_rule(const Tensor& self, const int64_t num_sa if (is_2D_case) { self_value = reshape_dim_into(0, 0, self_value); } - auto out = multinomial(self_value, num_samples, replacement, generator); + auto out = multinomial(self_value, num_samples, replacement, std::move(generator)); if (is_2D_case) { out = reshape_dim_outof_symint(0, maybe_layer->batchSize(), out); } @@ -249,7 +249,7 @@ static Tensor multinomial_batching_rule(const Tensor& self, const int64_t num_sa // Must be same randomness with unbatched input // 1D case: S -> multinomial(S) -> S // 2D case: MS -> multinomial(MS) -> MS - return multinomial(self_value, num_samples, replacement, generator); + return multinomial(self_value, num_samples, replacement, std::move(generator)); } template diff --git a/aten/src/ATen/functorch/BatchRulesReduceOps.cpp b/aten/src/ATen/functorch/BatchRulesReduceOps.cpp index 8385660be0b38..878ea58bdb2c9 100644 --- a/aten/src/ATen/functorch/BatchRulesReduceOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesReduceOps.cpp @@ -103,7 +103,7 @@ template< // optional cannot be used in a template, otherwise we would use it here. int maybe_keepdim_arg_pos > -void boxed_reduction_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) { +static void boxed_reduction_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) { const auto& schema = op.schema(); const auto num_returns = schema.returns().size(); const auto num_arguments = schema.arguments().size(); @@ -357,21 +357,21 @@ static std::tuple> searchsorted_batch_rule( // B<...>D, B<...>V -> no change if (buckets_bdim.has_value() && self_bdim.has_value()) { auto self_ = moveBatchDimToFront(self, self_bdim); - auto result = at::searchsorted(buckets, self_, out_int32, right, std::move(side), sorter_); + auto result = at::searchsorted(buckets, self_, out_int32, right, side, sorter_); return std::make_tuple(std::move(result), 0); } // B<...>D, <...>V -> B<...>D, B<...>V if (buckets_bdim.has_value() && !self_bdim.has_value()) { auto self_ = moveBatchDimToFront(self, self_bdim); self_ = ensure_has_bdim(self_, self_bdim.has_value(), buckets.size(0)); - auto result = at::searchsorted(buckets, self_, out_int32, right, std::move(side), sorter_); + auto result = at::searchsorted(buckets, self_, out_int32, right, side, sorter_); return std::make_tuple(std::move(result), 0); } // <...>D, B<...>V -> <...>D, <...>(BV) if (!buckets_bdim.has_value() && self_bdim.has_value()) { auto bdim_size = self.size(*self_bdim); auto self_ = reshape_dim_into(*self_bdim, -1, self); - auto result = at::searchsorted(buckets, self_, out_int32, right, std::move(side), sorter_); + auto result = at::searchsorted(buckets, self_, out_int32, right, side, sorter_); result = reshape_dim_outof(-1, bdim_size, result); return std::make_tuple(result, result.dim() - 2); } @@ -382,7 +382,7 @@ static std::tuple> searchsorted_batch_rule( if (buckets_bdim.has_value() && self_bdim.has_value()) { auto self_ = moveBatchDimToFront(self, self_bdim); auto self_view_ = self_logical_rank == 0 ? self_.unsqueeze(-1) : self_.flatten(1); - auto result = at::searchsorted(buckets, self_view_, out_int32, right, std::move(side), sorter_); + auto result = at::searchsorted(buckets, self_view_, out_int32, right, side, sorter_); result = self_logical_rank == 0 ? result.squeeze(-1) : result.view(self_.sizes()); return std::make_tuple(std::move(result), 0); } @@ -391,13 +391,13 @@ static std::tuple> searchsorted_batch_rule( auto bdim_size = buckets.size(*buckets_bdim); auto self_ = ensure_has_bdim(self, false, bdim_size); auto self_view_ = self_logical_rank == 0 ? self_.unsqueeze(-1) : self_.flatten(1); - auto result = at::searchsorted(buckets, self_view_, out_int32, right, std::move(side), sorter_); + auto result = at::searchsorted(buckets, self_view_, out_int32, right, side, sorter_); result = self_logical_rank == 0 ? result.squeeze(-1) : result.view(self_.sizes()); return std::make_tuple(std::move(result), 0); } // D, B* -> no change if (!buckets_bdim.has_value() && self_bdim.has_value()) { - auto result = at::searchsorted(buckets, self, out_int32, right, std::move(side), sorter_); + auto result = at::searchsorted(buckets, self, out_int32, right, side, sorter_); return std::make_tuple(std::move(result), self_bdim); } TORCH_INTERNAL_ASSERT(false); diff --git a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp index 496e58d7994fe..8f2738552310d 100644 --- a/aten/src/ATen/functorch/BatchRulesScatterOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesScatterOps.cpp @@ -235,7 +235,7 @@ std::tuple> index_batch_rule( bool advanced_indices_are_adjacent = are_advanced_indices_adjacent(indices); // Step 1 - const auto batched_indices = batchIndices(indices, indices_bdims, self_.size(0), self_bdim); + const auto batched_indices = batchIndices(indices, indices_bdims, self_.sym_size(0), self_bdim); auto num_leading_nones = get_num_leading_nones(indices); auto max_index_dim = get_max_index_logical_dim(indices, indices_bdims); @@ -427,7 +427,7 @@ namespace { // shape of `values` is (N, 2, 3), then following block // will reshape `values` to (N, 1, 1, 2, 3). if ( (int64_t) indexed_shape.size() > values_.dim()) { - auto values_sizes = values_.sizes(); + auto values_sizes = values_.sym_sizes(); // number of unit dims (for broadcasting value to indexed_shape) auto n_unit_dims = indexed_shape.size() - values_sizes.size(); @@ -841,26 +841,26 @@ std::tuple> gather_batch_rule( return std::make_tuple(result, 0); } -Tensor get_expanded_index(const Tensor& index, IntArrayRef self_size, int64_t dim) { +Tensor get_expanded_index(const Tensor& index, SymIntArrayRef self_size, int64_t dim) { if (index.dim() == 0) { - return index.expand(self_size); + return index.expand_symint(self_size); } dim = maybe_wrap_dim(dim, static_cast(self_size.size())); // setup new_index_shape as [BS, 1, ..., idx_size, ..., 1] // to reshape index_ - auto idx_size = index.size(0); // get non-batch size of index tensor + auto idx_size = index.sym_size(0); // get non-batch size of index tensor Tensor index_; { - VmapDimVector new_index_shape(self_size.size(), 1); + VmapSymDimVector new_index_shape(self_size.size(), 1); new_index_shape[dim] = idx_size; - index_ = index.view(new_index_shape); + index_ = index.view_symint(new_index_shape); } // Now apply expand to index_ { - VmapDimVector new_index_shape = {self_size.begin(), self_size.end()}; + VmapSymDimVector new_index_shape = {self_size.begin(), self_size.end()}; new_index_shape[dim] = idx_size; - index_ = index_.expand(new_index_shape); + index_ = index_.expand_symint(new_index_shape); } return index_; } @@ -869,7 +869,7 @@ Tensor index_select_decomp(const Tensor &self, int64_t dim, const Tensor &index) { Tensor index_ = index; if (self.dim() > index.dim()) { - index_ = get_expanded_index(index, self.sizes(), dim); + index_ = get_expanded_index(index, self.sym_sizes(), dim); } auto result = at::gather(self, dim, index_); @@ -893,7 +893,7 @@ Tensor index_copy_decomp( { Tensor index_ = index; if (self.dim() > index.dim()) { - index_ = get_expanded_index(index, self.sizes(), dim); + index_ = get_expanded_index(index, self.sym_sizes(), dim); } return at::scatter(self, dim, index_, source); ; @@ -909,7 +909,7 @@ Tensor slice_scatter_decomp(const Tensor &self, const Tensor &src, std::optional end, int64_t step) { auto idx = at::arange(start.value_or(0), end.value_or(self.size(dim)), step, self.options().dtype(kLong)); - idx = get_expanded_index(idx, self.sizes(), dim); + idx = get_expanded_index(idx, self.sym_sizes(), dim); return at::scatter(self, dim, idx, src); } diff --git a/aten/src/ATen/functorch/LegacyVmapTransforms.cpp b/aten/src/ATen/functorch/LegacyVmapTransforms.cpp index 07b97def63f3a..ace12bc9c4579 100644 --- a/aten/src/ATen/functorch/LegacyVmapTransforms.cpp +++ b/aten/src/ATen/functorch/LegacyVmapTransforms.cpp @@ -29,7 +29,7 @@ static Tensor permuteBatchDimsToFront(const BatchedTensorImpl* batched) { if (is_bdim[ptr]) { continue; } - permutation[idx++] = ptr; + permutation[idx++] = static_cast(ptr); } return physical_tensor.permute(permutation); } @@ -43,7 +43,7 @@ VmapPhysicalView MultiBatchVmapTransform::logicalToPhysical(const Tensor& logica } int64_t VmapPhysicalView::numBatchDims() const { - return levels_.count(); + return static_cast(levels_.count()); } int64_t VmapPhysicalView::numLogicalDims() const { @@ -102,7 +102,7 @@ static Tensor moveDimToFrontAndExpand(Tensor tensor, std::optional dim, } else { tensor = tensor.unsqueeze(0); auto expanded_sizes = tensor.sym_sizes().vec(); - expanded_sizes[0] = size; + expanded_sizes[0] = std::move(size); tensor = tensor.expand_symint(expanded_sizes); } return tensor; @@ -171,7 +171,7 @@ static Tensor moveDimToFrontAndUnsqueeze(Tensor tensor, std::optional d VmapPhysicalViewVec BroadcastingVmapTransform::logicalToPhysical(TensorList logical_tensors) { auto cur_level = maybeCurrentDynamicLayer().value().layerId(); - auto bdim_size = -1; + int64_t bdim_size = -1; // Figure out the batch size first for (const auto& logical_tensor : logical_tensors) { diff --git a/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp b/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp index e9e7b2a99553b..7bc3a3cbfe44a 100644 --- a/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp +++ b/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp @@ -4,7 +4,6 @@ #include #include #include -#include #include #include #include @@ -89,7 +88,7 @@ Tensor binary_cross_entropy_with_logits_hack( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - const Tensor& pos_weight = c10::value_or_else(pos_weight_opt, [] {return Tensor();}); + const Tensor& pos_weight = pos_weight_opt.value_or(Tensor()); Tensor loss; auto max_val = (-input).clamp_min(0); @@ -136,7 +135,7 @@ static Tensor make_feature_noise(const Tensor& input) { sizes.reserve(input.dim()); sizes.push_back(input_sizes[0]); sizes.push_back(input_sizes[1]); - for (C10_UNUSED const auto i : c10::irange(2, input.dim())) { + for ([[maybe_unused]] const auto i : c10::irange(2, input.dim())) { sizes.push_back(1); } // NB: THIS WAS CHANGED FROM THE ORIGINAL diff --git a/aten/src/ATen/metal/Context.cpp b/aten/src/ATen/metal/Context.cpp index f9b745387dc8e..c0d32086d4179 100644 --- a/aten/src/ATen/metal/Context.cpp +++ b/aten/src/ATen/metal/Context.cpp @@ -16,7 +16,7 @@ at::Tensor& metal_copy_(at::Tensor& self, const at::Tensor& src) { if (p) { return p->metal_copy_(self, src); } - AT_ERROR("Metal backend was not linked to the build"); + TORCH_CHECK(false, "Metal backend was not linked to the build"); } } // namespace at::metal diff --git a/aten/src/ATen/miopen/AutocastRNN.cpp b/aten/src/ATen/miopen/AutocastRNN.cpp index a23eb4a1a19b8..69fd575779a82 100644 --- a/aten/src/ATen/miopen/AutocastRNN.cpp +++ b/aten/src/ATen/miopen/AutocastRNN.cpp @@ -46,7 +46,7 @@ miopen_rnn(const Tensor & input_r, fn_dropout_state_opt); #else - AT_ERROR("autocast::miopen_rnn: ATen not compiled with ROCm enabled"); + TORCH_CHECK(false, "autocast::miopen_rnn: ATen not compiled with ROCm enabled"); return {Tensor{}, Tensor{}, Tensor{}, Tensor{}, Tensor{}}; // placate the compiler #endif diff --git a/aten/src/ATen/mps/MPSDevice.mm b/aten/src/ATen/mps/MPSDevice.mm index 37fa105cbee02..1d03128aa391d 100644 --- a/aten/src/ATen/mps/MPSDevice.mm +++ b/aten/src/ATen/mps/MPSDevice.mm @@ -36,7 +36,9 @@ static inline MTLLanguageVersion getMetalLanguageVersion(const id& de if (isMacOS13Plus(MacOSVersion::MACOS_VER_15_0_PLUS)) { options.mathMode = MTLMathModeFast; } else { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-declarations") [options setFastMathEnabled:YES]; + C10_DIAGNOSTIC_POP() } _mtl_indexing_library = [_mtl_device newLibraryWithSource:[NSString stringWithCString:mps::indexing_metal_shaders encoding:NSASCIIStringEncoding] diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h index 4858c0609f56b..20662be436910 100644 --- a/aten/src/ATen/mps/MPSHooks.h +++ b/aten/src/ATen/mps/MPSHooks.h @@ -12,7 +12,7 @@ namespace at::mps { // The real implementation of MPSHooksInterface struct MPSHooks : public at::MPSHooksInterface { MPSHooks(at::MPSHooksArgs) {} - void initMPS() const override; + void init() const override; // MPSDevice interface bool hasMPS() const override; diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm index 5855e16aca8c9..983bb516a31b8 100644 --- a/aten/src/ATen/mps/MPSHooks.mm +++ b/aten/src/ATen/mps/MPSHooks.mm @@ -10,7 +10,7 @@ namespace at::mps { -void MPSHooks::initMPS() const { +void MPSHooks::init() const { C10_LOG_API_USAGE_ONCE("aten.init.mps"); // TODO: initialize MPS devices and streams here } diff --git a/aten/src/ATen/mps/MPSProfiler.mm b/aten/src/ATen/mps/MPSProfiler.mm index 522328277787b..2dd270452fcc6 100644 --- a/aten/src/ATen/mps/MPSProfiler.mm +++ b/aten/src/ATen/mps/MPSProfiler.mm @@ -189,7 +189,7 @@ currentSigint.sa_flags = SA_RESTART; sigfillset(¤tSigint.sa_mask); if (sigaction(SIGINT, ¤tSigint, &previousSigint) == -1) { - AT_ERROR("Cannot install SIGINT handler for MPSProfiler."); + TORCH_CHECK(false, "Cannot install SIGINT handler for MPSProfiler."); } } } @@ -207,7 +207,7 @@ } else if (token == "event") { m_profile_options |= ProfileOptions::ALL_SIGNPOST_EVENTS; } else { - AT_ERROR("Invalid Signpost trace mode: ", token); + TORCH_CHECK(false, "Invalid Signpost trace mode: ", token); } } } @@ -654,7 +654,7 @@ isInfoLoggingEnabled = (m_log_options & LogOptions::CPU_FALLBACK_INFO); break; default: - AT_ERROR("invalid profiling info type"); + TORCH_CHECK(false, "invalid profiling info type"); } if (!isInfoLoggingEnabled) { return false; @@ -685,7 +685,7 @@ os_signpost_event_emit(m_os_log_events, signpost_id, kEvtSignpostCPUFallbacksStr, "%s", msg); break; default: - AT_ERROR("unknown SignpostType in MPS profiler"); + TORCH_CHECK(false, "unknown SignpostType in MPS profiler"); } } @@ -709,7 +709,7 @@ os_signpost_interval_begin(m_os_log_intervals, signpost_id, kIntSignpostCPUFallbacksStr, "%s", msg); break; default: - AT_ERROR("unknown SignpostType in MPS profiler"); + TORCH_CHECK(false, "unknown SignpostType in MPS profiler"); } } @@ -728,7 +728,7 @@ os_signpost_interval_end(m_os_log_intervals, signpost_id, kIntSignpostCPUFallbacksStr); break; default: - AT_ERROR("unknown SignpostType in MPS profiler"); + TORCH_CHECK(false, "unknown SignpostType in MPS profiler"); } } @@ -750,7 +750,7 @@ case BaseInfo::Type::CPU_FALLBACK: return SignpostTypes::CPU_FALLBACK; default: - AT_ERROR("invalid profiling info type"); + TORCH_CHECK(false, "invalid profiling info type"); } } diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 3794af0529fe0..1df22fb451f6e 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -132,11 +132,46 @@ extern "C" void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *inf extern "C" void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info); // potrs +#if defined(_WIN32) && defined(_M_ARM64) + +// The functions zpotrs, cpotrs, dpotrs, and spotrs are not directly available in LAPACKE on Windows on ARM, +// so we need to have wrapper functions to call them. +// The issue on ARM platform can be found below: +// https://community.arm.com/support-forums/f/high-performance-computing-forum/56512/unable-to-use-lapack---potrs-functions + +#define LAPACK_COL_MAJOR 102 +#define LAPACK_ROW_MAJOR 101 + +extern "C" int LAPACKE_zpotrs(int matrix_layout, char uplo, int n, int nrhs, const std::complex *a, int lda, std::complex *b, int ldb); +extern "C" int LAPACKE_cpotrs(int matrix_layout, char uplo, int n, int nrhs, const std::complex *a, int lda, std::complex *b, int ldb); +extern "C" int LAPACKE_dpotrs(int matrix_layout, char uplo, int n, int nrhs, const double *a, int lda, double *b, int ldb); +extern "C" int LAPACKE_spotrs(int matrix_layout, char uplo, int n, int nrhs, const float *a, int lda, float *b, int ldb); + +static inline void zpotrs_(char *uplo, int *n, int *nrhs, std::complex *a, int *lda, std::complex *b, int *ldb, int *info) { + *info = LAPACKE_zpotrs(LAPACK_COL_MAJOR, *uplo, *n, *nrhs, a, *lda, b, *ldb); +} + +static inline void cpotrs_(char *uplo, int *n, int *nrhs, std::complex *a, int *lda, std::complex *b, int *ldb, int *info) { + *info = LAPACKE_cpotrs(LAPACK_COL_MAJOR, *uplo, *n, *nrhs, a, *lda, b, *ldb); +} + +static inline void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info){ + *info = LAPACKE_dpotrs(LAPACK_COL_MAJOR, *uplo, *n, *nrhs, a, *lda, b, *ldb); +} + +static inline void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info) { + *info = LAPACKE_spotrs(LAPACK_COL_MAJOR, *uplo, *n, *nrhs, a, *lda, b, *ldb); +} + +#else + extern "C" void zpotrs_(char *uplo, int *n, int *nrhs, std::complex *a, int *lda, std::complex *b, int *ldb, int *info); extern "C" void cpotrs_(char *uplo, int *n, int *nrhs, std::complex *a, int *lda, std::complex *b, int *ldb, int *info); extern "C" void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info); extern "C" void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info); +#endif + // potrf extern "C" void zpotrf_(char *uplo, int *n, std::complex *a, int *lda, int *info); extern "C" void cpotrf_(char *uplo, int *n, std::complex *a, int *lda, int *info); @@ -284,11 +319,39 @@ extern "C" void dorgqr_(int *m, int *n, int *k, double *a, int *lda, double *tau extern "C" void sorgqr_(int *m, int *n, int *k, float *a, int *lda, float *tau, float *work, int *lwork, int *info); // ormqr +#if defined(_WIN32) && defined(_M_ARM64) + +// The functions zunmqr, cunmqr, dormqr, and sormqr are not directly available in LAPACKE on Windows on ARM, +// so we need to have wrapper functions to call them. +// The issue on ARM platform can be found below: +// https://community.arm.com/support-forums/f/high-performance-computing-forum/56512/unable-to-use-lapack---potrs-functions + +extern "C" int LAPACKE_zunmqr_work(int matrix_layout, char side, char trans, int m, int n, int k, const std::complex *a, int lda, const std::complex *tau, std::complex *c, int ldc, std::complex *work, int lwork); +extern "C" int LAPACKE_cunmqr_work(int matrix_layout, char side, char trans, int m, int n, int k, const std::complex *a, int lda, const std::complex *tau, std::complex *c, int ldc, std::complex *work, int lwork); +extern "C" int LAPACKE_dormqr_work(int matrix_layout, char side, char trans, int m, int n, int k, const double *a, int lda, const double *tau, double *c, int ldc, double *work, int lwork); +extern "C" int LAPACKE_sormqr_work(int matrix_layout, char side, char trans, int m, int n, int k, const float *a, int lda, const float *tau, float *c, int ldc, float *work, int lwork); + +static inline void zunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex *a, int *lda, std::complex *tau, std::complex *c, int *ldc, std::complex *work, int *lwork, int *info) { + *info = LAPACKE_zunmqr_work(LAPACK_COL_MAJOR, *side, *trans, *m, *n, *k, a, *lda, tau, c, *ldc, work, *lwork); +} + +static inline void cunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex *a, int *lda, std::complex *tau, std::complex *c, int *ldc, std::complex *work, int *lwork, int *info) { + *info = LAPACKE_cunmqr_work(LAPACK_COL_MAJOR, *side, *trans, *m, *n, *k, a, *lda, tau, c, *ldc, work, *lwork); +} + +static inline void dormqr_(char *side, char *trans, int *m, int *n, int *k, double *a, int *lda, double *tau, double *c, int *ldc, double *work, int *lwork, int *info) { + *info = LAPACKE_dormqr_work(LAPACK_COL_MAJOR, *side, *trans, *m, *n, *k, a, *lda, tau, c, *ldc, work, *lwork); +} + +static inline void sormqr_(char *side, char *trans, int *m, int *n, int *k, float *a, int *lda, float *tau, float *c, int *ldc, float *work, int *lwork, int *info) { + *info = LAPACKE_sormqr_work(LAPACK_COL_MAJOR, *side, *trans, *m, *n, *k, a, *lda, tau, c, *ldc, work, *lwork); +} +#else extern "C" void zunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex *a, int *lda, std::complex *tau, std::complex *c, int *ldc, std::complex *work, int *lwork, int *info); extern "C" void cunmqr_(char *side, char *trans, int *m, int *n, int *k, std::complex *a, int *lda, std::complex *tau, std::complex *c, int *ldc, std::complex *work, int *lwork, int *info); extern "C" void dormqr_(char *side, char *trans, int *m, int *n, int *k, double *a, int *lda, double *tau, double *c, int *ldc, double *work, int *lwork, int *info); extern "C" void sormqr_(char *side, char *trans, int *m, int *n, int *k, float *a, int *lda, float *tau, float *c, int *ldc, float *work, int *lwork, int *info); - +#endif // syevd extern "C" void zheevd_(char *jobz, char *uplo, int *n, std::complex *a, int *lda, double *w, std::complex *work, int *lwork, double *rwork, int *lrwork, int *iwork, int *liwork, int *info); extern "C" void cheevd_(char *jobz, char *uplo, int *n, std::complex *a, int *lda, float *w, std::complex *work, int *lwork, float *rwork, int *lrwork, int *iwork, int *liwork, int *info); @@ -1624,7 +1687,7 @@ Tensor inverse(const Tensor& A) { template static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, Tensor& infos) { #if !AT_BUILD_WITH_LAPACK() - AT_ERROR("cholesky_solve: LAPACK library not found in compilation"); + TORCH_CHECK(false, "cholesky_solve: LAPACK library not found in compilation"); #else char uplo = upper ? 'U' : 'L'; diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index d61c9870f4c52..ab1dd139b1b9a 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -1109,7 +1109,7 @@ void unpack_pivots_cpu_kernel(TensorIterator& iter, const int64_t dim_size, cons auto* perm_ptr = data[0]; const auto* pivots_ptr = data[1]; - for (C10_UNUSED const auto elem : c10::irange(nelems)) { + for ([[maybe_unused]] const auto elem : c10::irange(nelems)) { // WARNING: linalg.lu_factor returns int32 pivots, // this behavior could change in the future. const auto perm_data = reinterpret_cast(perm_ptr); diff --git a/aten/src/ATen/native/BlasKernel.cpp b/aten/src/ATen/native/BlasKernel.cpp index 966beb8a08915..c40539e63cdd7 100644 --- a/aten/src/ATen/native/BlasKernel.cpp +++ b/aten/src/ATen/native/BlasKernel.cpp @@ -15,6 +15,7 @@ #if defined(__aarch64__) && !defined(C10_MOBILE) #include +#include #endif C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-function") @@ -132,30 +133,50 @@ float bf16_dot_with_fp32_arith( #endif template -bool scal_use_fast_path(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) { +bool scal_use_fast_path( + [[maybe_unused]] int64_t n, + [[maybe_unused]] int64_t incx) { return false; } template -bool gemv_use_fast_path(C10_UNUSED char trans, C10_UNUSED int64_t m, - C10_UNUSED int64_t n, C10_UNUSED scalar_t alpha, - C10_UNUSED int64_t lda, - C10_UNUSED int64_t incx, C10_UNUSED scalar_t beta, - C10_UNUSED int64_t incy) { +bool gemv_use_fast_path( + [[maybe_unused]] char trans, + [[maybe_unused]] int64_t m, + [[maybe_unused]] int64_t n, + [[maybe_unused]] scalar_t alpha, + [[maybe_unused]] int64_t lda, + [[maybe_unused]] int64_t incx, + [[maybe_unused]] scalar_t beta, + [[maybe_unused]] int64_t incy) { return false; } template -void scal_fast_path(C10_UNUSED int *n, C10_UNUSED scalar_t *a, C10_UNUSED scalar_t *x, C10_UNUSED int *incx) { - TORCH_INTERNAL_ASSERT(false, "scal_fast_path shouldn't be called for this configuration"); +void scal_fast_path( + [[maybe_unused]] int* n, + [[maybe_unused]] scalar_t* a, + [[maybe_unused]] scalar_t* x, + [[maybe_unused]] int* incx) { + TORCH_INTERNAL_ASSERT( + false, "scal_fast_path shouldn't be called for this configuration"); } template -void gemv_fast_path(C10_UNUSED const char *trans, C10_UNUSED const int *m, C10_UNUSED const int *n, - C10_UNUSED const scalar_t *alpha, C10_UNUSED const scalar_t *a, C10_UNUSED const int *lda, - C10_UNUSED const scalar_t *x, C10_UNUSED const int *incx, C10_UNUSED const scalar_t *beta, - C10_UNUSED scalar_t *y, C10_UNUSED const int *incy) { - TORCH_INTERNAL_ASSERT(false, "gemv_fast_path shouldn't be called for this configuration"); +void gemv_fast_path( + [[maybe_unused]] const char* trans, + [[maybe_unused]] const int* m, + [[maybe_unused]] const int* n, + [[maybe_unused]] const scalar_t* alpha, + [[maybe_unused]] const scalar_t* a, + [[maybe_unused]] const int* lda, + [[maybe_unused]] const scalar_t* x, + [[maybe_unused]] const int* incx, + [[maybe_unused]] const scalar_t* beta, + [[maybe_unused]] scalar_t* y, + [[maybe_unused]] const int* incy) { + TORCH_INTERNAL_ASSERT( + false, "gemv_fast_path shouldn't be called for this configuration"); } #define INSTANTIATE(scalar_t) \ @@ -187,15 +208,32 @@ void scal_fast_path(int *n, float *a, float *x, int *incx) { } template <> -bool gemv_use_fast_path(C10_UNUSED char trans, int64_t m, int64_t n, C10_UNUSED float alpha, int64_t lda, int64_t incx, C10_UNUSED float beta, int64_t incy) { +bool gemv_use_fast_path( + [[maybe_unused]] char trans, + int64_t m, + int64_t n, + [[maybe_unused]] float alpha, + int64_t lda, + int64_t incx, + [[maybe_unused]] float beta, + int64_t incy) { auto intmax = std::numeric_limits::max(); return (m <= intmax) && (n <= intmax) && (lda <= intmax) && (incx > 0) && (incx <= intmax) && (incy > 0) && (incy <= intmax); } template <> -bool gemv_use_fast_path(C10_UNUSED char trans, int64_t m, int64_t n, C10_UNUSED double alpha, int64_t lda, int64_t incx, C10_UNUSED double beta, int64_t incy) { - return gemv_use_fast_path(trans, m, n, (float)alpha, lda, incx, (float)beta, incy); +bool gemv_use_fast_path( + [[maybe_unused]] char trans, + int64_t m, + int64_t n, + [[maybe_unused]] double alpha, + int64_t lda, + int64_t incx, + [[maybe_unused]] double beta, + int64_t incy) { + return gemv_use_fast_path( + trans, m, n, (float)alpha, lda, incx, (float)beta, incy); } template <> @@ -219,38 +257,40 @@ INSTANTIATE(int); INSTANTIATE(int64_t); #if defined(__aarch64__) && !defined(C10_MOBILE) template <> -bool scal_use_fast_path(C10_UNUSED int64_t n, C10_UNUSED int64_t incx) { +bool scal_use_fast_path( + [[maybe_unused]] int64_t n, + [[maybe_unused]] int64_t incx) { return false; } template <> bool gemv_use_fast_path( - C10_UNUSED char trans, - C10_UNUSED int64_t m, - C10_UNUSED int64_t n, + [[maybe_unused]] char trans, + [[maybe_unused]] int64_t m, + [[maybe_unused]] int64_t n, at::Half alpha, - C10_UNUSED int64_t lda, - C10_UNUSED int64_t incx, + [[maybe_unused]] int64_t lda, + [[maybe_unused]] int64_t incx, at::Half beta, - C10_UNUSED int64_t incy) { + [[maybe_unused]] int64_t incy) { return incx == 1 && c10::detail::fp16_from_bits(alpha.x) == 1.0f && - c10::detail::fp16_from_bits(beta.x) == 0.0f; + c10::detail::fp16_from_bits(beta.x) == 0.0f; } template <> bool gemv_use_fast_path( - C10_UNUSED char trans, - C10_UNUSED int64_t m, - C10_UNUSED int64_t n, + [[maybe_unused]] char trans, + [[maybe_unused]] int64_t m, + [[maybe_unused]] int64_t n, at::BFloat16 alpha, - C10_UNUSED int64_t lda, - C10_UNUSED int64_t incx, + [[maybe_unused]] int64_t lda, + [[maybe_unused]] int64_t incx, at::BFloat16 beta, - C10_UNUSED int64_t incy) { - return (trans == 'T' || trans == 't') && incx == 1 && alpha == 1.0 && beta == 0.0; + [[maybe_unused]] int64_t incy) { + return (trans == 'T' || trans == 't') && incx == 1 && alpha == 1.0 && + beta == 0.0; } - #ifdef __ARM_FEATURE_FP16_SCALAR_ARITHMETIC static inline float16_t reduce(float16x4_t x) { auto sum = vpadd_f16(x, x); @@ -301,7 +341,7 @@ static constexpr auto kF16RegistersPerIterationShift = kF16ElementsPerIterationS static constexpr auto kF16RegistersPerIteration = 1 << kF16RegistersPerIterationShift; static_assert(kF16RegistersPerIteration == kF16ElementsPerIteration / kF16ElementsPerRegister); -static inline double reduce(float16x8_t x[kF16RegistersPerIteration]) { +static inline float reduce(float16x8_t x[kF16RegistersPerIteration]) { int offset = kF16RegistersPerIteration; c10::ForcedUnroll{}([&offset, &x](auto idx) { offset /= 2; @@ -311,7 +351,7 @@ static inline double reduce(float16x8_t x[kF16RegistersPerIteration]) { }); const float32x4_t t0 = vcvt_f32_f16(vget_low_f16(x[0])); const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); - return (double)vaddvq_f32(vaddq_f32(t0, t1)); + return vaddvq_f32(vaddq_f32(t0, t1)); } static inline float16x8_t f16_fma(float16x8_t a, float16x8_t b, float16x8_t c) { @@ -333,12 +373,12 @@ static float fp16_dot_with_fp16_arith(const float16_t* x, const float16_t* a, in sum[k] = f16_fma(sum[k], temp_x, temp_a); } } - auto reducedSum = reduce(sum); + auto reduced_sum = reduce(sum); for (int j = len_aligned; j < len; ++j) { - reducedSum += x[j] * a[j]; + reduced_sum += x[j] * a[j]; } - return reducedSum; + return reduced_sum; } // Rather than unrolling to process multiple rows (transposed columns) @@ -352,7 +392,7 @@ static void fp16_gemv_trans_fp16_arith_by_dot_products(const int m, const int n, }); } -#endif +#endif // __ARM_FEATURE_FP16_SCALAR_ARITHMETIC static inline float reduce(float32x4_t x) { auto sum = vpaddq_f32(x, x); @@ -412,7 +452,7 @@ static constexpr auto kF32RegistersPerIterationShift = 3; static_assert(kF32RegistersPerIteration == kF32ElementsPerIteration / kF32ElementsPerRegister); static_assert(kF32RegistersPerIteration == 1 << kF32RegistersPerIterationShift); -static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) { +static inline float reduce(float32x4_t x[kF32RegistersPerIteration]) { int offset = kF32RegistersPerIteration; c10::ForcedUnroll{}([&offset, &x](auto idx) { offset /= 2; @@ -423,7 +463,7 @@ static inline double reduce(float32x4_t x[kF32RegistersPerIteration]) { return vaddvq_f32(x[0]); } -static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop( +static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot( const float16_t* vec1, const float16_t* vec2, float32x4_t sum[kF32RegistersPerIteration], @@ -436,86 +476,217 @@ static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop( sum[2 * registerPairIndex + 1] = f32_fma_high_f16(sum[2 * registerPairIndex + 1], temp_vec1, temp_vec2); } -static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( - const float16_t* vec1, - const float16_t* vec2, - float32x4_t* tailSum, - int idx) { +static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot( + const float16_t* vec1, + const float16_t* vec2, + float32x4_t* tail_sum, + int idx) { const auto temp_vec1 = vld1_f16(&vec1[idx]); const auto temp_vec2 = vld1_f16(&vec2[idx]); - *tailSum = f32_fma_f16(*tailSum, temp_vec1, temp_vec2); + *tail_sum = f32_fma_f16(*tail_sum, temp_vec1, temp_vec2); } -static C10_ALWAYS_INLINE float32x4_t to_bfloat16(uint16x4_t u16) { +static float32x4_t to_bfloat16(uint16x4_t u16) { int32x4_t shift = vdupq_n_s32(16); return vreinterpretq_f32_u32(vshlq_u32(vmovl_u16(u16), shift)); } -static C10_ALWAYS_INLINE float32x4_t f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) { +static float32x4_t f32_fma_bf16(float32x4_t a, uint16x4_t b, uint16x4_t c) { return f32_fma(a, to_bfloat16(b), to_bfloat16(c)); } -static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop( - const at::BFloat16* vec1, - const at::BFloat16* vec2, - float32x4_t sum[kF32RegistersPerIteration], - int registerPairIndex) { - // TODO: detect intrinsic availability, use them if they're available. __ARM_FEATURE_BF16 - // Load a pair of f32 registers at a time. - const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast(&vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); - const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast(&vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); +#if defined(__clang__) && __clang_major__ > 15 +// https://godbolt.org/z/z8P4Yncra +#define COMPILER_SUPPORTS_BF16_TARGET 1 +#elif !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 10 +// https://gcc.gnu.org/gcc-10/changes.html +// https://godbolt.org/z/cdGG7vn8o +#define COMPILER_SUPPORTS_BF16_TARGET 1 +#else +#define COMPILER_SUPPORTS_BF16_TARGET 0 +#endif + +#if COMPILER_SUPPORTS_BF16_TARGET +#define TARGET_ARM_BF16_ATTRIBUTE __attribute__((target("arch=armv8.2-a+bf16"))) + +TARGET_ARM_BF16_ATTRIBUTE static C10_ALWAYS_INLINE float32x4_t +f32_dot_bf16(float32x4_t a, bfloat16x8_t b, bfloat16x8_t c) { + return vbfdotq_f32(a, b, c); +} - sum[2 * registerPairIndex] = f32_fma_bf16(sum[2 * registerPairIndex], vget_low_u16(temp_vec1), vget_low_u16(temp_vec2)); - sum[2 * registerPairIndex + 1] = f32_fma_bf16(sum[2 * registerPairIndex + 1], vget_high_u16(temp_vec1), vget_high_u16(temp_vec2)); +TARGET_ARM_BF16_ATTRIBUTE static C10_ALWAYS_INLINE void +dot_with_fp32_arith_main_inner_loop_bfdot( + const BFloat16* vec1, + const BFloat16* vec2, + float32x4_t sum[kF32RegistersPerIteration], + int registerPairIndex) { + const bfloat16x8_t temp_vec1 = vld1q_bf16(reinterpret_cast( + &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const bfloat16x8_t temp_vec2 = vld1q_bf16(reinterpret_cast( + &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + sum[registerPairIndex] = + f32_dot_bf16(sum[registerPairIndex], temp_vec1, temp_vec2); } -static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop( - const at::BFloat16* vec1, - const at::BFloat16* vec2, - float32x4_t* tailSum, - int idx) { +// See NOTE [GCC code duplication] below for why we have _bfdot and +// _no_bfdot versions of +// dot_with_fp32_arith_vectorized_tail_inner_loop. +TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE +static void dot_with_fp32_arith_vectorized_tail_inner_loop_bfdot( + const at::BFloat16* vec1, + const at::BFloat16* vec2, + float32x4_t* tail_sum, + int idx) { const auto temp_vec1 = vld1_u16(reinterpret_cast(&vec1[idx])); const auto temp_vec2 = vld1_u16(reinterpret_cast(&vec2[idx])); - *tailSum = f32_fma_bf16(*tailSum, temp_vec1, temp_vec2); + *tail_sum = f32_fma_bf16(*tail_sum, temp_vec1, temp_vec2); } -template -float dot_with_fp32_arith(const T* vec1, const T* vec2, int64_t len) { +#else +#define TARGET_ARM_BF16_ATTRIBUTE +#endif // COMPILER_SUPPORTS_BF16_TARGET + +static C10_ALWAYS_INLINE void dot_with_fp32_arith_main_inner_loop_no_bfdot( + const BFloat16* vec1, + const BFloat16* vec2, + float32x4_t sum[kF32RegistersPerIteration], + int registerPairIndex) { + const uint16x8_t temp_vec1 = vld1q_u16(reinterpret_cast( + &vec1[registerPairIndex * 2 * kF32ElementsPerRegister])); + const uint16x8_t temp_vec2 = vld1q_u16(reinterpret_cast( + &vec2[registerPairIndex * 2 * kF32ElementsPerRegister])); + + sum[2 * registerPairIndex] = f32_fma_bf16( + sum[2 * registerPairIndex], + vget_low_u16(temp_vec1), + vget_low_u16(temp_vec2)); + sum[2 * registerPairIndex + 1] = f32_fma_bf16( + sum[2 * registerPairIndex + 1], + vget_high_u16(temp_vec1), + vget_high_u16(temp_vec2)); +} + +static C10_ALWAYS_INLINE void dot_with_fp32_arith_vectorized_tail_inner_loop_no_bfdot( + const at::BFloat16* vec1, + const at::BFloat16* vec2, + float32x4_t* tail_sum, + int idx) { + const auto temp_vec1 = vld1_u16(reinterpret_cast(&vec1[idx])); + const auto temp_vec2 = vld1_u16(reinterpret_cast(&vec2[idx])); + *tail_sum = f32_fma_bf16(*tail_sum, temp_vec1, temp_vec2); +} + +namespace { +#if COMPILER_SUPPORTS_BF16_TARGET +template +struct ForcedUnrollTargetBFloat16 { + template + TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(const Func& f) const { + ForcedUnrollTargetBFloat16{}(f); + f(n - 1); + } +}; + +template <> +struct ForcedUnrollTargetBFloat16<1> { + template + TARGET_ARM_BF16_ATTRIBUTE C10_ALWAYS_INLINE void operator()(const Func& f) const { + f(0); + } +}; + +C10_ALWAYS_INLINE TARGET_ARM_BF16_ATTRIBUTE auto +dot_with_fp32_arith_main_loop_bfdot( + const BFloat16* vec1, + const BFloat16* vec2, + int64_t len) { float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) { const auto* vec1_ = vec1 + j; const auto* vec2_ = vec2 + j; - c10::ForcedUnroll{}([vec1_, vec2_, &sum](auto k) { - dot_with_fp32_arith_main_inner_loop(vec1_, vec2_, sum, k); + ForcedUnrollTargetBFloat16{}([vec1_, vec2_, &sum](auto k) + C10_ALWAYS_INLINE_ATTRIBUTE TARGET_ARM_BF16_ATTRIBUTE { + dot_with_fp32_arith_main_inner_loop_bfdot(vec1_, vec2_, sum, k); }); } - auto reducedSum = reduce(sum); - - // First-tier tail fixup: make sure we handle workloads that can - // benefit from vectorization, but don't fit into our fully unrolled - // loop above. - float32x4_t tailSum = vdupq_n_f32(0); - const auto len_aligned_4 = len & ~3; - for (int j = len_aligned; j < len_aligned_4; j += 4) { - dot_with_fp32_arith_vectorized_tail_inner_loop(vec1, vec2, &tailSum, j); - } - auto reducedTail = vpaddq_f32(tailSum, tailSum); - reducedSum += vgetq_lane_f32(vpaddq_f32(reducedTail, reducedTail), 0); + return reduce(sum); +} +#endif // COMPILER_SUPPORTS_BF16_TARGET - // Second-tier tail fixup: handle all workloads. - for (int j = len_aligned_4; j < len; ++j) { - reducedSum += vec1[j] * vec2[j]; +template +C10_ALWAYS_INLINE auto +dot_with_fp32_arith_main_loop_no_bfdot( + const T* vec1, + const T* vec2, + int64_t len) { + float32x4_t sum[kF32RegistersPerIteration] = {vdupq_n_f32(0)}; + const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); + for (int j = 0; j < len_aligned ; j += kF32ElementsPerIteration) { + const auto* vec1_ = vec1 + j; + const auto* vec2_ = vec2 + j; + c10::ForcedUnroll{}([vec1_, vec2_, &sum](auto k) C10_ALWAYS_INLINE_ATTRIBUTE { + dot_with_fp32_arith_main_inner_loop_no_bfdot(vec1_, vec2_, sum, k); + }); } - return reducedSum; + return reduce(sum); +} + +// NOTE [GCC code duplication]: The first attempt at landing BFDOT support with +// TARGET_ARM_BF16_ATTRIBUTE failed because unlike clang, GCC will not +// allow inlining a non-bf16-specific function into a bf16-specific +// function. We can work around this by duplicating the code into the +// bfdot and non-bfdot callsites. The code is in this macro to avoid +// actual copy/paste. +#define DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(bfdot_suffix) \ + /* First-tier tail fixup: make sure we handle workloads that can */ \ + /* benefit from vectorization, but don't fit into our fully unrolled */ \ + /* loop above. */ \ + float32x4_t tail_sum = vdupq_n_f32(0); \ + const auto len_aligned = len & ~(kF32ElementsPerIteration - 1); \ + const auto len_aligned_4 = len & ~3; \ + for (int j = len_aligned; j < len_aligned_4; j += 4) { \ + dot_with_fp32_arith_vectorized_tail_inner_loop##bfdot_suffix(vec1, vec2, &tail_sum, j); \ + } \ + auto reduced_tail = vpaddq_f32(tail_sum, tail_sum); \ + reduced_sum += vgetq_lane_f32(vpaddq_f32(reduced_tail, reduced_tail), 0); \ + \ + /* Second-tier tail fixup: handle all workloads. */ \ + for (int j = len_aligned_4; j < len; ++j) { \ + reduced_sum += vec1[j] * vec2[j]; \ + } \ + return reduced_sum + +#if COMPILER_SUPPORTS_BF16_TARGET +TARGET_ARM_BF16_ATTRIBUTE float +dot_with_fp32_arith_bfdot(const BFloat16* vec1, const BFloat16* vec2, int64_t len) { + auto reduced_sum = dot_with_fp32_arith_main_loop_bfdot(vec1, vec2, len); + DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(_bfdot); } +#endif // COMPILER_SUPPORTS_BF16_TARGET + +template +C10_ALWAYS_INLINE float +dot_with_fp32_arith_no_bfdot(const T* vec1, const T* vec2, int64_t len) { + auto reduced_sum = dot_with_fp32_arith_main_loop_no_bfdot(vec1, vec2, len); + DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY(_no_bfdot); +} +#undef DOT_WITH_FP32_ARITH_TAIL_AFTER_MAIN_LOOP_BODY +} // namespace float fp16_dot_with_fp32_arith(const float16_t* vec1, const float16_t* vec2, int64_t len) { - return dot_with_fp32_arith(vec1, vec2, len); + return dot_with_fp32_arith_no_bfdot(vec1, vec2, len); } float bf16_dot_with_fp32_arith(const at::BFloat16* vec1, const at::BFloat16* vec2, int64_t len) { - return dot_with_fp32_arith(vec1, vec2, len); +#if COMPILER_SUPPORTS_BF16_TARGET + if (cpuinfo_has_arm_bf16()) { + return dot_with_fp32_arith_bfdot(vec1, vec2, len); + } else +#endif + { + return dot_with_fp32_arith_no_bfdot(vec1, vec2, len); + } } // On my Apple M1 Macbook (which is ARM v8.5 and thus has the diff --git a/aten/src/ATen/native/ConvUtils.h b/aten/src/ATen/native/ConvUtils.h index ad05a820f7b27..f601b0a3be2ec 100644 --- a/aten/src/ATen/native/ConvUtils.h +++ b/aten/src/ATen/native/ConvUtils.h @@ -168,7 +168,7 @@ static void check_args(CheckedFrom c, IntArrayRef args, size_t expected_size, co ss << arg_name << " should be greater than zero but got ("; std::copy(args.begin(), args.end() - 1, std::ostream_iterator(ss,", ")); ss << args.back() << ")" << " (while checking arguments for " << c << ")"; - AT_ERROR(ss.str()); + TORCH_CHECK(false, ss.str()); } } diff --git a/aten/src/ATen/native/Convolution.cpp b/aten/src/ATen/native/Convolution.cpp index adb531a51df5c..b9354cd610a8a 100644 --- a/aten/src/ATen/native/Convolution.cpp +++ b/aten/src/ATen/native/Convolution.cpp @@ -719,7 +719,7 @@ static void check_shape_forward(const at::Tensor& input, separator = " x "; } - AT_ERROR("Calculated padded input size per channel: (", input_ss.str(), "). " + TORCH_CHECK(false, "Calculated padded input size per channel: (", input_ss.str(), "). " "Kernel size: (", kernel_ss.str(), "). Kernel size can't be greater than actual input size"); } } else { // transposed @@ -1304,7 +1304,7 @@ ConvBackend _select_conv_backend( } // Error out if no suitable backend was found. - AT_ERROR("unsupported ConvNd parameters"); + TORCH_CHECK(false, "unsupported ConvNd parameters"); } // Selects a backend for convolution based on the inputs and params. @@ -1732,8 +1732,8 @@ std::tuple _convolution_double_backward( const std::option // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned ggI_maybe_owned = at::borrow_from_optional_tensor(ggI_opt); const Tensor& ggI = *ggI_maybe_owned; - const Tensor& ggW_r = c10::value_or_else(ggW_r_opt, [] {return Tensor();}); - const Tensor& ggb = c10::value_or_else(ggb_opt, [] {return Tensor();}); + const Tensor& ggW_r = ggW_r_opt.value_or(Tensor()); + const Tensor& ggb = ggb_opt.value_or(Tensor()); auto ggW = ggW_r; diff --git a/aten/src/ATen/native/DispatchStub.cpp b/aten/src/ATen/native/DispatchStub.cpp index b57649c263259..fa43aa886b2f7 100644 --- a/aten/src/ATen/native/DispatchStub.cpp +++ b/aten/src/ATen/native/DispatchStub.cpp @@ -262,7 +262,7 @@ void* DispatchStubImpl::get_call_ptr( false, "DispatchStub: missing kernel for ", device_type); return nullptr; case ErrorType::DeviceNotSupported: - AT_ERROR("DispatchStub: unsupported device type", device_type); + TORCH_CHECK(false, "DispatchStub: unsupported device type", device_type); } } diff --git a/aten/src/ATen/native/Dropout.cpp b/aten/src/ATen/native/Dropout.cpp index 24f9d648f4f31..366a00487ff5f 100644 --- a/aten/src/ATen/native/Dropout.cpp +++ b/aten/src/ATen/native/Dropout.cpp @@ -34,7 +34,7 @@ Tensor make_feature_noise(const Tensor& input) { sizes.reserve(input.dim()); sizes.push_back(input_sizes[0]); sizes.push_back(input_sizes[1]); - for (C10_UNUSED const auto i : c10::irange(2, input.dim())) { + for ([[maybe_unused]] const auto i : c10::irange(2, input.dim())) { sizes.push_back(1); } return input.new_empty_symint(sizes); diff --git a/aten/src/ATen/native/Embedding.cpp b/aten/src/ATen/native/Embedding.cpp index b0c4644e579c2..5a148e8ddb821 100644 --- a/aten/src/ATen/native/Embedding.cpp +++ b/aten/src/ATen/native/Embedding.cpp @@ -81,7 +81,7 @@ Tensor embedding_sparse_backward( // TODO: implement scale_grad_by_freq if (scale_grad_by_freq) { - AT_ERROR( + TORCH_CHECK(false, "embedding_backward: scale_grad_by_freq not supported with sparse gradients"); } diff --git a/aten/src/ATen/native/EmbeddingBag.cpp b/aten/src/ATen/native/EmbeddingBag.cpp index fbff571fececd..58dc1b991d267 100644 --- a/aten/src/ATen/native/EmbeddingBag.cpp +++ b/aten/src/ATen/native/EmbeddingBag.cpp @@ -106,7 +106,7 @@ bool is_fast_path(const Tensor& src, const std::optional& scale, Tensor& // index_add (using add_indices as the index), without creating an intermediary // tensor to hold the selected embeddings template -static typename std::enable_if::value, void>::type +static std::enable_if_t, void> index_select_add( const Tensor& select_indices, const Tensor& add_indices, @@ -184,10 +184,9 @@ void fbgemm_spmdm_report_error_( } // namespace template -typename std::enable_if< - std::is_same::value || - std::is_same::value, - void>::type +std::enable_if_t< + std::is_same_v || std::is_same_v, + void> index_select_add( const Tensor& select_indices, const Tensor& add_indices, @@ -366,7 +365,7 @@ index_select_add( } } template -typename std::enable_if::value, void>::type +std::enable_if_t, void> index_select_add(const Tensor &select_indices, const Tensor &add_indices, const Tensor &src, @@ -493,7 +492,7 @@ index_select_add(const Tensor &select_indices, // mul (scaling by per_sample_weights) // index_add (using add_indices as the index) template -static typename std::enable_if::value, void>::type +static std::enable_if_t, void> index_select_scale_add( const Tensor& select_indices, const Tensor& add_indices, @@ -548,10 +547,9 @@ index_select_scale_add( } template -typename std::enable_if< - std::is_same::value || - std::is_same::value, - void>::type +std::enable_if_t< + std::is_same_v || std::is_same_v, + void> index_select_scale_add( const Tensor& select_indices, const Tensor& add_indices, @@ -741,7 +739,7 @@ index_select_scale_add( } } template -typename std::enable_if::value, void>::type +std::enable_if_t, void> index_select_scale_add(const Tensor &select_indices, const Tensor &add_indices, const Tensor &scale, diff --git a/aten/src/ATen/native/Fill.cpp b/aten/src/ATen/native/Fill.cpp index 7225d7de1128b..5ff1e6b61ed20 100644 --- a/aten/src/ATen/native/Fill.cpp +++ b/aten/src/ATen/native/Fill.cpp @@ -104,7 +104,7 @@ Tensor& fill_diagonal_(Tensor& self, const Scalar& fill_value, bool wrap) { int64_t dim1 = height; for (const auto i : c10::irange(1, nDims)) { if (self.size(i) != dim1) { - AT_ERROR("all dimensions of input must be of equal length"); + TORCH_CHECK(false, "all dimensions of input must be of equal length"); } } } diff --git a/aten/src/ATen/native/IndexingUtils.h b/aten/src/ATen/native/IndexingUtils.h index cef21c3fd80d5..c442b2232a967 100644 --- a/aten/src/ATen/native/IndexingUtils.h +++ b/aten/src/ATen/native/IndexingUtils.h @@ -13,9 +13,11 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask, " does not match the shape of the indexed tensor ", self.sizes(), " at index ", idx); } - -static C10_UNUSED std::vector expandTensors(const Tensor & self, IOptTensorListRef indices) { - // If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors +[[maybe_unused]] static std::vector expandTensors( + const Tensor& self, + IOptTensorListRef indices) { + // If indices come in as ByteTensor or BoolTensor (masks), expand them into + // the equivalent indexing by LongTensors std::vector result; for (const auto& index_opt : indices) { if (!index_opt.has_value()) { @@ -48,7 +50,9 @@ static C10_UNUSED std::vector expandTensors(const Tensor & self, IOptTen return result; } -static C10_UNUSED void checkIndexTensorTypes(IOptTensorListRef indices, bool allow_int=false) { +[[maybe_unused]] static void checkIndexTensorTypes( + IOptTensorListRef indices, + bool allow_int = false) { for (const auto& tensor : indices) { if (tensor.has_value() && tensor->defined()) { auto scalarType = tensor->scalar_type(); @@ -83,7 +87,7 @@ inline torch::List> toListOfOptionalTensors(ArrayRef> -transposeToFront(const Tensor& self, TensorList indices) { +[[maybe_unused]] static std::tuple> transposeToFront( + const Tensor& self, + TensorList indices) { std::vector dims; std::vector transposedIndices; dims.reserve(self.dim()); diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 1016ed8606b2e..abc65ae5c6772 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -3558,7 +3559,7 @@ Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result) } bool dispatched = false; - if (at::globalContext().userEnabledMkldnn()) { + if (at::globalContext().userEnabledMkldnn() && at::cpu::is_avx512_vnni_supported()) { try { mkldnn_matmul_i8i8i32(self, mat2, result); dispatched = true; diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index 48d9c31129654..a0011a9ddf55f 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -241,8 +241,9 @@ void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const fu auto* b_batch_idx_ptr = data[0]; auto* a_batch_idx_ptr = data[1]; - for (const auto elem C10_UNUSED : c10::irange(nelems)) { - auto b_curr_linear_batch_idx = *reinterpret_cast(b_batch_idx_ptr); + for ([[maybe_unused]] const auto elem : c10::irange(nelems)) { + auto b_curr_linear_batch_idx = + *reinterpret_cast(b_batch_idx_ptr); auto a_curr_linear_batch_idx = *reinterpret_cast(a_batch_idx_ptr); check_if_copy_needed_for_a(a_curr_linear_batch_idx); @@ -268,7 +269,7 @@ inline double _get_epsilon(const ScalarType& sc_type) { case at::ScalarType::Double: return std::numeric_limits::epsilon(); default: - AT_ERROR("This function doesn't handle types other than float and double"); + TORCH_CHECK(false, "This function doesn't handle types other than float and double"); } } diff --git a/aten/src/ATen/native/LossMultiLabelMargin.cpp b/aten/src/ATen/native/LossMultiLabelMargin.cpp index a6998175b5d09..d0c2a4adb3d38 100644 --- a/aten/src/ATen/native/LossMultiLabelMargin.cpp +++ b/aten/src/ATen/native/LossMultiLabelMargin.cpp @@ -76,7 +76,7 @@ static void multilabel_margin_loss_forward_out_frame( accscalar_t sum = 0; - for (C10_UNUSED const auto t : c10::irange(nframe)) { + for ([[maybe_unused]] const auto t : c10::irange(nframe)) { sum += multilabel_margin_loss_forward_inner_sum_cpu( input_data, target_data, is_target_data, dim); @@ -180,7 +180,7 @@ static void multilabel_margin_loss_backward_out_frame( reduction == Reduction::Mean ? 1. / (nframe * dim) : 1. / dim); scalar_t* grad_input_row_data = grad_input.mutable_data_ptr(); - for (C10_UNUSED const auto t : c10::irange(nframe)) { + for ([[maybe_unused]] const auto t : c10::irange(nframe)) { for (const auto dt : c10::irange(dim)) { int64_t target_idx = target_data[dt]; if (target_idx < 0) { diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index e86a9aea411af..e04265e44f8e5 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -147,7 +148,7 @@ jiterator_also_stringify_as(jiterator_code( #define CENTRAL_RANGE 0.7 template -inline typename std::enable_if::value, T>::type +inline typename std::enable_if_t, T> calc_erfinv(T y) { /* Function to calculate inverse error function. Rational approximation is used to generate an initial approximation, which is then improved to @@ -1203,22 +1204,30 @@ scalar_t calc_igamma(scalar_t a, scalar_t x) { } template <> -C10_UNUSED inline c10::BFloat16 calc_igamma(c10::BFloat16 a, c10::BFloat16 x) { +[[maybe_unused]] inline c10::BFloat16 calc_igamma( + c10::BFloat16 a, + c10::BFloat16 x) { return calc_igamma(float(a), float(x)); } template <> -C10_UNUSED inline c10::Half calc_igamma(c10::Half a, c10::Half x) { +[[maybe_unused]] inline c10::Half calc_igamma( + c10::Half a, + c10::Half x) { return calc_igamma(float(a), float(x)); } template <> -C10_UNUSED inline c10::BFloat16 calc_igammac(c10::BFloat16 a, c10::BFloat16 x) { +[[maybe_unused]] inline c10::BFloat16 calc_igammac( + c10::BFloat16 a, + c10::BFloat16 x) { return calc_igammac(float(a), float(x)); } template <> -C10_UNUSED inline c10::Half calc_igammac(c10::Half a, c10::Half x) { +[[maybe_unused]] inline c10::Half calc_igammac( + c10::Half a, + c10::Half x) { return calc_igammac(float(a), float(x)); } @@ -1230,12 +1239,12 @@ inline T abs_impl(T v) { } template <> -C10_UNUSED inline uint8_t abs_impl(uint8_t v) { +[[maybe_unused]] inline uint8_t abs_impl(uint8_t v) { return v; } template -inline typename std::enable_if::value, T>::type +inline typename std::enable_if_t, T> calc_gcd(T a, T b) { a = abs_impl(a); b = abs_impl(b); @@ -1284,7 +1293,7 @@ C10_HOST_DEVICE c10::complex exp2_impl(c10::complex x) { * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, this becomes x -> 4a/x - 1. */ template -inline typename std::enable_if::value, T>::type +inline typename std::enable_if_t, T> chbevl(const T x, const T array[], size_t len) { T b0, b1, b2; @@ -1361,7 +1370,7 @@ inline std::tuple chebyshev_coefficients_i0e_B() { }; template -inline typename std::enable_if::value, std::tuple>::type +inline typename std::enable_if_t, std::tuple> chebyshev_coefficients_i1e_A() { /* Chebyshev coefficients for exp(-x) I1(x) * in the interval [0,8]. @@ -1388,7 +1397,7 @@ chebyshev_coefficients_i1e_A() { }; template -inline typename std::enable_if::value, std::tuple>::type +inline typename std::enable_if_t, std::tuple> chebyshev_coefficients_i1e_A() { /* Chebyshev coefficients for exp(-x) I1(x) * in the interval [0,8]. @@ -1417,7 +1426,7 @@ chebyshev_coefficients_i1e_A() { }; template -inline typename std::enable_if::value, std::tuple>::type +inline typename std::enable_if_t, std::tuple> chebyshev_coefficients_i1e_B() { /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) * in the inverted interval [8,infinity]. @@ -1443,7 +1452,7 @@ chebyshev_coefficients_i1e_B() { }; template -inline typename std::enable_if::value, std::tuple>::type +inline typename std::enable_if_t, std::tuple> chebyshev_coefficients_i1e_B() { /* Chebyshev coefficients for exp(-x) sqrt(x) I1(x) * in the inverted interval [8,infinity]. @@ -1463,7 +1472,7 @@ chebyshev_coefficients_i1e_B() { }; template -inline typename std::enable_if::value, T>::type +inline typename std::enable_if_t, T> calc_i0(T _x) { T x = std::abs(_x); @@ -1480,8 +1489,9 @@ calc_i0(T _x) { return std::exp(x) * chbevl(T{32.0} / x - T{2.0}, B, len) / std::sqrt(x); } -// Upcast bfloat16 input to float for numerical accuracy purposes +// Upcast bfloat16/half input to float for numerical accuracy purposes inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast(a)); } +inline c10::Half calc_i0(c10::Half a) { return calc_i0(static_cast(a)); } /* * This function is derived from the implementation of the i1 function in the Cephes Math Library. @@ -1493,7 +1503,7 @@ inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast -inline typename std::enable_if::value, T>::type +inline typename std::enable_if_t, T> calc_i1(T _x) { T x = std::abs(_x); @@ -1512,6 +1522,11 @@ calc_i1(T _x) { return (_x < T{0.0}) ? -out : out; } +// Upcast bfloat16/half input to float for numerical accuracy purposes +inline c10::BFloat16 calc_i1(c10::BFloat16 a) { return calc_i1(static_cast(a)); } +inline c10::Half calc_i1(c10::Half a) { return calc_i1(static_cast(a)); } + + /* * This function is derived from the implementation of the i1e function in the Cephes Math Library. * See note [3-Clause BSD License for the Cephes Math Library]. @@ -1522,7 +1537,7 @@ calc_i1(T _x) { * of all inputs to convert them into the domain of the approximation. */ template -inline typename std::enable_if::value, T>::type +inline typename std::enable_if_t, T> calc_i1e(T _x) { T x = std::abs(_x); @@ -1541,6 +1556,11 @@ calc_i1e(T _x) { return (_x < T{0.0}) ? -out : out; } +// Upcast bfloat16/half input to float for numerical accuracy purposes +inline c10::BFloat16 calc_i1e(c10::BFloat16 a) { return calc_i1e(static_cast(a)); } +inline c10::Half calc_i1e(c10::Half a) { return calc_i1e(static_cast(a)); } + + /* * This function is derived from the implementation of the i1e function in the Cephes Math Library. * See note [3-Clause BSD License for the Cephes Math Library]. @@ -1737,7 +1757,7 @@ inline C10_HOST_DEVICE T calc_ndtri(T y0) { template -C10_HOST_DEVICE inline typename std::enable_if::value, T>::type +C10_HOST_DEVICE inline typename std::enable_if_t, T> erfcx_y100(T y100) { switch (static_cast(y100)) { @@ -2148,7 +2168,7 @@ return 0.97771701335885035464e0 + (0.22000938572830479551e-1 + (0.27951610702682 } template -C10_HOST_DEVICE inline typename std::enable_if::value, T>::type +C10_HOST_DEVICE inline typename std::enable_if_t, T> calc_erfcx(T x) { if (at::_isnan(x)) { @@ -3060,14 +3080,14 @@ inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, int64_t n) { return r; } // hermite_polynomial_h_forward(T x, int64_t n) -template::value, int> = 0> +template, int> = 0> inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { return hermite_polynomial_h_forward(x, static_cast(n)); } // hermite_polynomial_h_forward(T x, T n) -template::value, int> = 0> -inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { - return hermite_polynomial_h_forward(x, ((!std::isinf(n)) && (!std::isnan(n))) ? static_cast(n) : static_cast(-1)); +template, int> = 0> +__ubsan_ignore_float_cast_overflow__ inline C10_HOST_DEVICE T hermite_polynomial_h_forward(T x, T n) { + return hermite_polynomial_h_forward(x, (!std::isinf(n) && !std::isnan(n)) ? static_cast(n) : static_cast(-1)); } // hermite_polynomial_h_forward(T x, T n) template diff --git a/aten/src/ATen/native/MaxUnpooling.cpp b/aten/src/ATen/native/MaxUnpooling.cpp index f7d4355785fb4..0e9294770e32a 100644 --- a/aten/src/ATen/native/MaxUnpooling.cpp +++ b/aten/src/ATen/native/MaxUnpooling.cpp @@ -136,7 +136,7 @@ static void max_unpooling3d_shape_check( if (gradOutput.defined()) { if (oT != gradOutput.size(dimt) || oH != gradOutput.size(dimh) || oW != gradOutput.size(dimw)) { - AT_ERROR( + TORCH_CHECK(false, "Inconsistent gradOutput size. oT= ", oT, ", oH= ", diff --git a/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp b/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp index 7da1ec9b19987..1fe298d9e1f1b 100644 --- a/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp +++ b/aten/src/ATen/native/NaiveConvolutionTranspose2d.cpp @@ -85,7 +85,7 @@ static inline void slow_conv_transpose2d_shape_check( check_dim_size(bias, 1, 0, weight.size(1)); } } else if (!weight_nullable) { - AT_ERROR("weight tensor is expected to be non-nullable"); + TORCH_CHECK(false, "weight tensor is expected to be non-nullable"); } int ndim = input.dim(); @@ -112,7 +112,7 @@ static inline void slow_conv_transpose2d_shape_check( (dilation_width * (kernel_width - 1) + 1) + output_padding_width; if (output_width < 1 || output_height < 1) { - AT_ERROR( + TORCH_CHECK(false, "Given input size per channel: (", input_height, " x ", diff --git a/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp b/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp index 9ef236d4dab93..82a263840e015 100644 --- a/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp +++ b/aten/src/ATen/native/NaiveConvolutionTranspose3d.cpp @@ -107,7 +107,7 @@ static inline void slow_conv_transpose3d_shape_check( check_dim_size(bias, 1, 0, weight.size(1)); } } else if (!weight_nullable) { - AT_ERROR("weight tensor is expected to be non-nullable"); + TORCH_CHECK(false, "weight tensor is expected to be non-nullable"); } int ndim = input.dim(); @@ -142,7 +142,7 @@ static inline void slow_conv_transpose3d_shape_check( output_padding_width; if (output_depth < 1 || output_width < 1 || output_height < 1) { - AT_ERROR( + TORCH_CHECK(false, "Given input size per channel: (", input_depth, " x ", diff --git a/aten/src/ATen/native/Normalization.cpp b/aten/src/ATen/native/Normalization.cpp index 796eac362b124..8e50d93b0b1ef 100644 --- a/aten/src/ATen/native/Normalization.cpp +++ b/aten/src/ATen/native/Normalization.cpp @@ -549,9 +549,9 @@ std::tuple _batch_norm_impl_index( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); - const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); + const Tensor& bias = bias_opt.value_or(Tensor()); + const Tensor& running_mean = running_mean_opt.value_or(Tensor()); + const Tensor& running_var = running_var_opt.value_or(Tensor()); auto num_features = input.sym_sizes()[1]; @@ -573,12 +573,12 @@ std::tuple _batch_norm_impl_index( if (running_mean.defined()) { check_dims_match_num_input_features("running_mean", num_features, running_mean.sym_numel()); } else if (!training) { - AT_ERROR("running_mean must be defined in evaluation mode"); + TORCH_CHECK(false, "running_mean must be defined in evaluation mode"); } if (running_var.defined()) { check_dims_match_num_input_features("running_var", num_features, running_var.sym_numel()); } else if (!training) { - AT_ERROR("running_var must be defined in evaluation mode"); + TORCH_CHECK(false, "running_var must be defined in evaluation mode"); } if (weight.defined()) { check_dims_match_num_input_features("weight", num_features, weight.sym_numel()); @@ -631,10 +631,10 @@ std::tuple _batch_norm_impl_index_backward( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); - const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();}); - const Tensor& save_var_transform = c10::value_or_else(save_var_transform_opt, [] {return Tensor();}); + const Tensor& running_mean = running_mean_opt.value_or(Tensor()); + const Tensor& running_var = running_var_opt.value_or(Tensor()); + const Tensor& save_mean = save_mean_opt.value_or(Tensor()); + const Tensor& save_var_transform = save_var_transform_opt.value_or(Tensor()); if (input.numel() == 0) { std::vector dims(input.dim() - 1); @@ -675,10 +675,10 @@ Tensor batch_norm( const Tensor& input, const std::optional& weight_opt, const std::optional& bias_opt, const std::optional& running_mean_opt, const std::optional& running_var_opt, bool training, double momentum, double eps, bool cudnn_enabled) { - const Tensor& weight = c10::value_or_else(weight_opt, [] {return Tensor();}); - const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); - const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); + const Tensor& weight = weight_opt.value_or(Tensor()); + const Tensor& bias = bias_opt.value_or(Tensor()); + const Tensor& running_mean = running_mean_opt.value_or(Tensor()); + const Tensor& running_var = running_var_opt.value_or(Tensor()); return std::get<0>(at::_batch_norm_impl_index(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled)); // TODO: switch to the new stack after the 2 week FC window @@ -713,9 +713,9 @@ Tensor instance_norm( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); - const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); + const Tensor& bias = bias_opt.value_or(Tensor()); + const Tensor& running_mean = running_mean_opt.value_or(Tensor()); + const Tensor& running_var = running_var_opt.value_or(Tensor()); TORCH_CHECK(use_input_stats || (running_mean.defined() && running_var.defined()), "Expected running_mean and running_var to be defined when use_input_stats is false"); @@ -750,7 +750,7 @@ std::tuple batch_norm_update_stats_cpu( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt); const Tensor& running_mean = *running_mean_maybe_owned; - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); + const Tensor& running_var = running_var_opt.value_or(Tensor()); const bool mixed_type = is_mixed_type(self, running_mean, running_var); return AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "batch_norm_update_stats_cpu", [&] { @@ -769,9 +769,9 @@ std::tuple batch_norm_cpu_out(const Tensor& self, con // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); - const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); + const Tensor& bias = bias_opt.value_or(Tensor()); + const Tensor& running_mean = running_mean_opt.value_or(Tensor()); + const Tensor& running_var = running_var_opt.value_or(Tensor()); checkBackend("batch_norm_cpu_out", {self, weight, bias, running_mean, running_var}, Backend::CPU); // Resize out @@ -812,9 +812,9 @@ std::tuple batch_norm_cpu(const Tensor& self, const std: // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); - const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); + const Tensor& bias = bias_opt.value_or(Tensor()); + const Tensor& running_mean = running_mean_opt.value_or(Tensor()); + const Tensor& running_var = running_var_opt.value_or(Tensor()); checkBackend("batch_norm_cpu", {self, weight, bias, running_mean, running_var}, Backend::CPU); @@ -879,8 +879,8 @@ std::tuple _batch_norm_no_update( const Tensor& input, const std::optional& weight_opt, const std::optional& bias_opt, const std::optional& running_mean_opt, const std::optional& running_var_opt, double momentum, double eps) { - const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); + const Tensor& running_mean = running_mean_opt.value_or(Tensor()); + const Tensor& running_var = running_var_opt.value_or(Tensor()); auto [output, save_mean, save_var] = batch_norm_cpu(input, weight_opt, bias_opt, const_cast(running_mean), const_cast(running_var), /*update*/false, momentum, eps); Tensor reserve = at::empty({0}, input.options().dtype(kByte)); @@ -927,10 +927,10 @@ std::tuple batch_norm_backward_cpu(const Tensor& grad_ou // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); - const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();}); - const Tensor& save_invstd = c10::value_or_else(save_invstd_opt, [] {return Tensor();}); + const Tensor& running_mean = running_mean_opt.value_or(Tensor()); + const Tensor& running_var = running_var_opt.value_or(Tensor()); + const Tensor& save_mean = save_mean_opt.value_or(Tensor()); + const Tensor& save_invstd = save_invstd_opt.value_or(Tensor()); const bool mixed_type = is_mixed_type(self, weight, running_mean, running_var, save_mean, save_invstd); return AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "batch_norm_backward_cpu", [&] { diff --git a/aten/src/ATen/native/Onehot.cpp b/aten/src/ATen/native/Onehot.cpp index fcbe7fd1ddc10..2ac513bf08880 100644 --- a/aten/src/ATen/native/Onehot.cpp +++ b/aten/src/ATen/native/Onehot.cpp @@ -34,7 +34,7 @@ Tensor one_hot(const Tensor &self, int64_t num_classes) { // but shape inference is not possible. if (self.numel() == 0) { if (num_classes <= 0) { - AT_ERROR("Can not infer total number of classes from empty tensor."); + TORCH_CHECK(false, "Can not infer total number of classes from empty tensor."); } else { shape.push_back(num_classes); return at::empty(shape, self.options()); diff --git a/aten/src/ATen/native/PackedSequence.cpp b/aten/src/ATen/native/PackedSequence.cpp index 85e24d2275a62..568f7dc1e31ee 100644 --- a/aten/src/ATen/native/PackedSequence.cpp +++ b/aten/src/ATen/native/PackedSequence.cpp @@ -51,7 +51,7 @@ std::tuple _pack_padded_sequence(const Tensor& _input, const Ten // NB: enforce_sorted is implemented at a Python level, but the sortedness // check lives here. If enforce_sorted=False then this error should never // get called. - AT_ERROR("`lengths` array must be sorted in decreasing order when " + TORCH_CHECK(false, "`lengths` array must be sorted in decreasing order when " "`enforce_sorted` is True. You can pass `enforce_sorted=False` " "to pack_padded_sequence and/or pack_sequence to sidestep this " "requirement if you do not need ONNX exportability."); @@ -188,7 +188,7 @@ std::tuple _pad_packed_sequence(const Tensor& data, const Tensor } int64_t dec = prev_batch_size - batch_size; if (dec > 0) { - for (C10_UNUSED const auto j : c10::irange(dec)) { + for ([[maybe_unused]] const auto j : c10::irange(dec)) { (*lengths--) = i; } } diff --git a/aten/src/ATen/native/Pow.h b/aten/src/ATen/native/Pow.h index 76ddda846a59a..f0e03b13f8f23 100644 --- a/aten/src/ATen/native/Pow.h +++ b/aten/src/ATen/native/Pow.h @@ -23,7 +23,7 @@ namespace native { // e.g. since 2**-1==0.5, the truncated integral result is zero. 1**negative_exponent is the // only non-zero result. template ::value, T>::type* = nullptr> + std::enable_if_t, T>* = nullptr> inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) { T result = 1; while (b) { @@ -37,13 +37,13 @@ inline HOST_DEVICE __ubsan_ignore_signed_int_overflow__ T powi_impl(T a, T b) { } template ::value && !std::is_signed::value, T>::type* = nullptr> + std::enable_if_t && !std::is_signed_v, T>* = nullptr> inline HOST_DEVICE T powi(T a, T b) { return powi_impl(a, b); } template ::value && std::is_signed::value, T>::type* = nullptr> + std::enable_if_t && std::is_signed_v, T>* = nullptr> inline HOST_DEVICE T powi(T a, T b) { if ( b < 0 ) { if ( a == 1 ) { diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp index 9db7b4cb7da09..00e3739539835 100644 --- a/aten/src/ATen/native/RNN.cpp +++ b/aten/src/ATen/native/RNN.cpp @@ -179,7 +179,7 @@ struct CellParams : public CellParamsBase { const Tensor& _b_ih, const Tensor& _b_hh, const Tensor& _w_hr) - : w_ih(_w_ih), w_hh(_w_hh), b_ih_(_b_ih), b_hh_(_b_hh), w_hr(_w_hr) {}; + : w_ih(_w_ih), w_hh(_w_hh), b_ih_(_b_ih), b_hh_(_b_hh), w_hr(_w_hr) {} const Tensor& w_ih; const Tensor& w_hh; @@ -825,7 +825,7 @@ struct FullLayer : Layer { using unstacked_output_type = LayerOutput, hidden_type>; FullLayer(Cell& cell) - : cell_(cell) {}; + : cell_(cell) {} unstacked_output_type operator()( const std::vector& step_inputs, @@ -870,7 +870,7 @@ struct FullBidirectionalLayer using output_type = typename Layer::output_type; FullBidirectionalLayer(Cell& cell) - : layer_(cell) {}; + : layer_(cell) {} output_type operator()( const Tensor& input, @@ -922,7 +922,7 @@ struct PackedLayer : Layer { typename Layer::output_type; PackedLayer(Cell& cell) - : cell_(cell) {}; + : cell_(cell) {} output_type operator()( const PackedSequence& input, @@ -983,7 +983,7 @@ struct ReversedPackedLayer : Layer { typename Layer::output_type; ReversedPackedLayer(Cell& cell) - : cell_(cell) {}; + : cell_(cell) {} output_type operator()( const PackedSequence& input, @@ -1040,7 +1040,7 @@ struct PackedBidirectionalLayer typename Layer::output_type; PackedBidirectionalLayer(Cell& cell) - : layer_(cell), rev_layer_(cell) {}; + : layer_(cell), rev_layer_(cell) {} output_type operator()( const PackedSequence& input, @@ -1529,7 +1529,7 @@ std::tuple lstm_cell( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned b_ih_maybe_owned = at::borrow_from_optional_tensor(b_ih_opt); const Tensor& b_ih = *b_ih_maybe_owned; - const Tensor& b_hh = c10::value_or_else(b_hh_opt, [] {return Tensor();}); + const Tensor& b_hh = b_hh_opt.value_or(Tensor()); TORCH_CHECK(hx.size() == 2, "lstm_cell expects two hidden states"); check_rnn_cell_forward_input(input, w_ih.sym_size(1)); @@ -1549,9 +1549,9 @@ _thnn_differentiable_lstm_cell_backward( const std::optional& grad_hy_op // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned grad_hy_maybe_owned = at::borrow_from_optional_tensor(grad_hy_opt); const Tensor& grad_hy = *grad_hy_maybe_owned; - const Tensor& grad_cy = c10::value_or_else(grad_cy_opt, [] {return Tensor();}); - const Tensor& input_bias = c10::value_or_else(input_bias_opt, [] {return Tensor();}); - const Tensor& hidden_bias = c10::value_or_else(hidden_bias_opt, [] {return Tensor();}); + const Tensor& grad_cy = grad_cy_opt.value_or(Tensor()); + const Tensor& input_bias = input_bias_opt.value_or(Tensor()); + const Tensor& hidden_bias = hidden_bias_opt.value_or(Tensor()); if (!grad_hy.defined() && !grad_cy.defined()) { return std::tuple(); @@ -1603,7 +1603,7 @@ std::tuple _thnn_differentiable_gru_cell // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned input_bias_maybe_owned = at::borrow_from_optional_tensor(input_bias_opt); const Tensor& input_bias = *input_bias_maybe_owned; - const Tensor& hidden_bias = c10::value_or_else(hidden_bias_opt, [] {return Tensor();}); + const Tensor& hidden_bias = hidden_bias_opt.value_or(Tensor()); Tensor in_g = input_gates; Tensor h_g = hidden_gates; @@ -1643,7 +1643,7 @@ Tensor gru_cell( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned b_ih_maybe_owned = at::borrow_from_optional_tensor(b_ih_opt); const Tensor& b_ih = *b_ih_maybe_owned; - const Tensor& b_hh = c10::value_or_else(b_hh_opt, [] {return Tensor();}); + const Tensor& b_hh = b_hh_opt.value_or(Tensor()); check_rnn_cell_forward_input(input, w_ih.size(1)); check_rnn_cell_forward_hidden(input, hx, w_hh.size(1), 0); @@ -1657,7 +1657,7 @@ Tensor rnn_tanh_cell( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned b_ih_maybe_owned = at::borrow_from_optional_tensor(b_ih_opt); const Tensor& b_ih = *b_ih_maybe_owned; - const Tensor& b_hh = c10::value_or_else(b_hh_opt, [] {return Tensor();}); + const Tensor& b_hh = b_hh_opt.value_or(Tensor()); static at::Tensor undefined; check_rnn_cell_forward_input(input, w_ih.size(1)); @@ -1671,7 +1671,7 @@ Tensor rnn_relu_cell( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned b_ih_maybe_owned = at::borrow_from_optional_tensor(b_ih_opt); const Tensor& b_ih = *b_ih_maybe_owned; - const Tensor& b_hh = c10::value_or_else(b_hh_opt, [] {return Tensor();}); + const Tensor& b_hh = b_hh_opt.value_or(Tensor()); static at::Tensor undefined; check_rnn_cell_forward_input(input, w_ih.size(1)); @@ -1889,7 +1889,8 @@ static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_tanh_cell_dynamic, simple namespace { -static C10_UNUSED auto ensure_linear_params_registered = register_linear_params(); +[[maybe_unused]] static auto ensure_linear_params_registered = + register_linear_params(); static auto cell_params_base_registry = torch::selective_class_("rnn", TORCH_SELECTIVE_CLASS("CellParamsBase")) diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index e9201df2f3365..9dcb7cd3570a6 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -753,11 +753,11 @@ Tensor cumprod_backward(const Tensor& grad, const Tensor& input, int64_t dim, co namespace { #ifdef _MSC_VER template -inline typename std::enable_if::value, bool>::type isnan_(T x) { +inline std::enable_if_t, bool> isnan_(T x) { return false; } template -inline typename std::enable_if::value, bool>::type isnan_(T x) { +inline std::enable_if_t, bool> isnan_(T x) { return std::isnan(x); } #else @@ -931,7 +931,7 @@ static inline Tensor diff_helper(const Tensor& self, int64_t n, int64_t dim) { bool is_kBool = (self.dtype() == at::kBool); n = n > self.sym_size(dim) ? self.sym_size(dim).guard_int(__FILE__, __LINE__) : n; - for (C10_UNUSED const auto i : c10::irange(n)) { + for ([[maybe_unused]] const auto i : c10::irange(n)) { if (is_kBool) { result = at::logical_xor( at::narrow_symint(result, dim, 1, out_len), @@ -2255,7 +2255,7 @@ bool cpu_equal(const Tensor& self, const Tensor& other) { return; } char* self_data = data[0]; - for (C10_UNUSED const auto i : c10::irange(dim_size)) { + for ([[maybe_unused]] const auto i : c10::irange(dim_size)) { if (isnan_(c10::load(self_data))) { result = false; return; @@ -2282,7 +2282,7 @@ bool cpu_equal(const Tensor& self, const Tensor& other) { } char* self_data = data[0]; char* other_data = data[1]; - for (C10_UNUSED const auto i : c10::irange(dim_size)) { + for ([[maybe_unused]] const auto i : c10::irange(dim_size)) { if (c10::load(self_data) != c10::load(other_data)) { result = false; return; diff --git a/aten/src/ATen/native/ReduceOpsUtils.h b/aten/src/ATen/native/ReduceOpsUtils.h index 928853ed44ca5..fa8de9c10a967 100644 --- a/aten/src/ATen/native/ReduceOpsUtils.h +++ b/aten/src/ATen/native/ReduceOpsUtils.h @@ -207,9 +207,13 @@ inline TensorIterator make_reduction( return TensorIterator::reduce_op(viewed_result, self.to(in_dtype)); } -inline C10_UNUSED TensorIterator make_reduction( - const char* name, Tensor& result, const Tensor& self, - at::OptionalIntArrayRef dim, bool keepdim, ScalarType out_dtype) { +[[maybe_unused]] inline TensorIterator make_reduction( + const char* name, + Tensor& result, + const Tensor& self, + at::OptionalIntArrayRef dim, + bool keepdim, + ScalarType out_dtype) { // special case for type promotion in mixed precision, improves computational // efficiency. // not generalize this to common mismatched input/output types to avoid cross @@ -259,9 +263,14 @@ inline TensorIterator make_reduction( return TensorIterator::reduce_op(viewed_result1, viewed_result2, self.to(dtype1)); } -inline C10_UNUSED TensorIterator make_reduction( - const char* name, Tensor& result1, Tensor& result2, const Tensor& self, - at::OptionalIntArrayRef dim, bool keepdim, ScalarType dtype) { +[[maybe_unused]] inline TensorIterator make_reduction( + const char* name, + Tensor& result1, + Tensor& result2, + const Tensor& self, + at::OptionalIntArrayRef dim, + bool keepdim, + ScalarType dtype) { return make_reduction(name, result1, result2, self, dim, keepdim, dtype, dtype); } @@ -313,9 +322,13 @@ inline std::vector get_zero_numel_tensor_size( // This function should be called when you are reducing a zero-numel tensor and want to // resize the output and return it. This function exists for resizing zero-numel // tensors when the size of the reduction dimension is non-zero. -inline C10_UNUSED void zero_numel_tensor_resize(Tensor& result, Tensor& result_indices, - const Tensor& self, const int64_t dim, - const bool keepdim, const char *fn_name) { +[[maybe_unused]] inline void zero_numel_tensor_resize( + Tensor& result, + Tensor& result_indices, + const Tensor& self, + const int64_t dim, + const bool keepdim, + const char* fn_name) { auto sizes = get_zero_numel_tensor_size(self, dim, keepdim, fn_name); at::native::resize_output(result, sizes); at::native::resize_output(result_indices, sizes); @@ -349,11 +362,11 @@ inline ScalarType get_dtype_from_result(Tensor& result, std::optional 2 || self.dim() < 1) { std::ostringstream ss; REPR(ss) << ": expected a 1D or 2D tensor"; - AT_ERROR(ss.str()); + TORCH_CHECK(false, ss.str()); } Tensor input = self; if (self.dim() == 1) { @@ -911,24 +911,24 @@ Tensor stft(const Tensor& self, const int64_t n_fft, const std::optional 0, but got hop_length=" << hop_length; - AT_ERROR(ss.str()); + TORCH_CHECK(false, ss.str()); } if (win_length <= 0 || win_length > n_fft) { std::ostringstream ss; REPR(ss) << ": expected 0 < win_length <= n_fft, but got win_length=" << win_length; - AT_ERROR(ss.str()); + TORCH_CHECK(false, ss.str()); } if (window.defined() && (window.dim() != 1 || window.size(0) != win_length)) { std::ostringstream ss; REPR(ss) << ": expected a 1D window tensor of size equal to win_length=" << win_length << ", but got window with size " << window.sizes(); - AT_ERROR(ss.str()); + TORCH_CHECK(false, ss.str()); } #undef REPR auto window_ = window; @@ -1063,17 +1063,17 @@ Tensor istft(const Tensor& self, const int64_t n_fft, const std::optional= 0"); + TORCH_CHECK(false, "minlength should be >= 0"); } if (self.dim() == 1 && self.numel() == 0) { return at::zeros({minlength}, kLong); } if (self.dim() != 1 || *self.min().data_ptr() < 0) { - AT_ERROR("bincount only supports 1-d non-negative integral inputs."); + TORCH_CHECK(false, "bincount only supports 1-d non-negative integral inputs."); } // Ensure max_val < 2 ^ 63 - 1 (9223372036854775807) auto max_val = *self.max().data_ptr(); if (max_val >= std::numeric_limits::max()) { - AT_ERROR( + TORCH_CHECK(false, "maximum value of input overflowed, it should be < ", std::numeric_limits::max(), " but got ", @@ -48,7 +48,7 @@ Tensor _bincount_cpu_template( bool has_weights = weights.defined(); if (has_weights && (weights.dim() != 1 || weights.size(0) != self.size(0))) { - AT_ERROR("weights should be 1-d and have the same length as input"); + TORCH_CHECK(false, "weights should be 1-d and have the same length as input"); } Tensor output; diff --git a/aten/src/ATen/native/TensorAdvancedIndexing.cpp b/aten/src/ATen/native/TensorAdvancedIndexing.cpp index 776ac9d073974..b2f7d78652552 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexing.cpp +++ b/aten/src/ATen/native/TensorAdvancedIndexing.cpp @@ -662,7 +662,7 @@ Tensor _unsafe_masked_index(const Tensor& self, const Tensor& mask, const torch: // with the main difference being that the when the `mask` is false, the tensor // `self` is not indexed using `indices`. This allows `indices` to be out-of-bounds // when `mask` is false. When `mask` is true, the `indices` are expected to be - // in bounds and is not checked. + // in bounds and is not checked. We also assume that the `indices` are non-negative // // This function is not meant to be executed on eager mode. An unoptimized version // is provided here. @@ -875,12 +875,8 @@ TORCH_IMPL_FUNC(index_copy_out) // See Note [Enabling Deterministic Operations] if (result.is_cuda() && globalContext().deterministicAlgorithms()){ torch::List> indices; - indices.reserve(dim + 1); - for (const auto i: c10::irange(dim)) { - (void)i; - indices.emplace_back(); - } - indices.emplace_back(index); + indices.resize(dim + 1); + indices.set(dim, index); result.index_put_(indices, source, false); return; } @@ -2413,7 +2409,7 @@ Tensor& nonzero_out_cpu(const Tensor& self, Tensor& result) { for (const auto i : c10::irange(n2)) { const char* ptr = data[0] + i * strides[1]; - for (C10_UNUSED const auto j : c10::irange(n1)) { + for ([[maybe_unused]] const auto j : c10::irange(n1)) { const auto& val = c10::load(ptr); // If nonzero, write index if (val != scalar_t(0)) { diff --git a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h index f9b616013ddb2..c6968521ae355 100644 --- a/aten/src/ATen/native/TensorAdvancedIndexingUtils.h +++ b/aten/src/ATen/native/TensorAdvancedIndexingUtils.h @@ -50,7 +50,8 @@ const Tensor& value){ } } } - for (C10_UNUSED const auto i : c10::irange(num_ind, self.ndimension())) { + for ([[maybe_unused]] const auto i : + c10::irange(num_ind, self.ndimension())) { mask = mask.unsqueeze(-1); } return std::make_tuple(true, mask); diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index c82e429621812..841194719c80f 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -584,8 +584,8 @@ std::tuple mode(const Tensor& self, int64_t dim, bool keepdim) { std::tuple mode_out(const Tensor& self, int64_t dim, bool keepdim, Tensor& values, Tensor& indices) { - TORCH_CHECK(self.device().is_cpu() || self.is_cuda(), - "mode only supports CPU AND CUDA device type, got: ", self.device().type()); + TORCH_CHECK(self.device().is_cpu() || self.is_cuda() || self.is_xpu(), + "mode only supports CPU, CUDA and XPU device type, got: ", self.device().type()); TORCH_CHECK(self.layout() == Layout::Strided, "mode only supports strided layout, got: ", self.layout()); TORCH_CHECK(self.device() == values.device(), diff --git a/aten/src/ATen/native/TensorConversions.cpp b/aten/src/ATen/native/TensorConversions.cpp index 22a576408bfbb..0c2ba79493ffa 100644 --- a/aten/src/ATen/native/TensorConversions.cpp +++ b/aten/src/ATen/native/TensorConversions.cpp @@ -588,7 +588,7 @@ Tensor to_dense_backward(const Tensor& grad, const Tensor& input_, std::optional case kMkldnn: return grad.to_mkldnn(input_.scalar_type()); default: - AT_ERROR("to_dense_backward: Unsupported input layout: ", input_layout); + TORCH_CHECK(false, "to_dense_backward: Unsupported input layout: ", input_layout); return Tensor{}; } } @@ -928,23 +928,23 @@ void _to_sparse_check_arguments(const std::string& funcname, const Tensor& self, auto layout_from_valid = layout_from == kStrided || layout_from == kSparse || at::sparse_csr::is_sparse_compressed(layout_from); if (!layout_from_valid) { - AT_ERROR(funcname, ": unexpected source layout ", layout_from); + TORCH_CHECK(false, funcname, ": unexpected source layout ", layout_from); } if (layout_from == kStrided) { if (sparse_dim == 0 && self.dim() > 0) { - AT_ERROR(funcname, ": sparse_dim argument must be in >0 when self.dim()>0"); + TORCH_CHECK(false, funcname, ": sparse_dim argument must be in >0 when self.dim()>0"); } if (sparse_dim < 0 || sparse_dim > self.dim()) { - AT_ERROR(funcname, ": sparse_dim argument must be in [0,", self.dim(), "] range, but ", sparse_dim, " is given"); + TORCH_CHECK(false, funcname, ": sparse_dim argument must be in [0,", self.dim(), "] range, but ", sparse_dim, " is given"); } } else if (layout_from == kSparse) { if (sparse_dim != self.sparse_dim()) { - AT_ERROR(funcname, ": conversion from ", layout_from, " to ", kSparse, " with sparse_dim argument !=self.sparse_dim() is not supported"); + TORCH_CHECK(false, funcname, ": conversion from ", layout_from, " to ", kSparse, " with sparse_dim argument !=self.sparse_dim() is not supported"); } } else if (at::sparse_csr::is_sparse_compressed(layout_from)) { if (sparse_dim != 2) { - AT_ERROR(funcname, ": conversion from ", layout_from, " to ", kSparse, " with sparse_dim argument !=2 is not supported"); + TORCH_CHECK(false, funcname, ": conversion from ", layout_from, " to ", kSparse, " with sparse_dim argument !=2 is not supported"); } } } @@ -956,40 +956,40 @@ void _to_sparse_check_arguments(const std::string& funcname, const Tensor& self, auto layout_from_valid = layout_from == kStrided || layout_from == kSparse || at::sparse_csr::is_sparse_compressed(layout_from); if (!layout_from_valid) { - AT_ERROR(funcname, ": unexpected source layout ", layout_from); + TORCH_CHECK(false, funcname, ": unexpected source layout ", layout_from); } auto layout_to_valid = layout_to == kStrided || layout_to == kSparse || at::sparse_csr::is_sparse_compressed(layout_to); if (!layout_to_valid) { - AT_ERROR(funcname, ": unexpected source layout ", layout_from); + TORCH_CHECK(false, funcname, ": unexpected source layout ", layout_from); } if (layout_from == kSparse && layout_to != kSparse) { if (self.sparse_dim() != 2) { - AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " for input tensors with sparse_dim()!=2 is not supported"); + TORCH_CHECK(false, funcname, ": conversion from ", layout_from, " to ", layout_to, " for input tensors with sparse_dim()!=2 is not supported"); } } if ((layout_from == kSparseCsr || layout_from == kSparseCsc) && (layout_to == kSparseBsr || layout_to == kSparseBsc)) { if (sparse_csr::numBatchDimensions(self) > 0) { - AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " for batched inputs is not supported"); + TORCH_CHECK(false, funcname, ": conversion from ", layout_from, " to ", layout_to, " for batched inputs is not supported"); } } if (blocksize.has_value()) { if (blocksize.value().size() != 2) { - AT_ERROR(funcname, ": blocksize needs to be a tuple of size 2, but got ", blocksize.value().size()); + TORCH_CHECK(false, funcname, ": blocksize needs to be a tuple of size 2, but got ", blocksize.value().size()); } auto blocksize_to = *blocksize; if (blocksize_to[0] <= 0 || blocksize_to[1] <= 0) { - AT_ERROR(funcname, ": blocksize needs to be positive, but got ", blocksize_to); + TORCH_CHECK(false, funcname, ": blocksize needs to be positive, but got ", blocksize_to); } if (layout_to == kSparseBsr || layout_to == kSparseBsc) { if (layout_from == kSparseBsr || layout_from == kSparseBsc) { auto blocksize_from = at::sparse_csr::getBlockSize(self); if (!(blocksize_to == blocksize_from)) { - AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " with blocksize changed from ", blocksize_from, " to ", blocksize_to, " is not supported"); + TORCH_CHECK(false, funcname, ": conversion from ", layout_from, " to ", layout_to, " with blocksize changed from ", blocksize_from, " to ", blocksize_to, " is not supported"); } } else { auto dense_dim = (layout_from == kStrided) ? dense_dim_opt.value_or(0) : self.dense_dim(); @@ -997,35 +997,35 @@ void _to_sparse_check_arguments(const std::string& funcname, const Tensor& self, auto sparse_col_dim = -(dense_dim + 1); if ((self.size(sparse_row_dim) % blocksize_to[0] != 0) || (self.size(sparse_col_dim) % blocksize_to[1] != 0)) { - AT_ERROR(funcname, ": tensor sparse size (", self.size(sparse_row_dim), ",", self.size(sparse_row_dim), ") must be divisible by given blocksize (", blocksize_to[0], ",", blocksize_to[1], ")"); + TORCH_CHECK(false, funcname, ": tensor sparse size (", self.size(sparse_row_dim), ",", self.size(sparse_row_dim), ") must be divisible by given blocksize (", blocksize_to[0], ",", blocksize_to[1], ")"); } } } else { - AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " with blocksize argument given is not supported"); + TORCH_CHECK(false, funcname, ": conversion from ", layout_from, " to ", layout_to, " with blocksize argument given is not supported"); } } else { if ((layout_to == kSparseBsr || layout_to == kSparseBsc) && !(layout_from == kSparseBsr && layout_from == kSparseBsc)) { - AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " without blocksize argument given is not supported"); + TORCH_CHECK(false, funcname, ": conversion from ", layout_from, " to ", layout_to, " without blocksize argument given is not supported"); } } if (dense_dim_opt.has_value()) { if (layout_from != kStrided) { - AT_ERROR(funcname, ": conversion from ", layout_from, " to ", layout_to, " with dense_dim argument given is not supported"); + TORCH_CHECK(false, funcname, ": conversion from ", layout_from, " to ", layout_to, " with dense_dim argument given is not supported"); } auto dense_dim = *dense_dim_opt; if (layout_to == kSparse) { if (dense_dim == self.dim() && self.dim() > 0) { - AT_ERROR(funcname, ": dense_dim argument must be !=self.dim() when self.dim()>0"); + TORCH_CHECK(false, funcname, ": dense_dim argument must be !=self.dim() when self.dim()>0"); } if (dense_dim < 0 || dense_dim > self.dim()) { - AT_ERROR(funcname, ": dense_dim argument must be in [0,", self.dim(), "] range, but ", dense_dim, " is given"); + TORCH_CHECK(false, funcname, ": dense_dim argument must be in [0,", self.dim(), "] range, but ", dense_dim, " is given"); } } else { if (dense_dim < 0 || dense_dim > self.dim() - 2) { - AT_ERROR(funcname, ": dense_dim argument must be in [0,", self.dim() - 2, "] range, but ", dense_dim, " is given"); + TORCH_CHECK(false, funcname, ": dense_dim argument must be in [0,", self.dim() - 2, "] range, but ", dense_dim, " is given"); } } } @@ -1129,7 +1129,7 @@ Tensor dense_to_sparse_with_mask(const Tensor& self, const Tensor& mask, std::op break; } - AT_ERROR("dense_to_sparse_with_mask: ", self.layout(), " to ", layout_to, " conversion not supported"); + TORCH_CHECK(false, "dense_to_sparse_with_mask: ", self.layout(), " to ", layout_to, " conversion not supported"); return Tensor{}; } @@ -1181,7 +1181,7 @@ Tensor dense_to_sparse(const Tensor& self, std::optional layout, Op break; } - AT_ERROR("dense_to_sparse: ", self.layout(), " to ", layout_to, " conversion not supported"); + TORCH_CHECK(false, "dense_to_sparse: ", self.layout(), " to ", layout_to, " conversion not supported"); return Tensor{}; } @@ -1440,7 +1440,7 @@ Tensor sparse_compressed_to_sparse_csr(const Tensor& self, std::optional layou break; } - AT_ERROR("sparse_coo_to_sparse: ", self.layout(), " to ", layout_to, " conversion not supported"); + TORCH_CHECK(false, "sparse_coo_to_sparse: ", self.layout(), " to ", layout_to, " conversion not supported"); return Tensor{}; } diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h index 32d0a1dc53561..80c81e3aeb8fc 100644 --- a/aten/src/ATen/native/TensorFactories.h +++ b/aten/src/ATen/native/TensorFactories.h @@ -106,7 +106,7 @@ inline Tensor& fill_empty_deterministic_(Tensor& tensor) { AT_DISPATCH_V2( tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() { tensor.fill_(std::numeric_limits::quiet_NaN()); - }), AT_EXPAND(AT_FLOATING_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), AT_EXPAND(AT_FLOAT8_TYPES), kBFloat16, kHalf); + }), AT_EXPAND(AT_FLOATING_TYPES), AT_EXPAND(AT_COMPLEX_TYPES), AT_EXPAND(AT_FLOAT8_TYPES), kBFloat16, kHalf, kComplexHalf); } else { AT_DISPATCH_V2( tensor.scalar_type(), "fill_empty_deterministic_", AT_WRAP([&]() { diff --git a/aten/src/ATen/native/TensorProperties.cpp b/aten/src/ATen/native/TensorProperties.cpp index 95c88f4572cbd..a7f5352aae890 100644 --- a/aten/src/ATen/native/TensorProperties.cpp +++ b/aten/src/ATen/native/TensorProperties.cpp @@ -101,7 +101,7 @@ bool cudnn_is_acceptable(const Tensor& self) { Tensor & detach_(Tensor & self) { // this just exists to give us a hook in VariableType and an entry in Declarations.yaml - //AT_ERROR("detach_ is not implemented for Tensor"); + //TORCH_CHECK(false, "detach_ is not implemented for Tensor"); return self; } diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 59598108f2cf1..5d8af650ccf2d 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -216,15 +217,6 @@ #include namespace at::meta { -inline void cat_check_no_zero_dim(const MaterializedITensorListRef& tensors) { - size_t i = 0; - for (const Tensor& t : tensors) { - TORCH_CHECK( - t.dim() > 0, - "zero-dimensional tensor (at position ", i, ") cannot be concatenated"); - i++; - } -} inline c10::MemoryFormat cat_compute_output_memory_format(const MaterializedITensorListRef& inputs) { std::optional format = std::nullopt; @@ -248,7 +240,7 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) { // size (i.e. other empty sizes are not skipped). auto materialized = tensors.materialize(); - cat_check_no_zero_dim(materialized); + native::check_cat_no_zero_dim(materialized); dim = at::legacy_cat_wrap_dim(dim, materialized); // Checking names before the actual dimensions. @@ -1954,7 +1946,7 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in at::parallel_for(0, index_len, at::internal::GRAIN_SIZE, [&](int64_t start, int64_t end) { const auto* src = ptr_index + start; auto* dst = ptr_nneg_index + start; - for (C10_UNUSED const auto _ : c10::irange(start, end)) { + for ([[maybe_unused]] const auto _ : c10::irange(start, end)) { auto idx = *src++; if (idx < -size || idx >= size) { // Mark self and dim as used if code is compiled with STRIP_ERROR_MESSAGES @@ -2060,36 +2052,42 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in const auto* ptr_sorted_start = ptr_sorted; const auto* ptr_sorted_end = ptr_sorted + sorted_len; - at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) { - const auto start = tid * chunk_size_src; - const auto end = std::min(start + chunk_size_src, src_len); - auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).data_ptr(); - auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).data_ptr(); - auto* ptr_tid_int_counts = int_counts.select(0, tid).data_ptr(); - const auto* ptr_src = src.const_data_ptr() + start; - - for (const auto i : c10::irange(start, end)) { - const auto src_val = *ptr_src++; - const auto src_val_lb = std::lower_bound(ptr_sorted_start, ptr_sorted_end, src_val); - // We cannot just use *src_val_lb != src_val because when - // src_val_lb == ptr_sorted_end, dereferencing past-the-end value - // is not well-defined. - if (src_val_lb == ptr_sorted_end || *src_val_lb != src_val) { - ++ptr_tid_src_int_idx; - ++ptr_tid_sorted_int_idx; - ++ptr_tid_int_counts; - continue; + at::parallel_for( + 0, n_threads_src, 1, [&](int64_t tid, [[maybe_unused]] int64_t _) { + const auto start = tid * chunk_size_src; + const auto end = std::min(start + chunk_size_src, src_len); + auto* ptr_tid_src_int_idx = + src_int_idx.select(0, tid).data_ptr(); + auto* ptr_tid_sorted_int_idx = + sorted_int_idx.select(0, tid).data_ptr(); + auto* ptr_tid_int_counts = + int_counts.select(0, tid).data_ptr(); + const auto* ptr_src = src.const_data_ptr() + start; + + for (const auto i : c10::irange(start, end)) { + const auto src_val = *ptr_src++; + const auto src_val_lb = + std::lower_bound(ptr_sorted_start, ptr_sorted_end, src_val); + // We cannot just use *src_val_lb != src_val because when + // src_val_lb == ptr_sorted_end, dereferencing past-the-end + // value is not well-defined. + if (src_val_lb == ptr_sorted_end || *src_val_lb != src_val) { + ++ptr_tid_src_int_idx; + ++ptr_tid_sorted_int_idx; + ++ptr_tid_int_counts; + continue; + } + const auto src_val_ub = + std::upper_bound(ptr_sorted_start, ptr_sorted_end, src_val); + + const int64_t count = src_val_ub - src_val_lb; + const int64_t j = src_val_lb - ptr_sorted_start; + + *ptr_tid_src_int_idx++ = i; + *ptr_tid_sorted_int_idx++ = j; + *ptr_tid_int_counts++ = count; } - const auto src_val_ub = std::upper_bound(ptr_sorted_start, ptr_sorted_end, src_val); - - const int64_t count = src_val_ub - src_val_lb; - const int64_t j = src_val_lb - ptr_sorted_start; - - *ptr_tid_src_int_idx++ = i; - *ptr_tid_sorted_int_idx++ = j; - *ptr_tid_int_counts++ = count; - } - }); + }); } const auto compressed_int_counts = int_counts.sum(-1); @@ -2120,29 +2118,35 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in const auto thread_offsets = compressed_int_counts.cumsum(0).sub_(compressed_int_counts); const auto* ptr_sorted_idx = sorted_idx.const_data_ptr(); - at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) { - const auto start = tid * chunk_size_src; - const auto end = std::min(start + chunk_size_src, src_len); - const auto tid_offset = thread_offsets.const_data_ptr()[tid]; - const auto* ptr_tid_src_int_idx = src_int_idx.select(0, tid).const_data_ptr(); - const auto* ptr_tid_sorted_int_idx = sorted_int_idx.select(0, tid).const_data_ptr(); - const auto* ptr_tid_int_counts = int_counts.select(0, tid).const_data_ptr(); - auto* ptr_tid_selected_sorted = ptr_selected_sorted + tid_offset; - auto* ptr_tid_selected_src = ptr_selected_src + tid_offset; - - for (C10_UNUSED const auto _ : c10::irange(start, end)) { - const auto count = *ptr_tid_int_counts++; - const auto i = *ptr_tid_src_int_idx++; - const auto j = *ptr_tid_sorted_int_idx++; - if (!count) continue; - - std::fill_n(ptr_tid_selected_src, count, i); - std::copy_n(ptr_sorted_idx + j, count, ptr_tid_selected_sorted); - - ptr_tid_selected_sorted += count; - ptr_tid_selected_src += count; - } - }); + at::parallel_for( + 0, n_threads_src, 1, [&](int64_t tid, [[maybe_unused]] int64_t _) { + const auto start = tid * chunk_size_src; + const auto end = std::min(start + chunk_size_src, src_len); + const auto tid_offset = + thread_offsets.const_data_ptr()[tid]; + const auto* ptr_tid_src_int_idx = + src_int_idx.select(0, tid).const_data_ptr(); + const auto* ptr_tid_sorted_int_idx = + sorted_int_idx.select(0, tid).const_data_ptr(); + const auto* ptr_tid_int_counts = + int_counts.select(0, tid).const_data_ptr(); + auto* ptr_tid_selected_sorted = ptr_selected_sorted + tid_offset; + auto* ptr_tid_selected_src = ptr_selected_src + tid_offset; + + for ([[maybe_unused]] const auto _ : c10::irange(start, end)) { + const auto count = *ptr_tid_int_counts++; + const auto i = *ptr_tid_src_int_idx++; + const auto j = *ptr_tid_sorted_int_idx++; + if (!count) + continue; + + std::fill_n(ptr_tid_selected_src, count, i); + std::copy_n(ptr_sorted_idx + j, count, ptr_tid_selected_sorted); + + ptr_tid_selected_sorted += count; + ptr_tid_selected_src += count; + } + }); } return search_in_dim_indices @@ -2201,7 +2205,7 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in else { auto* ptr_counts = counts.data_ptr(); const auto* ptr_vals = t.const_data_ptr(); - for (C10_UNUSED const auto _ : c10::irange(t.numel())) { + for ([[maybe_unused]] const auto _ : c10::irange(t.numel())) { ++ptr_counts[*ptr_vals++]; } } @@ -2221,14 +2225,19 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in const auto run_in_parallel = (n_threads == 1); auto counts_per_thread = at::zeros({n_threads, size}, idx.options()); - at::parallel_for(0, n_threads, 1, [&](int64_t tid, C10_UNUSED int64_t _) { - const auto start = tid * chunk_size; - const auto end = std::min(start + chunk_size, idx_len); - const auto tid_idx = idx.slice(0, start, end); - auto tid_counts = counts_per_thread.select(0, tid); - get_counts(tid_counts, tid_idx, /*bins=*/size, - /*is_sorted=*/is_sorted, /*run_in_parallel=*/run_in_parallel); - }); + at::parallel_for( + 0, n_threads, 1, [&](int64_t tid, [[maybe_unused]] int64_t _) { + const auto start = tid * chunk_size; + const auto end = std::min(start + chunk_size, idx_len); + const auto tid_idx = idx.slice(0, start, end); + auto tid_counts = counts_per_thread.select(0, tid); + get_counts( + tid_counts, + tid_idx, + /*bins=*/size, + /*is_sorted=*/is_sorted, + /*run_in_parallel=*/run_in_parallel); + }); return counts_per_thread; }; @@ -2319,32 +2328,38 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in 1, std::min((src_len + grain_size - 1) / grain_size, at::get_num_threads()) ); const auto chunk_size = (src_len + n_threads_src - 1) / n_threads_src; - at::parallel_for(0, n_threads_src, 1, [&](int64_t tid, C10_UNUSED int64_t _) { - const auto start = tid * chunk_size; - const auto end = std::min(start + chunk_size, src_len); - auto* ptr_src_tid = ptr_src + start; - const auto* ptr_src_counts_per_thread - = src_counts_per_thread.select(0, tid).const_data_ptr(); - const auto* ptr_src_offset_counts_per_thread - = src_offset_counts_per_thread.select(0, tid).const_data_ptr(); - auto tid_counts = at::zeros({size}, src.options()); - auto* ptr_tid_counts = tid_counts.data_ptr(); - - for (const auto i : c10::irange(start, end)) { - const auto idx_val = *ptr_src_tid++; - // skip idx value if not in the intersection - if (!ptr_intersection_counts[idx_val]) continue; - const auto idx_val_offset - = ptr_src_intersection_offsets[idx_val] - - ptr_src_intersection_counts[idx_val]; - const auto idx_val_tid_offset - = ptr_src_offset_counts_per_thread[idx_val] - - ptr_src_counts_per_thread[idx_val]; - auto& idx_val_local_tid_count = ptr_tid_counts[idx_val]; - ptr_src_idx[idx_val_offset + idx_val_tid_offset + idx_val_local_tid_count] = i; - ++idx_val_local_tid_count; - } - }); + at::parallel_for( + 0, n_threads_src, 1, [&](int64_t tid, [[maybe_unused]] int64_t _) { + const auto start = tid * chunk_size; + const auto end = std::min(start + chunk_size, src_len); + auto* ptr_src_tid = ptr_src + start; + const auto* ptr_src_counts_per_thread = + src_counts_per_thread.select(0, tid) + .const_data_ptr(); + const auto* ptr_src_offset_counts_per_thread = + src_offset_counts_per_thread.select(0, tid) + .const_data_ptr(); + auto tid_counts = at::zeros({size}, src.options()); + auto* ptr_tid_counts = tid_counts.data_ptr(); + + for (const auto i : c10::irange(start, end)) { + const auto idx_val = *ptr_src_tid++; + // skip idx value if not in the intersection + if (!ptr_intersection_counts[idx_val]) + continue; + const auto idx_val_offset = + ptr_src_intersection_offsets[idx_val] - + ptr_src_intersection_counts[idx_val]; + const auto idx_val_tid_offset = + ptr_src_offset_counts_per_thread[idx_val] - + ptr_src_counts_per_thread[idx_val]; + auto& idx_val_local_tid_count = ptr_tid_counts[idx_val]; + ptr_src_idx + [idx_val_offset + idx_val_tid_offset + + idx_val_local_tid_count] = i; + ++idx_val_local_tid_count; + } + }); const auto src_idx_offsets = src_intersection_offsets.sub_(src_intersection_counts); @@ -2378,26 +2393,28 @@ Tensor index_select_sparse_cpu(const Tensor& self, int64_t dim, const Tensor& in 1, std::min((idx_len + grain_size - 1) / grain_size, at::get_num_threads()) ); const auto chunk_size = (idx_len + n_threads_idx - 1) / n_threads_idx; - at::parallel_for(0, n_threads_idx, 1, [&](int64_t tid, C10_UNUSED int64_t _) { - const auto start = tid * chunk_size; - const auto end = std::min(start + chunk_size, idx_len); - const auto tid_offset = ptr_thread_offset[tid]; - const auto* ptr_idx_tid = ptr_idx + start; - auto* ptr_idx_selected_tid = ptr_idx_selected + tid_offset; - auto* ptr_src_selected_tid = ptr_src_selected + tid_offset; - - for (const auto i : c10::irange(start, end)) { - const auto idx_val = *ptr_idx_tid++; - // skip if idx_val is not in the intersection - if (!ptr_intersection_counts[idx_val]) continue; - const auto count = ptr_src_counts[idx_val]; - const auto j = ptr_src_idx_offsets[idx_val]; - std::fill_n(ptr_idx_selected_tid, count, i); - std::copy_n(ptr_src_idx + j, count, ptr_src_selected_tid); - ptr_idx_selected_tid += count; - ptr_src_selected_tid += count; - } - }); + at::parallel_for( + 0, n_threads_idx, 1, [&](int64_t tid, [[maybe_unused]] int64_t _) { + const auto start = tid * chunk_size; + const auto end = std::min(start + chunk_size, idx_len); + const auto tid_offset = ptr_thread_offset[tid]; + const auto* ptr_idx_tid = ptr_idx + start; + auto* ptr_idx_selected_tid = ptr_idx_selected + tid_offset; + auto* ptr_src_selected_tid = ptr_src_selected + tid_offset; + + for (const auto i : c10::irange(start, end)) { + const auto idx_val = *ptr_idx_tid++; + // skip if idx_val is not in the intersection + if (!ptr_intersection_counts[idx_val]) + continue; + const auto count = ptr_src_counts[idx_val]; + const auto j = ptr_src_idx_offsets[idx_val]; + std::fill_n(ptr_idx_selected_tid, count, i); + std::copy_n(ptr_src_idx + j, count, ptr_src_selected_tid); + ptr_idx_selected_tid += count; + ptr_src_selected_tid += count; + } + }); return std::make_tuple(idx_selected, src_selected); }(); @@ -4055,29 +4072,41 @@ void split_copy_Tensor_out(const at::Tensor & self, int64_t split_size, int64_t } } -void split_with_sizes_copy_out(const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) { - auto tmp = self.split_with_sizes(split_sizes, dim); +namespace { - TORCH_CHECK(out.size() == tmp.size(), "split_with_sizes_copy_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size()); +void copy_tensor_array_to_out(const char* name, const std::vector& array, at::TensorList out) { + TORCH_CHECK(out.size() == array.size(), name, " expected an out= argument of size ", array.size(), ", got size ", out.size()); for (const auto i : c10::irange(out.size())) { - if (resize_output_check(out[i], tmp[i].sizes())) { - out[i].resize_(tmp[i].sizes()); + if (resize_output_check(out[i], array[i].sizes())) { + out[i].resize_(array[i].sizes()); } - TORCH_CHECK(out[i].dtype() == tmp[i].dtype(), - "Expected out tensor to have dtype ", tmp[i].dtype(), ", but got ", out[i].dtype(), " instead"); - TORCH_CHECK(out[i].device() == tmp[i].device(), - "Expected out tensor to have device ", tmp[i].device(), ", but got ", out[i].device(), " instead"); - out[i].copy_(tmp[i]); + TORCH_CHECK(out[i].dtype() == array[i].dtype(), + "Expected out tensor to have dtype ", array[i].dtype(), ", but got ", out[i].dtype(), " instead"); + TORCH_CHECK(out[i].device() == array[i].device(), + "Expected out tensor to have device ", array[i].device(), ", but got ", out[i].device(), " instead"); + out[i].copy_(array[i]); } } -void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList out) { - auto tmp = self.unbind(dim); +} - TORCH_CHECK(out.size() == tmp.size(), "unbind_copy_int_out() expected an out= argument of size ", tmp.size(), ", got size ", out.size()); - for (const auto i : c10::irange(out.size())) { - out[i].copy_(tmp[i]); +void split_with_sizes_copy_out(const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) { + auto tmp = self.split_with_sizes(split_sizes, dim); + copy_tensor_array_to_out("split_with_sizes_copy_out()", tmp, out); +} + +void unbind_copy_int_out(const at::Tensor & self, int64_t dim, at::TensorList out) { + if (at::GradMode::is_enabled()) { + for (const auto i : c10::irange(out.size())) { + TORCH_CHECK(!out[i].requires_grad(), + "unbind_copy(): functions with out=... arguments don't support automatic differentiation, " + "but one of the arguments requires grad." + ); + } } + + auto tmp = self.unbind(dim); + copy_tensor_array_to_out("unbind_copy_int_out()", tmp, out); } int64_t sparse_dim_default(const Tensor& self) { diff --git a/aten/src/ATen/native/TensorShape.h b/aten/src/ATen/native/TensorShape.h index c35023d076e73..160fe254587d3 100644 --- a/aten/src/ATen/native/TensorShape.h +++ b/aten/src/ATen/native/TensorShape.h @@ -30,7 +30,7 @@ inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & seco } inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) { - int64_t i = 0; + [[maybe_unused]] int64_t i = 0; for(const Tensor& t : tensors) { TORCH_CHECK(t.dim() > 0, "zero-dimensional tensor (at position ", i, ") cannot be concatenated"); diff --git a/aten/src/ATen/native/UnfoldBackward.h b/aten/src/ATen/native/UnfoldBackward.h index 44e05c125913e..4a675a7623b43 100644 --- a/aten/src/ATen/native/UnfoldBackward.h +++ b/aten/src/ATen/native/UnfoldBackward.h @@ -29,13 +29,12 @@ namespace { // grad_in does not mean that it is a gradient wrt to input, // grad_in/grad_out is just an input/output of unfold_backward kernel. -static C10_UNUSED TensorIterator _make_unfold_backward_iter_over_grad_out( - Tensor& grad_out, - const Tensor& grad_in, - int64_t dim, - int64_t size, - int64_t step -) { +[[maybe_unused]] static TensorIterator _make_unfold_backward_iter_over_grad_out( + Tensor& grad_out, + const Tensor& grad_in, + int64_t dim, + int64_t size, + int64_t step) { dim = maybe_wrap_dim(dim, grad_out.dim()); // last dim stores the folds @@ -106,7 +105,6 @@ static C10_UNUSED TensorIterator _make_unfold_backward_iter_over_grad_out( return iter; } - } } // namespace at::native diff --git a/aten/src/ATen/native/UpSample.h b/aten/src/ATen/native/UpSample.h index 033ef2b7fad3e..769201804eafa 100644 --- a/aten/src/ATen/native/UpSample.h +++ b/aten/src/ATen/native/UpSample.h @@ -103,7 +103,9 @@ DECLARE_DISPATCH(upsampling_bicubic2d, upsample_bicubic2d_kernel); DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_kernel); DECLARE_DISPATCH(_upsampling_bicubic2d_aa, _upsample_bicubic2d_aa_backward_kernel); -inline C10_UNUSED std::array upsample_1d_common_check(IntArrayRef input_size, IntArrayRef output_size) { +[[maybe_unused]] inline std::array upsample_1d_common_check( + IntArrayRef input_size, + IntArrayRef output_size) { TORCH_CHECK( output_size.size() == 1, "It is expected output_size equals to 1, but got size ", @@ -131,7 +133,9 @@ inline C10_UNUSED std::array upsample_1d_common_check(IntArrayRef in return {nbatch, channels, output_width}; } -inline C10_UNUSED std::array upsample_2d_common_check(IntArrayRef input_size, IntArrayRef output_size) { +[[maybe_unused]] inline std::array upsample_2d_common_check( + IntArrayRef input_size, + IntArrayRef output_size) { TORCH_CHECK( output_size.size() == 2, "It is expected output_size equals to 2, but got size ", @@ -167,8 +171,9 @@ inline C10_UNUSED std::array upsample_2d_common_check(IntArrayRef in return {nbatch, channels, output_height, output_width}; } -inline C10_UNUSED -std::array upsample_3d_common_check(IntArrayRef input_size, IntArrayRef output_size) { +[[maybe_unused]] inline std::array upsample_3d_common_check( + IntArrayRef input_size, + IntArrayRef output_size) { TORCH_CHECK( output_size.size() == 3, "It is expected output_size equals to 3, but got size ", diff --git a/aten/src/ATen/native/UpSampleBicubic2d.cpp b/aten/src/ATen/native/UpSampleBicubic2d.cpp index b8c14bcc0731b..44892ebd4aad8 100644 --- a/aten/src/ATen/native/UpSampleBicubic2d.cpp +++ b/aten/src/ATen/native/UpSampleBicubic2d.cpp @@ -150,13 +150,11 @@ static void upsample_bicubic2d_backward_out_frame( opmath_t t_y; guard_index_and_lambda(real_y, input_height, input_y, t_y); - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - opmath_t x_coeffs[4]; - // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays) - opmath_t y_coeffs[4]; + std::array x_coeffs; + std::array y_coeffs; - get_cubic_upsample_coefficients(x_coeffs, t_x); - get_cubic_upsample_coefficients(y_coeffs, t_y); + get_cubic_upsample_coefficients(x_coeffs.data(), t_x); + get_cubic_upsample_coefficients(y_coeffs.data(), t_y); opmath_t out_value = out[output_y * output_width + output_x]; for (const auto ii : c10::irange(4)) { diff --git a/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.cpp index cdbfda3c71bb4..a9645e776a025 100644 --- a/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.cpp @@ -40,7 +40,6 @@ int register_linear_params() { } namespace { -static C10_UNUSED auto linear_params = register_linear_params(); -} // namespace - +[[maybe_unused]] static auto linear_params = register_linear_params(); +} // namespace }} // namespace ao::sparse diff --git a/aten/src/ATen/native/cpu/BlasKernel.cpp b/aten/src/ATen/native/cpu/BlasKernel.cpp index dd03756982193..d2475093bdd75 100644 --- a/aten/src/ATen/native/cpu/BlasKernel.cpp +++ b/aten/src/ATen/native/cpu/BlasKernel.cpp @@ -96,7 +96,7 @@ auto sum(int64_t N, Func f) { } template -typename std::enable_if::value, void>::type +std::enable_if_t, void> gemm_notrans_( int64_t m, int64_t n, @@ -132,7 +132,7 @@ gemm_notrans_( // std::is_same || std::is_same template -typename std::enable_if::value, void>::type +std::enable_if_t, void> gemm_notrans_( int64_t m, int64_t n, @@ -222,7 +222,7 @@ void gemm_transb_impl( } template -typename std::enable_if::value, void>::type +std::enable_if_t, void> gemm_transb_( TransposeType transb, int64_t m, @@ -244,7 +244,7 @@ gemm_transb_( // std::is_same || std::is_same template -typename std::enable_if::value, void>::type +std::enable_if_t, void> gemm_transb_( TransposeType transb, int64_t m, diff --git a/aten/src/ATen/native/cpu/CatKernel.cpp b/aten/src/ATen/native/cpu/CatKernel.cpp index 23d9aa1708ba7..f7f3c0c8d6f5b 100644 --- a/aten/src/ATen/native/cpu/CatKernel.cpp +++ b/aten/src/ATen/native/cpu/CatKernel.cpp @@ -2,9 +2,10 @@ #include #include -#include +#include #include #include +#include #include namespace at::native { @@ -16,15 +17,19 @@ struct InputMeta { int64_t inner_size; InputMeta(const Tensor& t, int64_t dim, int64_t inner) - : data_ptr(t.const_data_ptr()) - , inner_size(t.sizes()[dim] * inner) {} + : data_ptr(t.const_data_ptr()), inner_size(t.sizes()[dim] * inner) {} }; template -void cat_serial_kernel_impl(const Tensor& result, const MaterializedITensorListRef& tensors, int64_t dim) { +void cat_serial_kernel_impl( + const Tensor& result, + const MaterializedITensorListRef& tensors, + int64_t dim) { TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - dim >= 0 && dim < result.dim(), "dim out of range in cat_serial_kernel_impl"); - int64_t outer = result.numel() / (result.sizes()[dim] * result.strides()[dim]); + dim >= 0 && dim < result.dim(), + "dim out of range in cat_serial_kernel_impl"); + int64_t outer = + result.numel() / (result.sizes()[dim] * result.strides()[dim]); scalar_t* result_data = result.data_ptr(); int64_t ninputs = static_cast(tensors.size()); std::vector inputs; @@ -38,15 +43,16 @@ void cat_serial_kernel_impl(const Tensor& result, const MaterializedITensorListR for (const auto i : c10::irange(outer)) { for (const auto j : c10::irange(ninputs)) { int64_t local_inner = inputs[j].inner_size; - const scalar_t* input_ptr = (const scalar_t*)(inputs[j].data_ptr) + i * local_inner; + const scalar_t* input_ptr = + (const scalar_t*)(inputs[j].data_ptr) + i * local_inner; int64_t d = 0; for (; d < local_inner - (local_inner % Vec::size()); d += Vec::size()) { Vec in_vec = Vec::loadu(input_ptr + d); in_vec.store(result_ptr + d); } - #if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) - # pragma unroll - #endif +#if !defined(_MSC_VER) && !defined(COMPILING_FOR_MIN_SIZE) +#pragma unroll +#endif for (; d < local_inner; d++) { result_ptr[d] = input_ptr[d]; } @@ -55,14 +61,23 @@ void cat_serial_kernel_impl(const Tensor& result, const MaterializedITensorListR } } -void cat_serial_kernel(const Tensor& result, const MaterializedITensorListRef& tensors, int64_t dim) { - AT_DISPATCH_FLOATING_TYPES_AND2(kBFloat16, kHalf, result.scalar_type(), "cat_serial_kernel", [&]() { - cat_serial_kernel_impl(result, tensors, dim); - }); +void cat_serial_kernel( + const Tensor& result, + const MaterializedITensorListRef& tensors, + int64_t dim) { + AT_DISPATCH_V2( + result.scalar_type(), + "cat_serial_kernel", + AT_WRAP( + [&]() { cat_serial_kernel_impl(result, tensors, dim); }), + AT_EXPAND(AT_FLOATING_TYPES), + kBFloat16, + kHalf, + AT_EXPAND(AT_FLOAT8_TYPES)); } } // anonymous namespace REGISTER_DISPATCH(cat_serial_stub, &cat_serial_kernel); -} // at::native +} // namespace at::native diff --git a/aten/src/ATen/native/cpu/CopyKernel.cpp b/aten/src/ATen/native/cpu/CopyKernel.cpp index 906fa8911e884..c3cb5265bf5d4 100644 --- a/aten/src/ATen/native/cpu/CopyKernel.cpp +++ b/aten/src/ATen/native/cpu/CopyKernel.cpp @@ -82,7 +82,7 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne std::copy_n(base, 2, data.data()); const int64_t *outer_strides = &strides[2]; - for (const auto it C10_UNUSED : c10::irange(size1)) { + for ([[maybe_unused]] const auto it : c10::irange(size1)) { Vecd dst_s; if (strides_in[0] == 0) { dst_s = Vecd(dest_t(*((scalar_t*)data[1]))); @@ -151,7 +151,7 @@ static void reduced_float_copy_kernel(TensorIteratorBase &iter, bool requires_ne std::copy_n(base, 2, data.data()); const int64_t *outer_strides = &strides[2]; - for (const auto it C10_UNUSED : c10::irange(size1)) { + for ([[maybe_unused]] const auto it : c10::irange(size1)) { Vecd dst_s; if (strides_in[0] == 0) { dst_s = Vecd(dest_t(*((source_t*)data[1]))); diff --git a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp index 04d82d365baa3..5a96f89891ac2 100644 --- a/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/DistanceOpsKernel.cpp @@ -395,7 +395,7 @@ struct Dist { const scalar_t * t1_end = t1 + l1_size; const scalar_t * t2_end = t2 + l2_size; - for (const auto l C10_UNUSED : c10::irange(d)) { + for ([[maybe_unused]] const auto l : c10::irange(d)) { for (; t1 != t1_end; t1 += m, res += m) { const Vec vec_t1 = Vec::loadu(t1, count); Vec res_vec = Vec::loadu(res, count); diff --git a/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp b/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp index db1a3c48c90e1..981a470c4457d 100644 --- a/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp +++ b/aten/src/ATen/native/cpu/FlashAttentionKernel.cpp @@ -473,7 +473,7 @@ void cpu_flash_attention( scalar_t* transpose_buffer_ptr = transpose_buffer.get(); std::unique_ptr v_copy_buffer = std::make_unique(ekvSplitSize * packb_size); scalar_t* v_copy_buffer_ptr = v_copy_buffer.get(); - for (C10_UNUSED auto z : c10::irange(begin, end)) { + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { n = l * kvSplitSize; int64_t kvBlockSize = std::min(kvSplitSize, kvSize - n); int64_t ekvBlockSize = kvBlockSize % 2 == 0 ? kvBlockSize : kvBlockSize + 1; @@ -566,7 +566,7 @@ void cpu_flash_attention( ? query_padding_ptr + ompIdx * qSplitSize * eheadSize : nullptr; - for (C10_UNUSED auto z : c10::irange(begin, end)) { + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { int64_t m = k * qSplitSize; int64_t qBlockSize = std::min(qSplitSize, qSize - m); // Initialize max and sum @@ -931,7 +931,7 @@ void cpu_flash_attention_backward( at::Tensor dsum = at::empty({qSplitSize}, query.options().dtype(accumulate_dtype)); accum_t* dsum_data = dsum.data_ptr(); - for (C10_UNUSED auto z : c10::irange(begin, end)) { + for ([[maybe_unused]] auto z : c10::irange(begin, end)) { // rowsum of grad_out * out for (int64_t m = 0; m < qSize; m += qSplitSize) { int64_t qBlockSize = std::min(qSplitSize, qSize - m); diff --git a/aten/src/ATen/native/cpu/FunctionOfAMatrixUtilsKernel.cpp b/aten/src/ATen/native/cpu/FunctionOfAMatrixUtilsKernel.cpp index 92cf41c309e04..c6bd3f8c5681d 100644 --- a/aten/src/ATen/native/cpu/FunctionOfAMatrixUtilsKernel.cpp +++ b/aten/src/ATen/native/cpu/FunctionOfAMatrixUtilsKernel.cpp @@ -30,7 +30,7 @@ void _compute_linear_combination_cpu_kernel( auto* RESTRICT in_ptr = data[1]; auto* RESTRICT coeff_ptr = data[2]; - for (const auto elem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto elem : c10::irange(n)) { auto* RESTRICT out_data = reinterpret_cast(out_ptr); auto* RESTRICT in_data = reinterpret_cast(in_ptr); using primitive_t = typename scalar_value_type::type; diff --git a/aten/src/ATen/native/cpu/FusedAdagradKernel.cpp b/aten/src/ATen/native/cpu/FusedAdagradKernel.cpp index e19915e0a4f2c..24f04111f12cb 100644 --- a/aten/src/ATen/native/cpu/FusedAdagradKernel.cpp +++ b/aten/src/ATen/native/cpu/FusedAdagradKernel.cpp @@ -12,10 +12,10 @@ namespace at::native { namespace{ template -typename std::enable_if< - std::is_same::value || std::is_same::value, - void>:: - type inline adagrad_math( +std::enable_if_t< + std::is_same_v || std::is_same_v, + void> + inline adagrad_math( scalar_t* param_ptr, scalar_t* grad_ptr, scalar_t* state_sum_ptr, @@ -81,10 +81,10 @@ typename std::enable_if< template -typename std::enable_if< - std::is_same::value || std::is_same::value, - void>:: - type inline adagrad_math( +std::enable_if_t< + std::is_same_v || std::is_same_v, + void> + inline adagrad_math( scalar_t* param_ptr, scalar_t* grad_ptr, scalar_t* state_sum_ptr, diff --git a/aten/src/ATen/native/cpu/FusedAdamKernel.cpp b/aten/src/ATen/native/cpu/FusedAdamKernel.cpp index 239cdc3b37ac3..f583e089c6c01 100644 --- a/aten/src/ATen/native/cpu/FusedAdamKernel.cpp +++ b/aten/src/ATen/native/cpu/FusedAdamKernel.cpp @@ -12,10 +12,10 @@ namespace at::native { namespace{ template -typename std::enable_if< - std::is_same::value || std::is_same::value, - void>:: - type inline adam_math( +std::enable_if_t< + std::is_same_v || std::is_same_v, + void> + inline adam_math( scalar_t* param_ptr, scalar_t* exp_avg_ptr, scalar_t* exp_avg_sq_ptr, @@ -155,10 +155,10 @@ typename std::enable_if< template -typename std::enable_if< - std::is_same::value || std::is_same::value, - void>:: - type inline adam_math( +std::enable_if_t< + std::is_same_v || std::is_same_v, + void> + inline adam_math( scalar_t* param_ptr, scalar_t* exp_avg_ptr, scalar_t* exp_avg_sq_ptr, diff --git a/aten/src/ATen/native/cpu/FusedSGDKernel.cpp b/aten/src/ATen/native/cpu/FusedSGDKernel.cpp index 95e96ff5cf55d..023feeb16fe07 100644 --- a/aten/src/ATen/native/cpu/FusedSGDKernel.cpp +++ b/aten/src/ATen/native/cpu/FusedSGDKernel.cpp @@ -12,10 +12,10 @@ namespace at::native { namespace{ template -typename std::enable_if< - std::is_same::value || std::is_same::value, - void>:: - type inline sgd_math( +std::enable_if_t< + std::is_same_v || std::is_same_v, + void> + inline sgd_math( scalar_t* param_ptr, scalar_t* grad_ptr, scalar_t* momentum_buf_ptr, @@ -104,10 +104,10 @@ typename std::enable_if< template -typename std::enable_if< - std::is_same::value || std::is_same::value, - void>:: - type inline sgd_math( +std::enable_if_t< + std::is_same_v || std::is_same_v, + void> + inline sgd_math( scalar_t* param_ptr, scalar_t* grad_ptr, scalar_t* momentum_buf_ptr, diff --git a/aten/src/ATen/native/cpu/IndexKernel.cpp b/aten/src/ATen/native/cpu/IndexKernel.cpp index 1211eda0adb63..0b3d13beb9d58 100644 --- a/aten/src/ATen/native/cpu/IndexKernel.cpp +++ b/aten/src/ATen/native/cpu/IndexKernel.cpp @@ -78,7 +78,7 @@ void cpu_take_put_kernel( auto loop = [&](char** data, const int64_t* strides, int64_t n) { auto* iterated_data_bytes = data[0]; auto* index_data_bytes = data[1]; - for (const auto elem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto elem : c10::irange(n)) { auto idx = *reinterpret_cast(index_data_bytes); auto& iterated = *reinterpret_cast(iterated_data_bytes); @@ -203,7 +203,7 @@ void index_fill_kernel( auto handle_nonzero_idx_stride = [&](char** data, const int64_t* strides, int64_t n) { auto* self_data_bytes = data[0]; auto* index_data_bytes = data[1]; - for (const auto elem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto elem : c10::irange(n)) { auto* self_data = reinterpret_cast(self_data_bytes); auto idx = *reinterpret_cast(index_data_bytes); TORCH_CHECK_INDEX(idx >= -self_dim_size && idx < self_dim_size, @@ -229,7 +229,7 @@ void index_fill_kernel( if (idx < 0) { idx += self_dim_size; } - for (const auto elem C10_UNUSED: c10::irange(n)) { + for ([[maybe_unused]] const auto elem : c10::irange(n)) { auto* self_data = reinterpret_cast(self_data_bytes); self_data[idx * self_dim_stride] = fill_val; @@ -262,7 +262,7 @@ void index_copy_kernel( auto* self_data_bytes = data[0]; auto* index_data_bytes = data[1]; auto* source_data_bytes = data[2]; - for (const auto elem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto elem : c10::irange(n)) { auto* self_data = reinterpret_cast(self_data_bytes); auto idx = *reinterpret_cast(index_data_bytes); auto* source_data = reinterpret_cast(source_data_bytes); @@ -285,7 +285,7 @@ void index_copy_kernel( TORCH_CHECK_INDEX(idx >= 0 && idx < self_dim_size, "index_copy_(): index ", idx, " is out of bounds for dimension ", dim, " with size ", self_dim_size); - for (const auto elem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto elem : c10::irange(n)) { auto* self_data = reinterpret_cast(self_data_bytes); auto* source_data = reinterpret_cast(source_data_bytes); @@ -474,8 +474,7 @@ void cpu_hflip_vec(at::TensorIterator& iter) { constexpr auto stride = sizeof(scalar_t); TORCH_INTERNAL_ASSERT(stride == -strides[0] && stride == strides[1]); - for (const auto j C10_UNUSED : c10::irange(size1)) { - + for ([[maybe_unused]] const auto j : c10::irange(size1)) { // vectorized loop with negative stride for output char** C10_RESTRICT data_ = data_arr.data(); int64_t n = size0; @@ -543,8 +542,7 @@ void cpu_vflip_memcpy(at::TensorIterator& iter) { TORCH_INTERNAL_ASSERT(strides[0] == strides[1]); const int64_t stride = strides[0]; - for (const auto j C10_UNUSED : c10::irange(size1)) { - + for ([[maybe_unused]] const auto j : c10::irange(size1)) { char** C10_RESTRICT data_ = data_arr.data(); int64_t n = size0; diff --git a/aten/src/ATen/native/cpu/IsContiguous.h b/aten/src/ATen/native/cpu/IsContiguous.h index ddbbb6fb8f5af..02d8f5dd78e40 100644 --- a/aten/src/ATen/native/cpu/IsContiguous.h +++ b/aten/src/ATen/native/cpu/IsContiguous.h @@ -31,14 +31,16 @@ struct IsContiguous<0, -1, traits, s> { }; // output and all inputs are contiguous -template ::value>::type* = nullptr> +template < + typename traits, + std::enable_if_t>* = + nullptr> static inline bool is_contiguous(const int64_t* strides) { return IsContiguous::eval(strides); } template ::value>::type* = nullptr> + std::enable_if_t>* = nullptr> static inline bool is_contiguous(const int64_t* strides) { return IsContiguous::eval(strides); } @@ -46,14 +48,14 @@ static inline bool is_contiguous(const int64_t* strides) { // input at `s` is scalar (stride 0); output and other inputs are contiguous // NB: output is typically at strides[0] so first input corresponds to s=1 template ::value>::type* = nullptr> + std::enable_if_t>* = nullptr> static inline bool is_contiguous_scalar(const int64_t* strides) { static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds"); return IsContiguous::eval(strides); } template ::value>::type* = nullptr> + std::enable_if_t>* = nullptr> static inline bool is_contiguous_scalar(const int64_t* strides) { static_assert(s > 0 && s <= traits::arity, "scalar argument index out of bounds"); return IsContiguous::eval(strides); diff --git a/aten/src/ATen/native/cpu/Loops.h b/aten/src/ATen/native/cpu/Loops.h index a7a567aa915de..a910a329482b8 100644 --- a/aten/src/ATen/native/cpu/Loops.h +++ b/aten/src/ATen/native/cpu/Loops.h @@ -36,6 +36,7 @@ #include #include +#include #include namespace at::native { inline namespace CPU_CAPABILITY { @@ -271,7 +272,7 @@ struct VectorizedLoop2d { const int64_t *outer_strides = &strides[ntensors]; if (is_contiguous(strides)) { - for (const auto i C10_UNUSED : c10::irange(size1)) { + for ([[maybe_unused]] const auto i : c10::irange(size1)) { vectorized_loop(data.data(), size0, 0, op, vop); advance(data, outer_strides); } @@ -279,12 +280,12 @@ struct VectorizedLoop2d { using Indices = std::make_index_sequence; unroll_contiguous_scalar_checks(strides, Indices{}, [&](size_t idx) { if (idx) { - for (const auto i C10_UNUSED : c10::irange(size1)) { + for ([[maybe_unused]] const auto i : c10::irange(size1)) { vectorized_loop(data.data(), size0, idx, op, vop); advance(data, outer_strides); } } else { - for (const auto i C10_UNUSED : c10::irange(size1)) { + for ([[maybe_unused]] const auto i : c10::irange(size1)) { basic_loop(data.data(), strides, 0, size0, op); advance(data, outer_strides); } diff --git a/aten/src/ATen/native/cpu/MaxPoolKernel.cpp b/aten/src/ATen/native/cpu/MaxPoolKernel.cpp index c752106130fe1..15b784f055216 100644 --- a/aten/src/ATen/native/cpu/MaxPoolKernel.cpp +++ b/aten/src/ATen/native/cpu/MaxPoolKernel.cpp @@ -64,7 +64,7 @@ vec::Vectorized is_nan_vec(vec::Vectorized vec) { template inline -typename std::enable_if::value, void>::type +std::enable_if_t, void> compute_internal( const scalar_t* input_data, scalar_t* out_data, @@ -139,7 +139,7 @@ compute_internal( // std::is_same || std::is_same template inline -typename std::enable_if::value, void>::type +std::enable_if_t, void> compute_internal( const scalar_t* input_data, scalar_t* out_data, diff --git a/aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp b/aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp index 44d9a443c2e67..6714a60cbb3d1 100644 --- a/aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp +++ b/aten/src/ATen/native/cpu/MaxUnpoolKernel.cpp @@ -37,7 +37,10 @@ void cpu_max_unpool( // treat batch size and channels as one dimension // and the feature map as another dimension - [[maybe_unused]] int64_t channels, output_depth, output_height, output_width; + int64_t channels = 0; + [[maybe_unused]] int64_t output_depth = 0; + [[maybe_unused]] int64_t output_height = 0; + [[maybe_unused]] int64_t output_width = 0; if constexpr (is_3d) { TORCH_CHECK(ndim == 4 || ndim == 5, "MaxUnpool3d: expect input to be 4d or 5d tensor."); channels = ndim == 4 ? input.size(0) : input.size(0) * input.size(1); @@ -80,11 +83,11 @@ void cpu_max_unpool( if (optional_error_index) { if constexpr (is_3d) { - AT_ERROR("Found an invalid max index: ", optional_error_index.value(), + TORCH_CHECK(false, "Found an invalid max index: ", optional_error_index.value(), " (output volumes are of size ", output_depth, "x", output_height, "x", output_width); } else { - AT_ERROR("Found an invalid max index: ", optional_error_index.value(), + TORCH_CHECK(false, "Found an invalid max index: ", optional_error_index.value(), " (output volumes are of size ", output_height, "x", output_width); } @@ -148,7 +151,7 @@ void cpu_max_unpool_channels_last( }); if (optional_error_index) { - AT_ERROR("Found an invalid max index: ", optional_error_index.value(), + TORCH_CHECK(false, "Found an invalid max index: ", optional_error_index.value(), " (output volumes are of size ", output_height, "x", output_width, ")"); } @@ -174,7 +177,10 @@ void cpu_max_unpool_backward( // treat batch size and channels as one dimension // and the feature map as another dimension - int64_t channels, output_depth, output_height, output_width; + int64_t channels = 0; + [[maybe_unused]] int64_t output_depth = 0; + [[maybe_unused]] int64_t output_height = 0; + [[maybe_unused]] int64_t output_width = 0; if (is_3d) { TORCH_CHECK(ndim == 4 || ndim == 5, "MaxUnpool3d_backward: expect grad_output to be 4d or 5d tensor."); channels = ndim == 4 ? grad_output.size(0) : grad_output.size(0) * grad_output.size(1); @@ -217,12 +223,12 @@ void cpu_max_unpool_backward( if (optional_error_index) { if (is_3d) { - AT_ERROR("invalid max index ", optional_error_index.value(), + TORCH_CHECK(false, "invalid max index ", optional_error_index.value(), ", odepth= ", output_depth, ", owidth= ", output_width, ", oheight= ", output_height); } else { - AT_ERROR("invalid max index ", optional_error_index.value(), + TORCH_CHECK(false, "invalid max index ", optional_error_index.value(), ", owidth= ", output_width, ", oheight= ", output_height); } diff --git a/aten/src/ATen/native/cpu/Reduce.h b/aten/src/ATen/native/cpu/Reduce.h index 37bd32d1c4c13..09a8ba3b170fa 100644 --- a/aten/src/ATen/native/cpu/Reduce.h +++ b/aten/src/ATen/native/cpu/Reduce.h @@ -6,10 +6,9 @@ #include #include -#include #include -namespace at { namespace native { inline namespace CPU_CAPABILITY { +namespace at::native { inline namespace CPU_CAPABILITY { using namespace vec; @@ -70,7 +69,7 @@ inline void vectorized_reduction(char** data, int64_t n, int64_t stride, template inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, F f) { - for (const auto j C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto j : c10::irange(n)) { f(); data[0] += strides[0]; data[1] += strides[1]; @@ -81,7 +80,7 @@ inline void UNARY_OUTER_LOOP(char* data[2], const int64_t strides[2], int64_t n, template inline void vectorized_inner_reduction(char** data, int64_t n, func_t op, vec_func_t vop) { VEC_LOOP_HEADER(func_t, data) - int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t); + constexpr int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t); int64_t count = n / (4 * Vec::size()); if (count > 0) { vectorized_reduction(data, count, vector_stride, op, vop, /*reduce=*/true); @@ -96,12 +95,9 @@ template inline void vectorized_outer_reduction(char** data, int64_t inner_stride, int64_t size0, int64_t size1, func_t op, vec_func_t vop) { VEC_LOOP_HEADER(func_t, data) - // reduce down each column of 4 * Vec::size() elements (128 or 256 bytes) -#if defined(CPU_CAPABILITY_AVX512) - int64_t outer_stride[2] = { 256, 256 }; -#else - int64_t outer_stride[2] = { 128, 128 }; -#endif + // reduce down each column of 4 * Vec::size() elements. + constexpr int64_t vector_stride = 4 * Vec::size() * sizeof(scalar_t); + int64_t outer_stride[2] = { vector_stride, vector_stride }; UNARY_OUTER_LOOP(data, outer_stride, size1 / (4 * Vec::size()), [&] { vectorized_reduction(data, size0, inner_stride, op, vop, /*reduce=*/false); }); @@ -132,13 +128,13 @@ static void set_results(const res_t result, const TensorIteratorBase &iter, cons } template -inline typename std::enable_if::type +inline std::enable_if_t for_each_in_tuple(const std::tuple& /*t*/, const TensorIteratorBase& /*iter*/, const int /*num_outputs*/) { return i; } template -inline typename std::enable_if::type +inline std::enable_if_t for_each_in_tuple(const std::tuple& t, const TensorIteratorBase &iter, const int num_outputs) { if (i < (size_t)num_outputs) { set_result(i, std::get(t), iter, num_outputs); @@ -311,4 +307,4 @@ void binary_kernel_reduce_lastdim(TensorIteratorBase& iter, reduce_func_t reduce sub_iter.for_each(loop, grain_size); } -}}} // namespace at::native:: +}} // namespace at::native:: diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 2eaad7eb5d427..fa42aca950eca 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -62,11 +62,12 @@ static inline void cpu_cum_base_kernel(const Tensor& result, auto* result_data_bytes = data[0]; const auto* self_data_bytes = data[1]; - for (const auto i C10_UNUSED : c10::irange(n)) { - f( - (scalar_t*)result_data_bytes, result_dim_stride, - (scalar_t*)self_data_bytes, self_dim_stride, init_val - ); + for ([[maybe_unused]] const auto i : c10::irange(n)) { + f((scalar_t*)result_data_bytes, + result_dim_stride, + (scalar_t*)self_data_bytes, + self_dim_stride, + init_val); result_data_bytes += strides[0]; self_data_bytes += strides[1]; } diff --git a/aten/src/ATen/native/cpu/ReduceUtils.h b/aten/src/ATen/native/cpu/ReduceUtils.h index 8c6424f8b0eac..fd7c4a2750a6c 100644 --- a/aten/src/ATen/native/cpu/ReduceUtils.h +++ b/aten/src/ATen/native/cpu/ReduceUtils.h @@ -106,7 +106,7 @@ inline void _init(scalar_t* self_ptr, at::opmath_type* buffer_ptr, int } template -inline typename std::enable_if::value, scalar_t>::type +inline std::enable_if_t, scalar_t> _max(const scalar_t& x, const scalar_t& y) { return at::_isnan(y) ? y : std::max(x, y); } @@ -118,14 +118,14 @@ inline Vectorized _max(const Vectorized& x, const Vectorized } template -inline typename std::enable_if::value, Vec2>::type +inline std::enable_if_t, Vec2> _max(const vec_t& x, const vec_t& y) { // vec::maximum propagates NaN return maximum(x, y); } template -inline typename std::enable_if::value, scalar_t>::type +inline std::enable_if_t, scalar_t> _min(const scalar_t& x, const scalar_t& y) { return at::_isnan(y) ? y : std::min(x, y); } @@ -137,7 +137,7 @@ inline Vectorized _min(const Vectorized& x, const Vectorized } template -inline typename std::enable_if::value, Vec2>::type +inline std::enable_if_t, Vec2> _min(const vec_t& x, const vec_t& y) { // vec::minimum propagates NaN return minimum(x, y); diff --git a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp index 6af22033c805e..aaa8d3d438180 100644 --- a/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp +++ b/aten/src/ATen/native/cpu/ScatterGatherKernel.cpp @@ -215,7 +215,7 @@ struct cpu_scatter_gather_base_kernel { // vs dim-TensorIterator loop order depending on // whether dim is the last dimension if (dim== buffer.dim() - 1) { - for (const auto nelem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto nelem : c10::irange(n)) { // dim loop is a separate code block // for better performance loop_func.template operator()( @@ -232,7 +232,7 @@ struct cpu_scatter_gather_base_kernel { for (const auto i : c10::irange(index_dim_size)) { auto* self_data = self_data_bytes; auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride); - for (const auto nelem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto nelem : c10::irange(n)) { int64_t idx_dim = *(int64_t*)index_data; // we are not putting idx_dim in the error message because it disables // loop optimization in clang-7 @@ -306,7 +306,7 @@ struct cpu_scatter_gather_base_kernel { // vs dim-TensorIterator loop order depending on // whether dim is the last dimension if (dim== buffer.dim() - 1) { - for (const auto nelem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto nelem : c10::irange(n)) { // dim loop is a separate code block // for better performance loop_func.template operator()( @@ -327,7 +327,7 @@ struct cpu_scatter_gather_base_kernel { auto* self_data = self_data_bytes; auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride); auto* src_data = src_data_bytes; - for (const auto nelem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto nelem : c10::irange(n)) { int64_t idx_dim = *(int64_t*)index_data; // we are not putting idx_dim in the error message because it disables // loop optimization in clang-7 @@ -402,7 +402,7 @@ struct cpu_scatter_gather_base_kernel { // vs dim-TensorIterator loop order depending on // whether dim is the last dimension if (dim== buffer.dim() - 1) { - for (const auto nelem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto nelem : c10::irange(n)) { // dim loop is a separate code block // for better performance loop_func.template operator()( @@ -423,7 +423,7 @@ struct cpu_scatter_gather_base_kernel { auto* self_data = self_data_bytes; auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride); auto* src_data = src_data_bytes; - for (const auto nelem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto nelem : c10::irange(n)) { int64_t idx_dim = *(int64_t*)index_data; // we are not putting idx_dim in the error message because it disables // loop optimization in clang-7 @@ -497,7 +497,7 @@ struct cpu_scatter_gather_base_kernel { // vs dim-TensorIterator loop order depending on // whether dim is the last dimension if (dim== buffer.dim() - 1) { - for (const auto nelem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto nelem : c10::irange(n)) { // dim loop is a separate code block // for better performance loop_func.template operator()( @@ -518,7 +518,7 @@ struct cpu_scatter_gather_base_kernel { auto* self_data = self_data_bytes; auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride); auto* src_data = src_data_bytes; - for (const auto nelem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto nelem : c10::irange(n)) { int64_t idx_dim = *(int64_t*)index_data; // we are not putting idx_dim in the error message because it disables // loop optimization in clang-7 @@ -593,7 +593,7 @@ struct cpu_scatter_gather_base_kernel { // vs dim-TensorIterator loop order depending on // whether dim is the last dimension if (dim== buffer.dim() - 1) { - for (const auto nelem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto nelem : c10::irange(n)) { // dim loop is a separate code block // for better performance loop_func.template operator()( @@ -614,7 +614,7 @@ struct cpu_scatter_gather_base_kernel { auto* self_data = self_data_bytes; auto* index_data = (char*)((int64_t*)index_data_bytes + i * index_dim_stride); auto* src_data = src_data_bytes; - for (const auto nelem C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto nelem : c10::irange(n)) { int64_t idx_dim = *(int64_t*)index_data; // we are not putting idx_dim in the error message because it disables // loop optimization in clang-7 diff --git a/aten/src/ATen/native/cpu/SortingKernel.cpp b/aten/src/ATen/native/cpu/SortingKernel.cpp index 0382668ce1e1e..e9a62e3692b5d 100644 --- a/aten/src/ATen/native/cpu/SortingKernel.cpp +++ b/aten/src/ATen/native/cpu/SortingKernel.cpp @@ -53,14 +53,12 @@ void _dim_apply( return; } - for (const auto i C10_UNUSED : c10::irange(n)) { - f( - reinterpret_cast(values_data_bytes), + for ([[maybe_unused]] const auto i : c10::irange(n)) { + f(reinterpret_cast(values_data_bytes), values_dim_stride, reinterpret_cast(indices_data_bytes), indices_dim_stride, - dim_size - ); + dim_size); values_data_bytes += strides[0]; indices_data_bytes += strides[1]; diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index b374935036dad..f862b50985e2d 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -83,7 +83,7 @@ static inline void compare_base_kernel(const Tensor& result1, const Tensor& resu auto* result1_data_bytes = data[0]; auto* result2_data_bytes = data[1]; const auto* self_data_bytes = data[2]; - for (const auto i C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto i : c10::irange(n)) { f((scalar_t*)result1_data_bytes, (scalar_t_2*)result2_data_bytes, (scalar_t*)self_data_bytes, @@ -253,7 +253,7 @@ static void mode_kernel_impl( std::vector> elements(self_dim_size); - for (const auto k C10_UNUSED : c10::irange(n)) { + for ([[maybe_unused]] const auto k : c10::irange(n)) { scalar_t* values_data = (scalar_t*)values_data_bytes; int64_t* indices_data = (int64_t*)indices_data_bytes; const scalar_t* self_data = (scalar_t*)self_data_bytes; diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 9754b003e19c6..37b84db0aeb88 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -600,14 +600,16 @@ static void i0e_kernel(TensorIteratorBase& iter) { static void i1_kernel(TensorIteratorBase& iter) { TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); - AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1_cpu", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, kHalf, iter.common_dtype(), "i1_cpu", [&]() { cpu_kernel(iter, [](scalar_t x) { return calc_i1(x); }); }); } static void i1e_kernel(TensorIteratorBase& iter) { TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); - AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1e_cpu", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, kHalf, iter.common_dtype(), "i1e_cpu", [&]() { cpu_kernel(iter, [](scalar_t x) { return calc_i1e(x); }); }); } diff --git a/aten/src/ATen/native/cpu/Unfold2d.cpp b/aten/src/ATen/native/cpu/Unfold2d.cpp index 026cfa812f3c6..da3a77a0f1797 100644 --- a/aten/src/ATen/native/cpu/Unfold2d.cpp +++ b/aten/src/ATen/native/cpu/Unfold2d.cpp @@ -353,8 +353,9 @@ static void unfolded2d_copy_channels_last( int64_t x = 0; data_index_init(start, y, output_height, x, output_width); - for (const auto k C10_UNUSED: c10::irange(start, end)) { - scalar_t* dst = finput_data + y * output_width * kH * kW * n_input_plane + x * kH * kW * n_input_plane; + for (const auto k [[maybe_unused]] : c10::irange(start, end)) { + scalar_t* dst = finput_data + y * output_width * kH * kW * n_input_plane + + x * kH * kW * n_input_plane; const scalar_t* src = input_data; if (padW > 0 || padH > 0) { diff --git a/aten/src/ATen/native/cpu/UnfoldBackwardKernel.cpp b/aten/src/ATen/native/cpu/UnfoldBackwardKernel.cpp index 35049ce21d2e7..50ea531d57f07 100644 --- a/aten/src/ATen/native/cpu/UnfoldBackwardKernel.cpp +++ b/aten/src/ATen/native/cpu/UnfoldBackwardKernel.cpp @@ -76,7 +76,7 @@ void _unfold_backward_internal_kernel( auto* RESTRICT grad_in_ptr = data[1]; auto* RESTRICT idx_dim_ptr = data[2]; - for (const auto elem C10_UNUSED : c10::irange(nelems)) { + for ([[maybe_unused]] const auto elem : c10::irange(nelems)) { auto* RESTRICT grad_out_data = reinterpret_cast(grad_out_ptr); auto* RESTRICT grad_in_data = reinterpret_cast(grad_in_ptr); diff --git a/aten/src/ATen/native/cpu/UpSampleKernel.cpp b/aten/src/ATen/native/cpu/UpSampleKernel.cpp index ca11ffe88aeeb..a617f8a74f05e 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernel.cpp +++ b/aten/src/ATen/native/cpu/UpSampleKernel.cpp @@ -733,8 +733,9 @@ struct HelperInterpBase { auto new_shape = std::vector(ndims, 1); new_shape[reshape_dim] = output_size; - for (const auto j C10_UNUSED : c10::irange(interp_size)) { - output.emplace_back(empty(new_shape, CPU(c10::CppTypeToScalarType()))); + for ([[maybe_unused]] const auto j : c10::irange(interp_size)) { + output.emplace_back( + empty(new_shape, CPU(c10::CppTypeToScalarType()))); output.emplace_back(empty(new_shape, CPU(output_type))); } } @@ -1047,8 +1048,9 @@ struct HelperInterpNearest : public HelperInterpBase { auto new_shape = std::vector(ndims, 1); new_shape[reshape_dim] = output_size; - for (const auto j C10_UNUSED : c10::irange(interp_size)) { - output.emplace_back(empty(new_shape, CPU(c10::CppTypeToScalarType()))); + for ([[maybe_unused]] const auto j : c10::irange(interp_size)) { + output.emplace_back( + empty(new_shape, CPU(c10::CppTypeToScalarType()))); // Defines weights for consistency, but not used output.emplace_back(at::ones(new_shape, CPU(output_type))); } diff --git a/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h b/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h index 726a83c20963d..5b545509b1d99 100644 --- a/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h +++ b/aten/src/ATen/native/cpu/UpSampleKernelAVXAntialias.h @@ -102,7 +102,7 @@ void pack_rgb( TORCH_INTERNAL_ASSERT(unpacked_increment == 3 || unpacked_increment == 4); - for (const auto i C10_UNUSED : c10::irange(num_pixels)) { + for ([[maybe_unused]] const auto i : c10::irange(num_pixels)) { for (const auto j : c10::irange(num_channels)) { packed[j * packed_stride] = unpacked[j]; } diff --git a/aten/src/ATen/native/cpu/group_norm_kernel.cpp b/aten/src/ATen/native/cpu/group_norm_kernel.cpp index f6b7f2a5d4813..0aee364b49d8c 100644 --- a/aten/src/ATen/native/cpu/group_norm_kernel.cpp +++ b/aten/src/ATen/native/cpu/group_norm_kernel.cpp @@ -85,8 +85,8 @@ void GroupNormKernelImplInternal( } template -typename std::enable_if>::value, - std::tuple>::type +std::enable_if_t>, + std::tuple> ColumnwiseMoments( const T* X_data, int64_t HxW, @@ -118,8 +118,8 @@ ColumnwiseMoments( // std::is_same || std::is_same template -typename std::enable_if>::value, - std::tuple, at::opmath_type>>::type +std::enable_if_t>, + std::tuple, at::opmath_type>> ColumnwiseMoments( const T* X_data, int64_t HxW, @@ -160,7 +160,7 @@ ColumnwiseMoments( } template -inline typename std::enable_if::value, void>::type +inline std::enable_if_t, void> CalcMeanVar( const T* X_ptr, opmath_t* mean_ptr, @@ -183,7 +183,7 @@ CalcMeanVar( // std::is_same || std::is_same template -inline typename std::enable_if::value, void>::type +inline std::enable_if_t, void> CalcMeanVar( const T* X_ptr, opmath_t* mean_ptr, @@ -227,7 +227,7 @@ CalcMeanVar( } template -inline typename std::enable_if::value, void>::type +inline std::enable_if_t, void> ApplyScaleBias( T* Y_ptr, const T* X_ptr, @@ -246,7 +246,7 @@ ApplyScaleBias( // std::is_same || std::is_same template -inline typename std::enable_if::value, void>::type +inline std::enable_if_t, void> ApplyScaleBias( T* Y_ptr, const T* X_ptr, @@ -529,7 +529,7 @@ void GroupNormKernelImpl( template -typename std::enable_if::value, void>::type +std::enable_if_t, void> ComputeInternalGradients( int64_t N, int64_t C, @@ -556,7 +556,7 @@ ComputeInternalGradients( } template -typename std::enable_if::value, void>::type +std::enable_if_t, void> ComputeInternalGradients( int64_t N, int64_t C, @@ -603,7 +603,7 @@ ComputeInternalGradients( } template -inline typename std::enable_if::value, void>::type +inline std::enable_if_t, void> CalcDsDb( const opmath_t* ds_ptr, const opmath_t* db_ptr, @@ -626,7 +626,7 @@ CalcDsDb( } template -inline typename std::enable_if::value, void>::type +inline std::enable_if_t, void> CalcDsDb( const opmath_t* ds_ptr, const opmath_t* db_ptr, @@ -708,7 +708,7 @@ void GroupNormInputBackward( } template -typename std::enable_if::value, void>::type +std::enable_if_t, void> GammaBackward( int64_t N, int64_t C, @@ -755,7 +755,7 @@ GammaBackward( } template -typename std::enable_if::value, void>::type +std::enable_if_t, void> GammaBackward( int64_t N, int64_t C, @@ -817,7 +817,7 @@ GammaBackward( } template -typename std::enable_if::value, void>::type +std::enable_if_t, void> BetaBackward(int64_t N, int64_t C, const opmath_t* db, PT* dbeta) { using Vec = at::vec::Vectorized; constexpr int64_t K = Vec::size(); @@ -841,7 +841,7 @@ BetaBackward(int64_t N, int64_t C, const opmath_t* db, PT* dbeta) { } template -typename std::enable_if::value, void>::type +std::enable_if_t, void> BetaBackward(int64_t N, int64_t C, const opmath_t* db, PT* dbeta) { using Vec = at::vec::Vectorized; using fVec = at::vec::Vectorized; @@ -937,7 +937,7 @@ void GroupNormBackwardKernelImplInternal( } template -inline typename std::enable_if::value, void>::type +inline std::enable_if_t, void> DsDbRowwiseMomentsChannelsLast( const T* dY_ptr, const T* X_ptr, @@ -972,7 +972,7 @@ DsDbRowwiseMomentsChannelsLast( } template -inline typename std::enable_if::value, void>::type +inline std::enable_if_t, void> DsDbRowwiseMomentsChannelsLast( const T* dY_ptr, const T* X_ptr, @@ -1024,10 +1024,10 @@ DsDbRowwiseMomentsChannelsLast( } template -inline typename std::enable_if>::value, +inline std::enable_if_t>, std::tuple< vec::Vectorized, - vec::Vectorized>>::type + vec::Vectorized>> load_util(const T* data_ptr, int64_t n) { using Vec = vec::Vectorized; auto vec0 = Vec::loadu(data_ptr, n > Vec::size() ? Vec::size() : n); @@ -1037,11 +1037,11 @@ load_util(const T* data_ptr, int64_t n) { } template -inline typename std::enable_if>::value, +inline std::enable_if_t>, std::tuple< vec::Vectorized>, vec::Vectorized>> - >::type + > load_util(const T* data_ptr, int64_t n) { using Vec = vec::Vectorized; auto vec = Vec::loadu(data_ptr, n); @@ -1049,7 +1049,7 @@ load_util(const T* data_ptr, int64_t n) { } template -inline typename std::enable_if::value, void>::type +inline std::enable_if_t, void> ApplyInputGradientsChannelsLastColMov( const T* dY_data, const T* X_data, @@ -1097,7 +1097,7 @@ ApplyInputGradientsChannelsLastColMov( } template -inline typename std::enable_if::value, void>::type +inline std::enable_if_t, void> ApplyInputGradientsChannelsLastColMov( const T* dY_data, const T* X_data, @@ -1154,7 +1154,7 @@ ApplyInputGradientsChannelsLastColMov( } template -inline typename std::enable_if::value, void>::type +inline std::enable_if_t, void> ApplyInputGradientsChannelsLastRowMov( const T* dY_data, const T* X_data, @@ -1190,7 +1190,7 @@ ApplyInputGradientsChannelsLastRowMov( } template -inline typename std::enable_if::value, void>::type +inline std::enable_if_t, void> ApplyInputGradientsChannelsLastRowMov( const T* dY_data, const T* X_data, diff --git a/aten/src/ATen/native/cpu/int4mm_kernel.cpp b/aten/src/ATen/native/cpu/int4mm_kernel.cpp index f46f6625e3a8b..662f3af142554 100644 --- a/aten/src/ATen/native/cpu/int4mm_kernel.cpp +++ b/aten/src/ATen/native/cpu/int4mm_kernel.cpp @@ -723,7 +723,7 @@ void int4pack_mm_kernel_( int mb{0}, nb{0}; data_index_init(begin, mb, MB, nb, NB); - for (C10_UNUSED const auto i : c10::irange(begin, end)) { + for ([[maybe_unused]] const auto i : c10::irange(begin, end)) { int mb_start = mb * BLOCK_M; int mb_size = std::min(BLOCK_M, M - mb_start); int nb_start = nb * BLOCK_N; diff --git a/aten/src/ATen/native/cuda/AbsKernel.cu b/aten/src/ATen/native/cuda/AbsKernel.cu index 980bd6637341e..e2c0a456a232b 100644 --- a/aten/src/ATen/native/cuda/AbsKernel.cu +++ b/aten/src/ATen/native/cuda/AbsKernel.cu @@ -15,7 +15,7 @@ struct AbsFunctor { } }; -CONSTEXPR_EXCEPT_WIN_CUDA char abs_name[] = "abs_kernel"; +constexpr char abs_name[] = "abs_kernel"; void abs_kernel_cuda(TensorIteratorBase& iter) { auto dtype = iter.dtype(); if (at::isComplexType(dtype)) { diff --git a/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu b/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu index aa955a9c7e546..a7fa53fcb0abd 100644 --- a/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryDivTrueKernel.cu @@ -16,7 +16,7 @@ namespace at::native { namespace binary_internal { -CONSTEXPR_EXCEPT_WIN_CUDA char div_name[] = "div_kernel"; +constexpr char div_name[] = "div_kernel"; void div_true_kernel_cuda(TensorIteratorBase& iter) { auto common_dtype = iter.common_dtype(); if (iter.common_dtype() == kComplexHalf) { diff --git a/aten/src/ATen/native/cuda/BinaryLogicalOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryLogicalOpsKernels.cu index eaa01ac1accc8..918a6ba4e981e 100644 --- a/aten/src/ATen/native/cuda/BinaryLogicalOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryLogicalOpsKernels.cu @@ -11,7 +11,7 @@ namespace at::native { -CONSTEXPR_EXCEPT_WIN_CUDA char logical_and_name[] = "logical_and_kernel"; +constexpr char logical_and_name[] = "logical_and_kernel"; void logical_and_kernel_cuda(TensorIterator& iter) { auto dtype = iter.common_dtype(); if (at::isComplexType(dtype)) { @@ -48,7 +48,7 @@ void logical_and_kernel_cuda(TensorIterator& iter) { } } -CONSTEXPR_EXCEPT_WIN_CUDA char logical_or_name[] = "logical_or_kernel"; +constexpr char logical_or_name[] = "logical_or_kernel"; void logical_or_kernel_cuda(TensorIterator& iter) { auto dtype = iter.common_dtype(); if (at::isComplexType(dtype)) { @@ -84,7 +84,7 @@ void logical_or_kernel_cuda(TensorIterator& iter) { } } -CONSTEXPR_EXCEPT_WIN_CUDA char logical_xor_name[] = "logical_xor_kernel"; +constexpr char logical_xor_name[] = "logical_xor_kernel"; void logical_xor_kernel_cuda(TensorIterator& iter) { auto dtype = iter.common_dtype(); if (at::isComplexType(dtype)) { diff --git a/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu b/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu index 75d5991f93db5..0cd4c5040fe70 100644 --- a/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu +++ b/aten/src/ATen/native/cuda/BinaryMiscBackwardOpsKernels.cu @@ -15,7 +15,7 @@ namespace at::native { -CONSTEXPR_EXCEPT_WIN_CUDA char sigmoid_backward_name[] = "sigmoid_backward"; +constexpr char sigmoid_backward_name[] = "sigmoid_backward"; void sigmoid_backward_kernel_cuda(TensorIteratorBase& iter) { auto dtype = iter.dtype(); if(isComplexType(dtype)) { @@ -86,7 +86,7 @@ void logit_backward_kernel_cuda(TensorIteratorBase& iter, const Scalar& eps_scal }); } -CONSTEXPR_EXCEPT_WIN_CUDA char tanh_backward_name[] = "tanh_backward"; +constexpr char tanh_backward_name[] = "tanh_backward"; void tanh_backward_kernel_cuda(TensorIteratorBase& iter) { auto dtype = iter.dtype(); if(isComplexType(dtype)) { diff --git a/aten/src/ATen/native/cuda/BinaryMulKernel.cu b/aten/src/ATen/native/cuda/BinaryMulKernel.cu index 251221f7adcd1..242ff1c7cd52e 100644 --- a/aten/src/ATen/native/cuda/BinaryMulKernel.cu +++ b/aten/src/ATen/native/cuda/BinaryMulKernel.cu @@ -18,7 +18,7 @@ namespace at::native { -CONSTEXPR_EXCEPT_WIN_CUDA char mul_name[] = "mul_kernel"; +constexpr char mul_name[] = "mul_kernel"; void mul_kernel_cuda(TensorIteratorBase& iter) { auto common_dtype = iter.common_dtype(); if (common_dtype == kComplexHalf) { diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index f25b17b2db9b1..67efe6dea5c74 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -79,6 +79,7 @@ c10::MaybeOwned inline prepare_matrix_for_cublas(const Tensor& tensor, b transpose_tensor = tensor.is_contiguous(); return resolve_conj_if_indicated(tensor, true); } + IntArrayRef tensor_strides = tensor.strides(); IntArrayRef tensor_sizes = tensor.sizes(); if ((tensor_strides[0] == 1) && (tensor_strides[1] >= std::max(1, tensor_sizes[0]))) { @@ -179,29 +180,22 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa static bool getDisableAddmmCudaLt() { static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT"); -#ifdef USE_ROCM - // allow both CUDA and HIP env var names for ROCm builds - // also, current default for ROCm builds is disable by default - if (env_value == nullptr) { - env_value = std::getenv("DISABLE_ADDMM_HIP_LT"); - } - if (env_value != nullptr && strcmp(env_value, "0") == 0) { - return false; - } - return true; -#else if (env_value != nullptr && strcmp(env_value, "1") == 0) { return true; } return false; -#endif } #ifdef USE_ROCM static bool isSupportedHipLtROCmArch(int index) { hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index); std::string device_arch = prop->gcnArchName; - static const std::vector archs = {"gfx90a", "gfx940", "gfx941", "gfx942"}; + static const std::vector archs = { + "gfx90a", "gfx940", "gfx941", "gfx942", +#if ROCM_VERSION >= 60300 + "gfx1100", "gfx1101" +#endif + }; for (std::string arch : archs) { size_t substring = device_arch.find(arch); if (substring != std::string::npos) { @@ -322,14 +316,6 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma } self__sizes = self_->sizes(); } else { -#if defined(USE_ROCM) - useLtInterface = !disable_addmm_cuda_lt && - result.dim() == 2 && result.is_contiguous() && - isSupportedHipLtROCmArch(self.device().index()) && - (scalar_type == at::ScalarType::Float || - scalar_type == at::ScalarType::Half || - scalar_type == at::ScalarType::BFloat16); -#endif self_ = c10::MaybeOwned::borrowed(self); self__sizes = self_->sizes(); TORCH_CHECK(result.dim() == 2, "tensors must be 2-D"); diff --git a/aten/src/ATen/native/cuda/CuFFTUtils.h b/aten/src/ATen/native/cuda/CuFFTUtils.h index 4b02f914d7e20..f20baa9568661 100644 --- a/aten/src/ATen/native/cuda/CuFFTUtils.h +++ b/aten/src/ATen/native/cuda/CuFFTUtils.h @@ -66,7 +66,7 @@ static inline void CUFFT_CHECK(cufftResult error) if (error != CUFFT_SUCCESS) { std::ostringstream ss; ss << "cuFFT error: " << _cudaGetErrorEnum(error); - AT_ERROR(ss.str()); + TORCH_CHECK(false, ss.str()); } } diff --git a/aten/src/ATen/native/cuda/EmbeddingBag.cu b/aten/src/ATen/native/cuda/EmbeddingBag.cu index 97f69c1ccd72e..6514ab6f2dec6 100644 --- a/aten/src/ATen/native/cuda/EmbeddingBag.cu +++ b/aten/src/ATen/native/cuda/EmbeddingBag.cu @@ -462,7 +462,7 @@ Tensor _embedding_bag_dense_backward_cuda(const Tensor &grad_, const Tensor &ind padding_idx); default: - AT_ERROR( + TORCH_CHECK(false, "Unknown mode for embedding_bag_backward_cuda ", mode); } } diff --git a/aten/src/ATen/native/cuda/GcdLcmKernel.cu b/aten/src/ATen/native/cuda/GcdLcmKernel.cu index c4a8cdfaf1f8e..6b003a6f4fc03 100644 --- a/aten/src/ATen/native/cuda/GcdLcmKernel.cu +++ b/aten/src/ATen/native/cuda/GcdLcmKernel.cu @@ -14,7 +14,7 @@ namespace at::native { // See note [Jiterator] -CONSTEXPR_EXCEPT_WIN_CUDA char gcd_name[] = "gcd"; +constexpr char gcd_name[] = "gcd"; void gcd_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "gcd_cuda", [&]() { @@ -33,7 +33,7 @@ void gcd_kernel_cuda(TensorIteratorBase& iter) { } // See note [Jiterator] -CONSTEXPR_EXCEPT_WIN_CUDA char lcm_name[] = "lcm"; +constexpr char lcm_name[] = "lcm"; void lcm_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_INTEGRAL_TYPES(iter.common_dtype(), "lcm_cuda", [&]() { diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index 31a87991e0418..4f1a37ccf5e99 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -457,12 +457,20 @@ void flip_kernel(TensorIterator& iter, const bool quantized) { flip_kernel_impl(iter); }); } else { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16, - iter.dtype(), "flip_cuda", - [&] { - using dtype = OpaqueType; - flip_kernel_impl(iter); - }); + AT_DISPATCH_V2( + iter.dtype(), + "flip_cuda", + AT_WRAP([&] { + using dtype = OpaqueType; + flip_kernel_impl(iter); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + AT_EXPAND(AT_FLOAT8_TYPES), + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), + kComplexHalf, + kHalf, + kBool, + kBFloat16); } } diff --git a/aten/src/ATen/native/cuda/Indexing.cu b/aten/src/ATen/native/cuda/Indexing.cu index 2806cbb56dd8f..822525556bc3f 100644 --- a/aten/src/ATen/native/cuda/Indexing.cu +++ b/aten/src/ATen/native/cuda/Indexing.cu @@ -124,6 +124,55 @@ __global__ void indexing_backward_kernel( } } +#ifdef USE_ROCM +template +__global__ void indexing_backward_kernel_rocm( + const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, + int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim) { + + // This implementation is adopted from indexing_backward_kernel above. + using opmath_t = at::opmath_type; + for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){ + int64_t idx = blockIdx.x * blockDim.y + threadIdx.y; + if (idx < numel && (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){ + do { + // if not accumulate, we only keep the last duplicate index so skip those before it + if constexpr (!accumulate) { + if ((idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) { + idx++; + continue; + } + } + const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before; + const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride; + + opmath_t gradient; + opmath_t weight; + + int64_t feature_dim = threadIdx.x + blockIdx.y * blockDim.x; + while (feature_dim < stride) { + gradient = static_cast(grad_output[grad_row + feature_dim]); + if constexpr (accumulate) { + weight = static_cast(grad_weight[weight_row + feature_dim]); + } + + if constexpr (accumulate) { + weight += gradient; + } else { + weight = gradient; + } + + grad_weight[weight_row + feature_dim] = static_cast(weight); + feature_dim += gridDim.y * blockDim.x; + } + + idx++; + } while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1]); + } + } +} +#endif + template __global__ void indexing_backward_kernel_stride_1( const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight, @@ -470,7 +519,7 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List<<>>( + sorted_indices.const_data_ptr(), + orig_indices.const_data_ptr(), + expandedValue.const_data_ptr(), + src_.mutable_data_ptr(), + num_indices, + sliceSize, + strideBefore, + nElemBefore); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + AT_EXPAND(AT_FLOAT8_TYPES), + kComplexHalf, + kHalf, + kBool, + kBFloat16); + } else { + AT_DISPATCH_V2( + expandedValue.scalar_type(), + "indexing_backward", + AT_WRAP([&] { + indexing_backward_kernel_rocm<<>>( + sorted_indices.const_data_ptr(), + orig_indices.const_data_ptr(), + expandedValue.const_data_ptr(), + src_.mutable_data_ptr(), + num_indices, + sliceSize, + strideBefore, + nElemBefore); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + AT_EXPAND(AT_FLOAT8_TYPES), + kComplexHalf, + kHalf, + kBool, + kBFloat16); + } +#endif } else { AT_DISPATCH_V2( expandedValue.scalar_type(), @@ -572,8 +673,8 @@ void index_put_with_sort_kernel(Tensor & self, const c10::List void index_select_out_cuda_impl( Tensor& out, const Tensor& self, - uint64_t dim, + int64_t dim, const Tensor& index) { uint64_t numIndices = index.numel(); auto selfDims = self.dim() == 0 ? 1 : self.dim(); @@ -1506,24 +1607,27 @@ Tensor& index_select_out_cuda( dim = at::maybe_wrap_dim(dim, self); TORCH_CHECK(self.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING); TORCH_CHECK(index.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING); - if (self.is_quantized()){ + if (self.is_quantized()) { TORCH_CHECK( - self.qscheme() == kPerTensorAffine, - "Only per_tensor quantized quantized tensors are supported by index_select.") + self.qscheme() == kPerTensorAffine, + "Only per_tensor quantized quantized tensors are supported by index_select.") AT_DISPATCH_QINT_TYPES(out.scalar_type(), "index_select_quant_cuda", [&] { - index_select_out_cuda_impl(out, self, (uint64_t) dim, index); + index_select_out_cuda_impl(out, self, dim, index); }); } else { AT_DISPATCH_V2( out.scalar_type(), "index_select_cuda", - AT_WRAP([&] { index_select_out_cuda_impl(out, self, (uint64_t) dim, index); }), - AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), AT_EXPAND(AT_FLOAT8_TYPES), + AT_WRAP([&] { + index_select_out_cuda_impl(out, self, dim, index); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), + AT_EXPAND(AT_FLOAT8_TYPES), kComplexHalf, kHalf, kBool, - kBFloat16 - ); + kBFloat16); } return out; diff --git a/aten/src/ATen/native/cuda/KernelUtils.cuh b/aten/src/ATen/native/cuda/KernelUtils.cuh index 407083fc810ef..45e2415572db0 100644 --- a/aten/src/ATen/native/cuda/KernelUtils.cuh +++ b/aten/src/ATen/native/cuda/KernelUtils.cuh @@ -7,8 +7,8 @@ // ROCm 6.3 is planned to have these functions, but until then here they are. #if defined(USE_ROCM) && ROCM_VERSION >= 60201 -#include #include +#include __device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) { #if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) && \ diff --git a/aten/src/ATen/native/cuda/Lerp.cu b/aten/src/ATen/native/cuda/Lerp.cu index 01053a3beeabd..25692dcd4c494 100644 --- a/aten/src/ATen/native/cuda/Lerp.cu +++ b/aten/src/ATen/native/cuda/Lerp.cu @@ -9,7 +9,7 @@ namespace at::native { namespace { -CONSTEXPR_EXCEPT_WIN_CUDA char lerp_tensor_name[] = "lerp_tensor"; +constexpr char lerp_tensor_name[] = "lerp_tensor"; void lerp_tensor_kernel(at::TensorIteratorBase& iter) { auto dtype = iter.common_dtype(); if(at::isComplexType(dtype)) { @@ -63,7 +63,7 @@ void lerp_tensor_kernel(at::TensorIteratorBase& iter) { } } -CONSTEXPR_EXCEPT_WIN_CUDA char lerp_scalar_name[] = "lerp_scalar"; +constexpr char lerp_scalar_name[] = "lerp_scalar"; void lerp_scalar_kernel(at::TensorIteratorBase& iter, const c10::Scalar& weight) { auto dtype = iter.common_dtype(); if (at::isComplexType(dtype)) { diff --git a/aten/src/ATen/native/cuda/Loops.cuh b/aten/src/ATen/native/cuda/Loops.cuh index f5d9403950aa2..d157d44ade9de 100644 --- a/aten/src/ATen/native/cuda/Loops.cuh +++ b/aten/src/ATen/native/cuda/Loops.cuh @@ -11,6 +11,7 @@ #include +#include namespace at::native { diff --git a/aten/src/ATen/native/cuda/MaxUnpooling.cu b/aten/src/ATen/native/cuda/MaxUnpooling.cu index a7b48fad280a4..2cee1156ed1c9 100644 --- a/aten/src/ATen/native/cuda/MaxUnpooling.cu +++ b/aten/src/ATen/native/cuda/MaxUnpooling.cu @@ -267,7 +267,7 @@ static void max_unpooling3d_shape_check( if (gradOutput.defined()) { if (oT != gradOutput.size(dimt) || oH != gradOutput.size(dimh) || oW != gradOutput.size(dimw)) { - AT_ERROR( + TORCH_CHECK(false, "Inconsistent gradOutput size. oT= ", oT, ", oH= ", @@ -447,7 +447,7 @@ at::Tensor& max_unpooling2d_backward_out_cuda(const Tensor& grad_output_, nInputRows = self.size(dimh); if (oheight != grad_output.size(dimh) || owidth != grad_output.size(dimw)) { - AT_ERROR( + TORCH_CHECK(false, "Inconsistent gradOutput size. output height: ", oheight, ", output width= ", diff --git a/aten/src/ATen/native/cuda/MixedDtypesLinear.cu b/aten/src/ATen/native/cuda/MixedDtypesLinear.cu index f5c36d5694928..42b3dc5545d46 100644 --- a/aten/src/ATen/native/cuda/MixedDtypesLinear.cu +++ b/aten/src/ATen/native/cuda/MixedDtypesLinear.cu @@ -164,7 +164,7 @@ mixed_dtypes_linear_dispatch_bias_activation( ElementInputB, fastertransformer::EpilogueOpNoBias>(input, weight, scale, bias); } - AT_ERROR("mixed_dtypes_linear_dispatch_bias_activation: Activation \"", + TORCH_CHECK(false, "mixed_dtypes_linear_dispatch_bias_activation: Activation \"", activation, "\" is not supported"); return Tensor{}; } @@ -185,7 +185,7 @@ mixed_dtypes_linear_dispatch_bias_activation( ElementInputB, fastertransformer::EpilogueOpBiasSilu>(input, weight, scale, bias); } - AT_ERROR("mixed_dtypes_linear_dispatch_bias_activation: Activation \"", + TORCH_CHECK(false, "mixed_dtypes_linear_dispatch_bias_activation: Activation \"", activation, "\" is not supported"); return Tensor{}; } @@ -198,7 +198,7 @@ _mixed_dtypes_linear(const Tensor& input, const Tensor& weight, const std::optional& bias_opt, const std::optional activation_opt) { #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) - AT_ERROR("_mixed_dtypes_linear: not compiled for this platform"); + TORCH_CHECK(false, "_mixed_dtypes_linear: not compiled for this platform"); return Tensor{}; #else const auto bias = bias_opt.has_value() ? *bias_opt : Tensor{}; diff --git a/aten/src/ATen/native/cuda/NaiveConvolutionTranspose2d.cu b/aten/src/ATen/native/cuda/NaiveConvolutionTranspose2d.cu index 247b1728badea..5e82fe1bd3c21 100644 --- a/aten/src/ATen/native/cuda/NaiveConvolutionTranspose2d.cu +++ b/aten/src/ATen/native/cuda/NaiveConvolutionTranspose2d.cu @@ -88,7 +88,7 @@ static inline void slow_conv_transpose2d_shape_check( check_dim_size(bias, 1, 0, weight.size(1)); } } else if (!weight_nullable) { - AT_ERROR("weight tensor is expected to be non-nullable"); + TORCH_CHECK(false, "weight tensor is expected to be non-nullable"); } int ndim = input.dim(); @@ -115,7 +115,7 @@ static inline void slow_conv_transpose2d_shape_check( (dilation_width * (kernel_width - 1) + 1) + output_padding_width; if (output_width < 1 || output_height < 1) { - AT_ERROR( + TORCH_CHECK(false, "Given input size per channel: (", input_height, " x ", diff --git a/aten/src/ATen/native/cuda/NaiveConvolutionTranspose3d.cu b/aten/src/ATen/native/cuda/NaiveConvolutionTranspose3d.cu index 56b762a051fbf..20f10ed3b264f 100644 --- a/aten/src/ATen/native/cuda/NaiveConvolutionTranspose3d.cu +++ b/aten/src/ATen/native/cuda/NaiveConvolutionTranspose3d.cu @@ -106,7 +106,7 @@ static inline void slow_conv_transpose3d_shape_check( check_dim_size(bias, 1, 0, weight.size(1)); } } else if (!weight_nullable) { - AT_ERROR("weight tensor is expected to be non-nullable"); + TORCH_CHECK(false, "weight tensor is expected to be non-nullable"); } int ndim = input.dim(); @@ -140,7 +140,7 @@ static inline void slow_conv_transpose3d_shape_check( (dilation_width * (kernel_width - 1) + 1) + output_padding_width; if (output_depth < 1 || output_width < 1 || output_height < 1) { - AT_ERROR( + TORCH_CHECK(false, "Given input size per channel: (", input_depth, " x ", diff --git a/aten/src/ATen/native/cuda/Normalization.cu b/aten/src/ATen/native/cuda/Normalization.cu index ae0908b3abac6..8db7241dee137 100644 --- a/aten/src/ATen/native/cuda/Normalization.cu +++ b/aten/src/ATen/native/cuda/Normalization.cu @@ -487,7 +487,7 @@ std::tuple _batch_norm_with_update_cuda( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); + const Tensor& bias = bias_opt.value_or(Tensor()); Tensor output, save_mean, save_var, reserve; BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps); @@ -513,7 +513,7 @@ std::tuple _batch_norm_with_update_cuda_out( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); + const Tensor& bias = bias_opt.value_or(Tensor()); BatchNormBackend backend = _select_batch_norm_backend(input, weight, bias, running_mean, running_var, /*training*/true, eps); if (backend == BatchNormBackend::Cudnn) { @@ -551,10 +551,10 @@ std::tuple _new_batch_norm_backward_cuda( const std::optional& save_mean_opt, const std::optional& save_var_opt, bool update, double eps, std::array grad_input_mask, const Tensor& reserve) { const Tensor& dummy_bias = at::empty(1); - const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); - const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();}); - const Tensor& save_var = c10::value_or_else(save_var_opt, [] {return Tensor();}); + const Tensor& running_mean = running_mean_opt.value_or(Tensor()); + const Tensor& running_var = running_var_opt.value_or(Tensor()); + const Tensor& save_mean = save_mean_opt.value_or(Tensor()); + const Tensor& save_var = save_var_opt.value_or(Tensor()); BatchNormBackend backend = _select_batch_norm_backend(input, weight, dummy_bias, running_mean, running_var, /*training*/true, eps); @@ -694,7 +694,7 @@ std::tuple batch_norm_gather_stats_cuda(const Tensor& self, cons // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt); const Tensor& running_mean = *running_mean_maybe_owned; - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); + const Tensor& running_var = running_var_opt.value_or(Tensor()); std::vector counts(mean.size(0), count); Tensor counts_ = at::from_blob((void*)counts.data(), {(int64_t)counts.size()}, self.options().dtype(at::kLong).device(at::kCPU)); @@ -708,7 +708,7 @@ std::tuple batch_norm_gather_stats_with_counts_cuda( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt); const Tensor& running_mean = *running_mean_maybe_owned; - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); + const Tensor& running_var = running_var_opt.value_or(Tensor()); auto scalar_type = running_mean.defined() ? running_mean.scalar_type() : self.scalar_type(); diff --git a/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu b/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu index 4f174bf0874f0..eee0047fd7295 100644 --- a/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu +++ b/aten/src/ATen/native/cuda/PointwiseOpsKernel.cu @@ -12,7 +12,7 @@ namespace at::native { #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050 -CONSTEXPR_EXCEPT_WIN_CUDA char addcmul_name[] = "addcmul"; +constexpr char addcmul_name[] = "addcmul"; #endif void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { auto dtype = iter.common_dtype(); @@ -59,7 +59,7 @@ void addcmul_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { #if AT_USE_JITERATOR() && CUDA_VERSION >= 11050 // return a + alpha * (b / static_cast(c)); -CONSTEXPR_EXCEPT_WIN_CUDA char addcdiv_name[] = "addcdiv"; +constexpr char addcdiv_name[] = "addcdiv"; #endif void addcdiv_cuda_kernel(TensorIteratorBase& iter, const Scalar& value) { auto dtype = iter.common_dtype(); diff --git a/aten/src/ATen/native/cuda/PowKernel.cu b/aten/src/ATen/native/cuda/PowKernel.cu index eb56da722fbb8..010818ca213aa 100644 --- a/aten/src/ATen/native/cuda/PowKernel.cu +++ b/aten/src/ATen/native/cuda/PowKernel.cu @@ -38,7 +38,7 @@ void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex base } /* complex support impl */ -CONSTEXPR_EXCEPT_WIN_CUDA char pow_scalar_base_name[] = "pow_scalar_base_kernel"; +constexpr char pow_scalar_base_name[] = "pow_scalar_base_kernel"; template <> void pow_scalar_tensor_impl(TensorIteratorBase& iter, c10::complex base) { using scalar_t = c10::complex; @@ -68,7 +68,7 @@ namespace { #if AT_USE_JITERATOR() /* complex support impl */ -CONSTEXPR_EXCEPT_WIN_CUDA char pow_name[] = "pow_kernel"; +constexpr char pow_name[] = "pow_kernel"; static const auto pow_kernel_string = jiterator_stringify(template T pow_kernel(T base, T exp) { return std::pow(base, exp); diff --git a/aten/src/ATen/native/cuda/RNN.cu b/aten/src/ATen/native/cuda/RNN.cu index 3b10a836c409e..53dd49909b1a6 100644 --- a/aten/src/ATen/native/cuda/RNN.cu +++ b/aten/src/ATen/native/cuda/RNN.cu @@ -520,7 +520,7 @@ std::tuple _thnn_fused_lstm_cell_cuda( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned input_bias_maybe_owned = at::borrow_from_optional_tensor(input_bias_opt); const Tensor& input_bias = *input_bias_maybe_owned; - const Tensor& hidden_bias = c10::value_or_else(hidden_bias_opt, [] {return Tensor();}); + const Tensor& hidden_bias = hidden_bias_opt.value_or(Tensor()); checkSizes("_thnn_fused_lstm_cell_cuda", {input_gates, "input_gates", 1}, {hidden_gates, "hidden_gates", 2}, @@ -570,7 +570,7 @@ std::tuple _thnn_fused_lstm_cell_backward_impl_cuda( con // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned grad_hy_maybe_owned = at::borrow_from_optional_tensor(grad_hy_opt); const Tensor& grad_hy = *grad_hy_maybe_owned; - const Tensor& grad_cy = c10::value_or_else(grad_cy_opt, [] {return Tensor();}); + const Tensor& grad_cy = grad_cy_opt.value_or(Tensor()); if (!grad_hy.defined() && !grad_cy.defined()) { return std::tuple(); @@ -606,7 +606,7 @@ std::tuple _thnn_fused_gru_cell_cuda( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned input_bias_maybe_owned = at::borrow_from_optional_tensor(input_bias_opt); const Tensor& input_bias = *input_bias_maybe_owned; - const Tensor& hidden_bias = c10::value_or_else(hidden_bias_opt, [] {return Tensor();}); + const Tensor& hidden_bias = hidden_bias_opt.value_or(Tensor()); checkSizes("_thnn_fused_gru_cell_cuda", {input_gates, "input_gates", 1}, {hidden_gates, "hidden_gates", 2}, diff --git a/aten/src/ATen/native/cuda/Reduce.cuh b/aten/src/ATen/native/cuda/Reduce.cuh index 5f03d7b9bda57..4baa3bd560a6d 100644 --- a/aten/src/ATen/native/cuda/Reduce.cuh +++ b/aten/src/ATen/native/cuda/Reduce.cuh @@ -1092,11 +1092,7 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ } constexpr int min_values_per_thread = 16; -#ifndef USE_ROCM constexpr int max_values_per_thread = 256; -#else - constexpr int max_values_per_thread = 1024; -#endif if (config.values_per_thread() >= block_height * 16 || config.values_per_thread() >= max_values_per_thread) { // Divide the input across warps in a thread-block, if that leaves at least @@ -1108,7 +1104,18 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ config.output_mult[1] = config.split_output(block_height); } - const int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor / config.num_threads; + int max_threads_per_mp = + at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor; +#ifdef USE_ROCM + // Control the number of threadblocks by adjusting the maximum number of + // threads per multi-processor. These numbers better reflect the maximum + // theoretical achievable threads per MP for the reduction operation. + if (iter.ndim() == 1) + max_threads_per_mp = 512; + if (iter.ndim() == 2) + max_threads_per_mp = 256; +#endif + const int blocks_per_sm = max_threads_per_mp / config.num_threads; const int num_mp = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; const int target_grid_size = num_mp * blocks_per_sm; int grid = config.grid().x; @@ -1126,6 +1133,23 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){ // a large number of values to deal with. But we don't want values_per_thread to be larger than // max_values_per_thread config.ctas_per_output = std::max(std::min(ctas_per_output1, ctas_per_output2), ctas_per_output3); +#ifdef USE_ROCM + // In cases where a number of threadblocks along the y direction of the grid + // is needed then make sure they are reduced to the number of MPs. For + // smaller sizes, use half the number of MPs. For smaller sizes than half + // the number of MPs use the original value unless the value is less than 16 + // blocks in which case it is more profitable to use just 1 block. + if (config.ctas_per_output > num_mp) + if (num_mp < 128) + config.ctas_per_output = + num_mp * (config.ctas_per_output > 512 ? 4 : 2); + else + config.ctas_per_output = num_mp; + else if (config.ctas_per_output > div_up(num_mp, 2)) + config.ctas_per_output = div_up(num_mp, 2); + else if (config.ctas_per_output < 16) + config.ctas_per_output = 1; +#endif if (config.ctas_per_output > 1) { config.input_mult[2] = config.split_input(config.ctas_per_output); } diff --git a/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu b/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu index e628e1916f9e6..dc2f0fa492a7a 100644 --- a/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceSumProdKernel.cu @@ -21,7 +21,7 @@ struct sum_functor { }; // jiterated specialization for `complex` -CONSTEXPR_EXCEPT_WIN_CUDA char sum_name[] = "sum"; +constexpr char sum_name[] = "sum"; template <> struct sum_functor> { // jiterator reduction fails on windows @@ -57,7 +57,7 @@ struct nansum_functor { } }; -CONSTEXPR_EXCEPT_WIN_CUDA char nansum_name[] = "nansum"; +constexpr char nansum_name[] = "nansum"; template struct nansum_functor_complex { #if AT_USE_JITERATOR() @@ -79,7 +79,7 @@ struct nansum_functor_complex { #endif }; -CONSTEXPR_EXCEPT_WIN_CUDA char prod_name[] = "prod"; +constexpr char prod_name[] = "prod"; template struct prod_functor { // jiterator reduction fails on windows diff --git a/aten/src/ATen/native/cuda/Resize.cpp b/aten/src/ATen/native/cuda/Resize.cpp index c11dd8dcc960e..e6f050603c641 100644 --- a/aten/src/ATen/native/cuda/Resize.cpp +++ b/aten/src/ATen/native/cuda/Resize.cpp @@ -30,7 +30,7 @@ void resize_bytes_cuda(StorageImpl* storage, size_t size_bytes) { c10::cuda::CUDAGuard guard(device.index()); at::DataPtr data = allocator->allocate(size_bytes); if (storage->data_ptr()) { - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); C10_CUDA_CHECK( cudaMemcpyAsync( diff --git a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu index f76d6bfb66a72..ccf9cb1bc3031 100644 --- a/aten/src/ATen/native/cuda/RowwiseScaledMM.cu +++ b/aten/src/ATen/native/cuda/RowwiseScaledMM.cu @@ -141,7 +141,8 @@ void f8f8bf16_rowwise_impl( at::Tensor x_scale, at::Tensor w_scale, std::optional bias, - at::Tensor out) { + at::Tensor out, + const int swizzle) { int M = XQ.size(0); int N = WQ.size(1); int K = XQ.size(1); @@ -276,6 +277,9 @@ void f8f8bf16_rowwise_impl( // multiplication computation size_t workspace_size = Gemm::get_workspace_size(arguments); + // Set the swizzle size + arguments.scheduler.max_swizzle_size = swizzle; + // Allocate workspace memory auto workspace = XQ.new_empty( {static_cast(workspace_size)}, @@ -309,7 +313,8 @@ void dispatch_fp8_rowwise_kernel_on_tile_size( at::Tensor x_scale, at::Tensor w_scale, std::optional bias, - at::Tensor out) { + at::Tensor out, + const int swizzle) { int M = XQ.size(0); int N = WQ.size(1); @@ -323,13 +328,13 @@ void dispatch_fp8_rowwise_kernel_on_tile_size( /*TileShape=*/cute::Shape, ClusterShape, /*PingPong=*/std::false_type, - Types...>(XQ, WQ, x_scale, w_scale, bias, out); + Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle); } else { return f8f8bf16_rowwise_impl< /*TileShape=*/cute::Shape, ClusterShape, /*PingPong=*/std::true_type, - Types...>(XQ, WQ, x_scale, w_scale, bias, out); + Types...>(XQ, WQ, x_scale, w_scale, bias, out, swizzle); } } @@ -346,7 +351,8 @@ void handle_transposition( at::Tensor x_scale, at::Tensor w_scale, std::optional bias, - at::Tensor out) { + at::Tensor out, + const int swizzle=1) { if constexpr (!Transposed::value) { dispatch_fp8_rowwise_kernel_on_tile_size< ClusterShape, @@ -354,7 +360,7 @@ void handle_transposition( FastAccum, DtypeA, DtypeB, - DtypeBias>(XQ, WQ, x_scale, w_scale, bias, out); + DtypeBias>(XQ, WQ, x_scale, w_scale, bias, out, swizzle); } else { dispatch_fp8_rowwise_kernel_on_tile_size< ClusterShape, @@ -362,7 +368,7 @@ void handle_transposition( FastAccum, DtypeB, DtypeA, - DtypeBias>(WQ.t(), XQ.t(), w_scale.t(), x_scale.t(), bias, out.t()); + DtypeBias>(WQ.t(), XQ.t(), w_scale.t(), x_scale.t(), bias, out.t(), swizzle); } } @@ -438,6 +444,20 @@ void dispatch_fp8_rowwise_kernel_on_cluster_size_and_transpose( } // General case for large tensors. + + // Large M, N, k + if (M >= 4096 && N >= 4096) { + if (M >= N){ + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::false_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out, 8); + } + return handle_transposition< + /*ClusterShape=*/cute::Shape, + /*Transposed=*/std::true_type, + Types...>(XQ, WQ, x_scale, w_scale, bias, out, 8); + } if ((M <= N) ^ (M >= 2048 && N >= 2048)) { return handle_transposition< /*ClusterShape=*/cute::Shape, diff --git a/aten/src/ATen/native/cuda/ScanKernels.cpp b/aten/src/ATen/native/cuda/ScanKernels.cpp index 463ceb23bade5..65d7254f46912 100644 --- a/aten/src/ATen/native/cuda/ScanKernels.cpp +++ b/aten/src/ATen/native/cuda/ScanKernels.cpp @@ -12,8 +12,11 @@ #include #include #include +#include #include #include +#include +#include #endif namespace at::native { @@ -88,14 +91,44 @@ Tensor _logcumsumexp_cuda(const Tensor& self, int64_t dim) { return _logcumsumexp_out_cuda(self, dim, result); } -void cumsum_cuda_kernel(const Tensor& result, const Tensor& self, int64_t dim) { - if (self.is_floating_point() || self.is_complex()) { - // See Note [Writing Nondeterministic Operations] - // Issue reporting nondeterministic behavior: https://github.com/pytorch/pytorch/issues/75240 - globalContext().alertNotDeterministic("cumsum_cuda_kernel"); +int64_t canonicalize_dim(int64_t rank, int64_t idx) { + TORCH_INTERNAL_ASSERT(rank >= 0); + if (rank == 0) { + rank = 1; + } + if (idx >= 0 && idx < rank) { + return idx; + } + int64_t _idx = (idx < 0) ? (idx + rank) : idx; + TORCH_INTERNAL_ASSERT(!(_idx < 0 || _idx >= rank)); + return _idx; +} + +// This function is ported from the cumsum decomp in `torch/_refs`. +Tensor& cumsum_deterministic(Tensor& result, Tensor& self, int64_t dim) { + int64_t ndim = self.dim(); + dim = canonicalize_dim(ndim, dim); + if (ndim == 0) { + return at::sum_out(result, self.unsqueeze(0), /*dim=*/IntArrayRef{0}); } + self = self.unsqueeze(dim + 1); + Tensor rg = at::arange(self.size(dim), c10::TensorOptions().device(self.device())); + Tensor mask = rg.unsqueeze(1).le(rg); + for (int idx = 0; idx < (ndim - dim - 1); idx++) { + mask = mask.unsqueeze(-1); + } + Tensor masked_self = at::where(mask, self, 0); + return at::sum_out(result, masked_self, /*dim=*/IntArrayRef{dim}); +} + +void cumsum_cuda_kernel(const Tensor& result, const Tensor& self, int64_t dim) { auto result_ = contiguous_out_arg(result); - launch_cumsum_cuda_kernel(*result_, self, dim); + if ((self.is_floating_point() || self.is_complex()) && globalContext().deterministicAlgorithms()) { + // See Note [Enabling Deterministic Operations] + cumsum_deterministic(const_cast(*result_), const_cast(self), dim); + } else { + launch_cumsum_cuda_kernel(*result_, self, dim); + } if (!result.is_same(*result_)) { result.copy_(*result_); } diff --git a/aten/src/ATen/native/cuda/Shape.cu b/aten/src/ATen/native/cuda/Shape.cu index 61d2bd278981c..35593df59fa92 100644 --- a/aten/src/ATen/native/cuda/Shape.cu +++ b/aten/src/ATen/native/cuda/Shape.cu @@ -500,10 +500,21 @@ TORCH_IMPL_FUNC(cat_out_cuda) parallel_cat(result, materialized, dim, nDims, memory_format); }); } else { - AT_DISPATCH_V2(result.scalar_type(), "cat_cuda", AT_WRAP([&]() { - using dtype = OpaqueType; - parallel_cat(result, materialized, dim, nDims, memory_format); - }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + AT_DISPATCH_V2( + result.scalar_type(), + "cat_cuda", + AT_WRAP([&]() { + using dtype = OpaqueType; + parallel_cat( + result, materialized, dim, nDims, memory_format); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + kComplexHalf, + kHalf, + kBool, + kBFloat16, + AT_EXPAND(AT_FLOAT8_TYPES), + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } } else if (materialized.size() > 1 && result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS && @@ -518,10 +529,27 @@ TORCH_IMPL_FUNC(cat_out_cuda) parallel_cat(result, materialized, dim, nDims, memory_format); }); } else { - AT_DISPATCH_V2(result.scalar_type(), "cat_cuda", AT_WRAP([&]() { - using dtype = OpaqueType; - parallel_cat(result, materialized, dim, nDims, memory_format); - }), AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), kComplexHalf, kHalf, kBool, kBFloat16, AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); + AT_DISPATCH_V2( + result.scalar_type(), + "cat_cuda", + AT_WRAP([&]() { + using dtype = OpaqueType; + parallel_cat< + dtype, + CAT_ARRAY_BATCH_SIZE / 2, + CAT_ARRAY_BATCH_SIZE / 2>( + result, materialized, dim, nDims, memory_format); + }), + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + kComplexHalf, + kHalf, + kBool, + kBFloat16, + kFloat8_e4m3fn, + kFloat8_e4m3fnuz, + kFloat8_e5m2, + kFloat8_e5m2fnuz, + AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } } else { int64_t offset = 0; diff --git a/aten/src/ATen/native/cuda/Sorting.cu b/aten/src/ATen/native/cuda/Sorting.cu index 290be3926c6ff..75c25603388eb 100644 --- a/aten/src/ATen/native/cuda/Sorting.cu +++ b/aten/src/ATen/native/cuda/Sorting.cu @@ -177,14 +177,14 @@ struct KthValueLauncher { cuda::detail::TensorInfo values_info, int collapse_values_dim, cuda::detail::TensorInfo indices_info, - C10_UNUSED int collapse_indices_dim, + [[maybe_unused]] int collapse_indices_dim, cuda::detail::TensorInfo self_info, int collapse_self_dim, int64_t num_slices, int64_t slice_size) { dim3 grid; if (!getGridFromTiles(num_slices, grid)) { - AT_ERROR("slices are too many"); + TORCH_CHECK(false, "slices are too many"); } dim3 block(std::min( @@ -212,16 +212,16 @@ struct MedianLauncher { template inline void launch( cuda::detail::TensorInfo values_info, - C10_UNUSED int collapse_values_dim, + [[maybe_unused]] int collapse_values_dim, cuda::detail::TensorInfo indices_info, - C10_UNUSED int collapse_indices_dim, + [[maybe_unused]] int collapse_indices_dim, cuda::detail::TensorInfo self_info, int collapse_self_dim, int64_t num_slices, int64_t slice_size) { dim3 grid; if (!getGridFromTiles(num_slices, grid)) { - AT_ERROR("slices are too many"); + TORCH_CHECK(false, "slices are too many"); } dim3 block(std::min( diff --git a/aten/src/ATen/native/cuda/SparseMM.cu b/aten/src/ATen/native/cuda/SparseMM.cu index 78bc554b52e0c..c3fd93ad541e1 100644 --- a/aten/src/ATen/native/cuda/SparseMM.cu +++ b/aten/src/ATen/native/cuda/SparseMM.cu @@ -12,10 +12,10 @@ namespace at::native { // sparse, sparse, sparse, dense, real, real -> sparse Tensor& _sspaddmm_out_only_sparse_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Tensor& result) { - AT_ERROR("tensor.sspaddmm(...) can only be called on sparse tensors"); + TORCH_CHECK(false, "tensor.sspaddmm(...) can only be called on sparse tensors"); } Tensor& _sspaddmm_out_cuda(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Tensor& result) { - AT_ERROR("NYI: CUDA sspaddmm is not implemented"); + TORCH_CHECK(false, "NYI: CUDA sspaddmm is not implemented"); } } // namespace at::native diff --git a/aten/src/ATen/native/cuda/SummaryOps.cu b/aten/src/ATen/native/cuda/SummaryOps.cu index b7a45df29d00b..f9ceb9bdf7e0b 100644 --- a/aten/src/ATen/native/cuda/SummaryOps.cu +++ b/aten/src/ATen/native/cuda/SummaryOps.cu @@ -251,7 +251,7 @@ Tensor _bincount_cuda_template( const Tensor& weights, int64_t minlength) { if (minlength < 0) { - AT_ERROR("minlength should be >= 0"); + TORCH_CHECK(false, "minlength should be >= 0"); } if (self.dim() == 1 && self.numel() == 0) { return at::zeros( @@ -264,12 +264,12 @@ Tensor _bincount_cuda_template( if (self.dim() != 1 || (!std::is_same_v && *self.min().cpu().const_data_ptr() < 0)) { - AT_ERROR("bincount only supports 1-d non-negative integral inputs."); + TORCH_CHECK(false, "bincount only supports 1-d non-negative integral inputs."); } bool has_weights = weights.defined(); if (has_weights && (weights.dim() != 1 || weights.size(0) != self.size(0))) { - AT_ERROR("weights should be 1-d and have the same length as input"); + TORCH_CHECK(false, "weights should be 1-d and have the same length as input"); } const int64_t nbins = @@ -312,7 +312,7 @@ Tensor _histc_cuda_template( at::acc_type min, at::acc_type max) { if (nbins <= 0) { - AT_ERROR("bins must be > 0"); + TORCH_CHECK(false, "bins must be > 0"); } Tensor output = at::zeros( {nbins}, @@ -387,7 +387,7 @@ Tensor _histc_cuda( const Scalar& min, const Scalar& max) { if (self.scalar_type() == ScalarType::Half) { - AT_ERROR("HalfTensor is not supported"); + TORCH_CHECK(false, "HalfTensor is not supported"); } // See Note [Writing Nondeterministic Operations] // Nondeterministic because of atomicAdd usage diff --git a/aten/src/ATen/native/cuda/UnaryComplexKernels.cu b/aten/src/ATen/native/cuda/UnaryComplexKernels.cu index 14c4e934c69b5..960414f63cda5 100644 --- a/aten/src/ATen/native/cuda/UnaryComplexKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryComplexKernels.cu @@ -26,7 +26,7 @@ __host__ __device__ static inline c10::complex angle_wrapper(c10::complex } #if AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char angle_name[] = "angle_kernel"; +constexpr char angle_name[] = "angle_kernel"; #endif void angle_kernel_cuda(TensorIteratorBase& iter) { @@ -63,7 +63,7 @@ void angle_kernel_cuda(TensorIteratorBase& iter) { } // NB: Ignores the negative bit on tensors -CONSTEXPR_EXCEPT_WIN_CUDA char conj_name[] = "conj_kernel"; +constexpr char conj_name[] = "conj_kernel"; void conj_kernel_cuda(TensorIteratorBase& iter) { auto conj_chalf = [&] { using scalar_t = c10::complex; diff --git a/aten/src/ATen/native/cuda/UnaryGammaKernels.cu b/aten/src/ATen/native/cuda/UnaryGammaKernels.cu index 34ccfa298310e..6448335002cdd 100644 --- a/aten/src/ATen/native/cuda/UnaryGammaKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryGammaKernels.cu @@ -13,7 +13,7 @@ namespace at::native { #if AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char digamma_name[] = "digamma"; +constexpr char digamma_name[] = "digamma"; #endif // AT_USE_JITERATOR() // See note [Jiterator] void digamma_kernel_cuda(TensorIteratorBase& iter) { @@ -40,7 +40,7 @@ void digamma_kernel_cuda(TensorIteratorBase& iter) { } // See note [Jiterator] -CONSTEXPR_EXCEPT_WIN_CUDA char trigamma_name[] = "trigamma"; +constexpr char trigamma_name[] = "trigamma"; void trigamma_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_FLOATING_TYPES_AND2( @@ -64,7 +64,7 @@ void trigamma_kernel_cuda(TensorIteratorBase& iter) { #endif // AT_USE_JITERATOR() } -CONSTEXPR_EXCEPT_WIN_CUDA char polygamma_name[] = "polygamma"; +constexpr char polygamma_name[] = "polygamma"; void polygamma_kernel_cuda(TensorIteratorBase& iter, int64_t n) { if (n == 0) { digamma_kernel_cuda(iter); @@ -101,7 +101,7 @@ void polygamma_kernel_cuda(TensorIteratorBase& iter, int64_t n) { } } -CONSTEXPR_EXCEPT_WIN_CUDA char lgamma_name[] = "lgamma_kernel"; +constexpr char lgamma_name[] = "lgamma_kernel"; void lgamma_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_FLOATING_TYPES_AND2( diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAcosKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAcosKernel.cu index 42ef6a9960cf4..bd779fed2ab43 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricAcosKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricAcosKernel.cu @@ -13,7 +13,7 @@ namespace at::native { #if 0 && AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char acos_name[] = "acos_impl"; +constexpr char acos_name[] = "acos_impl"; #endif void acos_kernel_cuda(TensorIteratorBase& iter) { auto common_dtype = iter.common_dtype(); diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAcoshKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAcoshKernel.cu index d621dd246aa49..ab178f6df1f27 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricAcoshKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricAcoshKernel.cu @@ -13,7 +13,7 @@ namespace at::native { #if 0 && AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char acosh_name[] = "acosh_impl"; +constexpr char acosh_name[] = "acosh_impl"; #endif void acosh_kernel_cuda(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAsinKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAsinKernel.cu index e9b16dd3d2b6d..97a4e2b46e823 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricAsinKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricAsinKernel.cu @@ -13,7 +13,7 @@ namespace at::native { #if 0 && AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char asin_name[] = "asin_impl"; +constexpr char asin_name[] = "asin_impl"; #endif void asin_kernel_cuda(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAsinhKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAsinhKernel.cu index 7494932f9d538..1a0b2ce9e38c6 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricAsinhKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricAsinhKernel.cu @@ -13,7 +13,7 @@ namespace at::native { #if 0 && AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char asinh_name[] = "asinh_impl"; +constexpr char asinh_name[] = "asinh_impl"; #endif void asinh_kernel_cuda(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAtanKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAtanKernel.cu index 758d7bc5c86de..5018ac8a31257 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricAtanKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricAtanKernel.cu @@ -13,7 +13,7 @@ namespace at::native { #if AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char atan_name[] = "atan_impl"; +constexpr char atan_name[] = "atan_impl"; #endif void atan_kernel_cuda(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/UnaryGeometricAtanhKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricAtanhKernel.cu index aad7775219af7..71b65815bfea9 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricAtanhKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricAtanhKernel.cu @@ -13,7 +13,7 @@ namespace at::native { #if AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char atanh_name[] = "atanh_impl"; +constexpr char atanh_name[] = "atanh_impl"; #endif void atanh_kernel_cuda(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/UnaryGeometricCosKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricCosKernel.cu index 2a994fb626af4..0cac6ff79c3b5 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricCosKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricCosKernel.cu @@ -13,7 +13,7 @@ namespace at::native { #if AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char cos_name[] = "cos_impl"; +constexpr char cos_name[] = "cos_impl"; #endif // AT_USE_JITERATOR() void cos_kernel_cuda(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/UnaryGeometricCoshKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricCoshKernel.cu index 49babec1378a3..a5e390c8ec392 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricCoshKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricCoshKernel.cu @@ -13,7 +13,7 @@ namespace at::native { #if AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char cosh_name[] = "cosh_impl"; +constexpr char cosh_name[] = "cosh_impl"; #endif void cosh_kernel_cuda(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/UnaryGeometricSinKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricSinKernel.cu index d87a190959781..3613192562e44 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricSinKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricSinKernel.cu @@ -13,7 +13,7 @@ namespace at::native { #if AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char sin_name[] = "sin_impl"; +constexpr char sin_name[] = "sin_impl"; #endif void sin_kernel_cuda(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/UnaryGeometricSinhKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricSinhKernel.cu index 82b730a0ffbc9..039700c21be02 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricSinhKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricSinhKernel.cu @@ -13,7 +13,7 @@ namespace at::native { #if AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char sinh_name[] = "sinh_impl"; +constexpr char sinh_name[] = "sinh_impl"; #endif void sinh_kernel_cuda(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu index 8f62529e8e095..a71588e551cf0 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricTanKernel.cu @@ -13,7 +13,7 @@ namespace at::native { #if AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char tan_name[] = "tan_impl"; +constexpr char tan_name[] = "tan_impl"; #endif void tan_kernel_cuda(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu b/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu index d5f0172015d5e..6a9f6a4cbdd67 100644 --- a/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryGeometricTanhKernel.cu @@ -13,7 +13,7 @@ namespace at::native { #if AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char tanh_name[] = "tanh_impl"; +constexpr char tanh_name[] = "tanh_impl"; #endif void tanh_kernel_cuda(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/cuda/UnaryLogKernels.cu b/aten/src/ATen/native/cuda/UnaryLogKernels.cu index 2a2f56670b78b..f213886319d35 100644 --- a/aten/src/ATen/native/cuda/UnaryLogKernels.cu +++ b/aten/src/ATen/native/cuda/UnaryLogKernels.cu @@ -13,7 +13,7 @@ namespace at::native { #if AT_USE_JITERATOR() -CONSTEXPR_EXCEPT_WIN_CUDA char log_name[] = "log_kernel"; +constexpr char log_name[] = "log_kernel"; #endif void log_kernel_cuda(TensorIteratorBase& iter) { @@ -47,7 +47,7 @@ void log_kernel_cuda(TensorIteratorBase& iter) { } } -CONSTEXPR_EXCEPT_WIN_CUDA char log10_name[] = "log10_kernel"; +constexpr char log10_name[] = "log10_kernel"; void log10_kernel_cuda(TensorIteratorBase& iter) { auto common_dtype = iter.common_dtype(); if (at::isComplexType(common_dtype)) { @@ -84,7 +84,7 @@ void log1p_kernel_cuda(TensorIteratorBase& iter) { }); } -CONSTEXPR_EXCEPT_WIN_CUDA char log2_name[] = "log2_kernel"; +constexpr char log2_name[] = "log2_kernel"; void log2_kernel_cuda(TensorIteratorBase& iter) { auto common_dtype = iter.common_dtype(); if (at::isComplexType(common_dtype)) { diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index b0d6f549ab24d..5eb64ab57258e 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -34,7 +34,7 @@ void bitwise_not_kernel_cuda(TensorIteratorBase& iter) { } } -CONSTEXPR_EXCEPT_WIN_CUDA char exp_name[] = "exp_kernel"; +constexpr char exp_name[] = "exp_kernel"; void exp_kernel_cuda(TensorIteratorBase& iter) { auto common_dtype = iter.common_dtype(); if (at::isComplexType(common_dtype)) { @@ -92,7 +92,7 @@ C10_HOST_DEVICE static inline c10::complex rsqrt_wrapper(c10::complex v) { return one / ::sqrt(v); } -CONSTEXPR_EXCEPT_WIN_CUDA char rsqrt_name[] = "rsqrt_kernel"; +constexpr char rsqrt_name[] = "rsqrt_kernel"; void rsqrt_kernel_cuda(TensorIteratorBase& iter) { auto common_dtype = iter.common_dtype(); if (at::isComplexType(common_dtype)) { @@ -131,7 +131,7 @@ void rsqrt_kernel_cuda(TensorIteratorBase& iter) { } } -CONSTEXPR_EXCEPT_WIN_CUDA char sqrt_name[] = "sqrt_kernel"; +constexpr char sqrt_name[] = "sqrt_kernel"; void sqrt_kernel_cuda(TensorIteratorBase& iter) { auto common_dtype = iter.common_dtype(); if (at::isComplexType(common_dtype)) { diff --git a/aten/src/ATen/native/cuda/UnarySignKernels.cu b/aten/src/ATen/native/cuda/UnarySignKernels.cu index 83233f3143cba..2a811e314c2cc 100644 --- a/aten/src/ATen/native/cuda/UnarySignKernels.cu +++ b/aten/src/ATen/native/cuda/UnarySignKernels.cu @@ -25,7 +25,7 @@ void logical_not_kernel_cuda(TensorIteratorBase& iter) { } // NB: Ignores the negative bit on tensors -CONSTEXPR_EXCEPT_WIN_CUDA char neg_name[] = "neg_kernel"; +constexpr char neg_name[] = "neg_kernel"; void neg_kernel_cuda(TensorIteratorBase& iter) { auto dtype = iter.dtype(); if (at::isComplexType(dtype)) { @@ -96,7 +96,7 @@ C10_HOST_DEVICE static inline c10::complex sgn_wrapper(c10::complex z) { } } -CONSTEXPR_EXCEPT_WIN_CUDA char sgn_name[] = "sgn_kernel"; +constexpr char sgn_name[] = "sgn_kernel"; void sgn_kernel_cuda(TensorIteratorBase& iter){ auto dtype = iter.dtype(); #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu b/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu index 38df2106eddb5..af560d8e9a50a 100644 --- a/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu @@ -19,7 +19,7 @@ namespace at::native { -CONSTEXPR_EXCEPT_WIN_CUDA char exp2_name[] = "exp2_kernel"; +constexpr char exp2_name[] = "exp2_kernel"; void exp2_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( @@ -41,7 +41,7 @@ void exp2_kernel_cuda(TensorIteratorBase& iter) { #endif } -CONSTEXPR_EXCEPT_WIN_CUDA char i0_name[] = "i0"; +constexpr char i0_name[] = "i0"; void i0_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0_cuda", [&]() { @@ -63,7 +63,7 @@ void i0_kernel_cuda(TensorIteratorBase& iter) { } // See note [Jiterator] -CONSTEXPR_EXCEPT_WIN_CUDA char i0e_name[] = "calc_i0e"; +constexpr char i0e_name[] = "calc_i0e"; void i0e_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0e_cuda", [&]() { @@ -84,17 +84,17 @@ void i0e_kernel_cuda(TensorIteratorBase& iter) { // See note [Jiterator] -CONSTEXPR_EXCEPT_WIN_CUDA char i1_name[] = "i1"; +constexpr char i1_name[] = "i1"; void i1_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() - AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i1_cuda", [&]() { jitted_gpu_kernel(iter, i1_string); }); #else - AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i1_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_i1(a); }); @@ -102,17 +102,17 @@ void i1_kernel_cuda(TensorIteratorBase& iter) { #endif // AT_USE_JITERATOR() } -CONSTEXPR_EXCEPT_WIN_CUDA char i1e_name[] = "i1e"; +constexpr char i1e_name[] = "i1e"; void i1e_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() - AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1e_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i1e_cuda", [&]() { jitted_gpu_kernel(iter, i1e_string); }); #else - AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1e_cuda", [&]() { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i1e_cuda", [&]() { gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_i1e(a); }); @@ -120,7 +120,7 @@ void i1e_kernel_cuda(TensorIteratorBase& iter) { #endif } -CONSTEXPR_EXCEPT_WIN_CUDA char sigmoid_name[] = "sigmoid"; +constexpr char sigmoid_name[] = "sigmoid"; void sigmoid_kernel_cuda(TensorIteratorBase& iter) { auto common_dtype = iter.common_dtype(); if (at::isComplexType(common_dtype)) { @@ -159,7 +159,7 @@ void sigmoid_kernel_cuda(TensorIteratorBase& iter) { } } -CONSTEXPR_EXCEPT_WIN_CUDA char sinc_name[] = "sinc"; +constexpr char sinc_name[] = "sinc"; void sinc_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( @@ -217,7 +217,7 @@ void logit_kernel_cuda(TensorIteratorBase& iter, const Scalar& eps_scalar) { }); } -CONSTEXPR_EXCEPT_WIN_CUDA char ndtri_name[] = "ndtri"; +constexpr char ndtri_name[] = "ndtri"; void ndtri_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "ndtri_cuda", [&]() { @@ -234,7 +234,7 @@ void ndtri_kernel_cuda(TensorIteratorBase& iter) { #endif } -CONSTEXPR_EXCEPT_WIN_CUDA char log_ndtr_name[] = "log_ndtr"; +constexpr char log_ndtr_name[] = "log_ndtr"; void log_ndtr_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "log_ndtr_cuda", [&]() { @@ -259,7 +259,7 @@ void erf_kernel_cuda(TensorIteratorBase& iter) { }); } -CONSTEXPR_EXCEPT_WIN_CUDA char erfc_name[] = "erfc_kernel"; +constexpr char erfc_name[] = "erfc_kernel"; void erfc_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "erfc_cuda", [&]() { @@ -278,7 +278,7 @@ void erfc_kernel_cuda(TensorIteratorBase& iter) { #endif } -CONSTEXPR_EXCEPT_WIN_CUDA char erfinv_name[] = "erfinv_kernel"; +constexpr char erfinv_name[] = "erfinv_kernel"; void erfinv_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "erfinv_cuda", [&]() { @@ -297,7 +297,7 @@ void erfinv_kernel_cuda(TensorIteratorBase& iter) { #endif } -CONSTEXPR_EXCEPT_WIN_CUDA char erfcx_name[] = "erfcx"; +constexpr char erfcx_name[] = "erfcx"; void erfcx_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "erfcx_cuda", [&]() { @@ -314,7 +314,7 @@ void erfcx_kernel_cuda(TensorIteratorBase& iter) { #endif } -CONSTEXPR_EXCEPT_WIN_CUDA char kaiser_window_name[] = "kaiser_window"; +constexpr char kaiser_window_name[] = "kaiser_window"; void kaiser_window_kernel_cuda(TensorIteratorBase& iter, int64_t window_length, double beta_){ #if AT_USE_JITERATOR() AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){ @@ -348,7 +348,7 @@ void kaiser_window_kernel_cuda(TensorIteratorBase& iter, int64_t window_length, #endif } -CONSTEXPR_EXCEPT_WIN_CUDA char entr_name[] = "entr"; +constexpr char entr_name[] = "entr"; void entr_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "entr_cuda", [&]() { diff --git a/aten/src/ATen/native/cuda/ZetaKernel.cu b/aten/src/ATen/native/cuda/ZetaKernel.cu index 7459504f508cb..da536e8adbdd6 100644 --- a/aten/src/ATen/native/cuda/ZetaKernel.cu +++ b/aten/src/ATen/native/cuda/ZetaKernel.cu @@ -15,7 +15,7 @@ namespace { * See note [3-Clause BSD License for the Cephes Math Library]. */ // See note [Jiterator] -CONSTEXPR_EXCEPT_WIN_CUDA char zeta_name[] = "zeta"; +constexpr char zeta_name[] = "zeta"; void zeta_kernel_cuda(TensorIteratorBase& iter) { #if AT_USE_JITERATOR() AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "zeta_cuda", [&]() { diff --git a/aten/src/ATen/native/cuda/airy_ai.cu b/aten/src/ATen/native/cuda/airy_ai.cu index 35e6b002260c2..05257c99b1b22 100644 --- a/aten/src/ATen/native/cuda/airy_ai.cu +++ b/aten/src/ATen/native/cuda/airy_ai.cu @@ -20,7 +20,7 @@ namespace at::native { namespace { -CONSTEXPR_EXCEPT_WIN_CUDA char airy_ai_name[] = "airy_ai_forward"; +constexpr char airy_ai_name[] = "airy_ai_forward"; void airy_ai_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/bessel_j0.cu b/aten/src/ATen/native/cuda/bessel_j0.cu index 2ebfe676e50b9..a3d9b668e9556 100644 --- a/aten/src/ATen/native/cuda/bessel_j0.cu +++ b/aten/src/ATen/native/cuda/bessel_j0.cu @@ -20,7 +20,7 @@ namespace at::native { namespace { -CONSTEXPR_EXCEPT_WIN_CUDA char bessel_j0_name[] = "bessel_j0_forward"; +constexpr char bessel_j0_name[] = "bessel_j0_forward"; void bessel_j0_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/bessel_j1.cu b/aten/src/ATen/native/cuda/bessel_j1.cu index 42bd43321f40b..674fcadfdff1a 100644 --- a/aten/src/ATen/native/cuda/bessel_j1.cu +++ b/aten/src/ATen/native/cuda/bessel_j1.cu @@ -20,7 +20,7 @@ namespace at::native { namespace { -CONSTEXPR_EXCEPT_WIN_CUDA char bessel_j1_name[] = "bessel_j1_forward"; +constexpr char bessel_j1_name[] = "bessel_j1_forward"; void bessel_j1_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/bessel_y0.cu b/aten/src/ATen/native/cuda/bessel_y0.cu index 631031d4e26c5..344ea38765227 100644 --- a/aten/src/ATen/native/cuda/bessel_y0.cu +++ b/aten/src/ATen/native/cuda/bessel_y0.cu @@ -20,7 +20,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char bessel_y0_name[] = "bessel_y0_forward"; + constexpr char bessel_y0_name[] = "bessel_y0_forward"; void bessel_y0_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/bessel_y1.cu b/aten/src/ATen/native/cuda/bessel_y1.cu index 1375061e43e08..32433a22b0bbc 100644 --- a/aten/src/ATen/native/cuda/bessel_y1.cu +++ b/aten/src/ATen/native/cuda/bessel_y1.cu @@ -20,7 +20,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char bessel_y1_name[] = "bessel_y1_forward"; + constexpr char bessel_y1_name[] = "bessel_y1_forward"; void bessel_y1_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/chebyshev_polynomial_t.cu b/aten/src/ATen/native/cuda/chebyshev_polynomial_t.cu index 7736d20e01887..a84e0c5050e0c 100644 --- a/aten/src/ATen/native/cuda/chebyshev_polynomial_t.cu +++ b/aten/src/ATen/native/cuda/chebyshev_polynomial_t.cu @@ -10,7 +10,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char chebyshev_polynomial_t_name[] = "chebyshev_polynomial_t_forward"; + constexpr char chebyshev_polynomial_t_name[] = "chebyshev_polynomial_t_forward"; void chebyshev_polynomial_t_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/chebyshev_polynomial_u.cu b/aten/src/ATen/native/cuda/chebyshev_polynomial_u.cu index 412479e11f491..9ec870fd130a8 100644 --- a/aten/src/ATen/native/cuda/chebyshev_polynomial_u.cu +++ b/aten/src/ATen/native/cuda/chebyshev_polynomial_u.cu @@ -10,7 +10,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char chebyshev_polynomial_u_name[] = "chebyshev_polynomial_u_forward"; + constexpr char chebyshev_polynomial_u_name[] = "chebyshev_polynomial_u_forward"; void chebyshev_polynomial_u_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/chebyshev_polynomial_v.cu b/aten/src/ATen/native/cuda/chebyshev_polynomial_v.cu index ca2e534e641b6..7f393d9d674de 100644 --- a/aten/src/ATen/native/cuda/chebyshev_polynomial_v.cu +++ b/aten/src/ATen/native/cuda/chebyshev_polynomial_v.cu @@ -10,7 +10,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char chebyshev_polynomial_v_name[] = "chebyshev_polynomial_v_forward"; + constexpr char chebyshev_polynomial_v_name[] = "chebyshev_polynomial_v_forward"; void chebyshev_polynomial_v_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/chebyshev_polynomial_w.cu b/aten/src/ATen/native/cuda/chebyshev_polynomial_w.cu index 9d5a0e3a7bd33..9897213ee97d2 100644 --- a/aten/src/ATen/native/cuda/chebyshev_polynomial_w.cu +++ b/aten/src/ATen/native/cuda/chebyshev_polynomial_w.cu @@ -10,7 +10,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char chebyshev_polynomial_w_name[] = "chebyshev_polynomial_w_forward"; + constexpr char chebyshev_polynomial_w_name[] = "chebyshev_polynomial_w_forward"; void chebyshev_polynomial_w_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/hermite_polynomial_h.cu b/aten/src/ATen/native/cuda/hermite_polynomial_h.cu index f53253bcd0994..d581e38bbefef 100644 --- a/aten/src/ATen/native/cuda/hermite_polynomial_h.cu +++ b/aten/src/ATen/native/cuda/hermite_polynomial_h.cu @@ -10,7 +10,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char hermite_polynomial_h_name[] = "hermite_polynomial_h_forward"; + constexpr char hermite_polynomial_h_name[] = "hermite_polynomial_h_forward"; void hermite_polynomial_h_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/hermite_polynomial_he.cu b/aten/src/ATen/native/cuda/hermite_polynomial_he.cu index bab376565858a..b5b1891b80cf8 100644 --- a/aten/src/ATen/native/cuda/hermite_polynomial_he.cu +++ b/aten/src/ATen/native/cuda/hermite_polynomial_he.cu @@ -10,7 +10,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char hermite_polynomial_he_name[] = "hermite_polynomial_he_forward"; + constexpr char hermite_polynomial_he_name[] = "hermite_polynomial_he_forward"; void hermite_polynomial_he_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/jit_utils.h b/aten/src/ATen/native/cuda/jit_utils.h index 575c51c96db36..bee02105c0f3b 100644 --- a/aten/src/ATen/native/cuda/jit_utils.h +++ b/aten/src/ATen/native/cuda/jit_utils.h @@ -71,7 +71,7 @@ inline int can_vectorize_up_to(const KernelDescriptor &desc, c10::ArrayRef layer_norm_cuda( for (const auto idx: c10::irange(axis)) { stat_shape.push_back(input_shape[idx]); } - for (const auto C10_UNUSED idx: c10::irange(axis, input.dim())) { + for ([[maybe_unused]] const auto idx : c10::irange(axis, input.dim())) { stat_shape.push_back(1); } diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp index f85a343d8d685..8d67b6dc080cd 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp @@ -1158,7 +1158,7 @@ REGISTER_CUDA_DISPATCH(ldl_solve_stub, &ldl_solve_kernel) template static void apply_cholesky_solve(Tensor& b, Tensor& A, bool upper, int64_t& info) { #if !AT_MAGMA_ENABLED() -AT_ERROR("cholesky_solve: MAGMA library not found in " +TORCH_CHECK(false, "cholesky_solve: MAGMA library not found in " "compilation. Please rebuild with MAGMA."); #else magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower; @@ -1476,7 +1476,7 @@ template static void apply_lu_factor_looped_magma(const Tensor& input, const Tensor& pivots, const Tensor& infos, bool compute_pivots) { #if !AT_MAGMA_ENABLED() // This should never be thrown if the calling functions are correct. - AT_ERROR("linalg.lu_factor: PyTorch was not compiled with MAGMA support."); + TORCH_CHECK(false, "linalg.lu_factor: PyTorch was not compiled with MAGMA support."); #else // magmaLu and magmaLuNoPiv require infos and pivots tensor to be on CPU // the data is later copied back to the appropriate output tensor @@ -1677,7 +1677,7 @@ REGISTER_CUDA_DISPATCH(lu_factor_stub, &lu_factor); template static void apply_triangular_solve_batched_magma(const Tensor& A, const Tensor& b, bool left, bool upper, TransposeType transpose, bool unitriangular) { #if !AT_MAGMA_ENABLED() -AT_ERROR("triangular_solve: MAGMA library not found in " +TORCH_CHECK(false, "triangular_solve: MAGMA library not found in " "compilation. Please rebuild with MAGMA."); #else magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower; @@ -2106,7 +2106,7 @@ static void apply_svd_magma(const Tensor& A, const Tensor& Vh, const Tensor& info) { #if !AT_MAGMA_ENABLED() -AT_ERROR("linalg.svd: MAGMA library not found in " +TORCH_CHECK(false, "linalg.svd: MAGMA library not found in " "compilation. Please rebuild with MAGMA."); #else using value_t = typename c10::scalar_value_type::type; diff --git a/aten/src/ATen/native/cuda/linalg/MagmaUtils.h b/aten/src/ATen/native/cuda/linalg/MagmaUtils.h index 6e1b5a2659c23..7a293757c8a59 100644 --- a/aten/src/ATen/native/cuda/linalg/MagmaUtils.h +++ b/aten/src/ATen/native/cuda/linalg/MagmaUtils.h @@ -59,7 +59,7 @@ struct MAGMAQueue { static inline magma_int_t magma_int_cast(int64_t value, const char* varname) { auto result = static_cast(value); if (static_cast(result) != value) { - AT_ERROR("magma: The value of ", varname, "(", (long long)value, + TORCH_CHECK(false, "magma: The value of ", varname, "(", (long long)value, ") is too large to fit into a magma_int_t (", sizeof(magma_int_t), " bytes)"); } return result; diff --git a/aten/src/ATen/native/cuda/modified_bessel_i0.cu b/aten/src/ATen/native/cuda/modified_bessel_i0.cu index 9f1f3ba98c679..5d5e60c132c99 100644 --- a/aten/src/ATen/native/cuda/modified_bessel_i0.cu +++ b/aten/src/ATen/native/cuda/modified_bessel_i0.cu @@ -20,7 +20,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char modified_bessel_i0_name[] = "modified_bessel_i0_forward"; + constexpr char modified_bessel_i0_name[] = "modified_bessel_i0_forward"; void modified_bessel_i0_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/modified_bessel_i1.cu b/aten/src/ATen/native/cuda/modified_bessel_i1.cu index d51e7fefb0eb1..4576ce07042e6 100644 --- a/aten/src/ATen/native/cuda/modified_bessel_i1.cu +++ b/aten/src/ATen/native/cuda/modified_bessel_i1.cu @@ -20,7 +20,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char modified_bessel_i1_name[] = "modified_bessel_i1_forward"; + constexpr char modified_bessel_i1_name[] = "modified_bessel_i1_forward"; void modified_bessel_i1_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/modified_bessel_k0.cu b/aten/src/ATen/native/cuda/modified_bessel_k0.cu index 574268456c847..17de0d94a69a4 100644 --- a/aten/src/ATen/native/cuda/modified_bessel_k0.cu +++ b/aten/src/ATen/native/cuda/modified_bessel_k0.cu @@ -20,7 +20,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char modified_bessel_k0_name[] = "modified_bessel_k0_forward"; + constexpr char modified_bessel_k0_name[] = "modified_bessel_k0_forward"; void modified_bessel_k0_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/modified_bessel_k1.cu b/aten/src/ATen/native/cuda/modified_bessel_k1.cu index b3720d8e1ba98..a858ad52af6a9 100644 --- a/aten/src/ATen/native/cuda/modified_bessel_k1.cu +++ b/aten/src/ATen/native/cuda/modified_bessel_k1.cu @@ -20,7 +20,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char modified_bessel_k1_name[] = "modified_bessel_k1_forward"; + constexpr char modified_bessel_k1_name[] = "modified_bessel_k1_forward"; void modified_bessel_k1_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/scaled_modified_bessel_k0.cu b/aten/src/ATen/native/cuda/scaled_modified_bessel_k0.cu index ac2355e409ac2..880b6b54c1873 100644 --- a/aten/src/ATen/native/cuda/scaled_modified_bessel_k0.cu +++ b/aten/src/ATen/native/cuda/scaled_modified_bessel_k0.cu @@ -20,7 +20,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char scaled_modified_bessel_k0_name[] = "scaled_modified_bessel_k0_forward"; + constexpr char scaled_modified_bessel_k0_name[] = "scaled_modified_bessel_k0_forward"; void scaled_modified_bessel_k0_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/scaled_modified_bessel_k1.cu b/aten/src/ATen/native/cuda/scaled_modified_bessel_k1.cu index b1d8d2a41b62b..7e5c771dc80b1 100644 --- a/aten/src/ATen/native/cuda/scaled_modified_bessel_k1.cu +++ b/aten/src/ATen/native/cuda/scaled_modified_bessel_k1.cu @@ -20,7 +20,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char scaled_modified_bessel_k1_name[] = "scaled_modified_bessel_k1_forward"; + constexpr char scaled_modified_bessel_k1_name[] = "scaled_modified_bessel_k1_forward"; void scaled_modified_bessel_k1_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_t.cu b/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_t.cu index d86042030cd69..e08081495ecb0 100644 --- a/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_t.cu +++ b/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_t.cu @@ -10,7 +10,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char shifted_chebyshev_polynomial_t_name[] = "shifted_chebyshev_polynomial_t_forward"; + constexpr char shifted_chebyshev_polynomial_t_name[] = "shifted_chebyshev_polynomial_t_forward"; void shifted_chebyshev_polynomial_t_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_u.cu b/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_u.cu index a2e2cd485fdaf..12fe938334a20 100644 --- a/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_u.cu +++ b/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_u.cu @@ -10,7 +10,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char shifted_chebyshev_polynomial_u_name[] = "shifted_chebyshev_polynomial_u_forward"; + constexpr char shifted_chebyshev_polynomial_u_name[] = "shifted_chebyshev_polynomial_u_forward"; void shifted_chebyshev_polynomial_u_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_v.cu b/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_v.cu index 6e5404179ab93..19db5a5ed53dd 100644 --- a/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_v.cu +++ b/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_v.cu @@ -10,7 +10,7 @@ namespace at::native { namespace { -CONSTEXPR_EXCEPT_WIN_CUDA char shifted_chebyshev_polynomial_v_name[] = "shifted_chebyshev_polynomial_v_forward"; +constexpr char shifted_chebyshev_polynomial_v_name[] = "shifted_chebyshev_polynomial_v_forward"; void shifted_chebyshev_polynomial_v_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_w.cu b/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_w.cu index 3bfee57d14ee3..d53b026947a62 100644 --- a/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_w.cu +++ b/aten/src/ATen/native/cuda/shifted_chebyshev_polynomial_w.cu @@ -10,7 +10,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char shifted_chebyshev_polynomial_w_name[] = "shifted_chebyshev_polynomial_w_forward"; + constexpr char shifted_chebyshev_polynomial_w_name[] = "shifted_chebyshev_polynomial_w_forward"; void shifted_chebyshev_polynomial_w_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cuda/spherical_bessel_j0.cu b/aten/src/ATen/native/cuda/spherical_bessel_j0.cu index d0bf46e653946..14234b27e54e0 100644 --- a/aten/src/ATen/native/cuda/spherical_bessel_j0.cu +++ b/aten/src/ATen/native/cuda/spherical_bessel_j0.cu @@ -20,7 +20,7 @@ namespace at::native { namespace { - CONSTEXPR_EXCEPT_WIN_CUDA char spherical_bessel_j0_name[] = "spherical_bessel_j0_forward"; + constexpr char spherical_bessel_j0_name[] = "spherical_bessel_j0_forward"; void spherical_bessel_j0_kernel_cuda(TensorIteratorBase& iterator) { #if AT_USE_JITERATOR() diff --git a/aten/src/ATen/native/cudnn/AffineGridGenerator.cpp b/aten/src/ATen/native/cudnn/AffineGridGenerator.cpp index 3ee342a03e19e..f13c16b80312c 100644 --- a/aten/src/ATen/native/cudnn/AffineGridGenerator.cpp +++ b/aten/src/ATen/native/cudnn/AffineGridGenerator.cpp @@ -25,7 +25,8 @@ Tensor cudnn_affine_grid_generator_forward( int64_t C, int64_t H, int64_t W) { - AT_ERROR( + TORCH_CHECK( + false, "cudnn_affine_grid_generator_forward: ATen not compiled with cuDNN support"); } @@ -35,7 +36,8 @@ Tensor cudnn_affine_grid_generator_backward( int64_t C, int64_t H, int64_t W) { - AT_ERROR( + TORCH_CHECK( + false, "cudnn_affine_grid_generator_backward: ATen not compiled with cuDNN support"); } diff --git a/aten/src/ATen/native/cudnn/BatchNorm.cpp b/aten/src/ATen/native/cudnn/BatchNorm.cpp index 460a9b73dd2c5..c9e2fb361297d 100644 --- a/aten/src/ATen/native/cudnn/BatchNorm.cpp +++ b/aten/src/ATen/native/cudnn/BatchNorm.cpp @@ -25,7 +25,7 @@ std::tuple cudnn_batch_norm( bool training, double exponential_average_factor, double epsilon) { - AT_ERROR("cudnn_batch_norm: ATen not compiled with cuDNN support"); + TORCH_CHECK(false, "cudnn_batch_norm: ATen not compiled with cuDNN support"); } std::tuple cudnn_batch_norm_backward( @@ -38,13 +38,15 @@ std::tuple cudnn_batch_norm_backward( const std::optional& save_var_opt, double epsilon, const Tensor& reservedSpace) { - AT_ERROR("cudnn_batch_norm_backward: ATen not compiled with cuDNN support"); + TORCH_CHECK( + false, "cudnn_batch_norm_backward: ATen not compiled with cuDNN support"); } size_t _get_cudnn_batch_norm_reserve_space_size( const Tensor& input_t, bool training) { - AT_ERROR( + TORCH_CHECK( + false, "_get_cudnn_batch_norm_reserve_space_size: ATen not compiled with cuDNN support"); } @@ -131,10 +133,8 @@ std::tuple cudnn_batch_norm( c10::MaybeOwned bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt); const Tensor& bias_t = *bias_t_maybe_owned; - const Tensor& running_mean_t = - c10::value_or_else(running_mean_t_opt, [] { return Tensor(); }); - const Tensor& running_var_t = - c10::value_or_else(running_var_t_opt, [] { return Tensor(); }); + const Tensor& running_mean_t = running_mean_t_opt.value_or(Tensor()); + const Tensor& running_var_t = running_var_t_opt.value_or(Tensor()); TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2}, bias{bias_t, "bias", 3}, running_mean{running_mean_t, "running_mean", 4}, @@ -281,10 +281,8 @@ std::tuple cudnn_batch_norm_backward( double epsilon, const Tensor& reserveSpace) { // See [Note: hacky wrapper removal for optional tensor] - const Tensor& save_mean_t = - c10::value_or_else(save_mean_t_opt, [] { return Tensor(); }); - const Tensor& save_var_t = - c10::value_or_else(save_var_t_opt, [] { return Tensor(); }); + const Tensor& save_mean_t = save_mean_t_opt.value_or(Tensor()); + const Tensor& save_var_t = save_var_t_opt.value_or(Tensor()); // TODO: Is it worth it to have a contiguous call or maybe we should go with // whatever format is given here. diff --git a/aten/src/ATen/native/cudnn/ConvPlaceholders.cpp b/aten/src/ATen/native/cudnn/ConvPlaceholders.cpp index 349999e4544f9..7a6f401ab0203 100644 --- a/aten/src/ATen/native/cudnn/ConvPlaceholders.cpp +++ b/aten/src/ATen/native/cudnn/ConvPlaceholders.cpp @@ -35,7 +35,7 @@ at::Tensor cudnn_convolution( bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR("cudnn_convolution: ATen not compiled with cuDNN support"); + TORCH_CHECK(false, "cudnn_convolution: ATen not compiled with cuDNN support"); } at::Tensor& cudnn_convolution_out( @@ -49,7 +49,8 @@ at::Tensor& cudnn_convolution_out( bool deterministic, bool allow_tf32, Tensor& output_t) { - AT_ERROR("cudnn_convolution_out: ATen not compiled with cuDNN support"); + TORCH_CHECK( + false, "cudnn_convolution_out: ATen not compiled with cuDNN support"); } at::Tensor cudnn_convolution_backward_input( @@ -63,7 +64,8 @@ at::Tensor cudnn_convolution_backward_input( bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR( + TORCH_CHECK( + false, "cudnn_convolution_backward_input: ATen not compiled with cuDNN support"); } @@ -78,7 +80,8 @@ at::Tensor cudnn_convolution_backward_weight( bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR( + TORCH_CHECK( + false, "cudnn_convolution_backward_weight: ATen not compiled with cuDNN support"); } @@ -94,7 +97,9 @@ std::tuple cudnn_convolution_backward( bool deterministic, bool allow_tf32, std::array output_mask) { - AT_ERROR("cudnn_convolution_backward: ATen not compiled with cuDNN support"); + TORCH_CHECK( + false, + "cudnn_convolution_backward: ATen not compiled with cuDNN support"); } at::Tensor cudnn_convolution_transpose( @@ -108,7 +113,9 @@ at::Tensor cudnn_convolution_transpose( bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR("cudnn_convolution_transpose: ATen not compiled with cuDNN support"); + TORCH_CHECK( + false, + "cudnn_convolution_transpose: ATen not compiled with cuDNN support"); } at::Tensor cudnn_convolution_transpose_backward_input( @@ -121,7 +128,8 @@ at::Tensor cudnn_convolution_transpose_backward_input( bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR( + TORCH_CHECK( + false, "cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support"); } @@ -136,7 +144,8 @@ at::Tensor cudnn_convolution_transpose_backward_weight( bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR( + TORCH_CHECK( + false, "cudnn_convolution_transpose_backward_weight: ATen not compiled with cuDNN support"); } @@ -153,7 +162,8 @@ std::tuple cudnn_convolution_transpose_backward( bool deterministic, bool allow_tf32, std::array output_mask) { - AT_ERROR( + TORCH_CHECK( + false, "cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support"); } @@ -168,7 +178,8 @@ void raw_cudnn_convolution_forward_out( bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR( + TORCH_CHECK( + false, "raw_cudnn_convolution_forward_out: ATen not compiled with cuDNN support"); } @@ -183,7 +194,8 @@ void raw_cudnn_convolution_backward_input_out( bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR( + TORCH_CHECK( + false, "raw_cudnn_convolution_backward_input_out: ATen not compiled with cuDNN support"); } @@ -198,7 +210,8 @@ void raw_cudnn_convolution_backward_weight_out( bool benchmark, bool deterministic, bool allow_tf32) { - AT_ERROR( + TORCH_CHECK( + false, "raw_cudnn_convolution_backward_weight_out: ATen not compiled with cuDNN support"); } @@ -210,7 +223,8 @@ Tensor cudnn_convolution_relu( IntArrayRef padding, IntArrayRef dilation, int64_t groups) { - AT_ERROR("cudnn_convolution_relu: ATen not compiled with cuDNN support"); + TORCH_CHECK( + false, "cudnn_convolution_relu: ATen not compiled with cuDNN support"); } Tensor cudnn_convolution_add_relu( @@ -223,7 +237,9 @@ Tensor cudnn_convolution_add_relu( IntArrayRef padding, IntArrayRef dilation, int64_t groups) { - AT_ERROR("cudnn_convolution_add_relu: ATen not compiled with cuDNN support"); + TORCH_CHECK( + false, + "cudnn_convolution_add_relu: ATen not compiled with cuDNN support"); } #endif // AT_CUDNN_ENABLED diff --git a/aten/src/ATen/native/cudnn/Conv_v8.cpp b/aten/src/ATen/native/cudnn/Conv_v8.cpp index 4bd72735881f1..266e779aa319c 100644 --- a/aten/src/ATen/native/cudnn/Conv_v8.cpp +++ b/aten/src/ATen/native/cudnn/Conv_v8.cpp @@ -74,7 +74,7 @@ cudnn_frontend::Tensor getTensorDescriptorWithTypeVirtual( // Ubuntu-22+ if `libnvrtc.so` is not found on the system, which strictly // speaking is not necessary for usecases below See // https://github.com/pytorch/pytorch/issues/97041 - static C10_UNUSED auto cudnn_cnn_infer_handler = [] { + [[maybe_unused]] static auto cudnn_cnn_infer_handler = [] { void* handle = dlopen("libcudnn_cnn_infer.so.8", RTLD_LAZY); char* err = dlerror(); if (!handle) { diff --git a/aten/src/ATen/native/cudnn/GridSampler.cpp b/aten/src/ATen/native/cudnn/GridSampler.cpp index af6b13567e37c..3b5f5bd218bb5 100644 --- a/aten/src/ATen/native/cudnn/GridSampler.cpp +++ b/aten/src/ATen/native/cudnn/GridSampler.cpp @@ -21,14 +21,18 @@ namespace native { // See Note [ATen preprocessor philosophy] Tensor cudnn_grid_sampler_forward(const Tensor& input_t, const Tensor& grid_t) { - AT_ERROR("cudnn_grid_sampler_forward: ATen not compiled with cuDNN support"); + TORCH_CHECK( + false, + "cudnn_grid_sampler_forward: ATen not compiled with cuDNN support"); } std::tuple cudnn_grid_sampler_backward( const Tensor& input_t, const Tensor& grid_t, const Tensor& grad_output_t) { - AT_ERROR("cudnn_grid_sampler_backward: ATen not compiled with cuDNN support"); + TORCH_CHECK( + false, + "cudnn_grid_sampler_backward: ATen not compiled with cuDNN support"); } } // namespace native diff --git a/aten/src/ATen/native/cudnn/LossCTC.cpp b/aten/src/ATen/native/cudnn/LossCTC.cpp index d76fe7855e5ca..915fbed0f0660 100644 --- a/aten/src/ATen/native/cudnn/LossCTC.cpp +++ b/aten/src/ATen/native/cudnn/LossCTC.cpp @@ -55,7 +55,8 @@ std::tuple _cudnn_ctc_loss( int64_t BLANK, bool deterministic, bool zero_infinity) { - AT_ERROR("cudnn_ctc_loss: ATen not compiled with cuDNN >= 7 support"); + TORCH_CHECK( + false, "cudnn_ctc_loss: ATen not compiled with cuDNN >= 7 support"); } std::tuple _cudnn_ctc_loss_tensor( @@ -66,7 +67,8 @@ std::tuple _cudnn_ctc_loss_tensor( int64_t BLANK, bool deterministic, bool zero_infinity) { - AT_ERROR("cudnn_ctc_loss: ATen not compiled with cuDNN >= 8 support"); + TORCH_CHECK( + false, "cudnn_ctc_loss: ATen not compiled with cuDNN >= 8 support"); } } // namespace native diff --git a/aten/src/ATen/native/cudnn/RNN.cpp b/aten/src/ATen/native/cudnn/RNN.cpp index 704198cb7849b..f6526acaa61f6 100644 --- a/aten/src/ATen/native/cudnn/RNN.cpp +++ b/aten/src/ATen/native/cudnn/RNN.cpp @@ -44,7 +44,8 @@ Tensor _cudnn_rnn_flatten_weight( int64_t fn_num_layers, bool batch_first, bool fn_bidirectional) { - AT_ERROR("_cudnn_rnn_flatten_weight: ATen not compiled with cuDNN support"); + TORCH_CHECK( + false, "_cudnn_rnn_flatten_weight: ATen not compiled with cuDNN support"); } std::tuple _cudnn_rnn( @@ -64,7 +65,7 @@ std::tuple _cudnn_rnn( bool fn_bidirectional, IntArrayRef fn_batch_sizes, const std::optional& fn_dropout_state_opt) { - AT_ERROR("_cudnn_rnn: ATen not compiled with cuDNN support"); + TORCH_CHECK(false, "_cudnn_rnn: ATen not compiled with cuDNN support"); } std::tuple> _cudnn_rnn_backward( @@ -90,7 +91,8 @@ std::tuple> _cudnn_rnn_backward( const std::optional& dropout_state_opt, const Tensor& reserve, std::array output_mask) { - AT_ERROR("_cudnn_rnn_backward: ATen not compiled with cuDNN support"); + TORCH_CHECK( + false, "_cudnn_rnn_backward: ATen not compiled with cuDNN support"); } Tensor _cudnn_init_dropout_state( @@ -105,7 +107,8 @@ Tensor _cudnn_init_dropout_state( TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory( pin_memory); - AT_ERROR("_cudnn_init_dropout_state: ATen not compiled with cuDNN support"); + TORCH_CHECK( + false, "_cudnn_init_dropout_state: ATen not compiled with cuDNN support"); } } // namespace native @@ -181,7 +184,7 @@ struct RNNDescriptorParams { default: { std::ostringstream oss; oss << "unrecognized cuDNN RNN mode " << fn_mode; - AT_ERROR(oss.str()); + TORCH_CHECK(false, oss.str()); } } } @@ -583,7 +586,7 @@ int64_t _num_linear_layers(cudnnRNNMode_t mode) { case CUDNN_RNN_TANH: return 2; default: - AT_ERROR("unknown cuDNN RNN mode ", mode); + TORCH_CHECK(false, "unknown cuDNN RNN mode ", mode); } } @@ -1399,9 +1402,8 @@ std::tuple _cudnn_rnn( c10::MaybeOwned weight_buf_r_maybe_owned = at::borrow_from_optional_tensor(weight_buf_r_opt); const Tensor& weight_buf_r = *weight_buf_r_maybe_owned; - const Tensor& cx = c10::value_or_else(cx_opt, [] { return Tensor(); }); - const Tensor& fn_dropout_state = - c10::value_or_else(fn_dropout_state_opt, [] { return Tensor(); }); + const Tensor& cx = cx_opt.value_or(Tensor()); + const Tensor& fn_dropout_state = fn_dropout_state_opt.value_or(Tensor()); check_attributes(input_r, weight, {hx, cx}, /*check_dtype=*/true); auto input = input_r; @@ -2112,14 +2114,10 @@ std::tuple> _cudnn_rnn_backward( c10::MaybeOwned cx_maybe_owned = at::borrow_from_optional_tensor(cx_opt); const Tensor& cx = *cx_maybe_owned; - const Tensor& grad_output_r = - c10::value_or_else(grad_output_r_opt, [] { return Tensor(); }); - const Tensor& grad_hy_r = - c10::value_or_else(grad_hy_r_opt, [] { return Tensor(); }); - const Tensor& grad_cy_r = - c10::value_or_else(grad_cy_r_opt, [] { return Tensor(); }); - const Tensor& dropout_state = - c10::value_or_else(dropout_state_opt, [] { return Tensor(); }); + const Tensor& grad_output_r = grad_output_r_opt.value_or(Tensor()); + const Tensor& grad_hy_r = grad_hy_r_opt.value_or(Tensor()); + const Tensor& grad_cy_r = grad_cy_r_opt.value_or(Tensor()); + const Tensor& dropout_state = dropout_state_opt.value_or(Tensor()); if (!grad_output_r.defined() && !grad_hy_r.defined() && !grad_cy_r.defined()) { diff --git a/aten/src/ATen/native/group_norm.cpp b/aten/src/ATen/native/group_norm.cpp index 627fa71382e20..0971ddd3cf0df 100644 --- a/aten/src/ATen/native/group_norm.cpp +++ b/aten/src/ATen/native/group_norm.cpp @@ -72,7 +72,7 @@ std::tuple native_group_norm( c10::MaybeOwned gamma_maybe_owned = at::borrow_from_optional_tensor(gamma_opt); const Tensor& gamma = *gamma_maybe_owned; - const Tensor& beta = c10::value_or_else(beta_opt, [] { return Tensor(); }); + const Tensor& beta = beta_opt.value_or(Tensor()); // repeated check so expanded weights can call native_group_norm directly but // save mean and variance from forward @@ -185,7 +185,7 @@ Tensor group_norm( c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - const Tensor& bias = c10::value_or_else(bias_opt, [] { return Tensor(); }); + const Tensor& bias = bias_opt.value_or(Tensor()); const auto N = input.sym_size(0); const auto C = input.sym_size(1); @@ -224,7 +224,7 @@ std::tuple math_group_norm( c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - const Tensor& bias = c10::value_or_else(bias_opt, [] { return Tensor(); }); + const Tensor& bias = bias_opt.value_or(Tensor()); auto input_shape = input.sizes(); at::Tensor input_reshaped = input.view({1, N * group, N ? -1 : 1}); diff --git a/aten/src/ATen/native/hip/ck_gemm.h b/aten/src/ATen/native/hip/ck_gemm.h new file mode 100644 index 0000000000000..176cbabd5e01c --- /dev/null +++ b/aten/src/ATen/native/hip/ck_gemm.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include +namespace at::native { + + +template +inline void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + static_assert(false&&sizeof(Dtype),"at::cuda::blas_gemm_internal_ck: not implemented"); +} + +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(double)); +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(float)); +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)); +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)); + + + +} // namespace at::native diff --git a/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip b/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip new file mode 100644 index 0000000000000..dd1503de89cb1 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_gemm_bfloat16.hip @@ -0,0 +1,479 @@ +#undef __HIP_NO_HALF_CONVERSIONS__ + +#include +#include +#include + +template +using S = ck::Sequence; + +namespace at::native { + +void dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { + // If any of the shapes cant be tiled, we must use padding. + bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0)); + // Dispatch to best implementation. + // TODO add more configurations. Optimize. + bool transa_ = std::tolower(transa) != 'n'; + bool transb_ = std::tolower(transb) != 'n'; + + if (use_padding) { + if (m <= 128) { + if(transa_ && transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + true> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(transa_ && !transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + false> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(!transa_ && transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + true> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(!transa_ && !transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + false> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + + } else { + if(transa_ && transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + true> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(transa_ && !transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + false> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(!transa_ && transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + true> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(!transa_ && !transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + false> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + } + } else { + { + if(transa_ && transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + true, + true> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(transa_ && !transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + true, + false> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(!transa_ && transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + false, + true> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else if(!transa_ && !transb_) { + gemm_impl< + at::BFloat16, + 256, + 128, + 128, + 64, + 8, + 8, + 32, + 32, + 2, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 8, + 8, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 8, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + false, + false> + (CUDABLAS_GEMM_ARGS(at::BFloat16)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + } + } +} + + + +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { + dispatch_bfloat16_gemm(CUDABLAS_GEMM_ARGS(at::BFloat16)); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/hip/ck_gemm_float.hip b/aten/src/ATen/native/hip/ck_gemm_float.hip new file mode 100644 index 0000000000000..b8301a47981c6 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_gemm_float.hip @@ -0,0 +1,486 @@ +#undef __HIP_NO_HALF_CONVERSIONS__ + +#include +#include +#include + +template +using S = ck::Sequence; + +namespace at::native { + +void dispatch_float_gemm(CUDABLAS_GEMM_ARGTYPES(float)) { + // If any of the shapes cant be tiled, we must use padding. + bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0)); + // Dispatch to best implementation. + // TODO add more configurations. Optimize. + bool transa_ = std::tolower(transa) != 'n'; + bool transb_ = std::tolower(transb) != 'n'; + + if (use_padding) { + if (m <= 128) { + if(transa_ && transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + true> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(transa_ && !transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + false> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(!transa_ && transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + true> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(!transa_ && !transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + false> + (CUDABLAS_GEMM_ARGS(float)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + + } else { + + if(transa_ && transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + true> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(transa_ && !transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + true, + false> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(!transa_ && transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + true> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(!transa_ && !transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + true, + false, + false> + (CUDABLAS_GEMM_ARGS(float)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + } + } else { + { + if(transa_ && transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + true, + true> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(transa_ && !transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + true, + false> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(!transa_ && transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + false, + true> + (CUDABLAS_GEMM_ARGS(float)); + } + else if(!transa_ && !transb_) { + gemm_impl< + float, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 2, + 4, + 4, + 0, + S<8,32,1>, + S<0,2,1>, + S<0,2,1>, + 1, + 4, + 4, + 0, + 1, + 1, + S<1, 32, 1, 8>, + S<4>, + false, + false, + false> + (CUDABLAS_GEMM_ARGS(float)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + } + } +} + + + +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(float)) { + dispatch_float_gemm(CUDABLAS_GEMM_ARGS(float)); +} + +// temporarily put this here until we implement double support +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(double)) { + return; +} + +} // namespace at::native diff --git a/aten/src/ATen/native/hip/ck_gemm_half.hip b/aten/src/ATen/native/hip/ck_gemm_half.hip new file mode 100644 index 0000000000000..60b64ca275c54 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_gemm_half.hip @@ -0,0 +1,306 @@ +#undef __HIP_NO_HALF_CONVERSIONS__ + +#include +#include + +#include + +template +using S = ck::Sequence; + +namespace at::native { + +void dispatch_half_gemm(CUDABLAS_GEMM_ARGTYPES(at::Half)) { +#if 0 + // If any of the shapes cant be tiled, we must use padding. + bool use_padding = ((m % 256 != 0) || (n % 128 != 0) || (k % 64 != 0)); + // Dispatch to best implementation. + // TODO add more configurations. Optimize. + + bool transa_ = std::tolower(transa) != 'n'; + bool transb_ = std::tolower(transb) != 'n'; + + if (use_padding) { + if (m <= 128) { + if(transa_ && transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + true, + true> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(transa_ && !transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + true, + false> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(!transa_ && transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + false, + true> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(!transa_ && !transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + false, + false> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + + + + } else { + + if(transa_ && transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + true, + true> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(transa_ && !transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + true, + false> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(!transa_ && transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + false, + true>(CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(!transa_ && !transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + false, + false> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + } + } else { + { + if(transa_ && transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + 1, + true, + true, + true> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(transa_ && !transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 1, + true, + true, + false> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(!transa_ && transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2> + 1, + true, + false, + true> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else if(!transa_ && !transb_) { + gemm_impl< + at::Half, + 256, + 256, + 128, + 32, + 4, + 4, + 32, + 32, + 4, + 2, + S<8,32,1>, + S<1,0,2>, + S<1,0,2>, + 1, + true, + false, + false> + (CUDABLAS_GEMM_ARGS(at::Half)); + } + else { + TORCH_CHECK(false, "unreachable"); + } + } + } +#endif +} + +template <> +void gemm_internal_ck(CUDABLAS_GEMM_ARGTYPES(at::Half)) { + dispatch_half_gemm(CUDABLAS_GEMM_ARGS(at::Half)); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/hip/ck_gemm_template.h b/aten/src/ATen/native/hip/ck_gemm_template.h new file mode 100644 index 0000000000000..b9fc84956a06e --- /dev/null +++ b/aten/src/ATen/native/hip/ck_gemm_template.h @@ -0,0 +1,289 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#undef __HIP_NO_HALF_CONVERSIONS__ +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +// Define commonly used types. +template +using S = ck::Sequence; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +namespace at::native { + +template +struct CkMathType { + using dtype = T; +}; + +template <> +struct CkMathType { + using dtype = ck::bhalf_t; +}; + +template <> +struct CkMathType { + using dtype = ck::half_t; +}; + + +template +struct CkTensorLayout { + // default goes to row-wise for now + using a_layout = Row; + using b_layout = Row; +}; + +// True denotes transpose is necessary. Default is Col, so return Row +template <> +struct CkTensorLayout { + using a_layout = Col; + using b_layout = Col; +}; + + +template <> +struct CkTensorLayout { + using a_layout = Row; + using b_layout = Col; +}; + +template <> +struct CkTensorLayout { + using a_layout = Col; + using b_layout = Row; +}; + + +template <> +struct CkTensorLayout { + using a_layout = Row; + using b_layout = Row; +}; + + +// Elementwise Operators +struct AlphaBetaAdd +{ + AlphaBetaAdd(float alpha, float beta) : alpha_(alpha), beta_(beta){}; + + template + __host__ __device__ constexpr void operator()(C& c, const AB& ab) const; + + template<> + __host__ __device__ constexpr void operator() + (float& c, const float& ab) const + { + c = alpha_ * ab; + }; + + template<> + __host__ __device__ constexpr void operator() + (ck::bhalf_t& c, const ck::bhalf_t& ab) const + { + c = alpha_ * ab; + }; + + template<> + __host__ __device__ constexpr void operator() + (ck::half_t& c, const ck::half_t& ab) const + { + c = alpha_ * ab; + }; + + float alpha_; + // TODO: Leaving for now, will use later + float beta_; +}; + +template < + typename Dtype, + int BLOCK_SIZE, + int MBLOCK, + int NBLOCK, + int KBLOCK, + int AK1, + int BK1, + int MPER_XDL, + int NPER_XDL, + int MPER_WAVE, + int NPER_WAVE, + typename ABLOCK_CLUSTER_LENS, + typename ABLOCK_CLUSTER_ORDER, + typename ABLOCK_SRC_ORDER, + int ABLOCK_VECTOR_DIM, + int ABLOCK_SCALAR_VEC, + int ABLOCK_SCALAR_VEC_AK1, + bool ABLOCK_LDS_EXTRAM, + typename BBLOCK_CLUSTER_LENS, + typename BBLOCK_CLUSTER_ORDER, + typename BBLOCK_SRC_ORDER, + int BBLOCK_VECTOR_DIM, + int BBLOCK_SCALAR_VEC, + int BBLOCK_SCALAR_VEC_AK1, + bool BBLOCK_LDS_EXTRAN, + int CMPER_WAVE, + int CNPER_WAVE, + typename BLOCK_CLUSTER_LENS, + typename CDE_SCALAR_VEC, + bool PADDING = false, + bool TRANSA = false, + bool TRANSB = false> +void gemm_impl(CUDABLAS_GEMM_ARGTYPES(Dtype)) { + // Get input information. + int M = m; + int N = n; + int K = k; + + int StrideA = lda; + int StrideB = ldb; + int StrideC = ldc; + + int KBatch = 1; + + float falpha = alpha; + float fbeta = beta; + + using ADataType = typename CkMathType::dtype; + using BDataType = typename CkMathType::dtype; + using CDataType = typename CkMathType::dtype; + using DDataType = typename CkMathType::dtype; + + using AccDataType = float; + using CShuffleDataType = typename CkMathType::dtype; + + using ALayout = typename CkTensorLayout::a_layout; + using BLayout = typename CkTensorLayout::b_layout; + + using DLayout = Row; + using CLayout = Row; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using CElementOp = AlphaBetaAdd; + + + static constexpr auto GemmDefault = + ck::tensor_operation::device::GemmSpecialization::Default; + static constexpr auto GemmMNKPadding = + ck::tensor_operation::device::GemmSpecialization::MNKPadding; + static constexpr auto GemmSpec = PADDING ? GemmMNKPadding : GemmDefault; + + + using DeviceGemmInstance = + ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3, + CLayout, + ADataType, + BDataType, + ck::Tuple<>, + CDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CElementOp, + GemmSpec, + BLOCK_SIZE, + MBLOCK, + NBLOCK, + KBLOCK, + AK1, + BK1, + MPER_XDL, + NPER_XDL, + MPER_WAVE, + NPER_WAVE, + ABLOCK_CLUSTER_LENS, + ABLOCK_CLUSTER_ORDER, + ABLOCK_SRC_ORDER, + ABLOCK_VECTOR_DIM, + ABLOCK_SCALAR_VEC, + ABLOCK_SCALAR_VEC_AK1, + ABLOCK_LDS_EXTRAM, + BBLOCK_CLUSTER_LENS, + BBLOCK_CLUSTER_ORDER, + BBLOCK_SRC_ORDER, + BBLOCK_VECTOR_DIM, + BBLOCK_SCALAR_VEC, + BBLOCK_SCALAR_VEC_AK1, + BBLOCK_LDS_EXTRAN, + CMPER_WAVE, + CNPER_WAVE, + BLOCK_CLUSTER_LENS, + CDE_SCALAR_VEC>; + + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{alpha, beta}; + + + using DDataArrayType = std::array; + DDataArrayType DDataArray; + + // We swap A and B inputs here as a temporary workaround + auto argument = gemm.MakeArgument( + reinterpret_cast(b), + reinterpret_cast(a), + DDataArray, + reinterpret_cast(c), + N, + M, + K, + StrideB, + StrideA, + std::array{}, + StrideC, + KBatch, + a_element_op, + b_element_op, + c_element_op); + + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + + auto stream = at::cuda::getCurrentHIPStream().stream(); + invoker.Run(argument, StreamConfig{stream, false}); +} + +} // namespace at::native diff --git a/aten/src/ATen/native/im2col_shape_check.h b/aten/src/ATen/native/im2col_shape_check.h index 8a6fa47ba10f1..6c830c5c929cb 100644 --- a/aten/src/ATen/native/im2col_shape_check.h +++ b/aten/src/ATen/native/im2col_shape_check.h @@ -56,7 +56,7 @@ inline void col2im_shape_check( int64_t n_input_plane = input.size(batch_dim + 1); if (n_input_plane % (kernel_width * kernel_height) != 0) { - AT_ERROR( + TORCH_CHECK(false, "Expected size of input's dimension 1 to be divisible by the " "product of kernel_size, but got input.size(1)=", n_input_plane, @@ -81,7 +81,7 @@ inline void col2im_shape_check( 1; if (input_length != (n_blocks_height * n_blocks_width)) { - AT_ERROR( + TORCH_CHECK(false, "Given output_size=(", output_height, ", ", @@ -126,7 +126,7 @@ inline void col2im_shape_check( "which is too small (non-positive)"); if (output_width < 1 || output_height < 1) { - AT_ERROR( + TORCH_CHECK(false, "Expected output spatial size to be positive, but got: output_size=(", output_height, ", ", @@ -204,7 +204,7 @@ inline void im2col_shape_check( 1; if (output_height < 1 || output_width < 1) { - AT_ERROR( + TORCH_CHECK(false, "Given input with spatial size (", input_height, ", ", diff --git a/aten/src/ATen/native/layer_norm.cpp b/aten/src/ATen/native/layer_norm.cpp index c739547af9c1a..61be95a81a1c8 100644 --- a/aten/src/ATen/native/layer_norm.cpp +++ b/aten/src/ATen/native/layer_norm.cpp @@ -51,7 +51,7 @@ static void layer_norm_with_mean_rstd_out( for (const auto idx : c10::irange(axis)) { stat_shape.emplace_back(input_shape[idx]); } - for (const auto idx C10_UNUSED : c10::irange(axis, input.dim())) { + for ([[maybe_unused]] const auto idx : c10::irange(axis, input.dim())) { stat_shape.emplace_back(1); } @@ -256,7 +256,7 @@ std::tuple math_native_layer_norm( for (const auto idx : c10::irange(axis)) { stat_shape.push_back(input_shape[idx]); } - for (const auto idx C10_UNUSED : c10::irange(axis, input.dim())) { + for ([[maybe_unused]] const auto idx : c10::irange(axis, input.dim())) { stat_shape.push_back(1); } mean = mean.view(stat_shape); diff --git a/aten/src/ATen/native/layer_norm.h b/aten/src/ATen/native/layer_norm.h index ba2b356c0b045..339f0b1259b42 100644 --- a/aten/src/ATen/native/layer_norm.h +++ b/aten/src/ATen/native/layer_norm.h @@ -39,7 +39,7 @@ C10_ALWAYS_INLINE void _check_rms_norm_inputs_symint( ss << ", " << size; } ss << "], but got input of size" << input_shape; - AT_ERROR(ss.str()); + TORCH_CHECK(false, ss.str()); } } @@ -83,7 +83,7 @@ C10_ALWAYS_INLINE std::pair _check_layer_norm_inputs( ss << ", " << size; } ss << "], but got input of size" << input_shape; - AT_ERROR(ss.str()); + TORCH_CHECK(false, ss.str()); } const int axis = input_ndim - normalized_ndim; diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.h b/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.h index d26e358a35238..13b1f7ccaae3e 100644 --- a/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.h +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.h @@ -14,11 +14,10 @@ API_AVAILABLE(ios(11.0), macos(10.13)) @end -using namespace at::native::metal; API_AVAILABLE(ios(11.0), macos(10.13)) @interface MPSCNNConvOp : NSObject -+ (MPSCNNConvOp*)conv2d:(const Conv2DParams&)params ++ (MPSCNNConvOp*)conv2d:(const at::native::metal::Conv2DParams&)params weights:(float*)w bias:(float*)b - neuronFilter:(NeuronType)t; + neuronFilter:(at::native::metal::NeuronType)t; @end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.mm index bf4136aed5db3..a46d1a75f1671 100644 --- a/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.mm +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNConvOp.mm @@ -68,10 +68,10 @@ @implementation MPSCNNConvOp { @synthesize kernel = _kernel; -+ (MPSCNNConvOp*)conv2d:(const Conv2DParams&)params ++ (MPSCNNConvOp*)conv2d:(const at::native::metal::Conv2DParams&)params weights:(float*)w bias:(float*)b - neuronFilter:(NeuronType)t API_AVAILABLE(ios(11.0), macos(10.13)) { + neuronFilter:(at::native::metal::NeuronType)t API_AVAILABLE(ios(11.0), macos(10.13)) { using namespace at::native::metal::mpscnn; TORCH_CHECK( params.DX == params.DY == 1, "Dilated convolution is not supported yet."); diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.h b/aten/src/ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.h index 04116b54f37a9..a8560bd426305 100644 --- a/aten/src/ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.h +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.h @@ -5,8 +5,8 @@ API_AVAILABLE(ios(11.0), macos(10.13)) @interface MPSCNNFullyConnectedOp : NSObject -+ (MPSCNNFullyConnectedOp*)linear:(const Conv2DParams&)params ++ (MPSCNNFullyConnectedOp*)linear:(const at::native::metal::Conv2DParams&)params weights:(float*)w bias:(float*)b - neuronFilter:(NeuronType)t; -@end \ No newline at end of file + neuronFilter:(at::native::metal::NeuronType)t; +@end diff --git a/aten/src/ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.mm b/aten/src/ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.mm index 353095a8f52f7..19b71da963fdf 100644 --- a/aten/src/ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.mm +++ b/aten/src/ATen/native/metal/mpscnn/MPSCNNFullyConnectedOp.mm @@ -6,10 +6,10 @@ @implementation MPSCNNFullyConnectedOp @synthesize kernel = _kernel; -+ (MPSCNNFullyConnectedOp*)linear:(const Conv2DParams&)params ++ (MPSCNNFullyConnectedOp*)linear:(const at::native::metal::Conv2DParams&)params weights:(float*)w bias:(float*)b - neuronFilter:(NeuronType)t + neuronFilter:(at::native::metal::NeuronType)t API_AVAILABLE(ios(11.0), macos(10.13)) { MPSCNNNeuron* neuron = at::native::metal::neuron(t); MPSCNNConvolutionDescriptor* desc = [MPSCNNConvolutionDescriptor diff --git a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp index 607f55e058f8d..9002832fc3cc0 100644 --- a/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp +++ b/aten/src/ATen/native/miopen/BatchNorm_miopen.cpp @@ -24,13 +24,13 @@ namespace at { namespace native { std::tuple miopen_batch_norm( const Tensor& input, const Tensor& weight, const std::optional& bias_opt, const std::optional& running_mean_opt, const std::optional& running_var_opt, bool training, double exponential_average_factor, double epsilon) { - AT_ERROR("miopen_batch_norm: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_batch_norm: ATen not compiled with MIOpen support"); } std::tuple miopen_batch_norm_backward( const Tensor& input, const Tensor& grad_output, const Tensor& weight, const std::optional& running_mean_opt, const std::optional& running_var_opt, const std::optional& save_mean_opt, const std::optional& save_var_opt, double epsilon) { - AT_ERROR("miopen_batch_norm_backward: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_batch_norm_backward: ATen not compiled with MIOpen support"); } }} // namespace at::native @@ -64,8 +64,8 @@ std::tuple miopen_batch_norm( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned bias_t_maybe_owned = at::borrow_from_optional_tensor(bias_t_opt); const Tensor& bias_t = *bias_t_maybe_owned; - const Tensor& running_mean_t = c10::value_or_else(running_mean_t_opt, [] {return Tensor();}); - const Tensor& running_var_t = c10::value_or_else(running_var_t_opt, [] {return Tensor();}); + const Tensor& running_mean_t = running_mean_t_opt.value_or(Tensor()); + const Tensor& running_var_t = running_var_t_opt.value_or(Tensor()); TensorArg input{ input_t, "input", 1 }, weight{ weight_t, "weight", 2 }, @@ -169,13 +169,13 @@ std::tuple miopen_batch_norm_backward( double epsilon) { // See [Note: hacky wrapper removal for optional tensor] const Tensor& running_mean = - c10::value_or_else(running_mean_opt, [] { return Tensor(); }); + running_mean_opt.value_or(Tensor()); const Tensor& running_var = - c10::value_or_else(running_var_opt, [] { return Tensor(); }); + running_var_opt.value_or(Tensor()); const Tensor& save_mean_t = - c10::value_or_else(save_mean_t_opt, [] { return Tensor(); }); + save_mean_t_opt.value_or(Tensor()); const Tensor& save_var_t = - c10::value_or_else(save_var_t_opt, [] { return Tensor(); }); + save_var_t_opt.value_or(Tensor()); TensorArg input{ input_t, "input", 1 }, grad_output{ grad_output_t, "grad_output", 2 }, diff --git a/aten/src/ATen/native/miopen/Conv_miopen.cpp b/aten/src/ATen/native/miopen/Conv_miopen.cpp index 94bc728d6084c..0a7081ef0bd15 100644 --- a/aten/src/ATen/native/miopen/Conv_miopen.cpp +++ b/aten/src/ATen/native/miopen/Conv_miopen.cpp @@ -34,89 +34,89 @@ at::Tensor miopen_convolution( const Tensor& input, const Tensor& weight, const std::optional& bias_opt /* optional */, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { - AT_ERROR("miopen_convolution: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_convolution: ATen not compiled with MIOpen support"); } at::Tensor miopen_convolution_backward_input( IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { - AT_ERROR("miopen_convolution_backward_input: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_convolution_backward_input: ATen not compiled with MIOpen support"); } at::Tensor miopen_convolution_backward_weight( IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { - AT_ERROR("miopen_convolution_backward_weight: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_convolution_backward_weight: ATen not compiled with MIOpen support"); } at::Tensor miopen_convolution_backward_bias( const at::Tensor& grad_output) { - AT_ERROR("miopen_convolution_backward_bias: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_convolution_backward_bias: ATen not compiled with MIOpen support"); } std::tuple miopen_convolution_backward( const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, std::array output_mask) { - AT_ERROR("miopen_convolution_backward: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_convolution_backward: ATen not compiled with MIOpen support"); } at::Tensor miopen_convolution_transpose( const Tensor& input, const Tensor& weight, const std::optional& bias_opt /* optional */, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { - AT_ERROR("miopen_convolution_transpose: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_convolution_transpose: ATen not compiled with MIOpen support"); } at::Tensor miopen_convolution_transpose_backward_input( const at::Tensor& grad_output, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { - AT_ERROR("miopen_convolution_transpose_backward: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_convolution_transpose_backward: ATen not compiled with MIOpen support"); } at::Tensor miopen_convolution_transpose_backward_weight( IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { - AT_ERROR("miopen_convolution_transpose_backward_weight: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_convolution_transpose_backward_weight: ATen not compiled with MIOpen support"); } std::tuple miopen_convolution_transpose_backward( const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, std::array output_mask) { - AT_ERROR("miopen_convolution_transpose_backward: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_convolution_transpose_backward: ATen not compiled with MIOpen support"); } at::Tensor miopen_depthwise_convolution( const Tensor& input, const Tensor& weight, const std::optional& bias_opt /* optional */, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { - AT_ERROR("miopen_depthwise_convolution: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_depthwise_convolution: ATen not compiled with MIOpen support"); } at::Tensor miopen_depthwise_convolution_backward_input( IntArrayRef input_size, const at::Tensor& grad_output, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { - AT_ERROR("miopen_depthwise_convolution_backward_input: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_depthwise_convolution_backward_input: ATen not compiled with MIOpen support"); } at::Tensor miopen_depthwise_convolution_backward_weight( IntArrayRef weight_size, const at::Tensor& grad_output, const at::Tensor& input, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { - AT_ERROR("miopen_depthwise_convolution_backward_weight: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_depthwise_convolution_backward_weight: ATen not compiled with MIOpen support"); } std::tuple miopen_depthwise_convolution_backward( const at::Tensor& input, const at::Tensor& grad_output, const at::Tensor& weight, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, std::array output_mask) { - AT_ERROR("miopen_depthwise_convolution_backward: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_depthwise_convolution_backward: ATen not compiled with MIOpen support"); } @@ -124,13 +124,13 @@ at::Tensor miopen_convolution_add_relu( const at::Tensor& input, const at::Tensor& weight, const at::Tensor& z, const std::optional& alpha, const std::optional& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) { - AT_ERROR("miopen_convolution_add_relu: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_convolution_add_relu: ATen not compiled with MIOpen support"); } at::Tensor miopen_convolution_relu( const at::Tensor& input, const at::Tensor& weight, const std::optional& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) { - AT_ERROR("miopen_convolution_relu: ATen not compiled with MIOpen support"); + TORCH_CHECK(false, "miopen_convolution_relu: ATen not compiled with MIOpen support"); } }} @@ -396,7 +396,7 @@ struct algorithm_search { args.odesc.desc(), &max_solution_count)); if (max_solution_count > AT_MIOPEN_MAX_SOLUTIONS) { - AT_ERROR("miopenConvFwdAlgorithm_t getSolution max_solution_count > AT_MIOPEN_MAX_SOLUTIONS"); + TORCH_CHECK(false, "miopenConvFwdAlgorithm_t getSolution max_solution_count > AT_MIOPEN_MAX_SOLUTIONS"); } MIOPEN_CHECK(miopenConvolutionForwardGetSolution( args.handle, @@ -469,7 +469,7 @@ struct algorithm_search { args.idesc.desc(), &max_solution_count)); if (max_solution_count > AT_MIOPEN_MAX_SOLUTIONS) { - AT_ERROR("miopenConvBwdDataAlgorithm_t getSolution max_solution_count > AT_MIOPEN_MAX_SOLUTIONS"); + TORCH_CHECK(false, "miopenConvBwdDataAlgorithm_t getSolution max_solution_count > AT_MIOPEN_MAX_SOLUTIONS"); } MIOPEN_CHECK(miopenConvolutionBackwardDataGetSolution( args.handle, @@ -542,7 +542,7 @@ struct algorithm_search { args.wdesc.desc(), &max_solution_count)); if (max_solution_count > AT_MIOPEN_MAX_SOLUTIONS) { - AT_ERROR("miopenConvBwdWeightsAlgorithm_t getSolution max_solution_count > AT_MIOPEN_MAX_SOLUTIONS"); + TORCH_CHECK(false, "miopenConvBwdWeightsAlgorithm_t getSolution max_solution_count > AT_MIOPEN_MAX_SOLUTIONS"); } MIOPEN_CHECK(miopenConvolutionBackwardWeightsGetSolution( args.handle, diff --git a/aten/src/ATen/native/miopen/RNN_miopen.cpp b/aten/src/ATen/native/miopen/RNN_miopen.cpp index 86ef2fb707d50..e19243f70cdb4 100644 --- a/aten/src/ATen/native/miopen/RNN_miopen.cpp +++ b/aten/src/ATen/native/miopen/RNN_miopen.cpp @@ -34,7 +34,7 @@ namespace at { namespace native { bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const std::optional& fn_dropout_state_opt ) { - AT_ERROR("miopen_rnn : ATen not compiled with MIOpen support."); + TORCH_CHECK(false, "miopen_rnn : ATen not compiled with MIOpen support."); } std::tuple> miopen_rnn_backward( @@ -43,7 +43,7 @@ namespace at { namespace native { double dropout, bool train, bool bidirectional, IntArrayRef batch_sizes, const std::optional& dropout_state_opt, const Tensor& reserve, std::array output_mask ) { - AT_ERROR("miopen_rnn_backward: ATen not compiled with MIOpen support."); + TORCH_CHECK(false, "miopen_rnn_backward: ATen not compiled with MIOpen support."); } }} //namespace at::native @@ -109,7 +109,7 @@ struct RNNDescriptorParams { { std::ostringstream oss; oss << "unrecognized miopen RNN mode " << fn_mode; - AT_ERROR(oss.str()); + TORCH_CHECK(false, oss.str()); } } } @@ -323,7 +323,7 @@ int64_t _num_linear_layers(miopenRNNMode_t mode) { case miopenRNNTANH: return 2; default: - AT_ERROR("Unknown miopen RNN mode : ", mode); + TORCH_CHECK(false, "Unknown miopen RNN mode : ", mode); } } @@ -452,7 +452,7 @@ std::tuple miopen_rnn( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned cx_maybe_owned = at::borrow_from_optional_tensor(cx_opt); const Tensor& cx = *cx_maybe_owned; - const Tensor& fn_dropout_state = c10::value_or_else(fn_dropout_state_opt, [] {return Tensor();}); + const Tensor& fn_dropout_state = fn_dropout_state_opt.value_or(Tensor()); check_attributes(input_r, weight, {hx, cx}); auto input = input_r; @@ -766,10 +766,10 @@ std::tuple> miopen_rnn_backward( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned cx_maybe_owned = at::borrow_from_optional_tensor(cx_opt); const Tensor& cx = *cx_maybe_owned; - const Tensor& grad_output_r = c10::value_or_else(grad_output_r_opt, [] {return Tensor();}); - const Tensor& grad_hy_r = c10::value_or_else(grad_hy_r_opt, [] {return Tensor();}); - const Tensor& grad_cy_r = c10::value_or_else(grad_cy_r_opt, [] {return Tensor();}); - const Tensor& dropout_state = c10::value_or_else(dropout_state_opt, [] {return Tensor();}); + const Tensor& grad_output_r = grad_output_r_opt.value_or(Tensor()); + const Tensor& grad_hy_r = grad_hy_r_opt.value_or(Tensor()); + const Tensor& grad_cy_r = grad_cy_r_opt.value_or(Tensor()); + const Tensor& dropout_state = dropout_state_opt.value_or(Tensor()); if (!grad_output_r.defined() && !grad_hy_r.defined() && !grad_cy_r.defined()) { return std::tuple>(Tensor(), Tensor(), Tensor(), std::vector(weight.size())); @@ -804,7 +804,7 @@ std::tuple unpack_hidden(const std::tuple& hidde template hidden_type pack_hidden(const Tensor& hx, const Tensor& cx) { static_assert(std::is_same::value, "pack_hidden not implemented for this type"); - AT_ERROR("NOT IMPLEMENTED"); + TORCH_CHECK(false, "NOT IMPLEMENTED"); } template<> diff --git a/aten/src/ATen/native/mkl/MklAllocationHelper.cpp b/aten/src/ATen/native/mkl/MklAllocationHelper.cpp new file mode 100644 index 0000000000000..3ac062fb99776 --- /dev/null +++ b/aten/src/ATen/native/mkl/MklAllocationHelper.cpp @@ -0,0 +1,29 @@ +#include + +#if AT_MKLDNN_ENABLED() +#ifdef USE_MIMALLOC_ON_MKL +#include +#include +#if INTEL_MKL_VERSION > 20230000L +/* +MKL have a method to register memory allocation APIs via i_malloc.h, High +performance memory allocation APIs will help improve MKL performance. +Please check MKL online document: +https://www.intel.com/content/www/us/en/docs/onemkl/developer-guide-windows/2024-2/redefining-memory-functions.html +*/ +#include + +bool register_mimalloc_api_to_mkl() +{ + i_malloc = c10::mi_malloc_wrapper::c10_mi_malloc; + i_calloc = c10::mi_malloc_wrapper::c10_mi_calloc; + i_realloc = c10::mi_malloc_wrapper::c10_mi_realloc; + i_free = c10::mi_malloc_wrapper::c10_mi_free; + + return true; +} + +static bool g_b_registered_mkl_alloction = register_mimalloc_api_to_mkl(); +#endif +#endif +#endif diff --git a/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp b/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp index b938ccd937a8d..27e21787775e7 100644 --- a/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp +++ b/aten/src/ATen/native/mkl/SparseCsrLinearAlgebra.cpp @@ -19,9 +19,9 @@ Tensor& _sparse_mm_mkl_( const Scalar& alpha, const Scalar& beta) { #if __APPLE__ || __MACH__ - AT_ERROR("sparse_mm_mkl: MKL support is disabled on macos/iOS."); + TORCH_CHECK(false, "sparse_mm_mkl: MKL support is disabled on macos/iOS."); #else - AT_ERROR("sparse_mm_mkl: ATen not compiled with MKL support"); + TORCH_CHECK(false, "sparse_mm_mkl: ATen not compiled with MKL support"); #endif return self; // for stopping compiler warnings. } diff --git a/aten/src/ATen/native/mkl/SpectralOps.cpp b/aten/src/ATen/native/mkl/SpectralOps.cpp index 8ae620ed0028c..7fa9234e0fe8d 100644 --- a/aten/src/ATen/native/mkl/SpectralOps.cpp +++ b/aten/src/ATen/native/mkl/SpectralOps.cpp @@ -241,7 +241,7 @@ T compute_fct(int64_t size, int64_t normalization) { case fft_norm_mode::by_n: return one / static_cast(size); case fft_norm_mode::by_root_n: return one / std::sqrt(static_cast(size)); } - AT_ERROR("Unsupported normalization type", normalization); + TORCH_CHECK(false, "Unsupported normalization type", normalization); } template @@ -578,30 +578,30 @@ namespace at { namespace native { REGISTER_NO_CPU_DISPATCH(fft_fill_with_conjugate_symmetry_stub); Tensor _fft_c2r_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { - AT_ERROR("fft: ATen not compiled with FFT support"); + TORCH_CHECK(false, "fft: ATen not compiled with FFT support"); } Tensor _fft_r2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided) { - AT_ERROR("fft: ATen not compiled with FFT support"); + TORCH_CHECK(false, "fft: ATen not compiled with FFT support"); } Tensor _fft_c2c_mkl(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward) { - AT_ERROR("fft: ATen not compiled with FFT support"); + TORCH_CHECK(false, "fft: ATen not compiled with FFT support"); } Tensor& _fft_r2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool onesided, Tensor& out) { - AT_ERROR("fft: ATen not compiled with FFT support"); + TORCH_CHECK(false, "fft: ATen not compiled with FFT support"); } Tensor& _fft_c2r_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, int64_t last_dim_size, Tensor& out) { - AT_ERROR("fft: ATen not compiled with FFT support"); + TORCH_CHECK(false, "fft: ATen not compiled with FFT support"); } Tensor& _fft_c2c_mkl_out(const Tensor& self, IntArrayRef dim, int64_t normalization, bool forward, Tensor& out) { - AT_ERROR("fft: ATen not compiled with FFT support"); + TORCH_CHECK(false, "fft: ATen not compiled with FFT support"); } }} // namespace at::native diff --git a/aten/src/ATen/native/mkldnn/Linear.cpp b/aten/src/ATen/native/mkldnn/Linear.cpp index cc1d030d7cb8c..e5dc8a6e0c1da 100644 --- a/aten/src/ATen/native/mkldnn/Linear.cpp +++ b/aten/src/ATen/native/mkldnn/Linear.cpp @@ -339,7 +339,7 @@ Tensor mkldnn_linear_pointwise_binary( #if AT_MKL_ENABLED() #include -static Tensor mkl_linear( +Tensor mkl_linear( const Tensor& self, const Tensor& mkl_weight_t, const Tensor& origin_weight_t, diff --git a/aten/src/ATen/native/mkldnn/Linear.h b/aten/src/ATen/native/mkldnn/Linear.h index ef67c42ed533e..ff4f886a5309e 100644 --- a/aten/src/ATen/native/mkldnn/Linear.h +++ b/aten/src/ATen/native/mkldnn/Linear.h @@ -22,6 +22,17 @@ C10_API Tensor mkldnn_linear_pointwise_binary( const std::optional& bias_opt, c10::string_view attr); +#if AT_MKL_ENABLED() + +C10_API Tensor mkl_linear( + const Tensor& self, + const Tensor& mkl_weight_t, + const Tensor& origin_weight_t, + const std::optional& bias_opt, + const int64_t prepack_batch_size); + +#endif// AT_MKL_ENABLED + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mkldnn/Normalization.cpp b/aten/src/ATen/native/mkldnn/Normalization.cpp index dcc04f68e1848..88636a8b66b7c 100644 --- a/aten/src/ATen/native/mkldnn/Normalization.cpp +++ b/aten/src/ATen/native/mkldnn/Normalization.cpp @@ -138,9 +138,9 @@ std::tuple mkldnn_batch_norm( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); - const Tensor& running_mean = c10::value_or_else(running_mean_opt, [] {return Tensor();}); - const Tensor& running_var = c10::value_or_else(running_var_opt, [] {return Tensor();}); + const Tensor& bias = bias_opt.value_or(Tensor()); + const Tensor& running_mean = running_mean_opt.value_or(Tensor()); + const Tensor& running_var = running_var_opt.value_or(Tensor()); if (input.scalar_type() == ScalarType::BFloat16) { TORCH_CHECK(mkldnn_bf16_device_check(), @@ -253,8 +253,8 @@ std::tuple mkldnn_batch_norm_backward(const Tensor& grad // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - const Tensor& save_mean = c10::value_or_else(save_mean_opt, [] {return Tensor();}); - const Tensor& save_invstd = c10::value_or_else(save_invstd_opt, [] {return Tensor();}); + const Tensor& save_mean = save_mean_opt.value_or(Tensor()); + const Tensor& save_invstd = save_invstd_opt.value_or(Tensor()); TORCH_CHECK(train, "mkldnn_batch_norm_backward: currently mkldnn only support train model"); ideep::tensor& grady = itensor_from_mkldnn(grad_output); diff --git a/aten/src/ATen/native/mkldnn/RNN.cpp b/aten/src/ATen/native/mkldnn/RNN.cpp index 65f430ef58f5f..cbbae464c7d6a 100644 --- a/aten/src/ATen/native/mkldnn/RNN.cpp +++ b/aten/src/ATen/native/mkldnn/RNN.cpp @@ -41,7 +41,7 @@ const Tensor& input, bool bidirectional, bool batch_first, bool train) { - AT_ERROR("mkldnn_rnn_layer: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_rnn_layer: ATen not compiled with MKLDNN support"); } std::tuple mkldnn_rnn_layer_backward( @@ -68,7 +68,7 @@ std::tuple mkldnn_rnn_la at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor& workspace) { - AT_ERROR("mkldnn_rnn_layer_backward: ATen not compiled with MKLDNN support"); + TORCH_CHECK(false, "mkldnn_rnn_layer_backward: ATen not compiled with MKLDNN support"); } REGISTER_NO_CPU_DISPATCH(lstm_mkldnn_stub); @@ -315,9 +315,9 @@ std::tuple mkldnn_rnn_la at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor& workspace) { - const Tensor& grad_output_r = c10::value_or_else(grad_output_r_opt, [] {return Tensor();}); - const Tensor& grad_hy_r = c10::value_or_else(grad_hy_r_opt, [] {return Tensor();}); - const Tensor& grad_cy_r = c10::value_or_else(grad_cy_r_opt, [] {return Tensor();}); + const Tensor& grad_output_r = grad_output_r_opt.value_or(Tensor()); + const Tensor& grad_hy_r = grad_hy_r_opt.value_or(Tensor()); + const Tensor& grad_cy_r = grad_cy_r_opt.value_or(Tensor()); if (!grad_output_r.defined() && !grad_hy_r.defined() && !grad_cy_r.defined()) { return std::make_tuple(Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor(), Tensor()); } diff --git a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp index 518ce8a4f1d24..9b03e607e507c 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp @@ -68,7 +68,7 @@ Tensor& addmm_out( // complex/double case if (mat1.is_complex() || mat1.scalar_type() == ScalarType::Double) { - AT_ERROR( + TORCH_CHECK(false, "Double and complex datatype matmul is not supported in oneDNN"); } @@ -148,7 +148,7 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) { } if (self.is_complex() || self.scalar_type() == ScalarType::Double) { - AT_ERROR( + TORCH_CHECK(false, "Double and complex datatype matmul is not supported in oneDNN"); } @@ -203,7 +203,7 @@ Tensor& baddbmm_out( // complex and double case if (batch1.is_complex() || batch2.scalar_type() == ScalarType::Double) { - AT_ERROR( + TORCH_CHECK(false, "Double and complex datatype matmul is not supported in oneDNN"); } @@ -329,7 +329,7 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) { } if (self.is_complex() || self.scalar_type() == ScalarType::Double) { - AT_ERROR( + TORCH_CHECK(false, "Double and complex datatype matmul is not supported in oneDNN"); } onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr()); diff --git a/aten/src/ATen/native/mps/MPSGraphSequoiaOps.h b/aten/src/ATen/native/mps/MPSGraphSequoiaOps.h index 4ec62e33bfb03..70d70f51a9fff 100644 --- a/aten/src/ATen/native/mps/MPSGraphSequoiaOps.h +++ b/aten/src/ATen/native/mps/MPSGraphSequoiaOps.h @@ -31,8 +31,15 @@ typedef NS_ENUM(NSInteger, MTLMathMode) MTLMathModeFast = 2, }; +typedef NS_ENUM(NSInteger, MTLMathFloatingPointFunctions) +{ + MTLMathFloatingPointFunctionsFast = 0, + MTLMathFloatingPointFunctionsPrecise = 1, +}; + @interface MTLCompileOptions() @property (readwrite, nonatomic) MTLMathMode mathMode; +@property (readwrite, nonatomic) MTLMathFloatingPointFunctions mathFloatingPointFunctions; @end #endif diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h index 1e71d9d8819a8..ce16456c898e0 100644 --- a/aten/src/ATen/native/mps/OperationUtils.h +++ b/aten/src/ATen/native/mps/OperationUtils.h @@ -57,43 +57,43 @@ void runMPSGraph(MPSStream* mpsStream, NSDictionary* results); MPSDataType getMPSDataType(ScalarType scalar_type); -static inline MPSDataType getMPSDataType(const Tensor& t) { +static inline MPSDataType getMPSDataType(const TensorBase& t) { return getMPSDataType(t.scalar_type()); } MPSDataType getMPSScalarType(ScalarType scalar_type); -static inline MPSDataType getMPSScalarType(const Tensor& t) { +static inline MPSDataType getMPSScalarType(const TensorBase& t) { return getMPSScalarType(t.scalar_type()); } MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type); std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false); -static inline std::string getMPSTypeString(const Tensor& t, bool short_name = false) { +static inline std::string getMPSTypeString(const TensorBase& t, bool short_name = false) { return getMPSTypeString(t.scalar_type(), short_name); } std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type); -static inline std::string scalarToMetalTypeString(const Tensor& t) { +static inline std::string scalarToMetalTypeString(const TensorBase& t) { return scalarToMetalTypeString(t.scalar_type()); } -NSArray* getTensorAxes(const Tensor& t); +NSArray* getTensorAxes(const TensorBase& t); NSArray* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim); std::string getMPSShapeString(MPSShape* shape); std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = true, bool exclude_shape = false); std::string getArrayRefString(const IntArrayRef s); // use has_storage() on the returned tensor to determine if src actually is a view -Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst); -Tensor& scatterViewTensor(const at::Tensor& src, at::Tensor& output); -bool canSliceViewTensor(const Tensor& src, MPSShape *mpsShape); -MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape *mpsShape, const MPSDataType mpsDataType); -MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false); -MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const Tensor& input, bool includesInt64 = false); - -MPSNDArray* getMPSNDArray(const at::Tensor& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {}); -MPSNDArray* getMPSNDArray(const at::Tensor& t, MPSShape* sizes = nil, MPSShape* strides = nil); +Tensor gatherViewTensor(const Tensor& src, Tensor& dst); +Tensor& scatterViewTensor(const Tensor& src, Tensor& output); +bool canSliceViewTensor(const TensorBase& src, MPSShape *mpsShape); +MPSGraphTensorData* getMPSGraphTensorDataForView(const TensorBase& src, MPSShape *mpsShape, const MPSDataType mpsDataType); +MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input, bool includesInt64 = false); +MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input, bool includesInt64 = false); + +MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {}); +MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes = nil, MPSShape* strides = nil); // The MPSShape could vary based on memory format Tensor getTensorView(const Tensor& t, MPSShape* shape); -MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); +MPSShape* getMPSShape(const TensorBase& t, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); MPSShape* getMPSShape(IntArrayRef sizes, c10::MemoryFormat memory_format = MemoryFormat::Contiguous); -static inline id getMTLBufferStorage(const at::Tensor& tensor) { +static inline id getMTLBufferStorage(const TensorBase& tensor) { return __builtin_bit_cast(id, tensor.storage().data()); } @@ -126,16 +126,16 @@ MPSGraphTensor* trunc_tensor(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor); MPSGraphTensor* convertNHWCtoNCHW(MPSGraph *mpsGraph, MPSGraphTensor* tensor); MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, ScalarType toType); MPSGraphTensor* castMPSTensor(MPSGraph *mpsGraph, MPSGraphTensor* tensor, MPSDataType toType); -MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor); +MPSGraphTensorData *getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const TensorBase& tensor); MPSGraphTensorData* getMPSGraphTensorFromScalar(MPSStream* mpsStream, MPSScalar& scalar); MPSGraph* make_mps_graph(); -void printTensorNDArray(const Tensor& t); -MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape *shape, MPSDataType mpsType); +void printTensorNDArray(const TensorBase& t); +MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape *shape, MPSDataType mpsType); MPSGraphTensor* mpsGraphUnrankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType, MPSShape* mpsShape); -MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const Tensor& tensor); +MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph *mpsGraph, const TensorBase& tensor); MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, MPSDataType dataType); MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph *mpsGraph, const Scalar& scalar); @@ -326,12 +326,12 @@ MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor); /** * Returns distance from lowest to highest element offset in given tensor. */ -size_t compute_storage_numel_distance(const at::Tensor& t); +size_t compute_storage_numel_distance(const TensorBase& t); /** * Checks whether tensor is mapped to a contiguous area in the storage. */ -inline bool is_dense_in_storage(const at::Tensor& t) { +inline bool is_dense_in_storage(const TensorBase& t) { return compute_storage_numel_distance(t) == static_cast(t.numel()); } @@ -370,7 +370,7 @@ class MetalShaderLibrary { template, encoder_t> || std::is_same_v, encoder_t>>> -static inline void mtl_setBuffer(encoder_t encoder, const Tensor& t, unsigned idx) { +static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigned idx) { [encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx]; @@ -440,7 +440,7 @@ inline bool supportedFloatingType(ScalarType dtype) { return dtype == kFloat || dtype == kHalf || dtype == kBFloat16; } -inline bool supportedFloatingType(const Tensor& t) { +inline bool supportedFloatingType(const TensorBase& t) { return supportedFloatingType(t.scalar_type()); } @@ -450,7 +450,7 @@ inline bool supportedFloatingOrComplexType(ScalarType dtype) { } return supportedFloatingType(dtype); } -inline bool supportedFloatingOrComplexType(const Tensor& t) { +inline bool supportedFloatingOrComplexType(const TensorBase& t) { return supportedFloatingOrComplexType(t.scalar_type()); } @@ -459,7 +459,7 @@ inline void checkSupportsBFloat16() { "MPS bfloat16 type is supported on MacOS 14.0 or newer."); } -inline bool needsGather(const Tensor& t) { +inline bool needsGather(const TensorBase& t) { static const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); return !is_macOS_15_0_or_newer && (!t.is_contiguous() || t.storage_offset()) ; } diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 90879a026ed09..fc0b66b6ed40a 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -35,7 +35,7 @@ void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)()) { /** * Computes distance from lowest to highest element offset in given tensor. */ -size_t compute_storage_numel_distance(const at::Tensor& t) { +size_t compute_storage_numel_distance(const TensorBase& t) { size_t rc = 1; if (t.numel() == 0) { return 0; @@ -97,7 +97,7 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { // types. MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, - const Tensor& input, + const TensorBase& input, bool includesInt64) { MPSDataType dataType = getMPSDataType(input.scalar_type()); bool condition = @@ -117,7 +117,7 @@ MPSDataType getMPSDataType(ScalarType scalar_type) { // types. MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, - const Tensor& input, + const TensorBase& input, bool includesInt64) { MPSDataType dataType = getMPSDataType(input.scalar_type()); bool condition = @@ -240,7 +240,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { return axes; } -NSArray* getTensorAxes(const Tensor& t) { +NSArray* getTensorAxes(const TensorBase& t) { return getTensorAxes(t.dim()); } @@ -248,7 +248,7 @@ MPSDataType getMPSScalarType(ScalarType scalar_type) { return getTensorAxes(sizes.size()); } -NSArray* getTensorAxes(const IntArrayRef& sizes, at::OptionalIntArrayRef dim) { +NSArray* getTensorAxes(const IntArrayRef& sizes, OptionalIntArrayRef dim) { if (dim.has_value() && !dim.value().empty()) { IntArrayRef dimValues = dim.value(); int ndim = dimValues.size(); @@ -313,7 +313,7 @@ Tensor getTensorView(const Tensor& t, MPSShape* shape) { return t.view(res); } -MPSShape* getMPSShape(const Tensor& t, c10::MemoryFormat memory_format) { +MPSShape* getMPSShape(const TensorBase& t, c10::MemoryFormat memory_format) { return getMPSShape(t.sizes(), memory_format); } @@ -339,7 +339,7 @@ Tensor getTensorView(const Tensor& t, MPSShape* shape) { return [NSArray arrayWithObjects:numbers.data() count:numbers.size()]; } -void printTensorNDArray(const Tensor& t) { +void printTensorNDArray(const TensorBase& t) { if (!t.is_mps()) return; if (t.numel() == 0) @@ -360,7 +360,7 @@ void printTensorNDArray(const Tensor& t) { C10_CLANG_DIAGNOSTIC_POP() } -MPSNDArray* ndArrayFromTensor(const Tensor& tensor, MPSShape* shape, MPSDataType mpsType) { +MPSNDArray* ndArrayFromTensor(const TensorBase& tensor, MPSShape* shape, MPSDataType mpsType) { id buffer = getMTLBufferStorage(tensor); MPSGraphTensorData* tmpGraphTensorData = [[[MPSGraphTensorData alloc] initWithMTLBuffer:buffer shape:shape @@ -419,7 +419,7 @@ void printTensorNDArray(const Tensor& t) { return result; } -MPSNDArray* getMPSNDArray(const at::Tensor& t, MPSShape* sizes, MPSShape* strides) { +MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes, MPSShape* strides) { id srcBuf = getMTLBufferStorage(t); MPSDataType mpsDataType = getMPSDataType(t.scalar_type()); @@ -434,11 +434,11 @@ void printTensorNDArray(const Tensor& t) { return srcNDArray; } -MPSNDArray* getMPSNDArray(const at::Tensor& t, const IntArrayRef& sizes, const IntArrayRef& strides) { +MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes, const IntArrayRef& strides) { return getMPSNDArray(t, getMPSShape(sizes.empty() ? t.sizes() : sizes), strides.empty() ? nil : getMPSShape(strides)); } -static MPSNDArray* getStridedMPSNDArray(const at::Tensor& src, MPSNDArray* srcNDArray) { +static MPSNDArray* getStridedMPSNDArray(const TensorBase& src, MPSNDArray* srcNDArray) { auto strides = src.strides(); auto sizes = src.sizes(); auto nStrides = strides.size(); @@ -541,18 +541,9 @@ void printTensorNDArray(const Tensor& t) { MPSShape* mpsShape = getMPSShape(_tensor); MPSShape* mpsStrides = getMPSShape(_tensor.strides()); - IntArrayRef baseShape; - if (src.is_view()) { - baseShape = src._base().sizes(); - } else { - baseShape = getIMPSAllocator()->getBufferShape(src.storage().data()); - } - int flattenedShaped = 1; - for (const auto i : c10::irange(baseShape.size())) { - flattenedShaped *= baseShape[i]; - } - MPSShape* mpsBaseShape = @[ @(flattenedShaped) ]; - MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:dataType shape:mpsBaseShape]; + auto storage_numel = src.storage().nbytes() / src.element_size(); + MPSNDArrayDescriptor* srcTensorDesc = [MPSNDArrayDescriptor descriptorWithDataType:dataType + shape:@[ @(storage_numel) ]]; srcTensorDesc.preferPackedRows = YES; MPSNDArray* srcNDArray = [[[MPSNDArray alloc] initWithBuffer:srcBuf offset:src.storage_offset() * src.element_size() @@ -590,7 +581,7 @@ void printTensorNDArray(const Tensor& t) { _placeholder = mpsGraphTensor; } -MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const Tensor& tensor) { +MPSGraphTensorData* getMPSGraphTensorData(MPSGraph* mpsGraph, MPSStream* mpsStream, const TensorBase& tensor) { auto mpsShape = getMPSShape(tensor); auto dataType = getMPSDataType(tensor.scalar_type()); @@ -614,9 +605,9 @@ MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type) { case ScalarType::Float: return {.value.f = scalar.to(), .size = sizeof(float), .type = type}; case ScalarType::Half: - return {.value.h = scalar.to(), .size = sizeof(short), .type = type}; + return {.value.h = scalar.to(), .size = sizeof(short), .type = type}; case ScalarType::BFloat16: - return {.value.bf16 = scalar.to(), .size = sizeof(short), .type = type}; + return {.value.bf16 = scalar.to(), .size = sizeof(short), .type = type}; case ScalarType::Long: return {.value.i = scalar.to(), .size = sizeof(int64_t), .type = type}; case ScalarType::Int: @@ -630,7 +621,7 @@ MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type) { case ScalarType::Bool: return {.value.b = scalar.to(), .size = sizeof(bool), .type = type}; case ScalarType::ComplexHalf: - return {.value.ch = scalar.to>(), .size = sizeof(int32_t), .type = type}; + return {.value.ch = scalar.to>(), .size = sizeof(int32_t), .type = type}; case ScalarType::ComplexFloat: case ScalarType::ComplexDouble: return {.value.cf = scalar.to>(), .size = sizeof(int64_t), .type = type}; @@ -667,7 +658,7 @@ Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device) { // as MPS doesn't support float64 tensor. Tensor tensor; if (scalar.isFloatingPoint()) { - tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kFloat)); + tensor = at::scalar_tensor(scalar, at::device(device).dtype(kFloat)); } else if (scalar.isBoolean()) { tensor = at::scalar_tensor(scalar, at::device(device).dtype(at::kBool)); } else if (scalar.isComplex()) { @@ -693,8 +684,8 @@ Tensor wrapped_scalar_tensor_mps(const Scalar& scalar, const Device device) { return [mpsGraph placeholderWithShape:mpsShape dataType:dataType name:nil]; } -MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const Tensor& tensor) { - return [mpsGraph placeholderWithShape:getMPSShape(tensor) dataType:getMPSScalarType(tensor.scalar_type()) name:nil]; +MPSGraphTensor* mpsGraphRankedPlaceHolder(MPSGraph* mpsGraph, const TensorBase& tensor) { + return [mpsGraph placeholderWithShape:getMPSShape(tensor) dataType:getMPSScalarType(tensor) name:nil]; } MPSGraphTensor* mpsGraphScalarPlaceHolder(MPSGraph* mpsGraph, MPSDataType dataType) { @@ -848,7 +839,10 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {} } id MetalShaderLibrary::compileLibrary(const std::string& src) { - static const char* fast_math = std::getenv("PYTORCH_MPS_FAST_MATH"); + static auto fast_math = []() { + auto val = std::getenv("PYTORCH_MPS_FAST_MATH"); + return val && std::stoi(val) != 0; + }(); NSError* error = nil; MTLCompileOptions* options = compile_options; if (!options) { @@ -856,7 +850,15 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {} // Need 3.0 for atomic oprations, 3.1 introduces bfloat support [options setLanguageVersion:is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? MTLLanguageVersion3_1 : MTLLanguageVersion3_0]; - [options setFastMathEnabled:(!fast_math || std::stoi(fast_math) == 0) ? NO : YES]; + if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) { + options.mathMode = fast_math ? MTLMathModeFast : MTLMathModeSafe; + options.mathFloatingPointFunctions = + fast_math ? MTLMathFloatingPointFunctionsFast : MTLMathFloatingPointFunctionsPrecise; + } else { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-declarations") + [options setFastMathEnabled:fast_math ? YES : NO]; + C10_DIAGNOSTIC_POP() + } } const auto str = [NSString stringWithCString:src.c_str() encoding:NSASCIIStringEncoding]; diff --git a/aten/src/ATen/native/mps/operations/Convolution.mm b/aten/src/ATen/native/mps/operations/Convolution.mm index 6f65f08355c38..800a9a4648e19 100644 --- a/aten/src/ATen/native/mps/operations/Convolution.mm +++ b/aten/src/ATen/native/mps/operations/Convolution.mm @@ -167,12 +167,7 @@ static Tensor _mps_convolution_impl(const Tensor& input_t_, // TODO: MPS convolution kernel currently does not support output channels > 2^16 for (auto elem : output_t.sizes()) { - TORCH_CHECK_NOT_IMPLEMENTED( - elem <= (1 << 16), - "Output channels > 65536 not supported at the MPS device. ", - "As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` ", - "to use the CPU as a fallback for this op. WARNING: this will be slower than running natively ", - "on MPS."); + TORCH_CHECK_NOT_IMPLEMENTED(elem <= (1 << 16), "Output channels > 65536 not supported at the MPS device. "); } convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups); @@ -378,12 +373,7 @@ static Tensor mps_convolution_backward_input(IntArrayRef input_size, // TODO: MPS convolution kernel currently does not support output channels > 2^16 for (auto elem : grad_output_t.sizes()) { - TORCH_CHECK_NOT_IMPLEMENTED( - elem <= (1 << 16), - "Output channels > 65536 not supported at the MPS device. ", - "As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` ", - "to use the CPU as a fallback for this op. WARNING: this will be slower than running natively ", - "on MPS."); + TORCH_CHECK_NOT_IMPLEMENTED(elem <= (1 << 16), "Output channels > 65536 not supported at the MPS device. "); } TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types"); diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm index e40454307ac97..2ac8dd3172101 100644 --- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm +++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm @@ -163,7 +163,7 @@ static void linalg_lu_factor_out_mps_impl(const Tensor& A, bool pivot, Tensor& L status_tensors.reserve(batchSize); pivots_list.reserve(batchSize); - for (C10_UNUSED const auto i : c10::irange(batchSize)) { + for ([[maybe_unused]] const auto i : c10::irange(batchSize)) { status_tensors.push_back(at::zeros(1, kInt, std::nullopt, kMPS, std::nullopt)); pivots_list.push_back(at::zeros(numPivots, kInt, std::nullopt, kMPS, std::nullopt)); } diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm index 391422e77b535..f49a0a037ea1e 100644 --- a/aten/src/ATen/native/mps/operations/Normalization.mm +++ b/aten/src/ATen/native/mps/operations/Normalization.mm @@ -718,10 +718,15 @@ static string get_mem_string(c10::MemoryFormat memory_format) { secondaryTensor:epsilonTensor name:nil]; #ifdef __MAC_15_0 - rsqrtTensor = [mpsGraph reciprocalSquareRootWithTensor:varianceEpsTensor name:nil]; -#else - rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor name:nil]; -#endif + if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) { + rsqrtTensor = [mpsGraph reciprocalSquareRootWithTensor:varianceEpsTensor name:nil]; + } else +#endif // __MAC_15_0 + { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-declarations") + rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor name:nil]; + C10_DIAGNOSTIC_POP() + } MPSGraphTensor* bnForwardTensor = [mpsGraph multiplicationWithPrimaryTensor:xMinusMean secondaryTensor:rsqrtTensor name:nil]; @@ -747,10 +752,15 @@ static string get_mem_string(c10::MemoryFormat memory_format) { secondaryTensor:epsilonTensor name:nil]; #ifdef __MAC_15_0 - rsqrtTensor = [mpsGraph reciprocalSquareRootWithTensor:varianceEpsTensor name:nil]; -#else - rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor name:nil]; -#endif + if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) { + rsqrtTensor = [mpsGraph reciprocalSquareRootWithTensor:varianceEpsTensor name:nil]; + } else +#endif // __MAC_15_0 + { + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-declarations") + rsqrtTensor = [mpsGraph reverseSquareRootWithTensor:varianceEpsTensor name:nil]; + C10_DIAGNOSTIC_POP() + } } gradInputTensor = [mpsGraph multiplicationWithPrimaryTensor:unitTensor secondaryTensor:rsqrtTensor name:nil]; @@ -912,7 +922,7 @@ static string get_mem_string(c10::MemoryFormat memory_format) { for (const auto idx : c10::irange(axis)) { stat_shape.push_back(input_shape[idx]); } - for (C10_UNUSED auto idx : c10::irange(axis, input.dim())) { + for ([[maybe_unused]] auto idx : c10::irange(axis, input.dim())) { stat_shape.push_back(1); } mean = mean.view(stat_shape); diff --git a/aten/src/ATen/native/mps/operations/Pooling.mm b/aten/src/ATen/native/mps/operations/Pooling.mm index e270879fa9412..570d2024c640c 100644 --- a/aten/src/ATen/native/mps/operations/Pooling.mm +++ b/aten/src/ATen/native/mps/operations/Pooling.mm @@ -76,7 +76,7 @@ static void pool2d_template(const Tensor& input, } else if (suggested_memory_format == at::MemoryFormat::Contiguous) { TORCH_CHECK((ndims == 3 || ndims == 4), "non-empty 3D or 4D (batch mode) tensor expected for input"); } else { - AT_ERROR("Unsupported memory format. Supports only ChannelsLast, Contiguous"); + TORCH_CHECK(false, "Unsupported memory format. Supports only ChannelsLast, Contiguous"); } int padH = safe_downcast(padding[0]); diff --git a/aten/src/ATen/native/mps/operations/RnnOps.mm b/aten/src/ATen/native/mps/operations/RnnOps.mm index 3a773ef221dd6..4e46ea37bbadb 100644 --- a/aten/src/ATen/native/mps/operations/RnnOps.mm +++ b/aten/src/ATen/native/mps/operations/RnnOps.mm @@ -97,7 +97,7 @@ // Projections are not currently supported, raise an error if needed bool has_projections = (hx[0].size(2) != hx[1].size(2)); if (has_projections) { - AT_ERROR("LSTM with projections is not currently supported with MPS."); + TORCH_CHECK(false, "LSTM with projections is not currently supported with MPS."); } std::vector kernel_weights; @@ -358,9 +358,9 @@ using namespace mps; bool is_macos_14_4_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_4_PLUS); - const Tensor& grad_y_r = c10::value_or_else(grad_y_opt, [] { return Tensor(); }); - const Tensor& grad_hy_r = c10::value_or_else(grad_hy_opt, [] { return Tensor(); }); - const Tensor& grad_cy_r = c10::value_or_else(grad_cy_opt, [] { return Tensor(); }); + const Tensor& grad_y_r = grad_y_opt.value_or(Tensor()); + const Tensor& grad_hy_r = grad_hy_opt.value_or(Tensor()); + const Tensor& grad_cy_r = grad_cy_opt.value_or(Tensor()); const auto grad_hy = grad_hy_r.defined() ? grad_hy_r : at::zeros_like(hx[0], input.options()); const auto grad_cy = grad_cy_r.defined() ? grad_cy_r : at::zeros_like(hx[1], input.options()); diff --git a/aten/src/ATen/native/mps/operations/SpecialOps.mm b/aten/src/ATen/native/mps/operations/SpecialOps.mm new file mode 100644 index 0000000000000..d38c258cfe378 --- /dev/null +++ b/aten/src/ATen/native/mps/operations/SpecialOps.mm @@ -0,0 +1,144 @@ +#include +#include + +#define TORCH_ASSERT_NO_OPERATORS +#include + +#include + +namespace at::native { +static mps::MetalShaderLibrary lib(R"SPECIAL_METAL( +#include +using namespace metal; + +/* + * For licensing information and documentation, please refer to the cpu + * implementation located in "ATen/native/Math.h". + */ + +template +T chbevl(T x, const float array[], const int len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = 0; + + for (int i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return T{0.5} * (b0 - b2); +} + +template +T i0(T _x) { + auto x = fabs(_x); + + if (x <= 8.0) { + /* Chebyshev coefficients for exp(-x) I0(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I0(x) } = 1. + */ + const float A[] = {-4.41534164647933937950E-18, 3.33079451882223809783E-17, + -2.43127984654795469359E-16, 1.71539128555513303061E-15, + -1.16853328779934516808E-14, 7.67618549860493561688E-14, + -4.85644678311192946090E-13, 2.95505266312963983461E-12, + -1.72682629144155570723E-11, 9.67580903537323691224E-11, + -5.18979560163526290666E-10, 2.65982372468238665035E-9, + -1.30002500998624804212E-8, 6.04699502254191894932E-8, + -2.67079385394061173391E-7, 1.11738753912010371815E-6, + -4.41673835845875056359E-6, 1.64484480707288970893E-5, + -5.75419501008210370398E-5, 1.88502885095841655729E-4, + -5.76375574538582365885E-4, 1.63947561694133579842E-3, + -4.32430999505057594430E-3, 1.05464603945949983183E-2, + -2.37374148058994688156E-2, 4.93052842396707084878E-2, + -9.49010970480476444210E-2, 1.71620901522208775349E-1, + -3.04682672343198398683E-1, 6.76795274409476084995E-1}; + + auto y = (x / 2.0) - 2.0; + return static_cast(exp(x) * chbevl(y, A, 30)); + } + + // Handles x > 8 case + /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). + */ + const float B[] = {-7.23318048787475395456E-18, -4.83050448594418207126E-18, + 4.46562142029675999901E-17, 3.46122286769746109310E-17, + -2.82762398051658348494E-16, -3.42548561967721913462E-16, + 1.77256013305652638360E-15, 3.81168066935262242075E-15, + -9.55484669882830764870E-15, -4.15056934728722208663E-14, + 1.54008621752140982691E-14, 3.85277838274214270114E-13, + 7.18012445138366623367E-13, -1.79417853150680611778E-12, + -1.32158118404477131188E-11, -3.14991652796324136454E-11, + 1.18891471078464383424E-11, 4.94060238822496958910E-10, + 3.39623202570838634515E-9, 2.26666899049817806459E-8, + 2.04891858946906374183E-7, 2.89137052083475648297E-6, + 6.88975834691682398426E-5, 3.36911647825569408990E-3, + 8.04490411014108831608E-1}; + + return static_cast((exp(x) * chbevl(32.0 / x - 2.0, B, 25)) / sqrt(x)); +} + +template +void kernel +i0(constant T* input, + device Tout* output, + uint index [[thread_position_in_grid]]) { + output[index] = i0(static_cast(input[index])); +} + +#define REGISTER_I0(DTI,DTO) \ +template [[host_name("i0_" #DTI "_" #DTO )]] \ +void kernel i0(constant DTI*, device DTO*, uint) + +REGISTER_I0(float, float); +REGISTER_I0(bool, float); +REGISTER_I0(uchar, float); +REGISTER_I0(char, float); +REGISTER_I0(short, float); +REGISTER_I0(int, float); +REGISTER_I0(long, float); + +REGISTER_I0(half, half); +#if __METAL_VERSION__ >= 310 +REGISTER_I0(bfloat, bfloat); +#endif +)SPECIAL_METAL"); + +static void i0_kernel_mps(TensorIteratorBase& iter) { + using namespace mps; + TORCH_INTERNAL_ASSERT(iter.ntensors() == 2); + auto input = iter.input(); + auto output = iter.output(); + bool needs_copy = !output.is_contiguous(); + if (!input.is_contiguous()) { + input = input.contiguous(); + } + if (needs_copy) { + output = output.contiguous(); + } + auto i0PSO = lib.getPipelineStateForFunc( + fmt::format("i0_{}_{}", scalarToMetalTypeString(input), scalarToMetalTypeString(output))); + auto stream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:i0PSO]; + mtl_setBuffer(computeEncoder, input, 0); + mtl_setBuffer(computeEncoder, output, 1); + mtl_dispatch1DJob(computeEncoder, i0PSO, output.numel()); + } + }); + if (needs_copy) { + iter.output().copy_(output); + } +} + +REGISTER_DISPATCH(i0_stub, &i0_kernel_mps); +} // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/TriangularOps.mm b/aten/src/ATen/native/mps/operations/TriangularOps.mm index dcea978655b85..138f001dabbe5 100644 --- a/aten/src/ATen/native/mps/operations/TriangularOps.mm +++ b/aten/src/ATen/native/mps/operations/TriangularOps.mm @@ -1,13 +1,18 @@ // Copyright © 2022 Apple Inc. +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +#include #include +#include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else +#include #include +#include #include #endif @@ -15,6 +20,220 @@ namespace at::native { +static mps::MetalShaderLibrary lib(R"TRI_METAL( +#include +using namespace metal; +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +// To find the max integer that does not exceed the root of an int64_t variable, +// we could use a loop to test one bit at a time, which takes up to 31 +// iterations. This would give the accurate result, but is relatively slow and +// is an overkill for most cases where double's precision suffice. +// +// If we directly use sqrt to calculate the root, the conversion from int64_t +// to double would lose 11 bits precision. +// +// The following solution uses sqrt directly for most cases, and would only +// special handle it if there is indeed precision loss. +inline int64_t resolve_root_int( + int64_t b, int64_t cX4, int64_t x, int32_t sign) { + int64_t bXb_cX4 = b*b - cX4; + // precision loss could occur here when casting int64_t (63 bits + // precision) to float (23 bits precision) + float sr = sqrt((float)bXb_cX4); + int64_t res = floor((-b + sign * sr)/2); + + // have to cast double to int64_t, otherwise it would only compare up to the + // precision of a double variable, ignoring the precision loss + if (bXb_cX4 != (int64_t) (sr * sr)) { + // handle precision loss by using binary search + int64_t llsr = floor(sr); + // Use the following math to reduce search space. + // Suppose z is the accurate result of sqrt(bXb_cX4) without precision loss + // let d = abs(bXb_cX4 - llsr * llsr), then we have: + // z = sqrt(bXb_cX4) <= sqrt(llsr * llsr + d) <= llsr + sqrt(d) + // z = sqrt(bXb_cX4) >= sqrt(llsr * llsr - d) >= llsr - sqrt(d) + // Hence, it is sufficient to search range [llsr - sqrt(d), llsr + sqrt(d)). + // And the true value of row would also be with in range, + // [res - sqrt(d), res + sqrt(d) + 1) + // as the denominator would only reduce the precision penalty. + int64_t diff = ceil(sqrt(abs((float)(bXb_cX4 - llsr * llsr)))); + // l never exceeds (could equal to) the target row index + auto l = res > diff ? res - diff : 0; + // r is always larger than the target row index + auto r = res + diff + 1; + + // binary search for the correct answer + x <<= 1; // the loop always compares with 2x, so do it once here + while (l + 1 < r) { + auto m = (l + r) >> 1; + // for tril: + // b = 2f - 1, sign = 1, hence (2f + m - 1) * m / 2 + // for triu: + // b = -2f - 1, sign = -1, hence (2f - m + 1) * m / 2 + if (sign * (b + m) * m > x) { + r = m; + } else { + l = m; + } + } + res = l; + } + + return res; +} + +// f: the number of elements in the first row of the trapezoid. +// x: the index of the target coordinates ordered by row and then column. +// +// View the tril as a top trapezoid stacked on a bottom rectangle. Assume x +// corresponds to the coordinate (row, col) in the trapezoid, where the row and +// the col both start from 0, then we have: +// +// (f + f + row - 1) * row / 2 <= x [1] +// (f + f + row) * (row + 1) / 2 > x [2] +// +// Therefore, row is the maximum integer satisfying the following inequality: +// +// (row + 2f - 1)row <= 2x +// row^2 + (2f-1)row - 2x <= 0. [3] +// +// Based on inequality [3], we have the following coefficients for formula of +// root: +// a = 1 +// b = 2f - 1 +// c = -2x +// There are two roots, and we should use the largest integer that does not +// exceed the root on the right. Intuitively, it is because: +// i) the valid solution range of row is between two roots, as it is <= 0; +// ii) as we count in more rows, the total # of elements should always +// increase, hence so does the left-hand side row^2 + (2f-1)row - 2x. +// Therefore, the valid range of row lies in between the nadir point and +// the larger root on the right. +// Full proof can be derived from inequality [2]. So, we calculate the result +// coordinate as: +// +// row = floor((-b + sqrt(b^2 - 4c)) / 2) +// col = x - (f + f + row - 1) * row / 2 +inline void get_coordinate_in_tril_trapezoid( + int64_t f, int64_t x, thread int64_t & row, thread int64_t & col) { + f <<= 1; // all statements use 2f, so only calculate it once here. + auto b = f - 1; + auto cX4 = - (x << 3); // 4 * c = 4 * (-2x) = -8x; + row = resolve_root_int(b, cX4, x, 1); + col = x - ((f + row - 1) * row >> 1); +} + +// f: the number of elements in the first row of the bottom trapezoid. +// x: the index of the target coordinates ordered by row and then column. +// +// View the triu as a top rectangle stacked on a bottom trapezoid, where the +// trapezoid is upside down. Assume x corresponds to the coordinate (row, col) +// in the bottom trapezoid, where the row and the col start from 0, then we +// have: +// +// (f + f - row + 1) * row / 2 <= x [1] +// (f + f - row) * (row + 1) / 2 > x [2] +// +// Therefore, row is the maximum integer satisfying the following inequality: +// +// (-row + 2f + 1)row <= 2x +// row^2 - (2f+1)row + 2x >= 0. [3] +// +// Based on inequality [3], we have the following coefficients for formula of +// root: +// a = 1 +// b = -1 - 2f +// c = 2x +// There are two roots, and we should use the largest integer that does not +// exceed the root on the left. Intuitively, it is because: +// i) the valid solution range of row is outside of the two roots, as it is < +// > 0; +// ii) as we count in more rows, the total # of elements should always +// increase, hence so does the left-hand side row^2 - (2f+1)row + 2x. +// Therefore, the valid range of row lies to the left of the smaller root +// on the left. +// Full proof can be derived from inequality [2]. So, we calculate the result +// coordinate as: +// +// row = floor((-b - sqrt(b^2 - 4c)) / 2) +// col = x - (f + f - row + 1) * row / 2 +inline void get_coordinate_in_triu_trapezoid( + int64_t f, int64_t x, thread int64_t & row, thread int64_t & col) { + f <<= 1; // all statements use 2f, so only calculate it once here. + auto b = -1 - f; + auto cX4 = x << 3; // 4 * c = 4 * (2x) = 8x; + row = resolve_root_int(b, cX4, x, -1); + col = x - ((f - row + 1) * row >> 1) + row; +} + +template +kernel void tril_indices(device scalar_t * tensor, + constant int64_t& row_offset, + constant int64_t& m_first_row, + constant int64_t& col, + constant int64_t& trapezoid_size, + constant int64_t& tril_size, + uint linear_index [[thread_position_in_grid]]) { + int64_t r, c; + if (linear_index < trapezoid_size) { + // the coordinate is within the top trapezoid + get_coordinate_in_tril_trapezoid(m_first_row, linear_index, r, c); + } else { + // the coordinate falls in the bottom rectangle + auto surplus = linear_index - trapezoid_size; + // add the height of trapezoid: m_last_row (col) - m_first_row + 1 + r = surplus / col + col - m_first_row + 1; + c = surplus % col; + } + r += row_offset; + + tensor[linear_index] = r; + tensor[linear_index + tril_size] = c; +} + +template +kernel void triu_indices(device scalar_t * tensor, + constant int64_t& col_offset, + constant int64_t& m_first_row, + constant int64_t& col, + constant int64_t& rectangle_size, + constant int64_t& triu_size, + uint linear_index [[thread_position_in_grid]]) { + int64_t r, c; + if (linear_index < rectangle_size) { + // the coordinate is within the top rectangle + r = linear_index / col; + c = linear_index % col; + } else { + // the coordinate falls in the bottom trapezoid + get_coordinate_in_triu_trapezoid( + m_first_row, linear_index - rectangle_size, r, c); + r += rectangle_size / col; + } + + c += col_offset; + tensor[linear_index] = r; + tensor[linear_index + triu_size] = c; +} + +#define INSTANTIATE_TRI_INDICES(NAME, DTYPE) \ + template [[host_name(#NAME "_indices_" #DTYPE)]] kernel void \ + NAME ## _indices( \ + device DTYPE * tensor, \ + constant int64_t& col_offset, \ + constant int64_t& m_first_row, \ + constant int64_t& col, \ + constant int64_t& rectangle_size, \ + constant int64_t& triu_size, \ + uint linear_index [[thread_position_in_grid]]) + +INSTANTIATE_TRI_INDICES(triu, long); +INSTANTIATE_TRI_INDICES(triu, int); +INSTANTIATE_TRI_INDICES(tril, long); +INSTANTIATE_TRI_INDICES(tril, int); +)TRI_METAL"); + TORCH_IMPL_FUNC(triu_mps_out) (const Tensor& self, int64_t k, const Tensor& output) { using namespace mps; @@ -111,4 +330,88 @@ } } +Tensor tril_indices_mps(int64_t row, + int64_t col, + int64_t offset, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + check_args(row, col, layout_opt); + + auto tril_size = get_tril_size(row, col, offset); + auto tensor = at::detail::empty_mps({2, tril_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt, std::nullopt); + if (tril_size <= 0) { + return tensor; + } + auto m_first_row = offset > 0 ? std::min(col, 1 + offset) : // upper bounded by col + row + offset > 0; // either 0 or 1 + auto trapezoid_row_offset = std::max(0, -offset); + auto rectangle_row_offset = trapezoid_row_offset + col - m_first_row + 1; + int64_t rectangle_size = 0; + if (rectangle_row_offset < row) { + rectangle_size = (row - rectangle_row_offset) * col; + } + using namespace mps; + auto trilPSO = lib.getPipelineStateForFunc("tril_indices_" + scalarToMetalTypeString(tensor)); + auto stream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:trilPSO]; + mtl_setBuffer(computeEncoder, tensor, 0); + mtl_setBytes(computeEncoder, trapezoid_row_offset, 1); + mtl_setBytes(computeEncoder, m_first_row, 2); + mtl_setBytes(computeEncoder, col, 3); + mtl_setBytes(computeEncoder, tril_size - rectangle_size, 4); + mtl_setBytes(computeEncoder, tril_size, 5); + mtl_dispatch1DJob(computeEncoder, trilPSO, tril_size); + } + }); + + return tensor; +} + +Tensor triu_indices_mps(int64_t row, + int64_t col, + int64_t offset, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + check_args(row, col, layout_opt); + + auto triu_size = row * col - get_tril_size(row, col, offset - 1); + auto tensor = at::detail::empty_mps({2, triu_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt, std::nullopt); + if (triu_size <= 0) { + return tensor; + } + // # of triu elements in the first row + auto m_first_row = offset > 0 ? std::max(col - offset, 0) : // upper bounded by col + col; + + // size of the top rectangle + int64_t rectangle_size = 0; + if (offset < 0) { + rectangle_size = std::min(row, -offset) * col; + } + using namespace mps; + auto triuPSO = lib.getPipelineStateForFunc("triu_indices_" + scalarToMetalTypeString(tensor)); + auto stream = getCurrentMPSStream(); + dispatch_sync_with_rethrow(stream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = stream->commandEncoder(); + [computeEncoder setComputePipelineState:triuPSO]; + mtl_setBuffer(computeEncoder, tensor, 0); + mtl_setBytes(computeEncoder, std::max(0, offset), 1); + mtl_setBytes(computeEncoder, m_first_row, 2); + mtl_setBytes(computeEncoder, col, 3); + mtl_setBytes(computeEncoder, rectangle_size, 4); + mtl_setBytes(computeEncoder, triu_size, 5); + mtl_dispatch1DJob(computeEncoder, triuPSO, triu_size); + } + }); + + return tensor; +} } // namespace at::native diff --git a/aten/src/ATen/native/mps/operations/UnaryOps.mm b/aten/src/ATen/native/mps/operations/UnaryOps.mm index 334a056cddfb5..4326481f44526 100644 --- a/aten/src/ATen/native/mps/operations/UnaryOps.mm +++ b/aten/src/ATen/native/mps/operations/UnaryOps.mm @@ -225,11 +225,6 @@ static void unary_op(const Tensor& self, CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(exp2_out_mps, exponentBase2) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(reciprocal_out_mps, reciprocal) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(sqrt_out_mps, squareRoot) -#ifdef __MAC_15_0 -CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(rsqrt_out_mps, reciprocalSquareRoot) -#else -CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(rsqrt_out_mps, reverseSquareRoot) -#endif CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(neg_out_mps, negative) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(log_out_mps, logarithm) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(log10_out_mps, logarithmBase10) @@ -247,6 +242,19 @@ static void unary_op(const Tensor& self, CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(acosh_out_mps, acosh) CREATE_MPS_STRUCTURED_UNARY_TORCH_IMPL_FUNC(atanh_out_mps, atanh) +TORCH_IMPL_FUNC(rsqrt_out_mps)(const Tensor& self, const Tensor& output) { + mps::unary_op(self, output, "rsqrt_out_mps", ^MPSGraphTensor*(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor) { +#ifdef __MAC_15_0 + if (is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS)) { + return [mpsGraph reciprocalSquareRootWithTensor:inputTensor name:nil]; + } +#endif // __MAC_15_0 + C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-declarations") + return [mpsGraph reverseSquareRootWithTensor:inputTensor name:nil]; + C10_DIAGNOSTIC_POP() + }); +} + Tensor& abs_out_mps(const Tensor& self, Tensor& output) { using namespace mps; diff --git a/aten/src/ATen/native/mps/operations/UpSample.mm b/aten/src/ATen/native/mps/operations/UpSample.mm index 3c1f62f6f2277..db10c56ad6b0c 100644 --- a/aten/src/ATen/native/mps/operations/UpSample.mm +++ b/aten/src/ATen/native/mps/operations/UpSample.mm @@ -76,7 +76,7 @@ static void upsample_out_template(const Tensor& input, centerResults = true; nearestRoundingMode = MPSGraphResizeNearestRoundingModeRoundPreferCeil; } else { - AT_ERROR("Unsupported resize mode ", resize_mode_str); + TORCH_CHECK(false, "Unsupported resize mode ", resize_mode_str); } const int64_t output_width = output_size.size() > 1 ? output_size[1] : output_size[0]; diff --git a/aten/src/ATen/native/mps/operations/View.mm b/aten/src/ATen/native/mps/operations/View.mm index ba9d439b2e64d..66646113f3a89 100644 --- a/aten/src/ATen/native/mps/operations/View.mm +++ b/aten/src/ATen/native/mps/operations/View.mm @@ -435,7 +435,7 @@ return outputTensor; } -static std::vector getViewShape(const Tensor& src, MPSShape* mpsShape, const bool squeeze) { +static std::vector getViewShape(const TensorBase& src, MPSShape* mpsShape, const bool squeeze) { bool hasMPSShape = (mpsShape != nil); std::vector src_view_shape; if (hasMPSShape) { @@ -481,7 +481,7 @@ return src_base_shape; } -bool canSliceViewTensor(const Tensor& src, MPSShape* mpsShape) { +bool canSliceViewTensor(const TensorBase& src, MPSShape* mpsShape) { if (!src.is_contiguous()) { return false; } @@ -503,7 +503,9 @@ bool canSliceViewTensor(const Tensor& src, MPSShape* mpsShape) { return true; } -MPSGraphTensorData* getMPSGraphTensorDataForView(const Tensor& src, MPSShape* mpsShape, const MPSDataType mpsDataType) { +MPSGraphTensorData* getMPSGraphTensorDataForView(const TensorBase& src, + MPSShape* mpsShape, + const MPSDataType mpsDataType) { IntArrayRef src_base_shape = getIMPSAllocator()->getBufferShape(src.storage().data()); size_t src_ndim_base = src_base_shape.size(); std::vector src_view_shape = getViewShape(src, mpsShape, false); @@ -704,7 +706,7 @@ static IntArrayRef updateTensorBaseShape(const Tensor& self) { // Self is the input tensor we are creating view of newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, inputType, getMPSShape(base_shape)); newCachedGraph->storageOffsetTensor = mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @1 ]); - for (const auto C10_UNUSED i : c10::irange(size.size())) { + for ([[maybe_unused]] const auto i : c10::irange(size.size())) { newCachedGraph->strideTensors.push_back(mpsGraphRankedPlaceHolder(mpsGraph, MPSDataTypeInt32, @[ @1 ])); } if (needsScatter) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8ccafbca1b0a9..3625cd8712496 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8776,12 +8776,14 @@ dispatch: CPU: tril_indices_cpu CUDA: tril_indices_cuda + MPS: tril_indices_mps autogen: tril_indices.out - func: triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor dispatch: CPU: triu_indices_cpu CUDA: triu_indices_cuda + MPS: triu_indices_mps autogen: triu_indices.out - func: trace(Tensor self) -> Tensor @@ -9579,7 +9581,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: i0_out + CPU, CUDA, MPS: i0_out tags: pointwise - func: sign(Tensor self) -> Tensor diff --git a/aten/src/ATen/native/nested/NestedTensorBackward.cpp b/aten/src/ATen/native/nested/NestedTensorBackward.cpp index 85c15b603e47d..5bd737259261d 100644 --- a/aten/src/ATen/native/nested/NestedTensorBackward.cpp +++ b/aten/src/ATen/native/nested/NestedTensorBackward.cpp @@ -9,8 +9,6 @@ #include #include #include -#include -#include #include #include diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index 0b7138ec0ffaf..d8330287bb01f 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -752,7 +752,7 @@ inline std::tuple NestedTensor_compute_size_stride( } } else { - AT_ERROR("invalid shape dimension ", size_reshaped); + TORCH_CHECK(false, "invalid shape dimension ", size_reshaped); } } // See Note [Special size rule for nested tensor] diff --git a/aten/src/ATen/native/nested/NestedTensorMatmul.cpp b/aten/src/ATen/native/nested/NestedTensorMatmul.cpp index 568f36d4cd01a..8e0a371ba784e 100644 --- a/aten/src/ATen/native/nested/NestedTensorMatmul.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMatmul.cpp @@ -13,7 +13,6 @@ #include #include #include -#include namespace at::native { @@ -224,10 +223,10 @@ Tensor matmul_nested(const Tensor& self, const Tensor& mat2) { return matmul_nested_with_broadcasted_dense(self, mat2); } if (self.is_nested() && !mat2.is_nested()) { - AT_ERROR( + TORCH_CHECK(false, "Expected both to be nested, but got a nested self and non-nested other"); } else if (!self.is_nested() && mat2.is_nested()) { - AT_ERROR( + TORCH_CHECK(false, "Expected both to be nested, but got a non-nested self and nested other"); } // to_padded_tensor only supports contiguous inputs diff --git a/aten/src/ATen/native/nested/NestedTensorUtils.h b/aten/src/ATen/native/nested/NestedTensorUtils.h index 0dd89e74eaa14..b584a7319ff0b 100644 --- a/aten/src/ATen/native/nested/NestedTensorUtils.h +++ b/aten/src/ATen/native/nested/NestedTensorUtils.h @@ -22,6 +22,7 @@ #include #endif +#include #include #include diff --git a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp index 0fdf4709a1139..d559df9fbd10c 100644 --- a/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp +++ b/aten/src/ATen/native/nested/cuda/NestedTensorTransformerFunctions.cpp @@ -13,7 +13,6 @@ #include #endif -#include #include #include #include @@ -111,7 +110,7 @@ Tensor nested_from_padded_cuda( padded_contiguous.sizes()[0]); } } else { - AT_ERROR("Only support fp32/fp16 for padded input"); + TORCH_CHECK(false, "Only support fp32/fp16 for padded input"); } return at::detail::make_tensor(std::move(output), sizes); } else { diff --git a/aten/src/ATen/native/quantized/cpu/Normalization.cpp b/aten/src/ATen/native/quantized/cpu/Normalization.cpp index 846d712fbadc4..2fde5c954e782 100644 --- a/aten/src/ATen/native/quantized/cpu/Normalization.cpp +++ b/aten/src/ATen/native/quantized/cpu/Normalization.cpp @@ -389,7 +389,7 @@ Tensor quantized_batch_norm( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt); const Tensor& weight = *weight_maybe_owned; - const Tensor& bias = c10::value_or_else(bias_opt, [] {return Tensor();}); + const Tensor& bias = bias_opt.value_or(Tensor()); Tensor qy; // TODO: this should arguably support 3d as well diff --git a/aten/src/ATen/native/quantized/cpu/Pooling.cpp b/aten/src/ATen/native/quantized/cpu/Pooling.cpp index 47351d3a5902e..69f9d1283a4c5 100644 --- a/aten/src/ATen/native/quantized/cpu/Pooling.cpp +++ b/aten/src/ATen/native/quantized/cpu/Pooling.cpp @@ -478,6 +478,8 @@ void check_maxpool2d_params( "Expected 1d or 2d padding, got ", padding.size()); TORCH_CHECK(dilation.size() == 1 || dilation.size() == 2, "Expected 1d or 2d dilation, got ", dilation.size()); + TORCH_CHECK(dilation.allMatch([](const auto& ele) { return ele >= 1L; }), + "Expected dilation >= 1"); } void check_maxpool3d_params( @@ -490,6 +492,8 @@ void check_maxpool3d_params( "Expected no strides or 3d strides, got", stride.size()); TORCH_CHECK(padding.size() == 3, "Expected 3d padding, got ", padding.size()); TORCH_CHECK(dilation.size() == 3, "Expected 1d or 3d dilation, got ", dilation.size()); + TORCH_CHECK(dilation.allMatch([](const auto& ele) { return ele >= 1L; }), + "Expected dilation >= 1"); } #ifdef USE_PYTORCH_QNNPACK diff --git a/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h b/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h index b217c757740b3..a06e6b672ec32 100644 --- a/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h +++ b/aten/src/ATen/native/quantized/cpu/QnnpackUtils.h @@ -444,7 +444,7 @@ Tensor qnnpack_avg_pool2d( } // namespace at namespace { -C10_UNUSED std::vector generate_requantization_scales( +[[maybe_unused]] std::vector generate_requantization_scales( const at::Tensor& weight_scales, const float input_scale, const float output_scale, @@ -468,11 +468,11 @@ C10_UNUSED std::vector generate_requantization_scales( return requant_scales; } -C10_UNUSED std::pair, at::Tensor> make_zero_points_and_scales_tensor( +[[maybe_unused]] std::pair, at::Tensor> +make_zero_points_and_scales_tensor( const at::Tensor& weight_contig, bool transpose = false, - uint32_t groups = 1 - ) { + uint32_t groups = 1) { const int out_ch_idx = transpose ? 1 : 0; const auto num_output_channels = weight_contig.size(out_ch_idx) * (transpose ? groups : 1); // Add 8 to account for bufferring needed by QNNPACK. diff --git a/aten/src/ATen/native/quantized/cpu/QuantUtils.h b/aten/src/ATen/native/quantized/cpu/QuantUtils.h index 0b026c739786a..e81b0d87916b2 100644 --- a/aten/src/ATen/native/quantized/cpu/QuantUtils.h +++ b/aten/src/ATen/native/quantized/cpu/QuantUtils.h @@ -186,8 +186,9 @@ inline TensorQuantizationParams ChooseQuantizationParams( // This function helps to convert the Conv1D dimensions usable by the Conv2d op. constexpr int64_t kConv1dSqueezeDim = 0; -static C10_UNUSED torch::List MakeArgForConv1d(const torch::List& arg, - int64_t base_value) { +[[maybe_unused]] static torch::List MakeArgForConv1d( + const torch::List& arg, + int64_t base_value) { TORCH_CHECK(!arg.empty(), "Argument must have elements."); torch::List result({arg.get(0), base_value}); if (arg.size() == 1) { diff --git a/aten/src/ATen/native/quantized/cpu/UpSampleBilinear2d.cpp b/aten/src/ATen/native/quantized/cpu/UpSampleBilinear2d.cpp index 72079abd183f5..82c8a3f751dbf 100644 --- a/aten/src/ATen/native/quantized/cpu/UpSampleBilinear2d.cpp +++ b/aten/src/ATen/native/quantized/cpu/UpSampleBilinear2d.cpp @@ -124,7 +124,7 @@ static void upsample_bilinear2d_out_frame( const auto* pos1 = i_ptr + h1 * input_width + w1; - float result = h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p]) + + const float result = h0lambda * (w0lambda * pos1[0] + w1lambda * pos1[w1p]) + h1lambda * (w0lambda * pos1[h1p * input_width] + w1lambda * pos1[h1p * input_width + w1p]) - input_q_zero_point; diff --git a/aten/src/ATen/native/quantized/cpu/UpSampleNearest3d.cpp b/aten/src/ATen/native/quantized/cpu/UpSampleNearest3d.cpp index 02d23f46ba945..3195afa144eed 100644 --- a/aten/src/ATen/native/quantized/cpu/UpSampleNearest3d.cpp +++ b/aten/src/ATen/native/quantized/cpu/UpSampleNearest3d.cpp @@ -71,7 +71,7 @@ static void upsample_nearest3d_out_frame( const auto* pos1 = &i_p[d1 * input_height * input_width + h1 * input_width + w1]; auto* pos2 = &o_p[d2 * output_height * output_width + h2 * output_width + w2]; - for (C10_UNUSED const auto c : c10::irange(channels)) { + for ([[maybe_unused]] const auto c : c10::irange(channels)) { pos2[0] = pos1[0]; pos1 += input_depth * input_height * input_width; pos2 += output_depth * output_height * output_width; diff --git a/aten/src/ATen/native/quantized/cpu/conv_serialization.h b/aten/src/ATen/native/quantized/cpu/conv_serialization.h index 9f2dfd26118ac..214447e20eaaa 100644 --- a/aten/src/ATen/native/quantized/cpu/conv_serialization.h +++ b/aten/src/ATen/native/quantized/cpu/conv_serialization.h @@ -143,7 +143,7 @@ ConvParamsSerializationTypeV3 parse_conv_serialized_state(c10::IValue v) { config_vals.push_back(dilation[0].item()); } // output_padding does not exist in v1, so we fill in a default value - for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) { + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { config_vals.push_back(0); } config_vals.push_back(groups[0].item()); @@ -294,21 +294,24 @@ c10::intrusive_ptr> deserialize_conv( torch::List stride, padding, output_padding, dilation; // skip kSpatialDim int idx = 1; - for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) { + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { stride.emplace_back(config_vals.at(idx)); idx++; } - for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) { + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { padding.emplace_back(config_vals.at(idx)); idx++; } - for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) { + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { dilation.emplace_back(config_vals.at(idx)); idx++; } - for (C10_UNUSED const auto i : c10::irange(kSpatialDim)) { - TORCH_INTERNAL_ASSERT(idx < static_cast(config_vals.size()), - "Unexpected index = ", idx, " for config_vals of size ", + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { + TORCH_INTERNAL_ASSERT( + idx < static_cast(config_vals.size()), + "Unexpected index = ", + idx, + " for config_vals of size ", config_vals.size()); output_padding.emplace_back(config_vals.at(idx)); idx++; diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index d6ac157a116b5..abe403dc25508 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -554,9 +554,9 @@ int register_embedding_params() { namespace { -static C10_UNUSED auto conv2d_params = register_conv_params<2>(); -static C10_UNUSED auto conv3d_params = register_conv_params<3>(); -static C10_UNUSED auto linear_params = register_linear_params(); -static C10_UNUSED auto embedding_params = register_embedding_params(); +[[maybe_unused]] static auto conv2d_params = register_conv_params<2>(); +[[maybe_unused]] static auto conv3d_params = register_conv_params<3>(); +[[maybe_unused]] static auto linear_params = register_linear_params(); +[[maybe_unused]] static auto embedding_params = register_embedding_params(); } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp index 3fb49bbd8285e..8f8745c32b8d0 100644 --- a/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp +++ b/aten/src/ATen/native/quantized/cpu/kernels/QuantizedOpKernels.cpp @@ -15,7 +15,6 @@ #include #include #include -#include #ifndef AT_PER_OPERATOR_HEADERS #include @@ -2294,7 +2293,7 @@ void qupsample_bilinear2d_nhwc_kernel( int64_t b{0}, h2{0}, w2{0}; data_index_init(begin, b, nbatch, h2, output_height, w2, output_width); - for (C10_UNUSED const auto i : c10::irange(begin, end)) { + for ([[maybe_unused]] const auto i : c10::irange(begin, end)) { auto* i_p = reinterpret_cast( idata + b * input_height * input_width * channels); auto* o_p = reinterpret_cast( @@ -3819,8 +3818,8 @@ void quantize_tensor_per_channel_impl( // channels_last contig. // If axis = 0 and channels_last contig, implementation for channels // first (NCHW) works. - for (C10_UNUSED const auto b : c10::irange(batches)) { - for (C10_UNUSED const auto e : c10::irange(elements_per_channel)) { + for ([[maybe_unused]] const auto b : c10::irange(batches)) { + for ([[maybe_unused]] const auto e : c10::irange(elements_per_channel)) { uint32_t c = 0; while (c + 8 < channels) { const int32x4_t voffset0123 = vld1q_s32(&zero_points_int32t[c]); @@ -3854,7 +3853,7 @@ void quantize_tensor_per_channel_impl( } } } else { - for (C10_UNUSED const auto b : c10::irange(batches)) { + for ([[maybe_unused]] const auto b : c10::irange(batches)) { for (const auto c : c10::irange(channels)) { uint32_t e = 0; const int32x4_t voffset = vdupq_n_s32(zero_points_int32t[c]); @@ -3901,8 +3900,8 @@ void quantize_tensor_per_channel_impl( // channels_last contig. // If axis = 0 and channels_last contig, implementation for channels // first (NCHW) works. - for (const auto b C10_UNUSED : c10::irange(batches)) { - for (const auto e C10_UNUSED : c10::irange(elements_per_channel)) { + for ([[maybe_unused]] const auto b : c10::irange(batches)) { + for ([[maybe_unused]] const auto e : c10::irange(elements_per_channel)) { uint32_t c = 0; while (c + 8 < channels) { const int16x8_t vzero_point = vld1q_s16(&zero_points_int16t[c]); @@ -3932,8 +3931,8 @@ void quantize_tensor_per_channel_impl( } } } else { - for (const auto b C10_UNUSED : c10::irange(batches)) { - for (const auto c C10_UNUSED : c10::irange(channels)) { + for ([[maybe_unused]] const auto b : c10::irange(batches)) { + for ([[maybe_unused]] const auto c : c10::irange(channels)) { uint32_t e = 0; const int16x8_t vzero_point = vdupq_n_s16(zero_points_int16t[c]); const float32x4_t vinv_scale = vdupq_n_f32(inv_scales[c]); diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index 6895821fc0b53..098547cbb3ac8 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -342,6 +342,9 @@ c10::intrusive_ptr> PackedConvWeightsOnednn< stride.size() == kSpatialDim, "stride should contain ", kSpatialDim, " elements for ", kSpatialDim, "D convolution."); + TORCH_CHECK( + std::all_of(stride.begin(), stride.end(), [](bool s) { return s > 0; }), + "quantized::conv_prepack: stride should be positive."); TORCH_CHECK( padding.size() == kSpatialDim, "Specify front/top/left padding only. " @@ -631,7 +634,7 @@ class QConvPackWeightInt8 final { int64_t groups) { torch::List output_padding; output_padding.reserve(kSpatialDim); - for (C10_UNUSED const auto idx : c10::irange(kSpatialDim)) { + for ([[maybe_unused]] const auto idx : c10::irange(kSpatialDim)) { output_padding.push_back((int64_t)0); } return _run(weight, bias, stride, padding, output_padding, dilation, groups, diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index 26fe9cd2ac4cc..9f2cf186e03b3 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -1103,6 +1104,73 @@ static at::Tensor linear_int8_with_onednn_weight( namespace at { namespace native { + + Tensor QLinearOnednn::run_pointwise_tensor( + Tensor act, // int8 CPU tensor, not QTensor + Tensor act_scale, + Tensor act_zero_point, + Tensor onednn_weight, // int8 tensor from MkldnnCPU + Tensor weight_scales, + Tensor weight_zero_points, + std::optional bias, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + c10::string_view post_op_name, + torch::List> post_op_args, + c10::string_view post_op_algorithm) { +#if AT_MKLDNN_ENABLED() + TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1, + "onednn int8 linear: act scale/zp size should be 1"); + static std::optional other = std::nullopt; + static const c10::string_view binary_post_op = "none"; + return linear_int8_with_onednn_weight( + act, act_scale.item().toDouble(), act_zero_point.item().toLong(), + onednn_weight, weight_scales, weight_zero_points, + bias, output_scale, output_zero_point, output_dtype, + other, /*other scale*/1.0, /*other zp*/0, + binary_post_op, /*binary alpha*/1.0, + post_op_name, post_op_args, post_op_algorithm + ); +#endif + TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)"); + } + + Tensor QLinearOnednn::run_pointwise_binary_tensor( + Tensor act, // int8 CPU tensor, not QTensor + Tensor act_scale, + Tensor act_zero_point, + Tensor onednn_weight, // int8 tensor from MkldnnCPU + Tensor weight_scales, + Tensor weight_zero_points, + std::optional other, // extra input for binary post-op + std::optional bias, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + double other_scale, + int64_t other_zero_point, + c10::string_view binary_post_op, // e.g. "none", "sum", "add" + double binary_alpha, + c10::string_view unary_post_op, // e.g. "none", "relu" + torch::List> unary_post_op_args, + c10::string_view unary_post_op_algorithm) { +#if AT_MKLDNN_ENABLED() + TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1, + "onednn int8 linear: act scale/zp size should be 1"); + return linear_int8_with_onednn_weight( + act, act_scale.item().toDouble(), act_zero_point.item().toLong(), + onednn_weight, weight_scales, weight_zero_points, + bias, output_scale, output_zero_point, output_dtype, + other, other_scale, other_zero_point, + binary_post_op, binary_alpha, + unary_post_op, unary_post_op_args, unary_post_op_algorithm + ); +#endif + TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)"); + } + + namespace { template @@ -1220,37 +1288,6 @@ class QLinearOnednn final { TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)"); } - static Tensor run_pointwise_tensor( - Tensor act, // int8 CPU tensor, not QTensor - Tensor act_scale, - Tensor act_zero_point, - Tensor onednn_weight, // int8 tensor from MkldnnCPU - Tensor weight_scales, - Tensor weight_zero_points, - std::optional bias, - double output_scale, - int64_t output_zero_point, - std::optional output_dtype, - c10::string_view post_op_name, - torch::List> post_op_args, - c10::string_view post_op_algorithm) { -#if AT_MKLDNN_ENABLED() - TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1, - "onednn int8 linear: act scale/zp size should be 1"); - static std::optional other = std::nullopt; - static const c10::string_view binary_post_op = "none"; - return linear_int8_with_onednn_weight( - act, act_scale.item().toDouble(), act_zero_point.item().toLong(), - onednn_weight, weight_scales, weight_zero_points, - bias, output_scale, output_zero_point, output_dtype, - other, /*other scale*/1.0, /*other zp*/0, - binary_post_op, /*binary alpha*/1.0, - post_op_name, post_op_args, post_op_algorithm - ); -#endif - TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)"); - } - static Tensor run_pointwise_binary( Tensor act, // int8 CPU tensor, not QTensor double act_scale, @@ -1279,40 +1316,6 @@ class QLinearOnednn final { binary_post_op, binary_alpha, unary_post_op, unary_post_op_args, unary_post_op_algorithm ); -#endif - TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)"); - } - - static Tensor run_pointwise_binary_tensor( - Tensor act, // int8 CPU tensor, not QTensor - Tensor act_scale, - Tensor act_zero_point, - Tensor onednn_weight, // int8 tensor from MkldnnCPU - Tensor weight_scales, - Tensor weight_zero_points, - std::optional other, // extra input for binary post-op - std::optional bias, - double output_scale, - int64_t output_zero_point, - std::optional output_dtype, - double other_scale, - int64_t other_zero_point, - c10::string_view binary_post_op, // e.g. "none", "sum", "add" - double binary_alpha, - c10::string_view unary_post_op, // e.g. "none", "relu" - torch::List> unary_post_op_args, - c10::string_view unary_post_op_algorithm) { -#if AT_MKLDNN_ENABLED() - TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1, - "onednn int8 linear: act scale/zp size should be 1"); - return linear_int8_with_onednn_weight( - act, act_scale.item().toDouble(), act_zero_point.item().toLong(), - onednn_weight, weight_scales, weight_zero_points, - bias, output_scale, output_zero_point, output_dtype, - other, other_scale, other_zero_point, - binary_post_op, binary_alpha, - unary_post_op, unary_post_op_args, unary_post_op_algorithm - ); #endif TORCH_CHECK(false, "Unimplemented (int8 linear with packed weight and bias)"); } @@ -1340,11 +1343,11 @@ TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) { m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise"), TORCH_FN(QLinearOnednn::run_pointwise)); m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.tensor"), - TORCH_FN(QLinearOnednn::run_pointwise_tensor)); + TORCH_FN(at::native::QLinearOnednn::run_pointwise_tensor)); m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary"), TORCH_FN(QLinearOnednn::run_pointwise_binary)); m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary_tensor"), - TORCH_FN(QLinearOnednn::run_pointwise_binary_tensor)); + TORCH_FN(at::native::QLinearOnednn::run_pointwise_binary_tensor)); } } // namespace diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.h b/aten/src/ATen/native/quantized/cpu/qlinear.h new file mode 100644 index 0000000000000..bc1db01a741c2 --- /dev/null +++ b/aten/src/ATen/native/quantized/cpu/qlinear.h @@ -0,0 +1,47 @@ +#pragma once +#include +#include + +namespace at { +namespace native { + +class QLinearOnednn final { + public: + C10_API static Tensor run_pointwise_tensor( + Tensor act, // int8 CPU tensor, not QTensor + Tensor act_scale, + Tensor act_zero_point, + Tensor onednn_weight, // int8 tensor from MkldnnCPU + Tensor weight_scales, + Tensor weight_zero_points, + std::optional bias, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + c10::string_view post_op_name, + torch::List> post_op_args, + c10::string_view post_op_algorithm); + +C10_API static Tensor run_pointwise_binary_tensor( + Tensor act, // int8 CPU tensor, not QTensor + Tensor act_scale, + Tensor act_zero_point, + Tensor onednn_weight, // int8 tensor from MkldnnCPU + Tensor weight_scales, + Tensor weight_zero_points, + std::optional other, // extra input for binary post-op + std::optional bias, + double output_scale, + int64_t output_zero_point, + std::optional output_dtype, + double other_scale, + int64_t other_zero_point, + c10::string_view binary_post_op, // e.g. "none", "sum", "add" + double binary_alpha, + c10::string_view unary_post_op, // e.g. "none", "relu" + torch::List> unary_post_op_args, + c10::string_view unary_post_op_algorithm); +}; + +} // namespace native +} // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt b/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt index fd6b7ff551db8..86897fe9f8d01 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/CMakeLists.txt @@ -12,7 +12,6 @@ include(GNUInstallDirs) project(PYTORCH_QNNPACK C CXX ASM) # ---[ Options. -option(PYTORCH_QNNPACK_CUSTOM_THREADPOOL "Build QNNPACK for custom thread pool" OFF) set(PYTORCH_QNNPACK_LIBRARY_TYPE "default" CACHE STRING "Type of library (shared, static, or default) to build") set_property(CACHE PYTORCH_QNNPACK_LIBRARY_TYPE PROPERTY STRINGS default static shared) option(PYTORCH_QNNPACK_BUILD_TESTS "Build QNNPACK unit tests" ON) @@ -373,13 +372,7 @@ elseif(NOT TARGET pthreadpool AND USE_SYSTEM_PTHREADPOOL) IMPORTED_LOCATION "${PTHREADPOOL_LIBRARY}") add_library(pthreadpool_interface INTERFACE) endif() -if(PYTORCH_QNNPACK_CUSTOM_THREADPOOL) - # Depend on pthreadpool interface, but not on implementation. - # This is used when QNNPACK user (e.g. Caffe2) provides its own threadpool implementation. - target_link_libraries(pytorch_qnnpack PUBLIC pthreadpool_interface) -else() - target_link_libraries(pytorch_qnnpack PUBLIC pthreadpool) -endif() +target_link_libraries(pytorch_qnnpack PUBLIC pthreadpool) # ---[ Configure FXdiv if(NOT TARGET fxdiv AND NOT USE_SYSTEM_FXDIV) diff --git a/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp b/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp index 7083e309f0989..9103bdd0d4149 100644 --- a/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp +++ b/aten/src/ATen/native/quantized/cudnn/ConvPrepack.cpp @@ -139,7 +139,7 @@ class QConvPackWeightInt8Cudnn final { int64_t groups) { torch::List output_padding; output_padding.reserve(kSpatialDim); - for (C10_UNUSED const auto idx : c10::irange(kSpatialDim)) { + for ([[maybe_unused]] const auto idx : c10::irange(kSpatialDim)) { output_padding.push_back((int64_t)0); } return _run(weight, bias, stride, padding, output_padding, dilation, groups, diff --git a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp index 67ea9cf308d20..7e85ae9f468ee 100644 --- a/aten/src/ATen/native/quantized/cudnn/Pooling.cpp +++ b/aten/src/ATen/native/quantized/cudnn/Pooling.cpp @@ -64,7 +64,7 @@ Tensor adaptive_avg_pool2d_quantized_cuda( auto result_fp32 = at::adaptive_avg_pool2d(input_fp32, output_size); return at::quantize_per_tensor(result_fp32, input.q_scale(), input.q_zero_point(), input.scalar_type()); #else // USE_CUDA - AT_ERROR("at::native::adaptive_avg_pool2d_quantized_cuda: ATen not compiled with USE_CUDA support"); + TORCH_CHECK(false, "at::native::adaptive_avg_pool2d_quantized_cuda: ATen not compiled with USE_CUDA support"); return Tensor{}; // never reached, placates the compiler #endif } @@ -209,11 +209,11 @@ Tensor quantized_max_pool2d_cudnn( // recall we casted our input and output to 4D if qx was 3D, so we recast it back to 3D prior to returning return (ndim == 3 ? qy.view(std::vector(output_shape.begin() + 1, output_shape.end())) : qy); #else // AT_CUDNN_ENABLED() - AT_ERROR("at::native::quantized_max_pool2d_cudnn: ATen not compiled with cuDNN support"); + TORCH_CHECK(false, "at::native::quantized_max_pool2d_cudnn: ATen not compiled with cuDNN support"); return Tensor{}; // never reached, placates the compiler #endif // AT_CUDNN_ENABLED() #else // USE_CUDA - AT_ERROR("at::native::quantized_max_pool2d_cudnn: ATen not compiled with USE_CUDA support"); + TORCH_CHECK(false, "at::native::quantized_max_pool2d_cudnn: ATen not compiled with USE_CUDA support"); return Tensor{}; // never reached, placates the compiler #endif } diff --git a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp index 121307c9bbc66..e9552802082d8 100644 --- a/aten/src/ATen/native/sparse/SparseCsrTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseCsrTensor.cpp @@ -459,7 +459,7 @@ Tensor _sparse_compressed_tensor_unsafe_symint( std::optional device, std::optional pin_memory) { if (!layout) { - AT_ERROR("sparse_compressed_tensor_unsafe expected sparse compressed tensor layout but got none"); + TORCH_CHECK(false, "sparse_compressed_tensor_unsafe expected sparse compressed tensor layout but got none"); } Layout layout_ = layout.value(); AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor_unsafe", [&]{}); @@ -587,7 +587,7 @@ Tensor sparse_compressed_tensor( std::optional pin_memory) { if (!layout) { - AT_ERROR("sparse_compressed_tensor expected sparse compressed tensor layout but got none"); + TORCH_CHECK(false, "sparse_compressed_tensor expected sparse compressed tensor layout but got none"); } Layout layout_ = layout.value(); AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor", [&]{}); @@ -616,7 +616,7 @@ Tensor sparse_compressed_tensor( std::optional pin_memory) { if (!layout) { - AT_ERROR("sparse_compressed_tensor expected sparse compressed tensor layout but got none"); + TORCH_CHECK(false, "sparse_compressed_tensor expected sparse compressed tensor layout but got none"); } Layout layout_ = layout.value(); AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(layout_, "sparse_compressed_tensor", [&]{}); diff --git a/aten/src/ATen/native/sparse/SparseMatMul.cpp b/aten/src/ATen/native/sparse/SparseMatMul.cpp index 39e2d82287503..a480fa2b3c7a7 100644 --- a/aten/src/ATen/native/sparse/SparseMatMul.cpp +++ b/aten/src/ATen/native/sparse/SparseMatMul.cpp @@ -159,8 +159,7 @@ void _csr_matmult( } } - for (C10_UNUSED const auto jj : c10::irange(length)) { - + for ([[maybe_unused]] const auto jj : c10::irange(length)) { // NOTE: the linked list that encodes col indices // is not guaranteed to be sorted. Cj[nnz] = head; diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index de8ee97a77627..075a4a4e4bd32 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -588,7 +588,7 @@ SparseTensor& copy_sparse_wrapper_( { NoNamesGuard guard; if (!self.is_sparse() || !src.is_sparse()) { - AT_ERROR( + TORCH_CHECK(false, "copy_() between dense and sparse Tensors is not implemented! Found self type = ", self.toString(), " and src type = ", diff --git a/aten/src/ATen/native/sparse/SparseTensorMath.cpp b/aten/src/ATen/native/sparse/SparseTensorMath.cpp index 387dcb465d394..45ff374c4736f 100644 --- a/aten/src/ATen/native/sparse/SparseTensorMath.cpp +++ b/aten/src/ATen/native/sparse/SparseTensorMath.cpp @@ -1224,9 +1224,9 @@ void s_addmm_out_sparse_dense_worker(int64_t nnz, int64_t dim_i, int64_t dim_j, r_ptr + row * r_stride0, r_stride1); } else { if (col < 0 || col >= dim_j) { - AT_ERROR("addmm: index out of column bound: ", col, " not between 1 and ", dim_j); + TORCH_CHECK(false, "addmm: index out of column bound: ", col, " not between 1 and ", dim_j); } else { - AT_ERROR("addmm: index out of row bound: ", row, " not between 1 and ", dim_i); + TORCH_CHECK(false, "addmm: index out of row bound: ", row, " not between 1 and ", dim_i); } } } @@ -1577,7 +1577,7 @@ SparseTensor& _sspaddmm_out_cpu( dense_ptr + col * dense_stride0, dense_stride1, newv_ptr + p * newv_stride0, 1); } else { - AT_ERROR("index out of bound. sspmm: ", col, " not between 1 and ", dim_j); + TORCH_CHECK(false, "index out of bound. sspmm: ", col, " not between 1 and ", dim_j); } } // Fill up the indices with the right values @@ -1602,7 +1602,7 @@ SparseTensor& _sspaddmm_out_cpu( // sparse, sparse, sparse, dense, real, real -> sparse Tensor& _sspaddmm_out_only_sparse(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha, Tensor& result) { - AT_ERROR("tensor.sspaddmm(...) can only be called on sparse tensors"); + TORCH_CHECK(false, "tensor.sspaddmm(...) can only be called on sparse tensors"); } // sparse, dense -> sparse diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp index d5f6540976773..faa39af82c7e3 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDABlas.cpp @@ -88,7 +88,7 @@ cusparseOperation_t convertTransToCusparseOperation(char trans) { else if (trans == 'n') return CUSPARSE_OPERATION_NON_TRANSPOSE; else if (trans == 'c') return CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE; else { - AT_ERROR("trans must be one of: t, n, c"); + TORCH_CHECK(false, "trans must be one of: t, n, c"); } } diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredApplyDense.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredApplyDense.cu index 8b4d6be5aaac7..925a33b0bbd8e 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredApplyDense.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredApplyDense.cu @@ -124,7 +124,7 @@ Tensor _sparse_semi_structured_apply_dense( const Tensor& threads_masks) { #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) - AT_ERROR("_sparse_semi_structured_apply_dense: not supported"); + TORCH_CHECK(false, "_sparse_semi_structured_apply_dense: not supported"); return Tensor{}; #else TORCH_CHECK( diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu index 01aa11dbdecb5..b8a54c01bea57 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredLinear.cu @@ -195,7 +195,7 @@ Tensor two_four_sgemm( meta_dtype = at::kInt; break; default: - AT_ERROR("two_four_sgemm: invalid size of meta tensor datatype " + TORCH_CHECK(false, "two_four_sgemm: invalid size of meta tensor datatype " "encountered"); } TORCH_CHECK(meta.dtype() == meta_dtype, @@ -215,7 +215,7 @@ Tensor two_four_sgemm( } else if constexpr (std::is_same_v) { tensor_d_dtype = at::kFloat; } else { - AT_ERROR("two_four_sgemm: invalid datatype for sparse GEMM output ", + TORCH_CHECK(false, "two_four_sgemm: invalid datatype for sparse GEMM output ", "encountered"); } if constexpr (use_bias) { @@ -424,7 +424,7 @@ Tensor two_four_sgemm_dispatch_layouts( } } - AT_ERROR("two_four_sgemm_dispatch_layouts: Combination of ", + TORCH_CHECK(false, "two_four_sgemm_dispatch_layouts: Combination of ", tensor_a_row_major ? "row-major" : "column_major", " and ", tensor_b_row_major ? "row-major" : "column_major", " layouts for input tensors is not supported"); @@ -573,7 +573,7 @@ Tensor two_four_sgemm_dispatch_layouts_bias_activation( } } - AT_ERROR("two_four_sgemm_dispatch_layouts: Activation \"", activation, + TORCH_CHECK(false, "two_four_sgemm_dispatch_layouts: Activation \"", activation, "\" is not supported for given input tensors"); return Tensor{}; } @@ -608,7 +608,7 @@ Tensor _sparse_semi_structured_linear( "_sparse_semi_structured_mm/_sparse_semi_structured_addmm " "instead."); #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) - AT_ERROR("_sparse_semi_structured_linear: CUTLASS not supported"); + TORCH_CHECK(false, "_sparse_semi_structured_linear: CUTLASS not supported"); return Tensor{}; #else // No need to check that all tensors are on CUDA device, as this diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu index abd6cf9739c63..9f8fc2ca5a160 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredOps.cu @@ -187,7 +187,7 @@ void spgemm_cutlass( tensor_e_dtype = at::kInt; break; default: - AT_ERROR(__func__, ": invalid size of meta tensor datatype " + TORCH_CHECK(false, __func__, ": invalid size of meta tensor datatype " "encountered"); } TORCH_CHECK(tensor_e.dtype() == tensor_e_dtype, @@ -424,7 +424,7 @@ void spgemm_cutlass_dispatch_layouts( } } - AT_ERROR(__func__, "_dispatch_layouts: Combination of ", + TORCH_CHECK(false, __func__, "_dispatch_layouts: Combination of ", tensor_a_row_major ? "row-major" : "column_major", " and ", tensor_b_row_major ? "row-major" : "column_major", " layouts for input tensors is not supported"); @@ -525,7 +525,7 @@ Tensor sparse_semi_structured_mad_op( const std::optional& input_opt, const Scalar& alpha, const Scalar& beta, const std::optional out_dtype_opt) { #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) - AT_ERROR(__func__, " : CUTLASS not supported"); + TORCH_CHECK(false, __func__, " : CUTLASS not supported"); return Tensor{}; #else // No need to check that all tensors are on CUDA device, as this @@ -846,7 +846,7 @@ static void reorder_meta(cutlass::TensorRef dest, std::tuple _to_sparse_semi_structured(const Tensor& dense) { #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) - AT_ERROR(__func__, " : CUTLASS not supported"); + TORCH_CHECK(false, __func__, " : CUTLASS not supported"); return std::make_tuple(Tensor{}, Tensor{}); #else // Check dimensions of the dense matrix. @@ -871,7 +871,7 @@ _to_sparse_semi_structured(const Tensor& dense) { ksparse = 2; dense_elems_per_meta_elem = 8; } else { - AT_ERROR("_to_sparse_semi_structured: Invalid dense argument datatype ", + TORCH_CHECK(false, "_to_sparse_semi_structured: Invalid dense argument datatype ", dense.dtype(), " encountered"); } @@ -879,12 +879,12 @@ _to_sparse_semi_structured(const Tensor& dense) { const auto dense_ncols = dense.size(1); if (dense_nrows % (meta_dtype == at::kShort ? 32 : 16) != 0) { - AT_ERROR("_to_sparse_semi_structured: Number of rows of dense matrix must " + TORCH_CHECK(false, "_to_sparse_semi_structured: Number of rows of dense matrix must " "be divisible by ", (meta_dtype == at::kShort ? 32 : 16), ", but it is ", dense_nrows); } if (dense_ncols % dense_elems_per_meta_elem != 0) { - AT_ERROR("_to_sparse_semi_structured: Number of columns of dense matrix " + TORCH_CHECK(false, "_to_sparse_semi_structured: Number of columns of dense matrix " "must be divisible by ", dense_elems_per_meta_elem, ", but it is ", dense_ncols); } @@ -925,7 +925,7 @@ _to_sparse_semi_structured(const Tensor& dense) { } else if (mask_elems == std::make_tuple(0, 0, 1, 1)) { meta_quadruple = 14; // 1110 } else { - AT_ERROR("_to_sparse_semi_structured: dense argument does not match ", + TORCH_CHECK(false, "_to_sparse_semi_structured: dense argument does not match ", (dense.dtype() != at::kFloat) ? "2:4" : "1:2", "sparsity pattern"); } diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu index b5382b5b08486..7286e9263a05b 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiStructuredTile.cu @@ -281,7 +281,7 @@ std::tuple _sparse_semi_structured_tile( bool use_cutlass) { #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) - AT_ERROR("_sparse_semi_structured_tile: not supported"); + TORCH_CHECK(false, "_sparse_semi_structured_tile: not supported"); return std::make_tuple(Tensor{}, Tensor{}, Tensor{}, Tensor{}, Tensor{}); #else std::string algo(algorithm.data(), algorithm.size()); diff --git a/aten/src/ATen/native/sparse/cuda/SparseSemiSturcturedApply.cu b/aten/src/ATen/native/sparse/cuda/SparseSemiSturcturedApply.cu index 2fbbaa0290703..9b9b1bc0cc60d 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseSemiSturcturedApply.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseSemiSturcturedApply.cu @@ -90,7 +90,7 @@ std::tuple _sparse_semi_structured_apply_typed(Tensor input, Ten std::tuple _sparse_semi_structured_apply(const Tensor& input, const Tensor& threads_masks) // Returned by `_sparse_semi_structured_tile` { #if defined(USE_ROCM) || defined(_MSC_VER) || (defined(CUDA_VERSION) && CUDA_VERSION < 11080) - AT_ERROR("_sparse_semi_structured_apply: not supported"); + TORCH_CHECK(false, "_sparse_semi_structured_apply: not supported"); return std::make_tuple(Tensor{}, Tensor{}); #else TORCH_CHECK( diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index 914915da01a03..2a0974bcfad59 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -782,8 +782,8 @@ std::tuple _scaled_dot_product_cudnn_attention_backward_ } else if (bias_dim == 3) { attn_bias_ = attn_bias_.value().expand({batch_size, 1, max_seqlen_batch_q, max_seqlen_batch_k}); } else { - attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); TORCH_CHECK(bias_dim == 4, "cuDNN SDPA expects either a 2D, 3D, or 4D attn_bias but got ", attn_bias_.value().dim(), "D"); + attn_bias_ = attn_bias_.value().expand({batch_size, attn_bias_.value().size(1), max_seqlen_batch_q, max_seqlen_batch_k}); } } diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h index 7e4f11a9e537b..70320a599c4ab 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/epilogue/epilogue_pipelined.h @@ -589,7 +589,6 @@ class EpiloguePipelined : public EpilogueBase< } } - // This should be constexpr, but it's only supported on c++14 constexpr int CUTLASS_HOST_DEVICE getRowOffset(int i) { using ThreadMap = typename OutputTileIterator::ThreadMap; diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index d84d941769216..a30b02335fd64 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -68,16 +68,11 @@ bool check_prefer_cudnn_attention() { std::array priority_order(sdp_params const& params) { constexpr std::array default_order{ SDPBackend::flash_attention, - SDPBackend::cudnn_attention, SDPBackend::efficient_attention, - SDPBackend::math}; - constexpr std::array cudnn_order{ + SDPBackend::math, SDPBackend::cudnn_attention, - SDPBackend::flash_attention, - SDPBackend::efficient_attention, - SDPBackend::math}; - static const bool prefer_cudnn = check_prefer_cudnn_attention(); - return prefer_cudnn ? cudnn_order : default_order; + }; + return default_order; } bool use_tensor_cores(sdp_params const& params, cudaDeviceProp* dprops, bool is_half) { @@ -561,7 +556,7 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { #endif #if defined(CUDNN_VERSION) && CUDNN_VERSION < 90000 if (debug) { - TORCH_WARN(CUDNN_VERSION, "cuDNN version too old to use Flash Attention! (< v9.0.0)"); + TORCH_WARN(CUDNN_VERSION, " cuDNN version too old to use CuDNN Attention (< v9.0.0)"); } return false; #endif @@ -577,7 +572,6 @@ bool can_use_cudnn_attention(const sdp_params& params, bool debug) { check_tensor_shapes, check_cudnn_tensor_shapes, check_cudnn_deterministic, - // check_is_causal, check_dtypes_low_precision, check_attn_mask_shape, check_cudnn_hardware_support diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip index 92e51f85d8e54..7191a5f133312 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/flash_api.hip @@ -96,6 +96,12 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head int window_size_right, const bool return_softmax, std::optional gen_) { + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + // [ROCM specific]: must be at the beginning of the function + // Otherwise check_gpu_arch() checks cuda:0 device. + at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; + auto stream = at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream(); check_gpu_arch(stream); @@ -155,10 +161,6 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - at::hip::HIPGuardMasqueradingAsCUDA device_guard{(char)q.get_device()}; - // We want to checkpoint and save the RNG state for backward if dropout // We get the default generator and return the seed and offset which will // be used in the backward function @@ -201,14 +203,14 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head at::Tensor v_t = v_padded.permute({0,2,1,3}); at::Tensor output_t = out.permute({0,2,1,3}); - at::Tensor M = at::empty({batch_size * num_heads, seqlen_q}, at::dtype(at::kFloat).device(q.device())); // aka softmax_lse + auto opts = q.options(); + at::Tensor M = at::empty({batch_size * num_heads, seqlen_q}, opts.dtype(at::kFloat)); // aka softmax_lse at::Tensor softmax_fa_t; if (return_softmax) { - softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, - at::dtype(q.dtype()).device(q.device())); + softmax_fa_t = at::empty({batch_size, num_heads, seqlen_q, seqlen_k}, opts); } else { - softmax_fa_t = at::empty({ 0, 0, 0, 0 }, at::dtype(q.dtype()).device(q.device())); + softmax_fa_t = at::empty({ 0, 0, 0, 0 }, opts); } hipError_t err; // TODO: Error handling @@ -241,7 +243,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head is_causal, stream); - return {out, q_padded, k_padded, v_padded, M, seed_t, offset_t, softmax_fa_t}; + return {out, q_padded, k_padded, v_padded, M.view({batch_size, num_heads, seqlen_q}), seed_t, offset_t, softmax_fa_t}; } std::tuple @@ -406,8 +408,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si at::Tensor dv_t = dv.permute({0,2,1,3}); at::Tensor dout_t = dout.permute({0,2,1,3}); - at::Tensor softmax_lse_cont = softmax_lse.contiguous(); - at::Tensor delta = at::empty_like(softmax_lse).contiguous(); + at::Tensor softmax_lse_cont = softmax_lse.view({batch_size * num_heads, seqlen_q}).contiguous(); + at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); int d_head = head_size_og; hipError_t err; // TODO: Error handling diff --git a/aten/src/ATen/native/utils/ParamUtils.h b/aten/src/ATen/native/utils/ParamUtils.h index adb5f1cfa49f9..c9088c03d81c1 100644 --- a/aten/src/ATen/native/utils/ParamUtils.h +++ b/aten/src/ATen/native/utils/ParamUtils.h @@ -18,7 +18,7 @@ inline std::vector _expand_param_if_needed( ss << "expected " << param_name << " to be a single integer value or a " << "list of " << expected_dim << " values to match the convolution " << "dimensions, but got " << param_name << "=" << list_param; - AT_ERROR(ss.str()); + TORCH_CHECK(false, ss.str()); } else { return list_param.vec(); } diff --git a/aten/src/ATen/native/vulkan/api/Utils.h b/aten/src/ATen/native/vulkan/api/Utils.h index db4e012e23f57..bdc7a95a31b38 100644 --- a/aten/src/ATen/native/vulkan/api/Utils.h +++ b/aten/src/ATen/native/vulkan/api/Utils.h @@ -11,7 +11,7 @@ // Compiler Macros -// Suppress an unused variable. Copied from C10_UNUSED +// Suppress an unused variable. Copied from [[maybe_unused]] #if defined(_MSC_VER) && !defined(__clang__) #define VK_UNUSED __pragma(warning(suppress : 4100 4101)) #else diff --git a/aten/src/ATen/native/vulkan/ops/NativeLayerNorm.cpp b/aten/src/ATen/native/vulkan/ops/NativeLayerNorm.cpp index 94d155cc2f647..1ec6957162cbb 100644 --- a/aten/src/ATen/native/vulkan/ops/NativeLayerNorm.cpp +++ b/aten/src/ATen/native/vulkan/ops/NativeLayerNorm.cpp @@ -48,7 +48,7 @@ void _check_layer_norm_inputs( ss << ", " << size; } ss << "], but got input of size" << input_shape; - AT_ERROR(ss.str()); + TORCH_CHECK(false, ss.str()); } } diff --git a/aten/src/ATen/native/xnnpack/Init.cpp b/aten/src/ATen/native/xnnpack/Init.cpp index 5f8c5ecf89a0c..d8612ef9d7dea 100644 --- a/aten/src/ATen/native/xnnpack/Init.cpp +++ b/aten/src/ATen/native/xnnpack/Init.cpp @@ -31,7 +31,7 @@ bool initialize() { return is_initialized_; } -bool C10_UNUSED deinitialize() { +[[maybe_unused]] bool deinitialize() { using namespace internal; // This implementation allows for retries. diff --git a/aten/src/ATen/record_function.h b/aten/src/ATen/record_function.h index 63fbcb55e96d2..15130c9136752 100644 --- a/aten/src/ATen/record_function.h +++ b/aten/src/ATen/record_function.h @@ -299,6 +299,31 @@ struct TORCH_API RecordFunction { before(fn, current_sequence_nr); } + template + void before( + F fn, + c10::ArrayRef args, + const std::unordered_map* kwargs, + int64_t current_sequence_nr = -1) { + if (!isActive()) { + return; + } + kwinputs_ = *kwargs; + before(std::move(fn), args, current_sequence_nr); + } + + template + void before( + F fn, + const std::unordered_map* kwargs, + int64_t current_sequence_nr = -1) { + if (!isActive()) { + return; + } + kwinputs_ = *kwargs; + before(fn, current_sequence_nr); + } + template void before( F fn, @@ -629,6 +654,13 @@ void record_function_with_scope_and_debug_handle( #define RECORD_USER_SCOPE_WITH_INPUTS(fn, inputs) \ RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::USER_SCOPE, fn, inputs) +#define RECORD_USER_SCOPE_WITH_KWARGS_ONLY(fn, kwargs) \ + RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::USER_SCOPE, \ + fn, \ + c10::ArrayRef{}, \ + kwargs) + // Helper macro to pass in debug handle that is used to // post process events #define RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ diff --git a/aten/src/ATen/test/basic.cpp b/aten/src/ATen/test/basic.cpp index 1d4d644c5f098..94c10f6a14847 100644 --- a/aten/src/ATen/test/basic.cpp +++ b/aten/src/ATen/test/basic.cpp @@ -89,7 +89,7 @@ void TestAdd(DeprecatedTypeProperties& type) { void TestZeros(DeprecatedTypeProperties& type) { auto begin = std::chrono::high_resolution_clock::now(); Tensor a = zeros({1024, 1024}, type); - for (C10_UNUSED const auto i : c10::irange(1, 1000)) { + for ([[maybe_unused]] const auto i : c10::irange(1, 1000)) { a = zeros({128, 128}, type); } auto end = std::chrono::high_resolution_clock::now(); @@ -107,7 +107,7 @@ void TestLoadsOfAdds(DeprecatedTypeProperties& type) { auto begin = std::chrono::high_resolution_clock::now(); Tensor d = ones({3, 4}, type); Tensor r = zeros({3, 4}, type); - for (C10_UNUSED const auto i : c10::irange(1000)) { + for ([[maybe_unused]] const auto i : c10::irange(1000)) { add_out(r, r, d); } auto end = std::chrono::high_resolution_clock::now(); @@ -124,7 +124,7 @@ void TestLoadOfAddsWithCopy(DeprecatedTypeProperties& type) { auto begin = std::chrono::high_resolution_clock::now(); Tensor d = ones({3, 4}, type); Tensor r = zeros({3, 4}, type); - for (C10_UNUSED const auto i : c10::irange(1000)) { + for ([[maybe_unused]] const auto i : c10::irange(1000)) { r = add(r, d); } auto end = std::chrono::high_resolution_clock::now(); diff --git a/aten/src/ATen/test/cpu_generator_test.cpp b/aten/src/ATen/test/cpu_generator_test.cpp index f24ff69250424..5a345473e2693 100644 --- a/aten/src/ATen/test/cpu_generator_test.cpp +++ b/aten/src/ATen/test/cpu_generator_test.cpp @@ -161,7 +161,7 @@ TEST(CPUGeneratorImpl, TestPhiloxEngineOffset1) { // So if you want to skip 8 values, offset would // be 2, since 2*4=8. at::Philox4_32 engine2(123, 1, 2); - for (C10_UNUSED const auto i : c10::irange(8)) { + for ([[maybe_unused]] const auto i : c10::irange(8)) { // Note: instead of using the engine() call 8 times // we could have achieved the same functionality by // calling the incr() function twice. @@ -222,14 +222,14 @@ TEST(CPUGeneratorImpl, TestMT19937EngineReproducibility) { // test with zero seed at::mt19937 engine1(0); std::mt19937 engine2(0); - for (C10_UNUSED const auto i : c10::irange(10000)) { + for ([[maybe_unused]] const auto i : c10::irange(10000)) { ASSERT_EQ(engine1(), engine2()); } // test with large seed engine1 = at::mt19937(2147483647); engine2 = std::mt19937(2147483647); - for (C10_UNUSED const auto i : c10::irange(10000)) { + for ([[maybe_unused]] const auto i : c10::irange(10000)) { ASSERT_EQ(engine1(), engine2()); } @@ -238,10 +238,9 @@ TEST(CPUGeneratorImpl, TestMT19937EngineReproducibility) { auto seed = rd(); engine1 = at::mt19937(seed); engine2 = std::mt19937(seed); - for (C10_UNUSED const auto i : c10::irange(10000)) { + for ([[maybe_unused]] const auto i : c10::irange(10000)) { ASSERT_EQ(engine1(), engine2()); } - } TEST(CPUGeneratorImpl, TestPhiloxEngineReproducibilityRandN) { diff --git a/aten/src/ATen/test/cuda_cub_test.cu b/aten/src/ATen/test/cuda_cub_test.cu index 9041ef70cedb6..5e5e25d2a8c90 100644 --- a/aten/src/ATen/test/cuda_cub_test.cu +++ b/aten/src/ATen/test/cuda_cub_test.cu @@ -138,7 +138,9 @@ __managed__ int input[] = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }; TEST(InclusiveScanSplit, CubTest) { if (!at::cuda::is_available()) return; - at::globalContext().lazyInitCUDA(); // This is required to use PyTorch's caching allocator. + at::globalContext().lazyInitDevice( + c10::DeviceType::CUDA); // This is required to use PyTorch's caching + // allocator. int *output1; cudaMallocManaged(&output1, sizeof(int) * 10); @@ -162,7 +164,9 @@ TEST(InclusiveScanSplit, CubTest) { TEST(ExclusiveScanSplit, CubTest) { if (!at::cuda::is_available()) return; - at::globalContext().lazyInitCUDA(); // This is required to use PyTorch's caching allocator. + at::globalContext().lazyInitDevice( + c10::DeviceType::CUDA); // This is required to use PyTorch's caching + // allocator. int *output2; cudaMallocManaged(&output2, sizeof(int) * 10); diff --git a/aten/src/ATen/test/legacy_vmap_test.cpp b/aten/src/ATen/test/legacy_vmap_test.cpp index cbf7ca6ec4bdb..ad74ca0ce11e4 100644 --- a/aten/src/ATen/test/legacy_vmap_test.cpp +++ b/aten/src/ATen/test/legacy_vmap_test.cpp @@ -170,7 +170,7 @@ TEST(VmapTest, TestBatchedTensorActualDim) { { // ActualDim on kVmapMaxTensorDims sized underlying tensor auto tensor = ones({}); - for (C10_UNUSED const auto i : c10::irange(kVmapMaxTensorDims)) { + for ([[maybe_unused]] const auto i : c10::irange(kVmapMaxTensorDims)) { tensor = tensor.unsqueeze(0); } ASSERT_EQ(tensor.dim(), kVmapMaxTensorDims); diff --git a/aten/src/ATen/test/thread_init_test.cpp b/aten/src/ATen/test/thread_init_test.cpp index 5c2b9036875aa..7ad7a18e9c660 100644 --- a/aten/src/ATen/test/thread_init_test.cpp +++ b/aten/src/ATen/test/thread_init_test.cpp @@ -14,7 +14,7 @@ void test(int given_num_threads) { ASSERT_TRUE(given_num_threads >= 0); ASSERT_EQ(at::get_num_threads(), given_num_threads); auto t_sum = t.sum(); - for (C10_UNUSED const auto i : c10::irange(1000)) { + for ([[maybe_unused]] const auto i : c10::irange(1000)) { t_sum = t_sum + t.sum(); } } diff --git a/aten/src/ATen/test/vec_test_all_types.cpp b/aten/src/ATen/test/vec_test_all_types.cpp index 042bb56d6ffa1..bc480e781cc81 100644 --- a/aten/src/ATen/test/vec_test_all_types.cpp +++ b/aten/src/ATen/test/vec_test_all_types.cpp @@ -821,6 +821,17 @@ namespace { createDefaultTernaryTestCase(TestSeed()), RESOLVE_OVERLOAD(filter_clamp)); } + TYPED_TEST(MinMax, ClampVecN) { + using VT = ValueType; + using vec = at::vec::VectorizedN; + test_ternary( + NAME_INFO(clamp), clamp, + [](const vec& v0, const vec& v1, const vec& v2) { + return clamp(v0, v1, v2); + }, + createDefaultTernaryTestCase(TestSeed()), + RESOLVE_OVERLOAD(filter_clamp)); + } TYPED_TEST(BitwiseFloatsAdditional, ZeroMask) { using vec = TypeParam; using VT = ValueType; @@ -895,7 +906,25 @@ namespace { .setTestSeed(TestSeed()); test_ternary( - NAME_INFO(clamp), RESOLVE_OVERLOAD(local_fmadd), + NAME_INFO(fmadd), RESOLVE_OVERLOAD(local_fmadd), + [](const vec& v0, const vec& v1, const vec& v2) { + return at::vec::fmadd(v0, v1, v2); + }, + test_case, + RESOLVE_OVERLOAD(filter_fmadd)); + } + TYPED_TEST(BitwiseFloatsAdditional, FmaddVecN) { + using VT = ValueType; + using vec = at::vec::VectorizedN; + + auto test_case = TestingCase::getBuilder() + .addDomain(CheckWithinDomains{ + {{(VT)-1000, (VT)1000}, {(VT)-1000, (VT)1000}, {(VT)-1000, (VT)1000}}, + true, getDefaultTolerance()}) + .setTestSeed(TestSeed()); + + test_ternary( + NAME_INFO(fmadd), RESOLVE_OVERLOAD(local_fmadd), [](const vec& v0, const vec& v1, const vec& v2) { return at::vec::fmadd(v0, v1, v2); }, @@ -1122,24 +1151,28 @@ namespace { float minv = static_cast(static_cast(min_val) * 2.0); float maxv = static_cast(static_cast(max_val) * 2.0); ValueGen gen(minv, maxv, seed.add(2)); - for (C10_UNUSED const auto i : c10::irange(trials)) { - float scale = generator_sc.get(); - float inv_scale = 1.0f / static_cast(scale); - auto zero_point_val = generator_zp.get(); - int index = 0; - for (int j = 0; j < vec::float_num_vecs(); j++) { - //generate vals - for (auto& v : unit_float_vec) { - v = gen.get(); - expected_qint_vals[index] = quantize_val(scale, zero_point_val, v); - index++; - } - float_ret[j] = vfloat::loadu(unit_float_vec); + for ([[maybe_unused]] const auto i : c10::irange(trials)) { + float scale = generator_sc.get(); + float inv_scale = 1.0f / static_cast(scale); + auto zero_point_val = generator_zp.get(); + int index = 0; + for (int j = 0; j < vec::float_num_vecs(); j++) { + // generate vals + for (auto& v : unit_float_vec) { + v = gen.get(); + expected_qint_vals[index] = + quantize_val(scale, zero_point_val, v); + index++; } - auto expected = vec::loadu(expected_qint_vals); - auto actual = vec::quantize(float_ret, scale, zero_point_val, inv_scale); - if (AssertVectorized(NAME_INFO(Quantize), expected, actual).check()) return; - } //trials; + float_ret[j] = vfloat::loadu(unit_float_vec); + } + auto expected = vec::loadu(expected_qint_vals); + auto actual = + vec::quantize(float_ret, scale, zero_point_val, inv_scale); + if (AssertVectorized(NAME_INFO(Quantize), expected, actual) + .check()) + return; + } // trials; } #if (defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER) // This test case aims to test at::vec::QuantizeAvx512 and @@ -1168,7 +1201,7 @@ namespace { float minv = static_cast(static_cast(min_val) * 2.0); float maxv = static_cast(static_cast(max_val) * 2.0); ValueGen gen(minv, maxv, seed.add(2)); - for (C10_UNUSED const auto i : c10::irange(trials)) { + for ([[maybe_unused]] const auto i : c10::irange(trials)) { float scale = generator_sc.get(); float inv_scale = 1.0f / static_cast(scale); auto zero_point_val = generator_zp.get(); @@ -1227,35 +1260,36 @@ namespace { ValueGen generator(min_val, max_val, seed.add(1)); //scale ValueGen generator_sc(1.f, 15.f, seed.add(2)); - for (C10_UNUSED const auto i : c10::irange(trials)) { - float scale = generator_sc.get(); - int32_t zero_point_val = generator.get(); - float scale_zp_premul = -(scale * zero_point_val); - vfloat vf_scale = vfloat{ scale }; - vfloat vf_zp = vfloat{ static_cast(zero_point_val) }; - vfloat vf_scale_zp = vfloat{ scale_zp_premul }; - //generate vals - for (auto& x : qint_vals) { - x = generator.get(); + for ([[maybe_unused]] const auto i : c10::irange(trials)) { + float scale = generator_sc.get(); + int32_t zero_point_val = generator.get(); + float scale_zp_premul = -(scale * zero_point_val); + vfloat vf_scale = vfloat{scale}; + vfloat vf_zp = vfloat{static_cast(zero_point_val)}; + vfloat vf_scale_zp = vfloat{scale_zp_premul}; + // generate vals + for (auto& x : qint_vals) { + x = generator.get(); + } + // get expected + int index = 0; + auto qint_vec = vec::loadu(qint_vals); + auto actual_float_ret = + qint_vec.dequantize(vf_scale, vf_zp, vf_scale_zp); + for (int j = 0; j < vec::float_num_vecs(); j++) { + for (auto& v : unit_exp_vals) { + v = dequantize_val(scale, zero_point_val, qint_vals[index]); + index++; } - //get expected - int index = 0; - auto qint_vec = vec::loadu(qint_vals); - auto actual_float_ret = qint_vec.dequantize(vf_scale, vf_zp, vf_scale_zp); - for (int j = 0; j < vec::float_num_vecs(); j++) { - for (auto& v : unit_exp_vals) { - v = dequantize_val(scale, zero_point_val, qint_vals[index]); - index++; - } - vfloat expected = vfloat::loadu(unit_exp_vals); - const auto& actual = actual_float_ret[j]; + vfloat expected = vfloat::loadu(unit_exp_vals); + const auto& actual = actual_float_ret[j]; #if defined(CHECK_DEQUANT_WITH_LOW_PRECISION) if (AssertVectorized(NAME_INFO(DeQuantize), seed, expected, actual).check(false, true, 1.e-3f)) return; #else if (AssertVectorized(NAME_INFO(DeQuantize), seed, expected, actual).check()) return; #endif } - } //trials; + } // trials; } TYPED_TEST(QuantizationTests, ReQuantizeFromInt) { using vec = TypeParam; @@ -1274,25 +1308,29 @@ namespace { ValueGen generator(min_val, max_val, seed); //scale ValueGen generator_sc(1.f, 15.f, seed.add(1)); - for (C10_UNUSED const auto i : c10::irange(trials)) { - float multiplier = 1.f / (generator_sc.get()); - auto zero_point_val = generator.get(); - int index = 0; - for (int j = 0; j < vec::float_num_vecs(); j++) { - //generate vals - for (auto& v : unit_int_vec) { - v = c10::qint32(generator.get()); - expected_qint_vals[index] = requantize_from_int(multiplier, zero_point_val, v.val_); - index++; - } - int_ret[j] = vqint::loadu(unit_int_vec); - } - auto expected = vec::loadu(expected_qint_vals); - auto actual = vec::requantize_from_int(int_ret, multiplier, zero_point_val); - if (AssertVectorized(NAME_INFO(ReQuantizeFromInt), seed, expected, actual).check()) { - return; + for ([[maybe_unused]] const auto i : c10::irange(trials)) { + float multiplier = 1.f / (generator_sc.get()); + auto zero_point_val = generator.get(); + int index = 0; + for (int j = 0; j < vec::float_num_vecs(); j++) { + // generate vals + for (auto& v : unit_int_vec) { + v = c10::qint32(generator.get()); + expected_qint_vals[index] = requantize_from_int( + multiplier, zero_point_val, v.val_); + index++; } - } //trials; + int_ret[j] = vqint::loadu(unit_int_vec); + } + auto expected = vec::loadu(expected_qint_vals); + auto actual = + vec::requantize_from_int(int_ret, multiplier, zero_point_val); + if (AssertVectorized( + NAME_INFO(ReQuantizeFromInt), seed, expected, actual) + .check()) { + return; + } + } // trials; } TYPED_TEST(QuantizationTests, WideningSubtract) { using vec = TypeParam; @@ -1311,30 +1349,33 @@ namespace { typename vec::int_vec_return_type expected_int_ret; auto seed = TestSeed(); ValueGen generator(min_val, max_val, seed); - for (C10_UNUSED const auto i : c10::irange(trials)) { - //generate vals - for (int j = 0; j < vec::size(); j++) { - qint_vals[j] = generator.get(); - qint_b[j] = generator.get(); - if constexpr (std::is_same_v) { - //filter overflow cases - filter_sub_overflow(qint_vals[j], qint_b[j]); - } + for ([[maybe_unused]] const auto i : c10::irange(trials)) { + // generate vals + for (int j = 0; j < vec::size(); j++) { + qint_vals[j] = generator.get(); + qint_b[j] = generator.get(); + if constexpr (std::is_same_v) { + // filter overflow cases + filter_sub_overflow(qint_vals[j], qint_b[j]); } - int index = 0; - auto qint_vec = vec::loadu(qint_vals); - auto qint_vec_b = vec::loadu(qint_b); - auto actual_int_ret = qint_vec.widening_subtract(qint_vec_b); - for (int j = 0; j < vec::float_num_vecs(); j++) { - for (auto& v : unit_exp_vals) { - v = widening_subtract(qint_vals[index], qint_b[index]); - index++; - } - auto expected = vqint::loadu(unit_exp_vals); - const auto& actual = actual_int_ret[j]; - if (AssertVectorized(NAME_INFO(WideningSubtract), seed, expected, actual).check()) return; + } + int index = 0; + auto qint_vec = vec::loadu(qint_vals); + auto qint_vec_b = vec::loadu(qint_b); + auto actual_int_ret = qint_vec.widening_subtract(qint_vec_b); + for (int j = 0; j < vec::float_num_vecs(); j++) { + for (auto& v : unit_exp_vals) { + v = widening_subtract(qint_vals[index], qint_b[index]); + index++; } - } //trials; + auto expected = vqint::loadu(unit_exp_vals); + const auto& actual = actual_int_ret[j]; + if (AssertVectorized( + NAME_INFO(WideningSubtract), seed, expected, actual) + .check()) + return; + } + } // trials; } TYPED_TEST(QuantizationTests, Relu) { using vec = TypeParam; @@ -1706,6 +1747,7 @@ namespace { } while (0) TEST_CONVERT_TO(int8_t); TEST_CONVERT_TO(uint8_t); + TEST_CONVERT_TO(float); #undef TEST_CONVERT_TO } #endif @@ -1827,13 +1869,13 @@ namespace { #define TEST_MASK_CAST(dst_t, mask_t, mask_n) \ do { \ - CACHE_ALIGN mask_t x[mask_n * size]; \ - CACHE_ALIGN dst_t y[mask_n * size]; \ - auto seed = TestSeed(); \ - auto vec_mask = generate_vec_mask(seed); \ constexpr int num_dst_elements = \ std::min(size, at::vec::Vectorized::size()); \ constexpr int dst_n = mask_n * size / num_dst_elements; \ + CACHE_ALIGN mask_t x[mask_n * size]; \ + CACHE_ALIGN dst_t y[at::vec::VectorizedN::size()]; \ + auto seed = TestSeed(); \ + auto vec_mask = generate_vec_mask(seed); \ auto vec_mask_new = vec_mask.template cast(); \ vec_mask.template to().store(x); \ vec_mask_new.template to().store(y); \ diff --git a/aten/src/ATen/test/vec_test_all_types.h b/aten/src/ATen/test/vec_test_all_types.h index 9215e9ff393f3..db2e2616a306c 100644 --- a/aten/src/ATen/test/vec_test_all_types.h +++ b/aten/src/ATen/test/vec_test_all_types.h @@ -943,22 +943,25 @@ void test_unary( UVT start = dmn_argc > 0 ? dmn.ArgsDomain[0].start : default_start; UVT end = dmn_argc > 0 ? dmn.ArgsDomain[0].end : default_end; ValueGen generator(start, end, seed.add(changeSeedBy)); - for (C10_UNUSED const auto trial : c10::irange(trialCount)) { - for (const auto k : c10::irange(el_count)) { - vals[k] = generator.get(); - call_filter(filter, vals[k]); - //map operator - expected[k] = expectedFunction(vals[k]); - } - // test - auto input = vec_type::loadu(vals); - auto actual = actualFunction(input); - auto vec_expected = vec_type::loadu(expected); - AssertVectorized vecAssert(testNameInfo, seed, vec_expected, actual, input); - if (vecAssert.check(bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) return; - - }// trial - //inrease Seed + for ([[maybe_unused]] const auto trial : c10::irange(trialCount)) { + for (const auto k : c10::irange(el_count)) { + vals[k] = generator.get(); + call_filter(filter, vals[k]); + // map operator + expected[k] = expectedFunction(vals[k]); + } + // test + auto input = vec_type::loadu(vals); + auto actual = actualFunction(input); + auto vec_expected = vec_type::loadu(expected); + AssertVectorized vecAssert( + testNameInfo, seed, vec_expected, actual, input); + if (vecAssert.check( + bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) + return; + + } // trial + // inrease Seed changeSeedBy += 1; } for (auto& custom : testCase.getCustomChecks()) { @@ -1002,22 +1005,25 @@ void test_binary( UVT end1 = dmn_argc > 1 ? dmn.ArgsDomain[1].end : default_end; ValueGen generator0(start0, end0, seed.add(changeSeedBy)); ValueGen generator1(start1, end1, seed.add(changeSeedBy + 1)); - for (C10_UNUSED const auto trial : c10::irange(trialCount)) { - for (const auto k : c10::irange(el_count)) { - vals0[k] = generator0.get(); - vals1[k] = generator1.get(); - call_filter(filter, vals0[k], vals1[k]); - //map operator - expected[k] = expectedFunction(vals0[k], vals1[k]); - } - // test - auto input0 = vec_type::loadu(vals0); - auto input1 = vec_type::loadu(vals1); - auto actual = actualFunction(input0, input1); - auto vec_expected = vec_type::loadu(expected); - AssertVectorized vecAssert(testNameInfo, seed, vec_expected, actual, input0, input1); - if (vecAssert.check(bitwise, dmn.CheckWithTolerance, dmn.ToleranceError))return; - }// trial + for ([[maybe_unused]] const auto trial : c10::irange(trialCount)) { + for (const auto k : c10::irange(el_count)) { + vals0[k] = generator0.get(); + vals1[k] = generator1.get(); + call_filter(filter, vals0[k], vals1[k]); + // map operator + expected[k] = expectedFunction(vals0[k], vals1[k]); + } + // test + auto input0 = vec_type::loadu(vals0); + auto input1 = vec_type::loadu(vals1); + auto actual = actualFunction(input0, input1); + auto vec_expected = vec_type::loadu(expected); + AssertVectorized vecAssert( + testNameInfo, seed, vec_expected, actual, input0, input1); + if (vecAssert.check( + bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) + return; + } // trial changeSeedBy += 1; } for (auto& custom : testCase.getCustomChecks()) { @@ -1067,24 +1073,27 @@ void test_ternary( ValueGen generator1(start1, end1, seed.add(changeSeedBy + 1)); ValueGen generator2(start2, end2, seed.add(changeSeedBy + 2)); - for (C10_UNUSED const auto trial : c10::irange(trialCount)) { - for (const auto k : c10::irange(el_count)) { - vals0[k] = generator0.get(); - vals1[k] = generator1.get(); - vals2[k] = generator2.get(); - call_filter(filter, vals0[k], vals1[k], vals2[k]); - //map operator - expected[k] = expectedFunction(vals0[k], vals1[k], vals2[k]); - } - // test - auto input0 = vec_type::loadu(vals0); - auto input1 = vec_type::loadu(vals1); - auto input2 = vec_type::loadu(vals2); - auto actual = actualFunction(input0, input1, input2); - auto vec_expected = vec_type::loadu(expected); - AssertVectorized vecAssert(testNameInfo, seed, vec_expected, actual, input0, input1, input2); - if (vecAssert.check(bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) return; - }// trial + for ([[maybe_unused]] const auto trial : c10::irange(trialCount)) { + for (const auto k : c10::irange(el_count)) { + vals0[k] = generator0.get(); + vals1[k] = generator1.get(); + vals2[k] = generator2.get(); + call_filter(filter, vals0[k], vals1[k], vals2[k]); + // map operator + expected[k] = expectedFunction(vals0[k], vals1[k], vals2[k]); + } + // test + auto input0 = vec_type::loadu(vals0); + auto input1 = vec_type::loadu(vals1); + auto input2 = vec_type::loadu(vals2); + auto actual = actualFunction(input0, input1, input2); + auto vec_expected = vec_type::loadu(expected); + AssertVectorized vecAssert( + testNameInfo, seed, vec_expected, actual, input0, input1, input2); + if (vecAssert.check( + bitwise, dmn.CheckWithTolerance, dmn.ToleranceError)) + return; + } // trial changeSeedBy += 1; } } diff --git a/aten/src/ATen/vulkan/Context.cpp b/aten/src/ATen/vulkan/Context.cpp index 793c690a0c141..06d959b89fcb5 100644 --- a/aten/src/ATen/vulkan/Context.cpp +++ b/aten/src/ATen/vulkan/Context.cpp @@ -21,7 +21,7 @@ at::Tensor& vulkan_copy_(at::Tensor& self, const at::Tensor& src) { if (p) { return p->vulkan_copy_(self, src); } - AT_ERROR("Vulkan backend was not linked to the build"); + TORCH_CHECK(false, "Vulkan backend was not linked to the build"); } } // namespace vulkan diff --git a/aten/src/ATen/xpu/detail/XPUHooks.cpp b/aten/src/ATen/xpu/detail/XPUHooks.cpp index d9d0f06c0d804..05d4482fe979b 100644 --- a/aten/src/ATen/xpu/detail/XPUHooks.cpp +++ b/aten/src/ATen/xpu/detail/XPUHooks.cpp @@ -9,7 +9,7 @@ namespace at::xpu::detail { -void XPUHooks::initXPU() const { +void XPUHooks::init() const { C10_LOG_API_USAGE_ONCE("aten.init.xpu"); const auto device_count = c10::xpu::device_count_ensure_non_zero(); c10::xpu::XPUCachingAllocator::init(device_count); diff --git a/aten/src/ATen/xpu/detail/XPUHooks.h b/aten/src/ATen/xpu/detail/XPUHooks.h index 2f2b2b70e7a93..6c1c064bae80e 100644 --- a/aten/src/ATen/xpu/detail/XPUHooks.h +++ b/aten/src/ATen/xpu/detail/XPUHooks.h @@ -7,7 +7,7 @@ namespace at::xpu::detail { // The real implementation of XPUHooksInterface struct XPUHooks : public at::XPUHooksInterface { XPUHooks(at::XPUHooksArgs) {} - void initXPU() const override; + void init() const override; bool hasXPU() const override; std::string showConfig() const override; int32_t getGlobalIdxFromDevice(const at::Device& device) const override; diff --git a/benchmarks/distributed/ddp/diff.py b/benchmarks/distributed/ddp/diff.py index 14d839e973408..cfeb90cd6fa25 100644 --- a/benchmarks/distributed/ddp/diff.py +++ b/benchmarks/distributed/ddp/diff.py @@ -51,9 +51,7 @@ def main(): print() print(f"{'':>10s}", end="") # noqa: E999 for _ in [75, 95]: - print( - f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end="" - ) # noqa: E999 + print(f"{'sec/iter':>16s}{'ex/sec':>10s}{'diff':>10s}", end="") # noqa: E999 print() # Print measurements diff --git a/benchmarks/distributed/rpc/rl/launcher.py b/benchmarks/distributed/rpc/rl/launcher.py index 7c6f74524d79c..40b0fc4308e23 100644 --- a/benchmarks/distributed/rpc/rl/launcher.py +++ b/benchmarks/distributed/rpc/rl/launcher.py @@ -209,9 +209,8 @@ def main(): x_axis_variables ): # run benchmark for every x axis variable if len(x_axis_variables) > 1: - args[ - args["x_axis_name"] - ] = x_axis_variable # set x axis variable for this benchmark iteration + # set x axis variable for this benchmark iteration + args[args["x_axis_name"]] = x_axis_variable processes = [] start_time = time.time() for rank in range(args["world_size"]): diff --git a/benchmarks/dynamo/check_perf_csv.py b/benchmarks/dynamo/check_perf_csv.py index 2a19f6c4a1426..f5911d6a8a513 100644 --- a/benchmarks/dynamo/check_perf_csv.py +++ b/benchmarks/dynamo/check_perf_csv.py @@ -5,7 +5,7 @@ import pandas as pd -def check_perf_csv(filename, threshold): +def check_perf_csv(filename, threshold, threshold_scale): """ Basic performance checking. """ @@ -16,7 +16,7 @@ def check_perf_csv(filename, threshold): for _, row in df.iterrows(): model_name = row["name"] speedup = row["speedup"] - if speedup < threshold: + if speedup < threshold * threshold_scale: failed.append(model_name) print(f"{model_name:34} {speedup}") @@ -39,5 +39,12 @@ def check_perf_csv(filename, threshold): parser.add_argument( "--threshold", "-t", type=float, help="threshold speedup value to check against" ) + parser.add_argument( + "--threshold-scale", + "-s", + type=float, + default=1.0, + help="multiple threshold by this value to relax the check", + ) args = parser.parse_args() - check_perf_csv(args.file, args.threshold) + check_perf_csv(args.file, args.threshold, args.threshold_scale) diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv index c96684bc79462..14e4a23fac053 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_eager_torchbench_inference.csv @@ -82,7 +82,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 -detectron2_fcos_r_50_fpn,pass,21 +detectron2_fcos_r_50_fpn,pass,22 diff --git a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv index b82f003687d65..1b9b034987947 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/aot_inductor_torchbench_inference.csv @@ -30,7 +30,7 @@ basic_gnn_edgecnn,pass,0 -basic_gnn_gcn,fail_to_run,0 +basic_gnn_gcn,pass,0 @@ -278,7 +278,7 @@ resnext50_32x4d,pass,0 -sam,fail_to_run,0 +sam,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv index 3e81ee43beab4..5232996a8e41a 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_amp_freezing_torchbench_inference.csv @@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,fail_accuracy,46 -detectron2_fcos_r_50_fpn,pass,23 +detectron2_fcos_r_50_fpn,pass,24 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv index a897806e5188b..10dbea3f367e6 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_freezing_torchbench_inference.csv @@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,pass,46 -detectron2_fcos_r_50_fpn,pass,23 +detectron2_fcos_r_50_fpn,pass,24 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv index 66b23bd97e600..7671148626441 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cpu_inductor_torchbench_inference.csv @@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,pass,46 -detectron2_fcos_r_50_fpn,pass,23 +detectron2_fcos_r_50_fpn,pass,24 diff --git a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv index 1624d6dc7973f..1934304128888 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/cu124/aot_inductor_torchbench_inference.csv @@ -282,7 +282,7 @@ resnext50_32x4d,pass,0 -sam,fail_to_run,0 +sam,pass,0 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv index 9075a4adfd3a1..0a43ad91c7839 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_aot_eager_torchbench_inference.csv @@ -82,7 +82,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 -detectron2_fcos_r_50_fpn,pass,21 +detectron2_fcos_r_50_fpn,pass,22 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv index ae51d78c7e0bb..030558477462d 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_cpu_inductor_torchbench_inference.csv @@ -74,7 +74,7 @@ detectron2_fasterrcnn_r_50_fpn,pass,46 -detectron2_fcos_r_50_fpn,pass,23 +detectron2_fcos_r_50_fpn,pass,24 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv index 9075a4adfd3a1..0a43ad91c7839 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamic_inductor_torchbench_inference.csv @@ -82,7 +82,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 -detectron2_fcos_r_50_fpn,pass,21 +detectron2_fcos_r_50_fpn,pass,22 diff --git a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv index c96684bc79462..14e4a23fac053 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/dynamo_eager_torchbench_inference.csv @@ -82,7 +82,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 -detectron2_fcos_r_50_fpn,pass,21 +detectron2_fcos_r_50_fpn,pass,22 diff --git a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv index f21050a3d3d95..293ae08cd82dd 100644 --- a/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv +++ b/benchmarks/dynamo/ci_expected_accuracy/inductor_torchbench_inference.csv @@ -82,7 +82,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0 -detectron2_fcos_r_50_fpn,pass,21 +detectron2_fcos_r_50_fpn,pass,22 diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index 104e59bc193a4..0e7771c636a3d 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -12,6 +12,7 @@ import functools import importlib import itertools +import json import logging import os import shutil @@ -60,6 +61,7 @@ reset_rng_state, same, ) +from torch._logging.scribe import open_source_signpost try: @@ -375,6 +377,116 @@ def output_csv(filename, headers, row): writer.writerow(list(line) + ["0"] * (len(headers) - len(line))) +def get_suite_from_model_iter_fn(model_iter_fn): + # TODO: This is a bit of a hack + suite = None + if (runner := getattr(model_iter_fn, "__self__", None)) and hasattr( + runner, "suite_name" + ): + suite = runner.suite_name + return suite + + +def output_signpost(data, args, suite, error=None): + from torch.utils._stats import simple_call_counter + + data = data.copy() + + if "name" not in data: + data["name"] = current_name + + if "dev" not in data: + data["dev"] = current_device + + filtered_args = vars(args).copy() + # I generated this list by reading through all the configs and dropping + # ones that looked irrelevant or redundant + for k in [ + "filter", + "exclude", + "exclude_exact", + "dump_raw_metrics", + "log_operator_inputs", + "distributed_master_port", + "skip_accuracy_check", + "generate_aot_autograd_stats", + "output", + "output_directory", + "disable_output", + "export_profiler_trace", + "profiler_trace_name", + "explain", + "stats", + "print_memory", + "print_compilation_time", + "print_dataframe_summary", + "print_graph_breaks", + "log_graph_breaks", + "timing", + "progress", + "timeout", + "per_process_memory_fraction", + "minify", + "verbose", + "quiet", + "print_fx", + "print_aten_ops", + "log_conv_args", + "recompile_profiler", + "find_batch_sizes", + # Redundant + "batch_size", + "batch_size_file", + "only", + "diff_branch", + "tag", + "coverage", + "overhead", + "speedup_dynamo_ts", + "speedup_fx2trt", + "speedup_fx2trt_fp16", + "accuracy", + "performance", + "tolerance", + ]: + del filtered_args[k] + + event_name = "unknown" + if args.accuracy: + event_name = "accuracy" + elif args.quantization: + event_name = "quantization" + elif args.performance: + event_name = "performance" + + from torch._dynamo.utils import calculate_time_spent, compilation_time_metrics + + open_source_signpost( + subsystem="dynamo_benchmark", + name=event_name, + parameters=json.dumps( + { + **data, + # TODO: Arguably the rest of these should be in the CSV too + "suite": suite, + # Better than using compile_times utils directly + # NB: Externally, compilation_metrics colloquially refers to + # the coarse-grained phase timings, even though internally + # they are called something else + "compilation_metrics": calculate_time_spent(), + "agg_compilation_metrics": { + k: sum(v) for k, v in compilation_time_metrics.items() + }, + "detailed_compilation_metrics": compilation_time_metrics, + "simple_call_counter": simple_call_counter, + # NB: args has training vs inference + "args": filtered_args, + "error": error, + } + ), + ) + + def nothing(f): return f @@ -649,6 +761,7 @@ def speedup_experiment_fx2trt(args, model_iter_fn, model, example_inputs): return speedup_experiment(args, model_iter_fn, model, example_inputs) +# TODO: CompilerProfiler is deprecated, remove this def recompile_profiler_experiment(args, model_iter_fn, model, example_inputs): prof = torch._dynamo.utils.CompilerProfiler() opt_model_iter_fn = torch._dynamo.optimize(prof, nopython=args.nopython)( @@ -753,7 +866,8 @@ def maybe_mark_profile(*args, **kwargs): return timings -def latency_experiment_summary(args, model, timings, **kwargs): +# TODO: This seems to be specifically triggered by torchao testing +def latency_experiment_summary(suite_name, args, model, timings, **kwargs): median = np.median(timings, axis=0) speedup = median[0] / median[1] if args.dump_raw_metrics: @@ -814,15 +928,26 @@ def latency_experiment_summary(args, model, timings, **kwargs): headers, row, ) - headers, data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True) + c_headers, c_data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True) assert ( output_filename.find(".csv") > 0 ), f"expected output_filename to be a .csv, but got {output_filename}" output_csv( output_filename[:-4] + "_compilation_metrics.csv", - first_headers + headers, - first_fields + data, + first_headers + c_headers, + first_fields + c_data, + ) + + # Hypothetically you can use this from other places, but it's currently + # inaccessible, and when this assert fails you need to update the + # event_name here to account for the other cases you are using this + assert args.quantization is not None + output_signpost( + dict(zip(headers, row)), + args, + suite_name, ) + return msg @@ -974,18 +1099,26 @@ def maybe_mark_profile(*args, **kwargs): headers, row, ) - headers, data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True) + c_headers, c_data = torch._dynamo.utils.compile_times(repr="csv", aggregate=True) assert ( output_filename.find(".csv") > 0 ), f"expected output_filename to be a .csv, but got {output_filename}" output_csv( output_filename[:-4] + "_compilation_metrics.csv", - first_headers + headers, - first_fields + data, + first_headers + c_headers, + first_fields + c_data, + ) + + output_signpost( + dict(zip(headers, row)), + args, + get_suite_from_model_iter_fn(model_iter_fn), ) + return msg +# WARNING: This code is currently dead def speedup_experiment_ds(args, model_iter_fn, model, example_inputs): """ Run dynamic shapes benchmarks. @@ -1391,9 +1524,7 @@ def load(cls, model, example_inputs, device): strict=False, ).module() with torch.no_grad(): - so_path = torch._inductor.aot_compile( - gm, example_args, example_kwargs - ) # type: ignore[arg-type] + so_path = torch._inductor.aot_compile(gm, example_args, example_kwargs) # type: ignore[arg-type] cls.cache[key] = torch._export.aot_load(so_path, device) @@ -1559,12 +1690,10 @@ def _generate_onnx_model_directory( return model_path @abc.abstractmethod - def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]: - ... + def format_pt_inputs(self, pt_inputs: Any) -> Sequence[torch.Tensor]: ... @abc.abstractmethod - def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]: - ... + def format_pt_outputs(self, pt_outputs: Any) -> Sequence[torch.Tensor]: ... def adapt_pt_inputs_to_onnx(self, pt_inputs) -> Mapping[str, npt.NDArray]: pt_inputs = self.format_pt_inputs(pt_inputs) @@ -2701,6 +2830,9 @@ def record_status(accuracy_status, dynamo_start_stats): headers.insert(3, "tag") fields.insert(3, tag) + o_headers = list(headers) + o_fields = list(fields) + dynamo_stats = get_dynamo_stats() dynamo_stats.subtract(dynamo_start_stats) for k, v in dynamo_stats.items(): @@ -2708,6 +2840,13 @@ def record_status(accuracy_status, dynamo_start_stats): fields.append(v) output_csv(output_filename, headers, fields) + + output_signpost( + dict(zip(o_headers, o_fields)), + self.args, + self.suite_name, + ) + return accuracy_status if name in self.skip_accuracy_checks_large_models_dashboard: @@ -3023,6 +3162,7 @@ def warmup(fn, model, example_inputs, mode, niters=10): write_csv_when_exception( self.args, current_name, "warmup_failed", current_device ) + output_signpost({}, self.args, self.suite_name, error="warmup_failed") return sys.exit(-1) dynamo_stats = get_dynamo_stats() dynamo_stats.subtract(start_stats) @@ -3134,9 +3274,9 @@ def warmup(fn, model, example_inputs, mode, niters=10): experiment_kwargs["dynamo_peak_mem"] = dynamo_peak_mem experiment_kwargs["dynamo_stats"] = dynamo_stats if self.args.profile_dynamo_cache_lookup: - experiment_kwargs[ - "cache_lookup_latency" - ] = dynamo_cache_lookup_latency + experiment_kwargs["cache_lookup_latency"] = ( + dynamo_cache_lookup_latency + ) if experiment.func is speedup_experiment_onnx: experiment = functools.partial( @@ -3147,7 +3287,7 @@ def warmup(fn, model, example_inputs, mode, niters=10): ) timings = np.stack((baseline_timings, backend_timings), axis=1) result_summary = latency_experiment_summary( - self.args, model, timings, **experiment_kwargs + self.suite_name, self.args, model, timings, **experiment_kwargs ) if not hasattr(model, name): model.name = name @@ -3184,6 +3324,7 @@ def warmup(fn, model, example_inputs, mode, niters=5): write_csv_when_exception( self.args, current_name, "warmup_failed", current_device ) + output_signpost({}, self.args, self.suite_name, error="warmup_failed") return sys.exit(-1) dynamo_stats = get_dynamo_stats() dynamo_stats.subtract(start_stats) @@ -3290,9 +3431,9 @@ def warmup(fn, model, example_inputs, mode, niters=5): experiment_kwargs["dynamo_peak_mem"] = dynamo_peak_mem experiment_kwargs["dynamo_stats"] = dynamo_stats if self.args.profile_dynamo_cache_lookup: - experiment_kwargs[ - "cache_lookup_latency" - ] = dynamo_cache_lookup_latency + experiment_kwargs["cache_lookup_latency"] = ( + dynamo_cache_lookup_latency + ) if experiment.func is coverage_experiment: ok, total = Stats.reset_counters() @@ -4324,7 +4465,14 @@ def run(runner, args, original_dir=None): runner.skip_models.clear() experiment = null_experiment - global current_name, current_device, current_batch_size, output_filename, disable_output, optimize_ctx, current_onnx_compiler + global \ + current_name, \ + current_device, \ + current_batch_size, \ + output_filename, \ + disable_output, \ + optimize_ctx, \ + current_onnx_compiler optimize_ctx = contextlib.nullcontext() if args.disable_output: @@ -4655,6 +4803,14 @@ def model_iter_fn_and_mark_step(*args, **kwargs): else "eager_fail_to_run" ) write_csv_when_exception(args, name, status, device) + # NB: current_name/current_device not set, so pass + # explicitly + output_signpost( + {"name": name, "dev": device}, + args, + runner.suite_name, + error=status, + ) continue # bad benchmark implementation if args.trace_on_xla: @@ -4767,6 +4923,11 @@ def detect_and_mark_batch(t): ) except subprocess.TimeoutExpired: write_csv_when_exception(args, name, "timeout") + # NB: device is potentially multiple here, though we should + # try our best to report in anyway TODO + output_signpost( + {"name": name}, args, runner.suite_name, error="timeout" + ) except subprocess.CalledProcessError as e: print("Run failed with return code: ", e.returncode, file=sys.stderr) print("Output: ", e.output, file=sys.stderr) diff --git a/benchmarks/dynamo/expected_ci_speedup_inductor_torchbench_cpu.csv b/benchmarks/dynamo/expected_ci_speedup_inductor_torchbench_cpu.csv index 0fe9f8cd2ecce..9462efef99ae8 100644 --- a/benchmarks/dynamo/expected_ci_speedup_inductor_torchbench_cpu.csv +++ b/benchmarks/dynamo/expected_ci_speedup_inductor_torchbench_cpu.csv @@ -7,7 +7,7 @@ resnet50,inductor,float32,dynamic,default,1.67742767 #timm_efficientnet,inductor,float32,static,cpp, mobilenet_v3_large,inductor,float32,static,cpp,2.63311706 timm_resnest,inductor,float32,dynamic,cpp,1.7321529 -functorch_maml_omniglot,inductor,float32,dynamic,cpp,1.17617472 +functorch_maml_omniglot,inductor,float32,dynamic,cpp,1.126799 #hf_GPT2,inductor,float32,dynamic,cpp, yolov3,export-aot-inductor,float32,static,default,1.40687424 mobilenet_v2,export-aot-inductor,float32,static,default,2.90375357 diff --git a/benchmarks/dynamo/huggingface.py b/benchmarks/dynamo/huggingface.py index 06bf4f0ee7610..a96bad12b73f9 100755 --- a/benchmarks/dynamo/huggingface.py +++ b/benchmarks/dynamo/huggingface.py @@ -501,12 +501,12 @@ def get_tolerance_and_cosine_flag(self, is_training, current_device, name): else: return 1e-2, cosine else: - if name in self._config["tolerance"]["higher_inference"]: - return 4e-3, cosine if ( current_device == "cpu" and name in self._config["tolerance"]["higher_inference_cpu"] ): + return 5e-3, cosine + if name in self._config["tolerance"]["higher_inference"]: return 4e-3, cosine return 1e-3, cosine diff --git a/benchmarks/dynamo/huggingface.yaml b/benchmarks/dynamo/huggingface.yaml index 2ddc242537d6e..f0ee57a589657 100644 --- a/benchmarks/dynamo/huggingface.yaml +++ b/benchmarks/dynamo/huggingface.yaml @@ -89,6 +89,7 @@ tolerance: higher_inference_cpu: - LayoutLMForSequenceClassification + - GPT2ForSequenceClassification cosine: [] diff --git a/benchmarks/dynamo/join_results.py b/benchmarks/dynamo/join_results.py index fce6f81580486..006eb57a96975 100644 --- a/benchmarks/dynamo/join_results.py +++ b/benchmarks/dynamo/join_results.py @@ -2,6 +2,7 @@ A tool to merge multiple csv files (generated by torchbench.py/etc) into a single csv file. Performs an outer join based on the benchmark name, filling in any missing data with zeros. """ + import argparse import functools import operator diff --git a/benchmarks/dynamo/microbenchmarks/analyze_templates.py b/benchmarks/dynamo/microbenchmarks/analyze_templates.py index 65fa547123a4b..b9899f8adb590 100644 --- a/benchmarks/dynamo/microbenchmarks/analyze_templates.py +++ b/benchmarks/dynamo/microbenchmarks/analyze_templates.py @@ -4,6 +4,7 @@ That file can be fed into this script to generate the minimizes total, weighted matmul time as a function of allowed templates. """ + import json import click diff --git a/benchmarks/dynamo/microbenchmarks/cache_debug_microbenchmarks.py b/benchmarks/dynamo/microbenchmarks/cache_debug_microbenchmarks.py new file mode 100644 index 0000000000000..f152f0c9bd10f --- /dev/null +++ b/benchmarks/dynamo/microbenchmarks/cache_debug_microbenchmarks.py @@ -0,0 +1,32 @@ +import timeit + +import torch.fx +from torch._inductor.codecache import FxGraphHashDetails + + +N = 10000 +K = 100 + + +def huge_graph(): + def fn(x): + for _ in range(N): + x = x.sin() + return x + + return torch.fx.symbolic_trace(fn) + + +def main(): + g = huge_graph() + details = FxGraphHashDetails(g, [], {}, []) + + def fn(): + return details.debug_lines() + + t = min(timeit.repeat(fn, number=K, repeat=3)) + print(f"iterating over {N*K} FX nodes took {t:.1f}s ({N*K/t:.0f} nodes/s)") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dynamo/microbenchmarks/cache_hit_microbenchmarks.py b/benchmarks/dynamo/microbenchmarks/cache_hit_microbenchmarks.py new file mode 100644 index 0000000000000..53879f5e8c0ee --- /dev/null +++ b/benchmarks/dynamo/microbenchmarks/cache_hit_microbenchmarks.py @@ -0,0 +1,49 @@ +import os +import timeit + +import torch.fx +from torch._dynamo.utils import counters +from torch._inductor.utils import clear_inductor_caches, fresh_inductor_cache + + +N = 10000 +K = 100 + + +def huge_graph(x): + for _ in range(N): + x = x.sin() + return x + + +def main(): + torch._inductor.config.fx_graph_cache = True + torch._inductor.config.fx_graph_remote_cache = False + + with fresh_inductor_cache(): + a = torch.randn(4).cuda() + compiled_fn = torch.compile(huge_graph, backend="inductor") + + # write to cache + compiled_fn(a) + assert counters["inductor"]["fxgraph_cache_miss"] == 1 + + def setup(): + torch._dynamo.reset() + clear_inductor_caches() + for m in torch._inductor.codecache.PyCodeCache.cache.values(): + os.remove(m.__file__) + counters.clear() + + def fn(): + result = compiled_fn(a) + assert counters["inductor"]["fxgraph_cache_miss"] == 0 + assert counters["inductor"]["fxgraph_cache_hit"] == 1 + return result + + t = min(timeit.repeat(fn, setup=setup, number=K, repeat=3)) + print(f"took {t:.1f}s") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/aotdispatcher.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/aotdispatcher.py new file mode 100644 index 0000000000000..53a8f20b06122 --- /dev/null +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/aotdispatcher.py @@ -0,0 +1,72 @@ +import sys + +from benchmark_base import BenchmarkBase + +import torch +from torch.testing._internal.two_tensor import TwoTensor + + +class Benchmark(BenchmarkBase): + def __init__(self, *, training, subclass): + self._training = training + self._subclass = subclass + self._device = "cpu" + + def name(self): + prefix = "aotdispatcher" + if self._training: + prefix += "_training" + else: + prefix += "_inference" + if self._subclass: + prefix += "_subclass" + else: + prefix += "_nosubclass" + if self._device == "cpu": + prefix += "_cpu" + return prefix + + def description(self): + return "100 inputs, 100 outputs, each input is added once" + + def _prepare_once(self): + _args = [ + torch.ones(100, requires_grad=self._training, device=self._device) + for _ in range(100) + ] + if self._subclass: + _args = [ + TwoTensor(x, x.clone().detach().requires_grad_(self._training)) + for x in _args + ] + self._args = _args + + def _prepare(self): + torch._dynamo.reset() + + def _work(self): + @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True) + def f(*args): + outs = [torch.add(x, x) for x in args] + return outs + + f(*self._args) + + +def main(): + result_path = sys.argv[1] + all = [ + Benchmark(training=False, subclass=False), + Benchmark(training=True, subclass=False), + Benchmark(training=False, subclass=True), + Benchmark(training=True, subclass=True), + ] + + for benchmark in all: + benchmark.enable_compile_time_instruction_count().collect_all().append_results( + result_path + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dynamo/pr_time_benchmarks/benchmarks/aotdispatcher_partitioner.py b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/aotdispatcher_partitioner.py new file mode 100644 index 0000000000000..30fa5fa386124 --- /dev/null +++ b/benchmarks/dynamo/pr_time_benchmarks/benchmarks/aotdispatcher_partitioner.py @@ -0,0 +1,46 @@ +import sys + +from benchmark_base import BenchmarkBase + +import torch + + +class Benchmark(BenchmarkBase): + def name(self): + return "aotdispatcher_partitioner_cpu" + + def description(self): + return "partitioner benchmark 1 input and 100 weights, mix of recompute and non-recompute ops" + + def _prepare_once(self): + self.weights = [torch.randn(16, 16, requires_grad=True) for _ in range(100)] + self.inp = torch.randn(16, 16) + + def _prepare(self): + torch._dynamo.reset() + + def _work(self): + @torch.compile(backend="aot_eager_decomp_partition", fullgraph=True) + def f(inp, *weights): + x = inp + for w in weights: + x = torch.matmul(w, x).sin().sin() + return x + + f(self.inp, *self.weights) + + +def main(): + result_path = sys.argv[1] + all = [ + Benchmark(), + ] + + for benchmark in all: + benchmark.enable_compile_time_instruction_count().collect_all().append_results( + result_path + ) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/dynamo/pr_time_benchmarks/check_results.py b/benchmarks/dynamo/pr_time_benchmarks/check_results.py index 4d9bee7f9ff9a..8b18af47a589e 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/check_results.py +++ b/benchmarks/dynamo/pr_time_benchmarks/check_results.py @@ -1,3 +1,4 @@ +import copy import csv import json import sys @@ -21,6 +22,35 @@ class ResultFileEntry: actual_value: int +def replace_with_zeros(num): + """ + Keeps the first three digits of an integer and replaces the rest with zeros. + + Args: + num (int): The number to modify. + + Returns: + int: The modified number. + + Raises: + ValueError: If the input is not an integer. + """ + # Check if input is an integer + if not isinstance(num, int): + raise ValueError("Input must be an integer") + + # Calculate the number of digits to remove + digits_to_remove = len(str(abs(num))) - 4 + + # Replace digits with zeros + if digits_to_remove > 0: + modified_num = (num // 10**digits_to_remove) * 10**digits_to_remove + else: + modified_num = num + + return modified_num + + def main(): # Expected file is the file that have the results that we are comparing against. # Expected has the following format: @@ -35,12 +65,18 @@ def main(): # add_loop_eager,compile_time_instruction_count,283178305 result_file_path = sys.argv[2] + # A path where a new expected results file will be written that can be used to replace expected_results.csv + # in case of failure. In case of no failure the content of this file will match expected_file_path. + reference_expected_results_path = sys.argv[3] + # Read expected data file. expected_data: dict[str, ExpectedFileEntry] = {} with open(expected_file_path) as f: reader = csv.reader(f) for row in reader: + if len(row) == 0: + continue entry = ExpectedFileEntry( benchmark_name=row[0].strip(), metric_name=row[1].strip(), @@ -68,6 +104,7 @@ def main(): result_data[key] = entry fail = False + new_expected = copy.deepcopy(expected_data) for key, entry in expected_data.items(): if key not in result_data: print(f"Missing entry for {key} in result file") @@ -76,6 +113,7 @@ def main(): low = entry.expected_value - entry.expected_value * entry.noise_margin high = entry.expected_value + entry.expected_value * entry.noise_margin result = result_data[key].actual_value + ratio = float(result - entry.expected_value) * 100 / entry.expected_value def log(event_name): scribe.open_source_signpost( @@ -88,38 +126,60 @@ def log(event_name): "actual_value": result, "expected_value": entry.expected_value, "noise_margin": entry.noise_margin, + "change_ratio": ratio, } ), ) + new_entry = copy.deepcopy(entry) + # only change if abs(ratio) > entry.noise_margin /4. + new_entry.expected_value = ( + replace_with_zeros(result) + if abs(ratio) > entry.noise_margin / 4 + else entry.expected_value + ) + new_expected[key] = new_entry + if result > high: fail = True - ratio = float(result - entry.expected_value) * 100 / entry.expected_value print( f"REGRESSION: benchmark {key} failed, actual result {result} " - f"is {ratio:.2f}% higher than expected {entry.expected_value} ±{entry.noise_margin*100:.2f}% " - f"if this is an expected regression, please update the expected results." + f"is {ratio:.2f}% higher than expected {entry.expected_value} ±{entry.noise_margin*100:+.2f}% " + f"if this is an expected regression, please update the expected results.\n" + ) + print( + "please update all results that changed significantly, and not only the failed ones" ) log("fail_regression") - if result < low: + elif result < low: fail = True - ratio = float(entry.expected_value - result) * 100 / entry.expected_value print( - f"WIN: benchmark {key} failed, actual result {result} is {ratio:.2f}% lower than " + f"WIN: benchmark {key} failed, actual result {result} is {ratio:+.2f}% lower than " f"expected {entry.expected_value} ±{entry.noise_margin*100:.2f}% " - f"please update the expected results." + f"please update the expected results. \n" + ) + print( + "please update all results that changed significantly, and not only the failed ones" ) log("fail_win") + else: + print( + f"PASS: benchmark {key} pass, actual result {result} {ratio:+.2f}% is within " + f"expected {entry.expected_value} ±{entry.noise_margin*100:.2f}%\n" + ) + + log("pass") + # Log all benchmarks that do not have a regression test enabled for them. for key, entry in result_data.items(): if key not in expected_data: print( - f"MISSING REGRESSION TEST: benchmark {key} does not have a regression test enabled for it" + f"MISSING REGRESSION TEST: benchmark {key} does not have a regression test enabled for it.\n" ) scribe.open_source_signpost( subsystem="pr_time_benchmarks", @@ -131,7 +191,34 @@ def log(event_name): } ), ) + + with open(reference_expected_results_path, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + for entry in new_expected.values(): + # Write the data to the CSV file + # print(f"{entry.benchmark_name},{entry.metric_name,},{round(entry.expected_value)},{entry.noise_margin}") + writer.writerow( + [ + entry.benchmark_name, + entry.metric_name, + entry.expected_value, + entry.noise_margin, + ] + ) + # Three empty rows for merge conflicts. + writer.writerow([]) + writer.writerow([]) + writer.writerow([]) + + print("new expected results file content if needed:") + with open(reference_expected_results_path) as f: + print(f.read()) + if fail: + print( + f"There was some failures you can use the new reference expected result stored at path:" + f"{reference_expected_results_path} and printed above\n" + ) sys.exit(1) else: print("All benchmarks passed") diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 41649541c7e68..1605327050975 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,5 +1,66 @@ add_loop_eager, compile_time_instruction_count, 3004749893, 0.015 -add_loop_eager_dynamic, compile_time_instruction_count, 5726573328, 0.025 -add_loop_inductor, compile_time_instruction_count, 24146845503, 0.015 -add_loop_inductor_dynamic_gpu, compile_time_instruction_count, 39411706509, 0.025 -add_loop_inductor_gpu, compile_time_instruction_count, 22171041650, 0.015 + + + +add_loop_eager_dynamic, compile_time_instruction_count, 5563298740, 0.025 + + + +add_loop_inductor, compile_time_instruction_count, 24064639114, 0.015 + + + +add_loop_inductor_dynamic_gpu, compile_time_instruction_count, 40992578178, 0.025 + + + +add_loop_inductor_gpu, compile_time_instruction_count, 22822864522, 0.015 + + + +basic_modules_ListOfLinears_eager, compile_time_instruction_count, 1034818091, 0.015 + + + +basic_modules_ListOfLinears_inductor, compile_time_instruction_count, 19049541914, 0.015 + + + +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad, compile_time_instruction_count, 15806042948, 0.015 + + + +basic_modules_ListOfLinears_inductor_gpu, compile_time_instruction_count, 16403080126, 0.20 + + + + +update_hint_regression, compile_time_instruction_count, 1853008305, 0.02 + + + +sum_floordiv_regression, compile_time_instruction_count, 1154135694, 0.015 + + + +symint_sum, compile_time_instruction_count, 3270576815, 0.015 + + + +aotdispatcher_inference_nosubclass_cpu, compile_time_instruction_count, 1981730523, 0.015 + + + +aotdispatcher_inference_subclass_cpu, compile_time_instruction_count, 5711895807, 0.015 + + + +aotdispatcher_partitioner_cpu, compile_time_instruction_count, 8963708885 , 0.015 + + + +aotdispatcher_training_nosubclass_cpu, compile_time_instruction_count, 3795666651, 0.015 + + + +aotdispatcher_training_subclass_cpu, compile_time_instruction_count, 10175364418, 0.015 diff --git a/benchmarks/dynamo/pr_time_benchmarks/test_check_result/expected_test.csv b/benchmarks/dynamo/pr_time_benchmarks/test_check_result/expected_test.csv index 830751a8547c6..a3bcac705ea62 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/test_check_result/expected_test.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/test_check_result/expected_test.csv @@ -1,3 +1,3 @@ -a, instruction count, 110, 0.01 -b, memory, 100, 0.1 -c, something, 100, 0.1 +a, instruction count, 11011111111, 0.01 +b, memory, 10011111111, 0.1 +c, something, 10011111111, 0.1 diff --git a/benchmarks/dynamo/pr_time_benchmarks/test_check_result/result_test.csv b/benchmarks/dynamo/pr_time_benchmarks/test_check_result/result_test.csv index 07f6c814fbaef..f198fcd4e30d0 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/test_check_result/result_test.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/test_check_result/result_test.csv @@ -1,4 +1,4 @@ -a, instruction count, 90 -b, memory, 200 -c, something, 107 -d, missing-test, 10 +a, instruction count, 9011111111 +b, memory, 20011111111 +c, something, 107111111111 +d, missing-test, 10111111111 diff --git a/benchmarks/fastrnns/bench.py b/benchmarks/fastrnns/bench.py index f5325a848d92d..fc18c89fa95d9 100644 --- a/benchmarks/fastrnns/bench.py +++ b/benchmarks/fastrnns/bench.py @@ -214,8 +214,7 @@ def bench(rnn_runners, group_name, print_json=False, sep=" ", **params): k: {"avg": v.avg_fwd, "std": v.std_fwd, "info": v.info_fwd} for k, v in results.items() }, - group_name - + "-backward": { + f"{group_name}-backward": { k: {"avg": v.avg_bwd, "std": v.std_bwd, "info": v.info_bwd} for k, v in results.items() }, diff --git a/benchmarks/gpt_fast/mixtral_moe_quantize.py b/benchmarks/gpt_fast/mixtral_moe_quantize.py index c1840330c025a..2322451560901 100644 --- a/benchmarks/gpt_fast/mixtral_moe_quantize.py +++ b/benchmarks/gpt_fast/mixtral_moe_quantize.py @@ -184,7 +184,5 @@ def forward(self, x, expert_indices): ].to(x.dtype) expert_outs = torch.einsum( "tao, taio -> tai", (x1 * x3), w2_weights - ) * self.scales2[expert_indices].to( - x.dtype - ) # [T, A, D, D] + ) * self.scales2[expert_indices].to(x.dtype) # [T, A, D, D] return expert_outs diff --git a/benchmarks/instruction_counts/applications/ci.py b/benchmarks/instruction_counts/applications/ci.py index 9d50ad0fee061..e5d53ec57d339 100644 --- a/benchmarks/instruction_counts/applications/ci.py +++ b/benchmarks/instruction_counts/applications/ci.py @@ -1,5 +1,7 @@ """Collect instruction counts for continuous integration.""" + # mypy: ignore-errors + import argparse import hashlib import json diff --git a/benchmarks/instruction_counts/core/api.py b/benchmarks/instruction_counts/core/api.py index 640ef3f19270a..55e052d4063d8 100644 --- a/benchmarks/instruction_counts/core/api.py +++ b/benchmarks/instruction_counts/core/api.py @@ -1,5 +1,7 @@ """Key enums and structs used to handle data flow within the benchmark.""" + # mypy: ignore-errors + import dataclasses import enum import itertools as it diff --git a/benchmarks/instruction_counts/core/expand.py b/benchmarks/instruction_counts/core/expand.py index 01b22533dbc6e..6ceb2322fb9de 100644 --- a/benchmarks/instruction_counts/core/expand.py +++ b/benchmarks/instruction_counts/core/expand.py @@ -2,7 +2,9 @@ This is mostly string manipulation, with just a bit of importlib magic. """ + # mypy: ignore-errors + import importlib.abc import importlib.util import itertools as it diff --git a/benchmarks/instruction_counts/core/types.py b/benchmarks/instruction_counts/core/types.py index 06c6c2e87d893..52f722176d020 100644 --- a/benchmarks/instruction_counts/core/types.py +++ b/benchmarks/instruction_counts/core/types.py @@ -1,5 +1,7 @@ """Type annotations for various benchmark objects.""" + # mypy: ignore-errors + from typing import Any, Dict, Optional, Tuple, Union from core.api import AutoLabels, GroupedBenchmark, TimerArgs diff --git a/benchmarks/instruction_counts/definitions/setup.py b/benchmarks/instruction_counts/definitions/setup.py index fbc3798d9988f..4210eb49a71b9 100644 --- a/benchmarks/instruction_counts/definitions/setup.py +++ b/benchmarks/instruction_counts/definitions/setup.py @@ -1,5 +1,7 @@ """Define some common setup blocks which benchmarks can reuse.""" + # mypy: ignore-errors + import enum from core.api import GroupedSetup diff --git a/benchmarks/instruction_counts/execution/runner.py b/benchmarks/instruction_counts/execution/runner.py index 8d18ba02bc200..a86608059038c 100644 --- a/benchmarks/instruction_counts/execution/runner.py +++ b/benchmarks/instruction_counts/execution/runner.py @@ -1,5 +1,7 @@ """Run benchmarks while handling parallelism, isolation, and fault tolerance.""" + # mypy: ignore-errors + import math import multiprocessing import subprocess diff --git a/benchmarks/instruction_counts/execution/work.py b/benchmarks/instruction_counts/execution/work.py index b1b77282c4521..c44cb6489fffd 100644 --- a/benchmarks/instruction_counts/execution/work.py +++ b/benchmarks/instruction_counts/execution/work.py @@ -1,5 +1,7 @@ """Handle the details of subprocess calls and retries for a given benchmark run.""" + # mypy: ignore-errors + import dataclasses import json import os diff --git a/benchmarks/instruction_counts/main.py b/benchmarks/instruction_counts/main.py index 2f8e40b9dcb2e..43f712e99a722 100644 --- a/benchmarks/instruction_counts/main.py +++ b/benchmarks/instruction_counts/main.py @@ -5,7 +5,9 @@ components) in future iterations. However this allows us to excercise the underlying benchmark generation infrastructure in the mean time. """ + # mypy: ignore-errors + import argparse import sys from typing import List diff --git a/benchmarks/instruction_counts/worker/main.py b/benchmarks/instruction_counts/worker/main.py index 151cae993b133..b8c277eb6dcfb 100644 --- a/benchmarks/instruction_counts/worker/main.py +++ b/benchmarks/instruction_counts/worker/main.py @@ -15,6 +15,7 @@ Because this file only expects to run in a child context, error handling means plumbing failures up to the caller, not raising in this process. """ + import argparse import dataclasses import io diff --git a/benchmarks/operator_benchmark/pt/qrnn_test.py b/benchmarks/operator_benchmark/pt/qrnn_test.py index 6d140464e965a..5c0ef809acb7e 100644 --- a/benchmarks/operator_benchmark/pt/qrnn_test.py +++ b/benchmarks/operator_benchmark/pt/qrnn_test.py @@ -48,14 +48,20 @@ def init(self, I, H, NL, B, D, dtype): )[0] x = torch.randn( - sequence_len, batch_size, I # sequence length # batch size - ) # Number of features in X + sequence_len, # sequence length + batch_size, # batch size + I, # Number of features in X + ) h = torch.randn( - NL * (D + 1), batch_size, H # layer_num * dir_num # batch size - ) # hidden size + NL * (D + 1), # layer_num * dir_num + batch_size, # batch size + H, # hidden size + ) c = torch.randn( - NL * (D + 1), batch_size, H # layer_num * dir_num # batch size - ) # hidden size + NL * (D + 1), # layer_num * dir_num + batch_size, # batch size + H, # hidden size + ) self.inputs = {"x": x, "h": h, "c": c} self.set_module_name("QLSTM") diff --git a/benchmarks/transformer/better_transformer_vs_mha_functional.py b/benchmarks/transformer/better_transformer_vs_mha_functional.py index 3aa2e6c214c0f..f7a80169521b7 100644 --- a/benchmarks/transformer/better_transformer_vs_mha_functional.py +++ b/benchmarks/transformer/better_transformer_vs_mha_functional.py @@ -152,8 +152,8 @@ def run( result_entry["sequence_length"] = sequence_length result_entry["n_heads"] = num_heads result_entry["embed_dim"] = embed_dim - result_entry["time_native_mha_slow(\u00B5s)"] = f"{time_native_mha_slow:.3f}" - result_entry["time_native_mha_fast (\u00B5s)"] = f"{time_native_mha_fast:.3f}" + result_entry["time_native_mha_slow(\u00b5s)"] = f"{time_native_mha_slow:.3f}" + result_entry["time_native_mha_fast (\u00b5s)"] = f"{time_native_mha_fast:.3f}" result_entry["speedup flash_mha v native_mha"] = f"{speedup_fast_internal:.3f}" result_entry["padding"] = f"{padding:.3f}" return result_entry diff --git a/benchmarks/transformer/sdp.py b/benchmarks/transformer/sdp.py index 3edda07b309e6..ca15d1a95067c 100644 --- a/benchmarks/transformer/sdp.py +++ b/benchmarks/transformer/sdp.py @@ -82,10 +82,10 @@ def get_entries(self) -> List: @classmethod def get_entry_names(cls) -> List[str]: return [ - "nn_mha_time (\u00B5s)", - "compiled_nn_mha_time (\u00B5s)", - "composite_mha_time (\u00B5s)", - "compiled_composite_mha_time (\u00B5s)", + "nn_mha_time (\u00b5s)", + "compiled_nn_mha_time (\u00b5s)", + "composite_mha_time (\u00b5s)", + "compiled_composite_mha_time (\u00b5s)", ] diff --git a/buckbuild.bzl b/buckbuild.bzl index 4954e10d561ef..1c8b8e39a3a81 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -5,13 +5,13 @@ load("@bazel_skylib//lib:paths.bzl", "paths") load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native") load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") +load("//tools/build_defs/windows:windows_flag_map.bzl", "windows_convert_gcc_clang_flags") load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") load("//tools/build_defs:platform_defs.bzl", "APPLETVOS", "IOS", "MACOSX") load("//tools/build_defs:type_defs.bzl", "is_list", "is_string") load("//tools/build_defs/android:build_mode_defs.bzl", is_production_build_android = "is_production_build") load("//tools/build_defs/apple:build_mode_defs.bzl", is_production_build_ios = "is_production_build") -load("//tools/build_defs/windows:windows_flag_map.bzl", "windows_convert_gcc_clang_flags") load( ":build_variables.bzl", "aten_cpu_source_list", @@ -213,7 +213,6 @@ _PT_COMPILER_FLAGS = [ ATEN_COMPILER_FLAGS = [ "-fexceptions", "-frtti", - "-fPIC", "-Os", "-Wno-absolute-value", "-Wno-deprecated-declarations", @@ -225,10 +224,17 @@ ATEN_COMPILER_FLAGS = [ "-Wno-unused-variable", "-Wno-pass-failed", "-Wno-shadow", -] +] + select({ + # Not supported by clang on Windows + "DEFAULT": ["-fPIC"], + "ovr_config//compiler:clang-windows": [], +}) def get_aten_compiler_flags(): - return ATEN_COMPILER_FLAGS + return select({ + "DEFAULT": ATEN_COMPILER_FLAGS, + "ovr_config//compiler:cl": windows_convert_gcc_clang_flags(ATEN_COMPILER_FLAGS), + }) _COMMON_PREPROCESSOR_FLAGS = [ "-DC10_MOBILE", diff --git a/c10/benchmark/CMakeLists.txt b/c10/benchmark/CMakeLists.txt index 16b268e3800a0..8dee635d7e1d7 100644 --- a/c10/benchmark/CMakeLists.txt +++ b/c10/benchmark/CMakeLists.txt @@ -8,6 +8,7 @@ if(BUILD_TEST) add_executable(${bench_name} "${bench_src}") target_link_libraries(${bench_name} ${C10_LIB} benchmark) if(INSTALL_TEST) + set_target_properties(${bench_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${bench_name} DESTINATION test) endif() endforeach() diff --git a/c10/core/Allocator.cpp b/c10/core/Allocator.cpp index 491c85b081e88..ca02350ce7bfd 100644 --- a/c10/core/Allocator.cpp +++ b/c10/core/Allocator.cpp @@ -87,8 +87,6 @@ void reportOutOfMemoryToProfiler( } } -MemoryReportingInfoBase::MemoryReportingInfoBase() = default; - void MemoryReportingInfoBase::reportOutOfMemory( int64_t /*alloc_size*/, size_t /*total_allocated*/, diff --git a/c10/core/Allocator.h b/c10/core/Allocator.h index 412412557a0d1..bdb8c719fbc53 100644 --- a/c10/core/Allocator.h +++ b/c10/core/Allocator.h @@ -103,7 +103,7 @@ class C10_API DataPtr { * be; be sure to read the source code of the Allocator * in question to confirm this. */ - C10_NODISCARD bool compare_exchange_deleter( + [[nodiscard]] bool compare_exchange_deleter( DeleterFnPtr expected_deleter, DeleterFnPtr new_deleter) { return ptr_.compare_exchange_deleter(expected_deleter, new_deleter); @@ -157,6 +157,7 @@ inline bool operator!=(std::nullptr_t, const DataPtr& dp) noexcept { // possible, or the raw interface will incorrectly reported as unsupported, // when it is actually possible. +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) struct C10_API Allocator { virtual ~Allocator() = default; @@ -223,10 +224,24 @@ struct C10_API Allocator { // allocation InefficientStdFunctionContext, on top of the dynamic // allocation which is implied by std::function itself. struct C10_API InefficientStdFunctionContext { - void* ptr_; + void* ptr_{nullptr}; std::function deleter_; InefficientStdFunctionContext(void* ptr, std::function deleter) : ptr_(ptr), deleter_(std::move(deleter)) {} + InefficientStdFunctionContext(const InefficientStdFunctionContext&) = delete; + InefficientStdFunctionContext(InefficientStdFunctionContext&& rhs) noexcept + : ptr_(std::exchange(rhs.ptr_, nullptr)), + deleter_(std::move(rhs.deleter_)) {} + InefficientStdFunctionContext& operator=( + const InefficientStdFunctionContext&) = delete; + // NOLINTNEXTLINE(performance-noexcept-move-constructor) + InefficientStdFunctionContext& operator=( + InefficientStdFunctionContext&& rhs) { + this->~InefficientStdFunctionContext(); + ptr_ = std::exchange(rhs.ptr_, nullptr); + deleter_ = std::move(rhs.deleter_); + return *this; + } ~InefficientStdFunctionContext() { if (deleter_) { deleter_(ptr_); @@ -270,9 +285,6 @@ struct AllocatorRegisterer { // An interface for reporting thread local memory usage // per device struct C10_API MemoryReportingInfoBase : public c10::DebugInfoBase { - MemoryReportingInfoBase(); - ~MemoryReportingInfoBase() override = default; - /** * alloc_size corresponds to the size of the ptr. * @@ -312,6 +324,7 @@ C10_API void reportOutOfMemoryToProfiler( Device device); // used to hold traceback information in allocators +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) struct GatheredContext { virtual ~GatheredContext() = default; }; diff --git a/c10/core/CPUAllocator.cpp b/c10/core/CPUAllocator.cpp index 144e1b27b6de6..cac00cd7b27d9 100644 --- a/c10/core/CPUAllocator.cpp +++ b/c10/core/CPUAllocator.cpp @@ -75,9 +75,6 @@ ProfiledCPUMemoryReporter& profiledCPUMemoryReporter() { template class DefaultMobileCPUAllocator final : public at::Allocator { public: - DefaultMobileCPUAllocator() = default; - ~DefaultMobileCPUAllocator() override = default; - static void deleter(void* const pointer) { if (C10_UNLIKELY(!pointer)) { return; @@ -114,8 +111,7 @@ class DefaultMobileCPUAllocator final : public at::Allocator { } auto alloc_size = PreGuardBytes + nbytes + PostGuardBytes; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - void* data; + void* data = nullptr; auto allocator_ptr = GetThreadLocalCachingAllocator(); auto profiling_allocator_ptr = GetThreadLocalProfilingAllocator(); if (allocator_ptr != nullptr) { diff --git a/c10/core/Device.cpp b/c10/core/Device.cpp index 1b19114663c1f..5b0e5e5601290 100644 --- a/c10/core/Device.cpp +++ b/c10/core/Device.cpp @@ -36,6 +36,11 @@ DeviceType parse_type(const std::string& device_string) { {"mtia", DeviceType::MTIA}, {"privateuseone", DeviceType::PrivateUse1}, }}; + if (device_string == "mkldnn") { + TORCH_WARN_ONCE( + "'mkldnn' is no longer used as device type. So torch.device('mkldnn') will be " + "deprecated and removed in the future. Please use other valid device types instead."); + } auto device = std::find_if( types.begin(), types.end(), diff --git a/c10/core/DeviceGuard.h b/c10/core/DeviceGuard.h index 94b89bc31b729..7fa3660494804 100644 --- a/c10/core/DeviceGuard.h +++ b/c10/core/DeviceGuard.h @@ -34,6 +34,8 @@ class DeviceGuard { const impl::DeviceGuardImplInterface* impl) : guard_(device, impl) {} + ~DeviceGuard() = default; + /// Copy is disallowed DeviceGuard(const DeviceGuard&) = delete; DeviceGuard& operator=(const DeviceGuard&) = delete; @@ -143,6 +145,7 @@ class OptionalDeviceGuard { const impl::DeviceGuardImplInterface* impl) : guard_(device, impl) {} + ~OptionalDeviceGuard() = default; /// Copy is disallowed OptionalDeviceGuard(const OptionalDeviceGuard&) = delete; OptionalDeviceGuard& operator=(const OptionalDeviceGuard&) = delete; diff --git a/c10/core/DispatchKeySet.h b/c10/core/DispatchKeySet.h index ca54e1966c5e6..289a88312c916 100644 --- a/c10/core/DispatchKeySet.h +++ b/c10/core/DispatchKeySet.h @@ -349,10 +349,10 @@ class DispatchKeySet final { } // Add a DispatchKey to the DispatchKey set. Does NOT mutate, // returns the extended DispatchKeySet! - C10_NODISCARD constexpr DispatchKeySet add(DispatchKey t) const { + [[nodiscard]] constexpr DispatchKeySet add(DispatchKey t) const { return *this | DispatchKeySet(t); } - C10_NODISCARD constexpr DispatchKeySet add(DispatchKeySet ks) const { + [[nodiscard]] constexpr DispatchKeySet add(DispatchKeySet ks) const { return *this | ks; } @@ -380,7 +380,7 @@ class DispatchKeySet final { // // Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd" // bit from the bitset. - C10_NODISCARD constexpr DispatchKeySet remove(DispatchKey t) const { + [[nodiscard]] constexpr DispatchKeySet remove(DispatchKey t) const { return DispatchKeySet( repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask)); } diff --git a/c10/core/GeneratorImpl.cpp b/c10/core/GeneratorImpl.cpp index d7bb389d70453..8025ab966e720 100644 --- a/c10/core/GeneratorImpl.cpp +++ b/c10/core/GeneratorImpl.cpp @@ -88,8 +88,7 @@ static uint64_t readURandomLong() { * a 32 bit number to 64 bit. */ uint64_t getNonDeterministicRandom(bool is_cuda) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint64_t s; + uint64_t s = 0; if (!is_cuda) { #ifdef _WIN32 s = (uint64_t)std::chrono::high_resolution_clock::now() diff --git a/c10/core/GeneratorImpl.h b/c10/core/GeneratorImpl.h index 6757b6de6f65c..3b0b78ef46010 100644 --- a/c10/core/GeneratorImpl.h +++ b/c10/core/GeneratorImpl.h @@ -61,6 +61,7 @@ struct C10_API GeneratorImpl : public c10::intrusive_ptr_target { GeneratorImpl(const GeneratorImpl& other) = delete; GeneratorImpl(GeneratorImpl&& other) = delete; GeneratorImpl& operator=(const GeneratorImpl& other) = delete; + GeneratorImpl& operator=(GeneratorImpl&& other) = delete; ~GeneratorImpl() override = default; c10::intrusive_ptr clone() const; diff --git a/c10/core/GradMode.h b/c10/core/GradMode.h index d60add2cd2b06..a8f6329cf83bd 100644 --- a/c10/core/GradMode.h +++ b/c10/core/GradMode.h @@ -16,6 +16,10 @@ struct C10_API AutoGradMode { AutoGradMode(bool enabled) : prev_mode(GradMode::is_enabled()) { GradMode::set_enabled(enabled); } + AutoGradMode(const AutoGradMode&) = delete; + AutoGradMode(AutoGradMode&&) = delete; + AutoGradMode& operator=(const AutoGradMode&) = delete; + AutoGradMode& operator=(AutoGradMode&&) = delete; ~AutoGradMode() { GradMode::set_enabled(prev_mode); } @@ -35,6 +39,10 @@ struct C10_API AutoFwGradMode { : prev_mode(AutogradState::get_tls_state().get_fw_grad_mode()) { AutogradState::get_tls_state().set_fw_grad_mode(enabled); } + AutoFwGradMode(const AutoFwGradMode&) = delete; + AutoFwGradMode(AutoFwGradMode&&) = delete; + AutoFwGradMode& operator=(const AutoFwGradMode&) = delete; + AutoFwGradMode& operator=(AutoFwGradMode&&) = delete; ~AutoFwGradMode() { AutogradState::get_tls_state().set_fw_grad_mode(prev_mode); } diff --git a/c10/core/InferenceMode.h b/c10/core/InferenceMode.h index 52541886c0aea..a9cf2f0bf32e0 100644 --- a/c10/core/InferenceMode.h +++ b/c10/core/InferenceMode.h @@ -73,6 +73,11 @@ struct C10_API InferenceMode { c10::impl::_force_tls_local_dispatch_key_set(cur_keyset); } + InferenceMode(const InferenceMode&) = delete; + InferenceMode(InferenceMode&&) = delete; + InferenceMode& operator=(const InferenceMode&) = delete; + InferenceMode& operator=(InferenceMode&&) = delete; + ~InferenceMode() { AutogradState::set_tls_state(prev_mode); c10::impl::_force_tls_local_dispatch_key_set(prev_keyset); diff --git a/c10/core/SafePyObject.h b/c10/core/SafePyObject.h index bd6022e8c14da..6102aed8c0ba9 100644 --- a/c10/core/SafePyObject.h +++ b/c10/core/SafePyObject.h @@ -81,9 +81,11 @@ template struct SafePyObjectT : private SafePyObject { SafePyObjectT(PyObject* data, c10::impl::PyInterpreter* pyinterpreter) : SafePyObject(data, pyinterpreter) {} + ~SafePyObjectT() = default; SafePyObjectT(SafePyObjectT&& other) noexcept : SafePyObject(other) {} SafePyObjectT(SafePyObjectT const&) = delete; SafePyObjectT& operator=(SafePyObjectT const&) = delete; + SafePyObjectT& operator=(SafePyObjectT&&) = delete; using SafePyObject::ptr; using SafePyObject::pyinterpreter; diff --git a/c10/core/ScalarType.cpp b/c10/core/ScalarType.cpp index 05f709d648279..e3fe4b07532ad 100644 --- a/c10/core/ScalarType.cpp +++ b/c10/core/ScalarType.cpp @@ -154,6 +154,20 @@ std::pair getDtypeNames(c10::ScalarType scalarType) { return std::make_pair("uint32", ""); case c10::ScalarType::UInt64: return std::make_pair("uint64", ""); + case c10::ScalarType::Int1: + return std::make_pair("int1", ""); + case c10::ScalarType::Int2: + return std::make_pair("int2", ""); + case c10::ScalarType::Int3: + return std::make_pair("int3", ""); + case c10::ScalarType::Int4: + return std::make_pair("int4", ""); + case c10::ScalarType::Int5: + return std::make_pair("int5", ""); + case c10::ScalarType::Int6: + return std::make_pair("int6", ""); + case c10::ScalarType::Int7: + return std::make_pair("int7", ""); case c10::ScalarType::Char: // no "char" because it is not consistently signed or unsigned; we want // to move to int8 diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h index 0d602e2cfec0a..fa0ef9be84129 100644 --- a/c10/core/ScalarType.h +++ b/c10/core/ScalarType.h @@ -31,6 +31,11 @@ namespace c10 { template struct dummy_uint1_7_t {}; +// dummy struct for int1 to int7, actual functionality +// of these dtypes will be implemented in python with Tensor subclass +template +struct dummy_int1_7_t {}; + // For the macros below: // // For users: If you want to macro some code for all non-QInt scalar types @@ -90,7 +95,14 @@ struct dummy_uint1_7_t {}; _(c10::dummy_uint1_7_t<4>, UInt4) /* 33 */ \ _(c10::dummy_uint1_7_t<5>, UInt5) /* 34 */ \ _(c10::dummy_uint1_7_t<6>, UInt6) /* 35 */ \ - _(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ + _(c10::dummy_uint1_7_t<7>, UInt7) /* 36 */ \ + _(c10::dummy_int1_7_t<1>, Int1) /* 37 */ \ + _(c10::dummy_int1_7_t<2>, Int2) /* 38 */ \ + _(c10::dummy_int1_7_t<3>, Int3) /* 39 */ \ + _(c10::dummy_int1_7_t<4>, Int4) /* 40 */ \ + _(c10::dummy_int1_7_t<5>, Int5) /* 41 */ \ + _(c10::dummy_int1_7_t<6>, Int6) /* 42 */ \ + _(c10::dummy_int1_7_t<7>, Int7) /* 43 */ // If you want to support ComplexHalf for real, add ComplexHalf // into this macro (and change the name). But beware: convert() @@ -467,6 +479,14 @@ inline bool isSignedType(ScalarType t) { CASE_ISSIGNED(ComplexFloat); CASE_ISSIGNED(ComplexDouble); CASE_ISSIGNED(Bool); + case ScalarType::Int1: + case ScalarType::Int2: + case ScalarType::Int3: + case ScalarType::Int4: + case ScalarType::Int5: + case ScalarType::Int6: + case ScalarType::Int7: + return true; case ScalarType::UInt1: case ScalarType::UInt2: case ScalarType::UInt3: @@ -474,7 +494,7 @@ inline bool isSignedType(ScalarType t) { case ScalarType::UInt5: case ScalarType::UInt6: case ScalarType::UInt7: - return true; + return false; case ScalarType::Undefined: case ScalarType::NumOptions: break; diff --git a/c10/core/StorageImpl.cpp b/c10/core/StorageImpl.cpp index df43a796acce0..2e8d51cbade1b 100644 --- a/c10/core/StorageImpl.cpp +++ b/c10/core/StorageImpl.cpp @@ -40,6 +40,14 @@ void warnDeprecatedDataPtr() { "isinstance(tensor, FakeTensor).") } +[[noreturn]] void StorageImpl::throw_data_ptr_access_error() const { + if (extra_meta_ && extra_meta_->custom_data_ptr_error_msg_) { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + TORCH_CHECK(false, *extra_meta_->custom_data_ptr_error_msg_); + } + TORCH_CHECK(false, "Cannot access data pointer of Storage that is invalid."); +} + void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) { // Allowlist verification. // Only if the devicetype is in the allowlist, diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index abe6218fbc941..257ddea005125 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -16,9 +16,22 @@ namespace c10 { -C10_API void throwNullDataPtrError(); +[[noreturn]] C10_API void throwNullDataPtrError(); C10_API void warnDeprecatedDataPtr(); +// Used in StorageImpl to store extra metadata. +// Currently used only for storing a custom error message +// used when throwing an exception when data_ptr is accessed. +struct C10_API StorageExtraMeta { + std::optional custom_data_ptr_error_msg_ = std::nullopt; + StorageExtraMeta() = default; + StorageExtraMeta(const StorageExtraMeta& other) { + if (other.custom_data_ptr_error_msg_) { + custom_data_ptr_error_msg_ = other.custom_data_ptr_error_msg_; + } + } +}; + // A storage represents the underlying backing data buffer for a // tensor. This concept was inherited from the original Torch7 // codebase; we'd kind of like to get rid of the concept @@ -123,11 +136,17 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { } const at::DataPtr& data_ptr() const { + if (C10_UNLIKELY(throw_on_immutable_data_ptr_)) { + throw_data_ptr_access_error(); + } return data_ptr_; } at::DataPtr& mutable_data_ptr() { - if (C10_UNLIKELY(has_data_ptr_check_)) { + if (C10_UNLIKELY(has_mutable_data_ptr_check_)) { + if (throw_on_immutable_data_ptr_) { + throw_data_ptr_access_error(); + } if (throw_on_mutable_data_ptr_) { throwNullDataPtrError(); } @@ -158,11 +177,17 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { } const void* data() const { + if (C10_UNLIKELY(throw_on_immutable_data_ptr_)) { + throw_data_ptr_access_error(); + } return data_ptr_.get(); } void* mutable_data() { - if (C10_UNLIKELY(has_data_ptr_check_)) { + if (C10_UNLIKELY(has_mutable_data_ptr_check_)) { + if (throw_on_immutable_data_ptr_) { + throw_data_ptr_access_error(); + } if (throw_on_mutable_data_ptr_) { throwNullDataPtrError(); } @@ -248,6 +273,22 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { return &pyobj_slot_; } + StorageExtraMeta& get_extra_meta() { + if (!extra_meta_) { + extra_meta_ = std::make_unique(); + } + return *extra_meta_; + } + + [[noreturn]] void throw_data_ptr_access_error() const; + + void release_data_and_set_meta_custom_data_ptr_error_msg_( + std::optional s) { + throw_on_immutable_data_ptr_ = true; + get_extra_meta().custom_data_ptr_error_msg_ = std::move(s); + refresh_has_data_ptr_check(); + } + void set_throw_on_mutable_data_ptr() { throw_on_mutable_data_ptr_ = true; refresh_has_data_ptr_check(); @@ -273,8 +314,8 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { private: void refresh_has_data_ptr_check() { - has_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_ || - warn_deprecated_on_mutable_data_ptr_; + has_mutable_data_ptr_check_ = is_cow() || throw_on_mutable_data_ptr_ || + warn_deprecated_on_mutable_data_ptr_ || throw_on_immutable_data_ptr_; } inline bool is_cow() const { @@ -298,13 +339,16 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { // All special checks in data/data_ptr calls are guarded behind this single // boolean. This is for performance: .data/.data_ptr calls are commonly in the // hot-path. - bool has_data_ptr_check_ = false; + bool has_mutable_data_ptr_check_ = false; // If we should throw when mutable_data_ptr() or mutable_data() is called. bool throw_on_mutable_data_ptr_ = false; + // If we should throw when data_ptr() or data() is called. + bool throw_on_immutable_data_ptr_ = false; // If we warn when mutable_data_ptr() or mutable_data() is called. bool warn_deprecated_on_mutable_data_ptr_ = false; Allocator* allocator_; impl::PyObjectSlot pyobj_slot_; + std::unique_ptr extra_meta_ = nullptr; }; // Declare StorageImpl create function pointer types. diff --git a/c10/core/StreamGuard.h b/c10/core/StreamGuard.h index db6dbd88cbd9c..d3057823a5cd1 100644 --- a/c10/core/StreamGuard.h +++ b/c10/core/StreamGuard.h @@ -27,6 +27,7 @@ namespace c10 { struct StreamGuard { /// No default constructor, see Note [Omitted default constructor from RAII] explicit StreamGuard() = delete; + ~StreamGuard() = default; /// Set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream. @@ -111,6 +112,7 @@ struct OptionalStreamGuard { // See Note [Move assignment for RAII guards is tricky] OptionalStreamGuard& operator=(OptionalStreamGuard&& other) = delete; + ~OptionalStreamGuard() = default; /// Resets the currently set stream to the original stream and /// the currently set device to the original device. Then, @@ -162,6 +164,7 @@ struct MultiStreamGuard { // See Note [Move assignment for RAII guards is tricky] MultiStreamGuard& operator=(MultiStreamGuard&& other) = delete; + ~MultiStreamGuard() = default; private: c10::impl::InlineMultiStreamGuard guard_; diff --git a/c10/core/SymbolicShapeMeta.cpp b/c10/core/SymbolicShapeMeta.cpp index b59a95a4a2faf..4f272e177be4b 100644 --- a/c10/core/SymbolicShapeMeta.cpp +++ b/c10/core/SymbolicShapeMeta.cpp @@ -186,7 +186,6 @@ SymBool SymbolicShapeMeta::compute_is_non_overlapping_and_dense_anydim() const { return is_contiguous() | compute_non_overlapping_and_dense(); } -// NOLINTNEXTLINE(performance-unnecessary-value-param) void SymbolicShapeMeta::set_numel(SymInt val) const { std::scoped_lock lock(mutables_); if (has_numel()) { diff --git a/c10/core/SymbolicShapeMeta.h b/c10/core/SymbolicShapeMeta.h index 935f6481d02fc..ce0769a8074f7 100644 --- a/c10/core/SymbolicShapeMeta.h +++ b/c10/core/SymbolicShapeMeta.h @@ -22,7 +22,9 @@ class C10_API SymbolicShapeMeta { bool strides_valid_ = true; // e.g. for sparse where there are no strides SymbolicShapeMeta() = default; + ~SymbolicShapeMeta() = default; SymbolicShapeMeta(const SymbolicShapeMeta& other); + SymbolicShapeMeta(SymbolicShapeMeta&& other) = delete; SymbolicShapeMeta& operator=(const SymbolicShapeMeta& other) = delete; SymbolicShapeMeta& operator=(SymbolicShapeMeta&& other) = delete; diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 40bf133d2587e..f268dbe178594 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -81,11 +81,7 @@ TensorImpl::TensorImpl( DispatchKeySet key_set, const caffe2::TypeMeta data_type) // Use std::forward to suppress static analyzer false positive. - : TensorImpl( - std::forward(storage), - key_set, - data_type, - storage.device()) {} + : TensorImpl(std::move(storage), key_set, data_type, storage.device()) {} // [Note: Python key removal] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -111,7 +107,6 @@ TensorImpl::TensorImpl( DispatchKeySet key_set, const caffe2::TypeMeta data_type) : storage_(std::move(storage)), - numel_(0), data_type_(data_type), device_opt_(storage_.device()), @@ -123,7 +118,6 @@ TensorImpl::TensorImpl( } } -// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TensorImpl::TensorImpl( DispatchKeySet key_set, const caffe2::TypeMeta data_type, @@ -137,7 +131,6 @@ TensorImpl::TensorImpl( const caffe2::TypeMeta data_type, std::optional device_opt) : storage_(std::move(storage)), - numel_(0), data_type_(data_type), device_opt_(device_opt) { diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index a8d05dddcfa26..888881ac2d74d 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -133,6 +133,7 @@ struct C10_API PlacementDeleteContext { DataPtr data_ptr_; PlacementDtor placement_dtor_; size_t size_; + PlacementDeleteContext( DataPtr&& data_ptr, PlacementDtor placement_dtor, @@ -140,6 +141,11 @@ struct C10_API PlacementDeleteContext { : data_ptr_(std::move(data_ptr)), placement_dtor_(placement_dtor), size_(size) {} + + PlacementDeleteContext(PlacementDeleteContext&&) noexcept = delete; + PlacementDeleteContext(const PlacementDeleteContext&) = delete; + PlacementDeleteContext& operator=(const PlacementDeleteContext&) = delete; + PlacementDeleteContext& operator=(PlacementDeleteContext&&) = delete; static DataPtr makeDataPtr( DataPtr&& data_ptr, PlacementDtor placement_dtor, @@ -237,6 +243,7 @@ struct C10_API ExtraMeta { std::optional custom_storage_error_msg_ = std::nullopt; ExtraMeta() = default; + ~ExtraMeta() = default; ExtraMeta(const ExtraMeta& other) { if (other.symbolic_shape_meta_) { symbolic_shape_meta_ = diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index f98a93302e14e..d5412ecbad878 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -29,12 +29,12 @@ DispatchKey computeDispatchKey( std::optional device); inline ScalarType dtype_or_default(std::optional dtype) { - return value_or_else(dtype, [] { return get_default_dtype_as_scalartype(); }); + return dtype.value_or(get_default_dtype_as_scalartype()); } inline caffe2::TypeMeta dtype_or_default( std::optional dtype) { - return value_or_else(dtype, [] { return get_default_dtype(); }); + return dtype.value_or(get_default_dtype()); } inline Layout layout_or_default(std::optional layout) { @@ -42,7 +42,7 @@ inline Layout layout_or_default(std::optional layout) { } inline Device device_or_default(std::optional device) { - return value_or_else(device, [] { return Device(kCPU); }); + return device.value_or(Device(kCPU)); } inline bool pinned_memory_or_default(std::optional pinned_memory) { @@ -192,8 +192,8 @@ struct C10_API TensorOptions { /// Return a copy of `TensorOptions` with `device` set to the given one, or /// cleared if `device` is `nullopt`. - C10_NODISCARD TensorOptions - device(std::optional device) const noexcept { + [[nodiscard]] TensorOptions device( + std::optional device) const noexcept { TensorOptions r = *this; r.set_device(device); return r; @@ -203,7 +203,7 @@ struct C10_API TensorOptions { /// (This overload ensures that variadic template std::optional constructor /// for Device work correctly.) template - C10_NODISCARD TensorOptions device(Args&&... args) const noexcept { + [[nodiscard]] TensorOptions device(Args&&... args) const noexcept { return device( std::optional(std::in_place, std::forward(args)...)); } @@ -213,22 +213,22 @@ struct C10_API TensorOptions { /// /// TODO: This function encourages bad behavior (assuming CUDA is /// the only device that matters). Get rid of it / rename it. - C10_NODISCARD TensorOptions - device_index(c10::DeviceIndex device_index) const noexcept { + [[nodiscard]] TensorOptions device_index( + c10::DeviceIndex device_index) const noexcept { return device(Device::Type::CUDA, device_index); } /// Return a copy of `TensorOptions` with `dtype` set to the given one. - C10_NODISCARD TensorOptions - dtype(std::optional dtype) const noexcept { + [[nodiscard]] TensorOptions dtype( + std::optional dtype) const noexcept { TensorOptions r = *this; r.set_dtype(dtype); return r; } // legacy function to support ScalarType - C10_NODISCARD TensorOptions - dtype(std::optional dtype) const noexcept { + [[nodiscard]] TensorOptions dtype( + std::optional dtype) const noexcept { TensorOptions r = *this; r.set_dtype(dtype); return r; @@ -243,32 +243,32 @@ struct C10_API TensorOptions { } /// Sets the layout of the `TensorOptions`. - C10_NODISCARD TensorOptions - layout(std::optional layout) const noexcept { + [[nodiscard]] TensorOptions layout( + std::optional layout) const noexcept { TensorOptions r = *this; r.set_layout(layout); return r; } /// Sets the `requires_grad` property of the `TensorOptions`. - C10_NODISCARD TensorOptions - requires_grad(std::optional requires_grad) const noexcept { + [[nodiscard]] TensorOptions requires_grad( + std::optional requires_grad) const noexcept { TensorOptions r = *this; r.set_requires_grad(requires_grad); return r; } /// Sets the `pinned_memory` property on the `TensorOptions`. - C10_NODISCARD TensorOptions - pinned_memory(std::optional pinned_memory) const noexcept { + [[nodiscard]] TensorOptions pinned_memory( + std::optional pinned_memory) const noexcept { TensorOptions r = *this; r.set_pinned_memory(pinned_memory); return r; } /// Sets the `memory_format` property on `TensorOptions`. - C10_NODISCARD TensorOptions - memory_format(std::optional memory_format) const noexcept { + [[nodiscard]] TensorOptions memory_format( + std::optional memory_format) const noexcept { TensorOptions r = *this; r.set_memory_format(memory_format); return r; diff --git a/c10/core/impl/InlineDeviceGuard.h b/c10/core/impl/InlineDeviceGuard.h index e0c6d4f1ca8f9..a80ac550906aa 100644 --- a/c10/core/impl/InlineDeviceGuard.h +++ b/c10/core/impl/InlineDeviceGuard.h @@ -62,7 +62,7 @@ class InlineDeviceGuard { // DeviceGuard which reads the current device and promises to // restore to that device on exit. However, most cases where you // would have written this, you probably meant to actually just - // use OptionalDeviceGuard (since you don't actually need the + // use DeviceGuard (since you don't actually need the // restore to happen if you don't ever actually set the device). // We remove the constructor here to encourage you to think about // what you actually want to happen. @@ -221,6 +221,7 @@ class InlineOptionalDeviceGuard { explicit InlineOptionalDeviceGuard() : guard_() // See Note [Explicit initialization of optional fields] {} + ~InlineOptionalDeviceGuard() = default; /// Set the current device to the passed Device, if it is not nullopt. explicit InlineOptionalDeviceGuard(std::optional device_opt) @@ -286,6 +287,7 @@ class InlineOptionalDeviceGuard { // It's in principle possible to raise an error when this occurs // by doing some extra thread-local bookkeeping. But why bother? // Just don't provide the constructor. + InlineOptionalDeviceGuard(const InlineOptionalDeviceGuard& other) = delete; InlineOptionalDeviceGuard(InlineOptionalDeviceGuard&& other) = delete; // Note [Move assignment for RAII guards is tricky] @@ -335,6 +337,8 @@ class InlineOptionalDeviceGuard { // // We could solve this with an extra thread-local variable. But no one is // actually using move-assignment. So just get rid of it. + InlineOptionalDeviceGuard& operator=(const InlineOptionalDeviceGuard& other) = + delete; InlineOptionalDeviceGuard& operator=(InlineOptionalDeviceGuard&& other) = delete; diff --git a/c10/core/impl/InlineStreamGuard.h b/c10/core/impl/InlineStreamGuard.h index 6d2b3c70678ee..51c25e25ffa6b 100644 --- a/c10/core/impl/InlineStreamGuard.h +++ b/c10/core/impl/InlineStreamGuard.h @@ -135,6 +135,7 @@ class InlineOptionalStreamGuard { explicit InlineOptionalStreamGuard() : guard_() // See Note [Explicit initialization of optional fields] {} + ~InlineOptionalStreamGuard() = default; /// Set the current device to the device associated with the passed stream, /// and set the current stream on that device to the passed stream, @@ -151,6 +152,9 @@ class InlineOptionalStreamGuard { explicit InlineOptionalStreamGuard(Args&&... args) : guard_(std::in_place, std::forward(args)...) {} + InlineOptionalStreamGuard(const InlineOptionalStreamGuard& other) = delete; + InlineOptionalStreamGuard& operator=(const InlineOptionalStreamGuard& other) = + delete; // See Note [Move construction for RAII guards is tricky] InlineOptionalStreamGuard(InlineOptionalStreamGuard&& other) = delete; diff --git a/c10/core/impl/LocalDispatchKeySet.h b/c10/core/impl/LocalDispatchKeySet.h index 176d0a6b64219..1232bd25eb3bd 100644 --- a/c10/core/impl/LocalDispatchKeySet.h +++ b/c10/core/impl/LocalDispatchKeySet.h @@ -132,6 +132,11 @@ struct C10_API ForceDispatchKeyGuard { updated_set.excluded_ = exclude; c10::impl::_force_tls_local_dispatch_key_set(updated_set); } + + ForceDispatchKeyGuard(ForceDispatchKeyGuard&&) noexcept = delete; + ForceDispatchKeyGuard(const ForceDispatchKeyGuard&) = delete; + ForceDispatchKeyGuard& operator=(const ForceDispatchKeyGuard&) = delete; + ForceDispatchKeyGuard& operator=(ForceDispatchKeyGuard&&) = delete; ~ForceDispatchKeyGuard() { c10::impl::_force_tls_local_dispatch_key_set(saved_keyset_); } diff --git a/c10/core/impl/PythonDispatcherTLS.h b/c10/core/impl/PythonDispatcherTLS.h index 9016c3e11e157..12c0677f36fdb 100644 --- a/c10/core/impl/PythonDispatcherTLS.h +++ b/c10/core/impl/PythonDispatcherTLS.h @@ -15,6 +15,7 @@ struct C10_API DisablePythonDispatcher { DisablePythonDispatcher() : old_(PythonDispatcherTLS::get_state()) { PythonDispatcherTLS::set_state({}); } + ~DisablePythonDispatcher() { PythonDispatcherTLS::set_state(old_); } diff --git a/c10/core/impl/alloc_cpu.cpp b/c10/core/impl/alloc_cpu.cpp index 9b7ae22f9f841..f976e7b745e21 100644 --- a/c10/core/impl/alloc_cpu.cpp +++ b/c10/core/impl/alloc_cpu.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -53,8 +54,8 @@ void memset_junk(void* data, size_t num) { #if defined(__linux__) && !defined(__ANDROID__) static inline bool is_thp_alloc_enabled() { static bool value = [&] { - const char* ptr = std::getenv("THP_MEM_ALLOC_ENABLE"); - return ptr != nullptr ? std::atoi(ptr) : 0; + auto env = c10::utils::check_env("THP_MEM_ALLOC_ENABLE"); + return env.has_value() ? env.value() : 0; }(); return value; } @@ -71,11 +72,11 @@ inline bool is_thp_alloc(size_t nbytes) { return (is_thp_alloc_enabled() && (nbytes >= gAlloc_threshold_thp)); } #elif !defined(__ANDROID__) && !defined(_MSC_VER) -constexpr size_t c10_compute_alignment(C10_UNUSED size_t nbytes) { +constexpr size_t c10_compute_alignment([[maybe_unused]] size_t nbytes) { return gAlignment; } -constexpr bool is_thp_alloc(C10_UNUSED size_t nbytes) { +constexpr bool is_thp_alloc([[maybe_unused]] size_t nbytes) { return false; } #endif @@ -92,8 +93,7 @@ void* alloc_cpu(size_t nbytes) { "alloc_cpu() seems to have been called with negative number: ", nbytes); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - void* data; + void* data = nullptr; #ifdef __ANDROID__ data = memalign(gAlignment, nbytes); CAFFE_ENFORCE( @@ -163,4 +163,27 @@ void free_cpu(void* data) { #endif } +#ifdef USE_MIMALLOC_ON_MKL +namespace mi_malloc_wrapper { +void* c10_mi_malloc(size_t size) { + return mi_malloc(size); +} + +void* c10_mi_calloc(size_t count, size_t size) { + return mi_calloc(count, size); +} + +void* c10_mi_realloc(void* p, size_t newsize) { + return mi_realloc(p, newsize); +} + +void* c10_mi_malloc_aligned(size_t size, size_t alignment) { + return mi_malloc_aligned(size, alignment); +} + +void c10_mi_free(void* p) { + mi_free(p); +} +} // namespace mi_malloc_wrapper +#endif } // namespace c10 diff --git a/c10/core/impl/alloc_cpu.h b/c10/core/impl/alloc_cpu.h index ee32a0f463068..8d506acf392f4 100644 --- a/c10/core/impl/alloc_cpu.h +++ b/c10/core/impl/alloc_cpu.h @@ -9,4 +9,14 @@ namespace c10 { C10_API void* alloc_cpu(size_t nbytes); C10_API void free_cpu(void* data); +#ifdef USE_MIMALLOC_ON_MKL +namespace mi_malloc_wrapper { +C10_API void* c10_mi_malloc(size_t size); +C10_API void* c10_mi_calloc(size_t count, size_t size); +C10_API void* c10_mi_realloc(void* p, size_t newsize); +C10_API void* c10_mi_malloc_aligned(size_t size, size_t alignment); +C10_API void c10_mi_free(void* p); +} // namespace mi_malloc_wrapper +#endif + } // namespace c10 diff --git a/c10/core/thread_pool.cpp b/c10/core/thread_pool.cpp index dfe6cfaeb3343..cb997c1e59e79 100644 --- a/c10/core/thread_pool.cpp +++ b/c10/core/thread_pool.cpp @@ -62,6 +62,7 @@ ThreadPool::~ThreadPool() { for (auto& t : threads_) { try { t.join(); + // NOLINTNEXTLINE(bugprone-empty-catch) } catch (const std::exception&) { } } diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 0967693da78ed..4dc62366be238 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -1896,16 +1897,41 @@ class DeviceCachingAllocator { std::unordered_map pool_to_id; pool_to_id.reserve(graph_pools.size() + graph_pools_freeable.size()); - for (const auto& pair : graph_pools) { - pool_to_id[pair.second.get()] = pair.first; + std::vector all_blocks; + MempoolId_t mempool_id = {0, 0}; + + auto active_mempool = MemPoolContext::getActiveMemPool(); + if (active_mempool) { + mempool_id = active_mempool->id(); } - for (const auto& pair : graph_pools_freeable) { - pool_to_id[pair.second] = pair.first; + + if (mempool_id.first != 0 || mempool_id.second != 0) { + // If there is an active mempool, we find the corresponding PrivatePool + // in graph_pools and only return the blocks from it. + auto pool = graph_pools.find(mempool_id); + if (pool != graph_pools.end()) { + pool_to_id[pool->second.get()] = pool->first; + all_blocks = get_private_pool_head_blocks(pool->second.get()); + } + auto pool_freeable = graph_pools_freeable.find(mempool_id); + if (pool_freeable != graph_pools_freeable.end()) { + pool_to_id[pool_freeable->second] = pool_freeable->first; + } + } else { + // When snapshot is called outside a MemPoolContext, we return + // all the blocks in the CUDACachingAllocator (as returned by + // get_all_blocks). + for (const auto& pair : graph_pools) { + pool_to_id[pair.second.get()] = pair.first; + } + for (const auto& pair : graph_pools_freeable) { + pool_to_id[pair.second] = pair.first; + } + all_blocks = get_all_blocks(); } size_t total_active = 0; std::vector result; - const auto all_blocks = get_all_blocks(); for (const Block* const head_block : all_blocks) { // For expandable segments, we report one segment for each contiguous @@ -2015,6 +2041,13 @@ class DeviceCachingAllocator { } } + void ensureExistsAndIncrefPool(MempoolId_t mempool_id) { + // Create a PrivatePool object if it does not exist yet + // and increment its use_count + std::lock_guard lock(mutex); + ensure_exists_and_incref_pool(mempool_id); + } + // See Note [Interaction with CUDA graph capture] // Called by CUDAGraph::capture_begin @@ -2022,18 +2055,7 @@ class DeviceCachingAllocator { MempoolId_t mempool_id, std::function filter) { std::lock_guard lock(mutex); - auto it = graph_pools.find(mempool_id); - if (it == graph_pools.end()) { - // mempool_id does not reference an existing pool. Make a new pool for - // this capture. - graph_pools.emplace(mempool_id, std::make_unique()); - } else { - // mempool_id references an existing pool, which the current capture will - // share. Check this pool is live (at least one other capture already - // references it). - TORCH_INTERNAL_ASSERT(it->second->use_count > 0); - it->second->use_count++; - } + ensure_exists_and_incref_pool(mempool_id); for (auto it2 = captures_underway.begin(); it2 != captures_underway.end(); ++it2) { TORCH_CHECK( @@ -2057,7 +2079,7 @@ class DeviceCachingAllocator { false, "endAllocatePool: not currently recording to mempool_id"); } - // Called by CUDAGraph::reset + // Called by CUDAGraph::reset and MemPool::~MemPool() void releasePool(MempoolId_t mempool_id) { std::lock_guard lock(mutex); // The instantiated cudaGraphExec_t has been destroyed. We can't blindly @@ -2069,20 +2091,24 @@ class DeviceCachingAllocator { // mempool. When the count reaches 0, we tell free_cached_blocks it may now // cudaFree blocks from this graph's pool when it discovers they're unused // (unsplit). - auto it = graph_pools.find(mempool_id); - TORCH_INTERNAL_ASSERT(it != graph_pools.end()); - auto uc = --(it->second->use_count); + auto pp = get_private_pool(mempool_id); + auto uc = --(pp->use_count); TORCH_INTERNAL_ASSERT(uc >= 0); if (uc == 0) { // Allows free_cached_blocks to begin cudaFreeing this pool's memory, // and makes sure this pool wasn't somehow made freeable already. // NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores) - bool inserted = - graph_pools_freeable.insert({mempool_id, it->second.get()}).second; + bool inserted = graph_pools_freeable.insert({mempool_id, pp}).second; TORCH_INTERNAL_ASSERT(inserted); } } + int getPoolUseCount(MempoolId_t mempool_id) { + std::lock_guard lock(mutex); + auto pp = get_private_pool(mempool_id); + return pp->use_count; + } + void addPeerAccess(c10::DeviceIndex dev_to_access) { std::lock_guard lock(mutex); if (std::find( @@ -2108,8 +2134,8 @@ class DeviceCachingAllocator { private: // All private methods do not acquire the allocator mutex. - std::vector get_all_blocks() const { - std::vector blocks; + std::vector get_all_blocks() const { + std::vector blocks; blocks.insert( blocks.end(), small_blocks.blocks.begin(), small_blocks.blocks.end()); blocks.insert( @@ -2151,6 +2177,30 @@ class DeviceCachingAllocator { return blocks; } + void ensure_exists_and_incref_pool(MempoolId_t mempool_id) { + auto it = graph_pools.find(mempool_id); + if (it == graph_pools.end()) { + // mempool_id does not reference an existing pool. + // Make a new pool for CUDAGraph capture or torch.cuda.use_mem_pool + // usage. use_count is initially 1, which means the pool is + // being used since somebody called ensureExistsAndIncrefPool. + graph_pools.emplace(mempool_id, std::make_unique()); + } else { + // mempool_id references an existing pool, which the current CUDAGraph + // capture or torch.cuda.use_mem_pool will + // share. Check this pool is live (at least one other capture already + // references it). Increment it to establish the usage. + TORCH_INTERNAL_ASSERT(it->second->use_count > 0); + it->second->use_count++; + } + } + + PrivatePool* get_private_pool(MempoolId_t mempool_id) { + auto it = graph_pools.find(mempool_id); + TORCH_INTERNAL_ASSERT(it != graph_pools.end()); + return it->second.get(); + } + // returns the smallest possible address in any segment // where there is enough free address space to fit size // may be composed of free and unmapped segments @@ -3130,9 +3180,11 @@ class DeviceCachingAllocator { static bool forceUncachedAllocator() { // Allow either CUDA or HIP name for env var for maximum user comfort // the CUDA env var avoids being hipified in cuda_to_hip_mappings.py - static const char* cuda_env = getenv("PYTORCH_NO_CUDA_MEMORY_CACHING"); - static const char* rocm_env = getenv("PYTORCH_NO_HIP_MEMORY_CACHING"); - static bool force_uncached = (cuda_env != nullptr) || (rocm_env != nullptr); + static bool has_cuda_env = + c10::utils::has_env("PYTORCH_NO_CUDA_MEMORY_CACHING"); + static bool has_rocm_env = + c10::utils::has_env("PYTORCH_NO_HIP_MEMORY_CACHING"); + static bool force_uncached = has_cuda_env || has_rocm_env; return force_uncached; } @@ -3533,6 +3585,14 @@ class NativeCachingAllocator : public CUDAAllocator { assertValidDevice(device); device_allocator[device]->resetPeakStats(); } + + void ensureExistsAndIncrefPool( + c10::DeviceIndex device, + MempoolId_t mempool_id) override { + assertValidDevice(device); + device_allocator[device]->ensureExistsAndIncrefPool(std::move(mempool_id)); + } + // CUDAGraph interactions void beginAllocateToPool( c10::DeviceIndex device, @@ -3554,6 +3614,12 @@ class NativeCachingAllocator : public CUDAAllocator { device_allocator[device]->releasePool(std::move(mempool_id)); } + int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) + override { + assertValidDevice(device); + return device_allocator[device]->getPoolUseCount(std::move(mempool_id)); + } + void* raw_alloc(size_t nbytes) override { if (nbytes == 0) { return nullptr; @@ -3779,9 +3845,9 @@ struct BackendStaticInitializer { // version checks, to CUDAAllocatorConfig's runtime doublecheck. If this // works, maybe we should move all of CUDAAllocatorConfig here? CUDAAllocator* parseEnvForBackend() { - const char* val = getenv("PYTORCH_CUDA_ALLOC_CONF"); - if (val != nullptr) { - const std::string config(val); + const auto val = c10::utils::get_env("PYTORCH_CUDA_ALLOC_CONF"); + if (val.has_value()) { + const std::string& config = val.value(); std::regex exp("[\\s,]+"); std::sregex_token_iterator it(config.begin(), config.end(), exp, -1); @@ -3841,6 +3907,13 @@ MemPool::MemPool( } else { id_ = {uuid_++, 0}; } + device_ = c10::cuda::current_device(); + CUDACachingAllocator::ensureExistsAndIncrefPool(device_, id_); +} + +MemPool::~MemPool() { + TORCH_INTERNAL_ASSERT(use_count() == 1); + CUDACachingAllocator::releasePool(device_, id_); } MempoolId_t MemPool::id() { @@ -3851,6 +3924,17 @@ CUDACachingAllocator::CUDAAllocator* MemPool::allocator() { return allocator_; } +int MemPool::use_count() { + return CUDACachingAllocator::getPoolUseCount(device_, id_); +} + +MempoolId_t MemPool::graph_pool_handle(bool is_user_created) { + if (is_user_created) { + return {0, uid_++}; + } + return {uuid_++, 0}; +} + // Note that active_mempool_ is a global variable here // and not inside MemPoolContext class, because in windows we // can't use __declspec(dllexport) and __declspec(thread) diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index 910f928341c17..3d9e1ab9f5348 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -224,6 +224,22 @@ class CUDAAllocator : public Allocator { c10::DeviceIndex device, MempoolId_t mempool_id) = 0; virtual void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) = 0; + virtual int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) { + TORCH_CHECK( + false, + name(), + " does not yet support getPoolUseCount. " + "If you need it, please file an issue describing your use case."); + } + virtual void ensureExistsAndIncrefPool( + c10::DeviceIndex device, + MempoolId_t mempool_id) { + TORCH_CHECK( + false, + name(), + " does not yet support ensureExistsAndIncrefPool. " + "If you need it, please file an issue describing your use case."); + } // returns true if the allocated blocks are equal to expected live allocations virtual bool checkPoolLiveAllocations( c10::DeviceIndex device, @@ -427,6 +443,16 @@ inline void attachAllocatorTraceTracker(AllocatorTraceTracker tracker) { inline void releasePool(c10::DeviceIndex device, MempoolId_t mempool_id) { return get()->releasePool(device, mempool_id); } +inline void ensureExistsAndIncrefPool( + c10::DeviceIndex device, + MempoolId_t mempool_id) { + get()->ensureExistsAndIncrefPool(device, mempool_id); +} + +inline int getPoolUseCount(c10::DeviceIndex device, MempoolId_t mempool_id) { + return get()->getPoolUseCount(device, mempool_id); +} + // Not part of CUDA_ALLOCATOR_BACKEND_INTERFACE inline std::shared_ptr getIpcDevPtr(std::string handle) { return get()->getIpcDevPtr(std::move(handle)); @@ -472,9 +498,12 @@ struct C10_CUDA_API MemPool { MemPool( CUDACachingAllocator::CUDAAllocator* allocator = nullptr, bool is_user_created = true); + ~MemPool(); MempoolId_t id(); CUDACachingAllocator::CUDAAllocator* allocator(); + int use_count(); + static MempoolId_t graph_pool_handle(bool is_user_created = true); private: static std::atomic uid_; @@ -482,6 +511,7 @@ struct C10_CUDA_API MemPool { CUDACachingAllocator::CUDAAllocator* allocator_; bool is_user_created_; MempoolId_t id_; + c10::DeviceIndex device_; }; // MemPoolContext holds the currently active pool and stashes the previous diff --git a/c10/cuda/CUDADeviceAssertionHost.cpp b/c10/cuda/CUDADeviceAssertionHost.cpp index 1d52af7812273..21fd8b3052d30 100644 --- a/c10/cuda/CUDADeviceAssertionHost.cpp +++ b/c10/cuda/CUDADeviceAssertionHost.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -80,8 +81,8 @@ bool dsa_check_if_all_devices_support_managed_memory() { } bool env_flag_set(const char* env_var_name) { - const char* const env_string = std::getenv(env_var_name); - return (env_string == nullptr) ? false : std::strcmp(env_string, "0"); + const auto env_flag = c10::utils::check_env(env_var_name); + return env_flag.has_value() && env_flag.value(); } /// Deleter for UVM/managed memory pointers @@ -195,7 +196,7 @@ CUDAKernelLaunchRegistry::CUDAKernelLaunchRegistry() dsa_check_if_all_devices_support_managed_memory()), gather_launch_stacktrace(check_env_for_enable_launch_stacktracing()), enabled_at_runtime(check_env_for_dsa_enabled()) { - for (C10_UNUSED const auto _ : c10::irange(dsa_get_device_count())) { + for ([[maybe_unused]] const auto _ : c10::irange(dsa_get_device_count())) { uvm_assertions.emplace_back(nullptr, uvm_deleter); } diff --git a/c10/cuda/CUDAException.cpp b/c10/cuda/CUDAException.cpp index 2d725747c969d..5b51a3e2a5aed 100644 --- a/c10/cuda/CUDAException.cpp +++ b/c10/cuda/CUDAException.cpp @@ -23,7 +23,7 @@ void c10_cuda_check_implementation( return; } - auto error_unused C10_UNUSED = cudaGetLastError(); + [[maybe_unused]] auto error_unused = cudaGetLastError(); (void)error_unused; std::string check_message; diff --git a/c10/cuda/CUDAException.h b/c10/cuda/CUDAException.h index 7ecb9d6f13e34..899d85e8a73f6 100644 --- a/c10/cuda/CUDAException.h +++ b/c10/cuda/CUDAException.h @@ -40,8 +40,7 @@ class C10_CUDA_API CUDAError : public c10::Error { do { \ const cudaError_t __err = EXPR; \ if (C10_UNLIKELY(__err != cudaSuccess)) { \ - auto error_unused C10_UNUSED = cudaGetLastError(); \ - (void)error_unused; \ + [[maybe_unused]] auto error_unused = cudaGetLastError(); \ TORCH_WARN("CUDA warning: ", cudaGetErrorString(__err)); \ } \ } while (0) @@ -50,20 +49,18 @@ class C10_CUDA_API CUDAError : public c10::Error { #define C10_CUDA_ERROR_HANDLED(EXPR) EXPR // Intentionally ignore a CUDA error -#define C10_CUDA_IGNORE_ERROR(EXPR) \ - do { \ - const cudaError_t __err = EXPR; \ - if (C10_UNLIKELY(__err != cudaSuccess)) { \ - cudaError_t error_unused C10_UNUSED = cudaGetLastError(); \ - (void)error_unused; \ - } \ +#define C10_CUDA_IGNORE_ERROR(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + if (C10_UNLIKELY(__err != cudaSuccess)) { \ + [[maybe_unused]] cudaError_t error_unused = cudaGetLastError(); \ + } \ } while (0) // Clear the last CUDA error -#define C10_CUDA_CLEAR_ERROR() \ - do { \ - cudaError_t error_unused C10_UNUSED = cudaGetLastError(); \ - (void)error_unused; \ +#define C10_CUDA_CLEAR_ERROR() \ + do { \ + [[maybe_unused]] cudaError_t error_unused = cudaGetLastError(); \ } while (0) // This should be used directly after every kernel launch to ensure diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp index 8d88000b89db9..b1d573b16d1c4 100644 --- a/c10/cuda/CUDAFunctions.cpp +++ b/c10/cuda/CUDAFunctions.cpp @@ -1,5 +1,6 @@ #include #include +#include #include @@ -22,7 +23,7 @@ int device_count_impl(bool fail_if_no_driver) { // Clear out the error state, so we don't spuriously trigger someone else. // (This shouldn't really matter, since we won't be running very much CUDA // code in this regime.) - cudaError_t last_err C10_UNUSED = cudaGetLastError(); + [[maybe_unused]] cudaError_t last_err = cudaGetLastError(); switch (err) { case cudaErrorNoDevice: // Zero devices is ok here @@ -138,6 +139,7 @@ void device_synchronize() { if (C10_UNLIKELY(interp)) { (*interp)->trace_gpu_device_synchronization(c10::kCUDA); } + STATIC_SCOPED_WAIT_COUNTER(pytorch.wait_counter.cuda_device_synchronize); C10_CUDA_CHECK(cudaDeviceSynchronize()); } @@ -170,7 +172,7 @@ std::optional getDeviceIndexWithPrimaryContext() { } namespace _internal { -bool dummyHasPrimaryContext(C10_UNUSED DeviceIndex device_index) { +bool dummyHasPrimaryContext([[maybe_unused]] DeviceIndex device_index) { TORCH_CHECK(false, "Should never been called"); } bool (*hasPrimaryContext)(DeviceIndex) = dummyHasPrimaryContext; diff --git a/c10/cuda/CUDAGuard.h b/c10/cuda/CUDAGuard.h index 08b7bb711373f..4bde4ecc6507e 100644 --- a/c10/cuda/CUDAGuard.h +++ b/c10/cuda/CUDAGuard.h @@ -147,6 +147,7 @@ struct CUDAStreamGuard { /// stream, and set the current CUDA stream on that device to the passed /// stream. Errors if the Stream is not a CUDA stream. explicit CUDAStreamGuard(Stream stream) : guard_(stream) {} + ~CUDAStreamGuard() = default; /// Copy is disallowed CUDAStreamGuard(const CUDAStreamGuard&) = delete; diff --git a/c10/cuda/CUDAMiscFunctions.cpp b/c10/cuda/CUDAMiscFunctions.cpp index f55bba13e948d..cc6519728f1ea 100644 --- a/c10/cuda/CUDAMiscFunctions.cpp +++ b/c10/cuda/CUDAMiscFunctions.cpp @@ -1,12 +1,14 @@ #include -#include +#include namespace c10::cuda { +// NOLINTNEXTLINE(bugprone-exception-escape,-warnings-as-errors) const char* get_cuda_check_suffix() noexcept { - static char* device_blocking_flag = getenv("CUDA_LAUNCH_BLOCKING"); + static auto device_blocking_flag = + c10::utils::check_env("CUDA_LAUNCH_BLOCKING"); static bool blocking_enabled = - (device_blocking_flag && atoi(device_blocking_flag)); + (device_blocking_flag.has_value() && device_blocking_flag.value()); if (blocking_enabled) { return ""; } else { diff --git a/c10/cuda/driver_api.h b/c10/cuda/driver_api.h index 35a55b91a0f1b..d698beada411f 100644 --- a/c10/cuda/driver_api.h +++ b/c10/cuda/driver_api.h @@ -8,12 +8,12 @@ CUresult __err = EXPR; \ if (__err != CUDA_SUCCESS) { \ const char* err_str; \ - CUresult get_error_str_err C10_UNUSED = \ + CUresult get_error_str_err [[maybe_unused]] = \ c10::cuda::DriverAPI::get()->cuGetErrorString_(__err, &err_str); \ if (get_error_str_err != CUDA_SUCCESS) { \ - AT_ERROR("CUDA driver error: unknown error"); \ + TORCH_CHECK(false, "CUDA driver error: unknown error"); \ } else { \ - AT_ERROR("CUDA driver error: ", err_str); \ + TORCH_CHECK(false, "CUDA driver error: ", err_str); \ } \ } \ } while (0) diff --git a/c10/macros/Macros.h b/c10/macros/Macros.h index 01e71c88e7180..157b2a51287ab 100644 --- a/c10/macros/Macros.h +++ b/c10/macros/Macros.h @@ -33,12 +33,15 @@ #define __ubsan_ignore_pointer_overflow__ \ __attribute__((no_sanitize("pointer-overflow"))) #define __ubsan_ignore_function__ __attribute__((no_sanitize("function"))) +#define __ubsan_ignore_float_cast_overflow__ \ + __attribute__((no_sanitize("float-cast-overflow"))) #else #define __ubsan_ignore_float_divide_by_zero__ #define __ubsan_ignore_undefined__ #define __ubsan_ignore_signed_int_overflow__ #define __ubsan_ignore_pointer_overflow__ #define __ubsan_ignore_function__ +#define __ubsan_ignore_float_cast_overflow__ #endif // Detect address sanitizer as some stuff doesn't work with it @@ -115,46 +118,13 @@ #define C10_HAS_CPP_ATTRIBUTE(x) (0) #endif -/// C10_NODISCARD - Warn if a type or return value is discarded. - -// Technically, we should check if __cplusplus > 201402L here, because -// [[nodiscard]] is only defined in C++17. However, some compilers -// we care about don't advertise being C++17 (e.g., clang), but -// support the attribute anyway. In fact, this is not just a good idea, -// it's the law: clang::warn_unused_result doesn't work on nvcc + clang -// and the best workaround for this case is to use [[nodiscard]] -// instead; see https://github.com/pytorch/pytorch/issues/13118 -// -// Note to future editors: if you have noticed that a compiler is -// misbehaving (e.g., it advertises support, but the support doesn't -// actually work, or it is emitting warnings). Some compilers which -// are strict about the matter include MSVC, which will complain: -// -// error C2429: attribute 'nodiscard' requires compiler flag '/std:c++latest' -// -// Exhibits: -// - MSVC 19.14: https://godbolt.org/z/Dzd7gn (requires /std:c++latest) -// - Clang 8.0.0: https://godbolt.org/z/3PYL4Z (always advertises support) -// - gcc 8.3: https://godbolt.org/z/4tLMQS (always advertises support) -#if C10_HAS_CPP_ATTRIBUTE(nodiscard) +#ifndef FBCODE_CAFFE2 +/// DEPRECATED: Warn if a type or return value is discarded. #define C10_NODISCARD [[nodiscard]] -// Workaround for llvm.org/PR23435, since clang 3.6 and below emit a spurious -// error when __has_cpp_attribute is given a scoped attribute in C mode. -#elif __cplusplus && C10_HAS_CPP_ATTRIBUTE(clang::warn_unused_result) -// TODO: It's possible this is still triggering -// https://github.com/pytorch/pytorch/issues/13118 on Windows; if it is, better -// fix it. -#define C10_NODISCARD [[clang::warn_unused_result]] -#else -#define C10_NODISCARD -#endif -// suppress an unused variable. -#if defined(_MSC_VER) && !defined(__clang__) -#define C10_UNUSED __pragma(warning(suppress : 4100 4101)) -#else -#define C10_UNUSED __attribute__((__unused__)) -#endif //_MSC_VER +/// DEPRECATED: Suppress an unused variable. +#define C10_UNUSED [[maybe_unused]] +#endif #if !defined(__has_attribute) #define __has_attribute(x) 0 @@ -475,66 +445,14 @@ __host__ __device__ #define C10_ALWAYS_INLINE_UNLESS_MOBILE C10_ALWAYS_INLINE #endif -#if defined(__CUDA_ARCH__) -#if defined(_MSC_VER) && defined(__CUDACC__) -#define CONSTEXPR_EXCEPT_WIN_CUDA const -#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA __host__ - -// Note [static constexpr char* members for windows NVCC] -// The Windows NVCC compiler doesn't handle static constexpr class members, -// although it's fixed in a later version. -// (see -// https://developercommunity.visualstudio.com/t/intellisense-error-c11-static-constexpr-member-ini/245425) -// -// If we want to ensure that our field is static under all builds, then we need -// to work around it specifically for windows NVCC by making it (a) const, (b) -// defined outside of the class definition We need to define it outside of the -// class definition because of the C++ standard; char* is not an integral type -// (see -// https://stackoverflow.com/questions/24278473/intellisense-a-member-of-type-const-char-const-cannot-have-an-in-class-in) -// -// So instead of this: -// struct Foo { -// static constexpr const char* name = "foo"; -// } -// In Windows NVCC, we end up with this: -// struct Foo { -// static const char* name; -// } -// const char* Foo::name = "foo"; -// -// This gives us a small perf hit for any code that wants to access these field -// members, but right now it isn't used in any perf-critical code paths. -#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \ - static const char* field; -#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) \ - const char* cls::field = val; -#else -#define CONSTEXPR_EXCEPT_WIN_CUDA constexpr -#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA __host__ - -#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \ - static constexpr const char* field = val; -#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) -#endif -#else -#if defined(_MSC_VER) && defined(__CUDACC__) -#define CONSTEXPR_EXCEPT_WIN_CUDA const -#define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA - -#define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \ - static const char* field; -#define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) \ - const char* cls::field = val; -#else +#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) #define CONSTEXPR_EXCEPT_WIN_CUDA constexpr #define C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA constexpr #define STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(field, val) \ - static constexpr const char* field = val; + static constexpr const char field[] = val; #define STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA(cls, field, val) -#endif -#endif +#endif // !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) #ifndef HAS_DEMANGLE #if defined(__ANDROID__) || defined(_WIN32) || defined(__EMSCRIPTEN__) diff --git a/c10/mobile/CPUCachingAllocator.cpp b/c10/mobile/CPUCachingAllocator.cpp index cafef1030f3eb..f881d454a5383 100644 --- a/c10/mobile/CPUCachingAllocator.cpp +++ b/c10/mobile/CPUCachingAllocator.cpp @@ -12,8 +12,7 @@ std::mutex CPUCachingAllocator::mutex_; ska::flat_hash_map CPUCachingAllocator::allocation_map_; inline void* CPUCachingAllocator::allocate_and_cache(const size_t bytes) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - void* ptr; + void* ptr = nullptr; try { ptr = c10::alloc_cpu(bytes); } catch (c10::Error&) { diff --git a/c10/mobile/CPUProfilingAllocator.cpp b/c10/mobile/CPUProfilingAllocator.cpp index 2fc569135e267..d01cdd2b1d24b 100644 --- a/c10/mobile/CPUProfilingAllocator.cpp +++ b/c10/mobile/CPUProfilingAllocator.cpp @@ -152,10 +152,8 @@ std::vector formulate_greedy_allocation_plan( create_and_sort_mem_events(allocation_sizes, allocation_lifetimes); uint64_t max_offset{0}; for (const auto& mem_event : mem_events) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint64_t alloc_offset; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint64_t new_offset, new_size; + uint64_t alloc_offset = 0; + uint64_t new_offset = 0, new_size = 0; if (mem_event.type == EventType::Allocate) { auto it = free_size_to_offset.lower_bound(mem_event.size); if (it == free_size_to_offset.end()) { diff --git a/c10/test/CMakeLists.txt b/c10/test/CMakeLists.txt index 7f2a61246c6c6..83b5b17f9c8a6 100644 --- a/c10/test/CMakeLists.txt +++ b/c10/test/CMakeLists.txt @@ -12,6 +12,7 @@ if(BUILD_TEST) target_link_libraries(${test_name} ${C10_LIB} gmock gtest gtest_main) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) endif() endforeach() diff --git a/c10/test/util/Half_test.cpp b/c10/test/util/Half_test.cpp index 1176837c06782..fc2a002f3a94a 100644 --- a/c10/test/util/Half_test.cpp +++ b/c10/test/util/Half_test.cpp @@ -41,17 +41,15 @@ float halfbits2float(unsigned short h) { unsigned short float2halfbits(float src) { unsigned x = c10::detail::fp32_to_bits(src); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables,cppcoreguidelines-avoid-magic-numbers) - unsigned u = (x & 0x7fffffff), remainder, shift, lsb, lsb_s1, lsb_m1; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - unsigned sign, exponent, mantissa; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) + unsigned u = (x & 0x7fffffff), shift = 0; // Get rid of +NaN/-NaN case first. if (u > 0x7f800000) { return 0x7fffU; } - sign = ((x >> 16) & 0x8000); + unsigned sign = ((x >> 16) & 0x8000); // Get rid of +Inf/-Inf, +0/-0. if (u > 0x477fefff) { @@ -61,8 +59,8 @@ unsigned short float2halfbits(float src) { return (sign | 0x0000); } - exponent = ((u >> 23) & 0xff); - mantissa = (u & 0x7fffff); + unsigned exponent = ((u >> 23) & 0xff); + unsigned mantissa = (u & 0x7fffff); if (exponent > 0x70) { shift = 13; @@ -72,12 +70,12 @@ unsigned short float2halfbits(float src) { exponent = 0; mantissa |= 0x800000; } - lsb = (1 << shift); - lsb_s1 = (lsb >> 1); - lsb_m1 = (lsb - 1); + unsigned lsb = (1 << shift); + unsigned lsb_s1 = (lsb >> 1); + unsigned lsb_m1 = (lsb - 1); // Round to nearest even. - remainder = (mantissa & lsb_m1); + unsigned remainder = (mantissa & lsb_m1); mantissa >>= shift; if (remainder > lsb_s1 || (remainder == lsb_s1 && (mantissa & 0x1))) { ++mantissa; diff --git a/c10/test/util/TypeIndex_test.cpp b/c10/test/util/TypeIndex_test.cpp index b44bbab356898..5979d92edd592 100644 --- a/c10/test/util/TypeIndex_test.cpp +++ b/c10/test/util/TypeIndex_test.cpp @@ -55,11 +55,11 @@ static_assert( ""); namespace test_top_level_name { -#if C10_TYPENAME_SUPPORTS_CONSTEXPR + static_assert( string_view::npos != get_fully_qualified_type_name().find("Dummy"), ""); -#endif + TEST(TypeIndex, TopLevelName) { EXPECT_NE( string_view::npos, get_fully_qualified_type_name().find("Dummy")); @@ -69,12 +69,11 @@ TEST(TypeIndex, TopLevelName) { namespace test_nested_name { struct Dummy final {}; -#if C10_TYPENAME_SUPPORTS_CONSTEXPR static_assert( string_view::npos != get_fully_qualified_type_name().find("test_nested_name::Dummy"), ""); -#endif + TEST(TypeIndex, NestedName) { EXPECT_NE( string_view::npos, @@ -87,7 +86,6 @@ template struct Outer final {}; struct Inner final {}; -#if C10_TYPENAME_SUPPORTS_CONSTEXPR static_assert( string_view::npos != get_fully_qualified_type_name>().find( @@ -98,7 +96,7 @@ static_assert( get_fully_qualified_type_name>().find( "test_type_template_parameter::Inner"), ""); -#endif + TEST(TypeIndex, TypeTemplateParameter) { EXPECT_NE( string_view::npos, @@ -115,12 +113,11 @@ namespace test_nontype_template_parameter { template struct Class final {}; -#if C10_TYPENAME_SUPPORTS_CONSTEXPR static_assert( string_view::npos != get_fully_qualified_type_name>().find("38474355"), ""); -#endif + TEST(TypeIndex, NonTypeTemplateParameter) { EXPECT_NE( string_view::npos, @@ -134,7 +131,6 @@ struct Type final { using type = const T*; }; -#if C10_TYPENAME_SUPPORTS_CONSTEXPR static_assert( string_view::npos != get_fully_qualified_type_name::type>().find("int"), @@ -151,7 +147,7 @@ static_assert( std::remove_pointer_t::type>>() .find("*"), ""); -#endif + TEST(TypeIndex, TypeComputationsAreResolved) { EXPECT_NE( string_view::npos, @@ -163,21 +159,21 @@ TEST(TypeIndex, TypeComputationsAreResolved) { EXPECT_EQ( string_view::npos, get_fully_qualified_type_name< - typename std::remove_pointer::type>::type>() + std::remove_pointer_t::type>>() .find("*")); } struct Functor final { std::string operator()(int64_t a, const Type& b) const; }; -#if C10_TYPENAME_SUPPORTS_CONSTEXPR + static_assert( // NOLINTNEXTLINE(misc-redundant-expression) get_fully_qualified_type_name&)>() == get_fully_qualified_type_name< typename c10::guts::infer_function_traits_t::func_type>(), ""); -#endif + TEST(TypeIndex, FunctionTypeComputationsAreResolved) { EXPECT_EQ( get_fully_qualified_type_name&)>(), @@ -189,7 +185,6 @@ TEST(TypeIndex, FunctionTypeComputationsAreResolved) { namespace test_function_arguments_and_returns { class Dummy final {}; -#if C10_TYPENAME_SUPPORTS_CONSTEXPR static_assert( string_view::npos != get_fully_qualified_type_name().find( @@ -200,7 +195,7 @@ static_assert( get_fully_qualified_type_name().find( "test_function_arguments_and_returns::Dummy"), ""); -#endif + TEST(TypeIndex, FunctionArgumentsAndReturns) { EXPECT_NE( string_view::npos, diff --git a/c10/test/util/bfloat16_test.cpp b/c10/test/util/bfloat16_test.cpp index 1c6ef27f90ea9..39f2214eef99b 100644 --- a/c10/test/util/bfloat16_test.cpp +++ b/c10/test/util/bfloat16_test.cpp @@ -7,17 +7,14 @@ namespace { float float_from_bytes(uint32_t sign, uint32_t exponent, uint32_t fraction) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t bytes; - bytes = 0; + uint32_t bytes = 0; bytes |= sign; bytes <<= 8; bytes |= exponent; bytes <<= 23; bytes |= fraction; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float res; + float res = 0; std::memcpy(&res, &bytes, sizeof(res)); return res; } @@ -160,8 +157,7 @@ TEST(BFloat16Math, NextAfterZero) { } float BinaryToFloat(uint32_t bytes) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - float res; + float res = 0; std::memcpy(&res, &bytes, sizeof(res)); return res; } diff --git a/c10/test/util/ordered_preserving_dict_test.cpp b/c10/test/util/ordered_preserving_dict_test.cpp index 29fde5c1ae394..2279f44867084 100644 --- a/c10/test/util/ordered_preserving_dict_test.cpp +++ b/c10/test/util/ordered_preserving_dict_test.cpp @@ -35,14 +35,12 @@ dict_int_int test_dict(dict_int_int& dict) { // erase via iterators auto begin = dict.begin(); - for (const auto i : c10::irange(20)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(20)) { begin++; } auto end = begin; - for (const auto i : c10::irange(20)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(20)) { erase_set.insert(end->first); end++; } @@ -136,13 +134,11 @@ TEST(OrderedPreservingDictTest, DictCollisions) { // erase a few entries via iterator auto begin = dict.begin(); - for (const auto j : c10::irange(10)) { - (void)j; // Suppress unused variable warning + for ([[maybe_unused]] const auto j : c10::irange(10)) { begin++; } auto end = begin; - for (const auto j : c10::irange(7)) { - (void)j; // Suppress unused variable warning + for ([[maybe_unused]] const auto j : c10::irange(7)) { erase_set.insert(end->first); end++; } diff --git a/c10/test/util/small_vector_test.cpp b/c10/test/util/small_vector_test.cpp index 93df2ab7ed00a..df8c7fb6a656f 100644 --- a/c10/test/util/small_vector_test.cpp +++ b/c10/test/util/small_vector_test.cpp @@ -127,11 +127,6 @@ class Constructable { friend bool operator==(const Constructable& c0, const Constructable& c1) { return c0.getValue() == c1.getValue(); } - - friend bool C10_UNUSED - operator!=(const Constructable& c0, const Constructable& c1) { - return c0.getValue() != c1.getValue(); - } }; int Constructable::numConstructorCalls; @@ -144,6 +139,7 @@ int Constructable::numMoveAssignmentCalls; struct NonCopyable { NonCopyable() = default; + ~NonCopyable() = default; NonCopyable(NonCopyable&&) noexcept = default; NonCopyable& operator=(NonCopyable&&) noexcept = default; @@ -204,13 +200,12 @@ class SmallVectorTest : public SmallVectorTestBase { VectorT otherVector; }; -typedef ::testing::Types< +using SmallVectorTestTypes = ::testing::Types< SmallVector, SmallVector, SmallVector, SmallVector, - SmallVector> - SmallVectorTestTypes; + SmallVector>; TYPED_TEST_SUITE(SmallVectorTest, SmallVectorTestTypes, ); // Constructor test. @@ -472,11 +467,11 @@ TYPED_TEST(SmallVectorTest, AppendNonIterTest) { } struct output_iterator { - typedef std::output_iterator_tag iterator_category; - typedef int value_type; - typedef int difference_type; - typedef value_type* pointer; - typedef value_type& reference; + using iterator_category = std::output_iterator_tag; + using value_type = int; + using difference_type = int; + using pointer = value_type*; + using reference = value_type&; operator int() { return 2; } @@ -821,7 +816,7 @@ class DualSmallVectorsTest> } }; -typedef ::testing::Types< +using DualSmallVectorTestTypes = ::testing::Types< // Small mode -> Small mode. std::pair, SmallVector>, // Small mode -> Big mode. @@ -829,8 +824,7 @@ typedef ::testing::Types< // Big mode -> Small mode. std::pair, SmallVector>, // Big mode -> Big mode. - std::pair, SmallVector>> - DualSmallVectorTestTypes; + std::pair, SmallVector>>; TYPED_TEST_SUITE(DualSmallVectorsTest, DualSmallVectorTestTypes, ); @@ -890,9 +884,12 @@ TEST(SmallVectorCustomTest, NoAssignTest) { struct MovedFrom { bool hasValue{true}; MovedFrom() = default; + ~MovedFrom() = default; + MovedFrom(const MovedFrom& m) = delete; MovedFrom(MovedFrom&& m) noexcept : hasValue(m.hasValue) { m.hasValue = false; } + MovedFrom& operator=(const MovedFrom& m) = delete; MovedFrom& operator=(MovedFrom&& m) noexcept { hasValue = m.hasValue; m.hasValue = false; @@ -924,6 +921,7 @@ struct EmplaceableArg { EmplaceableArg(EmplaceableArg& X) : State(X.State == EAS_Arg ? EAS_LValue : EAS_Failure) {} + ~EmplaceableArg() = default; explicit EmplaceableArg(bool) : State(EAS_Arg) {} EmplaceableArg& operator=(EmplaceableArg&&) = delete; @@ -939,6 +937,7 @@ struct Emplaceable { EmplaceableState State; Emplaceable() : State(ES_Emplaced) {} + ~Emplaceable() = default; template explicit Emplaceable(A0Ty&& A0) diff --git a/c10/util/ApproximateClock.cpp b/c10/util/ApproximateClock.cpp index 0bda220d83da9..755054c50bc9e 100644 --- a/c10/util/ApproximateClock.cpp +++ b/c10/util/ApproximateClock.cpp @@ -26,7 +26,7 @@ ApproximateClockToUnixTimeConverter::measurePair() { ApproximateClockToUnixTimeConverter::time_pairs ApproximateClockToUnixTimeConverter::measurePairs() { static constexpr auto n_warmup = 5; - for (C10_UNUSED const auto _ : c10::irange(n_warmup)) { + for ([[maybe_unused]] const auto _ : c10::irange(n_warmup)) { getApproximateTime(); static_cast(steady_clock_t::now()); } diff --git a/c10/util/ArrayRef.h b/c10/util/ArrayRef.h index 2a56e60832993..c977d7e92b2a6 100644 --- a/c10/util/ArrayRef.h +++ b/c10/util/ArrayRef.h @@ -76,13 +76,13 @@ class ArrayRef final { constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} /// Construct an ArrayRef from a pointer and length. - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* data, size_t length) + constexpr ArrayRef(const T* data, size_t length) : Data(data), Length(length) { debugCheckNullptrInvariant(); } /// Construct an ArrayRef from a range. - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef(const T* begin, const T* end) + constexpr ArrayRef(const T* begin, const T* end) : Data(begin), Length(end - begin) { debugCheckNullptrInvariant(); } @@ -162,6 +162,11 @@ class ArrayRef final { return reverse_iterator(begin()); } + /// Check if all elements in the array satisfy the given expression + constexpr bool allMatch(const std::function& pred) const { + return std::all_of(cbegin(), cend(), pred); + } + /// empty - Check if the array is empty. constexpr bool empty() const { return Length == 0; @@ -177,14 +182,14 @@ class ArrayRef final { } /// front - Get the first element. - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& front() const { + constexpr const T& front() const { TORCH_CHECK( !empty(), "ArrayRef: attempted to access front() of empty list"); return Data[0]; } /// back - Get the last element. - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& back() const { + constexpr const T& back() const { TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list"); return Data[Length - 1]; } @@ -195,8 +200,7 @@ class ArrayRef final { } /// slice(n, m) - Take M elements of the array starting at element N - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef slice(size_t N, size_t M) - const { + constexpr ArrayRef slice(size_t N, size_t M) const { TORCH_CHECK( N + M <= size(), "ArrayRef: invalid slice, N = ", @@ -209,7 +213,7 @@ class ArrayRef final { } /// slice(n) - Chop off the first N elements of the array. - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA ArrayRef slice(size_t N) const { + constexpr ArrayRef slice(size_t N) const { TORCH_CHECK( N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size()); return slice(N, size() - N); @@ -223,7 +227,7 @@ class ArrayRef final { } /// Vector compatibility - C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA const T& at(size_t Index) const { + constexpr const T& at(size_t Index) const { TORCH_CHECK( Index < Length, "ArrayRef: invalid index Index = ", diff --git a/c10/util/BFloat16-inl.h b/c10/util/BFloat16-inl.h index f3b05d0e3a660..10ab0c828d7a8 100644 --- a/c10/util/BFloat16-inl.h +++ b/c10/util/BFloat16-inl.h @@ -57,24 +57,6 @@ inline C10_HOST_DEVICE BFloat16::operator __nv_bfloat16() const { return *reinterpret_cast(&x); } #endif -#if defined(__HIPCC__) && defined(USE_ROCM) -// 6.2.0 introduced __hip_bfloat16_raw -#if defined(__BF16_HOST_DEVICE__) -inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) { - x = __hip_bfloat16_raw(value).x; -} -inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const { - return __hip_bfloat16(__hip_bfloat16_raw{x}); -} -#else // !defined(__BF16_HOST_DEVICE__) -inline C10_HOST_DEVICE BFloat16::BFloat16(const __hip_bfloat16& value) { - x = value.data; -} -inline C10_HOST_DEVICE BFloat16::operator __hip_bfloat16() const { - return __hip_bfloat16{x}; -} -#endif // !defined(__BF16_HOST_DEVICE__) -#endif // defined(__HIPCC__) && defined(USE_ROCM) #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) inline C10_HOST_DEVICE BFloat16::BFloat16( diff --git a/c10/util/BFloat16.h b/c10/util/BFloat16.h index 17326d81d7279..09d3051ab71c3 100644 --- a/c10/util/BFloat16.h +++ b/c10/util/BFloat16.h @@ -13,9 +13,6 @@ #if defined(__CUDACC__) && !defined(USE_ROCM) #include #endif -#if defined(__HIPCC__) && defined(USE_ROCM) -#include -#endif #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) #if defined(CL_SYCL_LANGUAGE_VERSION) @@ -110,10 +107,6 @@ struct alignas(2) BFloat16 { inline C10_HOST_DEVICE BFloat16(const __nv_bfloat16& value); explicit inline C10_HOST_DEVICE operator __nv_bfloat16() const; #endif -#if defined(__HIPCC__) && defined(USE_ROCM) - inline C10_HOST_DEVICE BFloat16(const __hip_bfloat16& value); - explicit inline C10_HOST_DEVICE operator __hip_bfloat16() const; -#endif #if defined(SYCL_EXT_ONEAPI_BFLOAT16_MATH_FUNCTIONS) inline C10_HOST_DEVICE BFloat16(const sycl::ext::oneapi::bfloat16& value); diff --git a/c10/util/Backtrace.cpp b/c10/util/Backtrace.cpp index d461267000bef..bfcacfd9740d1 100644 --- a/c10/util/Backtrace.cpp +++ b/c10/util/Backtrace.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -16,7 +17,10 @@ #endif #if SUPPORTS_BACKTRACE +C10_CLANG_DIAGNOSTIC_PUSH() +C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-dynamic-exception-spec") #include +C10_CLANG_DIAGNOSTIC_POP() #ifdef C10_ANDROID #include #include @@ -277,6 +281,7 @@ class GetBacktraceImpl { } private: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const bool skip_python_frames_; std::vector callstack_; }; diff --git a/c10/util/C++17.h b/c10/util/C++17.h index fe2044f507d4a..359774b203aa1 100644 --- a/c10/util/C++17.h +++ b/c10/util/C++17.h @@ -45,15 +45,6 @@ constexpr bool is_pod_v = is_pod::value; namespace guts { -template -std::enable_if_t< - !std::is_array_v && !std::is_array_v && - std::is_base_of_v, - std::unique_ptr> -make_unique_base(Args&&... args) { - return std::unique_ptr(new Child(std::forward(args)...)); -} - #if defined(__cpp_lib_apply) && !defined(__CUDA_ARCH__) && !defined(__HIP__) template @@ -69,21 +60,10 @@ C10_HOST_DEVICE inline constexpr decltype(auto) apply(F&& f, Tuple&& t) { // member functions. namespace detail { template -#if defined(_MSC_VER) -// MSVC has a problem with the decltype() return type, but it also doesn't need -// it -C10_HOST_DEVICE constexpr auto apply_impl( - F&& f, - Tuple&& t, - std::index_sequence) -#else -// GCC/Clang need the decltype() return type C10_HOST_DEVICE constexpr decltype(auto) apply_impl( F&& f, Tuple&& t, - std::index_sequence) -#endif -{ + std::index_sequence) { return std::forward(f)(std::get(std::forward(t))...); } } // namespace detail @@ -99,44 +79,8 @@ C10_HOST_DEVICE constexpr decltype(auto) apply(F&& f, Tuple&& t) { #endif -template -std::enable_if_t< - std::is_member_pointer_v>, - typename std::invoke_result_t> -invoke(Functor&& f, Args&&... args) { - return std::mem_fn(std::forward(f))(std::forward(args)...); -} - -template -std::enable_if_t< - !std::is_member_pointer_v>, - typename std::invoke_result_t> -invoke(Functor&& f, Args&&... args) { - return std::forward(f)(std::forward(args)...); -} - -namespace detail { -struct _identity final { - template - using type_identity = T; - - template - decltype(auto) operator()(T&& arg) { - return std::forward(arg); - } -}; - -template -struct function_takes_identity_argument : std::false_type {}; - -template -struct function_takes_identity_argument< - Func, - std::void_t()(_identity()))>> : std::true_type { -}; -} // namespace detail - } // namespace guts + } // namespace c10 #endif // C10_UTIL_CPP17_H_ diff --git a/c10/util/CallOnce.h b/c10/util/CallOnce.h index 04ad455e33133..f6b8af0bda6dc 100644 --- a/c10/util/CallOnce.h +++ b/c10/util/CallOnce.h @@ -1,12 +1,13 @@ #pragma once +#include +#include + #include +#include #include #include -#include -#include - namespace c10 { // custom c10 call_once implementation to avoid the deadlock in std::call_once. @@ -47,7 +48,7 @@ class once_flag { if (init_.load(std::memory_order_relaxed)) { return; } - c10::guts::invoke(std::forward(f), std::forward(args)...); + std::invoke(std::forward(f), std::forward(args)...); init_.store(true, std::memory_order_release); } diff --git a/c10/util/ConstexprCrc.h b/c10/util/ConstexprCrc.h index 0eec44d576e98..96f1113a14c8c 100644 --- a/c10/util/ConstexprCrc.h +++ b/c10/util/ConstexprCrc.h @@ -98,8 +98,10 @@ constexpr uint64_t crc64_table[] = { 0x29b7d047efec8728, }; -inline C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA uint64_t -crc64impl(uint64_t accumulator, const char* data, size_t size) { +inline constexpr uint64_t crc64impl( + uint64_t accumulator, + const char* data, + size_t size) { for (size_t i = 0; i < size; ++i) { accumulator = crc64_table[(accumulator ^ data[i]) & 0xFF] ^ (accumulator >> 8); @@ -116,12 +118,11 @@ struct crc64_t final : IdWrapper { }; // CRC64 with Jones coefficients and an init value of 0. -inline C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA crc64_t -crc64(const char* str, size_t size) { +inline constexpr crc64_t crc64(const char* str, size_t size) { return crc64_t{detail::crc64impl(0, str, size)}; } -inline C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA crc64_t crc64(c10::string_view str) { +inline constexpr crc64_t crc64(c10::string_view str) { return crc64(str.data(), str.size()); } } // namespace c10::util diff --git a/c10/util/Exception.h b/c10/util/Exception.h index 983079ba285cc..275526cf40082 100644 --- a/c10/util/Exception.h +++ b/c10/util/Exception.h @@ -658,12 +658,12 @@ namespace c10::detail { // Report a warning to the user only once. Accepts an arbitrary number of extra // arguments which are concatenated into the warning message using operator<< // -#define _TORCH_WARN_ONCE(...) \ - C10_UNUSED static const auto C10_ANONYMOUS_VARIABLE(torch_warn_once_) = \ - [&] { \ - TORCH_WARN(__VA_ARGS__); \ - return true; \ - }() +#define _TORCH_WARN_ONCE(...) \ + [[maybe_unused]] static const auto C10_ANONYMOUS_VARIABLE( \ + torch_warn_once_) = [&] { \ + TORCH_WARN(__VA_ARGS__); \ + return true; \ + }() #ifdef DISABLE_WARN #define TORCH_WARN_ONCE(...) ((void)0); diff --git a/c10/util/Gauge.h b/c10/util/Gauge.h index f92ecd986bee1..f505c037ebc96 100644 --- a/c10/util/Gauge.h +++ b/c10/util/Gauge.h @@ -36,6 +36,7 @@ class C10_API GaugeHandle { void record(int64_t value); private: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) detail::GaugeImpl& impl_; }; diff --git a/c10/util/Lazy.h b/c10/util/Lazy.h index 34424691a8d8b..ad778cc1108d6 100644 --- a/c10/util/Lazy.h +++ b/c10/util/Lazy.h @@ -29,7 +29,7 @@ class OptimisticLazy { } template - T& ensure(Factory&& factory) { + T& ensure(const Factory& factory) { if (T* value = value_.load(std::memory_order_acquire)) { return *value; } diff --git a/c10/util/Logging.cpp b/c10/util/Logging.cpp index 5c12a187b4351..9aeee67f1a1e1 100644 --- a/c10/util/Logging.cpp +++ b/c10/util/Logging.cpp @@ -220,6 +220,7 @@ void SetGlobalRank(int64_t rank) { void LogAPIUsage(const std::string& event) try { if (auto logger = GetAPIUsageLogger()) (*logger)(event); + // NOLINTNEXTLINE(bugprone-empty-catch) } catch (std::bad_function_call&) { // static destructor race } @@ -229,6 +230,7 @@ void LogAPIUsageMetadata( const std::map& metadata_map) try { if (auto logger = GetAPIUsageMetadataLogger()) (*logger)(context, metadata_map); + // NOLINTNEXTLINE(bugprone-empty-catch) } catch (std::bad_function_call&) { // static destructor race } @@ -236,6 +238,7 @@ void LogAPIUsageMetadata( void LogPyTorchDDPUsage(const DDPLoggingData& ddpData) try { if (auto logger = GetDDPUsageLogger()) (*logger)(ddpData); + // NOLINTNEXTLINE(bugprone-empty-catch) } catch (std::bad_function_call&) { // static destructor race } @@ -245,6 +248,7 @@ bool LogAPIUsageFakeReturn(const std::string& event) try { if (auto logger = GetAPIUsageLogger()) (*logger)(event); return true; + // NOLINTNEXTLINE(bugprone-empty-catch) } catch (std::bad_function_call&) { // static destructor race return true; diff --git a/c10/util/Logging.h b/c10/util/Logging.h index a3e4f23e9c58f..fac615d836fca 100644 --- a/c10/util/Logging.h +++ b/c10/util/Logging.h @@ -322,8 +322,8 @@ C10_API const std::unique_ptr& GetEventSampledHandler( * // Logs caller info with an arbitrary text event, if there is a usage. * C10_LOG_API_USAGE_ONCE("my_api"); */ -#define C10_LOG_API_USAGE_ONCE(...) \ - C10_UNUSED static bool C10_ANONYMOUS_VARIABLE(logFlag) = \ +#define C10_LOG_API_USAGE_ONCE(...) \ + [[maybe_unused]] static bool C10_ANONYMOUS_VARIABLE(logFlag) = \ ::c10::detail::LogAPIUsageFakeReturn(__VA_ARGS__); // API usage logging capabilities diff --git a/c10/util/Optional.h b/c10/util/Optional.h index 1c62bc480e5f4..cbb3a5abb47d0 100644 --- a/c10/util/Optional.h +++ b/c10/util/Optional.h @@ -20,6 +20,8 @@ using std::nullopt_t; // NOLINTNEXTLINE(misc-unused-using-decls) using std::optional; +#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) + namespace detail_ { // the call to convert(b) has return type A and converts b to type A iff b // decltype(b) is implicitly convertible to A @@ -29,7 +31,9 @@ constexpr U convert(U v) { } } // namespace detail_ template -constexpr T value_or_else(const std::optional& v, F&& func) { +[[deprecated( + "Please use std::optional::value_or instead of c10::value_or_else")]] constexpr T +value_or_else(const std::optional& v, F&& func) { static_assert( std::is_convertible_v, T>, "func parameters must be a callable that returns a type convertible to the value stored in the optional"); @@ -37,12 +41,17 @@ constexpr T value_or_else(const std::optional& v, F&& func) { } template -constexpr T value_or_else(std::optional&& v, F&& func) { +[[deprecated( + "Please use std::optional::value_or instead of c10::value_or_else")]] constexpr T +value_or_else(std::optional&& v, F&& func) { static_assert( std::is_convertible_v, T>, "func parameters must be a callable that returns a type convertible to the value stored in the optional"); return v.has_value() ? constexpr_move(std::move(v).contained_val()) : detail_::convert(std::forward(func)()); } + +#endif + } // namespace c10 #endif // C10_UTIL_OPTIONAL_H_ diff --git a/c10/util/SmallVector.h b/c10/util/SmallVector.h index cbcfbc52cb8ae..d45b8c8616f5f 100644 --- a/c10/util/SmallVector.h +++ b/c10/util/SmallVector.h @@ -81,7 +81,7 @@ class C10_API SmallVectorBase { return Capacity; } - C10_NODISCARD bool empty() const { + [[nodiscard]] bool empty() const { return !Size; } @@ -710,7 +710,7 @@ class SmallVectorImpl : public SmallVectorTemplateBase { this->set_size(this->size() - NumItems); } - C10_NODISCARD T pop_back_val() { + [[nodiscard]] T pop_back_val() { T Result = ::std::move(this->back()); this->pop_back(); return Result; diff --git a/c10/util/StringUtil.h b/c10/util/StringUtil.h index 88a91c84ef0fd..8289fe453f45f 100644 --- a/c10/util/StringUtil.h +++ b/c10/util/StringUtil.h @@ -124,7 +124,7 @@ inline std::string Join(const std::string& delimiter, const Container& v) { for (auto i = v.begin(); i != v.end(); ++i, --cnt) { s << (*i) << (cnt ? delimiter : ""); } - return s.str(); + return std::move(s).str(); } // Replace all occurrences of "from" substring to "to" string. diff --git a/c10/util/ThreadLocal.h b/c10/util/ThreadLocal.h index 850bb5d4c4269..c6f3d6d874b5c 100644 --- a/c10/util/ThreadLocal.h +++ b/c10/util/ThreadLocal.h @@ -115,7 +115,10 @@ class ThreadLocal { explicit ThreadLocal(Accessor accessor) : accessor_(accessor) {} ThreadLocal(const ThreadLocal&) = delete; + ThreadLocal(ThreadLocal&&) noexcept = default; ThreadLocal& operator=(const ThreadLocal&) = delete; + ThreadLocal& operator=(ThreadLocal&&) noexcept = default; + ~ThreadLocal() = default; Type& get() { return *accessor_(); diff --git a/c10/util/ThreadLocalDebugInfo.h b/c10/util/ThreadLocalDebugInfo.h index bea8c5f27ac82..3d26dd44f6a52 100644 --- a/c10/util/ThreadLocalDebugInfo.h +++ b/c10/util/ThreadLocalDebugInfo.h @@ -74,6 +74,8 @@ class C10_API DebugInfoGuard { DebugInfoGuard(const DebugInfoGuard&) = delete; DebugInfoGuard(DebugInfoGuard&&) = delete; + DebugInfoGuard& operator=(const DebugInfoGuard&) = delete; + DebugInfoGuard& operator=(DebugInfoGuard&&) = delete; private: bool active_ = false; diff --git a/c10/util/TypeIndex.h b/c10/util/TypeIndex.h index 75b672d4a183f..d4af28daf52be 100644 --- a/c10/util/TypeIndex.h +++ b/c10/util/TypeIndex.h @@ -9,56 +9,12 @@ #include #include -namespace c10::util { - -// TODO Make it work for more compilers - -// Intel compiler works -#if defined(__INTEL_COMPILER) -#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0 -#define C10_TYPENAME_CONSTEXPR - -// Clang works -#elif defined(__clang__) - -// except for NVCC -#if defined(__CUDACC__) -#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0 -#define C10_TYPENAME_CONSTEXPR -#else +#if !defined(FBCODE_CAFFE2) && !defined(C10_NODEPRECATED) #define C10_TYPENAME_SUPPORTS_CONSTEXPR 1 #define C10_TYPENAME_CONSTEXPR constexpr #endif -// Windows works -#elif defined(_MSC_VER) - -// except for NVCC -#if defined(__CUDACC__) -#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0 -#define C10_TYPENAME_CONSTEXPR -#else -#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1 -#define C10_TYPENAME_CONSTEXPR constexpr -#endif - -// GCC works -#elif defined(__GNUC__) - -// except when gcc < 9 -#if (__GNUC__ < 9) || defined(__CUDACC__) -#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0 -#define C10_TYPENAME_CONSTEXPR -#else -#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1 -#define C10_TYPENAME_CONSTEXPR constexpr -#endif - -// some other compiler we don't know about -#else -#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1 -#define C10_TYPENAME_CONSTEXPR constexpr -#endif +namespace c10::util { struct type_index final : IdWrapper { constexpr explicit type_index(uint64_t checksum) : IdWrapper(checksum) {} @@ -76,17 +32,6 @@ struct type_index final : IdWrapper { namespace detail { -#if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && \ - __GNUC__ < 5 -// Getting __PRETTY_FUNCTION__ at compile time only works with GCC >= 5 -#error "You're running a too old version of GCC. We need GCC 5 or later." -#endif - -#if defined(__clang__) && __clang_major__ < 4 -// Getting __PRETTY_FUNCTION__ at compile time only works with Clang >= 4 -#error "You're running a too old version of Clang. We need Clang 4 or later." -#endif - inline constexpr string_view extract( string_view prefix, string_view suffix, @@ -101,7 +46,7 @@ inline constexpr string_view extract( } template -inline C10_TYPENAME_CONSTEXPR c10::string_view fully_qualified_type_name_impl() { +inline constexpr c10::string_view fully_qualified_type_name_impl() { #if defined(_MSC_VER) && !defined(__clang__) #if defined(__NVCC__) return extract( @@ -121,11 +66,7 @@ inline C10_TYPENAME_CONSTEXPR c10::string_view fully_qualified_type_name_impl() __PRETTY_FUNCTION__); #elif defined(__GNUC__) return extract( -#if C10_TYPENAME_SUPPORTS_CONSTEXPR "constexpr c10::string_view c10::util::detail::fully_qualified_type_name_impl() [with T = ", -#else - "c10::string_view c10::util::detail::fully_qualified_type_name_impl() [with T = ", -#endif "; c10::string_view = c10::basic_string_view]", __PRETTY_FUNCTION__); #endif @@ -181,14 +122,8 @@ inline constexpr type_index get_type_index() { #endif template -inline C10_TYPENAME_CONSTEXPR string_view -get_fully_qualified_type_name() noexcept { -#if C10_TYPENAME_SUPPORTS_CONSTEXPR - constexpr -#else - static -#endif - string_view name = detail::fully_qualified_type_name_impl(); +inline constexpr string_view get_fully_qualified_type_name() noexcept { + constexpr string_view name = detail::fully_qualified_type_name_impl(); return name; } } // namespace c10::util diff --git a/c10/util/UniqueVoidPtr.h b/c10/util/UniqueVoidPtr.h index f82de8c7059dc..175697f7f63b6 100644 --- a/c10/util/UniqueVoidPtr.h +++ b/c10/util/UniqueVoidPtr.h @@ -69,7 +69,7 @@ class UniqueVoidPtr { std::unique_ptr&& move_context() { return std::move(ctx_); } - C10_NODISCARD bool compare_exchange_deleter( + [[nodiscard]] bool compare_exchange_deleter( DeleterFnPtr expected_deleter, DeleterFnPtr new_deleter) { if (get_deleter() != expected_deleter) diff --git a/c10/util/WaitCounter.cpp b/c10/util/WaitCounter.cpp index 3941942dfb350..1edf4fee29f04 100644 --- a/c10/util/WaitCounter.cpp +++ b/c10/util/WaitCounter.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -110,7 +111,7 @@ class WaitCounterImpl { return ctxs; } - void stop(SmallVector&& ctxs) noexcept { + void stop(const SmallVector& ctxs) noexcept { auto now = std::chrono::steady_clock::now(); assert(ctxs.size() == backends_.size()); for (size_t i = 0; i < ctxs.size(); ++i) { @@ -155,7 +156,7 @@ WaitCounterHandle::WaitGuard WaitCounterHandle::start() { return WaitCounterHandle::WaitGuard(*this, impl_.start()); } -void WaitCounterHandle::stop(SmallVector&& ctxs) { - return impl_.stop(std::move(ctxs)); +void WaitCounterHandle::stop(const SmallVector& ctxs) { + return impl_.stop(ctxs); } } // namespace c10::monitor diff --git a/c10/util/WaitCounter.h b/c10/util/WaitCounter.h index 504e88720a9c1..193740cb10dbf 100644 --- a/c10/util/WaitCounter.h +++ b/c10/util/WaitCounter.h @@ -2,7 +2,6 @@ #include #include -#include #include #include @@ -61,7 +60,7 @@ class C10_API WaitCounterHandle { void stop() { if (auto handle = std::exchange(handle_, nullptr)) { - handle->stop(std::move(ctxs_)); + handle->stop(ctxs_); } } @@ -81,8 +80,9 @@ class C10_API WaitCounterHandle { private: // Stops the waiter. Each start() call should be matched by exactly one stop() // call. - void stop(SmallVector&& ctxs); + void stop(const SmallVector& ctxs); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) detail::WaitCounterImpl& impl_; }; } // namespace c10::monitor diff --git a/c10/util/env.cpp b/c10/util/env.cpp index c3d7e38f6ea6f..dcc969ac381ba 100644 --- a/c10/util/env.cpp +++ b/c10/util/env.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include namespace c10::utils { diff --git a/c10/util/string_utils.h b/c10/util/string_utils.h index 92af736452aba..61b5df3801559 100644 --- a/c10/util/string_utils.h +++ b/c10/util/string_utils.h @@ -2,6 +2,8 @@ #include +#if !defined(FBCODE_CAFFE2) && !defined(C10_NO_DEPRECATED) + namespace c10 { // NOLINTNEXTLINE(misc-unused-using-decls) @@ -16,3 +18,5 @@ using std::stoull; using std::to_string; } // namespace c10 + +#endif diff --git a/c10/util/string_view.h b/c10/util/string_view.h index 136e3cd154ecf..083b4ef5449e5 100644 --- a/c10/util/string_view.h +++ b/c10/util/string_view.h @@ -26,6 +26,7 @@ namespace c10 { * std::char_traits if we wanted to use it with our constexpr basic_string_view. */ template +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) class basic_string_view final { public: using value_type = CharT; @@ -149,7 +150,7 @@ class basic_string_view final { return std::numeric_limits::max(); } - C10_NODISCARD constexpr bool empty() const noexcept { + [[nodiscard]] constexpr bool empty() const noexcept { return size() == 0; } diff --git a/c10/util/typeid.h b/c10/util/typeid.h index 2c6ac38882f50..13f8a2adec085 100644 --- a/c10/util/typeid.h +++ b/c10/util/typeid.h @@ -71,7 +71,7 @@ class C10_API TypeIdentifier final * is generated during run-time. Do NOT serialize the id for storage. */ template - static C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA TypeIdentifier Get() noexcept { + static constexpr TypeIdentifier Get() noexcept { return TypeIdentifier(c10::util::get_type_index()); } @@ -328,6 +328,7 @@ class C10_API TypeMeta final { * type, use TypeMeta::Make(). */ TypeMeta() noexcept; + ~TypeMeta() = default; /** * Copy constructor. @@ -339,6 +340,7 @@ class C10_API TypeMeta final { */ TypeMeta& operator=(const TypeMeta& src) noexcept = default; + TypeMeta& operator=(TypeMeta&& src) noexcept = default; TypeMeta(TypeMeta&& rhs) noexcept = default; inline TypeMeta& operator=(ScalarType scalar_type) noexcept { @@ -423,7 +425,7 @@ class C10_API TypeMeta final { // Below are static functions that can be called by passing a specific type. template - static C10_HOST_CONSTEXPR_EXCEPT_WIN_CUDA TypeIdentifier Id() noexcept { + static constexpr TypeIdentifier Id() noexcept { return TypeIdentifier::Get(); } diff --git a/c10/xpu/test/impl/XPUStreamTest.cpp b/c10/xpu/test/impl/XPUStreamTest.cpp index 6cbe3ae672158..eb748430a9a5e 100644 --- a/c10/xpu/test/impl/XPUStreamTest.cpp +++ b/c10/xpu/test/impl/XPUStreamTest.cpp @@ -115,7 +115,7 @@ TEST(XPUStreamTest, StreamPoolRoundRobinTest) { } std::vector streams{}; - for (C10_UNUSED const auto _ : c10::irange(200)) { + for ([[maybe_unused]] const auto _ : c10::irange(200)) { streams.emplace_back(c10::xpu::getStreamFromPool()); } diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index f25286d5a6fe4..d77a726b41e5e 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -562,6 +562,7 @@ if(USE_CUDA) ${TORCH_SRC_DIR}/csrc/distributed/c10d/intra_node_comm.cpp ${TORCH_SRC_DIR}/csrc/distributed/c10d/CudaDMAConnectivity.cpp ${TORCH_SRC_DIR}/csrc/distributed/c10d/CUDASymmetricMemory.cu + ${TORCH_SRC_DIR}/csrc/distributed/c10d/ProcessGroupNCCL.cpp PROPERTIES COMPILE_FLAGS "-DPYTORCH_C10_DRIVER_API_SUPPORTED=1" ) endif() @@ -606,7 +607,7 @@ if(USE_ROCM) # caffe2_nvrtc's stubs to driver APIs are useful for HIP. # See NOTE [ ATen NVRTC Stub and HIP ] add_library(caffe2_nvrtc SHARED ${ATen_NVRTC_STUB_SRCS}) - target_link_libraries(caffe2_nvrtc ${PYTORCH_HIP_LIBRARIES} ${ROCM_HIPRTC_LIB}) + target_link_libraries(caffe2_nvrtc hip::amdhip64 hiprtc::hiprtc) target_include_directories(caffe2_nvrtc PRIVATE ${CMAKE_BINARY_DIR}) target_compile_definitions(caffe2_nvrtc PRIVATE USE_ROCM __HIP_PLATFORM_AMD__) install(TARGETS caffe2_nvrtc DESTINATION "${TORCH_INSTALL_LIB_DIR}") @@ -770,6 +771,10 @@ endif() if("${CMAKE_CXX_COMPILER_ID}" MATCHES "Clang" AND NOT USE_IOS AND NOT USE_COREML_DELEGATE) target_compile_options_if_supported(torch_cpu "-Wmissing-prototypes") target_compile_options_if_supported(torch_cpu "-Werror=missing-prototypes") + if(TARGET torch_cuda) + target_compile_options_if_supported(torch_cuda "-Wmissing-prototypes") + target_compile_options_if_supported(torch_cuda "-Werror=missing-prototypes") + endif() get_target_property(TORCH_CPU_SOURCES torch_cpu SOURCES) foreach(generated_file IN LISTS GENERATED_CXX_TORCH) set_source_files_properties(${generated_file} PROPERTIES COMPILE_OPTIONS "-Wno-missing-prototypes;-Wno-error=missing-prototypes") @@ -1325,6 +1330,7 @@ if(USE_ROCM) ${ROCM_SOURCE_DIR}/hcc/include ${ROCM_SOURCE_DIR}/rocblas/include ${ROCM_SOURCE_DIR}/hipsparse/include + ${ROCM_SOURCE_DIR}/include/rccl/ ) if(USE_FLASH_ATTENTION) target_compile_definitions(torch_hip PRIVATE USE_FLASH_ATTENTION) @@ -1713,7 +1719,10 @@ if(BUILD_TEST) endif() else() add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}") - target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library sleef gtest_main) + target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library gtest_main) + if(NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "ARM64") + target_link_libraries(${test_name}_${CPU_CAPABILITY} sleef) + endif() endif() target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE $) target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE $) @@ -1739,6 +1748,7 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) @@ -1759,6 +1769,7 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) @@ -1780,6 +1791,7 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) @@ -1801,6 +1813,7 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) endif() endforeach() @@ -1815,6 +1828,7 @@ if(BUILD_TEST) target_include_directories(${test_name} PRIVATE ${Caffe2_CPU_INCLUDE}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) @@ -1834,6 +1848,7 @@ if(BUILD_TEST) target_compile_options(${test_name} PRIVATE ${HIP_CXX_FLAGS}) add_test(NAME ${test_name} COMMAND $) if(INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS ${test_name} DESTINATION test) endif() endforeach() diff --git a/caffe2/perfkernels/CMakeLists.txt b/caffe2/perfkernels/CMakeLists.txt index 83e4a5f915d11..1b46916cb9276 100644 --- a/caffe2/perfkernels/CMakeLists.txt +++ b/caffe2/perfkernels/CMakeLists.txt @@ -10,9 +10,13 @@ endif() file(GLOB common_srcs *.cc) file(GLOB avx_srcs *_avx.cc) file(GLOB avx2_srcs *_avx2.cc) -# exclude avx and avx2 srcs from common_srcs +file(GLOB avx512_srcs *_avx512.cc) +file(GLOB sve_srcs *_sve.cc) +# exclude avx, avx2, avx512, and sve srcs from common_srcs exclude(common_srcs "${common_srcs}" ${avx_srcs}) exclude(common_srcs "${common_srcs}" ${avx2_srcs}) +exclude(common_srcs "${common_srcs}" ${avx512_srcs}) +exclude(common_srcs "${common_srcs}" ${sve_srcs}) # We will always build common srcs. set(Caffe2_CPU_SRCS ${Caffe2_CPU_SRCS} ${common_srcs}) @@ -42,6 +46,22 @@ if(CXX_AVX2_FOUND) "Caffe2_perfkernels_avx2_interface") endif() +# We will only build the SVE perfkernel files if the compiler supports SVE +# extensions. +if(CXX_SVE_FOUND) + add_library(Caffe2_perfkernels_sve STATIC ${sve_srcs}) + target_link_libraries(Caffe2_perfkernels_sve PRIVATE c10) + install(TARGETS Caffe2_perfkernels_sve + ARCHIVE DESTINATION "${CMAKE_INSTALL_LIBDIR}") + + target_compile_options(Caffe2_perfkernels_sve PRIVATE "-march=armv8-a+sve") + + caffe2_interface_library( + Caffe2_perfkernels_sve Caffe2_perfkernels_sve_interface) + list(APPEND + Caffe2_DEPENDENCY_WHOLE_LINK_LIBS "Caffe2_perfkernels_sve_interface") +endif() + # TODO(jiayq): currently, we only implement the very base files for the # perfkernels. This is because to implement avx and avx2 files, we actually # need to set up different compilation units and this is a bit more involving diff --git a/caffe2/perfkernels/common.h b/caffe2/perfkernels/common.h index 6fed9e1d6d06c..6e069861b28d2 100644 --- a/caffe2/perfkernels/common.h +++ b/caffe2/perfkernels/common.h @@ -61,9 +61,8 @@ In foo.cc, do: // we use cpuinfo to identify cpu support and run the proper functions. #pragma once - -#if defined(CAFFE2_PERF_WITH_AVX512) || defined(CAFFE2_PERF_WITH_AVX2) \ - || defined(CAFFE2_PERF_WITH_AVX) +#if defined(CAFFE2_PERF_WITH_SVE) || defined(CAFFE2_PERF_WITH_AVX512) || \ + defined(CAFFE2_PERF_WITH_AVX2) || defined(CAFFE2_PERF_WITH_AVX) #include #endif @@ -72,6 +71,18 @@ In foo.cc, do: #define BASE_DO(funcname, ...) return funcname##__base(__VA_ARGS__); +#ifdef CAFFE2_PERF_WITH_SVE +#define SVE_DO(funcname, ...) \ + { \ + static const bool isDo = cpuinfo_initialize() && cpuinfo_has_arm_sve(); \ + if (isDo) { \ + return funcname##__sve(__VA_ARGS__); \ + } \ + } +#else // CAFFE2_PERF_WITH_SVE +#define SVE_DO(funcname, ...) +#endif // CAFFE2_PERF_WITH_SVE + #ifdef CAFFE2_PERF_WITH_AVX512 #define AVX512_DO(funcname, ...) \ { \ diff --git a/caffe2/perfkernels/common_sve.cc b/caffe2/perfkernels/common_sve.cc new file mode 100644 index 0000000000000..03b0bf983c80d --- /dev/null +++ b/caffe2/perfkernels/common_sve.cc @@ -0,0 +1,22 @@ +// This file is here merely to check that the flags are not mixed up: for +// example, if your compiler did not specify -march=armv8-a+sve, you should not +// provide the CAFFE2_PERF_WITH_SVE macro. + +#include "caffe2/core/common.h" + +#ifdef CAFFE2_PERF_WITH_SVE +#ifndef __ARM_FEATURE_SVE +#error( \ + "You found a build system error: CAFFE2_PERF_WITH_SVE is defined" \ + "but __ARM_FEATURE_SVE is not defined (via e.g. -march=armv8-a+sve)."); +#endif // __ARM_FEATURE_SVE +#endif // CAFFE2_PERF_WITH_SVE + +#ifdef __ARM_FEATURE_SVE +#ifndef CAFFE2_PERF_WITH_SVE +#error( \ + "You found a build system error: __SVE__ is defined \ + (via e.g. -march=armv8-a+sve) " \ + "but CAFFE2_PERF_WITH_SVE is not defined."); +#endif // CAFFE2_PERF_WITH_SVE +#endif diff --git a/caffe2/perfkernels/embedding_lookup_idx.cc b/caffe2/perfkernels/embedding_lookup_idx.cc index 5fcf71016aea6..7c62d9e883fd6 100644 --- a/caffe2/perfkernels/embedding_lookup_idx.cc +++ b/caffe2/perfkernels/embedding_lookup_idx.cc @@ -88,7 +88,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const IndexType* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ @@ -113,6 +113,9 @@ static bool EmbeddingLookupGenericSlowIdx( decltype( \ EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \ EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \ + decltype( \ + EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \ + EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__sve; \ bool \ EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \ const int64_t block_size, \ @@ -121,7 +124,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const IndexType* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ @@ -131,6 +134,19 @@ static bool EmbeddingLookupGenericSlowIdx( } else { \ CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); \ } \ + SVE_DO( \ + EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \ + block_size, \ + output_size, \ + index_size, \ + data_size, \ + input, \ + indices, \ + offsets, \ + weights, \ + scale_bias, \ + normalize_by_lengths, \ + out); \ AVX2_FMA_DO( \ EmbeddingLookupIdx_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \ block_size, \ @@ -166,7 +182,7 @@ static bool EmbeddingLookupGenericSlowIdx( const int64_t data_size, \ const InType* input, \ const IndexType* indices, \ - const IndexType* offsets, \ + const IndexType* offsets, \ const float* weights, \ const float* scale_bias, \ bool normalize_by_lengths, \ diff --git a/caffe2/perfkernels/embedding_lookup_idx_sve.cc b/caffe2/perfkernels/embedding_lookup_idx_sve.cc new file mode 100644 index 0000000000000..873823536b55a --- /dev/null +++ b/caffe2/perfkernels/embedding_lookup_idx_sve.cc @@ -0,0 +1,6769 @@ +//// -------------------------- +//// ATTENTION: +//// THIS CODE IS AUTOGENERATED +//// BY sve_emblookup_codegen.py +//// DO NOT MODIFY!!! +//// -------------------------- + +#include +#include +#include +#include +#include +namespace caffe2 { + +template +static bool EmbeddingLookupIdx_int32_t_float_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const float* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + vsum4 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); + vsum5 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); + vsum6 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); + vsum7 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); + vsum8 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); + vsum9 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); + vsum10 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); + vsum11 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); + vsum12 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); + vsum13 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); + vsum14 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); + vsum15 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); + vsum16 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[16 * vLen]), vsum16); + vsum17 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[17 * vLen]), vsum17); + vsum18 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[18 * vLen]), vsum18); + vsum19 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[19 * vLen]), vsum19); + vsum20 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[20 * vLen]), vsum20); + vsum21 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[21 * vLen]), vsum21); + vsum22 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[22 * vLen]), vsum22); + vsum23 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[23 * vLen]), vsum23); + vsum24 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[24 * vLen]), vsum24); + vsum25 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[25 * vLen]), vsum25); + vsum26 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[26 * vLen]), vsum26); + vsum27 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[27 * vLen]), vsum27); + vsum28 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[28 * vLen]), vsum28); + vsum29 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[29 * vLen]), vsum29); + vsum30 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[30 * vLen]), vsum30); + vsum31 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[31 * vLen]), vsum31); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + vsum4 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); + vsum5 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); + vsum6 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); + vsum7 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); + vsum8 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); + vsum9 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); + vsum10 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); + vsum11 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); + vsum12 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); + vsum13 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); + vsum14 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); + vsum15 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + vsum4 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); + vsum5 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); + vsum6 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); + vsum7 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, vwgt, svld1_f32(pg, &ip[k]), svld1_f32(pg, &op[k]))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int32_t_float_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const float* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_float_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int32_t_float_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const float* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_float_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int64_t_float_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const float* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + vsum4 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); + vsum5 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); + vsum6 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); + vsum7 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); + vsum8 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); + vsum9 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); + vsum10 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); + vsum11 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); + vsum12 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); + vsum13 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); + vsum14 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); + vsum15 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); + vsum16 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[16 * vLen]), vsum16); + vsum17 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[17 * vLen]), vsum17); + vsum18 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[18 * vLen]), vsum18); + vsum19 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[19 * vLen]), vsum19); + vsum20 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[20 * vLen]), vsum20); + vsum21 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[21 * vLen]), vsum21); + vsum22 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[22 * vLen]), vsum22); + vsum23 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[23 * vLen]), vsum23); + vsum24 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[24 * vLen]), vsum24); + vsum25 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[25 * vLen]), vsum25); + vsum26 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[26 * vLen]), vsum26); + vsum27 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[27 * vLen]), vsum27); + vsum28 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[28 * vLen]), vsum28); + vsum29 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[29 * vLen]), vsum29); + vsum30 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[30 * vLen]), vsum30); + vsum31 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[31 * vLen]), vsum31); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + vsum4 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); + vsum5 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); + vsum6 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); + vsum7 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); + vsum8 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[8 * vLen]), vsum8); + vsum9 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[9 * vLen]), vsum9); + vsum10 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[10 * vLen]), vsum10); + vsum11 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[11 * vLen]), vsum11); + vsum12 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[12 * vLen]), vsum12); + vsum13 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[13 * vLen]), vsum13); + vsum14 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[14 * vLen]), vsum14); + vsum15 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[15 * vLen]), vsum15); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + vsum4 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[4 * vLen]), vsum4); + vsum5 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[5 * vLen]), vsum5); + vsum6 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[6 * vLen]), vsum6); + vsum7 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[7 * vLen]), vsum7); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + vsum2 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[2 * vLen]), vsum2); + vsum3 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[3 * vLen]), vsum3); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[0 * vLen]), vsum0); + vsum1 = + svmad_f32_x(svAll, vwgt, svld1_f32(svAll, &ip[1 * vLen]), vsum1); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const float* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, vwgt, svld1_f32(pg, &ip[k]), svld1_f32(pg, &op[k]))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int64_t_float_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const float* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_float_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int64_t_float_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const float* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_float_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int32_t_half_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::Half* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])))), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])))), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])))), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])))), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])))), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])))), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])))), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])))), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])))), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])))), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])))), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])))), + vsum15); + vsum16 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[16 * vLen])))), + vsum16); + vsum17 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[17 * vLen])))), + vsum17); + vsum18 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[18 * vLen])))), + vsum18); + vsum19 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[19 * vLen])))), + vsum19); + vsum20 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[20 * vLen])))), + vsum20); + vsum21 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[21 * vLen])))), + vsum21); + vsum22 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[22 * vLen])))), + vsum22); + vsum23 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[23 * vLen])))), + vsum23); + vsum24 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[24 * vLen])))), + vsum24); + vsum25 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[25 * vLen])))), + vsum25); + vsum26 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[26 * vLen])))), + vsum26); + vsum27 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[27 * vLen])))), + vsum27); + vsum28 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[28 * vLen])))), + vsum28); + vsum29 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[29 * vLen])))), + vsum29); + vsum30 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[30 * vLen])))), + vsum30); + vsum31 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[31 * vLen])))), + vsum31); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])))), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])))), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])))), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])))), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])))), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])))), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])))), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])))), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])))), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])))), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])))), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])))), + vsum15); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])))), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])))), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])))), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])))), + vsum7); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, + vwgt, + svcvt_f32_f16_x( + pg, + svreinterpret_f16_u32(svld1uh_u32( + pg, reinterpret_cast(&ip[k])))), + svld1_f32(pg, &op[k]))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int32_t_half_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::Half* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_half_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int32_t_half_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::Half* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_half_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int64_t_half_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::Half* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])))), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])))), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])))), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])))), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])))), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])))), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])))), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])))), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])))), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])))), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])))), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])))), + vsum15); + vsum16 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[16 * vLen])))), + vsum16); + vsum17 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[17 * vLen])))), + vsum17); + vsum18 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[18 * vLen])))), + vsum18); + vsum19 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[19 * vLen])))), + vsum19); + vsum20 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[20 * vLen])))), + vsum20); + vsum21 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[21 * vLen])))), + vsum21); + vsum22 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[22 * vLen])))), + vsum22); + vsum23 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[23 * vLen])))), + vsum23); + vsum24 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[24 * vLen])))), + vsum24); + vsum25 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[25 * vLen])))), + vsum25); + vsum26 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[26 * vLen])))), + vsum26); + vsum27 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[27 * vLen])))), + vsum27); + vsum28 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[28 * vLen])))), + vsum28); + vsum29 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[29 * vLen])))), + vsum29); + vsum30 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[30 * vLen])))), + vsum30); + vsum31 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[31 * vLen])))), + vsum31); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])))), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])))), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])))), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])))), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])))), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])))), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])))), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])))), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])))), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])))), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])))), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])))), + vsum15); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])))), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])))), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])))), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])))), + vsum7); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])))), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])))), + vsum3); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])))), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_f16_x( + svAll, + svreinterpret_f16_u32(svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])))), + vsum1); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::Half* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, + vwgt, + svcvt_f32_f16_x( + pg, + svreinterpret_f16_u32(svld1uh_u32( + pg, reinterpret_cast(&ip[k])))), + svld1_f32(pg, &op[k]))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int64_t_half_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::Half* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_half_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int64_t_half_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::Half* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_half_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int32_t_bfloat16_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::BFloat16* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])), + 16)), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])), + 16)), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])), + 16)), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])), + 16)), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])), + 16)), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])), + 16)), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])), + 16)), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])), + 16)), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])), + 16)), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])), + 16)), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])), + 16)), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])), + 16)), + vsum15); + vsum16 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[16 * vLen])), + 16)), + vsum16); + vsum17 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[17 * vLen])), + 16)), + vsum17); + vsum18 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[18 * vLen])), + 16)), + vsum18); + vsum19 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[19 * vLen])), + 16)), + vsum19); + vsum20 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[20 * vLen])), + 16)), + vsum20); + vsum21 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[21 * vLen])), + 16)), + vsum21); + vsum22 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[22 * vLen])), + 16)), + vsum22); + vsum23 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[23 * vLen])), + 16)), + vsum23); + vsum24 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[24 * vLen])), + 16)), + vsum24); + vsum25 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[25 * vLen])), + 16)), + vsum25); + vsum26 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[26 * vLen])), + 16)), + vsum26); + vsum27 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[27 * vLen])), + 16)), + vsum27); + vsum28 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[28 * vLen])), + 16)), + vsum28); + vsum29 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[29 * vLen])), + 16)), + vsum29); + vsum30 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[30 * vLen])), + 16)), + vsum30); + vsum31 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[31 * vLen])), + 16)), + vsum31); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])), + 16)), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])), + 16)), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])), + 16)), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])), + 16)), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])), + 16)), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])), + 16)), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])), + 16)), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])), + 16)), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])), + 16)), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])), + 16)), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])), + 16)), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])), + 16)), + vsum15); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])), + 16)), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])), + 16)), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])), + 16)), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])), + 16)), + vsum7); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + pg, + svld1uh_u32( + pg, reinterpret_cast(&ip[k])), + 16)), + svld1_f32(pg, &op[k]))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int32_t_bfloat16_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::BFloat16* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_bfloat16_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int32_t_bfloat16_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::BFloat16* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_bfloat16_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int64_t_bfloat16_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::BFloat16* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])), + 16)), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])), + 16)), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])), + 16)), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])), + 16)), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])), + 16)), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])), + 16)), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])), + 16)), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])), + 16)), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])), + 16)), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])), + 16)), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])), + 16)), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])), + 16)), + vsum15); + vsum16 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[16 * vLen])), + 16)), + vsum16); + vsum17 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[17 * vLen])), + 16)), + vsum17); + vsum18 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[18 * vLen])), + 16)), + vsum18); + vsum19 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[19 * vLen])), + 16)), + vsum19); + vsum20 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[20 * vLen])), + 16)), + vsum20); + vsum21 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[21 * vLen])), + 16)), + vsum21); + vsum22 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[22 * vLen])), + 16)), + vsum22); + vsum23 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[23 * vLen])), + 16)), + vsum23); + vsum24 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[24 * vLen])), + 16)), + vsum24); + vsum25 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[25 * vLen])), + 16)), + vsum25); + vsum26 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[26 * vLen])), + 16)), + vsum26); + vsum27 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[27 * vLen])), + 16)), + vsum27); + vsum28 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[28 * vLen])), + 16)), + vsum28); + vsum29 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[29 * vLen])), + 16)), + vsum29); + vsum30 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[30 * vLen])), + 16)), + vsum30); + vsum31 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[31 * vLen])), + 16)), + vsum31); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])), + 16)), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])), + 16)), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])), + 16)), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])), + 16)), + vsum7); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[8 * vLen])), + 16)), + vsum8); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[9 * vLen])), + 16)), + vsum9); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[10 * vLen])), + 16)), + vsum10); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[11 * vLen])), + 16)), + vsum11); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[12 * vLen])), + 16)), + vsum12); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[13 * vLen])), + 16)), + vsum13); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[14 * vLen])), + 16)), + vsum14); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[15 * vLen])), + 16)), + vsum15); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[4 * vLen])), + 16)), + vsum4); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[5 * vLen])), + 16)), + vsum5); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[6 * vLen])), + 16)), + vsum6); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[7 * vLen])), + 16)), + vsum7); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[2 * vLen])), + 16)), + vsum2); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[3 * vLen])), + 16)), + vsum3); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[0 * vLen])), + 16)), + vsum0); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + svAll, + svld1uh_u32( + svAll, reinterpret_cast(&ip[1 * vLen])), + 16)), + vsum1); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + const svfloat32_t vwgt = svdup_n_f32(wgt); + const at::BFloat16* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, + vwgt, + svreinterpret_f32_u32(svlsl_n_u32_x( + pg, + svld1uh_u32( + pg, reinterpret_cast(&ip[k])), + 16)), + svld1_f32(pg, &op[k]))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int64_t_bfloat16_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::BFloat16* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_bfloat16_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int64_t_bfloat16_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const at::BFloat16* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_bfloat16_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int32_t_uint8_t_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), + svadd_f32_x(svAll, vsum4, vbio)); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), + svadd_f32_x(svAll, vsum5, vbio)); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), + svadd_f32_x(svAll, vsum6, vbio)); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), + svadd_f32_x(svAll, vsum7, vbio)); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), + svadd_f32_x(svAll, vsum8, vbio)); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), + svadd_f32_x(svAll, vsum9, vbio)); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), + svadd_f32_x(svAll, vsum10, vbio)); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), + svadd_f32_x(svAll, vsum11, vbio)); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), + svadd_f32_x(svAll, vsum12, vbio)); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), + svadd_f32_x(svAll, vsum13, vbio)); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), + svadd_f32_x(svAll, vsum14, vbio)); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), + svadd_f32_x(svAll, vsum15, vbio)); + vsum16 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[16 * vLen])), + svadd_f32_x(svAll, vsum16, vbio)); + vsum17 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[17 * vLen])), + svadd_f32_x(svAll, vsum17, vbio)); + vsum18 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[18 * vLen])), + svadd_f32_x(svAll, vsum18, vbio)); + vsum19 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[19 * vLen])), + svadd_f32_x(svAll, vsum19, vbio)); + vsum20 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[20 * vLen])), + svadd_f32_x(svAll, vsum20, vbio)); + vsum21 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[21 * vLen])), + svadd_f32_x(svAll, vsum21, vbio)); + vsum22 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[22 * vLen])), + svadd_f32_x(svAll, vsum22, vbio)); + vsum23 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[23 * vLen])), + svadd_f32_x(svAll, vsum23, vbio)); + vsum24 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[24 * vLen])), + svadd_f32_x(svAll, vsum24, vbio)); + vsum25 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[25 * vLen])), + svadd_f32_x(svAll, vsum25, vbio)); + vsum26 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[26 * vLen])), + svadd_f32_x(svAll, vsum26, vbio)); + vsum27 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[27 * vLen])), + svadd_f32_x(svAll, vsum27, vbio)); + vsum28 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[28 * vLen])), + svadd_f32_x(svAll, vsum28, vbio)); + vsum29 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[29 * vLen])), + svadd_f32_x(svAll, vsum29, vbio)); + vsum30 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[30 * vLen])), + svadd_f32_x(svAll, vsum30, vbio)); + vsum31 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[31 * vLen])), + svadd_f32_x(svAll, vsum31, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), + svadd_f32_x(svAll, vsum4, vbio)); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), + svadd_f32_x(svAll, vsum5, vbio)); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), + svadd_f32_x(svAll, vsum6, vbio)); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), + svadd_f32_x(svAll, vsum7, vbio)); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), + svadd_f32_x(svAll, vsum8, vbio)); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), + svadd_f32_x(svAll, vsum9, vbio)); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), + svadd_f32_x(svAll, vsum10, vbio)); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), + svadd_f32_x(svAll, vsum11, vbio)); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), + svadd_f32_x(svAll, vsum12, vbio)); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), + svadd_f32_x(svAll, vsum13, vbio)); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), + svadd_f32_x(svAll, vsum14, vbio)); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), + svadd_f32_x(svAll, vsum15, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), + svadd_f32_x(svAll, vsum4, vbio)); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), + svadd_f32_x(svAll, vsum5, vbio)); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), + svadd_f32_x(svAll, vsum6, vbio)); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), + svadd_f32_x(svAll, vsum7, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + // unimplemented + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, + vwgt, + svcvt_f32_u32_x(pg, svld1ub_u32(pg, &ip[k])), + svadd_f32_x(pg, svld1_f32(pg, &op[k]), vbio))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_uint8_t_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int32_t* indices, + const int32_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int32_t_uint8_t_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +template +static bool EmbeddingLookupIdx_int64_t_uint8_t_float__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + const svbool_t svAll = svptrue_b32(); + const auto vLen = static_cast(svcntw()); + int64_t pos = 0; + if (block_size == 32 * vLen) { + // unrolling 32 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + svfloat32_t vsum16 = svdup_n_f32(0); + svfloat32_t vsum17 = svdup_n_f32(0); + svfloat32_t vsum18 = svdup_n_f32(0); + svfloat32_t vsum19 = svdup_n_f32(0); + svfloat32_t vsum20 = svdup_n_f32(0); + svfloat32_t vsum21 = svdup_n_f32(0); + svfloat32_t vsum22 = svdup_n_f32(0); + svfloat32_t vsum23 = svdup_n_f32(0); + svfloat32_t vsum24 = svdup_n_f32(0); + svfloat32_t vsum25 = svdup_n_f32(0); + svfloat32_t vsum26 = svdup_n_f32(0); + svfloat32_t vsum27 = svdup_n_f32(0); + svfloat32_t vsum28 = svdup_n_f32(0); + svfloat32_t vsum29 = svdup_n_f32(0); + svfloat32_t vsum30 = svdup_n_f32(0); + svfloat32_t vsum31 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), + svadd_f32_x(svAll, vsum4, vbio)); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), + svadd_f32_x(svAll, vsum5, vbio)); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), + svadd_f32_x(svAll, vsum6, vbio)); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), + svadd_f32_x(svAll, vsum7, vbio)); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), + svadd_f32_x(svAll, vsum8, vbio)); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), + svadd_f32_x(svAll, vsum9, vbio)); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), + svadd_f32_x(svAll, vsum10, vbio)); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), + svadd_f32_x(svAll, vsum11, vbio)); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), + svadd_f32_x(svAll, vsum12, vbio)); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), + svadd_f32_x(svAll, vsum13, vbio)); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), + svadd_f32_x(svAll, vsum14, vbio)); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), + svadd_f32_x(svAll, vsum15, vbio)); + vsum16 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[16 * vLen])), + svadd_f32_x(svAll, vsum16, vbio)); + vsum17 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[17 * vLen])), + svadd_f32_x(svAll, vsum17, vbio)); + vsum18 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[18 * vLen])), + svadd_f32_x(svAll, vsum18, vbio)); + vsum19 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[19 * vLen])), + svadd_f32_x(svAll, vsum19, vbio)); + vsum20 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[20 * vLen])), + svadd_f32_x(svAll, vsum20, vbio)); + vsum21 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[21 * vLen])), + svadd_f32_x(svAll, vsum21, vbio)); + vsum22 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[22 * vLen])), + svadd_f32_x(svAll, vsum22, vbio)); + vsum23 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[23 * vLen])), + svadd_f32_x(svAll, vsum23, vbio)); + vsum24 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[24 * vLen])), + svadd_f32_x(svAll, vsum24, vbio)); + vsum25 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[25 * vLen])), + svadd_f32_x(svAll, vsum25, vbio)); + vsum26 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[26 * vLen])), + svadd_f32_x(svAll, vsum26, vbio)); + vsum27 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[27 * vLen])), + svadd_f32_x(svAll, vsum27, vbio)); + vsum28 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[28 * vLen])), + svadd_f32_x(svAll, vsum28, vbio)); + vsum29 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[29 * vLen])), + svadd_f32_x(svAll, vsum29, vbio)); + vsum30 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[30 * vLen])), + svadd_f32_x(svAll, vsum30, vbio)); + vsum31 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[31 * vLen])), + svadd_f32_x(svAll, vsum31, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + svst1_f32(svAll, &op[16 * vLen], svmul_f32_x(svAll, vsum16, vlen_inv)); + svst1_f32(svAll, &op[17 * vLen], svmul_f32_x(svAll, vsum17, vlen_inv)); + svst1_f32(svAll, &op[18 * vLen], svmul_f32_x(svAll, vsum18, vlen_inv)); + svst1_f32(svAll, &op[19 * vLen], svmul_f32_x(svAll, vsum19, vlen_inv)); + svst1_f32(svAll, &op[20 * vLen], svmul_f32_x(svAll, vsum20, vlen_inv)); + svst1_f32(svAll, &op[21 * vLen], svmul_f32_x(svAll, vsum21, vlen_inv)); + svst1_f32(svAll, &op[22 * vLen], svmul_f32_x(svAll, vsum22, vlen_inv)); + svst1_f32(svAll, &op[23 * vLen], svmul_f32_x(svAll, vsum23, vlen_inv)); + svst1_f32(svAll, &op[24 * vLen], svmul_f32_x(svAll, vsum24, vlen_inv)); + svst1_f32(svAll, &op[25 * vLen], svmul_f32_x(svAll, vsum25, vlen_inv)); + svst1_f32(svAll, &op[26 * vLen], svmul_f32_x(svAll, vsum26, vlen_inv)); + svst1_f32(svAll, &op[27 * vLen], svmul_f32_x(svAll, vsum27, vlen_inv)); + svst1_f32(svAll, &op[28 * vLen], svmul_f32_x(svAll, vsum28, vlen_inv)); + svst1_f32(svAll, &op[29 * vLen], svmul_f32_x(svAll, vsum29, vlen_inv)); + svst1_f32(svAll, &op[30 * vLen], svmul_f32_x(svAll, vsum30, vlen_inv)); + svst1_f32(svAll, &op[31 * vLen], svmul_f32_x(svAll, vsum31, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + svst1_f32(svAll, &op[16 * vLen], vsum16); + svst1_f32(svAll, &op[17 * vLen], vsum17); + svst1_f32(svAll, &op[18 * vLen], vsum18); + svst1_f32(svAll, &op[19 * vLen], vsum19); + svst1_f32(svAll, &op[20 * vLen], vsum20); + svst1_f32(svAll, &op[21 * vLen], vsum21); + svst1_f32(svAll, &op[22 * vLen], vsum22); + svst1_f32(svAll, &op[23 * vLen], vsum23); + svst1_f32(svAll, &op[24 * vLen], vsum24); + svst1_f32(svAll, &op[25 * vLen], vsum25); + svst1_f32(svAll, &op[26 * vLen], vsum26); + svst1_f32(svAll, &op[27 * vLen], vsum27); + svst1_f32(svAll, &op[28 * vLen], vsum28); + svst1_f32(svAll, &op[29 * vLen], vsum29); + svst1_f32(svAll, &op[30 * vLen], vsum30); + svst1_f32(svAll, &op[31 * vLen], vsum31); + } + } + } else if (block_size == 16 * vLen) { + // unrolling 16 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + svfloat32_t vsum8 = svdup_n_f32(0); + svfloat32_t vsum9 = svdup_n_f32(0); + svfloat32_t vsum10 = svdup_n_f32(0); + svfloat32_t vsum11 = svdup_n_f32(0); + svfloat32_t vsum12 = svdup_n_f32(0); + svfloat32_t vsum13 = svdup_n_f32(0); + svfloat32_t vsum14 = svdup_n_f32(0); + svfloat32_t vsum15 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), + svadd_f32_x(svAll, vsum4, vbio)); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), + svadd_f32_x(svAll, vsum5, vbio)); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), + svadd_f32_x(svAll, vsum6, vbio)); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), + svadd_f32_x(svAll, vsum7, vbio)); + vsum8 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[8 * vLen])), + svadd_f32_x(svAll, vsum8, vbio)); + vsum9 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[9 * vLen])), + svadd_f32_x(svAll, vsum9, vbio)); + vsum10 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[10 * vLen])), + svadd_f32_x(svAll, vsum10, vbio)); + vsum11 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[11 * vLen])), + svadd_f32_x(svAll, vsum11, vbio)); + vsum12 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[12 * vLen])), + svadd_f32_x(svAll, vsum12, vbio)); + vsum13 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[13 * vLen])), + svadd_f32_x(svAll, vsum13, vbio)); + vsum14 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[14 * vLen])), + svadd_f32_x(svAll, vsum14, vbio)); + vsum15 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[15 * vLen])), + svadd_f32_x(svAll, vsum15, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + svst1_f32(svAll, &op[8 * vLen], svmul_f32_x(svAll, vsum8, vlen_inv)); + svst1_f32(svAll, &op[9 * vLen], svmul_f32_x(svAll, vsum9, vlen_inv)); + svst1_f32(svAll, &op[10 * vLen], svmul_f32_x(svAll, vsum10, vlen_inv)); + svst1_f32(svAll, &op[11 * vLen], svmul_f32_x(svAll, vsum11, vlen_inv)); + svst1_f32(svAll, &op[12 * vLen], svmul_f32_x(svAll, vsum12, vlen_inv)); + svst1_f32(svAll, &op[13 * vLen], svmul_f32_x(svAll, vsum13, vlen_inv)); + svst1_f32(svAll, &op[14 * vLen], svmul_f32_x(svAll, vsum14, vlen_inv)); + svst1_f32(svAll, &op[15 * vLen], svmul_f32_x(svAll, vsum15, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + svst1_f32(svAll, &op[8 * vLen], vsum8); + svst1_f32(svAll, &op[9 * vLen], vsum9); + svst1_f32(svAll, &op[10 * vLen], vsum10); + svst1_f32(svAll, &op[11 * vLen], vsum11); + svst1_f32(svAll, &op[12 * vLen], vsum12); + svst1_f32(svAll, &op[13 * vLen], vsum13); + svst1_f32(svAll, &op[14 * vLen], vsum14); + svst1_f32(svAll, &op[15 * vLen], vsum15); + } + } + } else if (block_size == 8 * vLen) { + // unrolling 8 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + svfloat32_t vsum4 = svdup_n_f32(0); + svfloat32_t vsum5 = svdup_n_f32(0); + svfloat32_t vsum6 = svdup_n_f32(0); + svfloat32_t vsum7 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + vsum4 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[4 * vLen])), + svadd_f32_x(svAll, vsum4, vbio)); + vsum5 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[5 * vLen])), + svadd_f32_x(svAll, vsum5, vbio)); + vsum6 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[6 * vLen])), + svadd_f32_x(svAll, vsum6, vbio)); + vsum7 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[7 * vLen])), + svadd_f32_x(svAll, vsum7, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + svst1_f32(svAll, &op[4 * vLen], svmul_f32_x(svAll, vsum4, vlen_inv)); + svst1_f32(svAll, &op[5 * vLen], svmul_f32_x(svAll, vsum5, vlen_inv)); + svst1_f32(svAll, &op[6 * vLen], svmul_f32_x(svAll, vsum6, vlen_inv)); + svst1_f32(svAll, &op[7 * vLen], svmul_f32_x(svAll, vsum7, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + svst1_f32(svAll, &op[4 * vLen], vsum4); + svst1_f32(svAll, &op[5 * vLen], vsum5); + svst1_f32(svAll, &op[6 * vLen], vsum6); + svst1_f32(svAll, &op[7 * vLen], vsum7); + } + } + } else if (block_size == 4 * vLen) { + // unrolling 4 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + svfloat32_t vsum2 = svdup_n_f32(0); + svfloat32_t vsum3 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + vsum2 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[2 * vLen])), + svadd_f32_x(svAll, vsum2, vbio)); + vsum3 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[3 * vLen])), + svadd_f32_x(svAll, vsum3, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + svst1_f32(svAll, &op[2 * vLen], svmul_f32_x(svAll, vsum2, vlen_inv)); + svst1_f32(svAll, &op[3 * vLen], svmul_f32_x(svAll, vsum3, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + svst1_f32(svAll, &op[2 * vLen], vsum2); + svst1_f32(svAll, &op[3 * vLen], vsum3); + } + } + } else if (block_size == 2 * vLen) { + // unrolling 2 times + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + if (pos != offsets[i] - offsets[0]) { + return false; + } + svfloat32_t vsum0 = svdup_n_f32(0); + svfloat32_t vsum1 = svdup_n_f32(0); + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* const ip = &input[idx * block_size]; + // weight * input + out + vsum0 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[0 * vLen])), + svadd_f32_x(svAll, vsum0, vbio)); + vsum1 = svmad_f32_x( + svAll, + vwgt, + svcvt_f32_u32_x(svAll, svld1ub_u32(svAll, &ip[1 * vLen])), + svadd_f32_x(svAll, vsum1, vbio)); + ++pos; + } + // Normalisation + const int64_t length = end_offset - start_offset; + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + const svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svst1_f32(svAll, &op[0 * vLen], svmul_f32_x(svAll, vsum0, vlen_inv)); + svst1_f32(svAll, &op[1 * vLen], svmul_f32_x(svAll, vsum1, vlen_inv)); + } else { + svst1_f32(svAll, &op[0 * vLen], vsum0); + svst1_f32(svAll, &op[1 * vLen], vsum1); + } + } + } else { + // generic code: + for (int64_t i = 0; i < output_size; ++i) { + float* const op = &out[i * block_size]; + memset(op, 0, sizeof(float) * block_size); + if (pos != offsets[i] - offsets[0]) { + return false; + } + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1]; + for (int64_t j = start_offset; j < end_offset; ++j) { + const auto idx = indices[pos]; + if (idx < 0 || idx >= data_size) { + return false; + } + // unimplemented + float wgt = 1.f; + float bio{}; + if (weights) { + wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos]; + } + if (scale_bias) { + bio = wgt * scale_bias[2 * idx + 1]; + wgt = wgt * scale_bias[2 * idx]; + } + svfloat32_t vbio = svdup_n_f32(bio); + const svfloat32_t vwgt = svdup_n_f32(wgt); + const uint8_t* ip = &input[idx * block_size]; + svbool_t pg; + for (int64_t k = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(k, block_size)); + k += vLen) { + svst1_f32( + pg, + &op[k], + svmad_f32_x( + pg, + vwgt, + svcvt_f32_u32_x(pg, svld1ub_u32(pg, &ip[k])), + svadd_f32_x(pg, svld1_f32(pg, &op[k]), vbio))); + } + + ++pos; + } + const int64_t length = end_offset - start_offset; + + if (normalize_by_lengths && length != 0) { + const float len_inv = 1.0f / length; + svfloat32_t vlen_inv = svdup_n_f32(len_inv); + svbool_t pg; + for (int64_t j = 0; + svptest_first(svAll, pg = svwhilelt_b32_s64(j, block_size)); + j += vLen) { + svst1_f32( + pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv)); + } + } + } + } + return pos == index_size; +} +bool EmbeddingLookupIdx_int64_t_uint8_t_float_false__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_uint8_t_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} +bool EmbeddingLookupIdx_int64_t_uint8_t_float_true__sve( + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t data_size, + const uint8_t* input, + const int64_t* indices, + const int64_t* offsets, + const float* weights, + const float* scale_bias, + bool normalize_by_lengths, + float* out) { + return EmbeddingLookupIdx_int64_t_uint8_t_float__sve( + block_size, + output_size, + index_size, + data_size, + input, + indices, + offsets, + weights, + scale_bias, + normalize_by_lengths, + out); +} + +} // namespace caffe2 diff --git a/caffe2/perfkernels/sve_emblookup_codegen.py b/caffe2/perfkernels/sve_emblookup_codegen.py new file mode 100644 index 0000000000000..02f010ccc250d --- /dev/null +++ b/caffe2/perfkernels/sve_emblookup_codegen.py @@ -0,0 +1,408 @@ +# mypy: allow-untyped-defs +import argparse +import sys + +# Unroll loops when block_size is a multiple of vector length. +def unroll(num_unrolls, IndexType, InType, OutType, use_weights): + def compute(regid, InType, use_weights): + code = [] + + if InType == "float": + code.append( + f" vsum{regid} =\n" + " svmad_f32_x(" + f"svAll, vwgt, svld1_f32(svAll, &ip[{regid} * vLen])," + f" vsum{regid});" + ) + elif InType == "at::Half": + code.append( + f" vsum{regid} = svmad_f32_x(\n" + " svAll,\n" + " vwgt,\n" + " svcvt_f32_f16_x(\n" + " svAll,\n" + " svreinterpret_f16_u32(svld1uh_u32(\n" + " svAll, reinterpret_cast(" + f"&ip[{regid} * vLen])))),\n" # noqa + f" vsum{regid});" + ) + elif InType == "at::BFloat16": + code.append( + f" vsum{regid} = svmad_f32_x(\n" + " svAll,\n" + " vwgt,\n" + " svreinterpret_f32_u32(svlsl_n_u32_x(\n" + " svAll,\n" + " svld1uh_u32(\n" + " svAll, reinterpret_cast(" + f"&ip[{regid} * vLen])),\n" + " 16)),\n" # noqa + f" vsum{regid});" + ) + elif InType == "uint8_t": + code.append( + f" vsum{regid} = svmad_f32_x(\n" + " svAll,\n" + " vwgt,\n" + " svcvt_f32_u32_x(svAll," + f" svld1ub_u32(svAll, &ip[{regid} * vLen])),\n" # noqa + f" svadd_f32_x(svAll, vsum{regid}, vbio));" + ) + else: + raise ValueError(f"Unknown datatype \"{InType}\"") + + return code + + code = [] + code.append(f" // unrolling {num_unrolls} times") + + code.append(" for (int64_t i = 0; i < output_size; ++i) {") + + code.append(" " + OutType + "* const op = &out[i * block_size];") + code.append( + " if (pos != offsets[i] - offsets[0]) {\n" + + " return false;\n" + + " }" + ) + + # Initialise vector sum registers + for i in range(num_unrolls): + code.append(f" svfloat32_t vsum{i} = svdup_n_f32(0);") + + # inner loop + code.append("""\ + int64_t start_offset = offsets[i]; + int64_t end_offset = offsets[i + 1];""") + code.append( + " for (" + + "int64_t" + + " j = start_offset; j < end_offset; ++j) {" # noqa + ) + + code.append(" const auto idx = indices[pos];") + code.append( + " if (idx < 0 || idx >= data_size) {\n" + + " return false;\n" + + " }" + ) + + if InType == "uint8_t": + code.append(" " + OutType + " wgt = 1.f;") + code.append(" " + OutType + " bio{};") + code.append(" if (weights) {") + code.append( + " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa + ) + code.append(" }") + code.append(" if (scale_bias) {") + code.append(" bio = wgt * scale_bias[2 * idx + 1];") + code.append(" wgt = wgt * scale_bias[2 * idx];") + code.append(" }") + code.append(" svfloat32_t vbio = svdup_n_f32(bio);") + else: + code.append(" " + OutType + " wgt = 1.f;") + code.append(" if (weights) {") + code.append( + " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa + ) + code.append(" }") + + code.append(" const svfloat32_t vwgt = svdup_n_f32(wgt);") + code.append(f" const {InType}* const ip = &input[idx * block_size];") + code.append(" // weight * input + out") + + for i in range(num_unrolls): + code.extend(compute(i, InType, use_weights)) + + code.append(" ++pos;") + code.append(" }") + + code.append(" // Normalisation") + code.append(" const int64_t length = end_offset - start_offset;") + code.append(" if (normalize_by_lengths && length != 0) {") + code.append(" const float len_inv = 1.0f / length;") + code.append(" const svfloat32_t vlen_inv = svdup_n_f32(len_inv);") + + for i in range(num_unrolls): + code.append(f" svst1_f32(svAll, &op[{i} * vLen]," + + f" svmul_f32_x(svAll, vsum{i}, vlen_inv));") + + code.append(" } else {") + # inv of length + for i in range(num_unrolls): + code.append(f" svst1_f32(svAll, &op[{i} * vLen], vsum{i});") + + code.append(" }") + code.append(" }") + return code + + +# Handle the case where block_size is not a multiple of vector length. +def generic(IndexType, InType, OutType, use_weights): + def compute(InType, use_weights): + code = [] + if InType == "float": + code.append( + " svst1_f32(\n" + " pg,\n" + " &op[k],\n" + " svmad_f32_x(\n" + " pg, vwgt, svld1_f32(pg, &ip[k])," + " svld1_f32(pg, &op[k])));" + ) + elif InType == "at::Half": + code.append( + " svst1_f32(\n" + " pg,\n" + " &op[k],\n" + " svmad_f32_x(\n" + " pg,\n" + " vwgt,\n" + " svcvt_f32_f16_x(\n" + " pg,\n" + " svreinterpret_f16_u32(svld1uh_u32(\n" + " pg," + " reinterpret_cast(&ip[k])))),\n" + " svld1_f32(pg, &op[k])));" + ) + elif InType == "at::BFloat16": + code.append( + " svst1_f32(\n" + " pg,\n" + " &op[k],\n" + " svmad_f32_x(\n" + " pg,\n" + " vwgt,\n" + " svreinterpret_f32_u32(svlsl_n_u32_x(\n" + " pg,\n" + " svld1uh_u32(\n" + " pg," + " reinterpret_cast(&ip[k])),\n" + " 16)),\n" + " svld1_f32(pg, &op[k])));" + ) + elif InType == "uint8_t": + code.append( + " svst1_f32(\n" + " pg,\n" + " &op[k],\n" + " svmad_f32_x(\n" + " pg,\n" + " vwgt,\n" + " svcvt_f32_u32_x(pg," + " svld1ub_u32(pg, &ip[k])),\n" # noqa + " svadd_f32_x(pg," + " svld1_f32(pg, &op[k]), vbio)));" + ) + else: + raise ValueError(f"Unknown datatype \"{InType}\"") + + return code + + code = [] + + code.append( + " for (int64_t i = 0; i < output_size; ++i) {" + ) + + code.append(" " + OutType + "* const op = &out[i * block_size];") + + # initialize to 0 + code.append(" memset(op, 0, sizeof(float) * block_size);") + + # inner loop + code.append( + " if (pos != offsets[i] - offsets[0]) {\n" + + " return false;\n" + + " }" + ) + code.append( + " int64_t start_offset = offsets[i];\n" + + " int64_t end_offset = offsets[i + 1];" + ) + code.append( + " for (" + + "int64_t" + + " j = start_offset; j < end_offset; ++j) {" # noqa + ) + + code.append(" const auto idx = indices[pos];") + code.append( + " if (idx < 0 || idx >= data_size) {\n" + + " return false;\n" + + " }" + ) + + if InType == "uint8_t": + code.append(" // unimplemented") + code.append(" " + OutType + " wgt = 1.f;") + code.append(" " + OutType + " bio{};") + code.append(" if (weights) {") + code.append( + " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa + ) + code.append(" }") + code.append(" if (scale_bias) {") + code.append(" bio = wgt * scale_bias[2 * idx + 1];") + code.append(" wgt = wgt * scale_bias[2 * idx];") + code.append(" }") + code.append(" svfloat32_t vbio = svdup_n_f32(bio);") + else: + code.append(" " + OutType + " wgt = 1.f;") + code.append(" if (weights) {") + code.append( + " wgt = weights[IS_WEIGHT_POSITIONAL ? (j - start_offset) : pos];" # noqa + ) + code.append(" }") + + code.append(" const svfloat32_t vwgt = svdup_n_f32(wgt);") + code.append(f" const {InType}* ip = &input[idx * block_size];") + + # compute and store main loop + code.append(" svbool_t pg;") + code.append(" for (int64_t k = 0;") + code.append(" svptest_first(svAll, pg = svwhilelt_b32_s64(" + + "k, block_size));") + code.append(" k += vLen) {") + code.extend(compute(InType, use_weights)) + code.append(" }\n") + code.append(" ++pos;") + code.append(" }") + + code.append(" const int64_t length = end_offset - start_offset;\n") + code.append(" if (normalize_by_lengths && length != 0) {") + code.append(" const float len_inv = 1.0f / length;") + code.append(" svfloat32_t vlen_inv = svdup_n_f32(len_inv);") + code.append(" svbool_t pg;") + code.append(" for (int64_t j = 0;\n" + " svptest_first(svAll, pg = svwhilelt_b32_s64(" + "j, block_size));") + code.append(" j += vLen) {") + code.append( + " svst1_f32(\n" + " pg, &op[j], svmul_f32_x(pg, svld1_f32(pg, &op[j]), vlen_inv));" + ) + code.append(" }") + code.append(" }") + code.append(" }") + return code + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-f", "--filename", help="file name") + opts = parser.parse_args() + if opts.filename: + filename = opts.filename + else: + filename = "embedding_lookup_idx_sve.cc" + + options = [ + ["int32_t", "int32_t", "float", "float", "float", "float"], + ["int64_t", "int64_t", "float", "float", "float", "float"], + ["int32_t", "int32_t", "half", "at::Half", "float", "float"], + ["int64_t", "int64_t", "half", "at::Half", "float", "float"], + ["int32_t", "int32_t", "bfloat16", "at::BFloat16", "float", "float"], + ["int64_t", "int64_t", "bfloat16", "at::BFloat16", "float", "float"], + ["int32_t", "int32_t", "uint8_t", "uint8_t", "float", "float"], + ["int64_t", "int64_t", "uint8_t", "uint8_t", "float", "float"], + ] + + code = [] + # includes + code.append("//// --------------------------") + code.append("//// ATTENTION:") + code.append("//// THIS CODE IS AUTOGENERATED") + code.append(f"//// BY {' '.join(sys.argv)}") + code.append("//// DO NOT MODIFY!!!") + code.append("//// --------------------------\n") + + code.append("#include ") + code.append("#include ") + code.append("#include ") + code.append("#include ") + code.append("#include ") + + code.append("namespace caffe2 {\n") + for o in options: + [IndexTypeName, IndexType, InTypeName, InType, OutTypeName, OutType] = o + + code.append("template ") + fn_base = f"EmbeddingLookupIdx_{IndexTypeName}_{InTypeName}_{OutTypeName}" + + suffix = "__sve" + fn = "static bool " + fn_base + suffix + code.append(fn + "(") + + args = [] + args.append(" const int64_t block_size,") + args.append(" const int64_t output_size,") + args.append(" const int64_t index_size,") + args.append(" const int64_t data_size,") + args.append(" const " + InType + "* input,") + args.append(" const " + IndexType + "* indices,") + args.append(" const " + IndexType + "* offsets,") + args.append(" const float* weights,") + args.append(" const float* scale_bias,") + args.append(" bool normalize_by_lengths,") + args.append(" " + OutType + "* out) {") + code += args + + code.append(" const svbool_t svAll = svptrue_b32();") + code.append(" const auto vLen = static_cast(svcntw());") + code.append(" int64_t pos = 0;") + + code.append(" if (block_size == 32 * vLen) {") + code += unroll(32, IndexType, InType, OutType, True) + code.append(" } else if (block_size == 16 * vLen) {") + code += unroll(16, IndexType, InType, OutType, True) + code.append(" } else if (block_size == 8 * vLen) {") + code += unroll(8, IndexType, InType, OutType, True) + code.append(" } else if (block_size == 4 * vLen) {") + code += unroll(4, IndexType, InType, OutType, True) + code.append(" } else if (block_size == 2 * vLen) {") + code += unroll(2, IndexType, InType, OutType, True) + code.append(" } else {") + code.append(" // generic code:") + code += generic(IndexType, InType, OutType, True) + code.append(" }") + code.append(" return pos == index_size;") + + code.append("}") + + for is_weight_positional in ["false", "true"]: + code.append("bool " + fn_base + "_" + is_weight_positional + suffix + "(") + code += args + + # Resolve the Lint warnings: Limit of 80 characters in one line. + extra_space = "\n " + ret_string = " return " + fn_base + suffix \ + + "<" + is_weight_positional + ">(" + if len(ret_string) <= 80: + code.append(ret_string) + else: + code.append(" return " + fn_base + suffix + "<" + extra_space + is_weight_positional + ">(") + + code.append(" block_size,") + code.append(" output_size,") + code.append(" index_size,") + code.append(" data_size,") + code.append(" input,") + code.append(" indices,") + code.append(" offsets,") + code.append(" weights,") + code.append(" scale_bias,") + code.append(" normalize_by_lengths,") + code.append(" out);") + code.append("}") + + code.append("") + + code.append("} // namespace caffe2") + + with open(filename, "w") as fout: + fout.write("\n".join(code) + "\n") + + print("Created " + filename) + +if __name__ == "__main__": + main() diff --git a/caffe2/serialize/file_adapter.cc b/caffe2/serialize/file_adapter.cc index 67634d7f7fd27..3839fb5bbb83a 100644 --- a/caffe2/serialize/file_adapter.cc +++ b/caffe2/serialize/file_adapter.cc @@ -21,7 +21,7 @@ FileAdapter::RAIIFile::RAIIFile(const std::string& file_name) { auto error_msg = std::system_category().default_error_condition(old_errno).message(); #endif - AT_ERROR( + TORCH_CHECK(false, "open file failed because of errno ", old_errno, " on fopen: ", diff --git a/caffe2/serialize/istream_adapter.cc b/caffe2/serialize/istream_adapter.cc index 9509a088736ef..438901848f6b0 100644 --- a/caffe2/serialize/istream_adapter.cc +++ b/caffe2/serialize/istream_adapter.cc @@ -29,7 +29,7 @@ size_t IStreamAdapter::read(uint64_t pos, void* buf, size_t n, const char* what) void IStreamAdapter::validate(const char* what) const { if (!*istream_) { - AT_ERROR("istream reader failed: ", what, "."); + TORCH_CHECK(false, "istream reader failed: ", what, "."); } } diff --git a/caffe2/utils/CMakeLists.txt b/caffe2/utils/CMakeLists.txt index e168eb595feb2..c229f88168c23 100644 --- a/caffe2/utils/CMakeLists.txt +++ b/caffe2/utils/CMakeLists.txt @@ -3,7 +3,7 @@ list(APPEND Caffe2_CPU_SRCS utils/threadpool/ThreadPool.cc ) -if(USE_PTHREADPOOL AND NOT USE_INTERNAL_PTHREADPOOL_IMPL) +if(USE_PTHREADPOOL) list(APPEND Caffe2_CPU_SRCS utils/threadpool/pthreadpool-cpp.cc utils/threadpool/thread_pool_guard.cpp diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 7ef8eabb51627..19667b73287ca 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -161,7 +161,7 @@ else() set(AT_MKLDNN_ENABLED 0) set(AT_MKL_ENABLED 0) endif() -set_property(CACHE BLAS PROPERTY STRINGS "ATLAS;BLIS;Eigen;FLAME;Generic;MKL;OpenBLAS;vecLib") +set_property(CACHE BLAS PROPERTY STRINGS "ATLAS;BLIS;Eigen;FLAME;Generic;MKL;OpenBLAS;vecLib;APL") message(STATUS "Trying to find preferred BLAS backend of choice: " ${BLAS}) if(BLAS STREQUAL "Eigen") @@ -226,6 +226,12 @@ elseif(BLAS STREQUAL "FlexiBLAS") find_package(FlexiBLAS REQUIRED) include_directories(SYSTEM ${FlexiBLAS_INCLUDE_DIR}) list(APPEND Caffe2_DEPENDENCY_LIBS ${FlexiBLAS_LIB}) +elseif(BLAS STREQUAL "APL") + find_package(APL REQUIRED) + include_directories(SYSTEM ${APL_INCLUDE_DIR}) + set(BLAS_INFO "apl") + set(BLAS_FOUND 1) + set(BLAS_LIBRARIES ${APL_LIBRARIES}) elseif(BLAS STREQUAL "Generic") # On Debian family, the CBLAS ABIs have been merged into libblas.so if(ENV{GENERIC_BLAS_LIBRARIES} STREQUAL "") @@ -246,7 +252,7 @@ endif() if(NOT INTERN_BUILD_MOBILE) set(AT_MKL_SEQUENTIAL 0) set(USE_BLAS 1) - if(NOT (ATLAS_FOUND OR BLIS_FOUND OR GENERIC_BLAS_FOUND OR MKL_FOUND OR OpenBLAS_FOUND OR VECLIB_FOUND OR FlexiBLAS_FOUND OR NVPL_BLAS_FOUND)) + if(NOT (ATLAS_FOUND OR BLIS_FOUND OR GENERIC_BLAS_FOUND OR MKL_FOUND OR OpenBLAS_FOUND OR VECLIB_FOUND OR FlexiBLAS_FOUND OR NVPL_BLAS_FOUND OR APL_FOUND)) message(WARNING "Preferred BLAS (" ${BLAS} ") cannot be found, now searching for a general BLAS library") find_package(BLAS) if(NOT BLAS_FOUND) @@ -372,9 +378,6 @@ if(INTERN_BUILD_MOBILE OR NOT DISABLE_NNPACK_AND_FAMILY) set(USE_PTHREADPOOL ON CACHE BOOL "" FORCE) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_PTHREADPOOL") - # Always use third_party/pthreadpool. - set(USE_INTERNAL_PTHREADPOOL_IMPL OFF CACHE BOOL "" FORCE) - if(NOT TARGET pthreadpool) if(USE_SYSTEM_PTHREADPOOL) add_library(pthreadpool SHARED IMPORTED) @@ -384,7 +387,7 @@ if(INTERN_BUILD_MOBILE OR NOT DISABLE_NNPACK_AND_FAMILY) message(FATAL_ERROR "Cannot find pthreadpool") endif() message("-- Found pthreadpool: ${PTHREADPOOL_LIBRARY}") - elseif(NOT USE_INTERNAL_PTHREADPOOL_IMPL) + else() if(NOT DEFINED PTHREADPOOL_SOURCE_DIR) set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party") set(PTHREADPOOL_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/pthreadpool" CACHE STRING "pthreadpool source directory") @@ -400,11 +403,7 @@ if(INTERN_BUILD_MOBILE OR NOT DISABLE_NNPACK_AND_FAMILY) set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON) endif() - if(USE_INTERNAL_PTHREADPOOL_IMPL) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_INTERNAL_PTHREADPOOL_IMPL") - else() - list(APPEND Caffe2_DEPENDENCY_LIBS pthreadpool) - endif() + list(APPEND Caffe2_DEPENDENCY_LIBS pthreadpool) endif() else() set(USE_PTHREADPOOL OFF CACHE BOOL "" FORCE) @@ -458,10 +457,6 @@ if(USE_PYTORCH_QNNPACK) endif() if(NOT TARGET pytorch_qnnpack) - if(NOT USE_SYSTEM_PTHREADPOOL AND USE_INTERNAL_PTHREADPOOL_IMPL) - set(PYTORCH_QNNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") - endif() - set(PYTORCH_QNNPACK_BUILD_TESTS OFF CACHE BOOL "") set(PYTORCH_QNNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "") set(PYTORCH_QNNPACK_LIBRARY_TYPE "static" CACHE STRING "") @@ -474,28 +469,6 @@ if(USE_PYTORCH_QNNPACK) set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON) # QNNPACK depends on gemmlowp headers target_include_directories(pytorch_qnnpack PRIVATE "${CAFFE2_THIRD_PARTY_ROOT}/gemmlowp") - - if(PYTORCH_QNNPACK_CUSTOM_THREADPOOL) - target_compile_definitions( - pytorch_qnnpack PRIVATE - pthreadpool_t=legacy_pthreadpool_t - pthreadpool_function_1d_t=legacy_pthreadpool_function_1d_t - pthreadpool_function_1d_tiled_t=legacy_pthreadpool_function_1d_tiled_t - pthreadpool_function_2d_t=legacy_pthreadpool_function_2d_t - pthreadpool_function_2d_tiled_t=legacy_pthreadpool_function_2d_tiled_t - pthreadpool_function_3d_tiled_t=legacy_pthreadpool_function_3d_tiled_t - pthreadpool_function_4d_tiled_t=legacy_pthreadpool_function_4d_tiled_t - pthreadpool_create=legacy_pthreadpool_create - pthreadpool_destroy=legacy_pthreadpool_destroy - pthreadpool_get_threads_count=legacy_pthreadpool_get_threads_count - pthreadpool_compute_1d=legacy_pthreadpool_compute_1d - pthreadpool_parallelize_1d=legacy_pthreadpool_parallelize_1d - pthreadpool_compute_1d_tiled=legacy_pthreadpool_compute_1d_tiled - pthreadpool_compute_2d=legacy_pthreadpool_compute_2d - pthreadpool_compute_2d_tiled=legacy_pthreadpool_compute_2d_tiled - pthreadpool_compute_3d_tiled=legacy_pthreadpool_compute_3d_tiled - pthreadpool_compute_4d_tiled=legacy_pthreadpool_compute_4d_tiled) - endif() endif() list(APPEND Caffe2_DEPENDENCY_LIBS pytorch_qnnpack) diff --git a/cmake/External/nnpack.cmake b/cmake/External/nnpack.cmake index 9d5f0643ece7c..7890e1f8a8b74 100644 --- a/cmake/External/nnpack.cmake +++ b/cmake/External/nnpack.cmake @@ -57,10 +57,6 @@ if(ANDROID OR IOS OR ${CMAKE_SYSTEM_NAME} STREQUAL "Linux" OR ${CMAKE_SYSTEM_NAM set(GOOGLETEST_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/googletest" CACHE STRING "Google Test source directory") if(NOT TARGET nnpack) - if(NOT USE_SYSTEM_PTHREADPOOL AND USE_INTERNAL_PTHREADPOOL_IMPL) - set(NNPACK_CUSTOM_THREADPOOL ON CACHE BOOL "") - endif() - set(NNPACK_BUILD_TESTS OFF CACHE BOOL "") set(NNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "") set(NNPACK_LIBRARY_TYPE "static" CACHE STRING "") @@ -75,27 +71,6 @@ if(ANDROID OR IOS OR ${CMAKE_SYSTEM_NAME} STREQUAL "Linux" OR ${CMAKE_SYSTEM_NAM set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON) set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON) - if(NNPACK_CUSTOM_THREADPOOL) - target_compile_definitions( - nnpack PRIVATE - pthreadpool_t=legacy_pthreadpool_t - pthreadpool_function_1d_t=legacy_pthreadpool_function_1d_t - pthreadpool_function_1d_tiled_t=legacy_pthreadpool_function_1d_tiled_t - pthreadpool_function_2d_t=legacy_pthreadpool_function_2d_t - pthreadpool_function_2d_tiled_t=legacy_pthreadpool_function_2d_tiled_t - pthreadpool_function_3d_tiled_t=legacy_pthreadpool_function_3d_tiled_t - pthreadpool_function_4d_tiled_t=legacy_pthreadpool_function_4d_tiled_t - pthreadpool_create=legacy_pthreadpool_create - pthreadpool_destroy=legacy_pthreadpool_destroy - pthreadpool_get_threads_count=legacy_pthreadpool_get_threads_count - pthreadpool_compute_1d=legacy_pthreadpool_compute_1d - pthreadpool_parallelize_1d=legacy_pthreadpool_parallelize_1d - pthreadpool_compute_1d_tiled=legacy_pthreadpool_compute_1d_tiled - pthreadpool_compute_2d=legacy_pthreadpool_compute_2d - pthreadpool_compute_2d_tiled=legacy_pthreadpool_compute_2d_tiled - pthreadpool_compute_3d_tiled=legacy_pthreadpool_compute_3d_tiled - pthreadpool_compute_4d_tiled=legacy_pthreadpool_compute_4d_tiled) - endif() endif() set(NNPACK_FOUND TRUE) diff --git a/cmake/External/rccl.cmake b/cmake/External/rccl.cmake index 911c80f3b9b3d..535bf8e28bd7b 100644 --- a/cmake/External/rccl.cmake +++ b/cmake/External/rccl.cmake @@ -7,8 +7,7 @@ if(NOT __NCCL_INCLUDED) if(rccl_FOUND) message(STATUS "RCCL Found!") add_library(__caffe2_nccl INTERFACE) - target_link_libraries(__caffe2_nccl INTERFACE ${PYTORCH_RCCL_LIBRARIES}) - target_include_directories(__caffe2_nccl INTERFACE ${RCCL_INCLUDE_DIRS}) + target_link_libraries(__caffe2_nccl INTERFACE roc::rccl) else() message(STATUS "RCCL NOT Found!") endif() diff --git a/cmake/MiscCheck.cmake b/cmake/MiscCheck.cmake index 10fa810b8fdfb..74fc1487333af 100644 --- a/cmake/MiscCheck.cmake +++ b/cmake/MiscCheck.cmake @@ -101,6 +101,16 @@ endif() # Also, we will turn off deprecated-declarations # due to protobuf. +# ---[ Check if the compiler has SVE support. +find_package(ARM) # checks SVE +if(CXX_SVE_FOUND) + message(STATUS "Compiler supports SVE extension. Will build perfkernels.") + # Also see CMakeLists.txt under caffe2/perfkernels. + add_compile_definitions(CAFFE2_PERF_WITH_SVE=1) +else() + message(STATUS "Compiler does not support SVE extension. Will not build perfkernels.") +endif() + if(IOS AND (${IOS_ARCH} MATCHES "armv7*")) add_definitions("-mfpu=neon-fp16") add_definitions("-arch" ${IOS_ARCH}) diff --git a/cmake/Modules/FindAPL.cmake b/cmake/Modules/FindAPL.cmake new file mode 100644 index 0000000000000..7b97283b67f1f --- /dev/null +++ b/cmake/Modules/FindAPL.cmake @@ -0,0 +1,58 @@ +# - Find APL (Arm Performance Libraries) +# +# This module sets the following variables: +# APL_INCLUDE_SEARCH_PATHS - list of paths to search for APL include files +# APL_LIB_SEARCH_PATHS - list of paths to search for APL libraries +# APL_FOUND - set to true if APL is found +# APL_INCLUDE_DIR - path to include dir. +# APL_LIB_DIR - path to include dir. +# APL_LIBRARIES - list of libraries for base APL + +SET(APL_INCLUDE_SEARCH_PATHS $ENV{ARMPL_DIR}/include) +SET(APL_LIB_SEARCH_PATHS $ENV{ARMPL_DIR}/lib) + +SET(APL_FOUND ON) + +# Check include file +FIND_PATH(APL_INCLUDE_DIR NAMES armpl.h PATHS ${APL_INCLUDE_SEARCH_PATHS}) +IF(NOT APL_INCLUDE_DIR) + SET(APL_FOUND OFF) + MESSAGE(STATUS "Could not verify APL include directory. Turning APL_FOUND off") +ENDIF() + +# Check lib file +FIND_PATH(APL_LIB_DIR NAMES libarmpl_lp64_mp.dll.lib libomp.dll.lib libarmpl_lp64_mp.a PATHS ${APL_LIB_SEARCH_PATHS}) +IF(NOT APL_LIB_DIR) + SET(APL_FOUND OFF) + MESSAGE(STATUS "Could not verify APL lib directory. Turning APL_FOUND off") +ENDIF() + +IF (APL_FOUND) + IF(WIN32) + set(APL_LIBRARIES + "${APL_LIB_DIR}/libarmpl_lp64_mp.dll.lib" + "${APL_LIB_DIR}/libomp.dll.lib" + ) + ELSEIF(UNIX) + set(APL_LIBRARIES + "${APL_LIB_DIR}/libarmpl_lp64_mp.a" + ) + ENDIF() + MESSAGE(STATUS "Found APL header: ${APL_INCLUDE_DIR}") + MESSAGE(STATUS "Found APL library: ${APL_LIB_DIR}") + message(STATUS "APL_LIBRARIES: ${APL_LIBRARIES}") + SET(CMAKE_REQUIRED_LIBRARIES ${APL_LIBRARIES}) + include(CheckCSourceRuns) + CHECK_C_SOURCE_RUNS(" +#include +#include +float x[4] = { 1, 2, 3, 4 }; +float y[4] = { .1, .01, .001, .0001 }; +extern float cblas_sdot(); +int main() { + int i; + double r = cblas_sdot(4, x, 1, y, 1); + exit((float)r != (float).1234); +}" BLAS_USE_CBLAS_DOT ) + MESSAGE(STATUS "BLAS_USE_CBLAS_DOT: ${BLAS_USE_CBLAS_DOT}") +ENDIF (APL_FOUND) \ No newline at end of file diff --git a/cmake/Modules/FindLAPACK.cmake b/cmake/Modules/FindLAPACK.cmake index dbe47d6cdcf19..7d343f8adab7f 100644 --- a/cmake/Modules/FindLAPACK.cmake +++ b/cmake/Modules/FindLAPACK.cmake @@ -223,6 +223,34 @@ if(BLAS_FOUND) endif(LAPACK_LIBRARIES) endif() + #Arm Performance Libraries + IF((NOT LAPACK_INFO) AND (BLAS_INFO STREQUAL "apl")) + SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES}) + check_function_exists("cheev_" APL_LAPACK_WORKS) + if(APL_LAPACK_WORKS) + check_function_exists("cgesdd_" LAPACK_CGESDD_WORKS) + if(NOT LAPACK_CGESDD_WORKS) + find_library(GFORTRAN_LIBRARY + NAMES libgfortran.a gfortran + PATHS ${CMAKE_C_IMPLICIT_LINK_DIRECTORIES}) + list(APPEND CMAKE_REQUIRED_LIBRARIES "${GFORTRAN_LIBRARY}") + unset(LAPACK_CGESDD_WORKS CACHE) + check_function_exists("cgesdd_" LAPACK_CGESDD_WORKS) + if(LAPACK_CGESDD_WORKS) + list(APPEND LAPACK_LIBRARIES "${GFORTRAN_LIBRARY}") + else() + message(WARNING "APL has been compiled with Lapack support, but cgesdd can not be used") + set(APL_LAPACK_WORKS NO) + endif() + endif() + endif() + set(CMAKE_REQUIRED_LIBRARIES) + if(APL_LAPACK_WORKS) + SET(LAPACK_INFO "apl") + else() + message(STATUS "It seems APL has not been compiled with Lapack support") + endif() + endif() else(BLAS_FOUND) message(STATUS "LAPACK requires BLAS") endif(BLAS_FOUND) diff --git a/cmake/Summary.cmake b/cmake/Summary.cmake index d51c451589c2c..3f70465c91d6d 100644 --- a/cmake/Summary.cmake +++ b/cmake/Summary.cmake @@ -163,6 +163,9 @@ function(caffe2_print_configuration_summary) message(STATUS " USE_OPENCL : ${USE_OPENCL}") message(STATUS " USE_OPENMP : ${USE_OPENMP}") message(STATUS " USE_MIMALLOC : ${USE_MIMALLOC}") + if(${USE_MIMALLOC}) + message(STATUS " USE_MIMALLOC_ON_MKL : ${USE_MIMALLOC_ON_MKL}") + endif() message(STATUS " USE_VULKAN : ${USE_VULKAN}") if(${USE_VULKAN}) message(STATUS " USE_VULKAN_FP16_INFERENCE : ${USE_VULKAN_FP16_INFERENCE}") diff --git a/cmake/TorchConfig.cmake.in b/cmake/TorchConfig.cmake.in index cba4d9298551e..8028ef5866a1f 100644 --- a/cmake/TorchConfig.cmake.in +++ b/cmake/TorchConfig.cmake.in @@ -75,6 +75,9 @@ else() if(@USE_CUDA@) append_wholearchive_lib_if_found(torch_cuda c10_cuda) endif() + if(@USE_XPU@) + append_wholearchive_lib_if_found(torch_xpu c10_xpu) + endif() # We need manually add dependent libraries when they are not linked into the # shared library. @@ -99,11 +102,8 @@ else() append_torchlib_if_found(fmt) append_torchlib_if_found(cpuinfo clog) - if(NOT @USE_INTERNAL_PTHREADPOOL_IMPL@) - append_torchlib_if_found(pthreadpool) - endif() - append_torchlib_if_found(eigen_blas) + append_torchlib_if_found(pthreadpool) if(@USE_FBGEMM@) append_torchlib_if_found(fbgemm) @@ -138,6 +138,10 @@ if(@USE_CUDA@) list(APPEND TORCH_LIBRARIES ${TORCH_CUDA_LIBRARIES}) endif() +if(@USE_XPU@ AND @BUILD_SHARED_LIBS@) + append_torchlib_if_found(c10_xpu torch_xpu) +endif() + # When we build libtorch with the old libstdc++ ABI, dependent libraries must too. if(CMAKE_SYSTEM_NAME STREQUAL "Linux") set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=@GLIBCXX_USE_CXX11_ABI@") diff --git a/cmake/public/ComputeLibrary.cmake b/cmake/public/ComputeLibrary.cmake index d0b3b56ff531d..e18527ce65b0c 100644 --- a/cmake/public/ComputeLibrary.cmake +++ b/cmake/public/ComputeLibrary.cmake @@ -21,10 +21,10 @@ if("${ACL_VERSION_FILE}" STREQUAL "") message(WARNING "Build may fail: Could not determine ACL version (minimum required is ${ACL_MINIMUM_VERSION})") else() file(READ ${ACL_VERSION_FILE} ACL_VERSION_STRING) - string(REGEX MATCH "v([0-9]+\\.[0-9]+)" ACL_VERSION ${ACL_VERSION_STRING}) + string(REGEX MATCH "v([0-9]+\\.[0-9]+)" ACL_VERSION "${ACL_VERSION_STRING}") set(ACL_VERSION "${CMAKE_MATCH_1}") - if(${ACL_VERSION} VERSION_EQUAL "0.0") + if("${ACL_VERSION}" VERSION_EQUAL "0.0") # Unreleased ACL versions come with version string "v0.0-unreleased", and may not be compatible with oneDNN. # It is recommended to use the latest release of ACL. message(WARNING "Build may fail: Using unreleased ACL version (minimum required is ${ACL_MINIMUM_VERSION})") diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index 1c0d3a203991c..1499977f8e44e 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -1,19 +1,33 @@ set(PYTORCH_FOUND_HIP FALSE) -if(NOT DEFINED ENV{ROCM_PATH}) - set(ROCM_PATH /opt/rocm) -else() +# If ROCM_PATH is set, assume intention is to compile with +# ROCm support and error out if the ROCM_PATH does not exist. +# Else ROCM_PATH does not exist, assume a default of /opt/rocm +# In the latter case, if /opt/rocm does not exist emit status +# message and return. +if(DEFINED ENV{ROCM_PATH}) set(ROCM_PATH $ENV{ROCM_PATH}) + if(NOT EXISTS ${ROCM_PATH}) + message(FATAL_ERROR + "ROCM_PATH environment variable is set to ${ROCM_PATH} but does not exist.\n" + "Set a valid ROCM_PATH or unset ROCM_PATH environment variable to fix.") + endif() +else() + set(ROCM_PATH /opt/rocm) + if(NOT EXISTS ${ROCM_PATH}) + message(STATUS + "ROCM_PATH environment variable is not set and ${ROCM_PATH} does not exist.\n" + "Building without ROCm support.") + return() + endif() endif() + if(NOT DEFINED ENV{ROCM_INCLUDE_DIRS}) set(ROCM_INCLUDE_DIRS ${ROCM_PATH}/include) else() set(ROCM_INCLUDE_DIRS $ENV{ROCM_INCLUDE_DIRS}) endif() -if(NOT EXISTS ${ROCM_PATH}) - return() -endif() # MAGMA_HOME if(NOT DEFINED ENV{MAGMA_HOME}) @@ -30,78 +44,60 @@ endif() message("Building PyTorch for GPU arch: ${PYTORCH_ROCM_ARCH}") # Add HIP to the CMAKE Module Path +# needed because the find_package call to this module uses the Module mode search +# https://cmake.org/cmake/help/latest/command/find_package.html#search-modes set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH}) +# Add ROCM_PATH to CMAKE_PREFIX_PATH, needed because the find_package +# call to individual ROCM components uses the Config mode search +list(APPEND CMAKE_PREFIX_PATH ${ROCM_PATH}) + macro(find_package_and_print_version PACKAGE_NAME) find_package("${PACKAGE_NAME}" ${ARGN}) message("${PACKAGE_NAME} VERSION: ${${PACKAGE_NAME}_VERSION}") endmacro() # Find the HIP Package -find_package_and_print_version(HIP 1.0) +# MODULE argument is added for clarity that CMake is searching +# for FindHIP.cmake in Module mode +find_package_and_print_version(HIP 1.0 MODULE) if(HIP_FOUND) set(PYTORCH_FOUND_HIP TRUE) - set(FOUND_ROCM_VERSION_H FALSE) - - set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}") - set(file "${PROJECT_BINARY_DIR}/detect_rocm_version.cc") # Find ROCM version for checks - # ROCM 5.0 and later will have header api for version management - if(EXISTS ${ROCM_INCLUDE_DIRS}/rocm_version.h) - set(FOUND_ROCM_VERSION_H TRUE) - file(WRITE ${file} "" - "#include \n" - ) - elseif(EXISTS ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h) - set(FOUND_ROCM_VERSION_H TRUE) - file(WRITE ${file} "" - "#include \n" - ) + if(EXISTS ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h) + set(ROCM_HEADER_FILE ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h) else() - message("********************* rocm_version.h couldnt be found ******************\n") - endif() - - if(FOUND_ROCM_VERSION_H) - file(APPEND ${file} "" - "#include \n" - - "#ifndef ROCM_VERSION_PATCH\n" - "#define ROCM_VERSION_PATCH 0\n" - "#endif\n" - "#define STRINGIFYHELPER(x) #x\n" - "#define STRINGIFY(x) STRINGIFYHELPER(x)\n" - "int main() {\n" - " printf(\"%d.%d.%s\", ROCM_VERSION_MAJOR, ROCM_VERSION_MINOR, STRINGIFY(ROCM_VERSION_PATCH));\n" - " return 0;\n" - "}\n" - ) - - try_run(run_result compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file} - CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}" - RUN_OUTPUT_VARIABLE rocm_version_from_header - COMPILE_OUTPUT_VARIABLE output_var - ) - # We expect the compile to be successful if the include directory exists. - if(NOT compile_result) - message(FATAL_ERROR "Caffe2: Couldn't determine version from header: " ${output_var}) - endif() - message(STATUS "Caffe2: Header version is: " ${rocm_version_from_header}) - set(ROCM_VERSION_DEV_RAW ${rocm_version_from_header}) - message("\n***** ROCm version from rocm_version.h ****\n") - endif() - - string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+).*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW}) - - if(ROCM_VERSION_DEV_MATCH) - set(ROCM_VERSION_DEV_MAJOR ${CMAKE_MATCH_1}) - set(ROCM_VERSION_DEV_MINOR ${CMAKE_MATCH_2}) - set(ROCM_VERSION_DEV_PATCH ${CMAKE_MATCH_3}) - set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") - math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") + message(FATAL_ERROR "********************* rocm_version.h could not be found ******************\n") endif() + # Read the ROCM headerfile into a variable + file(READ ${ROCM_HEADER_FILE} ROCM_HEADER_CONTENT) + + # Below we use a RegEx to find ROCM version numbers. + # Note that CMake does not support \s for blank space. That is + # why in the regular expressions below we have a blank space in + # the square brackets. + # There are three steps: + # 1. Match regular expression + # 2. Strip the non-numerical part of the string + # 3. Strip leading and trailing spaces + string(REGEX MATCH "ROCM_VERSION_MAJOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT}) + string(REPLACE "ROCM_VERSION_MAJOR" "" TEMP2 ${TEMP1}) + string(STRIP ${TEMP2} ROCM_VERSION_DEV_MAJOR) + string(REGEX MATCH "ROCM_VERSION_MINOR[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT}) + string(REPLACE "ROCM_VERSION_MINOR" "" TEMP2 ${TEMP1}) + string(STRIP ${TEMP2} ROCM_VERSION_DEV_MINOR) + string(REGEX MATCH "ROCM_VERSION_PATCH[ ]+[0-9]+" TEMP1 ${ROCM_HEADER_CONTENT}) + string(REPLACE "ROCM_VERSION_PATCH" "" TEMP2 ${TEMP1}) + string(STRIP ${TEMP2} ROCM_VERSION_DEV_PATCH) + + # Create ROCM_VERSION_DEV_INT which is later used as a preprocessor macros + set(ROCM_VERSION_DEV "${ROCM_VERSION_DEV_MAJOR}.${ROCM_VERSION_DEV_MINOR}.${ROCM_VERSION_DEV_PATCH}") + math(EXPR ROCM_VERSION_DEV_INT "(${ROCM_VERSION_DEV_MAJOR}*10000) + (${ROCM_VERSION_DEV_MINOR}*100) + ${ROCM_VERSION_DEV_PATCH}") + + message("\n***** ROCm version from rocm_version.h ****\n") message("ROCM_VERSION_DEV: ${ROCM_VERSION_DEV}") message("ROCM_VERSION_DEV_MAJOR: ${ROCM_VERSION_DEV_MAJOR}") message("ROCM_VERSION_DEV_MINOR: ${ROCM_VERSION_DEV_MINOR}") @@ -113,42 +109,9 @@ if(HIP_FOUND) message("HIP_VERSION_MINOR: ${HIP_VERSION_MINOR}") message("TORCH_HIP_VERSION: ${TORCH_HIP_VERSION}") - message("\n***** Library versions from dpkg *****\n") - execute_process(COMMAND dpkg -l COMMAND grep rocm-dev COMMAND awk "{print $2 \" VERSION: \" $3}") - execute_process(COMMAND dpkg -l COMMAND grep rocm-libs COMMAND awk "{print $2 \" VERSION: \" $3}") - execute_process(COMMAND dpkg -l COMMAND grep hsakmt-roct COMMAND awk "{print $2 \" VERSION: \" $3}") - execute_process(COMMAND dpkg -l COMMAND grep rocr-dev COMMAND awk "{print $2 \" VERSION: \" $3}") - execute_process(COMMAND dpkg -l COMMAND grep -w hcc COMMAND awk "{print $2 \" VERSION: \" $3}") - execute_process(COMMAND dpkg -l COMMAND grep hip-base COMMAND awk "{print $2 \" VERSION: \" $3}") - execute_process(COMMAND dpkg -l COMMAND grep hip_hcc COMMAND awk "{print $2 \" VERSION: \" $3}") - + # Find ROCM components using Config mode + # These components will be searced for recursively in ${ROCM_PATH} message("\n***** Library versions from cmake find_package *****\n") - - set(CMAKE_HIP_CLANG_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG}) - set(CMAKE_HIP_CLANG_FLAGS_RELEASE ${CMAKE_CXX_FLAGS_RELEASE}) - ### Remove setting of Flags when FindHIP.CMake PR #558 is accepted.### - - set(hip_DIR ${ROCM_PATH}/lib/cmake/hip) - set(hsa-runtime64_DIR ${ROCM_PATH}/lib/cmake/hsa-runtime64) - set(AMDDeviceLibs_DIR ${ROCM_PATH}/lib/cmake/AMDDeviceLibs) - set(amd_comgr_DIR ${ROCM_PATH}/lib/cmake/amd_comgr) - set(rocrand_DIR ${ROCM_PATH}/lib/cmake/rocrand) - set(hiprand_DIR ${ROCM_PATH}/lib/cmake/hiprand) - set(rocblas_DIR ${ROCM_PATH}/lib/cmake/rocblas) - set(hipblas_DIR ${ROCM_PATH}/lib/cmake/hipblas) - set(hipblaslt_DIR ${ROCM_PATH}/lib/cmake/hipblaslt) - set(miopen_DIR ${ROCM_PATH}/lib/cmake/miopen) - set(rocfft_DIR ${ROCM_PATH}/lib/cmake/rocfft) - set(hipfft_DIR ${ROCM_PATH}/lib/cmake/hipfft) - set(hipsparse_DIR ${ROCM_PATH}/lib/cmake/hipsparse) - set(rccl_DIR ${ROCM_PATH}/lib/cmake/rccl) - set(rocprim_DIR ${ROCM_PATH}/lib/cmake/rocprim) - set(hipcub_DIR ${ROCM_PATH}/lib/cmake/hipcub) - set(rocthrust_DIR ${ROCM_PATH}/lib/cmake/rocthrust) - set(hipsolver_DIR ${ROCM_PATH}/lib/cmake/hipsolver) - set(hiprtc_DIR ${ROCM_PATH}/lib/cmake/hiprtc) - - find_package_and_print_version(hip REQUIRED) find_package_and_print_version(hsa-runtime64 REQUIRED) find_package_and_print_version(amd_comgr REQUIRED) @@ -167,27 +130,11 @@ if(HIP_FOUND) find_package_and_print_version(hipsolver REQUIRED) find_package_and_print_version(hiprtc REQUIRED) - - find_library(PYTORCH_HIP_LIBRARIES amdhip64 HINTS ${ROCM_PATH}/lib) - # TODO: miopen_LIBRARIES should return fullpath to the library file, - # however currently it's just the lib name - if(TARGET ${miopen_LIBRARIES}) - set(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES}) - else() - find_library(PYTORCH_MIOPEN_LIBRARIES ${miopen_LIBRARIES} HINTS ${ROCM_PATH}/lib) - endif() - # TODO: rccl_LIBRARIES should return fullpath to the library file, - # however currently it's just the lib name - if(TARGET ${rccl_LIBRARIES}) - set(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES}) - else() - find_library(PYTORCH_RCCL_LIBRARIES ${rccl_LIBRARIES} HINTS ${ROCM_PATH}/lib) - endif() - find_library(ROCM_HIPRTC_LIB hiprtc HINTS ${ROCM_PATH}/lib) # roctx is part of roctracer find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib) # check whether HIP declares new types + set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}") set(file "${PROJECT_BINARY_DIR}/hip_new_types.cc") file(WRITE ${file} "" "#include \n" diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake index afc1bc12abf7d..152fbdbe6dd9b 100644 --- a/cmake/public/cuda.cmake +++ b/cmake/public/cuda.cmake @@ -170,7 +170,11 @@ else() endif() # nvToolsExt -find_path(nvtx3_dir NAMES nvtx3 PATHS "${PROJECT_SOURCE_DIR}/third_party/NVTX/c/include" NO_DEFAULT_PATH) +if(USE_SYSTEM_NVTX) + find_path(nvtx3_dir NAMES nvtx3) +else() + find_path(nvtx3_dir NAMES nvtx3 PATHS "${PROJECT_SOURCE_DIR}/third_party/NVTX/c/include" NO_DEFAULT_PATH) +endif() find_package_handle_standard_args(nvtx3 DEFAULT_MSG nvtx3_dir) if(nvtx3_FOUND) add_library(torch::nvtx3 INTERFACE IMPORTED) diff --git a/cmake/public/utils.cmake b/cmake/public/utils.cmake index c6647eb457c3b..c796fab1e9ac6 100644 --- a/cmake/public/utils.cmake +++ b/cmake/public/utils.cmake @@ -306,6 +306,17 @@ macro(torch_hip_get_arch_list store_var) string(REPLACE " " ";" ${store_var} "${_TMP}") endmacro() +############################################################################## +# Get the XPU arch flags specified by TORCH_XPU_ARCH_LIST. +# Usage: +# torch_xpu_get_arch_list(variable_to_store_flags) +# +macro(torch_xpu_get_arch_list store_var) + if(DEFINED ENV{TORCH_XPU_ARCH_LIST}) + set(${store_var} $ENV{TORCH_XPU_ARCH_LIST}) + endif() +endmacro() + ############################################################################## # Get the NVCC arch flags specified by TORCH_CUDA_ARCH_LIST and CUDA_ARCH_NAME. # Usage: diff --git a/cmake/public/xpu.cmake b/cmake/public/xpu.cmake index d1a442f8efd41..5395fba562ef1 100644 --- a/cmake/public/xpu.cmake +++ b/cmake/public/xpu.cmake @@ -28,3 +28,8 @@ add_library(torch::xpurt INTERFACE IMPORTED) set_property( TARGET torch::xpurt PROPERTY INTERFACE_LINK_LIBRARIES torch::sycl) + +# setting xpu arch flags +torch_xpu_get_arch_list(XPU_ARCH_FLAGS) +# propagate to torch-xpu-ops +set(TORCH_XPU_ARCH_LIST ${XPU_ARCH_FLAGS}) diff --git a/docs/cpp/source/conf.py b/docs/cpp/source/conf.py index 838f5f2fd1d52..7e8cdb818319c 100644 --- a/docs/cpp/source/conf.py +++ b/docs/cpp/source/conf.py @@ -123,7 +123,7 @@ # General information about the project. project = "PyTorch" -copyright = "2022, PyTorch Contributors" +copyright = "2024, PyTorch Contributors" author = "PyTorch Contributors" # The version info for the project you're documenting, acts as replacement for diff --git a/docs/source/_static/css/custom.css b/docs/source/_static/css/custom.css new file mode 100644 index 0000000000000..5ab7e755a1252 --- /dev/null +++ b/docs/source/_static/css/custom.css @@ -0,0 +1,26 @@ +/* styles needed for the Google Search button */ + +.pytorch-left-menu-search input[type=text] { + background-image: none; +} + +.gsc-control-cse { + padding-left: 0px !important; + padding-bottom: 0px !important; +} + +.gsc-search-button .gsc-search-button-v2:focus { + border: transparent !important; + outline: none; + box-shadow: none; +} + +.gsc-search-button-v2:active { + border: none !important; +} + +.gsc-search-button-v2 { + border: none !important; +} + +/* End of Google Search button styles */ diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index 985440d78a179..832b2cc31e57d 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -28,7 +28,10 @@ - {% include "searchbox.html" %} + {% endblock %} {%- block content %} diff --git a/docs/source/backends.rst b/docs/source/backends.rst index 4eefe0ea36019..2fd9277fa814d 100644 --- a/docs/source/backends.rst +++ b/docs/source/backends.rst @@ -215,7 +215,7 @@ torch.backends.opt_einsum .. attribute:: enabled - A :class:``bool`` that controls whether opt_einsum is enabled (``True`` by default). If so, + A :class:`bool` that controls whether opt_einsum is enabled (``True`` by default). If so, torch.einsum will use opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html) if available to calculate an optimal path of contraction for faster performance. @@ -224,7 +224,7 @@ torch.backends.opt_einsum .. attribute:: strategy - A :class:``str`` that specifies which strategies to try when ``torch.backends.opt_einsum.enabled`` + A :class:`str` that specifies which strategies to try when ``torch.backends.opt_einsum.enabled`` is ``True``. By default, torch.einsum will try the "auto" strategy, but the "greedy" and "optimal" strategies are also supported. Note that the "optimal" strategy is factorial on the number of inputs as it tries all possible paths. See more details in opt_einsum's docs diff --git a/docs/source/community/persons_of_interest.rst b/docs/source/community/persons_of_interest.rst index acded203d5756..2ad9d1982e6cd 100644 --- a/docs/source/community/persons_of_interest.rst +++ b/docs/source/community/persons_of_interest.rst @@ -73,6 +73,7 @@ TorchInductor - Horace He (`Chillee `__) - Shunting Zhang (`shunting314 `__) - Jason Ansel (`jansel `__) +- Jiong Gong (`jgong5 `__) Cudagraph Tree ~~~~~~~~~~~~~~ @@ -311,6 +312,12 @@ PowerPC - (emeritus) Alfredo Mendoza (`avmgithub `__) +x86 CPU +~~~~~~~ + +- Mingfei Ma (`mingfeima `__) +- Jiong Gong (`jgong5 `__) + AArch64 CPU ~~~~~~~~~~~~ diff --git a/docs/source/conf.py b/docs/source/conf.py index 577466448e86a..e1e33302da4c8 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -2165,6 +2165,7 @@ "SynchronizationError", "UnsynchronizedAccessError", # torch.cuda.memory + "MemPool", "MemPoolContext", # torch.distributed.elastic.multiprocessing.errors "ChildFailedError", @@ -3352,7 +3353,7 @@ # General information about the project. project = "PyTorch" -copyright = "2023, PyTorch Contributors" +copyright = "2024, PyTorch Contributors" author = "PyTorch Contributors" torch_version = str(torch.__version__) @@ -3470,9 +3471,7 @@ # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ["_static"] -html_css_files = [ - "css/jit.css", -] +html_css_files = ["css/jit.css", "css/custom.css"] from sphinx.ext.coverage import CoverageBuilder diff --git a/docs/source/cuda.tunable.rst b/docs/source/cuda.tunable.rst index 52482122ec754..a73419d01e22c 100644 --- a/docs/source/cuda.tunable.rst +++ b/docs/source/cuda.tunable.rst @@ -19,6 +19,8 @@ API Reference .. autofunction:: is_enabled .. autofunction:: tuning_enable .. autofunction:: tuning_is_enabled +.. autofunction:: record_untuned_enable +.. autofunction:: record_untuned_is_enabled .. autofunction:: set_max_tuning_duration .. autofunction:: get_max_tuning_duration .. autofunction:: set_max_tuning_iterations @@ -30,3 +32,4 @@ API Reference .. autofunction:: write_file_on_exit .. autofunction:: write_file .. autofunction:: read_file +.. autofunction:: tune_gemm_in_file diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst index fc1706a661dd0..98f5520db2fc9 100644 --- a/docs/source/distributed.pipelining.rst +++ b/docs/source/distributed.pipelining.rst @@ -414,7 +414,7 @@ You can implement your own pipeline schedule by extending one of the following t ``PipelineScheduleMulti`` is for schedules that assigns multiple stages per rank. For example, ``ScheduleGPipe`` and ``Schedule1F1B`` are subclasses of ``PipelineScheduleSingle``. -Whereas, ``ScheduleFlexibleInterleaved1F1B``, ``ScheduleInterleaved1F1B`` and ``ScheduleLoopedBFS`` +Whereas, ``ScheduleInterleaved1F1B``, ``ScheduleLoopedBFS``, and ``ScheduleInterleavedZeroBubble`` are subclasses of ``PipelineScheduleMulti``. @@ -483,8 +483,6 @@ Pipeline Schedules .. autoclass:: Schedule1F1B -.. autoclass:: ScheduleFlexibleInterleaved1F1B - .. autoclass:: ScheduleInterleaved1F1B .. autoclass:: ScheduleLoopedBFS diff --git a/docs/source/export.rst b/docs/source/export.rst index 603594847f061..da7d827b3d035 100644 --- a/docs/source/export.rst +++ b/docs/source/export.rst @@ -369,7 +369,7 @@ You can also go from this IR to an inference IR via :func:`run_decompositions` w :: # Lower to core aten inference IR, but keep conv2d - decomp_table = torch.export.core_aten_decompositions() + decomp_table = torch.export.default_decompositions() del decomp_table[torch.ops.aten.conv2d.default] ep_for_inference = ep_for_training.run_decompositions(decomp_table) @@ -418,7 +418,7 @@ You can do even more customizations by directly registering custom decomp behavi :: # Lower to core aten inference IR, but customize conv2d - decomp_table = torch.export.core_aten_decompositions() + decomp_table = torch.export.default_decompositions() def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1): return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups) @@ -849,7 +849,7 @@ API Reference .. autofunction:: load .. autofunction:: register_dataclass .. autofunction:: torch.export.dynamic_shapes.Dim -.. autofunction:: torch.export.exported_program.core_aten_decompositions +.. autofunction:: torch.export.exported_program.default_decompositions .. autofunction:: dims .. autoclass:: torch.export.dynamic_shapes.ShapesCollection @@ -872,6 +872,16 @@ API Reference .. autoclass:: ModuleCallEntry +.. automodule:: torch.export.decomp_utils +.. autoclass:: CustomDecompTable + + .. automethod:: copy + .. automethod:: items + .. automethod:: keys + .. automethod:: materialize + .. automethod:: pop + .. automethod:: update + .. automodule:: torch.export.exported_program .. automodule:: torch.export.graph_signature .. autoclass:: InputKind diff --git a/docs/source/fx.experimental.rst b/docs/source/fx.experimental.rst index 128c744940ddb..d3bd9b6b0af6c 100644 --- a/docs/source/fx.experimental.rst +++ b/docs/source/fx.experimental.rst @@ -39,8 +39,6 @@ torch.fx.experimental.symbolic_shapes definitely_true definitely_false guard_size_oblivious - parallel_or - parallel_and sym_eq constrain_range constrain_unify diff --git a/docs/source/library.rst b/docs/source/library.rst index 768cdc825f126..5a66887926f5b 100644 --- a/docs/source/library.rst +++ b/docs/source/library.rst @@ -11,7 +11,7 @@ custom operators, and extending operators defined with PyTorch's C++ operator registration APIs (e.g. aten operators). For a detailed guide on effectively using these APIs, please see -:ref:`custom-ops-landing-page` +`PyTorch Custom Operators Landing Page `_ for more details on how to effectively use these APIs. Testing custom ops diff --git a/docs/source/miscellaneous_environment_variables.rst b/docs/source/miscellaneous_environment_variables.rst index f783f4c923542..14494241af9de 100644 --- a/docs/source/miscellaneous_environment_variables.rst +++ b/docs/source/miscellaneous_environment_variables.rst @@ -8,7 +8,11 @@ Miscellaneous Environment Variables * - Variable - Description * - ``TORCH_FORCE_WEIGHTS_ONLY_LOAD`` - - If set to [``1``, ``y``, ``yes``, ``true``], the torch.load will use ``weight_only=True``. For more documentation on this, see :func:`torch.load`. + - If set to [``1``, ``y``, ``yes``, ``true``], the torch.load will use ``weights_only=True``. This will happen even if + ``weights_only=False`` was passed at the callsite. For more documentation on this, see :func:`torch.load`. + * - ``TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD`` + - If set to [``1``, ``y``, ``yes``, ``true``], the torch.load will use ``weights_only=False`` if the ``weights_only`` variable was not + passed at the callsite. For more documentation on this, see :func:`torch.load`. * - ``TORCH_AUTOGRAD_SHUTDOWN_WAIT_LIMIT`` - Under some conditions, autograd threads can hang on shutdown, therefore we do not wait for them to shutdown indefinitely but rely on timeout that is default set to ``10`` seconds. This environment variable can be used to set the timeout in seconds. * - ``TORCH_DEVICE_BACKEND_AUTOLOAD`` diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst index 1f80e36a48e08..0cad49479d920 100644 --- a/docs/source/notes/extending.rst +++ b/docs/source/notes/extending.rst @@ -13,8 +13,7 @@ and have it behave like PyTorch's built-in operators. In order to do so, you mus register the custom operation with PyTorch via the Python :ref:`torch-library-docs` or C++ TORCH_LIBRARY APIs. - -Please see :ref:`custom-ops-landing-page` for more details. +Please see `PyTorch Custom Operators Landing Page `_ for more details. .. _extending-autograd: diff --git a/docs/source/notes/get_start_xpu.rst b/docs/source/notes/get_start_xpu.rst index 38b1b78db1872..7742751d5433e 100644 --- a/docs/source/notes/get_start_xpu.rst +++ b/docs/source/notes/get_start_xpu.rst @@ -15,7 +15,7 @@ Hardware Prerequisite * - Intel Client GPU - Windows/Linux -Intel GPUs support (Beta) is ready in PyTorch* 2.5 for Intel® Data Center GPU Max Series and Intel® Client GPUs on both Linux and Windows, which brings Intel GPUs and the SYCL* software stack into the official PyTorch stack with consistent user experience to embrace more AI application scenarios. +Intel GPUs support (Prototype) is ready in PyTorch* 2.5 for Intel® Data Center GPU Max Series and Intel® Client GPUs on both Linux and Windows, which brings Intel GPUs and the SYCL* software stack into the official PyTorch stack with consistent user experience to embrace more AI application scenarios. Software Prerequisite --------------------- @@ -38,11 +38,11 @@ Platform Linux Now we have all the required packages installed and environment activated. Use the following commands to install ``pytorch``, ``torchvision``, ``torchaudio`` on Linux. -For release wheels +For preview wheels .. code-block:: - pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/xpu + pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/test/xpu For nightly wheels @@ -55,17 +55,17 @@ Platform Windows Now we have all the required packages installed and environment activated. Use the following commands to install ``pytorch`` on Windows, build from source for ``torchvision`` and ``torchaudio``. -For release wheels +For preview wheels .. code-block:: - pip3 install torch --index-url https://download.pytorch.org/whl/xpu + pip3 install torch --index-url https://download.pytorch.org/whl/test/xpu For nightly wheels .. code-block:: - pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/xpu + pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/xpu From Source ^^^^^^^^^^^ diff --git a/docs/source/onnx_torchscript.rst b/docs/source/onnx_torchscript.rst index 8c8032bd26b4d..aec370f4411d5 100644 --- a/docs/source/onnx_torchscript.rst +++ b/docs/source/onnx_torchscript.rst @@ -697,7 +697,6 @@ Functions ^^^^^^^^^ .. autofunction:: export -.. autofunction:: export_to_pretty_string .. autofunction:: register_custom_op_symbolic .. autofunction:: unregister_custom_op_symbolic .. autofunction:: select_model_mode_for_export diff --git a/docs/source/optim.rst b/docs/source/optim.rst index 93d20798894a0..35c23dacc8ef9 100644 --- a/docs/source/optim.rst +++ b/docs/source/optim.rst @@ -13,7 +13,8 @@ Constructing it ^^^^^^^^^^^^^^^ To construct an :class:`Optimizer` you have to give it an iterable containing the -parameters (all should be :class:`~torch.autograd.Variable` s) to optimize. Then, +parameters (all should be :class:`~torch.nn.Parameter` s) or named parameters +(tuples of (str, :class:`~torch.nn.Parameter`)) to optimize. Then, you can specify optimizer-specific options such as the learning rate, weight decay, etc. Example:: @@ -21,6 +22,11 @@ Example:: optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) optimizer = optim.Adam([var1, var2], lr=0.0001) +Named parameters example:: + + optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9) + optimizer = optim.Adam([('layer0', var1), ('layer1', var2)], lr=0.0001) + Per-parameter options ^^^^^^^^^^^^^^^^^^^^^ @@ -38,6 +44,11 @@ For example, this is very useful when one wants to specify per-layer learning ra {'params': model.classifier.parameters()} ], lr=1e-3, momentum=0.9) + optim.SGD([ + {'params': model.base.named_parameters(), 'lr': 1e-2}, + {'params': model.classifier.named_parameters()} + ], lr=1e-3, momentum=0.9) + This means that ``model.base``'s parameters will use a learning rate of ``1e-2``, whereas ``model.classifier``'s parameters will stick to the default learning rate of ``1e-3``. Finally a momentum of ``0.9`` will be used for all parameters. @@ -303,6 +314,182 @@ algorithms. lr_scheduler.OneCycleLR lr_scheduler.CosineAnnealingWarmRestarts +How to utilize named parameters to load optimizer state dict +------------------------------------------------------------ + +The function :func:`~Optimizer.load_state_dict` stores the optional ``param_names`` content from the +loaded state dict if present. However, the process of loading the optimizer state is not affected, +as the order of the parameters matters to maintain compatibility (in case of different ordering). +To utilize the loaded parameters names from the loaded state dict, a custom ``register_load_state_dict_pre_hook`` +needs to be implemented according to the desired behavior. + +This can be useful, for instance, when the model architecture changes, but the weights and optimizer states need to +remain unchanged. The following example demonstrates how to implement this customization. + +Example:: + + class OneLayerModel(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(3, 4) + + def forward(self, x): + return self.fc(x) + + model = OneLayerModel() + optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9) + # training.. + torch.save(optimizer.state_dict(), PATH) + +Let's say that ``model`` implements an expert (MoE), and we want to duplicate it and resume training +for two experts, both initialized the same way as the ``fc`` layer. For the following ``model2`` we create two layers identical to ``fc`` and resume training by loading the model weights and optimizer states from ``model`` into both ``fc1`` and ``fc2`` of ``model2`` (and adjust them accordingly):: + + class TwoLayerModel(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(3, 4) + self.fc2 = nn.Linear(3, 4) + + def forward(self, x): + return (self.fc1(x) + self.fc2(x)) / 2 + + model2 = TwoLayerModel() + # adapt and load model weights.. + optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9) + +To load the state dict for ``optimizer2`` with the state dict of the previous optimizer such that both +``fc1`` and ``fc2`` will be initialized with a copy of ``fc`` optimizer states +(to resume training for each layer from ``fc``), we can use the following hook:: + + def adapt_state_dict_ids(optimizer, state_dict): + adapted_state_dict = deepcopy(optimizer.state_dict()) + # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict. + for k, v in state_dict['param_groups'][0].items(): + if k not in ['params', 'param_names']: + adapted_state_dict['param_groups'][0][k] = v + + lookup_dict = { + 'fc1.weight': 'fc.weight', + 'fc1.bias': 'fc.bias', + 'fc2.weight': 'fc.weight', + 'fc2.bias': 'fc.bias' + } + clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()} + for param_id, param_name in zip( + optimizer.state_dict()['param_groups'][0]['params'], + optimizer.state_dict()['param_groups'][0]['param_names']): + name_in_loaded = lookup_dict[param_name] + index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded) + id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list] + # Copy the state of the corresponding parameter + if id_in_loaded in state_dict['state']: + adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded]) + + return adapted_state_dict + + optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids) + optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict + +This ensures that the adapted state_dict with the correct states for the layers of ``model2`` will be used +during model loading. +Note that this code is designed specifically for this example (e.g., assuming a single parameter group), +and other cases might require different adaptations. + +The following example shows how to handle missing parameters in a loaded +``state dict`` when the model structure changes. +The ``Model_bypass`` adds a new ``bypass`` layer, which is not present in the original ``Model1``. +To resume training, a custom ``adapt_state_dict_missing_param`` hook is used to adapt the optimizer's ``state_dict``, +ensuring existing parameters are mapped correctly, while missing ones (like the bypass layer) remain unchanged +(as initialized in this example). +This approach enables smooth loading and resuming of the optimizer state despite model changes. +The new bypass layer will be trained from scratch:: + + class Model1(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(5, 5) + + def forward(self, x): + return self.fc(x) + x + + + model = Model1() + optimizer = optim.SGD(model.named_parameters(), lr=0.01, momentum=0.9) + # training.. + torch.save(optimizer.state_dict(), PATH) + + class Model_bypass(nn.Module): + def __init__(self): + super().__init__() + self.fc = nn.Linear(5, 5) + self.bypass = nn.Linear(5, 5, bias=False) + torch.nn.init.eye_(self.bypass.weight) + + def forward(self, x): + return self.fc(x) + self.bypass(x) + + model2 = Model_bypass() + optimizer2 = optim.SGD(model2.named_parameters(), lr=0.01, momentum=0.9) + + def adapt_state_dict_missing_param(optimizer, state_dict): + adapted_state_dict = deepcopy(optimizer.state_dict()) + # Copy setup parameters (lr, weight_decay, etc.), in case they differ in the loaded state dict. + for k, v in state_dict['param_groups'][0].items(): + if k not in ['params', 'param_names']: + adapted_state_dict['param_groups'][0][k] = v + + lookup_dict = { + 'fc.weight': 'fc.weight', + 'fc.bias': 'fc.bias', + 'bypass.weight': None, + } + + clone_deepcopy = lambda d: {k: (v.clone() if isinstance(v, torch.Tensor) else deepcopy(v)) for k, v in d.items()} + for param_id, param_name in zip( + optimizer.state_dict()['param_groups'][0]['params'], + optimizer.state_dict()['param_groups'][0]['param_names']): + name_in_loaded = lookup_dict[param_name] + if name_in_loaded in state_dict['param_groups'][0]['param_names']: + index_in_loaded_list = state_dict['param_groups'][0]['param_names'].index(name_in_loaded) + id_in_loaded = state_dict['param_groups'][0]['params'][index_in_loaded_list] + # Copy the state of the corresponding parameter + if id_in_loaded in state_dict['state']: + adapted_state_dict['state'][param_id] = clone_deepcopy(state_dict['state'][id_in_loaded]) + + return adapted_state_dict + + optimizer2.register_load_state_dict_pre_hook(adapt_state_dict_ids) + optimizer2.load_state_dict(torch.load(PATH)) # The previous optimizer saved state_dict + + + +As a third example, instead of loading a state according to the order of parameters (the default approach), +this hook can be used to load according to the parameters' names:: + + def names_matching(optimizer, state_dict): + assert len(state_dict['param_groups']) == len(optimizer.state_dict()['param_groups']) + adapted_state_dict = deepcopy(optimizer.state_dict()) + for g_ind in range(len(state_dict['param_groups'])): + assert len(state_dict['param_groups'][g_ind]['params']) == len( + optimizer.state_dict()['param_groups'][g_ind]['params']) + + for k, v in state_dict['param_groups'][g_ind].items(): + if k not in ['params', 'param_names']: + adapted_state_dict['param_groups'][g_ind][k] = v + + for param_id, param_name in zip( + optimizer.state_dict()['param_groups'][g_ind]['params'], + optimizer.state_dict()['param_groups'][g_ind]['param_names']): + index_in_loaded_list = state_dict['param_groups'][g_ind]['param_names'].index(param_name) + id_in_loaded = state_dict['param_groups'][g_ind]['params'][index_in_loaded_list] + # Copy the state of the corresponding parameter + if id_in_loaded in state_dict['state']: + adapted_state_dict['state'][param_id] = deepcopy(state_dict['state'][id_in_loaded]) + + return adapted_state_dict + + + Weight Averaging (SWA and EMA) ------------------------------ diff --git a/docs/source/torch.compiler_api.rst b/docs/source/torch.compiler_api.rst index e1c05f71c1461..bcf9772351a2c 100644 --- a/docs/source/torch.compiler_api.rst +++ b/docs/source/torch.compiler_api.rst @@ -20,6 +20,7 @@ For a quick overview of ``torch.compiler``, see :ref:`torch.compiler_overview`. assume_constant_result list_backends disable + set_stance cudagraph_mark_step_begin is_compiling is_dynamo_compiling diff --git a/docs/source/torch.compiler_get_started.rst b/docs/source/torch.compiler_get_started.rst index 2b5bec254958f..7661c884177d6 100644 --- a/docs/source/torch.compiler_get_started.rst +++ b/docs/source/torch.compiler_get_started.rst @@ -57,7 +57,7 @@ the following: .. code-block:: python - @pointwise(size_hints=[16384], filename=__file__, triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [instance_descriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}) + @pointwise(size_hints=[16384], filename=__file__, triton_meta={'signature': {'in_ptr0': '*fp32', 'out_ptr0': '*fp32', 'xnumel': 'i32'}, 'device': 0, 'constants': {}, 'mutated_arg_names': [], 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())]}) @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 10000 diff --git a/docs/source/xpu.rst b/docs/source/xpu.rst index a83bea4d1b3f8..0dfbe40ebeee8 100644 --- a/docs/source/xpu.rst +++ b/docs/source/xpu.rst @@ -13,9 +13,11 @@ torch.xpu device device_count device_of + get_arch_list get_device_capability get_device_name get_device_properties + get_gencode_flags init is_available is_initialized diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp index 47fe87c235261..304839cbaeedb 100644 --- a/functorch/csrc/dim/dim.cpp +++ b/functorch/csrc/dim/dim.cpp @@ -867,7 +867,7 @@ mpy::object Tensor::from_positional(Arena & A, at::Tensor tensor, Slice self = Tensor::create(); diff --git a/functorch/csrc/dim/python_variable_simple.h b/functorch/csrc/dim/python_variable_simple.h index caae566107600..fbd5cfd828157 100644 --- a/functorch/csrc/dim/python_variable_simple.h +++ b/functorch/csrc/dim/python_variable_simple.h @@ -26,7 +26,7 @@ struct THPVariable { TORCH_PYTHON_API extern PyObject *THPVariableClass; TORCH_PYTHON_API extern PyObject *ParameterClass; -TORCH_PYTHON_API PyObject * THPVariable_Wrap(at::TensorBase var); +TORCH_PYTHON_API PyObject * THPVariable_Wrap(const at::TensorBase& var); inline bool THPVariable_Check(PyObject *obj) { diff --git a/functorch/dim/dim.py b/functorch/dim/dim.py index cbafce2f0ee0c..9a4b568664849 100644 --- a/functorch/dim/dim.py +++ b/functorch/dim/dim.py @@ -32,8 +32,7 @@ def __del__(self): if self._vmap_level is not None: _vmap_active_levels[self._vmap_stack].alive = False # noqa: F821 while ( - not _vmap_levels[-1].alive - and current_level() == _vmap_levels[-1].level # noqa: F821 + not _vmap_levels[-1].alive and current_level() == _vmap_levels[-1].level # noqa: F821 ): _vmap_decrement_nesting() # noqa: F821 _vmap_levels.pop() diff --git a/functorch/einops/_parsing.py b/functorch/einops/_parsing.py index ffb1fc00a20ee..ee69aa60d1a58 100644 --- a/functorch/einops/_parsing.py +++ b/functorch/einops/_parsing.py @@ -22,6 +22,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations import keyword @@ -283,16 +284,16 @@ def comma_separate(collection: Collection[Union[str, Collection[str]]]) -> str: str: the comma-separated string Examples: - >>> comma_separate(('d0',)) + >>> comma_separate(("d0",)) 'd0' - >>> comma_separate(('d0', 'd1', 'd2', 'd3')) + >>> comma_separate(("d0", "d1", "d2", "d3")) 'd0, d1, d2, d3' - >>> comma_separate([('d1', 'd4')]) + >>> comma_separate([("d1", "d4")]) '(d1, d4)' - >>> comma_separate([('d0',), (), ('d1',), ('d2',), ('d3', 'd4')]) + >>> comma_separate([("d0",), (), ("d1",), ("d2",), ("d3", "d4")]) '(d0,), (), (d1,), (d2,), (d3, d4)' """ return ", ".join( diff --git a/functorch/einops/rearrange.py b/functorch/einops/rearrange.py index 1cd3cd8b3cf64..a0bceed738834 100644 --- a/functorch/einops/rearrange.py +++ b/functorch/einops/rearrange.py @@ -95,7 +95,7 @@ def _create_rearrange_callable( raise ValueError(f"Unexpected dimension: {dimension}") def composition_to_dims( - composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]] + composition: Sequence[Union[List[Union[str, AnonymousAxis]], str]], ) -> List[Union[str, Tuple[str, ...]]]: """Convert a `ParsedExpression.composition` into a `Tensor.__getitem__` index of strings representing first class dims.""" @@ -171,31 +171,31 @@ def rearrange( >>> images = torch.randn((32, 30, 40, 3)) >>> # stack along first (batch) axis, output is a single array - >>> rearrange(images, 'b h w c -> b h w c').shape + >>> rearrange(images, "b h w c -> b h w c").shape torch.Size([32, 30, 40, 3]) >>> # concatenate images along height (vertical axis), 960 = 32 * 30 - >>> rearrange(images, 'b h w c -> (b h) w c').shape + >>> rearrange(images, "b h w c -> (b h) w c").shape torch.Size([960, 40, 3]) >>> # concatenated images along horizontal axis, 1280 = 32 * 40 - >>> rearrange(images, 'b h w c -> h (b w) c').shape + >>> rearrange(images, "b h w c -> h (b w) c").shape torch.Size([30, 1280, 3]) >>> # reordered axes to "b c h w" format for deep learning - >>> rearrange(images, 'b h w c -> b c h w').shape + >>> rearrange(images, "b h w c -> b c h w").shape torch.Size([32, 3, 30, 40]) >>> # flattened each image into a vector, 3600 = 30 * 40 * 3 - >>> rearrange(images, 'b h w c -> b (c h w)').shape + >>> rearrange(images, "b h w c -> b (c h w)").shape torch.Size([32, 3600]) >>> # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2 - >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape + >>> rearrange(images, "b (h1 h) (w1 w) c -> (b h1 w1) h w c", h1=2, w1=2).shape torch.Size([128, 15, 20, 3]) >>> # space-to-depth operation - >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape + >>> rearrange(images, "b (h h1) (w w1) c -> b h w (c h1 w1)", h1=2, w1=2).shape torch.Size([32, 15, 20, 12]) """ if not isinstance(tensor, torch.Tensor): diff --git a/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py b/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py index 9067c7b75bcc6..35696675305e9 100755 --- a/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py +++ b/functorch/examples/maml_omniglot/maml-omniglot-ptonly.py @@ -152,7 +152,7 @@ def train(db, net, device, meta_opt, epoch, log): spt_logits = fnet(new_params, buffers, x_spt[i]) spt_loss = F.cross_entropy(spt_logits, y_spt[i]) grads = torch.autograd.grad(spt_loss, new_params, create_graph=True) - new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] + new_params = [p - g * 1e-1 for p, g in zip(new_params, grads)] # The final set of adapted parameters will induce some # final loss and accuracy on the query dataset. @@ -215,7 +215,7 @@ def test(db, net, device, epoch, log): spt_logits = fnet(new_params, buffers, x_spt[i]) spt_loss = F.cross_entropy(spt_logits, y_spt[i]) grads = torch.autograd.grad(spt_loss, new_params) - new_params = [p - g * 1e-1 for p, g, in zip(new_params, grads)] + new_params = [p - g * 1e-1 for p, g in zip(new_params, grads)] # The query loss and acc induced by these parameters. qry_logits = fnet(new_params, buffers, x_qry[i]).detach() diff --git a/functorch/examples/maml_omniglot/maml-omniglot-transforms.py b/functorch/examples/maml_omniglot/maml-omniglot-transforms.py index cbc28ac1ee577..be44863d36f4e 100755 --- a/functorch/examples/maml_omniglot/maml-omniglot-transforms.py +++ b/functorch/examples/maml_omniglot/maml-omniglot-transforms.py @@ -132,7 +132,7 @@ def compute_loss(new_params, buffers, x, y): new_params = params for _ in range(n_inner_iter): grads = grad(compute_loss)(new_params, buffers, x_spt, y_spt) - new_params = {k: new_params[k] - g * 1e-1 for k, g, in grads.items()} + new_params = {k: new_params[k] - g * 1e-1 for k, g in grads.items()} # The final set of adapted parameters will induce some # final loss and accuracy on the query dataset. @@ -216,7 +216,7 @@ def test(db, net, device, epoch, log): spt_loss = F.cross_entropy(spt_logits, y_spt[i]) grads = torch.autograd.grad(spt_loss, new_params.values()) new_params = { - k: new_params[k] - g * 1e-1 for k, g, in zip(new_params, grads) + k: new_params[k] - g * 1e-1 for k, g in zip(new_params, grads) } # The query loss and acc induced by these parameters. diff --git a/functorch/examples/maml_omniglot/support/omniglot_loaders.py b/functorch/examples/maml_omniglot/support/omniglot_loaders.py index 7e54d3584a871..4390caa717b58 100644 --- a/functorch/examples/maml_omniglot/support/omniglot_loaders.py +++ b/functorch/examples/maml_omniglot/support/omniglot_loaders.py @@ -169,9 +169,7 @@ def __init__(self, root, batchsz, n_way, k_shot, k_query, imgsz, device=None): ), ) - temp = ( - {} - ) # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label} + temp = {} # {label:img1, img2..., 20 imgs, label2: img1, img2,... in total, 1623 label} for img, label in self.x: if label in temp.keys(): temp[label].append(img) diff --git a/functorch/notebooks/_src/plot_ensembling.py b/functorch/notebooks/_src/plot_ensembling.py index 55554a1985b43..f720f3a612717 100644 --- a/functorch/notebooks/_src/plot_ensembling.py +++ b/functorch/notebooks/_src/plot_ensembling.py @@ -16,6 +16,7 @@ Let's demonstrate how to do this using an ensemble of simple CNNs. """ + import torch import torch.nn as nn import torch.nn.functional as F diff --git a/functorch/notebooks/_src/plot_jacobians_and_hessians.py b/functorch/notebooks/_src/plot_jacobians_and_hessians.py index 295810675ea02..3faeaa9a16752 100644 --- a/functorch/notebooks/_src/plot_jacobians_and_hessians.py +++ b/functorch/notebooks/_src/plot_jacobians_and_hessians.py @@ -8,6 +8,7 @@ efficiently using a standard autodiff system like PyTorch Autograd; functorch provides ways of computing various higher-order autodiff quantities efficiently. """ + from functools import partial import torch diff --git a/functorch/notebooks/_src/plot_per_sample_gradients.py b/functorch/notebooks/_src/plot_per_sample_gradients.py index 98e850e5ce002..c39e9a1794f2a 100644 --- a/functorch/notebooks/_src/plot_per_sample_gradients.py +++ b/functorch/notebooks/_src/plot_per_sample_gradients.py @@ -9,6 +9,7 @@ sample in a batch of data. It is a useful quantity in differential privacy and optimization research. """ + import torch import torch.nn as nn import torch.nn.functional as F diff --git a/ios/TestApp/Gemfile.lock b/ios/TestApp/Gemfile.lock index 4dc5c72263c04..f191218073b4b 100644 --- a/ios/TestApp/Gemfile.lock +++ b/ios/TestApp/Gemfile.lock @@ -196,7 +196,7 @@ GEM unf_ext unf_ext (0.0.8.2) unicode-display_width (1.8.0) - webrick (1.7.0) + webrick (1.8.2) word_wrap (1.0.0) xcodeproj (1.19.0) CFPropertyList (>= 2.3.3, < 4.0) diff --git a/pyproject.toml b/pyproject.toml index 1e7def7ec492f..c15594e54a737 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ requires = [ "ninja", "pyyaml", "cmake", - "typing-extensions", + "typing-extensions>=4.10.0", "requests", ] # Use legacy backend to import local packages in setup.py @@ -150,7 +150,7 @@ select = [ "RUF026", # default factory kwarg "TCH", "TRY002", # ban vanilla raise (todo fix NOQAs) - "TRY302", + "TRY203", "TRY401", # verbose-log-message "UP", ] diff --git a/requirements.txt b/requirements.txt index d087698b4b9e0..6ce86e87d8927 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,14 +11,13 @@ requests # is required until pytorch build not refactored to work for latest setuptools. setuptools<=72.1.0 types-dataclasses -typing-extensions>=4.8.0 -sympy==1.12.1 ; python_version == "3.8" +typing-extensions>=4.10.0 sympy==1.13.1 ; python_version >= "3.9" filelock networkx jinja2 fsspec -lintrunner +lintrunner ; platform_system != "Windows" ninja packaging optree>=0.13.0 diff --git a/setup.py b/setup.py index 25c1d53495f81..576b635c8a260 100644 --- a/setup.py +++ b/setup.py @@ -119,6 +119,10 @@ # These are not CUDA versions, instead, they specify what # classes of NVIDIA hardware we should generate PTX for. # +# TORCH_XPU_ARCH_LIST +# specify which XPU architectures to build for. +# ie `TORCH_XPU_ARCH_LIST="ats-m150,lnl-m"` +# # PYTORCH_ROCM_ARCH # specify which AMD GPU targets to build for. # ie `PYTORCH_ROCM_ARCH="gfx900;gfx906"` @@ -183,7 +187,21 @@ # USE_SYSTEM_LIBS (work in progress) # Use system-provided libraries to satisfy the build dependencies. # When turned on, the following cmake variables will be toggled as well: -# USE_SYSTEM_CPUINFO=ON USE_SYSTEM_SLEEF=ON BUILD_CUSTOM_PROTOBUF=OFF +# USE_SYSTEM_CPUINFO=ON +# USE_SYSTEM_SLEEF=ON +# USE_SYSTEM_GLOO=ON +# BUILD_CUSTOM_PROTOBUF=OFF +# USE_SYSTEM_EIGEN_INSTALL=ON +# USE_SYSTEM_FP16=ON +# USE_SYSTEM_PTHREADPOOL=ON +# USE_SYSTEM_PSIMD=ON +# USE_SYSTEM_FXDIV=ON +# USE_SYSTEM_BENCHMARK=ON +# USE_SYSTEM_ONNX=ON +# USE_SYSTEM_XNNPACK=ON +# USE_SYSTEM_PYBIND11=ON +# USE_SYSTEM_NCCL=ON +# USE_SYSTEM_NVTX=ON # # USE_MIMALLOC # Static link mimalloc into C10, and use mimalloc in alloc_cpu & alloc_free. @@ -1141,9 +1159,8 @@ def main(): ) install_requires = [ "filelock", - "typing-extensions>=4.8.0", + "typing-extensions>=4.10.0", 'setuptools ; python_version >= "3.12"', - 'sympy==1.12.1 ; python_version == "3.8"', 'sympy==1.13.1 ; python_version >= "3.9"', "networkx", "jinja2", @@ -1235,6 +1252,7 @@ def main(): "include/*.h", "include/ATen/*.h", "include/ATen/cpu/*.h", + "include/ATen/cpu/vec/vec128/*.h", "include/ATen/cpu/vec/vec256/*.h", "include/ATen/cpu/vec/vec256/vsx/*.h", "include/ATen/cpu/vec/vec256/zarch/*.h", diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 9107da9a37cfe..f2908243477c9 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -262,7 +262,9 @@ "Future" ], "torch.fx": [ + "PH", "ProxyableClassMeta", + "CodeGen", "Tracer", "symbolic_trace", "wrap" diff --git a/test/ao/sparsity/test_kernels.py b/test/ao/sparsity/test_kernels.py index e34d53349d114..7e4337ba431da 100644 --- a/test/ao/sparsity/test_kernels.py +++ b/test/ao/sparsity/test_kernels.py @@ -261,7 +261,6 @@ def forward(self, x): class TestQuantizedSparseLayers(TestCase): @override_qengines - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") def test_sparse_qlinear(self): # Note: At the moment, for sparse kernels # fbgemm supports only static quantized sparse linear @@ -294,7 +293,6 @@ def test_sparse_qlinear(self): ) @override_qengines - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") def test_sparse_qlinear_serdes(self): # Note: At the moment, for sparse kernels # fbgemm supports only static quantized sparse linear diff --git a/test/benchmark_utils/callgrind_artifacts.json b/test/benchmark_utils/callgrind_artifacts.json index d4cdcdd7804fa..f9f8ce13d3bb8 100644 --- a/test/benchmark_utils/callgrind_artifacts.json +++ b/test/benchmark_utils/callgrind_artifacts.json @@ -159,41 +159,41 @@ "5411822 build/../torch/csrc/autograd/generated/variable_factories.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", "5241822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::TensorOptions const&)", "5130822 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "5114822 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "4964822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "4943822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "4682822 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", - "4660822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "4597822 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "4586822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "4372822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "4352822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "4091822 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", - "4069822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", - "4006822 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", - "3995822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "3905822 build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "5114822 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "4964822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "4943822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "4682822 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const", + "4660822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "4597822 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "4586822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "4372822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "4352822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "4091822 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const'2", + "4069822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)'2", + "4006822 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)'2", + "3995822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "3905822 build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "3831822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", "3742822 build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "3718822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/native/TensorFactories.cpp:at::native::ones(c10::ArrayRef, c10::TensorOptions const&)", "3715822 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "3702822 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "2526822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/Functions.cpp:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "2438822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "2422822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "2209822 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", - "2198822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "2183822 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "2178822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "1934822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "1917822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "1704822 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", - "1693822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", - "1678822 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", - "1673822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "1669822 build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "1658822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "1433822 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "2526822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/Functions.cpp:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "2438822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "2422822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "2209822 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const", + "2198822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "2183822 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "2178822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "1934822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "1917822 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "1704822 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const'2", + "1693822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)'2", + "1678822 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)'2", + "1673822 /data/users/test_user/repos/pytorch/build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "1669822 build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "1658822 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "1433822 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "1112000 /data/users/test_user/repos/pytorch/build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::fill_(c10::Scalar) const", "1098500 build/../torch/csrc/autograd/generated/python_torch_functions.cpp:torch::autograd::THPVariable_ones(_object*, _object*, _object*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "1062157 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_SetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", @@ -246,7 +246,7 @@ "209209 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_RestoreThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "205609 /tmp/build/80754af9/python_1599604603603/work/Objects/moduleobject.c:module_getattro [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "197500 /data/users/test_user/repos/pytorch/build/../torch/csrc/autograd/utils/wrap_outputs.h:torch::autograd::utils::wrap(at::Tensor)", - "196000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "196000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "192000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::~RecordFunction() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "192000 build/../c10/core/Device.h:c10::Device::validate() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "191567 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GenericGetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", @@ -258,7 +258,7 @@ "179500 /data/users/test_user/repos/pytorch/build/../torch/csrc/autograd/python_variable.cpp:THPVariable_Wrap(at::Tensor)", "178000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "173500 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_Wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", - "171000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "171000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "170175 ???:_int_malloc [/usr/lib64/libc-2.28.so]", "169000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorImpl.h:c10::TensorImpl::empty_tensor_restride(c10::MemoryFormat)", "168000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/new_op.cc:operator new(unsigned long) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", @@ -293,14 +293,14 @@ "100000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const", "98098 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_SaveThread", "95000 /data/users/test_user/repos/pytorch/build/../torch/csrc/autograd/variable.h:torch::autograd::make_variable(at::Tensor, bool, bool)", - "95000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "95000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "94000 build/../c10/core/StorageImpl.h:c10::TensorImpl::release_resources()", "92821 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:lookdict_unicode_nodummy [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "91000 /data/users/test_user/repos/pytorch/build/../torch/csrc/tensor/python_tensor.cpp:torch::tensors::get_default_scalar_type()", "91000 build/../torch/csrc/tensor/python_tensor.cpp:torch::tensors::get_default_scalar_type() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "90090 ???:pthread_mutex_lock [/usr/lib64/libpthread-2.28.so]", "90000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorImpl.h:c10::VariableVersion::VersionCounter::~VersionCounter()", - "90000 /data/users/test_user/repos/pytorch/build/../c10/util/Optional.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(c10::optional)", + "90000 /data/users/test_user/repos/pytorch/build/../c10/util/Optional.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(std::optional)", "90000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::end() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "90000 build/../c10/core/TensorImpl.h:c10::VariableVersion::VersionCounter::~VersionCounter() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "88000 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_ >, bool (at::Tensor const&)>::call(c10::OperatorKernel*, at::Tensor const&)" @@ -327,24 +327,24 @@ "90000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::end() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "84338 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_GetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "84000 build/../c10/util/SmallVector.h:at::RecordFunction::RecordFunction(at::RecordScope)", - "78000 build/../c10/core/TensorOptions.h:c10::computeDispatchKey(c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "78000 build/../c10/core/TensorOptions.h:c10::computeDispatchKey(std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "78000 build/../torch/csrc/autograd/generated/python_torch_functions.cpp:torch::autograd::THPVariable_ones(_object*, _object*, _object*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "74710 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_SetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "72000 /usr/include/c++/8/bits/stl_vector.h:at::RecordFunction::~RecordFunction()", "67000 build/../c10/util/intrusive_ptr.h:c10::intrusive_ptr::reset_() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "66066 ???:__pthread_mutex_unlock_usercnt [/usr/lib64/libpthread-2.28.so]", "64110 /tmp/build/80754af9/python_1599604603603/work/Objects/methodobject.c:_PyCFunction_FastCallDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", - "64000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "64000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "64000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::FunctionParameter::check(_object*, std::vector >&, int) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "61182 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:call_function [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "60061 /tmp/build/80754af9/python_1599604603603/work/Objects/tupleobject.c:PyTuple_New [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "59177 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:_PyObject_GenericGetAttrWithDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", - "59000 build/../c10/util/Optional.h:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "59000 build/../c10/util/Optional.h:at::(anonymous namespace)::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "59000 build/../torch/csrc/utils/python_arg_parser.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", "57000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "55000 /tmp/build/80754af9/python_1599604603603/work/Objects/tupleobject.c:tupledealloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "54000 /usr/include/c++/8/bits/stl_vector.h:at::RecordFunction::RecordFunction(at::RecordScope)", - "52000 build/../c10/util/Optional.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "52000 build/../c10/util/Optional.h:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "50000 build/../c10/util/ThreadLocalDebugInfo.cpp:c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "50000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "49049 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:lookdict_unicode [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", @@ -355,14 +355,14 @@ "45000 ???:_mid_memalign [/usr/lib64/libc-2.28.so]", "44044 ???:pthread_cond_signal@@GLIBC_2.3.2 [/usr/lib64/libpthread-2.28.so]", "44000 build/../c10/core/CPUAllocator.cpp:c10::alloc_cpu(unsigned long) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", - "44000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", - "43000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", - "43000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "44000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "43000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const", + "43000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const'2", "42000 build/../c10/util/typeid.h:c10::typeMetaToScalarType(caffe2::TypeMeta)", "41000 build/../aten/src/ATen/native/Fill.cpp:at::native::fill_out(at::Tensor&, c10::Scalar) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "41000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "40000 build/../c10/core/TensorOptions.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", - "39000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "39000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "37111 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:_PyType_Lookup [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "36613 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:PyObject_Malloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "36000 /usr/include/c++/8/bits/stl_construct.h:at::RecordFunction::~RecordFunction()", @@ -370,21 +370,21 @@ "36000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::intlist(int) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "35035 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_RestoreThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "35035 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_RestoreThread", - "35000 build/../c10/core/TensorOptions.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "35000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "35000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", + "35000 build/../c10/core/TensorOptions.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "35000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "35000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)'2", "34000 /tmp/build/80754af9/python_1599604603603/work/Objects/weakrefobject.c:PyObject_ClearWeakRefs [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "34000 build/../c10/core/impl/InlineDeviceGuard.h:c10::impl::InlineDeviceGuard::InlineDeviceGuard(c10::Device)", "33066 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:_PyObject_Free [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "33000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/new_op.cc:operator new(unsigned long) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", - "33000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "33000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "33000 build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "33000 build/../c10/core/TensorImpl.h:c10::TensorImpl::set_sizes_contiguous(c10::ArrayRef) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "33000 build/../torch/csrc/autograd/variable.h:torch::autograd::make_variable(at::Tensor, bool, bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", - "33000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "32000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "32000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "32000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "32000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "32000 build/../c10/core/Allocator.cpp:c10::memoryProfilingEnabled() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "31000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "31000 build/../c10/util/SmallVector.h:c10::TensorImpl::empty_tensor_restride(c10::MemoryFormat)", @@ -399,8 +399,8 @@ "27000 ???:posix_memalign [/usr/lib64/libc-2.28.so]", "27000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const", "26000 build/../c10/core/TensorImpl.h:c10::TensorImpl::data() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "26000 build/../c10/core/TensorOptions.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "26000 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "26000 build/../c10/core/TensorOptions.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "26000 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "25000 /usr/include/c++/8/bits/stl_vector.h:torch::PythonArgs::intlist(int)", "25000 build/../c10/core/CPUAllocator.cpp:c10::ProfiledCPUMemoryReporter::Delete(void*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "25000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::~TensorImpl()", @@ -414,44 +414,44 @@ "25000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::device(int) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "25000 build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "24000 build/../c10/core/DispatchKey.cpp:c10::getAutogradKeyFromBackend(c10::DispatchKey) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", - "24000 build/aten/src/ATen/Functions.cpp:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "24000 build/aten/src/ATen/Functions.cpp:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "24000 build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "23000 build/../aten/src/ATen/core/LegacyTypeDispatch.h:at::AutoNonVariableTypeMode::AutoNonVariableTypeMode(bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "23000 build/../c10/core/CPUAllocator.cpp:c10::DefaultCPUAllocator::allocate(unsigned long) const [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", - "23000 build/../c10/core/TensorOptions.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "23000 build/../c10/core/TensorOptions.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.484] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "22044 /home/test_user/miniconda3/envs/throwaway/include/pybind11/detail/internals.h:pybind11::detail::get_internals() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "22000 /usr/include/c++/8/bits/shared_ptr_base.h:c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind)", - "22000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", - "22000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", - "22000 build/../c10/core/TensorOptions.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "22000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const", + "22000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const'2", + "22000 build/../c10/core/TensorOptions.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.469] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "21021 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_SaveThread", "20035 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_PyStack_AsTuple [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", - "20000 build/../c10/util/Optional.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "20000 build/../c10/util/Optional.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "20000 build/../torch/csrc/autograd/generated/variable_factories.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", "20000 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_Wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "20000 build/../torch/csrc/autograd/variable.cpp:torch::autograd::impl::get_autograd_meta(at::Tensor const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "19019 /tmp/build/80754af9/python_1599604603603/work/Modules/gcmodule.c:_PyObject_GC_Malloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", - "19000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "19000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "19000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "19000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "19000 build/../aten/src/ATen/native/TypeProperties.cpp:at::native::is_complex(at::Tensor const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "19000 build/../c10/util/SmallVector.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "19000 build/../c10/util/SmallVector.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", "18054 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "18000 /usr/include/c++/8/bits/shared_ptr_base.h:at::RecordFunction::~RecordFunction()", "18000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor::fill_(c10::Scalar) const", "18000 build/../aten/src/ATen/native/TensorFactories.h:at::native::check_size_nonnegative(c10::ArrayRef) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "18000 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_NewWithVar(_typeobject*, at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "17000 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:pybind11::gil_scoped_release::gil_scoped_release(bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", - "17000 /usr/include/c++/8/new:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "17000 /usr/include/c++/8/new:at::(anonymous namespace)::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "17000 build/../torch/csrc/tensor/python_tensor.cpp:torch::tensors::get_default_dispatch_key() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "17000 build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::fill_(c10::Scalar) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "16064 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_Py_CheckFunctionResult [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", - "16000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "16000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "16000 build/../c10/util/Exception.cpp:c10::Warning::set_warning_handler(c10::WarningHandler*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "16000 build/../c10/util/intrusive_ptr.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", "16000 build/../c10/util/intrusive_ptr.h:torch::autograd::make_variable(at::Tensor, bool, bool)", "16000 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "15000 /usr/include/c++/8/new:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "15000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "15000 /usr/include/c++/8/new:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "15000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "15000 build/../c10/core/ScalarType.h:at::native::is_complex(at::Tensor const&)", "15000 build/../c10/core/TensorOptions.h:c10::TensorOptions::computeDispatchKey() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "15000 build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(c10::DispatchKeySet) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", @@ -464,7 +464,7 @@ "14000 build/../c10/core/ScalarType.h:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda()", "14000 build/../c10/core/TensorImpl.h:c10::TensorImpl::~TensorImpl() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "14000 build/../c10/util/Optional.h:at::ones(c10::ArrayRef, c10::TensorOptions const&)", - "14000 build/../c10/util/Optional.h:c10::computeDispatchKey(c10::optional, c10::optional, c10::optional)", + "14000 build/../c10/util/Optional.h:c10::computeDispatchKey(std::optional, std::optional, std::optional)", "14000 build/../c10/util/typeid.h:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda()", "14000 build/../c10/util/typeid.h:at::native::is_complex(at::Tensor const&)", "14000 build/aten/src/ATen/core/TensorBody.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", @@ -476,60 +476,60 @@ "13000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::scalartype(int) [clone .isra.180] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "12000 /tmp/build/80754af9/python_1599604603603/work/Modules/gcmodule.c:PyObject_GC_Del [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "12000 /usr/include/c++/8/bits/shared_ptr_base.h:at::RecordFunction::RecordFunction(at::RecordScope)", - "12000 /usr/include/c++/8/new:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", - "12000 /usr/include/c++/8/new:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "12000 /usr/include/c++/8/new:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const", + "12000 /usr/include/c++/8/new:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const'2", "12000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const", - "12000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "12000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "12000 build/../c10/core/TensorImpl.h:c10::TensorImpl::compute_contiguous() const", "12000 build/../c10/util/intrusive_ptr.h:c10::intrusive_ptr >::reset_() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "11011 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_SaveThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "11000 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:pybind11::gil_scoped_release::~gil_scoped_release() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", - "11000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", - "11000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", - "11000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "11000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", - "11000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "11000 build/../c10/util/Optional.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "11000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const", + "11000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const'2", + "11000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "11000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)'2", + "11000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "11000 build/../c10/util/Optional.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "11000 build/../torch/csrc/autograd/utils/wrap_outputs.h:torch::autograd::utils::wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", - "11000 build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "11000 build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "10000 build/../c10/core/CPUAllocator.cpp:c10::profiledCPUMemoryReporter() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "10000 build/../c10/util/Exception.cpp:c10::Warning::get_warning_handler() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "10000 build/../c10/util/intrusive_ptr.h:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&)", "10000 build/../c10/util/intrusive_ptr.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", - "10000 build/../c10/util/llvmMathExtras.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "10000 build/../c10/util/llvmMathExtras.h:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "9009 ???:pthread_mutex_unlock [/usr/lib64/libpthread-2.28.so]", - "9000 /usr/include/c++/8/new:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "9000 /usr/include/c++/8/new:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", - "9000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "9000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "9000 /usr/include/c++/8/new:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "9000 /usr/include/c++/8/new:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)'2", + "9000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "9000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "9000 build/../c10/core/Device.h:at::native::fill_out(at::Tensor&, c10::Scalar)", "9000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::release_resources() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "9000 build/../c10/core/TensorOptions.h:c10::TensorOptions::TensorOptions() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", - "9000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", - "9000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", - "9000 build/../c10/util/Optional.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "9000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", - "8000 /usr/include/c++/8/bits/atomic_base.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "9000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const", + "9000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const'2", + "9000 build/../c10/util/Optional.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "9000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", + "8000 /usr/include/c++/8/bits/atomic_base.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", "8000 /usr/include/c++/8/bits/stl_vector.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", "8000 /usr/include/c++/8/bits/unique_ptr.h:c10::TensorImpl::release_resources()", "8000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor::is_complex() const", "8000 build/../aten/src/ATen/detail/CPUGuardImpl.h:at::detail::CPUGuardImpl::getDevice() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "8000 build/../c10/core/CPUAllocator.cpp:c10::GetCPUAllocator() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", - "8000 build/../c10/core/DeviceGuard.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(c10::optional)", + "8000 build/../c10/core/DeviceGuard.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(std::optional)", "8000 build/../c10/core/DispatchKeySet.h:c10::DispatchKeySet::has(c10::DispatchKey) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "8000 build/../c10/core/StorageImpl.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "8000 build/../c10/core/StorageImpl.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "8000 build/../c10/core/impl/DeviceGuardImplInterface.h:c10::impl::getDeviceGuardImpl(c10::DeviceType) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "8000 build/../c10/core/impl/VirtualGuardImpl.h:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", - "8000 build/../c10/util/Optional.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "8000 build/../c10/util/Optional.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "8000 build/../c10/util/Optional.h:c10::TensorOptions::computeDispatchKey() const", "8000 build/../c10/util/SmallVector.h:c10::TensorImpl::~TensorImpl()", - "8000 build/../c10/util/llvmMathExtras.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "8000 build/../c10/util/llvmMathExtras.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "8000 build/../c10/util/llvmMathExtras.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "8000 build/../c10/util/llvmMathExtras.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", "8000 build/aten/src/ATen/core/TensorBody.h:at::native::fill_out(at::Tensor&, c10::Scalar)", "7035 /tmp/build/80754af9/python_1599604603603/work/Python/errors.c:PyErr_Occurred [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "7000 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:_PyObject_GetDictPtr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "7000 build/../c10/core/Scalar.h:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda()", - "7000 build/../c10/core/ScalarType.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469]", + "7000 build/../c10/core/ScalarType.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.469]", "7000 build/../c10/core/StorageImpl.h:c10::TensorImpl::release_resources()", "7000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "7000 build/../c10/core/impl/VirtualGuardImpl.h:c10::optional_base >::~optional_base()", @@ -545,35 +545,35 @@ "6000 /usr/include/c++/8/bits/move.h:torch::PythonArgs::intlist(int)", "6000 /usr/include/c++/8/bits/shared_ptr_base.h:c10::memoryProfilingEnabled()", "6000 /usr/include/c++/8/bits/stl_iterator.h:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool)", - "6000 /usr/include/c++/8/bits/unique_ptr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "6000 /usr/include/c++/8/bits/unique_ptr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "6000 /usr/include/c++/8/bits/unique_ptr.h:c10::TensorImpl::~TensorImpl()", - "6000 /usr/include/c++/8/tuple:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "6000 /usr/include/c++/8/tuple:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "6000 build/../aten/src/ATen/record_function.h:at::RecordFunction::RecordFunction(at::RecordScope)", "6000 build/../c10/core/Allocator.cpp:c10::GetAllocator(c10::DeviceType const&) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "6000 build/../c10/core/Device.h:at::detail::CPUGuardImpl::getDevice() const", "6000 build/../c10/core/DispatchKeySet.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", - "6000 build/../c10/core/DispatchKeySet.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", - "6000 build/../c10/core/ScalarType.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484]", + "6000 build/../c10/core/DispatchKeySet.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", + "6000 build/../c10/core/ScalarType.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.484]", "6000 build/../c10/core/TensorImpl.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", "6000 build/../c10/core/TensorImpl.h:at::Tensor::device() const", "6000 build/../c10/core/TensorOptions.h:torch::utils::maybe_initialize_cuda(c10::TensorOptions const&)", - "6000 build/../c10/util/Optional.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", - "6000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "6000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", + "6000 build/../c10/util/Optional.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", + "6000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "6000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)'2", "6000 build/../c10/util/SmallVector.h:c10::TensorImpl::compute_contiguous() const", "6000 build/../c10/util/TypeCast.h:float c10::checked_convert(double, char const*)", "6000 build/../c10/util/intrusive_ptr.h:THPVariable_Wrap(at::Tensor)", - "6000 build/../c10/util/intrusive_ptr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "6000 build/../c10/util/intrusive_ptr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "6000 build/../torch/csrc/tensor/python_tensor.cpp:torch::tensors::get_default_scalar_type() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "6000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**)", "5000 /tmp/build/80754af9/python_1599604603603/_build_env/x86_64-conda_cos6-linux-gnu/sysroot/usr/include/bits/string3.h:PyType_GenericAlloc", - "5000 /usr/include/c++/8/bits/atomic_base.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "5000 /usr/include/c++/8/bits/atomic_base.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "5000 build/../aten/src/ATen/DeviceGuard.h:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar)", - "5000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "5000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", + "5000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "5000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)'2", "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::Tensor::fill_(c10::Scalar) const", - "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "5000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::(anonymous namespace)::infer_full_options(c10::Scalar, c10::TensorOptions const&) [clone .isra.262] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "5000 build/../c10/core/Device.h:torch::PythonArgs::device(int)", "5000 build/../c10/core/DispatchKeySet.h:at::Tensor::fill_(c10::Scalar) const", @@ -581,12 +581,12 @@ "5000 build/../c10/core/TensorImpl.h:at::Tensor::is_quantized() const", "5000 build/../c10/core/TensorOptions.h:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", "5000 build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", - "5000 build/../c10/util/Optional.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "5000 build/../c10/util/Optional.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "5000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::release_resources()", "5000 build/../torch/csrc/utils/cuda_lazy_init.h:torch::utils::maybe_initialize_cuda(c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "4004 ???:__errno_location [/usr/lib64/libpthread-2.28.so]", - "4000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", - "4000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "4000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const", + "4000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const'2", "4000 /usr/include/c++/8/bits/atomic_base.h:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&)", "4000 /usr/include/c++/8/bits/atomic_base.h:c10::impl::getDeviceGuardImpl(c10::DeviceType)", "4000 /usr/include/c++/8/bits/atomic_base.h:torch::autograd::make_variable(at::Tensor, bool, bool)", @@ -594,28 +594,28 @@ "4000 /usr/include/c++/8/bits/unique_ptr.h:c10::TensorImpl::set_autograd_meta(std::unique_ptr >) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "4000 /usr/include/c++/8/cmath:float c10::checked_convert(double, char const*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "4000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::Tensor::is_complex() const", - "4000 build/../c10/core/Allocator.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "4000 build/../c10/core/Device.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "4000 build/../c10/core/Allocator.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "4000 build/../c10/core/Device.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "4000 build/../c10/core/DeviceGuard.h:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", "4000 build/../c10/core/DispatchKeySet.h:c10::impl::ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(c10::DispatchKeySet)", "4000 build/../c10/core/TensorImpl.h:torch::autograd::impl::set_pyobj(at::Tensor const&, _object*)", - "4000 build/../c10/util/Optional.h:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "4000 build/../c10/util/Optional.h:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "4000 build/../c10/util/Optional.h:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar)", - "4000 build/../c10/util/Optional.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484]", + "4000 build/../c10/util/Optional.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.484]", "4000 build/../c10/util/SmallVector.h:c10::TensorImpl::sizes() const [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", - "4000 build/../c10/util/UniqueVoidPtr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "4000 build/../c10/util/UniqueVoidPtr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "4000 build/../c10/util/intrusive_ptr.h:THPVariable_NewWithVar(_typeobject*, at::Tensor)", "4000 build/../torch/csrc/autograd/variable.cpp:torch::autograd::impl::pyobj(at::Tensor const&)", "4000 build/../torch/csrc/autograd/variable.cpp:torch::autograd::impl::set_pyobj(at::Tensor const&, _object*)", "4000 build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::device() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "3006 /tmp/build/80754af9/python_1599604603603/work/Objects/methodobject.c:_PyCFunction_FastCallKeywords [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "3006 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:PyObject_Free [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", - "3000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", - "3000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", + "3000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const", + "3000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const'2", "3000 /usr/include/c++/8/array:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const", "3000 /usr/include/c++/8/bits/atomic_base.h:c10::intrusive_ptr::reset_()", "3000 /usr/include/c++/8/bits/shared_ptr_base.h:THPVariable_clear(THPVariable*)", - "3000 /usr/include/c++/8/bits/stl_numeric.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "3000 /usr/include/c++/8/bits/stl_numeric.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "3000 /usr/include/c++/8/bits/stl_vector.h:torch::PyWarningHandler::PyWarningHandler()", "3000 /usr/include/c++/8/bits/unique_ptr.h:torch::autograd::make_variable(at::Tensor, bool, bool)", "3000 /usr/include/c++/8/tuple:c10::DefaultCPUAllocator::allocate(unsigned long) const", @@ -624,21 +624,21 @@ "3000 build/../c10/core/Backend.h:torch::PythonArgs::device(int)", "3000 build/../c10/core/Backend.h:torch::tensors::get_default_dispatch_key()", "3000 build/../c10/core/Device.h:c10::DefaultCPUAllocator::allocate(unsigned long) const", - "3000 build/../c10/core/Device.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "3000 build/../c10/core/DispatchKeySet.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "3000 build/../c10/core/DispatchKeySet.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "3000 build/../c10/core/Device.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "3000 build/../c10/core/DispatchKeySet.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "3000 build/../c10/core/DispatchKeySet.h:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "3000 build/../c10/core/Scalar.h:at::native::ones(c10::ArrayRef, c10::TensorOptions const&)", - "3000 build/../c10/core/TensorImpl.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "3000 build/../c10/core/TensorImpl.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", "3000 build/../c10/core/TensorImpl.h:c10::VariableVersion::VersionCounter::~VersionCounter() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "3000 build/../c10/util/Optional.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&)", - "3000 build/../c10/util/Optional.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469]", + "3000 build/../c10/util/Optional.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.469]", "3000 build/../c10/util/Optional.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", "3000 build/../c10/util/intrusive_ptr.h:THPVariable_dealloc(THPVariable*)", "3000 build/../c10/util/typeid.h:c10::TensorImpl::data() const", - "3000 build/../c10/util/typeid.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469]", - "3000 build/../c10/util/typeid.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484]", + "3000 build/../c10/util/typeid.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.469]", + "3000 build/../c10/util/typeid.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.484]", "3000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::PythonArgParser::check_deprecated(torch::FunctionSignature const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", - "3000 build/aten/src/ATen/core/TensorBody.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "3000 build/aten/src/ATen/core/TensorBody.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "3000 build/aten/src/ATen/core/TensorBody.h:torch::autograd::make_variable(at::Tensor, bool, bool)", "2006 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GenericGetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "2000 /usr/include/c++/8/array:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const", @@ -650,9 +650,9 @@ "2000 /usr/include/c++/8/bits/stl_vector.h:torch::PyWarningHandler::~PyWarningHandler()", "2000 /usr/include/c++/8/bits/stl_vector.h:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**)", "2000 /usr/include/c++/8/ext/new_allocator.h:torch::PythonArgs::intlist(int)", - "2000 /usr/include/c++/8/new:c10::OptionalDeviceGuard::OptionalDeviceGuard(c10::optional)", - "2000 /usr/include/c++/8/new:c10::computeDispatchKey(c10::optional, c10::optional, c10::optional)", - "2000 /usr/include/c++/8/tuple:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "2000 /usr/include/c++/8/new:c10::OptionalDeviceGuard::OptionalDeviceGuard(std::optional)", + "2000 /usr/include/c++/8/new:c10::computeDispatchKey(std::optional, std::optional, std::optional)", + "2000 /usr/include/c++/8/tuple:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", "2000 build/../aten/src/ATen/Context.cpp:at::getCPUAllocator() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "2000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_ >, at::Tensor& (at::Tensor&, c10::Scalar)>::call(c10::OperatorKernel*, at::Tensor&, c10::Scalar)", "2000 build/../aten/src/ATen/core/dispatch/OperatorEntry.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const", @@ -681,41 +681,41 @@ "5458967 build/../torch/csrc/autograd/generated/variable_factories.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", "5288967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::TensorOptions const&)", "5177967 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "5161967 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "5011967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "4990967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "4729967 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", - "4707967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "4644967 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "4633967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "4419967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "4399967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "4138967 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", - "4116967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", - "4053967 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", - "4042967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "3952967 build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "5161967 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "5011967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "4990967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "4729967 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const", + "4707967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "4644967 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "4633967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "4419967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "4399967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "4138967 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const'2", + "4116967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)'2", + "4053967 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)'2", + "4042967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "3952967 build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "3878967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", "3789967 build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "3765967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/native/TensorFactories.cpp:at::native::ones(c10::ArrayRef, c10::TensorOptions const&)", "3762967 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "3749967 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "2573967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/Functions.cpp:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "2485967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "2469967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "2256967 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", - "2245967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "2230967 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "2225967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "1981967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "1964967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "1751967 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", - "1740967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", - "1725967 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", - "1720967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "1716967 build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "1705967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "1475967 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "2573967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/Functions.cpp:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "2485967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "2469967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "2256967 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const", + "2245967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "2230967 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "2225967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "1981967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "1964967 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "1751967 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const'2", + "1740967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)'2", + "1725967 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)'2", + "1720967 /data/users/test_user/repos/pytorch/build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "1716967 build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "1705967 /data/users/test_user/repos/pytorch/build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "1475967 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "1307993 build/../torch/csrc/autograd/generated/python_torch_functions.cpp:torch::autograd::THPVariable_ones(_object*, _object*, _object*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "1112000 /data/users/test_user/repos/pytorch/build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::fill_(c10::Scalar) const", "1067166 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_SetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", @@ -774,7 +774,7 @@ "209209 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_RestoreThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "200993 /data/users/test_user/repos/pytorch/build/../torch/csrc/autograd/utils/wrap_outputs.h:torch::autograd::utils::wrap(at::Tensor)", "200000 /data/users/test_user/repos/pytorch/build/../c10/core/TensorImpl.h:c10::TensorImpl::empty_tensor_restride(c10::MemoryFormat)", - "196000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "196000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "193993 build/../torch/csrc/autograd/utils/wrap_outputs.h:torch::autograd::utils::wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "192000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::~RecordFunction() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "192000 build/../c10/core/Device.h:c10::Device::validate() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", @@ -784,7 +784,7 @@ "178000 ???:malloc [/usr/lib64/libc-2.28.so]", "178000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "176993 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_Wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", - "171000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "171000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "170404 ???:_int_malloc [/usr/lib64/libc-2.28.so]", "170000 build/../c10/core/TensorImpl.h:c10::TensorImpl::empty_tensor_restride(c10::MemoryFormat) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "167167 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_RestoreThread", @@ -818,7 +818,7 @@ "100000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const", "98098 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_SaveThread", "95000 /data/users/test_user/repos/pytorch/build/../torch/csrc/autograd/variable.h:torch::autograd::make_variable(at::Tensor, bool, bool)", - "95000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "95000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "94000 build/../c10/core/StorageImpl.h:c10::TensorImpl::release_resources()", "93229 /usr/include/c++/8/ext/new_allocator.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)" ], @@ -845,24 +845,24 @@ "90000 build/../aten/src/ATen/record_function.cpp:at::RecordFunction::end() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "84338 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_GetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "84000 build/../c10/util/SmallVector.h:at::RecordFunction::RecordFunction(at::RecordScope)", - "78000 build/../c10/core/TensorOptions.h:c10::computeDispatchKey(c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "78000 build/../c10/core/TensorOptions.h:c10::computeDispatchKey(std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "78000 build/../torch/csrc/autograd/generated/python_torch_functions.cpp:torch::autograd::THPVariable_ones(_object*, _object*, _object*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "74710 /tmp/build/80754af9/python_1599604603603/work/Objects/dictobject.c:PyDict_SetItem [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "72000 /usr/include/c++/8/bits/stl_vector.h:at::RecordFunction::~RecordFunction()", - "67000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "67000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "67000 build/../c10/util/intrusive_ptr.h:c10::intrusive_ptr::reset_() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "66066 ???:__pthread_mutex_unlock_usercnt [/usr/lib64/libpthread-2.28.so]", "64110 /tmp/build/80754af9/python_1599604603603/work/Objects/methodobject.c:_PyCFunction_FastCallDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "64000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::FunctionParameter::check(_object*, std::vector >&, int) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "61182 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:call_function [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "59177 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:_PyObject_GenericGetAttrWithDict [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", - "59000 build/../c10/util/Optional.h:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "59000 build/../c10/util/Optional.h:at::(anonymous namespace)::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "59000 build/../torch/csrc/utils/python_arg_parser.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", "57000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "56000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgs::intlist(int) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "55000 /tmp/build/80754af9/python_1599604603603/work/Objects/tupleobject.c:tupledealloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "54000 /usr/include/c++/8/bits/stl_vector.h:at::RecordFunction::RecordFunction(at::RecordScope)", - "52000 build/../c10/util/Optional.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "52000 build/../c10/util/Optional.h:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "50000 build/../c10/util/ThreadLocalDebugInfo.cpp:c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "50000 build/../c10/util/typeid.cpp:caffe2::detail::TypeMetaData const* caffe2::TypeMeta::_typeMetaDataInstance() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "49000 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:PyType_GenericAlloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", @@ -874,35 +874,35 @@ "44044 ???:pthread_cond_signal@@GLIBC_2.3.2 [/usr/lib64/libpthread-2.28.so]", "44000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/new_op.cc:operator new(unsigned long) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", "44000 build/../c10/core/CPUAllocator.cpp:c10::alloc_cpu(unsigned long) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", - "44000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", + "44000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "44000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::compute_contiguous() const [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", - "43000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", - "43000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "43000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const", + "43000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const'2", "42000 build/../c10/util/typeid.h:c10::typeMetaToScalarType(caffe2::TypeMeta)", "41000 build/../aten/src/ATen/native/Fill.cpp:at::native::fill_out(at::Tensor&, c10::Scalar) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "41000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "41000 build/../c10/core/TensorImpl.h:c10::TensorImpl::set_sizes_contiguous(c10::ArrayRef) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "40106 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:PyObject_Malloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "40000 build/../c10/core/TensorOptions.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", - "39000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "39000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "38056 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:_PyObject_Free [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "37111 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:_PyType_Lookup [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "36000 /usr/include/c++/8/bits/stl_construct.h:at::RecordFunction::~RecordFunction()", "36000 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_clear(THPVariable*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "35035 /tmp/build/80754af9/python_1599604603603/work/Python/ceval.c:PyEval_RestoreThread [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "35035 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_RestoreThread", - "35000 build/../c10/core/TensorOptions.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "35000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "35000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", + "35000 build/../c10/core/TensorOptions.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "35000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "35000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)'2", "34000 /tmp/build/80754af9/python_1599604603603/work/Objects/weakrefobject.c:PyObject_ClearWeakRefs [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "34000 build/../c10/core/impl/InlineDeviceGuard.h:c10::impl::InlineDeviceGuard::InlineDeviceGuard(c10::Device)", - "33000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "33000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "33000 build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/../aten/src/ATen/core/op_registration/hacky_wrapper_for_legacy_signatures.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "33000 build/../torch/csrc/autograd/variable.h:torch::autograd::make_variable(at::Tensor, bool, bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", - "33000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "32000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "32000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "33000 build/aten/src/ATen/BackendSelectRegister.cpp:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "32000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "32000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "32000 build/../c10/core/Allocator.cpp:c10::memoryProfilingEnabled() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "31000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "30000 build/../aten/src/ATen/core/dispatch/Dispatcher.cpp:c10::Dispatcher::singleton() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", @@ -917,8 +917,8 @@ "27000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const", "26000 /usr/include/c++/8/bits/stl_vector.h:torch::PythonArgs::intlist(int)", "26000 build/../c10/core/TensorImpl.h:c10::TensorImpl::data() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "26000 build/../c10/core/TensorOptions.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "26000 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "26000 build/../c10/core/TensorOptions.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "26000 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "25000 build/../aten/src/ATen/native/TensorFactories.h:at::native::check_size_nonnegative(c10::ArrayRef) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "25000 build/../c10/core/CPUAllocator.cpp:c10::ProfiledCPUMemoryReporter::Delete(void*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "25000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::~TensorImpl()", @@ -933,27 +933,27 @@ "25000 build/../torch/csrc/utils/python_numbers.h:torch::PythonArgs::intlist(int)", "25000 build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "24000 build/../c10/core/DispatchKey.cpp:c10::getAutogradKeyFromBackend(c10::DispatchKey) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", - "24000 build/aten/src/ATen/Functions.cpp:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "24000 build/aten/src/ATen/Functions.cpp:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "24000 build/aten/src/ATen/TypeDefault.cpp:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "23000 build/../aten/src/ATen/core/LegacyTypeDispatch.h:at::AutoNonVariableTypeMode::AutoNonVariableTypeMode(bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "23000 build/../c10/core/CPUAllocator.cpp:c10::DefaultCPUAllocator::allocate(unsigned long) const [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", - "23000 build/../c10/core/TensorOptions.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "23000 build/../c10/core/TensorOptions.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.484] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "22044 /home/test_user/miniconda3/envs/throwaway/include/pybind11/detail/internals.h:pybind11::detail::get_internals() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "22000 /usr/include/c++/8/bits/shared_ptr_base.h:c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind)", - "22000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", - "22000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", - "22000 build/../c10/core/TensorOptions.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", + "22000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const", + "22000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const'2", + "22000 build/../c10/core/TensorOptions.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.469] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "21021 /tmp/build/80754af9/python_1599604603603/work/Python/ceval_gil.h:PyEval_SaveThread", "20035 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_PyStack_AsTuple [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", - "20000 build/../c10/util/Optional.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "20000 build/../c10/util/Optional.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "20000 build/../torch/csrc/autograd/generated/variable_factories.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", "20000 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_Wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "20000 build/../torch/csrc/autograd/variable.cpp:torch::autograd::impl::get_autograd_meta(at::Tensor const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "19019 /tmp/build/80754af9/python_1599604603603/work/Modules/gcmodule.c:_PyObject_GC_Malloc [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", - "19000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "19000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "19000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "19000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "19000 build/../aten/src/ATen/native/TypeProperties.cpp:at::native::is_complex(at::Tensor const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "19000 build/../c10/util/SmallVector.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "19000 build/../c10/util/SmallVector.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", "18054 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "18000 /tmp/build/80754af9/python_1599604603603/work/Objects/longobject.c:PyLong_AsLongLongAndOverflow [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "18000 /usr/include/c++/8/bits/shared_ptr_base.h:at::RecordFunction::~RecordFunction()", @@ -961,17 +961,17 @@ "18000 build/../torch/csrc/autograd/python_variable.cpp:THPVariable_NewWithVar(_typeobject*, at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "17010 /tmp/build/80754af9/python_1599604603603/work/Objects/typeobject.c:PyType_IsSubtype [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "17000 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:pybind11::gil_scoped_release::gil_scoped_release(bool) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", - "17000 /usr/include/c++/8/new:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "17000 /usr/include/c++/8/new:at::(anonymous namespace)::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "17000 build/../torch/csrc/tensor/python_tensor.cpp:torch::tensors::get_default_dispatch_key() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "17000 build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::fill_(c10::Scalar) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "16064 /tmp/build/80754af9/python_1599604603603/work/Objects/abstract.c:_Py_CheckFunctionResult [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", - "16000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "16000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "16000 build/../c10/util/Exception.cpp:c10::Warning::set_warning_handler(c10::WarningHandler*) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "16000 build/../c10/util/intrusive_ptr.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", "16000 build/../c10/util/intrusive_ptr.h:torch::autograd::make_variable(at::Tensor, bool, bool)", "16000 build/aten/src/ATen/Functions.cpp:at::ones(c10::ArrayRef, c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "15000 /usr/include/c++/8/new:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "15000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "15000 /usr/include/c++/8/new:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "15000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "15000 build/../c10/core/ScalarType.h:at::native::is_complex(at::Tensor const&)", "15000 build/../c10/core/TensorOptions.h:c10::TensorOptions::computeDispatchKey() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "15000 build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(c10::DispatchKeySet) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", @@ -984,7 +984,7 @@ "14000 build/../c10/core/ScalarType.h:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda()", "14000 build/../c10/core/TensorImpl.h:c10::TensorImpl::~TensorImpl() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "14000 build/../c10/util/Optional.h:at::ones(c10::ArrayRef, c10::TensorOptions const&)", - "14000 build/../c10/util/Optional.h:c10::computeDispatchKey(c10::optional, c10::optional, c10::optional)", + "14000 build/../c10/util/Optional.h:c10::computeDispatchKey(std::optional, std::optional, std::optional)", "14000 build/../c10/util/typeid.h:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda()", "14000 build/../c10/util/typeid.h:at::native::is_complex(at::Tensor const&)", "14000 build/aten/src/ATen/core/TensorBody.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", @@ -996,10 +996,10 @@ "13000 build/../torch/csrc/utils/tensor_numpy.cpp:torch::utils::is_numpy_int(_object*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "12000 /tmp/build/80754af9/python_1599604603603/work/Modules/gcmodule.c:PyObject_GC_Del [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "12000 /usr/include/c++/8/bits/shared_ptr_base.h:at::RecordFunction::RecordFunction(at::RecordScope)", - "12000 /usr/include/c++/8/new:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", - "12000 /usr/include/c++/8/new:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "12000 /usr/include/c++/8/new:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const", + "12000 /usr/include/c++/8/new:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const'2", "12000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const", - "12000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "12000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "12000 build/../c10/core/TensorImpl.h:c10::TensorImpl::compute_contiguous() const", "12000 build/../c10/util/SmallVector.h:c10::TensorImpl::set_sizes_contiguous(c10::ArrayRef)", "12000 build/../c10/util/intrusive_ptr.h:c10::intrusive_ptr >::reset_() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", @@ -1007,56 +1007,56 @@ "11000 /home/test_user/miniconda3/envs/throwaway/include/pybind11/pybind11.h:pybind11::gil_scoped_release::~gil_scoped_release() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "11000 /usr/include/c++/8/bits/stl_algobase.h:torch::PythonArgs::intlist(int)", "11000 /usr/include/c++/8/ext/new_allocator.h:torch::PythonArgs::intlist(int)", - "11000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", - "11000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", - "11000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "11000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", - "11000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "11000 build/../c10/util/Optional.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "11000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const", + "11000 build/../aten/src/ATen/core/boxing/KernelFunction_impl.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const'2", + "11000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "11000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)'2", + "11000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "11000 build/../c10/util/Optional.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "11000 build/../torch/csrc/autograd/utils/wrap_outputs.h:torch::autograd::utils::wrap(at::Tensor) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "11000 build/../torch/csrc/jit/frontend/tracer.cpp:torch::jit::tracer::getTracingState() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "11000 build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "11000 build/aten/src/ATen/CPUType.cpp:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "10000 build/../c10/core/CPUAllocator.cpp:c10::profiledCPUMemoryReporter() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "10000 build/../c10/util/Exception.cpp:c10::Warning::get_warning_handler() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "10000 build/../c10/util/intrusive_ptr.h:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&)", "10000 build/../c10/util/intrusive_ptr.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", - "10000 build/../c10/util/llvmMathExtras.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "10000 build/../c10/util/llvmMathExtras.h:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "9009 ???:pthread_mutex_unlock [/usr/lib64/libpthread-2.28.so]", - "9000 /usr/include/c++/8/new:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "9000 /usr/include/c++/8/new:c10::impl::wrap_kernel_functor_unboxed_, c10::optional, c10::optional, c10::optional, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::optional, c10::optional, c10::optional, c10::optional > >, at::Tensor (c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)'2", - "9000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "9000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "9000 /usr/include/c++/8/new:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "9000 /usr/include/c++/8/new:c10::impl::wrap_kernel_functor_unboxed_, std::optional, std::optional, std::optional, std::optional), at::Tensor, c10::guts::typelist::typelist, std::optional, std::optional, std::optional, std::optional > >, at::Tensor (c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)'2", + "9000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", + "9000 build/../aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)'2 [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "9000 build/../c10/core/Device.h:at::native::fill_out(at::Tensor&, c10::Scalar)", "9000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::release_resources() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "9000 build/../c10/core/TensorOptions.h:c10::TensorOptions::TensorOptions() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", - "9000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", - "9000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", - "9000 build/../c10/util/Optional.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "9000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "9000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const", + "9000 build/../c10/util/Optional.h:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const'2", + "9000 build/../c10/util/Optional.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "9000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", "8000 /home/nwani/m3/conda-bld/compilers_linux-64_1560109574129/work/.build/x86_64-conda_cos6-linux-gnu/src/gcc/libstdc++-v3/libsupc++/del_op.cc:operator delete(void*) [/home/test_user/miniconda3/envs/throwaway/lib/libstdc++.so.6.0.26]", - "8000 /usr/include/c++/8/bits/atomic_base.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "8000 /usr/include/c++/8/bits/atomic_base.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", "8000 /usr/include/c++/8/bits/unique_ptr.h:c10::TensorImpl::release_resources()", "8000 build/../aten/src/ATen/core/dispatch/Dispatcher.h:at::Tensor::is_complex() const", "8000 build/../aten/src/ATen/detail/CPUGuardImpl.h:at::detail::CPUGuardImpl::getDevice() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "8000 build/../c10/core/CPUAllocator.cpp:c10::GetCPUAllocator() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", - "8000 build/../c10/core/DeviceGuard.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(c10::optional)", + "8000 build/../c10/core/DeviceGuard.h:c10::OptionalDeviceGuard::OptionalDeviceGuard(std::optional)", "8000 build/../c10/core/DispatchKeySet.h:c10::DispatchKeySet::has(c10::DispatchKey) const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", - "8000 build/../c10/core/StorageImpl.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "8000 build/../c10/core/StorageImpl.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "8000 build/../c10/core/impl/DeviceGuardImplInterface.h:c10::impl::getDeviceGuardImpl(c10::DeviceType) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "8000 build/../c10/core/impl/VirtualGuardImpl.h:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", - "8000 build/../c10/util/Optional.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "8000 build/../c10/util/Optional.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "8000 build/../c10/util/Optional.h:c10::TensorOptions::computeDispatchKey() const", "8000 build/../c10/util/SmallVector.h:c10::TensorImpl::compute_contiguous() const", "8000 build/../c10/util/SmallVector.h:c10::TensorImpl::~TensorImpl()", - "8000 build/../c10/util/llvmMathExtras.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "8000 build/../c10/util/llvmMathExtras.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "8000 build/../c10/util/llvmMathExtras.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "8000 build/../c10/util/llvmMathExtras.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", "8000 build/aten/src/ATen/core/TensorBody.h:at::native::fill_out(at::Tensor&, c10::Scalar)", "7035 /tmp/build/80754af9/python_1599604603603/work/Python/errors.c:PyErr_Occurred [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "7000 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:_PyObject_GetDictPtr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", - "7000 /usr/include/c++/8/bits/stl_numeric.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "7000 /usr/include/c++/8/bits/stl_numeric.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "7000 /usr/include/c++/8/bits/stl_vector.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", "7000 build/../c10/core/Scalar.h:at::native::fill_out(at::Tensor&, c10::Scalar)::{lambda()", - "7000 build/../c10/core/ScalarType.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469]", + "7000 build/../c10/core/ScalarType.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.469]", "7000 build/../c10/core/StorageImpl.h:c10::TensorImpl::release_resources()", "7000 build/../c10/core/TensorImpl.cpp:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "7000 build/../c10/core/impl/VirtualGuardImpl.h:c10::optional_base >::~optional_base()", @@ -1071,34 +1071,34 @@ "6000 /usr/include/c++/8/bits/move.h:torch::PythonArgs::intlist(int)", "6000 /usr/include/c++/8/bits/shared_ptr_base.h:c10::memoryProfilingEnabled()", "6000 /usr/include/c++/8/bits/stl_iterator.h:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool)", - "6000 /usr/include/c++/8/bits/unique_ptr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "6000 /usr/include/c++/8/bits/unique_ptr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "6000 /usr/include/c++/8/bits/unique_ptr.h:c10::TensorImpl::~TensorImpl()", - "6000 /usr/include/c++/8/tuple:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "6000 /usr/include/c++/8/tuple:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "6000 build/../aten/src/ATen/record_function.h:at::RecordFunction::RecordFunction(at::RecordScope)", "6000 build/../c10/core/Allocator.cpp:c10::GetAllocator(c10::DeviceType const&) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "6000 build/../c10/core/Device.h:at::detail::CPUGuardImpl::getDevice() const", "6000 build/../c10/core/DispatchKeySet.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", - "6000 build/../c10/core/DispatchKeySet.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", - "6000 build/../c10/core/ScalarType.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484]", + "6000 build/../c10/core/DispatchKeySet.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", + "6000 build/../c10/core/ScalarType.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.484]", "6000 build/../c10/core/TensorImpl.h:at::Tensor at::detail::make_tensor >, c10::DispatchKey, caffe2::TypeMeta&>(c10::intrusive_ptr >&&, c10::DispatchKey&&, caffe2::TypeMeta&)", "6000 build/../c10/core/TensorImpl.h:at::Tensor::device() const", "6000 build/../c10/core/TensorOptions.h:torch::utils::maybe_initialize_cuda(c10::TensorOptions const&)", - "6000 build/../c10/util/Optional.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", - "6000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "6000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", + "6000 build/../c10/util/Optional.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", + "6000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "6000 build/../c10/util/Optional.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)'2", "6000 build/../c10/util/TypeCast.h:float c10::checked_convert(double, char const*)", "6000 build/../c10/util/intrusive_ptr.h:THPVariable_Wrap(at::Tensor)", - "6000 build/../c10/util/intrusive_ptr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "6000 build/../c10/util/intrusive_ptr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "6000 build/../torch/csrc/tensor/python_tensor.cpp:torch::tensors::get_default_scalar_type() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "6000 build/../torch/csrc/utils/python_arg_parser.h:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**)", "5000 /tmp/build/80754af9/python_1599604603603/_build_env/x86_64-conda_cos6-linux-gnu/sysroot/usr/include/bits/string3.h:PyType_GenericAlloc", - "5000 /usr/include/c++/8/bits/atomic_base.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "5000 /usr/include/c++/8/bits/atomic_base.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "5000 build/../aten/src/ATen/DeviceGuard.h:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar)", - "5000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "5000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, c10::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, c10::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, c10::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, c10::optional)'2", + "5000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "5000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_, c10::TensorOptions const&, std::optional), at::Tensor, c10::guts::typelist::typelist, c10::TensorOptions const&, std::optional > >, at::Tensor (c10::ArrayRef, c10::TensorOptions const&, std::optional)>::call(c10::OperatorKernel*, c10::ArrayRef, c10::TensorOptions const&, std::optional)'2", "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::Tensor::fill_(c10::Scalar) const", - "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "5000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "5000 build/../aten/src/ATen/native/TensorFactories.cpp:at::native::(anonymous namespace)::infer_full_options(c10::Scalar, c10::TensorOptions const&) [clone .isra.262] [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "5000 build/../c10/core/Device.h:torch::PythonArgs::device(int)", "5000 build/../c10/core/DispatchKeySet.h:at::Tensor::fill_(c10::Scalar) const", @@ -1106,12 +1106,12 @@ "5000 build/../c10/core/TensorImpl.h:at::Tensor::is_quantized() const", "5000 build/../c10/core/TensorOptions.h:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", "5000 build/../c10/core/impl/LocalDispatchKeySet.cpp:c10::impl::ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard() [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", - "5000 build/../c10/util/Optional.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "5000 build/../c10/util/Optional.h:at::(anonymous namespace)::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "5000 build/../c10/util/intrusive_ptr.h:c10::TensorImpl::release_resources()", "5000 build/../torch/csrc/utils/cuda_lazy_init.h:torch::utils::maybe_initialize_cuda(c10::TensorOptions const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "4004 ???:__errno_location [/usr/lib64/libpthread-2.28.so]", - "4000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const", - "4000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::optional, c10::optional, c10::optional, c10::optional >(c10::TypedOperatorHandle, c10::optional, c10::optional, c10::optional, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional) const'2", + "4000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const", + "4000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, std::optional, std::optional, std::optional, std::optional >(c10::TypedOperatorHandle, std::optional, std::optional, std::optional, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, std::optional, std::optional, std::optional, std::optional) const'2", "4000 /usr/include/c++/8/bits/atomic_base.h:at::native::full(c10::ArrayRef, c10::Scalar, c10::TensorOptions const&)", "4000 /usr/include/c++/8/bits/atomic_base.h:c10::impl::getDeviceGuardImpl(c10::DeviceType)", "4000 /usr/include/c++/8/bits/atomic_base.h:torch::autograd::make_variable(at::Tensor, bool, bool)", @@ -1119,24 +1119,24 @@ "4000 /usr/include/c++/8/bits/unique_ptr.h:c10::TensorImpl::set_autograd_meta(std::unique_ptr >) [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", "4000 /usr/include/c++/8/cmath:float c10::checked_convert(double, char const*) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "4000 build/../aten/src/ATen/core/dispatch/DispatchKeyExtractor.h:at::Tensor::is_complex() const", - "4000 build/../c10/core/Allocator.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "4000 build/../c10/core/Device.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "4000 build/../c10/core/Allocator.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "4000 build/../c10/core/Device.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "4000 build/../c10/core/DeviceGuard.h:at::TypeDefault::ones(c10::ArrayRef, c10::TensorOptions const&)", "4000 build/../c10/core/DispatchKeySet.h:c10::impl::ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(c10::DispatchKeySet)", "4000 build/../c10/core/TensorImpl.h:torch::autograd::impl::set_pyobj(at::Tensor const&, _object*)", - "4000 build/../c10/util/Optional.h:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "4000 build/../c10/util/Optional.h:at::CPUType::empty_memory_format(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "4000 build/../c10/util/Optional.h:at::TypeDefault::fill__Scalar(at::Tensor&, c10::Scalar)", - "4000 build/../c10/util/Optional.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484]", + "4000 build/../c10/util/Optional.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.484]", "4000 build/../c10/util/SmallVector.h:c10::TensorImpl::sizes() const [/data/users/test_user/repos/pytorch/torch/lib/libc10.so]", - "4000 build/../c10/util/UniqueVoidPtr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "4000 build/../c10/util/UniqueVoidPtr.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "4000 build/../c10/util/intrusive_ptr.h:THPVariable_NewWithVar(_typeobject*, at::Tensor)", "4000 build/../torch/csrc/autograd/variable.cpp:torch::autograd::impl::pyobj(at::Tensor const&)", "4000 build/../torch/csrc/autograd/variable.cpp:torch::autograd::impl::set_pyobj(at::Tensor const&, _object*)", "4000 build/aten/src/ATen/core/TensorMethods.cpp:at::Tensor::device() const [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "3006 /tmp/build/80754af9/python_1599604603603/work/Objects/methodobject.c:_PyCFunction_FastCallKeywords [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "3006 /tmp/build/80754af9/python_1599604603603/work/Objects/obmalloc.c:PyObject_Free [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", - "3000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const", - "3000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, c10::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, c10::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, c10::optional) const'2", + "3000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const", + "3000 /usr/include/c++/8/array:at::Tensor c10::Dispatcher::callWithDispatchKey, c10::TensorOptions const&, std::optional >(c10::TypedOperatorHandle, c10::TensorOptions const&, std::optional)> const&, c10::DispatchKey, c10::ArrayRef, c10::TensorOptions const&, std::optional) const'2", "3000 /usr/include/c++/8/array:bool c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor const&) const", "3000 /usr/include/c++/8/bits/atomic_base.h:c10::intrusive_ptr::reset_()", "3000 /usr/include/c++/8/bits/shared_ptr_base.h:THPVariable_clear(THPVariable*)", @@ -1149,22 +1149,22 @@ "3000 build/../c10/core/Backend.h:torch::PythonArgs::device(int)", "3000 build/../c10/core/Backend.h:torch::tensors::get_default_dispatch_key()", "3000 build/../c10/core/Device.h:c10::DefaultCPUAllocator::allocate(unsigned long) const", - "3000 build/../c10/core/Device.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", - "3000 build/../c10/core/DispatchKeySet.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", - "3000 build/../c10/core/DispatchKeySet.h:at::ones(c10::ArrayRef, c10::optional, c10::optional, c10::optional, c10::optional)", + "3000 build/../c10/core/Device.h:c10::impl::detail::with_scattered_tensor_options_impl_, c10::TensorOptions const&), &at::TypeDefault::ones>, c10::guts::typelist::typelist >, c10::guts::typelist::typelist<> >::wrapper(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", + "3000 build/../c10/core/DispatchKeySet.h:at::empty(c10::ArrayRef, c10::TensorOptions const&, std::optional)", + "3000 build/../c10/core/DispatchKeySet.h:at::ones(c10::ArrayRef, std::optional, std::optional, std::optional, std::optional)", "3000 build/../c10/core/Scalar.h:at::native::ones(c10::ArrayRef, c10::TensorOptions const&)", - "3000 build/../c10/core/TensorImpl.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "3000 build/../c10/core/TensorImpl.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", "3000 build/../c10/core/TensorImpl.h:c10::VariableVersion::VersionCounter::~VersionCounter() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", "3000 build/../c10/util/Optional.h:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&)", - "3000 build/../c10/util/Optional.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469]", + "3000 build/../c10/util/Optional.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.469]", "3000 build/../c10/util/Optional.h:torch::autograd::THPVariable_ones(_object*, _object*, _object*)", "3000 build/../c10/util/intrusive_ptr.h:THPVariable_dealloc(THPVariable*)", "3000 build/../c10/util/typeid.h:c10::TensorImpl::data() const", - "3000 build/../c10/util/typeid.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.469]", - "3000 build/../c10/util/typeid.h:c10::TensorOptions::dtype(c10::optional) const [clone .isra.484]", + "3000 build/../c10/util/typeid.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.469]", + "3000 build/../c10/util/typeid.h:c10::TensorOptions::dtype(std::optional) const [clone .isra.484]", "3000 build/../torch/csrc/utils/object_ptr.h:torch::PythonArgs::intlist(int)", "3000 build/../torch/csrc/utils/python_arg_parser.cpp:torch::PythonArgParser::check_deprecated(torch::FunctionSignature const&) [/data/users/test_user/repos/pytorch/torch/lib/libtorch_python.so]", - "3000 build/aten/src/ATen/core/TensorBody.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, c10::optional)", + "3000 build/aten/src/ATen/core/TensorBody.h:at::native::empty_cpu(c10::ArrayRef, c10::TensorOptions const&, std::optional)", "3000 build/aten/src/ATen/core/TensorBody.h:torch::autograd::make_variable(at::Tensor, bool, bool)", "2006 /tmp/build/80754af9/python_1599604603603/work/Objects/object.c:PyObject_GenericGetAttr [/home/test_user/miniconda3/envs/throwaway/bin/python3.6]", "2000 /usr/include/c++/8/array:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const", @@ -1174,9 +1174,9 @@ "2000 /usr/include/c++/8/bits/stl_vector.h:torch::FunctionSignature::parse(_object*, _object*, _object*, _object**, bool)", "2000 /usr/include/c++/8/bits/stl_vector.h:torch::PyWarningHandler::~PyWarningHandler()", "2000 /usr/include/c++/8/bits/stl_vector.h:torch::PythonArgParser::raw_parse(_object*, _object*, _object*, _object**)", - "2000 /usr/include/c++/8/new:c10::OptionalDeviceGuard::OptionalDeviceGuard(c10::optional)", - "2000 /usr/include/c++/8/new:c10::computeDispatchKey(c10::optional, c10::optional, c10::optional)", - "2000 /usr/include/c++/8/tuple:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, c10::optional)", + "2000 /usr/include/c++/8/new:c10::OptionalDeviceGuard::OptionalDeviceGuard(std::optional)", + "2000 /usr/include/c++/8/new:c10::computeDispatchKey(std::optional, std::optional, std::optional)", + "2000 /usr/include/c++/8/tuple:c10::TensorImpl::TensorImpl(c10::Storage&&, c10::DispatchKeySet, caffe2::TypeMeta const&, std::optional)", "2000 build/../aten/src/ATen/Context.cpp:at::getCPUAllocator() [/data/users/test_user/repos/pytorch/torch/lib/libtorch_cpu.so]", "2000 build/../aten/src/ATen/core/boxing/impl/WrapFunctionIntoRuntimeFunctor.h:c10::impl::wrap_kernel_functor_unboxed_ >, at::Tensor& (at::Tensor&, c10::Scalar)>::call(c10::OperatorKernel*, at::Tensor&, c10::Scalar)", "2000 build/../aten/src/ATen/core/dispatch/OperatorEntry.h:at::Tensor& c10::Dispatcher::callWithDispatchKey(c10::TypedOperatorHandle const&, c10::DispatchKey, at::Tensor&, c10::Scalar) const", diff --git a/test/benchmark_utils/test_benchmark_utils.py b/test/benchmark_utils/test_benchmark_utils.py index ff3538769e06d..106d11440218c 100644 --- a/test/benchmark_utils/test_benchmark_utils.py +++ b/test/benchmark_utils/test_benchmark_utils.py @@ -703,7 +703,7 @@ def custom_transforms(fn: str): 90090 ???:pthread_mutex_lock [/usr/lib64/libpthread-2.28.so] 90000 build/../c10/core/TensorImpl.h:c ... ch/torch/lib/libtorch_python.so] 90000 build/../aten/src/ATen/record_fu ... torch/torch/lib/libtorch_cpu.so] - 90000 /data/users/test_user/repos/pyto ... uard(c10::optional) + 90000 /data/users/test_user/repos/pyto ... uard(std::optional) 90000 /data/users/test_user/repos/pyto ... ersionCounter::~VersionCounter() 88000 /data/users/test_user/repos/pyto ... ratorKernel*, at::Tensor const&)""", ) diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index 61cf3a9be5ecd..fe34bf6a5021f 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -73,6 +73,7 @@ if(NOT MSVC) endif() if(INSTALL_TEST) + set_target_properties(test_api PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_api DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/api/dataloader.cpp b/test/cpp/api/dataloader.cpp index 5dd43ab04ce89..e592f10df8dc2 100644 --- a/test/cpp/api/dataloader.cpp +++ b/test/cpp/api/dataloader.cpp @@ -2220,8 +2220,7 @@ TEST(DataLoaderTest, ChunkDatasetCrossChunkShuffle) { for (const auto i : c10::irange( (chunk_count + cross_chunk_shuffle_count - 1) / cross_chunk_shuffle_count)) { - for (const auto j : c10::irange(chunk_size)) { - (void)j; // Suppress unused variable warning + for ([[maybe_unused]] const auto j : c10::irange(chunk_size)) { for (const auto k : c10::irange(cross_chunk_shuffle_count)) { if (i * cross_chunk_shuffle_count + k < chunk_count) { expected_result.push_back(i * cross_chunk_shuffle_count + k); diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp index 68d41cb163d51..8dc29be38ea4b 100644 --- a/test/cpp/api/functional.cpp +++ b/test/cpp/api/functional.cpp @@ -1343,8 +1343,7 @@ TEST_F(FunctionalTest, GumbelSoftmax) { auto counts = torch::zeros_like(logits); torch::Tensor y_draw; - for (const auto i : c10::irange(num_draws)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(num_draws)) { y_draw = F::gumbel_softmax(logits, F::GumbelSoftmaxFuncOptions().hard(true)); counts += y_draw; diff --git a/test/cpp/api/integration.cpp b/test/cpp/api/integration.cpp index cbdf49df1058e..0220f5a6738c3 100644 --- a/test/cpp/api/integration.cpp +++ b/test/cpp/api/integration.cpp @@ -123,8 +123,7 @@ bool test_mnist( torch::Device device(with_cuda ? torch::kCUDA : torch::kCPU); model->to(device); - for (const auto epoch : c10::irange(number_of_epochs)) { - (void)epoch; // Suppress unused variable warning + for ([[maybe_unused]] const auto epoch : c10::irange(number_of_epochs)) { // NOLINTNEXTLINE(performance-for-range-copy) for (torch::data::Example<> batch : *data_loader) { auto data = batch.data.to(device); diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index bd980dd9b8926..a584624bd1b7a 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -3511,8 +3511,7 @@ void _multihead_attn_test_helper( std::uniform_int_distribution d_2_10(2, 10); std::uniform_int_distribution d_3_10(3, 10); bool registration_checked = false; - for (const auto i : c10::irange(100)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(100)) { const auto batch_sz = d_2_10(generator); const auto seq_len = d_2_10(generator); const auto d_head = d_3_10(generator); diff --git a/test/cpp/api/nn_utils.cpp b/test/cpp/api/nn_utils.cpp index dd9928e80a213..43d9e64c1ed54 100644 --- a/test/cpp/api/nn_utils.cpp +++ b/test/cpp/api/nn_utils.cpp @@ -398,8 +398,8 @@ std::vector PackedSequenceTest_ordered_sequence( torch::ScalarType tensor_type) { std::vector seqs; seqs.reserve(PackedSequenceTest_batch_size); - for (const auto i : c10::irange(PackedSequenceTest_batch_size)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : + c10::irange(PackedSequenceTest_batch_size)) { seqs.emplace_back(torch::empty( {torch::randint(1, PackedSequenceTest_max_length, {1}).item()}, tensor_type)); diff --git a/test/cpp/api/operations.cpp b/test/cpp/api/operations.cpp index bf1643ae1e795..0494a728bb626 100644 --- a/test/cpp/api/operations.cpp +++ b/test/cpp/api/operations.cpp @@ -12,8 +12,7 @@ struct OperationTest : torch::test::SeedingFixture { }; TEST_F(OperationTest, Lerp) { - for (const auto i : c10::irange(TEST_AMOUNT)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(TEST_AMOUNT)) { // test lerp_kernel_scalar auto start = torch::rand({3, 5}); auto end = torch::rand({3, 5}); @@ -37,8 +36,7 @@ TEST_F(OperationTest, Lerp) { } TEST_F(OperationTest, Cross) { - for (const auto i : c10::irange(TEST_AMOUNT)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(TEST_AMOUNT)) { // input auto a = torch::rand({10, 3}); auto b = torch::rand({10, 3}); diff --git a/test/cpp/api/optim.cpp b/test/cpp/api/optim.cpp index b8799a17157fb..33f4d9bf7eee2 100644 --- a/test/cpp/api/optim.cpp +++ b/test/cpp/api/optim.cpp @@ -157,8 +157,7 @@ void check_exact_values( TEST(OptimTest, OptimizerAccessors) { auto options = AdagradOptions(1.0); std::vector params; - for (const auto i : c10::irange(3)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(3)) { params.push_back(torch::randn(10)); } auto optimizer = Adagrad(params, options); diff --git a/test/cpp/c10d/BackoffTest.cpp b/test/cpp/c10d/BackoffTest.cpp index 054f30ba4993e..b229ec5dbfef1 100644 --- a/test/cpp/c10d/BackoffTest.cpp +++ b/test/cpp/c10d/BackoffTest.cpp @@ -1,9 +1,6 @@ #include #include "StoreTestCommon.hpp" -#include -#include - #include TEST(BackoffTest, exponentialBackoffDefaults) { diff --git a/test/cpp/c10d/CMakeLists.txt b/test/cpp/c10d/CMakeLists.txt index 0874852517e33..fdcc20c5bc753 100644 --- a/test/cpp/c10d/CMakeLists.txt +++ b/test/cpp/c10d/CMakeLists.txt @@ -6,37 +6,40 @@ if(USE_CUDA) endif() function(c10d_add_test test_src) + set(prefix ARG) + set(noValues) + set(singleValues INSTALL_TEST) + set(multiValues LINK_LIBRARIES) + + include(CMakeParseArguments) + cmake_parse_arguments(${prefix} "${noValues}" "${singleValues}" "${multiValues}" ${ARGN}) + get_filename_component(test_name ${test_src} NAME_WE) add_executable(${test_name} "${test_src}") target_include_directories(${test_name} PRIVATE $) - target_link_libraries(${test_name} ${ARGN}) + target_link_libraries(${test_name} ${ARG_LINK_LIBRARIES}) if(NOT WIN32) target_link_libraries(${test_name} pthread) endif() add_test(NAME ${test_name} COMMAND $) + + if(ARG_INSTALL_TEST) + set_target_properties(${test_name} PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") + install(TARGETS ${test_name} DESTINATION bin) + endif() endfunction() -c10d_add_test(BackoffTest.cpp torch_cpu gtest_main) -c10d_add_test(FileStoreTest.cpp torch_cpu gtest_main) -c10d_add_test(TCPStoreTest.cpp torch_cpu gtest_main) -if(INSTALL_TEST) - install(TARGETS FileStoreTest DESTINATION bin) - install(TARGETS TCPStoreTest DESTINATION bin) -endif() +c10d_add_test(BackoffTest.cpp LINK_LIBRARIES torch_cpu gtest_main INSTALL_TEST OFF) +c10d_add_test(FileStoreTest.cpp LINK_LIBRARIES torch_cpu gtest_main INSTALL_TEST ${INSTALL_TEST}) +c10d_add_test(TCPStoreTest.cpp LINK_LIBRARIES torch_cpu gtest_main INSTALL_TEST ${INSTALL_TEST}) if(NOT WIN32) - c10d_add_test(HashStoreTest.cpp torch_cpu gtest_main) - if(INSTALL_TEST) - install(TARGETS HashStoreTest DESTINATION bin) - endif() + c10d_add_test(HashStoreTest.cpp LINK_LIBRARIES torch_cpu gtest_main INSTALL_TEST ${INSTALL_TEST}) endif() if(USE_CUDA) if(USE_GLOO AND USE_C10D_GLOO) - c10d_add_test(ProcessGroupGlooTest.cpp torch_cpu c10d_cuda_test gtest_main) - if(INSTALL_TEST) - install(TARGETS ProcessGroupGlooTest DESTINATION bin) - endif() - c10d_add_test(ProcessGroupGlooAsyncTest.cpp torch_cpu c10d_cuda_test gtest_main) + c10d_add_test(ProcessGroupGlooTest.cpp LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main INSTALL_TEST ${INSTALL_TEST}) + c10d_add_test(ProcessGroupGlooAsyncTest.cpp LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main INSTALL_TEST ${INSTALL_TEST}) endif() if(USE_NCCL AND USE_C10D_NCCL) # NCCL is a private dependency of libtorch, but the tests include some @@ -45,13 +48,11 @@ if(USE_CUDA) # a private dependency of the tests as well. c10d_add_test( ProcessGroupNCCLTest.cpp - torch_cpu c10d_cuda_test gtest_main __caffe2_nccl) + LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main __caffe2_nccl INSTALL_TEST ${INSTALL_TEST}) c10d_add_test( ProcessGroupNCCLErrorsTest.cpp - torch_cpu c10d_cuda_test gtest_main __caffe2_nccl) + LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main __caffe2_nccl INSTALL_TEST ${INSTALL_TEST}) if(INSTALL_TEST) - install(TARGETS ProcessGroupNCCLTest DESTINATION bin) - install(TARGETS ProcessGroupNCCLErrorsTest DESTINATION bin) install(TARGETS c10d_cuda_test DESTINATION lib) endif() endif() @@ -62,15 +63,14 @@ if(USE_CUDA) # a private dependency of the tests as well. c10d_add_test( ProcessGroupUCCTest.cpp - torch_cpu c10d_cuda_test gtest_main __caffe2_ucc) + LINK_LIBRARIES torch_cpu c10d_cuda_test gtest_main __caffe2_ucc INSTALL_TEST ${INSTALL_TEST}) if(INSTALL_TEST) - install(TARGETS ProcessGroupUCCTest DESTINATION bin) install(TARGETS c10d_cuda_test DESTINATION lib) endif() endif() else() if(USE_GLOO AND USE_C10D_GLOO) - c10d_add_test(ProcessGroupGlooTest.cpp torch_cpu gtest_main) + c10d_add_test(ProcessGroupGlooTest.cpp LINK_LIBRARIES torch_cpu gtest_main INSTALL_TEST OFF) endif() endif() @@ -80,10 +80,7 @@ if(USE_MPI AND USE_C10D_MPI) # private headers of libtorch, which in turn include MPI. As a hacky # alternative to making MPI a public dependency of libtorch, we make it # a private dependency of the tests as well. - c10d_add_test(ProcessGroupMPITest.cpp torch_cpu MPI::MPI_CXX) - if(INSTALL_TEST) - install(TARGETS ProcessGroupMPITest DESTINATION bin) - endif() + c10d_add_test(ProcessGroupMPITest.cpp LINK_LIBRARIES torch_cpu MPI::MPI_CXX INSTALL_TEST ${INSTALL_TEST}) endif() if(LINUX AND USE_GLOO AND USE_C10D_GLOO) diff --git a/test/cpp/c10d/FileStoreTest.cpp b/test/cpp/c10d/FileStoreTest.cpp index 29b4b370b011e..67e008ff2a7e5 100644 --- a/test/cpp/c10d/FileStoreTest.cpp +++ b/test/cpp/c10d/FileStoreTest.cpp @@ -40,7 +40,7 @@ std::string tmppath() { } #endif -void testGetSet(std::string path, std::string prefix = "") { +void testGetSet(const std::string& path, const std::string& prefix = "") { // Basic Set/Get on File Store { auto fileStore = c10::make_intrusive(path, 2); @@ -99,17 +99,17 @@ void stressTestStore(std::string path, std::string prefix = "") { std::vector threads; c10d::test::Semaphore sem1, sem2; - for (C10_UNUSED const auto i : c10::irange(numThreads)) { - threads.emplace_back(std::thread([&] { + for ([[maybe_unused]] const auto i : c10::irange(numThreads)) { + threads.emplace_back([&] { auto fileStore = c10::make_intrusive(path, numThreads + 1); c10d::PrefixStore store(prefix, fileStore); sem1.post(); sem2.wait(); - for (C10_UNUSED const auto j : c10::irange(numIterations)) { + for ([[maybe_unused]] const auto j : c10::irange(numIterations)) { store.add("counter", 1); } - })); + }); } sem1.wait(numThreads); diff --git a/test/cpp/c10d/HashStoreTest.cpp b/test/cpp/c10d/HashStoreTest.cpp index f3478f6071b19..ad3f38fb93df9 100644 --- a/test/cpp/c10d/HashStoreTest.cpp +++ b/test/cpp/c10d/HashStoreTest.cpp @@ -3,15 +3,15 @@ #include -#include #include #include #include +#include constexpr int64_t kShortStoreTimeoutMillis = 100; -void testGetSet(std::string prefix = "") { +void testGetSet(const std::string& prefix = "") { // Basic set/get { auto hashStore = c10::make_intrusive(); @@ -60,16 +60,16 @@ void stressTestStore(std::string prefix = "") { std::vector threads; c10d::test::Semaphore sem1, sem2; auto hashStore = c10::make_intrusive(); - c10d::PrefixStore store(prefix, hashStore); + c10d::PrefixStore store(std::move(prefix), hashStore); - for (C10_UNUSED const auto i : c10::irange(numThreads)) { - threads.emplace_back(std::thread([&] { + for ([[maybe_unused]] const auto i : c10::irange(numThreads)) { + threads.emplace_back([&] { sem1.post(); sem2.wait(); - for (C10_UNUSED const auto j : c10::irange(numIterations)) { + for ([[maybe_unused]] const auto j : c10::irange(numIterations)) { store.add("counter", 1); } - })); + }); } sem1.wait(numThreads); diff --git a/test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp b/test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp index 0059560a602ab..086d26b8e8d14 100644 --- a/test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp +++ b/test/cpp/c10d/ProcessGroupGlooAsyncTest.cpp @@ -13,14 +13,14 @@ using namespace c10d::test; using at::cuda::CUDAStream; template -std::vector initialize(const std::string& path, int N, Args&&... args) { +std::vector initialize(const std::string& path, size_t N, Args&&... args) { std::vector tests; - for (C10_UNUSED const auto i : c10::irange(N)) { + for ([[maybe_unused]] const auto i : c10::irange(N)) { tests.push_back(std::move(T(path, std::forward(args)...))); } std::vector threads; - for (C10_UNUSED const auto i : c10::irange(N)) { + for ([[maybe_unused]] const auto i : c10::irange(N)) { threads.push_back(std::thread([i, N, &tests] { tests[i].start(i, N); })); } @@ -35,10 +35,7 @@ class AsyncTest { public: AsyncTest(std::string path) : path_(std::move(path)) {} - AsyncTest(AsyncTest&& other) { - path_ = std::move(other.path_); - pg_ = std::move(other.pg_); - } + AsyncTest(AsyncTest&& other) noexcept = default; ::c10d::ProcessGroupGloo& getProcessGroup() { return *pg_; @@ -53,8 +50,8 @@ class AsyncTest { options->devices.push_back( ::c10d::ProcessGroupGloo::createDeviceForHostname("127.0.0.1")); - pg_ = std::unique_ptr<::c10d::ProcessGroupGloo>( - new ::c10d::ProcessGroupGloo(store, rank, size, options)); + pg_ = + std::make_unique<::c10d::ProcessGroupGloo>(store, rank, size, options); } protected: @@ -69,7 +66,7 @@ class AsyncInputIsOutputTest : public AsyncTest { numTensors_(numTensors), numDevices_(cudaNumDevices()) { // Allocate inputs on available devices in a round robin fashion. - ::at::globalContext().lazyInitCUDA(); + ::at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); inputs_.resize(numTensors_); for (const auto i : c10::irange(numTensors_)) { inputs_[i] = at::empty( @@ -88,7 +85,7 @@ class AsyncInputIsOutputTest : public AsyncTest { at::cuda::OptionalCUDAGuard deviceGuard; streams_.reserve(numDevices_); for (const auto i : c10::irange(numDevices_)) { - deviceGuard.set_index(i); + deviceGuard.set_index(static_cast(i)); streams_.push_back(at::cuda::getStreamFromPool()); } } @@ -118,7 +115,9 @@ class AsyncInputIsOutputTest : public AsyncTest { } protected: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const int numTensors_; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const int numDevices_; std::vector inputs_; std::vector streams_; @@ -136,13 +135,13 @@ class AsyncAllreduceTest : public AsyncInputIsOutputTest { // Launch sleep on every stream at::cuda::OptionalCUDAGuard deviceGuard; for (const auto i : c10::irange(numDevices_)) { - deviceGuard.set_index(i); - cudaSleep(streams_[i], 10 * 1000 * 1000); + deviceGuard.set_index(static_cast(i)); + cudaSleep(streams_[i], 10ull * 1000 * 1000); } // Launch value initialization for every tensor for (const auto i : c10::irange(numTensors_)) { - deviceGuard.set_index(i % numDevices_); + deviceGuard.set_index(static_cast(i % numDevices_)); inputs_[i].fill_(pg_->getRank() * numTensors_ + i); } @@ -155,26 +154,26 @@ class AsyncBroadcastTest : public AsyncInputIsOutputTest { AsyncBroadcastTest(const std::string& path, int numTensors) : AsyncInputIsOutputTest(path, numTensors) {} - c10::intrusive_ptr run(int rootRank, int rootTensor) { + c10::intrusive_ptr run(size_t rootRank, size_t rootTensor) { // For the duration of this function, make THC use our streams c10::cuda::CUDAMultiStreamGuard guard(streams_); // Launch sleep on every stream at::cuda::OptionalCUDAGuard deviceGuard; for (const auto i : c10::irange(numDevices_)) { - deviceGuard.set_index(i); - cudaSleep(streams_[i], 10 * 1000 * 1000); + deviceGuard.set_index(static_cast(i)); + cudaSleep(streams_[i], 10ull * 1000 * 1000); } // Launch value initialization for every tensor for (const auto i : c10::irange(numTensors_)) { - deviceGuard.set_index(i % numDevices_); + deviceGuard.set_index(static_cast(i % numDevices_)); inputs_[i].fill_(pg_->getRank() * numTensors_ + i); } ::c10d::BroadcastOptions options; - options.rootRank = rootRank; - options.rootTensor = rootTensor; + options.rootRank = static_cast(rootRank); + options.rootTensor = static_cast(rootTensor); return pg_->broadcast(inputs_, options); } }; diff --git a/test/cpp/c10d/ProcessGroupGlooTest.cpp b/test/cpp/c10d/ProcessGroupGlooTest.cpp index a5c48bf31cfad..402ee72cad515 100644 --- a/test/cpp/c10d/ProcessGroupGlooTest.cpp +++ b/test/cpp/c10d/ProcessGroupGlooTest.cpp @@ -1,15 +1,12 @@ #ifndef _WIN32 -#include #include #include +#include #endif #include -#include -#include -#include -#include +#include #include #include @@ -30,7 +27,7 @@ constexpr auto kWaitTimeout = std::chrono::milliseconds(1); #ifndef _WIN32 class SignalTest { public: - SignalTest(const std::string& path) : path_(path) {} + SignalTest(std::string path) : path_(std::move(path)) {} ~SignalTest() { if (arm_.joinable()) { @@ -41,7 +38,7 @@ class SignalTest { // Arms test to send signal to PID when the semaphore unlocks. This // happens as soon as the first collective completes successfully. void arm(int pid, int signal) { - arm_ = std::thread([=] { + arm_ = std::thread([this, pid, signal] { sem_.wait(); kill(pid, signal); }); @@ -108,7 +105,7 @@ class ProcessGroupGlooDelayed : public ::c10d::ProcessGroupGloo { int rank, int size, c10::intrusive_ptr options) - : ProcessGroupGloo(store, rank, size, options) {} + : ProcessGroupGloo(store, rank, size, std::move(options)) {} c10::intrusive_ptr<::c10d::Work> send( std::vector& tensors, @@ -126,14 +123,14 @@ class CollectiveTest { int num, bool delayed = false) { std::vector tests; - for (C10_UNUSED const auto i : c10::irange(num)) { - tests.emplace_back(CollectiveTest(path)); + for ([[maybe_unused]] const auto i : c10::irange(num)) { + tests.emplace_back(path); } std::vector threads; for (const auto i : c10::irange(num)) { - threads.emplace_back(std::thread( - [i, &tests, delayed] { tests[i].start(i, tests.size(), delayed); })); + threads.emplace_back( + [i, &tests, delayed] { tests[i].start(i, tests.size(), delayed); }); } for (auto& thread : threads) { thread.join(); @@ -144,16 +141,13 @@ class CollectiveTest { CollectiveTest(std::string path) : path_(std::move(path)) {} - CollectiveTest(CollectiveTest&& other) { - path_ = std::move(other.path_); - pg_ = std::move(other.pg_); - } + CollectiveTest(CollectiveTest&& other) noexcept = default; ::c10d::ProcessGroupGloo& getProcessGroup() { return *pg_; } - void start(int rank, int size, bool delayed) { + void start(int rank, size_t size, bool delayed) { auto store = c10::make_intrusive<::c10d::FileStore>(path_, size); // Set a timeout that is small enough to make this test run fast, but also @@ -164,11 +158,11 @@ class CollectiveTest { ::c10d::ProcessGroupGloo::createDeviceForHostname("127.0.0.1")); if (!delayed) { - pg_ = std::unique_ptr<::c10d::ProcessGroupGloo>( - new ::c10d::ProcessGroupGloo(store, rank, size, options)); + pg_ = std::make_unique<::c10d::ProcessGroupGloo>( + store, rank, size, options); } else { - pg_ = std::unique_ptr( - new ProcessGroupGlooDelayed(store, rank, size, options)); + pg_ = + std::make_unique(store, rank, size, options); } } @@ -192,13 +186,13 @@ std::vector> copyTensors( } std::vector> waitWork( - std::vector> works) { + const std::vector>& works) { std::vector> outputTensors; for (auto& work : works) { try { work->wait(); } catch (const std::exception& ex) { - LOG(ERROR) << "Exception received: " << ex.what() << std::endl; + LOG(ERROR) << "Exception received: " << ex.what() << '\n'; } outputTensors.emplace_back(work->result()); } @@ -206,14 +200,14 @@ std::vector> waitWork( } std::vector> waitFuture( - std::vector> works) { + const std::vector>& works) { std::vector> outputTensors; for (auto& work : works) { auto fut = work->getFuture(); try { fut->wait(); } catch (const std::exception& ex) { - LOG(ERROR) << "Exception received: " << ex.what() << std::endl; + LOG(ERROR) << "Exception received: " << ex.what() << '\n'; } auto result = fut->value(); if (result.isNone()) { @@ -288,8 +282,7 @@ void testAllreduce( auto outputs = waitFuture(work); auto event_lists = disableProfilerLegacy(); - checkProfiledEvents( - std::move(event_lists), GLOO_ALLREDUCE_STR, size, allShapes); + checkProfiledEvents(event_lists, GLOO_ALLREDUCE_STR, size, allShapes); // Verify outputs const auto expected = (size * (size - 1)) / 2; @@ -334,8 +327,7 @@ void testAllreduceUsingWorkAPI( auto outputs = waitWork(work); auto event_lists = disableProfilerLegacy(); - checkProfiledEvents( - std::move(event_lists), GLOO_ALLREDUCE_STR, size, allShapes); + checkProfiledEvents(event_lists, GLOO_ALLREDUCE_STR, size, allShapes); // Verify outputs const auto expected = (size * (size - 1)) / 2; @@ -371,7 +363,8 @@ void testBroadcast( at::OptionalDeviceGuard deviceGuard; for (const auto l : c10::irange(stride)) { if (b == at::DeviceType::CUDA) { - deviceGuard.reset_device(at::Device(at::kCUDA, l)); + deviceGuard.reset_device( + at::Device(at::kCUDA, static_cast(l))); } inputs[k][l] = at::ones(shapes, at::dtype(dtype).device(b)) * (k * stride + l); @@ -396,8 +389,7 @@ void testBroadcast( auto outputs = waitFuture(work); auto event_lists = disableProfilerLegacy(); - checkProfiledEvents( - std::move(event_lists), GLOO_BROADCAST_STR, size, allShapes); + checkProfiledEvents(event_lists, GLOO_BROADCAST_STR, size, allShapes); // Verify outputs const auto expected = (i * stride + j); @@ -427,8 +419,9 @@ void testAlltoall(const std::string& path, const at::DeviceType b) { {30, 31, 32, 33, 34, 35, 36}, }; for (const auto rank : c10::irange(size)) { - const std::vector& blob = blobs[rank]; - inputs[rank] = at::from_blob((int32_t*)(blob.data()), blob.size()).to(b); + std::vector& blob = blobs[rank]; + inputs[rank] = + at::from_blob(blob.data(), static_cast(blob.size())).to(b); } // Allocate outputs @@ -478,7 +471,7 @@ void testAlltoall(const std::string& path, const at::DeviceType b) { } auto event_lists = disableProfilerLegacy(); - checkProfiledEvents(std::move(event_lists), GLOO_A2A_STR, size, allShapes); + checkProfiledEvents(event_lists, GLOO_A2A_STR, size, allShapes); // Verify outputs std::vector> expected = { {0, 1, 10, 11, 12, 20, 21, 30, 31}, @@ -516,7 +509,7 @@ void testBarrier(const std::string& path) { std::vector> allShapes; // Barrier does not use tensors, so skip shape checking. checkProfiledEvents( - std::move(event_lists), + event_lists, GLOO_STR, size, allShapes, @@ -533,7 +526,7 @@ void testMonitoredBarrier(const std::string& path) { std::vector threads; threads.reserve(size); for (const auto r : c10::irange(size)) { - threads.emplace_back(std::thread([=]() { runMonitoredBarrier(r); })); + threads.emplace_back([=]() { runMonitoredBarrier(r); }); } for (auto& t : threads) { t.join(); @@ -555,8 +548,7 @@ void testMonitoredBarrier(const std::string& path) { }; threads.clear(); for (const auto r : c10::irange(size)) { - threads.emplace_back( - std::thread([=]() { runMonitoredBarrierWithException(r); })); + threads.emplace_back([=]() { runMonitoredBarrierWithException(r); }); } for (auto& t : threads) { t.join(); @@ -613,14 +605,14 @@ void testSend(const std::string& path) { enableProfilerLegacy(ProfilerConfig( ProfilerState::CPU, /* report_input_shapes */ true, false)); auto sendWork = pg.send(tensors, dstRank, tag); - bool sendCompleted; + bool sendCompleted = false; std::thread waitSendThreadAbort([&]() { sendCompleted = sendWork->wait(); }); sendWork->abort(); // Block until the sendWork gets successfully aborted waitSendThreadAbort.join(); EXPECT_FALSE(sendCompleted); auto event_lists = disableProfilerLegacy(); - checkProfiledEvents(std::move(event_lists), GLOO_SEND_STR, 1, allShapes); + checkProfiledEvents(event_lists, GLOO_SEND_STR, 1, allShapes); // Now create a separate sender thread to ensure that future waitsends can // complete successfully. @@ -663,14 +655,14 @@ void testRecv(const std::string& path) { enableProfilerLegacy(ProfilerConfig( ProfilerState::CPU, /* report_input_shapes */ true, false)); auto recvWork = pg.recv(tensors, srcRank, tag); - bool recvCompleted; + bool recvCompleted = false; std::thread waitRecvThreadAbort([&]() { recvCompleted = recvWork->wait(); }); recvWork->abort(); // Block until the first recv gets successfully aborted waitRecvThreadAbort.join(); EXPECT_FALSE(recvCompleted); auto event_lists = disableProfilerLegacy(); - checkProfiledEvents(std::move(event_lists), GLOO_RECV_STR, 1, allShapes); + checkProfiledEvents(event_lists, GLOO_RECV_STR, 1, allShapes); // Now create a separate receiver thread to ensure that future waits can // complete successfully. diff --git a/test/cpp/c10d/ProcessGroupMPITest.cpp b/test/cpp/c10d/ProcessGroupMPITest.cpp index d9fcacc83d2fe..1112ab723bd54 100644 --- a/test/cpp/c10d/ProcessGroupMPITest.cpp +++ b/test/cpp/c10d/ProcessGroupMPITest.cpp @@ -5,23 +5,21 @@ #include #include -#include #include -#include #define STR_HELPER(x) #x #define STR(x) STR_HELPER(x) // Wait for work to complete std::vector> waitWork( - c10::intrusive_ptr<::c10d::ProcessGroupMPI> pg, - std::vector> works) { + const c10::intrusive_ptr<::c10d::ProcessGroupMPI>& pg, + const std::vector>& works) { std::vector> outputTensors; for (auto& work : works) { try { work->wait(); } catch (const std::exception& ex) { - std::cerr << "Exception received: " << ex.what() << std::endl; + std::cerr << "Exception received: " << ex.what() << '\n'; pg->abort(); } outputTensors.emplace_back(work->result()); @@ -31,15 +29,15 @@ std::vector> waitWork( // Wait using Futures std::vector> waitFuture( - c10::intrusive_ptr<::c10d::ProcessGroupMPI> pg, - std::vector> works) { + const c10::intrusive_ptr<::c10d::ProcessGroupMPI>& pg, + const std::vector>& works) { std::vector> outputTensors; for (auto& work : works) { auto fut = work->getFuture(); try { fut->wait(); } catch (const std::exception& ex) { - std::cerr << "Exception received: " << ex.what() << std::endl; + std::cerr << "Exception received: " << ex.what() << '\n'; pg->abort(); } auto result = fut->value(); @@ -78,7 +76,7 @@ void testAllreduce(int iter = 1000) { const auto expected = worldSize * i; auto data = outputTensors[i][0].data_ptr(); for (auto j = 0; j < outputTensors[i][0].numel(); ++j) { - if (data[j] != expected) { + if (data[j] != static_cast(expected)) { TORCH_CHECK(false, "BOOM!"); } } @@ -110,7 +108,7 @@ void testBroadcast(int iter = 10000) { const auto expected = i; auto data = outputTensors[i][0].data_ptr(); for (auto j = 0; j < outputTensors[i][0].numel(); ++j) { - if (data[j] != expected) { + if (data[j] != static_cast(expected)) { TORCH_CHECK(false, "BOOM!"); } } @@ -140,7 +138,7 @@ void testReduce(int iter = 10000) { const auto expected = worldSize * i; auto data = outputTensors[i][0].data_ptr(); for (auto j = 0; j < outputTensors[i][0].numel(); ++j) { - if (data[j] != expected) { + if (data[j] != static_cast(expected)) { TORCH_CHECK(false, "BOOM!"); } } @@ -179,7 +177,7 @@ void testAllgather(int iter = 10000) { const auto expected = i * j; auto data = outputTensors[i][j].data_ptr(); for (auto k = 0; k < outputTensors[i][j].numel(); ++k) { - if (data[k] != expected) { + if (data[k] != static_cast(expected)) { TORCH_CHECK(false, "BOOM!"); } } @@ -222,7 +220,7 @@ void testGather(int iter = 10000) { const auto expected = i * j; auto data = outputTensors[i][j].data_ptr(); for (auto k = 0; k < outputTensors[i][j].numel(); ++k) { - if (data[k] != expected) { + if (data[k] != static_cast(expected)) { TORCH_CHECK(false, "BOOM!"); } } @@ -230,7 +228,7 @@ void testGather(int iter = 10000) { } } else { for (const auto i : c10::irange(iter)) { - if (outputTensors[i].size() != 0) { + if (!outputTensors[i].empty()) { TORCH_CHECK(false, "BOOM!"); } } @@ -271,7 +269,7 @@ void testScatter(int iter = 1) { const auto expected = i * j; auto data = outputTensors[i][0].data_ptr(); for (auto k = 0; k < outputTensors[i][0].numel(); ++k) { - if (data[k] != expected) { + if (data[k] != static_cast(expected)) { TORCH_CHECK(false, "BOOM!"); } } @@ -331,7 +329,7 @@ void testSendRecv(bool recvAnysource, int iter = 10000) { const auto expected = i; auto data = outputTensors[i][0].data_ptr(); for (auto j = 0; j < outputTensors[i][0].numel(); ++j) { - if (data[j] != expected) { + if (data[j] != static_cast(expected)) { TORCH_CHECK(false, "BOOM!"); } } @@ -349,7 +347,7 @@ int main(int argc, char** argv) { #ifdef MPIEXEC // If we are within an openmpi mpirun, then skip the exec if (!std::getenv("OMPI_COMM_WORLD_SIZE")) { - std::cout << "Execute mpiexec from: " << STR(MPIEXEC) << std::endl; + std::cout << "Execute mpiexec from: " << STR(MPIEXEC) << '\n'; execl(STR(MPIEXEC), "-np 2", argv[0], (char*)nullptr); } @@ -363,7 +361,7 @@ int main(int argc, char** argv) { testSendRecv(true); testBackendName(); - std::cout << "Test successful" << std::endl; + std::cout << "Test successful" << '\n'; #else std::cout << "MPI executable not found, skipping test" << std::endl; #endif diff --git a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp index d416847f7911a..847c96f07bd0b 100644 --- a/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLErrorsTest.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "CUDATest.hpp" #include "TestUtils.hpp" @@ -24,8 +25,9 @@ class WorkNCCLSimulateErrors : public c10d::ProcessGroupNCCL::WorkNCCL { bool simulate_error, int rank, c10d::OpType opType, - uint64_t seq) - : WorkNCCL("0", "default_pg", device, rank, opType, seq), + uint64_t seq, + bool isP2P) + : WorkNCCL("0", "default_pg", device, rank, opType, seq, isP2P), simulateError_(simulate_error) {} std::exception_ptr checkForNCCLErrors() override { @@ -46,7 +48,7 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { int rank, int size, c10::intrusive_ptr opts) - : ProcessGroupNCCL(store, rank, size, opts), simulateError_(false) {} + : ProcessGroupNCCL(store, rank, size, std::move(opts)) {} std::exception_ptr checkForNCCLErrors( std::shared_ptr& ncclComm) override { @@ -65,12 +67,18 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { at::Device& device, int rank, c10d::OpType opType, + bool isP2P, const char* profilingTitle, const std::vector& inputs = {}, const std::vector& outputs = {}, bool record = false) override { return c10::make_intrusive( - device, simulateError_, rank, opType, seqCollective_); + device, + simulateError_, + rank, + opType, + isP2P ? seqP2P_ : seqCollective_, + isP2P); } size_t getNCCLCommCacheSize() { @@ -86,7 +94,7 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL { } private: - bool simulateError_; + bool simulateError_{false}; }; class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL { @@ -96,8 +104,9 @@ class WorkNCCLTimedoutErrors : public c10d::ProcessGroupNCCL::WorkNCCL { bool set_timedout_error, int rank, c10d::OpType opType, - uint64_t seq) - : WorkNCCL("0", "default_pg", device, rank, opType, seq), + uint64_t seq, + bool isP2P) + : WorkNCCL("0", "default_pg", device, rank, opType, seq, isP2P), setTimedoutError_(set_timedout_error) {} private: @@ -119,20 +128,24 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { int rank, int size, c10::intrusive_ptr opts) - : ProcessGroupNCCLSimulateErrors(store, rank, size, opts), - watchDogDebugInfoFinished_(false), - setTimedoutError_(false) {} + : ProcessGroupNCCLSimulateErrors(store, rank, size, std::move(opts)) {} c10::intrusive_ptr initWork( at::Device& device, int rank, c10d::OpType opType, + bool isP2P, const char* profilingTitle, const std::vector& inputs = {}, const std::vector& outputs = {}, bool record = false) override { return c10::make_intrusive( - device, setTimedoutError_, rank, opType, seqCollective_); + device, + setTimedoutError_, + rank, + opType, + isP2P ? seqP2P_ : seqCollective_, + isP2P); } void setTimedoutError() { @@ -163,10 +176,10 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors { watchDogDebugInfoFinished_ = true; return ""; } - bool watchDogDebugInfoFinished_; + bool watchDogDebugInfoFinished_{false}; private: - bool setTimedoutError_; + bool setTimedoutError_{false}; }; class ProcessGroupNCCLNoHeartbeatCaught @@ -177,8 +190,7 @@ class ProcessGroupNCCLNoHeartbeatCaught int rank, int size, c10::intrusive_ptr opts) - : ProcessGroupNCCLTimedOutErrors(store, rank, size, opts), - hasMonitorThreadCaughtError_(false) {} + : ProcessGroupNCCLTimedOutErrors(store, rank, size, std::move(opts)) {} std::mutex& getWatchdogMutex() { return workMetaListMutex_; @@ -209,11 +221,11 @@ class ProcessGroupNCCLNoHeartbeatCaught // It's really hard to unit test std::abort. So we override it instead. // Commented this override, we do see process aborted with core dump without // this override. - void terminateProcess(std::string errMsg) override { + void terminateProcess(const std::string& errMsg) override { throw std::runtime_error(errMsg); } - bool hasMonitorThreadCaughtError_; + bool hasMonitorThreadCaughtError_{false}; }; class ProcessGroupNCCLDebugInfoStuck @@ -224,7 +236,7 @@ class ProcessGroupNCCLDebugInfoStuck int rank, int size, c10::intrusive_ptr opts) - : ProcessGroupNCCLNoHeartbeatCaught(store, rank, size, opts) {} + : ProcessGroupNCCLNoHeartbeatCaught(store, rank, size, std::move(opts)) {} protected: // Override the heartbeat monitor function to set a long timeout to mimic the @@ -292,13 +304,8 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsBlocking) { // Now run all reduce with errors. pg.simulateError(); work = pg.allreduce(tensors_); - EXPECT_THROW(work->wait(), std::runtime_error); - // Verify the work item failed. - EXPECT_TRUE(work->isCompleted()); EXPECT_THROW(work->wait(), std::runtime_error); - - // Communicators might be aborted here, further operations would fail. } TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) { @@ -320,6 +327,10 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLTimedoutErrorsBlocking) { } TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) { + // Avoid watchdog thread to throw the exception first to test the barrier + // throw behavior. + ASSERT_TRUE( + setenv(c10d::TORCH_NCCL_ASYNC_ERROR_HANDLING[0].c_str(), "0", 1) == 0); auto options = c10d::ProcessGroupNCCL::Options::create(); options->timeout = std::chrono::milliseconds(3000); ProcessGroupNCCLSimulateErrors pg(store_, 0, 1, options); @@ -332,12 +343,10 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNonBlocking) { pg.simulateError(); work = pg.allreduce(tensors_); - // Should not throw exceptions. work->wait(); - pg.barrier()->wait(); - - EXPECT_TRUE(work->isCompleted()); - // Communicators might be aborted here, further operations would fail. + // a NCCL ERROR happened before should stop the thread from passing the + // barrier. + EXPECT_THROW(pg.barrier()->wait(), std::runtime_error); } // Function to read what we wrote to the local disk for validation. @@ -346,7 +355,7 @@ std::string readTraceFromFile(const std::string& filename, size_t size) { // Read the strings from the file if (file) { // While the file stream is in good state std::string str(size, '\0'); - file.read(&str[0], size); + file.read(&str[0], static_cast(size)); if (file) { return str; } @@ -357,7 +366,7 @@ std::string readTraceFromFile(const std::string& filename, size_t size) { // Extend the nested class outside the parent class class TestDebugInfoWriter : public c10d::DebugInfoWriter { public: - TestDebugInfoWriter(std::string namePrefix) + TestDebugInfoWriter(const std::string& namePrefix) : DebugInfoWriter(namePrefix, 0) {} void write(const std::string& ncclTrace) override { @@ -376,7 +385,7 @@ class TestDebugInfoWriter : public c10d::DebugInfoWriter { TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { int heartBeatIntervalInSec = 2; std::string timeInterval = std::to_string(heartBeatIntervalInSec); - ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "1", 1) == 0); + ASSERT_TRUE(setenv(c10d::TORCH_NCCL_BLOCKING_WAIT[0].c_str(), "0", 1) == 0); ASSERT_TRUE( setenv( c10d::TORCH_NCCL_HEARTBEAT_TIMEOUT_SEC[0].c_str(), @@ -422,7 +431,7 @@ TEST_F(ProcessGroupNCCLErrorsTest, testNCCLErrorsNoHeartbeat) { EXPECT_TRUE(pg.getErrorCaughtFlag()); } work->wait(); - EXPECT_TRUE(traces.size() > 0); + EXPECT_TRUE(!traces.empty()); auto filename = c10::str(tempFilename, 0); auto traceFromStorage = readTraceFromFile(filename, traces.size()); // Check the traces read from storage match with the original nccl trace. diff --git a/test/cpp/c10d/ProcessGroupNCCLTest.cpp b/test/cpp/c10d/ProcessGroupNCCLTest.cpp index d7436248f100c..769bbaeca385d 100644 --- a/test/cpp/c10d/ProcessGroupNCCLTest.cpp +++ b/test/cpp/c10d/ProcessGroupNCCLTest.cpp @@ -21,15 +21,12 @@ using at::cuda::CUDAStream; class NCCLTestBase { public: NCCLTestBase( - const std::string& path, + std::string path, const std::chrono::milliseconds pgTimeout = c10d::kProcessGroupNCCLDefaultTimeout) - : path_(path), pgTimeout_(pgTimeout) {} + : path_(std::move(path)), pgTimeout_(pgTimeout) {} - NCCLTestBase(NCCLTestBase&& other) { - path_ = std::move(other.path_); - pg_ = std::move(other.pg_); - } + NCCLTestBase(NCCLTestBase&& other) noexcept = default; std::shared_ptr<::c10d::ProcessGroupNCCL> getProcessGroup() { return pg_; @@ -41,7 +38,7 @@ class NCCLTestBase { void initialize( int rank, - int size, + size_t size, std::optional<::std::shared_ptr<::c10d::ProcessGroupNCCL>> split_from = std::nullopt) { store_ = c10::make_intrusive<::c10d::FileStore>(path_, size); @@ -55,8 +52,8 @@ class NCCLTestBase { opts->split_color = ++color_; } #endif - pg_ = std::unique_ptr<::c10d::ProcessGroupNCCL>( - new ::c10d::ProcessGroupNCCL(store_, rank, size, std::move(opts))); + pg_ = std::make_unique<::c10d::ProcessGroupNCCL>( + store_, rank, size, std::move(opts)); } protected: @@ -76,22 +73,19 @@ class NCCLTest : public NCCLTestBase { std::chrono::milliseconds pgTimeout = c10d::kProcessGroupNCCLDefaultTimeout, int inputDim = 3) - : NCCLTestBase(path, pgTimeout), - numDevices_(1), // one device per rank (thread) - rank_(rank), - worldSize_(worldSize) { + : NCCLTestBase(path, pgTimeout), rank_(rank), worldSize_(worldSize) { // Each device has a single tensor to perf the NCCL op - ::at::globalContext().lazyInitCUDA(); + ::at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); tensors_.resize(numDevices_); inputs_.resize(numDevices_); outputs_.resize(numDevices_); at::cuda::OptionalCUDAGuard deviceGuard; assert(numDevices_ == 1); for (const auto i : c10::irange(numDevices_)) { - deviceGuard.set_index(rank_); + deviceGuard.set_index(static_cast(rank_)); tensors_[i] = at::empty({inputDim, inputDim}, at::kCUDA); - inputs_[i].resize(worldSize_ * numDevices_); - outputs_[i].resize(worldSize_ * numDevices_); + inputs_[i].resize(static_cast(worldSize_) * numDevices_); + outputs_[i].resize(static_cast(worldSize_) * numDevices_); for (auto j = 0; j < worldSize_ * numDevices_; ++j) { inputs_[i][j] = at::empty({inputDim, inputDim}, at::kCUDA); outputs_[i][j] = at::empty({inputDim, inputDim}, at::kCUDA); @@ -106,7 +100,7 @@ class NCCLTest : public NCCLTestBase { // getters to retrieve the current stream). // // 1 device only, hence 1 stream only - deviceGuard.set_index(rank_); + deviceGuard.set_index(static_cast(rank_)); streams_.push_back(at::cuda::getStreamFromPool()); } @@ -148,7 +142,8 @@ class NCCLTest : public NCCLTestBase { std::vector>& tensor_lists) { std::vector> outputs(numDevices_); for (auto& output : outputs) { - output = std::vector(worldSize_ * numDevices_); + output = std::vector( + static_cast(worldSize_ * numDevices_)); } // For the duration of this function, make THC use our streams @@ -169,8 +164,8 @@ class NCCLTest : public NCCLTestBase { void launchDeviceSleep() { at::cuda::OptionalCUDAGuard deviceGuard; for (const auto i : c10::irange(numDevices_)) { - deviceGuard.set_index(rank_); - cudaSleep(streams_[i], 2000 * 1000 * 1000); + deviceGuard.set_index(static_cast(rank_)); + cudaSleep(streams_[i], 2000ull * 1000 * 1000); } } @@ -178,7 +173,7 @@ class NCCLTest : public NCCLTestBase { void valueInitialization() { at::cuda::OptionalCUDAGuard deviceGuard; for (const auto i : c10::irange(numDevices_)) { - deviceGuard.set_index(rank_); + deviceGuard.set_index(static_cast(rank_)); tensors_[i].fill_(pg_->getRank() * numDevices_ + i); } } @@ -199,14 +194,15 @@ class NCCLTest : public NCCLTestBase { void valueInitializationForSparse() { at::cuda::OptionalCUDAGuard deviceGuard; for (const auto i : c10::irange(numDevices_)) { - deviceGuard.set_index(rank_); + deviceGuard.set_index(static_cast(rank_)); tensors_[i].fill_(pg_->getRank() * numDevices_ + i + 1); // Convert the dense tensor to a sparse tensor in COO row format tensors_[i] = to_sparse_row_indices_format(tensors_[i]); } } - const int numDevices_; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const int numDevices_{1}; // one device per rank (thread) int rank_; int worldSize_; std::vector tensors_; @@ -374,7 +370,7 @@ class ReduceScatterBaseNCCLTest : public NCCLTest { ReduceScatterBaseNCCLTest(const std::string& path, int rank, int worldSize) : NCCLTest(path, rank, worldSize) { at::cuda::OptionalCUDAGuard deviceGuard; - deviceGuard.set_index(rank_); + deviceGuard.set_index(static_cast(rank_)); output_tensor_ = at::empty({1}, at::kCUDA); input_tensor_ = at::empty({worldSize}, at::kCUDA); for (const auto i : c10::irange(worldSize)) { @@ -755,7 +751,7 @@ class ProcessGroupNCCLTest : public ::testing::Test { std::vector threads; threads.reserve(size_); for (const auto rank : c10::irange(size_)) { - threads.emplace_back(std::thread(testFunc, file.path, rank, size_)); + threads.emplace_back(testFunc, file.path, rank, size_); } for (const auto rank : c10::irange(size_)) { threads[rank].join(); @@ -765,6 +761,33 @@ class ProcessGroupNCCLTest : public ::testing::Test { int size_{1}; }; +TEST_F(ProcessGroupNCCLTest, CUDAEventCache) { + if (skipTest()) { + return; + } + + // Test that the CUDAEventCache can be used to create CUDA events and reuse. + auto event1 = c10d::ProcessGroupNCCL::CUDAEventCache::get().create(true); + auto event2 = c10d::ProcessGroupNCCL::CUDAEventCache::get().create(false); + + auto event1_ptr = event1.get(); + auto event2_ptr = event2.get(); + // Mimic the behavior of the destroy of events. + event1 = nullptr; + event2 = nullptr; + + // Test that the CUDAEventCache is indeed reused. + auto event3 = c10d::ProcessGroupNCCL::CUDAEventCache::get().create(true); + auto event4 = c10d::ProcessGroupNCCL::CUDAEventCache::get().create(false); + // The cache has been used up, new events should be created. + auto event5 = c10d::ProcessGroupNCCL::CUDAEventCache::get().create(true); + auto event6 = c10d::ProcessGroupNCCL::CUDAEventCache::get().create(false); + EXPECT_EQ(event1_ptr, event3.get()); + EXPECT_EQ(event2_ptr, event4.get()); + EXPECT_NE(event1_ptr, event5.get()); + EXPECT_NE(event2_ptr, event6.get()); +} + TEST_F(ProcessGroupNCCLTest, testAllreduce) { if (skipTest()) { return; @@ -827,7 +850,7 @@ TEST_F(ProcessGroupNCCLTest, testBackendName) { } TemporaryFile file; auto test = NCCLTestBase(file.path); - test.initialize(/*rank=*/0, /*world_size=*/1); + test.initialize(/*rank=*/0, /*size=*/1); EXPECT_EQ( test.getProcessGroup()->getBackendName(), std::string(c10d::NCCL_BACKEND_NAME)); diff --git a/test/cpp/c10d/ProcessGroupUCCTest.cpp b/test/cpp/c10d/ProcessGroupUCCTest.cpp index a31e990536e10..84affb59cc2da 100644 --- a/test/cpp/c10d/ProcessGroupUCCTest.cpp +++ b/test/cpp/c10d/ProcessGroupUCCTest.cpp @@ -1,11 +1,9 @@ +#ifdef USE_C10D_UCC #include #include #include #include - -using namespace c10d; - TEST(ProcessGroupUCCTest, testTrim) { std::vector> tests = { {" allreduce ", "allreduce"}, @@ -13,7 +11,7 @@ TEST(ProcessGroupUCCTest, testTrim) { {"send\n", "send"}, }; for (auto entry : tests) { - ASSERT_EQ(trim(entry.first), entry.second); + ASSERT_EQ(c10d::trim(entry.first), entry.second); } } @@ -24,12 +22,13 @@ TEST(ProcessGroupUCCTest, testToLower) { {"send", "send"}, }; for (auto entry : tests) { - ASSERT_EQ(tolower(entry.first), entry.second); + ASSERT_EQ(c10d::tolower(entry.first), entry.second); } } TEST(ProcessGroupUCCTest, testParseList) { std::string input = "\tAllReduce, ALLGATHER, send\n"; std::vector expect{"allreduce", "allgather", "send"}; - ASSERT_EQ(parse_list(input), expect); + ASSERT_EQ(c10d::parse_list(input), expect); } +#endif diff --git a/test/cpp/c10d/TCPStoreTest.cpp b/test/cpp/c10d/TCPStoreTest.cpp index 7351984f36c99..48504a2d0d973 100644 --- a/test/cpp/c10d/TCPStoreTest.cpp +++ b/test/cpp/c10d/TCPStoreTest.cpp @@ -2,10 +2,7 @@ #include "StoreTestCommon.hpp" #include -#include -#include #include -#include #include #include @@ -104,33 +101,32 @@ void testHelper(bool useLibUV, const std::string& prefix = "") { std::to_string(numThreads * numIterations + 1); for (const auto i : c10::irange(numThreads)) { - threads.emplace_back( - std::thread([=, &sem1, &sem2, &clientStores, &expectedCounterRes] { - for (C10_UNUSED const auto j : c10::irange(numIterations)) { - clientStores[i]->add("counter", 1); - } - // Let each thread set and get key on its client store - std::string key = "thread_" + std::to_string(i); - for (const auto j : c10::irange(numIterations)) { - std::string val = "thread_val_" + std::to_string(j); - c10d::test::set(*clientStores[i], key, val); - c10d::test::check(*clientStores[i], key, val); - } - - sem1.post(); - sem2.wait(); - // Check the counter results - c10d::test::check(*clientStores[i], "counter", expectedCounterRes); - // Now check other threads' written data - for (const auto j : c10::irange(numThreads)) { - if (j == i) { - continue; - } - std::string key = "thread_" + std::to_string(i); - std::string val = "thread_val_" + std::to_string(numIterations - 1); - c10d::test::check(*clientStores[i], key, val); - } - })); + threads.emplace_back([=, &sem1, &sem2, &clientStores, &expectedCounterRes] { + for ([[maybe_unused]] const auto j : c10::irange(numIterations)) { + clientStores[i]->add("counter", 1); + } + // Let each thread set and get key on its client store + std::string key = "thread_" + std::to_string(i); + for (const auto j : c10::irange(numIterations)) { + std::string val = "thread_val_" + std::to_string(j); + c10d::test::set(*clientStores[i], key, val); + c10d::test::check(*clientStores[i], key, val); + } + + sem1.post(); + sem2.wait(); + // Check the counter results + c10d::test::check(*clientStores[i], "counter", expectedCounterRes); + // Now check other threads' written data + for (const auto j : c10::irange(numThreads)) { + if (j == i) { + continue; + } + std::string key = "thread_" + std::to_string(i); + std::string val = "thread_val_" + std::to_string(numIterations - 1); + c10d::test::check(*clientStores[i], key, val); + } + }); } sem1.wait(numThreads); diff --git a/test/cpp/dist_autograd/CMakeLists.txt b/test/cpp/dist_autograd/CMakeLists.txt index 0ae6e3bef1410..6b5bba4b82086 100644 --- a/test/cpp/dist_autograd/CMakeLists.txt +++ b/test/cpp/dist_autograd/CMakeLists.txt @@ -14,6 +14,7 @@ if(USE_DISTRIBUTED AND NOT WIN32) endif() if(INSTALL_TEST) + set_target_properties(test_dist_autograd PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_dist_autograd DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt index f0510d9c81f20..cd2eaf761dffd 100644 --- a/test/cpp/jit/CMakeLists.txt +++ b/test/cpp/jit/CMakeLists.txt @@ -143,14 +143,15 @@ if(USE_CUDA) endif() elseif(USE_ROCM) target_link_libraries(test_jit PRIVATE - ${ROCM_HIPRTC_LIB} - ${PYTORCH_HIP_LIBRARIES} + hiprtc::hiprtc + hip::amdhip64 ${TORCH_CUDA_LIBRARIES}) target_compile_definitions(test_jit PRIVATE USE_ROCM) endif() if(INSTALL_TEST) + set_target_properties(test_jit PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_jit DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/jit/test_alias_analysis.cpp b/test/cpp/jit/test_alias_analysis.cpp index 5a094462fca3f..3c89f8104a106 100644 --- a/test/cpp/jit/test_alias_analysis.cpp +++ b/test/cpp/jit/test_alias_analysis.cpp @@ -298,7 +298,8 @@ inline void expectThrows(Functor&& functor, const char* expectMessageContains) { } catch (const Exception& e) { if (std::string(e.what()).find(expectMessageContains) == std::string::npos) { - AT_ERROR( + TORCH_CHECK( + false, "Expected error message to contain \"", expectMessageContains, "\" but error message was: ", @@ -306,7 +307,8 @@ inline void expectThrows(Functor&& functor, const char* expectMessageContains) { } return; } - AT_ERROR( + TORCH_CHECK( + false, "Expected to throw exception containing \"", expectMessageContains, "\" but didn't throw"); diff --git a/test/cpp/jit/test_custom_class_registrations.cpp b/test/cpp/jit/test_custom_class_registrations.cpp index c3448a46cdf0a..d1e0d5fa2180b 100644 --- a/test/cpp/jit/test_custom_class_registrations.cpp +++ b/test/cpp/jit/test_custom_class_registrations.cpp @@ -145,17 +145,39 @@ struct TensorQueue : torch::CustomClassHolder { } } - c10::Dict serialize() const { - c10::Dict dict; - dict.insert(std::string("init_tensor"), init_tensor_); - const std::string key = "queue"; - dict.insert( - key + "/size", torch::tensor(static_cast(queue_.size()))); - for (const auto index : c10::irange(queue_.size())) { - dict.insert(key + "/" + std::to_string(index), queue_[index]); + std::tuple< + std::tuple, + std::tuple>> + serialize() { + return std::tuple( + std::tuple("init_tensor", this->init_tensor_.clone()), + std::tuple("queue", this->clone_queue())); + } + + static c10::intrusive_ptr deserialize( + std::tuple< + std::tuple, + std::tuple>> flattened) { + TORCH_CHECK(std::tuple_size::value == 2); + + auto init_tensor_tuple = std::get<0>(flattened); + TORCH_CHECK(std::tuple_size::value == 2); + TORCH_CHECK(std::get<0>(init_tensor_tuple) == std::string("init_tensor")); + + c10::intrusive_ptr queue = + c10::make_intrusive(std::get<1>(init_tensor_tuple)); + + auto queue_tuple = std::get<1>(flattened); + TORCH_CHECK(std::tuple_size::value == 2); + TORCH_CHECK(std::get<0>(queue_tuple) == std::string("queue")); + + for (auto& value : std::get<1>(queue_tuple)) { + queue->push(value); } - return dict; + + return queue; } + // Push the element to the rear of queue. // Lock is added for thread safe. void push(at::Tensor x) { @@ -639,13 +661,17 @@ TORCH_LIBRARY(_TorchScriptTesting, m) { .def_pickle( // __getstate__ [](const c10::intrusive_ptr& self) - -> c10::Dict { + -> std::tuple< + std::tuple, + std::tuple>> { return self->serialize(); }, // __setstate__ - [](c10::Dict data) + [](std::tuple< + std::tuple, + std::tuple>> data) -> c10::intrusive_ptr { - return c10::make_intrusive(std::move(data)); + return TensorQueue::deserialize(data); }); } diff --git a/test/cpp/lazy/CMakeLists.txt b/test/cpp/lazy/CMakeLists.txt index be37b47ac9b92..9542343ff7816 100644 --- a/test/cpp/lazy/CMakeLists.txt +++ b/test/cpp/lazy/CMakeLists.txt @@ -36,14 +36,15 @@ if(USE_CUDA) target_compile_definitions(test_lazy PRIVATE USE_CUDA) elseif(USE_ROCM) target_link_libraries(test_lazy PRIVATE - ${ROCM_HIPRTC_LIB} - ${PYTORCH_HIP_LIBRARIES} + hiprtc::hiprtc + hip::amdhip64 ${TORCH_CUDA_LIBRARIES}) target_compile_definitions(test_lazy PRIVATE USE_ROCM) endif() if(INSTALL_TEST) + set_target_properties(test_lazy PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_lazy DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/profiler/record_function.cpp b/test/cpp/profiler/record_function.cpp index 2e652ed038b64..0e3ef95f0c614 100644 --- a/test/cpp/profiler/record_function.cpp +++ b/test/cpp/profiler/record_function.cpp @@ -303,3 +303,31 @@ TEST(RecordFunctionTest, MultipleCallbacks) { at::clearCallbacks(); ASSERT_FALSE(at::hasCallbacks()); } + +// Test that KwargsOnly callbacks are run in USER_SCOPE. +TEST(RecordFunctionTest, KwargsOnly) { + at::clearCallbacks(); + ASSERT_FALSE(at::hasCallbacks()); + static const std::unordered_map myMap = { + {"a", 1}, {"b", 2.5}}; + +#define REGISTER_CALLBACK() \ + at::addThreadLocalCallback( \ + at::RecordFunctionCallback( \ + [](const at::RecordFunction& fn) \ + -> std::unique_ptr { \ + EXPECT_EQ(myMap, fn.kwinputs()); \ + return nullptr; \ + }, \ + [](const at::RecordFunction& fn, at::ObserverContext*) {}) \ + .needsInputs(true) \ + .scopes({at::RecordScope::USER_SCOPE})) + + REGISTER_CALLBACK(); +#undef REGISTER_CALLBACK + + RECORD_USER_SCOPE_WITH_KWARGS_ONLY("Test", &myMap); + + at::clearCallbacks(); + ASSERT_FALSE(at::hasCallbacks()); +} diff --git a/test/cpp/rpc/CMakeLists.txt b/test/cpp/rpc/CMakeLists.txt index 6834b428ff937..5c3a0dc020de9 100644 --- a/test/cpp/rpc/CMakeLists.txt +++ b/test/cpp/rpc/CMakeLists.txt @@ -37,6 +37,7 @@ if(USE_CUDA) endif() if(INSTALL_TEST) + set_target_properties(test_cpp_rpc PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_cpp_rpc DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/tensorexpr/CMakeLists.txt b/test/cpp/tensorexpr/CMakeLists.txt index 179270c4a4a15..9c409e078d9dd 100644 --- a/test/cpp/tensorexpr/CMakeLists.txt +++ b/test/cpp/tensorexpr/CMakeLists.txt @@ -58,20 +58,22 @@ if(USE_CUDA) target_compile_definitions(tutorial_tensorexpr PRIVATE USE_CUDA) elseif(USE_ROCM) target_link_libraries(test_tensorexpr PRIVATE - ${ROCM_HIPRTC_LIB} - ${PYTORCH_HIP_LIBRARIES} + hiprtc::hiprtc + hip::amdhip64 ${TORCH_CUDA_LIBRARIES}) target_compile_definitions(test_tensorexpr PRIVATE USE_ROCM) target_link_libraries(tutorial_tensorexpr PRIVATE - ${ROCM_HIPRTC_LIB} - ${PYTORCH_HIP_LIBRARIES} + hiprtc::hiprtc + hip::amdhip64 ${TORCH_CUDA_LIBRARIES}) target_compile_definitions(tutorial_tensorexpr PRIVATE USE_ROCM) endif() if(INSTALL_TEST) + set_target_properties(test_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_tensorexpr DESTINATION bin) + set_target_properties(tutorial_tensorexpr PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS tutorial_tensorexpr DESTINATION bin) # Install PDB files for MSVC builds if(MSVC AND BUILD_SHARED_LIBS) diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp index d65b5c544f6c2..ddb63431fe3f6 100644 --- a/test/cpp/tensorexpr/test_reductions.cpp +++ b/test/cpp/tensorexpr/test_reductions.cpp @@ -1043,8 +1043,7 @@ TEST(Reductions, ReduceSplitRfactor) { SimpleIREvaluator cg(s, {b, c}); cg.call({in, out}); - for (const auto i : c10::irange(M)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(M)) { ASSERT_EQ(out[0], 4950); } } diff --git a/test/cpp/tensorexpr/test_simplify.cpp b/test/cpp/tensorexpr/test_simplify.cpp index 1971304e8e5c4..99a00d0d62c11 100644 --- a/test/cpp/tensorexpr/test_simplify.cpp +++ b/test/cpp/tensorexpr/test_simplify.cpp @@ -3884,8 +3884,7 @@ TEST(Simplify, SimplifyEliminateEmptyFor) { { // Flatten many layers around an empty block to an empty block. StmtPtr last = alloc(std::vector({})); - for (const auto i : c10::irange(11)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(11)) { VarHandle loopVar("loopVar", kInt); last = For::make(loopVar, 0, 10, last); } @@ -3969,8 +3968,7 @@ TEST(Simplify, SimplifyFlattenBlock) { { // Flatten many layers around an empty block to an empty block. StmtPtr last = alloc(std::vector({})); - for (const auto i : c10::irange(11)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(11)) { last = alloc(std::vector({last})); } diff --git a/test/cpp_extensions/mtia_extension.cpp b/test/cpp_extensions/mtia_extension.cpp index fdbfcaa26a27e..257ecf9cc91f8 100644 --- a/test/cpp_extensions/mtia_extension.cpp +++ b/test/cpp_extensions/mtia_extension.cpp @@ -139,7 +139,7 @@ struct MTIAGuardImpl final : public c10::impl::DeviceGuardImplInterface { struct MTIAHooks : public at::MTIAHooksInterface { explicit MTIAHooks(at::MTIAHooksArgs) {} - void initMTIA() const override {} + void init() const override {} bool hasMTIA() const override { return true; diff --git a/test/cpp_extensions/open_registration_extension/README.md b/test/cpp_extensions/open_registration_extension/README.md index 07f1f98d915a7..18d98971eda85 100644 --- a/test/cpp_extensions/open_registration_extension/README.md +++ b/test/cpp_extensions/open_registration_extension/README.md @@ -23,7 +23,6 @@ The main next step would be to: - Split the daemon into a proper user-process driver vs device-process executor. The main goal would be to better mimick which information is held on the user-process side and when we're actually communicating with the device. In particular current device or stream should be user-process informations. - Add Stream/Event system. Most likely by having multiple requests queue that go to the device from the driver. - Add RNG Generator. -- Add Pinned memory and HostAllocator. Longer term: - Replace the current `open_registration_extension.cpp` test in PyTorch CI with this. diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py index c150cdd6b1d6e..588ae26348e45 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py @@ -29,6 +29,9 @@ def _(*args, **kwargs): _register_same_name("exchangeDevice") _register_same_name("malloc", True) _register_same_name("free", True) +_register_same_name("isPinnedPtr", True) +_register_same_name("hostMalloc", True) +_register_same_name("hostFree", True) # TODO: replace it with implementing torch.openreg.device diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py index 54d7619ba6bc8..e910603cceb19 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_device_daemon.py @@ -32,6 +32,9 @@ def free(self, ptr): del self.allocated[ptr] return True + def is_allocated(self, ptr): + return ptr in self.allocated + def tensor_from_meta(self, meta): # Usual case, we're receiving a known Tensor found_base = self.allocated.get(meta.data_ptr, None) @@ -92,6 +95,7 @@ def _lazy_init(self): self.num_devices = 2 # Allocated memory belongs to which device self.memory_belong = {} + self.host_allocator = Allocator() self.devices = [] for i in range(self.num_devices): @@ -164,6 +168,18 @@ def free(self, ptr): return False return self.run_on_executor(device_idx, "free", ptr) + @register(registry) + def isPinnedPtr(self, ptr): + return self.host_allocator.is_allocated(ptr) + + @register(registry) + def hostMalloc(self, size): + return self.host_allocator.malloc(size) + + @register(registry) + def hostFree(self, ptr): + return self.host_allocator.free(ptr) + class _Executor: def __init__(self, id): diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp index 97c8d5f5d56ef..d1e111b9c2d89 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/csrc/OpenRegHooks.cpp @@ -1,7 +1,7 @@ #include -#include #include +#include #include @@ -11,18 +11,70 @@ namespace { // Python dictionary where real implementations can be found PyObject* py_registry; +using host_ptr_t = uint64_t; + +struct HostAllocator final : at::Allocator { + HostAllocator() = default; + + at::DataPtr allocate(size_t nbytes) override { + py::gil_scoped_acquire acquire; + void* data = nullptr; + if (nbytes > 0) { + data = reinterpret_cast(get_method("hostMalloc")(nbytes).cast()); + TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host."); + } + return {data, data, &ReportAndDelete, at::Device(at::kCPU)}; + } + + static void ReportAndDelete(void* ptr) { + if (!ptr) { + return; + } + py::gil_scoped_acquire acquire; + TORCH_CHECK( + get_method("hostFree")(reinterpret_cast(ptr)).cast(), + "Failed to free memory pointer at ", ptr); + } + + at::DeleterFnPtr raw_deleter() const override { + return &ReportAndDelete; + } + + void copy_data(void* dest, const void* src, std::size_t count) const final { + py::gil_scoped_acquire acquire; + get_method("hostCopyData")(reinterpret_cast(dest), reinterpret_cast(src), count); + } +}; +static HostAllocator global_host_alloc; + + // C++ hooks implementation struct OpenRegHooksArgs : public at::PrivateUse1HooksArgs {}; struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface { - OpenRegHooksInterface(OpenRegHooksArgs) {}; - ~OpenRegHooksInterface() override = default; + OpenRegHooksInterface(OpenRegHooksArgs) {}; + ~OpenRegHooksInterface() override = default; + + bool hasPrimaryContext(c10::DeviceIndex device_index) const override { + return get_method("hasPrimaryContext")(device_index).cast(); + } - bool hasPrimaryContext(c10::DeviceIndex device_index) const override { - return get_method("hasPrimaryContext")(device_index).cast(); - } + at::Allocator* getPinnedMemoryAllocator() const override { + return &global_host_alloc; + } + + bool isPinnedPtr(const void* data) const override { + py::gil_scoped_acquire acquire; + return get_method("isPinnedPtr")(reinterpret_cast(data)).cast(); + } }; +int register_hook() { + at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface(OpenRegHooksArgs{})); + return 0; +} +int temp_register_hook = register_hook(); + TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, OpenRegHooksInterface, OpenRegHooksArgs); C10_DEFINE_REGISTRY(PrivateUse1HooksRegistry, OpenRegHooksInterface, OpenRegHooksArgs); // Using Create function to get PrivateUse1HooksInterface point from PrivateUse1HooksRegistry class. @@ -237,14 +289,14 @@ C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl); // Setter for the python dictionary with implementations void set_impl_registry(PyObject* registry) { - py_registry = registry; + py_registry = registry; } py::function get_method(const char* name) { - auto dict = py::cast(py_registry); + auto dict = py::cast(py_registry); TORCH_CHECK(dict.contains(name), "OpenReg registry does not contain ", "an implementation for '", name, "' make sure to add it in the __init__.py " - "file and register it.") - return dict[name]; + "file and register it.") + return dict[name]; } } // openreg \ No newline at end of file diff --git a/test/cpp_extensions/open_registration_extension/test/test_openreg.py b/test/cpp_extensions/open_registration_extension/test/test_openreg.py index 3c4f8928ddbf5..9dc5ecbd19abe 100644 --- a/test/cpp_extensions/open_registration_extension/test/test_openreg.py +++ b/test/cpp_extensions/open_registration_extension/test/test_openreg.py @@ -67,6 +67,14 @@ def test_data_dependent_output(self): self.assertEqual(out, cpu_a.masked_select(cpu_a.gt(0))) + def test_pin_memory(self): + cpu_a = torch.randn(10) + self.assertFalse(cpu_a.is_pinned()) + pinned_a = cpu_a.pin_memory() + self.assertTrue(pinned_a.is_pinned()) + slice_a = pinned_a[2:5] + self.assertTrue(slice_a.is_pinned()) + if __name__ == "__main__": run_tests() diff --git a/test/custom_operator/op.cpp b/test/custom_operator/op.cpp index ab0506a822f61..c074b818c185a 100644 --- a/test/custom_operator/op.cpp +++ b/test/custom_operator/op.cpp @@ -12,8 +12,7 @@ torch::List custom_op( int64_t repeat) { torch::List output; output.reserve(repeat); - for (const auto i : c10::irange(repeat)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(repeat)) { output.push_back(tensor * scalar); } return output; diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index 2f17d8275a270..6303db7c03416 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -109,6 +109,7 @@ def _init_fsdp_param_group( mesh_info, post_forward_mesh_info, self.device, + None, # shard_placement_fn MixedPrecisionPolicy(), OffloadPolicy(), ) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index d96a8c2b0f753..4d02a06af6900 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -23,7 +23,11 @@ from torch.distributed._composable.fsdp._fsdp_param_group import FSDPParamGroup from torch.distributed._tensor import init_device_mesh from torch.testing import FileCheck -from torch.testing._internal.common_distributed import at_least_x_gpu, skip_if_lt_x_gpu +from torch.testing._internal.common_distributed import ( + at_least_x_gpu, + skip_if_lt_x_gpu, + sm_is_or_higher_than, +) from torch.testing._internal.common_fsdp import FSDPTest, MLP from torch.testing._internal.common_utils import run_tests, skipIfRocm from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -96,13 +100,21 @@ def patched_trace_rules_check(*args, **kwargs): self.assertTrue(trace_rules_check_count > 0) +@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") class TestFullyShardCompile(FSDPTest): fake_pg = not at_least_x_gpu(2) - @property - def world_size(self) -> int: - return 2 + # This method is an override of the base class. + # Tests in this class requires bf16 support, so SM arch must be 80 or + # higher. + def skipTestForOldSm(self): + # Assumption: This test class is only run on GPU. See `HAS_GPU` check at + # the top of the class. + device = torch.device("cuda", self.rank % torch.cuda.device_count()) + if not sm_is_or_higher_than(device, 8, 0): + self.skipTest("bf16 requires sm >= 8.0") + @skipIfRocm def test_dynamo_trace_use_training_state(self): torch._dynamo.reset() # Construct a dummy FSDPParamGroup, since we just want to test the `use_training_state` ctx manager. @@ -112,6 +124,7 @@ def test_dynamo_trace_use_training_state(self): None, # mesh_info: FSDPMeshInfo, None, # post_forward_mesh_info: Optional[FSDPMeshInfo], torch.device("cuda"), # device: torch.device, + None, # shard_placement_fn: Optional[Callable], None, # mp_policy: MixedPrecisionPolicy, None, # offload_policy: OffloadPolicy, ) @@ -139,6 +152,7 @@ def f(x): self.assertEqual(cnt.op_count, 1) self.assertEqual(len(cnt.graphs), 1) + @skipIfRocm def test_trace_fsdp_copy_(self): @torch.library.custom_op("mylib::add_one_out", mutates_args={"out"}) def add_one_out(x: torch.Tensor, out: torch.Tensor) -> None: @@ -242,7 +256,7 @@ def _check_count(copy_count, resize_count): f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950 ) - if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: + if not torch._dynamo.compiled_autograd.in_compiled_autograd_region(): _check_count(fwd_copy_count, fwd_resize_count) # fwd graph else: _check_count(bwd_copy_count, bwd_resize_count) # bwd graph @@ -400,6 +414,29 @@ def inductor_code_check_fsdp_reduce_scatter( file_check = file_check.check("torch.ops._c10d_functional.wait_tensor.") return file_check + @skipIfRocm + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + def test_compiled_autograd_ctx(self): + self.skipTestForOldSm() + with torch._dynamo.config.patch( + skip_fsdp_hooks=False, + ), torch._functorch.config.patch( + recompute_views=True, + ): + inputs = torch.randn(8, 8) + model = torch.nn.Linear(8, 8) + fully_shard(model) + model_compiled = torch.compile(model, backend="inductor") + for i in range(10): + torch.compiler.set_stance( + "force_eager" if i < 1 else "default" + ) # eager warmup for 1 iteration + with torch._dynamo.compiled_autograd.enable( + torch.compile(backend="inductor", fullgraph=True) + ): + out = model_compiled(inputs) + out.sum().backward() + def _test_traceable_fsdp( self, model_init_fn, @@ -422,6 +459,8 @@ def run_iters( torch.manual_seed(42) losses = [] for i in range(n_iter): + # eager warmup for 1 iteration, so that all FSDP2 lazy-initialization is done in eager + torch.compiler.set_stance("force_eager" if i < 1 else "default") inp = input_creation_fn() loss = fwd_bwd_func(inp) losses.append(loss.item()) @@ -432,8 +471,6 @@ def run_iters( def test_compiled(): model, optim = model_init_fn() fwd_bwd_fn = functools.partial(fwd_bwd, model) - # FSDP2 does lazy init using 1st run, so run it once to init using eager mode - run_iters(fwd_bwd_fn, optim, n_iter=1) counters.clear() with self._remove_fsdp2_unsharded_param_graph_input_usage_with_optional_checks( @@ -462,8 +499,6 @@ def test_compiled(): def test_eager(): model, optim = model_init_fn() fwd_bwd_fn = functools.partial(fwd_bwd, model) - # FSDP2 does lazy init using 1st run, so run it once to init using eager mode - run_iters(fwd_bwd_fn, optim, n_iter=1) res = run_iters(fwd_bwd_fn, optim) return res @@ -546,6 +581,7 @@ def test_simple_mlp_fullgraph_backend_aot_eager_decomp_partition(self): @skipIfRocm @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_simple_mlp_fullgraph_backend_inductor(self): + self.skipTestForOldSm() self._test_traceable_fsdp( *self._create_simple_mlp_factory_fns(), "inductor", fwd_fullgraph=True ) @@ -614,7 +650,8 @@ def input_creation_fn(): @skipIfRocm @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_nested_fully_shard_backend_aot_eager(self): - for fwd_fullgraph in [True, False]: + # TODO: fix fwd_fullgraph=False case + for fwd_fullgraph in [True]: self._test_traceable_fsdp( *self._create_nested_fully_shard_factory_fns( fwd_fullgraph=fwd_fullgraph @@ -626,7 +663,8 @@ def test_nested_fully_shard_backend_aot_eager(self): @skipIfRocm @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_nested_fully_shard_backend_aot_eager_decomp_partition(self): - for fwd_fullgraph in [True, False]: + # TODO: fix fwd_fullgraph=False case + for fwd_fullgraph in [True]: self._test_traceable_fsdp( *self._create_nested_fully_shard_factory_fns( fwd_fullgraph=fwd_fullgraph @@ -638,6 +676,7 @@ def test_nested_fully_shard_backend_aot_eager_decomp_partition(self): @skipIfRocm @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_nested_fully_shard_backend_inductor_fullgraph_True(self): + self.skipTestForOldSm() for fwd_fullgraph in [True]: with self._reinplace_all_gather_with_optional_checks( fwd_fullgraph @@ -731,9 +770,11 @@ def test_nested_fully_shard_backend_inductor_fullgraph_True(self): ) file_check.run(bwd_code) + @unittest.skip("TODO: fix fwd_fullgraph=False case") @skipIfRocm @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_nested_fully_shard_backend_inductor_fullgraph_False(self): + self.skipTestForOldSm() _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( *self._create_nested_fully_shard_factory_fns(fwd_fullgraph=False), @@ -811,8 +852,9 @@ def _sdpa_with_graph_break(*args, **kwargs): @skipIfRocm @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_transformer_backend_aot_eager(self): + # TODO: fix fwd_fullgraph=False case for fwd_fullgraph, all_requires_grad in itertools.product( - [True, False], [True, False] + [True], [True, False] ): with self._maybe_add_graph_break_to_sdpa( fwd_fullgraph @@ -830,8 +872,9 @@ def test_transformer_backend_aot_eager(self): # TODO: native_dropout has worse accuracy after decomp, need to figure out why @torch._inductor.config.patch(fallback_random=True) def test_transformer_backend_aot_eager_decomp_partition(self): + # TODO: fix fwd_fullgraph=False case for fwd_fullgraph, all_requires_grad in itertools.product( - [True, False], [True, False] + [True], [True, False] ): with self._maybe_add_graph_break_to_sdpa(fwd_fullgraph): self._test_traceable_fsdp( @@ -847,6 +890,7 @@ def test_transformer_backend_aot_eager_decomp_partition(self): # TODO: native_dropout causes CUDA IMA error, need to figure out why @torch._inductor.config.patch(fallback_random=True) def test_transformer_backend_inductor_fullgraph_True(self): + self.skipTestForOldSm() for ( fwd_fullgraph, all_requires_grad, @@ -947,11 +991,13 @@ def test_transformer_backend_inductor_fullgraph_True(self): ) file_check.run(bwd_code) + @unittest.skip("TODO: fix fwd_fullgraph=False case") @skipIfRocm @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") # TODO: native_dropout causes CUDA IMA error, need to figure out why @torch._inductor.config.patch(fallback_random=True) def test_transformer_backend_inductor_fullgraph_False(self): + self.skipTestForOldSm() fwd_fullgraph = False # TODO: fix numerical issue in activation_checkpoint=True case for all_requires_grad, activation_checkpoint in itertools.product( diff --git a/test/distributed/_composable/fsdp/test_fully_shard_init.py b/test/distributed/_composable/fsdp/test_fully_shard_init.py index 3b3912aad40a3..33bc1de851b98 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_init.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_init.py @@ -3,7 +3,7 @@ import copy import itertools import unittest -from typing import List +from typing import List, Optional import torch import torch.distributed as dist @@ -1013,5 +1013,148 @@ def test_hsdp_broadcast_across_replicas(self): model(inp).sum().backward() +class TestFullyShardShardPlacementFn(FSDPTestMultiThread): + @property + def world_size(self) -> int: + return 8 + + def _init_models(self): + torch.manual_seed(42) + model_args = ModelArgs(n_layers=3, dropout_p=0.0) + model = Transformer(model_args) + for param in model.parameters(): + dist.broadcast(param.detach(), src=0) + ref_model = copy.deepcopy(model) + return model, ref_model + + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_init_1d_transformer_shard_largest_dim(self): + model, ref_model = self._init_models() + + def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: + largest_dim = largest_dim_size = -1 + for dim, dim_size in enumerate(param.shape): + if dim_size > largest_dim_size: + largest_dim = dim + largest_dim_size = dim_size + assert largest_dim >= 0, f"{param.shape}" + return Shard(largest_dim) + + for layer in model.layers: + fully_shard(layer, shard_placement_fn=shard_placement_fn) + fully_shard(model, shard_placement_fn=shard_placement_fn) + + any_shard_dim1 = False + for param in model.parameters(): + self.assertEqual(len(param.placements), 1) + self.assertIsInstance(param.placements[0], Shard) + any_shard_dim1 |= param.placements[0].dim == 1 + self.assertTrue(any_shard_dim1) + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + full_param = param.full_tensor() + self.assertEqual(full_param, ref_param) + + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_init_1d_transformer_shard_dim_neg1(self): + model, ref_model = self._init_models() + + def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: + # Check that FSDP will normalize this dim to non-negative + return Shard(-1) + + for layer in model.layers: + fully_shard(layer, shard_placement_fn=shard_placement_fn) + fully_shard(model, shard_placement_fn=shard_placement_fn) + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + full_param = param.full_tensor() + self.assertEqual(full_param, ref_param) + + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_init_2d_transformer_shard_diff_dim(self): + model, ref_model = self._init_models() + + dp_size, tp_size = self.world_size // 2, 2 + global_mesh = init_device_mesh( + "cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp") + ) + model = Transformer.parallelize(model, global_mesh["tp"], use_seq_parallel=True) + + def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: + if isinstance(param, DTensor): + for placement in param.placements: + if isinstance(placement, Shard): + shard_dim = param.ndim - 1 - placement.dim + assert shard_dim >= 0, f"{param.shape}" + return Shard(shard_dim) + return Shard(0) + + for layer in model.layers: + fully_shard( + layer, mesh=global_mesh["dp"], shard_placement_fn=shard_placement_fn + ) + fully_shard( + model, mesh=global_mesh["dp"], shard_placement_fn=shard_placement_fn + ) + + linear_weight_names = ["wq", "wk", "wv", "wo", "w1", "w2"] + for param_name, param in model.named_parameters(): + if ( + any(n in param_name for n in linear_weight_names) + and "weight" in param_name + ): + total_placement_dims = 0 + for placement in param.placements: + self.assertTrue(isinstance(placement, Shard)) + total_placement_dims += placement.dim + self.assertEqual(param.ndim, 2) + # Check that FSDP shards on either dim-0 or dim-1, and TP + # shards on the other + self.assertEqual(total_placement_dims, 1) + else: + self.assertTrue( + any(isinstance(placement, Shard) for placement in param.placements) + ) + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + full_param = param.full_tensor() + self.assertEqual(full_param, ref_param) + + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_init_1d_uneven_shard_largest_dim(self): + torch.manual_seed(42) + model = nn.Sequential(nn.Linear(16, 17), nn.Linear(17, 8)) + + def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: + largest_dim = -1 + largest_dim_size = -1 + for dim, dim_size in enumerate(param.shape): + if dim_size > largest_dim_size: + largest_dim = dim + largest_dim_size = dim_size + assert largest_dim >= 0, f"{param.shape}" + assert largest_dim < param.ndim, f"{largest_dim=} {param.shape}" + return Shard(largest_dim) + + with self.assertRaisesRegex( + NotImplementedError, "FSDP does not support uneven sharding on dim 1" + ): + fully_shard(model, shard_placement_fn=shard_placement_fn) + + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_invalid_shard_dim(self): + model = nn.Sequential(nn.Linear(16, 16), nn.Linear(16, 8)) + + def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: + return Shard(1) + + # Shard(1) is invalid for 1D bias parameters + with self.assertRaisesRegex( + AssertionError, "Shard dim 1 is invalid for 1D tensor" + ): + fully_shard(model, shard_placement_fn=shard_placement_fn) + + if __name__ == "__main__": run_tests() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py b/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py index 19ba92724e964..e62f394a9e154 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py @@ -12,6 +12,7 @@ from torch.distributed._composable.fsdp._fsdp_collectives import ( _get_gradient_divide_factors, ) +from torch.distributed.tensor import Shard from torch.testing._internal.common_distributed import ( requires_nccl_version, SaveForwardInputsModel, @@ -38,18 +39,32 @@ def _init_models_and_optims( reshard_after_forward: Union[bool, int], param_dtype: Optional[torch.dtype], reduce_dtype: Optional[torch.dtype], + use_shard_placement_fn, ): torch.manual_seed(42) model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)]) ref_model = copy.deepcopy(model).cuda() ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) + + def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: + largest_dim = -1 + largest_dim_size = -1 + for dim, dim_size in enumerate(param.shape): + if dim_size > largest_dim_size: + largest_dim = dim + largest_dim_size = dim_size + assert largest_dim >= 0, f"{param.shape}" + return Shard(largest_dim) + mp_policy = MixedPrecisionPolicy( param_dtype=param_dtype, reduce_dtype=reduce_dtype ) + shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None fully_shard_fn = functools.partial( fully_shard, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy, + shard_placement_fn=shard_placement_fn, ) for mlp in model: fully_shard_fn(mlp) @@ -57,22 +72,41 @@ def _init_models_and_optims( optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) return ref_model, ref_optim, model, optim + def _get_use_shard_placement_fn_vals_for_bf16_reduce(self): + use_shard_placement_fn_vals = [False] + if self.world_size == 2: + # For world size >2, gradient elements get reduced in different + # orders for the baseline vs. dim-1 sharding, leading to numeric + # differences for bf16 reduction, so only test world size 2. + use_shard_placement_fn_vals.append(True) + return use_shard_placement_fn_vals + @skip_if_lt_x_gpu(2) @requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives") def test_compute_dtype(self): + use_shard_placement_fn_vals = ( + self._get_use_shard_placement_fn_vals_for_bf16_reduce() + ) self.run_subtests( { "param_dtype": [torch.bfloat16, torch.float16], "reshard_after_forward": [False, True, 2], + "use_shard_placement_fn": use_shard_placement_fn_vals, }, self._test_compute_dtype, ) def _test_compute_dtype( - self, param_dtype: torch.dtype, reshard_after_forward: Union[bool, int] + self, + param_dtype: torch.dtype, + reshard_after_forward: Union[bool, int], + use_shard_placement_fn: bool, ): ref_model, ref_optim, model, optim = self._init_models_and_optims( - reshard_after_forward, param_dtype=param_dtype, reduce_dtype=None + reshard_after_forward, + param_dtype=param_dtype, + reduce_dtype=None, + use_shard_placement_fn=use_shard_placement_fn, ) ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype) orig_reduce_scatter = dist.reduce_scatter_tensor @@ -130,18 +164,38 @@ def assert_fn(output: torch.Tensor): @requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives") def test_reduce_dtype(self): self.run_subtests( - {"reshard_after_forward": [False, True, 2]}, + { + "reshard_after_forward": [False, True, 2], + "use_shard_placement_fn": [False, True], + }, self._test_reduce_dtype_fp32_reduce, ) + use_shard_placement_fn_vals = ( + self._get_use_shard_placement_fn_vals_for_bf16_reduce() + ) self.run_subtests( - {"reshard_after_forward": [False, True, 2]}, + { + "reshard_after_forward": [False, True, 2], + "use_shard_placement_fn": use_shard_placement_fn_vals, + }, self._test_reduce_dtype_bf16_reduce, ) - def _test_reduce_dtype_fp32_reduce(self, reshard_after_forward: Union[bool, int]): + def _test_reduce_dtype_fp32_reduce( + self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool + ): + if ( + self.world_size > 2 + and isinstance(reshard_after_forward, int) + and use_shard_placement_fn + ): + return param_dtype, reduce_dtype = torch.bfloat16, torch.float32 ref_model, ref_optim, model, optim = self._init_models_and_optims( - reshard_after_forward, param_dtype=param_dtype, reduce_dtype=reduce_dtype + reshard_after_forward, + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + use_shard_placement_fn=use_shard_placement_fn, ) ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype) orig_reduce_scatter = dist.reduce_scatter_tensor @@ -182,10 +236,15 @@ def assert_fn(output: torch.Tensor): self.assertEqual(fsdp_loss, ref_loss) check_sharded_parity(self, ref_model, model) - def _test_reduce_dtype_bf16_reduce(self, reshard_after_forward: Union[bool, int]): + def _test_reduce_dtype_bf16_reduce( + self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool + ): param_dtype, reduce_dtype = torch.float32, torch.bfloat16 ref_model, ref_optim, model, optim = self._init_models_and_optims( - reshard_after_forward, param_dtype=param_dtype, reduce_dtype=reduce_dtype + reshard_after_forward, + param_dtype=param_dtype, + reduce_dtype=reduce_dtype, + use_shard_placement_fn=use_shard_placement_fn, ) group = dist.distributed_c10d._get_default_group() orig_reduce_scatter = dist.reduce_scatter_tensor diff --git a/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py b/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py index 0bada875b51cb..8526d950d4e6b 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_state_dict.py @@ -4,13 +4,13 @@ import functools import unittest from contextlib import nullcontext -from typing import Dict +from typing import Dict, Optional import torch import torch.nn as nn from torch.distributed._composable.fsdp import CPUOffloadPolicy, fully_shard -from torch.distributed._tensor import distribute_tensor, DTensor from torch.distributed.device_mesh import DeviceMesh, init_device_mesh +from torch.distributed.tensor import distribute_tensor, DTensor, Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, @@ -39,6 +39,10 @@ def test_dp_state_dict_save_load(self): {"mlp_dim": [2, 3, 4, 5], "mesh": [fsdp_mesh]}, self._test_dp_state_dict_save_load, ) + self.run_subtests( + {"mlp_dim": [16], "mesh": [fsdp_mesh], "use_shard_placement_fn": [True]}, + self._test_dp_state_dict_save_load, + ) if self.world_size % 2 != 0: return hsdp_mesh = init_device_mesh( @@ -50,26 +54,46 @@ def test_dp_state_dict_save_load(self): {"mlp_dim": [2, 3, 4, 5], "mesh": [hsdp_mesh]}, self._test_dp_state_dict_save_load, ) + self.run_subtests( + {"mlp_dim": [16], "mesh": [hsdp_mesh], "use_shard_placement_fn": [True]}, + self._test_dp_state_dict_save_load, + ) - def _test_dp_state_dict_save_load(self, mlp_dim: int, mesh: DeviceMesh): + def _test_dp_state_dict_save_load( + self, mlp_dim: int, mesh: DeviceMesh, use_shard_placement_fn: bool = False + ): torch.manual_seed(42) base_model = nn.Sequential( MLP(mlp_dim), nn.Sequential(MLP(mlp_dim), nn.Linear(mlp_dim, mlp_dim)), MLP(mlp_dim), ) + + def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: + largest_dim = largest_dim_size = -1 + for dim, dim_size in enumerate(param.shape): + if dim_size > largest_dim_size: + largest_dim = dim + largest_dim_size = dim_size + return Shard(largest_dim) + + shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None + fully_shard_fn = functools.partial( + fully_shard, mesh=mesh, shard_placement_fn=shard_placement_fn + ) + # Check basic `reshard_after_forward=True` model1 = copy.deepcopy(base_model) for module in model1: - fully_shard(module, mesh=mesh) - fully_shard(model1, mesh=mesh) + fully_shard_fn(module) + fully_shard_fn(model1) self._test_state_dict_save_load(model1) # Check `reshard_after_forward=False` before and after a forward model2 = copy.deepcopy(base_model) for module in model2: - fully_shard(module, mesh=mesh, reshard_after_forward=False) - fully_shard(model2, mesh=mesh, reshard_after_forward=False) + fully_shard_fn(module, reshard_after_forward=False) + fully_shard_fn(model2, reshard_after_forward=False) self._test_state_dict_save_load(model2) ref_sharded_sd = model2.state_dict() inp = torch.randn((2, mlp_dim), device="cuda") diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index ab52cb925709a..3cf4e122915d7 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -6,7 +6,7 @@ import itertools import unittest from collections import defaultdict -from typing import Iterable, List, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple, Union import torch import torch.distributed as dist @@ -19,12 +19,12 @@ OffloadPolicy, register_fsdp_forward_method, ) -from torch.distributed._tensor import DTensor, init_device_mesh from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( _CHECKPOINT_PREFIX, apply_activation_checkpointing, ) from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, init_device_mesh, Shard from torch.distributed.tensor.debug import CommDebugMode from torch.testing._internal.common_cuda import TEST_CUDA from torch.testing._internal.common_distributed import skip_if_lt_x_gpu @@ -216,6 +216,8 @@ def test_to_float64_after_init(self): model.to(dtype) for param in model.parameters(): self.assertEqual(param.dtype, dtype) + self.assertEqual(param.to_local().dtype, dtype) + self.assertEqual(param._spec.tensor_meta.dtype, dtype) optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) check_sharded_parity(self, ref_model, model) torch.manual_seed(42 + self.rank + 1) @@ -227,6 +229,13 @@ def test_to_float64_after_init(self): losses[-1].backward() self.assertEqual(losses[0], losses[1]) check_sharded_parity(self, ref_model, model) + for param in model.parameters(): + self.assertEqual(param.dtype, dtype) + self.assertEqual(param.to_local().dtype, dtype) + self.assertEqual(param._spec.tensor_meta.dtype, dtype) + self.assertEqual(param.grad.dtype, dtype) + self.assertEqual(param.grad.to_local().dtype, dtype) + self.assertEqual(param.grad._spec.tensor_meta.dtype, dtype) for _optim in (ref_optim, optim): _optim.step() _optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) @@ -238,16 +247,41 @@ def world_size(self) -> int: return min(8, torch.cuda.device_count()) @skip_if_lt_x_gpu(2) - def test_train_parity_single_group(self): - """Tests train parity with DDP for a single FSDP group.""" + def test_train_parity_single_group_shard_dim0(self): + """ + Tests train parity with DDP for a single FSDP group when sharding + parameters on dim-0. + """ + self.run_subtests( + { + "lin_shapes": [ + [(16, 15), (15, 8)], + [(7, 15), (15, 3)], + [(16, 17), (17, 8)], + ], + "use_shard_placement_fn": [False], + }, + self._test_train_parity_single_group, + ) + + @skip_if_lt_x_gpu(2) + def test_train_parity_single_group_shard_largest_dim(self): + """ + Tests train parity with DDP for a single FSDP group when sharding + parameters on their largest dim. + """ self.run_subtests( { - "lin_shapes": [[(16, 15), (15, 8)], [(7, 15), (15, 3)]], + # Sharding on nonzero dim requires even sharding + "lin_shapes": [[(32, 16), (16, 8)]], + "use_shard_placement_fn": [True], }, self._test_train_parity_single_group, ) - def _test_train_parity_single_group(self, lin_shapes: List[Tuple[int, int]]): + def _test_train_parity_single_group( + self, lin_shapes: List[Tuple[int, int]], use_shard_placement_fn: bool + ): torch.manual_seed(42) model = nn.Sequential( nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1]) @@ -255,7 +289,20 @@ def _test_train_parity_single_group(self, lin_shapes: List[Tuple[int, int]]): ref_model = copy.deepcopy(model).cuda() replicate(ref_model, device_ids=[self.rank]) ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2) - fully_shard(model) + + def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: + largest_dim = -1 + largest_dim_size = -1 + for dim, dim_size in enumerate(param.shape): + if dim_size > largest_dim_size: + largest_dim = dim + largest_dim_size = dim_size + assert largest_dim >= 0, f"{param.shape}" + assert largest_dim < param.ndim, f"{largest_dim=} {param.shape}" + return Shard(largest_dim) + + shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None + fully_shard(model, shard_placement_fn=shard_placement_fn) optim = torch.optim.Adam(model.parameters(), lr=1e-2) torch.manual_seed(42 + self.rank + 1) inp = (torch.randn((4, lin_shapes[0][0]), device="cuda"),) @@ -665,6 +712,8 @@ def _test_train_parity_with_activation_checkpointing( fully_shard(model.layers[0], **fsdp_kwargs) fully_shard([model.layers[1], model.layers[2]], **fsdp_kwargs) fully_shard([model.tok_embeddings, model.pos_embeddings], **fsdp_kwargs) + # Embedding weights are not needed for embedding backward + model.tok_embeddings.set_unshard_in_backward(False) fully_shard([model.norm, model.output], **fsdp_kwargs) elif module_grouping == "mem_eff_weight_tied": fully_shard([model.tok_embeddings, model.output], **fsdp_kwargs) @@ -705,6 +754,100 @@ def _test_train_parity_with_activation_checkpointing( ) +class TestFullyShardShardPlacementFnMultiProcess(FSDPTest): + @property + def world_size(self) -> int: + return min(8, torch.cuda.device_count()) + + @skip_if_lt_x_gpu(2) + def test_train_parity_shard_placement_fn_shard_largest_dim(self): + torch.manual_seed(42) + model_args = ModelArgs(n_layers=3, dropout_p=0.0) + model = Transformer(model_args) + ref_model = copy.deepcopy(model).cuda() + ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2) + + def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: + largest_dim = -1 + largest_dim_size = -1 + for dim, dim_size in enumerate(param.shape): + if dim_size > largest_dim_size: + largest_dim = dim + largest_dim_size = dim_size + return Shard(largest_dim) + + for layer in model.layers: + fully_shard(layer, shard_placement_fn=shard_placement_fn) + fully_shard(model, shard_placement_fn=shard_placement_fn) + optim = torch.optim.AdamW(model.parameters(), lr=1e-2) + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + full_param = param.full_tensor() + self.assertEqual(full_param, ref_param) + + torch.manual_seed(42 + self.rank) + inp = torch.randint(0, model_args.vocab_size, (2, 16), device="cuda") + for iter_idx in range(5): + ref_loss = ref_model(inp).sum() + loss = model(inp).sum() + self.assertEqual(ref_loss, loss) + + ref_loss.backward() + loss.backward() + for param in ref_model.parameters(): + if param.grad is not None: + dist.all_reduce(param.grad, op=dist.ReduceOp.AVG) + + ref_optim.step() + optim.step() + ref_optim.zero_grad() + optim.zero_grad() + + for param, ref_param in zip(model.parameters(), ref_model.parameters()): + full_param = param.full_tensor() + self.assertEqual(full_param, ref_param) + + +class TestFullyShardShardPlacementFnMultiThread(FSDPTestMultiThread): + @property + def world_size(self) -> int: + return 4 + + @unittest.skipIf(not TEST_CUDA, "no cuda") + def test_shard_placement_fn_contiguous_params_grads(self): + dim = 4 + model = MLP(dim=dim) + + def shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: + if param.ndim > 1: + return Shard(1) + return Shard(0) + + fully_shard(model.in_proj, shard_placement_fn=shard_placement_fn) + fully_shard(model.out_proj, shard_placement_fn=shard_placement_fn) + fully_shard(model, shard_placement_fn=shard_placement_fn) + + def assert_contiguous_params(module: nn.Module, args: Any): + for param in module.parameters(): + self.assertTrue(param.is_contiguous()) + + model.in_proj.register_forward_pre_hook(assert_contiguous_params) + model.out_proj.register_forward_pre_hook(assert_contiguous_params) + + for param in model.parameters(): + self.assertTrue(param.is_contiguous()) + self.assertTrue(param.to_local().is_contiguous()) + + inp = torch.randn((2, dim), device="cuda") + model(inp).sum().backward() + + for param in model.parameters(): + self.assertTrue(param.is_contiguous()) + self.assertTrue(param.to_local().is_contiguous()) + self.assertTrue(param.grad.is_contiguous()) + self.assertTrue(param.grad.to_local().is_contiguous()) + + class TestFullyShardSharedParams(FSDPTest): @property def world_size(self) -> int: @@ -765,7 +908,13 @@ def test_gradient_accumulation(self): meshes = [init_device_mesh("cuda", (self.world_size,))] # always test FSDP if self.world_size == 4: # test HSDP too if enough GPUs shard_size, replicate_size = 2, 2 - meshes.append(init_device_mesh("cuda", (replicate_size, shard_size))) + meshes.append( + init_device_mesh( + "cuda", + (replicate_size, shard_size), + mesh_dim_names=("dp_replicate", "dp_shard"), + ) + ) self.run_subtests( { "mesh": meshes, @@ -1158,7 +1307,9 @@ def test_train_parity_hsdp(self): shard_size = 2 if self.world_size > 2 else 1 replicate_size = self.world_size // shard_size global_mesh = init_device_mesh( - "cuda", (replicate_size, shard_size), mesh_dim_names=("replicate", "shard") + "cuda", + (replicate_size, shard_size), + mesh_dim_names=("dp_replicate", "dp_shard"), ) self.run_subtests( { diff --git a/test/distributed/_composable/test_composability/test_2d_composability.py b/test/distributed/_composable/test_composability/test_2d_composability.py index 83b0f8f2b5ac6..c6865b0ceeed4 100644 --- a/test/distributed/_composable/test_composability/test_2d_composability.py +++ b/test/distributed/_composable/test_composability/test_2d_composability.py @@ -4,7 +4,7 @@ import functools import io from copy import deepcopy -from typing import List, Type +from typing import List, Optional, Type import torch import torch.distributed as dist @@ -174,6 +174,12 @@ def _test_train_parity_2d_mlp( @skip_if_lt_x_gpu(2) @skipIfRocm def test_train_parity_2d_transformer(self): + self.run_subtests( + {"use_shard_placement_fn": [False, True]}, + self._test_train_parity_2d_transformer, + ) + + def _test_train_parity_2d_transformer(self, use_shard_placement_fn: bool): torch.manual_seed(42) model_args = ModelArgs(n_layers=3, dropout_p=0.0) model = Transformer(model_args) @@ -186,9 +192,23 @@ def test_train_parity_2d_transformer(self): ) model = Transformer.parallelize(model, global_mesh["tp"], use_seq_parallel=True) + def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: + if isinstance(param, DTensor): + for placement in param.placements: + if isinstance(placement, Shard): + shard_dim = param.ndim - 1 - placement.dim + assert shard_dim >= 0, f"{param.shape}" + return Shard(shard_dim) + return Shard(0) + + shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None for layer in model.layers: - fully_shard(layer, mesh=global_mesh["dp"]) - fully_shard(model, mesh=global_mesh["dp"]) + fully_shard( + layer, mesh=global_mesh["dp"], shard_placement_fn=shard_placement_fn + ) + fully_shard( + model, mesh=global_mesh["dp"], shard_placement_fn=shard_placement_fn + ) optim = torch.optim.AdamW(model.parameters(), lr=1e-2) for param, ref_param in zip(model.parameters(), ref_model.parameters()): diff --git a/test/distributed/_composable/test_composability/test_pp_composability.py b/test/distributed/_composable/test_composability/test_pp_composability.py index e173bb34e3eaa..93895e4d3ae50 100644 --- a/test/distributed/_composable/test_composability/test_pp_composability.py +++ b/test/distributed/_composable/test_composability/test_pp_composability.py @@ -14,7 +14,6 @@ from torch.distributed.pipelining.schedules import ( PipelineScheduleSingle, Schedule1F1B, - ScheduleFlexibleInterleaved1F1B, ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleInterleavedZeroBubble, @@ -86,7 +85,6 @@ def device(self): Schedule1F1B, ScheduleInterleaved1F1B, ScheduleLoopedBFS, - ScheduleFlexibleInterleaved1F1B, ScheduleInterleavedZeroBubble, ], ) @@ -176,7 +174,6 @@ def build_stage(stage_idx, num_stages): num_stages, self.device, group=pp_group, - input_args=input_mb[0], ) return stage, offset @@ -212,7 +209,14 @@ def build_stage(stage_idx, num_stages): ) # Run - pipeline_schedule._step_microbatches(arg_mbs=input_mb, target_mbs=input_mb) + # TODO(whc) should we make it a hard error if you pass arguments into the step API on nonzero ranks? + # why are we passing inputs/targets on every rank? + if pp_group.rank() == 0: + pipeline_schedule._step_microbatches(arg_mbs=input_mb, target_mbs=input_mb) + else: + pipeline_schedule._step_microbatches( + arg_mbs=[[] for _ in input_mb], target_mbs=input_mb + ) # Ref model runs on 2 different inputs, accumulating grads across them. # this ensures that we detect if the FSDP reduce becomes a no-op. diff --git a/test/distributed/_composable/test_replicate_with_compiler.py b/test/distributed/_composable/test_replicate_with_compiler.py index 03dc74b113ce4..0a072ec4ab3ff 100644 --- a/test/distributed/_composable/test_replicate_with_compiler.py +++ b/test/distributed/_composable/test_replicate_with_compiler.py @@ -30,6 +30,7 @@ MultiProcessTestCase, skip_if_lt_x_gpu, skip_if_rocm_multiprocess, + sm_is_or_higher_than, ) from torch.testing._internal.common_utils import run_tests, skipIfRocm from torch.testing._internal.distributed.fake_pg import FakeStore @@ -79,6 +80,8 @@ class MultiProcessInductorTestCase(MultiProcessTestCase, InductorTestCase): class ReplicateTest(MultiProcessInductorTestCase): + # TODO: consider using all devices? The min(2, ...) here would limit the + # test to always run on 2 GPUs only. @property def world_size(self) -> int: return min(2, torch.cuda.device_count()) @@ -219,14 +222,18 @@ def test_compile_cpu_no_sync(self): @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) - @torch._inductor.config.patch(reorder_for_locality=False) + @torch._inductor.config.patch( + reorder_for_locality=False, reorder_for_peak_memory=False + ) def test_compile_gpu(self): self._test_compile(use_gpu=True, no_sync=False, checkpoint=False) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) - @torch._inductor.config.patch(reorder_for_locality=False) + @torch._inductor.config.patch( + reorder_for_locality=False, reorder_for_peak_memory=False + ) def test_compile_gpu_ac(self): self._test_compile(use_gpu=True, no_sync=False, checkpoint=True) @@ -234,6 +241,11 @@ def test_compile_gpu_ac(self): @skip_if_rocm_multiprocess @skip_if_lt_x_gpu(2) def test_compile_bf16(self): + # Check device capability wrt bf16 + device = torch.device("cuda", self.rank % torch.cuda.device_count()) + if not sm_is_or_higher_than(device, 8, 0): + self.skipTest("bf16 requires sm >= 8.0") + def setup(model, compiled_replicate_model, compiled_ddp_model) -> None: model.register_comm_hook(None, ddp_default_hooks.bf16_compress_hook) compiled_m = compiled_replicate_model._orig_mod @@ -305,7 +317,9 @@ def bwd(loss): ) # todo: This pass mucks things up since Inductor thinks its inference # and can apply this. Should turn off these passes in compiled autograd - @torch._inductor.config.patch(reorder_for_locality=False) + @torch._inductor.config.patch( + reorder_for_locality=False, reorder_for_peak_memory=False + ) def test_bucketing_coalesced_op(self): # Gradient is None code = self._test_bucketing() @@ -341,7 +355,9 @@ def test_bucketing_coalesced_op(self): ) # todo: This pass mucks things up since Inductor thinks its inference # and can apply this. Should turn off these passes in compiled autograd - @torch._inductor.config.patch(reorder_for_locality=False) + @torch._inductor.config.patch( + reorder_for_locality=False, reorder_for_peak_memory=False + ) def test_bucketing_concat_op(self): # Gradient is None code = self._test_bucketing() @@ -370,6 +386,7 @@ def test_bucketing_concat_op(self): class DDP_TP_Test(InductorTestCase): def setUp(self): + # Hmm, why a specific set_device call for rank 0? self.rank = 0 self.world_size = 4 torch.cuda.set_device("cuda:0") @@ -385,6 +402,9 @@ def setUp(self): def tearDown(self): dist.destroy_process_group() + @unittest.skip( + "Temporarily disabled due to SymInt error: `unhashable type: non-nested SymInt`" + ) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skipIfRocm def test_ddp_tp(self): diff --git a/test/distributed/_tensor/test_attention.py b/test/distributed/_tensor/test_attention.py index 238e551ab6a4a..06cc7ca73aea9 100644 --- a/test/distributed/_tensor/test_attention.py +++ b/test/distributed/_tensor/test_attention.py @@ -12,6 +12,7 @@ _CausalBehavior, _cp_options, _is_causal_behavior, + _RotateMethod, context_parallel, context_parallel_unshard, ) @@ -66,13 +67,16 @@ def world_size(self) -> int: @parametrize("compiled", [True, False]) @parametrize("backend", backends) @parametrize("load_balance", [True, False]) + @parametrize("rotater", [_RotateMethod.ALL_TO_ALL, _RotateMethod.ALL_GATHER]) def test_ring_attention_sdpa( self, is_causal: bool, compiled: bool, backend: SDPBackend, load_balance: bool, + rotater: _RotateMethod, ) -> None: + _cp_options.rotate_method = rotater device_mesh = DeviceMesh(self.device_type, torch.arange(0, self.world_size)) dtype = torch.bfloat16 bs = 8 @@ -148,7 +152,7 @@ def test_ring_attention_sdpa( cp_out = fn(cp_q, cp_k, cp_v, is_causal=is_causal) cp_out.sum().backward() - if not compiled: + if not compiled and rotater == _RotateMethod.ALL_TO_ALL: # Compiler and CommDebugMode do not work well together. self.assertDictEqual( comm_mode.get_comm_counts(), @@ -225,8 +229,12 @@ def test_is_causal_behavior(self) -> None: @with_comms @sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]) @parametrize("is_causal", [True, False]) - def test_ring_attention_native_transformer(self, is_causal: bool) -> None: + @parametrize("rotater", [_RotateMethod.ALL_GATHER, _RotateMethod.ALL_TO_ALL]) + def test_ring_attention_native_transformer( + self, is_causal: bool, rotater: _RotateMethod + ) -> None: _cp_options.enable_load_balance = is_causal + _cp_options.rotate_method = rotater device_mesh = DeviceMesh( self.device_type, torch.arange(0, self.world_size), @@ -265,22 +273,42 @@ def test_ring_attention_native_transformer(self, is_causal: bool) -> None: with CommDebugMode() as comm_mode: out = model(seq, mask=mask, is_causal=is_causal) - self.assertDictEqual( - comm_mode.get_comm_counts(), - { - c10d_functional.all_to_all_single: (self.world_size - 1) * num_layers, - }, - ) + + if rotater == _RotateMethod.ALL_TO_ALL: + self.assertDictEqual( + comm_mode.get_comm_counts(), + { + c10d_functional.all_to_all_single: (self.world_size - 1) + * num_layers, + }, + ) + else: + self.assertDictEqual( + comm_mode.get_comm_counts(), + { + c10d_functional.all_gather_into_tensor: num_layers, + }, + ) with CommDebugMode() as comm_mode: out.sum().backward() - self.assertDictEqual( - comm_mode.get_comm_counts(), - { - c10d_functional.all_to_all_single: (self.world_size * 2 - 1) - * num_layers, - }, - ) + + if rotater == _RotateMethod.ALL_TO_ALL: + self.assertDictEqual( + comm_mode.get_comm_counts(), + { + c10d_functional.all_to_all_single: (self.world_size * 2 - 1) + * num_layers, + }, + ) + else: + self.assertDictEqual( + comm_mode.get_comm_counts(), + { + c10d_functional.all_gather_into_tensor: num_layers, + c10d_functional.all_to_all_single: self.world_size * num_layers, + }, + ) @skip_if_lt_x_gpu(2) @unittest.skipIf( @@ -288,7 +316,9 @@ def test_ring_attention_native_transformer(self, is_causal: bool) -> None: ) @with_comms @sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]) - def test_ring_attention_custom_transformer(self) -> None: + @parametrize("rotater", [_RotateMethod.ALL_GATHER, _RotateMethod.ALL_TO_ALL]) + def test_ring_attention_custom_transformer(self, rotater: _RotateMethod) -> None: + _cp_options.rotate_method = rotater device_mesh = DeviceMesh( self.device_type, torch.arange(0, self.world_size), @@ -314,23 +344,40 @@ def test_ring_attention_custom_transformer(self) -> None: with CommDebugMode() as comm_mode: out = model(seq) - self.assertDictEqual( - comm_mode.get_comm_counts(), - { - c10d_functional.all_to_all_single: (self.world_size - 1) - * args.n_layers, - }, - ) + + if rotater == _RotateMethod.ALL_TO_ALL: + self.assertDictEqual( + comm_mode.get_comm_counts(), + { + c10d_functional.all_to_all_single: (self.world_size - 1) + * args.n_layers, + }, + ) + else: + self.assertDictEqual( + comm_mode.get_comm_counts(), + {c10d_functional.all_gather_into_tensor: args.n_layers}, + ) with CommDebugMode() as comm_mode: out.sum().backward() - self.assertDictEqual( - comm_mode.get_comm_counts(), - { - c10d_functional.all_to_all_single: (self.world_size * 2 - 1) - * args.n_layers, - }, - ) + + if rotater == _RotateMethod.ALL_TO_ALL: + self.assertDictEqual( + comm_mode.get_comm_counts(), + { + c10d_functional.all_to_all_single: (self.world_size * 2 - 1) + * args.n_layers, + }, + ) + else: + self.assertDictEqual( + comm_mode.get_comm_counts(), + { + c10d_functional.all_gather_into_tensor: args.n_layers, + c10d_functional.all_to_all_single: self.world_size * args.n_layers, + }, + ) if backends: diff --git a/test/distributed/_tensor/test_dtensor_compile.py b/test/distributed/_tensor/test_dtensor_compile.py index a32cb2a8ae56d..cc84ac196516c 100644 --- a/test/distributed/_tensor/test_dtensor_compile.py +++ b/test/distributed/_tensor/test_dtensor_compile.py @@ -622,6 +622,8 @@ def fn(x_dt): self.assertEqual(ref, res) def test_graph_input_is_async(self): + from torch.distributed._functional_collectives import AsyncCollectiveTensor + mesh = DeviceMesh(self.device_type, torch.arange(self.world_size)) def fn(x): @@ -633,6 +635,7 @@ def fn(x): x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) x2 = x_dt.redistribute(mesh, [Replicate()], async_op=True) x2 = x2.to_local() + self.assertTrue(isinstance(x2, AsyncCollectiveTensor)) out = opt_fn(x2) # The important part: we get a wait_tensor() in the graph. # At runtime, the input to the graph is an AsyncCollectiveTensor, diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py index 454f7aaddc770..d6f9781e3bc40 100644 --- a/test/distributed/_tensor/test_dtensor_ops.py +++ b/test/distributed/_tensor/test_dtensor_ops.py @@ -14,11 +14,7 @@ ops, ) from torch.testing._internal.common_methods_invocations import DecorateInfo, op_db -from torch.testing._internal.common_utils import ( - run_tests, - suppress_warnings, - TEST_WITH_ASAN, -) +from torch.testing._internal.common_utils import run_tests, suppress_warnings from torch.testing._internal.distributed._tensor.common_dtensor import ( DTensorConverter, DTensorOpTestBase, @@ -365,6 +361,7 @@ def wrapped(fn): xfail("ormqr"), xfail("ones"), xfail("pca_lowrank"), + xfail("permute_copy"), xfail("pinverse"), xfail("polar"), xfail("put"), @@ -449,6 +446,7 @@ def wrapped(fn): xfail("trapz"), xfail("triangular_solve"), xfail("unbind"), + xfail("unbind_copy"), xfail("unfold"), xfail("unfold_copy"), xfail("uniform"), @@ -528,7 +526,6 @@ def world_size(self) -> int: # only allow float dytpe for now, we can relax this constraint # when feel necessary later (i.e when adding quantization support). - @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") @suppress_warnings @ops(op_db, allowed_dtypes=(torch.float,)) @skipOps("TestDTensorOps", "test_dtensor_op_db", dtensor_fails) diff --git a/test/distributed/_tools/test_runtime_estimator.py b/test/distributed/_tools/test_runtime_estimator.py index 3c41bcbce61d6..0d4e4782675cf 100644 --- a/test/distributed/_tools/test_runtime_estimator.py +++ b/test/distributed/_tools/test_runtime_estimator.py @@ -167,9 +167,11 @@ def test_transformer_runtime( f"\nActual: {actual_runtime} Roofline Estimate: {roofline_estimate} Accuracy: {roofline_accuracy}" f"\nActual: {actual_runtime} Learned Estimate: {learned_estimate} Accuracy: {learned_accuracy}" ) - self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2) - self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.3) - self.assertAlmostEqual(learned_accuracy, 1.0, delta=0.3) + + # No accuracy check for benchmark in CI as it is highly variable + # self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2) + # self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.3) + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") @unittest.skipIf(not TEST_CUDA, "CUDA not available") @@ -202,9 +204,9 @@ def test_conv_model_runtime( f"\nActual: {actual_runtime} Roofline Estimate: {roofline_estimate} Accuracy: {roofline_accuracy}" f"\nActual: {actual_runtime} Learned Estimate: {learned_estimate} Accuracy: {learned_accuracy}" ) - self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2) - self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.4) - self.assertAlmostEqual(learned_accuracy, 1.0, delta=0.4) + # No accuracy check for benchmark in CI as it is highly variable + # self.assertAlmostEqual(benchmark_accuracy, 1.0, delta=0.2) + # self.assertAlmostEqual(roofline_accuracy, 1.0, delta=0.4) if __name__ == "__main__": diff --git a/test/distributed/_tools/test_sac_estimator.py b/test/distributed/_tools/test_sac_estimator.py new file mode 100644 index 0000000000000..be2eba257455a --- /dev/null +++ b/test/distributed/_tools/test_sac_estimator.py @@ -0,0 +1,90 @@ +# Owner(s): ["module: unknown"] +import unittest + +import torch +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tools.sac_estimator import SACEstimator +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase +from torch.testing._internal.distributed._tensor.common_dtensor import ( + ModelArgs, + Transformer, +) + + +class TestSACEstimator(TestCase): + def _sac_estimation( + self, + estimate_mode: str, + model: torch.nn.Module, + inp: torch.Tensor, + ): + sace = SACEstimator() + with sace(estimate_mode_type=estimate_mode): + loss = model(inp).sum() + loss.backward() + sace.pwlf_sac_tradeoff_curve(n_segments=2, save_tradeoff_graphs=False) + + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_transformer_sac_estimation(self): + """Runs a basic GPT-2 model""" + dev = torch.cuda.current_device() + vocab_size = 8192 + bsz, seq_len = 8, 1024 + model_args = ModelArgs( + n_layers=4, + n_heads=12, + vocab_size=vocab_size, + max_seq_len=seq_len, + dim=768, + dropout_p=0.1, + ) + with FakeTensorMode(): + with torch.device(dev): + model = Transformer(model_args) + inp = torch.randint( + 0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=dev + ) + + self._sac_estimation("operator-level-benchmark", model, inp) + self._sac_estimation("operator-level-cost-model", model, inp) + + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_simple_model_sac_estimation(self): + """This test checks the correctness of view_ops, random_ops and inplace_ops""" + + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(5, 10) + self.relu1 = torch.nn.ReLU(inplace=True) + + def forward(self, x): + x = self.fc1(x) + x = self.relu1(x) + x = torch.cos_(x) + x = torch.sin_(x) + return x + + dev = torch.cuda.current_device() + with FakeTensorMode(): + with torch.device(dev): + model = Foo() + x = torch.rand((10, 5), device=dev) + + sac_estimator = SACEstimator() + with sac_estimator(estimate_mode_type="operator-level-benchmark"): + loss = model(x).sum() + loss.backward() + + self.assertEqual(sac_estimator.sac_mod_stats["Foo"].view_like_ops, [0]) + self.assertEqual(sac_estimator.sac_mod_stats["Foo"].rand_ops, []) + self.assertEqual( + sac_estimator.sac_mod_stats["Foo"].inplace_ops, [(2, 1), (3, 1), (4, 1)] + ) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/_tools/test_sac_ilp.py b/test/distributed/_tools/test_sac_ilp.py new file mode 100644 index 0000000000000..2d8c96a0a1a07 --- /dev/null +++ b/test/distributed/_tools/test_sac_ilp.py @@ -0,0 +1,252 @@ +# Owner(s): ["module: unknown"] +import copy +import unittest +from typing import Tuple + +import torch +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tools.ilp_utils import ( + aggregate_stats, + get_peak_memory_runtime_baseline, + ModuleInfo, + parse_module_info, +) +from torch.distributed._tools.mem_tracker import _ModState, MemTracker +from torch.distributed._tools.runtime_estimator import RuntimeEstimator +from torch.distributed._tools.sac_estimator import SACEstimator, SACStats +from torch.distributed._tools.sac_ilp import ( + get_optimal_checkpointing_policy_per_module, + sac_milp, +) +from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase +from torch.testing._internal.distributed._tensor.common_dtensor import ( + ModelArgs, + Transformer, +) + + +class TestSACILP(TestCase): + def setUp(self): + super().setUp() + self.device = torch.cuda.current_device() + self.estimate_mode = "operator-level-cost-model" + + def _init_model_input_optimizer( + self, + ) -> Tuple[torch.nn.Module, torch.optim.Optimizer, torch.Tensor]: + bsz = 8 + model_args = ModelArgs( + n_layers=4, + n_heads=12, + vocab_size=8192, + max_seq_len=1024, + dim=768, + dropout_p=0.1, + ) + with torch.device(self.device): + model = Transformer(model_args) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True) + inp = torch.randint( + 0, model_args.vocab_size, (bsz, model_args.max_seq_len), device=self.device + ) + return (model, optimizer, inp) + + def _run_and_get_memTracker( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + inp: torch.Tensor, + ) -> MemTracker: + mem_tracker = MemTracker() + mem_tracker.track_external(model, optimizer) + with mem_tracker as mt: + for iter_idx in range(2): # running twice to initialize optimizer + output = model(inp) + output.sum().backward() + if iter_idx == 1: + last_snapshot = mt.get_tracker_snapshot("current") + optimizer.step() + optimizer.zero_grad() + if iter_idx == 0: + mt.reset_mod_stats() + assert last_snapshot is not None + for mod_stats in mem_tracker.memory_tracking.values(): + # postprocessing due to the fact that for ModTracker, the post backward hook + # is not being called for modules whose inputs don't require gradients + # TODO: fix this in ModTracker and ensure it does not lead to any perf regression + if _ModState.POST_BW not in mod_stats.snapshots.keys(): + mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append( + copy.deepcopy(last_snapshot) + ) + return mem_tracker + + def _run_and_get_runtime_estimator( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + inp: torch.Tensor, + ) -> RuntimeEstimator: + def _run_one_step() -> None: + output = model(inp) + output.sum().backward() + optimizer.step() + optimizer.zero_grad() + + # Initializing optimizer states and warm-up + _run_one_step() + + runtime_estimator = RuntimeEstimator() + with runtime_estimator(estimate_mode_type=self.estimate_mode): + _run_one_step() # We use only one iteration for estimation + return runtime_estimator + + def _run_and_get_sac_estimator( + self, + model: torch.nn.Module, + inp: torch.Tensor, + ) -> SACEstimator: + sac_estimator = SACEstimator() + with sac_estimator(estimate_mode_type=self.estimate_mode): + loss = model(inp).sum() + loss.backward() + return sac_estimator + + def _collect_module_info_with_fake_tensor_mode(self) -> ModuleInfo: + with FakeTensorMode(): + model, optimizer, inp = self._init_model_input_optimizer() + mem_tracker = self._run_and_get_memTracker(model, optimizer, inp) + runtime_estimator = self._run_and_get_runtime_estimator( + model, optimizer, inp + ) + sac_estimator = self._run_and_get_sac_estimator(model, inp) + mod_info = aggregate_stats( + model, + mem_tracker, + runtime_estimator, + sac_estimator, + torch.device(self.device), + ) + return mod_info + + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_sac_ilp_case1(self): + """ + This is a case where the memory budget is either binding or too tight, + meaning that with some AC, the model can fit into GPU memory. + """ + mod_info = self._collect_module_info_with_fake_tensor_mode() + g = parse_module_info(mod_info) + + peak_mem, compute_time = get_peak_memory_runtime_baseline(g) + self.assertAlmostEqual(peak_mem / 2583888896, 1, delta=0.05) + + ac_decisions, recomputation_time, _ = sac_milp( + g, memory_budget=1.6, world_size=4 + ) + + # The solution should AC all four transformer layers. On A100 machine, the percentage of + # activation memory to discard is 0.5232 for three layers and is 0.7964 for the fourth layer. + # Due to symmetry, the layer that has 0.7964 can be any of the first three layers. On CI, + # due to machine variance and difference in flops, the results can be different -- e.g., + # the ratios are 0.672, 0.5646, 0.5646, 0.5646 for the four transformer layers for test + # linux-focal-cuda11.8-py3.10-gcc9 / test (distributed, 1, 3, lf.linux.8xlarge.nvidia.gpu). + # and recomputation_time = 58.14; compute_time = 902.26 + modules_to_ac = set(ac_decisions.keys()) + sorted_discard_ratio = sorted(ac_decisions.values()) + self.assertEqual( + modules_to_ac, + {"Transformer.layers." + str(i) for i in range(4)}, # n_layers=4 + ) + self.assertAlmostEqual(sorted_discard_ratio[0], 0.55, delta=0.05) + self.assertAlmostEqual(sorted_discard_ratio[1], 0.55, delta=0.05) + self.assertAlmostEqual(sorted_discard_ratio[2], 0.55, delta=0.05) + self.assertAlmostEqual(sum(sorted_discard_ratio), 2.35, delta=0.05) + self.assertAlmostEqual(ac_decisions["Transformer.layers.3"], 0.55, delta=0.05) + + # On A100 machine, recomputation_time is 6.97 ms and compute_time is 97.97 ms. + # Since runtime is device_flops dependent, so we only check the ratio + self.assertAlmostEqual( + (recomputation_time / compute_time) / (6.97 / 97.97), 1, delta=0.25 + ) + + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_sac_ilp_case2(self): + """ + This is a case where the memory budget is not binding, meaning that no + AC is needed to fit the model into memory. + """ + mod_info = self._collect_module_info_with_fake_tensor_mode() + g = parse_module_info(mod_info) + ac_decisions, recomputation_time, peak_mem = sac_milp( + g, memory_budget=2.4, world_size=4 + ) + self.assertDictEqual(ac_decisions, {}) + self.assertEqual(recomputation_time, 0) + self.assertGreater(peak_mem, 1) + + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_sac_ilp_case3(self): + """ + This is a case where the memory budget is too tight, meaning that even with + aggressive AC, the model cannot fit into memory. + """ + mod_info = self._collect_module_info_with_fake_tensor_mode() + g = parse_module_info(mod_info) + ac_decisions, recomputation_time, peak_mem = sac_milp( + g, memory_budget=0.8, world_size=4 + ) + self.assertEqual(ac_decisions, {}) + self.assertEqual(recomputation_time, 0) + self.assertEqual(peak_mem, -1) + + +class TestOptimalCheckpointingPolicy(TestCase): + # tests are adpated from tests in xformers + # https://github.com/facebookresearch/xformers/blob/c6c0ac31f1b08542a0bc27278c6ed10f825f6963/tests/test_checkpoint.py#L222 + def setUp(self): + super().setUp() + data = [ + ("aten.copy_", 5, 0), + ("aten.add", 5, 100), + ("aten.div", 8, 100), + ("aten.mm", 15, 120), + ("aten.native_dropout", 15, 0), + ("aten.linear", 9, 100), + ("aten.t", 1, 0), + ("aten.relu_", 5, 0), + ] + self.sac_stats = SACStats( + func_names=[x[0] for x in data], + runtimes=[x[1] for x in data], + memory=[x[2] for x in data], + view_like_ops=[6], + rand_ops=[4], + saved_autograd_ops=[], # not needed for SAC decisions + inplace_ops=[(0, 0), (7, 5)], + force_store_random=False, + ) + + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") + @unittest.skipIf(not TEST_CUDA, "CUDA not available") + def test_get_optimial_checkpointing_policy_per_module(self): + for memory_budget, optimal_soln in [ + (0, [1, 0, 0, 0, 1, 0, 0, 0]), + (100 / 420, [1, 0, 0, 0, 1, 1, 0, 1]), + (120 / 420, [1, 0, 0, 1, 1, 0, 0, 0]), + (200 / 420, [1, 0, 1, 0, 1, 1, 0, 1]), + (220 / 420, [1, 0, 0, 1, 1, 1, 0, 1]), + (320 / 420, [1, 0, 1, 1, 1, 1, 0, 1]), + (420 / 420, [1, 1, 1, 1, 1, 1, 0, 1]), + ]: + soln = get_optimal_checkpointing_policy_per_module( + sac_stats=self.sac_stats, memory_budget=memory_budget + ) + self.assertEqual(optimal_soln, soln) + + +if __name__ == "__main__": + run_tests() diff --git a/test/distributed/checkpoint/test_traverse.py b/test/distributed/checkpoint/test_traverse.py index f1815b41d9485..ca79c2daa4774 100644 --- a/test/distributed/checkpoint/test_traverse.py +++ b/test/distributed/checkpoint/test_traverse.py @@ -12,8 +12,11 @@ from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE -# TODO: add comments for TestTraverse class TestTraverse(TestCase): + """ + Test class for util methods of _traverse + """ + def test_traverse_shallow(self) -> None: state_dict = { "key0": 1, diff --git a/test/distributed/fsdp/test_distributed_checkpoint.py b/test/distributed/fsdp/test_distributed_checkpoint.py index 54a444d1dc944..8ed99c655d579 100644 --- a/test/distributed/fsdp/test_distributed_checkpoint.py +++ b/test/distributed/fsdp/test_distributed_checkpoint.py @@ -4,12 +4,7 @@ import torch from torch import distributed as dist -from torch.distributed.checkpoint import ( - FileSystemReader, - FileSystemWriter, - load_state_dict, - save_state_dict, -) +from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter, load, save from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, StateDictType from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel from torch.distributed.fsdp.wrap import enable_wrap, wrap @@ -71,13 +66,13 @@ def test_distributed_checkpoint(self, state_dict_type) -> None: ): state_dict = model.state_dict() - save_state_dict(state_dict, writer) + save(state_dict, writer) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type ): state_dict = new_model.state_dict() - load_state_dict(state_dict, reader) + load(state_dict, reader) new_model.load_state_dict(state_dict) with FullyShardedDataParallel.summon_full_params( diff --git a/test/distributed/fsdp/test_fsdp_flatten_params.py b/test/distributed/fsdp/test_fsdp_flatten_params.py index cb3cf7087db02..5581318b1c386 100644 --- a/test/distributed/fsdp/test_fsdp_flatten_params.py +++ b/test/distributed/fsdp/test_fsdp_flatten_params.py @@ -13,7 +13,12 @@ ) from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import FSDPTest -from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, + run_tests, + TEST_WITH_DEV_DBG_ASAN, +) if not dist.is_available(): @@ -335,6 +340,8 @@ def test_flat_param_shard_metadata_unaligned(self): expected=FlatParamShardMetadata( param_names=["0.weight"], param_shapes=[(10, 10)], + param_strides=[(10, 1)], + param_contiguities=[True], param_numels=[100], param_offsets=[(0, 0)], ), @@ -346,6 +353,8 @@ def test_flat_param_shard_metadata_unaligned(self): expected=FlatParamShardMetadata( param_names=["0.weight"], param_shapes=[(10, 10)], + param_strides=[(10, 1)], + param_contiguities=[True], param_numels=[100], param_offsets=[(0, 50)], ), @@ -357,6 +366,8 @@ def test_flat_param_shard_metadata_unaligned(self): expected=FlatParamShardMetadata( param_names=["0.weight"], param_shapes=[(10, 10)], + param_strides=[(10, 1)], + param_contiguities=[True], param_numels=[100], param_offsets=[(0, 99)], ), @@ -368,6 +379,8 @@ def test_flat_param_shard_metadata_unaligned(self): expected=FlatParamShardMetadata( param_names=["0.weight", "2.weight"], param_shapes=[(10, 10), (10, 10)], + param_strides=[(10, 1), (10, 1)], + param_contiguities=[True, True], param_numels=[100, 100], param_offsets=[(50, 99), (0, 49)], ), @@ -379,6 +392,8 @@ def test_flat_param_shard_metadata_unaligned(self): expected=FlatParamShardMetadata( param_names=["0.weight", "2.weight"], param_shapes=[(10, 10), (10, 10)], + param_strides=[(10, 1), (10, 1)], + param_contiguities=[True, True], param_numels=[100, 100], param_offsets=[(50, 99), (0, 99)], ), @@ -390,6 +405,8 @@ def test_flat_param_shard_metadata_unaligned(self): expected=FlatParamShardMetadata( param_names=["0.weight", "2.weight"], param_shapes=[(10, 10), (10, 10)], + param_strides=[(10, 1), (10, 1)], + param_contiguities=[True, True], param_numels=[100, 100], param_offsets=[(99, 99), (0, 99)], ), @@ -401,6 +418,8 @@ def test_flat_param_shard_metadata_unaligned(self): expected=FlatParamShardMetadata( param_names=["2.weight"], param_shapes=[(10, 10)], + param_strides=[(10, 1)], + param_contiguities=[True], param_numels=[100], param_offsets=[(0, 99)], ), @@ -412,6 +431,8 @@ def test_flat_param_shard_metadata_unaligned(self): expected=FlatParamShardMetadata( param_names=["2.weight", "4.weight"], param_shapes=[(10, 10), (10, 10)], + param_strides=[(10, 1), (10, 1)], + param_contiguities=[True, True], param_numels=[100, 100], param_offsets=[(0, 99), (0, 99)], ), @@ -423,6 +444,8 @@ def test_flat_param_shard_metadata_unaligned(self): expected=FlatParamShardMetadata( param_names=["2.weight", "4.weight"], param_shapes=[(10, 10), (10, 10)], + param_strides=[(10, 1), (10, 1)], + param_contiguities=[True, True], param_numels=[100, 100], param_offsets=[(0, 99), (0, 99)], ), @@ -434,6 +457,8 @@ def test_flat_param_shard_metadata_unaligned(self): expected=FlatParamShardMetadata( param_names=["4.weight"], param_shapes=[(10, 10)], + param_strides=[(10, 1)], + param_contiguities=[True], param_numels=[100], param_offsets=[(99, 99)], ), @@ -469,6 +494,8 @@ def test_flat_param_shard_metadata_aligned_full_precision(self): expected=FlatParamShardMetadata( param_names=["0.weight", "1.weight"], param_shapes=[(7, 3), (5, 7)], + param_strides=[(3, 1), (7, 1)], + param_contiguities=[True, True], param_numels=[21, 35], # 21 + (3) + 19 = 43 param_offsets=[(0, 20), (0, 18)], @@ -482,6 +509,8 @@ def test_flat_param_shard_metadata_aligned_full_precision(self): expected=FlatParamShardMetadata( param_names=["1.weight", "2.weight"], param_shapes=[(5, 7), (5, 5)], + param_strides=[(7, 1), (5, 1)], + param_contiguities=[True, True], param_numels=[35, 25], # 16 + (1) + 25 = 42 param_offsets=[(19, 34), (0, 24)], @@ -519,6 +548,8 @@ def test_flat_param_shard_metadata_aligned_mixed_precision(self): expected=FlatParamShardMetadata( param_names=["0.weight", "1.weight"], param_shapes=[(5, 2), (5, 5)], + param_strides=[(2, 1), (5, 1)], + param_contiguities=[True, True], param_numels=[10, 25], # 10 + (6) + 16 = 32 param_offsets=[(0, 9), (0, 15)], @@ -532,6 +563,8 @@ def test_flat_param_shard_metadata_aligned_mixed_precision(self): expected=FlatParamShardMetadata( param_names=["1.weight", "2.weight"], param_shapes=[(5, 5), (3, 5)], + param_strides=[(5, 1), (5, 1)], + param_contiguities=[True, True], param_numels=[25, 15], # 9 + (7) + 15 = 31 param_offsets=[(16, 24), (0, 14)], @@ -565,6 +598,57 @@ def _test_flat_param_shard_metadata( msg=f"{handle.shard_metadata()}, {expected}", ) + @parametrize("memory_format", [torch.contiguous_format, torch.channels_last]) + def test_flat_param_shard_metadata_with_memory_format(self, memory_format): + """ + Tests that ``FlatParameter`` shard metadata are computed as expected + with alignment padding and parameter full precision. + """ + module = torch.nn.Sequential( + torch.nn.Conv2d(10, 20, 3, bias=False), # 0.weight, 1800 params + torch.nn.Conv2d(20, 10, 5, bias=False), # 1.weight, 5000 params + torch.nn.Conv2d(10, 10, 1, bias=False), # 2.weight, 100 params + ).to(memory_format=memory_format) + params_to_flatten = list(module.parameters()) + handle_kwargs = self._get_default_config() + handle_kwargs["use_orig_params"] = True + handle = FlatParamHandle(params_to_flatten, module, **handle_kwargs) + contiguous_tensors = memory_format == torch.contiguous_format + self._test_flat_param_shard_metadata( + handle, + # Emulate rank 0 of 2 ranks + start=0, + end=2999, + expected=FlatParamShardMetadata( + param_names=["0.weight", "1.weight"], + param_shapes=[(20, 10, 3, 3), (10, 20, 5, 5)], + param_strides=[(90, 9, 3, 1), (500, 25, 5, 1)] + if contiguous_tensors + else [(90, 1, 30, 10), (500, 1, 100, 20)], + param_contiguities=[contiguous_tensors, contiguous_tensors], + param_numels=[1800, 5000], + param_offsets=[(0, 1799), (0, 1199)], + ), + ) + self._test_flat_param_shard_metadata( + handle, + # Emulate rank 1 of 2 ranks + start=3000, + end=6899, + expected=FlatParamShardMetadata( + param_names=["1.weight", "2.weight"], + param_shapes=[(10, 20, 5, 5), (10, 10, 1, 1)], + param_strides=[(500, 25, 5, 1), (10, 1, 1, 1)] + if contiguous_tensors + else [(500, 1, 100, 20), (10, 1, 10, 10)], + param_contiguities=[contiguous_tensors, contiguous_tensors], + param_numels=[5000, 100], + param_offsets=[(1200, 4999), (0, 99)], + ), + ) + + +instantiate_parametrized_tests(TestFlattenParams) if __name__ == "__main__": run_tests() diff --git a/test/distributed/fsdp/test_fsdp_state_dict.py b/test/distributed/fsdp/test_fsdp_state_dict.py index 96fa6f8457e9b..0fa1b38eef42b 100644 --- a/test/distributed/fsdp/test_fsdp_state_dict.py +++ b/test/distributed/fsdp/test_fsdp_state_dict.py @@ -156,8 +156,7 @@ class TestFSDPStateDict(FSDPTest): def world_size(self): return min(torch.cuda.device_count(), 2) - def _broadcast_state_dict(self, model, state_dict): - # TODO (rohan-varma): remove model + def _broadcast_state_dict(self, state_dict): return _broadcast_state_dict(self.rank, state_dict) def _state_compare(self, model, model_new, assert_fn, state_generator="parameters"): @@ -361,7 +360,7 @@ def apply_ac_to_linears(model) -> None: _zero_model(model_new) self._compare_models(model, model_new, self.assertNotEqual) if rank0_only_and_offload: - state_dict = self._broadcast_state_dict(model, state_dict) + state_dict = self._broadcast_state_dict(state_dict) # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks model_new.load_state_dict(state_dict, strict=True) self._compare_models(model, model_new, self.assertEqual) @@ -417,8 +416,8 @@ def test_state_dict_with_manual_ac_wrapper( state_dict_ac = model_ac.state_dict() self.assertEqual(state_dict_ac.keys(), state_dict_no_ac.keys()) if rank0_only_and_offload: - state_dict_no_ac = self._broadcast_state_dict(model_no_ac, state_dict_no_ac) - state_dict_ac = self._broadcast_state_dict(model_ac, state_dict_ac) + state_dict_no_ac = self._broadcast_state_dict(state_dict_no_ac) + state_dict_ac = self._broadcast_state_dict(state_dict_ac) with self._get_state_dict_mgr( model_no_ac, state_dict_type, rank0_only_and_offload ): @@ -612,7 +611,7 @@ def test_basic_save_and_load_state_dict( # Verify parameters are the same in the new model. if state_dict_rank0_and_offload: - fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict) + fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]): model_new.load_state_dict(fsdp_state_dict, strict=True) @@ -679,7 +678,7 @@ def test_buffers_save_and_load_state_dict( # Verify parameters are the same in the new model. if state_dict_rank0_and_offload: - fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict) + fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]): model_new.load_state_dict(fsdp_state_dict, strict=True) @@ -746,7 +745,7 @@ def test_save_and_load_after_forward_state_dict( # Load state_dict into zeroed model if state_dict_rank0_and_offload: - state_dict = self._broadcast_state_dict(model, state_dict) + state_dict = self._broadcast_state_dict(state_dict) with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]): model.load_state_dict(state_dict, strict=True) @@ -926,7 +925,7 @@ def test_state_dict_load_into_local_module( # Load fsdp's full state dict into the local and verify params are as # expected. if state_dict_rank0_and_offload: - fsdp_state_dict = self._broadcast_state_dict(model, fsdp_state_dict) + fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) blank_local_model.load_state_dict(fsdp_state_dict, strict=True) local_params = list(blank_local_model.parameters()) diff --git a/test/distributed/pipelining/test_backward.py b/test/distributed/pipelining/test_backward.py index ff2f27c1c0ccf..a19092d8a211d 100644 --- a/test/distributed/pipelining/test_backward.py +++ b/test/distributed/pipelining/test_backward.py @@ -75,7 +75,7 @@ def test_stage_backward_input(self): out = mod(x) loss = loss_fn(out, target) dinputs, param_groups = stage_backward_input( - stage_outputs=(loss,), + stage_outputs_or_loss=(loss,), output_grads=None, input_values=[x], weights=mod.parameters(), @@ -110,7 +110,7 @@ def test_stage_backward_weight(self): out = mod(x) loss = loss_fn(out, target) dinputs, param_groups = stage_backward_input( - stage_outputs=(loss,), + stage_outputs_or_loss=(loss,), output_grads=None, input_values=[x], weights=mod.parameters(), @@ -158,7 +158,7 @@ def test_stage_backward_weight_multiple_iters(self): out = mod(x) loss = loss_fn(out, target) dinputs, param_groups = stage_backward_input( - stage_outputs=(loss,), + stage_outputs_or_loss=(loss,), output_grads=None, input_values=[x], weights=mod.parameters(), diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py index ce1bbc51412b9..0b3134e273f30 100644 --- a/test/distributed/pipelining/test_schedule.py +++ b/test/distributed/pipelining/test_schedule.py @@ -5,8 +5,8 @@ import torch from torch.distributed.pipelining import ( - ScheduleFlexibleInterleaved1F1B, ScheduleInterleaved1F1B, + ScheduleInterleavedZeroBubble, ScheduleLoopedBFS, ) from torch.distributed.pipelining.schedules import ( @@ -66,7 +66,6 @@ def test_get_schedule_class(self): "Interleaved1F1B", "INTERLEAVED1F1B", "GPipe", - "FlexibleInterleaved1F1B", "LoopedBFS", "PipelineScheduleSingle", "PipelineScheduleMulti", @@ -164,7 +163,7 @@ def test_pipeline_order(self, ScheduleClass): @parametrize( "ScheduleClass", - [ScheduleFlexibleInterleaved1F1B], + [ScheduleInterleaved1F1B, ScheduleInterleavedZeroBubble], ) def test_pipeline_order_flex_and_zero_bubble(self, ScheduleClass): for num_local_stages, num_microbatches, group_size in self.test_cases: @@ -179,25 +178,22 @@ def test_pipeline_order_flex_and_zero_bubble(self, ScheduleClass): warmup_ops = warmups_ops_last_stage + 2 * (group_size - 1) warmup_ops = min(warmup_ops, num_microbatches * num_local_stages) - for i in range(2): - num_stages = num_local_stages * group_size - stages = [ - MockPipelineStage(group_size=group_size, num_stages=num_stages) - for i in range(num_local_stages) - ] - schedule = ScheduleClass( - stages, num_microbatches, enable_zero_bubble=(i == 0) - ) - formatted_pipeline_order = _format_pipeline_order( - schedule.pipeline_order - ) - # print(formatted_pipeline_order) - _validate_pipeline_order( - schedule.pipeline_order, - num_microbatches, - num_stages, - enable_zero_bubble=(i == 0), - ) + num_stages = num_local_stages * group_size + stages = [ + MockPipelineStage(group_size=group_size, num_stages=num_stages) + for i in range(num_local_stages) + ] + schedule = ScheduleClass(stages, num_microbatches) + formatted_pipeline_order = _format_pipeline_order( + schedule.pipeline_order + ) + # print(formatted_pipeline_order) + _validate_pipeline_order( + schedule.pipeline_order, + num_microbatches, + num_stages, + enable_zero_bubble=(ScheduleClass is ScheduleInterleavedZeroBubble), + ) instantiate_parametrized_tests(TestSchedulePlan) diff --git a/test/distributed/pipelining/test_schedule_multiproc.py b/test/distributed/pipelining/test_schedule_multiproc.py index c70994e4f17a6..7e38fd14e492c 100644 --- a/test/distributed/pipelining/test_schedule_multiproc.py +++ b/test/distributed/pipelining/test_schedule_multiproc.py @@ -16,7 +16,6 @@ pipeline, PipelineStage, Schedule1F1B, - ScheduleFlexibleInterleaved1F1B, ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleInterleavedZeroBubble, @@ -278,7 +277,8 @@ def test_grad_with_tracer(self, ScheduleClass, ModelClass): @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @parametrize("ScheduleClass", [ScheduleGPipe, Schedule1F1B]) - def test_grad_with_manual(self, ScheduleClass): + @parametrize("shape_inference", [True, False]) + def test_grad_with_manual(self, ScheduleClass, shape_inference): full_mod = MultiMLP(d_hid, n_layers=self.world_size) full_mod.to(self.device) @@ -302,13 +302,23 @@ def test_grad_with_manual(self, ScheduleClass): submod_name = f"layers.{self.rank}" stage_module = full_mod.get_submodule(submod_name) chunks = 4 + + if shape_inference: + input_args = None + output_args = None + else: + input_args = (x.chunk(chunks)[0],) + with torch.no_grad(): + output_args = stage_module(*input_args) + # Create a pipeline stage to wrap that submodule stage = PipelineStage( stage_module, self.rank, self.world_size, self.device, - input_args=x.chunk(chunks)[0], + input_args=input_args, + output_args=output_args, ) # Attach to a schedule @@ -399,7 +409,6 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): stage_idx, n_stages, self.device, - input_args=input_args, ) for stage_module, stage_idx in zip(stage_modules, stage_indices) ] @@ -502,10 +511,10 @@ def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime): @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") - @parametrize("ScheduleClass", [ScheduleWithW, ScheduleFlexibleInterleaved1F1B]) + @parametrize("ScheduleClass", [ScheduleWithW, ScheduleInterleavedZeroBubble]) def test_schedule_with_native_zero_bubble(self, ScheduleClass): print(ScheduleClass) - if ScheduleClass is ScheduleFlexibleInterleaved1F1B: + if ScheduleClass is ScheduleInterleavedZeroBubble: n_stages = 4 num_microbatches = 8 rank_stages = { @@ -545,14 +554,11 @@ def test_schedule_with_native_zero_bubble(self, ScheduleClass): stage_idx, n_stages, self.device, - input_args=input_args, ) for stage_module, stage_idx in zip(stage_modules, rank_stages[self.rank]) ] - schedule = ScheduleClass( - stages, num_microbatches, loss_fn=loss_fn, enable_zero_bubble=True - ) + schedule = ScheduleClass(stages, num_microbatches, loss_fn=loss_fn) # Run reference ref_x = x.clone().detach().requires_grad_(x.requires_grad) @@ -633,7 +639,6 @@ def test_non_symmetric_stage_ids(self, ScheduleClass): stage_idx, n_stages, self.device, - input_args=input_args, ) for stage_module, stage_idx in zip(stage_modules, rank_stages[self.rank]) ] @@ -685,7 +690,7 @@ def test_non_symmetric_stage_ids(self, ScheduleClass): @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") - @parametrize("ScheduleClass", [ScheduleFlexibleInterleaved1F1B]) + @parametrize("ScheduleClass", [ScheduleInterleavedZeroBubble]) def test_schedule_with_weight_update_mlp_e2e(self, ScheduleClass): stages_per_rank = 2 n_stages = stages_per_rank * self.world_size @@ -757,16 +762,13 @@ def dw_runner(): stage_idx, n_stages, self.device, - input_args=input_args, dw_builder=cs[stage_idx].dw_builder, ) for stage_module, stage_idx in zip(stage_modules, stage_indices) ] # Attach to a schedule - schedule = ScheduleClass( - stages, chunks, loss_fn=full_loss_fn, enable_zero_bubble=True - ) + schedule = ScheduleClass(stages, chunks, loss_fn=full_loss_fn) for _ in range(2): # Zero gradients diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py index 3aa715b7826ba..b02e7e25aff0f 100644 --- a/test/distributed/pipelining/test_stage.py +++ b/test/distributed/pipelining/test_stage.py @@ -203,7 +203,6 @@ def test_manual(self): self.rank, self.world_size, self.device, - input_args=x.chunk(chunks)[0], ) # Attach to a schedule @@ -273,7 +272,6 @@ def dw_runner(): self.rank, self.world_size, self.device, - input_args=x.chunk(chunks)[0], dw_builder=cs.dw_builder, ) @@ -320,7 +318,6 @@ def test_custom_dw_errors(self): self.rank, self.world_size, self.device, - input_args=x.chunk(chunks)[0], dw_builder=lambda: None, ) with self.assertRaisesRegex(AssertionError, "backward_one_chunk"): diff --git a/test/distributed/tensor/parallel/test_parallelize_api.py b/test/distributed/tensor/parallel/test_parallelize_api.py index 06f968f0b0365..f27de4736e536 100644 --- a/test/distributed/tensor/parallel/test_parallelize_api.py +++ b/test/distributed/tensor/parallel/test_parallelize_api.py @@ -265,6 +265,33 @@ def test_parallelize_module_multi_wildcard(self): ) self._compare_module(model, model_tp, inp_size, rank0_only=False) + @with_comms + def test_under_devicemesh_context(self): + # test ColwiseParallel + inp_size = [8, 10] + colwise = ColwiseParallel(output_layouts=Replicate()) + + torch.manual_seed(5) + model = torch.nn.Linear(10, 16, device=self.device_type) + model_tp = deepcopy(model) + + # Call parallelize_module under DeviceMesh context. + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + with device_mesh: + model_tp = parallelize_module(model_tp, parallelize_plan=colwise) + + self._compare_module(model, model_tp, inp_size) + + @with_comms + def test_empty_plan(self): + torch.manual_seed(5) + model = torch.nn.Linear(10, 16, device=self.device_type) + + # Call parallelize_module with empty plan. + # Goal is not to crash. + device_mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + parallelize_module(model, device_mesh) + if __name__ == "__main__": run_tests() diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index 97924f0b70738..0f4bf91edc21b 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -597,14 +597,11 @@ def func(arg: torch.Tensor) -> torch.Tensor: ( FileCheck() .check("buf0 = empty") - # Ensure the all_reduce_ input is a view - .check( - "torch.ops._c10d_functional.all_reduce_.default(reinterpret_tensor(buf0" - ) - .check( - "torch.ops._c10d_functional.wait_tensor.default(reinterpret_tensor(buf0" - ) - .check("return (reinterpret_tensor(buf0") + # We always call .contiguous() on the input to all_reduce_, + # so input will not be a view anymore. + .check("torch.ops._c10d_functional.all_reduce_.default(buf0") + .check("torch.ops._c10d_functional.wait_tensor.default(buf0") + .check("return (buf0") .run(code) ) @@ -624,6 +621,16 @@ def func(arg: torch.Tensor) -> torch.Tensor: # clone induced by non contig input assert "torch.ops._c10d_functional.wait_tensor.default" in code + def func2(arg: torch.Tensor) -> torch.Tensor: + torch.ops._c10d_functional.all_reduce_(arg, "avg", "0") + return arg + + compiled = torch.compile(func) + + code = run_and_get_triton_code(compiled, arg) + # clone induced by non contig input + assert "torch.ops._c10d_functional.wait_tensor.default" in code + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_reuse_buffer_after_inplace_collective(self): @@ -701,7 +708,7 @@ def func(args: List[torch.Tensor]) -> torch.Tensor: FileCheck() .check( "buf0 = torch.ops._c10d_functional.all_gather_into_tensor_coalesced" - ".default([arg0_1, arg1_1, arg2_1, arg3_1]" + ".default([arg3_1, arg2_1, arg1_1, arg0_1]" ) .check("buf1 = buf0[0]") .check("buf2 = buf0[1]") @@ -720,6 +727,28 @@ def func(args: List[torch.Tensor]) -> torch.Tensor: out = AOTIRunnerUtil.run("cuda", func, (args,)) torch.cuda.synchronize() + @unittest.skipIf(not HAS_GPU, "This is a GPU test!") + @fresh_inductor_cache() + def test_wait_tensor(self): + def func(arg: torch.Tensor) -> torch.Tensor: + t = torch.ops._c10d_functional.all_reduce(arg, "avg", "0") + return funcol.wait_tensor(t) + + # Test aoti + arg = torch.rand(4, 4, device="cuda") + compiled = torch.compile(func) + code = run_and_get_triton_code(compiled, arg) + ( + FileCheck() + .check("torch.ops._c10d_functional.wait_tensor.default(buf0") + .check("return (buf0, )") + .run(code) + ) + + # Test aoti + out = AOTIRunnerUtil.run("cuda", func, (arg,)) + torch.cuda.synchronize() + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @fresh_inductor_cache() def test_inductor_reduce_scatter_tensor_single(self): diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 333ad465e0d9f..64a210ed3b6c0 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -37,7 +37,7 @@ import torch.nn.functional as F import torch.testing._internal.common_utils as common from torch import nn -from torch._C._distributed_c10d import OpType +from torch._C._distributed_c10d import OpType, WorkResult from torch.nn.parallel import DistributedDataParallel from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_distributed import ( @@ -49,6 +49,7 @@ requires_nccl_version, skip_if_lt_x_gpu, skip_if_rocm_multiprocess, + sm_is_or_higher_than, TEST_SKIPS, with_dist_debug_levels, with_nccl_blocking_wait, @@ -320,25 +321,30 @@ def abortpg(): @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") - def test_close_pg(self): + @parametrize("eager_init", [True, False]) + def test_close_pg(self, eager_init: bool): # Disable ASYNC_ERROR_HANDLING for this test to ensure we can programmatically # abort the process group. os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" store = c10d.FileStore(self.file_name, self.world_size) - pg = self._create_process_group_nccl(store, self.opts()) - device = self.rank_to_GPU[self.rank][0] + device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}") + c10d.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + device_id=device if eager_init else None, + ) t = torch.rand(10, 10, device=device) # First allreduce to initialize state. - pg.allreduce(t) + dist.all_reduce(t) # Destroy pg and validate pg is no longer valid dist.destroy_process_group() - with self.assertRaises(dist.DistBackendError): - pg.allreduce([t]) - - del pg + with self.assertRaises(ValueError): + dist.all_reduce(t) CUDA_12_AND_ABOVE = torch.cuda.is_available() and ( torch.version.cuda is not None and int(torch.version.cuda.split(".")[0]) >= 12 @@ -431,9 +437,12 @@ def test_nan_rank_filter(self): @skip_if_lt_x_gpu(2) def test_nan_check(self): # Not expecting an error, NaN check should not make legit code fail + device = torch.device("cuda:%d" % self.rank) + if not sm_is_or_higher_than(device, 8, 0): + self.skipTest("bf16 requires sm >= 8.0") + os.environ["TORCH_NCCL_NAN_CHECK"] = "1" store = c10d.FileStore(self.file_name, self.world_size) - device = torch.device("cuda:%d" % self.rank) c10d.init_process_group( backend="nccl", store=store, rank=self.rank, world_size=self.world_size ) @@ -446,6 +455,95 @@ def test_nan_check(self): # reset env os.environ["TORCH_NCCL_NAN_CHECK"] = "0" + def _helper_test_extra_cuda_context_by_nvml(self): + """ + A helper for `test_extra_cuda_context`, if pynvml is avaiable. + pynvml provides python bindings for NVIDIA NVML functionalities. + Here we are interested in: nvmlDeviceGetComputeRunningProcesses + """ + import pynvml + + pynvml.nvmlInit() + + device = torch.device("cuda:%d" % self.rank) + x = torch.empty((1,), device=device) + work = c10d.all_reduce(x, async_op=True) + + # Wait for non-0 ranks to garbage collect Work -- this is the latest + # point where extra CUDA context can be created + if self.rank == 0: + time.sleep(5) + del work + handle = pynvml.nvmlDeviceGetHandleByIndex(self.rank) + processes = pynvml.nvmlDeviceGetComputeRunningProcesses(handle) + nprocs = len(processes) + + # A barrier for non-0 ranks + c10d.all_reduce(x) + torch.cuda.synchronize(device) + c10d.destroy_process_group() + self.assertEqual( + nprocs, + 1, + f"Found {nprocs} processes creating contexts on {device}, expecting 1 only", + ) + + def _helper_test_extra_cuda_context_by_memory(self): + """ + A helper for `test_extra_cuda_context`, if pynvml is NOT avaiable. + If extra context is created, it would manifest into device 0's memory usage. + """ + device = torch.device("cuda:%d" % self.rank) + x = torch.empty((1,), device=device) + # Rank 0 takes a snapshot before collective -- this snapshot should have + # included rank 0's own context. + if self.rank == 0: + free, total = torch.cuda.mem_get_info(device) + used_before = float(total - free) + + work = c10d.all_reduce(x, async_op=True) + + # Wait for non-0 ranks to garbage collect Work -- this is the latest + # point where extra CUDA context can be created + if self.rank == 0: + time.sleep(5) + free, total = torch.cuda.mem_get_info(device) + used_after = float(total - free) + del work + + # A barrier for non-0 ranks + c10d.all_reduce(x) + torch.cuda.synchronize(device) + c10d.destroy_process_group() + if self.rank == 0: + # If non-0 rank creates a context on device 0, this assert would + # fail because one context takes about 1 GB -- much more than the + # tensor size created in this test. + self.assertTrue( + used_after < used_before * 1.5, + f"{device} used {used_after} bytes after collective, " + f"50% more than the status before ({used_before} bytes). " + f"Extra CUDA context may have been created.", + ) + + @requires_nccl() + @skip_if_lt_x_gpu(2) + def test_extra_cuda_context(self): + # Check if non-0 ranks would create extra CUDA context on device 0 + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device("cuda:%d" % self.rank) + c10d.init_process_group( + backend="nccl", + store=store, + rank=self.rank, + world_size=self.world_size, + device_id=device, + ) + try: + self._helper_test_extra_cuda_context_by_nvml() + except ModuleNotFoundError: + self._helper_test_extra_cuda_context_by_memory() + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_destruct_before_terminate_pg(self): @@ -533,8 +631,9 @@ def test_abort_in_destroy_multi_pgs(self): new_pg1.allreduce(t1).wait() new_pg2.allreduce(t2).wait() backend = pg._get_backend(torch.device(device)) - # default PG's backend should have a split count of 2 - self.assertEqual(backend.comm_split_count(), 2) + # default PG's backend should have a split count of 0 because + # it's not eager initialized + self.assertEqual(backend.comm_split_count(), 0) # shutdown all NCCL PGs in one shot dist.destroy_process_group() @@ -556,8 +655,8 @@ def test_abort_in_destroy_mixed_empty_pgs(self): new_pg2.allreduce(t2).wait() backend = pg._get_backend(torch.device(device)) - # default PG's backend should have a split count of 1 - self.assertEqual(backend.comm_split_count(), 1) + # default PG's backend should have a split count of 0 + self.assertEqual(backend.comm_split_count(), 0) # shutdown all NCCL PGs in one shot dist.destroy_process_group() @@ -709,27 +808,24 @@ def test_extend_nccl_pg_timeout(self, backend): @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit") @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") - def test_comm_split_optimization(self): + @parametrize("eager_init", [True, False]) + def test_new_group(self, eager_init: bool): # Test the optimization of new groups that contain all world # ranks use the "transparent" `ncclCommSplit` optimization. store = c10d.FileStore(self.file_name, self.world_size) - pg = self._create_process_group_nccl(store, self.opts()) - - # Test lazy splitting behavior across each per-device backend. - for device in self.rank_to_GPU[self.rank]: - backend = pg._get_backend(torch.device(device)) - - # split doesn't happen unless the original process group has lazily - # created communicators, so first verify we haven't split even when - # making the new group and running an operation on the original pg. - ng = c10d.new_group() - tensor = torch.tensor([self.rank]).cuda(device) - pg.broadcast(tensor, 0) - self.assertEqual(backend.comm_split_count(), 0) - - # The new group will force a split of the original on first use. - ng.broadcast(tensor, 0) - self.assertEqual(backend.comm_split_count(), 1) + device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}") + c10d.init_process_group( + "nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + device_id=device if eager_init else None, + ) + ng = c10d.new_group() + tensor = torch.tensor([self.rank], device=device) + dist.broadcast(tensor, 0) + dist.broadcast(tensor, 0, group=ng) + dist.destroy_process_group() @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit") @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @@ -758,6 +854,26 @@ def test_comm_split_subgroup(self): self.assertEqual(tensor, original_tensor) dist.destroy_process_group() + @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit") + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_comm_eager_init_subgroup(self): + # Test `ncclCommSplit` for smaller subgroups of the world when + # we've passed a specific device_id to init_process_group. + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device(f"cuda:{self.rank}") + # default PG comm is not initialized yet + pg = self._create_process_group_nccl(store, self.opts()) + backend = pg._get_backend(torch.device(device)) + self.assertEqual(backend._is_initialized(), False) + # create a subgroup eagerly + new_group = c10d.new_group([0, 1], device_id=device) + tensor = torch.full((1,), self.rank).cuda(device) + dist.broadcast(tensor, 0, group=new_group) + # the default group should stay lazy + self.assertEqual(backend._is_initialized(), False) + torch.cuda.synchronize() + dist.destroy_process_group() + @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit") @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_comm_split_group(self): @@ -769,8 +885,10 @@ def test_comm_split_group(self): backend = pg._get_backend(torch.device(device)) tensor = torch.full((1,), self.rank).cuda(device) - ng1 = c10d.split_group(pg, [[0, 1]]) - backend1 = pg._get_backend(torch.device(device)) + # Create subgroup between ranks 0, 1 + subg_ranks = [0, 1] + ng1 = c10d.split_group(pg, [subg_ranks]) + backend1 = ng1._get_backend(torch.device(device)) # check basic options are the same between parent and child self.assertEqual(backend.options._timeout, backend1.options._timeout) @@ -782,10 +900,18 @@ def test_comm_split_group(self): # comm split happens eagerly since device_id is passed to init_process_group. self.assertEqual(backend.comm_split_count(), 1) - dist.broadcast(tensor, 0, group=ng1) - self.assertEqual(tensor, torch.full((1,), 0)) + # dist.get_process_group_ranks returns the global ranks in the subgroup. + self.assertEqual( + dist.get_process_group_ranks(ng1), + subg_ranks if self.rank in subg_ranks else [], + ) + + # is part of ng1; otherwise, -1 + if dist.get_rank(ng1) >= 0: + dist.broadcast(tensor, dist.get_global_rank(ng1, 0), group=ng1) + self.assertEqual(tensor, torch.full((1,), 0)) - ng2 = c10d.split_group(pg, [[0, 1]]) + ng2 = c10d.split_group(pg, [subg_ranks]) self.assertEqual(ng2.group_desc, "default_pg:split:1") self.assertEqual(backend.comm_split_count(), 2) @@ -810,7 +936,7 @@ def test_non_blocking_init(self): self.assertEqual(backend.comm_split_count(), 0) broadcast_tensor = torch.tensor([self.rank]).cuda(device) new_pg.broadcast(broadcast_tensor, 0).wait() - self.assertEqual(backend.comm_split_count(), 1) + self.assertEqual(backend.comm_split_count(), 0) dist.destroy_process_group() @requires_nccl_version((2, 18), "Need NCCL 2.18+ for ncclCommSplit") @@ -838,6 +964,24 @@ def test_non_blocking_with_eager_init(self): self.assertEqual(backend.comm_split_count(), 1) dist.destroy_process_group() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + def test_non_blocking_p2p(self): + # Test creating a pg using nonblocking mode but not eagerly + os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1" + os.environ["TORCH_NCCL_NONBLOCKING_TIMEOUT"] = "100" + store = c10d.FileStore(self.file_name, self.world_size) + device = self.rank_to_GPU[self.rank][0] + self._create_process_group_nccl(store, self.opts()) + # Generate the same tensor + send_tensor = torch.ones(10, 10, device=device) + if self.rank == 0: + dist.send(send_tensor, 1) + if self.rank == 1: + recv_tensor = torch.rand(10, 10, device=device) + dist.recv(recv_tensor, 0) + self.assertEqual(send_tensor, recv_tensor) + dist.destroy_process_group() + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_get_uid(self): @@ -2638,6 +2782,50 @@ def test_nccl_non_blocking_wait_with_barrier(self): "TORCH_NCCL_ASYNC_ERROR_HANDLING" ] = prev_nccl_async_error_handling + @requires_nccl() + @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") + @skip_if_lt_x_gpu(3) + def test_get_future_result(self): + def assert_fut_success(fut): + self.assertEqual(WorkResult(fut.value()), WorkResult.SUCCESS) + + # test the barrier behavior in the non blocking wait setting + prev_nccl_async_error_handling = os.environ.get( + "TORCH_NCCL_ASYNC_ERROR_HANDLING", None + ) + # avoid watchdog thread interference + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" + store = c10d.FileStore(self.file_name, self.world_size) + process_group = c10d.ProcessGroupNCCL( + store, + self.rank, + self.world_size, + timeout=timedelta(seconds=2), + ) + barrier_work = process_group.barrier() + barrier_work.wait() + barrier_result = barrier_work.get_future_result().wait() + self.assertEqual(WorkResult(barrier_result), WorkResult.SUCCESS) + ar_work = process_group.allreduce(torch.rand(10).cuda(self.rank)) + ar_work.wait() + fut = ar_work.get_future_result() + # test adding a callback function + fut.then(assert_fut_success) + if self.rank == 0: + work = process_group.allreduce(torch.rand(10).cuda(self.rank)) + work.wait() + result = work.get_future_result().wait() + self.assertEqual(WorkResult(result), WorkResult.TIMEOUT) + else: + # other ranks not exiting before rank 0 timeout, this is to avoid + # nccl error happening before rank 0 timeouts + time.sleep(4) + + if prev_nccl_async_error_handling is not None: + os.environ[ + "TORCH_NCCL_ASYNC_ERROR_HANDLING" + ] = prev_nccl_async_error_handling + def _run_invalid_nccl_blocking_wait_env(self, val): os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val store = c10d.FileStore(self.file_name, self.world_size) @@ -4036,9 +4224,9 @@ def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled): self.assertEqual( t["entries"][p2p_op_idx]["profiling_name"], profiling_name ) - self.assertEqual( - t["entries"][p2p_op_idx]["collective_seq_id"], expected_seq - ) + # we don't increment collective_seq_id for p2p ops. + self.assertEqual(t["entries"][p2p_op_idx]["collective_seq_id"], 0) + self.assertEqual(t["entries"][p2p_op_idx]["p2p_seq_id"], expected_seq) self.assertEqual(t["entries"][p2p_op_idx]["op_id"], expected_op_id) expected_op_id += 1 self.assertEqual(t["entries"][p2p_op_idx]["input_sizes"], [input_sizes]) @@ -4058,9 +4246,7 @@ def test_batched_send_recv(self, op_sizes_per_coalesce, timing_enabled): self.assertEqual( t["entries"][coalesced_op]["profiling_name"], "nccl:coalesced" ) - self.assertEqual( - t["entries"][coalesced_op]["collective_seq_id"], expected_seq - ) + self.assertEqual(t["entries"][coalesced_op]["p2p_seq_id"], expected_seq) expected_seq += 1 self.assertEqual(t["entries"][coalesced_op]["state"], "completed") self.assertEqual(t["entries"][coalesced_op]["input_sizes"], []) @@ -4117,6 +4303,8 @@ def test_individual_send_recv(self, op_sizes, timing_enabled): input_sizes = op_sizes[seq % ops_per_repeat] profiling_name = "nccl:recv 0<-1" if self.rank == 0 else "nccl:send 1->0" self.assertEqual(t["entries"][seq]["profiling_name"], profiling_name) + # we don't increment collective_seq_id for p2p ops. + self.assertEqual(t["entries"][seq]["collective_seq_id"], 0) self.assertEqual(t["entries"][seq]["p2p_seq_id"], expected_seq) expected_seq += 1 self.assertEqual(t["entries"][seq]["op_id"], expected_op_id) @@ -4173,10 +4361,11 @@ def test_coalescing_manager_collective(self, timing_enabled): self.assertEqual( len(t["entries"]), 1 - ) # one for the reduce_scatter_tensor_coalesced, one for the endCoalescing + ) # one for the reduce_scatter_tensor_coalesced self.assertEqual( t["entries"][0]["profiling_name"], "nccl:reduce_scatter_tensor_coalesced" ) + # collective_seq_id should be incremented once. self.assertEqual(t["entries"][0]["collective_seq_id"], 1) self.assertEqual(t["entries"][0]["input_sizes"], [[2, 2], [2, 2]]) self.assertEqual( diff --git a/test/distributed/test_c10d_object_collectives.py b/test/distributed/test_c10d_object_collectives.py index ece50ebe8890b..dcd6de797e725 100644 --- a/test/distributed/test_c10d_object_collectives.py +++ b/test/distributed/test_c10d_object_collectives.py @@ -24,7 +24,6 @@ sys.exit(0) BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO -WORLD_SIZE = min(4, max(2, torch.cuda.device_count())) def with_comms(func=None): @@ -54,14 +53,16 @@ def setUp(self): @property def device(self): return ( - torch.device(self.rank) + torch.device("cuda", self.rank % torch.cuda.device_count()) if BACKEND == dist.Backend.NCCL else torch.device("cpu") ) @property def world_size(self): - return WORLD_SIZE + if BACKEND == dist.Backend.NCCL: + return torch.cuda.device_count() + return super().world_size @property def process_group(self): diff --git a/test/distributed/test_c10d_ops_nccl.py b/test/distributed/test_c10d_ops_nccl.py index c9fb0f30b53f9..f0249877c63bb 100644 --- a/test/distributed/test_c10d_ops_nccl.py +++ b/test/distributed/test_c10d_ops_nccl.py @@ -28,6 +28,7 @@ init_multigpu_helper, MultiProcContinousTest, requires_nccl, + TEST_SKIPS, ) from torch.testing._internal.common_utils import ( skip_but_pass_in_sandcastle_if, @@ -278,16 +279,21 @@ def test_allreduce_in_cudagraph(self): # single warmup pg.allreduce(xs).wait() - self.assertEqual(xs[0].item(), 2) + # 1 + 1 + ... = world_size + expected_val = self.world_size + self.assertEqual(xs[0].item(), expected_val) graph = torch.cuda.CUDAGraph() with torch.cuda.graph(graph): pg.allreduce(xs).wait() - self.assertEqual(xs[0].item(), 2) + # Graph capture should not change the tensor value + self.assertEqual(xs[0].item(), expected_val) graph.replay() + expected_val *= self.world_size graph.replay() - self.assertEqual(xs[0].item(), 8) + expected_val *= self.world_size + self.assertEqual(xs[0].item(), expected_val) @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") @@ -979,8 +985,14 @@ def allgather_base(output_t, input_t): if __name__ == "__main__": + if not torch.cuda.is_available(): + sys.exit(TEST_SKIPS["no_cuda"].exit_code) + rank = int(os.getenv("RANK", -1)) - world_size = int(os.getenv("WORLD_SIZE", 2)) + world_size = int(os.getenv("WORLD_SIZE", -1)) + + if world_size == -1: # Not set by external launcher + world_size = torch.cuda.device_count() if rank != -1: # Launched with torchrun or other multi-proc launchers. Directly run the test. diff --git a/test/distributed/test_compute_comm_reordering.py b/test/distributed/test_compute_comm_reordering.py index bb0e54484b94e..a2780b55c203f 100644 --- a/test/distributed/test_compute_comm_reordering.py +++ b/test/distributed/test_compute_comm_reordering.py @@ -28,6 +28,7 @@ DynamoDistributedMultiProcTestCase, requires_nccl, ) +from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.inductor_utils import HAS_GPU @@ -158,7 +159,6 @@ def func(a): inputs = torch.ones(4, 4, dtype=torch.float, device="cuda") + self.rank compiled = torch.compile(func) code = run_and_get_triton_code(compiled, inputs) - print(code) # Verify that the all_reduce_ has been raised above the 2nd matmul # but below the 1st matmul. Note that the all_reduce_ directly # writes to the output buffer of the 1st matmul, which is an input @@ -272,7 +272,7 @@ def func(a, *, tag, ranks, group_size): .check("extern_kernels.mm") .check("extern_kernels.mm") .check("torch.ops._c10d_functional.wait_tensor.default") - .check("triton_poi_fused_mul") + .check("triton_poi_fused_all_reduce_mul") .check("torch.ops._c10d_functional.all_reduce_.default") .check("torch.ops._c10d_functional.wait_tensor.default") .check("triton_poi_fused_add") @@ -329,7 +329,7 @@ def func(a, *, tag, ranks, group_size): .check("extern_kernels.mm") .check("extern_kernels.mm") .check("torch.ops._c10d_functional.wait_tensor.default") - .check("triton_poi_fused_mul") + .check("triton_poi_fused_all_reduce_mul") .check("torch.ops._c10d_functional.all_reduce_.default") .check("torch.ops._c10d_functional.wait_tensor.default") .check("triton_poi_fused_add") @@ -341,6 +341,7 @@ def func(a, *, tag, ranks, group_size): self.assertTrue(same(out, correct)) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") + @skipIfRocm # TODO: somehow inductor bg compile threads are causing hangs at exit with distributed work dtor @patch.object(torch._inductor.config, "compile_threads", 1) @patch.object( @@ -371,7 +372,7 @@ def func(a, *, tag, ranks, group_size): # still happens among nodes within a GroupedSchedulerNode. # 2. `mul = a * a` is not fused with `add` or `div`, because the latter two are within # GroupedSchedulerNode and thus are prevented from being fused with any outside ops. - FileCheck().check("triton_poi_fused_add_div_0.").check( + FileCheck().check("triton_poi_fused_add_all_reduce_div_0.").check( "_c10d_functional.all_reduce_." ).check("triton_poi_fused_mul_1.").run(code) out = compiled(inputs, **self.get_world_trs()) diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 7cbbe1a9e7145..64bd6ebdea75e 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -88,7 +88,27 @@ def test_assert_invalid_mesh_tensor(self): with self.assertRaises(ValueError): device_mesh = DeviceMesh(self.device_type, mesh) - @with_comms + @with_comms() + def test_2d_mesh_non_eager_init_subgroup(self): + mesh_shape = (2, self.world_size // 2) + mesh_2d = init_device_mesh(self.device_type, mesh_shape) + + self.assertEqual(mesh_2d.get_group(0).bound_device_id, None) + self.assertEqual(mesh_2d.get_group(1).bound_device_id, None) + + # TODO: need to refactor the other tests in this file to test both + # eager_init=True and eager_init=False scenarios. + @skip_if_lt_x_gpu(4) + @with_comms(eager_init=True) + def test_2d_mesh_eager_init_subgroup(self): + mesh_shape = (2, self.world_size // 2) + mesh_2d = init_device_mesh(self.device_type, mesh_shape) + + curr_device = torch.cuda.current_device() + self.assertEqual(mesh_2d.get_group(0).bound_device_id.index, curr_device) + self.assertEqual(mesh_2d.get_group(1).bound_device_id.index, curr_device) + + @with_comms() def test_get_group_and_get_all_groups(self): mesh_shape = (2, self.world_size // 2) mesh_2d = init_device_mesh( @@ -200,6 +220,15 @@ def test_from_group_with_global_pg(self): self.assertEqual( ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim ) + # Check when `mesh` is passed as well + global_mesh = DeviceMesh.from_group( + mesh_pg, self.device_type, mesh=torch.arange(self.world_size) + ) + self.assertEqual(ref_global_mesh, global_mesh) + self.assertEqual(ref_global_mesh._dim_group_infos, global_mesh._dim_group_infos) + self.assertEqual( + ref_global_mesh._coordinate_on_dim, global_mesh._coordinate_on_dim + ) @with_comms def test_from_group_with_invalid_mesh(self): diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index 078f319adaf7f..5394a515aad33 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -32,6 +32,7 @@ lambda_auto_wrap_policy, transformer_auto_wrap_policy, ) +from torch.nn.attention.flex_attention import flex_attention from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, @@ -1293,6 +1294,118 @@ def opt_fn(inputs): self.assertEqual(len(break_reasons), 4) self.assertTrue(all("DDPOptimizer" in r.reason for r in break_reasons)) + @patch.object(config, "optimize_ddp", True) + def test_compiled_flex_attention_full_model_ddp(self): + class Model(torch.nn.Module): + def __init__(self, S, H, D): + super().__init__() + + self.S = S + self.H = H + self.D = D + + alibi_bias = self.generate_alibi_bias(H) + self.register_buffer("alibi_bias", alibi_bias, persistent=True) + self.attention = flex_attention + + self.project_qk = torch.nn.Linear(H * D, H * D * 2) + self.project_v = torch.nn.Linear(H * D, H * D) + + def forward(self, hidden_states): + batch_size, _, _ = hidden_states.size() + + query, key = self.project_qk(hidden_states).chunk(2, dim=2) + query = query.view(self.S, batch_size, self.H, self.D) + query = query.permute(1, 2, 0, 3) + + key = key.view(self.S, batch_size, self.H, self.D) + key = key.permute(1, 2, 0, 3) + + value = self.project_v(hidden_states) + value = value.view(self.S, batch_size, self.H, self.D) + value = value.permute(1, 2, 0, 3) + + return self.attention(query, key, value, score_mod=self.alibi_score_mod) + + def generate_alibi_bias(self, num_heads): + alibi_bias = [-((i + 1) * 8.0) / num_heads for i in range(num_heads)] + return torch.tensor(alibi_bias) + + def alibi_score_mod(self, score, b, h, q_idx, kv_idx): + bias = (q_idx - kv_idx) * self.alibi_bias[h] + return score + bias + + B = 16 + H = 12 + S = 512 + D = 64 + + device = "cuda" + model = Model(S, H, D) + model.to(device) + model = torch.compile(model) + model = DDP(model, device_ids=self.device_ids) + + hidden_states = torch.randn(B, S, H * D).to(device) + attention_scores = model(hidden_states) + torch.cuda.synchronize() + + @patch.object(config, "optimize_ddp", True) + def test_compiled_flex_attention_local_ddp(self): + class Model(torch.nn.Module): + def __init__(self, S, H, D): + super().__init__() + + self.S = S + self.H = H + self.D = D + + alibi_bias = self.generate_alibi_bias(H) + self.register_buffer("alibi_bias", alibi_bias, persistent=True) + self.attention = torch.compile(flex_attention) + + self.project_qk = torch.nn.Linear(H * D, H * D * 2) + self.project_v = torch.nn.Linear(H * D, H * D) + + def forward(self, hidden_states): + batch_size, _, _ = hidden_states.size() + + query, key = self.project_qk(hidden_states).chunk(2, dim=2) + query = query.view(self.S, batch_size, self.H, self.D) + query = query.permute(1, 2, 0, 3) + + key = key.view(self.S, batch_size, self.H, self.D) + key = key.permute(1, 2, 0, 3) + + value = self.project_v(hidden_states) + value = value.view(self.S, batch_size, self.H, self.D) + value = value.permute(1, 2, 0, 3) + + return self.attention(query, key, value, score_mod=self.alibi_score_mod) + + def generate_alibi_bias(self, num_heads): + alibi_bias = [-((i + 1) * 8.0) / num_heads for i in range(num_heads)] + return torch.tensor(alibi_bias) + + def alibi_score_mod(self, score, b, h, q_idx, kv_idx): + bias = (q_idx - kv_idx) * self.alibi_bias[h] + return score + bias + + B = 16 + H = 12 + S = 512 + D = 64 + + device = "cuda" + model = Model(S, H, D) + model.to(device) + model = torch.compile(model) + model = DDP(model, device_ids=self.device_ids) + + hidden_states = torch.randn(B, S, H * D).to(device) + attention_scores = model(hidden_states) + torch.cuda.synchronize() + @patch.object(config, "optimize_ddp", True) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_graph_split_inductor(self): @@ -1548,11 +1661,7 @@ def forward(self, x): backend = "aot_eager" cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) - with self.assertRaisesRegex( - torch._dynamo.exc.BackendCompilerFailed, - "DDPOptimizer backend: Found a higher order op in the graph", - ): - torch.compile(mod, backend=cnt)(*args) + torch.compile(mod, backend=cnt)(*args) def test_fsdp_orig_params_assert(self): # Test with basic FSDP wrapping (outer wrap around whole model) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 216643e59cdec..f59c471a0f978 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -156,7 +156,9 @@ def func(x): ) for nelem in [1024, 2048, 4096]: - x = torch.randn(nelem, device="cuda", dtype=torch.bfloat16) + # CI (Tesla T4) does not support bfloat16 compilation natively, + # using float + x = torch.randn(nelem, device="cuda", dtype=torch.float) golden_out = eager_func(x) for _ in range(3): diff --git a/test/distributed/test_nccl.py b/test/distributed/test_nccl.py index ebf03e7ae1ddd..f9bb4f6543ee5 100644 --- a/test/distributed/test_nccl.py +++ b/test/distributed/test_nccl.py @@ -45,6 +45,13 @@ ) or TEST_WITH_ROCM: datatypes.append(torch.bfloat16) +# Broadcast (and alltoall) support float8, while reduce and allreduce do not support float8 currently +broadcast_dtypes = ( + datatypes + [torch.float8_e4m3fnuz, torch.float8_e5m2fnuz] + if TEST_WITH_ROCM + else [torch.float8_e4m3fn, torch.float8_e5m2] +) + class TestNCCL(TestCase): @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows") @@ -58,7 +65,7 @@ def test_unique_id(self, device): ) @skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows") @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "only one GPU detected") - @dtypes(*datatypes) + @dtypes(*broadcast_dtypes) def test_broadcast(self, device, dtype): expected = torch.zeros(128).uniform_().to(dtype=dtype) tensors = [expected.cuda()] diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index 9e9577bb79165..b2976abd0875f 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -516,6 +516,11 @@ def test_store_timeout_on_missing_clients(self): use_libuv=self._use_libuv, ) + @skip_if_win32() + def test_world_size_0_raises(self): + with self.assertRaisesRegex(ValueError, "TCPStore world size cannot be 0"): + dist.TCPStore("localhost", 0, world_size=0, is_master=False) + class LibUvTCPStoreTest(TCPStoreTest): _use_libuv = True diff --git a/test/distributed/test_symmetric_memory.py b/test/distributed/test_symmetric_memory.py index c1f183c300da5..0f61e7577e950 100644 --- a/test/distributed/test_symmetric_memory.py +++ b/test/distributed/test_symmetric_memory.py @@ -1,9 +1,12 @@ # Owner(s): ["module: c10d"] +import os + import torch import torch.distributed as dist from torch._C._autograd import DeviceType from torch._C._distributed_c10d import _SymmetricMemory +from torch.distributed._functional_collectives import all_gather_tensor from torch.distributed._symmetric_memory import ( _fused_all_gather_matmul_fallback, _fused_all_gather_scaled_matmul_fallback, @@ -83,11 +86,23 @@ def _init_process(self): store=store, ) enable_symm_mem_for_group(dist.group.WORLD.group_name) + torch.manual_seed(42 + self.rank) + + def _get_test_alloc_args(self): + shape = (64, 64) + stride = (64, 1) + dtype = torch.float32 + device = self.device + group_name = "0" + return (shape, stride, dtype, device, group_name) def _verify_symmetric_memory(self, symm_mem): self.assertEqual(symm_mem.world_size, 2) - buf = symm_mem.get_buffer(0, (64, 64), torch.float32) + buf = symm_mem.get_buffer(0, (symm_mem.buffer_size // 4,), torch.float32) + self.assertEqual(buf.storage_offset(), 0) + self.assertEqual(buf.storage().size(), symm_mem.buffer_size // 4) + if symm_mem.rank == 0: symm_mem.wait_signal(src_rank=1) self.assertTrue(buf.eq(42).all()) @@ -123,14 +138,9 @@ def test_cuda_nvlink_connectivity_detection(self) -> None: def test_empty_strided_p2p(self) -> None: self._init_process() - shape = (64, 64) - stride = (64, 1) - dtype = torch.float32 - device = self.device - group_name = "0" - alloc_args = (shape, stride, dtype, device, group_name) + alloc_args = self._get_test_alloc_args() - t = torch.empty(shape, dtype=dtype, device=device) + t = torch.empty((64, 64), device=self.device) self.assertIsNone(_SymmetricMemory.rendezvous(t)) t = _SymmetricMemory.empty_strided_p2p(*alloc_args) @@ -145,27 +155,21 @@ def test_empty_strided_p2p(self) -> None: def test_empty_strided_p2p_persistent(self) -> None: self._init_process() - shape = (64, 64) - stride = (64, 1) - dtype = torch.float32 - device = self.device - alloc_id = 42 # Persistent allocation - group_name = "0" - alloc_args = (shape, stride, dtype, device, group_name, alloc_id) + alloc_args = self._get_test_alloc_args() - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + t = _SymmetricMemory.empty_strided_p2p(*alloc_args, alloc_id=42) data_ptr = t.data_ptr() # Verify that persistent allocation would fail if there's an active # allocation with the same alloc_id. with self.assertRaises(RuntimeError): - _SymmetricMemory.empty_strided_p2p(*alloc_args) + _SymmetricMemory.empty_strided_p2p(*alloc_args, alloc_id=42) # Verify that persistent allocation would succeed in lieu of activate # allocations with the same alloc_id, and the returned tensor would # have the same data pointer. del t - t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + t = _SymmetricMemory.empty_strided_p2p(*alloc_args, alloc_id=42) self.assertEqual(t.data_ptr(), data_ptr) # Verify that get_symmetric_memory would fail if called before @@ -180,6 +184,78 @@ def test_empty_strided_p2p_persistent(self) -> None: self._verify_symmetric_memory(symm_mem_0) dist.destroy_process_group() + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_barrier_timeout(self) -> None: + self._init_process() + + alloc_args = self._get_test_alloc_args() + + t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + symm_mem = _SymmetricMemory.rendezvous(t) + + if self.rank == 0: + with self.assertRaises(RuntimeError): + symm_mem.barrier(timeout_ms=1000) + torch.cuda.synchronize() + else: + torch.cuda.synchronize() + + # The device-side timeout triggers a __trap() that causes all + # subsequent host/device interactions to result in an "unspecified + # launch failure." Using os._exit(0) to abort the test, as it's + # impossible to terminate the process in this state. + os._exit(0) + + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_put_signal_timeout(self) -> None: + self._init_process() + + alloc_args = self._get_test_alloc_args() + + t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + symm_mem = _SymmetricMemory.rendezvous(t) + + if self.rank == 0: + with self.assertRaises(RuntimeError): + # First, put a signal into rank 1's signal pad. Since rank 1 + # doesn't wait on this signal, the subsequent put will timeout. + symm_mem.put_signal(dst_rank=1) + symm_mem.put_signal(dst_rank=1, timeout_ms=1000) + torch.cuda.synchronize() + else: + torch.cuda.synchronize() + + # The device-side timeout triggers a __trap() that causes all + # subsequent host/device interactions to result in an "unspecified + # launch failure." Using os._exit(0) to abort the test, as it's + # impossible to terminate the process in this state. + os._exit(0) + + @skipIfRocm + @skip_if_lt_x_gpu(2) + def test_wait_signal_timeout(self) -> None: + self._init_process() + + alloc_args = self._get_test_alloc_args() + + t = _SymmetricMemory.empty_strided_p2p(*alloc_args) + symm_mem = _SymmetricMemory.rendezvous(t) + + if self.rank == 0: + with self.assertRaises(RuntimeError): + symm_mem.wait_signal(src_rank=1, timeout_ms=1000) + torch.cuda.synchronize() + else: + torch.cuda.synchronize() + + # The device-side timeout triggers a __trap() that causes all + # subsequent host/device interactions to result in an "unspecified + # launch failure." Using os._exit(0) to abort the test, as it's + # impossible to terminate the process in this state. + os._exit(0) + @skipIfRocm @skip_if_lt_x_gpu(2) @parametrize("gather_dim", [0, 1]) @@ -216,7 +292,12 @@ def test_fused_all_gather_matmul(self, gather_dim: int) -> None: @skipIfRocm @skip_if_lt_x_gpu(2) @parametrize("gather_dim", [0, 1]) - def test_fused_all_gather_scaled_matmul(self, gather_dim: int) -> None: + @parametrize( + "scale_mode", ["tensor-wise", "row-wise-replicated", "row-wise-sharded"] + ) + def test_fused_all_gather_scaled_matmul( + self, gather_dim: int, scale_mode: str + ) -> None: self._init_process() BATCH = 8 @@ -227,16 +308,33 @@ def test_fused_all_gather_scaled_matmul(self, gather_dim: int) -> None: rank = self.rank world_size = self.world_size + if gather_dim == 0: + leading_dims = (BATCH // self.world_size, M) + elif gather_dim == 1: + leading_dims = (BATCH, M // self.world_size) + else: + raise AssertionError("Invalid scale_mode: {scale_mode}") + torch.manual_seed(42 + rank) - A_shard = torch.rand(BATCH, M // self.world_size, K, device="cuda").to( - torch.float8_e4m3fn - ) - A_scale = torch.tensor(0.1, device="cuda") + A_shard = torch.rand(*leading_dims, K, device="cuda").to(torch.float8_e4m3fn) Bs = [ torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T for _ in range(3) ] - B_scales = [torch.tensor(0.1, device="cuda") for _ in range(3)] - out_dtypes = [None, torch.bfloat16, torch.float32] + + if scale_mode == "tensor-wise": + A_scale = torch.tensor(0.1, device="cuda") + B_scales = [torch.tensor(0.1, device="cuda") for _ in range(3)] + out_dtypes = [None, torch.bfloat16, torch.float32] + elif scale_mode == "row-wise-sharded": + A_scale = torch.full((*leading_dims, 1), 0.1, device="cuda") + B_scales = [torch.full((1, N), 0.1, device="cuda") for _ in range(3)] + out_dtypes = [torch.bfloat16] * 3 + elif scale_mode == "row-wise-replicated": + A_scale = torch.full((BATCH, M, 1), 0.1, device="cuda") + B_scales = [torch.full((1, N), 0.1, device="cuda") for _ in range(3)] + out_dtypes = [torch.bfloat16] * 3 + else: + raise AssertionError(f"Invalid scale_mode: {scale_mode}") ag_output_0, mm_outputs_0 = _fused_all_gather_scaled_matmul_fallback( A_shard, @@ -314,7 +412,10 @@ def test_fused_matmul_reduce_scatter(self, scatter_dim: int) -> None: @skipIfRocm @skip_if_lt_x_gpu(2) @parametrize("scatter_dim", [0, 1]) - def test_fused_scaled_matmul_reduce_scatter(self, scatter_dim: int) -> None: + @parametrize("rowwise", [True, False]) + def test_fused_scaled_matmul_reduce_scatter( + self, scatter_dim: int, rowwise: bool + ) -> None: self._init_process() BATCH = 8 @@ -327,9 +428,14 @@ def test_fused_scaled_matmul_reduce_scatter(self, scatter_dim: int) -> None: torch.manual_seed(42 + rank) A = torch.rand(BATCH, M, K, device="cuda").to(torch.float8_e4m3fn) - A_scale = torch.tensor(0.1, device="cuda") B = torch.rand(N, K, device="cuda").to(torch.float8_e4m3fn).T - B_scale = torch.tensor(0.1, device="cuda") + + if rowwise: + A_scale = torch.full((BATCH, M, 1), 0.1, device="cuda") + B_scale = torch.full((1, N), 0.1, device="cuda") + else: + A_scale = torch.tensor(0.1, device="cuda") + B_scale = torch.tensor(0.1, device="cuda") output_0 = _fused_scaled_matmul_reduce_scatter_fallback( A, @@ -435,7 +541,60 @@ def test_low_contention_reduce_scatter( dist.destroy_process_group() + @skipIfRocm @skip_if_lt_x_gpu(2) + def test_stream_write_value(self): + self._init_process() + group_name = dist.group.WORLD.group_name + + t = _SymmetricMemory.empty_strided_p2p( + size=(64,), + stride=(1,), + dtype=torch.float32, + device=self.device, + group_name=group_name, + ).fill_(self.rank + 42) + symm_mem = _SymmetricMemory.rendezvous(t) + + tensor = torch.zeros(4, dtype=torch.uint32, device=self.device) + expect = torch.tril(torch.ones(4, 4, device=self.device)).to(torch.uint32) + + for i in range(4): + symm_mem.stream_write_value32( + int(tensor.data_ptr()) + i * tensor.element_size(), 1 + ) + torch.testing.assert_close(tensor, expect[i]) + + +@instantiate_parametrized_tests +@requires_cuda_p2p_access() +class SymmMemAllReduceTest(MultiProcessTestCase): + def setUp(self) -> None: + super().setUp() + self._spawn_processes() + + @property + def world_size(self) -> int: + # world_size > 2 is needed to verify accumulation order + return 4 + + @property + def device(self) -> torch.device: + return torch.device(f"cuda:{self.rank}") + + def _init_process(self): + torch.cuda.set_device(self.device) + store = dist.FileStore(self.file_name, self.world_size) + dist.init_process_group( + backend="nccl", + world_size=self.world_size, + rank=self.rank, + store=store, + ) + enable_symm_mem_for_group(dist.group.WORLD.group_name) + torch.manual_seed(42 + self.rank) + + @skip_if_lt_x_gpu(4) @requires_multicast_support() @parametrize("dtype", [torch.float, torch.bfloat16]) @parametrize("align_bytes", [4, 8, 16]) @@ -452,7 +611,7 @@ def test_multimem_all_reduce( dtype=dtype, device=self.device, group_name=group_name, - ).fill_(1) + ).fill_(0) self.assertTrue(t.data_ptr() % 16 == 0) self.assertTrue(align_bytes % t.element_size() == 0) @@ -460,17 +619,20 @@ def test_multimem_all_reduce( shift = align_bytes // t.element_size() numel = size_bytes // t.element_size() - x = t[shift : shift + numel] + res = t[shift : shift + numel] + res.normal_() + inp = res.clone() - torch.ops.symm_mem.multimem_all_reduce_(x, "sum", group_name) - self.assertTrue(x.eq(self.world_size).all().item()) + torch.ops.symm_mem.multimem_all_reduce_(res, "sum", group_name) # Head and tail should not be written - self.assertTrue(t[:shift].eq(1).all().item()) - self.assertTrue(t[shift + numel :].eq(1).all().item()) + self.assertTrue(t[:shift].eq(0).all().item()) + self.assertTrue(t[shift + numel :].eq(0).all().item()) + self._verify_all_reduce_result(inp, res) + dist.destroy_process_group() - @skip_if_lt_x_gpu(2) + @skip_if_lt_x_gpu(4) @requires_multicast_support() @parametrize("dtype", [torch.float, torch.bfloat16]) @parametrize("align_bytes", [4, 8, 16]) @@ -481,6 +643,58 @@ def test_multimem_one_shot_all_reduce( self._init_process() group_name = dist.group.WORLD.group_name + inp = _SymmetricMemory.empty_strided_p2p( + size=(size_bytes,), + stride=(1,), + dtype=dtype, + device=self.device, + group_name=group_name, + ).normal_() + + res = torch.ops.symm_mem.multimem_one_shot_all_reduce(inp, "sum", group_name) + + gathered_inps = all_gather_tensor(inp, 0, "0").view(self.world_size, -1) + # Only verify that the results are close to the sum of inputs across + # ranks (see Note [multimem_one_shot_all_reduce]). + torch.testing.assert_close( + gathered_inps.sum(dim=0), res, rtol=1e-03, atol=1e-05 + ) + + dist.destroy_process_group() + + @skip_if_lt_x_gpu(4) + @parametrize("dtype", [torch.float, torch.bfloat16]) + @parametrize("align_bytes", [4, 8, 16]) + @parametrize("size_bytes", [4, 8192, 8196]) + def test_one_shot_all_reduce( + self, dtype: torch.dtype, size_bytes: int, align_bytes: int + ) -> None: + self._init_process() + group_name = dist.group.WORLD.group_name + + inp = _SymmetricMemory.empty_strided_p2p( + size=(size_bytes,), + stride=(1,), + dtype=dtype, + device=self.device, + group_name=group_name, + ).normal_() + + res = torch.ops.symm_mem.one_shot_all_reduce(inp, "sum", group_name) + self._verify_all_reduce_result(inp, res) + + dist.destroy_process_group() + + @skip_if_lt_x_gpu(4) + @parametrize("dtype", [torch.float, torch.bfloat16]) + @parametrize("align_bytes", [4, 8, 16]) + @parametrize("size_bytes", [4, 8192, 8196]) + def test_two_shot_all_reduce( + self, dtype: torch.dtype, size_bytes: int, align_bytes: int + ) -> None: + self._init_process() + group_name = dist.group.WORLD.group_name + t = _SymmetricMemory.empty_strided_p2p( size=(16384,), stride=(1,), @@ -495,36 +709,31 @@ def test_multimem_one_shot_all_reduce( shift = align_bytes // t.element_size() numel = size_bytes // t.element_size() - x = t[shift : shift + numel] - x.fill_(1) + res = t[shift : shift + numel] + res.normal_() + inp = res.clone() - res = torch.ops.symm_mem.multimem_one_shot_all_reduce(x, "sum", group_name) - self.assertTrue(res.eq(self.world_size).all().item()) - dist.destroy_process_group() + torch.ops.symm_mem.two_shot_all_reduce_(res, "sum", group_name) - @skipIfRocm - @skip_if_lt_x_gpu(2) - def test_stream_write_value(self): - self._init_process() - group_name = dist.group.WORLD.group_name + # Head and tail should not be written + self.assertTrue(t[:shift].eq(0).all().item()) + self.assertTrue(t[shift + numel :].eq(0).all().item()) + self._verify_all_reduce_result(inp, res) - t = _SymmetricMemory.empty_strided_p2p( - size=(64,), - stride=(1,), - dtype=torch.float32, - device=self.device, - group_name=group_name, - ).fill_(self.rank + 42) - symm_mem = _SymmetricMemory.rendezvous(t) + dist.destroy_process_group() - tensor = torch.zeros(4, dtype=torch.uint32, device=self.device) - expect = torch.tril(torch.ones(4, 4, device=self.device)).to(torch.uint32) + def _verify_all_reduce_result(self, inp, res): + gathered_res = all_gather_tensor(res, 0, "0").view(self.world_size, -1) + # Verify that the results across ranks are identical + self.assertEqual( + (gathered_res == gathered_res[0, :]).all(dim=0).sum(), inp.numel() + ) - for i in range(4): - symm_mem.stream_write_value32( - int(tensor.data_ptr()) + i * tensor.element_size(), 1 - ) - torch.testing.assert_close(tensor, expect[i]) + # Verify that the result are close to the sum of inputs across ranks + gathered_inps = all_gather_tensor(inp, 0, "0").view(self.world_size, -1) + torch.testing.assert_close( + gathered_inps.sum(dim=0), res, rtol=1e-01, atol=1e-01 + ) if __name__ == "__main__": diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index e82ad38d1fbd5..afdc7bcadf600 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -86,7 +86,7 @@ def match_rng_op(node, op): def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]): - if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph + if not torch._dynamo.compiled_autograd.in_compiled_autograd_region(): # fwd graph return_node = list(graph.nodes)[-1] assert return_node.target == "output" for x in return_node.args[0]: diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 629e21e0daf94..e9b913ed9559d 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -439,29 +439,6 @@ def f(x): self.assertEqual(result, Foo.apply(x)) self.assertEqual(cnt.frame_count, 1) - def test_fwd_no_grad(self): - # autograd.Function.forward should be traced and called under no_grad mode. - # torch.exp with out=... arguments don't support automatic differentiation, - # so can't be traced/called under grad mode (throwing RuntimeError), - # therefore this unit test ensures fwd is under no_grad mode. - class Foo(torch.autograd.Function): - @staticmethod - def forward(ctx, inputs): - torch.exp(inputs, out=inputs) - return inputs - - @staticmethod - def backward(ctx, grad_output): - return None - - @torch.compile(backend="eager", fullgraph=True) - def f(x): - return Foo.apply(x) - - x1 = torch.randn(2, 3, requires_grad=True) - x2 = x1.clone() - self.assertEqual(f(x1), Foo.apply(x2)) - def test_amp_custom_fwd_bwd(self): torch._dynamo.utils.counters.clear() cnt = torch._dynamo.testing.CompileCounter() @@ -570,13 +547,9 @@ def forward(self, L_x_: "f32[]", L_z_: "f32[]", L_weird_b: "f32[]", L_weird_c: " class fwd_body_0(torch.nn.Module): def forward(self, ctx, x: "f32[]", z: "f32[]", l_weird_b: "f32[]", l_weird_c: "f32[]"): - _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None - mul: "f32[]" = l_weird_b * l_weird_c clone: "f32[]" = x.clone(); x = None mul_1: "f32[]" = mul * clone; mul = clone = None - - _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None return (mul_1, [l_weird_b, l_weird_c]) class bwd_body_0(torch.nn.Module): @@ -1140,13 +1113,9 @@ def forward(self, L_x_: "f32[]", L_y_: "f32[]"): class fwd_body_0(torch.nn.Module): def forward(self, ctx, x: "f32[]", y: "f32[]"): - _set_grad_enabled = torch._C._set_grad_enabled(False); _set_grad_enabled = None - out1: "f32[]" = x.sin(); x = None out2: "f32[]" = y * 2; y = None - - _set_grad_enabled_1 = torch._C._set_grad_enabled(True); _set_grad_enabled_1 = None return ((out1, out2), []) class bwd_body_0(torch.nn.Module): diff --git a/test/dynamo/test_compiler_bisector.py b/test/dynamo/test_compiler_bisector.py new file mode 100644 index 0000000000000..a5671dd4a20a4 --- /dev/null +++ b/test/dynamo/test_compiler_bisector.py @@ -0,0 +1,198 @@ +# Owner(s): ["module: dynamo"] + +import unittest +from contextlib import contextmanager +from importlib import import_module + +import torch +import torch._prims_common as utils +from torch._dynamo.test_case import TestCase +from torch._inductor import config +from torch._inductor.bisect_helper import BisectionManager +from torch.library import _scoped_library, Library +from torch.testing._internal.inductor_utils import HAS_CUDA + + +aten = torch.ops.aten + +requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") + +f32 = torch.float32 +i64 = torch.int64 +i32 = torch.int32 + + +@requires_cuda +class TestCompilerBisector(TestCase): + test_ns = "_test_bisector" + + def tearDown(self): + if hasattr(torch.ops, self.test_ns): + delattr(torch.ops, self.test_ns) + if hasattr(self, "lib"): + del self.lib.m + del self.lib + + def get_op(self, name): + return getattr(getattr(torch.ops, self.test_ns), name).default + + def get_lib(self): + lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901 + self.lib = lib + return lib + + def test_bad_decomp(self): + mod = import_module("torch._inductor.compile_fx") + + def bad_exp_decomp(self, rate=1, generator=None): + assert generator is None + torch._check( + not utils.is_complex_dtype(self.dtype) + and not utils.is_integer_dtype(self.dtype) + and not utils.is_boolean_dtype(self.dtype), + lambda: f"Exponential distribution is a continuous probability distribution. \ + dtype must be a floating point but you specified {self.dtype}", + ) + torch._check( + rate > 0.0, + lambda: f"exponential_ expects lambda > 0.0, but found lambda={rate}", + ) + return torch.rand_like(self) * float("nan") + + @contextmanager + def patch_exp_decomp(): + from torch._inductor.compile_fx import select_decomp_table as old_decomp + + def get_decomp(): + out = old_decomp() + out = out.copy() + out[aten.exponential.default] = bad_exp_decomp + return out + + torch._inductor.compile_fx.select_decomp_table = get_decomp + try: + yield + + finally: + torch._inductor.compile_fx.select_decomp_table = old_decomp + + def vq(x): + return (x + 3).exponential_() * 10.5 + + def test_fn(): + torch._dynamo.reset() + with patch_exp_decomp(): + vq_compiled = torch.compile(vq) + x = torch.randn(4, 400, 256).cuda() + with torch._dynamo.utils.preserve_rng_state(): + out = vq(x) + out_compiled = vq_compiled(x) + + return not out_compiled.isnan().any() + + out = BisectionManager.do_bisect(test_fn) + self.assertEqual(out.backend, "aot_eager_decomp_partition") + self.assertEqual(out.subsystem, "decomposition") + self.assertEqual(out.bisect_number, 1) + self.assertTrue("aten.exponential" in out.debug_info) + + def test_crossref(self): + test_ns = "bisect_ops" + with _scoped_library(self.test_ns, "FRAGMENT") as lib: + lib.define("foo(Tensor x) -> Tensor") + op = self.get_op("foo") + + class Foo(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + # Emulate AutoDispatchBelowADInplaceOrView, which is not bound into python + with torch._C._AutoDispatchBelowAutograd(): + with torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet( + torch._C.DispatchKey.ADInplaceOrView + ) + ): + return op(x) + + @staticmethod + def backward(ctx, gx): + return gx + + def foo_impl(x): + return x.view_as(x).clone() + + def foo_meta(x): + return x.view_as(x) + + lib.impl("foo", Foo.apply, "Autograd") + lib.impl("foo", foo_impl, "CPU") + lib.impl("foo", foo_meta, "Meta") + + x = torch.tensor(3.14159 / 3, requires_grad=True) + + def test_fn(): + torch._dynamo.reset() + + try: + torch.testing.assert_allclose(torch.compile(op)(x), op(x)) + except Exception: + return False + return True + + out = BisectionManager.do_bisect(test_fn) + self.assertEqual(out.backend, "aot_eager_decomp_partition_crossref") + + def test_emulate_precision_casts(self): + def test_fn(): + torch._dynamo.reset() + + def calculate_scale(inp): + amax = torch.abs(torch.max(inp)) + scale = 448.0 / torch.clamp(amax, min=1e-12) + scale = scale.to(torch.float32) + return scale + + dtype = torch.bfloat16 + torch.manual_seed(0) + inp = torch.randn(16, 16, 768, dtype=dtype, device="cuda") + eager_scale = calculate_scale(inp) + compile_scale = torch.compile(calculate_scale)(inp) + + return torch.equal(eager_scale, compile_scale) + + out = BisectionManager.do_bisect(test_fn) + self.assertEqual(out.backend, "inductor") + self.assertEqual(out.subsystem, "inductor_emulate_precision_casts") + + def test_bad_lowering(self): + def test_fn(): + torch._dynamo.reset() + with config.patch("triton.inject_relu_bug_TESTING_ONLY", "accuracy"): + + def my_func(x): + return ((x * -1) - 0.01).relu() + + inp = torch.rand([100], device="cuda") + + return torch.allclose(torch.compile(my_func)(inp), my_func(inp)) + + out = BisectionManager.do_bisect(test_fn) + self.assertEqual(out.backend, "inductor") + self.assertEqual(out.subsystem, "lowerings") + self.assertEqual(out.bisect_number, 2) + self.assertTrue("relu" in out.debug_info) + + def test_eager_backend(self): + # should indicate problem with first backend + def test_fn(): + return False + + out = BisectionManager.do_bisect(test_fn) + self.assertEqual(out.backend, "eager") + self.assertEqual(out.subsystem, None) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/dynamo/test_decorators.py b/test/dynamo/test_decorators.py index 0472702fadca6..f656c5d0a88ab 100644 --- a/test/dynamo/test_decorators.py +++ b/test/dynamo/test_decorators.py @@ -624,6 +624,170 @@ def forward(self, x): # Must be 3 compilations. If not marked static there would be 2, because self.c would be converted to symints. self.assertEqual(cnts.frame_count, 3) + def test_set_stance_force_eager(self): + @torch.compile(backend="eager") + def a(x): + if torch._dynamo.is_compiling(): + return x + 1 + return x + 2 + + @torch.compiler.set_stance("force_eager") + def b(x): + return a(x) + + def c(x): + out0 = a(x) + with torch.compiler.set_stance("force_eager"): + out1 = a(x) + return out0, out1, a(x) + + inp = torch.ones(3) + # test that decorating b has no overall side effect + self.assertEqual(a(inp), inp + 1) + + self.assertEqual(b(inp), inp + 2) + self.assertEqual(c(inp), (inp + 1, inp + 2, inp + 1)) + + torch.compiler.set_stance("force_eager") + self.assertEqual(a(inp), inp + 2) + torch.compiler.set_stance("default") + self.assertEqual(a(inp), inp + 1) + + def test_set_stance_eager_on_recompile(self): + @torch.compile(backend="eager", dynamic=False) + def a(x, n): + if torch._dynamo.is_compiling(): + return x + n + 1 + return x + n + 2 + + inp = torch.ones(3) + out1 = a(inp, 1) + with torch.compiler.set_stance("eager_on_recompile"): + out2 = a(inp, 1) + out3 = a(inp, 2) + + self.assertEqual(out1, inp + 2) + self.assertEqual(out2, inp + 2) + self.assertEqual(out3, inp + 4) + + def test_set_stance_fail_on_recompile(self): + @torch.compile(backend="eager", dynamic=False) + def a(x, n): + if torch._dynamo.is_compiling(): + return x + n + 1 + return x + n + 2 + + inp = torch.ones(3) + out1 = a(inp, 1) + with torch.compiler.set_stance("fail_on_recompile"): + out2 = a(inp, 1) + with self.assertRaisesRegex(RuntimeError, "fail_on_recompile"): + a(inp, 2) + + self.assertEqual(out1, inp + 2) + self.assertEqual(out2, inp + 2) + + def test_set_stance_forbid_in_graph(self): + @torch.compiler.set_stance("force_eager") + def a(x): + return x + 1 + + @torch.compile(backend="eager") + def b(x): + return a(x) + + with self.assertRaisesRegex( + AssertionError, "Attempt to trace forbidden callable" + ): + b(torch.ones(3)) + + @torch.compile(backend="eager") + def c(x): + with torch.compiler.set_stance("force_eager"): + return x + 1 + + with self.assertRaisesRegex( + AssertionError, "Attempt to trace forbidden callable" + ): + c(torch.ones(3)) + + @torch.compile(backend="eager") + @torch.compiler.set_stance("force_eager") + def d(x): + return x + 1 + + with self.assertRaisesRegex( + AssertionError, "Attempt to trace forbidden callable" + ): + d(torch.ones(3)) + + @torch.compile(backend="eager") + def e(x): + with torch._dynamo.set_stance("force_eager"): + return x + 1 + + with self.assertRaisesRegex( + AssertionError, "Attempt to trace forbidden callable" + ): + e(torch.ones(3)) + + @torch.compile(backend="eager") + def f(x): + torch._dynamo.eval_frame._set_stance("force_eager") + return x + 1 + + with self.assertRaisesRegex( + AssertionError, "Attempt to trace forbidden callable" + ): + f(torch.ones(3)) + + @torch.compile(backend="eager") + def g(x): + # cause a skipped frame + try: + torch._dynamo.graph_break() + except Exception: + pass + # NOTE: torch._dynamo.is_compiling() will get traced + # and return true. torch.compiler.is_compiling() is skipped + # and will return false. + if torch.compiler.is_compiling(): + raise RuntimeError("Expect this frame to be skipped") + # should not be traced, but eval frame callback is still set + with torch.compiler.set_stance("force_eager"): + return x + 1 + + with self.assertRaisesRegex(RuntimeError, "set_stance in a torch.compile"): + g(torch.ones(3)) + + def test_set_stance_force_backend(self): + @torch.compile + def a(x): + return x + 1 + + cnts = torch._dynamo.testing.CompileCounter() + + @torch.compiler.set_stance("default", force_backend=cnts) + def b(x): + return a(x) + + b(torch.ones(3)) + + self.assertEqual(cnts.frame_count, 1) + + @torch.compiler.set_stance("default", force_backend="eager") + def c(x): + return a(x) + + # just make sure this doesn't crash + c(torch.ones(3)) + + with self.assertRaisesRegex(RuntimeError, "force_backend"): + + @torch.compiler.set_stance("force_eager", force_backend="eager") + def d(x): + pass + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 03d40e377335b..78a72208b4fcd 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -34,6 +34,11 @@ from torch.testing._internal.common_cuda import TEST_CUDA +@torch._dynamo.assume_constant_result +def dynamo_assume_constant_result_global_function(): + return "test" + + class ExportTests(torch._dynamo.test_case.TestCase): # TODO(voz): Refactor to a shared test function. # The tests in this file are a little redundant, @@ -1272,6 +1277,18 @@ def forward(self, x): result = graph(torch.tensor([[1, 0], [0.25, 0.25]])) self.assertTrue(torch._dynamo.utils.same(result, real_result)) + def test_export_with_constant_global_function(self): + class MyModule(torch.nn.Module): + def forward(self): + a = dynamo_assume_constant_result_global_function() + b = dynamo_assume_constant_result_global_function() + return a + b + + module = MyModule() + graph, _ = torch._dynamo.export(module)() + result = graph() + self.assertEqual(result, "testtest") + def test_export_with_constant_free_function_and_class_method(self): @torch._dynamo.assume_constant_result def helper_fn(x): @@ -3247,10 +3264,11 @@ def false_fn(x): def f(x): return cond(x.shape[0] > 10, true_fn, false_fn) + # Now we allow torch.cond to handle empty args example_inputs = (torch.rand(5),) with self.assertRaisesRegex( TypeError, - r"cond\(\) missing 1 required positional argument: 'operands'", + r"false_fn\(\) missing 1 required positional argument: 'x'", ): f(*example_inputs) @@ -4565,10 +4583,7 @@ def forward(self, x, b, y): return pytree.tree_unflatten([x], self._out_spec)""", # NOQA: B950 ) - with self.assertRaisesRegex( - torch._dynamo.exc.Unsupported, "boolean masking setitem backwards" - ): - gm, _ = torch._dynamo.export(fn)(x, b, y) + gm, _ = torch._dynamo.export(fn)(x, b, y) def test_dynamo_list_index(self): def fn(x, in_list): @@ -4579,6 +4594,18 @@ def fn(x, in_list): out = graph(*inputs) self.assertEqual(out, torch.ones(2, 2) + 1) + def test_dynamo_enum_in_tuple(self): + class IntEnum(int, Enum): + X = 0 + + def fn(tensor): + return tensor[..., IntEnum.X] + + tensor = torch.rand((5, 5)) + graph, _ = torch._dynamo.export(fn)(tensor) + out = graph(tensor) + self.assertEqual(out, tensor[:, 0]) + common_utils.instantiate_parametrized_tests(ExportTests) diff --git a/test/dynamo/test_frame_init.py b/test/dynamo/test_frame_init.py index 5abf6a45c7429..97aac1870e984 100644 --- a/test/dynamo/test_frame_init.py +++ b/test/dynamo/test_frame_init.py @@ -87,11 +87,13 @@ def test_frame_init(self): target_with_varkwargs.__code__: varkwargs_code2.__code__, } + empty_guard_manager = torch._dynamo.guards.GuardManagerWrapper() + def callback1(frame, cache_entry, frame_state): if frame.f_code in code_map1: transformed_code = code_map1[frame.f_code] return torch._dynamo.types.GuardedCode( - transformed_code, lambda f_locals: True, CompileId(0, 0) + transformed_code, empty_guard_manager, CompileId(0, 0) ) return None @@ -99,7 +101,7 @@ def callback2(frame, cache_entry, frame_state): if frame.f_code in code_map2: transformed_code = code_map2[frame.f_code] return torch._dynamo.types.GuardedCode( - transformed_code, lambda f_locals: True, CompileId(0, 0) + transformed_code, empty_guard_manager, CompileId(0, 0) ) return None diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 3d5d8e6928c86..694e8b5d23502 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -938,6 +938,16 @@ def test_tensor_is_complex(x): else: return x - 1 + @make_test + def test_tensor_size(x): + fn = torch.Tensor.size + return fn(x + 1) + + @make_test + def test_tensor_dim(x): + fn = torch.Tensor.dim + return fn(x + 1) + @make_test def test_tensor_is_inference(x): if x.is_inference(): diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index 1c777c1550f4c..ff3c83863a5b0 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -2573,6 +2573,22 @@ def fn(pred, pytree_in): ): torch.compile(fn, backend="eager")(pred, pytree_in) + def test_cond_with_empty_operands(self): + @torch.compile(fullgraph=True) + def fn(x, y, z): + def true_fn(): + return y + 2 + + def false_fn(): + return z + 1 + + return torch.cond(x, true_fn, false_fn) + + zeros = torch.zeros(1) + ones = torch.ones(1) + self.assertEqual(fn(zeros, ones, ones), torch.tensor([2.0])) + self.assertEqual(fn(ones, ones, ones), torch.tensor([3.0])) + def test_hints_wrapper(self): def ref_fn(x, y): x = x + y @@ -4020,7 +4036,7 @@ def forward(self, L_x_: "f32[3, 3, 3]"): set_inplace_requires_grad_allowed_1 = torch._C._functorch.set_inplace_requires_grad_allowed(False); set_inplace_requires_grad_allowed_1 = None sin: "f32[3, 3, 3]" = diff_args.sin() - add: "f32[3, 3, 3]" = sin + y; sin = None + add: "f32[3, 3, 3]" = sin + y; sin = y = None output: "f32[]" = add.sum(); add = None _autograd_grad = torch._functorch.eager_transforms._autograd_grad((output,), [diff_args], create_graph = True); diff_args = None @@ -4032,7 +4048,7 @@ def forward(self, L_x_: "f32[3, 3, 3]"): _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None - return (y, grad_input_1) + return (grad_input_1,) """, ) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 6c1e77bfff13e..9ef49da2037fd 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -162,8 +162,8 @@ def test_dynamo_error(self, records): ) test_aot = within_range_record_test(2, 6, aot=logging.INFO) - test_inductor_debug = within_range_record_test(3, 17, inductor=logging.DEBUG) - test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO) + test_inductor_debug = within_range_record_test(3, 22, inductor=logging.DEBUG) + test_inductor_info = within_range_record_test(2, 9, inductor=logging.INFO) @make_logging_test() def test_inductor_error(self, records): @@ -532,6 +532,24 @@ def fn(x, y): ~~~~~~~~^~~~~~~~~""", ) + @skipIfNotPy311 + @make_logging_test(trace_call=True) + def test_trace_call_prefix(self, records): + def fn(x, y): + return (x * 2) @ (y * 3) + + fn_opt = torch._dynamo.optimize("eager")(fn) + fn_opt(torch.randn(10, 20), torch.randn(20, 30)) + + msg0 = munge_exc(records[0].getMessage()) + self.assertExpectedInline( + msg0, + """\ +TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_prefix.fn) + return (x * 2) @ (y * 3) + ~~^~~""", + ) + @skipIfNotPy311 @make_logging_test(trace_call=True) def test_trace_call_inline_call(self, records): @@ -560,12 +578,14 @@ def f(x): return x * 2 ~~^~~""", ) - self.assertExpectedInline( - messages[2], - """\ - return g(g(x)) - ~^^^^^^""", - ) + # skip this check since 3.13 removed carets for this case + # see https://github.com/python/cpython/issues/99180 + # self.assertExpectedInline( + # messages[2], + # """\ + # return g(g(x)) + # ~^^^^^^""", + # ) self.assertExpectedInline( messages[3], """\ @@ -646,10 +666,10 @@ def f(x, y, z): self.assertExpectedInline( munge_shape_guards(record.getMessage()), """\ -+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in # -+- LAMBDA_GUARD: L['z'].size()[0] == L['y'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False) -+- LAMBDA_GUARD: Eq(Mod(2*L['y'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in # -+- LAMBDA_GUARD: 2 <= L['y'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950 ++- LAMBDA_GUARD: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in # ++- LAMBDA_GUARD: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False) ++- LAMBDA_GUARD: Eq(Mod(2*L['z'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in # ++- LAMBDA_GUARD: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950 ) @make_logging_test(guards=True) @@ -684,6 +704,18 @@ def fn(x): self.assertGreater(len(records), 0) self.assertLess(len(records), 4) + @make_logging_test(perf_hints=True) + @requires_cuda + def test_optimizer_non_static_param(self, records): + params = [torch.randn(10, 10, device="cuda") for _ in range(2)] + for param in params: + param.grad = torch.zeros_like(param) + opt = torch.optim.Adam(params) + compiled_opt_step = torch.compile(opt.step, mode="reduce-overhead") + compiled_opt_step() + self.assertGreater(len(records), 0) + self.assertLess(len(records), 3) + @skipIfTorchDynamo("too slow") @make_logging_test(**torch._logging.DEFAULT_LOGGING) def test_default_logging(self, records): diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index d59bd76f5bf0a..21be24a2678c6 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -79,6 +79,7 @@ from torch.testing._internal.common_utils import ( freeze_rng_state, IS_FBCODE, + scoped_load_inline, set_default_dtype, skipIfNNModuleInlined, skipIfWindows, @@ -92,7 +93,7 @@ if HAS_OPTREE: import optree -mytuple = collections.namedtuple("mytuple", ["a", "b", "ab"]) +MyTuple = collections.namedtuple("MyTuple", ["a", "b", "ab"]) T = typing.TypeVar("T") @@ -321,16 +322,17 @@ def add_fn(a, b, out): res_compiled = add_fn(2, 3, torch.tensor(0.0)) self.assertEqual(res, res_compiled) + @scoped_load_inline @skipIfNNModuleInlined("fails internal CI") @unittest.skipIf(IS_FBCODE, "inline cpp_extension doesn't work in fbcode") - def test_cpp_extension_recommends_custom_ops(self): + def test_cpp_extension_recommends_custom_ops(self, load_inline): cpp_source = """ #include at::Tensor foobar(const at::Tensor& x) { return x.clone(); } """ - module = torch.utils.cpp_extension.load_inline( + module = load_inline( name="mylib", cpp_sources=cpp_source, functions="foobar", @@ -362,7 +364,7 @@ def f(x): return x.clone(); } """ - module2 = torch.utils.cpp_extension.load_inline( + module2 = load_inline( name="mylib2", cpp_sources=cpp_source, functions="baz", @@ -1656,8 +1658,8 @@ def fn(a, b): def test_namedtuple1(self): def fn(a, b): - tmp = mytuple(a, b, a + b) - return mytuple(tmp.a, tmp[1], tmp.ab + b) + tmp = MyTuple(a, b, a + b) + return MyTuple(tmp.a, tmp[1], tmp.ab + b) v1 = torch.Tensor([10]) v2 = torch.Tensor([20]) @@ -1680,24 +1682,48 @@ def fn(packed): v3 = torch.Tensor([3]) cnts = torch._dynamo.testing.CompileCounter() opt_fn = torch._dynamo.optimize(cnts)(fn) - self.assertEqual(opt_fn(mytuple(v1, v2, v3))[0], 7) + self.assertEqual(opt_fn(MyTuple(v1, v2, v3))[0], 7) self.assertEqual(cnts.frame_count, 1) self.assertEqual(cnts.op_count, 3) def test_namedtuple3(self): def fn(x, packed): - if isinstance(packed, mytuple): + if isinstance(packed, MyTuple): return x + 1 else: return x - 1 x = torch.rand([2, 3]) - packed = mytuple(1, 2, 3) + packed = MyTuple(1, 2, 3) ref = fn(x, packed) opt_fn = torch._dynamo.optimize("eager")(fn) res = opt_fn(x, packed) self.assertTrue(same(ref, res)) + def test_structseq1(self): + def fn(x, y): + return torch.return_types.max((x, y)) + + x = torch.randn(3, 2) + y = torch.randn(2, 4) + expected = fn(x, y) + fn_opt = torch.compile(fullgraph=True)(fn) + actual = fn_opt(x, y) + + self.assertEqual(actual, expected) + + def test_structseq2(self): + def fn(x, y): + return tuple(torch.return_types.qr((2 * x, y - 1))) + + x = torch.randn(3, 2) + y = torch.randn(2, 4) + expected = fn(x, y) + fn_opt = torch.compile(fullgraph=True)(fn) + actual = fn_opt(x, y) + + self.assertEqual(actual, expected) + def test_range_input(self): def fn(a, rng): x = a @@ -3744,6 +3770,33 @@ def deep(x): self.assertTrue(torch.allclose(exp1, actual1)) self.assertTrue(torch.allclose(exp2, actual2)) + def test_closure_write_across_functions(self): + z = 1 + k = 2 + + def create_fn(): + def fn(x): + nonlocal k, z + k = z + + return fn + + def update_z_and_run_fn(fn, x): + nonlocal z + z = 3 + fn(x) + return x.cos() + + @torch.compile(backend="eager") + def foo(x): + fn = create_fn() + return update_z_and_run_fn(fn, x) + + x = torch.randn(1) + foo(x) + self.assertEqual(3, z) + self.assertEqual(3, k) + def test_top_package_import(self): def fn(x): import torch.fx @@ -7319,7 +7372,9 @@ def fn(): # NOTE this test can be removed once multiline errors are in Python. # See https://github.com/python/cpython/issues/106922 + # Covered by test_logging.py:test_trace_call* tests in 3.13+ @skipIfNotPy311 + @unittest.skipIf(sys.version_info >= (3, 13), "feature landed in 3.13") def test_get_instruction_source_311(self): def f(): # flake8: noqa @@ -7512,6 +7567,20 @@ def fn(x, y, z): opt = torch._dynamo.optimize(nopython=True)(fn) opt(*inputs) + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_symint_fold_nontrivial_product_modulo(self): + @torch.compile(fullgraph=True) + def f(x): + u0, u1 = x.tolist() + torch._check_is_size(u0) + # The condition should fold to true. + if ((u0 + 10) * (u0 + 10)) % (u0 + 10) == 0: + return torch.tensor(True) + return torch.tensor(False) + + res = f(torch.tensor([20, 21])) + self.assertEqual(torch.tensor(True), res) + # Translation validation changes the exception type, don't run with it @torch.fx.experimental._config.patch(translation_validation=False) def test_mark_dynamic_with_ranges(self): @@ -9050,6 +9119,29 @@ def deep(c): self.assertEqual(eager, compiled) self.assertEqual(counter.frame_count, 1) + def test_inline_closure_returned_by_another_function_and_captures(self): + x = torch.ones(1) + + def fn(): + def inner(): + return x + 2 + + return inner + + @torch.compile + def start(): + # Obtain the `inner` function, which holds reference to `x`. + inner = fn() + + # When we call `inner`, we end up looking up `x` from our inlining + # tracer, Dynamo must make sure it still has some modeling of `x` at + # that point. + res = inner() + return res + + res = start() + self.assertEqual(torch.ones(1) * 3, res) + def test_deque_input(self): a = torch.randn([2, 3]) b = torch.randn([2, 3]) @@ -9858,6 +9950,132 @@ def fn(x, y): self.assertEqual(actual, expected) + def test_pytree_tree_leaves(self): + implemtations = [("python", pytree)] + + for name, module in implemtations: + with self.subTest(f"pytree implement: {name}"): + + def fn(x): + tree = { + "a": [x, x - 1], + "b": x + 2, + "c": ( + x, + 3.0, + collections.deque([0.0, -x]), + ), + "d": collections.OrderedDict( + { + "e": torch.return_types.qr((2 * x, None)), + "f": MyTuple(x, x + 1, torch.zeros(4, 3)), + }, + ), + } + leaves = module.tree_leaves(tree) + return leaves + + x = torch.randn(3, 2) + expected = fn(x) + fn_opt = torch.compile(fullgraph=True)(fn) + actual = fn_opt(x) + + self.assertEqual(actual, expected) + + def test_pytree_tree_flatten_unflatten(self): + implemtations = [("python", pytree)] + + for name, module in implemtations: + with self.subTest(f"pytree implement: {name}"): + + def fn(x, y): + tree = { + "a": [x, x - 1], + "b": x + 2, + "c": ( + x, + 3.0, + [0.0, -x], + ), + "d": collections.OrderedDict( + { + "e": torch.return_types.qr((2 * x, None)), + "f": MyTuple(x, x + 1, torch.zeros(4, 3)), + }, + ), + } + leaves, treespec = module.tree_flatten(tree) + new_leaves = [ + x - 1, + y, + x * y, + 3.0, + y - 2, + torch.zeros(2, 2), + 2 * y, + -y, + x + y, + x - y, + torch.ones(3, 2), + 1, + ] + new_tree = module.tree_unflatten(leaves, treespec) + return leaves, new_tree + + x = torch.randn(3, 2) + y = torch.randn(3, 2) + expected = fn(x, y) + fn_opt = torch.compile(fullgraph=True)(fn) + actual = fn_opt(x, y) + + self.assertEqual(actual, expected) + + def test_pytree_tree_map(self): + implemtations = [("python", pytree)] + + for name, module in implemtations: + with self.subTest(f"pytree implement: {name}"): + + def fn(x, y): + tree1 = { + "a": [x, x - 1], + "b": x + 2, + "c": ( + x, + 3.0, + [0.0, -x], + ), + "d": collections.OrderedDict( + { + "e": torch.return_types.qr((2 * x, None)), + "f": MyTuple(x, x + 1, torch.zeros(4, 3)), + }, + ), + } + tree2 = collections.OrderedDict( + [ + ("c", (y, 3.0, [-y, 10.0])), + ("a", [y, y + 1]), + ("b", y + 2), + ( + "d", + { + "f": MyTuple(torch.ones(4, 3), -y, y + 1), + "e": torch.return_types.qr((2 * y, None)), + }, + ), + ], + ) + return module.tree_map(lambda u, v: (u, v), tree1, tree2) + + x = torch.randn(3, 2) + y = torch.randn(3, 2) + expected = fn(x, y) + fn_opt = torch.compile(fullgraph=True)(fn) + actual = fn_opt(x, y) + + self.assertEqual(actual, expected) + def test_shape_env_no_recording(self): main = ShapeEnv(should_record_events=False) @@ -10935,6 +11153,26 @@ def fn(x): self.assertEqual(expected.stride(), actual.stride()) self.assertEqual(expected.storage_offset(), actual.storage_offset()) + def test_dynamic_shapes_as_strided(self): + def fn(t, new_size, new_stride): + tmp = t.as_strided(new_size, new_stride) + tmp = tmp.view(-1) + return t * tmp.sum() + + optfn = torch.compile(backend="eager", dynamic=True)(fn) + + x = torch.randn(3) + new_size = [0, 3] + new_stride = [3, 1] + + expected = fn(x, new_size, new_stride) + actual = optfn(x, new_size, new_stride) + + self.assertEqual(expected.dtype, actual.dtype) + self.assertEqual(expected.shape, actual.shape) + self.assertEqual(expected.stride(), actual.stride()) + self.assertEqual(expected.storage_offset(), actual.storage_offset()) + @torch._dynamo.config.patch(guard_nn_modules=True) def test_hasattr_nn_module_guard(self): class M(torch.nn.Module): @@ -11359,6 +11597,168 @@ def get_rng(): self.assertEqual(expected, actual) self.assertGreater(po.call_count, 0) + def test_data_ptr_graph_break_builtin(self): + def f(a, b): + # builtin + not implemented for DataPtrVariable + return a.data_ptr() + b.data_ptr() + + a = torch.randn(4) + b = torch.randn(5) + + # make sure there is a graph break + with self.assertRaises(torch._dynamo.exc.Unsupported): + torch.compile(f, backend="eager", fullgraph=True)(a, b) + + torch._dynamo.reset() + + expected = f(a, b) + actual = torch.compile(f, backend="eager")(a, b) + + self.assertEqual(expected, actual) + + def test_data_ptr_graph_break_aten(self): + def f(a): + # torch.add not implemented for DataPtrVariable + return torch.add(a, a.data_ptr()) + + a = torch.randn(4) + + counters.clear() + + expected = f(a) + actual = torch.compile(f, backend="eager")(a) + + self.assertEqual(expected, actual) + self.assertTrue(len(counters["graph_break"]) > 0) + counters.clear() + + class AssertNumOutputBackend: + """ + A backend that checks the number of output for compiled graph, and + return the graph as is. + """ + + def __init__(self, test_case, expected_num_output: int): + self.test_case = test_case + self.expected_num_output = expected_num_output + + def __call__(self, gm: torch.fx.GraphModule, example_inputs): + outputs = gm(*example_inputs) + self.test_case.assertEqual(self.expected_num_output, len(outputs)) + return gm + + def test_returning_nested_func_with_captured_tensor(self): + @torch.compile(backend=self.AssertNumOutputBackend(self, 2)) + def test(): + x = torch.rand(1) + + def func(): + return x + x + + # Returning `func` forces dynamo to output `x` in the compiled + # graph, so that we can store it as `func`'s closure. The output of + # compiled graph would be `(x, x + x)`. + return func, func() + + test() + + def test_running_nested_func_with_captured_tensor(self): + @torch.compile(backend=self.AssertNumOutputBackend(self, 1)) + def test(): + x = torch.rand(1) + + def func(): + return x + x + + # `x` is no longer needed after running the compiled graph, so we + # shouldn't return it. The output of compiled graph would be `(x + + # x,)`. + return func() + + test() + + def test_returning_func_with_captured_func_and_tensor(self): + @torch.compile(backend=self.AssertNumOutputBackend(self, 2)) + def test(): + x = torch.rand(1) + + def nested(): + return x + x + + def func(): + return nested() + + # Returning `func` forces dynamo to output `x` in the compiled + # graph, so that we can store it as `func`'s closure. The output of + # compiled graph would be `(x, x + x)`. + return func, func() + + test() + + def test_running_func_with_captured_func_and_tensor(self): + @torch.compile(backend=self.AssertNumOutputBackend(self, 1)) + def test(): + x = torch.rand(1) + + def nested(): + return x + x + + def func(): + return nested() + + # `x` is no longer needed after running the compiled graph, so we + # shouldn't return it. The output of compiled graph would be `(x)`. + return func() + + test() + + def test_escaping_closure_var_with_backward_hook(self): + @torch.compile(backend=self.AssertNumOutputBackend(self, 2)) + def fn(x): + temp = x * x + captured_var = temp + 1 + + # This is where the lambda escapes the lifetime of `fn`, so + # dynamo must generate proper bytecode to update `captured_var`. + x.register_hook(lambda _: captured_var) + + # The output of compiled graph would be `(x * x, x * x + 1)`. + return temp + + ones = torch.ones(4, requires_grad=True) + fn(ones).sum().backward() + + def test_escaping_closure_var_with_nonlocal_var(self): + nonlocal_fn = None + + @torch.compile(backend=self.AssertNumOutputBackend(self, 2)) + def fn(x): + temp = x * x + captured_var = x + 1 + + def inner(): + return captured_var + + # This is where `inner` escapes the lifetime of `fn`, so dynamo must + # generate proper bytecode to update `captured_var`. + nonlocal nonlocal_fn + nonlocal_fn = inner + + # The output of compiled graph would be `(x * x, x * x + 1)`. + return temp + + ones = torch.ones(4, requires_grad=True) + fn(ones) + nonlocal_fn() + + def test_compare_tensor_with_none(self): + @torch.compile() + def f(x): + return torch.tensor(x == None) + + res = f(torch.tensor(1)) + self.assertEqual(torch.tensor(False), res) + class TestTracer(JitTestCase): def test_jit_save(self): diff --git a/test/dynamo/test_modes.py b/test/dynamo/test_modes.py index 64c979168222f..b89e3cbbd466f 100644 --- a/test/dynamo/test_modes.py +++ b/test/dynamo/test_modes.py @@ -12,6 +12,7 @@ _push_on_torch_function_stack, ) from torch.overrides import _get_current_function_mode_stack, BaseTorchFunctionMode +from torch.testing._internal.triton_utils import requires_cuda from torch.utils._device import DeviceContext from torch.utils._python_dispatch import TorchDispatchMode @@ -581,6 +582,23 @@ def run_checks(setups_and_oplists, skips, ref_map): run_checks(setups_and_oplists, skips, BUILTIN_TO_TENSOR_FN_MAP) run_checks(rsetups_and_oplists, rskips, BUILTIN_TO_TENSOR_RFN_MAP) + @requires_cuda + def test_flex_attention(self): + import torch + from torch.nn.attention.flex_attention import create_block_mask, flex_attention + + torch.set_default_device("cuda") + + flex_attention = torch.compile(flex_attention, dynamic=False) + + prefix_lengths = torch.arange(8) + + def prefix_lm(b, h, q, kv): + return prefix_lengths[b] >= kv + + # This runs in fullgraph already + mask = create_block_mask(prefix_lm, 8, None, 512, 512, _compile=True) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 329b04fd7d810..5e1c1369b423e 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -1613,6 +1613,45 @@ def test_lazy_module_kwargs(self): exp_res = m(x, y) self.assertTrue(torch.allclose(exp_res, opt_m(x, y))) + # RuntimeError: SymIntArrayRef expected to contain only concrete integers + @expectedFailureDynamic + def test_lazy_module_speculation_log_divergence(self): + class ModWithOneLazyLinear(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.layer = torch.nn.LazyLinear(8) + + def forward(self, x): + return self.layer(x) + + # This allows us to restart tracing without clearing speculation log + def id_and_fail_inlining(x): + torch._dynamo.graph_break() + return x + + cnt = torch._dynamo.testing.CompileCounter() + + @torch.compile(backend=cnt) + def test(mod, x): + res = mod(x) + # Speculation log must not diverge in the 2nd round of tracing, + # after we've initialized the `LazyLinear` into a `Linear` in the + # 1st round. + res2 = id_and_fail_inlining(res) + return res + + mod = ModWithOneLazyLinear() + x = torch.ones(10, 3) + + # Make sure we don't get recompilation across multiple runs + actual_res = test(mod, x) + expect_res = mod(x) + self.assertTrue(torch.allclose(expect_res, actual_res)) + actual_res = test(mod, x) + expect_res = mod(x) + self.assertTrue(torch.allclose(expect_res, actual_res)) + self.assertEqual(cnt.frame_count, 1) + def test_call_fn_with_non_const_inputs_safe(self): class ModuleSpecialFwd(torch.nn.Module): def __init__(self) -> None: @@ -3046,6 +3085,31 @@ def forward(self, x): # Must be 3 compilations. If not marked static there would be 2, because strides would be converted to symints. self.assertEqual(cnts.frame_count, 3) + @patch.object(torch._dynamo.config, "inline_inbuilt_nn_modules", True) + def test_overridden_call(self): + class OverRiddenCallModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def __call__(self, x): + # Overrides the __call__ method of torch.nn.Module + return 5 * self.forward(x) + + def forward(self, x): + return x * 3 + + m = OverRiddenCallModule() + + def fn(x): + return m(x) + + x = torch.ones(4) + ref = fn(x) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x) + self.assertEqual(ref, res) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/test/dynamo/test_reconstruct.py b/test/dynamo/test_reconstruct.py index c4fee05889b72..f78660ae248e7 100644 --- a/test/dynamo/test_reconstruct.py +++ b/test/dynamo/test_reconstruct.py @@ -16,9 +16,9 @@ def _filter_instructions(instructions, opname): class ReconstructTest(torch._dynamo.test_case.TestCase): @contextlib.contextmanager - def register_bytecode_hook(self, check_fn): + def register_bytecode_hook(self, fn): def hook(code, out_code): - check_fn(list(dis.get_instructions(out_code))) + fn(list(dis.get_instructions(out_code))) return code torch._dynamo.reset() @@ -40,7 +40,6 @@ def hook(instructions: List[dis.Instruction]): self.assertEqual(build_map[0].argval, 1) def f(d, t): - d[1] = t d[40] = t + 1 t = torch.randn(3, 4) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 392d25528207d..6d488c74fe291 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -23,7 +23,7 @@ from copy import deepcopy from enum import Enum, IntEnum from functools import wraps -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Literal, Tuple, TypedDict from unittest import mock import numpy as np @@ -37,7 +37,7 @@ import torch.utils._pytree as pytree from torch import nn from torch._dynamo.debug_utils import same_two_models -from torch._dynamo.testing import CompileCounter, rand_strided, same +from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312 from torch._inductor.utils import fresh_inductor_cache from torch.nn import functional as F from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION @@ -69,6 +69,11 @@ import msgspec +HAS_OMEGACONG = importlib.util.find_spec("omegaconf") +if HAS_OMEGACONG: + from omegaconf import OmegaConf + + def exists(val): return val is not None @@ -1695,10 +1700,7 @@ def test_issue175(self): opt_model(inp) opt_model(inp) self.assertEqual(cnt.frame_count, 1) - - self.assertEqual( - 15 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count - ) + self.assertEqual(12, cnt.op_count) def test_exec_import(self): def fn1(): @@ -4164,7 +4166,7 @@ def fn(x): def test_inductor_no_recursionerror_on_for_loops(self): def forward(x): - for _ in range(1000): + for _ in range(10000): x = 1.0 * x return x @@ -6007,6 +6009,19 @@ def outer_func(x): res = compile_outer(x) self.assertEqual(ref, res) + # https://github.com/pytorch/pytorch/issues/136640 + def test_inductor_dynamic_shapes_broadcasting(self) -> None: + def fn(x, y): + x_view = x.view(-1, 4) + y_view = y.view(-1, 4) + return x_view * y_view + + x = torch.randn(4) + y = torch.randn(8) + out_ref = fn(x, y) + out_test = torch.compile(fn, dynamic=True)(x, y) + self.assertEqual(out_ref, out_test) + # https://github.com/pytorch/pytorch/issues/119162 def test_inductor_rng_default_dtype(self) -> None: @torch.compile @@ -6039,6 +6054,147 @@ def fn(x): opt_fn = torch.compile(fn, backend="eager") self.assertEqual(fn(x), opt_fn(x)) + @unittest.skipIf(not HAS_OMEGACONG, "missing omegaconf package") + def test_omegaconf_dictconfig(self): + def fn(cfg, x): + a = cfg["foo"].a * x + b = cfg.bar["b"] * a + cfg.__dict__["baz"] = 4 + return b * cfg.baz + + config = OmegaConf.create({"foo": {"a": 3}, "bar": {"b": 5}}) + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + + ref = fn(config, x) + cloned_config = copy.deepcopy(config) + res = opt_fn(cloned_config, x) + + self.assertEqual(fn(config, x), opt_fn(config, x)) + self.assertEqual(cloned_config.baz, 4) + + # https://github.com/pytorch/pytorch/issues/136257 + def test_overwriting_params(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(2, 2) + self.fc2 = torch.nn.Linear(2, 2) + + def forward(self, x): + x = self.fc1(x) + x = self.fc2(x) + return x + + class ZeROOrderedDict(collections.OrderedDict): + def __init__(self, parent_module=None, *args, **kwargs): + """A replacement for ``collections.OrderedDict`` to detect external ZeRO params. + + Args: + parent_module (``collections.OrderedDict``): the collection to replace + """ + + super().__init__(*args, **kwargs) + self._parent_module = parent_module + + def __getitem__(self, key): + param = super().__getitem__(key) + + # Params can be registered as None (e.g., bias) + if param is None: + return param + + # do something here + return param + + def inject_parameters(module, cls): + for module in module.modules(): # noqa: B020 + if cls == ZeROOrderedDict: + new_param = cls(parent_module=module) + else: + new_param = cls() + + for key, param in module._parameters.items(): + new_param[key] = param + module._parameters = new_param + + model = M() + + inject_parameters(model, ZeROOrderedDict) + + model = torch.compile(model, backend="eager", fullgraph=True) + + x = torch.ones(2) + with torch.no_grad(): + y = model(x) + + def test_typed_dict(self): + class LlavaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size, num_channels, height, width)`""" + + def fn(x, y): + obj = LlavaImagePixelInputs(type=int, data=y) + out = x * obj["data"] + obj["data"] = 3 + return out * obj["data"] + + x, y = torch.randn(4), torch.randn(4) + ref = fn(x, y) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x, y) + + self.assertEqual(ref, res) + + def test_typed_dict_total(self): + class LlavaImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """Shape: `(batch_size, num_channels, height, width)`""" + + def fn(x, y): + obj = LlavaImagePixelInputs(data=y, total=False) + return x * obj["data"] + + x, y = torch.randn(4), torch.randn(4) + ref = fn(x, y) + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + res = opt_fn(x, y) + + self.assertEqual(ref, res) + + @skipIfPy312 # listcomp bytecode is optimized + def test_listcomp(self): + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self._num = 4 + + @torch._dynamo.disable(recursive=False) + def forward(self, x): + values = [i * torch.cos(x) for i in range(self._num)] + return sum(values) + + mod = Module() + + def fn(x): + return mod(x) + + cnt = torch._dynamo.testing.CompileCounter() + opt_fn = torch.compile(fn, backend=cnt) + x = torch.randn(4) + + ref = fn(x) + res = opt_fn(x) + self.assertEqual(ref, res) + self.assertEqual(cnt.frame_count, 1) + # Ensure that the listcomp is fully compiled + self.assertEqual(cnt.op_count, 8) + instantiate_parametrized_tests(ReproTests) diff --git a/test/dynamo/test_structured_trace.py b/test/dynamo/test_structured_trace.py index c86580df50b85..4e5c04d399fbf 100644 --- a/test/dynamo/test_structured_trace.py +++ b/test/dynamo/test_structured_trace.py @@ -5,6 +5,7 @@ import json import logging import os +import re import shutil import subprocess import tempfile @@ -37,6 +38,13 @@ def example_fn(a): return output +def example_training_fn(a): + output = a.mul(torch.ones(1000, 1000, requires_grad=True)) + output = output.add(torch.ones(1000, 1000)) + output.sum().backward() + return output + + def dynamo_error_fn(a): output = a.mul(torch.ones(1000, 1000)) output = output.add(torch.ones(10, 10)) @@ -56,6 +64,10 @@ def inductor_schedule_fn(a): ARGS = (torch.ones(1000, 1000, requires_grad=True),) +def replace_dynamic(buffer, key): + return re.sub(r'("' + key + r'":\s*)(\d+\.\d+)', r"\1", buffer) + + class StructuredTraceTestingFilter(logging.Filter): def __init__(self, match_name=None): self.match_name = match_name @@ -195,7 +207,8 @@ def test_schedule(self): {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -219,7 +232,8 @@ def test_cudagraphs(self): {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -251,7 +265,8 @@ def fn(x, y): {"describe_tensor": {"id": 1, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 1, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 1, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_y_": [1000, 1000], "l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -264,7 +279,8 @@ def fn(x, y): {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1000, 1000], "add": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} -{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 1, "attempt": 0, "has_payload": "HASH"} @@ -288,7 +304,8 @@ def test_example_fn(self): {"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "ones_1": [1000, 1000], "output_1": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -300,6 +317,68 @@ def test_example_fn(self): self.assertParses() + @requires_tlparse + def test_example_training_fn(self): + fn_opt = torch._dynamo.optimize("inductor")(example_training_fn) + fn_opt(torch.ones(1000, 1000, requires_grad=True)) + buffer = self.buffer.getvalue() + buffer = replace_dynamic(buffer, "inductor_compile_time_s") + buffer = replace_dynamic(buffer, "code_gen_time_s") + buffer = replace_dynamic(buffer, "structured_logging_overhead_s") + self.assertExpectedInline( + buffer, + """\ +{"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +{"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 1} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 1, "frame_compile_id": 0, "attempt": 1} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "is_leaf": true, "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 1} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['___stack1']"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 1} +{"dynamo_cpp_guards_str": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 1, "frame_compile_id": 0, "attempt": 1} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['___stack0']"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['___stack0']"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1} +{"dynamo_output_graph": {"sizes": {"l_stack0_": [1000, 1000], "ones": [1000, 1000], "output": [1000, 1000], "sum_1": []}}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"aot_joint_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"aot_forward_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"aot_backward_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"dynamo_cpp_guards_str": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"compilation_metrics": "METRICS", "frame_id": 2, "frame_compile_id": 0, "attempt": 1} +{"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"inductor_post_grad_graph": {}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"artifact": {"name": "fx_graph_cache_miss", "encoding": "json"}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} +{"bwd_compilation_metrics": {"compile_id": "2/0", "inductor_compile_time_s": , "code_gen_time_s": , "fail_type": null, "fail_reason": null, "remote_cache_time_saved_s": null, "structured_logging_overhead_s": , "is_forward": false, "remote_fx_graph_cache_get_time_ms": null, "remote_fx_graph_cache_put_time_ms": null}, "frame_id": 2, "frame_compile_id": 0, "attempt": 1} +{"dynamo_start": {"stack": "STACK"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +{"describe_storage": {"id": 0, "describer_id": "ID", "size": 4000000}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +{"describe_tensor": {"id": 0, "ndim": 2, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1000, 1000], "requires_grad": true, "stride": [1000, 1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +{"describe_source": {"describer_id": "ID", "id": 0, "source": "L['output']"}, "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +{"compilation_metrics": "METRICS", "frame_id": 3, "frame_compile_id": 0, "attempt": 0} +""", # noqa: B950 + ) + + self.assertParses() + @requires_tlparse def test_dynamo_error(self): try: @@ -350,6 +429,7 @@ def throw(x): {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1000, 1000], "output": [1000, 1000]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_joint_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_backward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -393,6 +473,7 @@ def forward(self, x): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} @@ -429,6 +510,7 @@ def forward(self, x): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "rank": 0, "frame_id": 0, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} @@ -461,6 +543,7 @@ def forward(self, x): {"describe_tensor": {"id": 2, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 2, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 2, "source": "L['self']._modules['layers']._modules['0']._parameters['bias']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"aot_joint_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_backward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -474,6 +557,7 @@ def forward(self, x): {"describe_tensor": {"id": 30, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cuda', index=0)", "size": [1024], "is_leaf": true, "requires_grad": true, "is_parameter": true, "stride": [1], "storage": 17, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 30, "source": "L['self']._modules['layers']._modules['1']._parameters['bias']"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"aot_joint_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_forward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"aot_backward_graph": {}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "rank": 0, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -500,6 +584,7 @@ def fn(x): self.buffer.getvalue(), """\ {"dynamo_start": {"stack": "STACK"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} +{"artifact": {"name": "dynamo_graph_break_reason", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 1, "has_payload": "HASH"} {"compilation_metrics": "METRICS", "frame_id": 0, "frame_compile_id": 0, "attempt": 1} {"dynamo_start": {"stack": "STACK"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} @@ -507,7 +592,8 @@ def fn(x): {"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['x']"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_x_": [1], "add": [1]}}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"aot_forward_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 1, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -655,7 +741,8 @@ def fn(a): {"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1], "sin": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_runnable", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_post_grad_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} @@ -667,7 +754,8 @@ def fn(a): {"describe_tensor": {"id": 0, "ndim": 1, "dtype": "torch.float32", "device": "device(type='cpu')", "size": [1], "is_leaf": true, "stride": [1], "storage": 0, "view_func": "VIEW_FUNC", "describer_id": "ID"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"describe_source": {"describer_id": "ID", "id": 0, "source": "L['a']"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0} {"dynamo_output_graph": {"sizes": {"l_a_": [1], "sin": [1]}}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} -{"aot_forward_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"artifact": {"name": "aot_forward_graph_fw_metadata", "encoding": "string"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} +{"aot_inference_graph": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"inductor_output_code": {"filename": "FILENAME"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"artifact": {"name": "fx_graph_cache_hit", "encoding": "json"}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} {"dynamo_cpp_guards_str": {}, "frame_id": 0, "frame_compile_id": 0, "attempt": 0, "has_payload": "HASH"} diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 5379405bfbe58..db3706f90eeab 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -10,7 +10,7 @@ import torch._functorch.config import torch.utils._pytree as pytree import torch.utils.checkpoint -from torch._dynamo.testing import normalize_gm +from torch._dynamo.testing import CompileCounterWithBackend, normalize_gm from torch._higher_order_ops.wrap import wrap from torch.fx.experimental.symbolic_shapes import ( DimDynamic, @@ -373,9 +373,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) out_a = func(*args_a, **kwargs_a) out = pytree.tree_map( - lambda x: CtxSubclassTensor(x, biggest_constant) - if isinstance(x, torch.Tensor) - else x, + lambda x: ( + CtxSubclassTensor(x, biggest_constant) + if isinstance(x, torch.Tensor) + else x + ), out_a, ) @@ -672,7 +674,7 @@ def test_torch_function_call_on_method(self): wrapped2 = y.as_subclass(SigmoidToExpSubclass) def fn(w): - return w.sigmoid() + return w.exp() fn_opt = compile_full_eager(fn) @@ -683,6 +685,38 @@ def fn(w): self.assertEqual(res_exp, res_act) self.assertEqual(res_exp, res_exp2) + def test_torch_function_call_on_method_arg(self): + class LocalSubclass(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if func == torch._C.TensorBase.add_: + func = torch._C.TensorBase.sub_ + + if kwargs is None: + kwargs = {} + return super().__torch_function__(func, types, args, kwargs) + + def sigmoid(self): + return None + + x = torch.ones(2, 2) + y = torch.ones(2, 2) + z = torch.ones(2, 2) + wrapped = y.as_subclass(LocalSubclass) + wrapped2 = z.as_subclass(LocalSubclass) + + def fn(a, w): + a.add_(w) + return a + + fn_opt = torch.compile(fn) + + with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): + res_exp = fn(x, wrapped) + res_act = fn_opt(y, wrapped2) + + self.assertEqual(res_exp, res_act) + def test_user_overidden_method_unsupported(self): class LocalSubclass(torch.Tensor): @classmethod @@ -823,6 +857,31 @@ def fn(w): res_act = fn_opt(wrapped) self.assertEqual(res_exp, res_act) + def test_no_torch_function_on_size_bytecode(self): + class TestTensor(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + with torch._C.DisableTorchFunctionSubclass(): + out = func(*args, **kwargs) + + if func == torch.clone: + return out * 2 + else: + return out + + def fn(x): + return torch.clone(x) + + with torch._dynamo.config.patch(traceable_tensor_subclasses={TestTensor}): + inp = torch.ones(4, 4) + x = inp.as_subclass(TestTensor) + torch._dynamo.mark_dynamic(x, 0) + compiled_fn = torch.compile(fn, fullgraph=True) + out = compiled_fn(x) + self.assertEqual(out, torch.ones(4, 4) * 2) + def test_torch_function_wrapper_class_with_kwargs(self): x = torch.ones(2, 2) wrapped = WrapperSubclass(x) @@ -1490,11 +1549,11 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) out_a = func(*args_a, **kwargs_a) out = pytree.tree_map( - lambda x: SubclassTensor( - x, SubclassTensorArgs2(x.shape, x.device, None) - ) - if isinstance(x, torch.Tensor) - else x, + lambda x: ( + SubclassTensor(x, SubclassTensorArgs2(x.shape, x.device, None)) + if isinstance(x, torch.Tensor) + else x + ), out_a, ) return return_and_correct_aliasing(func, args, kwargs, out) @@ -1929,6 +1988,36 @@ def append_guard_fail(guards): return guards_exported, guards_failed + def test_in_graph_is_nested_call(self): + def f(nt): + if nt.is_nested: + return nt + 2 + else: + return nt + 1 + + cnt = CompileCounterWithBackend("aot_eager") + compiled_f = torch.compile(f, backend=cnt, fullgraph=True) + nt, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) + output = compiled_f(nt) + output.backward(torch.ones_like(output)) + self.assertEqual(cnt.frame_count, 1) + self.assertEqual(len(cnt.graphs), 1) + graph = cnt.graphs[0] + norm_graph = normalize_gm(graph.print_readable(print_output=False)) + + # expect -no- is_nested calls within the graph + self.assertExpectedInline( + norm_graph, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_nt_: "f64[3, s1, 5]", s1: "Sym(s1)"): + l_nt_ = L_nt_ + + add: "f64[3, s1, 5]" = l_nt_ + 2; l_nt_ = None + return (add,) +""", # noqa: B950 + ) + # Note: [What kind of guards are involved in nested tensor compilation] # # Until we implement UnionFind, dynamic shapes guards are not involved. diff --git a/test/dynamo/test_unspec.py b/test/dynamo/test_unspec.py index d58a1caf15b1b..bd154f904fdf4 100644 --- a/test/dynamo/test_unspec.py +++ b/test/dynamo/test_unspec.py @@ -10,7 +10,7 @@ import torch._dynamo.testing import torch.nn.functional as F from torch._dynamo.comptime import comptime -from torch._dynamo.testing import CompileCounter, same +from torch._dynamo.testing import CompileCounter, CompileCounterWithBackend, same from torch.testing._internal.common_utils import skipIfWindows from torch.testing._internal.logging_utils import logs_to_string @@ -601,6 +601,20 @@ def fn(x): compl_fn = torch.compile(fn, dynamic=True, backend="eager") self.assertEqual(compl_fn(inputs), fn(inputs)) + @torch._dynamo.config.patch(specialize_float=False) + def test_unspec_roundtrip_float_input(self): + def f(x, y): + if y == 5.0: + return x + 2 + else: + return x + y + return (x, y) + + cf = torch.compile(backend="eager", fullgraph=True)(f) + x = 1.1234567891234568 + y = 1.1234567891234569 + self.assertAlmostEqual(f(x, y), cf(x, y)) + @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=True) def test_unspec_float_input(self): cnts = torch._dynamo.testing.CompileCounter() @@ -622,6 +636,24 @@ def f(x, y): self.assertEqual(f(x, math.nan), cf(x, math.nan)) self.assertExpectedInline(cnts.frame_count, """3""") # nan always recompiles + @torch._dynamo.config.patch(specialize_float=False, capture_scalar_outputs=True) + def test_unspecialized_float_multiply_precision(self): + dtypes = [torch.bfloat16, torch.float16, torch.float32, torch.float64] + for dtype in dtypes: + + def fn(x, y): + return x * y + + cnt = CompileCounterWithBackend("aot_eager") + fn_opt = torch._dynamo.optimize(cnt)(fn) + x = torch.tensor(9.734375, dtype=dtype, requires_grad=True) + y1 = 1.00048828125 + y2 = 1.00048828126 + + self.assertEqual(fn_opt(x, y1), fn(x, y1)) + self.assertEqual(fn_opt(x, y2), fn(x, y2)) + self.assertEqual(cnt.frame_count, 1) + @torch._dynamo.config.patch(specialize_float=False, assume_static_by_default=False) def test_unspec_float_input_f64(self): cnts = torch._dynamo.testing.CompileCounter() diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_bfloat16 b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cpu similarity index 100% rename from test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_bfloat16 rename to test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cpu diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_bool b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cpu similarity index 100% rename from test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_bool rename to test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cpu diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_complex128 b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cpu similarity index 100% rename from test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_complex128 rename to test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cpu diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_complex64 b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cpu similarity index 100% rename from test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_complex64 rename to test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cpu diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float16 b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cpu similarity index 100% rename from test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float16 rename to test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cpu diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float32 b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cpu similarity index 100% rename from test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float32 rename to test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cpu diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float64 b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cpu similarity index 100% rename from test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_float64 rename to test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cpu diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int16 b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cpu similarity index 100% rename from test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int16 rename to test/dynamo_expected_failures/TestAutogradFunctionCPU.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cpu diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int32 b/test/dynamo_expected_failures/TestAutogradFunctionCPU.test_once_differentiable_autograd_vjp_cpu similarity index 100% rename from test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int32 rename to test/dynamo_expected_failures/TestAutogradFunctionCPU.test_once_differentiable_autograd_vjp_cpu diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int64 b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cuda similarity index 100% rename from test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int64 rename to test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_input_mark_dirty_False_cuda diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int8 b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cuda similarity index 100% rename from test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_int8 rename to test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_jvp_save_tensors_output_mark_dirty_False_cuda diff --git a/test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_uint8 b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cuda similarity index 100% rename from test/dynamo_expected_failures/TestAsArrayCPU.test_copy_list_cpu_uint8 rename to test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_input_mark_dirty_False_cuda diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv1d b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cuda similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv1d rename to test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_False_save_for_vjp_save_tensors_output_mark_dirty_False_cuda diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv1d_pickle b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cuda similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv1d_pickle rename to test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_input_mark_dirty_False_cuda diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv2d b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cuda similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv2d rename to test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_jvp_save_tensors_output_mark_dirty_False_cuda diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv2d_pickle b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cuda similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv2d_pickle rename to test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_input_mark_dirty_False_cuda diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv3d b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cuda similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv3d rename to test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_function_returns_input_inner_requires_grad_True_save_for_vjp_save_tensors_output_mark_dirty_False_cuda diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv3d_pickle b/test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_once_differentiable_autograd_vjp_cuda similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv3d_pickle rename to test/dynamo_expected_failures/TestAutogradFunctionCUDA.test_once_differentiable_autograd_vjp_cuda diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose1d_kwargs b/test/dynamo_expected_failures/TestAutogradFunctionVmapAPICPU.test_has_vmap_staticmethod_and_has_generate_vmap_rule_cpu similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose1d_kwargs rename to test/dynamo_expected_failures/TestAutogradFunctionVmapAPICPU.test_has_vmap_staticmethod_and_has_generate_vmap_rule_cpu diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose1d_pickle b/test/dynamo_expected_failures/TestAutogradFunctionVmapAPICPU.test_no_vmap_staticmethod_and_no_generate_vmap_rule_cpu similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose1d_pickle rename to test/dynamo_expected_failures/TestAutogradFunctionVmapAPICPU.test_no_vmap_staticmethod_and_no_generate_vmap_rule_cpu diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose2d b/test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_has_vmap_staticmethod_and_has_generate_vmap_rule_cuda similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose2d rename to test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_has_vmap_staticmethod_and_has_generate_vmap_rule_cuda diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose2d_kwargs b/test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_no_vmap_staticmethod_and_no_generate_vmap_rule_cuda similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose2d_kwargs rename to test/dynamo_expected_failures/TestAutogradFunctionVmapAPICUDA.test_no_vmap_staticmethod_and_no_generate_vmap_rule_cuda diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose2d_pickle b/test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_hessian_cuda similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose2d_pickle rename to test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_hessian_cuda diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose3d b/test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_jacfwd_cuda similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose3d rename to test/dynamo_expected_failures/TestComposabilityCUDA.test_autograd_function_no_setup_context_transform_jacfwd_cuda diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose3d_kwargs b/test/dynamo_expected_failures/TestComposabilityCUDA.test_jvp_supports_saved_tensor_hooks_cuda similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose3d_kwargs rename to test/dynamo_expected_failures/TestComposabilityCUDA.test_jvp_supports_saved_tensor_hooks_cuda diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose3d_pickle b/test/dynamo_expected_failures/TestComposabilityCUDA.test_requires_grad_inside_transform_cuda similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transpose3d_pickle rename to test/dynamo_expected_failures/TestComposabilityCUDA.test_requires_grad_inside_transform_cuda diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transposed1d b/test/dynamo_expected_failures/TestHessianCUDA.test_jacfwd_different_levels_cuda similarity index 100% rename from test/dynamo_expected_failures/TestLazyModules.test_lazy_conv_transposed1d rename to test/dynamo_expected_failures/TestHessianCUDA.test_jacfwd_different_levels_cuda diff --git a/test/dynamo_expected_failures/TestLazyModules.test_lazy_linear_pickle b/test/dynamo_expected_failures/TestLazyModules.test_lazy_linear_pickle deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestLazyModules.test_linear b/test/dynamo_expected_failures/TestLazyModules.test_linear deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestUtils.test_get_fqn_to_example_inputs_complex_args b/test/dynamo_expected_failures/TestUtils.test_get_fqn_to_example_inputs_complex_args deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestUtils.test_get_fqn_to_example_inputs_default_kwargs b/test/dynamo_expected_failures/TestUtils.test_get_fqn_to_example_inputs_default_kwargs deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/TestUtils.test_get_fqn_to_example_inputs_simple b/test/dynamo_expected_failures/TestUtils.test_get_fqn_to_example_inputs_simple deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/edge/CMakeLists.txt b/test/edge/CMakeLists.txt index 50579c9109dc8..72c01a2d36492 100644 --- a/test/edge/CMakeLists.txt +++ b/test/edge/CMakeLists.txt @@ -73,5 +73,6 @@ elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang|GNU") ) endif() if(INSTALL_TEST) + set_target_properties(test_edge_op_registration PROPERTIES INSTALL_RPATH "${CMAKE_INSTALL_RPATH}:${_rpath_portable_origin}/../lib") install(TARGETS test_edge_op_registration DESTINATION bin) endif() diff --git a/test/expect/HasDecompTest.test_aten_core_operators.expect b/test/expect/HasDecompTest.test_aten_core_operators.expect index bf7ad0a4659cc..60529dfcc6370 100644 --- a/test/expect/HasDecompTest.test_aten_core_operators.expect +++ b/test/expect/HasDecompTest.test_aten_core_operators.expect @@ -174,8 +174,6 @@ aten::cumsum aten::cumsum.out aten::cumsum_ aten::diagonal -aten::diagonal_copy -aten::diagonal_copy.out aten::diagonal_scatter aten::diagonal_scatter.out aten::digamma @@ -393,6 +391,8 @@ aten::normal.float_float_out aten::normal.out aten::normal_ aten::permute +aten::permute_copy +aten::permute_copy.out aten::polar aten::polar.out aten::pow.Scalar @@ -506,6 +506,8 @@ aten::triu_indices.out aten::trunc aten::trunc.out aten::trunc_ +aten::unbind_copy.int +aten::unbind_copy.int_out aten::unfold aten::uniform aten::uniform.out @@ -513,9 +515,11 @@ aten::uniform_ aten::unsqueeze aten::upsample_bicubic2d aten::upsample_bicubic2d.out +aten::upsample_bilinear2d aten::upsample_nearest1d.out aten::upsample_nearest2d.out aten::upsample_nearest3d.out +aten::upsample_trilinear3d aten::var.correction aten::var.correction_out aten::var_mean.correction diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 444462b35dc78..83d660a97584c 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -1010,8 +1010,6 @@ aten::ones.names_out aten::ones.out aten::ormqr aten::ormqr.out -aten::permute_copy -aten::permute_copy.out aten::poisson aten::poisson.out aten::polygamma @@ -1294,8 +1292,6 @@ aten::topk.values aten::transpose_ aten::triangular_solve aten::triangular_solve.X -aten::unbind_copy.int -aten::unbind_copy.int_out aten::unique_consecutive aten::unique_consecutive.out aten::unique_dim diff --git a/test/export/test_db.py b/test/export/test_db.py index 50be33740bd8a..30ee827d117de 100644 --- a/test/export/test_db.py +++ b/test/export/test_db.py @@ -9,7 +9,7 @@ filter_examples_by_support_level, get_rewrite_cases, ) -from torch.export import export +from torch.export import export_for_training from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_WINDOWS, @@ -35,7 +35,7 @@ def test_exportdb_supported(self, name: str, case: ExportCase) -> None: kwargs_export = case.example_kwargs args_model = copy.deepcopy(args_export) kwargs_model = copy.deepcopy(kwargs_export) - exported_program = export( + exported_program = export_for_training( model, args_export, kwargs_export, @@ -67,7 +67,7 @@ def test_exportdb_not_supported(self, name: str, case: ExportCase) -> None: with self.assertRaises( (torchdynamo.exc.Unsupported, AssertionError, RuntimeError) ): - export( + export_for_training( model, case.example_args, case.example_kwargs, @@ -92,7 +92,7 @@ def test_exportdb_not_supported_rewrite( self, name: str, rewrite_case: ExportCase ) -> None: # pyre-ignore - export( + export_for_training( rewrite_case.model, rewrite_case.example_args, rewrite_case.example_kwargs, diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index d5ac532a5cc98..2b03284eaa6d6 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -6,9 +6,8 @@ import torch import torch._dynamo from torch._dynamo.test_case import run_tests, TestCase -from torch._export.wrappers import _mark_strict_experimental from torch._functorch.aot_autograd import aot_export_module -from torch.export import export +from torch.export import export, export_for_training from torch.export._trace import _convert_ts_to_export_experimental from torch.export.experimental import _export_forward_backward from torch.testing import FileCheck @@ -16,98 +15,6 @@ @unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't supported") class TestExperiment(TestCase): - def test_with_buffer_as_submodule(self): - @_mark_strict_experimental - class B(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.buffer1 = torch.nn.Buffer(torch.ones(3)) - - def forward(self, x): - y = x + 2 - y.add_(4) - # this doesnt' work today with HOO - # self.buffer1.add_(6) - buffer_updated = self.buffer1 + 6 - return x.sum() + y.sum() + buffer_updated.sum() - - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.submodule = B() - - def forward(self, x): - x_v2 = x.sin() - return (self.submodule(x_v2), x + 3) - - inp = torch.randn(3) - ep = torch.export.export(M(), (inp,), strict=False) - self.assertExpectedInline( - str(ep.graph_module.code.strip()), - """\ -def forward(self, b_submodule_buffer1, x): - sin = torch.ops.aten.sin.default(x) - strict_graph_0 = self.strict_graph_0 - strict_mode = torch.ops.higher_order.strict_mode(strict_graph_0, (sin, b_submodule_buffer1)); strict_graph_0 = sin = b_submodule_buffer1 = None - getitem = strict_mode[0]; strict_mode = None - add = torch.ops.aten.add.Tensor(x, 3); x = None - return (getitem, add)""", - ) - - self.assertExpectedInline( - str(ep.graph_module.strict_graph_0.code.strip()), - """\ -def forward(self, arg0_1, arg1_1): - add = torch.ops.aten.add.Tensor(arg0_1, 2) - add_1 = torch.ops.aten.add.Tensor(add, 4); add = None - add_2 = torch.ops.aten.add.Tensor(arg1_1, 6); arg1_1 = None - sum_1 = torch.ops.aten.sum.default(arg0_1); arg0_1 = None - sum_2 = torch.ops.aten.sum.default(add_1); add_1 = None - add_3 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None - sum_3 = torch.ops.aten.sum.default(add_2); add_2 = None - add_4 = torch.ops.aten.add.Tensor(add_3, sum_3); add_3 = sum_3 = None - return (add_4,)""", - ) - - eager_mod = M() - ep = torch.export.export(eager_mod, (inp,), strict=True) - - graph_res_1, graph_res_2 = ep.module()(inp) - eager_res_1, eager_res_2 = eager_mod(inp) - - self.assertTrue(torch.allclose(graph_res_2, eager_res_2)) - self.assertTrue(torch.allclose(graph_res_1, eager_res_1)) - - graph_res_1, graph_res_2 = ep.module()(inp) - eager_res_1, eager_res_2 = eager_mod(inp) - - self.assertTrue(torch.allclose(graph_res_2, eager_res_2)) - self.assertTrue(torch.allclose(graph_res_1, eager_res_1)) - - def test_mark_strict_with_container_type(self): - @_mark_strict_experimental - class B(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - - def forward(self, x): - x0 = x[0][0] - return x0.sum() - - class M(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.submodule = B() - - def forward(self, x): - return self.submodule(x) - - inp = ((torch.randn(3),),) - with self.assertRaisesRegex( - RuntimeError, "strict_mode HOO doesn't work unless" - ): - ep = torch.export.export(M(), inp, strict=False) - def test_torchscript_module_export(self): class M(torch.nn.Module): def forward(self, x): @@ -152,7 +59,7 @@ def _check_equality_and_annotations(m_func, inps): ) # ExportedProgram from original module. - original_exported_module = torch.export.export(m_func(), inps) + original_exported_module = torch.export.export_for_training(m_func(), inps) # Check whether input annotations are the same as tracing the original module. orig_ph_name_list = [ @@ -208,7 +115,7 @@ def forward(self, x): m = Module() example_inputs = (torch.randn(3),) m(*example_inputs) - ep = torch.export._trace._export(m, example_inputs, pre_dispatch=True) + ep = torch.export.export_for_training(m, example_inputs) joint_ep = _export_forward_backward(ep) self.assertExpectedInline( str(joint_ep.graph_module.code).strip(), @@ -222,13 +129,10 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): alias = torch.ops.aten.alias.default(_softmax) alias_1 = torch.ops.aten.alias.default(alias); alias = None clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None - alias_2 = torch.ops.aten.alias.default(clone); clone = None - alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None - alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None _log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None - alias_5 = torch.ops.aten.alias.default(_log_softmax) - alias_6 = torch.ops.aten.alias.default(alias_5); alias_5 = None - mul = torch.ops.aten.mul.Tensor(_log_softmax, alias_4); _log_softmax = None + alias_2 = torch.ops.aten.alias.default(_log_softmax) + alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None + mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None neg = torch.ops.aten.neg.default(sum_1); sum_1 = None div = torch.ops.aten.div.Scalar(neg, 1); neg = None @@ -236,18 +140,18 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None - mul_1 = torch.ops.aten.mul.Tensor(expand, alias_4); expand = alias_4 = None - alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None - alias_8 = torch.ops.aten.alias.default(alias_7); alias_7 = None - exp = torch.ops.aten.exp.default(alias_8); alias_8 = None + mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None + alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None + alias_5 = torch.ops.aten.alias.default(alias_4); alias_4 = None + exp = torch.ops.aten.exp.default(alias_5); alias_5 = None sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True) mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None - alias_9 = torch.ops.aten.alias.default(alias_1); alias_1 = None - alias_10 = torch.ops.aten.alias.default(alias_9); alias_9 = None - mul_3 = torch.ops.aten.mul.Tensor(sub, alias_10); sub = None + alias_6 = torch.ops.aten.alias.default(alias_1); alias_1 = None + alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None + mul_3 = torch.ops.aten.mul.Tensor(sub, alias_7); sub = None sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True) - mul_4 = torch.ops.aten.mul.Tensor(alias_10, sum_3); alias_10 = sum_3 = None + mul_4 = torch.ops.aten.mul.Tensor(alias_7, sum_3); alias_7 = sum_3 = None sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None permute_1 = torch.ops.aten.permute.default(view_2, [1, 0]) @@ -271,13 +175,10 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): alias = torch.ops.aten.alias.default(_softmax) alias_1 = torch.ops.aten.alias.default(alias); alias = None clone = torch.ops.aten.clone.default(c_lifted_tensor_0); c_lifted_tensor_0 = None - alias_2 = torch.ops.aten.alias.default(clone); clone = None - alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None - alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None _log_softmax = torch.ops.aten._log_softmax.default(_softmax, 0, False); _softmax = None - alias_5 = torch.ops.aten.alias.default(_log_softmax) - alias_6 = torch.ops.aten.alias.default(alias_5); alias_5 = None - mul = torch.ops.aten.mul.Tensor(_log_softmax, alias_4); _log_softmax = None + alias_2 = torch.ops.aten.alias.default(_log_softmax) + alias_3 = torch.ops.aten.alias.default(alias_2); alias_2 = None + mul = torch.ops.aten.mul.Tensor(_log_softmax, clone); _log_softmax = None sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None neg = torch.ops.aten.neg.default(sum_1); sum_1 = None div = torch.ops.aten.div.Scalar(neg, 1); neg = None @@ -285,18 +186,18 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None - mul_1 = torch.ops.aten.mul.Tensor(expand, alias_4); expand = alias_4 = None - alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None - alias_8 = torch.ops.aten.alias.default(alias_7); alias_7 = None - exp = torch.ops.aten.exp.default(alias_8); alias_8 = None + mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None + alias_4 = torch.ops.aten.alias.default(alias_3); alias_3 = None + alias_5 = torch.ops.aten.alias.default(alias_4); alias_4 = None + exp = torch.ops.aten.exp.default(alias_5); alias_5 = None sum_2 = torch.ops.aten.sum.dim_IntList(mul_1, [0], True) mul_2 = torch.ops.aten.mul.Tensor(exp, sum_2); exp = sum_2 = None sub = torch.ops.aten.sub.Tensor(mul_1, mul_2); mul_1 = mul_2 = None - alias_9 = torch.ops.aten.alias.default(alias_1); alias_1 = None - alias_10 = torch.ops.aten.alias.default(alias_9); alias_9 = None - mul_3 = torch.ops.aten.mul.Tensor(sub, alias_10); sub = None + alias_6 = torch.ops.aten.alias.default(alias_1); alias_1 = None + alias_7 = torch.ops.aten.alias.default(alias_6); alias_6 = None + mul_3 = torch.ops.aten.mul.Tensor(sub, alias_7); sub = None sum_3 = torch.ops.aten.sum.dim_IntList(mul_3, [0], True) - mul_4 = torch.ops.aten.mul.Tensor(alias_10, sum_3); alias_10 = sum_3 = None + mul_4 = torch.ops.aten.mul.Tensor(alias_7, sum_3); alias_7 = sum_3 = None sub_1 = torch.ops.aten.sub.Tensor(mul_3, mul_4); mul_3 = mul_4 = None view_2 = torch.ops.aten.view.default(sub_1, [1, 3]); sub_1 = None permute_1 = torch.ops.aten.permute.default(view_2, [1, 0]) @@ -323,8 +224,8 @@ def forward(self, x): m = Module() example_inputs = (torch.randn(3),) m(*example_inputs) - ep = torch.export._trace._export( - m, example_inputs, pre_dispatch=True, dynamic_shapes={"x": {0: Dim("x0")}} + ep = torch.export.export_for_training( + m, example_inputs, dynamic_shapes={"x": {0: Dim("x0")}} ) joint_ep = _export_forward_backward(ep) @@ -359,7 +260,7 @@ def forward(self, x, labels): labels = torch.ones(4, dtype=torch.int64) inputs = (x, labels) - ep = export(net, inputs) + ep = export_for_training(net, inputs) ep = _export_forward_backward(ep) diff --git a/test/export/test_export.py b/test/export/test_export.py index b13ce23c689ef..d00ceac949cf8 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -18,15 +18,12 @@ import torch.nn.functional as F from functorch.experimental.control_flow import cond, map from torch import Tensor -from torch._decomp import ( - _decomp_table_to_post_autograd_aten, - core_aten_decompositions, - get_decompositions, -) +from torch._decomp import decomposition_table, get_decompositions from torch._dynamo.test_case import TestCase from torch._dynamo.testing import normalize_gm from torch._export.pass_base import _ExportPassBaseDeprecatedDoNotUse from torch._export.utils import ( + _decomp_table_to_post_autograd_aten, get_buffer, get_param, is_buffer, @@ -36,7 +33,7 @@ from torch._higher_order_ops.hints_wrap import hints_wrapper from torch._inductor.compile_fx import split_const_gm from torch._subclasses import FakeTensorMode -from torch.export import Dim, export, unflatten +from torch.export import default_decompositions, Dim, export, unflatten from torch.export._trace import ( _export, _export_to_torch_ir, @@ -65,6 +62,7 @@ IS_WINDOWS, run_tests, skipIfCrossRef, + skipIfXpu, TEST_TRANSFORMERS, TestCase as TorchTestCase, ) @@ -167,8 +165,10 @@ class Inp: NON_STRICT_SUFFIX = "_non_strict" -RETRACEABILITY_SUFFIX = "_retraceability" +RETRACEABILITY_STRICT_SUFFIX = "_retraceability" +RETRACEABILITY_NON_STRICT_SUFFIX = "_retraceability_non_strict" SERDES_SUFFIX = "_serdes" +SERDES_NON_STRICT_SUFFIX = "_serdes_non_strict" PREDISPATCH_SUFFIX = "_pre_dispatch" TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp" TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_non_strict" @@ -179,11 +179,15 @@ def is_non_strict_test(test_name): def is_retracebility_test(test_name): - return test_name.endswith(RETRACEABILITY_SUFFIX) + return test_name.endswith(RETRACEABILITY_STRICT_SUFFIX) or test_name.endswith( + RETRACEABILITY_NON_STRICT_SUFFIX + ) def is_serdes_test(test_name): - return test_name.endswith(SERDES_SUFFIX) + return test_name.endswith(SERDES_SUFFIX) or test_name.endswith( + SERDES_NON_STRICT_SUFFIX + ) def is_training_ir_test(test_name): @@ -230,8 +234,14 @@ def forward(self, x): inp = torch.zeros([3]) dim_x = torch.export.Dim("dim_x", min=6) - with self.assertRaisesRegex(torch._dynamo.exc.UserError, "not in range"): - torch.export.export( + + if is_non_strict_test(self._testMethodName): + error_type = torch.fx.experimental.symbolic_shapes.ConstraintViolationError + else: + error_type = torch._dynamo.exc.UserError + + with self.assertRaisesRegex(error_type, "not in range"): + export( InvalidInputConflictWithInputConstraints(), (inp,), dynamic_shapes={"x": {0: dim_x}}, @@ -352,6 +362,60 @@ def forward(self, x, y): inp = ([torch.ones(1, 3)], torch.ones(1, 3)) self._test_export_same_as_eager(f, inp) + @skipIfCrossRef + def test_custom_tag_metadata_re_export(self): + class Foo(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.nn.Parameter(torch.rand(4, 2)) + self.b = torch.nn.Parameter(torch.rand(4)) + + def forward(self, x): + out = torch.nn.functional.linear(x, self.w, self.b) + return out + + f = Foo() + inputs = (torch.zeros(1, 2),) + ep = export(f, inputs) + + new_gm = copy.deepcopy(ep.graph_module) + new_gm.meta["custom"] = {} + new_gm.meta["custom"]["f"] = "bar" + + for node in new_gm.graph.nodes: + if ( + node.op == "call_function" + and node.target == torch.ops.aten.linear.default + ): + node.meta["custom"] = {} + node.meta["custom"]["quantization_tag"] = "foo" + + new_ep = ep._update(new_gm, ep.graph_signature) + new_ep = export(new_ep.module(), inputs) + self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar") + + # the custom field should be preserved after re-export and + # should not be copied to other nodes + counter = 0 + for node in new_ep.graph.nodes: + if "custom" in node.meta: + counter += 1 + self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") + self.assertTrue(node.target == torch.ops.aten.linear.default) + + self.assertEqual(counter, 1) + + def test_symint_output(self): + class Foo(torch.nn.Module): + def forward(self, x): + z, y = x.size() + return z + y + x[0], z + + inputs = (torch.ones(2, 3),) + dim0_x, dim1_x = torch.export.dims("dim0_x", "dim1_x") + dynamic_shapes = {"x": (dim0_x, dim1_x)} + export(Foo(), inputs, dynamic_shapes=dynamic_shapes) + def test_no_tensor_computation(self): class Module(torch.nn.Module): def forward(self, x, y): @@ -773,6 +837,26 @@ def false_fn(x): ) torch.export.export(M(), args) + def test_cond_int_closure(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.num = 4 + + def forward(self, a, x): + def true_fn(x): + return x * self.num + + def false_fn(x): + return x + self.num + + r = torch.cond(a, true_fn, false_fn, (x,)) + return r * 2 + + args = (torch.tensor(True), torch.randn(10)) + ep = torch.export.export(M(), args) + self.assertEqual(ep.module()(*args), M()(*args)) + def test_state_tensors(self): class M(torch.nn.Module): # simple with register buffer def __init__(self) -> None: @@ -911,7 +995,6 @@ def forward(self, x): ep_model = export(model, (x,), strict=False).module() self.assertTrue(torch.allclose(model(x), ep_model(x))) - @testing.expectedFailureTrainingIRToRunDecompNonStrict # TODO(pianpwk): user_output signature def test_real_tensor_for_max_op(self): class Foo(torch.nn.Module): def forward(self, x, y): @@ -930,6 +1013,48 @@ def forward(self, x, y): self.assertEqual(ep.module()(x, x), model(x, x)) self.assertEqual(ep.module()(x, y), model(x, y)) + @testing.expectedFailureSerDer # SymBool serialization? TODO(pianpwk) + @testing.expectedFailureSerDerNonStrict + def test_real_tensor_bool_cast(self): + class Foo(torch.nn.Module): + def forward(self, x): + return bool(x.eq(0.1).any()) + + model = Foo() + inputs = (torch.randn(64),) + with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True): + ep = export(model, inputs, strict=False) + + @testing.expectedFailureSerDer + @testing.expectedFailureSerDerNonStrict + def test_is_nonzero(self): + class Foo(torch.nn.Module): + def forward(self, x): + return torch.is_nonzero(x) + + def _long_tensor(nz): + return torch.full((), int(nz)) + + def _float_tensor(nz): + return torch.full((), int(nz), dtype=torch.float32) + + def _bool_tensor(nz): + return torch.full((), int(nz)).bool() + + mod = Foo() + for _tensor in [ + _long_tensor, + _float_tensor, + _bool_tensor, + # local_scalar_dense on complex NYI for fake tensors + ]: + with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True): + for nz in [True, False]: + sample_input = _tensor(nz=nz) + ep = export(mod, (sample_input,), strict=False) + self.assertEqual(ep.module()(sample_input), nz) + print(ep) + def test_export_script_module(self): class Foo(torch.nn.Module): def forward(self, rv: torch.Tensor, t: torch.Tensor): @@ -949,6 +1074,57 @@ def forward(self, rv: torch.Tensor, t: torch.Tensor): TS2EPConverter(foo_script, inp).convert() + def test_dim_auto_and_dim(self): + # test basic Dims + class Foo(torch.nn.Module): + def forward(self, x, y): + return x - y + + inputs = (torch.randn(4, 4), torch.randn(4, 4)) + shapes = { + "x": (Dim.AUTO, Dim("d1", min=3)), + "y": (Dim("d0", max=8), Dim.DYNAMIC), + } + ep = export(Foo(), inputs, dynamic_shapes=shapes) + x, y = [node for node in ep.graph.nodes if node.op == "placeholder"] + self.assertEqual((s0 := x.meta["val"].shape[0]), y.meta["val"].shape[0]) + self.assertEqual((s1 := x.meta["val"].shape[1]), y.meta["val"].shape[1]) + vr0 = ep.range_constraints[s0.node.expr] + vr1 = ep.range_constraints[s1.node.expr] + self.assertEqual([vr0.upper, vr1.lower], [8, 3]) + + # test derived Dims + class Bar(torch.nn.Module): + def forward(self, x, y, z): + return x + y[1::3] + z + + inputs = (torch.randn(4), torch.randn(13), torch.randn(4)) + dx = Dim("dx", min=2, max=10) + shapes = { + "x": (dx,), + "y": (3 * dx + 1,), + "z": (Dim.AUTO,), + } + ep = export(Bar(), inputs, dynamic_shapes=shapes) + x, y, z = [node for node in ep.graph.nodes if node.op == "placeholder"] + self.assertEqual((s0 := x.meta["val"].shape[0]), z.meta["val"].shape[0]) + expr = y.meta["val"].shape[0] + free_symbols = expr.node.expr.free_symbols + self.assertEqual(len(free_symbols), 1) + self.assertEqual(next(iter(free_symbols)), s0.node.expr) + + # test specialization still complains + inputs = (torch.randn(4), torch.randn(4)) + shapes = { + "x": (Dim.STATIC,), + "y": (Dim("dy"),), + } + with self.assertRaisesRegex( + torch._dynamo.exc.UserError, + r"Not all values of dy .* in the specified range are valid because dy was inferred to be a constant", + ): + export(Foo(), inputs, dynamic_shapes=shapes) + def test_torch_fn(self): class M1(torch.nn.Module): def __init__(self) -> None: @@ -1004,6 +1180,7 @@ def forward(self, x, weight, bias): self.assertEqual(actual_result, expected_result) @testing.expectedFailureSerDer # failed serializing SymInt nodes in subgraph (known issue) + @testing.expectedFailureSerDerNonStrict def test_hoo_inline_users_issue(self): # This came from an issue where replace_with_hop passes would inline subgraphs, # and mess up node.users for nodes present in multiple subgraphs (e.g. _x in SetGradCase @@ -1043,6 +1220,7 @@ def forward(self, x): ) check_users_for_graph(ep.graph) + @unittest.skipIf(IS_FBCODE, "Broken in fbcode") def test_export_predispatch_custom_ops_warnings(self): @torch.library.custom_op("mylib::foo", mutates_args={}) def foo(x: torch.Tensor) -> torch.Tensor: @@ -1063,6 +1241,8 @@ def forward(self, x): warnings.simplefilter("error") torch.export.export(Foo(), (x,)) + ops_registered_before = set(torch.ops.mylib) + # Assert warning for CompositeImplictAutograd op with torch.library._scoped_library("mylib", "FRAGMENT") as lib: lib.define("foo123(Tensor x) -> Tensor") @@ -1079,6 +1259,38 @@ def forward(self, x): warnings.simplefilter("always") torch.export.export(Bar(), (x,)) + ops_registered_after = set(torch.ops.mylib) + self.assertEqual(ops_registered_after, ops_registered_before) + + def test_export_preserve_linear_but_not_custom_op(self): + table = torch.export.default_decompositions() + del table[torch.ops.aten.linear.default] + + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define("foo123(Tensor x) -> Tensor") + lib.impl("foo123", lambda x: x.sin(), "CompositeImplicitAutograd") + + class Bar(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4) + + def forward(self, x): + lin = self.linear(x) + return torch.ops.mylib.foo123(lin) + + x = torch.randn(4, 4) + ep = export(Bar(), (x,)).run_decompositions(table) + + self.assertExpectedInline( + str(ep.graph_module.code).strip(), + """\ +def forward(self, p_linear_weight, p_linear_bias, x): + linear = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias); x = p_linear_weight = p_linear_bias = None + sin = torch.ops.aten.sin.default(linear); linear = None + return (sin,)""", + ) + def test_export_preserve_linear_at_aot_level(self): class Foo(torch.nn.Module): def __init__(self) -> None: @@ -1090,14 +1302,9 @@ def forward(self, x): return torch.ops.aten.chunk.default(x, 3, 0) ep = torch.export.export(Foo(), (torch.randn(3, 3),)) - if IS_FBCODE: - ep = ep.run_decompositions( - {}, _preserve_ops=(torch.ops.aten.linear.default,) - ) - else: - decomp_table = _decomp_table_to_post_autograd_aten() - del decomp_table[torch.ops.aten.linear.default] - ep = ep.run_decompositions(decomp_table) + decomp_table = _decomp_table_to_post_autograd_aten() + del decomp_table[torch.ops.aten.linear.default] + ep = ep.run_decompositions(decomp_table) gm = ep.graph_module # linear is CompositeImplicitAutograd functional op so we should preserve it @@ -1553,13 +1760,7 @@ def forward(self, x, y): ep = torch.export.export( Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50)) ) - if IS_FBCODE: - ep_has_linear_convd = ep.run_decompositions( - {}, - _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY, - ) - else: - ep_has_linear_convd = ep.run_decompositions({}) + ep_has_linear_convd = ep.run_decompositions({}) self.assertExpectedInline( str(ep_has_linear_convd.graph_module.code).strip(), @@ -1574,19 +1775,11 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_ return (add,)""", ) - if IS_FBCODE: - ep_has_convd = ep.run_decompositions( - _preserve_ops=( - torch.ops.aten.conv2d.default, - torch.ops.aten.conv1d.default, - ) - ) - else: - decomp_table = core_aten_decompositions() - del decomp_table[torch.ops.aten.conv2d.default] - del decomp_table[torch.ops.aten.conv1d.default] + decomp_table = default_decompositions() + del decomp_table[torch.ops.aten.conv2d.default] + del decomp_table[torch.ops.aten.conv1d.default] - ep_has_convd = ep.run_decompositions(decomp_table=decomp_table) + ep_has_convd = ep.run_decompositions(decomp_table=decomp_table) self.assertExpectedInline( str(ep_has_convd.graph_module.code).strip(), """\ @@ -1602,15 +1795,10 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, c_ add = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None return (add,)""", ) - if IS_FBCODE: - ep_has_convd = ep_has_convd.run_decompositions( - _preserve_ops=(torch.ops.aten.conv2d.default,) - ) - else: - decomp_table = core_aten_decompositions() - del decomp_table[torch.ops.aten.conv2d.default] + decomp_table = default_decompositions() + del decomp_table[torch.ops.aten.conv2d.default] - ep_has_convd = ep_has_convd.run_decompositions(decomp_table=decomp_table) + ep_has_convd = ep_has_convd.run_decompositions(decomp_table=decomp_table) self.assertExpectedInline( str(ep_has_convd.graph_module.code).strip(), """\ @@ -1654,15 +1842,9 @@ def forward(self, x, y): Foo(), (torch.randn(20, 16, 50, 100), torch.randn(20, 16, 50)) ) - if IS_FBCODE: - ep_has_linear_convd = ep.run_decompositions( - {}, - _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY, - ) - else: - ep_has_linear_convd = ep.run_decompositions( - decomp_table={}, - ) + ep_has_linear_convd = ep.run_decompositions( + decomp_table={}, + ) self.assertExpectedInline( str(ep_has_linear_convd.graph_module.code).strip(), @@ -1677,19 +1859,11 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_ return (add,)""", ) - if IS_FBCODE: - ep_has_convd = ep.run_decompositions( - _preserve_ops=( - torch.ops.aten.conv2d.default, - torch.ops.aten.conv1d.default, - ) - ) - else: - decomp_table = core_aten_decompositions() - del decomp_table[torch.ops.aten.conv2d.default] - del decomp_table[torch.ops.aten.conv1d.default] + decomp_table = default_decompositions() + del decomp_table[torch.ops.aten.conv2d.default] + del decomp_table[torch.ops.aten.conv1d.default] - ep_has_convd = ep.run_decompositions(decomp_table=decomp_table) + ep_has_convd = ep.run_decompositions(decomp_table=decomp_table) self.assertExpectedInline( str(ep_has_convd.graph_module.code).strip(), @@ -1707,14 +1881,9 @@ def forward(self, p_conv_weight, p_conv_bias, p_conv1d_weight, p_conv1d_bias, b_ return (add,)""", ) - if IS_FBCODE: - ep_has_convd = ep_has_convd.run_decompositions( - _preserve_ops=(torch.ops.aten.conv2d.default,) - ) - else: - decomp_table = core_aten_decompositions() - del decomp_table[torch.ops.aten.conv2d.default] - ep_has_convd = ep_has_convd.run_decompositions(decomp_table=decomp_table) + decomp_table = default_decompositions() + del decomp_table[torch.ops.aten.conv2d.default] + ep_has_convd = ep_has_convd.run_decompositions(decomp_table=decomp_table) self.assertExpectedInline( str(ep_has_convd.graph_module.code).strip(), @@ -1745,20 +1914,80 @@ def forward(self, x): ): ep.run_decompositions({torch.ops.aten.index_put_.default: None}) + def test_export_custom_decomp_table_basic_pop(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define("foo123(Tensor x) -> Tensor") + lib.impl("foo123", lambda x: x.sin(), "CompositeImplicitAutograd") + + lib.define("foo456(Tensor x) -> Tensor") + lib.impl("foo456", lambda x: x.sin(), "CompositeImplicitAutograd") + + table = default_decompositions() + # Since this table hasn't been materialized yet, we shouldn't error + val = table.pop(torch.ops.mylib.foo123.default) + self.assertIsNotNone(val) + + with self.assertRaisesRegex(KeyError, "mylib.foo123.default"): + table.pop(torch.ops.mylib.foo123.default) + + val = table.pop(torch.ops.mylib.foo123.default, "HELLO") + self.assertEqual(val, "HELLO") + + all_ops = set(k for k, v in table.items()) + self.assertTrue(table.has_materialized) + # When we force materialize, torch.ops.mylib.foo123.default should have gone + self.assertFalse(torch.ops.mylib.foo123.default in all_ops) + self.assertTrue(torch.ops.mylib.foo456.default in all_ops) + + def test_export_custom_decomp_table_container_methods(self): + # tests __len__ + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + table = default_decompositions() + length_before = len(table) + lib.define("foo123(Tensor x) -> Tensor") + lib.impl("foo123", lambda x: x.sin(), "CompositeImplicitAutograd") + + lib.define("foo456(Tensor x) -> Tensor") + lib.impl("foo456", lambda x: x.sin(), "CompositeImplicitAutograd") + + table = default_decompositions() + self.assertEqual(len(table) - length_before, 2) + + # tests __contains__ + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define("foo123(Tensor x) -> Tensor") + lib.impl("foo123", lambda x: x.sin(), "CompositeImplicitAutograd") + + table = default_decompositions() + self.assertTrue(torch.ops.mylib.foo123.default in table) + del table[torch.ops.mylib.foo123.default] + self.assertFalse(torch.ops.mylib.foo123.default in table) + + # Lot of ppl do + # for op in all_ops: + # if op in table: + # del table[op] + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + lib.define("foo123(Tensor x) -> Tensor") + lib.impl("foo123", lambda x: x.sin(), "CompositeImplicitAutograd") + + table = default_decompositions() + if torch.ops.mylib.foo123.default in table: + del table[torch.ops.mylib.foo123.default] + + self.assertFalse(torch.ops.mylib.foo123.default in table) + table.materialize() + self.assertFalse(torch.ops.mylib.foo123.default in table) + def test_if_post_autograd_op_preserved(self): class Foo(torch.nn.Module): def forward(self, x): return x.sin() + x.sum() ep = export(Foo(), (torch.ones(3, 3),)) - if IS_FBCODE: - ep_preserve_sum = ep.run_decompositions( - _preserve_ops=(torch.ops.aten.sum.default,) - ) - else: - decomp_table = core_aten_decompositions() - del decomp_table[torch.ops.aten.sum.default] - ep_preserve_sum = ep.run_decompositions(decomp_table) + decomp_table = default_decompositions() + del decomp_table[torch.ops.aten.sum.default] + ep_preserve_sum = ep.run_decompositions(decomp_table) # Even though we are decomposing to core aten which should make # sum into sum.dim_IntList, we explicitly marked it to not do that. @@ -2172,6 +2401,7 @@ def forward(self, x, y, z): if node.op == "placeholder": self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)") + @testing.expectedFailureRetraceabilityNonStrict def test_dynamic_shapes_builder_kwargs(self): class M(torch.nn.Module): def forward(self, x, y, z): @@ -2332,18 +2562,6 @@ def forward(self, x): ): export(M(), inputs, dynamic_shapes=dynamic_shapes) - dynamic_shapes = { - "x": {"k": {"k": [(dim,), (AUTO,)]}} - } # mixing AUTO and Dims is not well supported. - with self.assertRaisesRegex( - torch._dynamo.exc.UserError, - re.escape( - "Specifying both `Dim.AUTO/Dim.DYNAMIC` and `Dim/DerivedDim` in `dynamic_shapes` is not well supported at the moment, " - "and can easily lead to constraint violation errors or obscure errors in torch.export." - ), - ): - export(M(), inputs, dynamic_shapes=dynamic_shapes) - class N(torch.nn.Module): def forward(self, x): return x["k"]["k1"][0] + x["k"]["k2"][0] @@ -2354,6 +2572,47 @@ def forward(self, x): dynamic_shapes = ({"k": {"k2": [(dim,)], "k1": [(dim,)]}},) # ok export(N(), inputs, dynamic_shapes=dynamic_shapes) + @testing.expectedFailureSerDer # no unbacked bindings after deserialization? + @testing.expectedFailureSerDerNonStrict + def test_unbacked_bindings_for_divisible_u_symint(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor a, Tensor b) -> (Tensor)", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + class M(torch.nn.Module): + def forward(self, a, b): + return torch.ops.mylib.foo(a, b) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + def foo_impl(a, b): + return a[b.item()] + + @torch.library.register_fake("mylib::foo", lib=lib) + def foo_fake_impl(a, b): + ctx = torch.library.get_ctx() + u = ctx.new_dynamic_size(min=0, max=len(a) // 10) * 10 + return torch.empty(u, a.shape[1], dtype=a.dtype) + + ep = export( + M(), + (torch.randn(100, 4), torch.tensor(10)), + ) + foo = [node for node in ep.graph.nodes if node.name == "foo"][0] + unbacked_bindings = foo.meta["unbacked_bindings"] + self.assertEqual(len(unbacked_bindings), 1) # check binding is {u: path} + u = next(iter(unbacked_bindings.keys())) + self.assertEqual( + type(u).__name__, "Symbol" + ) # check binding is symbol, not expr + path = unbacked_bindings[u] + self.assertEqual(len(path), 3) # check path is [size, 0, DivideByKey(10)] + self.assertEqual(type(path[2]).__name__, "DivideByKey") + self.assertEqual(path[2].divisor, 10) + def test_torch_check_eq_commutativity(self): class M1(torch.nn.Module): def forward(self, x1, x2, x3, y): @@ -2515,6 +2774,7 @@ def forward(self, t): export(N(), (t,), strict=strict) @testing.expectedFailureSerDer # T195866111 + @testing.expectedFailureSerDerNonStrict def test_suggested_fixes_for_data_dependent_errors_puzzlers(self): # suggested fixes for data-dependent errors only work in non-strict mode strict = False @@ -2775,6 +3035,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ): em.module()(x) + @testing.expectedFailureRetraceabilityNonStrict def test_dont_duck_size_for_auto_dynamic(self): AUTO, STATIC = Dim.AUTO, Dim.STATIC @@ -2795,6 +3056,7 @@ def forward(self, x, y): ep.module()(torch.randn(6, 3), torch.randn(7, 4)) @testing.expectedFailureRetraceability # T183144629 + @testing.expectedFailureSerDerNonStrict def test_map(self): class Module(torch.nn.Module): def forward(self, xs, y, z): @@ -2837,6 +3099,7 @@ def forward(self, image, crop_height, crop_width): args = (torch.rand(3, 700, 700), 150, 150) self.assertEqual(ecrop.module()(*args), ecrop(*args)) + @testing.expectedFailureRetraceabilityNonStrict def test_export_func_with_kwargs(self): class Module(torch.nn.Module): def forward(self, arg1, arg2, kw1, kw2): @@ -2847,6 +3110,7 @@ def forward(self, arg1, arg2, kw1, kw2): kwargs = {"kw1": torch.ones(1, 1), "kw2": torch.ones(6, 4)} self._test_export_same_as_eager(kw_func, args, kwargs) + @testing.expectedFailureRetraceabilityNonStrict def test_export_func_with_pytree_kwargs(self): class Module(torch.nn.Module): def forward(self, arg1, arg2, a, b): @@ -2860,6 +3124,7 @@ def forward(self, arg1, arg2, a, b): } self._test_export_same_as_eager(kw_func, args, kwargs) + @testing.expectedFailureRetraceabilityNonStrict def test_export_func_with_default_kwargs(self): class Module(torch.nn.Module): def forward(self, arg1, arg2, a, b=1): @@ -2890,6 +3155,7 @@ def forward(self, arg1, arg2, *args): args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4)) self._test_export_same_as_eager(kw_func, args) + @testing.expectedFailureRetraceabilityNonStrict def test_export_func_with_keyword_only_args(self): class Module(torch.nn.Module): def forward(self, arg1, arg2, *args, kw1, kw2): @@ -2900,6 +3166,7 @@ def forward(self, arg1, arg2, *args, kw1, kw2): kwargs = {"kw1": torch.ones(2, 3), "kw2": torch.ones(3, 4)} self._test_export_same_as_eager(kw_func, args, kwargs) + @testing.expectedFailureRetraceabilityNonStrict def test_export_func_with_var_keyword_args(self): class Module(torch.nn.Module): def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs): @@ -2994,6 +3261,7 @@ def forward(self, x, y): self.assertTrue(torch.allclose(orig_res[1], ep_res[1])) self.assertTrue(torch.allclose(orig_res[2], ep_res[2])) + @testing.expectedFailureRetraceabilityNonStrict def test_export_func_with_var_keyword_pytree_args(self): class Module(torch.nn.Module): def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs): @@ -3018,8 +3286,10 @@ def forward(self, arg1, arg2, *args, kw1, kw2, **kwargs): self._test_export_same_as_eager(kw_func, args, kwargs) @testing.expectedFailureSerDer # we don't save placeholder metadata + @testing.expectedFailureSerDerNonStrict @testing.expectedFailureNonStrict @testing.expectedFailureTrainingIRToRunDecompNonStrict # source_fn_stack failure + @testing.expectedFailureRetraceabilityNonStrict def test_linear_conv(self): class MyLinear(torch.nn.Module): def __init__(self) -> None: @@ -3935,27 +4205,57 @@ def forward(self, x): "torch.ops.aten._assert_async.msg", 1, exactly=True ).run(ep.graph_module.code) + @testing.expectedFailureRetraceabilityNonStrict def test_decomp_item_in_prim_after_decomposition(self): class M(torch.nn.Module): def forward(self, x): torch.ops.aten._assert_async.msg(torch.tensor(True), "Fail") return x - from torch._decomp import decomposition_table + decomp_table = {**_decomp_table_to_post_autograd_aten(), **decomposition_table} - ep = export(M(), (torch.randn(2, 2),)).run_decompositions(decomposition_table) + ep = export(M(), (torch.randn(2, 2),)).run_decompositions(decomp_table) # The difference seems fine because export_for_training catches const tensor little differently. + # Training IR produces: + # graph(): + # %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0] + # %x : [num_users=1] = placeholder[target=x] + # %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {}) + # %detach_ : [num_users=1] = call_function[target=torch.ops.aten.detach_.default](args = (%lift_fresh_copy,), kwargs = {}) + # %_assert_async : [num_users=0] = call_function[target=torch.ops.aten._assert_async.msg](args = (%detach_, Fail), kwargs = {}) + # return (x,) + # + # Pre-dispatch functionalization produces: + # graph(): + # %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0] + # %x : [num_users=1] = placeholder[target=x] + # %lift_fresh_copy : [num_users=1] = call_function[target=torch.ops.aten.lift_fresh_copy.default](args = (%c_lifted_tensor_0,), kwargs = {}) + # %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%lift_fresh_copy,), kwargs = {}) + # %_assert_async : [num_users=0] = call_function[target=torch.ops.aten._assert_async.msg](args = (%detach, Fail), kwargs = {}) + # return (x,) + # + # Retracing: + # graph(): + # %c_lifted_tensor_0 : [num_users=1] = placeholder[target=c_lifted_tensor_0] + # %x : [num_users=1] = placeholder[target=x] + # %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%c_lifted_tensor_0,), kwargs = {}) + # %detach : [num_users=1] = call_function[target=torch.ops.aten.detach.default](args = (%clone,), kwargs = {}) + # %_assert_async : [num_users=0] = call_function[target=torch.ops.aten._assert_async.msg](args = (%detach, Fail), kwargs = {}) + # return (x,) + # The difference comes from the fact that prim has registration for aten.detach while not for aten.detach_. + # The diference in retracing comes from the fact that we retrace at pre-dispatch level while the usual flow + # traces to post-dispatch. if is_training_ir_test(self._testMethodName): self.assertExpectedInline( str(ep.graph_module.code).strip(), """\ def forward(self, c_lifted_tensor_0, x): - clone = torch.ops.prims.clone.default(c_lifted_tensor_0, memory_format = torch.preserve_format); c_lifted_tensor_0 = None - _assert_async = torch.ops.aten._assert_async.msg(clone, 'Fail'); clone = _assert_async = None + lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None + _assert_async = torch.ops.aten._assert_async.msg(lift_fresh_copy, 'Fail'); lift_fresh_copy = _assert_async = None return (x,)""", ) - else: + elif is_retracebility_test(self._testMethodName): self.assertExpectedInline( str(ep.graph_module.code).strip(), """\ @@ -3965,6 +4265,18 @@ def forward(self, c_lifted_tensor_0, x): view_of_1 = torch.ops.prims.view_of.default(view_of); view_of = None view_of_2 = torch.ops.prims.view_of.default(view_of_1); view_of_1 = None _assert_async = torch.ops.aten._assert_async.msg(view_of_2, 'Fail'); view_of_2 = _assert_async = None + return (x,)""", + ) + else: + self.assertExpectedInline( + str(ep.graph_module.code).strip(), + """\ +def forward(self, c_lifted_tensor_0, x): + lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(c_lifted_tensor_0); c_lifted_tensor_0 = None + view_of = torch.ops.prims.view_of.default(lift_fresh_copy); lift_fresh_copy = None + view_of_1 = torch.ops.prims.view_of.default(view_of); view_of = None + view_of_2 = torch.ops.prims.view_of.default(view_of_1); view_of_1 = None + _assert_async = torch.ops.aten._assert_async.msg(view_of_2, 'Fail'); view_of_2 = _assert_async = None return (x,)""", ) @@ -4284,6 +4596,34 @@ def forward(self, a, b, alpha: int): if node.op == "placeholder": self.assertTrue(isinstance(node.meta["val"], (Tensor, int))) + def test_tensor_constant_with_wrapped_method(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.constant = torch.ones(4, 4) + + def forward(self, x): + return x + self.constant, self.constant + + class Wrapper(torch.nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, *arg, **kwargs): + return self.fn(*arg, **kwargs) + + inp = (torch.zeros(4, 4),) + + def test(m): + m_result = m(*inp) + ep_result = export(m, inp).module()(*inp) + for m_t, ep_t in zip(m_result, ep_result): + self.assertTrue(torch.allclose(m_t, ep_t)) + + test(M()) + test(Wrapper(M().forward)) + def test_export_with_inline_constraints(self): class Module(torch.nn.Module): def forward(self, x): @@ -4897,7 +5237,7 @@ def forward(self, x): inp = (torch.randn(5, 10),) m = M() - decomp_table = torch.export.core_aten_decompositions() + decomp_table = torch.export.default_decompositions() def _custom_decomp_for_linear(x, weight, bias): return x + bias.sum() @@ -4951,7 +5291,7 @@ def forward(self, x): def custom_decomp_callable(x, weight, bias): return x + bias - decomp_table = core_aten_decompositions() + decomp_table = default_decompositions() decomp_table[torch.ops.aten.linear.default] = custom_decomp_callable core_aten_ep = ep.run_decompositions(decomp_table) self.assertExpectedInline( @@ -5121,12 +5461,12 @@ def forward(self, x, y): return {"prediction": (x + y, self.bff)} mod = ModuleConstant() - ep = torch.export.export(mod, ()) + ep = export(mod, ()) self.assertEqual(ep.module()(), mod()) args = (torch.randn(3, 2), torch.randn(3, 2)) mod = ModuleNestedConstant() - ep = torch.export.export(mod, args) + ep = export(mod, args) self.assertEqual(ep.module()(*args), mod(*args)) def test_non_arg_name_dynamic_shapes_api_with_kwarg(self): @@ -5349,6 +5689,7 @@ def forward(self, x): unflattened = unflatten(ep) self.assertTrue(torch.allclose(unflattened(*inps), M2()(*inps))) + @testing.expectedFailureRetraceabilityNonStrict def test_lazy_module_kwargs(self): class LazyModule(torch.nn.modules.lazy.LazyModuleMixin, torch.nn.Module): def initialize_parameters(self, *args, **kwargs): @@ -5358,9 +5699,7 @@ def forward(self, x, y): return x + y m = LazyModule() - ep = torch.export.export( - m, (), {"x": torch.randn(3, 3), "y": torch.randn(3, 3)} - ) + ep = export(m, (), {"x": torch.randn(3, 3), "y": torch.randn(3, 3)}) inputs = {"x": torch.randn(3, 3), "y": torch.randn(3, 3)} self.assertEqual(ep.module()(**inputs), m(**inputs)) @@ -5375,11 +5714,10 @@ def forward(self, x): return x.sum() + self.buffer.sum() inp = torch.randn(4, 4) - gm = _export( + gm = export( Foo(), (inp,), dynamic_shapes=({0: torch.export.Dim("dim", min=3)},), - pre_dispatch=True, ).module() with self.assertRaisesRegex( @@ -5390,9 +5728,9 @@ def forward(self, x): with self.assertRaisesRegex( RuntimeError, escape("Expected input at *args[0].shape[0]") ): - torch.export.export(gm, (torch.randn(2, 2),)) + export(gm, (torch.randn(2, 2),)) - ep = torch.export.export( + ep = export( gm, (torch.randn(5, 4),), dynamic_shapes=({0: torch.export.Dim("dim", min=3)},), @@ -5479,6 +5817,7 @@ def forward(self, x): export_res = decomposed_ep.module()(x) self.assertTrue(export_res.size() == exp_res.size()) + @skipIfXpu def test_export_with_fake_tensor_inputs_on_cuda_devices(self): fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() @@ -5773,6 +6112,7 @@ def forward(self, q, k, v): self.assertEqual(ep.module()(*inputs), m(*inputs)) @testing.expectedFailureSerDer # symfloat nyi + @testing.expectedFailureSerDerNonStrict def test_sym_sqrt(self): import math @@ -6020,6 +6360,69 @@ def forward(self, x): self.assertEqual(gm_flat_non_strict(*inp), gm_flat_strict(*inp)) + def test_unflatten_no_unroll(self): + inp = (torch.ones(1),) + + class N(torch.nn.Module): + def __init__(self): + super().__init__() + self.const = torch.ones(1) * 4 + self.buf = torch.nn.Buffer(torch.ones(1) * 4) + + def forward(self, x, b): + if b: + return x + self.const + 1 + else: + return x + 2 * (self.buf + 1) - self.const + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.n = N() + + def forward(self, x): + x0 = x + 3 + x1 = self.n(x0, True) + x2 = self.n(x0, False) + return x1 + x2 + + m = M() + eager_result = m(*inp) + + def test(ep, swap): + epm = ep.module() + ufm = torch.export.unflatten(ep) + + exported_result = epm(*inp) + self.assertTrue(torch.allclose(exported_result, eager_result)) + + unflattened_result = ufm(*inp) + self.assertTrue(torch.allclose(unflattened_result, eager_result)) + + for fqn, mod in swap.items(): + ufm.set_submodule(fqn, mod) + unflattened_result = ufm(*inp) + self.assertTrue(torch.allclose(unflattened_result, eager_result)) + + if not is_retracebility_test(self._testMethodName): + test( + export(M(), inp, preserve_module_call_signature=("n",)), + swap={"n": N(), "n@1": N()}, + ) + + class _N(torch.nn.Module): + def forward(self, x): + return x + 5 + + class _N_1(torch.nn.Module): + def forward(self, x): + return x + 6 + + test( + export(M(), inp), + swap={"n": _N(), "n@1": _N_1()}, + ) + def test_preserve_module_call_signature_unflatten_specialization(self): class N(torch.nn.Module): def forward(self, x, b): @@ -6057,7 +6460,7 @@ def forward(self, x): unflattened_result = ufm(*inp) self.assertTrue(torch.allclose(unflattened_result, eager_result)) - def test_unflatten_multiple_graphs_preserve_signature_error(self): + def test_unflatten_multiple_graphs_preserve_signature_no_error(self): class N(torch.nn.Module): def forward(self, x, b): if b: @@ -6071,34 +6474,33 @@ def __init__(self): self.n = N() def forward(self, x): + x = x + 3 x = self.n(x, True) - x = x + 1 + x = x + 4 x = self.n(x, False) - x = x + 1 + x = x + 5 return x inp = (torch.ones(1),) m = M() eager_result = m(*inp) - if not is_retracebility_test(self._testMethodName): - ep = export(M(), inp, preserve_module_call_signature=("n",)) - with self.assertRaisesRegex( - ValueError, - "Cannot unflatten multiple calls to module n while preserving its signature", - ): - torch.export.unflatten(ep) + def test(ep): + epm = ep.module() + ufm = torch.export.unflatten(ep) - ep = export(M(), inp) - epm = ep.module() - ufm = torch.export.unflatten(ep) + exported_result = epm(*inp) + self.assertTrue(torch.allclose(exported_result, eager_result)) - exported_result = epm(*inp) - self.assertTrue(torch.allclose(exported_result, eager_result)) + unflattened_result = ufm(*inp) + self.assertTrue(torch.allclose(unflattened_result, eager_result)) - unflattened_result = ufm(*inp) - self.assertTrue(torch.allclose(unflattened_result, eager_result)) + if not is_retracebility_test(self._testMethodName): + test(export(M(), inp, preserve_module_call_signature=("n",))) + test(export(M(), inp)) + + @testing.expectedFailureRetraceabilityNonStrict def test_unflatten_multiple_graphs_state(self): class N(torch.nn.Module): def __init__(self): @@ -6131,15 +6533,34 @@ def forward(self, x): m = M() eager_result = m(*inp) - ep = export(M(), inp) - epm = ep.module() - ufm = torch.export.unflatten(ep) + def test(ep): + epm = ep.module() + ufm = torch.export.unflatten(ep) + + exported_result = epm(*inp) + self.assertTrue(torch.allclose(exported_result, eager_result)) - exported_result = epm(*inp) - self.assertTrue(torch.allclose(exported_result, eager_result)) + unflattened_result = ufm(*inp) + self.assertTrue(torch.allclose(unflattened_result, eager_result)) - unflattened_result = ufm(*inp) - self.assertTrue(torch.allclose(unflattened_result, eager_result)) + if not is_retracebility_test(self._testMethodName): + test(export(M(), inp, preserve_module_call_signature=("n",))) + # running decompositions again should work for all IRs + ep = export(M(), inp, preserve_module_call_signature=("n",)) + test(ep.run_decompositions({})) + if is_training_ir_test(self._testMethodName): + # since we run decompositions by default when testing training IR, + # also test training IR without running decompositions + strict = not is_non_strict_test(self._testMethodName) + ept = torch.export.export_for_training( + M(), + inp, + strict=strict, + preserve_module_call_signature=("n",), + ) + test(ept) + + test(export(M(), inp)) def test_unflatten_multiple_graphs_shared_submodule(self): class N(torch.nn.Module): @@ -6452,12 +6873,11 @@ def true_fn(x, y): model = Model() with torch.no_grad(): - exported_program = torch.export._trace._export( + exported_program = torch.export.export_for_training( model, (torch.tensor(10), torch.tensor(12)), {}, dynamic_shapes=None, - pre_dispatch=True, strict=False, ) @@ -6510,12 +6930,11 @@ def forward(self, x, y): # no grad model = Model() with torch.no_grad(): - ep_nograd = torch.export._trace._export( + ep_nograd = torch.export.export_for_training( model, (torch.tensor(10), torch.tensor(12)), {}, dynamic_shapes=None, - pre_dispatch=True, strict=False, ) # check that only sub op is wrapped with grad_enabled @@ -6531,12 +6950,11 @@ def forward(self, x, y): # enable grad model = Model() - ep_grad = torch.export._trace._export( + ep_grad = torch.export.export_for_training( model, (torch.tensor(10), torch.tensor(12)), {}, dynamic_shapes=None, - pre_dispatch=True, strict=False, ) # check that only add op is wrapped with grad_enabled @@ -6636,9 +7054,12 @@ def forward(self, x): "torch.ops.higher_order.wrap_with_set_grad_enabled", ep.graph_module.code, ) + gm = torch.export.export_for_training(model, (torch.randn(4, 4),)).module() + self.assertIn( + "set_grad_enabled", + gm.code, + ) - # T203671967 - @testing.expectedFailureRetraceability # autocast nodes not created after re-tracing def test_export_with_autocast(self): class Model(torch.nn.Module): def forward(self, x): @@ -6647,23 +7068,26 @@ def forward(self, x): ): y = x.sin().sum() with torch.autocast( - device_type="cpu", dtype=torch.float64, enabled=True + device_type="cpu", dtype=torch.float16, enabled=True ): z = y.sin().sum() return z model = Model() ep = export(model, (torch.randn(4, 4),), {}) - # _export_for_traininig is using pre_dispatch=False - # Therefore the autocast calls are not replaced with a hop. - # non_strict doesn't have autocast nodes - if not is_non_strict_test(self._testMethodName) and not is_training_ir_test( - self._testMethodName - ): + # autocast nodes do not exist after run_decomposition() + if not is_training_ir_test(self._testMethodName): self.assertIn( "torch.ops.higher_order.wrap_with_autocast", ep.graph_module.code, ) + # _export_for_traininig is using pre_dispatch=False + # Therefore the autocast calls are not replaced with a hop. + gm = torch.export.export_for_training(model, (torch.randn(4, 4),)).module() + self.assertIn( + "autocast", + gm.code, + ) def test_export_as_backend(self): def f(x, y): @@ -7162,6 +7586,7 @@ def forward(self, x): } export(f, (inputs,), dynamic_shapes=dynamic_shapes) + @testing.expectedFailureRetraceabilityNonStrict def test_disable_forced_specializations_ok(self): # check that we don't force specialization, and defer to runtime asserts # with allow_complex_guards_as_runtime_asserts=True to successfully export @@ -7282,6 +7707,7 @@ def forward(self, w, x, y, z): # TODO requires_grad doesn't seem to work with serialization. @testing.expectedFailureSerDer + @testing.expectedFailureSerDerNonStrict def test_preserve_requires_grad_placeholders(self): class Module(torch.nn.Module): def __init__(self) -> None: @@ -7402,6 +7828,33 @@ def forward(self, x, y): 0, ) + def test_constant_output_dup(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.constant = torch.ones(4, 4) + + def forward(self, x): + return x + self.constant, self.constant + + ep = export(M(), (torch.ones(4, 4),)).run_decompositions() + mod = ep.module() + a, b = mod(torch.zeros(4, 4)) + self.assertTrue(torch.allclose(a, torch.ones(4, 4))) + self.assertTrue(torch.allclose(b, torch.ones(4, 4))) + + def test_constant_requires_grad_const(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.foo = torch.randn(2, 2, requires_grad=True) + + def forward(self, x): + return x.cos() + self.foo.sum() + + gm = export(M(), (torch.ones(2, 2),)).module() + self.assertFalse(gm.foo.requires_grad) + def test_constant_aliasing(self): class M1(torch.nn.Module): def __init__(self, m2, foo): @@ -7415,7 +7868,7 @@ def forward(self, x): class M2(torch.nn.Module): def __init__(self) -> None: super().__init__() - self.foo = torch.ones(3, 3) + self.foo = torch.ones(3, 3, requires_grad=True) def forward(self, x): return x + self.foo @@ -7461,6 +7914,7 @@ def forward(self, x): for param in ["alpha", "beta", "gamma"]: self.assertTrue(param in unep.state_dict()) + @testing.expectedFailureRetraceabilityNonStrict def test_intermediate_shape_comp(self): class Foo(torch.nn.Module): def forward(self, x, y): @@ -7691,15 +8145,11 @@ def forward(self, x): y = torch.ops.testlib.foo_functional.default(x) return torch.ops.testlib.foo_mutated.default(y) - decomp_table = torch.export.core_aten_decompositions() + decomp_table = torch.export.default_decompositions() + del decomp_table[torch.ops.testlib.foo_functional.default] - # FIXME (We need to design a proper way that doesn't need _preserve_ops) ep = torch.export.export(M(), (torch.randn(4, 4),)).run_decompositions( decomp_table, - _preserve_ops=( - torch.ops.testlib.foo_functional.default, - torch.ops.testlib.foo_mutated.default, - ), ) self.assertExpectedInline( @@ -7734,14 +8184,9 @@ def forward(self, x): }, ) - if IS_FBCODE: - ep = ep.run_decompositions( - {}, _preserve_ops=(torch.ops.aten.linear.default,) - ) - else: - table = torch.export.core_aten_decompositions() - del table[torch.ops.aten.linear.default] - ep = ep.run_decompositions(table) + table = torch.export.default_decompositions() + del table[torch.ops.aten.linear.default] + ep = ep.run_decompositions(table) comp_mod = ep.module() inp1 = torch.randn(3, 4) @@ -7749,6 +8194,7 @@ def forward(self, x): self.assertTrue(torch.allclose(comp_mod(inp1), mod(inp1))) self.assertTrue(torch.allclose(comp_mod(inp2), mod(inp2))) + @testing.expectedFailureRetraceabilityNonStrict def test_automatic_dynamic_shapes_simple_equality(self): # The next 3 test cases tests for automatic dynamic shapes specs, verifying that automatic dynamism # leads to replacement symbols being set for equalities, and inferred relationships being checked @@ -7820,6 +8266,7 @@ def forward(self, x, y, z): test_serdes=True, ) + @testing.expectedFailureRetraceabilityNonStrict def test_automatic_dynamic_shapes_constant_relation(self): AUTO, STATIC = Dim.AUTO, Dim.STATIC @@ -7865,6 +8312,7 @@ def forward(self, x, y): test_serdes=True, ) + @testing.expectedFailureRetraceabilityNonStrict def test_automatic_dynamic_shapes_linear_relation(self): AUTO, STATIC = Dim.AUTO, Dim.STATIC @@ -8143,6 +8591,8 @@ def test_dynamic_shapes_serdes_user_errors(self): _load_dynamic_shapes(spec, from_dict=True) @testing.expectedFailureSerDer # TODO(pianpwk): PowByNatural valuerange deserialization + @testing.expectedFailureSerDerNonStrict + @testing.expectedFailureRetraceabilityNonStrict def test_dim_dynamic(self): dynamic = Dim.DYNAMIC @@ -8221,6 +8671,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @testing.expectedFailureNonStrict @testing.expectedFailureTrainingIRToRunDecompNonStrict # unbacked symint not tracked? @testing.expectedFailureSerDer # T195866111 + @testing.expectedFailureSerDerNonStrict + @testing.expectedFailureRetraceabilityNonStrict def test_hints_wrapper(self): class M(torch.nn.Module): def __init__(self) -> None: @@ -8281,6 +8733,89 @@ def forward(self, arg0_1: "f32[2, 4]", arg1_1: "f32[4]"): """, ) + def test_export_for_training_with_state_dict_hooks(self): + def _state_dict_pre_hook(mod, prefix, keep_vars): + mod._buffers["test"] = torch.Tensor([1]) + + def _state_dict_hook(mod, state_dict, prefix, *args, **kwargs): + keys = list(state_dict.keys()) + for key in keys: + local_key = key[len(prefix) :] + if local_key.startswith("layer"): + new_key = prefix + local_key.replace("layer.", "") + state_dict[new_key] = state_dict[key] + if new_key != key: + del state_dict[key] + + class Layer(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(2, 2) + self.linear2 = torch.nn.Linear(2, 2) + + def forward(self, x): + x = self.linear1(x) + x = torch.relu(x) + x = self.linear2(x) + return x + + class CustomModule(torch.nn.Module): + def __init__(self): + super().__init__() + self._register_state_dict_hook(_state_dict_hook) + self.register_state_dict_pre_hook(_state_dict_pre_hook) + # non-persistent buffer in named_buffers() + self.foo = torch.nn.Buffer(torch.rand(2, 3), persistent=False) + # non-persistent buffer not in named_buffers() + self.register_buffer("buf", None, persistent=False) + self.layer = Layer() + + def forward(self, x): + x = self.layer(x) + return x + + M = CustomModule() + inp = (torch.randn(2, 2),) + ep = export(M, inp) + export_res = ep.module()(*inp) + ref_res = M(*inp) + self.assertEqual(export_res, ref_res) + # we want to store the unprocessed keys + self.assertTrue( + { + "layer.linear1.weight", + "layer.linear1.bias", + "layer.linear2.weight", + "layer.linear2.bias", + }.issubset({spec.target for spec in ep.graph_signature.input_specs}) + ) + unflattened = torch.export.unflatten(ep) + export_res = unflattened(*inp) + self.assertEqual(export_res, ref_res) + + with torch._export.utils._disable_load_state_dict_hooks(M): + state_dict = M.state_dict() + self.assertEqual( + { + "layer.linear1.weight", + "layer.linear1.bias", + "layer.linear2.weight", + "layer.linear2.bias", + }, + state_dict.keys(), + ) + state_dict = M.state_dict() + self.assertEqual( + { + "linear1.weight", + "linear1.bias", + "linear2.weight", + "linear2.bias", + "test", + }, + state_dict.keys(), + ) + @unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support") class TestOneOffModelExportResult(TestCase): @@ -8841,15 +9376,12 @@ def forward(self, x): ep.graph_module.code ) - if IS_FBCODE: - ep = ep.run_decompositions(_preserve_ops=(torch.ops.aten.elu.default,)) - else: - decomp_table = core_aten_decompositions() - del decomp_table[torch.ops.aten.elu.default] + decomp_table = default_decompositions() + del decomp_table[torch.ops.aten.elu.default] - ep = ep.run_decompositions( - decomp_table=decomp_table, - ) + ep = ep.run_decompositions( + decomp_table=decomp_table, + ) FileCheck().check_count("torch.ops.aten.elu.default", 1, exactly=True).run( ep.graph_module.code ) @@ -8871,16 +9403,11 @@ def forward(self, x): "torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True ).run(ep.graph_module.code) - if IS_FBCODE: - ep = ep.run_decompositions( - _preserve_ops=(torch.ops.aten.upsample_bilinear2d.vec,) - ) - else: - decomp_table = core_aten_decompositions() - del decomp_table[torch.ops.aten.upsample_bilinear2d.vec] - ep = ep.run_decompositions( - decomp_table=decomp_table, - ) + decomp_table = default_decompositions() + del decomp_table[torch.ops.aten.upsample_bilinear2d.vec] + ep = ep.run_decompositions( + decomp_table=decomp_table, + ) FileCheck().check_count( "torch.ops.aten.upsample_bilinear2d.vec", 1, exactly=True diff --git a/test/export/test_export_training_ir_to_run_decomp.py b/test/export/test_export_training_ir_to_run_decomp.py index b1168f54bb227..335f4ec7a0c19 100644 --- a/test/export/test_export_training_ir_to_run_decomp.py +++ b/test/export/test_export_training_ir_to_run_decomp.py @@ -1,6 +1,5 @@ # Owner(s): ["oncall: export"] import torch -from torch.testing._internal.common_utils import IS_FBCODE try: @@ -16,10 +15,6 @@ def mocked_training_ir_to_run_decomp_export_strict(*args, **kwargs): ep = torch.export.export_for_training(*args, **kwargs) - if IS_FBCODE: - return ep.run_decompositions( - {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY - ) return ep.run_decompositions({}) @@ -29,10 +24,6 @@ def mocked_training_ir_to_run_decomp_export_non_strict(*args, **kwargs): else: ep = torch.export.export_for_training(*args, **kwargs, strict=False) - if IS_FBCODE: - return ep.run_decompositions( - {}, _preserve_ops=testing._COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY - ) return ep.run_decompositions({}) diff --git a/test/export/test_retraceability.py b/test/export/test_retraceability.py index e7f243fd9fb7e..071598878e2ab 100644 --- a/test/export/test_retraceability.py +++ b/test/export/test_retraceability.py @@ -12,7 +12,7 @@ test_classes = {} -def mocked_retraceability_export(*args, **kwargs): +def mocked_retraceability_export_strict(*args, **kwargs): ep = export(*args, **kwargs) if "dynamic_shapes" in kwargs: if isinstance(kwargs["dynamic_shapes"], dict): @@ -22,16 +22,39 @@ def mocked_retraceability_export(*args, **kwargs): return ep -def make_dynamic_cls(cls): - cls_prefix = "RetraceExport" +def mocked_retraceability_export_non_strict(*args, **kwargs): + if "strict" in kwargs: + ep = export(*args, **kwargs) + else: + ep = export(*args, **kwargs, strict=False) + if "dynamic_shapes" in kwargs: + if isinstance(kwargs["dynamic_shapes"], dict): + kwargs["dynamic_shapes"] = tuple(kwargs["dynamic_shapes"].values()) + + if "strict" in kwargs: + ep = export(ep.module(), *(args[1:]), **kwargs) + else: + ep = export(ep.module(), *(args[1:]), **kwargs, strict=False) + return ep + - test_class = testing.make_test_cls_with_mocked_export( - cls, - cls_prefix, - test_export.RETRACEABILITY_SUFFIX, - mocked_retraceability_export, - xfail_prop="_expected_failure_retrace", - ) +def make_dynamic_cls(cls, strict): + if strict: + test_class = testing.make_test_cls_with_mocked_export( + cls, + "RetraceExport", + test_export.RETRACEABILITY_STRICT_SUFFIX, + mocked_retraceability_export_strict, + xfail_prop="_expected_failure_retrace", + ) + else: + test_class = testing.make_test_cls_with_mocked_export( + cls, + "RetraceExportNonStrict", + test_export.RETRACEABILITY_NON_STRICT_SUFFIX, + mocked_retraceability_export_non_strict, + xfail_prop="_expected_failure_retrace_non_strict", + ) test_classes[test_class.__name__] = test_class # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING @@ -45,7 +68,8 @@ def make_dynamic_cls(cls): test_export.TestExport, ] for test in tests: - make_dynamic_cls(test) + make_dynamic_cls(test, True) + make_dynamic_cls(test, False) del test if __name__ == "__main__": diff --git a/test/export/test_serdes.py b/test/export/test_serdes.py index a1ced9dd4e5e6..d22d19500f3ae 100644 --- a/test/export/test_serdes.py +++ b/test/export/test_serdes.py @@ -15,7 +15,7 @@ test_classes = {} -def mocked_serder_export(*args, **kwargs): +def mocked_serder_export_strict(*args, **kwargs): ep = export(*args, **kwargs) buffer = io.BytesIO() save(ep, buffer) @@ -24,16 +24,35 @@ def mocked_serder_export(*args, **kwargs): return loaded_ep -def make_dynamic_cls(cls): - cls_prefix = "SerDesExport" +def mocked_serder_export_non_strict(*args, **kwargs): + if "strict" in kwargs: + ep = export(*args, **kwargs) + else: + ep = export(*args, **kwargs, strict=False) + buffer = io.BytesIO() + save(ep, buffer) + buffer.seek(0) + loaded_ep = load(buffer) + return loaded_ep + - test_class = testing.make_test_cls_with_mocked_export( - cls, - cls_prefix, - test_export.SERDES_SUFFIX, - mocked_serder_export, - xfail_prop="_expected_failure_serdes", - ) +def make_dynamic_cls(cls, strict): + if strict: + test_class = testing.make_test_cls_with_mocked_export( + cls, + "SerDesExport", + test_export.SERDES_SUFFIX, + mocked_serder_export_strict, + xfail_prop="_expected_failure_serdes", + ) + else: + test_class = testing.make_test_cls_with_mocked_export( + cls, + "SerDesExportNonStrict", + test_export.SERDES_NON_STRICT_SUFFIX, + mocked_serder_export_non_strict, + xfail_prop="_expected_failure_serdes_non_strict", + ) test_classes[test_class.__name__] = test_class # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING @@ -46,7 +65,8 @@ def make_dynamic_cls(cls): test_export.TestExport, ] for test in tests: - make_dynamic_cls(test) + make_dynamic_cls(test, True) + make_dynamic_cls(test, False) del test if __name__ == "__main__": diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index 19e3db9ed2957..161b7c71ec493 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -29,14 +29,13 @@ ) from torch._higher_order_ops.torchbind import enable_torchbind_tracing from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode -from torch.export import Dim, export, load, save +from torch.export import Dim, export_for_training, load, save from torch.fx.experimental.symbolic_shapes import is_concrete_int, ValueRanges from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, IS_WINDOWS, parametrize, run_tests, - skipIfCrossRef, TemporaryFileName, TestCase, ) @@ -97,7 +96,7 @@ def op_schema(cls, op): return torch.ops.aten.add.Tensor._schema inp = (torch.ones(10),) - ep = export(TestModule(), inp) + ep = export_for_training(TestModule(), inp) # Register the custom op handler. foo_custom_op = FooExtensionOp() @@ -162,7 +161,7 @@ def forward(self, x, y, use_p=False): model = MyModule().eval() random_inputs = (torch.rand([2, 3]), torch.rand([2, 3])) - exp_program = torch.export.export(model, random_inputs, {"use_p": True}) + exp_program = export_for_training(model, random_inputs, {"use_p": True}) output_buffer = io.BytesIO() # Tests that example inputs are preserved when saving and loading module. @@ -181,7 +180,7 @@ class M(torch.nn.Module): def forward(self, x): return x.sin() - exp_program = torch.export.export_for_training(M(), (torch.randn(4, 4),)) + exp_program = export_for_training(M(), (torch.randn(4, 4),)) output_buffer = io.BytesIO() # Tests that example forward arg names are preserved when saving and loading module. @@ -221,7 +220,7 @@ def forward(self, x): inp = (torch.ones(10),) # Module will only be able to roundtrip if metadata # can be correctly parsed. - ep = export(MyModule(), inp) + ep = export_for_training(MyModule(), inp) buffer = io.BytesIO() save(ep, buffer) loaded_ep = load(buffer) @@ -244,7 +243,7 @@ def forward(self, x): # Check that module can be roundtripped, thereby confirming proper deserialization. inp = (torch.ones(10),) - ep = export(MyModule(), inp) + ep = export_for_training(MyModule(), inp) buffer = io.BytesIO() save(ep, buffer) loaded_ep = load(buffer) @@ -267,7 +266,7 @@ def forward(self, x, w, b): eps=1e-5, ) - exported_module = export( + exported_module = export_for_training( MyModule(), ( torch.ones([512, 512], requires_grad=True), @@ -310,7 +309,7 @@ def forward(self, a, b, c) -> torch.Tensor: "b": {1: dim1_bc}, "c": {0: dim0_ac, 1: dim1_bc}, } - exported_module = export( + exported_module = export_for_training( DynamicShapeSimpleModel(), inputs, dynamic_shapes=dynamic_shapes ).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) @@ -332,7 +331,7 @@ def forward(self, x): return torch.split(x, 2) input = torch.arange(10.0).reshape(5, 2) - exported_module = export(MyModule(), (input,)).run_decompositions() + exported_module = export_for_training(MyModule(), (input,)).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) node = serialized.exported_program.graph_module.graph.nodes[-1] @@ -373,7 +372,7 @@ def __init__(self) -> None: def forward(self, x): return torch.ops.aten.var_mean.correction(x, [1])[0] - exported_module = export( + exported_module = export_for_training( MyModule(), (torch.ones([512, 512], requires_grad=True),), ).run_decompositions() @@ -395,7 +394,7 @@ class M(torch.nn.Module): def forward(self, x): return x + x - ep = torch.export.export( + ep = export_for_training( M(), (torch.randn(4),), dynamic_shapes=({0: Dim("temp")},) ) @@ -427,7 +426,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: f = Foo() x, _ = torch.sort(torch.randn(3, 4)) - exported_module = export(f, (x,)).run_decompositions() + exported_module = export_for_training(f, (x,)).run_decompositions() serialized = ExportedProgramSerializer().serialize(exported_module) node = serialized.exported_program.graph_module.graph.nodes[-1] @@ -445,7 +444,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: b = x + y return b + a - ep = torch.export.export(Module(), (torch.randn(3, 2), torch.randn(3, 2))) + ep = export_for_training(Module(), (torch.randn(3, 2), torch.randn(3, 2))) s = ExportedProgramSerializer().serialize(ep) c = canonicalize(s.exported_program) g = c.graph_module.graph @@ -459,7 +458,7 @@ class M(torch.nn.Module): def forward(self, x): return torch.ops.aten.sum.dim_IntList(x, []) - ep = torch.export.export(M(), (torch.randn(3, 2),)) + ep = torch.export.export_for_training(M(), (torch.randn(3, 2),)) serialized = ExportedProgramSerializer().serialize(ep) for node in serialized.exported_program.graph_module.graph.nodes: if "aten.sum.dim_IntList" in node.target: @@ -576,21 +575,24 @@ def _deepcopy_inputs(inputs): def _check_graph(pre_dispatch): if pre_dispatch: - ep = torch.export._trace._export( + ep = torch.export.export_for_training( fn, _deepcopy_inputs(inputs), {}, dynamic_shapes=dynamic_shapes, - pre_dispatch=True, strict=strict, ) else: - ep = torch.export.export( + # We should have this branch because + # PT2 Inference goes through this private + # export API. + ep = torch.export._trace._export( fn, _deepcopy_inputs(inputs), {}, dynamic_shapes=dynamic_shapes, strict=strict, + pre_dispatch=False, ) ep.graph.eliminate_dead_code() @@ -926,7 +928,7 @@ def forward(self, x): a = a * 2 return a, b - ep = torch.export.export(M(), (torch.ones(3),)) + ep = torch.export.export_for_training(M(), (torch.ones(3),)) # insert another getitem node for node in ep.graph.nodes: @@ -1072,7 +1074,7 @@ def __init__(self) -> None: def forward(self): return self.p * self.p - ep = torch.export.export(M(), ()) + ep = torch.export.export_for_training(M(), ()) ep._example_inputs = None roundtrip_ep = deserialize(serialize(ep)) self.assertTrue(torch.allclose(ep.module()(), roundtrip_ep.module()())) @@ -1089,7 +1091,7 @@ def forward(self, x): return x + x f = Module() - ep = export(f, (torch.randn(1, 3),)) + ep = export_for_training(f, (torch.randn(1, 3),)) serialized_program = ExportedProgramSerializer().serialize(ep) serialized_program.exported_program.schema_version.major = -1 @@ -1125,7 +1127,7 @@ def forward(self, x): y = self.linear(y) return y - ep = export(Module(), inp) + ep = export_for_training(Module(), inp) buffer = io.BytesIO() save(ep, buffer) @@ -1142,7 +1144,7 @@ def forward(self, x): f = Foo() inp = (torch.randn(2, 2),) - ep = export(f, inp) + ep = export_for_training(f, inp) with tempfile.NamedTemporaryFile() as f: save(ep, f) @@ -1159,7 +1161,7 @@ def forward(self, x, y): f = Foo() inp = (torch.tensor([6]), torch.tensor([7])) - ep = export(f, inp) + ep = export_for_training(f, inp) with TemporaryFileName() as fname: path = Path(fname) @@ -1177,7 +1179,7 @@ def forward(self, x): f = Foo() - ep = export(f, inp) + ep = export_for_training(f, inp) buffer = io.BytesIO() save(ep, buffer, extra_files={"extra.txt": "moo"}) @@ -1195,7 +1197,7 @@ def forward(self, x): f = Foo() - ep = export(f, (torch.randn(1, 3),)) + ep = export_for_training(f, (torch.randn(1, 3),)) with tempfile.NamedTemporaryFile() as f: save(ep, f) @@ -1221,7 +1223,7 @@ def forward(self, x): list_tensor = [torch.tensor(3), torch.tensor(4)] return x + self.a + list_tensor[0] + list_tensor[1] - ep = export(Foo(), (torch.tensor(1),)) + ep = export_for_training(Foo(), (torch.tensor(1),)) buffer = io.BytesIO() save(ep, buffer) buffer.seek(0) @@ -1247,7 +1249,7 @@ def forward(self, x): f = Foo() inputs = (torch.zeros(4, 4),) - ep = export(f, inputs) + ep = export_for_training(f, inputs) # Replace one of the values with an instance of our custom class for node in ep.graph.nodes: @@ -1301,7 +1303,7 @@ def forward(self, x): inputs = (torch.zeros(2, 3),) with enable_torchbind_tracing(): - ep = export(f, inputs, strict=False) + ep = export_for_training(f, inputs, strict=False) serialized_vals = serialize(ep) ep = deserialize(serialized_vals) @@ -1315,7 +1317,7 @@ def forward(self, x): f = Foo() inputs = (torch.zeros(4, 4),) - ep = export(f, inputs) + ep = export_for_training(f, inputs) new_gm = copy.deepcopy(ep.graph_module) new_gm.meta["custom"] = {} @@ -1350,7 +1352,7 @@ def forward(self, x): f = Foo() inputs = (torch.ones(2, 2),) - ep = export(f, inputs) + ep = export_for_training(f, inputs) new_gm = copy.deepcopy(ep.graph_module) new_gm.meta["custom"] = {} @@ -1378,49 +1380,6 @@ def forward(self, x): self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") self.assertTrue(counter > 1) - @skipIfCrossRef - def test_custom_tag_metadata_re_export(self): - class Foo(torch.nn.Module): - def __init__(self): - super().__init__() - self.w = torch.nn.Parameter(torch.rand(4, 2)) - self.b = torch.nn.Parameter(torch.rand(4)) - - def forward(self, x): - out = torch.nn.functional.linear(x, self.w, self.b) - return out - - f = Foo() - inputs = (torch.zeros(1, 2),) - ep = export(f, inputs) - - new_gm = copy.deepcopy(ep.graph_module) - new_gm.meta["custom"] = {} - new_gm.meta["custom"]["f"] = "bar" - - for node in new_gm.graph.nodes: - if ( - node.op == "call_function" - and node.target == torch.ops.aten.linear.default - ): - node.meta["custom"] = {} - node.meta["custom"]["quantization_tag"] = "foo" - - new_ep = ep._update(new_gm, ep.graph_signature) - new_ep = torch.export.export(new_ep.module(), inputs) - self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar") - - # the custom field should be preserved after re-export and - # should not be copied to other nodes - counter = 0 - for node in new_ep.graph.nodes: - if "custom" in node.meta: - counter += 1 - self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo") - self.assertTrue(node.target == torch.ops.aten.linear.default) - - self.assertEqual(counter, 1) - def test_custom_tag_metadata_copy(self): class Foo(torch.nn.Module): def forward(self, x): @@ -1429,7 +1388,7 @@ def forward(self, x): f = Foo() inputs = (torch.zeros(4, 4),) - ep = export(f, inputs) + ep = export_for_training(f, inputs) new_gm = copy.deepcopy(ep.graph_module) new_gm.meta["custom"] = {} diff --git a/test/export/test_torchbind.py b/test/export/test_torchbind.py index 9b2f1546f1a78..997aeecd37dda 100644 --- a/test/export/test_torchbind.py +++ b/test/export/test_torchbind.py @@ -1,6 +1,6 @@ # Owner(s): ["oncall: export"] - +import copy import unittest import torch @@ -1028,6 +1028,30 @@ def forward(self, token, tq, x): return (tq,)""", # noqa: B950 ) + def test_deepcopy(self): + tq = torch.classes._TorchScriptTesting._TensorQueue( + torch.empty( + 0, + ).fill_(-1) + ) + tq_0 = copy.deepcopy(tq) + tq.push(torch.zeros(2, 2)) + tq.push(torch.ones(2, 2)) + tq_1 = copy.deepcopy(tq) + tq.push(torch.ones(2, 2) * 2) + self.assertEqual(tq_0.size(), 0) + self.assertEqual(tq_1.size(), 2) + self.assertEqual(tq.size(), 3) + + foo = torch.classes._TorchScriptTesting._Foo(1, 2) + foo_0 = copy.deepcopy(foo) + foo.increment(1) + foo_1 = copy.deepcopy(foo) + foo.increment(1) + self.assertEqual(foo_0.add(1), 3) + self.assertEqual(foo_1.add(1), 5) + self.assertEqual(foo.add(1), 7) + class TestCompileTorchbind(TestCase): def setUp(self): diff --git a/test/export/test_unflatten_training_ir.py b/test/export/test_unflatten_training_ir.py new file mode 100644 index 0000000000000..684d9a149ecfa --- /dev/null +++ b/test/export/test_unflatten_training_ir.py @@ -0,0 +1,47 @@ +# Owner(s): ["oncall: export"] + + +try: + from . import test_unflatten, testing +except ImportError: + import test_unflatten # @manual=fbcode//caffe2/test:test_export-library + import testing # @manual=fbcode//caffe2/test:test_export-library + +from torch.export import export_for_training + + +test_classes = {} + + +def mocked_training_ir_export(*args, **kwargs): + return export_for_training(*args, **kwargs) + + +def make_dynamic_cls(cls): + cls_prefix = "TrainingIRUnflatten" + + test_class = testing.make_test_cls_with_mocked_export( + cls, + cls_prefix, + "_training_ir", + mocked_training_ir_export, + xfail_prop="_expected_failure_training_ir", + ) + + test_classes[test_class.__name__] = test_class + # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING + globals()[test_class.__name__] = test_class + test_class.__module__ = __name__ + + +tests = [ + test_unflatten.TestUnflatten, +] +for test in tests: + make_dynamic_cls(test) +del test + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/test/export/test_verifier.py b/test/export/test_verifier.py index ec6c08d75c4e5..dd3d18db1cda1 100644 --- a/test/export/test_verifier.py +++ b/test/export/test_verifier.py @@ -6,7 +6,7 @@ from torch import Tensor from torch._dynamo.eval_frame import is_dynamo_supported from torch._export.verifier import SpecViolationError, Verifier -from torch.export import export +from torch.export import export_for_training from torch.export.exported_program import InputKind, InputSpec, TensorArgument from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase @@ -20,7 +20,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: f = Foo() - ep = export(f, (torch.randn(100), torch.randn(100))) + ep = export_for_training(f, (torch.randn(100), torch.randn(100))) verifier = Verifier() verifier.check(ep) @@ -47,7 +47,9 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: f = Foo() - ep = export(f, (torch.randn(100), torch.randn(100))) + ep = export_for_training( + f, (torch.randn(100), torch.randn(100)) + ).run_decompositions({}) for node in ep.graph.nodes: if node.target == torch.ops.aten.add.Tensor: node.target = torch.ops.aten.add_.Tensor @@ -70,7 +72,7 @@ def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: f = Foo() - ep = export(f, (torch.randn(3, 3), torch.randn(3, 3))) + ep = export_for_training(f, (torch.randn(3, 3), torch.randn(3, 3))) verifier = Verifier() verifier.check(ep) @@ -89,7 +91,9 @@ def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: f = Foo() - ep = export(f, (torch.randn(3, 3), torch.randn(3, 3))) + ep = export_for_training( + f, (torch.randn(3, 3), torch.randn(3, 3)) + ).run_decompositions({}) for node in ep.graph_module.true_graph_0.graph.nodes: if node.target == torch.ops.aten.add.Tensor: node.target = torch.ops.aten.add_.Tensor @@ -107,7 +111,7 @@ def __init__(self) -> None: def forward(self, x: Tensor) -> Tensor: return self.linear(x) - ep = export(M(), (torch.randn(10, 10),)) + ep = export_for_training(M(), (torch.randn(10, 10),)) ep.validate() def test_ep_verifier_invalid_param(self) -> None: @@ -121,7 +125,7 @@ def __init__(self) -> None: def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y + self.a - ep = export(M(), (torch.randn(100), torch.randn(100))) + ep = export_for_training(M(), (torch.randn(100), torch.randn(100))) # Parameter doesn't exist in the state dict ep.graph_signature.input_specs[0] = InputSpec( @@ -146,7 +150,7 @@ def __init__(self) -> None: def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return x + y + self.a - ep = export(M(), (torch.randn(100), torch.randn(100))) + ep = export_for_training(M(), (torch.randn(100), torch.randn(100))) # Buffer doesn't exist in the state dict ep.graph_signature.input_specs[0] = InputSpec( @@ -178,7 +182,7 @@ def forward(self, x1, x2): self.my_buffer2.add_(1.0) return output - ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0))) + ep = export_for_training(M(), (torch.tensor(5.0), torch.tensor(6.0))) ep.validate() def test_ep_verifier_invalid_output(self) -> None: @@ -201,14 +205,13 @@ def forward(self, x1, x2): self.my_buffer2.add_(1.0) return output - ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0))) + ep = export_for_training(M(), (torch.tensor(5.0), torch.tensor(6.0))) output_node = list(ep.graph.nodes)[-1] output_node.args = ( ( output_node.args[0][0], next(iter(ep.graph.nodes)), - output_node.args[0][1], ), ) diff --git a/test/export/testing.py b/test/export/testing.py index 3647d4c9edd86..ed72f219eb639 100644 --- a/test/export/testing.py +++ b/test/export/testing.py @@ -258,12 +258,24 @@ def expectedFailureRetraceability(fn): return fn +# Controls tests generated in test/export/test_retraceability.py +def expectedFailureRetraceabilityNonStrict(fn): + fn._expected_failure_retrace_non_strict = True + return fn + + # Controls tests generated in test/export/test_serdes.py def expectedFailureSerDer(fn): fn._expected_failure_serdes = True return fn +# Controls tests generated in test/export/test_serdes.py +def expectedFailureSerDerNonStrict(fn): + fn._expected_failure_serdes_non_strict = True + return fn + + def expectedFailureSerDerPreDispatch(fn): fn._expected_failure_serdes_pre_dispatch = True return fn diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 8c438bc2e4fc7..7fe56facf5b53 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -1,14 +1,17 @@ import argparse import datetime +import logging import re import sys -import warnings from collections import defaultdict import torch -from torch._C import parse_schema +from torch._C import parse_schema, Tag +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + # How to run this test locally: # 1 Have two virtual environments (eg conda env), one without PyTorch installed (venv_nightly) # one with your local changes (venv_yours). @@ -22,7 +25,10 @@ # 5. Run this test with # `python test/forward_backward_compatibility/check_forward_backward_compatibility.py --existing-schemas nightly_schemas.txt` -# The date specifies how long the allowlist exclusion should apply to. +# The date specifies how long the allowlist exclusion should apply to. Note that core ATen opset +# (https://pytorch.org/docs/stable/torch.compiler_ir.html#core-aten-ir) is guaranteed to be BC, based on this policy +# (https://dev-discuss.pytorch.org/t/core-aten-opset-backward-forward-compatibility-policy/1772) and hence the +# allowlist does not apply (or the date is always arbitrarily far for core ATen ops). # # - If we NEVER give BC guarantee for an operator, you can put the # date arbitrarily far in the future. @@ -109,34 +115,15 @@ ("aten::mps_max_pool2d_backward.out", datetime.date(9999, 1, 1)), # TODO: FIXME: prims shouldn't be checked ("prims::.*", datetime.date(9999, 1, 1)), - ("aten::_flash_attention_forward", datetime.date(2023, 12, 30)), - ("aten::_flash_attention_backward", datetime.date(2023, 12, 30)), ("aten::_scaled_dot_product_cudnn_attention", datetime.date(9999, 1, 1)), - ("aten::_sparse_mask_helper", datetime.date(2023, 3, 15)), # BetterTransformer 1.0 internal operators ("aten::_transformer_decoder_only_layer_fwd", datetime.date(9999, 1, 1)), ("aten::_native_decoder_only_multi_head_attention", datetime.date(9999, 1, 1)), - ("c10d::_allgather_base_", datetime.date(2023, 12, 30)), - ("c10d::_reduce_scatter_base_", datetime.date(2023, 12, 30)), - ("c10d::broadcast_", datetime.date(2023, 12, 30)), - ("c10d::scatter_", datetime.date(2023, 12, 30)), # These ops were moved to python under the c10d_functional namespace ("aten::wait_tensor", datetime.date(9999, 1, 30)), ("aten::reduce_scatter_tensor", datetime.date(9999, 1, 30)), ("aten::all_gather_into_tensor", datetime.date(9999, 1, 30)), ("aten::all_reduce", datetime.date(9999, 1, 30)), - ("aten::to_sparse.out", datetime.date(2023, 12, 31)), - ("aten::to_sparse.sparse_dim_out", datetime.date(2023, 12, 31)), - ("aten::to_sparse_bsc.out", datetime.date(2023, 12, 31)), - ("aten::to_sparse_bsr.out", datetime.date(2023, 12, 31)), - ("aten::to_sparse_csc.out", datetime.date(2023, 12, 31)), - ("aten::to_sparse_csr.out", datetime.date(2023, 12, 31)), - ("aten::_structured_sparse_linear", datetime.date(2023, 12, 31)), - ("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)), - ("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)), - ("aten::sym_constrain_range", datetime.date(2023, 12, 31)), - ("aten::_efficient_attention_forward", datetime.date(2024, 7, 1)), - ("aten::_efficient_attention_backward", datetime.date(2024, 7, 1)), ("onednn::qconv1d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv2d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv3d_pointwise", datetime.date(2024, 12, 31)), @@ -150,8 +137,6 @@ ("_quantized::wrapped_linear_prepack", datetime.date(2024, 12, 31)), ("_quantized::wrapped_linear_prepacked", datetime.date(2024, 12, 31)), ("_quantized::wrapped_quantized_linear_prepacked", datetime.date(2024, 12, 31)), - # BC-breaking change in can_cast signature: 'from' -> 'from_' - ("aten::can_cast", datetime.date(2024, 5, 31)), ] ALLOW_LIST_COMPILED = [ @@ -249,6 +234,14 @@ def process_version_map(version_map): return output +def is_core_aten_op(schema) -> bool: + # Check if the schema is a core ATen op + if "::" not in schema.name: + return False + _, _, tags = torch._C._get_operation_overload(schema.name, schema.overload_name) + return Tag.core in tags + + def check_bc(existing_schemas): new_schema_dict = load_schemas_to_dict() version_map = process_version_map(torch._C._get_operator_version_map()) @@ -256,12 +249,23 @@ def check_bc(existing_schemas): broken_ops = [] for existing_schema in existing_schemas: if allow_listed(existing_schema): - print("schema: ", str(existing_schema), " found on allowlist, skipping") - continue + if not is_core_aten_op(existing_schema): + logging.info("schema: %s found on allowlist, skipping", existing_schema) + continue + else: + logging.info( + "schema: %s found on allowlist, but is a core ATen op, checking BC", + existing_schema, + ) if has_valid_upgraders(existing_schema, version_map): - print("schema: ", str(existing_schema), " has valid upgrader, skipping") - continue - print("processing existing schema: ", str(existing_schema)) + if not is_core_aten_op(existing_schema): + logging.info("schema: %s has valid upgrader, skipping", existing_schema) + continue + else: + logging.info( + "schema: %s has a valid upgrader, but is a core ATen op, checking BC" + ) + logging.debug("processing existing schema: %s", existing_schema) matching_new_schemas = new_schema_dict.get(existing_schema.name, []) found = False for matching_new_schema in matching_new_schemas: @@ -269,24 +273,24 @@ def check_bc(existing_schemas): found = True break if not found: - print( + logging.warning( "Can NOT find backward compatible schemas after changes " - "for schema {} from the following candidates:\n[\n{}\n]".format( - str(existing_schema), - "\n\t".join(str(s) for s in matching_new_schemas), - ) + "for schema %s from the following candidates:\n[\n%s\n]", + str(existing_schema), + "\n\t".join(str(s) for s in matching_new_schemas), ) # TODO Print out more details about why candidates don't match. broken_ops.append(str(existing_schema)) is_bc = False if is_bc: - print("Found backward compatible schemas for all existing schemas") + logging.info("Found backward compatible schemas for all existing schemas") else: - print( + logging.warning( "The PR is introducing backward incompatible changes to the " "operator library. Please contact PyTorch team to confirm " "whether this change is wanted or not. \n\nBroken ops: " - "[\n\t{}\n]".format("\n\t".join(broken_ops)) + "[\n\t%s\n]", + "\n\t".join(broken_ops), ) return is_bc @@ -297,9 +301,9 @@ def check_fc(existing_schemas): broken_ops = [] for existing_schema in existing_schemas: if allow_listed(existing_schema): - print("schema: ", str(existing_schema), " found on allowlist, skipping") + logging.info("schema: %s found on allowlist, skipping", existing_schema) continue - print("processing existing schema: ", str(existing_schema)) + logging.info("processing existing schema: %s", existing_schema) matching_new_schemas = new_schema_dict.get(existing_schema.name, []) found = False possible_failure_reasons = [] @@ -313,29 +317,28 @@ def check_fc(existing_schemas): if reason != "": possible_failure_reasons.append(reason) if not found: - print( + logging.warning( "Can NOT find forward compatible schemas after changes " - "for schema {} from the following candidates:\n[\n{}\n]".format( - str(existing_schema), - "\n\t".join(str(s) for s in matching_new_schemas), - ) + "for schema %s from the following candidates:\n[\n\t%s\n]", + str(existing_schema), + "\n\t".join(str(s) for s in matching_new_schemas), ) - print( + logging.warning( "Refer to following reasons for failure " - "to find FC schema:\n[\n{}\n]".format( - "\n\t".join(str(r) for r in possible_failure_reasons) - ) + "to find FC schema:\n[\n%s\n]", + "\n\t".join(str(r) for r in possible_failure_reasons), ) broken_ops.append(str(existing_schema)) is_fc = False if is_fc: - print("Found forward compatible schemas for all existing schemas") + logging.info("Found forward compatible schemas for all existing schemas") else: - warnings.warn( + logging.warning( "The PR is introducing a potentially forward incompatible changes to the " "operator library. Please contact PyTorch team to confirm " "whether this change is wanted or not. \n\nBroken ops: " - "[\n\t{}\n]".format("\n\t".join(broken_ops)) + "[\n\t%s\n]", + "\n\t".join(broken_ops), ) @@ -357,7 +360,7 @@ def check_fc(existing_schemas): break if dont_parse(line.strip()): - print("Not parsing schema line: ", line.strip()) + logging.info("Not parsing schema line: %s", line.strip()) continue s = parse_schema(line.strip()) slist.append(s) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index ac0c6d89467f3..3e1eeb8255b75 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -6307,12 +6307,6 @@ def fn_(x): xfail( "nn.functional.nll_loss", "" ), # Cannot call sizes() on tensor with symbolic sizes/strides - xfail( - "_segment_reduce", "lengths" - ), # aten.segment_reduce.default - couldn't find symbolic meta functio... - xfail( - "_segment_reduce", "offsets" - ), # aten.segment_reduce.default - couldn't find symbolic meta functio... xfail("trace", ""), # Cannot call sizes() on tensor with symbolic sizes/strides xfail( "_upsample_bilinear2d_aa" diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index e4714fe768fb5..b6c0e103dfee4 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -8,7 +8,7 @@ from functorch.experimental import control_flow from functorch.experimental.control_flow import cond, UnsupportedAliasMutationException from torch._higher_order_ops.associative_scan import associative_scan -from torch._higher_order_ops.scan import scan +from torch._higher_order_ops.scan import _fake_scan, scan from torch._higher_order_ops.while_loop import while_loop from torch._subclasses.functional_tensor import ( CppFunctionalizeAPI, @@ -29,6 +29,7 @@ skipIfCrossRef, skipIfRocm, skipIfTorchDynamo, + TEST_WITH_CROSSREF, TEST_WITH_TORCHDYNAMO, TestCase, xfailIfTorchDynamo, @@ -113,48 +114,6 @@ def _fake_associative_scan(combine_fn, xs, dim, reverse=False): return pytree.tree_unflatten(results, spec) -def _fake_scan(combine_fn, init, xs=None, dim=0, reverse=False): - carry_leaves, carry_spec = pytree.tree_flatten(init) - inp_leaves, inp_spec = pytree.tree_flatten(xs) - if xs is None or len(inp_leaves) == 0: - return init, [] - result_flat = [] - carry = carry_leaves - op = reversed if reverse else lambda x: x - - dummy_carry, dummy_out = combine_fn( - pytree.tree_unflatten(carry, carry_spec), - pytree.tree_unflatten( - [torch._ops.ops.aten.slice(elem, dim, 0, 1, 1) for elem in inp_leaves], - inp_spec, - ), - ) - dummy_out_leaves, dummy_out_spec = pytree.tree_flatten(dummy_out) - num_leaves = len(dummy_out_leaves) - - for ind in op(range(inp_leaves[0].size(dim))): - xs = [ - torch._ops.ops.aten.slice(elem, dim, ind, ind + 1, 1) for elem in inp_leaves - ] - - carry, y = combine_fn( - pytree.tree_unflatten(carry, carry_spec), - pytree.tree_unflatten(xs, inp_spec), - ) - carry, _ = pytree.tree_flatten(carry) - y, _ = pytree.tree_flatten(y) - result_flat.append(y) - - results = [ - torch.concatenate([e[leave_ind] for e in op(result_flat)], dim) - for leave_ind in range(num_leaves) - ] - return ( - pytree.tree_unflatten(carry, carry_spec), - pytree.tree_unflatten(results, dummy_out_spec), - ) - - def compile_mode_helper(fct, compile_mode): if compile_mode == "compile": return torch.compile(fct, fullgraph=True, dynamic=False) @@ -1377,6 +1336,17 @@ def test_associative_scan_compile( ) self.assertEqual(cumsum1, cumsum_exp) + def test_scan_y_less_ndim_then_dim(self): + def combine_fn(carry, x): + return carry @ x, (carry @ x).sum() + + init = torch.randn(4, 3) + xs = torch.randn(3, 3, 2) + dim = 2 + out = scan(combine_fn, init, xs, dim=dim) + exp_out = _fake_scan(combine_fn, init, xs, dim=dim) + self.assertEqual(out, exp_out) + # TODO: provide an implementation for all compile modes and re-enable all test @requires_cuda @parametrize("reverse", [False, True]) @@ -1394,12 +1364,12 @@ def add2(x: torch.Tensor, y: torch.Tensor): ( get_scan_combine_fn("add", False), torch.cumsum, - torch.zeros(1, 10, 2, device=device), + torch.zeros(10, 2, device=device), ), ( get_scan_combine_fn("mul", False), torch.cumprod, - torch.ones(1, 10, 2, device=device), + torch.ones(10, 2, device=device), ), ]: result = scan_fct(op, init, x, dim=0, reverse=reverse) @@ -1428,12 +1398,14 @@ def add2(x: torch.Tensor, y: torch.Tensor): ) if not reverse: self.assertEqual( - cumsum1[1], torch.tensor([0.0, 1.0, 3.0, 6.0], dtype=torch.int64) + cumsum1[1], + torch.tensor([[0.0], [1.0], [3.0], [6.0]], dtype=torch.int64), ) self.assertEqual(cumsum1[0], torch.tensor([6.0], dtype=torch.int64)) else: self.assertEqual( - cumsum1[1], torch.tensor([6.0, 6.0, 5.0, 3.0], dtype=torch.int64) + cumsum1[1], + torch.tensor([[6.0], [6.0], [5.0], [3.0]], dtype=torch.int64), ) self.assertEqual(cumsum1[0], torch.tensor([6.0], dtype=torch.int64)) self.assertEqual(cumsum1, cumsum_exp) @@ -1445,12 +1417,14 @@ def add2(x: torch.Tensor, y: torch.Tensor): result_exp = _fake_scan(add2, init=init, xs=x, dim=0, reverse=reverse) if not reverse: self.assertEqual( - result[1], torch.tensor([2.0, 3.0, 5.0, 10.0], dtype=torch.int64) + result[1], + torch.tensor([[2.0], [3.0], [5.0], [10.0]], dtype=torch.int64), ) self.assertEqual(result[0], torch.tensor([24.0], dtype=torch.int64)) else: self.assertEqual( - result[1], torch.tensor([25.0, 14.0, 7.0, 5.0], dtype=torch.int64) + result[1], + torch.tensor([[25.0], [14.0], [7.0], [5.0]], dtype=torch.int64), ) self.assertEqual(result[0], torch.tensor([24.0], dtype=torch.int64)) self.assertEqual(result, result_exp) @@ -1496,7 +1470,7 @@ def test_scan_dtype(self, reverse, compile_mode, device, dtype): x = torch.randn(3, 10, 2, device=device).to(dtype=dtype) op, init = ( get_scan_combine_fn("adds"), - torch.zeros(1, 10, 2, device=device, dtype=dtype), + torch.zeros(10, 2, device=device, dtype=dtype), ) result = scan_fct(op, init, x, dim=0, reverse=reverse) result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse) @@ -1515,7 +1489,7 @@ def test_scan_dtype(self, reverse, compile_mode, device, dtype): x = torch.randn(3, 10, 2, device=device).to(dtype=dtype) op, init = ( get_scan_combine_fn("adds"), - torch.zeros(1, 10, 2, device=device, dtype=torch.float32), + torch.zeros(10, 2, device=device, dtype=torch.float32), ) result = scan_fct(op, init, x, dim=0, reverse=reverse) result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse) @@ -1533,7 +1507,7 @@ def test_scan_dtype(self, reverse, compile_mode, device, dtype): x = torch.randn(3, 10, 2, device=device) op, init = ( get_scan_combine_fn("adds"), - torch.zeros(1, 10, 2, device=device, dtype=dtype), + torch.zeros(10, 2, device=device, dtype=dtype), ) result = scan_fct(op, init, x, dim=0, reverse=reverse) result_exp = _fake_scan(op, init=init, xs=x, dim=0, reverse=reverse) @@ -1563,6 +1537,8 @@ def test_scan_dtype(self, reverse, compile_mode, device, dtype): def test_associative_scan_dim(self, combine_mode, reverse, device): import random + random.seed(10) + num_dims = [random.randint(2, 5) for _ in range(10)] for num_dim in num_dims: shapes = [random.randint(1, 10) for _ in range(num_dim)] @@ -1595,8 +1571,7 @@ def test_scan_dim(self, reverse, device): shapes = [random.randint(1, 10) for _ in range(num_dim)] rnd_scan_dim = random.randint(0, num_dim - 1) x = torch.randn(*shapes, device=device) - init_shapes = shapes - init_shapes[rnd_scan_dim] = 1 + init_shapes = shapes[:rnd_scan_dim] + shapes[rnd_scan_dim + 1 :] for op, op_pt, init in [ ( @@ -1617,7 +1592,9 @@ def test_scan_dim(self, reverse, device): self.assertEqual(result, result_exp) if not reverse: result_exp_PT = op_pt(x, rnd_scan_dim) - self.assertEqual(result[1], result_exp_PT) + res_list = list(result) + res_list[1] = res_list[1].movedim(0, rnd_scan_dim) + self.assertEqual(res_list[1], result_exp_PT) @skipIfRocm(msg="Unsupported on ROCM yet") @unittest.skipIf(not SM70OrLater, "triton") @@ -1671,10 +1648,16 @@ def test_scan_binary_operator(self, reverse, device): A = torch.randn(state_dim, requires_grad=True, device=device) elements = (A.repeat((timesteps, 1)), projected_inputs) init = tuple( - [torch.ones_like(torch._ops.ops.aten.slice(elements[0], 0, 0, 1, 1))] + [ + torch.ones_like( + torch._ops.ops.aten.slice(elements[0], 0, 0, 1, 1), + requires_grad=True, + ) + ] + [ torch.zeros_like( - torch._ops.ops.aten.slice(projected_inputs, 0, 0, 1, 1) + torch._ops.ops.aten.slice(projected_inputs, 0, 0, 1, 1), + requires_grad=True, ) ] ) @@ -1835,8 +1818,7 @@ def fct_pointwise(x, y): self.assertEqual(result, expected_result) @requires_cuda - @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_wrong_pytree(self, device): + def test_scan_wrong_pytree(self): # Init and input have same pytree def fct_wrong_pytree(x, y): return ( @@ -1852,9 +1834,9 @@ def fct_wrong_pytree(x, y): }, ) - x = torch.randn(3, 2, 2, device=device) - y = torch.randn(3, 2, 2, device=device) - z = torch.randn(3, 2, 2, device=device) + x = torch.randn(3, 2, 2) + y = torch.randn(3, 2, 2) + z = torch.randn(3, 2, 2) inp = {"i": x, "j": ([y], [{"o": z}])} inp_flat, inp_spec = pytree.tree_flatten(inp) init_flat = [torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp_flat] @@ -1864,8 +1846,8 @@ def fct_wrong_pytree(x, y): # Should be: RuntimeError, # r"The number of leaves of the pytree of the new carry produced by # the operator needs to match the length of the pytree of the init", - torch._dynamo.exc.Unsupported, - "Observed exception.*", + RuntimeError, + "The number of leaves of the pytree of the new carry", ): result = scan(fct_wrong_pytree, init, inp, dim=0) @@ -2049,7 +2031,7 @@ def chain_fct_different_dim(inp): @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) def test_scan_downstream_scan_matmul(self, compile_mode, reverse, device): inp = torch.randn(3, 10, 2, device=device) - init = torch.randn(3, 1, 2, device=device) + init = torch.randn(3, 2, device=device) for ind in range(2): # Chain with matmul @@ -2076,51 +2058,6 @@ def chain_fct(inp): result1 = fct_cmp(inp) self.assertEqual(result1, expected_result) - # TODO: provide an implementation for all compile modes and re-enable all test - @requires_cuda - @parametrize("compile_mode", ["none", "eager"]) - @parametrize("reverse", [False, True]) - @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_downstream_scan_scan(self, compile_mode, reverse, device): - inp = torch.randn(3, 10, 2, device=device) - init = torch.randn(3, 1, 2, device=device) - - # Chain with scan - def chain_fct_same_dim(inp): - o1 = scan( - get_scan_combine_fn("add", False), - init, - inp, - dim=1, - reverse=reverse, - ) - o2 = scan( - get_scan_combine_fn("add", False), - init, - o1[1], - dim=1, - reverse=reverse, - ) - return o2 - - fct_cmp = compile_mode_helper(chain_fct_same_dim, compile_mode) - - expected_result = _fake_scan( - get_scan_combine_fn("add", False), - init=init, - xs=_fake_scan( - get_scan_combine_fn("add", False), - init=init, - xs=inp, - dim=1, - reverse=reverse, - )[1], - dim=1, - reverse=reverse, - ) - result1 = fct_cmp(inp) - self.assertEqual(result1, expected_result) - # TODO: provide an implementation for all compile modes and re-enable all test @requires_cuda @parametrize("compile_mode", ["none", "eager"]) @@ -2128,7 +2065,7 @@ def chain_fct_same_dim(inp): @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) def test_scan_downstream_scan_scan_dim(self, compile_mode, reverse, device): inp = torch.randn(3, 10, 2, device=device) - init = torch.randn(3, 1, 2, device=device) + init = torch.randn(3, 2, device=device) # Chain with scan on different dim init2 = torch.randn(1, 10, 2, device=device) @@ -2141,6 +2078,7 @@ def chain_fct_different_dim(inp): dim=1, reverse=reverse, ) + o1 = pytree.tree_map(lambda t: t.movedim(0, 1), o1) o2 = scan( get_scan_combine_fn("add", False), init2, @@ -2152,16 +2090,18 @@ def chain_fct_different_dim(inp): fct_cmp = compile_mode_helper(chain_fct_different_dim, compile_mode) + xs = _fake_scan( + get_scan_combine_fn("add", False), + init=init, + xs=inp, + dim=1, + reverse=reverse, + )[1] + xs = pytree.tree_map(lambda t: t.movedim(0, 1), xs) expected_result = _fake_scan( get_scan_combine_fn("add", False), init=init2, - xs=_fake_scan( - get_scan_combine_fn("add", False), - init=init, - xs=inp, - dim=1, - reverse=reverse, - )[1], + xs=xs, dim=0, reverse=reverse, ) @@ -2222,7 +2162,7 @@ def test_associative_scan_non_pointwise_generic(self, reverse, device): @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) def test_scan_non_pointwise(self, reverse, device): x = torch.randn(3, 10, 2, device=device) - init = torch.randn(1, 10, 2, device=device) + init = torch.randn(10, 2, device=device) result_expected = _fake_scan( get_scan_combine_fn("non_pointwise", False), init=init, @@ -2252,7 +2192,7 @@ def test_scan_compile_cnt(self, reverse, device): with torch._dynamo.config.patch(automatic_dynamic_shapes=True): cnt = CompileCounter() x = torch.randn(3, 2, 5, device=device) - init = torch.randn(3, 1, 5, device=device) + init = torch.randn(3, 5, device=device) # First compilation step torch.compile(scan, backend=cnt)( get_scan_combine_fn("add", False), @@ -2264,7 +2204,7 @@ def test_scan_compile_cnt(self, reverse, device): self.assertEqual(cnt.frame_count, 1) x = torch.randn(3, 20, 5, device=device) - init = torch.randn(3, 1, 5, device=device) + init = torch.randn(3, 5, device=device) # Recompilation due to first different size torch.compile(scan, backend=cnt)( get_scan_combine_fn("add", False), @@ -2276,7 +2216,7 @@ def test_scan_compile_cnt(self, reverse, device): self.assertEqual(cnt.frame_count, 2) x = torch.randn(3, 40, 5, device=device) - init = torch.randn(3, 1, 5, device=device) + init = torch.randn(3, 5, device=device) # No recompilation, because of dynamic shape torch.compile(scan, backend=cnt)( get_scan_combine_fn("add", False), @@ -2288,7 +2228,7 @@ def test_scan_compile_cnt(self, reverse, device): self.assertEqual(cnt.frame_count, 2) x = torch.randn(3, 40, 5, device=device) - init = torch.randn(3, 40, 1, device=device) + init = torch.randn(3, 40, device=device) # Recompilation because of dim change torch.compile(scan, backend=cnt)( get_scan_combine_fn("add", False), @@ -2300,7 +2240,7 @@ def test_scan_compile_cnt(self, reverse, device): self.assertEqual(cnt.frame_count, 3) x = torch.randn(3, 40, 20, device=device) - init = torch.randn(3, 40, 1, device=device) + init = torch.randn(3, 40, device=device) # Recompilation due to first different size on new dim torch.compile(scan, backend=cnt)( get_scan_combine_fn("add", False), @@ -2312,7 +2252,7 @@ def test_scan_compile_cnt(self, reverse, device): self.assertEqual(cnt.frame_count, 4) x = torch.randn(3, 40, 40, device=device) - init = torch.randn(3, 40, 1, device=device) + init = torch.randn(3, 40, device=device) # No recompilation, because of dynamic shape on new dim torch.compile(scan, backend=cnt)( get_scan_combine_fn("add", False), @@ -2324,7 +2264,7 @@ def test_scan_compile_cnt(self, reverse, device): self.assertEqual(cnt.frame_count, 4) x = torch.randn(3, 60, 40, device=device) - init = torch.randn(3, 1, 40, device=device) + init = torch.randn(3, 40, device=device) # Recompilation because of dim change torch.compile(scan, backend=cnt)( get_scan_combine_fn("add", False), @@ -2336,7 +2276,7 @@ def test_scan_compile_cnt(self, reverse, device): self.assertEqual(cnt.frame_count, 5) x = torch.randn(3, 60, 40, device=device) - init = torch.randn(3, 1, 40, device=device) + init = torch.randn(3, 40, device=device) # Recompilation because of reverse change torch.compile(scan, backend=cnt)( get_scan_combine_fn("add", False), @@ -2348,7 +2288,7 @@ def test_scan_compile_cnt(self, reverse, device): self.assertEqual(cnt.frame_count, 6) x = torch.randn(3, 60, 40, device=device) - init = torch.randn(3, 1, 40, device=device) + init = torch.randn(3, 40, device=device) # No recompilation, as nothing changed torch.compile(scan, backend=cnt)( get_scan_combine_fn("add", False), @@ -2360,7 +2300,7 @@ def test_scan_compile_cnt(self, reverse, device): self.assertEqual(cnt.frame_count, 6) x = torch.randn(3, 120, 80, device=device) - init = torch.randn(3, 1, 80, device=device) + init = torch.randn(3, 80, device=device) # No recompilation, final test torch.compile(scan, backend=cnt)( get_scan_combine_fn("add", False), @@ -2372,115 +2312,135 @@ def test_scan_compile_cnt(self, reverse, device): self.assertEqual(cnt.frame_count, 6) @requires_cuda - @parametrize("reverse", [False, True]) @parametrize("compile_mode", ["none", "eager"]) - @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_init_scanned_0(self, reverse, compile_mode, device): + def test_scan_init_scanned_0(self, compile_mode): scan_fct = compile_mode_helper(scan, compile_mode) # Only init and no input - x = torch.randn(3, 1, 2, device=device) - init = torch.randn(3, 1, 2, device=device) + x = torch.randn(3, 1, 2) + init = torch.randn(3, 2) dim = 1 # Scan dimension is 0 init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1) inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) - with self.assertRaisesRegex( - # Should be: RuntimeError, "Input leaves must have a scan dimension > 0" - torch._dynamo.exc.Unsupported, - "Observed exception.*", - ): - result_init = scan_fct( - get_scan_combine_fn("add", False), - init, - inp, - dim=dim, - reverse=reverse, - ) + if compile_mode == "none": + with self.assertRaisesRegex( + RuntimeError, + "xs leaves must have a scan dimension > 0", + ): + result_init = scan_fct( + get_scan_combine_fn("add", False), + init, + inp, + dim=dim, + ) + else: + with self.assertRaisesRegex( + # Should be: RuntimeError, "Input leaves must have a scan dimension > 0" + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result_init = scan_fct( + get_scan_combine_fn("add", False), + init, + inp, + dim=dim, + ) @requires_cuda - @parametrize("reverse", [False, True]) @parametrize("compile_mode", ["none", "eager"]) - @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_init_non_tensor(self, reverse, compile_mode, device): + def test_scan_init_non_tensor(self, compile_mode): scan_fct = compile_mode_helper(scan, compile_mode) - # Only init and no input - x = torch.randn(3, 1, 2, device=device) - init = torch.randn(3, 1, 2, device=device) + x = torch.randn(3, 1, 2) dim = 1 # Init is a float and not a tensor - inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) init = 1.0 - with self.assertRaisesRegex( - # Should be: RuntimeError, "Init leaves must be a Tensor" - torch._dynamo.exc.Unsupported, - "Observed exception.*", - ): - result_init = scan_fct( - get_scan_combine_fn("add", False), init, inp, dim=dim, reverse=reverse - ) + if compile_mode == "none": + with self.assertRaisesRegex( + RuntimeError, + "All init leaves must be a Tensor", + ): + result_init = scan_fct( + get_scan_combine_fn("add", False), init, x, dim=dim + ) + else: + with self.assertRaisesRegex( + # Should be: RuntimeError, "Init leaves must be a Tensor" + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result_init = scan_fct( + get_scan_combine_fn("add", False), init, x, dim=dim + ) @requires_cuda - @parametrize("reverse", [False, True]) @parametrize("compile_mode", ["none", "eager"]) - @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_init_wrong_shape(self, reverse, compile_mode, device): + def test_scan_init_wrong_shape(self, compile_mode): scan_fct = compile_mode_helper(scan, compile_mode) # Only init and no input - x = torch.randn(3, 1, 2, device=device) - init = torch.randn(3, 1, 2, device=device) + x = torch.randn(3, 1, 2) dim = 1 # Init wrong shape (Other dim different) - inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) - init = torch._ops.ops.aten.slice(x, dim, 0, 1, 1) - init = torch.tile(init, (1, 2, 1)) - with self.assertRaisesRegex( - # Should be: RuntimeError, "The size of tensor a.*" - torch._dynamo.exc.Unsupported, - "Observed exception.*", - ): - result_init = scan_fct( - get_scan_combine_fn("add", False), - init, - inp, - dim=dim, - reverse=reverse, - ) + init = torch.randn(1, 2) + if compile_mode == "none": + with self.assertRaisesRegex(RuntimeError, "The shape of the new_carry"): + result_init = scan_fct( + get_scan_combine_fn("add", False), + init, + x, + dim=dim, + ) + else: + with self.assertRaisesRegex( + # Should be: RuntimeError, "The size of tensor a.*" + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result_init = scan_fct( + get_scan_combine_fn("add", False), + init, + x, + dim=dim, + ) @requires_cuda - @parametrize("reverse", [False, True]) @parametrize("compile_mode", ["none", "eager"]) - @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_init_wrong_pytree(self, reverse, compile_mode, device): + def test_scan_init_wrong_pytree(self, compile_mode): def add_one_carry(x: torch.Tensor, y: torch.Tensor): return x[0], x scan_fct = compile_mode_helper(scan, compile_mode) # Only init and no input - x = torch.randn(3, 1, 2, device=device) - init = torch.randn(3, 1, 2, device=device) + x = torch.randn(3, 1, 2) dim = 1 # Init wrong pytree - inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) init = ( torch._ops.ops.aten.slice(x, dim, 0, 1, 1), torch._ops.ops.aten.slice(x, dim, 0, 1, 1), ) - with self.assertRaisesRegex( - # Should be: RuntimeError: The number of leaves of the pytree of the new carry produced - # by the operator needs to match the length of the pytree of the init - torch._dynamo.exc.Unsupported, - "Observed exception.*", - ): - result_init = scan_fct(add_one_carry, init, inp, dim=dim, reverse=reverse) + if compile_mode == "none": + with self.assertRaisesRegex( + RuntimeError, + "The number of leaves of the pytree of the new carry produced by the operator", + ): + result_init = scan_fct(add_one_carry, init, x, dim=dim) + + else: + with self.assertRaisesRegex( + # Should be: RuntimeError: The number of leaves of the pytree of the new carry produced + # by the operator needs to match the length of the pytree of the init + torch._dynamo.exc.Unsupported, + "Observed exception.*", + ): + result_init = scan_fct(add_one_carry, init, x, dim=dim) @requires_cuda @parametrize("reverse", [False, True]) @@ -2528,7 +2488,7 @@ def add_scalar_carry(x: torch.Tensor, y: torch.Tensor): init = torch.randn(7, 8, device=device) def add_scalar_carry2(x: torch.Tensor, y: torch.Tensor): - return x + 1.0, x[: y.shape[1], : y.shape[2]] + y + return x + 1.0, x[: y.shape[0], : y.shape[1]] + y result_init = scan_fct(add_scalar_carry2, init, inp, dim=dim, reverse=reverse) result_exp = _fake_scan( @@ -2560,7 +2520,7 @@ def add_scalar_carry_sliced_out(x: torch.Tensor, y: torch.Tensor): ) self.assertEqual(result_init, result_exp) self.assertEqual(result_init[0].shape, torch.Size([2, 10, 2])) - self.assertEqual(result_init[1].shape, torch.Size([4, 5, 2])) + self.assertEqual(result_init[1].shape, torch.Size([2, 2, 5, 2])) # Correct case op, op_pt = (get_scan_combine_fn("add", False), torch.cumsum) @@ -2568,10 +2528,10 @@ def add_scalar_carry_sliced_out(x: torch.Tensor, y: torch.Tensor): dim = 1 if reverse: - init = torch.zeros_like(torch._ops.ops.aten.slice(x, dim, -1, None, 1)) + init = torch.zeros_like(torch.select_copy(x, -1, 0)) inp = torch._ops.ops.aten.slice(x, dim, 0, -1, 1) else: - init = torch.zeros_like(torch._ops.ops.aten.slice(x, dim, 0, 1, 1)) + init = torch.zeros_like(torch.select_copy(x, 1, 0)) inp = torch._ops.ops.aten.slice(x, dim, 1, None, 1) result = scan_fct(op, init, x, dim=dim, reverse=reverse) @@ -2580,56 +2540,10 @@ def add_scalar_carry_sliced_out(x: torch.Tensor, y: torch.Tensor): self.assertEqual(result, result_exp) if not reverse: result_exp_PT = op_pt(x, dim) + result = list(result) + result[1] = pytree.tree_map(lambda t: t.movedim(0, dim), result[1]) self.assertEqual(result[1], result_exp_PT) - @requires_cuda - @parametrize("reverse", [False, True]) - @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) - def test_scan_carry_wrong_pytree(self, reverse, device): - def fct_pointwise_carry_wrong_pytree(x, y): - return ( - ( - x["i"], - { - "i": x["i"] * y["i"], - "j": ( - [x["j"][0][0] * y["j"][0][0]], - [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], - ), - }, - ), - { - "i": x["i"] * y["i"], - "j": ( - [x["j"][0][0] * y["j"][0][0]], - [{"o": x["j"][1][0]["o"] + y["j"][1][0]["o"]}], - ), - }, - ) - - x = torch.randn(3, 2, 2, device=device) - y = torch.randn(3, 2, 2, device=device) - z = torch.randn(3, 2, 2, device=device) - inp = {"i": x, "j": ([y], [{"o": z}])} - inp_flat, inp_spec = pytree.tree_flatten(inp) - init_flat = [torch._ops.ops.aten.slice(e, 0, 0, 1, 1) for e in inp_flat] - init = pytree.tree_unflatten(init_flat, inp_spec) - - # Wrong pytree of the carry produced by the operation - with self.assertRaisesRegex( - # Should be: RuntimeError: The number of leaves of the pytree of the new carry - # produced by the operator needs to match the length of the pytree of the init - torch._dynamo.exc.Unsupported, - "Observed exception.*", - ): - result = scan( - fct_pointwise_carry_wrong_pytree, - init, - inp, - dim=0, - reverse=reverse, - ) - @requires_cuda @parametrize("reverse", [False, True]) @parametrize("device", [torch.device("cpu"), torch.device("cuda")]) @@ -2816,50 +2730,10 @@ def RNN(x: torch.Tensor, y: torch.Tensor): expected_result = rnn( torch.permute(x, (1, 0, 2)), torch.unsqueeze(h[:, 0, :], 0) ) - expected_result_out = torch.permute(expected_result[0], (1, 0, 2)) expected_result_state = torch.permute(expected_result[1], (1, 0, 2)) - result = scan(RNN, h[:, 0:1, :], x, dim=dim) - self.assertEqual(result[0], expected_result_state) - self.assertEqual(result[1], expected_result_out) - - @skipIfNoDynamoSupport - def test_scan_simple_graph_no_carry(self): - x = torch.randn(3, 10, 2, device=torch.device("cpu")) - init = torch.randn(1, 10, 2, device=torch.device("cpu")) - - def f(fct, init, xs): - return scan(fct, init, xs, dim=0, reverse=True) - - # Wrong number of returns from function - with self.assertRaisesRegex( - # Should be: RuntimeError: The pytree of the new carry produced - # by the operator needs to match the pytree of the init - torch._dynamo.exc.Unsupported, - "Observed exception.*", - ): - gm = make_fx(f, tracing_mode="symbolic")( - get_scan_combine_fn("add", True), init, x - ) - - @skipIfNoDynamoSupport - def test_scan_simple_graph_wrong_carry(self): - def add_wrong_carry(x: torch.Tensor, y: torch.Tensor): - return (x + y)[0, :], x + y - - x = torch.randn(3, 10, 2, device=torch.device("cpu")) - init = torch.randn(1, 10, 2, device=torch.device("cpu")) - - def f(fct, init, xs): - return scan(fct, init, xs, dim=0, reverse=True) - - # Wrong carry shape - with self.assertRaisesRegex( - # Should be: RuntimeError: The pytree of the new carry produced by - # the operator needs to match the pytree of the init - torch._dynamo.exc.Unsupported, - "Observed exception.*", - ): - gm = make_fx(f, tracing_mode="symbolic")(add_wrong_carry, init, x) + result = scan(RNN, init=torch.select_copy(h, dim, 0), xs=x, dim=dim) + self.assertEqual(result[0].unsqueeze(0), expected_result_state) + self.assertEqual(result[1], expected_result[0]) @skipIfNoDynamoSupport def test_scan_simple_graph_wrong_dtype(self): @@ -2877,10 +2751,10 @@ def f(fct, init, xs): # Should be: RuntimeError: Expected the init and # the new carry produced by the operator to be a tensor of # torch.int64 but got torch.float32 and torch.int64 - torch._dynamo.exc.UncapturedHigherOrderOpError, - ".*", + RuntimeError, + "The dtype of the new_carry", ): - gm = make_fx(f, tracing_mode="symbolic")(add_wrong_dtype, init, x) + f(add_wrong_dtype, init, x) @skipIfNoDynamoSupport @skipIfCrossRef # Arg order changes with crossref @@ -2901,20 +2775,16 @@ def f(fct, init, xs): gm.code.strip(), """\ def forward(self, fct_1, init_1, xs_1): - slice_1 = torch.ops.aten.slice.Tensor(xs_1, 0, 0, 1) - add = torch.ops.aten.add.Tensor(init_1, slice_1); add = None - add_1 = torch.ops.aten.add.Tensor(init_1, slice_1); slice_1 = add_1 = None - sym_size_int = torch.ops.aten.sym_size.int(init_1, 1) - sym_size_int_1 = torch.ops.aten.sym_size.int(init_1, 2) - new_empty = torch.ops.aten.new_empty.default(init_1, [1, sym_size_int, sym_size_int_1], dtype = torch.float32, device = device(type='cpu'), pin_memory = False); new_empty = None - new_empty_1 = torch.ops.aten.new_empty.default(xs_1, [1, sym_size_int, sym_size_int_1], dtype = torch.float32, device = device(type='cpu'), pin_memory = False); sym_size_int = sym_size_int_1 = new_empty_1 = None + select = torch.ops.aten.select.int(xs_1, 0, 0) + add = torch.ops.aten.add.Tensor(init_1, select); add = None + add_1 = torch.ops.aten.add.Tensor(init_1, select); select = add_1 = None + clone = torch.ops.aten.clone.default(init_1); clone = None + select_copy = torch.ops.aten.select_copy.int(xs_1, 0, 0); select_copy = None scan_combine_graph_0 = self.scan_combine_graph_0 - scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [xs_1], 0, True); scan_combine_graph_0 = init_1 = xs_1 = None + scan = torch.ops.higher_order.scan(scan_combine_graph_0, [init_1], [xs_1], 0, True, []); scan_combine_graph_0 = init_1 = xs_1 = None getitem = scan[0] - getitem_1 = getitem[0]; getitem = None - getitem_2 = scan[1]; scan = None - getitem_3 = getitem_2[0]; getitem_2 = None - return (getitem_1, getitem_3)""", # noqa: B950 + getitem_1 = scan[1]; scan = None + return (getitem, getitem_1)""", # noqa: B950 ) # Check graph @@ -2928,18 +2798,16 @@ def forward(self, fct_1, init_1, xs_1): def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor): l_init_ = L_init_ l_xs_ = L_xs_ - slice_1 = torch.ops.aten.slice(l_xs_, 0, 0, 1, 1) - out_l = l_init_ + slice_1; out_l = None - add_1 = l_init_ + slice_1; slice_1 = add_1 = None - child = l_init_.new_empty((1, 10, 2), dtype = torch.float32, device = device(type='cpu'), requires_grad = False); child = None - child_1 = l_xs_.new_empty((1, 10, 2), dtype = torch.float32, device = device(type='cpu'), requires_grad = False); child_1 = None + select = l_xs_.select(0, 0) + new_carry = l_init_ + select; new_carry = None + add_1 = l_init_ + select; select = add_1 = None + child = l_init_.clone(); child = None + child_1 = torch.select_copy(l_xs_, 0, 0); child_1 = None scan_combine_fn_0 = self.scan_combine_fn_0 - scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], 0, True); scan_combine_fn_0 = l_init_ = l_xs_ = None + scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], 0, True, []); scan_combine_fn_0 = l_init_ = l_xs_ = None getitem = scan[0] - getitem_1 = getitem[0]; getitem = None - getitem_2 = scan[1]; scan = None - getitem_3 = getitem_2[0]; getitem_2 = None - return (getitem_1, getitem_3)""", # noqa: B950 + getitem_1 = scan[1]; scan = None + return (getitem, getitem_1)""", # noqa: B950 ) @@ -5304,6 +5172,45 @@ def forward(self, l_inp_, l_tmp_): ) self.assertEqual(out, f(inp, tmp)) + @parametrize("requires_grad", [True, False]) + def test_cond_symint_operands(self, requires_grad): + from torch._dynamo.testing import EagerAndRecordGraphs + + backend = EagerAndRecordGraphs() + + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.num = 3 + + def forward(self, a, b): + return torch.cond( + pred=torch.tensor([True]), + true_fn=lambda a, b: a + b + self.num, + false_fn=lambda a, b: a - b - self.num, + operands=(a, b), + ) + + a = torch.ones(3, 3, requires_grad=requires_grad) + b = torch.ones(3, 3, requires_grad=requires_grad) + out = torch.compile(Mod(), backend=backend, dynamic=True)(a, b) + self.assertEqual(out, Mod()(a, b)) + self.assertEqual(len(backend.graphs), 1) + self.assertExpectedInline( + backend.graphs[0].code.strip(), + """\ +def forward(self, s0 : torch.SymInt, L_a_ : torch.Tensor, L_b_ : torch.Tensor, L_self_num : torch.SymInt): + l_a_ = L_a_ + l_b_ = L_b_ + l_self_num = L_self_num + tensor = torch.tensor([True]) + cond_true_0 = self.cond_true_0 + cond_false_0 = self.cond_false_0 + cond = torch.ops.higher_order.cond(tensor, cond_true_0, cond_false_0, [l_a_, l_b_, l_self_num]); tensor = cond_true_0 = cond_false_0 = l_a_ = l_b_ = l_self_num = None + getitem = cond[0]; cond = None + return (getitem,)""", # noqa: B950 + ) + def test_two_hops_not_sharing_code_obj(self): pred, args = torch.tensor(True), (torch.ones(3, 3),) @@ -5355,7 +5262,7 @@ def f(init, xs): return scan(get_scan_combine_fn("add", False), init, xs, dim=1) example_inputs = torch.ones(5, 7, 4) - example_init = torch.ones(5, 1, 4) + example_init = torch.ones(5, 4) functional_f = torch.func.functionalize(f) self.assertEqual( functional_f(example_init, example_inputs), f(example_init, example_inputs) @@ -5372,7 +5279,7 @@ def f(init, xs): return scan(add1, init, xs, dim=1) example_inputs = torch.ones(5, 7, 4) - example_init = torch.ones(5, 1, 4) + example_init = torch.ones(5, 4) functional_f = torch.func.functionalize(f) with self.assertRaisesRegex( UnsupportedAliasMutationException, @@ -5387,8 +5294,6 @@ def add2(x, y): def f(init, xs): return scan(add2, init, xs, dim=1) - example_inputs = torch.ones(5, 7, 4) - example_init = torch.ones(5, 1, 4) functional_f = torch.func.functionalize(f) with self.assertRaisesRegex( UnsupportedAliasMutationException, @@ -5406,13 +5311,83 @@ def f(init, xs): return scan(add, init, xs, dim=1) example_inputs = torch.ones(5, 7, 4) - example_init = torch.ones(5, 1, 4) + example_init = torch.ones(5, 4) functional_f = torch.func.functionalize(f) with self.assertRaisesRegex( UnsupportedAliasMutationException, "Combine_fn might be aliasing the input!" ): functional_f(example_init, example_inputs) + @skipIfTorchDynamo("Graph is not captured by backend if test with dynamo") + def test_scan_pytree_closure(self): + from torch._dynamo.testing import EagerAndRecordGraphs + + param_buffer = ({"param": torch.randn(3, 3)}, (torch.randn(3),)) + + def add(carry, x): + ret = (carry @ param_buffer[0]["param"]) @ x + param_buffer[1][0] + return ret, ret.sum() + + def f(init, xs): + return scan(add, init, xs) + + init = torch.randn(4, 3) + xs = torch.randn(3, 3, 3) + + backend = EagerAndRecordGraphs() + eager_out = f(init, xs) + compiled_out = torch.compile(f, backend=backend)(init, xs) + exp_out = _fake_scan(add, init, xs) + + self.assertEqual(len(backend.graphs), 1) + if TEST_WITH_CROSSREF: + self.assertExpectedInline( + backend.graphs[0].code.strip(), + """\ +def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_cell_contents_0_param_ : torch.Tensor, L_add_closure_0_cell_contents_1_0_ : torch.Tensor): + l_init_ = L_init_ + l_xs_ = L_xs_ + l_add_closure_0_cell_contents_0_param_ = L_add_closure_0_cell_contents_0_param_ + l_add_closure_0_cell_contents_1_0_ = L_add_closure_0_cell_contents_1_0_ + r = l_xs_.select(0, 0) + r_1 = l_init_.matmul(l_add_closure_0_cell_contents_0_param_) + r_2 = r_1.matmul(r); r_1 = r = None + r_3 = r_2.add(l_add_closure_0_cell_contents_1_0_); r_2 = None + r_4 = r_3.sum(); r_3 = r_4 = None + r_5 = l_init_.clone(); r_5 = None + r_6 = torch.select_copy(l_xs_, 0, 0); r_6 = None + scan_combine_fn_0 = self.scan_combine_fn_0 + scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], 0, False, [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = l_xs_ = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None + getitem = scan[0] + getitem_1 = scan[1]; scan = None + return (getitem, getitem_1)""", # noqa: B950 + ) + + else: + self.assertExpectedInline( + backend.graphs[0].code.strip(), + """\ +def forward(self, L_init_ : torch.Tensor, L_xs_ : torch.Tensor, L_add_closure_0_cell_contents_0_param_ : torch.Tensor, L_add_closure_0_cell_contents_1_0_ : torch.Tensor): + l_init_ = L_init_ + l_xs_ = L_xs_ + l_add_closure_0_cell_contents_0_param_ = L_add_closure_0_cell_contents_0_param_ + l_add_closure_0_cell_contents_1_0_ = L_add_closure_0_cell_contents_1_0_ + select = l_xs_.select(0, 0) + matmul = l_init_ @ l_add_closure_0_cell_contents_0_param_ + matmul_1 = matmul @ select; matmul = select = None + ret = matmul_1 + l_add_closure_0_cell_contents_1_0_; matmul_1 = None + sum_1 = ret.sum(); ret = sum_1 = None + child = l_init_.clone(); child = None + child_1 = torch.select_copy(l_xs_, 0, 0); child_1 = None + scan_combine_fn_0 = self.scan_combine_fn_0 + scan = torch.ops.higher_order.scan(scan_combine_fn_0, [l_init_], [l_xs_], 0, False, [l_add_closure_0_cell_contents_0_param_, l_add_closure_0_cell_contents_1_0_]); scan_combine_fn_0 = l_init_ = l_xs_ = l_add_closure_0_cell_contents_0_param_ = l_add_closure_0_cell_contents_1_0_ = None + getitem = scan[0] + getitem_1 = scan[1]; scan = None + return (getitem, getitem_1)""", # noqa: B950 + ) + self.assertEqual(eager_out, exp_out) + self.assertEqual(compiled_out, exp_out) + _hop_schema_test_schema_types = [ "bool", @@ -5549,6 +5524,44 @@ def test_while_loop_schema_gen(self): ) self.assertEqual(schema.parse(str(schema)), schema) + @skipIfTorchDynamo("Skip because dynamo cannot trace torch.export.") + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_cond_eager_run_with_item(self): + class M(torch.nn.Module): + def forward(self, a, b1, b2, c): + def true_fn(x): + return x * b1.item() + + def false_fn(x): + return x * b2.item() + + r = torch.cond(a, true_fn, false_fn, (c,)) + return r * 2 + + x = torch.randn(10, requires_grad=True) + args = ( + torch.tensor(True), + torch.tensor([3]), + torch.tensor([4]), + x, + ) + model = M() + ep = torch.export.export(model, args) + self.assertExpectedInline( + ep.module().code.strip(), + """\ +def forward(self, a, b1, b2, c): + a, b1, b2, c, = fx_pytree.tree_flatten_spec(([a, b1, b2, c], {}), self._in_spec) + true_graph_0 = self.true_graph_0 + false_graph_0 = self.false_graph_0 + cond = torch.ops.higher_order.cond(a, true_graph_0, false_graph_0, [c, b1, b2]); a = true_graph_0 = false_graph_0 = c = b1 = b2 = None + getitem = cond[0]; cond = None + mul = torch.ops.aten.mul.Tensor(getitem, 2); getitem = None + return pytree.tree_unflatten((mul,), self._out_spec)""", # noqa: B950 + ) + expected_output = model(*args) + self.assertEqual(expected_output, x * 3 * 2) + instantiate_parametrized_tests(TestHopSchema) instantiate_parametrized_tests(TestControlFlowTraced) diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 54136a4f7babb..8de92e04c68fb 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1038,6 +1038,9 @@ def fn(inp, *args, **kwargs): xfail("_native_batch_norm_legit"), # TODO: implement batching rule xfail("_batch_norm_with_update"), + xfail( + "unbind_copy" + ), # Batching rule not implemented for aten::unbind_copy.int. } ), ) @@ -1177,6 +1180,9 @@ def vjp_of_vjp(*args_and_cotangents): xfail("sparse.mm", "reduce"), xfail("as_strided_scatter", ""), # calls as_strided xfail("index_reduce", "prod"), # .item() call + xfail( + "unbind_copy" + ), # Batching rule not implemented for aten::unbind_copy.int. # --------------------------------------------------------------------- } ) @@ -1315,6 +1321,9 @@ def test_vmapvjp(self, device, dtype, op): xfail("_native_batch_norm_legit"), # TODO: implement batching rule xfail("_batch_norm_with_update"), + xfail( + "unbind_copy" + ), # Batching rule not implemented for aten::unbind_copy.int. # ---------------------------------------------------------------------- } @@ -1415,6 +1424,7 @@ def test_vmapjvpall(self, device, dtype, op): xfail("nn.functional.dropout3d", ""), xfail("as_strided_scatter", ""), xfail("masked.cumprod", ""), + xfail("permute_copy"), xfail("renorm"), # hit vmap fallback, which is disabled xfail("squeeze_copy"), xfail("t_copy"), @@ -1480,6 +1490,7 @@ def test(): xfail("masked_select"), xfail("nanquantile"), xfail("ormqr"), + xfail("permute_copy"), xfail("put"), xfail("quantile"), xfail("renorm"), @@ -1626,6 +1637,9 @@ def test(): xfail("__getitem__", ""), xfail("index_put", ""), xfail("view_as_complex"), + xfail( + "unbind_copy" + ), # Batching rule not implemented for aten::unbind_copy.int. xfail("nn.functional.gaussian_nll_loss"), xfail("masked_select"), xfail( @@ -1920,6 +1934,9 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents): xfail( "as_strided_scatter" ), # AssertionError: Tensor-likes are not close! + xfail( + "unbind_copy" + ), # Batching rule not implemented for aten::unbind_copy.int. xfail("bernoulli"), # calls random op xfail("bfloat16"), # required rank 4 tensor to use channels_last format xfail("cdist"), # Forward AD not implemented and no decomposition diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index 870b0e61b26e5..bf11423cdcb34 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -4375,6 +4375,9 @@ def sample_vmap_out_dim_numpy_split_copy_with_int(x, splits, dim): xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints # TypeError: expected Tensor as element 0 in argument 0, but got float xfail("item"), + xfail( + "unbind_copy" + ), # Batching rule not implemented for aten::unbind_copy.int. } ), ) @@ -4450,6 +4453,9 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("item"), xfail("tril"), # Exception not raised on error input xfail("triu"), # Exception not raised on error input + xfail( + "unbind_copy" + ), # Batching rule not implemented for aten::unbind_copy.int. xfail("__getitem__", ""), xfail("count_nonzero"), xfail( @@ -4472,6 +4478,7 @@ def test_vmap_exhaustive(self, device, dtype, op): xfail("histc"), xfail("as_strided"), xfail("as_strided_copy"), + xfail("permute_copy"), xfail("t_copy"), xfail("unsqueeze_copy"), xfail("istft"), diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py new file mode 100644 index 0000000000000..df2f94c724d6b --- /dev/null +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -0,0 +1,396 @@ +# Owner(s): ["module: higher order operators"] +# flake8: noqa: B950 + +import torch +import torch._dynamo +import torch._functorch +import torch._inductor +import torch._inductor.decomposition +from functorch.compile import aot_function, nop +from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm +from torch._higher_order_ops import invoke_subgraph +from torch.testing._internal.common_utils import ( + run_tests, + skipIfTorchDynamo, + TEST_WITH_CROSSREF, + TestCase, +) + + +@skipIfTorchDynamo("Not a torch._dynamo test") +class TestInvokeSubgraph(TestCase): + def test_simple(self): + def gn(x, y): + return (torch.mul(x, y),) + + def fn(x, y): + return invoke_subgraph(gn, None, (x, y))[0] + + x = torch.randn(8, requires_grad=True) + y = torch.randn(8, requires_grad=True) + ref = gn(x, y)[0] + + x_clone = x.clone().detach().requires_grad_(True) + y_clone = y.clone().detach().requires_grad_(True) + res = fn(x_clone, y_clone) + + # Run backward + ref.sum().backward() + res.sum().backward() + + self.assertEqual(ref, res) + self.assertEqual(x.grad, x_clone.grad) + self.assertEqual(y.grad, y_clone.grad) + + def test_aot_function(self): + def gn(x, y): + return (torch.mul(x, y),) + + def fn(x, y): + return invoke_subgraph(gn, None, (x, y))[0] + + x = torch.randn(8, requires_grad=True) + y = torch.randn(8, requires_grad=True) + ref = gn(x, y)[0] + + x_clone = x.clone().detach().requires_grad_(True) + y_clone = y.clone().detach().requires_grad_(True) + aot_fn = aot_function(fn, nop) + res = aot_fn(x_clone, y_clone) + + # Run backward + ref.sum().backward() + res.sum().backward() + + self.assertEqual(ref, res) + self.assertEqual(x.grad, x_clone.grad) + self.assertEqual(y.grad, y_clone.grad) + + def test_multiple(self): + n_layers = 2 + + def cos(x): + return (torch.cos(x),) + + def sin(x): + return (torch.sin(x),) + + def fn(x): + a = invoke_subgraph(cos, None, (x,))[0] + b = invoke_subgraph(sin, None, (a,))[0] + return invoke_subgraph(cos, None, (b,))[0] + + x = torch.randn(8, requires_grad=True) + ref = fn(x) + aot_fn = aot_function(fn, nop) + res = aot_fn(x) + + self.assertEqual(ref, res) + + def test_differing_strides_for_grad_outs(self): + class CustomOp(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return torch.sin(x) + + @staticmethod + def backward(ctx, grad_out): + a = grad_out.view(12, 5) + return torch.cos(torch.reshape(a, (3, 4, 5))) + + def gn(x): + return (CustomOp.apply(x),) + + def fn(x): + a = invoke_subgraph(gn, None, (x,))[0] + # Force stride changes so that backward view causes a failure if + # contiguous not called. + b = torch.permute(a, (0, 2, 1)) + return b + + x = torch.randn(3, 4, 5, requires_grad=True) + ref = torch.permute(gn(x)[0], (0, 2, 1)) + + x_clone = x.clone().detach().requires_grad_(True) + aot_fn = aot_function(fn, nop) + res = aot_fn(x_clone) + + # Run backward + ref.sum().backward() + res.sum().backward() + + self.assertEqual(ref, res) + self.assertEqual(x.grad, x_clone.grad) + + +@skipIfTorchDynamo("Not a torch._dynamo test") +class TestInvokeSubgraphCompile(TestCase): + def count_unique_get_attr_nodes(self, gm, args, expected): + subgraph_attr_names = set() + for node in gm.graph.nodes: + if node.op == "get_attr": + subgraph_attr_names.add(node.target) + self.assertEqual(len(subgraph_attr_names), expected) + + def test_simple(self): + def gn(x, y): + return (torch.mul(x, y),) + + def fn(x, y): + return invoke_subgraph(gn, None, (x, y))[0] + + x = torch.randn(8, requires_grad=True) + y = torch.randn(8, requires_grad=True) + ref = gn(x, y)[0] + + x_clone = x.clone().detach().requires_grad_(True) + y_clone = y.clone().detach().requires_grad_(True) + res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone, y_clone) + + # Run backward + ref.sum().backward() + res.sum().backward() + + self.assertEqual(ref, res) + self.assertEqual(x.grad, x_clone.grad) + self.assertEqual(y.grad, y_clone.grad) + + def test_dedupe(self): + def gn(x, y): + return (torch.mul(x, y),) + + def fn(x, y): + a = invoke_subgraph(gn, None, (x, y))[0] + return invoke_subgraph(gn, None, (a, y))[0] + + x = torch.randn(8, requires_grad=True) + y = torch.randn(8, requires_grad=True) + ref = fn(x, y) + + x_clone = x.clone().detach().requires_grad_(True) + y_clone = y.clone().detach().requires_grad_(True) + backend = AotEagerAndRecordGraphs() + res = torch.compile(fn, backend=backend, fullgraph=True)(x_clone, y_clone) + + # Run backward + ref.sum().backward() + res.sum().backward() + + self.assertEqual(ref, res) + self.assertEqual(x.grad, x_clone.grad) + self.assertEqual(y.grad, y_clone.grad) + + # Check that the Dynamo and AOT graphs have just one subgraph module + self.assertEqual(len(backend.graphs), 1) + self.assertEqual(len(backend.fw_graphs), 1) + self.assertEqual(len(backend.bw_graphs), 1) + self.count_unique_get_attr_nodes(backend.graphs[0], [], 1) + self.count_unique_get_attr_nodes(backend.fw_graphs[0], [], 1) + self.count_unique_get_attr_nodes(backend.bw_graphs[0], [], 1) + + if not TEST_WITH_CROSSREF: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[8]", L_y_: "f32[8]"): + l_x_ = L_x_ + l_y_ = L_y_ + + invoke_subgraph_0 = self.invoke_subgraph_0 + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None + a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None + + invoke_subgraph_1 = self.invoke_subgraph_0 + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (a, l_y_)); invoke_subgraph_1 = a = l_y_ = None + getitem_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None + return (getitem_1,) + + class invoke_subgraph_0(torch.nn.Module): + def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"): + child: "f32[8]" = torch.mul(l_x_, l_y_); l_x_ = l_y_ = None + return (child,) +""", + ) + + self.assertExpectedInline( + normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[8]", primals_2: "f32[8]"): + repeated_subgraph0 = self.repeated_subgraph0 + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, '___forward_invoke_subgraph_0', (primals_1, primals_2)); repeated_subgraph0 = None + getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None + + repeated_subgraph0_1 = self.repeated_subgraph0 + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, '___forward_invoke_subgraph_0', (getitem, primals_2)); repeated_subgraph0_1 = None + getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None + return (getitem_1, primals_1, primals_2, getitem) + + class repeated_subgraph0(torch.nn.Module): + def forward(self, arg0_1: "f32[8]", arg1_1: "f32[8]"): + mul: "f32[8]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None + return (mul,) +""", + ) + + def test_nonlocal_update(self): + counter = 2 + + def gn(x, y): + nonlocal counter + return (torch.mul(x, y) * counter,) + + def fn(x, y): + nonlocal counter + counter = 2 + a = invoke_subgraph(gn, None, (x, y))[0] + counter = 3 + return invoke_subgraph(gn, None, (a, y))[0] + + x = torch.randn(8, requires_grad=True) + y = torch.randn(8, requires_grad=True) + ref = fn(x, y) + + x_clone = x.clone().detach().requires_grad_(True) + y_clone = y.clone().detach().requires_grad_(True) + res = torch.compile(fn, backend="inductor", fullgraph=True)(x_clone, y_clone) + + # Run backward + ref.sum().backward() + res.sum().backward() + + self.assertEqual(ref, res) + self.assertEqual(x.grad, x_clone.grad) + self.assertEqual(y.grad, y_clone.grad) + + torch._dynamo.reset() + backend = AotEagerAndRecordGraphs() + torch.compile(fn, backend=backend, fullgraph=True)(x_clone, y_clone) + + if not TEST_WITH_CROSSREF: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[8]", L_y_: "f32[8]"): + l_x_ = L_x_ + l_y_ = L_y_ + + invoke_subgraph_0 = self.invoke_subgraph_0 + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None + a: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None + + invoke_subgraph_1 = self.invoke_subgraph_1 + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_1', (a, l_y_)); invoke_subgraph_1 = a = l_y_ = None + getitem_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None + return (getitem_1,) + + class invoke_subgraph_0(torch.nn.Module): + def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"): + mul: "f32[8]" = torch.mul(l_x_, l_y_); l_x_ = l_y_ = None + child: "f32[8]" = mul * 2; mul = None + return (child,) + + class invoke_subgraph_1(torch.nn.Module): + def forward(self, a: "f32[8]", l_y_: "f32[8]"): + mul: "f32[8]" = torch.mul(a, l_y_); a = l_y_ = None + child: "f32[8]" = mul * 3; mul = None + return (child,) +""", + ) + + def test_normalize_gm(self): + def gn(x, y): + # Different graph give different names to intermediate nodes + for _ in range(5): + x = x * y + return x + + def fn(x, y): + for _ in range(5): + x = invoke_subgraph(gn, None, (x, y)) + return x + + backend = AotEagerAndRecordGraphs() + opt_fn = torch.compile(fn, backend=backend, fullgraph=True) + + x = torch.randn(8, requires_grad=True) + y = torch.randn(8, requires_grad=True) + + opt_fn(x, y) + + if not TEST_WITH_CROSSREF: + self.assertExpectedInline( + normalize_gm(backend.graphs[0].print_readable(print_output=False)), + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[8]", L_y_: "f32[8]"): + l_x_ = L_x_ + l_y_ = L_y_ + + invoke_subgraph_0 = self.invoke_subgraph_0 + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_0, 'invoke_subgraph_0', (l_x_, l_y_)); invoke_subgraph_0 = l_x_ = None + x: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None + invoke_subgraph_1 = self.invoke_subgraph_0 + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_1, 'invoke_subgraph_0', (x, l_y_)); invoke_subgraph_1 = x = None + x_1: "f32[8]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None + invoke_subgraph_3 = self.invoke_subgraph_0 + invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_3, 'invoke_subgraph_0', (x_1, l_y_)); invoke_subgraph_3 = x_1 = None + x_2: "f32[8]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None + invoke_subgraph_5 = self.invoke_subgraph_0 + invoke_subgraph_6 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_5, 'invoke_subgraph_0', (x_2, l_y_)); invoke_subgraph_5 = x_2 = None + x_3: "f32[8]" = invoke_subgraph_6[0]; invoke_subgraph_6 = None + invoke_subgraph_7 = self.invoke_subgraph_0 + invoke_subgraph_8 = torch.ops.higher_order.invoke_subgraph(invoke_subgraph_7, 'invoke_subgraph_0', (x_3, l_y_)); invoke_subgraph_7 = x_3 = l_y_ = None + x_4: "f32[8]" = invoke_subgraph_8[0]; invoke_subgraph_8 = None + return (x_4,) + + class invoke_subgraph_0(torch.nn.Module): + def forward(self, l_x_: "f32[8]", l_y_: "f32[8]"): + x: "f32[8]" = l_x_ * l_y_; l_x_ = None + x_1: "f32[8]" = x * l_y_; x = None + x_2: "f32[8]" = x_1 * l_y_; x_1 = None + x_3: "f32[8]" = x_2 * l_y_; x_2 = None + x_4: "f32[8]" = x_3 * l_y_; x_3 = l_y_ = None + return (x_4,) +""", + ) + + def test_input_mutation(self): + def gn(x, y): + x.add_(1) + return (torch.mul(x, y),) + + def fn(x, y): + return invoke_subgraph(gn, None, (x, y))[0] + + x = torch.randn(8, requires_grad=False) + y = torch.randn(8, requires_grad=False) + + opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, "NYI: invoke_subgraph with aliasing" + ): + opt_fn(x, y) + + def test_input_aliasing(self): + def gn(x, y): + return (x, torch.mul(x, y)) + + def fn(x, y): + outs = invoke_subgraph(gn, None, (x, y)) + return outs[0] * outs[1] + + x = torch.randn(8, requires_grad=False) + y = torch.randn(8, requires_grad=False) + + opt_fn = torch.compile(fn, backend="inductor", fullgraph=True) + with self.assertRaisesRegex( + torch._dynamo.exc.Unsupported, "NYI: invoke_subgraph with aliasing" + ): + opt_fn(x, y) + + +if __name__ == "__main__": + run_tests() diff --git a/test/inductor/mock_cache.py b/test/inductor/mock_cache.py index 1f654aa1ce5de..9c6d0ad77362d 100644 --- a/test/inductor/mock_cache.py +++ b/test/inductor/mock_cache.py @@ -76,6 +76,7 @@ class _GlobalStats(threading.local): def __init__(self) -> None: self.autotune_local = _GlobalItemStats() self.autotune_remote = _GlobalItemStats() + self.bundled_autotune = _GlobalItemStats() self.fx_graph = _GlobalItemStats() self.triton = _GlobalItemStats() self.aot_autograd = _GlobalItemStats() @@ -83,6 +84,7 @@ def __init__(self) -> None: def reset(self) -> None: self.autotune_local.reset() self.autotune_remote.reset() + self.bundled_autotune.reset() self.fx_graph.reset() self.triton.reset() self.aot_autograd.reset() @@ -94,6 +96,7 @@ def report(self): subs = ( ("autotune_local", self.autotune_local), ("autotune_remote", self.autotune_remote), + ("bundled_autotune", self.bundled_autotune), ("fx_graph", self.fx_graph), ("triton", self.triton), ("aot_autograd", self.aot_autograd), @@ -151,7 +154,7 @@ def _put(self, key: str, data: Any) -> None: "fx_graph_remote_cache", "autotune_local_cache", "autotune_remote_cache", - # "bundled_autotune_cache", + "bundled_autotune_remote_cache", ) @@ -194,6 +197,12 @@ def __enter__(self) -> Self: ) self._stack.enter_context(ctx) + ctx = patch( + "torch._inductor.remote_cache.RemoteBundledAutotuneCache.backend_override_cls", + MockBackend.with_name("bundled_autotune"), + ) + self._stack.enter_context(ctx) + ctx = patch( "torch._inductor.remote_cache.RemoteFxGraphCache.backend_override_cls", MockBackend.with_name("fx_graph"), @@ -213,6 +222,12 @@ def __enter__(self) -> Self: ) self._stack.enter_context(ctx) + ctx = patch( + "torch._inductor.fb.remote_cache.FbRemoteBundledAutotuneCache.backend_override_cls", + MockBackend.with_name("bundled_autotune"), + ) + self._stack.enter_context(ctx) + ctx = patch( "torch._inductor.fb.remote_cache.FbRemoteFxGraphCache.backend_override_cls", MockBackend.with_name("fx_graph"), diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 652dba3f85bf7..3f346a2098140 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1,6 +1,7 @@ # Owner(s): ["module: inductor"] import copy import itertools +import logging import os import sys import tempfile @@ -14,6 +15,8 @@ import torch._inductor import torch._inductor.config import torch.nn as nn +import torch.nn.functional as F +from torch._dynamo import config as dynamo_config from torch._dynamo.testing import rand_strided, same from torch._dynamo.utils import counters from torch._inductor import config @@ -40,6 +43,7 @@ skipIfRocm, TEST_WITH_ROCM, ) +from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test from torch.testing._internal.triton_utils import HAS_CUDA, requires_cuda from torch.utils import _pytree as pytree @@ -108,7 +112,6 @@ def check_model( ): with torch.no_grad(), config.patch( { - "abi_compatible": self.abi_compatible, "allow_stack_allocation": self.allow_stack_allocation, "use_minimal_arrayref_interface": self.use_minimal_arrayref_interface, } @@ -142,8 +145,8 @@ def check_model_with_multiple_inputs( ): with torch.no_grad(), config.patch( { - "abi_compatible": self.abi_compatible, "allow_stack_allocation": self.allow_stack_allocation, + "use_minimal_arrayref_interface": self.use_minimal_arrayref_interface, } ): torch.manual_seed(0) @@ -167,7 +170,14 @@ def code_check_count( target_str: str, target_count: int, ): - so_path = torch._export.aot_compile(model, example_inputs) + with torch.no_grad(), config.patch( + { + "allow_stack_allocation": self.allow_stack_allocation, + "use_minimal_arrayref_interface": self.use_minimal_arrayref_interface, + } + ): + so_path = torch._export.aot_compile(model, example_inputs) + with open(os.path.splitext(so_path)[0] + ".cpp") as cpp: src_code = cpp.read() FileCheck().check_count( @@ -191,7 +201,12 @@ def forward(self, x, y): torch.randn(10, 10, device=self.device), torch.randn(10, 10, device=self.device), ) - self.check_model(Model(), example_inputs) + model = Model() + self.check_model(model, example_inputs) + if self.use_minimal_arrayref_interface: + self.code_check_count( + model, example_inputs, "AOTInductorModelRunMinimalArrayrefInterface(", 1 + ) def test_small_constant(self): class Model(torch.nn.Module): @@ -373,7 +388,8 @@ def forward(self, x, y): "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used", ) def test_conv_freezing(self): - for dtype, groups in itertools.product([torch.bfloat16, torch.float], [1, 2]): + dtypes = [torch.bfloat16, torch.float] if SM80OrLater else [torch.float] + for dtype, groups in itertools.product(dtypes, [1, 2]): iC = 2 oC = 3 @@ -427,7 +443,8 @@ def forward(self, y): "Not yet runnable in fbcode when the model.so is newly generated while older PyTorch is used", ) def test_linear_freezing(self): - for dtype in [torch.float32, torch.bfloat16]: + dtypes = [torch.bfloat16, torch.float] if SM80OrLater else [torch.float] + for dtype in dtypes: class LinearModel(torch.nn.Module): def __init__(self, device): @@ -1065,6 +1082,25 @@ def forward(self, x, y): ) self.check_model(Repro(), example_inputs) + @config.patch({"triton.autotune_at_compile_time": None}) + def test_stride_with_unbacked_expr(self): + class Repro(torch.nn.Module): + def forward(self, x, y): + u0 = x.item() + torch._check(u0 >= 1) + s0 = y.size(0) + expr = u0 * s0 + sevens = torch.empty_strided( + size=(10, expr, 32), stride=(expr * 32, 32, 1), device=x.device + ).fill_(7) + return sevens * 3 + + example_inputs = ( + torch.scalar_tensor(2, dtype=torch.int, device=self.device), + torch.ones(8, device=self.device), + ) + self.check_model(Repro(), example_inputs) + def test_large_grid(self): if self.device != "cuda": raise unittest.SkipTest("requires CUDA") @@ -1238,6 +1274,30 @@ def test_cond_non_tensor_predicates(self, dynamic): dynamic_shapes=dynamic_shapes, ) + def test_cond_symint_input(self): + class M(torch.nn.Module): + def forward(self, x, y, z): + a = y.shape[0] + b = z.shape[0] + + def true_fn(x): + return x + a + + def false_fn(x): + return x + b * z + + return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,)) + + input1 = (torch.ones(3, 3), torch.ones(5), torch.ones(3, 3)) + input2 = (torch.ones(10, 3), torch.ones(6), torch.ones(10, 3)) + inputs = (input1, input2) + dynamic_shapes = {"x": {0: Dim("d")}, "y": {0: Dim("d1")}, "z": {0: Dim("d")}} + self.check_model_with_multiple_inputs( + M(), + inputs, + dynamic_shapes=dynamic_shapes, + ) + def test_while_loop_simple(self): inputs = ( torch.randn((10, 20), device=self.device), @@ -1532,7 +1592,7 @@ def forward(self, x, y): result_cpu = Model(w1, w2)(*inputs) # Compile model with AOTInductor - with torch.cuda.device(0), config.patch("abi_compatible", self.abi_compatible): + with torch.cuda.device(0): so_path = AOTIRunnerUtil.compile( model=Model(w1.cuda(0), w2.cuda(0)), example_inputs=tuple(t.cuda(0) for t in inputs), @@ -1623,16 +1683,12 @@ def forward(self, x, y): inputs = (torch.randn(10, 10), torch.randn(10, 10)) result_cpu = Model(weight)(*inputs) - with torch.cuda.device(0), torch.no_grad(), config.patch( - "abi_compatible", self.abi_compatible - ): + with torch.cuda.device(0), torch.no_grad(): result_cuda_0 = AOTIRunnerUtil.run( "cuda", Model(weight.cuda(0)), tuple(t.cuda(0) for t in inputs) ) - with torch.cuda.device(1), torch.no_grad(), config.patch( - "abi_compatible", self.abi_compatible - ): + with torch.cuda.device(1), torch.no_grad(): result_cuda_1 = AOTIRunnerUtil.run( "cuda", Model(weight.cuda(1)), tuple(t.cuda(1) for t in inputs) ) @@ -2792,7 +2848,6 @@ def forward(self, x, y): self.check_model(Model(), example_inputs) - @config.patch({"abi_compatible": True}) def test_triton_kernel_reinterpret_view_mem_leak(self): # Check for memory leak when using user-defined Triton Kernel + AOTI. if self.device != "cuda": @@ -2932,23 +2987,33 @@ class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() - def forward(self, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9): - return (x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) + if SM80OrLater: + + def forward(self, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9): + return (x0, x1, x2, x3, x4, x5, x6, x7, x8, x9) + + else: + + def forward(self, x0, x1, x2, x4, x5, x6, x7, x8, x9): + return (x0, x1, x2, x4, x5, x6, x7, x8, x9) inputs = [] - for dtype in ( + dtypes = [ torch.float16, torch.float32, torch.float64, - torch.bfloat16, torch.bool, torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8, - ): + ] + if SM80OrLater: + dtypes.append(torch.bfloat16) + for dtype in dtypes: inputs.append(torch.ones(4, 8, 10, dtype=dtype, device=self.device)) + dim0 = Dim("s0", min=2, max=1024) dim1 = Dim("s1", min=2, max=512) dim2 = Dim("s2", min=2, max=128) @@ -2956,7 +3021,6 @@ def forward(self, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9): "x0": {0: dim0}, "x1": {0: dim0}, "x2": {0: dim0}, - "x3": {1: dim1}, "x4": {1: dim1}, "x5": {1: dim1}, "x6": {}, @@ -2964,11 +3028,13 @@ def forward(self, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9): "x8": {2: dim2}, "x9": {2: dim2}, } + if SM80OrLater: + dynamic_shapes["x3"] = {1: dim1} + m = Model() inputs = tuple(inputs) with torch.no_grad(), config.patch( { - "abi_compatible": self.abi_compatible, "aot_inductor.debug_compile": True, } ): @@ -2977,22 +3043,28 @@ def forward(self, x0, x1, x2, x3, x4, x5, x6, x7, x8, x9): src_code = cpp.read() FileCheck().check_count( "unmatched dtype", - 10, + 10 if SM80OrLater else 9, exactly=True, ).run(src_code) FileCheck().check_count( "unmatched dim value at", - 21, # we have 9 dynamic dims for which we generate different checks + 21 + if SM80OrLater + else 19, # we have 9 dynamic dims for which we generate different checks exactly=True, ).run(src_code) FileCheck().check_count( "dim value is too", - 18, # we have 9 dynamic dims for which we generate two checks + 18 + if SM80OrLater + else 16, # we have 9 dynamic dims for which we generate two checks exactly=True, ).run(src_code) FileCheck().check_count( "unmatched stride value at", - 21, # we have 9 symbolic strides for which we don't generate checks + 21 + if SM80OrLater + else 19, # we have 9 symbolic strides for which we don't generate checks exactly=True, ).run(src_code) optimized = AOTIRunnerUtil.load(self.device, so_path) @@ -3032,7 +3104,6 @@ def forward(self, x0, x1): } with torch.no_grad(), config.patch( { - "abi_compatible": self.abi_compatible, "aot_inductor.debug_compile": True, } ): @@ -3069,7 +3140,6 @@ def forward(self, x0, x1, x2): } with torch.no_grad(), config.patch( { - "abi_compatible": self.abi_compatible, "aot_inductor.debug_compile": True, } ): @@ -3093,7 +3163,6 @@ def forward(self, x): model = Model() with torch.no_grad(), config.patch( { - "abi_compatible": self.abi_compatible, "aot_inductor.debug_compile": True, } ): @@ -3117,11 +3186,7 @@ def forward(self, x): x = torch.randn(3, 4, dtype=torch.float16, device=self.device) model = Model() - with torch.no_grad(), config.patch( - { - "abi_compatible": self.abi_compatible, - } - ): + with torch.no_grad(): result = AOTIRunnerUtil.run( self.device, model, @@ -3146,11 +3211,7 @@ def forward(self, x): x = torch.randn(3, 4, dtype=torch.float32, device=self.device) model = Model() - with torch.no_grad(), config.patch( - { - "abi_compatible": self.abi_compatible, - } - ): + with torch.no_grad(): result = AOTIRunnerUtil.run( self.device, model, @@ -3189,7 +3250,6 @@ def forward(self, x): model = Model() with torch.no_grad(), config.patch( { - "abi_compatible": self.abi_compatible, "aot_inductor.debug_compile": True, } ): @@ -3313,9 +3373,7 @@ def forward(self, values, offsets): model, example_inputs_list, dynamic_shapes=dynamic_shapes ) - # max_autotune is disabled due to https://github.com/pytorch/pytorch/issues/135106 - # @common_utils.parametrize("max_autotune", [False, True]) - @common_utils.parametrize("max_autotune", [False]) + @common_utils.parametrize("max_autotune", [True, False]) def test_misc_1(self, max_autotune): if self.device == "cpu" and IS_MACOS and max_autotune: raise unittest.SkipTest("max_autotune not supported on macos") @@ -3591,6 +3649,161 @@ def forward(self, x): example_inputs = (torch.randn(8, device=self.device),) self.check_model(Model(), example_inputs) + def test_tile_positional_embedding(self): + class TilePositionalEmbedding(nn.Module): + """ + Positional embedding for tiles, different for every tile, same for every token within a tile. + Notice that tile is different from patch (token). For details, please check the documentation of + :class:`torchtune.modules.vision_transformer.VisionTransformer`. + Args: + max_num_tiles (int): The maximum number of tiles an image can be divided into. + embed_dim (int): The dimensionality of each tile embedding. + """ + + def __init__( + self, + max_num_tiles: int, + embed_dim: int, + ): + super().__init__() + self.max_num_tiles = max_num_tiles + self.embed_dim = embed_dim + + scale = embed_dim**-0.5 + self.embedding = nn.Parameter( + scale * torch.randn(max_num_tiles, max_num_tiles, 1, embed_dim) + ) + self.gate = nn.Parameter(torch.zeros(1)) + + def forward( + self, x: torch.Tensor, aspect_ratio: torch.Tensor + ) -> torch.Tensor: + """ + args: + x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim). + aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2), + representing the aspect ratio of the image before tile-cropping, e.g. (2,1). + returns: + torch.Tensor: The input tensor with added positional embeddings. + """ + bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape + torch._check(n_tiles <= self.max_num_tiles) + + for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio): + # When we batch images, all are padded to the same amount of tiles. + # The aspect_ratio lets us know the non padded tiles for each image. + # We only add positional encoding to those. + n_tiles_h = n_tiles_h.item() + n_tiles_w = n_tiles_w.item() + + n_non_padded_tiles = int(n_tiles_h * n_tiles_w) + + # We get only the positional encoding for non padded tiles, + # i.e. n_tiles_h, n_tiles_w. + torch._check_is_size(n_tiles_h) + torch._check_is_size(n_tiles_w) + torch._check(n_tiles_h > 0) + torch._check(n_tiles_w > 0) + torch._check(n_tiles_h <= self.max_num_tiles) + torch._check(n_tiles_w <= self.max_num_tiles) + padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1)) + # pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :] + pos_embed = padded_embedding.narrow(0, 0, n_tiles_h).narrow( + 1, 0, n_tiles_w + ) + + # Add pos encoding to the non padded tiles. + pos_embed = pos_embed.clone() + pos_embed = pos_embed.view(n_non_padded_tiles, 1, self.embed_dim) + + x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0)) + torch._check_is_size(n_non_padded_tiles) + torch._check(n_non_padded_tiles < x.size(1)) + # x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed + updating = x.narrow(0, batch_idx, batch_idx + 1).narrow( + 1, 0, n_non_padded_tiles + ) + # updating += pos_embed * self.gate.tanh() + updating.add_(pos_embed * self.gate.tanh()) + # x = x[:, :n_tiles, :, :] + x = x.narrow(1, 0, n_tiles) + + return x + + x = torch.ones(1, 4, 1600, 1280, device=self.device) + aspect_ratio = torch.tensor([[2, 2]], device=self.device) + + self.check_model( + TilePositionalEmbedding(4, 1280), + (x, aspect_ratio), + ) + + @dynamo_config.patch({"capture_scalar_outputs": True}) + def test_sym_i64_input_codegen(self): + if self.device != "cuda": + raise unittest.SkipTest("requires CUDA") + + from torch.testing._internal.triton_utils import add_kernel + + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x): + x_symint = x.item() + a = torch.ones(x_symint, device="cuda") + b = torch.ones(x_symint, device="cuda") + out = torch.zeros_like(a) + # unbacked symint in grid + add_kernel[(1, 1, x_symint)](a, b, out, x_symint, 32) + return out + + example_inputs = ( + torch.randint(high=1024, size=(1,), device=self.device, dtype=torch.int32), + ) + # This simple unit test case model generates two triton kernels: + # 1. triton_poi_fused_ones_1: + # triton_meta={'signature': {'out_ptr0': '*fp32', 'xnumel': 'i64'} + # 2. add_kernel: + # triton_meta={'signature': {'in_ptr0': '*fp32', 'in_ptr1': '*fp32', 'out_ptr': '*fp32', 'n_elements': 'i64'} + # input u0 was defined as int32_t initially, verify for every kernel var args downstream, + # it gets explicitly declared using its data types in the cpp wrapper codegen code. + expected_scalar_args = [ + "int64_t var_1 = u0;", + "int64_t var_3 = u0;", + "int64_t var_5 = u0;", + "int64_t var_9 = u0;", + ] + # check the new behavior of codegen is expected + result, code = run_and_get_cpp_code( + AOTIRunnerUtil.compile, Model(), example_inputs + ) + for scalar_line in expected_scalar_args: + FileCheck().check_count( + scalar_line, + 1, + ).run(code) + + self.check_model(Model(), example_inputs) + + +class AOTInductorLoggingTest(LoggingTestCase): + @make_logging_test(dynamic=logging.DEBUG) + def test_shape_env_reuse(self, records): + # make sure ShapeEnv is only created once and reused afterwards + class Foo(torch.nn.Module): + def forward(self, x): + return x + 2 + + inputs = (torch.randn(4, 4),) + dynamic_shapes = { + "x": {0: Dim.AUTO, 1: Dim.AUTO}, + } + ep = export(Foo(), inputs, dynamic_shapes=dynamic_shapes, strict=False) + with torch.no_grad(): + torch._inductor.aot_compile(ep.module(), inputs) + self.assertEqual([r.msg == "create_env" for r in records].count(True), 1) + common_utils.instantiate_parametrized_tests(AOTInductorTestsTemplate) @@ -3609,309 +3822,47 @@ def setUp(self): super().setUp() -class AOTInductorTestABICompatibleCpu(AOTITestCase): - device = "cpu" - abi_compatible = True - check_model = check_model - check_model_with_multiple_inputs = check_model_with_multiple_inputs - code_check_count = code_check_count - allow_stack_allocation = False - use_minimal_arrayref_interface = False - - -def fail_with_and_without_stack_allocation(is_skip=False): - return TestFailure( - ( - "abi_compatible_cpu", - "abi_compatible_cpu_with_stack_allocation", - "abi_compatible_cpu_with_stack_allocation_and_minimal_arrayref_interface", - ), - is_skip=is_skip, - ) - - -def fail_stack_allocation(is_skip=False): +def fail_cpu(is_skip=False): return TestFailure( - ( - "abi_compatible_cpu_with_stack_allocation", - "abi_compatible_cpu_with_stack_allocation_and_minimal_arrayref_interface", - ), - is_skip=is_skip, - ) - - -def fail_minimal_arrayref_interface(is_skip=False): - return TestFailure( - ("abi_compatible_cpu_with_stack_allocation_and_minimal_arrayref_interface",), + ("cpu",), is_skip=is_skip, ) def fail_cuda(is_skip=False): return TestFailure( - ("abi_compatible_cuda", "non_abi_compatible_cuda"), - is_skip=is_skip, - ) - - -def fail_abi_compatible_cuda(is_skip=False): - return TestFailure( - ("abi_compatible_cuda",), - is_skip=is_skip, - ) - - -def fail_non_abi_compatible_cuda(is_skip=False): - return TestFailure( - ("non_abi_compatible_cuda",), + ("cuda"), is_skip=is_skip, ) # test_failures, xfail by default, set is_skip=True to skip CPU_TEST_FAILURES = { - # TODO: error: ‘complex64’ was not declared in this scope - "test_add_complex": fail_minimal_arrayref_interface(is_skip=True), - "test_conv_freezing": fail_minimal_arrayref_interface(is_skip=True), - "test_deconv_freezing": fail_minimal_arrayref_interface(is_skip=True), - # FIXME: failed with Segfault while exiting the Python runtime - "test_duplicate_constant_folding": fail_with_and_without_stack_allocation( - is_skip=True - ), - # TODO: use of deleted function RAIIAtenTensorHandle - "test_dup_unbacked_sym_decl": fail_minimal_arrayref_interface(is_skip=True), - # TODO: use of deleted function RAIIAtenTensorHandle - "test_dup_unbacked_sym_decl_with_refinement": fail_minimal_arrayref_interface( - is_skip=True - ), - # TODO: error: cannot convert ArrayRefTensor to AtenTensorHandle - "test_dynamic_cat": fail_minimal_arrayref_interface(), - # https://github.com/pytorch/pytorch/issues/129550 - # https://github.com/pytorch/pytorch/issues/123691 - "test_dynamic_scalar": fail_minimal_arrayref_interface(is_skip=True), - # https://github.com/pytorch/pytorch/issues/122980 - "test_fft_c2c": fail_stack_allocation(is_skip=True), - "test_freezing": fail_minimal_arrayref_interface(is_skip=True), - "test_linear_freezing": fail_minimal_arrayref_interface(is_skip=True), - # FIXME: failed with Segfault while exiting the Python runtime - "test_missing_cubin": fail_with_and_without_stack_allocation(is_skip=True), - # minimal arrayref interface only works with CPU; test crashes. - # https://github.com/pytorch/pytorch/issues/122983 - "test_multi_device": fail_minimal_arrayref_interface(is_skip=True), - # TODO: AssertionError: unsupported Optional type in convert_arg_type: Generator - "test_normal_functional": fail_with_and_without_stack_allocation(is_skip=True), - # TODO: The same issue as https://github.com/pytorch/pytorch/issues/122978 - # error: cannot convert ArrayRefTensor to AtenTensorHandle - "test_reuse_kernel_dynamic": fail_minimal_arrayref_interface(is_skip=True), - # the test segfaults - "test_repeat_output": fail_stack_allocation(is_skip=True), # TODO: failed internally - "test_multiple_output_alias": fail_with_and_without_stack_allocation(is_skip=True), - # segfault - "test_buffer_mutation_1": fail_stack_allocation(is_skip=True), - # segfault - "test_buffer_mutation_2": fail_stack_allocation(is_skip=True), - # segfault - "test_bool_input": fail_stack_allocation(is_skip=True), - # segfault - "test_int_list_input": fail_stack_allocation(is_skip=True), - # segfault - # 'AOTInductorTestABICompatibleCpuWithStackAllocation' object has no attribute 'code_check_count' - "test_buffer_mutation_3": fail_stack_allocation(is_skip=True), - # FIXME: failed with Segfault while exiting the Python runtime - "test_scatter_fallback": fail_stack_allocation(is_skip=True), - # Looks like the same issue as https://github.com/pytorch/pytorch/issues/122978 - "test_scatter_reduce_fallback": fail_minimal_arrayref_interface(is_skip=True), - # Looks like the same issue as https://github.com/pytorch/pytorch/issues/122978 - "test_index_put_fallback": fail_minimal_arrayref_interface(is_skip=True), - # https://github.com/pytorch/pytorch/issues/122984 - "test_index_put_with_none_index": fail_minimal_arrayref_interface(is_skip=True), - # FIXME: failed with Segfault while exiting the Python runtime - "test_constant": fail_stack_allocation(is_skip=True), - # Looks like the same issue as https://github.com/pytorch/pytorch/issues/122978 - "test_shifted_constraint_ranges": fail_with_and_without_stack_allocation( - is_skip=True - ), - # https://github.com/pytorch/pytorch/issues/123691 - "test_amp_fallback_random": fail_minimal_arrayref_interface(is_skip=True), - "test_simple_dynamic": fail_minimal_arrayref_interface(), - # https://github.com/pytorch/pytorch/issues/123691 - "test_zero_grid_with_unbacked_symbols": fail_minimal_arrayref_interface( - is_skip=True - ), - # failed on MacOS - "test_zero_grid_with_backed_symbols": fail_with_and_without_stack_allocation( - is_skip=True - ), - # https://github.com/pytorch/pytorch/issues/122990 - "test_cond_non_tensor_predicates_dynamic_False": fail_stack_allocation( - is_skip=True - ), - # same issue as https://github.com/pytorch/pytorch/issues/122990 - "test_cond_non_tensor_predicates_dynamic_True": fail_stack_allocation(is_skip=True), - # https://github.com/pytorch/pytorch/issues/122991 - "test_runtime_checks_complex": fail_with_and_without_stack_allocation(is_skip=True), - "test_runtime_checks_fp8": fail_with_and_without_stack_allocation(is_skip=True), - "test_while_loop_simple": fail_stack_allocation(is_skip=True), - "test_while_loop_nested": fail_stack_allocation(is_skip=True), - "test_while_loop_with_outer_code": fail_stack_allocation(is_skip=True), - # TODO: error: cannot convert ArrayRefTensor to AtenTensorHandle - "test_while_loop_with_outer_buffers": fail_stack_allocation(is_skip=True), - # TODO: use of undeclared identifier 'float8_e4m3fn' and 'half' - "test_fp8": fail_minimal_arrayref_interface(is_skip=True), - "test_custom_op_add": fail_minimal_arrayref_interface(is_skip=True), - "test_custom_op_all_inputs": fail_minimal_arrayref_interface(is_skip=True), - "test_custom_op_with_multiple_outputs": fail_minimal_arrayref_interface( - is_skip=True - ), - "test_custom_op_with_reinterpret_view_inputs": fail_minimal_arrayref_interface( - is_skip=True - ), - "test_custom_op_with_concat_inputs": fail_minimal_arrayref_interface(is_skip=True), - "test_custom_op_missing_arg_with_default_value": fail_minimal_arrayref_interface( - is_skip=True - ), - "test_size_from_multi_output": fail_stack_allocation(is_skip=True), - "test_torchvision_transforms_functional_tensor_resize": fail_minimal_arrayref_interface(), + "test_multiple_output_alias": fail_cpu(is_skip=True), } # test_failures, xfail by default, set is_skip=True to skip CUDA_TEST_FAILURES = { - # TODO: AssertionError: unsupported Optional type in convert_arg_type: Generator - "test_normal_functional": fail_abi_compatible_cuda(is_skip=True), - # no runtime checks for non_abi_compatible mode - "test_runtime_checks": fail_non_abi_compatible_cuda(is_skip=True), - "test_runtime_checks_complex": fail_non_abi_compatible_cuda(is_skip=True), - "test_runtime_checks_fp8": fail_non_abi_compatible_cuda(is_skip=True), - "test_runtime_checks_dtype_failed": fail_non_abi_compatible_cuda(is_skip=True), - "test_runtime_checks_shape_failed": fail_non_abi_compatible_cuda(is_skip=True), # quantized unsupported for GPU - "test_quantized_linear": fail_cuda(is_skip=True), - "test_quanatized_int8_linear": fail_cuda(is_skip=True), - "test_custom_op_add": fail_non_abi_compatible_cuda(is_skip=True), - "test_custom_op_all_inputs": fail_non_abi_compatible_cuda(is_skip=True), - "test_custom_op_missing_arg_with_default_value": fail_non_abi_compatible_cuda( - is_skip=True - ), - "test_custom_op_with_concat_inputs": fail_non_abi_compatible_cuda(is_skip=True), - "test_custom_op_with_reinterpret_view_inputs": fail_non_abi_compatible_cuda( - is_skip=True - ), - "test_custom_op_with_multiple_outputs": fail_non_abi_compatible_cuda(is_skip=True), - # non-abi compatible mode aoti debug printer is not supported yet - "test_aoti_debug_printer_codegen": fail_non_abi_compatible_cuda(is_skip=True), - "test_aoti_debug_printer_user_defined_triton_kernel": fail_non_abi_compatible_cuda( - is_skip=True - ), - "test_aoti_debug_printer_sym_inputs": fail_non_abi_compatible_cuda(is_skip=True), + "test_quantized_linear": fail_cuda(), + "test_quanatized_int8_linear": fail_cuda(), } -if not IS_FBCODE: - # The following tests look like they pass in both pytest and unittest (xml - # and terminal output say pass), but the process will segfault. This only - # happens in OSS CI and is fine internally. - CPU_TEST_FAILURES.update( - { - "test_duplicated_params": fail_stack_allocation(is_skip=True), - "test_embedding_bag": fail_stack_allocation(is_skip=True), - "test_fqn": fail_stack_allocation(is_skip=True), - "test_no_args": fail_stack_allocation(is_skip=True), - "test_output_misaligned": fail_stack_allocation(is_skip=True), - "test_pytree_inputs": fail_stack_allocation(is_skip=True), - "test_seq": fail_stack_allocation(is_skip=True), - "test_simple_split": fail_stack_allocation(is_skip=True), - "test_addmm": fail_minimal_arrayref_interface(is_skip=True), - "test_aliased_buffer_reuse": fail_minimal_arrayref_interface(is_skip=True), - "test_buffer_reuse": fail_minimal_arrayref_interface(is_skip=True), - "test_constant_folding": fail_minimal_arrayref_interface(is_skip=True), - "test_convolution": fail_minimal_arrayref_interface(is_skip=True), - "test_empty_graph": fail_minimal_arrayref_interface(is_skip=True), - "test_large_weight": fail_minimal_arrayref_interface(is_skip=True), - "test_large_mmaped_weights": fail_minimal_arrayref_interface(is_skip=True), - "test_normal_functional": fail_minimal_arrayref_interface(is_skip=True), - "test_misc_1": fail_minimal_arrayref_interface(is_skip=True), - "test_missing_output": fail_minimal_arrayref_interface(is_skip=True), - "test_model_modified_weights": fail_minimal_arrayref_interface( - is_skip=True - ), - "test_output_path_1": fail_minimal_arrayref_interface(is_skip=True), - "test_quantized_linear": fail_minimal_arrayref_interface(is_skip=True), - "test_quanatized_int8_linear": fail_minimal_arrayref_interface( - is_skip=True - ), - "test_repeat_interleave": fail_minimal_arrayref_interface(is_skip=True), - "test_return_constant": fail_minimal_arrayref_interface(is_skip=True), - "test_reuse_kernel": fail_minimal_arrayref_interface(is_skip=True), - "test_simple": fail_minimal_arrayref_interface(is_skip=True), - "test_small_constant": fail_minimal_arrayref_interface(is_skip=True), - "test_with_no_triton_profiler": fail_minimal_arrayref_interface( - is_skip=True - ), - "test_with_offset": fail_minimal_arrayref_interface(is_skip=True), - "test_with_profiler": fail_minimal_arrayref_interface(is_skip=True), - "test_zero_size_weight": fail_minimal_arrayref_interface(is_skip=True), - "test_aoti_debug_printer_codegen": fail_with_and_without_stack_allocation( - is_skip=True - ), - "test_view_outputs": fail_minimal_arrayref_interface(is_skip=True), - "test_aoti_debug_printer_cpp_kernel": fail_with_and_without_stack_allocation( - is_skip=True - ), - } - ), - # The following test passes internally but fails in OSS CI. To be investigated. - CUDA_TEST_FAILURES.update( - { - "test_aoti_debug_printer_codegen": fail_cuda(is_skip=True), - "test_aoti_debug_printer_user_defined_triton_kernel": fail_cuda( - is_skip=True - ), - "test_aoti_debug_printer_sym_inputs": fail_cuda(is_skip=True), - } - ) - -copy_tests( - AOTInductorTestsTemplate, - AOTInductorTestABICompatibleCpu, - "abi_compatible_cpu", - CPU_TEST_FAILURES, -) - - -class AOTInductorTestABICompatibleCpuWithStackAllocation(AOTITestCase): +class AOTInductorTestABICompatibleCpu(AOTITestCase): device = "cpu" - abi_compatible = True check_model = check_model check_model_with_multiple_inputs = check_model_with_multiple_inputs code_check_count = code_check_count - allow_stack_allocation = True + allow_stack_allocation = False use_minimal_arrayref_interface = False copy_tests( AOTInductorTestsTemplate, - AOTInductorTestABICompatibleCpuWithStackAllocation, - "abi_compatible_cpu_with_stack_allocation", - CPU_TEST_FAILURES, -) - - -class AOTInductorTestABICompatibleCpuWithStackAllocationAndMinimalArrayRefInterface( - TestCase -): - device = "cpu" - abi_compatible = True - check_model = check_model - check_model_with_multiple_inputs = check_model_with_multiple_inputs - allow_stack_allocation = True - use_minimal_arrayref_interface = True - - -copy_tests( - AOTInductorTestsTemplate, - AOTInductorTestABICompatibleCpuWithStackAllocationAndMinimalArrayRefInterface, - "abi_compatible_cpu_with_stack_allocation_and_minimal_arrayref_interface", + AOTInductorTestABICompatibleCpu, + "cpu", CPU_TEST_FAILURES, ) @@ -3919,7 +3870,6 @@ class AOTInductorTestABICompatibleCpuWithStackAllocationAndMinimalArrayRefInterf @unittest.skipIf(sys.platform == "darwin", "No CUDA on MacOS") class AOTInductorTestABICompatibleCuda(AOTITestCase): device = "cuda" - abi_compatible = True check_model = check_model check_model_with_multiple_inputs = check_model_with_multiple_inputs code_check_count = code_check_count @@ -3930,90 +3880,10 @@ class AOTInductorTestABICompatibleCuda(AOTITestCase): copy_tests( AOTInductorTestsTemplate, AOTInductorTestABICompatibleCuda, - "abi_compatible_cuda", + "cuda", CUDA_TEST_FAILURES, ) - -@unittest.skipIf( - IS_FBCODE or sys.platform == "darwin", - "NonABI mode should not be used in fbcode nor on MacOS", -) -class AOTInductorTestNonABICompatibleCpu(AOTITestCase): - device = "cpu" - abi_compatible = False - check_model = check_model - check_model_with_multiple_inputs = check_model_with_multiple_inputs - code_check_count = code_check_count - allow_stack_allocation = False - use_minimal_arrayref_interface = False - - -copy_tests( - AOTInductorTestsTemplate, - AOTInductorTestNonABICompatibleCpu, - "non_abi_compatible_cpu", - # test_failures, xfail by default, set is_skip=True to skip - { - "test_duplicate_constant_folding": TestFailure( - ("non_abi_compatible_cpu",), is_skip=True - ), - # no runtime checks for non_abi_compatible mode - "test_runtime_checks": TestFailure(("non_abi_compatible_cpu",), is_skip=True), - "test_runtime_checks_dtype_failed": TestFailure( - ("non_abi_compatible_cpu",), is_skip=True - ), - "test_runtime_checks_shape_failed": TestFailure( - ("non_abi_compatible_cpu",), is_skip=True - ), - "test_custom_op_add": TestFailure(("non_abi_compatible_cpu",), is_skip=True), - "test_aoti_debug_printer_codegen": TestFailure( - ("non_abi_compatible_cpu",), is_skip=True - ), - "test_custom_op_all_inputs": TestFailure( - ("non_abi_compatible_cpu",), is_skip=True - ), - "test_custom_op_missing_arg_with_default_value": TestFailure( - ("non_abi_compatible_cpu",), is_skip=True - ), - "test_custom_op_with_concat_inputs": TestFailure( - ("non_abi_compatible_cpu",), is_skip=True - ), - "test_custom_op_with_multiple_outputs": TestFailure( - ("non_abi_compatible_cpu",), is_skip=True - ), - "test_custom_op_with_reinterpret_view_inputs": TestFailure( - ("non_abi_compatible_cpu",), is_skip=True - ), - "test_aoti_debug_printer_cpp_kernel": TestFailure( - ("non_abi_compatible_cpu",), is_skip=True - ), - }, -) - - -@unittest.skipIf( - IS_FBCODE or sys.platform == "darwin", - "NonABI mode should not be used in fbcode nor on MacOS", -) -class AOTInductorTestNonABICompatibleCuda(AOTITestCase): - device = "cuda" - abi_compatible = False - check_model = check_model - check_model_with_multiple_inputs = check_model_with_multiple_inputs - code_check_count = code_check_count - allow_stack_allocation = False - use_minimal_arrayref_interface = False - - -copy_tests( - AOTInductorTestsTemplate, - AOTInductorTestNonABICompatibleCuda, - "non_abi_compatible_cuda", - CUDA_TEST_FAILURES, -) - - if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_aot_inductor_arrayref.py b/test/inductor/test_aot_inductor_arrayref.py new file mode 100644 index 0000000000000..a35099c4df8ca --- /dev/null +++ b/test/inductor/test_aot_inductor_arrayref.py @@ -0,0 +1,222 @@ +# Owner(s): ["module: inductor"] +import sys +import unittest + +from torch._inductor.test_case import TestCase +from torch.testing._internal.common_utils import IS_CI, IS_FBCODE, IS_WINDOWS + + +if IS_WINDOWS and IS_CI: + sys.stderr.write( + "Windows CI does not have necessary dependencies for test_torchinductor yet\n" + ) + if __name__ == "__main__": + sys.exit(0) + raise unittest.SkipTest("requires sympy/functorch/filelock") + +try: + try: + from .test_aot_inductor import ( + AOTInductorTestsTemplate, + AOTITestCase, + check_model, + check_model_with_multiple_inputs, + code_check_count, + ) + from .test_torchinductor import copy_tests, TestFailure + except ImportError: + from test_aot_inductor import ( # @manual + AOTInductorTestsTemplate, + AOTITestCase, + check_model, + check_model_with_multiple_inputs, + code_check_count, + ) + from test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library + copy_tests, + TestFailure, + ) +except (unittest.SkipTest, ImportError) as e: + if __name__ == "__main__": + sys.exit(0) + raise + + +def fail_stack_allocation(is_skip=False): + return TestFailure( + ( + "cpu_with_stack_allocation", + "cpu_with_stack_allocation_and_minimal_arrayref_interface", + ), + is_skip=is_skip, + ) + + +def fail_minimal_arrayref_interface(is_skip=False): + return TestFailure( + ("cpu_with_stack_allocation_and_minimal_arrayref_interface",), + is_skip=is_skip, + ) + + +# test_failures, xfail by default, set is_skip=True to skip +CPU_TEST_FAILURES = { + # TODO: error: ‘complex64’ was not declared in this scope + "test_add_complex": fail_minimal_arrayref_interface(is_skip=True), + "test_conv_freezing": fail_minimal_arrayref_interface(is_skip=True), + "test_deconv_freezing": fail_minimal_arrayref_interface(is_skip=True), + "test_addmm_multiple_dynamic": fail_minimal_arrayref_interface(), + "test_bmm_multiple_dynamic": fail_minimal_arrayref_interface(), + "test_cond_nested": fail_minimal_arrayref_interface(), + "test_cond_simple": fail_minimal_arrayref_interface(), + "test_cond_symint_input": fail_minimal_arrayref_interface(), + "test_cond_use_buffers_from_outer_scope": fail_minimal_arrayref_interface(), + "test_cond_with_multiple_outputs": fail_minimal_arrayref_interface(), + "test_cond_with_outer_code_before_after": fail_minimal_arrayref_interface(), + "test_cond_with_parameters": fail_minimal_arrayref_interface(), + "test_cond_with_reinterpret_view_inputs_outputs": fail_minimal_arrayref_interface(), + "test_foreach_multiple_dynamic": fail_minimal_arrayref_interface(), + "test_nested_tensor_from_jagged": fail_minimal_arrayref_interface(), + "test_poi_multiple_dynamic": fail_minimal_arrayref_interface(), + "test_while_loop_with_parameters": fail_minimal_arrayref_interface(), + # FIXME: failed with Segfault while exiting the Python runtime + "test_duplicate_constant_folding": fail_stack_allocation(is_skip=True), + "test_stride_with_unbacked_expr": fail_minimal_arrayref_interface(is_skip=True), + # TODO: use of deleted function RAIIAtenTensorHandle + "test_dup_unbacked_sym_decl": fail_minimal_arrayref_interface(is_skip=True), + # TODO: use of deleted function RAIIAtenTensorHandle + "test_dup_unbacked_sym_decl_with_refinement": fail_minimal_arrayref_interface( + is_skip=True + ), + # TODO: error: cannot convert ArrayRefTensor to AtenTensorHandle + "test_dynamic_cat": fail_minimal_arrayref_interface(), + # https://github.com/pytorch/pytorch/issues/129550 + # https://github.com/pytorch/pytorch/issues/123691 + "test_dynamic_scalar": fail_minimal_arrayref_interface(is_skip=True), + # https://github.com/pytorch/pytorch/issues/122980 + "test_fft_c2c": fail_stack_allocation(is_skip=True), + "test_freezing": fail_minimal_arrayref_interface(is_skip=True), + "test_linear_freezing": fail_minimal_arrayref_interface(is_skip=True), + # FIXME: failed with Segfault while exiting the Python runtime + "test_missing_cubin": fail_stack_allocation(is_skip=True), + # minimal arrayref interface only works with CPU; test crashes. + # https://github.com/pytorch/pytorch/issues/122983 + "test_multi_device": fail_minimal_arrayref_interface(is_skip=True), + # TODO: AssertionError: unsupported Optional type in convert_arg_type: Generator + "test_normal_functional": fail_stack_allocation(is_skip=True), + # TODO: The same issue as https://github.com/pytorch/pytorch/issues/122978 + # error: cannot convert ArrayRefTensor to AtenTensorHandle + "test_reuse_kernel_dynamic": fail_minimal_arrayref_interface(is_skip=True), + # the test segfaults + "test_repeat_output": fail_stack_allocation(is_skip=True), + # TODO: failed internally + "test_multiple_output_alias": fail_stack_allocation(is_skip=True), + # segfault + "test_buffer_mutation_1": fail_stack_allocation(is_skip=True), + # segfault + "test_buffer_mutation_2": fail_stack_allocation(is_skip=True), + # segfault + "test_bool_input": fail_stack_allocation(is_skip=True), + # segfault + "test_int_list_input": fail_stack_allocation(is_skip=True), + # segfault + # 'AOTInductorTestABICompatibleCpuWithStackAllocation' object has no attribute 'code_check_count' + "test_buffer_mutation_3": fail_stack_allocation(is_skip=True), + # FIXME: failed with Segfault while exiting the Python runtime + "test_scatter_fallback": fail_stack_allocation(is_skip=True), + # Looks like the same issue as https://github.com/pytorch/pytorch/issues/122978 + "test_scatter_reduce_fallback": fail_minimal_arrayref_interface(is_skip=True), + # Looks like the same issue as https://github.com/pytorch/pytorch/issues/122978 + "test_index_put_fallback": fail_minimal_arrayref_interface(is_skip=True), + # https://github.com/pytorch/pytorch/issues/122984 + "test_index_put_with_none_index": fail_minimal_arrayref_interface(is_skip=True), + # FIXME: failed with Segfault while exiting the Python runtime + "test_constant": fail_stack_allocation(is_skip=True), + # Looks like the same issue as https://github.com/pytorch/pytorch/issues/122978 + "test_shifted_constraint_ranges": fail_stack_allocation(is_skip=True), + # https://github.com/pytorch/pytorch/issues/123691 + "test_amp_fallback_random": fail_minimal_arrayref_interface(is_skip=True), + "test_simple_dynamic": fail_minimal_arrayref_interface(), + # https://github.com/pytorch/pytorch/issues/123691 + "test_zero_grid_with_unbacked_symbols": fail_minimal_arrayref_interface( + is_skip=True + ), + # failed on MacOS + "test_zero_grid_with_backed_symbols": fail_stack_allocation(is_skip=True), + # https://github.com/pytorch/pytorch/issues/122990 + "test_cond_non_tensor_predicates_dynamic_False": fail_stack_allocation( + is_skip=True + ), + # same issue as https://github.com/pytorch/pytorch/issues/122990 + "test_cond_non_tensor_predicates_dynamic_True": fail_stack_allocation(is_skip=True), + # https://github.com/pytorch/pytorch/issues/122991 + "test_runtime_checks_complex": fail_stack_allocation(is_skip=True), + "test_runtime_checks_fp8": fail_stack_allocation(is_skip=True), + "test_while_loop_simple": fail_stack_allocation(is_skip=True), + "test_while_loop_nested": fail_stack_allocation(is_skip=True), + "test_while_loop_with_outer_code": fail_stack_allocation(is_skip=True), + # TODO: error: cannot convert ArrayRefTensor to AtenTensorHandle + "test_while_loop_with_outer_buffers": fail_stack_allocation(is_skip=True), + # TODO: use of undeclared identifier 'float8_e4m3fn' and 'half' + "test_fp8": fail_minimal_arrayref_interface(is_skip=True), + "test_custom_op_add": fail_minimal_arrayref_interface(is_skip=True), + "test_custom_op_all_inputs": fail_minimal_arrayref_interface(is_skip=True), + "test_custom_op_with_multiple_outputs": fail_minimal_arrayref_interface( + is_skip=True + ), + "test_custom_op_with_reinterpret_view_inputs": fail_minimal_arrayref_interface( + is_skip=True + ), + "test_custom_op_with_concat_inputs": fail_minimal_arrayref_interface(is_skip=True), + "test_custom_op_missing_arg_with_default_value": fail_minimal_arrayref_interface( + is_skip=True + ), + "test_size_from_multi_output": fail_stack_allocation(is_skip=True), + "test_torchvision_transforms_functional_tensor_resize": fail_minimal_arrayref_interface(), +} + + +class AOTInductorTestABICompatibleCpuWithStackAllocation(AOTITestCase): + device = "cpu" + check_model = check_model + check_model_with_multiple_inputs = check_model_with_multiple_inputs + code_check_count = code_check_count + allow_stack_allocation = True + use_minimal_arrayref_interface = False + + +copy_tests( + AOTInductorTestsTemplate, + AOTInductorTestABICompatibleCpuWithStackAllocation, + "cpu_with_stack_allocation", + CPU_TEST_FAILURES, +) + + +class AOTInductorTestABICompatibleCpuWithStackAllocationAndMinimalArrayRefInterface( + TestCase +): + device = "cpu" + check_model = check_model + check_model_with_multiple_inputs = check_model_with_multiple_inputs + code_check_count = code_check_count + allow_stack_allocation = True + use_minimal_arrayref_interface = True + + +if IS_FBCODE: + # The following tests look like they pass in both pytest and unittest (xml + # and terminal output say pass), but the process will segfault. This only + # happens in OSS CI and is fine internally. + # See https://github.com/pytorch/pytorch/issues/123691 + copy_tests( + AOTInductorTestsTemplate, + AOTInductorTestABICompatibleCpuWithStackAllocationAndMinimalArrayRefInterface, + "cpu_with_stack_allocation_and_minimal_arrayref_interface", + CPU_TEST_FAILURES, + ) + +if __name__ == "__main__": + from torch._inductor.test_case import run_tests + + run_tests(needs="filelock") diff --git a/test/inductor/test_auto_functionalize.py b/test/inductor/test_auto_functionalize.py index 2ab6f38cacc53..f904d97c4f821 100644 --- a/test/inductor/test_auto_functionalize.py +++ b/test/inductor/test_auto_functionalize.py @@ -1,5 +1,7 @@ # Owner(s): ["module: functionalization"] +import unittest + import numpy as np import torch @@ -9,6 +11,8 @@ import torch.utils._pytree as pytree import torch.utils.cpp_extension from torch import Tensor +from torch._dynamo.testing import CompileCounterWithBackend +from torch._higher_order_ops.auto_functionalize import try_use_slice from torch.testing._internal.logging_utils import logs_to_string @@ -180,12 +184,10 @@ def f(x, y, z, n): self.assertExpectedInline( post_grad_graphs, """\ -def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: \ -"f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): +def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): # No stacktrace found for following nodes - foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = \ -arg3_1 = arg1_1 = arg0_1 = foo_default = None - return ()""", + foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = foo_default = None + return ()""", # noqa: B950 ) eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) @@ -239,7 +241,7 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): - foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg4_1 = arg2_1 = arg3_1 = arg1_1 = arg0_1 = None + foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg4_1 = arg1_1 = arg0_1 = None getitem_4: "f32[3][1]cpu" = foo_default[0] getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None return (getitem_4, getitem_5)""", # noqa: B950 @@ -402,9 +404,9 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1]cpu", arg3_1: "f32[s0][1]cpu", arg4_1: "f32[s0][1]cpu", arg5_1: "f32[s0][1]cpu"): - foo_default = torch.ops.mylib.foo.default(arg5_1, [arg3_1, arg4_1], arg2_1, 2, arg1_1); arg3_1 = arg4_1 = arg1_1 = foo_default = None + foo_default = torch.ops.mylib.foo.default(arg3_1, [arg4_1, arg5_1], arg2_1, 2, arg1_1); arg4_1 = arg5_1 = arg1_1 = foo_default = None copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy_ = None - copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg5_1, arg5_1); arg5_1 = copy__1 = None + copy__1: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg3_1, arg3_1); arg3_1 = copy__1 = None return ()""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -414,9 +416,9 @@ def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu", arg2_1: "f32[s0][1 post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): - foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = foo_default = None + foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = foo_default = None copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None - copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None + copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None return ()""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -503,12 +505,11 @@ def f(x, y, z, n): post_grad_graphs, """\ def forward(self, arg0_1: "f32[3][1]cpu", arg1_1: "f32[3][1]cpu", arg2_1: "f32[3][1]cpu", arg3_1: "f32[3][1]cpu", arg4_1: "f32[3][1]cpu"): - foo_default = torch.ops.mylib.foo.default(arg4_1, [arg2_1, arg3_1], arg1_1, 2, arg0_1); arg2_1 = arg3_1 = arg0_1 = None + foo_default = torch.ops.mylib.foo.default(arg2_1, [arg3_1, arg4_1], arg1_1, 2, arg0_1); arg3_1 = arg4_1 = arg0_1 = None getitem_4: "f32[3][1]cpu" = foo_default[0] getitem_5: "f32[3][1]cpu" = foo_default[1]; foo_default = None - copy_: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); arg1_1 = copy_ = None - copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg4_1, arg4_1); arg4_1 = copy__1 = None + copy__1: "f32[3][1]cpu" = torch.ops.aten.copy_.default(arg2_1, arg2_1); arg2_1 = copy__1 = None return (getitem_4, getitem_5)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, @@ -937,7 +938,6 @@ def test_dynamic2_v2(self): def test_dynamic3_v2(self): self.test_auto_functionalize_extra2(_dynamic=True) - # foo takes two views on the same input, function does not have return. @torch._inductor.config.patch(enable_auto_functionalized_v2=True) def test_graph_input_is_view(self): with torch.library._scoped_library("mylib", "FRAGMENT") as lib: @@ -964,6 +964,685 @@ def f(x): # to clone not-inplaced args. f(x[1]) + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_alias(self, _dynamic=False): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!) y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y): + x.sin_() + y.sin_() + + def f(x): + a = torch.ops.aten.alias.default(x) + b = torch.ops.aten.alias.default(x) + torch.ops.mylib.foo(a, b) + return (a, b, x) + + orig_args = [torch.randn(2)] + [aot_eager_args, result1, graph_aot] = self.run_aot_eager( + f, orig_args, _dynamic + ) + [inductor_args, result2, graph_inductor] = self.run_inductor( + f, orig_args, _dynamic + ) + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + result3 = f(*eager_args) + + self.assertEqual(inductor_args, eager_args) + self.assertEqual(inductor_args, aot_eager_args) + + self.assertEqual(result3, result1) + self.assertEqual(result3, result2) + + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 0, _y_alias = True, _all_bases = [arg1_1]) + getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None + alias_2: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1) + alias_3: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None + return (alias_2, alias_3)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "f32[2][1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 0, _y_alias = True, _all_bases = [arg0_1]) + getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None + alias_2: "f32[2][1]cpu" = torch.ops.aten.alias.default(getitem_1) + alias_3: "f32[2][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None + return (alias_2, alias_3)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # 2. Run with inductor backend + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): + alias_default: "f32[s0][1]cpu" = torch.ops.aten.alias.default(arg1_1) + alias_default_1: "f32[s0][1]cpu" = torch.ops.aten.alias.default(arg1_1) + foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); \ +alias_default = alias_default_1 = foo_default = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None + return (arg1_1, arg1_1)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "f32[2][1]cpu"): + alias_default: "f32[2][1]cpu" = torch.ops.aten.alias.default(arg0_1) + alias_default_1: "f32[2][1]cpu" = torch.ops.aten.alias.default(arg0_1) + foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); \ +alias_default = alias_default_1 = foo_default = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None + return (arg0_1, arg0_1)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # Test that slice view is generated instead of as_strided when split is used. + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_split(self, _dynamic=False): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!) y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y): + x.sin_() + y.sin_() + + def f(x): + splits = x.split([4, 6], dim=1) + a = splits[0] + b = splits[1] + torch.ops.mylib.foo(a, b) + return (a, b, x) + + orig_args = [torch.randn(10, 10)] + [aot_eager_args, result1, graph_aot] = self.run_aot_eager( + f, orig_args, _dynamic + ) + [inductor_args, result2, graph_inductor] = self.run_inductor( + f, orig_args, _dynamic + ) + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + result3 = f(*eager_args) + + self.assertEqual(inductor_args, eager_args) + self.assertEqual(inductor_args, aot_eager_args) + + self.assertEqual(result3, result1) + self.assertEqual(result3, result2) + + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + # split forces a specialization on size so we dont see arg0_1 dynamic anymore. + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_slice_dim = 1, _x_slice_start = 0, _x_slice_end = 4, _y_base_index = 0, _y_slice_dim = 1, _y_slice_start = 4, _y_slice_end = 10, _all_bases = [arg0_1]) + getitem_3: "f32[10, 10][10, 1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None + copy_: "f32[10, 10][10, 1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_3); arg0_1 = copy_ = None + split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(getitem_3, [4, 6], 1) + getitem_4: "f32[10, 4][10, 1]cpu" = split_with_sizes_1[0]; split_with_sizes_1 = None + split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(getitem_3, [4, 6], 1); getitem_3 = None + getitem_7: "f32[10, 6][10, 1]cpu" = split_with_sizes_2[1]; split_with_sizes_2 = None + return (getitem_4, getitem_7)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_slice_dim = 1, _x_slice_start = 0, _x_slice_end = 4, _y_base_index = 0, _y_slice_dim = 1, _y_slice_start = 4, _y_slice_end = 10, _all_bases = [arg0_1]) + getitem_3: "f32[10, 10][10, 1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None + copy_: "f32[10, 10][10, 1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_3); arg0_1 = copy_ = None + split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(getitem_3, [4, 6], 1) + getitem_4: "f32[10, 4][10, 1]cpu" = split_with_sizes_1[0]; split_with_sizes_1 = None + split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(getitem_3, [4, 6], 1); getitem_3 = None + getitem_7: "f32[10, 6][10, 1]cpu" = split_with_sizes_2[1]; split_with_sizes_2 = None + return (getitem_4, getitem_7)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # 2. Run with inductor backend + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + # split forces a specialization on size so we dont see arg0_1 dynamic anymore. + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): + slice_tensor: "f32[10, 4][10, 1]cpu" = torch.ops.aten.slice.Tensor(arg0_1, 1, 0, 4) + slice_tensor_1: "f32[10, 6][10, 1]cpu" = torch.ops.aten.slice.Tensor(arg0_1, 1, 4, 10) + foo_default = torch.ops.mylib.foo.default(slice_tensor, slice_tensor_1); slice_tensor = slice_tensor_1 = foo_default = None + copy_: "f32[10, 10][10, 1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None + split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(arg0_1, [4, 6], 1) + getitem_4: "f32[10, 4][10, 1]cpu" = split_with_sizes_1[0]; split_with_sizes_1 = None + split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(arg0_1, [4, 6], 1); arg0_1 = None + getitem_7: "f32[10, 6][10, 1]cpu" = split_with_sizes_2[1]; split_with_sizes_2 = None + return (getitem_4, getitem_7)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): + slice_tensor: "f32[10, 4][10, 1]cpu" = torch.ops.aten.slice.Tensor(arg0_1, 1, 0, 4) + slice_tensor_1: "f32[10, 6][10, 1]cpu" = torch.ops.aten.slice.Tensor(arg0_1, 1, 4, 10) + foo_default = torch.ops.mylib.foo.default(slice_tensor, slice_tensor_1); slice_tensor = slice_tensor_1 = foo_default = None + copy_: "f32[10, 10][10, 1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None + split_with_sizes_1 = torch.ops.aten.split_with_sizes.default(arg0_1, [4, 6], 1) + getitem_4: "f32[10, 4][10, 1]cpu" = split_with_sizes_1[0]; split_with_sizes_1 = None + split_with_sizes_2 = torch.ops.aten.split_with_sizes.default(arg0_1, [4, 6], 1); arg0_1 = None + getitem_7: "f32[10, 6][10, 1]cpu" = split_with_sizes_2[1]; split_with_sizes_2 = None + return (getitem_4, getitem_7)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # Note that split force the input tensor to get specialized. So we do not see SymInts when _dynamic=True. + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_split_dynamic(self): + self.test_split(_dynamic=True) + + # Test that slice view is generated instead of as_strided when slice is used. + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_slice(self, _dynamic=False): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!) y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y): + x.sin_() + y.sin_() + + def f(x): + a = torch.ops.aten.slice.Tensor(x, 0, 0, 2) + b = torch.ops.aten.slice.Tensor(x, 1, 3, 4) + torch.ops.mylib.foo(a, b) + return (a, b, x) + + orig_args = [torch.randn(10, 10)] + [aot_eager_args, result1, graph_aot] = self.run_aot_eager( + f, orig_args, _dynamic + ) + [inductor_args, result2, graph_inductor] = self.run_inductor( + f, orig_args, _dynamic + ) + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + result3 = f(*eager_args) + + self.assertEqual(inductor_args, eager_args) + self.assertEqual(inductor_args, aot_eager_args) + + self.assertEqual(result3, result1) + self.assertEqual(result3, result2) + + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0, s0][s0, 1]cpu"): + floordiv: "Sym(0)" = 0 // arg0_1; arg0_1 = None + add_6: "Sym(2)" = floordiv + 2 + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_slice_dim = 0, _x_slice_start = floordiv, _x_slice_end = add_6, _y_base_index = 0, _y_slice_dim = 1, _y_slice_start = 3, _y_slice_end = 4, _all_bases = [arg1_1]); floordiv = add_6 = None + getitem_1: "f32[s0, s0][s0, 1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None + copy_: "f32[s0, s0][s0, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None + slice_3: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 2) + slice_4: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 1, 3, 4); getitem_1 = None + return (slice_3, slice_4)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_slice_dim = 0, _x_slice_start = 0, _x_slice_end = 2, _y_base_index = 0, _y_slice_dim = 1, _y_slice_start = 3, _y_slice_end = 4, _all_bases = [arg0_1]) + getitem_1: "f32[10, 10][10, 1]cpu" = auto_functionalized_v2[1]; auto_functionalized_v2 = None + copy_: "f32[10, 10][10, 1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None + slice_3: "f32[2, 10][10, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 0, 0, 2) + slice_4: "f32[10, 1][10, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_1, 1, 3, 4); getitem_1 = None + return (slice_3, slice_4)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # 2. Run with inductor backend + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0, s0][s0, 1]cpu"): + floordiv: "Sym(0)" = 0 // arg0_1; arg0_1 = None + add_6: "Sym(2)" = floordiv + 2; floordiv = add_6 = None + slice_tensor: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2) + slice_tensor_1: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4) + foo_default = torch.ops.mylib.foo.default(slice_tensor, slice_tensor_1); slice_tensor = slice_tensor_1 = foo_default = None + copy_: "f32[s0, s0][s0, 1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None + slice_3: "f32[2, s0][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, 0, 2) + slice_4: "f32[s0, 1][s0, 1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 1, 3, 4); arg1_1 = None + return (slice_3, slice_4)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "f32[10, 10][10, 1]cpu"): + slice_tensor: "f32[2, 10][10, 1]cpu" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 2) + slice_tensor_1: "f32[10, 1][10, 1]cpu" = torch.ops.aten.slice.Tensor(arg0_1, 1, 3, 4) + foo_default = torch.ops.mylib.foo.default(slice_tensor, slice_tensor_1); slice_tensor = slice_tensor_1 = foo_default = None + copy_: "f32[10, 10][10, 1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None + slice_3: "f32[2, 10][10, 1]cpu" = torch.ops.aten.slice.Tensor(arg0_1, 0, 0, 2) + slice_4: "f32[10, 1][10, 1]cpu" = torch.ops.aten.slice.Tensor(arg0_1, 1, 3, 4); arg0_1 = None + return (slice_3, slice_4)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # Note that split force the input tensor to get specialized. So we do not see SymInts when _dynamic=True. + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_slice_dynamic(self): + self.test_slice(_dynamic=True) + + def test_try_use_slice(self): + def test_round_trip(base, tensor): + (dim, start, end) = try_use_slice(base, tensor) + sliced = torch.ops.aten.slice.Tensor(base, dim, start, end) + self.assertEqual(sliced, tensor) + + t = torch.tensor([[2, 2], [3, 4]]) + test_round_trip(t, t) + + for dim in range(-1, 1): + f = t.split(2, dim) + test_round_trip(t, f[0]) + + for dim in range(-1, 1): + f = t.split(1, dim) + test_round_trip(t, f[0]) + test_round_trip(t, f[1]) + + t = torch.randint(1, 10, (3, 3, 3)) + test_round_trip(t, t) + + for dim in range(-3, 3): + f = t.split([1, 2], dim) + test_round_trip(t, f[0]) + test_round_trip(t, f[1]) + + for dim in range(-3, 3): + f = t.split(1, dim) + test_round_trip(t, f[0]) + test_round_trip(t, f[1]) + test_round_trip(t, f[2]) + + t = torch.rand(10, 10, 10) + test_round_trip(t, t) + for dim in range(-3, 3): + f = t.split([2, 2, 6], dim) + test_round_trip(t, f[0]) + test_round_trip(t, f[1]) + test_round_trip(t, f[2]) + + # example where slice wont work + + # selection + t = torch.ones(10) + b = t[0] + self.assertEqual(try_use_slice(t, b), None) + + t = torch.tensor([[1, 2], [3, 4]]) + self.assertEqual(try_use_slice(t, t[0]), None) + self.assertEqual(try_use_slice(t, t[1]), None) + + t = torch.tensor( + [ + [[1, 2, 3, 4, 5, 6, 7, 8], [10, 11, 12, 13, 14, 15, 16, 17]], + [[71, 72, 73, 74, 75, 76, 77, 78], [81, 82, 83, 84, 85, 86, 87, 88]], + ] + ) + + self.assertEqual(try_use_slice(t, t[0:1, 0:1, :7]), None) + self.assertEqual(try_use_slice(t, t[0:1, 0:2, :3]), None) + self.assertEqual(try_use_slice(t, t[0:2, 1, 0:8]), None) + + # simple slice operations are supported + test_round_trip(t, t[0:2]) + test_round_trip(t, t[3:4]) + + @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_alias2(self, _dynamic=False): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!) y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y): + x.sin_() + y.sin_() + + def f(x): + a = torch.ops.aten.alias.default(x) + b = x.clone() + c = b.nonzero().float() + d = torch.ops.aten.slice( + c + ) # d is a Tensor with unbacked Symint in the shape + torch.ops.mylib.foo(a, d) + return a, d + + orig_args = [torch.randn(2)] + [aot_eager_args, result1, graph_aot] = self.run_aot_eager( + f, orig_args, _dynamic + ) + [inductor_args, result2, graph_inductor] = self.run_inductor( + f, orig_args, _dynamic + ) + eager_args = pytree.tree_map_only(torch.Tensor, torch.clone, orig_args) + result3 = f(*eager_args) + + self.assertEqual(inductor_args, eager_args) + self.assertEqual(inductor_args, aot_eager_args) + + self.assertEqual(result3, result1) + self.assertEqual(result3, result2) + + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): + clone: "f32[s0][1]cpu" = torch.ops.aten.clone.default(arg1_1) + nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(clone); clone = None + sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) + ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None + _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + _to_copy: "f32[u0, 1][1, 1]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg1_1, _to_copy]); _to_copy = None + getitem_1: "f32[s0][1]cpu" = auto_functionalized_v2[1] + getitem_2: "f32[u0, 1][1, 1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, getitem_1); arg1_1 = copy_ = None + alias_1: "f32[s0][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None + slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None + return (alias_1, slice_2)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_aot, + """\ +def forward(self, arg0_1: "f32[2][1]cpu"): + clone: "f32[2][1]cpu" = torch.ops.aten.clone.default(arg0_1) + nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(clone); clone = None + sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) + ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0 + _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None + _to_copy: "f32[u0, 1][1, 1]cpu" = torch.ops.aten._to_copy.default(nonzero, dtype = torch.float32); nonzero = None + auto_functionalized_v2 = torch.ops.higher_order.auto_functionalized_v2(torch.ops.mylib.foo.default, _x_base_index = 0, _x_alias = True, _y_base_index = 1, _y_alias = True, _all_bases = [arg0_1, _to_copy]); _to_copy = None + getitem_1: "f32[2][1]cpu" = auto_functionalized_v2[1] + getitem_2: "f32[u0, 1][1, 1]cpu" = auto_functionalized_v2[2]; auto_functionalized_v2 = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, getitem_1); arg0_1 = copy_ = None + alias_1: "f32[2][1]cpu" = torch.ops.aten.alias.default(getitem_1); getitem_1 = None + slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(getitem_2); getitem_2 = None + return (alias_1, slice_2)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + # 2. Run with inductor backend + if torch._dynamo.config.assume_static_by_default: + if _dynamic: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "Sym(s0)", arg1_1: "f32[s0][1]cpu"): + nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(arg1_1) + sym_size_int_1: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) + ge_1: "Sym(u0 >= 0)" = sym_size_int_1 >= 0; sym_size_int_1 = None + _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + convert_element_type: "f32[u0, 1][1, 1]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None + alias_default: "f32[s0][1]cpu" = torch.ops.aten.alias.default(arg1_1) + alias_default_1: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.alias.default(convert_element_type) + foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None + copy_: "f32[s0][1]cpu" = torch.ops.aten.copy_.default(arg1_1, arg1_1); copy_ = None + slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None + return (arg1_1, slice_2)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + else: + self.assertExpectedInline( + graph_inductor, + """\ +def forward(self, arg0_1: "f32[2][1]cpu"): + nonzero: "i64[u0, 1][1, 1]cpu" = torch.ops.aten.nonzero.default(arg0_1) + sym_size_int: "Sym(u0)" = torch.ops.aten.sym_size.int(nonzero, 0) + ge_1: "Sym(u0 >= 0)" = sym_size_int >= 0 + _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None + le: "Sym(u0 <= 2)" = sym_size_int <= 2; sym_size_int = None + _assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u0 <= 2 on node 'le'"); le = _assert_scalar_1 = None + convert_element_type: "f32[u0, 1][1, 1]cpu" = torch.ops.prims.convert_element_type.default(nonzero, torch.float32); nonzero = None + alias_default: "f32[2][1]cpu" = torch.ops.aten.alias.default(arg0_1) + alias_default_1: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.alias.default(convert_element_type) + foo_default = torch.ops.mylib.foo.default(alias_default, alias_default_1); alias_default = alias_default_1 = foo_default = None + copy_: "f32[2][1]cpu" = torch.ops.aten.copy_.default(arg0_1, arg0_1); copy_ = None + slice_2: "f32[u0, 1][1, 1]cpu" = torch.ops.aten.slice.Tensor(convert_element_type); convert_element_type = None + return (arg0_1, slice_2)""", # noqa: B950 + ignore_comments=True, + ignore_empty_lines=True, + ) + + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_alias2_dynamic(self): + self.test_alias2(_dynamic=True) + + # Test that the view regenration optimizations do not result in recompilations. By comparing re-compilation in eager backend + # with recompilation in inductor backend. + @torch.fx.experimental._config.patch(use_duck_shape=False) + def test_recompile(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!) y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo_impl(x, y): + pass + + def run_and_compare(func, expected=1): + counter_v2 = CompileCounterWithBackend("inductor") + counter_v1 = CompileCounterWithBackend("inductor") + v1 = torch.compile( + func, backend=counter_v1, fullgraph=True, dynamic=True + ) + + v2 = torch.compile( + func, backend=counter_v2, fullgraph=True, dynamic=True + ) + inputs = [ + torch.rand(10, 10), + torch.rand(100, 100), + torch.rand(10, 2), + torch.rand(1000, 1000), + ] + + with torch._inductor.config.patch(enable_auto_functionalized_v2=True): + for input in inputs: + v2(input) + + torch._dynamo.reset() + + with torch._inductor.config.patch(enable_auto_functionalized_v2=False): + for input in inputs: + v1(input) + + self.assertEqual(counter_v2.frame_count, counter_v1.frame_count) + + self.assertEqual(counter_v1.frame_count, expected) + + def func(x): + a = x[0] + b = x[1] + torch.ops.mylib.foo(a, b) + + run_and_compare(func) + + def func(x): + a = torch.ops.aten.alias.default(x) + b = torch.ops.aten.alias.default(x) + torch.ops.mylib.foo(a, b) + + run_and_compare(func) + + def func(x): + # last row + a = x[x.size()[0] - 1] + + # first row + b = x[0] + torch.ops.mylib.foo(a, b) + + run_and_compare(func) + + def func(x): + a = torch.ops.aten.slice.Tensor(x, 1, 3, 4) + b = torch.ops.aten.slice.Tensor(x, 0, 1, 4) + torch.ops.mylib.foo(a, b) + + # recompile here is not triggered by auto_functionalize + # [__recompiles] - 0/0: 4 <= L['x'].size()[1] # a = torch.ops.aten.slice.Tensor(x, 1, 3, 4) + # test/inductor/test_auto_functionalize.py:1160 in func (_decomp/decompositions.py:781 in slice_forward) + run_and_compare(func, 2) + + def func(x): + a = torch.ops.aten.alias.default(x) + b = x.clone() + c = b.nonzero().float() + d = torch.ops.aten.slice( + c + ) # d is a Tensor with unbacked Symint in the shape + torch.ops.mylib.foo(a, d) + return a, d + + with torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True): + run_and_compare(func, 1) + + # Test that the alias optimization, were alias is called instead of as_strided, preserve the fact + # that id(x) != id(base) + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + @unittest.skip( + reason="This test fails because something else in inductor optimize out the alias. issue #137434" + ) + def test_alias_id_input_to_custom_op(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::not_eq", + "(Tensor(a!) x, Tensor(b!) y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::not_eq", "cpu", lib=lib) + @torch._dynamo.disable + def not_eq_impl(x, y): + self.assertNotEqual(id(x), id(y)) + + def func(x): + a = torch.ops.aten.alias.default(x) + torch.ops.mylib.not_eq(a, x) + + compiled = torch.compile(func, backend="inductor", fullgraph=True) + compiled(torch.rand(2, 2)) + + # Test that the alias optimization, were alias is called instead of as_strided, preserve the fact + # that id(x) != id(base) + @torch._inductor.config.patch(enable_auto_functionalized_v2=True) + def test_alias_id_output(self): + with torch.library._scoped_library("mylib", "FRAGMENT") as lib: + torch.library.define( + "mylib::foo", + "(Tensor(a!) x, Tensor(b!) y) -> ()", + tags=torch.Tag.pt2_compliant_tag, + lib=lib, + ) + + @torch.library.impl("mylib::foo", "cpu", lib=lib) + @torch._dynamo.disable + def foo(x, y): + pass + + def func(x): + a = torch.ops.aten.alias.default(x) + torch.ops.mylib.foo(a, x) + return a + + compiled = torch.compile(func, backend="inductor", fullgraph=True) + input = torch.rand(2, 2) + output = compiled(torch.rand(2, 2)) + self.assertNotEqual(id(output), id(input)) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_autoheuristic.py b/test/inductor/test_autoheuristic.py index 7679a2cc35926..196ccbfbde17f 100644 --- a/test/inductor/test_autoheuristic.py +++ b/test/inductor/test_autoheuristic.py @@ -4,14 +4,17 @@ import torch import torch._inductor.config as inductor_config +from torch._dynamo.device_interface import get_interface_for_device from torch._inductor.autoheuristic.autoheuristic import AutoHeuristic, LocalFeedback from torch._inductor.autoheuristic.autoheuristic_utils import AHContext from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import get_gpu_shared_memory -from torch.testing._internal.inductor_utils import HAS_CUDA, IS_A100, IS_H100 +from torch.testing._internal.common_utils import skipIfXpu +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, IS_A100, IS_H100 +@skipIfXpu(msg="AutoHeuristic doesn't currently work on the XPU stack") class AutoHeuristicTest(TestCase): def count_lines_in_file(self, file_path): with open(file_path) as file: @@ -23,8 +26,8 @@ def f(a, b): return torch.mm(a, b) cf = torch.compile(f) - a = torch.randn(2047, 2048, device="cuda", dtype=torch.float16) - b = torch.randn(2048, 2048, device="cuda", dtype=torch.float16) + a = torch.randn(2047, 2048, device=GPU_TYPE, dtype=torch.float16) + b = torch.randn(2048, 2048, device=GPU_TYPE, dtype=torch.float16) cf(a, b) def get_path_to_autoheuristic_log(self, name): @@ -99,7 +102,7 @@ def feedback_fn(choice): self.assertEqual(num_lines, 5) shared_memory = get_gpu_shared_memory() - (fst, snd) = torch.cuda.get_device_capability() + (fst, snd) = get_interface_for_device(GPU_TYPE).get_device_capability() with open(path) as file: lines = file.readlines() @@ -131,8 +134,10 @@ def run_mixed_mm(self): def fn(a, b): return torch.mm(a, b.to(a.dtype)) - a = torch.randn(8, 1024, device="cuda", dtype=torch.float16) - b = torch.randint(-128, 127, (1024, 1024), dtype=torch.int8, device="cuda").t() + a = torch.randn(8, 1024, device=GPU_TYPE, dtype=torch.float16) + b = torch.randint( + -128, 127, (1024, 1024), dtype=torch.int8, device=GPU_TYPE + ).t() torch.compile(fn, mode="max-autotune-no-cudagraphs")(a, b) # have to set autoheuristic_use="" because if autoheuristic_use="mixed_mm", @@ -164,5 +169,5 @@ def test_mixed_mm_a100(self): if __name__ == "__main__": - if HAS_CUDA: + if HAS_GPU: run_tests() diff --git a/test/inductor/test_b2b_gemm.py b/test/inductor/test_b2b_gemm.py index 201903c85b9c9..0b4d73368b5c2 100644 --- a/test/inductor/test_b2b_gemm.py +++ b/test/inductor/test_b2b_gemm.py @@ -6,10 +6,14 @@ from torch._inductor.runtime.benchmarking import benchmarker from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import run_and_get_code -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.common_utils import skipIfXpu +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU +@skipIfXpu(msg="Segmentation fault on CI machine") class B2BGEMMTest(TestCase): + device = GPU_TYPE + @torch._dynamo.config.patch(cache_size_limit=32) @torch._inductor.config.patch(b2b_gemm_pass=True) def test_b2b_gemm_left_assoc_good_shape(self): @@ -37,9 +41,9 @@ def f_32(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: return f(m1, m2, m3).to(torch.float16) f_opt = torch.compile(f) - A = torch.randn((256, 32), device="cuda", dtype=torch.float16) - B = torch.randn((32, 256), device="cuda", dtype=torch.float16) - C = torch.randn((256, 32), device="cuda", dtype=torch.float16) + A = torch.randn((256, 32), device=GPU_TYPE, dtype=torch.float16) + B = torch.randn((32, 256), device=GPU_TYPE, dtype=torch.float16) + C = torch.randn((256, 32), device=GPU_TYPE, dtype=torch.float16) res, (code,) = run_and_get_code(f_opt, A, B, C) self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01)) self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" in code) @@ -63,9 +67,9 @@ def f_32(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: return f(m1, m2, m3).to(torch.float16) f_opt = torch.compile(f) - A = torch.randn((32, 256), device="cuda", dtype=torch.float16) - B = torch.randn((256, 32), device="cuda", dtype=torch.float16) - C = torch.randn((32, 256), device="cuda", dtype=torch.float16) + A = torch.randn((32, 256), device=GPU_TYPE, dtype=torch.float16) + B = torch.randn((256, 32), device=GPU_TYPE, dtype=torch.float16) + C = torch.randn((32, 256), device=GPU_TYPE, dtype=torch.float16) res, (code,) = run_and_get_code(f_opt, A, B, C) self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01)) self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" in code) @@ -88,9 +92,9 @@ def f_32(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: return f(m1, m2, m3).to(torch.float16) f_opt = torch.compile(f) - A = torch.randn((256, 32), device="cuda", dtype=torch.float16) - B = torch.randn((32, 256), device="cuda", dtype=torch.float16) - C = torch.randn((256, 32), device="cuda", dtype=torch.float16) + A = torch.randn((256, 32), device=GPU_TYPE, dtype=torch.float16) + B = torch.randn((32, 256), device=GPU_TYPE, dtype=torch.float16) + C = torch.randn((256, 32), device=GPU_TYPE, dtype=torch.float16) res, (code,) = run_and_get_code(f_opt, A, B, C) self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01)) self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" in code) @@ -113,9 +117,9 @@ def f_32(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: return f(m1, m2, m3).to(torch.float16) f_opt = torch.compile(f) - A = torch.randn((32, 256), device="cuda", dtype=torch.float16) - B = torch.randn((256, 32), device="cuda", dtype=torch.float16) - C = torch.randn((32, 256), device="cuda", dtype=torch.float16) + A = torch.randn((32, 256), device=GPU_TYPE, dtype=torch.float16) + B = torch.randn((256, 32), device=GPU_TYPE, dtype=torch.float16) + C = torch.randn((32, 256), device=GPU_TYPE, dtype=torch.float16) res, (code,) = run_and_get_code(f_opt, A, B, C) self.assertTrue(torch.allclose(f_32(A, B, C), res, atol=0.1, rtol=0.01)) self.assertTrue("B2B_GEMM_RIGHT_TRITON_ENTRANCE" in code) @@ -133,9 +137,9 @@ def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: return torch.mm(mm1, mm2) f_opt = torch.compile(f) - A = torch.randn((256, 32), device="cuda", dtype=torch.float16) - B = torch.randn((32, 256), device="cuda", dtype=torch.float16) - C = torch.randn((256, 32), device="cuda", dtype=torch.float16) + A = torch.randn((256, 32), device=GPU_TYPE, dtype=torch.float16) + B = torch.randn((32, 256), device=GPU_TYPE, dtype=torch.float16) + C = torch.randn((256, 32), device=GPU_TYPE, dtype=torch.float16) res, (code,) = run_and_get_code(f_opt, A, B, C) self.assertTrue(torch.allclose(f(A, B, C), res, atol=0.1, rtol=0.01)) self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" not in code) @@ -152,9 +156,9 @@ def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: return torch.mm(torch.mm(m1, m2), m3) f_opt = torch.compile(f) - A = torch.randn((100, 100), device="cuda", dtype=torch.float16) - B = torch.randn((100, 100), device="cuda", dtype=torch.float16) - C = torch.randn((100, 100), device="cuda", dtype=torch.float16) + A = torch.randn((100, 100), device=GPU_TYPE, dtype=torch.float16) + B = torch.randn((100, 100), device=GPU_TYPE, dtype=torch.float16) + C = torch.randn((100, 100), device=GPU_TYPE, dtype=torch.float16) res, (code,) = run_and_get_code(f_opt, A, B, C) self.assertTrue(torch.allclose(f(A, B, C), res, atol=0.1, rtol=0.01)) self.assertTrue("B2B_GEMM_LEFT_TRITON_ENTRANCE" not in code) @@ -198,9 +202,9 @@ def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: print(f"M = {M}".ljust(10), end="") for N in Ns: O, P = M, N - A = torch.randn((M, N), device="cuda", dtype=torch.float16) - B = torch.randn((N, O), device="cuda", dtype=torch.float16) - C = torch.randn((O, P), device="cuda", dtype=torch.float16) + A = torch.randn((M, N), device=GPU_TYPE, dtype=torch.float16) + B = torch.randn((N, O), device=GPU_TYPE, dtype=torch.float16) + C = torch.randn((O, P), device=GPU_TYPE, dtype=torch.float16) speedup = run_with_b2b_gemm_off(A, B, C) / run_with_b2b_gemm_on(A, B, C) print(f"{round(speedup, 3)}".ljust(10), end="") speedups.append(speedup) @@ -255,9 +259,9 @@ def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: print(f"M = {M}".ljust(10), end="") for N in Ns: O, P = M, N - A = torch.randn((M, N), device="cuda", dtype=torch.float16) - B = torch.randn((N, O), device="cuda", dtype=torch.float16) - C = torch.randn((O, P), device="cuda", dtype=torch.float16) + A = torch.randn((M, N), device=GPU_TYPE, dtype=torch.float16) + B = torch.randn((N, O), device=GPU_TYPE, dtype=torch.float16) + C = torch.randn((O, P), device=GPU_TYPE, dtype=torch.float16) speedup = run_with_b2b_gemm_off(A, B, C) / run_with_b2b_gemm_on(A, B, C) print(f"{round(speedup, 3)}".ljust(10), end="") speedups.append(speedup) @@ -312,9 +316,9 @@ def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: print(f"M = {M}".ljust(10), end="") for N in Ns: O, P = N, N - A = torch.randn((M, N), device="cuda", dtype=torch.float16) - B = torch.randn((N, O), device="cuda", dtype=torch.float16) - C = torch.randn((O, P), device="cuda", dtype=torch.float16) + A = torch.randn((M, N), device=GPU_TYPE, dtype=torch.float16) + B = torch.randn((N, O), device=GPU_TYPE, dtype=torch.float16) + C = torch.randn((O, P), device=GPU_TYPE, dtype=torch.float16) speedup = run_with_b2b_gemm_off(A, B, C) / run_with_b2b_gemm_on(A, B, C) print(f"{round(speedup, 3)}".ljust(10), end="") speedups.append(speedup) @@ -331,5 +335,5 @@ def f(m1: torch.Tensor, m2: torch.Tensor, m3: torch.Tensor) -> torch.Tensor: if __name__ == "__main__": - if HAS_CUDA: + if HAS_GPU: run_tests() diff --git a/test/inductor/test_benchmark_fusion.py b/test/inductor/test_benchmark_fusion.py index 9eb25aa305a1a..8f8fbcd9274d9 100644 --- a/test/inductor/test_benchmark_fusion.py +++ b/test/inductor/test_benchmark_fusion.py @@ -279,7 +279,7 @@ def test_equivalent_template_code(self): for out_code in [code, code2]: FileCheck().check("def call").check_count( "empty_strided_cuda", 1, exactly=True - ).check("triton_tem_fused_relu_0.run").check_count( + ).check("triton_tem_fused_addmm_relu_0.run").check_count( "del", 3, exactly=True ).check( "return" diff --git a/test/inductor/test_ck_backend.py b/test/inductor/test_ck_backend.py index bf386df514f80..3d51f621466d4 100644 --- a/test/inductor/test_ck_backend.py +++ b/test/inductor/test_ck_backend.py @@ -364,6 +364,44 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) + @unittest.skipIf(not torch.version.hip, "ROCM only") + @unittest.mock.patch.dict( + os.environ, + {"PATH": _get_path_without_sccache(), "PYTORCH_MIOPEN_SUGGEST_NHWC": "1"}, + ) + @parametrize("max_autotune_conv_backends", ("CK", "ATEN,CK,TRITON")) + def test_max_autotune_conv2d(self, max_autotune_conv_backends): + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False + + tensor_options = {"device": "cuda", "dtype": torch.float32} + + x = torch.randn(1, 8, 224, 224, **tensor_options) + w = torch.randn(64, 8, 7, 7, **tensor_options) + x_cl = x.to(memory_format=torch.channels_last) + w_cl = w.to(memory_format=torch.channels_last) + + assert "rocm" in dir(config) + + with config.patch( + { + "max_autotune": True, + "autotune_in_subproc": False, + "max_autotune_conv_backends": max_autotune_conv_backends, + "compile_threads": 4, + "rocm.ck_dir": self.ck_dir, + "rocm.n_max_profiling_configs": 4, + } + ): + + @torch.compile(dynamic=False) + def conv2d(x, w): + return torch.conv2d(x, w) + + Y_eager = torch.conv2d(x_cl, w_cl) + Y_compiled = conv2d(x_cl, w_cl) + + torch.testing.assert_close(Y_compiled, Y_eager, atol=2e-4, rtol=2e-4) + if __name__ == "__main__": from torch._inductor.utils import is_big_gpu diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index 6b58b9a000b93..70d1ae48f7cb0 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -1,5 +1,4 @@ # Owner(s): ["module: inductor"] -import functools import os import pickle import tempfile @@ -37,10 +36,11 @@ HAS_CUDA, HAS_GPU, HAS_MULTIGPU, + HAS_TRITON, requires_gpu, + requires_triton, ) from torch.testing._internal.triton_utils import requires_cuda -from torch.utils._triton import has_triton try: @@ -49,15 +49,11 @@ from mock_cache import global_stats, PatchCaches, Stats # @manual -HAS_TRITON = has_triton() - if HAS_TRITON: import triton # @manual from torch.testing._internal.triton_utils import add_kernel -requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton") - torch._dynamo.config.fake_tensor_cache_enabled = True torch._dynamo.config.fake_tensor_cache_crosscheck_enabled = True @@ -912,6 +908,7 @@ def reset(self): @config.patch({"fx_graph_remote_cache": False}) @config.patch({"autotune_local_cache": False}) @config.patch({"autotune_remote_cache": True}) + @config.patch({"bundled_autotune_remote_cache": False}) @config.patch({"max_autotune": True}) def test_autotune_cache(self): class Model(torch.nn.Module): @@ -944,6 +941,52 @@ def f(x, y, a, b): for k in global_stats.triton.cache.keys(): self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c10") + @unittest.skipIf(not HAS_CUDA, "Requires CUDA") + @unittest.skipIf(not SM80OrLater, "Requires SM80+") + @config.patch({"fx_graph_cache": False}) + @config.patch({"fx_graph_remote_cache": False}) + @config.patch({"autotune_local_cache": True}) + @config.patch({"autotune_remote_cache": False}) + @config.patch({"bundled_autotune_remote_cache": True}) + @config.patch({"max_autotune": True}) + def test_bundled_autotune_remote_cache(self): + class Model(torch.nn.Module): + def forward(self, a, b, c, d, e, f): + return a + b, c + d, e + f + + def f(a, b, c, d, e, f): + return Model()(a, b, c, d, e, f) + + f_compiled = torch.compile(f, fullgraph=True) + + a = torch.randn(101, 100).cuda() + b = torch.randn(101, 100).cuda() + c = torch.randn(102, 100).cuda() + d = torch.randn(102, 100).cuda() + e = torch.randn(103, 100).cuda() + f = torch.randn(103, 100).cuda() + + with PatchCaches(): + f_compiled(a, b, c, d, e, f) + + self.assertEqual(global_stats.autotune_local, Stats(3, 0, 3)) + self.assertEqual(global_stats.bundled_autotune, Stats(1, 0, 1)) + + self.reset() + f_compiled(a, b, c, d, e, f) + + self.assertEqual(global_stats.autotune_local, Stats(6, 3, 3)) + self.assertEqual(global_stats.bundled_autotune, Stats(1, 1, 1)) + + if config.is_fbcode(): + # Check that the cache entries seem reasonable + for k in global_stats.autotune_local.cache.keys(): + self.assertRegex(k, r"tmp[^/]*/([^/]{2})/c\1[^/]{49}\.best_config") + for k in global_stats.bundled_autotune.cache.keys(): + self.assertRegex(k, r"pt2:bundled-autotune-v1::[0-9a-z]{64}:c10") + for k in global_stats.triton.cache.keys(): + self.assertRegex(k, r"triton:[0-9a-f]{64}::[0-9a-f]{64}:c10") + class TestRemoteAOTAutogradCache(TestCase): @unittest.skipIf(not HAS_CUDA, "Requires CUDA") diff --git a/test/inductor/test_codegen_triton.py b/test/inductor/test_codegen_triton.py index c9cea123041d8..84264bf1b0119 100644 --- a/test/inductor/test_codegen_triton.py +++ b/test/inductor/test_codegen_triton.py @@ -39,32 +39,44 @@ def test_config_of_sizearg(self): s0 = sympy.Symbol("s0", positive=True, integer=True) s1 = sympy.Symbol("s1", positive=True, integer=True) + def _check_divisibility(config): + try: + from triton.backends.compiler import AttrsDescriptor # noqa: F401 + + return config.divisibility_16 + except ImportError: + return config.divisible_by_16 + self.assertEqual( (2,), - triton_utils.config_of( - [ - SizeArg("A", two), # no - SizeArg("B", eight), # no - SizeArg("C", sixteen), # yes - SizeArg("D", s0), # no - SizeArg("E", s1), # no - ] - ).divisible_by_16, + _check_divisibility( + triton_utils.config_of( + [ + SizeArg("A", two), # no + SizeArg("B", eight), # no + SizeArg("C", sixteen), # yes + SizeArg("D", s0), # no + SizeArg("E", s1), # no + ] + ) + ), ) self.assertEqual( (0, 2, 4, 5, 6), - triton_utils.config_of( - [ - SizeArg("A", two * eight), # 0: yes - SizeArg("B", eight * s0), # 1: no - SizeArg("C", two * eight * s0), # 2: yes - SizeArg("D", s0 * s1), # 3: no - SizeArg("E", sixteen * s0), # 4: yes - SizeArg("F", sixteen * eight * s0 * s1), # 5: yes - SizeArg("G", two * eight * s0 * s1), # 6: yes - ] - ).divisible_by_16, + _check_divisibility( + triton_utils.config_of( + [ + SizeArg("A", two * eight), # 0: yes + SizeArg("B", eight * s0), # 1: no + SizeArg("C", two * eight * s0), # 2: yes + SizeArg("D", s0 * s1), # 3: no + SizeArg("E", sixteen * s0), # 4: yes + SizeArg("F", sixteen * eight * s0 * s1), # 5: yes + SizeArg("G", two * eight * s0 * s1), # 6: yes + ] + ) + ), ) diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index aa4cc59928a88..80b174cc4ae3e 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -6,9 +6,11 @@ import itertools import logging import os +import queue import re import subprocess import sys +import threading import unittest from importlib.machinery import SourceFileLoader from pathlib import Path @@ -24,7 +26,7 @@ from torch._dynamo.utils import counters from torch._inductor import config as inductor_config from torch._inductor.test_case import run_tests, TestCase -from torch.testing._internal.common_utils import skipIfWindows +from torch.testing._internal.common_utils import scoped_load_inline, skipIfWindows from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_CUDA, HAS_GPU from torch.testing._internal.logging_utils import logs_to_string @@ -1584,7 +1586,8 @@ def _compiler_fn(gm): f, compiler_fn=compiler_fn_with_op_check, compile_fn=False ) - def test_non_traceable_autograd_cpp_node(self): + @scoped_load_inline + def test_non_traceable_autograd_cpp_node(self, load_inline): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { static constexpr bool is_traceable = false; @@ -1611,7 +1614,7 @@ def test_non_traceable_autograd_cpp_node(self): } """ - module = torch.utils.cpp_extension.load_inline( + module = load_inline( name="test_non_traceable_autograd_cpp_node", cpp_sources=cpp_source, functions="custom_op_backed_by_autograd_fn", @@ -1632,8 +1635,8 @@ def fn(): ), compiled_autograd.enable(compiler_fn): fn() - @unittest.skip("Flaky, cache from test ordering affects test. #135369") - def test_autograd_cpp_node(self): + @scoped_load_inline + def test_autograd_cpp_node(self, load_inline): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { static constexpr bool is_traceable = true; @@ -1660,7 +1663,7 @@ def test_autograd_cpp_node(self): } """ - module = torch.utils.cpp_extension.load_inline( + module = load_inline( name="test_autograd_cpp_node", cpp_sources=cpp_source, functions="custom_op_backed_by_autograd_fn", @@ -1680,7 +1683,8 @@ def fn(): # compiles for 10 (static) and 100 (dynamic) self.check_output_and_recompiles(fn, 2) - def test_autograd_cpp_node_id(self): + @scoped_load_inline + def test_autograd_cpp_node_id(self, load_inline): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { static constexpr bool is_traceable = true; @@ -1728,7 +1732,7 @@ def test_autograd_cpp_node_id(self): } """ - module = torch.utils.cpp_extension.load_inline( + module = load_inline( name="test_autograd_cpp_node_id", cpp_sources=cpp_source, functions="custom_op_backed_by_autograd_fn", @@ -1771,7 +1775,8 @@ def fn(op): self.check_output_and_recompiles(different_autograd_fn, 2) - def test_autograd_cpp_node_saved(self): + @scoped_load_inline + def test_autograd_cpp_node_saved(self, load_inline): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { static constexpr bool is_traceable = true; @@ -1825,7 +1830,7 @@ def test_autograd_cpp_node_saved(self): } """ - module = torch.utils.cpp_extension.load_inline( + module = load_inline( name="test_autograd_cpp_node_saved", cpp_sources=cpp_source, functions="custom_op_backed_by_autograd_fn", @@ -1846,7 +1851,8 @@ def fn(): self.check_output_and_recompiles(fn, 2) - def test_autograd_cpp_node_saved_dynamic(self): + @scoped_load_inline + def test_autograd_cpp_node_saved_dynamic(self, load_inline): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { static constexpr bool is_traceable = true; @@ -1882,7 +1888,7 @@ def test_autograd_cpp_node_saved_dynamic(self): } """ - module = torch.utils.cpp_extension.load_inline( + module = load_inline( name="test_autograd_cpp_node_saved_dynamic", cpp_sources=cpp_source, functions="custom_op_backed_by_autograd_fn", @@ -1902,7 +1908,8 @@ def fn(): # compiles for 10 (static) and 100 (dynamic) self.check_output_and_recompiles(fn, 2) - def test_autograd_cpp_node_saved_int(self): + @scoped_load_inline + def test_autograd_cpp_node_saved_int(self, load_inline): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { static constexpr bool is_traceable = true; @@ -1941,7 +1948,7 @@ def test_autograd_cpp_node_saved_int(self): } """ - module = torch.utils.cpp_extension.load_inline( + module = load_inline( name="test_autograd_cpp_node_saved_int", cpp_sources=cpp_source, functions="custom_op_backed_by_autograd_fn", @@ -1960,7 +1967,8 @@ def fn(): self.check_output_and_recompiles(fn, 1) - def test_autograd_cpp_node_saved_float(self): + @scoped_load_inline + def test_autograd_cpp_node_saved_float(self, load_inline): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { static constexpr bool is_traceable = true; @@ -1999,7 +2007,7 @@ def test_autograd_cpp_node_saved_float(self): } """ - module = torch.utils.cpp_extension.load_inline( + module = load_inline( name="test_autograd_cpp_node_saved_float", cpp_sources=cpp_source, functions="custom_op_backed_by_autograd_fn", @@ -2019,7 +2027,8 @@ def fn(): # compiled autograd and dynamo both support symfloat, but not backend self.check_output_and_recompiles(fn, [1, 3]) - def test_autograd_cpp_node_data_dependent(self): + @scoped_load_inline + def test_autograd_cpp_node_data_dependent(self, load_inline): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { static constexpr bool is_traceable = true; @@ -2090,7 +2099,7 @@ def test_autograd_cpp_node_data_dependent(self): } """ - module = torch.utils.cpp_extension.load_inline( + module = load_inline( name="test_autograd_cpp_node_data_dependent", cpp_sources=cpp_source, functions="custom_op_backed_by_autograd_fn", @@ -2330,8 +2339,9 @@ def backward(ctx, gO): # Must skip since we do not know if the cpu scalar will be used only in ATen/prim ops. self.assertEqual(counters["inductor"]["cudagraph_skips"], 1) + @scoped_load_inline @unittest.skipIf(not HAS_CUDA, "requires cuda") - def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self): + def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self, load_inline): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { static constexpr bool is_traceable = true; @@ -2369,7 +2379,7 @@ def test_cudagraphs_cpu_scalar_used_in_cpp_custom_op(self): } """ - module = torch.utils.cpp_extension.load_inline( + module = load_inline( name="test_cudagraphs_cpu_scalar_used_in_cpp_custom_op", cpp_sources=cpp_source, functions="custom_op_backed_by_autograd_fn", @@ -2405,6 +2415,39 @@ def test_logs(self): not in logs.getvalue() ) + def test_multithreading_tls(self): + def train(errors, model, x): + try: + out = model(x) + with compiled_autograd.enable(compiler_fn): + self.assertEqual(compiled_autograd.enabled(), True) + self.assertEqual(compiled_autograd.local.get("next_ctx_id"), 1) + except Exception as e: + print(f"Found error: {e}") + errors.put(1) + raise + + model = torch.nn.Sequential( + torch.nn.Linear(4, 4), + torch.nn.ReLU(), + torch.nn.Linear(4, 4), + torch.nn.ReLU(), + ) + x = torch.randn([2, 4]) + + threads = [] + errors = queue.Queue() + with compiled_autograd.enable(compiler_fn): + for i in range(4): + thread = threading.Thread(target=train, args=(errors, model, x)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert errors.empty() + def test_verbose_logs_graph(self): def fn(): model = torch.nn.Sequential( @@ -2839,6 +2882,12 @@ def wrap_test_class(orig_cls): "test_backward_tensorlist_input_requires_list_grads_none_or_Tensor", # torch/_custom_op/autograd.py in skip files "test_backward_tensorlist_input_requires_list_grads_with_same_numel", # torch/_custom_op/autograd.py in skip files "test_save_for_backward_inputs_are_namedtuple", # torch/_custom_op/autograd.py in skip files + "test_reentrant_with_leaf_variable_hook", # reentrant .backward + "test_reentrant_with_non_leaf_variable_hook", # reentrant .backward + "test_reentrant_child_error", # reentrant .backward + "test_deep_reentrant", # reentrant .backward + "test_reentrant_priority", # reentrant .backward + "test_simple_reentrant", # reentrant .backward } test_contexts = { @@ -2860,9 +2909,11 @@ def wrap_test_class(orig_cls): known_failing_tests = { # Category: Compiled autograd + "test_grad_mode_restored_reentrant", # create_graph + "test_reentrant_with_callbacks_both_depths", # queue_callback + "test_reentrant_with_callbacks_depth_0", # queue_callback + "test_reentrant_with_callbacks_depth_1", # queue_callback "test_current_graph_task_execution_order", # nodes are already freed by the time dynamo traces the lifted hook - "test_reentrant_with_leaf_variable_hook", # hangs when enabled with graph breaks - "test_reentrant_with_non_leaf_variable_hook", # hangs when enabled with graph breaks "test_anomaly_grad_warnings", # does not support anomaly mode "test_autograd_inplace_views_cross_dtype", # view_fn not supported by compiled autograd "test_current_node", # TorchDispatchMode not yet implemented for compiled autograd @@ -2872,7 +2923,6 @@ def wrap_test_class(orig_cls): "test_retain_grad_inplace_over_view", # retains_grad_hooks "test_retains_grad_can_always_observe_tensor_prehook", # retains_grad_hooks "test_retains_grad_inplace_multiple_outputs", # retains_grad_hooks - "test_reentrant_child_error", # hangs when enabled with graph breaks "test_accumulate_grad", # create_graph "test_anomaly_assign_parent_cleanup", # create_graph "test_anomaly_mode_no_check_nan", # anomaly mode @@ -2911,19 +2961,12 @@ def wrap_test_class(orig_cls): "test_custom_autograd_no_early_free", # create_graph "test_custom_function_error", # vjp "test_custom_function_save_for_forward", # vjp - "test_deep_reentrant", # hangs with graph breaks "test_dont_materialize_grads", # undefined grad - "test_grad_mode_restored_reentrant", # hangs with graph breaks "test_no_grad_copy", # setting static member in lifted backward "test_no_grad_copy_sparse", # setting static member in lifted backward "test_node_ordering_when_none_returned", # torch._dynamo.exc.Unsupported: TypeError None: do { \\ hipError_t code = EXPR; \\ const char *msg; \\ - hipDrvGetErrorString(code, &msg); \\ + hipError_t code_get_error = hipDrvGetErrorString(code, &msg); \\ + if (code_get_error != hipSuccess) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string("invalid error code!")); \\ + } \\ if (code != hipSuccess) { \\ throw std::runtime_error( \\ std::string("CUDA driver error: ") + \\ diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index 17b5291201709..a1069678eff2a 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -13,6 +13,7 @@ IS_MACOS, IS_WINDOWS, slowTest, + TEST_MKL, TEST_WITH_ROCM, ) from torch.testing._internal.inductor_utils import HAS_CPU @@ -142,7 +143,6 @@ def fn(self): if RUN_CPU: - config.abi_compatible = True class BaseTest(NamedTuple): name: str @@ -168,12 +168,8 @@ class BaseTest(NamedTuple): test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), func_inputs=[ - ["aoti_torch_cpu_mkldnn__convolution_pointwise_binary("] - if config.abi_compatible - else ["op_mkldnn__convolution_pointwise_binary.call"], - ["aoti_torch_cpu_mkldnn__convolution_pointwise_binary_("] - if config.abi_compatible - else ["op_mkldnn__convolution_pointwise__binary.call"], + ["aoti_torch_cpu_mkldnn__convolution_pointwise_binary("], + ["aoti_torch_cpu_mkldnn__convolution_pointwise_binary_("], ], ), BaseTest( @@ -182,12 +178,8 @@ class BaseTest(NamedTuple): test_mkldnn_pattern_matcher.TestPatternMatcher(), condition=torch.backends.mkldnn.is_available(), func_inputs=[ - ["aoti_torch_cpu_mkldnn__convolution_pointwise_binary_("] - if config.abi_compatible - else ["op_mkldnn__convolution_pointwise__binary.call"], - ["aoti_torch_cpu_mkldnn__convolution_pointwise_binary("] - if config.abi_compatible - else ["op_mkldnn__convolution_pointwise_binary.call"], + ["aoti_torch_cpu_mkldnn__convolution_pointwise_binary_("], + ["aoti_torch_cpu_mkldnn__convolution_pointwise_binary("], ], ), BaseTest( @@ -245,6 +237,9 @@ class BaseTest(NamedTuple): if func.startswith("test_lstm_packed_change_input_sizes") ], BaseTest("test_max_pool2d6"), + BaseTest( + "test_mkl_linear", "", test_cpu_repro.CPUReproTests(), condition=TEST_MKL + ), BaseTest("test_mm_views"), BaseTest("test_multihead_attention", "cpu", test_cpu_repro.CPUReproTests()), BaseTest( @@ -290,13 +285,11 @@ class BaseTest(NamedTuple): test_mkldnn_pattern_matcher.TestDynamicPatternMatcher(), condition=torch.backends.mkldnn.is_available() and not IS_WINDOWS, func_inputs=[ - None - if config.abi_compatible - else [ - "op_onednn_qconv2d_pointwise_.call", - "op_quantized_max_pool2d_.call", - "op_onednn_qlinear_pointwise_tensor.call", - ], + [ + "torch.ops.onednn.qconv2d_pointwise", + "torch.ops.quantized.max_pool2d", + "aoti_torch_cpu__qlinear_pointwise_tensor", + ] ], ), *[ diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index a9042e73fdd54..b8ed2c6644a39 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -4,6 +4,7 @@ import functools import itertools import math +import os import platform import sys import unittest @@ -34,6 +35,7 @@ parametrize, skipIfRocm, slowTest, + TEST_MKL, ) from torch.utils._python_dispatch import TorchDispatchMode @@ -60,12 +62,16 @@ check_model = test_torchinductor.check_model requires_vectorization = unittest.skipUnless( - cpu_vec_isa.valid_vec_isa_list(), "Does not support vectorization" + cpu_vec_isa.valid_vec_isa_list() and os.getenv("ATEN_CPU_CAPABILITY") != "default", + "Does not support vectorization", ) def check_metrics_vec_kernel_count(num_expected_vec_kernels): - if cpu_vec_isa.valid_vec_isa_list(): + if ( + cpu_vec_isa.valid_vec_isa_list() + and os.getenv("ATEN_CPU_CAPABILITY") != "default" + ): assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels @@ -206,6 +212,24 @@ def test_conv2d_autocast(self): (v,), ) + @config.patch(freezing=True) + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @patch("torch.cuda.is_available", lambda: False) + def test_mkl_linear(self): + dtypes = [torch.float32] + options = itertools.product([[2, 3, 10]], [2], [True, False], dtypes) + for input_shape, out_dim, bias, dtype in options: + mod = torch.nn.Sequential( + torch.nn.Linear(input_shape[-1], out_dim, bias=bias) + ).eval() + + v = torch.randn(input_shape) + with torch.no_grad(): + self.common( + mod.to(dtype), + (v.to(dtype),), + ) + @unittest.skipIf(not torch.backends.mkldnn.is_available(), "MKLDNN is not enabled") @patch("torch.cuda.is_available", lambda: False) def test_unsupported_conv_transpose(self): @@ -1639,6 +1663,73 @@ def fn(x): metrics.reset() self.common(fn, (value,)) + @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") + @unittest.skipIf( + not cpu_vec_isa.valid_vec_isa_list() + or "avx2" in [str(vec_isa) for vec_isa in cpu_vec_isa.valid_vec_isa_list()], + "Does not support vectorization or not s390x/ppc64le machine", + ) + @patch("torch.cuda.is_available", lambda: False) + def test_auto_zvec_vsx_simd(self): + vec_zvec_vsx = cpu_vec_isa.valid_vec_isa_list()[0] + self.assertTrue(vec_zvec_vsx.bit_width() == 256) + + with config.patch({"cpp.simdlen": 0}): + isa = cpu_vec_isa.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": 1}): + isa = cpu_vec_isa.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": 257}): + isa = cpu_vec_isa.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": 256}): + isa = cpu_vec_isa.pick_vec_isa() + self.assertTrue(isa == vec_zvec_vsx) + + pre_var = os.getenv("ATEN_CPU_CAPABILITY") + if pre_var: + os.environ.pop("ATEN_CPU_CAPABILITY") + + try: + with config.patch({"cpp.simdlen": None}): + isa = cpu_vec_isa.pick_vec_isa() + self.assertTrue(isa == vec_zvec_vsx) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "avx2" + isa = cpu_vec_isa.pick_vec_isa() + self.assertTrue(isa == vec_zvec_vsx) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "avx512" + isa = cpu_vec_isa.pick_vec_isa() + self.assertTrue(isa == vec_zvec_vsx) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "default" + isa = cpu_vec_isa.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "zvector" + isa = cpu_vec_isa.pick_vec_isa() + self.assertTrue(isa == vec_zvec_vsx) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "vsx" + isa = cpu_vec_isa.pick_vec_isa() + self.assertTrue(isa == vec_zvec_vsx) + + finally: + if pre_var: + os.environ["ATEN_CPU_CAPABILITY"] = pre_var + elif os.getenv("ATEN_CPU_CAPABILITY"): + os.environ.pop("ATEN_CPU_CAPABILITY") + @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") @unittest.skipIf( platform.machine() != "x86_64" or not cpu_vec_isa.valid_vec_isa_list(), @@ -1659,15 +1750,6 @@ def test_auto_simd(self): self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32) self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16) - with config.patch({"cpp.simdlen": None}): - isa = cpu_vec_isa.pick_vec_isa() - if vec_amx in cpu_vec_isa.valid_vec_isa_list(): - self.assertTrue(isa == vec_amx) - elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list(): - self.assertTrue(isa == vec_avx512) - else: - self.assertTrue(isa == vec_avx2) - with config.patch({"cpp.simdlen": 0}): isa = cpu_vec_isa.pick_vec_isa() self.assertFalse(isa) @@ -1699,6 +1781,71 @@ def test_auto_simd(self): isa = cpu_vec_isa.pick_vec_isa() self.assertTrue(isa == vec_avx2) + pre_var = os.getenv("ATEN_CPU_CAPABILITY") + if pre_var: + os.environ.pop("ATEN_CPU_CAPABILITY") + + try: + with config.patch({"cpp.simdlen": None}): + isa = cpu_vec_isa.pick_vec_isa() + if vec_amx in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_amx) + elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "avx2" + isa = cpu_vec_isa.pick_vec_isa() + if vec_amx in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx2) + elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx2) + elif vec_avx2 in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx2) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "avx512" + isa = cpu_vec_isa.pick_vec_isa() + if vec_amx in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_amx) + elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "default" + isa = cpu_vec_isa.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "zvector" + isa = cpu_vec_isa.pick_vec_isa() + if vec_amx in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_amx) + elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "vsx" + isa = cpu_vec_isa.pick_vec_isa() + if vec_amx in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_amx) + elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + finally: + if pre_var: + os.environ["ATEN_CPU_CAPABILITY"] = pre_var + elif os.getenv("ATEN_CPU_CAPABILITY"): + os.environ.pop("ATEN_CPU_CAPABILITY") + @requires_vectorization @patch("torch.cuda.is_available", lambda: False) def test_masked_fill_softmax(self): @@ -2579,7 +2726,15 @@ def fn(x): with config.patch({"cpp.simdlen": None}): torch._dynamo.reset() metrics.reset() - self.common(fn, (x,)) + atol = None + rtol = None + if ( + not cpu_vec_isa.valid_vec_isa_list() + or os.getenv("ATEN_CPU_CAPABILITY") == "default" + ): + atol = 1e-5 + rtol = 1e-5 + self.common(fn, (x,), atol=atol, rtol=rtol) self.assertEqual( len(metrics.cpp_outer_loop_fused_inner_counts), 1, @@ -2679,6 +2834,7 @@ def fn(x, y): 1, ) + @requires_vectorization def test_argmin(self): def fn(x): return torch.argmin(x, -1) @@ -2690,6 +2846,7 @@ def fn(x): self.common(fn, (x,)) assert metrics.generated_cpp_vec_kernel_count == 1 + @requires_vectorization def test_argmax_argmin_with_nan_value(self): def fn(x): return torch.argmax(x) @@ -3574,6 +3731,7 @@ def forward(self, idx, x): self.common(m, (idx, x)) check_metrics_vec_kernel_count(1) + @requires_vectorization def test_embedding_vec_bf16(self): class M(torch.nn.Module): def __init__(self) -> None: @@ -3915,7 +4073,7 @@ def fn(x): x = torch.randint(0, 100, (819,), dtype=torch.int64) metrics.reset() self.common(fn, (x,)) - assert metrics.generated_cpp_vec_kernel_count == 1 + check_metrics_vec_kernel_count(1) def test_highp_to_lowp_cse_var_cache_with_store(self): # Fix issue: https://github.com/pytorch/pytorch/issues/128263 @@ -3949,7 +4107,7 @@ def fn(x): x = torch.randint(0, 100, (22, 51), dtype=torch.int64) metrics.reset() self.common(fn, (x,)) - assert metrics.generated_cpp_vec_kernel_count == 1 + check_metrics_vec_kernel_count(1) @config.patch({"cpp.dynamic_threads": True}) def test_reduction_with_dynamic_threads(self): @@ -4012,6 +4170,47 @@ def forward(self, x): x = torch.randn(1, 4, 2, 2) self.common(fn, (x,)) + @parametrize("is_inference", (True, False)) + def test_disabled_amp(self, is_inference): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.all_head_size = 12 * 64 + self.dense = nn.Linear(self.all_head_size, self.all_head_size) + + def forward(self, q, k, v): + context_layer = F.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=0.2 + ) + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + ( + self.all_head_size, + ) + context_layer = context_layer.view(new_context_layer_shape) + return self.dense(context_layer) + + mod = M().to(torch.bfloat16).eval() + + q = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0 + k = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0 + v = torch.randn((4, 12, 512, 64), dtype=torch.bfloat16) / 10.0 + inputs = ( + q, + k, + v, + ) + compiler_mode = torch.compile(mod) + from torch.nn.attention import sdpa_kernel, SDPBackend + + context = contextlib.nullcontext if not is_inference else torch.no_grad + with config.patch( + {"fallback_random": True} + ), torch.cpu.amp.autocast(), context(), sdpa_kernel(SDPBackend.MATH): + torch.manual_seed(0) + eager = mod(*inputs) + torch.manual_seed(0) + self.assertEqual(compiler_mode(*inputs), eager) + @requires_vectorization def test_vec_indirect_load_cse_cache(self): # https://github.com/pytorch/pytorch/issues/123502 @@ -4060,6 +4259,7 @@ def fn(arg0_1, arg0_2): exactly=True, ).run(code) + @requires_vectorization def test_repeated_exp(self): def fn(x): y = x.sigmoid() @@ -4088,6 +4288,7 @@ def fn(x): self.common(fn, (x,)) check_metrics_vec_kernel_count(1) + @requires_vectorization def test_consistent_remove_buffers(self): def fn(x): z = x + x @@ -4198,6 +4399,34 @@ def func2(arg0, arg1): ): check_use_full_bits(func, shapes, dtype, mixed, check_vecn) + @config.patch("cpp.simdlen", 256) + @requires_vectorization + def test_avx2_bool_constant_pad_nd(self): + # NOTE: I tried using (0, 12, 12) and removing the cpp.simdlen=256 override, but + # that didn't repro the issue. + result = torch.testing.make_tensor( + (0, 6, 6), dtype=torch.bool, device=torch.device("cpu") + ) + + def fn(arg): + return torch.constant_pad_nd(arg, (1, 1, 1, 1, 1, 1)) + + self.common(fn, (result,)) + + @config.patch(unroll_reductions_threshold=9999) + @requires_vectorization + def test_unrolled_bool_prod_vectorized(self): + result = torch.zeros((37, 37, 37), dtype=torch.bool) + dim_select = [0, 1] + result.narrow(dim_select[0], 0, 1).narrow(dim_select[1], 1, 1).zero_() + result.narrow(dim_select[0], 2, 1).narrow(dim_select[1], 3, 1).zero_() + result.narrow(dim_select[0], 4, 1).narrow(dim_select[1], 3, 1).zero_() + + def fn(arg): + return torch.prod(arg, 1, dtype=torch.bool) + + self.common(fn, (result,)) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index dd9290168fa4b..5812b63e79573 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -1623,6 +1623,111 @@ def forward(self, x): self.common(mod, (v,), atol=atol, rtol=rtol) self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + @inductor_config.patch({"freezing": False}) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("batch_size", (16,)) + @parametrize("in_features", (128,)) + @parametrize("out_features", (64,)) + @parametrize("bias", (True,)) + @dtypes( + torch.float, + ) + def test_aoti_linear(self, batch_size, in_features, out_features, bias, dtype): + try: + try: + from . import test_aot_inductor_utils + except ImportError: + import test_aot_inductor_utils + except Exception: + # skip this UT if import failed + return + + class M(torch.nn.Module): + def __init__(self, bias=bias) -> None: + super().__init__() + self.mlp = torch.nn.Sequential( + torch.nn.Linear(in_features, out_features, bias=bias), + torch.nn.ReLU(), + ) + + def forward(self, x): + return self.mlp(x) + + assert torch._inductor.config.freezing is False + + counters.clear() + v = torch.randn(batch_size, in_features).to(dtype=dtype) + mod = M(bias=bias).to(dtype=dtype).eval() + torch._dynamo.reset() + torch._inductor.metrics.reset() + torch.manual_seed(0) + with verify(dtype) as (atol, rtol), torch.no_grad(): + expected = mod(v) + actual = test_aot_inductor_utils.AOTIRunnerUtil.run( + "cpu", + mod, + (v,), + ) + self.assertEqual(actual, expected, atol=atol, rtol=rtol) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + + @inductor_config.patch({"freezing": False}) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("batch_size", (16,)) + @parametrize("in_features", (128,)) + @parametrize("out_features", (64,)) + @dtypes( + torch.float, + ) + def test_aoti_linear_multi_view_operations( + self, batch_size, in_features, out_features, dtype + ): + try: + try: + from . import test_aot_inductor_utils + except ImportError: + import test_aot_inductor_utils + except Exception: + # skip this UT if import failed + return + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.bias = torch.randn(out_features) + self.weight = torch.randn(out_features // 2, 2, in_features) + self.relu = torch.nn.ReLU() + + def forward(self, x): + tmp = torch.addmm( + self.bias, + x, + self.weight.permute(2, 0, 1).view(in_features, out_features), + ) + return self.relu(tmp) + + assert torch._inductor.config.freezing is False + + counters.clear() + v = torch.randn(batch_size, in_features).to(dtype=dtype) + mod = M().to(dtype=dtype).eval() + torch._dynamo.reset() + torch._inductor.metrics.reset() + torch.manual_seed(0) + with verify(dtype) as (atol, rtol), torch.no_grad(): + expected = mod(v) + actual = test_aot_inductor_utils.AOTIRunnerUtil.run( + "cpu", + mod, + (v,), + ) + self.assertEqual(actual, expected, atol=atol, rtol=rtol) + self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1) + @dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False}) class _DynamicShapesTestBase(BaseTestSelectAlgorithm): diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py index 7cdf664ae5260..4ee224462d09b 100644 --- a/test/inductor/test_cuda_cpp_wrapper.py +++ b/test/inductor/test_cuda_cpp_wrapper.py @@ -121,7 +121,6 @@ def fn(self): if RUN_CUDA: - config.abi_compatible = True class BaseTest(NamedTuple): name: str @@ -177,6 +176,7 @@ class BaseTest(NamedTuple): BaseTest(f"test_unspec_inputs_{str(dtype)[6:]}") for dtype in test_torchinductor.test_dtypes ], + BaseTest("test_consecutive_split_cumprod"), BaseTest("test_pointwise_hermite_polynomial_he"), BaseTest("test_pointwise_hermite_polynomial_h"), BaseTest( diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index d54a4c464fca4..b86575e12be34 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -1,4 +1,5 @@ # Owner(s): ["module: inductor"] +import functools import gc import math import sys @@ -25,6 +26,7 @@ from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, SM80OrLater, + TEST_MULTIGPU, ) from torch.testing._internal.common_utils import ( DeterministicGuard, @@ -33,6 +35,11 @@ skipIfRocm, TEST_WITH_ASAN, ) + + +requires_multigpu = functools.partial( + unittest.skipIf, not TEST_MULTIGPU, "requires multiple cuda devices" +) from torch.testing._internal.inductor_utils import skipCUDAIf @@ -410,7 +417,7 @@ def test_autotune_inplace_kernel(self): https://github.com/pytorch/torchdynamo/issues/1670 """ from torch._C import _cuda_getCurrentRawStream as get_cuda_stream - from torch._inductor.runtime.hints import HeuristicType, instance_descriptor + from torch._inductor.runtime.hints import AttrsDescriptorWrapper, HeuristicType from torch._inductor.runtime.triton_heuristics import CachingAutotuner, grid def autotune(configs, meta): @@ -440,7 +447,9 @@ def decorator(fn): "xnumel": "i32", }, "device": DeviceProperties.create(torch.device("cuda")), - "configs": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())], + "configs": [ + AttrsDescriptorWrapper(divisible_by_16=(0, 1), equal_to_1=()) + ], "constants": {}, }, ) @@ -1383,6 +1392,24 @@ def foo(inp): foo_c = torch.compile(foo) torch.testing.assert_allclose(foo(inp), foo_c(inp)) + @requires_multigpu() + def test_not_initializing_wrong_device(self): + device_stats = torch.cuda.memory_stats("cuda:0") + + @torch.compile() + def foo(x, y): + return x @ y + + x = torch.rand([256, 256], device="cuda:1", requires_grad=True) + y = torch.rand([256, 256], device="cuda:1", requires_grad=True) + + foo(x, y).sum().backward() + + device_stats2 = torch.cuda.memory_stats("cuda:0") + self.assertTrue( + device_stats2["active.all.peak"] <= device_stats["active.all.peak"] + ) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_cudacodecache.py b/test/inductor/test_cudacodecache.py index 9697451bf6898..549bfd31f3d74 100644 --- a/test/inductor/test_cudacodecache.py +++ b/test/inductor/test_cudacodecache.py @@ -10,6 +10,7 @@ from torch._inductor.codegen.cuda.cuda_env import nvcc_exist from torch._inductor.exc import CUDACompileError from torch._inductor.test_case import TestCase as InductorTestCase +from torch._inductor.utils import fresh_inductor_cache _SOURCE_CODE = r""" @@ -39,51 +40,56 @@ @unittest.skipIf(config.is_fbcode(), "fbcode requires different CUDA_HOME setup") class TestCUDACodeCache(InductorTestCase): def test_cuda_load(self): - # Test both .o and .so compilation. - object_file_path, object_hash_key, source_code_path0 = CUDACodeCache.compile( - _SOURCE_CODE, "o" - ) - dll_wrapper, so_hash_key, source_code_path1 = CUDACodeCache.load( - _SOURCE_CODE, "so" - ) - self.assertNotEqual(source_code_path0, source_code_path1) - self.assertNotEqual(object_hash_key, so_hash_key) - - # Test load and call functions in .so. - x = torch.rand(10).float().cuda() - y = torch.rand(10).float().cuda() - a = 5.0 - expected_y = a * x + y - res = dll_wrapper.saxpy( - ctypes.c_int(10), - ctypes.c_float(a), - ctypes.c_void_p(x.data_ptr()), - ctypes.c_void_p(y.data_ptr()), - ) - torch.testing.assert_close(y, expected_y) + with fresh_inductor_cache(): + # Test both .o and .so compilation. + ( + object_file_path, + object_hash_key, + source_code_path0, + ) = CUDACodeCache.compile(_SOURCE_CODE, "o") + dll_wrapper, so_hash_key, source_code_path1 = CUDACodeCache.load( + _SOURCE_CODE, "so" + ) + self.assertNotEqual(source_code_path0, source_code_path1) + self.assertNotEqual(object_hash_key, so_hash_key) + + # Test load and call functions in .so. + x = torch.rand(10).float().cuda() + y = torch.rand(10).float().cuda() + a = 5.0 + expected_y = a * x + y + res = dll_wrapper.saxpy( + ctypes.c_int(10), + ctypes.c_float(a), + ctypes.c_void_p(x.data_ptr()), + ctypes.c_void_p(y.data_ptr()), + ) + torch.testing.assert_close(y, expected_y) def test_compilation_error(self): - error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1) - with self.assertRaises(CUDACompileError): - CUDACodeCache.compile(error_source_code, "o") + with fresh_inductor_cache(): + error_source_code = _SOURCE_CODE.replace("saxpy_device", "saxpy_wrong", 1) + with self.assertRaises(CUDACompileError): + CUDACodeCache.compile(error_source_code, "o") def test_async_compile(self): - async_compile = AsyncCompile() - compiled_res = async_compile.cuda(_SOURCE_CODE, "so") - async_compile.wait(globals()) - - # Test load and call functions in .so. - x = torch.rand(5).float().cuda() - y = torch.rand(5).float().cuda() - a = 2.0 - expected_y = a * x + y - res = compiled_res.result().saxpy( - ctypes.c_int(5), - ctypes.c_float(a), - ctypes.c_void_p(x.data_ptr()), - ctypes.c_void_p(y.data_ptr()), - ) - torch.testing.assert_close(y, expected_y) + with fresh_inductor_cache(): + async_compile = AsyncCompile() + compiled_res = async_compile.cuda(_SOURCE_CODE, "so") + async_compile.wait(globals()) + + # Test load and call functions in .so. + x = torch.rand(5).float().cuda() + y = torch.rand(5).float().cuda() + a = 2.0 + expected_y = a * x + y + res = compiled_res.result().saxpy( + ctypes.c_int(5), + ctypes.c_float(a), + ctypes.c_void_p(x.data_ptr()), + ctypes.c_void_p(y.data_ptr()), + ) + torch.testing.assert_close(y, expected_y) if __name__ == "__main__": diff --git a/test/inductor/test_cudagraph_trees.py b/test/inductor/test_cudagraph_trees.py index 7d102875e9222..f1bfaad712771 100644 --- a/test/inductor/test_cudagraph_trees.py +++ b/test/inductor/test_cudagraph_trees.py @@ -1033,15 +1033,18 @@ def foo(mod, x): def foo2(x): return x[2:] - x = torch.rand([10, 10], device="cuda", requires_grad=True) param_c = cdata(m.weight) for _ in range(3): + x = torch.rand([10, 10], device="cuda", requires_grad=True) + torch.compiler.cudagraph_mark_step_begin() out1, alias_1, alias_2 = foo(m, x) self.assertEqual(len({param_c, cdata(alias_1), cdata(alias_2)}), 1) out2 = foo2(out1) out2.sum().backward() self.assertEqual(cdata(out1), cdata(out2)) + m.weight.grad = None + m.bias.grad = None node = self.curr_node() first_node = next(node._path_from_root) @@ -1649,14 +1652,35 @@ def foo(x): out = foo(inp) out2 = foo(inp) - with self.assertRaisesRegex(Exception, "overwritten by a subsequent run."): + with self.assertRaisesRegex(Exception, "overwritten by a subsequent"): + out + out + + foo(inp) + + with self.assertRaisesRegex(Exception, "overwritten by a subsequent"): + out2 + out2 + + def test_error_on_dealloc_use2(self): + @torch.compile() + def foo(x): + return x * x * x + + inp = torch.rand([4], device="cuda") + out = foo(inp).detach() + out2 = foo(inp).detach() + + with self.assertRaises(Exception) as exc: out + out + FileCheck().check("overwritten").check("x * x * x").run(repr(exc.exception)) + foo(inp) - with self.assertRaisesRegex(Exception, "overwritten by a subsequent run."): + with self.assertRaises(Exception) as exc: out2 + out2 + FileCheck().check("overwritten").check("x * x * x").run(repr(exc.exception)) + @unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn") def test_conv_benchmark(self): with torch.backends.cudnn.flags( @@ -1681,6 +1705,7 @@ def foo(x): streams_init = {seg["stream"] for seg in get_all_cudagraph_segments()} for _ in range(4): foo(inp).sum().backward() + inp.grad = None streams = { seg["stream"] for seg in get_all_cudagraph_segments() @@ -1768,6 +1793,7 @@ def foo2(x): out2.sum().backward() self.assertFalse(self.get_manager().running_forwards_with_pending_backwards) + ones.grad = None del out del out2 @@ -2150,6 +2176,7 @@ def forward(self, x): fn_compiled = torch.compile(Foo(), mode="reduce-overhead") for _ in range(3): fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward() + fn_compiled.param.grad = None # Change static tensor address fn_compiled.param.data = torch.rand([2, 2], device="cuda") @@ -2187,11 +2214,13 @@ def forward(self, x): fn_compiled = torch.compile(Foo(), mode="reduce-overhead") for _ in range(3): fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward() + fn_compiled.param.grad = None for _ in range(5): # Change static tensor address fn_compiled.param.data = torch.rand([2, 2], device="cuda") fn_compiled(torch.rand([2, 2], device="cuda")).sum().backward() + fn_compiled.param.grad = None FileCheck().check_count( "skipping cudagraph due to function 0 exceeding max re-recording limit (=0) " diff --git a/test/inductor/test_custom_lowering.py b/test/inductor/test_custom_lowering.py index 4aaeac2b95458..17eb27ef4ec27 100644 --- a/test/inductor/test_custom_lowering.py +++ b/test/inductor/test_custom_lowering.py @@ -1,6 +1,5 @@ # Owner(s): ["module: inductor"] -import unittest from functools import partial import torch @@ -8,8 +7,13 @@ from torch._inductor.lowering import make_pointwise, register_lowering from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.virtualized import ops -from torch.testing._internal.common_utils import skipIfRocm -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.common_utils import skipIfRocm, skipIfXpu +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CPU, + HAS_GPU, + requires_gpu, +) # These tests check issues for lowerings that aren't in the main pytorch repo @@ -20,12 +24,15 @@ def setUpClass(cls): cls.test_inductor_ops = torch.library.Library( # noqa: TOR901 "test_inductor_ops", "DEF" ) - cls.impl_cuda = torch.library.Library( # noqa: TOR901 - "test_inductor_ops", "IMPL", "CUDA" - ) - cls.impl_meta = torch.library.Library( # noqa: TOR901 - "test_inductor_ops", "IMPL", "Meta" - ) + cls.device_list = ["Meta", "CUDA", "XPU"] + for device in cls.device_list: + setattr( + cls, + "impl_" + device.lower(), + torch.library.Library( # noqa: TOR901 + "test_inductor_ops", "IMPL", device + ), + ) cls._register_jagged_to_padded_dense() cls._register_asm_op() @@ -47,7 +54,7 @@ def j2pd_meta(inp, offsets, max_seq_len, pad_value): dtype=inp.dtype, ) - def j2pd_cuda(inp, offsets, max_seq_len, pad_value): + def j2pd_gpu(inp, offsets, max_seq_len, pad_value): res = torch.full( (offsets.shape[0] - 1, max_seq_len, inp.shape[1]), pad_value, @@ -96,7 +103,8 @@ def inner_fn(index): )(j2pd_lowering) cls.impl_meta.impl("jagged_to_padded_dense", j2pd_meta) - cls.impl_cuda.impl("jagged_to_padded_dense", j2pd_cuda) + cls.impl_cuda.impl("jagged_to_padded_dense", j2pd_gpu) + cls.impl_xpu.impl("jagged_to_padded_dense", j2pd_gpu) @classmethod def _register_asm_op(cls): @@ -131,15 +139,15 @@ def add_custom_lowering(a, b): torch.ops.test_inductor_ops.add_custom, type_promotion_kind=None )(add_custom_lowering) - @unittest.skipIf(not HAS_CUDA, "CUDA needed") + @requires_gpu() def test_jagged_to_padded_dense_sanity_cuda(self): def fn(inp, offsets, max_seq_len): return torch.ops.test_inductor_ops.jagged_to_padded_dense( inp, offsets, max_seq_len, 60.0 ) - inp = torch.rand((9, 96), device="cuda") - offsets = torch.tensor([0, 2, 5, 9], dtype=torch.int32, device="cuda") + inp = torch.rand((9, 96), device=GPU_TYPE) + offsets = torch.tensor([0, 2, 5, 9], dtype=torch.int32, device=GPU_TYPE) max_seq_len = 4 res = fn(inp, offsets, max_seq_len) @@ -156,19 +164,19 @@ def fn(inp, offsets, max_seq_len): fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len) ) - @unittest.skipIf(not HAS_CUDA, "CUDA needed") + @requires_gpu() def test_jagged_to_padded_dense_zero_size(self): # Previously, the masking was being completely stripped for the # masked load of the input value. That would lead to an IMA # because cuda was trying to read index 0 of a zero-size tensor. def fn(inp, offsets, max_seq_len): - inp = torch.bmm(inp, torch.ones((1, 96, 1), device="cuda")).view((0, 1)) + inp = torch.bmm(inp, torch.ones((1, 96, 1), device=GPU_TYPE)).view((0, 1)) return torch.ops.test_inductor_ops.jagged_to_padded_dense( inp, offsets, max_seq_len, 60.0 ) - inp = torch.rand((1, 0, 96), device="cuda") - offsets = torch.zeros(1025, device="cuda", dtype=torch.int32) + inp = torch.rand((1, 0, 96), device=GPU_TYPE) + offsets = torch.zeros(1025, device=GPU_TYPE, dtype=torch.int32) max_seq_len = 20 fn_opt = torch.compile(fn) @@ -177,27 +185,29 @@ def fn(inp, offsets, max_seq_len): fn(inp, offsets, max_seq_len), fn_opt(inp, offsets, max_seq_len) ) - @unittest.skipIf(not HAS_CUDA, "CUDA needed") + @requires_gpu() @skipIfRocm + @skipIfXpu def test_tanh_approx(self): def fn(inp): return torch.ops.test_inductor_ops.tanh_approx(inp) - inp = torch.randn(32, device="cuda") + inp = torch.randn(32, device=GPU_TYPE) fn_opt = torch.compile(fn) a = torch.tanh(inp) b = fn_opt(inp) self.assertEqual(a, b) - @unittest.skipIf(not HAS_CUDA, "CUDA needed") + @requires_gpu() @skipIfRocm + @skipIfXpu def test_multi_inp_asm(self): def fn(a, b): return torch.ops.test_inductor_ops.add_custom(a, b) - a = torch.randn(32, device="cuda") - b = torch.randn(32, device="cuda") + a = torch.randn(32, device=GPU_TYPE) + b = torch.randn(32, device=GPU_TYPE) fn_opt = torch.compile(fn) out1 = a + b @@ -208,5 +218,5 @@ def fn(a, b): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_GPU: run_tests(needs="filelock") diff --git a/test/inductor/test_dependencies.py b/test/inductor/test_dependencies.py index e9fe0fb519d7e..d61317832ed10 100644 --- a/test/inductor/test_dependencies.py +++ b/test/inductor/test_dependencies.py @@ -13,7 +13,10 @@ class TestDependencies(InductorTestCase): def _create_buffer(self, name, shape, dtype=torch.float32): - return Buffer(name, FixedLayout(torch.device(GPU_TYPE), dtype, shape)) + return Buffer( + name=name, + layout=FixedLayout(torch.device(GPU_TYPE), dtype=dtype, size=shape), + ) def setUp(self): super().setUp() diff --git a/test/inductor/test_efficient_conv_bn_eval.py b/test/inductor/test_efficient_conv_bn_eval.py index 911f320ce3dfd..9307345e6d590 100644 --- a/test/inductor/test_efficient_conv_bn_eval.py +++ b/test/inductor/test_efficient_conv_bn_eval.py @@ -17,7 +17,7 @@ from torch._inductor import config as inductor_config from torch._inductor.test_case import TestCase from torch.testing._internal.common_utils import TEST_WITH_ASAN -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU importlib.import_module("functorch") @@ -207,17 +207,17 @@ class EfficientConvBNEvalCpuTests(TestCase): copy_tests(EfficientConvBNEvalTemplate, EfficientConvBNEvalCpuTests, "cpu") -if HAS_CUDA and not TEST_WITH_ASAN: +if HAS_GPU and not TEST_WITH_ASAN: - class EfficientConvBNEvalCudaTests(TestCase): - device = "cuda" + class EfficientConvBNEvalGpuTests(TestCase): + device = GPU_TYPE - copy_tests(EfficientConvBNEvalTemplate, EfficientConvBNEvalCudaTests, "cuda") + copy_tests(EfficientConvBNEvalTemplate, EfficientConvBNEvalGpuTests, GPU_TYPE) del EfficientConvBNEvalTemplate if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_GPU: run_tests(needs="filelock") diff --git a/test/inductor/test_extension_backend.py b/test/inductor/test_extension_backend.py index 4504d472fe2fe..3742de5f31526 100644 --- a/test/inductor/test_extension_backend.py +++ b/test/inductor/test_extension_backend.py @@ -22,6 +22,8 @@ ExtensionWrapperCodegen, ) +from filelock import FileLock, Timeout + import torch._inductor.config as config from torch._inductor import cpu_vec_isa, metrics from torch._inductor.codegen import cpp_utils @@ -48,14 +50,23 @@ TestCase = test_torchinductor.TestCase -@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now") -class ExtensionBackendTests(TestCase): +class BaseExtensionBackendTests(TestCase): module = None + # Use a lock file so that only one test can build this extension at a time + lock_file = "extension_device.lock" + lock = FileLock(lock_file) + @classmethod def setUpClass(cls): super().setUpClass() + try: + cls.lock.acquire(timeout=600) + except Timeout: + # This shouldn't happen, still attempt to build the extension anyway + pass + # Build Extension torch.testing._internal.common_utils.remove_cpp_extensions_build_root() source_file_path = os.path.dirname(os.path.abspath(__file__)) @@ -78,6 +89,10 @@ def tearDownClass(cls): torch.testing._internal.common_utils.remove_cpp_extensions_build_root() + if os.path.exists(cls.lock_file): + os.remove(cls.lock_file) + cls.lock.release() + def setUp(self): torch._dynamo.reset() super().setUp() @@ -95,6 +110,9 @@ def tearDown(self): # return the working directory (see setUp) os.chdir(self.old_working_dir) + +@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now") +class ExtensionBackendTests(BaseExtensionBackendTests): def test_open_device_registration(self): torch.utils.rename_privateuse1_backend("extension_device") torch._register_device_module("extension_device", self.module) @@ -138,7 +156,10 @@ def fn(a, b, c): metrics.reset() opt_fn = torch.compile()(fn) _, code = run_and_get_cpp_code(opt_fn, x, y, z) - if cpu_vec_isa.valid_vec_isa_list(): + if ( + cpu_vec_isa.valid_vec_isa_list() + and os.getenv("ATEN_CPU_CAPABILITY") != "default" + ): load_expr = "loadu" else: load_expr = " = in_ptr0[static_cast(i0)];" diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index b6aaa1c374aa6..25fe73aaad246 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -989,14 +989,15 @@ def composed_score_mod(score, b, h, m, n): self.run_test(composed_score_mod, dtype) @supported_platform + @expectedFailure # TODO: Remove this after supporting compiled flex attention with training bias @common_utils.parametrize("dtype", test_dtypes) - def test_captured_buffers(self, dtype: torch.dtype): - head_offset = torch.rand(H, device="cuda", dtype=dtype) + def test_captured_buffers_req_grad(self, dtype: torch.dtype): + head_offset = torch.rand(8, device="cuda", dtype=dtype, requires_grad=True) def score_mod(score, b, h, m, n): return score + head_offset[h] - self.run_test(score_mod, dtype) + self.run_test(score_mod, dtype, 4, 8, 128, 128) @supported_platform @common_utils.parametrize("dtype", test_dtypes) @@ -1345,8 +1346,8 @@ def test_make_block_mask(self): def causal_mask(b, h, q_idx, kv_idx): return q_idx >= kv_idx - block_mask_a = create_block_mask(causal_mask, 1, 1, 512, 512, _compile=True) - block_mask_b = create_block_mask(causal_mask, 1, 1, 512, 512, _compile=False) + block_mask_a = torch.compile(create_block_mask)(causal_mask, 1, 1, 512, 512) + block_mask_b = create_block_mask(causal_mask, 1, 1, 512, 512) self.assertEqual(block_mask_a.kv_num_blocks, block_mask_b.kv_num_blocks) self.assertEqual(block_mask_a.kv_indices, block_mask_b.kv_indices) self.assertEqual(block_mask_a.q_num_blocks, block_mask_b.q_num_blocks) @@ -2068,6 +2069,257 @@ def causal_mask(b, h, q_idx, kv_idx): f"Ref error: {ref_error}, Flex Error: {flex_error}", ) + @supported_platform + def test_block_mask_non_divisible(self): + seq = torch.arange(1023, device="cuda") // 128 + + def mod(b, h, q, kv): + return seq[q] == seq[kv] + + block_mask = create_block_mask(mod, None, None, 1023, 1023, device="cuda") + torch.compile(create_block_mask)(mod, None, None, 1023, 1023, device="cuda") + self.run_test_with_call( + lambda q, k, v: flex_attention(q, k, v, block_mask=block_mask), + Q_S=1023, + KV_S=1023, + ) + + @supported_platform + def test_head_bias_req_grad(self): + B, H, S, D = 1, 4, 256, 64 + bias = torch.randn(H, device="cuda", dtype=torch.float16, requires_grad=True) + + bias_flex = bias.detach().clone().requires_grad_(True) + + def head_bias(score, b, h, q_idx, kv_idx): + return score + bias_flex[h] + + bias_sdpa_ref = bias.detach().clone().requires_grad_(True) + implicit_bias_sdpa_ref = bias_sdpa_ref + implicit_bias_sdpa_ref = implicit_bias_sdpa_ref.view(H, 1, 1).expand(H, S, S) + bias_sdpa_gold = ( + bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) + ) + implicit_bias_sdpa_gold = bias_sdpa_gold + implicit_bias_sdpa_gold = implicit_bias_sdpa_gold.view(H, 1, 1).expand(H, S, S) + + self._test_learnable_bias_inner( + B, + H, + S, + D, + head_bias, + bias_flex, + implicit_bias_sdpa_ref, + bias_sdpa_ref, + implicit_bias_sdpa_gold, + bias_sdpa_gold, + ) + + @supported_platform + def test_comparison_vs_sdpa_with_learnable_bias(self): + # 1-dimensional bias: + B, H, S, D = 1, 1, 256, 64 + bias = torch.randn( + 2 * S, device="cuda", dtype=torch.float16, requires_grad=True + ) + + bias_flex = bias.detach().clone().requires_grad_(True) + + def rel_pos_1d(score, b, h, q_idx, kv_idx): + return score + bias_flex[q_idx + kv_idx] + + bias_indices = torch.arange(S)[:, None] + torch.arange(S) + bias_sdpa_ref = bias.detach().clone().requires_grad_(True) + implicit_bias_sdpa_ref = bias_sdpa_ref[bias_indices] + bias_sdpa_gold = ( + bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) + ) + implicit_bias_sdpa_gold = bias_sdpa_gold[bias_indices] + + self._test_learnable_bias_inner( + B, + H, + S, + D, + rel_pos_1d, + bias_flex, + implicit_bias_sdpa_ref, + bias_sdpa_ref, + implicit_bias_sdpa_gold, + bias_sdpa_gold, + ) + + # 2-dimensional bias: + B, H, S, D = 1, 1, 256, 64 + bias = torch.randn(S, S, device="cuda", dtype=torch.float16, requires_grad=True) + + bias_flex = bias.detach().clone().requires_grad_(True) + + def rel_pos_2d(score, b, h, q_idx, kv_idx): + return score + bias_flex[q_idx, kv_idx] + + bias_sdpa_ref = bias.detach().clone().requires_grad_(True) + implicit_bias_sdpa_ref = bias_sdpa_ref + bias_sdpa_gold = ( + bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) + ) + implicit_bias_sdpa_gold = bias_sdpa_gold + + self._test_learnable_bias_inner( + B, + H, + S, + D, + rel_pos_2d, + bias_flex, + implicit_bias_sdpa_ref, + bias_sdpa_ref, + implicit_bias_sdpa_gold, + bias_sdpa_gold, + ) + + # 2-dimensional bias + index multiple + B, H, S, D = 1, 1, 256, 64 + bias = torch.randn(S, S, device="cuda", dtype=torch.float16, requires_grad=True) + + bias_flex = bias.detach().clone().requires_grad_(True) + + def rel_pos_2d(score, b, h, q_idx, kv_idx): + return score + bias_flex[q_idx][kv_idx] + + bias_sdpa_ref = bias.detach().clone().requires_grad_(True) + implicit_bias_sdpa_ref = bias_sdpa_ref + bias_sdpa_gold = ( + bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) + ) + implicit_bias_sdpa_gold = bias_sdpa_gold + + self._test_learnable_bias_inner( + B, + H, + S, + D, + rel_pos_2d, + bias_flex, + implicit_bias_sdpa_ref, + bias_sdpa_ref, + implicit_bias_sdpa_gold, + bias_sdpa_gold, + ) + + # 2-dimensional bias + transposed: + B, H, S, D = 1, 1, 256, 64 + bias = torch.randn(S, S, device="cuda", dtype=torch.float16, requires_grad=True) + + bias_flex = bias.detach().clone().requires_grad_(True) + + def rel_pos_2d_transposed(score, b, h, q_idx, kv_idx): + return score + bias_flex[kv_idx, q_idx] + + bias_sdpa_ref = bias.detach().clone().requires_grad_(True) + implicit_bias_sdpa_ref = bias_sdpa_ref.transpose(-1, -2) + bias_sdpa_gold = ( + bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) + ) + implicit_bias_sdpa_gold = bias_sdpa_gold.transpose(-1, -2) + + self._test_learnable_bias_inner( + B, + H, + S, + D, + rel_pos_2d_transposed, + bias_flex, + implicit_bias_sdpa_ref, + bias_sdpa_ref, + implicit_bias_sdpa_gold, + bias_sdpa_gold, + ) + + # 3-dimensional bias + transposed + B, H, S, D = 4, 8, 256, 64 + bias = torch.randn( + H, S, S, device="cuda", dtype=torch.float16, requires_grad=True + ) + + bias_flex = bias.detach().clone().requires_grad_(True) + + def rel_pos_3d_transposed(score, b, h, q_idx, kv_idx): + return score + bias_flex[h, kv_idx, q_idx] + + bias_sdpa_ref = bias.detach().clone().requires_grad_(True) + implicit_bias_sdpa_ref = bias_sdpa_ref.transpose(-1, -2) + bias_sdpa_gold = ( + bias.detach().clone().to(dtype=torch.float64).requires_grad_(True) + ) + implicit_bias_sdpa_gold = bias_sdpa_gold.transpose(-1, -2) + + self._test_learnable_bias_inner( + B, + H, + S, + D, + rel_pos_3d_transposed, + bias_flex, + implicit_bias_sdpa_ref, + bias_sdpa_ref, + implicit_bias_sdpa_gold, + bias_sdpa_gold, + ) + + def _test_learnable_bias_inner( + self, + B, + H, + S, + D, + score_mod, + bias_flex, + implicit_bias_sdpa_ref, + bias_sdpa_ref, + implicit_bias_sdpa_gold, + bias_sdpa_gold, + ): + make_tensor = functools.partial( + torch.ones, + (B, H, S, D), + device="cuda", + dtype=torch.float16, + requires_grad=True, + ) + q_ref, k_ref, v_ref = make_tensor(), make_tensor(), make_tensor() + q_gold, k_gold, v_gold = query_key_value_clones( + q_ref, k_ref, v_ref, torch.float64 + ) + q_flex, k_flex, v_flex = query_key_value_clones(q_ref, k_ref, v_ref) + + out_ref = torch.nn.functional.scaled_dot_product_attention( + q_ref, k_ref, v_ref, attn_mask=implicit_bias_sdpa_ref + ) + out_ref.sum().backward() + out_gold = torch.nn.functional.scaled_dot_product_attention( + q_gold, k_gold, v_gold, attn_mask=implicit_bias_sdpa_gold + ) + out_gold.sum().backward() + out_flex = flex_attention(q_flex, k_flex, v_flex, score_mod=score_mod) + out_flex.sum().backward() + + name = score_mod.__name__ + for ref, flex, gold in [ + (out_ref, out_flex, out_gold), + (q_ref.grad, q_flex.grad, q_gold.grad), + (k_ref.grad, k_flex.grad, k_gold.grad), + (v_ref.grad, v_flex.grad, v_gold.grad), + (bias_sdpa_ref.grad, bias_flex.grad, bias_sdpa_gold.grad), + ]: + ref_error = rmse(ref, gold) + flex_error = rmse(flex, gold) + self.assertTrue( + ref_error * 1.2 >= flex_error, + f"{name} -> Ref error: {ref_error}, Flex eager Error: {flex_error}", + ) + @supported_platform def test_causal_block_non_divisible(self): def mask_mod(b, h, q, kv): @@ -2315,6 +2567,51 @@ def mask_mod(b, h, q, kv): ): torch.compile(flex_attention)(query, key, value, block_mask=block_mask) + @supported_platform + def test_free_symbol_dynamic(self): + def batch_flip_causal(b, h, q_idx, kv_idx): + return (q_idx >= kv_idx) & (b % 2 == 0) + + class SimpleAttention(torch.nn.Module): + def __init__(self, dim=512, n_head=8): + super().__init__() + self.qkv = torch.nn.Linear(dim, 3 * dim) + self.n_head = n_head + self.head_dim = dim // n_head + + def forward(self, x, block_mask=None): + B, T, C = x.size() + qkv = self.qkv(x).view(B, T, 3, self.n_head, self.head_dim) + qkv = qkv.permute(2, 0, 3, 1, 4) + q, k, v = qkv + y = flex_attention(q, k, v, block_mask=block_mask) + return y.transpose(1, 2).contiguous().view(B, T, C) + + model = SimpleAttention().cuda() + model.compile(mode="default", dynamic=True) + sequence_len = 256 + + # Test different batch shapes with dense masks + torch._dynamo.reset() + for batch_shape in [4, 16, 32]: + # Create dense mask + rand_mask = torch.randint(0, 2, (batch_shape, sequence_len)).cuda().bool() + block_mask = torch.compile(create_block_mask, dynamic=True)( + B=batch_shape, + BLOCK_SIZE=128, + mask_mod=lambda b, h, q_idx, kv_idx: ~rand_mask[b, q_idx], + H=None, + Q_LEN=sequence_len, + KV_LEN=sequence_len, + device="cuda", + ) + + # Run forward pass + x = torch.randn(batch_shape, sequence_len, 512).cuda() + y = model(x, block_mask=block_mask) + + self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2) + @supported_platform def test_fw_bw_graph_correctness(self): cnt = CompileCounterWithBackend("aot_eager") @@ -2562,10 +2859,14 @@ def causal_mask(b, h, q, kv): @supported_platform def test_compiling_create_block_mask(self): + seq = torch.arange(512, device="cuda") // 127 + def mask_mod(b, h, q, kv): - return q >= kv + return (q >= kv) & (seq[q] == seq[kv]) - block_mask = create_block_mask(mask_mod, 1, 1, 512, 512, _compile=True) + block_mask = torch.compile(create_block_mask, fullgraph=True)( + mask_mod, 1, 1, 512, 512 + ) self.assertIsInstance(block_mask, BlockMask) self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((1, 1, 4))) self.assertEqual(block_mask.kv_indices.shape, torch.Size((1, 1, 4, 4))) @@ -2576,21 +2877,21 @@ def mask_mod(b, h, q, kv): return q >= kv torch._dynamo.reset() - block_mask = create_block_mask(mask_mod, 2, 4, 1024, 1024, _compile=True) + block_mask = torch.compile(create_block_mask)(mask_mod, 2, 4, 1024, 1024) self.assertIsInstance(block_mask, BlockMask) self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((2, 4, 8))) self.assertEqual(block_mask.kv_indices.shape, torch.Size((2, 4, 8, 8))) self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 1) # automatic dynamic shapes triggered and recompilation. - block_mask = create_block_mask(mask_mod, 4, 8, 2048, 2048, _compile=True) + block_mask = torch.compile(create_block_mask)(mask_mod, 4, 8, 2048, 2048) self.assertIsInstance(block_mask, BlockMask) self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((4, 8, 16))) self.assertEqual(block_mask.kv_indices.shape, torch.Size((4, 8, 16, 16))) self.assertEqual(torch._dynamo.utils.counters["aot_autograd"]["ok"], 2) # no recompilation. - block_mask = create_block_mask(mask_mod, 6, 16, 3072, 3072, _compile=True) + block_mask = torch.compile(create_block_mask)(mask_mod, 6, 16, 3072, 3072) self.assertIsInstance(block_mask, BlockMask) self.assertEqual(block_mask.kv_num_blocks.shape, torch.Size((6, 16, 24))) self.assertEqual(block_mask.kv_indices.shape, torch.Size((6, 16, 24, 24))) @@ -2867,23 +3168,21 @@ def generate_random_lengths(total_length, num_documents): offsets = length_to_offsets(lengths, device) document_causal_mask = generate_doc_mask_mod(offsets) - block_mask_compiled = create_block_mask( + block_mask_compiled = torch.compile(create_block_mask)( document_causal_mask, 1, 1, SEQ_LEN, SEQ_LEN, device=device, - _compile=True, ) - block_mask = create_block_mask( + block_mask = torch.compile(create_block_mask)( document_causal_mask, 1, 1, SEQ_LEN, SEQ_LEN, device=device, - _compile=True, ) self.assertEqual(block_mask_compiled.kv_indices, block_mask.kv_indices) self.assertEqual( diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index 3348a90bc909e..72211eee70e74 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -88,6 +88,27 @@ def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype): @instantiate_parametrized_tests class TestFP8Types(TestCase): + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @unittest.skipIf(TEST_WITH_ROCM, "Not supported yet") + @parametrize("float8_dtype", (torch.float8_e4m3fn, torch.float8_e5m2)) + def test_xblock_for_small_numel(self, float8_dtype: torch.dtype): + """ + TritonOverrides.to_dtype will set min_elem_per_thread to 2 or 4 + depends on the variant of fp8 type. + This cause triton_heuristics.triton_config pick a XBLOCK larger + than numel and fail the config sanity check. + + We should not pick a XBLOCK larger than xnumel + """ + + def f(x): + return x.to(dtype=float8_dtype) + + x = torch.randn(1, device="cuda") + expected = f(x) + actual = torch.compile(f)(x) + torch.testing.assert_close(expected.half(), actual.half(), rtol=1e-2, atol=1e-2) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @unittest.skipIf(TEST_WITH_ROCM, "Not supported yet") @parametrize("dtype", (torch.float16, torch.bfloat16)) diff --git a/test/inductor/test_fused_attention.py b/test/inductor/test_fused_attention.py index c17d78f628a37..336a6c07946d7 100644 --- a/test/inductor/test_fused_attention.py +++ b/test/inductor/test_fused_attention.py @@ -161,7 +161,6 @@ def dot_prod_attention( check_train=False, ) - @skipIfRocm def _test_insignificant_strides(self): f32 = torch.float32 @@ -368,7 +367,6 @@ def sfdp_pattern_6(query, key, value, training): checkpoint_wrapper(sfdp_pattern_6), contains=False, has_dropout=True ) - @skipIfRocm def _test_sdpa_rewriter_7(self): def sfdp_pattern_7(query, key, value, training): q = query.permute(0, 2, 1, 3) @@ -410,7 +408,6 @@ def sfdp_pattern_7(query, key, value, training): atol=2e-3, ) - @skipIfRocm def _test_sdpa_rewriter_8(self): def sfdp_pattern_8(query, key, value): q = query.permute(0, 2, 1, 3) @@ -436,7 +433,6 @@ def sfdp_pattern_8(query, key, value): ) self._check_common(checkpoint_wrapper(sfdp_pattern_8), args, atol=2e-3) - @skipIfRocm def _test_sdpa_rewriter_9(self): def sfdp_pattern_9(query, key, value, training): q = query.permute(0, 2, 1, 3) @@ -478,7 +474,6 @@ def sfdp_pattern_9(query, key, value, training): atol=2e-3, ) - @skipIfRocm def _test_sdpa_rewriter_10(self): def sfdp_pattern_10(query, key, value): q = query.permute(0, 2, 1, 3) @@ -668,7 +663,6 @@ def dot_prod_attention( self._check_common(dot_prod_attention, check_train=False) - @skipIfRocm def _test_sdpa_rewriter_13(self, dtype): def dot_prod_attention( query: torch.Tensor, @@ -909,7 +903,6 @@ def dot_prod_attention( check_train=False, ) - @skipIfRocm def _test_sdpa_rewriter_19(self): def dot_prod_attention( query: torch.Tensor, diff --git a/test/inductor/test_graph_transform_observer.py b/test/inductor/test_graph_transform_observer.py index 081f46a9e5d85..1def72ae9e273 100644 --- a/test/inductor/test_graph_transform_observer.py +++ b/test/inductor/test_graph_transform_observer.py @@ -10,7 +10,7 @@ import torch._inductor.config as inductor_config from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FUSED_ATTENTION -from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm +from torch.testing._internal.common_utils import IS_LINUX from torch.testing._internal.inductor_utils import HAS_CUDA @@ -26,7 +26,6 @@ class TestGraphTransformObserver(TestCase): - @skipIfRocm def test_sdpa_rewriter(self): if not ( HAS_CUDA and PLATFORM_SUPPORTS_FUSED_ATTENTION and HAS_PYDOT and HAS_DOT diff --git a/test/inductor/test_group_batch_fusion.py b/test/inductor/test_group_batch_fusion.py index 9e40fc8e25a0e..6bde0305137be 100644 --- a/test/inductor/test_group_batch_fusion.py +++ b/test/inductor/test_group_batch_fusion.py @@ -9,7 +9,7 @@ import torch._inductor.fx_passes.group_batch_fusion from torch._dynamo.utils import counters, optimus_scuba_log from torch._inductor.test_case import run_tests, TestCase -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, requires_gpu try: @@ -20,8 +20,6 @@ except Exception: has_fbgemm = False -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") - class TestHighwaySelfGating(torch.nn.Module): def __init__( @@ -240,10 +238,8 @@ def forward(self, x): inputs = torch.split(x.to(self.device), 500, dim=1) x_split = torch.split(inputs[0].to(self.device), 50, dim=1) y_split = torch.split(inputs[1].to(self.device), 50, dim=1) - tanh_1 = [torch.tanh(x_split[i]) for i in range(len(x_split))] - tanh_2 = [torch.tanh(y_split[i]) for i in range(len(y_split))] - sigmoid_1 = [torch.sigmoid(tanh_1[i]) for i in range(len(tanh_1))] - sigmoid_2 = [torch.sigmoid(tanh_2[i]) for i in range(len(tanh_2))] + sigmoid_1 = [torch.sigmoid(x_split[i]) for i in range(len(x_split))] + sigmoid_2 = [torch.sigmoid(y_split[i]) for i in range(len(y_split))] relu_1 = [torch.nn.functional.relu(sigmoid_1[i]) for i in range(len(sigmoid_1))] relu_2 = [torch.nn.functional.relu(sigmoid_2[i]) for i in range(len(sigmoid_2))] add = [torch.add(relu_1[i], relu_2[i]) for i in range(len(relu_1))] @@ -272,7 +268,26 @@ def forward(self, x): return torch.cat(add, dim=1) -@requires_cuda +class TestMathOps(torch.nn.Module): + def __init__(self, device): + super().__init__() + self.device = device + + def forward(self, x): + inputs = [x.to(self.device) for i in range(10)] + others = [x.to(self.device) for i in range(10)] + clamp_input = [x.clamp(min=-1000.1, max=1000.1) for x in inputs] + clamp_other = [x.clamp(min=-1000.1, max=1000.1) for x in others] + nan_to_num_input = [torch.nan_to_num(x, 0.0) for x in clamp_input] + nan_to_num_other = [torch.nan_to_num(x, 0.0) for x in clamp_other] + detach_input = [x.detach() for x in nan_to_num_input] + detach_other = [x.detach() for x in nan_to_num_other] + stack_input = torch.stack(detach_input, dim=0) + stack_other = torch.stack(detach_other, dim=0) + return torch.stack((stack_input, stack_other), dim=0) + + +@requires_gpu() @torch._inductor.config.patch( pre_grad_fusion_options={ "batch_linear": {}, @@ -323,8 +338,8 @@ def test_group_linear_fusion(self): z = 10 for has_bias in [True, False]: counters.clear() - module = MyModule(z, has_bias).to("cuda") - input = [torch.randn(z, z, device="cuda")] + module = MyModule(z, has_bias).to(GPU_TYPE) + input = [torch.randn(z, z, device=GPU_TYPE)] traced = torch.compile(module) ref = module(*input) res = traced(*input) @@ -344,7 +359,7 @@ def test_group_linear_fusion(self): ) self.assertEqual( counters["inductor"]["batch_aten_add"], - 3, + 0, ) self.assertIn("GroupLinearFusion", optimus_scuba_log) counters.clear() @@ -352,8 +367,8 @@ def test_group_linear_fusion(self): @unittest.skipIf(not has_fbgemm, "requires fbgemm") def test_group_linear_fusion_different_shapes(self): counters.clear() - module = MyModule2().eval().to("cuda") - input = [torch.rand(4, 24, device="cuda")] + module = MyModule2().eval().to(GPU_TYPE) + input = [torch.rand(4, 24, device=GPU_TYPE)] traced = torch.compile(module) ref = module(*input) res = traced(*input) @@ -384,8 +399,8 @@ def test_batch_layer_norm_fusion(self): for has_weight in [True, False]: for has_bias in [True, False]: counters.clear() - module = MyModule3("cuda", has_weight, has_bias).to("cuda") - input = [torch.randn(2, 5, 50, device="cuda")] + module = MyModule3(GPU_TYPE, has_weight, has_bias).to(GPU_TYPE) + input = [torch.randn(2, 5, 50, device=GPU_TYPE)] traced = torch.compile(module) ref = module(*input) res = traced(*input) @@ -401,8 +416,8 @@ def test_batch_linear_lhs_fusion(self): z = 10 for has_bias in [True, False]: counters.clear() - module = MyModule4(z, "cuda", has_bias) - input = [torch.randn(20, z, device="cuda")] + module = MyModule4(z, GPU_TYPE, has_bias) + input = [torch.randn(20, z, device=GPU_TYPE)] traced = torch.compile(module) ref = module(*input) res = traced(*input) @@ -417,8 +432,8 @@ def test_batch_linear_lhs_fusion(self): def test_batch_linear_pre_grad_fusion(self): for has_bias in [True, False]: counters.clear() - module = MyModule5("cuda", has_bias) - input = [torch.randn(50, 500, device="cuda")] + module = MyModule5(GPU_TYPE, has_bias) + input = [torch.randn(50, 500, device=GPU_TYPE)] traced = torch.compile(module) ref = module(*input) res = traced(*input) @@ -432,13 +447,12 @@ def test_batch_linear_pre_grad_fusion(self): def test_pointwise_op_fusion(self): counters.clear() - module = TestPoitwiseOps("cuda") - input = [torch.randn(50, 1000, requires_grad=True, device="cuda")] + module = TestPoitwiseOps(GPU_TYPE) + input = [torch.randn(50, 1000, requires_grad=True, device=GPU_TYPE)] traced = torch.compile(module) ref = module(*input) res = traced(*input) self.compare_pred(module, traced, input) - self.assertEqual(counters["inductor"]["batch_tanh"], 1) self.assertEqual(counters["inductor"]["batch_relu"], 1) self.assertEqual(counters["inductor"]["batch_sigmoid"], 1) self.assertEqual(counters["inductor"]["batch_aten_add"], 1) @@ -451,7 +465,7 @@ def test_pointwise_op_fusion(self): self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() - @requires_cuda + @requires_gpu() @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ @@ -463,8 +477,8 @@ def test_pointwise_op_fusion(self): ) def test_pointwise_op_fusion_post_grad(self): counters.clear() - module = TestPoitwiseOpsPostGrad("cuda") - input = [torch.randn(50, 1000, requires_grad=True, device="cuda")] + module = TestPoitwiseOpsPostGrad(GPU_TYPE) + input = [torch.randn(50, 1000, requires_grad=True, device=GPU_TYPE)] traced = torch.compile(module) ref = module(*input) res = traced(*input) @@ -472,14 +486,14 @@ def test_pointwise_op_fusion_post_grad(self): self.assertEqual(counters["inductor"]["batch_aten_tanh"], 1) self.assertEqual(counters["inductor"]["batch_aten_relu"], 1) self.assertEqual(counters["inductor"]["batch_aten_sigmoid"], 1) - self.assertEqual(counters["inductor"]["unbind_stack_aten_pass"], 2) + self.assertEqual(counters["inductor"]["unbind_stack_aten_pass"], 1) ref.sum().backward() res.sum().backward() self.compare_parameters(module, traced, rtol=1e-8, atol=1e-8) self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() - @requires_cuda + @requires_gpu() @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ @@ -497,10 +511,10 @@ def test_pointwise_op_fusion_post_grad(self): def test_gate_fusion_post_grad(self): counters.clear() size = 20 - module = TestHighwaySelfGating(d_model=10, size=size) + module = TestHighwaySelfGating(d_model=10, size=size, device=GPU_TYPE) input = [ [ - torch.randn(10, 10, requires_grad=True, device="cuda") + torch.randn(10, 10, requires_grad=True, device=GPU_TYPE) for i in range(size) ] ] @@ -520,6 +534,39 @@ def test_gate_fusion_post_grad(self): self.compare_gradients(module, traced, rtol=1e-8, atol=1e-8) counters.clear() + @requires_gpu() + @torch._inductor.config.patch( + pre_grad_fusion_options={ + "normalization_pass": {}, + "batch_detach": {}, + "batch_nan_to_num": {}, + "batch_clamp": {}, + "unbind_stack_pass": {}, + "unbind_stack_to_slices_pass": {}, + }, + post_grad_fusion_options={}, + ) + def test_math_op_fusion(self): + counters.clear() + module = TestMathOps(GPU_TYPE) + input = [ + torch.tensor( + [float("nan"), float("inf"), -float("inf"), 3.14], device=GPU_TYPE + ) + ] + traced = torch.compile(module) + ref = module(*input) + res = traced(*input) + self.compare_pred(module, traced, input) + self.assertEqual(counters["inductor"]["normalization_pass"], 3) + self.assertEqual(counters["inductor"]["batch_clamp"], 1) + self.assertEqual(counters["inductor"]["batch_detach"], 1) + self.assertEqual(counters["inductor"]["batch_nan_to_num"], 1) + self.assertEqual(counters["inductor"]["unbind_stack_to_slices_pass"], 2) + self.assertEqual(counters["inductor"]["unbind_stack_pass"], 2) + self.assertTrue(torch.allclose(ref, res)) + counters.clear() + class TestBMMFusionModule(torch.nn.Module): def __init__(self) -> None: @@ -538,16 +585,16 @@ def forward(self, inputs): return output -@requires_cuda +@requires_gpu() @torch._inductor.config.patch( post_grad_fusion_options={"batch_linear_post_grad": {"require_fbgemm": False}} ) class TestPostGradBatchLinearFusion(TestCase): def test_batch_linear_post_grad_fusion(self): - pt1_module = TestBMMFusionModule().cuda() + pt1_module = TestBMMFusionModule().to(GPU_TYPE) inputs = [] for _ in range(10): - inputs.append(torch.randn(10, 10).cuda()) + inputs.append(torch.randn(10, 10).to(GPU_TYPE)) eager_output = pt1_module(inputs) pt2_module = torch.compile(pt1_module) pt2_output = pt2_module(inputs) diff --git a/test/inductor/test_inductor_freezing.py b/test/inductor/test_inductor_freezing.py index 88f5530b57870..c3c2c95d99aed 100644 --- a/test/inductor/test_inductor_freezing.py +++ b/test/inductor/test_inductor_freezing.py @@ -16,7 +16,7 @@ from torch._inductor.utils import override_lowering, run_and_get_code from torch.testing import FileCheck from torch.testing._internal.common_cuda import SM80OrLater -from torch.testing._internal.common_utils import skipIfRocm +from torch.testing._internal.common_utils import IS_FBCODE, skipIfRocm, skipIfXpu # Make the helper files in test/ importable @@ -25,7 +25,7 @@ from inductor.test_torchinductor import ( # @manual=fbcode//caffe2/test/inductor:test_inductor-library check_model, - check_model_cuda, + check_model_gpu, copy_tests, ) from torch.testing._internal.common_utils import TEST_WITH_ASAN, TEST_WITH_ROCM @@ -34,12 +34,16 @@ importlib.import_module("functorch") importlib.import_module("filelock") -from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CPU, + HAS_GPU, + requires_gpu, +) aten = torch.ops.aten prims = torch.ops.prims -requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") class TestCase(InductorTestCase): @@ -250,7 +254,7 @@ def foo(mod, inp): return mod(inp) with torch.no_grad(): - with self.autocast(): + with torch.autocast(self.device): out_eager = mod(inp) out_compiled, code = run_and_get_code(foo, mod, inp) @@ -389,7 +393,7 @@ def fn(a): torch._dynamo.mark_dynamic(inp2, 1) self.assertEqual(fn(inp2), fn_opt(inp2)) - @requires_cuda + @requires_gpu() def test_conv_multiple_uses(self): from torch import nn @@ -404,10 +408,10 @@ def forward(self, x, y): return self.conv1(x) + self.bn1(self.conv1(y)) model = ToyModel() - model.eval().cuda() + model.eval().to(GPU_TYPE) - a = torch.rand(64, 1, 32, 32).cuda() - b = torch.rand(64, 1, 32, 32).cuda() + a = torch.rand(64, 1, 32, 32).to(GPU_TYPE) + b = torch.rand(64, 1, 32, 32).to(GPU_TYPE) output = model(a, b) @@ -441,7 +445,7 @@ def test_folded_conv_bn(self): if self.device == "cpu" and dtype == torch.float16: continue - if self.device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: + if self.device == GPU_TYPE and dtype == torch.bfloat16 and not SM80OrLater: continue mod = ( @@ -468,7 +472,7 @@ def foo(mod, x): out_optimized_for_infernece, code = run_and_get_code(foo, mod, x) # we unfuse the conv bias, but it should only have one constant in the kernel - if self.device == "cuda": + if self.device == GPU_TYPE: FileCheck().check_not(".run(").check("conv").check(".run(").check_same( "frozen_param" ).check_not("frozen_param").check_next("return").run(code[0]) @@ -486,7 +490,7 @@ def test_folded_conv_bn_hardswish(self): if self.device == "cpu" and dtype == torch.float16: continue - if self.device == "cuda" and dtype == torch.bfloat16 and not SM80OrLater: + if self.device == GPU_TYPE and dtype == torch.bfloat16 and not SM80OrLater: continue mod = ( @@ -513,7 +517,7 @@ def foo(mod, x): out_optimized_for_infernece, code = run_and_get_code(foo, mod, x) # we unfuse the conv bias, but it should only have one constant in the kernel - if self.device == "cuda": + if self.device == GPU_TYPE: FileCheck().check_not(".run(").check("conv").check(".run(").check_same( "frozen_param" ).check_not("frozen_param").check_next("return").run(code[0]) @@ -648,7 +652,7 @@ def foo(mod, x): @torch._inductor.config.patch(layout_optimization=False) def test_dont_change_dtype_folding(self): - dtype = torch.float16 if self.device == "cuda" else torch.bfloat16 + dtype = torch.float16 if self.device == GPU_TYPE else torch.bfloat16 mod = ( torch.nn.Conv2d(3, 32, bias=None, kernel_size=3, stride=2) @@ -742,6 +746,8 @@ def foo(mod, inp): mod_eager = mod(x) self.assertEqual(foo(mod, x), mod_eager) + @skipIfXpu + @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") def test_cpp_wrapper(self): mod = ConvBN(3, 32, kernel_size=3, stride=2).eval().to(self.device) @@ -835,7 +841,7 @@ def my_inner_compile(gm, example_inputs, *args, **kwargs): # in the joint graph rather than torch.ops.aten.convolution.default. # Currently we only handle aten.convolution.default in layout # optimization. That's why the count may be 0 here for CPU. - if self.device == "cuda": + if self.device == GPU_TYPE: self.assertTrue(nconv == 1) def test_unequal_bias_horizontal_addmm_fusion(self): @@ -956,14 +962,13 @@ class FreezingCpuTests(TestCase): copy_tests(OptimizeForInferenceTemplate, FreezingCpuTests, "cpu") -if HAS_CUDA and not TEST_WITH_ASAN: +if HAS_GPU and not TEST_WITH_ASAN: - class FreezingCudaTests(TestCase): - common = check_model_cuda - device = "cuda" - autocast = torch.cuda.amp.autocast + class FreezingGpuTests(TestCase): + common = check_model_gpu + device = GPU_TYPE - copy_tests(OptimizeForInferenceTemplate, FreezingCudaTests, "cuda") + copy_tests(OptimizeForInferenceTemplate, FreezingGpuTests, GPU_TYPE) del OptimizeForInferenceTemplate @@ -972,5 +977,5 @@ class FreezingCudaTests(TestCase): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CPU or HAS_CUDA: + if HAS_CPU or HAS_GPU: run_tests(needs="filelock") diff --git a/test/inductor/test_layout_optim.py b/test/inductor/test_layout_optim.py index bd698d5b23b55..946cd45413f05 100644 --- a/test/inductor/test_layout_optim.py +++ b/test/inductor/test_layout_optim.py @@ -9,7 +9,8 @@ from torch._inductor import config from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_cuda import tf32_off -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.common_utils import skipIfXpu +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU USE_DDP_WRAPPER = os.environ.get("USE_DDP_WRAPPER", "1") == "1" @@ -33,6 +34,7 @@ def get_example_inputs(self): return (torch.rand(2, 3, 16, 16),) +@skipIfXpu(msg="ccl doesn't currently work on the XPU stack") class TestLayoutOptim(TestCase): @classmethod def setUpClass(cls): @@ -45,8 +47,12 @@ def setUpClass(cls): for retry_no in range(tot_retry): try: port = random.randint(10000, 60000) + if GPU_TYPE == "cuda": + backend = "nccl" + elif GPU_TYPE == "xpu": + backend = "ccl" dist.init_process_group( - backend="nccl", + backend=backend, init_method=f"tcp://localhost:{port}", world_size=1, rank=0, @@ -85,8 +91,8 @@ def f(*inp): return m manual_graph_break = not use_ddp_wrapper - mod = model_class(manual_graph_break=manual_graph_break).cuda() - inp = [t.cuda() for t in mod.get_example_inputs()] + mod = model_class(manual_graph_break=manual_graph_break).to(GPU_TYPE) + inp = [t.to(GPU_TYPE) for t in mod.get_example_inputs()] expected_out = wrap_mod(mod)(*inp) fp64_mod = copy.deepcopy(mod).to(torch.float64) @@ -167,8 +173,8 @@ def forward(self, x): def get_example_inputs(self): return (torch.randn(2, 3, 5, 5),) - mod = Model().cuda() - inp = [t.cuda() for t in mod.get_example_inputs()] + mod = Model().to(GPU_TYPE) + inp = [t.to(GPU_TYPE) for t in mod.get_example_inputs()] out = mod(*inp) opt_mod = torch.compile(mod) @@ -206,9 +212,9 @@ def f(x): y = x.view(3, 2) y.mul_(2) - x = torch.ones(2, 3).cuda() + x = torch.ones(2, 3).to(GPU_TYPE) f(x) - self.assertTrue(torch.equal(x, torch.ones(2, 3).cuda() * 2)) + self.assertTrue(torch.equal(x, torch.ones(2, 3).to(GPU_TYPE) * 2)) def test_mutate_base(self): """ @@ -225,9 +231,9 @@ def f(x): x.mul_(2) return y - x = torch.ones(2, 3).cuda() + x = torch.ones(2, 3).to(GPU_TYPE) y = f(x) - self.assertTrue(torch.equal(y, torch.ones(3, 2).cuda() * 2)) + self.assertTrue(torch.equal(y, torch.ones(3, 2).to(GPU_TYPE) * 2)) @tf32_off() def test_mutate_base_for_conv_output(self): @@ -279,8 +285,8 @@ def f(a, b): return z for size in [4, 8, 16]: - a = torch.randn(2, size, requires_grad=True).cuda() - b = torch.randn(2, size).cuda() + a = torch.randn(2, size, requires_grad=True).to(GPU_TYPE) + b = torch.randn(2, size).to(GPU_TYPE) actual = torch.compile(f, dynamic=True)(a, b) self.assertTrue(torch.allclose(f(a, b), actual)) @@ -312,7 +318,7 @@ def forward(self, x: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: loss = torch.nn.functional.cross_entropy(logits, targets) return loss - device = "cuda" + device = GPU_TYPE batch_size = 48 seq_len = 144 input_dim = 39 @@ -336,5 +342,5 @@ def forward(self, x: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: if __name__ == "__main__": - if HAS_CUDA: + if HAS_GPU: run_tests() diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index f0d931ed41994..fe1c821ea9d3e 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -18,13 +18,14 @@ from torch._inductor.utils import sympy_index_symbol from torch._inductor.virtualized import ops, V from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8 -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.common_device_type import expectedFailureXPU +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU from torch.utils._pytree import tree_map from torch.utils._sympy.functions import ModularIndexing -if HAS_CUDA: - torch.set_default_device("cuda") +if HAS_GPU: + torch.set_default_device(GPU_TYPE) class MockScheduler: @@ -76,7 +77,13 @@ def _create_computed_buffer_ax2(sizes=(32, 64), strides=None): box_a = ir.TensorBox.create( ir.Buffer( - "a", ir.FixedLayout(torch.device("cuda"), torch.float32, sizes, strides) + name="a", + layout=ir.FixedLayout( + torch.device(GPU_TYPE), + dtype=torch.float32, + size=sizes, + stride=strides, + ), ) ) box_a_loader = box_a.make_loader() @@ -139,7 +146,7 @@ def inner_fn(index): ) buf = ir.Pointwise.create( - device=torch.device("cuda"), + device=torch.device(GPU_TYPE), dtype=torch.float32, inner_fn=inner_fn, ranges=[128, 4, 49, 49], @@ -174,6 +181,8 @@ def inner_fn(index): } ) class LoopOrderingTest(TestCase): + device = GPU_TYPE + def do_acc_test(self, f, *args, cast_fp8=True): expect = f(*args) actual = torch.compile(f)(*args) @@ -217,7 +226,7 @@ def f(x, y): A, B = 20, 30 # Make the first 2 dimension not able to merge on purpose so that # ComputedBuffer.iter_reoredering_reindex will be updated. - x = rand_strided([A, A, B], [B, B * A + 300, 1], device="cuda") + x = rand_strided([A, A, B], [B, B * A + 300, 1], device=GPU_TYPE) y = torch.randn(A, A) self.do_acc_test(f, x, y) @@ -228,6 +237,8 @@ def f(x, y): expected_num_bytes *= x.itemsize self.assertEqual(expected_num_bytes, metrics.num_bytes_accessed) + # xpu generate 2 kernels + @expectedFailureXPU def test_apbt_realize(self): M = 1024 N = 2048 @@ -247,6 +258,8 @@ def f(x, y): self.do_acc_test(f, x, y) self.assertEqual(1, metrics.generated_kernel_count) + # xpu generate 2 kernels + @expectedFailureXPU def test_sum_and_t(self): N = 1024 @@ -257,6 +270,8 @@ def f(x): self.do_acc_test(f, x) self.assertEqual(1, metrics.generated_kernel_count) + # xpu generate 2 kernels + @expectedFailureXPU def test_pw_outer_red(self): def f(x): x = realize(x + 1) @@ -267,6 +282,8 @@ def f(x): self.do_acc_test(f, x) self.assertEqual(1, metrics.generated_kernel_count) + # xpu generate 2 kernels + @expectedFailureXPU def test_pw_outer_red_2(self): """ The pointwise kernel is a fused kernel @@ -340,6 +357,8 @@ def f(*args): # some buffer is used before being defined. f(input_ids, labels, position_ids) + # xpu generate 2 kernels + @expectedFailureXPU def test_different_broadcast_shapes(self): def f(x, y, c): return x + c, y + c @@ -387,13 +406,53 @@ def f(x, scale): return x, x_t x = torch.randn(4096, 4096, dtype=torch.bfloat16) - scale = torch.Tensor([10.0]).cuda() + scale = torch.Tensor([10.0]).to(GPU_TYPE) E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max self.do_acc_test(f, x, scale) self.assertEqual(1, metrics.generated_kernel_count) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "FP8 requires H100+ and MI300+") + def test_fp8_pattern_2(self): + """ + This test repros the fp8 fusion relation issue here: + https://github.com/pytorch/pytorch/issues/133242 + """ + ref_dtype = torch.bfloat16 + M, K = 4096, 4096 + + input_tensor = torch.randn( + M, K, device="cuda", dtype=ref_dtype, requires_grad=False + ) + scale = torch.Tensor([10.0]).to("cuda") + + E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max + E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max + + def test_pattern2(tensor_x_inp, scale_x): + tensor_x = tensor_x_inp * scale_x + tensor_x = tensor_x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) + tensor_fp8 = tensor_x.to(torch.float8_e4m3fn) + + tensor_x_t = (tensor_x_inp * scale_x).t() + tensor_x_t = tensor_x_t.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) + tensor_fp8_t = tensor_x_t.to(torch.float8_e4m3fn) + + tensor_fp8_t = tensor_fp8_t.contiguous().t() + + return (tensor_fp8, tensor_fp8_t) + + test_pattern = torch.compile(test_pattern2) + tensor_fp8, tensor_fp8_t = test_pattern(input_tensor, scale) + + self.assertEqual(1, metrics.generated_kernel_count) + + expected_numbytes = scale.nbytes # scalar + expected_numbytes += input_tensor.nbytes # input + expected_numbytes += tensor_fp8.nbytes + tensor_fp8_t.nbytes # output + self.assertEqual(expected_numbytes, metrics.num_bytes_accessed) + if __name__ == "__main__": - if HAS_CUDA: + if HAS_GPU: run_tests() diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index be8fdf99d74d0..429759c945941 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -21,6 +21,9 @@ AlgorithmSelectorCache, TritonTemplateCaller, ) + + +aten = torch.ops.aten from torch._inductor.test_case import run_tests, TestCase from torch._inductor.utils import fresh_inductor_cache, run_and_get_code from torch._inductor.virtualized import V @@ -30,6 +33,7 @@ instantiate_parametrized_tests, parametrize, skipIfRocm, + TEST_WITH_ROCM, ) from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA @@ -72,7 +76,10 @@ def benchmark(self, *args, out): @instantiate_parametrized_tests class TestMaxAutotune(TestCase): def _create_buffer(self, name, shape): - return Buffer(name, FixedLayout(torch.device("cuda:0"), torch.float32, shape)) + return Buffer( + name=name, + layout=FixedLayout(torch.device("cuda:0"), dtype=torch.float32, size=shape), + ) def test_benchmark_choice_in_subproc(self): gm = make_fx( @@ -135,7 +142,7 @@ def test_benchmark_choice_fail_in_subproc(self): out = AlgorithmSelectorCache.benchmark_example_value(layout) expected_out = (mat1 @ mat2) + (mat3 @ mat4) - choice = FailChoiceCaller("fail_choice_caller", [], None) + choice = FailChoiceCaller("fail_choice_caller", [], None, description="") # use a tensor since python list is not synced back timings = torch.zeros(3, dtype=torch.float32) @@ -232,7 +239,7 @@ def test_precompilation_threads(self): class FakeChoiceCaller(ChoiceCaller): def __init__(self) -> None: - super().__init__("none", [], Mock()) + super().__init__("none", [], Mock(), description="") self.thread_id = None def precompile(self): @@ -599,6 +606,72 @@ def f(x, y, z): z = torch.randint(0, 10, (224,)).to(device="cuda") f(x, y, z) + def _test_cat_max_autotune_impl(self, using_triton_mm): + def f(x, y): + y = torch.cos(y) + x = torch.mm(x, x) + return torch.cat([x, y]) + + f_c = torch.compile(mode="max-autotune-no-cudagraphs")(f) + inps = [torch.randn(32, 32, device="cuda"), torch.randn(32, 32, device="cuda")] + out, code = run_and_get_code(f_c, inps[0], inps[1]) + self.assertEqual(f_c(*inps), f(*inps), atol=0.03, rtol=0.25) + + # mm kernel, and cos kernel + count = 2 if using_triton_mm else 1 + FileCheck().check("call(").check_count(".run", count, exactly=True).run(code[0]) + + def f(x, y): + y = torch.cos(y) + x = torch.mm(x, x) + out = torch.cat([x, y]) + return out, x + 1 + + f_c = torch.compile(mode="max-autotune-no-cudagraphs")(f) + out, code = run_and_get_code(f_c, inps[0], inps[1]) + self.assertEqual(f_c(*inps), f(*inps), atol=0.03, rtol=0.25) + FileCheck().check("call(").check_count(".run", 2, exactly=True).run(code[0]) + + def f(x, y): + y = torch.cos(y) + x = torch.mm(x, x) + return torch.cat([x, y]), torch.cat([y, x]) + + f_c = torch.compile(mode="max-autotune-no-cudagraphs")(f) + self.assertEqual(f_c(*inps), f(*inps), atol=0.03, rtol=0.25) + + @config.patch({"test_configs.force_extern_kernel_in_multi_template": True}) + def test_cat_max_autotune_extern(self): + self._test_cat_max_autotune_impl(using_triton_mm=False) + + @config.patch(max_autotune_gemm_backends="TRITON") + def test_cat_max_autotune_triton(self): + self._test_cat_max_autotune_impl(using_triton_mm=True) + + def test_conv_cat(self): + class ToyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + 3, 64, kernel_size=3, stride=1, padding=1, bias=False + ) + + def forward(self, x): + x = self.conv(x) + return torch.cat((x, x + 1)) + + with torch.no_grad(): + m = ToyModel().to(device="cuda") + input_tensor = torch.randn(32, 3, 64, 64).to(device="cuda") + + # convolution is not currently plannable + m = torch.compile(m, mode="max-autotune-no-cudagraphs") + out, code = run_and_get_code(m, input_tensor) + self.assertEqual(out, m(input_tensor)) + + if not TEST_WITH_ROCM: + FileCheck().check("triton_poi_fused_cat_2.run").run(code[0]) + def test_conv3d(self): fn = torch.nn.functional.conv3d image = torch.randn([1, 3, 8, 16, 32]) diff --git a/test/inductor/test_memory.py b/test/inductor/test_memory.py index 87465e2016e4b..185095673a6b5 100644 --- a/test/inductor/test_memory.py +++ b/test/inductor/test_memory.py @@ -39,6 +39,11 @@ def forward(self, x): return t3.sum() + t4.sum() +# The tests in this class uses very small tensors. The default +# score_fusion_memory threshold will cause different fusion decisions and +# generate a different wrapper. Override the threshold to make these tests +# happy. +@config.patch("score_fusion_memory_threshold", 1) class TestOperatorReorderForPeakMemory(TestCase): def setUp(self): super().setUp() @@ -142,10 +147,10 @@ def reorder_with_only_bfs( FileCheck() .check("def call(args):") .check("buf0 = ") - .check("buf2 = ") .check("buf1 = ") - .check("buf4 = ") + .check("buf2 = ") .check("buf3 = ") + .check("buf4 = ") .check("buf5 = ") .check("buf7 = ") .run(code) diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index 6ded0991319f1..43da48156f366 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -3,8 +3,14 @@ import sys import unittest -from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, skipIfRocm -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.common_device_type import expectedFailureXPU +from torch.testing._internal.common_utils import ( + IS_CI, + IS_WINDOWS, + skipIfRocm, + skipIfXpu, +) +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu if IS_WINDOWS and IS_CI: @@ -24,9 +30,11 @@ from torch.export import Dim -@unittest.skipIf(not HAS_CUDA, "Inductor+gpu needs triton and CUDA") +@requires_gpu() @config.patch(memory_planning=True) class TestMemoryPlanning(TestCase): + device = GPU_TYPE + def _generate(self, *, device): """ Generate a simple test case that has multiple simultaneously-live intermediate tensors. @@ -46,12 +54,14 @@ def forward(self, x, y, z): return (Foo(), (x, y, z)) def test_python_wrapper(self): - f, args = self._generate(device="cuda") + f, args = self._generate(device=GPU_TYPE) compiled = torch.compile(f, dynamic=True) result, code = run_and_get_cpp_code(compiled, *args) FileCheck().check( - "pool1 = empty_strided_cuda(((4*s0*s1) + (align(4*(s0*s0))), ), (1, )" + "pool1 = empty_strided_" + + GPU_TYPE + + "(((4*s0*s1) + (align(4*(s0*s0))), ), (1, )" ).check_next( "buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))" ).check( @@ -61,25 +71,25 @@ def test_python_wrapper(self): ) self.assertTrue(same(f(*args), result)) + @expectedFailureXPU def test_cpp_wrapper(self): - f, args = self._generate(device="cuda") + f, args = self._generate(device=GPU_TYPE) compiled = torch.compile(f, dynamic=True) - with config.patch({"cpp_wrapper": True, "abi_compatible": False}): + with config.patch({"cpp_wrapper": True}): result, code = run_and_get_cpp_code(compiled, *args) FileCheck().check( - "pool1 = at::detail::empty_strided_cuda({(4L*s0*s1) + (align(4L*(static_cast(s0*s0)))), }, {1L, }" - ).check_next( - "auto buf0 = alloc_from_pool(pool1, 0, at::kFloat, {s0, s0}, {s0, 1L});" - ).check( - "auto buf1 = alloc_from_pool(pool1, align(4L*(static_cast(s0*s0)))," + "aoti_torch__alloc_from_pool(pool1, 0, cached_torch_dtype_float32, 2, int_array_4, int_array_5, &tmp_tensor_handle_1)" + ).check_next("auto buf0 = RAIIAtenTensorHandle(tmp_tensor_handle_1);").check( + "auto buf1 = RAIIAtenTensorHandle(tmp_tensor_handle_2);" ).run( code ) self.assertTrue(same(f(*args), result)) @skipIfRocm(msg="test_aot_inductor doesn't work on ROCm") - def test_abi_compatible(self): + @skipIfXpu(msg="aoti doesn't work on XPU") + def test_aoti(self): try: from .test_aot_inductor import AOTIRunnerUtil except ImportError: @@ -87,15 +97,12 @@ def test_abi_compatible(self): AOTIRunnerUtil, ) - f, args = self._generate(device="cuda") + f, args = self._generate(device=GPU_TYPE) dim0_x = Dim("dim0_x", min=1, max=2048) dynamic_shapes = ({0: dim0_x}, None, None) - with config.patch("abi_compatible", True): - result, code = run_and_get_cpp_code( - lambda: AOTIRunnerUtil.run( - "cuda", f, args, dynamic_shapes=dynamic_shapes - ) - ) + result, code = run_and_get_cpp_code( + lambda: AOTIRunnerUtil.run(GPU_TYPE, f, args, dynamic_shapes=dynamic_shapes) + ) FileCheck().check( "int64_t int_array_2[] = {24L + (align(12L*s0)), };" @@ -120,5 +127,5 @@ def test_abi_compatible(self): if __name__ == "__main__": - if HAS_CUDA: + if HAS_GPU: run_tests() diff --git a/test/inductor/test_metrics.py b/test/inductor/test_metrics.py index bec9943eafd7e..90d6b0132e176 100644 --- a/test/inductor/test_metrics.py +++ b/test/inductor/test_metrics.py @@ -18,7 +18,7 @@ 'device': 0, 'device_type': 'GPU_TYPE', 'constants': {}, - 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=(), ids_of_folded_args=(), divisible_by_8=(2, 3))]}, + 'configs': [AttrsDescriptor(divisible_by_16=(0, 1, 2, 3), equal_to_1=())]}, inductor_meta={ 'autotune_hints': set(), 'kernel_name': 'triton_red_fused_add_sum_2', diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py index d5fd2eb94cc44..d0e13ab26c4aa 100644 --- a/test/inductor/test_mkldnn_pattern_matcher.py +++ b/test/inductor/test_mkldnn_pattern_matcher.py @@ -940,27 +940,115 @@ def matcher_check_fn(): matcher_check_fn=matcher_check_fn, ) + def _qconv2d_add_cpu_test_helper2(self, use_relu=False, int8_mixed_bf16=False): + r""" + This testcase will quantize two Conv2d->Add patterns as: + + Conv(X) extra input + \ / + Add + | + Optional(relu) + | + Y + + , and + + extra input Conv(X) + \ / + Add + | + Optional(relu) + | + Y + """ + + class M(torch.nn.Module): + def __init__( + self, + add_fn, + use_relu, + swap_inputs, + **kwargs, + ): + super().__init__() + self.conv1 = torch.nn.Conv2d(3, 6, kernel_size=3, stride=1) + self.add_fn = add_fn + self.relu = torch.nn.ReLU() + self.conv2 = torch.nn.Conv2d(6, 6, kernel_size=3, stride=1, bias=False) + self.add_fn2 = add_fn + self.relu2 = torch.nn.ReLU() + self.use_relu = use_relu + self.swap_inputs = swap_inputs + + def forward(self, x, x2, x3): + x1 = self.conv1(x) + if self.swap_inputs: + tmp = self.add_fn(x2, x1) + else: + tmp = self.add_fn(x1, x2) + if self.use_relu: + tmp = self.relu(tmp) + tmp1 = self.conv2(tmp) + if self.swap_inputs: + res = self.add_fn2(x3, tmp1) + else: + res = self.add_fn2(tmp1, x3) + if self.use_relu: + res = self.relu2(res) + return res + + for add_fn, swap_inputs in itertools.product( + quantization_add_fn_list + quantization_inplace_add_fn_list, [False, True] + ): + mod = M(add_fn, use_relu, swap_inputs).eval() + x = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=False) + x2 = torch.randn((1, 6, 6, 6), dtype=torch.float32, requires_grad=False) + x3 = torch.randn((1, 6, 4, 4), dtype=torch.float32, requires_grad=False) + + def matcher_check_fn(): + # 1. Dequant-Conv2D pattern matched in quantization weight prepack * 2 + self.assertEqual( + counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 2 + ) + # 2. Qconv2d Binary Unary fusion in post-grad fusion pass * 2 + self.assertEqual( + counters["inductor"]["qconv2d_binary_matcher_count"], 2 + ) + + self._test_common( + mod, + (x, x2, x3), + check_quantization=True, + check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, + matcher_check_fn=matcher_check_fn, + ) + @skipIfNoDynamoSupport @skipIfNoONEDNN def test_qconv2d_add_cpu(self): self._qconv2d_add_cpu_test_helper() + self._qconv2d_add_cpu_test_helper2() @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN def test_qconv2d_add_int8_mixed_bf16(self): self._qconv2d_add_cpu_test_helper(int8_mixed_bf16=True) + self._qconv2d_add_cpu_test_helper2(int8_mixed_bf16=True) @skipIfNoDynamoSupport @skipIfNoONEDNN def test_qconv2d_add_relu_cpu(self): self._qconv2d_add_cpu_test_helper(use_relu=True) + self._qconv2d_add_cpu_test_helper2(use_relu=True) @skipIfNoDynamoSupport @skipIfNoONEDNNBF16 @skipIfNoONEDNN def test_qconv2d_add_relu_int8_mixed_bf16(self): self._qconv2d_add_cpu_test_helper(use_relu=True, int8_mixed_bf16=True) + self._qconv2d_add_cpu_test_helper2(use_relu=True, int8_mixed_bf16=True) @skipIfNoDynamoSupport @skipIfNoONEDNN @@ -1457,9 +1545,11 @@ def _default_matcher_check_fn(): inputs, check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, check_quantization=True, - matcher_check_fn=matcher_check_fn - if matcher_check_fn is not None - else _default_matcher_check_fn, + matcher_check_fn=( + matcher_check_fn + if matcher_check_fn is not None + else _default_matcher_check_fn + ), is_qat=is_qat, is_dynamic=is_dynamic, ) @@ -1806,17 +1896,12 @@ def matcher_check_fn(): mod, (v,), [ - "torch.ops.onednn.qlinear_pointwise.tensor", - "torch.ops.onednn.qlinear_pointwise.binary", - ] - if config.abi_compatible - else [ - "op_onednn_qlinear_pointwise_tensor.call", - "op_onednn_qlinear_pointwise_binary_tensor.call", + "aoti_torch_cpu__qlinear_pointwise_tensor", + "aoti_torch_cpu__qlinear_pointwise_binary_tensor", ], [], check_quantization=True, - num_include_ops=[4, 4] if config.abi_compatible else [2, 2], + num_include_ops=[2, 2], ) else: # For python wrapper @@ -1895,9 +1980,11 @@ def default_matcher_check_fn(): inputs, check_autocast=torch.bfloat16 if int8_mixed_bf16 else torch.float, check_quantization=True, - matcher_check_fn=matcher_check_fn - if matcher_check_fn is not None - else default_matcher_check_fn, + matcher_check_fn=( + matcher_check_fn + if matcher_check_fn is not None + else default_matcher_check_fn + ), is_dynamic=is_dynamic, ) diff --git a/test/inductor/test_pad_mm.py b/test/inductor/test_pad_mm.py index 752676a1a212a..b2d1c51fcf956 100644 --- a/test/inductor/test_pad_mm.py +++ b/test/inductor/test_pad_mm.py @@ -4,6 +4,7 @@ import torch import torch._inductor.config as inductor_config from torch._dynamo.testing import rand_strided +from torch._dynamo.utils import counters from torch._inductor.fx_passes.pad_mm import ( get_alignment_size, get_pad_cache, @@ -489,6 +490,36 @@ def mm(mat1, mat2): assert torch.allclose(res2, mm_expected_result), "MM results are not identical" + @fresh_inductor_cache() + @inductor_config.patch( + { + "triton.unique_kernel_names": "original_aten", + "max_autotune_gemm_backends": "TRITON", + "shape_padding": True, + } + ) + def test_original_aten_preserved_pad_mm(self): + def fn(x, y): + return x @ y + + args = [ + torch.randn(2**4, 2**14 - 1, device="cuda", dtype=torch.float16), + torch.randn(2**14 - 1, 2**4, device="cuda", dtype=torch.float16), + ] + + counters.clear() + + with unittest.mock.patch( + "torch._inductor.fx_passes.pad_mm._skip_do_bench_times", True + ): + opt_fn = torch.compile(fn, mode="max-autotune") + ret, code = run_and_get_code(opt_fn, *args) + self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1) + + # The mm kernel should use a template (because we set max_autotune_gemm_backends = TRITON). + # Its name should contain `mm` because `mm` was the original aten op where the mm came from. + FileCheck().check("def triton_tem_fused_mm").run(code[0]) + if __name__ == "__main__": if HAS_CUDA: diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index 9ae3dd3a125df..fd976f69d93b0 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -8,6 +8,7 @@ import torch from torch import nn, Tensor from torch._dynamo.convert_frame import maybe_cprofile +from torch._dynamo.device_interface import get_interface_for_device from torch._dynamo.test_case import run_tests, TestCase from torch._dynamo.testing import rand_strided, reduce_to_scalar_loss from torch._inductor import config, ir, metrics @@ -17,10 +18,9 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, - requires_cuda, serialTest, ) -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu DO_PERF_TEST = os.environ.get("DO_PERF_TEST") == "1" @@ -91,19 +91,19 @@ def forward_and_backward_pass(m, inputs): "triton.cudagraphs": USE_CUDA_GRAPHS, } ) -@requires_cuda +@requires_gpu() class TestCaseBase(TestCase): @classmethod def setUpClass(cls): - if HAS_CUDA: + if HAS_GPU: cls.prior_float32_matmul_precision = torch.get_float32_matmul_precision() cls.prior_default_device = torch.get_default_device() torch.set_float32_matmul_precision("high") - torch.set_default_device("cuda") + torch.set_default_device(GPU_TYPE) @classmethod def tearDownClass(cls): - if HAS_CUDA: + if HAS_GPU: torch.set_float32_matmul_precision(cls.prior_float32_matmul_precision) torch.set_default_device(cls.prior_default_device) @@ -141,7 +141,8 @@ def do_profiling( ): if kwargs is None: kwargs = {} - torch.cuda.synchronize() + device_interface = get_interface_for_device(GPU_TYPE) + device_interface.synchronize() with torch.profiler.profile(with_stack=WITH_STACK) as p: niter = 3 for _ in range(niter): @@ -150,7 +151,7 @@ def do_profiling( with torch.profiler.record_function(tag_rhs): f_rhs(*args, **kwargs) - torch.cuda.synchronize() + device_interface.synchronize() profile_path = "/tmp/chrome.json" p.export_chrome_trace(profile_path) @@ -207,7 +208,7 @@ def create_model(vocab_size): def f(**inputs): optim.zero_grad(True) - with torch.cuda.amp.autocast(): + with torch.autocast(GPU_TYPE): pred = model(**inputs) loss = pred[0] loss.backward() @@ -279,7 +280,7 @@ def _process_inputs(x): def get_f(m, optim): def f(*args, **kwargs): optim.zero_grad(True) - with torch.cuda.amp.autocast(): + with torch.autocast(GPU_TYPE): pred = m(*args, **kwargs) loss = reduce_to_scalar_loss(pred) loss.backward() @@ -443,7 +444,7 @@ def test_matmul(self): # Using stride (30522, 1) does not make a difference here. x_bad_shape = rand_strided( - (8192, 30522), (30528, 1), device="cuda", dtype=torch.float16 + (8192, 30522), (30528, 1), device=GPU_TYPE, dtype=torch.float16 ) weight_bad_shape = torch.randn(30522, 768, dtype=torch.float16) out_bad_shape = torch.randn(8192, 768, dtype=torch.float16) @@ -592,7 +593,7 @@ def test_conv(self): x1 = torch.randn(*x_shape) padded_stride = ir.Layout._pad_strides(x1.stride(), x1.shape, torch.float32) - x2 = rand_strided(x_shape, padded_stride, device="cuda") + x2 = rand_strided(x_shape, padded_stride, device=GPU_TYPE) x2.copy_(x1) weight = torch.randn(64, 128, 3, 3) @@ -710,5 +711,5 @@ def test_pad_outputs( if __name__ == "__main__": - if HAS_CUDA: + if HAS_GPU: run_tests() diff --git a/test/inductor/test_pattern_matcher.py b/test/inductor/test_pattern_matcher.py index 4b1a5bc78ab8b..d6bfdbcc05f91 100644 --- a/test/inductor/test_pattern_matcher.py +++ b/test/inductor/test_pattern_matcher.py @@ -29,12 +29,20 @@ from torch._inductor.virtualized import V from torch.testing import FileCheck from torch.testing._internal.common_cuda import SM80OrLater -from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm -from torch.testing._internal.inductor_utils import HAS_CUDA, IS_A100, IS_BIG_GPU +from torch.testing._internal.common_device_type import expectedFailureXPU, skipCUDAIf +from torch.testing._internal.common_utils import IS_LINUX, skipIfRocm, skipIfXpu +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_GPU, + IS_A100, + IS_BIG_GPU, +) from torch.utils import _pytree as pytree class TestPatternMatcher(TestCase): + device_type = GPU_TYPE + def common( self, fn, @@ -74,16 +82,16 @@ def fn(a, b, c, d): # when m1 == n1 and m2 == n2, mm_plus_mm can be matched to fused op fusible_args_list = [ ( - torch.randn(16, 16, device="cuda"), - torch.randn(16, 16, device="cuda"), - torch.randn(16, 16, device="cuda"), - torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device=GPU_TYPE), + torch.randn(16, 16, device=GPU_TYPE), + torch.randn(16, 16, device=GPU_TYPE), + torch.randn(16, 16, device=GPU_TYPE), ), ( - torch.randn(1, 4, device="cuda"), - torch.randn(4, 2, device="cuda"), - torch.randn(1, 5, device="cuda"), - torch.randn(5, 2, device="cuda"), + torch.randn(1, 4, device=GPU_TYPE), + torch.randn(4, 2, device=GPU_TYPE), + torch.randn(1, 5, device=GPU_TYPE), + torch.randn(5, 2, device=GPU_TYPE), ), ] for args in fusible_args_list: @@ -93,16 +101,16 @@ def fn(a, b, c, d): unfusible_args_list = [ # https://github.com/pytorch/pytorch/issues/100670. ( - torch.randn(1, 4, device="cuda"), - torch.randn(4, 2, device="cuda"), - torch.randn(1, 2, device="cuda"), - torch.randn(2, 1, device="cuda"), + torch.randn(1, 4, device=GPU_TYPE), + torch.randn(4, 2, device=GPU_TYPE), + torch.randn(1, 2, device=GPU_TYPE), + torch.randn(2, 1, device=GPU_TYPE), ), ( - torch.randn(1, 2, device="cuda"), - torch.randn(2, 1, device="cuda"), - torch.randn(1, 4, device="cuda"), - torch.randn(4, 2, device="cuda"), + torch.randn(1, 2, device=GPU_TYPE), + torch.randn(2, 1, device=GPU_TYPE), + torch.randn(1, 4, device=GPU_TYPE), + torch.randn(4, 2, device=GPU_TYPE), ), ] for args in unfusible_args_list: @@ -121,7 +129,8 @@ def _test_fused_int_mm_mul_impl(self, fn, args, fused_int_mm_mul_expected=True): ) # also checks that dtype is correct @skipIfRocm - @unittest.skipIf(not SM80OrLater, "need sm_80") + @skipIfXpu + @skipCUDAIf(not SM80OrLater, "need sm_80") @inductor_config.patch(force_fuse_int_mm_with_mul=True) def test_fused_int_mm_mul(self): def fn1(a, b, c): @@ -134,19 +143,19 @@ def fn2(a, b, c): args_list = [ ( - torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"), - torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"), - torch.randn((32, 1), dtype=torch.float16, device="cuda") * 0 + 0.5, + torch.randint(-128, 127, (32, 32), dtype=torch.int8, device=GPU_TYPE), + torch.randint(-128, 127, (32, 8), dtype=torch.int8, device=GPU_TYPE), + torch.randn((32, 1), dtype=torch.float16, device=GPU_TYPE) * 0 + 0.5, ), ( - torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"), - torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"), - torch.randn((1, 8), dtype=torch.bfloat16, device="cuda"), + torch.randint(-128, 127, (32, 32), dtype=torch.int8, device=GPU_TYPE), + torch.randint(-128, 127, (32, 8), dtype=torch.int8, device=GPU_TYPE), + torch.randn((1, 8), dtype=torch.bfloat16, device=GPU_TYPE), ), ( - torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"), - torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"), - torch.randn((1, 8), dtype=torch.float32, device="cuda"), + torch.randint(-128, 127, (32, 32), dtype=torch.int8, device=GPU_TYPE), + torch.randint(-128, 127, (32, 8), dtype=torch.int8, device=GPU_TYPE), + torch.randn((1, 8), dtype=torch.float32, device=GPU_TYPE), ), ] @@ -155,22 +164,23 @@ def fn2(a, b, c): self._test_fused_int_mm_mul_impl(fn2, args, True) @skipIfRocm - @unittest.skipIf(not SM80OrLater, "need sm_80") + @skipIfXpu + @skipCUDAIf(not SM80OrLater, "need sm_80") @inductor_config.patch(force_fuse_int_mm_with_mul=True) def test_fused_int_mm_mul_gating(self): def fn1(a, b, c): return out_dtype(torch.ops.aten.mm.default, torch.int32, a, b) * c args1 = ( - torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"), - torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"), - torch.randn((8), dtype=torch.float32, device="cuda"), + torch.randint(-128, 127, (32, 32), dtype=torch.int8, device=GPU_TYPE), + torch.randint(-128, 127, (32, 8), dtype=torch.int8, device=GPU_TYPE), + torch.randn((8), dtype=torch.float32, device=GPU_TYPE), ) args2 = ( - torch.randint(-128, 127, (32, 32), dtype=torch.int8, device="cuda"), - torch.randint(-128, 127, (32, 8), dtype=torch.int8, device="cuda"), - torch.randn((32, 1), dtype=torch.float16, device="cuda"), + torch.randint(-128, 127, (32, 32), dtype=torch.int8, device=GPU_TYPE), + torch.randint(-128, 127, (32, 8), dtype=torch.int8, device=GPU_TYPE), + torch.randn((32, 1), dtype=torch.float16, device=GPU_TYPE), ) self._test_fused_int_mm_mul_impl(fn1, args1, False) self._test_fused_int_mm_mul_impl(fn1, [arg.cpu() for arg in args2], False) @@ -194,7 +204,8 @@ def _test_mixed_impl( self.assertEqual("mixed_mm" in code, mixed_mm_expected) self.assertEqual("fallback_mixed_mm" in code, fallback_mixed_mm_expected) - @unittest.skipIf(not SM80OrLater, "need sm_80") + @expectedFailureXPU + @skipCUDAIf(not SM80OrLater, "need sm_80") @inductor_config.patch(mixed_mm_choice="triton") def test_mixed_mm(self): def fn(a, b): @@ -202,27 +213,28 @@ def fn(a, b): args_list = [ ( - torch.randn(8, 8, device="cuda"), - torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"), + torch.randn(8, 8, device=GPU_TYPE), + torch.randint(-128, 127, (8, 8), dtype=torch.int8, device=GPU_TYPE), ), ( - torch.randn(8, 2, device="cuda", dtype=torch.bfloat16), - torch.randint(-128, 127, (2, 8), dtype=torch.int8, device="cuda"), + torch.randn(8, 2, device=GPU_TYPE, dtype=torch.bfloat16), + torch.randint(-128, 127, (2, 8), dtype=torch.int8, device=GPU_TYPE), ), ( - torch.randn(8, 5, device="cuda", dtype=torch.float16), - torch.randint(0, 255, (5, 2), dtype=torch.uint8, device="cuda"), + torch.randn(8, 5, device=GPU_TYPE, dtype=torch.float16), + torch.randint(0, 255, (5, 2), dtype=torch.uint8, device=GPU_TYPE), ), ( - torch.randn(8, 8, device="cuda", dtype=torch.float32), - torch.randn(8, 8, device="cuda", dtype=torch.bfloat16), + torch.randn(8, 8, device=GPU_TYPE, dtype=torch.float32), + torch.randn(8, 8, device=GPU_TYPE, dtype=torch.bfloat16), ), ] for args in args_list: self._test_mixed_impl(fn, args, True, False) - @unittest.skipIf(not SM80OrLater, "need sm_80") + @expectedFailureXPU + @skipCUDAIf(not SM80OrLater, "need sm_80") @inductor_config.patch(mixed_mm_choice="triton") def test_mixed_mm_exhaustive_dtypes(self): def fn(a, b): @@ -234,8 +246,10 @@ def fn(a, b): for dtype_left, dtype_right in itertools.product(dtypes_left, dtypes_right): low, high = dtype_ranges[dtype_right] args = ( - torch.randn(256, 256, dtype=dtype_left, device="cuda"), - torch.randint(low, high, (256, 256), dtype=dtype_right, device="cuda"), + torch.randn(256, 256, dtype=dtype_left, device=GPU_TYPE), + torch.randint( + low, high, (256, 256), dtype=dtype_right, device=GPU_TYPE + ), ) fallback_mixed_mm_expected = ( dtype_left == torch.bfloat16 and dtype_right == torch.uint8 @@ -244,7 +258,8 @@ def fn(a, b): fn, args, True, fallback_mixed_mm_expected, rtol=0.16, atol=1e-4 ) - @unittest.skipIf(not SM80OrLater, "need sm_80") + @expectedFailureXPU + @skipCUDAIf(not SM80OrLater, "need sm_80") @inductor_config.patch(mixed_mm_choice="triton") def test_mixed_mm_bad_cases(self): def fn(a, b): @@ -253,14 +268,14 @@ def fn(a, b): # when b is transposed and not contiguous, we skip triton and use fallback args_list = [ ( - torch.randn(8, 8, device="cuda", dtype=torch.float16), - torch.randint(-128, 127, (4, 8), dtype=torch.int8, device="cuda").t()[ + torch.randn(8, 8, device=GPU_TYPE, dtype=torch.float16), + torch.randint(-128, 127, (4, 8), dtype=torch.int8, device=GPU_TYPE).t()[ :, ::2 ], ), ( - torch.randn(8, 8, device="cuda", dtype=torch.bfloat16), - torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda").t()[ + torch.randn(8, 8, device=GPU_TYPE, dtype=torch.bfloat16), + torch.randint(0, 255, (4, 8), dtype=torch.uint8, device=GPU_TYPE).t()[ :, ::2 ], ), @@ -269,7 +284,8 @@ def fn(a, b): for args in args_list: self._test_mixed_impl(fn, args, True, True) - @unittest.skipIf(not SM80OrLater, "need sm_80") + @expectedFailureXPU + @skipCUDAIf(not SM80OrLater, "need sm_80") @inductor_config.patch(mixed_mm_choice="triton", max_autotune_gemm=True) def test_mixed_mm_epi_works(self): def fn(a, b, c, d): @@ -277,31 +293,32 @@ def fn(a, b, c, d): args_list = [ ( - torch.randn(8, 8, device="cuda"), - torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"), - torch.randn(8, device="cuda"), - torch.randn(8, device="cuda"), + torch.randn(8, 8, device=GPU_TYPE), + torch.randint(-128, 127, (8, 8), dtype=torch.int8, device=GPU_TYPE), + torch.randn(8, device=GPU_TYPE), + torch.randn(8, device=GPU_TYPE), ), ( - torch.randn(8, 2, device="cuda", dtype=torch.bfloat16), - torch.randint(-128, 127, (2, 8), dtype=torch.int8, device="cuda"), - torch.randn(8, device="cuda", dtype=torch.bfloat16), - torch.randn(8, device="cuda", dtype=torch.bfloat16), + torch.randn(8, 2, device=GPU_TYPE, dtype=torch.bfloat16), + torch.randint(-128, 127, (2, 8), dtype=torch.int8, device=GPU_TYPE), + torch.randn(8, device=GPU_TYPE, dtype=torch.bfloat16), + torch.randn(8, device=GPU_TYPE, dtype=torch.bfloat16), ), ( - torch.randn(8, 5, device="cuda", dtype=torch.float16), - torch.randint(0, 255, (5, 2), dtype=torch.uint8, device="cuda"), - torch.randn(2, device="cuda", dtype=torch.float16), - torch.randn(2, device="cuda", dtype=torch.float16), + torch.randn(8, 5, device=GPU_TYPE, dtype=torch.float16), + torch.randint(0, 255, (5, 2), dtype=torch.uint8, device=GPU_TYPE), + torch.randn(2, device=GPU_TYPE, dtype=torch.float16), + torch.randn(2, device=GPU_TYPE, dtype=torch.float16), ), ] for args in args_list: self._test_mixed_impl(fn, args, True, False) - @unittest.skipIf(not SM80OrLater, "need sm_80") - @unittest.skipIf(not IS_A100, "heuristic only run on Linux A100") - @unittest.skipIf(not IS_BIG_GPU, "tests fail on small GPU") + @expectedFailureXPU + @skipCUDAIf(not SM80OrLater, "need sm_80") + @skipCUDAIf(not IS_A100, "heuristic only run on Linux A100") + @skipCUDAIf(not IS_BIG_GPU, "tests fail on small GPU") @inductor_config.patch( mixed_mm_choice="heuristic", autoheuristic_use="", @@ -315,53 +332,64 @@ def fn(a, b): # examples that should not be selected by handwritten heuristic mat1_dtype = torch.float16 - dyn_tensor = torch.randn(4, 4096, dtype=mat1_dtype, device="cuda") + dyn_tensor = torch.randn(4, 4096, dtype=mat1_dtype, device=GPU_TYPE) torch._dynamo.mark_dynamic(dyn_tensor, 0) args_list = [ ( - torch.randn(1, 4097, dtype=mat1_dtype, device="cuda"), - torch.randint(-128, 127, (4097, 4096), dtype=torch.int8, device="cuda"), + torch.randn(1, 4097, dtype=mat1_dtype, device=GPU_TYPE), + torch.randint( + -128, 127, (4097, 4096), dtype=torch.int8, device=GPU_TYPE + ), ), ( - torch.randn(1, 4096, dtype=mat1_dtype, device="cuda"), - torch.randint(-128, 127, (4096, 4097), dtype=torch.int8, device="cuda"), + torch.randn(1, 4096, dtype=mat1_dtype, device=GPU_TYPE), + torch.randint( + -128, 127, (4096, 4097), dtype=torch.int8, device=GPU_TYPE + ), ), ( - torch.randn(8, 8, dtype=mat1_dtype, device="cuda"), - torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"), + torch.randn(8, 8, dtype=mat1_dtype, device=GPU_TYPE), + torch.randint(-128, 127, (8, 8), dtype=torch.int8, device=GPU_TYPE), ), ( - torch.randn(8, 2048, dtype=mat1_dtype, device="cuda"), - torch.randint(-128, 127, (2048, 2048), dtype=torch.int8, device="cuda"), + torch.randn(8, 2048, dtype=mat1_dtype, device=GPU_TYPE), + torch.randint( + -128, 127, (2048, 2048), dtype=torch.int8, device=GPU_TYPE + ), ), ( - torch.randn(8, 2048, dtype=mat1_dtype, device="cuda"), + torch.randn(8, 2048, dtype=mat1_dtype, device=GPU_TYPE), torch.randint( - -128, 127, (2048, 2048), dtype=torch.int8, device="cuda" + -128, 127, (2048, 2048), dtype=torch.int8, device=GPU_TYPE ).t(), ), ( - torch.randn(8, 4096, dtype=mat1_dtype, device="cuda"), - torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda")[ - :, ::2 - ], + torch.randn(8, 4096, dtype=mat1_dtype, device=GPU_TYPE), + torch.randint( + -128, 127, (4096, 4096), dtype=torch.int8, device=GPU_TYPE + )[:, ::2], ), ( - torch.randn(1, 4096, dtype=torch.float32, device="cuda"), - torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"), + torch.randn(1, 4096, dtype=torch.float32, device=GPU_TYPE), + torch.randint( + -128, 127, (4096, 4096), dtype=torch.int8, device=GPU_TYPE + ), ), ( dyn_tensor, - torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"), + torch.randint( + -128, 127, (4096, 4096), dtype=torch.int8, device=GPU_TYPE + ), ), ] for args in args_list: self._test_mixed_impl(fn, args, True, True) - @unittest.skipIf(not SM80OrLater, "need sm_80") - @unittest.skipIf(not IS_A100, "heuristic only run on Linux A100") - @unittest.skipIf(not IS_BIG_GPU, "tests fail on small GPU") + @expectedFailureXPU + @skipCUDAIf(not SM80OrLater, "need sm_80") + @skipCUDAIf(not IS_A100, "heuristic only run on Linux A100") + @skipCUDAIf(not IS_BIG_GPU, "tests fail on small GPU") @inductor_config.patch( mixed_mm_choice="heuristic", autoheuristic_use="", @@ -377,50 +405,61 @@ def fn(a, b): # examples that should be selected by handwritten heuristic args_list = [ ( - torch.randn(1, 4096, dtype=mat1_dtype, device="cuda"), - torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"), + torch.randn(1, 4096, dtype=mat1_dtype, device=GPU_TYPE), + torch.randint( + -128, 127, (4096, 4096), dtype=torch.int8, device=GPU_TYPE + ), ), ( - torch.randn(4, 4096, dtype=mat1_dtype, device="cuda"), - torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"), + torch.randn(4, 4096, dtype=mat1_dtype, device=GPU_TYPE), + torch.randint( + -128, 127, (4096, 4096), dtype=torch.int8, device=GPU_TYPE + ), ), ( - torch.randn(8, 4096, dtype=mat1_dtype, device="cuda"), - torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"), + torch.randn(8, 4096, dtype=mat1_dtype, device=GPU_TYPE), + torch.randint( + -128, 127, (4096, 4096), dtype=torch.int8, device=GPU_TYPE + ), ), ( - torch.randn(8, 4096, dtype=mat1_dtype, device="cuda"), + torch.randn(8, 4096, dtype=mat1_dtype, device=GPU_TYPE), torch.randint( - -128, 127, (4096, 4096), dtype=torch.int8, device="cuda" + -128, 127, (4096, 4096), dtype=torch.int8, device=GPU_TYPE ).t(), ), ( - torch.randn(16, 4096, dtype=mat1_dtype, device="cuda"), + torch.randn(16, 4096, dtype=mat1_dtype, device=GPU_TYPE), torch.randint( - -128, 127, (8192, 4096), dtype=torch.int8, device="cuda" + -128, 127, (8192, 4096), dtype=torch.int8, device=GPU_TYPE ).t(), ), ( - torch.randn(32, 4096, dtype=mat1_dtype, device="cuda"), - torch.randint(-128, 127, (4096, 8192), dtype=torch.int8, device="cuda"), + torch.randn(32, 4096, dtype=mat1_dtype, device=GPU_TYPE), + torch.randint( + -128, 127, (4096, 8192), dtype=torch.int8, device=GPU_TYPE + ), ), ( - torch.randn(64, 4096, dtype=mat1_dtype, device="cuda"), - torch.randint(-128, 127, (4096, 4096), dtype=torch.int8, device="cuda"), + torch.randn(64, 4096, dtype=mat1_dtype, device=GPU_TYPE), + torch.randint( + -128, 127, (4096, 4096), dtype=torch.int8, device=GPU_TYPE + ), ), ] for args in args_list: self._test_mixed_impl(fn, args, True, False, rtol=0.01, atol=0.04) - @unittest.skipIf(not SM80OrLater, "need sm_80") + @expectedFailureXPU + @skipCUDAIf(not SM80OrLater, "need sm_80") def test_mixed_mm_gating(self): def fn(a, b): return torch.mm(a, b.to(a.dtype)) args = ( - torch.randn(8, 8, device="cuda"), - torch.randint(-128, 127, (8, 8), dtype=torch.int8, device="cuda"), + torch.randn(8, 8, device=GPU_TYPE), + torch.randint(-128, 127, (8, 8), dtype=torch.int8, device=GPU_TYPE), ) # will ignore the mixed_mm code (including fallback) with inductor_config.patch( @@ -469,7 +508,8 @@ def fn(a, b): ) self._test_mixed_impl(fn, args, False, False) - @unittest.skipIf(not SM80OrLater, "need sm_80") + @expectedFailureXPU + @skipCUDAIf(not SM80OrLater, "need sm_80") @inductor_config.patch(use_mixed_mm=True) def test_uint4x2_mixed_mm(self): def fn(a, b): @@ -491,12 +531,12 @@ def check_uint4x2_mixed_mm(args, expect_mixed_mm): args_expect_mixed_mm = [ ( - torch.randn(8, 8, device="cuda"), - torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"), + torch.randn(8, 8, device=GPU_TYPE), + torch.randint(0, 255, (4, 8), dtype=torch.uint8, device=GPU_TYPE), ), ( - torch.randn(8, 8, device="cuda", dtype=torch.float16), - torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda") + torch.randn(8, 8, device=GPU_TYPE, dtype=torch.float16), + torch.randint(0, 255, (4, 8), dtype=torch.uint8, device=GPU_TYPE) .t() .contiguous() .t(), @@ -509,19 +549,20 @@ def check_uint4x2_mixed_mm(args, expect_mixed_mm): # mixed mm is only enabled when casting from a lower-bitwidth dtype to a higher one args_expect_no_mixed_mm = [ ( - torch.randn(8, 8, device="cuda"), - torch.randint(0, 255, (4, 8), dtype=torch.int32, device="cuda"), + torch.randn(8, 8, device=GPU_TYPE), + torch.randint(0, 255, (4, 8), dtype=torch.int32, device=GPU_TYPE), ), ( - torch.randn(8, 8, device="cuda"), - torch.randint(0, 255, (4, 8), dtype=torch.int64, device="cuda"), + torch.randn(8, 8, device=GPU_TYPE), + torch.randint(0, 255, (4, 8), dtype=torch.int64, device=GPU_TYPE), ), ] for args in args_expect_no_mixed_mm: check_uint4x2_mixed_mm(args, False) - @unittest.skipIf(not SM80OrLater, "need sm_80") + @expectedFailureXPU + @skipCUDAIf(not SM80OrLater, "need sm_80") @inductor_config.patch(use_mixed_mm=True) def test_uint4x2_mixed_mm_epi(self): def fn(a, b, c, d): @@ -539,10 +580,10 @@ def fn(a, b, c, d): args_list = [ ( - torch.randn(8, 8, device="cuda"), - torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"), - torch.randn(8, device="cuda"), - torch.randn(8, device="cuda"), + torch.randn(8, 8, device=GPU_TYPE), + torch.randint(0, 255, (4, 8), dtype=torch.uint8, device=GPU_TYPE), + torch.randn(8, device=GPU_TYPE), + torch.randn(8, device=GPU_TYPE), ), ] @@ -572,8 +613,8 @@ def fn(a, b): torch.randint(0, 255, (4, 8), dtype=torch.uint8), ), ( # int8 - torch.randn(8, 8, device="cuda"), - torch.randint(-128, 127, (4, 8), dtype=torch.int8, device="cuda"), + torch.randn(8, 8, device=GPU_TYPE), + torch.randint(-128, 127, (4, 8), dtype=torch.int8, device=GPU_TYPE), ), # we don't match for int8 since numerics ] # for int8 bitshifts don't match between triton and pytorch @@ -599,8 +640,8 @@ def fn(a, b): args_list = [ ( - torch.randn(8, 8, device="cuda"), - torch.randint(0, 255, (4, 8), dtype=torch.uint8, device="cuda"), + torch.randn(8, 8, device=GPU_TYPE), + torch.randint(0, 255, (4, 8), dtype=torch.uint8, device=GPU_TYPE), ), ] @@ -618,33 +659,33 @@ def fn(a, b, c): args_list = [ ( - torch.randn(16, 16, device="cuda"), - torch.randn(16, 16, device="cuda"), - torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device=GPU_TYPE), + torch.randn(16, 16, device=GPU_TYPE), + torch.randn(16, 16, device=GPU_TYPE), True, ), ( - torch.randn(8, device="cuda"), - torch.randn(16, 16, device="cuda"), - torch.randn(16, 8, device="cuda"), + torch.randn(8, device=GPU_TYPE), + torch.randn(16, 16, device=GPU_TYPE), + torch.randn(16, 8, device=GPU_TYPE), True, ), ( - torch.randn(16, 16, device="cuda"), - torch.randn(1, 16, device="cuda"), - torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device=GPU_TYPE), + torch.randn(1, 16, device=GPU_TYPE), + torch.randn(16, 16, device=GPU_TYPE), False, ), ( - torch.randn(1, 16, 16, device="cuda"), - torch.randn(16, 16, device="cuda"), - torch.randn(16, 16, device="cuda"), + torch.randn(1, 16, 16, device=GPU_TYPE), + torch.randn(16, 16, device=GPU_TYPE), + torch.randn(16, 16, device=GPU_TYPE), False, ), ( 4, - torch.randn(16, 16, device="cuda"), - torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device=GPU_TYPE), + torch.randn(16, 16, device=GPU_TYPE), False, ), ] @@ -665,8 +706,8 @@ def fn(m1, m2): bias = m1.size(0) return torch.add(bias, torch.mm(m1, m2)), torch.mm(m1, m2) + bias - m1 = torch.randn(16, 16, device="cuda") - m2 = torch.randn(16, 16, device="cuda") + m1 = torch.randn(16, 16, device=GPU_TYPE) + m2 = torch.randn(16, 16, device=GPU_TYPE) counters.clear() expect = fn(m1, m2) @@ -679,16 +720,16 @@ class Model(torch.nn.Module): def __init__(self) -> None: super().__init__() self.linear = torch.nn.functional.linear - self.linear_weight = torch.randn(4, 4).cuda() - self.bias = torch.randn(1, 4).cuda() + self.linear_weight = torch.randn(4, 4).to(GPU_TYPE) + self.bias = torch.randn(1, 4).to(GPU_TYPE) def forward(self, x): x = self.linear(x, self.linear_weight, self.bias) return x - input_tensor = torch.randn(1, 3, 4).cuda() + input_tensor = torch.randn(1, 3, 4).to(GPU_TYPE) - func = Model().cuda() + func = Model().to(GPU_TYPE) res1 = func(input_tensor) jit_func = torch.compile(func) @@ -708,11 +749,13 @@ def fn(a, b, c): ) args = [ - torch.randn(16, 16, device="cuda"), - torch.randn(16, 16, device="cuda"), - torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device=GPU_TYPE), + torch.randn(16, 16, device=GPU_TYPE), + torch.randn(16, 16, device=GPU_TYPE), ] - self.common(fn, args, 1, 4) + out, code = run_and_get_code(torch.compile(fn), *args) + self.assertEqual(out, fn(*args)) + FileCheck().check("call").check_not(".run").run(code[0]) def test_cat_addmm(self): def fn(a, b, c): @@ -726,11 +769,13 @@ def fn(a, b, c): ) args = [ - torch.randn(16, 16, device="cuda"), - torch.randn(16, 16, device="cuda"), - torch.randn(16, 16, device="cuda"), + torch.randn(16, 16, device=GPU_TYPE), + torch.randn(16, 16, device=GPU_TYPE), + torch.randn(16, 16, device=GPU_TYPE), ] - self.common(fn, args, 1, 4) + out, code = run_and_get_code(torch.compile(fn), *args) + self.assertEqual(out, fn(*args)) + FileCheck().check("call").check_not(".run").run(code[0]) def test_cat_slice_cat_cuda(self): def fn(a, b): @@ -740,14 +785,14 @@ def fn(a, b): return torch.ops.aten.cat.default([cat_1, slice_2], 1) args = [ - torch.randn(2, 32, device="cuda"), - torch.randn(2, 16, device="cuda"), + torch.randn(2, 32, device=GPU_TYPE), + torch.randn(2, 16, device=GPU_TYPE), ] self.common(fn, args, 1, 3) args = [ - torch.randn(2, 8, device="cuda"), - torch.randn(2, 16, device="cuda"), + torch.randn(2, 8, device=GPU_TYPE), + torch.randn(2, 16, device=GPU_TYPE), ] torch._dynamo.reset() counters.clear() @@ -767,8 +812,8 @@ def fn(a, b): return torch.ops.aten.cat.default([cat_1, slice_2], 1) args = [ - torch.randn(2, 8, device="cuda"), - torch.randn(2, 16, device="cuda"), + torch.randn(2, 8, device=GPU_TYPE), + torch.randn(2, 16, device=GPU_TYPE), ] self.common(fn, args, 1, 3) @@ -843,7 +888,7 @@ def fn(a): return cat**2 args = [ - torch.randn(2, 32, device="cuda"), + torch.randn(2, 32, device=GPU_TYPE), ] self.common(fn, args, 1, 4) @@ -857,7 +902,7 @@ def fn(a): return cat**2 + getitem_2 args = [ - torch.randn(2, 32, device="cuda"), + torch.randn(2, 32, device=GPU_TYPE), ] self.common(fn, args, 0, 0) @@ -870,7 +915,7 @@ def fn(a): return cat**2 args = [ - torch.randn(2, 32, device="cuda"), + torch.randn(2, 32, device=GPU_TYPE), ] self.common(fn, args, 0, 0) @@ -881,7 +926,7 @@ def fn(a): return cat args = [ - torch.randn(1, 8, device="cuda"), + torch.randn(1, 8, device=GPU_TYPE), ] self.common(fn, args, 0, 0) @@ -895,9 +940,9 @@ def fn(a, b, c): return [s**2 for s in split_with_sizes] args = [ - torch.randn(2, 2, device="cuda"), - torch.randn(2, 3, device="cuda"), - torch.randn(2, 5, device="cuda"), + torch.randn(2, 2, device=GPU_TYPE), + torch.randn(2, 3, device=GPU_TYPE), + torch.randn(2, 5, device=GPU_TYPE), ] self.common(fn, args, 1, 2) @@ -910,9 +955,9 @@ def fn(a, b, c): return [s**2 for s in split_with_sizes] + [cat**3] args = [ - torch.randn(2, 2, device="cuda"), - torch.randn(2, 3, device="cuda"), - torch.randn(2, 5, device="cuda"), + torch.randn(2, 2, device=GPU_TYPE), + torch.randn(2, 3, device=GPU_TYPE), + torch.randn(2, 5, device=GPU_TYPE), ] self.common(fn, args, 0, 0) @@ -925,9 +970,9 @@ def fn(a, b, c): return [s**2 for s in split_with_sizes] args = [ - torch.randn(10, 2, device="cuda"), - torch.randn(10, 3, device="cuda"), - torch.randn(10, 5, device="cuda"), + torch.randn(10, 2, device=GPU_TYPE), + torch.randn(10, 3, device=GPU_TYPE), + torch.randn(10, 5, device=GPU_TYPE), ] self.common(fn, args, 0, 0) @@ -938,9 +983,9 @@ def fn(a, b, c): return [s**2 for s in split_with_sizes] args = [ - torch.randn(2, 2, device="cuda"), - torch.randn(2, 3, device="cuda"), - torch.randn(2, 5, device="cuda"), + torch.randn(2, 2, device=GPU_TYPE), + torch.randn(2, 3, device=GPU_TYPE), + torch.randn(2, 5, device=GPU_TYPE), ] self.common(fn, args, 0, 0) @@ -953,9 +998,9 @@ def fn(a, b, c): return [s**2 for s in split_with_sizes] args = [ - torch.randn(2, 2, device="cuda"), - torch.randn(2, 3, device="cuda"), - torch.randn(2, 5, device="cuda"), + torch.randn(2, 2, device=GPU_TYPE), + torch.randn(2, 3, device=GPU_TYPE), + torch.randn(2, 5, device=GPU_TYPE), ] self.common(fn, args, 0, 0) @@ -1064,7 +1109,7 @@ def fn2(x, y): def fn3(x, y): a = torch.sin(x) - with torch.autocast("cuda"): + with torch.autocast(GPU_TYPE): b = torch.add(x, a) return b @@ -1081,8 +1126,8 @@ def fn5(x, y): return b args = [ - torch.randn(5, 5, device="cuda"), - torch.randn(5, 5, device="cuda"), + torch.randn(5, 5, device=GPU_TYPE), + torch.randn(5, 5, device=GPU_TYPE), ] with unittest.mock.patch( @@ -1113,11 +1158,12 @@ def fn(a, b): self.assertIn("return (buf0, )", code[0]) self.assertNotIn("async_compile.cpp", code[0]) + @expectedFailureXPU def test_unfuse_bias_addmm(self): args = [ - torch.randn(20, device="cuda"), - torch.randn(10, 15, device="cuda"), - torch.randn(15, 20, device="cuda"), + torch.randn(20, device=GPU_TYPE), + torch.randn(10, 15, device=GPU_TYPE), + torch.randn(15, 20, device=GPU_TYPE), ] @torch.compile() @@ -1188,15 +1234,46 @@ def remap_fake_tensor(x): # of search_fn). self.assertTrue(pattern.pattern_eq(search_fn_pattern)) + @skipIfXpu + @inductor_config.patch( + { + "triton.unique_kernel_names": "original_aten", + "fx_graph_remote_cache": False, + "max_autotune_gemm_backends": "TRITON", + } + ) + def test_original_aten_preserved_split_addmm(self): + # addmm -> elementwise should be decomposed into mm -> add -> elementwise + def fn(x, y, z): + return torch.addmm(z, x, y).sin() + + args = [ + torch.randn(16, 24, device=GPU_TYPE), + torch.randn(24, 32, device=GPU_TYPE), + torch.randn(16, 32, device=GPU_TYPE), + ] + + counters.clear() + + opt_fn = torch.compile(fn, mode="max-autotune") + ret, code = run_and_get_code(opt_fn, *args) + self.assertEqual(counters["inductor"]["pattern_matcher_count"], 1) + + # The mm kernel should use a template (because we set max_autotune_gemm_backends = TRITON). + # Its name should contain `addmm` because `addmm` was the original aten op where the mm came from. + FileCheck().check_not("extern_kernels.addmm(").check( + "def triton_tem_fused_addmm" + ).run(code[0]) + @inductor_config.patch(fx_graph_remote_cache=False) def test_match_equivalent_function_invocations1(self): counter = 0 test_pass = PatternMatcherPass() args = [ - torch.randn(20, device="cuda"), - torch.randn(10, 15, device="cuda"), - torch.randn(15, 20, device="cuda"), + torch.randn(20, device=GPU_TYPE), + torch.randn(10, 15, device=GPU_TYPE), + torch.randn(15, 20, device=GPU_TYPE), ] def f0(inp, a, b): @@ -1251,9 +1328,9 @@ def test_match_equivalent_function_invocations2(self): test_pass = PatternMatcherPass() args = [ - torch.randn(20, device="cuda"), - torch.randn(10, 15, device="cuda"), - torch.randn(15, 20, device="cuda"), + torch.randn(20, device=GPU_TYPE), + torch.randn(10, 15, device=GPU_TYPE), + torch.randn(15, 20, device=GPU_TYPE), ] def f0(inp, a, b): @@ -1297,9 +1374,9 @@ def test_match_equivalent_function_invocations3(self): test_pass = PatternMatcherPass() args = [ - torch.randn(20, device="cuda"), - torch.randn(10, 15, device="cuda"), - torch.randn(15, 20, device="cuda"), + torch.randn(20, device=GPU_TYPE), + torch.randn(10, 15, device=GPU_TYPE), + torch.randn(15, 20, device=GPU_TYPE), ] def f0(inp, a, b): @@ -1425,7 +1502,7 @@ def check(type, func_name, args, kwargs, expect=True): check( "call_function", torch.amp.autocast_mode._enter_autocast, - ("cuda", None, True, None), + (GPU_TYPE, None, True, None), {}, ) check("call_function", torch.amp.autocast_mode._exit_autocast, (None,), {}) @@ -1474,5 +1551,5 @@ def fused_rms_norm_quant_static(out: torch.Tensor, input: torch.Tensor) -> None: if __name__ == "__main__": - if IS_LINUX and HAS_CUDA: + if IS_LINUX and HAS_GPU: run_tests() diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py index c7c341d9165a4..87d8e383bd58a 100644 --- a/test/inductor/test_perf.py +++ b/test/inductor/test_perf.py @@ -533,6 +533,52 @@ def f(x, scale, amax_keep_dim): self.assertEqual(actual_numel_amax_keep_dim, actual_numel_amax_no_keep_dim) self.assertGreaterAlmostEqual(actual_numel_amax_keep_dim, str(expected_numel)) + def test_create_block_mask(self): + def mk_3d_flex_natten_mask(dims, kernel_size): + T, H, W = dims + K_T, K_H, K_W = kernel_size + spatial = H * W + + def get_x_y_t(idx: int) -> tuple[int, int, int]: + t = idx // spatial + s = idx % spatial + x = s // W + y = s % W + return x, y, t + + def get_mask(b, h, q_idx, kv_idx): + q_x, q_y, q_t = get_x_y_t(q_idx) + kv_x, kv_y, kv_t = get_x_y_t(kv_idx) + kernel_x = q_x.clamp(K_W // 2, (W - 1) - K_W // 2) + kernel_y = q_y.clamp(K_H // 2, (H - 1) - K_H // 2) + kernel_t = q_t.clamp(K_T // 2, (T - 1) - K_T // 2) + hori_mask = (kernel_x - kv_x).abs() <= K_W // 2 + vert_mask = (kernel_y - kv_y).abs() <= K_H // 2 + temp_mask = (kernel_t - kv_t).abs() <= K_T // 2 + return hori_mask & vert_mask & temp_mask + + return get_mask + + T = 4 + H = 16 + W = 16 + t = 5 + h = 5 + w = 5 + data_size = (T, H, W) + kernel_size = (t, h, w) + S = T * H * W + from torch.nn.attention.flex_attention import create_block_mask + + mask_mod = mk_3d_flex_natten_mask(data_size, kernel_size) + + torch.compile(create_block_mask)(mask_mod, None, None, S, S) + numel = int(count_numel(create_block_mask, mask_mod, None, None, S, S)) + + # We should be writing way less than a quadratic amount of bytes here + # With fusion, we should only be writing a linear number of bytes + self.assertLess(numel * 5, S * S) + class SchedulerFusionTests(TestCase): """ diff --git a/test/inductor/test_snode_runtime.py b/test/inductor/test_snode_runtime.py index 146c095e21d23..e002a61b6725f 100644 --- a/test/inductor/test_snode_runtime.py +++ b/test/inductor/test_snode_runtime.py @@ -10,7 +10,8 @@ from torch._inductor.compile_fx import compile_fx, compile_fx_inner from torch._inductor.test_case import TestCase as InductorTestCase from torch._inductor.utils import is_collective -from torch.testing._internal.inductor_utils import HAS_CUDA +from torch.testing._internal.common_device_type import expectedFailureXPU +from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU aten = torch.ops.aten @@ -41,7 +42,7 @@ def calculate_runtime(f, *args) -> float: return ret -DEVICE = "cuda" +DEVICE = GPU_TYPE def T(*size, dtype=torch.float32, device=DEVICE, grad=False) -> torch.Tensor: @@ -81,6 +82,8 @@ def assertNotZero(self, x): class UnsupportedTests(TestCase): + device = DEVICE + def test_no_op(self): def f(a): return a @@ -97,6 +100,10 @@ def f(a): class ComputeBoundedTests(TestCase): + device = DEVICE + + # lack of profiler on XPU + @expectedFailureXPU def test_conv1d(self): def f(x, y): return torch.nn.functional.conv1d(x, y) @@ -104,6 +111,8 @@ def f(x, y): inp = (T(33, 16, 30), T(20, 16, 5)) self.assertNotZero(calculate_runtime(f, *inp)) + # lack of profiler on XPU + @expectedFailureXPU def test_conv2d(self): def f(x, y): return torch.nn.functional.conv2d(x, y, padding=1) @@ -111,6 +120,8 @@ def f(x, y): inp = (T(8, 4, 3, 3), T(1, 4, 5, 5)) self.assertNotZero(calculate_runtime(f, *inp)) + # lack of profiler on XPU + @expectedFailureXPU def test_conv2d_transpose(self): def f(x, y): return torch.nn.functional.conv_transpose2d(x, y, padding=1) @@ -118,6 +129,8 @@ def f(x, y): inp = (T(8, 1, 1, 1), T(1, 4, 5, 5)) self.assertNotZero(calculate_runtime(f, *inp)) + # lack of profiler on XPU + @expectedFailureXPU def test_conv3d(self): def f(x, y): return torch.nn.functional.conv3d(x, y) @@ -125,6 +138,8 @@ def f(x, y): inp = (T(20, 16, 50, 10, 20), T(33, 16, 3, 3, 3)) self.assertNotZero(calculate_runtime(f, *inp)) + # lack of profiler on XPU + @expectedFailureXPU def test_mm(self): def f(a, b): return torch.mm(a, b) @@ -135,6 +150,8 @@ def f(a, b): ) self.assertNotZero(calculate_runtime(f, *inp)) + # lack of profiler on XPU + @expectedFailureXPU def test_addmm(self): def f(a, b, c): return torch.addmm(a, b, c) @@ -146,6 +163,8 @@ def f(a, b, c): ) self.assertNotZero(calculate_runtime(f, *inp)) + # lack of profiler on XPU + @expectedFailureXPU def test_bmm(self): def f(a, b): return torch.bmm(a, b) @@ -158,6 +177,10 @@ def f(a, b): class MemoryBoundedTests(TestCase): + device = DEVICE + + # lack of profiler on XPU + @expectedFailureXPU def test_relu(self): def f(a): return torch.nn.functional.relu(a) @@ -165,6 +188,8 @@ def f(a): inp = (T(10, 10),) self.assertNotZero(calculate_runtime(f, *inp)) + # lack of profiler on XPU + @expectedFailureXPU def test_horizontal_reduction_pointwise(self): def f(a): b = a.sum(dim=1) @@ -174,6 +199,8 @@ def f(a): inp = (T(10, 10),) self.assertNotZero(calculate_runtime(f, *inp)) + # lack of profiler on XPU + @expectedFailureXPU def test_pointwise(self): def f(x): return x.cos() @@ -181,6 +208,8 @@ def f(x): inp = (T(10),) self.assertNotZero(calculate_runtime(f, *inp)) + # lack of profiler on XPU + @expectedFailureXPU @torch._dynamo.config.patch(assume_static_by_default=False) def test_dynamic(self): def f(x): @@ -192,6 +221,8 @@ def f(x): @skipIf(not dist.is_available(), "requires distributed") class TestCommAnalysis(TestCase): + device = DEVICE + WORLD_SIZE: int = 8 RANKS = list(range(8)) @@ -223,6 +254,8 @@ def _verify_runtime_estimation(self, fn, inps): finally: dist.destroy_process_group() + # lack of profiler on XPU + @expectedFailureXPU def test_legacy_all_reduce(self): def fn(x): r = c10d.all_reduce(x, "sum", "", self.RANKS, self.WORLD_SIZE) @@ -231,6 +264,8 @@ def fn(x): inp = T(10, 10) self._verify_runtime_estimation(fn, (inp,)) + # lack of profiler on XPU + @expectedFailureXPU def test_legacy_all_reduce_coalesced(self): def fn(x): rs = c10d.all_reduce_coalesced(x, "sum", "", self.RANKS, self.WORLD_SIZE) @@ -239,6 +274,8 @@ def fn(x): inp = [T(10, 10), T(15, 15)] self._verify_runtime_estimation(fn, (inp,)) + # lack of profiler on XPU + @expectedFailureXPU def test_legacy_all_gather_into_tensor_coalesced(self): def fn(x): rs = c10d.all_gather_into_tensor_coalesced( @@ -252,6 +289,8 @@ def fn(x): inp = [T(10, 10), T(15, 15)] self._verify_runtime_estimation(fn, (inp,)) + # lack of profiler on XPU + @expectedFailureXPU def test_all_reduce(self): def fn(x): r = _c10d.all_reduce(x, "sum", "0") @@ -260,6 +299,8 @@ def fn(x): inp = T(10, 10) self._verify_runtime_estimation(fn, (inp,)) + # lack of profiler on XPU + @expectedFailureXPU def test_all_reduce_coalesced(self): def fn(x): rs = _c10d.all_reduce_coalesced(x, "sum", "0") @@ -268,6 +309,8 @@ def fn(x): inp = [T(10, 10), T(15, 15)] self._verify_runtime_estimation(fn, (inp,)) + # lack of profiler on XPU + @expectedFailureXPU def test_all_gather_into_tensor(self): def fn(x): rs = _c10d.all_gather_into_tensor( @@ -280,6 +323,8 @@ def fn(x): inp = T(10, 10) self._verify_runtime_estimation(fn, (inp,)) + # lack of profiler on XPU + @expectedFailureXPU def test_all_gather_into_tensor_coalesced(self): def fn(x): rs = _c10d.all_gather_into_tensor_coalesced( @@ -292,6 +337,8 @@ def fn(x): inp = [T(10, 10), T(15, 15)] self._verify_runtime_estimation(fn, (inp,)) + # lack of profiler on XPU + @expectedFailureXPU def test_reduce_scatter_tensor(self): def fn(x): rs = _c10d.reduce_scatter_tensor( @@ -305,6 +352,8 @@ def fn(x): inp = T(self.WORLD_SIZE, 10) self._verify_runtime_estimation(fn, (inp,)) + # lack of profiler on XPU + @expectedFailureXPU def test_reduce_scatter_tensor_coalesced(self): def fn(x): rs = _c10d.reduce_scatter_tensor_coalesced( @@ -322,5 +371,5 @@ def fn(x): if __name__ == "__main__": from torch._inductor.test_case import run_tests - if HAS_CUDA: + if HAS_GPU: run_tests(needs="filelock") diff --git a/test/inductor/test_split_cat_fx_passes.py b/test/inductor/test_split_cat_fx_passes.py index 3e775ef2de8e4..974895258eafb 100644 --- a/test/inductor/test_split_cat_fx_passes.py +++ b/test/inductor/test_split_cat_fx_passes.py @@ -113,8 +113,6 @@ def normalize_reshape_with_dynamic_shape(x): expected_split_norm_count, msg=f"for {fn}", ) - if expected_split_norm_count > 0: - self.assertIn("normalization_pass_pre_grad", optimus_scuba_log) counters.clear() @patch diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ea7fb83219d24..73687f41d95a8 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -99,7 +99,7 @@ importlib.import_module("functorch") importlib.import_module("filelock") -from torch._inductor import config, test_operators +from torch._inductor import config, cpu_vec_isa, test_operators from torch._inductor.compile_fx import ( compile_fx, compile_fx_inner, @@ -1487,7 +1487,6 @@ def nested(x, repeats): actual = nested_opt(*example_inputs) self.assertEqual(expect, actual) - @xfail_if_triton_cpu def test_index_propagation_flip(self): def flip(x): i = torch.arange(x.size(0) - 1, -1, -1, device=x.device) @@ -1575,10 +1574,16 @@ def test( pass # no device asserts in halide elif self.device == "cpu" and not is_triton_cpu_backend(self.device): _, code = run_and_get_cpp_code(fn_opt, *inps) - self.assertTrue((") ? (" in code or "blendv" in code) is has_wrapping) self.assertTrue(("TORCH_CHECK" in code) is has_assert) - # Assert that we always vectorize the kernel regardless of wrapping / checks - self.assertTrue(("loadu" in code) is vectorize) + if ( + cpu_vec_isa.valid_vec_isa_list() + and os.getenv("ATEN_CPU_CAPABILITY") != "default" + ): + self.assertTrue( + (") ? (" in code or "blendv" in code) is has_wrapping + ) + # Assert that we always vectorize the kernel regardless of wrapping / checks + self.assertTrue(("loadu" in code) is vectorize) else: code = run_and_get_triton_code(fn_opt, *inps) self.assertTrue(("tl.where" in code) is has_wrapping) @@ -1669,7 +1674,6 @@ def constant_propagation_neg(a): vectorize=False, # There's no loop to vectorize! ) - @xfail_if_triton_cpu def test_computed_buffer_inlining(self): def flip(x): idx = torch.arange(x.size(0) - 1, -1, -1, device=x.device) @@ -1691,7 +1695,7 @@ def fn(a, mask, idx): ( torch.randn(8, device=self.device), torch.tensor([True, False, True], device=self.device), - [torch.tensor([3, 9, -2], device=self.device)], + [torch.tensor([3, 9, 2], device=self.device)], ), ) @@ -1704,7 +1708,7 @@ def fn(a, mask, idx, values): ( torch.randn(8, device=self.device), torch.tensor([True, False, True], device=self.device), - [torch.tensor([3, 9, -2], device=self.device)], + [torch.tensor([3, 9, 2], device=self.device)], torch.randn(3, device=self.device), ), ) @@ -1891,8 +1895,20 @@ def test_multilayer_var_lowp(self): def fn(a): return torch.var(a) - self.common(fn, (torch.rand((16, 16, 352, 352), dtype=torch.float16),)) - self.common(fn, (torch.rand((14923), dtype=torch.float16),)) + atol = None + rtol = None + if self.device == "cpu" and os.getenv("ATEN_CPU_CAPABILITY") == "default": + atol = 1e-3 + rtol = 1e-3 + self.common( + fn, + (torch.rand((16, 16, 352, 352), dtype=torch.float16),), + atol=atol, + rtol=rtol, + ) + self.common( + fn, (torch.rand((14923), dtype=torch.float16),), atol=atol, rtol=rtol + ) def test_split_cumsum(self): def fn(a): @@ -3331,6 +3347,7 @@ def fn(a, b): ) @skipIfPy312 # segfaults + @skipCUDAIf(not SM80OrLater, "Requires sm80") @config.patch(mixed_mm_choice="triton") def test_mixed_mm(self): def fn(a, b): @@ -3346,6 +3363,7 @@ def fn(a, b): ) @skipIfPy312 # segfaults + @skipCUDAIf(not SM80OrLater, "Requires sm80") @config.patch(mixed_mm_choice="triton") def test_mixed_mm2(self): def fn(a, b, scale, bias): @@ -3363,6 +3381,7 @@ def fn(a, b, scale, bias): ) @skipIfPy312 # segfaults + @skipCUDAIf(not SM80OrLater, "Requires sm80") @config.patch(mixed_mm_choice="triton") def test_mixed_mm3(self): def fn(a, b): @@ -7706,6 +7725,9 @@ def fn(a, dim, index, b): check_lowp = False for deterministic in [False, True]: + if deterministic and self.device == "xpu": + # There is no deterministic implementation for scatter_add on Intel GPU. + continue with DeterministicGuard(deterministic): self.common( fn, @@ -9742,7 +9764,6 @@ def forward(arg6, arg7, arg16): not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support SDPA or pre-SM80 hardware", ) - @skipIfRocm def test_sdpa(self, use_block_ptr: bool, prefer_nd_tiling: bool): def foo(arg0_1, arg1_1, arg2_1, arg3_1, arg4_1): view = torch.ops.aten.view.default(arg3_1, [23760, 128]) @@ -10197,9 +10218,15 @@ def fn(query, scores, window_overlap): if is_cpp_backend(self.device): opt_fn = torch._dynamo.optimize("inductor")(fn) _, code = run_and_get_cpp_code(opt_fn, *args) + num = ( + 2 + if cpu_vec_isa.valid_vec_isa_list() + and os.getenv("ATEN_CPU_CAPABILITY") != "default" + else 1 + ) FileCheck().check_count( "static_cast(256)", - 2, + num, exactly=True, ).run(code) @@ -10396,7 +10423,6 @@ def fn(q, k, v): rtol=1e-2, # to pass lowp check on GPU ) - @skipIfRocm @expectedFailureXPU def test_scaled_dot_product_efficient_attention(self): if self.device == "cpu": @@ -10433,7 +10459,6 @@ def fn(x): self.common(fn, (torch.randn((16, 16, 16)),), check_lowp=False) - @xfail_if_triton_cpu def test_searchsorted(self): def fn(sorted_sequence, values, out_int32, right, side, sorter): return torch.searchsorted( @@ -10783,7 +10808,6 @@ def fn(x): self.common(fn, (inp,), check_lowp=False) @requires_gpu() - @xfail_if_triton_cpu @config.patch(implicit_fallbacks=True) def test_mutable_custom_op_fixed_layout2(self): with torch.library._scoped_library("mylib", "DEF") as lib: @@ -11104,8 +11128,6 @@ def fn(x): self.common(fn, (x,)) @skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80") - # We only support dtypeview for abi_compatible aoti - @torch._inductor.config.patch(abi_compatible=True) @parametrize( "dtype_x, dtype_y", list(itertools.product(test_dtypes, test_dtypes)), @@ -11140,7 +11162,6 @@ def fn(x, y, x_dtype, x2): check_lowp=False, ) - @torch._inductor.config.patch(abi_compatible=True) def test_dtypeview_fusion(self): @torch.compile def fn(x): @@ -11538,7 +11559,6 @@ def forward(): self.common(forward, ()) - @xfail_if_triton_cpu def test_flip_cat(self): def forward(unsqueeze, unsqueeze_1): cat_1 = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1) @@ -12125,18 +12145,35 @@ def f(x, mask): @requires_gpu() @parametrize("upcast_to_fp32", [False, True]) + @config.patch("triton.use_block_ptr", True) def test_codegen_upcast_to_fp32(self, upcast_to_fp32): @torch.compile - def func(a, b): - return a * b + def func(a, b, c, d): + return a * b * c * d - inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=torch.float16),) * 2 + inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=torch.float16),) * 4 with config.patch("triton.codegen_upcast_to_fp32", upcast_to_fp32): func_opt = torch._dynamo.optimize("inductor")(func) code = run_and_get_triton_code(func_opt, *inps) fp32_cast_in_code = "to(tl.float32)" in code self.assertEqual(fp32_cast_in_code, upcast_to_fp32) + @requires_gpu() + @parametrize("load_upcast_to_fp32", [False, True]) + @parametrize("input_dtype", [torch.float16, torch.bfloat16]) + @config.patch("triton.use_block_ptr", True) + def test_dtype_aware_codegen(self, load_upcast_to_fp32, input_dtype): + @torch.compile + def func(a, b, c, d): + return torch.sqrt(a * b * c * d) + + inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=input_dtype),) * 4 + with config.patch("triton.codegen_upcast_to_fp32", load_upcast_to_fp32): + func_opt = torch._dynamo.optimize("inductor")(func) + code = run_and_get_triton_code(func_opt, *inps) + libdevice_cast_in_code = "libdevice.sqrt(tmp3.to(tl.float32))" in code + self.assertNotEqual(libdevice_cast_in_code, load_upcast_to_fp32) + @config.patch("triton.use_block_ptr", False) def test_evict_last_non_coalesced_loads(self): @torch.compile diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 07d5f98fb5595..fda02bda12bc7 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -11,7 +11,7 @@ import torch import torch.library -from torch._dynamo.testing import make_test_cls_with_patches +from torch._dynamo.testing import CompileCounterWithBackend, make_test_cls_with_patches from torch._inductor import metrics from torch._inductor.codegen.common import device_codegens, register_backend_for_device from torch._inductor.codegen.cpp import CppScheduling @@ -951,6 +951,38 @@ def f(xt): f(torch.tensor([5] * 320)) + def test_mark_unbacked_slice(self): + @torch.compile(backend="inductor", mode="reduce-overhead", fullgraph=True) + def f(x): + return x.sum() + + x = torch.empty_strided((1, 4), (5, 1), device=GPU_TYPE) + torch._dynamo.decorators.mark_unbacked(x, 0) + f(x) + + @torch._dynamo.config.patch(specialize_float=False, capture_scalar_outputs=True) + def test_unspecialized_float_operations(self): + operations = { + "multiply": operator.mul, + "add": operator.add, + "subtract": operator.sub, + "divide": operator.truediv, + } + + for name, op in operations.items(): + with self.subTest(operation=name): + + def fn(x, y): + return op(x, y) + + cnt = CompileCounterWithBackend("inductor") + fn_opt = torch._dynamo.optimize(cnt)(fn) + + x = torch.arange(3) + self.assertEqual(fn(x, 2.0), fn_opt(x, 2.0)) + self.assertEqual(fn(x, 3.0), fn_opt(x, 3.0)) + self.assertEqual(cnt.frame_count, 1) + def test_sort_dynamic_shape_with_check(self, device): if TEST_WITH_ROCM or torch.device(device).type != GPU_TYPE: diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index 429fc1d187e6b..f2cc17e258f60 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -351,13 +351,9 @@ def format_op(op): "nn.functional.conv_transpose3d": {f32, f64}, # rrelu not supported on XPU now "nn.functional.rrelu": {f16, f32, f64}, - "histc": {i32, i64}, # not implemented for 'Half' "nn.functional.multilabel_margin_loss": {f16}, "nn.functional.multi_margin_loss": {f16}, - "nn.functional.avg_pool3d": {f16}, - # not implemented for 'Bool' - "nn.functional.unfold": {b8}, } diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index 0d1c629413d7e..d9d6c7415fd29 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -497,6 +497,27 @@ def get_input() -> torch.Tensor: else: self.assertNotIn(tile_name, program) + def test_complex_reshape_block_ptr(self): + def func(x, y): + add_ = x + y + reshape_0 = add_.reshape([8, 16, 128]) + permute_0 = reshape_0.permute([0, 2, 1]) + reshape_1 = permute_0.reshape([1024, 16]) + clone_0 = reshape_1.clone(memory_format=torch.contiguous_format) + permute_1 = clone_0.permute([1, 0]) + clone_1 = permute_1.clone(memory_format=torch.contiguous_format) + + return clone_0, clone_1 + + inps = (torch.rand((8, 2048), device=GPU_TYPE, dtype=torch.float32),) * 2 + result, code = self.run_and_compare( + func, + *inps, + expected_num_triton_kernels=2, + expected_num_block_pointers=4, + ) + self.assertTrue("Min" not in code[0]) + if __name__ == "__main__": from torch._inductor.test_case import run_tests diff --git a/test/inductor/test_triton_extension_backend.py b/test/inductor/test_triton_extension_backend.py index 5646c2f7b365f..c2a0a8cdea7f7 100644 --- a/test/inductor/test_triton_extension_backend.py +++ b/test/inductor/test_triton_extension_backend.py @@ -36,9 +36,14 @@ register_device_op_overrides, ) from torch._inductor.utils import get_triton_code -from torch.testing._internal.common_utils import IS_MACOS +from torch.testing._internal.common_utils import IS_FBCODE, IS_MACOS +try: + from .test_extension_backend import BaseExtensionBackendTests +except ImportError: + from test_extension_backend import BaseExtensionBackendTests + try: try: from . import test_torchinductor @@ -59,42 +64,31 @@ def mock_triton_hash_with_backend(*args, **kwargs): return "".join(random.choices(string.ascii_uppercase + string.digits, k=64)) -class TritonExtensionBackendTests(TestCase): +@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now") +class TritonExtensionBackendTests(BaseExtensionBackendTests): """ Test creating a backend for inductor with Triton scheduling. """ - @classmethod - def setUpClass(cls): - super().setUpClass() - - @classmethod - def tearDownClass(cls): - cls._stack.close() - super().tearDownClass() - - def setUp(self): - torch._dynamo.reset() - super().setUp() - - def tearDown(self): - super().tearDown() - torch._dynamo.reset() - def test_open_device_registration(self): - register_backend_for_device("cpu", ExtensionScheduling, ExtensionWrapperCodegen) - register_device_op_overrides("cpu", CPUDeviceOpOverrides()) - device_interface.register_interface_for_device("cpu", DeviceInterface) + torch._register_device_module("privateuseone", self.module) + register_backend_for_device( + "privateuseone", ExtensionScheduling, ExtensionWrapperCodegen + ) + register_device_op_overrides("privateuseone", CPUDeviceOpOverrides()) + device_interface.register_interface_for_device("privateuseone", DeviceInterface) - self.assertTrue(get_scheduling_for_device("cpu") == ExtensionScheduling) - self.assertTrue( - get_wrapper_codegen_for_device("cpu") == ExtensionWrapperCodegen + self.assertEqual( + get_scheduling_for_device("privateuseone"), ExtensionScheduling + ) + self.assertEqual( + get_wrapper_codegen_for_device("privateuseone"), ExtensionWrapperCodegen ) - self.assertTrue( - device_interface.get_interface_for_device("cpu") == DeviceInterface + self.assertEqual( + device_interface.get_interface_for_device("privateuseone"), DeviceInterface ) - device = torch.device("cpu") + device = torch.device("privateuseone") x = torch.empty(2, 16).fill_(1).to(device) def foo(x): @@ -113,7 +107,7 @@ def foo(x): FileCheck().check("import triton").check("@triton.jit").check( "tl_math.sin" - ).check("device_str='cpu'").run(code) + ).check("device_str='privateuseone'").run(code) if __name__ == "__main__": diff --git a/test/inductor/test_triton_heuristics.py b/test/inductor/test_triton_heuristics.py index 9800fa25253d6..c2bd415a58da8 100644 --- a/test/inductor/test_triton_heuristics.py +++ b/test/inductor/test_triton_heuristics.py @@ -18,6 +18,7 @@ from torch._inductor import config from torch._inductor.runtime.hints import ( + AttrsDescriptorWrapper, AutotuneHint, DeviceProperties, HeuristicType, @@ -93,8 +94,6 @@ def test_artificial_grid_cpp_wrapper(self): self._test_artificial_zgrid() def _get_cos_kernel_caching_autotuner_args(self): - from triton.compiler.compiler import AttrsDescriptor # @manual - @triton.jit def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): xnumel = 16 @@ -110,7 +109,9 @@ def triton_(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): "signature": {"in_ptr0": "*fp32", "out_ptr0": "*fp32", "xnumel": "i32"}, "device": DeviceProperties.create(torch.device("cuda")), "constants": {}, - "configs": [AttrsDescriptor(divisible_by_16=(0, 1, 2), equal_to_1=())], + "configs": [ + AttrsDescriptorWrapper(divisible_by_16=(0, 1, 2), equal_to_1=()) + ], } configs = [ diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index 7900cd5c674bd..759f46e3c8ac3 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -29,7 +29,7 @@ # Defines all the kernels for tests from torch.testing._internal.triton_utils import * # noqa: F403 -from torch.utils._triton import has_triton_package +from torch.utils._triton import has_triton_package, has_triton_tma if HAS_GPU: @@ -54,10 +54,19 @@ fast_dividef as my_fast_dividef, ) + def _triton_get_ast_equal_to_str(params): + try: + from triton.backends.compiler import AttrsDescriptor # noqa: F401 + + return f"'tt.equal_to': {params}" + except ImportError: + return f"equal_to_1={params}" + # Define shared triton constants here. CONSTANT_C: tl.constexpr = 4 STRING_CONSTANT_C: tl.constexpr = "CONSTANT_C" BOOL_CONSTANT_C: tl.constexpr = True + FLOAT_CONSTANT_C = tl.constexpr(3.14) # intentionally un-annotated class KernelTests(torch._inductor.test_case.TestCase): @@ -99,6 +108,7 @@ def test_triton_kernel_higher_order_func(self): kernel_idx=add_kernel_id, constant_args_idx=constant_args_idx, grid=[grid], + tma_descriptor_metadata={}, kwargs={ "in_ptr0": t1, "in_ptr1": t2, @@ -115,6 +125,7 @@ def test_triton_kernel_higher_order_func(self): kernel_idx=add_kernel_id, constant_args_idx=constant_args_idx, grid=[grid], + tma_descriptor_metadata={}, kwargs={ "in_ptr0": t1, "in_ptr1": t2, @@ -145,6 +156,7 @@ def f(x, output): {"n_elements": output.numel(), "BLOCK_SIZE": 16} ), grid=[(x.numel(),)], + tma_descriptor_metadata={}, kwargs={ "in_ptr0": x, "out_ptr": output, @@ -173,7 +185,7 @@ def f(x, output): gm.code.strip(), """\ def forward(self, x_1, output_1): - triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 3, grid = [(5,)], kwargs = {'in_ptr0': x_1, 'out_ptr': output_1}, tensors_to_clone = ['in_ptr0', 'out_ptr']); x_1 = output_1 = None + triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 3, grid = [(5,)], tma_descriptor_metadata = {}, kwargs = {'in_ptr0': x_1, 'out_ptr': output_1}, tensors_to_clone = ['in_ptr0', 'out_ptr']); x_1 = output_1 = None getitem = triton_kernel_wrapper_functional_proxy['in_ptr0']; getitem = None getitem_1 = triton_kernel_wrapper_functional_proxy['out_ptr']; triton_kernel_wrapper_functional_proxy = None return getitem_1""", @@ -217,6 +229,7 @@ def prep(): {"n_elements": x_func.numel(), "BLOCK_SIZE": 16} ), grid=[(x_func.numel(),)], + tma_descriptor_metadata={}, kwargs={ "ptr": x_func, }, @@ -238,6 +251,7 @@ def prep(): {"n_elements": x_func.numel(), "BLOCK_SIZE": 16} ), grid=[(x_func.numel(),)], + tma_descriptor_metadata={}, kwargs={ "ptr": x_func, }, @@ -941,7 +955,7 @@ def f(x): f(x_cloned) out.sum().backward() - @requires_cuda + @requires_gpu @patch.object(torch._inductor.config, "allow_buffer_reuse", True) def test_triton_kernel_inputs_buffer_reuse(self): def _mul2(x): @@ -962,15 +976,15 @@ def f(x): x = _mul2(x) return x + 1 - x = torch.randn(10, device="cuda", dtype=torch.float32) + x = torch.randn(10, device=GPU_TYPE, dtype=torch.float32) eager_out = f(x) compiled_out, (code,) = run_and_get_code(torch.compile(f), x) self.assertEqual(compiled_out, eager_out) # Check that we're allocating the minimal # of buffers. - num_bufs_allocated = code.count( - "empty_strided_cuda((10, ), (1, ), torch.float32)" - ) + code_string = f"empty_strided_{GPU_TYPE}((10, ), (1, ), torch.float32)" + + num_bufs_allocated = code.count(code_string) self.assertEqual(num_bufs_allocated, 2) # Check we're re-using buffers if not allocating. @@ -1254,9 +1268,9 @@ def f(x, y): if dynamic: # when half_n_elements passed to the Triton kernel is # dynamic, equal_to_1 specializaiton can't be enforced - self.assertTrue("equal_to_1=()" in sources[0]) + self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0]) else: - self.assertTrue("equal_to_1=(3,)" in sources[0]) + self.assertTrue(_triton_get_ast_equal_to_str((3,)) in sources[0]) self.assertEqual(compiled_out, eager_out) @requires_gpu @@ -1285,7 +1299,7 @@ def f(x, y): # float 1.0 (both literal or symbolic) # should not be added to equal_to_1 - self.assertTrue("equal_to_1=()" in sources[0]) + self.assertTrue(_triton_get_ast_equal_to_str(()) in sources[0]) self.assertEqual(compiled_out, eager_out) @requires_gpu @@ -1630,6 +1644,266 @@ def f(x, y, z): self.assertEqual(out2, x + y + 1) self.assertEqual(out3, z**2) + @requires_gpu + @unittest.skipIf(not has_triton_tma(), "requires Triton TMA support") + @common_utils.parametrize("dynamic", [False, True]) + def test_tma_capture_and_functionalize(self, dynamic): + def f(a, b): + BLOCK_SIZE = 256 + out = torch.zeros_like(a) + n_elements = out.numel() + + desc_a, desc_b, desc_out = ( + triton.tools.experimental_descriptor.create_1d_tma_descriptor( + t.data_ptr(), + n_elements, + BLOCK_SIZE, + t.element_size(), + ) + for t in (a, b, out) + ) + + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + add_kernel_with_tma_1d[grid]( + desc_a, + desc_b, + desc_out, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return out + + a = torch.randn(301, device=GPU_TYPE) + b = torch.randn(301, device=GPU_TYPE) + + backend = torch._dynamo.testing.AotEagerAndRecordGraphs() + torch.compile( + f, + fullgraph=True, + backend=backend, + dynamic=dynamic, + )(a, b) + + if dynamic: + self.assertExpectedInline( + backend.fw_graphs[0].code.strip(), + """\ +def forward(self, arg0_1, arg1_1, arg2_1): + zeros_like = torch.ops.aten.zeros_like.default(arg1_1, pin_memory = False) + add_2 = arg0_1 + 256 + sub_1 = add_2 - 1; add_2 = None + floordiv = sub_1 // 256; sub_1 = None + triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 1, grid = [(floordiv, 1, 1)], tma_descriptor_metadata = {'in_desc_ptr0': ([arg0_1], [256], 4), 'in_desc_ptr1': ([arg0_1], [256], 4), 'out_desc_ptr': ([arg0_1], [256], 4)}, kwargs = {'in_desc_ptr0': arg1_1, 'in_desc_ptr1': arg2_1, 'out_desc_ptr': zeros_like}, tensors_to_clone = ['out_desc_ptr']); floordiv = arg0_1 = arg1_1 = arg2_1 = zeros_like = None + getitem = triton_kernel_wrapper_functional_proxy['out_desc_ptr']; triton_kernel_wrapper_functional_proxy = None + return (getitem,)""", + ) + else: + self.assertExpectedInline( + backend.fw_graphs[0].code.strip(), + """\ +def forward(self, arg0_1, arg1_1): + zeros_like = torch.ops.aten.zeros_like.default(arg0_1, pin_memory = False) + triton_kernel_wrapper_functional_proxy = torch.ops.higher_order.triton_kernel_wrapper_functional(kernel_idx = 0, constant_args_idx = 0, grid = [(2, 1, 1)], tma_descriptor_metadata = {'in_desc_ptr0': ([301], [256], 4), 'in_desc_ptr1': ([301], [256], 4), 'out_desc_ptr': ([301], [256], 4)}, kwargs = {'in_desc_ptr0': arg0_1, 'in_desc_ptr1': arg1_1, 'out_desc_ptr': zeros_like}, tensors_to_clone = ['out_desc_ptr']); arg0_1 = arg1_1 = zeros_like = None + getitem = triton_kernel_wrapper_functional_proxy['out_desc_ptr']; triton_kernel_wrapper_functional_proxy = None + return (getitem,)""", + ) + + @requires_gpu + @unittest.skipIf(not has_triton_tma(), "requires Triton TMA support") + @common_utils.parametrize("after_data_ptr", [False, True]) + @common_utils.parametrize("after_create_desc", [False, True]) + def test_tma_graph_breaks(self, after_data_ptr, after_create_desc): + def f(a, b): + BLOCK_SIZE = 256 + out = torch.zeros_like(a) + n_elements = out.numel() + + ptrs = [t.data_ptr() for t in (a, b, out)] + + if after_data_ptr: + torch._dynamo.graph_break() + + descs = [ + triton.tools.experimental_descriptor.create_1d_tma_descriptor( + ptr, + n_elements, + BLOCK_SIZE, + t.element_size(), + ) + for ptr in ptrs + ] + + if after_create_desc: + torch._dynamo.graph_break() + + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + add_kernel_with_tma_1d[grid]( + *descs, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return out + + a = torch.randn(301, device=GPU_TYPE) + b = torch.randn(301, device=GPU_TYPE) + + expected_out = a + b + eager_out = f(a, b) + compiled_out = torch.compile( + f, + fullgraph=False, + backend="eager", + dynamic=False, + )(a, b) + + self.assertEqual(eager_out, expected_out) + self.assertEqual(compiled_out, expected_out) + + @requires_gpu + @unittest.skipIf(not has_triton_tma(), "requires Triton TMA support") + @common_utils.parametrize("dynamic", [False, True]) + @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) + def test_tma_descriptor_1d(self, dynamic, backend): + def f(a, b): + BLOCK_SIZE = 256 + out = torch.zeros_like(a) + n_elements = out.numel() + + desc_a, desc_b, desc_out = ( + triton.tools.experimental_descriptor.create_1d_tma_descriptor( + t.data_ptr(), + n_elements, + BLOCK_SIZE, + t.element_size(), + ) + for t in (a, b, out) + ) + + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + add_kernel_with_tma_1d[grid]( + desc_a, + desc_b, + desc_out, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return out + + a = torch.randn(301, device=GPU_TYPE) + b = torch.randn(301, device=GPU_TYPE) + + expected_out = a + b + eager_out = f(a, b) + compiled_out = torch.compile( + f, + fullgraph=True, + backend=backend, + dynamic=dynamic, + )(a, b) + + self.assertEqual(eager_out, expected_out) + self.assertEqual(compiled_out, expected_out) + + @requires_gpu + @unittest.skipIf(not has_triton_tma(), "requires Triton TMA support") + def test_tma_descriptor_dedup(self): + def f(a): + BLOCK_SIZE = 256 + out = torch.zeros_like(a) + n_elements = out.numel() + + desc_a, desc_out = ( + triton.tools.experimental_descriptor.create_1d_tma_descriptor( + t.data_ptr(), + n_elements, + BLOCK_SIZE, + t.element_size(), + ) + for t in (a, out) + ) + + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + add_kernel_with_tma_1d[grid]( + desc_a, + desc_a, + desc_out, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return out + + a = torch.randn(301, device=GPU_TYPE) + + expected_out = a + a + eager_out = f(a) + compiled_out, (code,) = run_and_get_code( + torch.compile( + f, + fullgraph=True, + backend="inductor", + dynamic=True, + ), + a, + ) + + self.assertEqual(eager_out, expected_out) + self.assertEqual(compiled_out, expected_out) + + # 2 calls: one for two inputs (dedupped), one for the output + self.assertEqual(code.count("create_1d_tma_descriptor("), 2) + + @requires_gpu + @unittest.skipIf(not has_triton_tma(), "requires Triton TMA support") + @common_utils.parametrize("dynamic", [False, True]) + @common_utils.parametrize("backend", ["eager", "aot_eager"]) + def test_tma_descriptor_2d(self, dynamic, backend): + def f(a, b): + BLOCK_SIZE_X = 16 + BLOCK_SIZE_Y = 32 + out = torch.zeros_like(a) + x_size, y_size = out.size() + + desc_a, desc_b, desc_out = ( + triton.tools.experimental_descriptor.create_2d_tma_descriptor( + t.data_ptr(), + x_size, + y_size, + BLOCK_SIZE_X, + BLOCK_SIZE_Y, + t.element_size(), + ) + for t in (a, b, out) + ) + + grid = lambda meta: ( + triton.cdiv(x_size, meta["BLOCK_SIZE_X"]), + triton.cdiv(y_size, meta["BLOCK_SIZE_Y"]), + ) + add_kernel_with_tma_2d[grid]( + desc_a, + desc_b, + desc_out, + BLOCK_SIZE_X=BLOCK_SIZE_X, + BLOCK_SIZE_Y=BLOCK_SIZE_Y, + ) + + return out + + a = torch.randn((25, 16), device=GPU_TYPE) + b = torch.randn((25, 16), device=GPU_TYPE) + + expected_out = a + b + eager_out = f(a, b) + compiled_out = torch.compile( + f, + fullgraph=True, + backend=backend, + dynamic=dynamic, + )(a, b) + + self.assertEqual(eager_out, expected_out) + self.assertEqual(compiled_out, expected_out) + @requires_gpu @common_utils.parametrize("backend", ["eager", "aot_eager", "inductor"]) def test_triton_kernel_num_ctas(self, backend): @@ -1699,15 +1973,13 @@ def f(x): # TODO enable this test case on XPU. @requires_cuda - @parametrize("cfg", ["normal", "cpp_wrapper", "cpp_abi"]) + @parametrize("cfg", ["normal", "cpp_wrapper"]) def test_triton_kernel_dtype_view(self, cfg): # https://github.com/pytorch/pytorch/issues/136159 if cfg == "normal": - config_kwargs = {"cpp_wrapper": False, "abi_compatible": False} + config_kwargs = {"cpp_wrapper": False} elif cfg == "cpp_wrapper": - config_kwargs = {"cpp_wrapper": True, "abi_compatible": False} - elif cfg == "cpp_abi": - config_kwargs = {"cpp_wrapper": True, "abi_compatible": True} + config_kwargs = {"cpp_wrapper": True} with torch._inductor.config.patch(**config_kwargs): @@ -1745,7 +2017,7 @@ def fn(x): self.assertEqual(out_e[1], out_c[1]) # TODO enable this test case on XPU. - @requires_cuda + @requires_gpu def test_i64_input(self): # The i64 "seed" input needs to be marked as "i64", not "i32". @triton.jit @@ -1855,6 +2127,78 @@ def fn(x): res2 = fn_c(x2) self.assertEqual(x2 * x2, res2) + @requires_gpu + def test_triton_kernel_none_args(self): + # https://github.com/pytorch/pytorch/issues/115344 + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 32}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4), + ], + key=["n_elements"], + ) + @triton.jit + def sin_kernel( + in_ptr0, + out_ptr, + n_elements, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + if in_ptr0 is not None: + x = tl.load(in_ptr0 + offsets, mask=mask) + else: + x = 0.0 + output = tl.sin(x) + tl.store(out_ptr + offsets, output, mask=mask) + + def sin_triton(x, out): + n_elements = out.numel() + sin_kernel[(n_elements,)](x, out, n_elements) + + x = torch.randn(65, device=GPU_TYPE) + out = torch.empty_like(x) + out_compiled = torch.empty_like(x) + sin_triton_compiled = torch.compile(fullgraph=True)(sin_triton) + + sin_triton(x, out) + sin_triton_compiled(x, out_compiled) + self.assertEqual(out, out_compiled) + + sin_triton(None, out) + sin_triton_compiled(None, out_compiled) + self.assertEqual(out, out_compiled) + + @requires_gpu + def test_triton_kernel_global_constexpr(self): + @triton.jit + def triton_(in_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(in_ptr + offsets) + output = x + FLOAT_CONSTANT_C + tl.store(out_ptr + offsets, output) + + def fn(x): + y = torch.empty_like(x) + BLOCK_SIZE = 256 + grid = (triton.cdiv(x.numel(), BLOCK_SIZE),) + triton_[grid](x, y, BLOCK_SIZE) + return y + + # make sure FLOAT_CONSTANT_C is NOT annotated + self.assertFalse("FLOAT_CONSTANT_C" in globals().get("__annotations__", {})) + # sanity check: STRING_CONSTANT_C _should_ be annotated + self.assertTrue("STRING_CONSTANT_C" in globals().get("__annotations__", {})) + + x = torch.randn(512, device=GPU_TYPE) + expected = x + 3.14 + actual = torch.compile(fn)(x) + self.assertEqual(expected, actual) + def make_mutation_test(fn): @requires_gpu @@ -2031,7 +2375,7 @@ def argmax_kernel(a_ptr, c_ptr, stride_am, stride_an): expected, ) - @requires_cuda + @requires_gpu @skipIfRocm def test_triton_kernel_inference_mode(self): def f(x, y, out): @@ -2040,8 +2384,8 @@ def f(x, y, out): add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=4) with torch.inference_mode(): - x = torch.ones(32, device="cuda") - y = torch.ones(32, device="cuda") + x = torch.ones(32, device=GPU_TYPE) + y = torch.ones(32, device=GPU_TYPE) out_ref = torch.zeros_like(x) out_test = torch.zeros_like(x) f(x, y, out_ref) @@ -2910,6 +3254,7 @@ def f(x, y): gm = make_fx(f, tracing_mode=tracing_mode)(x, x) self.assertEqual(gm(x, x), x + x) + @skipIfXpu @requires_gpu @patch.object(torch._inductor.config, "cpp_wrapper", True) @patch.object(torch._inductor.config, "triton.autotune_at_compile_time", True) @@ -3018,8 +3363,8 @@ def grid(META): return z M, K, N = 128, 64, 32 - x = torch.randn(M, K, device="cuda") - w = torch.randn(K, N, device="cuda") + x = torch.randn(M, K, device=GPU_TYPE) + w = torch.randn(K, N, device=GPU_TYPE) torch._dynamo.decorators.mark_unbacked(x, 0) torch._logging.set_logs(output_code=True) diff --git a/test/inductor/test_unbacked_symints.py b/test/inductor/test_unbacked_symints.py index 0ef2e6131166c..5c438d6cbc709 100644 --- a/test/inductor/test_unbacked_symints.py +++ b/test/inductor/test_unbacked_symints.py @@ -11,14 +11,19 @@ from torch.testing import make_tensor from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, - skipCUDAIf, + skipGPUIf, ) from torch.testing._internal.common_utils import IS_LINUX, parametrize -from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CUDA +from torch.testing._internal.inductor_utils import ( + GPU_TYPE, + HAS_CUDA, + HAS_GPU, + requires_gpu, +) class TestUnbackedSymints(InductorTestCase): - @skipCUDAIf(not HAS_CUDA, "requires cuda") + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) def test_expand(self, device): def fn(x, y): @@ -39,7 +44,7 @@ def fn(x, y): torch.testing.assert_close(actual, expected) - @skipCUDAIf(not HAS_CUDA, "requires cuda") + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) def test_expand_ok_with_runtime_assert(self, device): def fn(x): @@ -50,7 +55,7 @@ def fn(x): x = make_tensor(32, 4, device=device, dtype=torch.float32, exclude_zero=True) actual = torch.compile(fn, fullgraph=True)(x) - @skipCUDAIf(not HAS_CUDA, "requires cuda") + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) def test_broadcast_tensors(self, device): def fn(x): @@ -64,7 +69,7 @@ def fn(x): expected = fn(x) torch.testing.assert_close(actual, expected) - @skipCUDAIf(not HAS_CUDA, "requires cuda") + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) def test_autotuning(self, device): def fn(x, y): @@ -88,7 +93,7 @@ def fn(x, y): torch.testing.assert_close(actual, expected) - @skipCUDAIf(not HAS_CUDA, "requires cuda") + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch({"capture_scalar_outputs": True}) def test_split_with_sizes(self, device): def fn(x, y): @@ -104,7 +109,7 @@ def fn(x, y): torch.testing.assert_close(actual, expected) - @skipCUDAIf(not HAS_CUDA, "requires cuda") + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) def test_view_of_slice(self, device): # Tests View.create(slice, size_with_unbacked_symint) @@ -122,9 +127,8 @@ def fn(x): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) - @skipCUDAIf(not HAS_CUDA, "requires cuda") + @requires_gpu() @dynamo_config.patch({"capture_scalar_outputs": True}) - @inductor_config.patch({"abi_compatible": True}) def test_triton_kernel_grid(self, device): if device == "cpu": raise unittest.SkipTest("Triton kernel requires GPU") @@ -145,7 +149,7 @@ def fn(x): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) - @skipCUDAIf(not HAS_CUDA, "requires cuda") + @skipGPUIf(not HAS_GPU, "requires gpu and triton") @dynamo_config.patch({"capture_dynamic_output_shape_ops": True}) def test_nonzero_in_inference_mode(self, device): def fn(x): @@ -191,15 +195,12 @@ def fn(x, w, a, b): expected = fn(*example_inputs) torch.testing.assert_close(actual, expected) - @skipCUDAIf(not HAS_CUDA, "requires cuda") + @requires_gpu() @dynamo_config.patch({"capture_scalar_outputs": True}) def test_vertical_pointwise_reduction_fusion(self, device): # reset in case we run both cpu and cuda tests torch._inductor.metrics.reset() - if device == "cpu": - raise unittest.SkipTest("This test requires cuda") - # Tests fusing a pointwise & reduction op with unbacked numel/rnumel. def fn(x, y, repeats): u0 = repeats.item() @@ -213,9 +214,9 @@ def fn(x, y, repeats): return pointwise, reduction example_inputs = ( - torch.randn(32, 16).cuda(), - torch.randn(1, 16).cuda(), - torch.tensor(32).cuda(), + torch.randn(32, 16).to(GPU_TYPE), + torch.randn(1, 16).to(GPU_TYPE), + torch.tensor(32).to(GPU_TYPE), ) actual = torch.compile(fn, fullgraph=True)(*example_inputs) @@ -279,12 +280,10 @@ def fn(x, num): torch.testing.assert_close(actual, expected) -instantiate_device_type_tests( - TestUnbackedSymints, globals(), only_for=(GPU_TYPE, "cpu") -) +instantiate_device_type_tests(TestUnbackedSymints, globals(), allow_xpu=True) if __name__ == "__main__": from torch._inductor.test_case import run_tests - if IS_LINUX and HAS_CUDA and is_big_gpu(0): + if IS_LINUX and HAS_GPU and (not HAS_CUDA or is_big_gpu(0)): run_tests() diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index 213ac2b69b2bf..8ed22b930134a 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -2850,7 +2850,6 @@ def test_append(self): with self.assertRaises(TypeError): script_data.append("str") - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") def test_clear(self): """ Test clear. diff --git a/test/nn/test_convolution.py b/test/nn/test_convolution.py index aa9a69bbfd228..27c57f302d193 100644 --- a/test/nn/test_convolution.py +++ b/test/nn/test_convolution.py @@ -724,6 +724,7 @@ def test_ConvTranspose2d_half_cublas_gemm(self): # For https://github.com/pytorch/pytorch/pull/1273 # Almost identical to the above `test_Conv2d_naive_groups` @torch.backends.cudnn.flags(enabled=True, benchmark=False) + @tf32_on_and_off(0.001) @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") def test_Conv2d_groups_nobias(self): dev_dtypes = [("cpu", torch.float)] @@ -769,6 +770,7 @@ def test_Conv2d_groups_nobias(self): # See also https://github.com/pytorch/pytorch/pull/18463#issuecomment-476563686 # and https://github.com/pytorch/pytorch/pull/18463#issuecomment-477001024 @torch.backends.cudnn.flags(enabled=True, benchmark=False) + @tf32_on_and_off(0.001) @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") def test_Conv2d_groups_nobias_v2(self): torch.manual_seed(123) @@ -3396,6 +3398,7 @@ def test_ConvTranspose3d_size_1_kernel(self, device): ) @dtypes(torch.float) @torch.backends.cudnn.flags(enabled=True, benchmark=False) + @tf32_on_and_off(0.001) @unittest.skipIf(TEST_WITH_ROCM, "Skipped on ROCm, since it is failing on ROCm 5.7") def test_Conv2d_naive_groups(self, device, dtype): # Check that grouped convolutions matches two half convolutions diff --git a/test/nn/test_embedding.py b/test/nn/test_embedding.py index 7ac8d58d3b2f1..40081b779dc61 100644 --- a/test/nn/test_embedding.py +++ b/test/nn/test_embedding.py @@ -876,7 +876,7 @@ def test_embedding_bag_dimension_errors(self, device): @dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long))) def test_EmbeddingBag_per_sample_weights_failures(self, device, dtypes): - # Failure 1: mismatched embeddings / per_sample_weights dtype + # Failure 1: mismatched embeddings / per_sample_weights dtype (only on CPU device) es = nn.EmbeddingBag(5, 2, mode="sum").to(dtype=torch.float, device=device) input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtypes[0], device=device) offsets = torch.tensor([0, 0, 3, 3, 6], dtype=dtypes[1], device=device) @@ -884,9 +884,6 @@ def test_EmbeddingBag_per_sample_weights_failures(self, device, dtypes): if device == "cpu": with self.assertRaisesRegex(RuntimeError, "have the same type as"): es(input, offsets, per_sample_weights) - else: - with self.assertRaisesRegex(RuntimeError, "expected scalar type"): - es(input, offsets, per_sample_weights) # Failure 2.1: input/per_sample_weights have different sizes (1d input) input = torch.tensor([3, 1, 1, 1, 4, 0], dtype=dtypes[0], device=device) diff --git a/test/nn/test_pooling.py b/test/nn/test_pooling.py index c609520b30869..b033003e93ccb 100644 --- a/test/nn/test_pooling.py +++ b/test/nn/test_pooling.py @@ -492,6 +492,16 @@ def test_quantized_max_pool1d_empty_kernel(self): with self.assertRaises(RuntimeError): torch.quantized_max_pool1d(temp_tensor, []) + def test_quantized_max_pool3d(self): + # This used to segfault when called with a negative dilation + # see https://github.com/pytorch/pytorch/issues/136716 + input = torch.randn([1, 1, 1, 1, 1]) + input = torch.quantize_per_tensor(input, -0.1, -10, torch.qint32) + with self.assertRaisesRegex(RuntimeError, "Expected dilation >= 1"): + torch.quantized_max_pool3d( + input, (1, 1, 1), (1, 1, 1), (0, 0, 0), (-3, 1, 1) + ) + class TestPoolingNNDeviceType(NNTestCase): @onlyNativeDeviceTypes diff --git a/test/onnx/exporter/test_capture_strategies.py b/test/onnx/exporter/test_capture_strategies.py new file mode 100644 index 0000000000000..c795fc21ecee7 --- /dev/null +++ b/test/onnx/exporter/test_capture_strategies.py @@ -0,0 +1,40 @@ +# Owner(s): ["module: onnx"] +"""Unit tests for the _capture_strategies module.""" + +from __future__ import annotations + +import torch +from torch.onnx._internal.exporter import _capture_strategies +from torch.testing._internal import common_utils + + +@common_utils.instantiate_parametrized_tests +class ExportStrategiesTest(common_utils.TestCase): + @common_utils.parametrize( + "strategy_cls", + [ + _capture_strategies.TorchExportStrategy, + _capture_strategies.TorchExportNonStrictStrategy, + _capture_strategies.JitTraceConvertStrategy, + ], + name_fn=lambda strategy_cls: strategy_cls.__name__, + ) + def test_jit_isinstance(self, strategy_cls): + class Model(torch.nn.Module): + def forward(self, a, b): + if torch.jit.isinstance(a, torch.Tensor): + return a.cos() + return b.sin() + + model = Model() + a = torch.tensor(0.0) + b = torch.tensor(1.0) + + result = strategy_cls()(model, (a, b), kwargs=None, dynamic_shapes=None) + ep = result.exported_program + assert ep is not None + torch.testing.assert_close(ep.module()(a, b), model(a, b)) + + +if __name__ == "__main__": + common_utils.run_tests() diff --git a/test/onnx/exporter/test_small_models_e2e.py b/test/onnx/exporter/test_small_models_e2e.py index c8f2dc223c615..cd60570329fd3 100644 --- a/test/onnx/exporter/test_small_models_e2e.py +++ b/test/onnx/exporter/test_small_models_e2e.py @@ -33,6 +33,20 @@ def forward(self, query, key, value): ) onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1) + def test_constant_complex(self): + class MulModule(torch.nn.Module): + def forward(self, x): + y = 2 + 3j + return torch.ops.aten.mul(x, y) + + # Example usage with complex inputs + x = torch.tensor( + [[1.0 + 2.0j, 3.0 + 4.0j], [5.0 + 6.0j, 7.0 + 8.0j]], dtype=torch.complex64 + ) + + onnx_program = torch.onnx.export(MulModule(), (x,), dynamo=True) + onnx_testing.assert_onnx_program(onnx_program, atol=1e-3, rtol=1) + if __name__ == "__main__": common_utils.run_tests() diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index 41a0bb4f8a860..d075c8f88f7c1 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -3,7 +3,6 @@ import logging import tempfile -from typing import Mapping, Tuple import onnx import onnx.inliner @@ -111,75 +110,6 @@ def forward(self, x): _ = dynamo_export(TopKModel(), x, export_options=self.export_options) - def test_symbolic_shape_of_values_inside_function_is_exported_as_graph_value_info( - self, - ): - class SubModule(torch.nn.Module): - def forward(self, x, y, bias): - output = x @ y - return output + bias - - class Module(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.submodule = SubModule() - - def forward(self, x, y, bias): - return self.submodule(x, y, bias) - - x = torch.randn(2, 3) - y = torch.randn(3, 4) - bias = torch.randn(4) - onnx_program = torch.onnx.dynamo_export( - Module(), - x, - y, - bias, - export_options=torch.onnx.ExportOptions(dynamic_shapes=True), - ) - model_proto = onnx_program.model_proto - - # Assert value_info for values inside local function can be retrieved - def _assert_node_outputs_has_value_info( - node: onnx.NodeProto, - value_infos: Mapping[str, onnx.ValueInfoProto], - local_functions: Mapping[Tuple[str, str], onnx.FunctionProto], - exclude_names_in_value_info, - function_id: str = "", - ): - for output in node.output: - name = f"{function_id}/{output}" if function_id else output - if name not in exclude_names_in_value_info: - self.assertIn(name, value_infos) - if node.domain.startswith("pkg.onnxscript.torch_lib"): - # No shape info available for values inside torchlib functions. - return - if ( - function := local_functions.get((node.domain, node.op_type)) - ) is not None: - for node in function.node: - function_id = f"{function.domain}::{function.name}" - _assert_node_outputs_has_value_info( - node, - value_infos, - local_functions, - exclude_names_in_value_info, - function_id, - ) - - type_infos = {vi.name: vi for vi in model_proto.graph.value_info} - functions = {(f.domain, f.name): f for f in model_proto.functions} - # NOTE: inputs, outputs, and initializers are not included in value_info spec - exclude_names_in_value_info = ( - [input.name for input in model_proto.graph.input] - + [output.name for output in model_proto.graph.output] - + [init.name for init in model_proto.graph.initializer] - ) - for node in model_proto.graph.node: - _assert_node_outputs_has_value_info( - node, type_infos, functions, exclude_names_in_value_info - ) - def test_dynamo_export_retains_readable_parameter_and_buffer_names(self): class SubModule(torch.nn.Module): def __init__(self) -> None: diff --git a/test/onnx/test_pytorch_onnx_no_runtime.py b/test/onnx/test_pytorch_onnx_no_runtime.py index 37ca3836e5387..380a208bf9881 100644 --- a/test/onnx/test_pytorch_onnx_no_runtime.py +++ b/test/onnx/test_pytorch_onnx_no_runtime.py @@ -83,7 +83,7 @@ def forward(self, x): x = torch.ones(3, 3) f = io.BytesIO() - torch.onnx.export(AddmmModel(), x, f, verbose=False) + torch.onnx.export(AddmmModel(), x, f) def test_onnx_transpose_incomplete_tensor_type(self): # Smoke test to get us into the state where we are attempting to export @@ -115,7 +115,8 @@ def foo(x): traced = torch.jit.trace(foo, (torch.rand([2]))) - torch.onnx.export_to_pretty_string(traced, (torch.rand([2]),)) + f = io.BytesIO() + torch.onnx.export(traced, (torch.rand([2]),), f) def test_onnx_export_script_module(self): class ModuleToExport(torch.jit.ScriptModule): @@ -125,7 +126,8 @@ def forward(self, x): return x + x mte = ModuleToExport() - torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False) + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) @common_utils.suppress_warnings def test_onnx_export_func_with_warnings(self): @@ -138,9 +140,8 @@ def forward(self, x): return func_with_warning(x) # no exception - torch.onnx.export_to_pretty_string( - WarningTest(), torch.randn(42), verbose=False - ) + f = io.BytesIO() + torch.onnx.export(WarningTest(), torch.randn(42), f) def test_onnx_export_script_python_fail(self): class PythonModule(torch.jit.ScriptModule): @@ -161,7 +162,7 @@ def forward(self, x): mte = ModuleToExport() f = io.BytesIO() with self.assertRaisesRegex(RuntimeError, "Couldn't export Python"): - torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f, verbose=False) + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) def test_onnx_export_script_inline_trace(self): class ModuleToInline(torch.nn.Module): @@ -179,7 +180,8 @@ def forward(self, x): return y + y mte = ModuleToExport() - torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False) + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) def test_onnx_export_script_inline_script(self): class ModuleToInline(torch.jit.ScriptModule): @@ -198,7 +200,8 @@ def forward(self, x): return y + y mte = ModuleToExport() - torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False) + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) def test_onnx_export_script_module_loop(self): class ModuleToExport(torch.jit.ScriptModule): @@ -212,7 +215,8 @@ def forward(self, x): return x mte = ModuleToExport() - torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False) + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) @common_utils.suppress_warnings def test_onnx_export_script_truediv(self): @@ -224,9 +228,8 @@ def forward(self, x): mte = ModuleToExport() - torch.onnx.export_to_pretty_string( - mte, (torch.zeros(1, 2, 3, dtype=torch.float),), verbose=False - ) + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3, dtype=torch.float),), f) def test_onnx_export_script_non_alpha_add_sub(self): class ModuleToExport(torch.jit.ScriptModule): @@ -236,7 +239,8 @@ def forward(self, x): return bs - 1 mte = ModuleToExport() - torch.onnx.export_to_pretty_string(mte, (torch.rand(3, 4),), verbose=False) + f = io.BytesIO() + torch.onnx.export(mte, (torch.rand(3, 4),), f) def test_onnx_export_script_module_if(self): class ModuleToExport(torch.jit.ScriptModule): @@ -247,7 +251,8 @@ def forward(self, x): return x mte = ModuleToExport() - torch.onnx.export_to_pretty_string(mte, (torch.zeros(1, 2, 3),), verbose=False) + f = io.BytesIO() + torch.onnx.export(mte, (torch.zeros(1, 2, 3),), f) def test_onnx_export_script_inline_params(self): class ModuleToInline(torch.jit.ScriptModule): @@ -277,7 +282,8 @@ def forward(self, x): torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4) ) self.assertEqual(result, reference) - torch.onnx.export_to_pretty_string(mte, (torch.ones(2, 3),), verbose=False) + f = io.BytesIO() + torch.onnx.export(mte, (torch.ones(2, 3),), f) def test_onnx_export_speculate(self): class Foo(torch.jit.ScriptModule): @@ -312,8 +318,10 @@ def transpose(x): f1 = Foo(transpose) f2 = Foo(linear) - torch.onnx.export_to_pretty_string(f1, (torch.ones(1, 10, dtype=torch.float),)) - torch.onnx.export_to_pretty_string(f2, (torch.ones(1, 10, dtype=torch.float),)) + f = io.BytesIO() + torch.onnx.export(f1, (torch.ones(1, 10, dtype=torch.float),), f) + f = io.BytesIO() + torch.onnx.export(f2, (torch.ones(1, 10, dtype=torch.float),), f) def test_onnx_export_shape_reshape(self): class Foo(torch.nn.Module): @@ -326,7 +334,8 @@ def forward(self, x): return reshaped foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3)) - torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3))) + f = io.BytesIO() + torch.onnx.export(foo, (torch.zeros(1, 2, 3)), f) def test_listconstruct_erasure(self): class FooMod(torch.nn.Module): @@ -334,9 +343,11 @@ def forward(self, x): mask = x < 0.0 return x[mask] - torch.onnx.export_to_pretty_string( + f = io.BytesIO() + torch.onnx.export( FooMod(), (torch.rand(3, 4),), + f, add_node_names=False, do_constant_folding=False, operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK, @@ -351,13 +362,10 @@ def forward(self, x): retval += torch.sum(x[0:i], dim=0) return retval - mod = DynamicSliceExportMod() - input = torch.rand(3, 4, 5) - torch.onnx.export_to_pretty_string( - DynamicSliceExportMod(), (input,), opset_version=10 - ) + f = io.BytesIO() + torch.onnx.export(DynamicSliceExportMod(), (input,), f, opset_version=10) def test_export_dict(self): class DictModule(torch.nn.Module): @@ -368,10 +376,12 @@ def forward(self, x_in: torch.Tensor) -> Dict[str, torch.Tensor]: mod = DictModule() mod.train(False) - torch.onnx.export_to_pretty_string(mod, (x_in,)) + f = io.BytesIO() + torch.onnx.export(mod, (x_in,), f) with self.assertRaisesRegex(RuntimeError, r"DictConstruct.+is not supported."): - torch.onnx.export_to_pretty_string(torch.jit.script(mod), (x_in,)) + f = io.BytesIO() + torch.onnx.export(torch.jit.script(mod), (x_in,), f) def test_source_range_propagation(self): class ExpandingModule(torch.nn.Module): @@ -497,11 +507,11 @@ def forward(self, box_regression: Tensor, proposals: List[Tensor]): proposal = [torch.randn(2, 4), torch.randn(2, 4)] with self.assertRaises(RuntimeError) as cm: - onnx_model = io.BytesIO() + f = io.BytesIO() torch.onnx.export( model, (box_regression, proposal), - onnx_model, + f, ) def test_initializer_sequence(self): @@ -637,7 +647,7 @@ def forward(self, x): x = torch.randn(1, 2, 3, requires_grad=True) f = io.BytesIO() - torch.onnx.export(Model(), x, f) + torch.onnx.export(Model(), (x,), f) model = onnx.load(f) model.ir_version = 0 @@ -744,7 +754,7 @@ def forward(self, x): f = io.BytesIO() with warnings.catch_warnings(record=True): - torch.onnx.export(MyDrop(), (eg,), f, verbose=False) + torch.onnx.export(MyDrop(), (eg,), f) def test_pack_padded_pad_packed_trace(self): from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence @@ -791,7 +801,7 @@ def forward(self, x, seq_lens): self.assertEqual(grad, grad_traced) f = io.BytesIO() - torch.onnx.export(m, (x, seq_lens), f, verbose=False) + torch.onnx.export(m, (x, seq_lens), f) # Suppression: ONNX warns when exporting RNNs because of potential batch size mismatch. @common_utils.suppress_warnings @@ -851,7 +861,7 @@ def forward(self, x, seq_lens): self.assertEqual(grad, grad_traced) f = io.BytesIO() - torch.onnx.export(m, (x, seq_lens), f, verbose=False) + torch.onnx.export(m, (x, seq_lens), f) def test_pushpackingpastrnn_in_peephole_create_own_gather_input(self): from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence @@ -931,7 +941,8 @@ class Mod(torch.nn.Module): def forward(self, x, w): return torch.matmul(x, w).detach() - torch.onnx.export_to_pretty_string(Mod(), (torch.rand(3, 4), torch.rand(4, 5))) + f = io.BytesIO() + torch.onnx.export(Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f) def test_aten_fallback_must_fallback(self): class ModelWithAtenNotONNXOp(torch.nn.Module): @@ -1088,12 +1099,12 @@ def sym_scatter_max(g, src, index, dim, out, dim_size): torch.onnx.register_custom_op_symbolic( "torch_scatter::scatter_max", sym_scatter_max, 1 ) + f = io.BytesIO() with torch.no_grad(): torch.onnx.export( m, (src, idx), - "mymodel.onnx", - verbose=False, + f, opset_version=13, custom_opsets={"torch_scatter": 1}, do_constant_folding=True, @@ -1176,7 +1187,7 @@ def forward(self, x): model = Net(C).cuda().half() x = torch.randn(N, C).cuda().half() f = io.BytesIO() - torch.onnx.export(model, x, f, opset_version=14) + torch.onnx.export(model, (x,), f, opset_version=14) onnx_model = onnx.load_from_string(f.getvalue()) const_node = [n for n in onnx_model.graph.node if n.op_type == "Constant"] self.assertNotEqual(len(const_node), 0) diff --git a/test/optim/test_lrscheduler.py b/test/optim/test_lrscheduler.py index 6b8dcbe05795e..316d639a6b5d5 100644 --- a/test/optim/test_lrscheduler.py +++ b/test/optim/test_lrscheduler.py @@ -759,6 +759,32 @@ def test_sequentiallr4(self): # Ensure that multiple schedulers does not affect the initial learning rate self.assertEqual(prev_lr, new_lr) + def test_sequentiallr5(self): + """ + Test SequentialLR with a ChainedScheduler. + """ + epochs = 10 + schedulers = [] + milestones = [] + + targets = [ + [0.0005, 0.0014, 0.0023, 0.0032, 0.0041] + + [0.025, 0.025, 0.025, 0.025, 0.025] + ] + + const_sched = ConstantLR(optimizer=self.opt, factor=0.1, total_iters=5) + lin_sched = LinearLR(optimizer=self.opt, start_factor=0.1, total_iters=5) + milestones.append(5) + + chained = ChainedScheduler([lin_sched, const_sched]) + schedulers.append(chained) + + const_sched2 = ConstantLR(optimizer=self.opt, factor=0.5, total_iters=5) + schedulers.append(const_sched2) + + scheduler = SequentialLR(self.opt, schedulers=schedulers, milestones=milestones) + self._test(scheduler, targets, epochs) + def test_get_last_lr_sequentiallr(self): epochs = 12 milestones = [3, 6] @@ -2405,6 +2431,60 @@ def test_lr_scheduler_state_dict_load(self, LRClass, weights_only): scheduler2.load_state_dict(state_dict_loaded) self.assertEqual(scheduler2.state_dict(), state_dict) + @parametrize("min_lr", ["scalar", "list"]) + def test_add_param_group_does_not_break_reduce_lr_on_plateau(self, min_lr): + epochs = 20 + for param_group in self.opt.param_groups: + param_group["lr"] = 0.5 + targets = [[0.5] * 6 + [0.05] * (5 + 6) + [0.005] * 4] + metrics = [1] * 7 + [0.6] + [0.5] * 12 + scheduler = ReduceLROnPlateau( + self.opt, + mode="min", + threshold_mode="rel", + threshold=0.1, + patience=5, + cooldown=5, + min_lr=0 if min_lr == "scalar" else [1e-5, 1e-4], + ) + for epoch in range(epochs): + # Point is to test the use case in #104361 + if epoch == 8: + param = torch.nn.Parameter(torch.rand(2, 3)) + self.opt.add_param_group({"params": [param], "lr": 0.05}) + if min_lr == "list": + scheduler.min_lrs.append(1e-6) + self.opt.step() + scheduler.step(metrics[epoch]) + for param_group, target in zip(self.opt.param_groups, targets): + self.assertEqual( + target[epoch], + param_group["lr"], + msg="LR is wrong in epoch {}: expected {}, got {}".format( + epoch, target[epoch], param_group["lr"] + ), + atol=1e-5, + rtol=0, + ) + + def test_add_param_group_errors_reduce_lr_on_plateau(self): + scheduler = ReduceLROnPlateau( + self.opt, + mode="min", + threshold_mode="rel", + threshold=1e-5, + patience=0, + cooldown=0, + min_lr=[1e-5, 1e-4], + ) + param = torch.nn.Parameter(torch.rand(2, 3)) + self.opt.add_param_group({"params": [param], "lr": 0.05}) + self.opt.step() + scheduler.step(1) + with self.assertRaisesRegex(RuntimeError, "The number of param groups in the"): + self.opt.step() + scheduler.step(1.3) + @parametrize( "LRClass", [ diff --git a/test/profiler/test_execution_trace.py b/test/profiler/test_execution_trace.py index 2514ecd71c1fd..da52d17845c63 100644 --- a/test/profiler/test_execution_trace.py +++ b/test/profiler/test_execution_trace.py @@ -34,10 +34,13 @@ supported_activities, ) from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_device_type import instantiate_device_type_tests from torch.testing._internal.common_utils import ( IS_WINDOWS, run_tests, + skipIfHpu, skipIfTorchDynamo, + TEST_HPU, TestCase, ) from torch.utils._triton import has_triton @@ -47,7 +50,7 @@ class TestExecutionTrace(TestCase): - def payload(self, use_cuda=False): + def payload(self, device, use_device=False): u = torch.randn(3, 4, 5, requires_grad=True) with record_function("## TEST 1 ##", "1, 2, 3"): inf_val = float("inf") @@ -67,17 +70,17 @@ def payload(self, use_cuda=False): nan_val, ) x = torch.randn(10, 10, requires_grad=True) - if use_cuda: - x = x.cuda() + if use_device: + x = x.to(device) y = torch.randn(10, 10, requires_grad=True) - if use_cuda: - y = y.cuda() + if use_device: + y = y.to(device) z = x + y + x * y + x * y z.backward(z) gelu = nn.GELU() m = torch.randn(2) _ = gelu(m) - if use_cuda: + if use_device: z = z.cpu() _record_function_with_args_exit(rf_handle) @@ -117,14 +120,19 @@ def get_kineto_rf_ids(self, events: List[Json]) -> List[int]: ) @unittest.skipIf(not kineto_available(), "Kineto is required") - def test_execution_trace_with_kineto(self): + @skipIfHpu + @skipIfTorchDynamo("profiler gets ignored if dynamo activated") + def test_execution_trace_with_kineto(self, device): trace_called_num = 0 def trace_handler(p): nonlocal trace_called_num trace_called_num += 1 - use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() + use_device = ( + torch.profiler.ProfilerActivity.CUDA + or torch.profiler.ProfilerActivity.HPU in supported_activities() + ) # Create a temp file to save execution trace and kineto data. fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) fp.close() @@ -145,7 +153,7 @@ def trace_handler(p): ) as p: for idx in range(10): with record_function(f"## LOOP {idx} ##"): - self.payload(use_cuda=use_cuda) + self.payload(device, use_device=use_device) p.step() self.assertEqual(fp.name, p.execution_trace_observer.get_output_file_path()) @@ -190,8 +198,11 @@ def trace_handler(p): f" rf_ids_kineto = {rf_ids_kineto}\n", ) - def test_execution_trace_alone(self): - use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() + def test_execution_trace_alone(self, device): + use_device = ( + torch.profiler.ProfilerActivity.CUDA + or torch.profiler.ProfilerActivity.HPU in supported_activities() + ) # Create a temp file to save execution trace data. fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) fp.close() @@ -203,7 +214,7 @@ def test_execution_trace_alone(self): for idx in range(5): expected_loop_events += 1 with record_function(f"## LOOP {idx} ##"): - self.payload(use_cuda=use_cuda) + self.payload(device, use_device=use_device) et.stop() assert fp.name == et.get_output_file_path() @@ -231,14 +242,15 @@ def test_execution_trace_alone(self): sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+" ) @unittest.skipIf(not TEST_CUDA or not has_triton(), "need CUDA and triton to run") - def test_execution_trace_with_pt2(self): + @skipIfHpu + def test_execution_trace_with_pt2(self, device): @torchdynamo.optimize("inductor") def fn(a, b, c): x = torch.nn.functional.linear(a, b) x = x + c return x.cos() - a, b, c = (torch.randn(4, 4, requires_grad=True).to("cuda") for _ in range(3)) + a, b, c = (torch.randn(4, 4, requires_grad=True).to(device) for _ in range(3)) inputs = [a, b, c] with torch._inductor.config.patch(compile_threads=1): @@ -275,8 +287,11 @@ def fn(a, b, c): assert len(n["outputs"]["values"]) == 0 assert found_captured_triton_kernel_node - def test_execution_trace_start_stop(self): - use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() + def test_execution_trace_start_stop(self, device): + use_device = ( + torch.profiler.ProfilerActivity.CUDA + or torch.profiler.ProfilerActivity.HPU in supported_activities() + ) # Create a temp file to save execution trace data. fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) fp.close() @@ -294,7 +309,7 @@ def test_execution_trace_start_stop(self): if et._execution_trace_running: expected_loop_events += 1 with record_function(f"## LOOP {idx} ##"): - self.payload(use_cuda=use_cuda) + self.payload(device, use_device=use_device) assert fp.name == et.get_output_file_path() et.unregister_callback() @@ -310,8 +325,11 @@ def test_execution_trace_start_stop(self): assert found_root_node assert loop_count == expected_loop_events - def test_execution_trace_repeat_in_loop(self): - use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() + def test_execution_trace_repeat_in_loop(self, device): + use_device = ( + torch.profiler.ProfilerActivity.CUDA + or torch.profiler.ProfilerActivity.HPU in supported_activities() + ) iter_list = {3, 4, 6, 8} expected_loop_events = len(iter_list) output_files = [] @@ -324,7 +342,7 @@ def test_execution_trace_repeat_in_loop(self): et = ExecutionTraceObserver().register_callback(fp.name) et.start() with record_function(f"## LOOP {idx} ##"): - self.payload(use_cuda=use_cuda) + self.payload(device, use_device=use_device) if idx in iter_list: et.stop() et.unregister_callback() @@ -343,7 +361,8 @@ def test_execution_trace_repeat_in_loop(self): assert found_root_node assert event_count == expected_loop_events - def test_execution_trace_no_capture(self): + @skipIfHpu + def test_execution_trace_no_capture(self, device): fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) fp.close() et = ExecutionTraceObserver().register_callback(fp.name) @@ -358,7 +377,8 @@ def test_execution_trace_no_capture(self): assert found_root_node @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/124500") - def test_execution_trace_nested_tensor(self): + @skipIfHpu + def test_execution_trace_nested_tensor(self, device): fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False) fp.close() @@ -383,5 +403,10 @@ def fn(nt): assert found_cos +devices = ["cpu", "cuda"] +if TEST_HPU: + devices.append("hpu") +instantiate_device_type_tests(TestExecutionTrace, globals(), only_for=devices) + if __name__ == "__main__": run_tests() diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index f4f4e2e99270a..ba9cbd79bb817 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -337,6 +337,7 @@ def extract(pattern: str): ) @serialTest() @parametrize("work_in_main_thread", [True, False]) + @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_source_multithreaded(self, name, thread_spec, work_in_main_thread): """Test various threading configurations. @@ -1452,6 +1453,7 @@ def test_nested_tensor_with_shapes(self): @patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"}) @patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"}) + @skipIfTorchDynamo("profiler gets ignored if dynamo activated") def test_kineto_profiler_with_environment_variable(self): script = """ import torch @@ -1957,6 +1959,9 @@ def test_cpu_annotation_overlap(self): record_shapes=True, with_stack=True, schedule=torch.profiler.schedule(wait=0, warmup=0, active=5, repeat=1), + experimental_config=torch._C._profiler._ExperimentalConfig( + adjust_profiler_step=True + ), ) as prof: for i in range(5): self._step_helper_func(prof) @@ -2042,6 +2047,40 @@ def test_lazy_build_tree(self): self.assertGreater(stats.function_events_build_tree_call_duration_us, 0) self.assertGreater(stats.number_of_events, 0) + @skipIfTorchDynamo("profiler gets ignored if dynamo activated") + @unittest.skipIf( + torch.cuda.is_available(), "CUDA complains about forking after init" + ) + @unittest.skipIf(IS_WINDOWS, "can't use os.fork() on Windows") + def test_forked_process(self): + # Induce a pid cache by running the profiler with payload + def validate_forked_json(profiler): + nonlocal cpu_op_found, parent_tid, child_pid + with TemporaryFileName(mode="w+") as fname: + profiler.export_chrome_trace(fname) + with open(fname) as f: + events = json.load(f)["traceEvents"] + for event in events: + if "cat" in event and event["cat"] == "cpu_op": + self.assertEqual(event["pid"], child_pid) + self.assertNotEqual(event["tid"], parent_tid) + cpu_op_found = True + + cpu_op_found = False + parent_tid = threading.current_thread().ident + with profile() as p: + self.payload() + pid = os.fork() + if pid == 0: + child_pid = os.getpid() + with profile() as p: + self.payload() + validate_forked_json(p) + self.assertTrue(cpu_op_found) + os._exit(0) + else: + os.waitpid(pid, 0) + class SimpleNet(nn.Module): def __init__(self) -> None: diff --git a/test/quantization/core/test_docs.py b/test/quantization/core/test_docs.py index 6e5a7cc18d923..6462366992457 100644 --- a/test/quantization/core/test_docs.py +++ b/test/quantization/core/test_docs.py @@ -6,15 +6,16 @@ import torch -# import torch.ao.nn.quantized as nnq from torch.testing._internal.common_quantization import ( QuantizationTestCase, SingleLayerLinearModel, ) from torch.testing._internal.common_quantized import override_quantized_engine -from torch.testing._internal.common_utils import IS_ARM64 +from torch.testing._internal.common_utils import IS_ARM64, IS_FBCODE +import unittest +@unittest.skipIf(IS_FBCODE, "some path issues in fbcode") class TestQuantizationDocs(QuantizationTestCase): r""" The tests in this section import code from the quantization docs and check that diff --git a/test/quantization/core/test_utils.py b/test/quantization/core/test_utils.py index 6024fe29eaefb..e4a3d3079c4ec 100644 --- a/test/quantization/core/test_utils.py +++ b/test/quantization/core/test_utils.py @@ -192,30 +192,31 @@ def test_quantize_weight_clamping_per_channel(self): assert quantized_tensor.int_repr().max().item() == q8_max assert quantized_tensor.int_repr().min().item() == q8_min - def test_uint1_7_dtype(self): + def test_uint4_int4_dtype(self): def up_size(size): return (*size[:-1], size[-1] * 2) - class UInt4Tensor(torch.Tensor): - @staticmethod - def __new__(cls, elem, **kwargs): - assert elem.dtype is torch.uint8 - assert not kwargs.get("requires_grad", False) - kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=torch.uint4, **kwargs) - - def __init__(self, elem): - self.elem = elem - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): - pass - - # make sure it runs - x = UInt4Tensor(torch.tensor([ - [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], - ], dtype=torch.uint8)) - assert x.dtype == torch.uint4 + for dtype in [torch.uint4, torch.int4]: + class UInt4OrInt4Tensor(torch.Tensor): + @staticmethod + def __new__(cls, elem, **kwargs): + assert elem.dtype is torch.uint8 + assert not kwargs.get("requires_grad", False) + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, up_size(elem.shape), dtype=dtype, **kwargs) + + def __init__(self, elem): + self.elem = elem + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + pass + + # make sure it runs + x = UInt4OrInt4Tensor(torch.tensor([ + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF], + ], dtype=torch.uint8)) + assert x.dtype == dtype diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index 6958b0e277359..8e3b5fa1cb44d 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -1,7 +1,6 @@ # Owner(s): ["oncall: quantization"] import copy import itertools -import sys from enum import Enum import torch @@ -25,12 +24,7 @@ skipIfNoX86, ) from torch.testing._internal.common_quantized import override_quantized_engine -from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, skipIfTorchDynamo - - -if IS_WINDOWS and IS_CI: - sys.stderr.write("Windows CI still has some issue to be fixed.\n") - sys.exit(0) +from torch.testing._internal.common_utils import skipIfTorchDynamo class NodePosType(Enum): diff --git a/test/run_test.py b/test/run_test.py index b2b04cfdad54e..10be33fcfadbb 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -70,6 +70,7 @@ ShardedTest, THRESHOLD, ) +from tools.testing.upload_artifacts import zip_and_upload_artifacts # Make sure to remove REPO_ROOT after import is done @@ -170,9 +171,6 @@ def __contains__(self, item): "distributed/_shard/checkpoint/test_checkpoint" "distributed/_shard/checkpoint/test_file_system_checkpoint" "distributed/_shard/sharding_spec/test_sharding_spec", - "distributed/_shard/sharding_plan/test_sharding_plan", - "distributed/_shard/sharded_tensor/test_sharded_tensor", - "distributed/_shard/sharded_tensor/test_sharded_tensor_reshard", "distributed/_shard/sharded_tensor/ops/test_embedding", "distributed/_shard/sharded_tensor/ops/test_embedding_bag", "distributed/_shard/sharded_tensor/ops/test_binary_cmp", @@ -220,6 +218,7 @@ def __contains__(self, item): "test_cuda_nvml_based_avail", # temporarily sets a global config "test_autograd_fallback", + "inductor/test_compiler_bisector", ] + FSDP_TEST # Test files that should always be run serially with other test files, @@ -289,23 +288,26 @@ def __contains__(self, item): } if dist.is_nccl_available(): DISTRIBUTED_TESTS_CONFIG["nccl"] = { - "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3", + "WORLD_SIZE": f"{torch.cuda.device_count()}", "TEST_REPORT_SOURCE_OVERRIDE": "dist-nccl", } if dist.is_gloo_available(): DISTRIBUTED_TESTS_CONFIG["gloo"] = { - "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3", + # TODO: retire testing gloo with CUDA + "WORLD_SIZE": f"{torch.cuda.device_count()}", "TEST_REPORT_SOURCE_OVERRIDE": "dist-gloo", } - if dist.is_ucc_available(): - DISTRIBUTED_TESTS_CONFIG["ucc"] = { - "WORLD_SIZE": "2" if torch.cuda.device_count() == 2 else "3", - "TEST_REPORT_SOURCE_OVERRIDE": "dist-ucc", - "UCX_TLS": "tcp,cuda", - "UCC_TLS": "nccl,ucp,cuda", - "UCC_TL_UCP_TUNE": "cuda:0", # don't use UCP TL on CUDA as it is not well supported - "UCC_EC_CUDA_USE_COOPERATIVE_LAUNCH": "n", # CI nodes (M60) fail if it is on - } + # Test with UCC backend is deprecated. + # See https://github.com/pytorch/pytorch/pull/137161 + # if dist.is_ucc_available(): + # DISTRIBUTED_TESTS_CONFIG["ucc"] = { + # "WORLD_SIZE": f"{torch.cuda.device_count()}", + # "TEST_REPORT_SOURCE_OVERRIDE": "dist-ucc", + # "UCX_TLS": "tcp,cuda", + # "UCC_TLS": "nccl,ucp,cuda", + # "UCC_TL_UCP_TUNE": "cuda:0", # don't use UCP TL on CUDA as it is not well supported + # "UCC_EC_CUDA_USE_COOPERATIVE_LAUNCH": "n", # CI nodes (M60) fail if it is on + # } # https://stackoverflow.com/questions/2549939/get-signal-names-from-numbers-in-python SIGNALS_TO_NAMES_DICT = { @@ -1330,6 +1332,10 @@ def parse_args(): action="store_false", help="Run tests without translation validation.", ) + parser.add_argument( + "--upload-artifacts-while-running", + action="store_true", + ) group = parser.add_mutually_exclusive_group() group.add_argument( @@ -1676,6 +1682,8 @@ def handle_error_messages(failure: Optional[TestFailure]): def parallel_test_completion_callback(failure): test_failed = handle_error_messages(failure) + if IS_CI and options.upload_artifacts_while_running: + zip_and_upload_artifacts(test_failed) if ( test_failed and not options.continue_through_error @@ -1768,6 +1776,8 @@ def main(): selected_tests = get_selected_tests(options) test_prioritizations = import_results() + if len(test_prioritizations.get_all_tests()) == 0: + options.enable_td = False test_prioritizations.amend_tests(selected_tests) os.makedirs(REPO_ROOT / "test" / "test-reports", exist_ok=True) diff --git a/test/slow_tests.json b/test/slow_tests.json index b198bdcf2d428..5ced052f14e15 100644 --- a/test/slow_tests.json +++ b/test/slow_tests.json @@ -1,302 +1,283 @@ { - "test_AllenaiLongformerBase_repro_cpu (__main__.CpuHalideTests)": 211.949, - "test_adaptive_max_pool2d1_cpu (__main__.CpuHalideTests)": 111.929, - "test_alexnet_prefix_cpu (__main__.CpuHalideTests)": 185.141, - "test_aot_autograd_symbolic_exhaustive_linalg_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 67.44693333333333, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 103.4952, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 215.06906666666666, - "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 126.95360000000001, - "test_aot_autograd_symbolic_exhaustive_svd_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 69.75275, - "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 110.57966666666667, - "test_aot_export_joint_simple_repro_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 345.20975000000004, - "test_aoti_eager_override_registration_cpu (__main__.CpuTests)": 81.80724000000001, - "test_aoti_eager_override_registration_cuda (__main__.GPUTests)": 81.5502857142857, - "test_aoti_eager_override_registration_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 80.72995238095237, - "test_aoti_eager_override_registration_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 79.88047619047619, - "test_aoti_eager_override_registration_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 75.02325, - "test_aoti_eager_override_registration_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 74.38550000000001, - "test_aoti_eager_with_scalar_cpu (__main__.CpuTests)": 87.54754166666667, - "test_aoti_eager_with_scalar_cuda (__main__.GPUTests)": 85.06014285714285, - "test_aoti_eager_with_scalar_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 87.73019047619047, - "test_aoti_eager_with_scalar_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 87.14119047619049, - "test_aoti_eager_with_scalar_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 83.3455, - "test_aoti_eager_with_scalar_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 82.4865, - "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 94.67914285714285, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 70.80028571428572, - "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 66.188125, - "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 66.239, - "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 135.30899999999997, - "test_avg_pool3d_backward_cpu (__main__.CpuHalideTests)": 61.719, - "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 78.70693333333334, - "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 96.261, - "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 237.76485714285712, - "test_basic_cuda (__main__.EfficientConvBNEvalCudaTests)": 92.503, - "test_captured_score_mod_aot_eager_gradcheck_score_mod_name__head_offset_mode_eager (__main__.TestFlexAttention)": 139.8705, - "test_checkpoint_cast (__main__.TestFxToOnnx)": 136.983, - "test_comprehensive_constant_pad_nd_cpu_float16 (__main__.TestInductorOpInfoCPU)": 60.31935294117646, - "test_comprehensive_diff_cpu_bool (__main__.TestInductorOpInfoCPU)": 92.7407, - "test_comprehensive_diff_cpu_float32 (__main__.TestInductorOpInfoCPU)": 92.67049999999999, - "test_comprehensive_diff_cpu_float64 (__main__.TestInductorOpInfoCPU)": 91.261, - "test_comprehensive_diff_cpu_int32 (__main__.TestInductorOpInfoCPU)": 92.00640000000001, - "test_comprehensive_diff_cpu_int64 (__main__.TestInductorOpInfoCPU)": 88.7649, - "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 120.986, - "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 105.15944444444446, - "test_comprehensive_diff_cuda_float32 (__main__.TestDecompCUDA)": 87.45349999999999, - "test_comprehensive_diff_cuda_float64 (__main__.TestDecompCUDA)": 88.94550000000001, - "test_comprehensive_dist_cpu_float16 (__main__.TestInductorOpInfoCPU)": 72.8767, - "test_comprehensive_dist_cpu_float32 (__main__.TestInductorOpInfoCPU)": 71.64869999999999, - "test_comprehensive_dist_cpu_float64 (__main__.TestInductorOpInfoCPU)": 70.62299999999999, - "test_comprehensive_eye_cpu_bool (__main__.TestInductorOpInfoCPU)": 112.79639999999999, - "test_comprehensive_eye_cpu_float16 (__main__.TestInductorOpInfoCPU)": 110.69359999999999, - "test_comprehensive_eye_cpu_float32 (__main__.TestInductorOpInfoCPU)": 111.8332, - "test_comprehensive_eye_cpu_float64 (__main__.TestInductorOpInfoCPU)": 113.01580000000001, - "test_comprehensive_eye_cpu_int32 (__main__.TestInductorOpInfoCPU)": 110.6647, - "test_comprehensive_eye_cpu_int64 (__main__.TestInductorOpInfoCPU)": 113.61270000000002, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 337.5013, - "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 112.65060000000001, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 352.82779999999997, - "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 67.2527, - "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 277.8468888888889, - "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 261.31533333333334, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 1217.510111111111, - "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 73.16566666666667, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 1187.5324999999998, - "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 81.23666666666666, - "test_comprehensive_linalg_svd_cuda_complex128 (__main__.TestDecompCUDA)": 83.36449999999999, - "test_comprehensive_linalg_svd_cuda_complex64 (__main__.TestDecompCUDA)": 85.197, - "test_comprehensive_linalg_vector_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 176.8523, - "test_comprehensive_linalg_vector_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 176.5644, - "test_comprehensive_linalg_vector_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 176.33440000000002, - "test_comprehensive_logspace_cpu_float32 (__main__.TestInductorOpInfoCPU)": 379.27200000000005, - "test_comprehensive_logspace_cpu_float64 (__main__.TestInductorOpInfoCPU)": 382.0692, - "test_comprehensive_logspace_cpu_int32 (__main__.TestInductorOpInfoCPU)": 365.48, - "test_comprehensive_logspace_cpu_int64 (__main__.TestInductorOpInfoCPU)": 366.99120000000005, - "test_comprehensive_masked_amax_cpu_float16 (__main__.TestInductorOpInfoCPU)": 85.73589999999999, - "test_comprehensive_masked_amax_cpu_float32 (__main__.TestInductorOpInfoCPU)": 84.76559999999999, - "test_comprehensive_masked_amax_cpu_float64 (__main__.TestInductorOpInfoCPU)": 83.74539999999999, - "test_comprehensive_masked_amax_cpu_int32 (__main__.TestInductorOpInfoCPU)": 81.6752, - "test_comprehensive_masked_amax_cpu_int64 (__main__.TestInductorOpInfoCPU)": 80.1269, - "test_comprehensive_masked_amin_cpu_float16 (__main__.TestInductorOpInfoCPU)": 85.1681, - "test_comprehensive_masked_amin_cpu_float32 (__main__.TestInductorOpInfoCPU)": 87.01599999999999, - "test_comprehensive_masked_amin_cpu_float64 (__main__.TestInductorOpInfoCPU)": 85.30009999999999, - "test_comprehensive_masked_amin_cpu_int32 (__main__.TestInductorOpInfoCPU)": 81.06280000000001, - "test_comprehensive_masked_amin_cpu_int64 (__main__.TestInductorOpInfoCPU)": 84.49640000000001, - "test_comprehensive_masked_mean_cpu_bool (__main__.TestInductorOpInfoCPU)": 82.6498, - "test_comprehensive_masked_mean_cpu_float16 (__main__.TestInductorOpInfoCPU)": 85.0721, - "test_comprehensive_masked_mean_cpu_float32 (__main__.TestInductorOpInfoCPU)": 86.45490000000002, - "test_comprehensive_masked_mean_cpu_float64 (__main__.TestInductorOpInfoCPU)": 84.9486, - "test_comprehensive_masked_mean_cpu_int32 (__main__.TestInductorOpInfoCPU)": 85.1464, - "test_comprehensive_masked_mean_cpu_int64 (__main__.TestInductorOpInfoCPU)": 83.1313, - "test_comprehensive_masked_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 422.03270000000003, - "test_comprehensive_masked_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 419.49539999999996, - "test_comprehensive_masked_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 409.55060000000003, - "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 93.67716666666666, - "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 88.622, - "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 91.79166666666667, - "test_comprehensive_masked_prod_cpu_bool (__main__.TestInductorOpInfoCPU)": 81.3001, - "test_comprehensive_masked_prod_cpu_float16 (__main__.TestInductorOpInfoCPU)": 86.5596, - "test_comprehensive_masked_prod_cpu_float32 (__main__.TestInductorOpInfoCPU)": 85.2926, - "test_comprehensive_masked_prod_cpu_float64 (__main__.TestInductorOpInfoCPU)": 84.71660000000001, - "test_comprehensive_masked_prod_cpu_int32 (__main__.TestInductorOpInfoCPU)": 84.0162, - "test_comprehensive_masked_prod_cpu_int64 (__main__.TestInductorOpInfoCPU)": 81.37209999999999, - "test_comprehensive_masked_sum_cpu_bool (__main__.TestInductorOpInfoCPU)": 81.57050000000001, - "test_comprehensive_masked_sum_cpu_float16 (__main__.TestInductorOpInfoCPU)": 82.18870000000001, - "test_comprehensive_masked_sum_cpu_float32 (__main__.TestInductorOpInfoCPU)": 82.77929999999999, - "test_comprehensive_masked_sum_cpu_float64 (__main__.TestInductorOpInfoCPU)": 81.9615, - "test_comprehensive_masked_sum_cpu_int32 (__main__.TestInductorOpInfoCPU)": 82.8871, - "test_comprehensive_masked_sum_cpu_int64 (__main__.TestInductorOpInfoCPU)": 83.2116, - "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex128 (__main__.TestDecompCUDA)": 62.840444444444444, - "test_comprehensive_nn_functional_conv_transpose3d_cuda_complex64 (__main__.TestDecompCUDA)": 63.12155555555556, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 115.99399999999999, - "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 112.42922222222222, - "test_comprehensive_nn_functional_glu_cpu_float16 (__main__.TestInductorOpInfoCPU)": 66.2836, - "test_comprehensive_nn_functional_glu_cpu_float32 (__main__.TestInductorOpInfoCPU)": 63.87760000000001, - "test_comprehensive_nn_functional_glu_cpu_float64 (__main__.TestInductorOpInfoCPU)": 61.07164705882354, - "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 93.46609090909091, - "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 92.66881818181818, - "test_comprehensive_nn_functional_grid_sample_cuda_bfloat16 (__main__.TestDecompCUDA)": 72.35, - "test_comprehensive_nn_functional_grid_sample_cuda_float16 (__main__.TestDecompCUDA)": 64.90466666666666, - "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 265.2443333333333, - "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 250.08033333333333, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 (__main__.TestDecompCUDA)": 61.85044444444444, - "test_comprehensive_nn_functional_interpolate_bicubic_cuda_float64 (__main__.TestDecompCUDA)": 63.002444444444436, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 102.0025, - "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 104.59100000000001, - "test_comprehensive_nn_functional_max_pool1d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 152.2596, - "test_comprehensive_nn_functional_max_pool1d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 140.01214285714286, - "test_comprehensive_nn_functional_max_pool1d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 140.58085714285716, - "test_comprehensive_nn_functional_max_pool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 710.7855714285714, - "test_comprehensive_nn_functional_max_pool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 697.3474285714285, - "test_comprehensive_nn_functional_max_pool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 678.1218571428572, - "test_comprehensive_nn_functional_max_pool2d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 641.9231428571428, - "test_comprehensive_nn_functional_max_pool2d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 655.1732857142857, - "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 775.9625, - "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 767.9121666666666, - "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 778.5028333333333, - "test_comprehensive_nn_functional_pad_constant_cpu_float16 (__main__.TestInductorOpInfoCPU)": 60.259235294117644, - "test_comprehensive_nn_functional_pad_constant_cpu_float32 (__main__.TestInductorOpInfoCPU)": 60.22264705882352, - "test_comprehensive_nn_functional_pad_constant_cpu_float64 (__main__.TestInductorOpInfoCPU)": 60.483411764705885, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float16 (__main__.TestInductorOpInfoCPU)": 94.4827142857143, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestInductorOpInfoCPU)": 96.45214285714285, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float64 (__main__.TestInductorOpInfoCPU)": 91.70985714285715, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_int32 (__main__.TestInductorOpInfoCPU)": 95.28557142857143, - "test_comprehensive_nn_functional_poisson_nll_loss_cpu_int64 (__main__.TestInductorOpInfoCPU)": 92.5167142857143, - "test_comprehensive_nn_functional_unfold_cpu_float16 (__main__.TestInductorOpInfoCPU)": 189.38628571428572, - "test_comprehensive_nn_functional_unfold_cpu_float32 (__main__.TestInductorOpInfoCPU)": 183.38171428571428, - "test_comprehensive_nn_functional_unfold_cpu_float64 (__main__.TestInductorOpInfoCPU)": 184.58571428571423, - "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 137.61211111111112, - "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 139.59522222222222, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestDecompCUDA)": 88.364, - "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 83.6305, - "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 87.9585, - "test_comprehensive_pca_lowrank_cuda_complex128 (__main__.TestDecompCUDA)": 68.5215, - "test_comprehensive_pca_lowrank_cuda_complex64 (__main__.TestDecompCUDA)": 62.06933333333333, - "test_comprehensive_svd_cuda_complex128 (__main__.TestDecompCUDA)": 94.2525, - "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 94.307, - "test_comprehensive_svd_lowrank_cuda_complex128 (__main__.TestDecompCUDA)": 68.14222222222222, - "test_comprehensive_svd_lowrank_cuda_complex64 (__main__.TestDecompCUDA)": 72.507, - "test_cond_autograd_nested (__main__.TestControlFlow)": 162.97220000000002, - "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 110.65333333333334, - "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 119.97566666666665, - "test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 86.57166666666667, - "test_constructor_autograd_SparseCSR_cuda (__main__.TestSparseAnyCUDA)": 88.60611111111112, - "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 93.19026470588237, - "test_conv2d_binary_inplace_fusion_failed_cpu_cpp_wrapper (__main__.TestCppWrapper)": 75.7442, - "test_conv3d_binary_dynamic_shapes (__main__.TestDynamicPatternMatcher)": 113.2255238095238, - "test_conv3d_unary_dynamic_shapes (__main__.TestDynamicPatternMatcher)": 71.51290476190476, - "test_conv_freezing_non_abi_compatible_cuda (__main__.AOTInductorTestNonABICompatibleCuda)": 73.27433333333333, - "test_conv_transpose2d_packed_cpu_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 62.88960000000001, - "test_correctness_NAdam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 66.9665, - "test_cusparse_multiple_threads_same_device (__main__.TestCuda)": 93.28726315789474, - "test_custom_module_lstm (__main__.TestQuantizedOps)": 81.82516, - "test_ddp_model_diff_shape_across_ranks (__main__.TestDistBackendWithSpawn)": 91.468, - "test_ddp_uneven_inputs (__main__.TestDistBackendWithSpawn)": 538.063, - "test_diff_hyperparams_sharding_strategy_str_full_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 70.84433333333332, - "test_diff_hyperparams_sharding_strategy_str_no_shard (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 69.47433333333333, - "test_diff_hyperparams_sharding_strategy_str_shard_grad_op (__main__.TestFSDPUseOrigParamsMultipleParamGroups)": 67.11266666666667, - "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 113.79488888888889, - "test_dtypeview_cpu (__main__.CpuTests)": 84.04108000000001, - "test_dtypeview_cuda_cuda_wrapper (__main__.TestCudaWrapper)": 279.51483333333334, - "test_dtypeview_cuda_dynamic_shapes_cuda_wrapper (__main__.DynamicShapesCudaWrapperCudaTests)": 283.54316666666665, - "test_dtypeview_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 96.16909523809525, - "test_dtypeview_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 97.33604761904763, - "test_fail_creation_ops.py (__main__.TestTyping)": 65.19961538461537, - "test_fail_random.py (__main__.TestTyping)": 63.64657575757575, - "test_fake_tensor_mode_huggingface_databricks_dolly_v2_3b (__main__.TORCH_EXPORT_EXPORTEDPROGRAM)": 96.72200000000001, - "test_fake_tensor_mode_huggingface_google_t5 (__main__.TORCH_EXPORT_EXPORTEDPROGRAM)": 108.259, - "test_fake_tensor_mode_huggingface_google_t5 (__main__.TORCH_NN_MODULE)": 209.738, - "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 100.12588888888891, - "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 108.84245454545454, - "test_fn_gradgrad_map_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 81.5349411764706, - "test_fn_gradgrad_map_nested_cuda_float64 (__main__.TestBwdGradientsCUDA)": 78.576, - "test_fn_gradgrad_map_triple_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 514.4126666666667, - "test_fn_gradgrad_map_triple_nested_cuda_float64 (__main__.TestBwdGradientsCUDA)": 398.1390909090909, - "test_fn_gradgrad_ormqr_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 70.3155, - "test_fuse_large_params_cpu (__main__.CpuTests)": 101.38141666666665, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 142.97099999999998, - "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 144.0665238095238, - "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 73.9485, - "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 97.20400000000001, - "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 230.74563636363638, - "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 130.09063636363638, - "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 152.54545454545453, - "test_grid_sampler_2d_cpu (__main__.CpuHalideTests)": 191.693, - "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 105.473, - "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 136.84644444444444, - "test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 382.9056, - "test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 403.69230000000005, - "test_linear_dynamic_shapes_batch_size_1000_in_features_1000_out_features_1024_bias_False_input_3d_False_cpu_float16 (__main__.TestSelectAlgorithmDynamicShapesCPU)": 66.2571, - "test_linear_dynamic_shapes_batch_size_1000_in_features_1000_out_features_1024_bias_False_input_3d_True_cpu_float16 (__main__.TestSelectAlgorithmDynamicShapesCPU)": 131.267, - "test_linear_dynamic_shapes_batch_size_1000_in_features_1000_out_features_1024_bias_True_input_3d_False_cpu_float16 (__main__.TestSelectAlgorithmDynamicShapesCPU)": 69.4666, - "test_linear_dynamic_shapes_batch_size_1000_in_features_1000_out_features_1024_bias_True_input_3d_True_cpu_float16 (__main__.TestSelectAlgorithmDynamicShapesCPU)": 134.5207, - "test_linear_packed_cpp_wrapper (__main__.TestCppWrapper)": 198.969, - "test_linear_packed_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 206.356, - "test_linear_static_shapes_batch_size_1000_in_features_1000_out_features_1024_bias_False_input_3d_False_cpu_float16 (__main__.TestSelectAlgorithmCPU)": 67.953, - "test_linear_static_shapes_batch_size_1000_in_features_1000_out_features_1024_bias_False_input_3d_True_cpu_float16 (__main__.TestSelectAlgorithmCPU)": 128.8496, - "test_linear_static_shapes_batch_size_1000_in_features_1000_out_features_1024_bias_True_input_3d_False_cpu_float16 (__main__.TestSelectAlgorithmCPU)": 71.0599, - "test_linear_static_shapes_batch_size_1000_in_features_1000_out_features_1024_bias_True_input_3d_True_cpu_float16 (__main__.TestSelectAlgorithmCPU)": 132.49540000000005, - "test_lstm_cpu (__main__.TestMkldnnCPU)": 94.47895454545456, - "test_lstm_packed_change_input_sizes_cpu_cpp_wrapper (__main__.TestCppWrapper)": 62.3839, - "test_lstm_packed_change_input_sizes_cpu_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 60.9987, - "test_max_autotune_cutlass_backend_addmm_dynamic_False_max_autotune_gemm_backends_ATen,Triton,CUTLASS (__main__.TestCutlassBackend)": 81.91425, - "test_missing_cubin_non_abi_compatible_cuda (__main__.AOTInductorTestNonABICompatibleCuda)": 76.22216666666667, - "test_pipeline_order_flex_and_zero_bubble_ScheduleClass0 (__main__.TestSchedulePlan)": 76.18842857142859, - "test_python_ref__refs_special_zeta_cuda_float64 (__main__.TestCommonCUDA)": 64.07560000000001, - "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 111.9, - "test_python_ref_torch_fallback__refs_special_zeta_cuda_float64 (__main__.TestCommonCUDA)": 62.56008333333333, - "test_qat_conv2d_unary (__main__.TestQuantizePT2EX86Inductor)": 165.02692, - "test_qat_conv_bn_fusion_cuda (__main__.TestQuantizePT2EQAT_ConvBn1d)": 64.07754545454546, - "test_qat_conv_bn_fusion_cuda (__main__.TestQuantizePT2EQAT_ConvBn2d)": 63.76154545454545, - "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn1d)": 75.56493617021276, - "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn2d)": 74.55095744680851, - "test_qat_conv_bn_relu_fusion_cuda (__main__.TestQuantizePT2EQAT_ConvBn1d)": 62.93418181818183, - "test_qat_conv_bn_relu_fusion_cuda (__main__.TestQuantizePT2EQAT_ConvBn2d)": 64.67954545454546, - "test_qat_mobilenet_v2 (__main__.TestQuantizePT2EQATModels)": 206.24339393939394, - "test_qat_resnet18 (__main__.TestQuantizePT2EQATModels)": 68.29046153846154, - "test_qlinear_add_cpu (__main__.TestPatternMatcher)": 70.38814285714287, - "test_qlinear_add_cpu_cpp_wrapper (__main__.TestCppWrapper)": 504.9821999999999, - "test_qlinear_add_cpu_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 528.0255, - "test_qlinear_add_int8_mixed_bf16 (__main__.TestPatternMatcher)": 153.60283333333334, - "test_qlinear_add_relu_cpu (__main__.TestPatternMatcher)": 71.7899523809524, - "test_qlinear_add_relu_cpu_cpp_wrapper (__main__.TestCppWrapper)": 512.7648999999999, - "test_qlinear_add_relu_cpu_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 530.9839, - "test_qlinear_add_relu_int8_mixed_bf16 (__main__.TestPatternMatcher)": 157.79833333333332, - "test_qlinear_gelu_cpu_cpp_wrapper (__main__.TestCppWrapper)": 61.369600000000005, - "test_qlinear_gelu_cpu_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 61.51380000000001, - "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 389.8819, - "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 888.4723333333333, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 578.2715000000001, - "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 1291.462111111111, - "test_quick_core_backward_expand_copy_cuda_float64 (__main__.TestDecompCUDA)": 92.35, - "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 97.31772727272727, - "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 212.4562222222222, - "test_quick_core_backward_select_scatter_cpu_float64 (__main__.TestDecompCPU)": 65.85936363636364, - "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 151.2542222222222, - "test_quick_core_backward_split_cuda_float64 (__main__.TestDecompCUDA)": 79.787, - "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 76.91172727272728, - "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 173.51111111111112, - "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 117.01122222222222, - "test_rnn_decomp_module_nn_LSTM_train_mode_cuda_float32 (__main__.TestDecompCUDA)": 72.3970909090909, - "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 128.59333333333333, - "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 111.14317647058824, - "test_sum_all_cpu_float64 (__main__.TestReductionsCPU)": 161.41700000000003, - "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 306.7471111111111, - "test_svd_lowrank_cuda_float64 (__main__.TestLinalgCUDA)": 70.03777777777779, - "test_terminate_handler_on_crash (__main__.TestTorch)": 68.968, - "test_terminate_signal (__main__.ForkTest)": 144.6801515151515, - "test_terminate_signal (__main__.SpawnTest)": 135.16911764705878, - "test_train_parity_multi_group (__main__.TestFullyShard1DTrainingCore)": 88.26866666666666, - "test_transpose_copy (__main__.CPUReproTests)": 62.54614285714285, - "test_triton_bsr_scatter_mm_blocksize_32_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 119.48249999999999, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 71.85900000000001, - "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 87.0915, - "test_unary_ops (__main__.TestTEFuserDynamic)": 119.30365384615385, - "test_unary_ops (__main__.TestTEFuserStatic)": 90.27661538461538, - "test_unspec_inputs_cuda_cuda_wrapper (__main__.TestCudaWrapper)": 84.24216666666666, - "test_unspec_inputs_cuda_dynamic_shapes_cuda_wrapper (__main__.DynamicShapesCudaWrapperCudaTests)": 83.43050000000001, - "test_upsample_bicubic2d_cpu (__main__.CpuHalideTests)": 96.144, - "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 106.12053333333334, - "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 105.8185, - "test_vec_compare_op_cpu_only (__main__.CPUReproTests)": 71.327, - "test_verify_model_across_rank_with_logger (__main__.TestDistBackendWithSpawn)": 61.44333333333333, - "test_verify_model_across_rank_without_logger (__main__.TestDistBackendWithSpawn)": 61.16233333333334, - "test_vmapjvpvjp_diff_cuda_float32 (__main__.TestOperatorsCUDA)": 81.9345, - "test_vmapjvpvjp_linalg_lu_solve_cpu_float32 (__main__.TestOperatorsCPU)": 62.15947826086957, - "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 86.87155555555556, - "test_vmapjvpvjp_linalg_multi_dot_cuda_float32 (__main__.TestOperatorsCUDA)": 84.2345, - "test_vmapjvpvjp_linalg_solve_triangular_cuda_float32 (__main__.TestOperatorsCUDA)": 83.042, - "test_vmapjvpvjp_linalg_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 91.31800000000001, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 84.47900000000003, - "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 111.041, - "test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 70.18206666666667, - "test_vmapjvpvjp_nn_functional_conv2d_cuda_float32 (__main__.TestOperatorsCUDA)": 71.4435, - "test_vmapjvpvjp_nn_functional_max_pool1d_cuda_float32 (__main__.TestOperatorsCUDA)": 65.864, - "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 78.47160000000001, - "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 108.5055, - "test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 78.44033333333334, - "test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 61.437625000000004, - "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 113.4555, - "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 111.5335, - "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 117.1695, - "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 146.6571111111111 + "test_AllenaiLongformerBase_repro_cpu (__main__.CpuHalideTests)": 217.70533333333333, + "test__adaptive_avg_pool2d (__main__.CPUReproTests)": 183.53220000000002, + "test_adaptive_max_pool2d1_cpu (__main__.CpuHalideTests)": 113.71199999999999, + "test_after_aot_cpu_runtime_error (__main__.MinifierIsolateTests)": 65.14765517241378, + "test_alexnet_prefix_cpu (__main__.CpuHalideTests)": 192.105, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool1d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 90.7365, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 152.649, + "test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool3d_cpu_float32 (__main__.TestEagerFusionOpInfoCPU)": 102.364, + "test_aot_autograd_symbolic_module_exhaustive_nn_TransformerDecoderLayer_cpu_float32 (__main__.TestEagerFusionModuleInfoCPU)": 107.731, + "test_aot_export_joint_simple_repro_dynamic_shapes (__main__.DynamicShapesAotAutogradFallbackTests)": 521.2014, + "test_associative_scan_dim_reverse_False_combine_mode_generic_cpu (__main__.TestControlFlow)": 71.284, + "test_associative_scan_dim_reverse_True_combine_mode_generic_cpu (__main__.TestControlFlow)": 72.0559090909091, + "test_avg_pool3d_backward2_cpu (__main__.CpuTests)": 482.14825423728814, + "test_avg_pool3d_backward2_cuda (__main__.GPUTests)": 91.70333333333333, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 507.2, + "test_avg_pool3d_backward2_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 504.7608, + "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 63.166666666666664, + "test_avg_pool3d_backward2_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 124.896, + "test_avg_pool3d_backward_cpu (__main__.CpuHalideTests)": 62.757666666666665, + "test_backward_nn_functional_multi_head_attention_forward_cpu_float32 (__main__.TestCompositeComplianceCPU)": 76.14750000000001, + "test_backward_nn_functional_multi_head_attention_forward_cuda_float32 (__main__.TestCompositeComplianceCUDA)": 60.264, + "test_basic_cpu (__main__.EfficientConvBNEvalCpuTests)": 278.2862, + "test_basic_cuda (__main__.EfficientConvBNEvalGpuTests)": 111.48954545454545, + "test_builtin_equivalent_funcs (__main__.TorchFunctionModeTests)": 108.13043478260867, + "test_captured_score_mod_aot_eager_gradcheck_score_mod_name__head_offset_mode_eager (__main__.TestFlexAttention)": 163.244, + "test_collect_callgrind (__main__.TestBenchmarkUtils)": 378.20866666666666, + "test_comprehensive_constant_pad_nd_cpu_float16 (__main__.TestInductorOpInfoCPU)": 80.16999999999999, + "test_comprehensive_constant_pad_nd_cpu_float32 (__main__.TestInductorOpInfoCPU)": 70.06291666666665, + "test_comprehensive_constant_pad_nd_cpu_float64 (__main__.TestInductorOpInfoCPU)": 69.58566666666667, + "test_comprehensive_constant_pad_nd_cpu_int32 (__main__.TestInductorOpInfoCPU)": 69.70495833333334, + "test_comprehensive_constant_pad_nd_cpu_int64 (__main__.TestInductorOpInfoCPU)": 69.52449999999999, + "test_comprehensive_diff_cpu_bool (__main__.TestInductorOpInfoCPU)": 118.6565, + "test_comprehensive_diff_cpu_float32 (__main__.TestInductorOpInfoCPU)": 122.9565, + "test_comprehensive_diff_cpu_float64 (__main__.TestInductorOpInfoCPU)": 114.0035, + "test_comprehensive_diff_cpu_int32 (__main__.TestInductorOpInfoCPU)": 112.271, + "test_comprehensive_diff_cpu_int64 (__main__.TestInductorOpInfoCPU)": 113.428, + "test_comprehensive_diff_cuda_complex128 (__main__.TestDecompCUDA)": 73.43875, + "test_comprehensive_diff_cuda_complex64 (__main__.TestDecompCUDA)": 75.05725, + "test_comprehensive_dist_cpu_float16 (__main__.TestInductorOpInfoCPU)": 95.8675, + "test_comprehensive_dist_cpu_float32 (__main__.TestInductorOpInfoCPU)": 91.992, + "test_comprehensive_dist_cpu_float64 (__main__.TestInductorOpInfoCPU)": 92.976, + "test_comprehensive_eye_cpu_bool (__main__.TestInductorOpInfoCPU)": 143.73000000000002, + "test_comprehensive_eye_cpu_float16 (__main__.TestInductorOpInfoCPU)": 138.624, + "test_comprehensive_eye_cpu_float32 (__main__.TestInductorOpInfoCPU)": 139.755, + "test_comprehensive_eye_cpu_float64 (__main__.TestInductorOpInfoCPU)": 147.81799999999998, + "test_comprehensive_eye_cpu_int32 (__main__.TestInductorOpInfoCPU)": 140.828, + "test_comprehensive_eye_cpu_int64 (__main__.TestInductorOpInfoCPU)": 143.93099999999998, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestDecompCPU)": 341.48900000000003, + "test_comprehensive_grid_sampler_2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 86.49000000000001, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestDecompCPU)": 323.1645, + "test_comprehensive_grid_sampler_2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 85.9655, + "test_comprehensive_grid_sampler_2d_cuda_bfloat16 (__main__.TestDecompCUDA)": 215.22800000000004, + "test_comprehensive_grid_sampler_2d_cuda_float16 (__main__.TestDecompCUDA)": 201.79633333333334, + "test_comprehensive_grid_sampler_2d_cuda_float32 (__main__.TestDecompCUDA)": 756.4825, + "test_comprehensive_grid_sampler_2d_cuda_float64 (__main__.TestDecompCUDA)": 856.9263333333333, + "test_comprehensive_linalg_vector_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 219.004, + "test_comprehensive_linalg_vector_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 227.63799999999998, + "test_comprehensive_linalg_vector_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 223.42000000000002, + "test_comprehensive_logspace_cpu_float32 (__main__.TestInductorOpInfoCPU)": 474.6385, + "test_comprehensive_logspace_cpu_float64 (__main__.TestInductorOpInfoCPU)": 496.866, + "test_comprehensive_logspace_cpu_int32 (__main__.TestInductorOpInfoCPU)": 459.975, + "test_comprehensive_logspace_cpu_int64 (__main__.TestInductorOpInfoCPU)": 457.97450000000003, + "test_comprehensive_masked_amax_cpu_float16 (__main__.TestInductorOpInfoCPU)": 107.4735, + "test_comprehensive_masked_amax_cpu_float32 (__main__.TestInductorOpInfoCPU)": 106.3655, + "test_comprehensive_masked_amax_cpu_float64 (__main__.TestInductorOpInfoCPU)": 112.69399999999999, + "test_comprehensive_masked_amax_cpu_int32 (__main__.TestInductorOpInfoCPU)": 105.007, + "test_comprehensive_masked_amax_cpu_int64 (__main__.TestInductorOpInfoCPU)": 100.816, + "test_comprehensive_masked_amin_cpu_float16 (__main__.TestInductorOpInfoCPU)": 106.1785, + "test_comprehensive_masked_amin_cpu_float32 (__main__.TestInductorOpInfoCPU)": 106.233, + "test_comprehensive_masked_amin_cpu_float64 (__main__.TestInductorOpInfoCPU)": 106.112, + "test_comprehensive_masked_amin_cpu_int32 (__main__.TestInductorOpInfoCPU)": 101.783, + "test_comprehensive_masked_amin_cpu_int64 (__main__.TestInductorOpInfoCPU)": 102.14850000000001, + "test_comprehensive_masked_mean_cpu_bool (__main__.TestInductorOpInfoCPU)": 103.827, + "test_comprehensive_masked_mean_cpu_float16 (__main__.TestInductorOpInfoCPU)": 105.78999999999999, + "test_comprehensive_masked_mean_cpu_float32 (__main__.TestInductorOpInfoCPU)": 104.21549999999999, + "test_comprehensive_masked_mean_cpu_float64 (__main__.TestInductorOpInfoCPU)": 110.54400000000001, + "test_comprehensive_masked_mean_cpu_int32 (__main__.TestInductorOpInfoCPU)": 107.743, + "test_comprehensive_masked_mean_cpu_int64 (__main__.TestInductorOpInfoCPU)": 102.5795, + "test_comprehensive_masked_norm_cpu_float16 (__main__.TestInductorOpInfoCPU)": 520.4745, + "test_comprehensive_masked_norm_cpu_float32 (__main__.TestInductorOpInfoCPU)": 526.9034999999999, + "test_comprehensive_masked_norm_cpu_float64 (__main__.TestInductorOpInfoCPU)": 526.642, + "test_comprehensive_masked_norm_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 120.993, + "test_comprehensive_masked_norm_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 117.78525, + "test_comprehensive_masked_norm_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 106.073, + "test_comprehensive_masked_prod_cpu_bool (__main__.TestInductorOpInfoCPU)": 99.75399999999999, + "test_comprehensive_masked_prod_cpu_float16 (__main__.TestInductorOpInfoCPU)": 107.64099999999999, + "test_comprehensive_masked_prod_cpu_float32 (__main__.TestInductorOpInfoCPU)": 106.8455, + "test_comprehensive_masked_prod_cpu_float64 (__main__.TestInductorOpInfoCPU)": 107.0445, + "test_comprehensive_masked_prod_cpu_int32 (__main__.TestInductorOpInfoCPU)": 106.7095, + "test_comprehensive_masked_prod_cpu_int64 (__main__.TestInductorOpInfoCPU)": 98.7585, + "test_comprehensive_masked_sum_cpu_bool (__main__.TestInductorOpInfoCPU)": 103.51249999999999, + "test_comprehensive_masked_sum_cpu_float16 (__main__.TestInductorOpInfoCPU)": 103.95750000000001, + "test_comprehensive_masked_sum_cpu_float32 (__main__.TestInductorOpInfoCPU)": 102.953, + "test_comprehensive_masked_sum_cpu_float64 (__main__.TestInductorOpInfoCPU)": 106.582, + "test_comprehensive_masked_sum_cpu_int32 (__main__.TestInductorOpInfoCPU)": 104.2175, + "test_comprehensive_masked_sum_cpu_int64 (__main__.TestInductorOpInfoCPU)": 96.6305, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestDecompCUDA)": 84.78766666666667, + "test_comprehensive_nn_functional_gaussian_nll_loss_cuda_float64 (__main__.TestDecompCUDA)": 89.57775, + "test_comprehensive_nn_functional_glu_cpu_float16 (__main__.TestInductorOpInfoCPU)": 79.2515, + "test_comprehensive_nn_functional_glu_cpu_float32 (__main__.TestInductorOpInfoCPU)": 79.8335, + "test_comprehensive_nn_functional_glu_cpu_float64 (__main__.TestInductorOpInfoCPU)": 80.49, + "test_comprehensive_nn_functional_grid_sample_cpu_float32 (__main__.TestDecompCPU)": 88.69, + "test_comprehensive_nn_functional_grid_sample_cpu_float64 (__main__.TestDecompCPU)": 80.43, + "test_comprehensive_nn_functional_grid_sample_cuda_float32 (__main__.TestDecompCUDA)": 183.62233333333333, + "test_comprehensive_nn_functional_grid_sample_cuda_float64 (__main__.TestDecompCUDA)": 228.25766666666667, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float32 (__main__.TestDecompCUDA)": 63.317666666666675, + "test_comprehensive_nn_functional_interpolate_trilinear_cuda_float64 (__main__.TestDecompCUDA)": 64.50699999999999, + "test_comprehensive_nn_functional_max_pool1d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 174.4395, + "test_comprehensive_nn_functional_max_pool1d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 164.878, + "test_comprehensive_nn_functional_max_pool1d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 165.157, + "test_comprehensive_nn_functional_max_pool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 939.275, + "test_comprehensive_nn_functional_max_pool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 873.8385, + "test_comprehensive_nn_functional_max_pool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 869.2495, + "test_comprehensive_nn_functional_max_pool2d_cpu_int32 (__main__.TestInductorOpInfoCPU)": 798.8875, + "test_comprehensive_nn_functional_max_pool2d_cpu_int64 (__main__.TestInductorOpInfoCPU)": 835.7080000000001, + "test_comprehensive_nn_functional_max_pool2d_cuda_float16 (__main__.TestInductorOpInfoCUDA)": 719.5625, + "test_comprehensive_nn_functional_max_pool2d_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 722.98475, + "test_comprehensive_nn_functional_max_pool2d_cuda_float64 (__main__.TestInductorOpInfoCUDA)": 710.978, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 194.88083333333336, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 196.07629166666663, + "test_comprehensive_nn_functional_max_unpool2d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 194.2515, + "test_comprehensive_nn_functional_max_unpool3d_cpu_float16 (__main__.TestInductorOpInfoCPU)": 128.8409583333333, + "test_comprehensive_nn_functional_max_unpool3d_cpu_float32 (__main__.TestInductorOpInfoCPU)": 126.98766666666666, + "test_comprehensive_nn_functional_max_unpool3d_cpu_float64 (__main__.TestInductorOpInfoCPU)": 127.97137500000001, + "test_comprehensive_nn_functional_pad_constant_cpu_float16 (__main__.TestInductorOpInfoCPU)": 78.40700000000001, + "test_comprehensive_nn_functional_pad_constant_cpu_float32 (__main__.TestInductorOpInfoCPU)": 77.36099999999999, + "test_comprehensive_nn_functional_pad_constant_cpu_float64 (__main__.TestInductorOpInfoCPU)": 80.1695, + "test_comprehensive_nn_functional_pad_constant_cpu_int32 (__main__.TestInductorOpInfoCPU)": 69.87650000000001, + "test_comprehensive_nn_functional_pad_constant_cpu_int64 (__main__.TestInductorOpInfoCPU)": 70.05720833333334, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float16 (__main__.TestInductorOpInfoCPU)": 127.6425, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float32 (__main__.TestInductorOpInfoCPU)": 129.609, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_float64 (__main__.TestInductorOpInfoCPU)": 136.7715, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_int32 (__main__.TestInductorOpInfoCPU)": 137.901, + "test_comprehensive_nn_functional_poisson_nll_loss_cpu_int64 (__main__.TestInductorOpInfoCPU)": 126.132, + "test_comprehensive_nn_functional_unfold_cpu_bool (__main__.TestInductorOpInfoCPU)": 123.48891666666667, + "test_comprehensive_nn_functional_unfold_cpu_float16 (__main__.TestInductorOpInfoCPU)": 261.71, + "test_comprehensive_nn_functional_unfold_cpu_float32 (__main__.TestInductorOpInfoCPU)": 273.28200000000004, + "test_comprehensive_nn_functional_unfold_cpu_float64 (__main__.TestInductorOpInfoCPU)": 266.879, + "test_comprehensive_ormqr_cuda_complex128 (__main__.TestDecompCUDA)": 101.85466666666666, + "test_comprehensive_ormqr_cuda_complex64 (__main__.TestDecompCUDA)": 86.66175000000001, + "test_comprehensive_ormqr_cuda_float32 (__main__.TestInductorOpInfoCUDA)": 83.36566666666666, + "test_comprehensive_ormqr_cuda_float64 (__main__.TestDecompCUDA)": 60.800333333333334, + "test_comprehensive_svd_cuda_complex64 (__main__.TestDecompCUDA)": 68.02033333333334, + "test_cond_autograd_nested (__main__.TestControlFlow)": 144.61216666666667, + "test_constructor_autograd_SparseBSC_cuda (__main__.TestSparseAnyCUDA)": 104.74549999999999, + "test_constructor_autograd_SparseBSR_cuda (__main__.TestSparseAnyCUDA)": 120.5435, + "test_constructor_autograd_SparseCSC_cuda (__main__.TestSparseAnyCUDA)": 68.446, + "test_constructor_autograd_SparseCSR_cuda (__main__.TestSparseAnyCUDA)": 81.96675, + "test_conv1d_basic (__main__.TestXNNPACKConv1dTransformPass)": 142.8464, + "test_conv1d_with_relu_fc (__main__.TestXNNPACKConv1dTransformPass)": 584.9645, + "test_conv2d_unary_cpu_cpp_wrapper (__main__.TestCppWrapper)": 275.89, + "test_conv_bn_fuse_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 60.4926, + "test_correctness_NAdam_use_closure_True_cuda_float32 (__main__.CompiledOptimizerParityTestsCUDA)": 63.9085, + "test_count_nonzero_all (__main__.TestBool)": 585.3546, + "test_cusparse_multiple_threads_same_device (__main__.TestCuda)": 108.34033333333333, + "test_custom_module_lstm (__main__.TestQuantizedOps)": 311.147, + "test_dispatch_symbolic_meta_outplace_all_strides_nn_functional_gaussian_nll_loss_cuda_float32 (__main__.TestMetaCUDA)": 83.78333333333335, + "test_dtensor_op_db_nn_functional_gaussian_nll_loss_cpu_float32 (__main__.TestDTensorOpsCPU)": 88.53099999999999, + "test_eig_check_magma_cuda_float32 (__main__.TestLinalgCUDA)": 253.42925, + "test_fail_creation_ops.py (__main__.TestTyping)": 60.42466666666667, + "test_fn_fwgrad_bwgrad_cumprod_cuda_complex128 (__main__.TestFwdGradientsCUDA)": 73.053, + "test_fn_fwgrad_bwgrad_nn_functional_scaled_dot_product_attention_cuda_float64 (__main__.TestFwdGradientsCUDA)": 64.3565294117647, + "test_fn_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 135.923, + "test_fn_gradgrad_map_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 91.0015, + "test_fn_gradgrad_map_triple_nested_cpu_float64 (__main__.TestBwdGradientsCPU)": 520.559, + "test_fn_gradgrad_map_triple_nested_cuda_float64 (__main__.TestBwdGradientsCUDA)": 355.68533333333335, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 79.11880000000001, + "test_fuse_large_params_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 77.26866666666666, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesCodegenGPUTests)": 72.283, + "test_fuse_large_params_dynamic_shapes_cuda (__main__.DynamicShapesGPUTests)": 63.989, + "test_grad_nn_Transformer_cuda_float64 (__main__.TestModuleCUDA)": 97.49375, + "test_gradgrad_nn_LSTM_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 92.881, + "test_gradgrad_nn_LSTM_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 89.92000000000002, + "test_gradgrad_nn_TransformerDecoderLayer_cuda_float64 (__main__.TestModuleCUDA)": 215.37775, + "test_gradgrad_nn_TransformerEncoder_eval_mode_cuda_float64 (__main__.TestModuleCUDA)": 141.02475, + "test_gradgrad_nn_TransformerEncoder_train_mode_cuda_float64 (__main__.TestModuleCUDA)": 144.87225, + "test_grid_sampler_2d_cpu (__main__.CpuHalideTests)": 187.98233333333334, + "test_hessian_argnums_dynamic_shapes (__main__.DynamicShapesFuncTorchHigherOrderOpTests)": 243.68739024390243, + "test_indexing (__main__.TestAutogradWithCompiledAutograd)": 66.53580487804876, + "test_indirect_device_assert (__main__.TritonCodeGenTests)": 178.977625, + "test_inplace_gradgrad_cumprod_cuda_complex128 (__main__.TestBwdGradientsCUDA)": 96.526, + "test_jacrev_two_tensors_argnums_dynamic_shapes (__main__.DynamicShapesFuncTorchHigherOrderOpTests)": 66.33568292682926, + "test_jit_cuda_archflags (__main__.TestCppExtensionJIT)": 117.10499999999999, + "test_linalg_solve_triangular_large_cuda_complex128 (__main__.TestLinalgCUDA)": 697.57475, + "test_linalg_solve_triangular_large_cuda_complex64 (__main__.TestLinalgCUDA)": 83.33125000000001, + "test_linalg_solve_triangular_large_cuda_float64 (__main__.TestLinalgCUDA)": 73.11175, + "test_linear (__main__.TestStaticQuantizedModule)": 83.46508771929825, + "test_linear_binary_cpp_wrapper (__main__.TestCppWrapper)": 167.69299999999998, + "test_linear_binary_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 174.84699999999998, + "test_linear_packed_cpp_wrapper (__main__.TestCppWrapper)": 78.657, + "test_linear_packed_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 80.5985, + "test_linear_relu (__main__.TestStaticQuantizedModule)": 67.61533333333334, + "test_lstm_cpu (__main__.TestMkldnnCPU)": 62.057500000000005, + "test_matmul_small_brute_force_tunableop_cuda_float16 (__main__.TestLinalgCUDA)": 84.33282352941175, + "test_max_autotune_cutlass_backend_addmm_dynamic_False_max_autotune_gemm_backends_ATen,Triton,CUTLASS (__main__.TestCutlassBackend)": 84.89699999999999, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCodegenCpuTests)": 62.57302439024391, + "test_max_pool2d_with_indices_backward4_dynamic_shapes_cpu (__main__.DynamicShapesCpuTests)": 62.35275609756097, + "test_memory_format_operators_cpu (__main__.TestTorchDeviceTypeCPU)": 76.08051111111111, + "test_mixed_mm_exhaustive_dtypes (__main__.TestPatternMatcher)": 103.41881818181818, + "test_pipeline_order_flex_and_zero_bubble_ScheduleClass1 (__main__.TestSchedulePlan)": 71.24585714285715, + "test_proper_exit (__main__.TestDataLoader)": 188.92475, + "test_proper_exit (__main__.TestDataLoaderPersistentWorkers)": 172.64125, + "test_python_ref_executor__refs_special_zeta_executor_aten_cuda_float64 (__main__.TestCommonCUDA)": 69.83999999999999, + "test_qat_conv2d_unary (__main__.TestQuantizePT2EX86Inductor)": 146.72099999999998, + "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn1d)": 63.105317460317494, + "test_qat_conv_bn_fusion_no_conv_bias (__main__.TestQuantizePT2EQAT_ConvBn2d)": 62.03507936507937, + "test_qat_mobilenet_v2 (__main__.TestQuantizePT2EQATModels)": 72.86783333333334, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 110.89726666666668, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 111.5624, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 111.90106666666665, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 113.40520000000001, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 111.72566666666668, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 112.32593333333335, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 112.38706666666667, + "test_qlinear_add_int8_mixed_bf16_use_relu_False_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 114.79066666666668, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 115.27933333333333, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 113.48193333333333, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 112.9768, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_False_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 113.761, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_cpp_wrapper (__main__.TestCppWrapper)": 112.22186666666668, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_False_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 112.23886666666668, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_cpp_wrapper (__main__.TestCppWrapper)": 112.1882, + "test_qlinear_add_int8_mixed_bf16_use_relu_True_is_qat_True_is_dynamic_True_dynamic_shapes_cpp_wrapper (__main__.DynamicShapesCppWrapperCpuTests)": 115.28106666666667, + "test_quick_core_backward__unsafe_masked_index_cpu_float64 (__main__.TestDecompCPU)": 356.806, + "test_quick_core_backward__unsafe_masked_index_cuda_float64 (__main__.TestDecompCUDA)": 613.9933333333333, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cpu_float64 (__main__.TestDecompCPU)": 571.8885, + "test_quick_core_backward__unsafe_masked_index_put_accumulate_cuda_float64 (__main__.TestDecompCUDA)": 952.0746666666668, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cpu_float64 (__main__.TestDecompCPU)": 66.04319047619047, + "test_quick_core_backward_nn_functional_max_unpool3d_grad_cuda_float64 (__main__.TestDecompCUDA)": 251.12686666666664, + "test_quick_core_backward_roll_cpu_float64 (__main__.TestDecompCPU)": 89.99199999999999, + "test_quick_core_backward_roll_cuda_float64 (__main__.TestDecompCUDA)": 152.62966666666668, + "test_quick_core_backward_select_scatter_cuda_float64 (__main__.TestDecompCUDA)": 99.9515, + "test_quick_core_backward_split_with_sizes_copy_cpu_float64 (__main__.TestDecompCPU)": 73.9705, + "test_quick_core_backward_split_with_sizes_copy_cuda_float64 (__main__.TestDecompCUDA)": 126.93, + "test_quick_core_backward_std_cuda_float64 (__main__.TestDecompCUDA)": 79.00166666666667, + "test_replicatepad_64bit_indexing_cuda_float16 (__main__.TestNNDeviceTypeCUDA)": 320.14799999999997, + "test_retracibility_dict_container_inp_out_dynamic_shapes (__main__.DynamicShapesExportTests)": 1307.0030000000002, + "test_retracibility_nested_list_out_dynamic_shapes (__main__.DynamicShapesExportTests)": 1304.0381999999997, + "test_reveal_module_list.py (__main__.TestTyping)": 69.52092857142858, + "test_rosenbrock_sparse_with_lrsched_True_SGD_cuda_float64 (__main__.TestOptimRenewedCUDA)": 73.97947058823529, + "test_save_load_large_string_attribute (__main__.TestSaveLoad)": 116.828, + "test_sdpa_kernel_ctx_manager2_dynamic_shapes (__main__.DynamicShapesCtxManagerTests)": 384.3718536585367, + "test_shuffler_iterdatapipe (__main__.IntegrationTestDataLoaderDataPipe)": 124.52640000000001, + "test_slow_tasks (__main__.TestFunctionalAutogradBenchmark)": 143.84016666666668, + "test_sort_stable_cpu (__main__.CpuTritonTests)": 76.68900000000001, + "test_split_cumsum_cpu (__main__.CpuTritonTests)": 105.111, + "test_svd_lowrank_cuda_complex128 (__main__.TestLinalgCUDA)": 250.93199999999996, + "test_terminate_handler_on_crash (__main__.TestTorch)": 70.20433333333334, + "test_terminate_signal (__main__.ForkTest)": 94.89233333333334, + "test_terminate_signal (__main__.ParallelForkServerShouldWorkTest)": 138.68493220338985, + "test_terminate_signal (__main__.SpawnTest)": 97.77383333333334, + "test_transformer_backend_inductor_fullgraph_True (__main__.TestFullyShardCompile)": 113.80291666666669, + "test_transpose_copy (__main__.CPUReproTests)": 67.3468, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 95.54525, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 63.90972727272728, + "test_triton_bsr_scatter_mm_blocksize_64_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 108.88275, + "test_triton_bsr_softmax_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 126.33936363636363, + "test_triton_bsr_softmax_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 121.914, + "test_triton_bsr_softmax_cuda_float32 (__main__.TestSparseCompressedTritonKernelsCUDA)": 97.888, + "test_triton_scaled_dot_product_attention_block_size_16_cuda_bfloat16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 74.247, + "test_triton_scaled_dot_product_attention_block_size_16_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 73.68725, + "test_triton_scaled_dot_product_attention_block_size_32_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 72.97200000000001, + "test_triton_scaled_dot_product_attention_block_size_64_cuda_float16 (__main__.TestSparseCompressedTritonKernelsCUDA)": 73.30687499999999, + "test_unary_ops (__main__.TestTEFuserDynamic)": 195.76619999999997, + "test_unary_ops (__main__.TestTEFuserStatic)": 161.1238, + "test_upsample_bicubic2d_cpu (__main__.CpuHalideTests)": 95.887, + "test_variant_consistency_jit_nn_functional_max_pool2d_cpu_float32 (__main__.TestJitCPU)": 94.2415, + "test_variant_consistency_jit_nn_functional_max_pool2d_cuda_float32 (__main__.TestJitCUDA)": 79.27025, + "test_vec_bitwise (__main__.CPUReproTests)": 67.08171428571428, + "test_vmapjvpvjp_linalg_lu_solve_cuda_float32 (__main__.TestOperatorsCUDA)": 65.91033333333333, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cpu_float32 (__main__.TestOperatorsCPU)": 64.7045, + "test_vmapjvpvjp_max_pool2d_with_indices_backward_cuda_float32 (__main__.TestOperatorsCUDA)": 67.02433333333333, + "test_vmapjvpvjp_nn_functional_conv2d_cpu_float32 (__main__.TestOperatorsCPU)": 60.5145, + "test_vmapjvpvjp_nn_functional_max_pool2d_cpu_float32 (__main__.TestOperatorsCPU)": 69.64450000000001, + "test_vmapjvpvjp_nn_functional_max_pool2d_cuda_float32 (__main__.TestOperatorsCUDA)": 70.06475, + "test_vmapjvpvjp_svd_cuda_float32 (__main__.TestOperatorsCUDA)": 63.21600000000001, + "test_vmapjvpvjp_unbind_cpu_float32 (__main__.TestOperatorsCPU)": 66.5395, + "test_vmapjvpvjp_unbind_cuda_float32 (__main__.TestOperatorsCUDA)": 78.33349999999999, + "test_vmapvjpvjp_meshgrid_list_of_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 68.875, + "test_vmapvjpvjp_meshgrid_variadic_tensors_cuda_float32 (__main__.TestOperatorsCUDA)": 92.238, + "test_vmapvjpvjp_nn_functional_bilinear_cuda_float32 (__main__.TestOperatorsCUDA)": 125.17075 } \ No newline at end of file diff --git a/test/test_autograd.py b/test/test_autograd.py index 8a141ea20c357..f25ca30fd963b 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -69,6 +69,7 @@ IS_WINDOWS, parametrize, run_tests, + scoped_load_inline, set_warn_always_context, skipIfMps, skipIfNoLapack, @@ -85,7 +86,6 @@ CheckpointPolicy, create_selective_checkpoint_contexts, ) -from torch.utils.cpp_extension import load_inline from torch.utils.flop_counter import FlopCounterMode @@ -7456,20 +7456,20 @@ def backward(ctx, input): def test_reentrant_with_callbacks_depth_0(self): # Verify callback is called only once. ret = self._test_reentrant_with_callbacks([0]) - self.assertEqual(1, ret["outer"]) - self.assertEqual(0, ret["inner"]) + self.assertEqual(ret["outer"], 1) + self.assertEqual(ret["inner"], 0) def test_reentrant_with_callbacks_depth_1(self): # Verify callback is called only once. ret = self._test_reentrant_with_callbacks([1]) - self.assertEqual(0, ret["outer"]) - self.assertEqual(1, ret["inner"]) + self.assertEqual(ret["outer"], 0) + self.assertEqual(ret["inner"], 1) def test_reentrant_with_callbacks_both_depths(self): # Verify callback is called twice. ret = self._test_reentrant_with_callbacks([0, 1]) - self.assertEqual(1, ret["outer"]) - self.assertEqual(1, ret["inner"]) + self.assertEqual(ret["outer"], 1) + self.assertEqual(ret["inner"], 1) def test_reentrant_with_leaf_variable_hook(self): handle = None @@ -9854,7 +9854,8 @@ def test_scalar_grad_mixed_device(self): out = x * y out.sum().backward() - def test_multi_grad_all_hooks(self): + @scoped_load_inline + def test_multi_grad_all_hooks(self, load_inline): t1 = torch.rand(2, requires_grad=True) t2 = torch.rand(2, requires_grad=True) t3 = torch.rand(2, requires_grad=True) @@ -9899,19 +9900,19 @@ def backward(ctx, gO): return CustomOpAutogradFunction::apply(x); } -TORCH_LIBRARY(test_autograd_cpp_node, m) { +TORCH_LIBRARY(test_multigrad_all_hooks, m) { m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); } """ module = load_inline( - name="test_autograd_cpp_node", + name="test_multigrad_all_hooks", cpp_sources=cpp_source, functions="custom_op_backed_by_autograd_fn", verbose=True, ) - t4 = torch.ops.test_autograd_cpp_node.custom_op_backed_by_autograd_fn(t4) + t4 = torch.ops.test_multigrad_all_hooks.custom_op_backed_by_autograd_fn(t4) res = [None] * 4 count = [0] diff --git a/test/test_binary_ufuncs.py b/test/test_binary_ufuncs.py index 9cc041deed924..1166de2b70dd5 100644 --- a/test/test_binary_ufuncs.py +++ b/test/test_binary_ufuncs.py @@ -159,6 +159,15 @@ def _helper_reference_numerics( actual = op(l, r) expected = op.ref(l_numpy, r_numpy) + # Dtype promo rules have changed since NumPy 2. + # Specialize the backward-incompatible cases. + if ( + np.__version__ > "2" + and op.name in ("sub", "_refs.sub") + and isinstance(l_numpy, np.ndarray) + ): + expected = expected.astype(l_numpy.dtype) + # Crafts a custom error message for smaller, printable tensors def _numel(x): if isinstance(x, torch.Tensor): @@ -3199,7 +3208,12 @@ def test_shift_limits(self, device, dtype): ): shift_left_expected = torch.zeros_like(input) shift_right_expected = torch.clamp(input, -1, 0) - for shift in chain(range(-100, -1), range(bits, 100)): + # NumPy 2 does not support negative shift values. + if np.__version__ > "2": + iterator = range(bits, 100) + else: + iterator = chain(range(-100, -1), range(bits, 100)) + for shift in iterator: shift_left = input << shift self.assertEqual(shift_left, shift_left_expected, msg=f"<< {shift}") self.compare_with_numpy( diff --git a/test/test_cpp_extensions_aot.py b/test/test_cpp_extensions_aot.py index 808ecff991eb7..f4cb94ba22ee0 100644 --- a/test/test_cpp_extensions_aot.py +++ b/test/test_cpp_extensions_aot.py @@ -11,7 +11,11 @@ import torch.testing._internal.common_utils as common import torch.utils.cpp_extension from torch.testing._internal.common_cuda import TEST_CUDA -from torch.testing._internal.common_utils import IS_WINDOWS, skipIfTorchDynamo +from torch.testing._internal.common_utils import ( + IS_WINDOWS, + skipIfTorchDynamo, + xfailIfTorchDynamo, +) try: @@ -315,7 +319,7 @@ class TestRNGExtension(common.TestCase): def setUp(self): super().setUp() - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") + @xfailIfTorchDynamo def test_rng(self): fourty_two = torch.full((10,), 42, dtype=torch.int64) diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index e6883b0f2e942..40d33df3b8c36 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -1,6 +1,7 @@ # Owner(s): ["module: cpp-extensions"] import _codecs +import io import os import tempfile import types @@ -529,35 +530,110 @@ def test_open_device_tensorlist_type_fallback(self): # call _fused_adamw_ with undefined tensor. self.module.fallback_with_undefined_tensor() + @unittest.skipIf( + np.__version__ < "1.25", + "versions < 1.25 serialize dtypes differently from how it's serialized in data_legacy_numpy", + ) def test_open_device_numpy_serialization(self): + """ + This tests the legacy _rebuild_device_tensor_from_numpy serialization path + """ torch.utils.rename_privateuse1_backend("foo") device = self.module.custom_device() default_protocol = torch.serialization.DEFAULT_PROTOCOL - # This is a hack to test serialization through numpy + + # Legacy data saved with _rebuild_device_tensor_from_numpy on f80ed0b8 via + + # with patch.object(torch._C, "_has_storage", return_value=False): + # x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32, device=device) + # x_foo = x.to(device) + # sd = {"x": x_foo} + # rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0] + # self.assertTrue( + # rebuild_func is torch._utils._rebuild_device_tensor_from_numpy + # ) + # with open("foo.pt", "wb") as f: + # torch.save(sd, f) + + data_legacy_numpy = ( + b"PK\x03\x04\x00\x00\x08\x08\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + b"\x00\x00\x00\x10\x00\x12\x00archive/data.pklFB\x0e\x00ZZZZZZZZZZZZZZ\x80\x02}q\x00X\x01" + b"\x00\x00\x00xq\x01ctorch._utils\n_rebuild_device_tensor_from_numpy\nq\x02(cnumpy.core.m" + b"ultiarray\n_reconstruct\nq\x03cnumpy\nndarray\nq\x04K\x00\x85q\x05c_codecs\nencode\nq\x06" + b"X\x01\x00\x00\x00bq\x07X\x06\x00\x00\x00latin1q\x08\x86q\tRq\n\x87q\x0bRq\x0c(K\x01K\x02K" + b"\x03\x86q\rcnumpy\ndtype\nq\x0eX\x02\x00\x00\x00f4q\x0f\x89\x88\x87q\x10Rq\x11(K\x03X\x01" + b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00" + b"PK\x05\x06\x00\x00\x00\x00\x04\x00\x04\x00\x06\x01\x00\x008\x03\x00\x00\x00\x00" + ) + buf_data_legacy_numpy = io.BytesIO(data_legacy_numpy) + + with safe_globals( + [ + np.core.multiarray._reconstruct, + np.ndarray, + np.dtype, + _codecs.encode, + np.dtypes.Float32DType, + ] + ): + sd_loaded = torch.load(buf_data_legacy_numpy, weights_only=True) + buf_data_legacy_numpy.seek(0) + # Test map_location + sd_loaded_cpu = torch.load( + buf_data_legacy_numpy, weights_only=True, map_location="cpu" + ) + expected = torch.tensor( + [[1, 2, 3], [4, 5, 6]], dtype=torch.float32, device=device + ) + self.assertEqual(sd_loaded["x"].cpu(), expected.cpu()) + self.assertFalse(sd_loaded["x"].is_cpu) + self.assertTrue(sd_loaded_cpu["x"].is_cpu) + + def test_open_device_cpu_serialization(self): + torch.utils.rename_privateuse1_backend("foo") + device = self.module.custom_device() + default_protocol = torch.serialization.DEFAULT_PROTOCOL + with patch.object(torch._C, "_has_storage", return_value=False): x = torch.randn(2, 3) x_foo = x.to(device) sd = {"x": x_foo} rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0] self.assertTrue( - rebuild_func is torch._utils._rebuild_device_tensor_from_numpy + rebuild_func is torch._utils._rebuild_device_tensor_from_cpu_tensor ) # Test map_location with TemporaryFileName() as f: torch.save(sd, f) - with safe_globals( - [ - np.core.multiarray._reconstruct, - np.ndarray, - np.dtype, - _codecs.encode, - type(np.dtype(np.float32)) - if np.__version__ < "1.25.0" - else np.dtypes.Float32DType, - ] - ): - sd_loaded = torch.load(f, map_location="cpu") - self.assertTrue(sd_loaded["x"].is_cpu) + sd_loaded = torch.load(f, weights_only=True) + # Test map_location + sd_loaded_cpu = torch.load(f, weights_only=True, map_location="cpu") + self.assertFalse(sd_loaded["x"].is_cpu) + self.assertEqual(sd_loaded["x"].cpu(), x) + self.assertTrue(sd_loaded_cpu["x"].is_cpu) # Test metadata_only with TemporaryFileName() as f: diff --git a/test/test_cuda.py b/test/test_cuda.py index 8a7b2193cf3a8..a0a5edba32ffd 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -590,10 +590,8 @@ def test_manual_seed(self): self.assertEqual(torch.cuda.initial_seed(), 2) def test_specify_improper_device_name(self): - import os - - fname = "tempfile.pt" - try: + with tempfile.TemporaryDirectory() as tmpdir: + fname = os.path.join(tmpdir, "tempfile.pt") with self.assertRaisesRegex(RuntimeError, "Invalid device string"): torch.save( [torch.nn.Parameter(torch.randn(10, 10))], @@ -601,9 +599,6 @@ def test_specify_improper_device_name(self): _use_new_zipfile_serialization=True, ) torch.load(fname, "cuda0") - finally: - if os.path.exists(fname): - os.remove(fname) def test_get_device_index(self): from torch.cuda._utils import _get_device_index @@ -749,7 +744,7 @@ def test_record_stream_on_shifted_view(self): # Record another stream on a shifted view tensor. view = base[5:] - assert view.storage_offset() > 0 + self.assertTrue(view.storage_offset() > 0) stream_record = torch.cuda.Stream() with torch.cuda.stream(stream_record): @@ -1048,7 +1043,9 @@ def run(dev: torch.device) -> int: return torch.stack([t1, t2]).unique().shape[0] # Use CPU as reference. The results should not deviate too much. - assert abs(run(torch.device("cuda")) - run(torch.device("cpu"))) < 10_000 + self.assertTrue( + abs(run(torch.device("cuda")) - run(torch.device("cpu"))) < 10_000 + ) @parametrize("dtype", [torch.float32, torch.double]) def test_random_no_reused_random_states(self, dtype: torch.dtype) -> None: @@ -1071,7 +1068,7 @@ def run(func, dev: torch.device, dtype: torch.dtype) -> int: run(func, torch.device("cuda"), dtype) - run(func, torch.device("cpu"), dtype) ) - assert deviation < 50_000, deviation + self.assertTrue(deviation < 50_000, deviation) def test_min_max_inits(self): # Testing if THC_reduceAll received the correct index initialization. @@ -1627,7 +1624,7 @@ def test_graph_capture_simple(self): g.replay() - self.assertTrue(b.sum().item() == 11000.0) + self.assertEqual(b.sum().item(), 11000.0) @unittest.skipIf( not TEST_CUDA_GRAPH, "CUDA >= 11.0 or ROCM >= 5.3 required for graphs" @@ -1704,7 +1701,7 @@ def get_final_offsets_of_states(generator_state): graph_offset = get_final_offsets_of_states(default_generator_state) # Compare the final offsets of states for both generators to ensure consistency - self.assertTrue(offset == graph_offset) + self.assertEqual(offset, graph_offset) # Compare the states generated outside and inside the graph self.assertEqual(random_values, graphed_random_values) @@ -1759,12 +1756,14 @@ def test(num_graphs, num_generators): expected_blocks_diff = 2 * num_generators expected_size_diff = 2 * 512 * num_generators # Each block's size is 512 - self.assertTrue( - (num_blocks - baseline_num_blocks) == expected_blocks_diff, + self.assertEqual( + (num_blocks - baseline_num_blocks), + expected_blocks_diff, "Unexpected number of active blocks.", ) - self.assertTrue( - (total_size - baseline_total_size) == expected_size_diff, + self.assertEqual( + (total_size - baseline_total_size), + expected_size_diff, "Unexpected total memory size.", ) @@ -1775,8 +1774,9 @@ def test(num_graphs, num_generators): clear_cuda_cache() # Assert that memory stats return to baseline after cleanup - self.assertTrue( - get_memory_stats() == baseline, + self.assertEqual( + get_memory_stats(), + baseline, "Memory stats do not match baseline after cleanup.", ) @@ -1804,7 +1804,7 @@ def test_graph_capture_reset_recapture(self): g.replay() - self.assertTrue(b.sum().item() == 11000.0) + self.assertEqual(b.sum().item(), 11000.0) g.reset() @@ -1817,7 +1817,7 @@ def test_graph_capture_reset_recapture(self): torch.cuda.current_stream().wait_stream(s) g.replay() - self.assertTrue(b.sum().item() == 22000.0) + self.assertEqual(b.sum().item(), 22000.0) g.reset() del g @@ -3519,8 +3519,8 @@ def thefree(): thealloc() thefree() ss = json.dumps(torch.cuda.memory._snapshot()) - self.assertTrue(("thefree" in ss) == (context == "all")) - self.assertTrue(("thealloc" in ss) == (context != "state")) + self.assertEqual(("thefree" in ss), (context == "all")) + self.assertEqual(("thealloc" in ss), (context != "state")) finally: torch.cuda.memory._record_memory_history(None) @@ -3576,7 +3576,7 @@ def test_memory_plots_free_segment_stack(self): torch.cuda.memory.empty_cache() ss = json.dumps(torch.cuda.memory._snapshot()) - self.assertTrue(("empty_cache" in ss) == (context == "all")) + self.assertEqual(("empty_cache" in ss), (context == "all")) finally: torch.cuda.memory._record_memory_history(None) @@ -3599,7 +3599,7 @@ def foo(): for seg in ss: for b in seg["blocks"]: if b["requested_size"] == 311 * 411 * 4: - self.assertTrue(b["frames"][0]["name"] == "foo") + self.assertEqual(b["frames"][0]["name"], "foo") found_it = True self.assertTrue(found_it) @@ -3613,7 +3613,7 @@ def test_max_split_expandable(self): pre_reserved = torch.cuda.memory_reserved() total_allowed = 120 * mb + pre_reserved fraction_allowed = total_allowed / all_memory - assert int(fraction_allowed * all_memory) == total_allowed + self.assertEqual(int(fraction_allowed * all_memory), total_allowed) torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed) def alloc(n): @@ -3644,7 +3644,7 @@ def test_garbage_collect_expandable(self): pre_reserved = torch.cuda.memory_reserved() total_allowed = 120 * mb + pre_reserved fraction_allowed = total_allowed / all_memory - assert int(fraction_allowed * all_memory) == total_allowed + self.assertEqual((fraction_allowed * all_memory), total_allowed) torch.cuda.memory.set_per_process_memory_fraction(fraction_allowed) def alloc(n): @@ -3704,11 +3704,11 @@ def power2_div(size, div_factor): pow2_div4_mem = torch.cuda.memory_stats()[key_allocated] current_requested = torch.cuda.memory_stats()[key_requested] - self.assertTrue(reg_mem - start_mem == nbytes) + self.assertEqual(reg_mem - start_mem, nbytes) if not TEST_CUDAMALLOCASYNC: # not supported with the cudaMallocAsync backend - self.assertTrue(pow2_div4_mem - reg_mem == power2_div(nbytes, 4)) - self.assertTrue(current_requested - start_requested == nbytes) + self.assertEqual(pow2_div4_mem - reg_mem, power2_div(nbytes, 4)) + self.assertEqual(current_requested - start_requested, nbytes) torch.cuda.memory._set_allocator_settings("garbage_collection_threshold:0.5") torch.cuda.memory._set_allocator_settings( @@ -3720,7 +3720,7 @@ def power2_div(size, div_factor): start_mem = torch.cuda.memory_stats()[key_allocated] z = torch.rand(nelems, device="cuda") reg_mem = torch.cuda.memory_stats()[key_allocated] - self.assertTrue(reg_mem - start_mem == nbytes) + self.assertEqual(reg_mem - start_mem, nbytes) # roundup_power2_divisions knob array syntax torch.cuda.memory.empty_cache() @@ -3733,7 +3733,7 @@ def power2_div(size, div_factor): pow2_div8_mem = torch.cuda.memory_stats()[key_allocated] if not TEST_CUDAMALLOCASYNC: # not supported with the cudaMallocAsync backend - self.assertTrue(pow2_div8_mem - start_mem == power2_div(nbytes, 8)) + self.assertEqual(pow2_div8_mem - start_mem, power2_div(nbytes, 8)) torch.cuda.memory.empty_cache() start_mem = torch.cuda.memory_stats()[key_allocated] @@ -3742,14 +3742,14 @@ def power2_div(size, div_factor): pow2_div2_mem = torch.cuda.memory_stats()[key_allocated] if not TEST_CUDAMALLOCASYNC: # not supported with the cudaMallocAsync backend - self.assertTrue(pow2_div2_mem - start_mem == power2_div(nbytes_big, 2)) + self.assertEqual(pow2_div2_mem - start_mem, power2_div(nbytes_big, 2)) torch.cuda.memory.empty_cache() torch.cuda.memory._set_allocator_settings("release_lock_on_cudamalloc:True") start_mem = torch.cuda.memory_stats()[key_allocated] w = torch.rand(nelems, device="cuda") reg_mem = torch.cuda.memory_stats()[key_allocated] - self.assertTrue(reg_mem - start_mem == nbytes) + self.assertEqual(reg_mem - start_mem, nbytes) with self.assertRaises(RuntimeError): torch.cuda.memory._set_allocator_settings("foo:1,bar:2") @@ -3878,8 +3878,8 @@ def run(): self.assertTrue("case.py" in frame_text) found = True last_action = mem["device_traces"][0][-1] - self.assertTrue(last_action["action"] == "alloc") - self.assertTrue(last_action["size"] == 311 * 411 * 4) + self.assertEqual(last_action["action"], "alloc") + self.assertEqual(last_action["size"], 311 * 411 * 4) self.assertTrue(found) finally: m.record(False, False) @@ -3918,7 +3918,7 @@ def free(): nonlocal total idx = random.randrange(0, len(mem)) v, x = mem.pop(idx) - assert torch.all(v == x) + self.assertTrue(torch.all(v == x)) total -= x.numel() choices = [alloc, free, torch.cuda.memory.empty_cache] @@ -3930,25 +3930,108 @@ def free(): finally: random.setstate(state) - @unittest.skipIf(TEST_PYNVML, "pynvml is not available") + @unittest.skipIf(TEST_PYNVML, "pynvml/amdsmi is not available") def test_nvml_get_handler(self): if not torch.version.hip: self.assertTrue(torch.cuda._get_pynvml_handler() is not None) else: self.assertTrue(torch.cuda._get_amdsmi_handler() is not None) - @unittest.skipIf(TEST_PYNVML, "pynvml is not available") + @unittest.skipIf(TEST_PYNVML, "pynvml/amdsmi is not available") def test_temperature(self): self.assertTrue(0 <= torch.cuda.temperature() <= 150) - @unittest.skipIf(TEST_PYNVML, "pynvml is not available") + @unittest.skipIf(TEST_PYNVML, "pynvml/amdsmi is not available") def test_power_draw(self): self.assertTrue(torch.cuda.power_draw() >= 0) - @unittest.skipIf(TEST_PYNVML, "pynvml is not available") + @unittest.skipIf(TEST_PYNVML, "pynvml/amdsmi is not available") def test_clock_speed(self): self.assertTrue(torch.cuda.clock_rate() >= 0) + @unittest.skipIf(TEST_PYNVML, "pynvml/amdsmi is not available") + @unittest.skipIf(not TEST_WITH_ROCM, "amdsmi specific test") + def test_raw_amdsmi_device_count(self): + """ + This unit test will verify if the number of GPUs shown in `amd-smi + list` is equivalent to the count returned by `_raw_device_count_amdsmi`. + This should be unaffected by visible device settings. + """ + raw_device_cnt = int( + subprocess.check_output( + "amd-smi list | grep 'GPU' | wc -l", shell=True + ).strip() + ) + self.assertEqual(torch.cuda._raw_device_count_amdsmi(), raw_device_cnt) + + @unittest.skipIf(TEST_PYNVML, "pynvml/amdsmi is not available") + @unittest.skipIf(not TEST_WITH_ROCM, "amdsmi specific test") + def test_raw_amdsmi_device_uuids(self): + """ + This unit test will extract a list of UUIDs for each GPU using + rocminfo information, and check whether each UUID is present in + the output from `_raw_device_uuid_amdsmi` this allows us to test + that the pytorch call is returning a correct list of UUIDs. + """ + cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid:.*GPU-//'" + uuids = ( + subprocess.check_output(cmd, shell=True, universal_newlines=True) + .strip() + .split("\n") + ) + uuids = [s.strip() for s in uuids] + raw_uuids = torch.cuda._raw_device_uuid_amdsmi() + for uuid in uuids: + matching = True + if not any(uuid in raw_id for raw_id in raw_uuids): + matching = False + self.assertEqual(True, matching) + + @unittest.skipIf(TEST_PYNVML, "pynvml/amdsmi is not available") + @unittest.skipIf(not TEST_WITH_ROCM, "amdsmi specific test") + def test_uuid_visible_devices(self): + """ + This unit test will simulate an environment where a UUID is passed + via CUDA/HIP_VISIBLE_DEVICES and ensure that the correct device count + is returned. This allows us to test that the visible device functionality + is operating as expected. + """ + test_script = """\ +import torch +import os +print(f"{torch.cuda.device_count()}") + """ + cmd = "rocminfo | grep -o 'Uuid:.*GPU-.*' | sed 's/Uuid://'" + uuids = ( + subprocess.check_output(cmd, shell=True, universal_newlines=True) + .strip() + .split("\n") + ) + uuids = [s.strip() for s in uuids] + + custom_envs = [] + for uuid in uuids: + custom_envs.append( + {"CUDA_VISIBLE_DEVICES": f"{uuid}", "HIP_VISIBLE_DEVICES": None} + ) + custom_envs.append( + {"HIP_VISIBLE_DEVICES": f"{uuid}", "CUDA_VISIBLE_DEVICES": None} + ) + + for env_config in custom_envs: + env = os.environ.copy() + for key, value in env_config.items(): + if value is None: + env.pop(key, None) + else: + env[key] = value + r = ( + subprocess.check_output([sys.executable, "-c", test_script], env=env) + .decode("ascii") + .strip() + ) + self.assertEqual("1", r) + MIN_BLOCK_SIZE = 512 SMALL_SIZE = 1048576 @@ -4221,7 +4304,7 @@ def foo(x): device = outputs[0].device.index for i in range(len(outputs)): - self.assertTrue(outputs[i].mean(dtype=torch.float) == 2) + self.assertEqual(outputs[i].mean(dtype=torch.float), 2) state = torch._C._cuda_getCheckpointState(outputs[0].device.index, pool_id) @@ -4239,13 +4322,13 @@ def foo(x): ] for i in range(len(reconstructed_tensors)): - self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 2) + self.assertEqual(reconstructed_tensors[i].mean(dtype=torch.float), 2) inp.add_(1) graph.replay() for i in range(len(reconstructed_tensors)): - self.assertTrue(reconstructed_tensors[i].mean(dtype=torch.float) == 3) + self.assertEqual(reconstructed_tensors[i].mean(dtype=torch.float), 3) self.setCheckpointPoolState( device, state, [], [reconstructed_tensors[0], reconstructed_tensors[1]] @@ -4451,18 +4534,48 @@ def test_mempool_with_allocator(self): # pool should point to the same allocator as the one passed into it self.assertEqual(allocator.allocator(), pool.allocator) + # pool's use count should be 1 at this point as MemPool object + # holds a reference + self.assertEqual(pool.use_count(), 1) + # no allocations happened yet, so called_dummy_alloc should be 0 alloc_lib = ctypes.CDLL(dummy_allocator) called_dummy_alloc = ctypes.c_int.in_dll(alloc_lib, "called_dummy_alloc") self.assertEqual(called_dummy_alloc.value, 0) + nelem_1mb = 1024 * 1024 // 4 with torch.cuda.use_mem_pool(pool): - out = torch.randn(1, device="cuda") + out_0 = torch.randn(nelem_1mb, device="cuda") + + # pool's use count should be 2 at this point as use_mem_pool + # holds a reference + self.assertEqual(pool.use_count(), 2) + + # pool's use count should be back to 1 at this point as use_mem_pool + # released its reference + self.assertEqual(pool.use_count(), 1) # called_dummy_alloc should be 123 if dummy_alloc was used to allocate # out tensor self.assertEqual(called_dummy_alloc.value, 123) + with torch.cuda.use_mem_pool(pool): + # pool should have 1 segment since we made a small allocation (1 MB) + # above and so the CUDACachingAllocator packed it into a 2 MB buffer + self.assertEqual(len(pool.snapshot()), 1) + + out_1 = torch.randn(nelem_1mb, device="cuda") + + # pool should still have 1 segment since we made another small allocation + # (1 MB) that got packed into the existing 2 MB buffer + self.assertEqual(len(pool.snapshot()), 1) + + out_2 = torch.randn(nelem_1mb, device="cuda") + + # pool now should have 2 segments since the CUDACachingAllocator had + # to make a new 2 MB buffer to accomodate out_2 + self.assertEqual(len(pool.snapshot()), 2) + def test_mempool_context(self): active_pool = torch.cuda.MemPoolContext.active_pool() diff --git a/test/test_cuda_sanitizer.py b/test/test_cuda_sanitizer.py index b6397255ecfab..d4b65a1541e97 100644 --- a/test/test_cuda_sanitizer.py +++ b/test/test_cuda_sanitizer.py @@ -9,6 +9,7 @@ import torch.cuda._sanitizer as csan from torch.cuda._sanitizer import DataPtr, EventId, StreamId from torch.testing._internal.common_utils import NoTest, run_tests, TEST_CUDA, TestCase +from torch.testing._internal.two_tensor import TwoTensor if not TEST_CUDA: @@ -23,9 +24,9 @@ def test_add(self): b = torch.randn(5, 3, device="cuda") argument_handler = csan.ArgumentHandler() - argument_handler.parse_inputs(add_func._schema, (a, b), {}) + argument_handler.parse_inputs(add_func._schema, (a, b), {}, is_factory=False) c = torch.add(a, b) - argument_handler.parse_outputs(c) + argument_handler.parse_outputs(add_func._schema, c, is_factory=False) self.assertEqual({a.data_ptr(), b.data_ptr()}, argument_handler.dataptrs_read) self.assertEqual({c.data_ptr()}, argument_handler.dataptrs_written) @@ -37,9 +38,11 @@ def test_cat(self): c = torch.rand(2, 7, 5, device="cuda") argument_handler = csan.ArgumentHandler() - argument_handler.parse_inputs(cat_func._schema, ([a, b, c], 1), {}) + argument_handler.parse_inputs( + cat_func._schema, ([a, b, c], 1), {}, is_factory=False + ) d = torch.cat((a, b, c), dim=1) - argument_handler.parse_outputs(d) + argument_handler.parse_outputs(cat_func._schema, d, is_factory=False) self.assertEqual( {a.data_ptr(), b.data_ptr(), c.data_ptr()}, argument_handler.dataptrs_read @@ -51,22 +54,25 @@ def test_split(self): a = torch.arange(10, device="cuda").reshape(5, 2) argument_handler = csan.ArgumentHandler() - argument_handler.parse_inputs(split_func._schema, (a, 2), {}) + argument_handler.parse_inputs(split_func._schema, (a, 2), {}, is_factory=False) out = torch.split(a, 2) - argument_handler.parse_outputs(out) + argument_handler.parse_outputs(split_func._schema, out, is_factory=False) outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()} - self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read) - self.assertEqual(outputs, argument_handler.dataptrs_written) + # Split is a view op, no data is read or written! + self.assertEqual(len(argument_handler.dataptrs_read), 0) + self.assertEqual(len(argument_handler.dataptrs_written), 0) def test_inplace(self): add_inplace_func = torch.ops.aten.add_.Tensor a = torch.rand(4, 2, device="cuda") argument_handler = csan.ArgumentHandler() - argument_handler.parse_inputs(add_inplace_func._schema, (a, 5), {}) + argument_handler.parse_inputs( + add_inplace_func._schema, (a, 5), {}, is_factory=False + ) a.add_(5) - argument_handler.parse_outputs(a) + argument_handler.parse_outputs(add_inplace_func._schema, a, is_factory=False) self.assertEqual(set(), argument_handler.dataptrs_read) self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_written) @@ -77,9 +83,11 @@ def test_out(self): b = torch.empty(8, device="cuda") argument_handler = csan.ArgumentHandler() - argument_handler.parse_inputs(mul_out_func._schema, (a, 3), {"out": b}) + argument_handler.parse_inputs( + mul_out_func._schema, (a, 3), {"out": b}, is_factory=False + ) torch.mul(a, 3, out=b) - argument_handler.parse_outputs(b) + argument_handler.parse_outputs(mul_out_func._schema, b, is_factory=False) self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read) self.assertEqual({b.data_ptr()}, argument_handler.dataptrs_written) @@ -89,9 +97,11 @@ def test_nonzero(self): a = torch.ones(5, 3, 2, device="cuda") argument_handler = csan.ArgumentHandler() - argument_handler.parse_inputs(nonzero_func._schema, (a,), {"as_tuple": True}) + argument_handler.parse_inputs( + nonzero_func._schema, (a,), {"as_tuple": True}, is_factory=False + ) out = torch.nonzero(a, as_tuple=True) - argument_handler.parse_outputs(out) + argument_handler.parse_outputs(nonzero_func._schema, out, is_factory=False) outputs = {out[0].data_ptr(), out[1].data_ptr(), out[2].data_ptr()} self.assertEqual({a.data_ptr()}, argument_handler.dataptrs_read) @@ -103,9 +113,11 @@ def test_tensor_names(self): M = torch.zeros(3, 3, device="cuda") argument_handler = csan.ArgumentHandler() - argument_handler.parse_inputs(addr_func._schema, (M, vec, vec), {}) + argument_handler.parse_inputs( + addr_func._schema, (M, vec, vec), {}, is_factory=False + ) out = torch.addr(M, vec, vec) - argument_handler.parse_outputs(out) + argument_handler.parse_outputs(addr_func._schema, out, is_factory=False) self.assertEqual( argument_handler.tensor_aliases, @@ -491,6 +503,22 @@ def test_error_message(self): ), ) + def test_subclass(self): + class MyT(torch.Tensor): + def __new__(cls, data): + new_data = data.clone() + return new_data.as_subclass(cls) + + try: + csan.enable_cuda_sanitizer() + + # These two tests ensure that subclass creation + # happens smoothly under the mode used by csan + t = TwoTensor(torch.rand(2), torch.rand(2)) + t = MyT(torch.rand(2)) + finally: + csan.cuda_sanitizer.disable() + if __name__ == "__main__": run_tests() diff --git a/test/test_custom_ops.py b/test/test_custom_ops.py index cb6ff55f3f471..f0ee8b65be6c5 100644 --- a/test/test_custom_ops.py +++ b/test/test_custom_ops.py @@ -33,6 +33,7 @@ IS_WINDOWS, parametrize, run_tests, + scoped_load_inline, skipIfTorchDynamo, subtest, TestCase, @@ -2088,7 +2089,8 @@ def test_impl_device_invalid(self): with self.assertRaisesRegex(RuntimeError, "Expected one of cpu, cuda"): torch.library.impl("blah::blah", "somethingsomething") - def test_autograd_function_backed_op(self): + @scoped_load_inline + def test_autograd_function_backed_op(self, load_inline): cpp_source = """ struct CustomOpAutogradFunction : public torch::autograd::Function { static constexpr bool is_traceable = true; @@ -2110,13 +2112,13 @@ def test_autograd_function_backed_op(self): return CustomOpAutogradFunction::apply(x); } -TORCH_LIBRARY(mylib, m) { +TORCH_LIBRARY(test_autograd_function_backed_op, m) { m.def("custom_op_backed_by_autograd_fn", custom_op_backed_by_autograd_fn); } """ - module = torch.utils.cpp_extension.load_inline( - name="mylib", + module = load_inline( + name="test_autograd_function_backed_op", cpp_sources=cpp_source, functions="custom_op_backed_by_autograd_fn", verbose=True, @@ -2124,7 +2126,11 @@ def test_autograd_function_backed_op(self): x = torch.ones(2, 2, requires_grad=True) temp = x.clone().detach() - out = torch.ops.mylib.custom_op_backed_by_autograd_fn(x) + out = ( + torch.ops.test_autograd_function_backed_op.custom_op_backed_by_autograd_fn( + x + ) + ) loss = out.sum() loss.backward() self.assertEqual(x.grad, temp) diff --git a/test/test_decomp.py b/test/test_decomp.py index 7ab3859454ff6..86f3531bb5ab4 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -10,8 +10,9 @@ import torch._inductor.decomposition import torch.autograd from torch import Tensor -from torch._decomp import _is_cia_op, core_aten_decompositions, decomposition_table +from torch._decomp import core_aten_decompositions, decomposition_table from torch._dispatch.python import enable_python_dispatcher +from torch._export.utils import _is_cia_op from torch._ops import DispatchKey from torch.testing import make_tensor from torch.testing._internal.common_cuda import tf32_off @@ -1230,9 +1231,7 @@ def _can_appear_in_trace(op: torch._ops.OpOverload) -> bool: try: # CompositeImplicitAutograd ops are transparent to the tracer, so don't need decompositions - return not op.has_kernel_for_dispatch_key( - DispatchKey.CompositeImplicitAutograd - ) + return not _is_cia_op(op) except RuntimeError as e: # has_key fails for some jit-registered ops, which shouldn't be # relevant here anyway diff --git a/test/test_dlpack.py b/test/test_dlpack.py index a9036be160b0a..fe1107ac850fc 100644 --- a/test/test_dlpack.py +++ b/test/test_dlpack.py @@ -15,6 +15,23 @@ from torch.utils.dlpack import from_dlpack, to_dlpack +# Wraps a tensor, exposing only DLPack methods: +# - __dlpack__ +# - __dlpack_device__ +# +# This is used for guaranteeing we are going through the DLPack method, and not +# something else, e.g.: CUDA array interface, buffer protocol, etc. +class TensorDLPackWrapper: + def __init__(self, tensor): + self.tensor = tensor + + def __dlpack__(self, *args, **kwargs): + return self.tensor.__dlpack__(*args, **kwargs) + + def __dlpack_device__(self, *args, **kwargs): + return self.tensor.__dlpack_device__(*args, **kwargs) + + class TestTorchDlPack(TestCase): exact_dtype = True @@ -251,6 +268,19 @@ def test_dlpack_normalize_strides(self): # gh-83069, make sure __dlpack__ normalizes strides self.assertEqual(z.stride(), (1,)) + @skipMeta + @onlyNativeDeviceTypes + def test_automatically_select_in_creation(self, device): + # Create a new tensor, and wrap it using TensorDLPackWrapper. + tensor = torch.rand(10) + wrap = TensorDLPackWrapper(tensor) + # Create a new tensor from the wrapper. + # This should identify that the wrapper class provides the DLPack methods + # and use them for creating the new tensor, instead of iterating element + # by element. + new_tensor = torch.tensor(wrap) + self.assertEqual(tensor, new_tensor) + instantiate_device_type_tests(TestTorchDlPack, globals()) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 01408586c4d1b..58f240e507544 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -832,6 +832,15 @@ def test_non_overlapping_and_dense_unbacked(self): ) ) + def test_sym_max_multi_max_simplify(self): + shape_env = ShapeEnv() + u0 = shape_env.create_unbacked_symint() + self.assertTrue( + statically_known_true( + torch.sym_max(1, torch.sym_max(257, u0)) == torch.sym_max(257, u0) + ) + ) + def test_numpy_sym_max(self): self.assertEqual(torch.sym_max(np.int64(10), 12), 12) self.assertEqual(torch.sym_max(np.int64(12), 10), 12) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index bf756f7b30fcd..621c59edc4960 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -25,6 +25,7 @@ from torch._guards import tracing, TracingContext from torch._higher_order_ops.scan import scan from torch._subclasses.fake_tensor import ( + _CacheKeyState, DynamicOutputShapeException, extract_tensor_metadata, FakeTensor, @@ -32,7 +33,6 @@ FakeTensorMode, unset_fake_temporarily, UnsupportedOperatorException, - _CacheKeyState ) from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ( @@ -61,10 +61,11 @@ TemporaryFileName, TEST_WITH_TORCHDYNAMO, TestCase, + xfailIfTorchDynamo, ) +from torch.testing._internal.custom_op_db import custom_op_db from torch.testing._internal.inductor_utils import GPU_TYPE -from torch.testing._internal.custom_op_db import custom_op_db from torch.testing._internal.jit_utils import RUN_CUDA from torch.utils._mode_utils import no_dispatch from torch.utils._python_dispatch import TorchDispatchMode @@ -931,13 +932,11 @@ def add(x, y): with torch._subclasses.fake_tensor.FakeTensorMode(): x = torch.randn((3, 5, 7), device="cpu") - init = torch.randn((3, 1, 7), device="cpu") + init = torch.randn((3, 7), device="cpu") r = scan(add, init, x, dim=1, reverse=reverse) self.assertIsInstance(r[0], FakeTensor) self.assertIsInstance(r[1], FakeTensor) - self.assertEqual(r[0].size(), init.size()) - self.assertEqual(r[1].size(), x.size()) instantiate_parametrized_tests(FakeTensorTest) @@ -1102,7 +1101,7 @@ def test_separate_tensor_storages_view(self): y_conv = converter.from_real_tensor(mode, y) self.assertEqual(torch._C._storage_id(x_conv), torch._C._storage_id(y_conv)) - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") + @xfailIfTorchDynamo def test_separate_tensor_storages_non_view(self): x = torch.rand(2, 2, 2) y = torch.rand(4, 2) @@ -1122,7 +1121,6 @@ def test_separate_tensor_storages_non_view(self): self.assertEqual(len(converter.tensor_memo), 0) self.assertEqual(len(converter.meta_converter.storage_memo), 0) - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") def test_dead_weak_ref(self): x = torch.rand(2, 2, 2) y = x[0] @@ -1135,7 +1133,7 @@ def test_dead_weak_ref(self): y_conv = converter.from_real_tensor(mode, y) self.assertIs(x_conv_storage, y_conv.untyped_storage()) - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") + @xfailIfTorchDynamo def test_dead_key(self): x = torch.rand(2, 2, 2) mode = FakeTensorMode() @@ -1177,7 +1175,7 @@ def test_separate_mode_error(self): y = torch.empty(2, 2, device="cpu") self.assertRaises(Exception, lambda: x, y) - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") + @xfailIfTorchDynamo def test_no_ref_cycle(self): x = torch.rand([4]) mode = FakeTensorMode() @@ -1925,6 +1923,29 @@ def test_inference_mode(self): extract_tensor_metadata(res4), ) + def test_cache_tuple_outputs(self): + """ + Test to check that ops with tuple outputs work. + """ + with FakeTensorMode(): + x = torch.randn(6, 4) + y = torch.randn(6, 4) + + FakeTensorMode.cache_clear() + self.assertHitsMisses(0, 0) + + ref = torch.split(x, 2) + self.assertHitsMisses(0, 1) + + res = torch.split(y, 2) + self.assertHitsMisses(1, 1) + self.assertEqual(len(ref), len(res)) + for a, b in zip(ref, res): + self.assertEqual( + extract_tensor_metadata(a), + extract_tensor_metadata(b), + ) + if __name__ == "__main__": run_tests() diff --git a/test/test_indexing.py b/test/test_indexing.py index 98f503765c1cb..77456affe507d 100644 --- a/test/test_indexing.py +++ b/test/test_indexing.py @@ -980,6 +980,37 @@ def test_index_put_accumulate_expanded_values(self, device): out_cpu = t.index_put_(indices, values2d, accumulate=True) self.assertEqual(out_cuda.cpu(), out_cpu) + @onlyCUDA + def test_index_put_large_indices(self, device): + def generate_indices(num_indices: int, index_range: int): + indices = [] + for _ in range(num_indices): + x = random.randint(0, index_range - 1) + indices.append(x) + return torch.tensor(indices) + + num_indices = 401988 + max_index_range = 2000 + results = [] + target_index_range = [16, 256, 2000] + for generated_index_range in target_index_range: + # create CPU tensors + a_tensor_size = (max_index_range, 256) + a = torch.randn(a_tensor_size, dtype=torch.bfloat16) + b = generate_indices( + num_indices=num_indices, index_range=generated_index_range + ) + c_tensor_size = (num_indices, 256) + c = torch.randn(c_tensor_size, dtype=torch.bfloat16) + # create GPU copies + a_dev = a.to(device) + b_dev = b.to(device) + c_dev = c.to(device) + # run + a.index_put_(indices=[b], values=c, accumulate=True) + a_dev.index_put_(indices=[b_dev], values=c_dev, accumulate=True) + self.assertEqual(a_dev.cpu(), a) + @onlyCUDA def test_index_put_accumulate_non_contiguous(self, device): t = torch.zeros((5, 2, 2)) diff --git a/test/test_linalg.py b/test/test_linalg.py index 430cf4ec85f5a..e9a7aa0fa6c76 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -4573,6 +4573,69 @@ def test_matmul_small_brute_force_tunableop(self, device, dtype): # disables TunableOp torch.cuda.tunable.enable(False) + @onlyCUDA + @dtypes(torch.half) + def test_matmul_offline_tunableop(self, device, dtype): + import os + os.putenv('PYTORCH_TUNABLEOP_ROTATING_BUFFER_SIZE', '0') + + # Pointing to temp files. The test cannot remove them on Windows because + # they are in use and locked + import tempfile + tmp_dir = tempfile.mkdtemp() + os.putenv("PYTORCH_TUNABLEOP_UNTUNED_FILENAME", os.path.join(tmp_dir, "tunableop_untuned.csv")) + os.putenv("PYTORCH_TUNABLEOP_FILENAME", os.path.join(tmp_dir, "tunableop_results.csv")) + + torch.cuda.tunable.enable() + # record GEMM + torch.cuda.tunable.tuning_enable(False) + torch.cuda.tunable.record_untuned_enable(True) + assert torch.cuda.tunable.record_untuned_is_enabled() + + make_arg = partial(make_tensor, device=device, dtype=dtype) + for (size_x, size_y), nctg_x, nctg_y in product(self.gen_sizes_matmul(1), (True, False), (True, False)): + x = make_arg(size_x, noncontiguous=nctg_x) + y = make_arg(size_y, noncontiguous=nctg_y) + self.check_single_matmul(x, y) + + assert torch.cuda.tunable.is_enabled() + assert torch.cuda.tunable.tuning_is_enabled() is False + ordinal = torch.cuda.current_device() + untuned_filename = os.path.join(tmp_dir, f"tunableop_untuned{ordinal}.csv") + assert os.path.exists(untuned_filename) + + # tuning the untuned GEMMs in file + torch.cuda.tunable.tuning_enable(True) + torch.cuda.tunable.record_untuned_enable(False) + + # set these to single iterations to keep it short but still exercise the code + torch.cuda.tunable.set_max_tuning_duration(1) + torch.cuda.tunable.set_max_tuning_iterations(1) + + torch.cuda.tunable.tune_gemm_in_file(untuned_filename) + assert len(torch.cuda.tunable.get_validators()) > 0 + assert len(torch.cuda.tunable.get_results()) > 0 + assert torch.cuda.tunable.write_file() + + result_filename = os.path.join(tmp_dir, f"tunableop_results{ordinal}.csv") + assert os.path.exists(result_filename) + + # remove the files created above to avoid error 'Build left local git repository checkout dirty', ignore errors + for filename in [untuned_filename, result_filename]: + try: + os.remove(filename) + # NB: The file is locked on Windows + except (FileNotFoundError, PermissionError): + pass + + # disables TunableOp, no file will be written, restore to default values + torch.cuda.tunable.enable(False) + torch.cuda.tunable.record_untuned_enable(False) + torch.cuda.tunable.set_max_tuning_duration(30) + torch.cuda.tunable.set_max_tuning_iterations(100) + assert torch.cuda.tunable.is_enabled() is False, "TunableOp should be off after resetting" + assert torch.cuda.tunable.get_max_tuning_iterations() == 100 + @onlyCUDA @skipCUDAIfNotRocm @dtypes(torch.float) @@ -4817,6 +4880,71 @@ def test_matmul_check_entries_tunableop(self, device, dtype): except FileNotFoundError: pass + @onlyCUDA + @dtypes(torch.float) + def test_disable_tuning_tunableop(self, device, dtype): + # Test that the Python API for disabling tuning stops + # additional tunings even when TunableOp is enabled. + # In other words, test that: + # PYTORCH_TUNABLEOP_ENABLED=1 + # PYTORCH_TUNABLEOP_TUNING=0 + # is no longer tuning GEMMs. + + try: + set_tunableop_defaults() + torch.cuda.tunable.enable() + # set these to single iterations to keep it short but still exercise the code + torch.cuda.tunable.set_max_tuning_iterations(1) + + # Reference number of results + ref_num_results = len(torch.cuda.tunable.get_results()) + + # Tune one GEMMs to make sure TunableOp is enabled + M = 3 + N = 3 + K = 3 + A = torch.randn(N, K, device=device, dtype=dtype) + B = torch.randn(K, M, device=device, dtype=dtype) + C = torch.matmul(A, B) + + # This stores total number of cummulative results + total_num_results = len(torch.cuda.tunable.get_results()) + + # Take the difference to calculate the number of results from + # this test. There should be one additional tuned GEMM + self.assertEqual((total_num_results - ref_num_results), 1) + + # New total number of results becomes new reference result + ref_num_results = total_num_results + + # Now disable further tuning, while keeping TunableOp Enabled + torch.cuda.tunable.tuning_enable(False) + + # Try to tune one more GEMM + M = 3 + N = 3 + K = 4 + A = torch.randn(N, K, device=device, dtype=dtype) + B = torch.randn(K, M, device=device, dtype=dtype) + C = torch.matmul(A, B) + + # Take the difference to calculate the number of results from + # this test. There should be no change in the number of results + # since tuning is disabe. + self.assertEqual((total_num_results - ref_num_results), 0) + + finally: + # disable TunableOp + torch.cuda.tunable.enable(False) + + # clean up, remove any file that was generated + try: + import os + filename = torch.cuda.tunable.get_filename() + os.remove(filename) + except FileNotFoundError: + pass + @dtypes(torch.float, torch.complex64) def test_matmul_out_kernel_errors_with_autograd(self, device, dtype): a = torch.empty((256, 512), device=device, dtype=dtype, requires_grad=True).unsqueeze(0) @@ -6307,7 +6435,7 @@ def genf_int_float(x, y, use_transpose, non_contig_type): x, y = y, x if non_contig_type != 0: y = y * 2 - x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device) + x_int8 = torch.randint(-128, 127, (x, y), dtype=torch.int8, device=device) x_float = x_int8.to(torch.float32) if non_contig_type == 1: x_int8 = x_int8[:, : y // 2] @@ -8359,6 +8487,22 @@ def test_preferred_blas_library(self): self.assertEqual(out1, out2) self.assertEqual(out_ref, out2.cpu()) + @skipCUDAIfNotRocm + @unittest.skipIf(not blaslt_supported_device(), "blasLt not supported on current device") + @setBlasBackendsToDefaultFinally + def test_ck_blas_library(self): + m1 = torch.randint(2, 5, (7168, 8192), device='cuda', dtype=torch.float) + m2 = torch.randint(2, 5, (1280, 8192), device='cuda', dtype=torch.float) + + torch.backends.cuda.preferred_blas_library('ck') + ck_out = torch.nn.functional.linear(m1, m2) + + cpu_out = torch.nn.functional.linear(m1.cpu(), m2.cpu()) + + self.assertEqual(ck_out, cpu_out) + + + def test_permute_matmul(self): a = torch.ones([2, 5, 24, 24]) b = torch.ones([3, 2, 5, 24, 24]) diff --git a/test/test_maskedtensor.py b/test/test_maskedtensor.py index 580f1301d4060..6d4a851ca8659 100644 --- a/test/test_maskedtensor.py +++ b/test/test_maskedtensor.py @@ -78,7 +78,6 @@ def _compare_forward_backward(data, mask, fn): _compare_mt_t(masked_res, tensor_res) _compare_mt_t(mt.grad, t.grad, atol=1e-06) - def _create_random_mask(shape, device): return make_tensor(shape, device=device, dtype=torch.bool) diff --git a/test/test_meta.py b/test/test_meta.py index 3a77d86b128a9..106355d435c37 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -17,13 +17,13 @@ from torch.testing._internal.common_utils import ( TestCase, skipIfCrossRef, - skipIfTorchDynamo, suppress_warnings, TEST_WITH_ASAN, TEST_WITH_TORCHDYNAMO, run_tests, dtype_abbrs, - parametrize + parametrize, + xfailIfTorchDynamo, ) from torch.testing._internal.common_device_type import ( ops, @@ -294,7 +294,7 @@ def test_inplace_set_storage(self): meta.set_(storage, 0, (), ()) self.assertEqual(storage.size(), ssize) - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") + @xfailIfTorchDynamo def test_weakref(self): x = torch.randn(4, 4, 4) m = MetaConverter() @@ -334,7 +334,7 @@ def test_weakref(self): self.assertEqual(len(m.tensor_memo), 0) self.assertEqual(len(m.storage_memo), 0) - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") + @xfailIfTorchDynamo def test_tensor_outlives_converter(self): m = MetaConverter() ref = weakref.ref(m) @@ -1636,6 +1636,42 @@ def test_embedding_bag_dense_backward(self, mode): ) self.assertEqual(grad_weight.to('meta'), meta_grad_weight) + def test_segment_reduce_backward(self): + grad = torch.ones(16, dtype=torch.float) + output = torch.ones(16, dtype=torch.float) + data = torch.ones(16, dtype=torch.float) + reduce_str = 'max' + lengths = torch.ones(16, dtype=torch.long) + + out = torch.ops.aten._segment_reduce_backward(grad, output, data, reduce_str, lengths=lengths) + out_meta = torch.ops.aten._segment_reduce_backward( + grad.to(device='meta'), + output.to(device='meta'), + data.to(device='meta'), + reduce_str, + lengths=lengths.to(device='meta'), + ) + self.assertEqual(out.shape, out_meta.shape) + self.assertEqual(out.stride(), out_meta.stride()) + self.assertEqual(out.dtype, out_meta.dtype) + self.assertEqual(out.layout, out_meta.layout) + + # noncontiguous + grad = torch.ones(16, 2, dtype=torch.float)[:, 1] + data = torch.ones(16, 2, dtype=torch.float)[:, 1] + out = torch.ops.aten._segment_reduce_backward(grad, output, data, reduce_str, lengths=lengths) + out_meta = torch.ops.aten._segment_reduce_backward( + grad.to(device='meta'), + output.to(device='meta'), + data.to(device='meta'), + reduce_str, + lengths=lengths.to(device='meta'), + ) + self.assertEqual(out.shape, out_meta.shape) + self.assertEqual(out.stride(), out_meta.stride()) + self.assertEqual(out.dtype, out_meta.dtype) + self.assertEqual(out.layout, out_meta.layout) + def test_embedding_bag_dense_backward_per_sample_weights(self): weight = torch.randn(4, 3, requires_grad=True) indices = torch.tensor([1, 0, 2, 1, 3]) diff --git a/test/test_mps.py b/test/test_mps.py index 295f3cfbb4945..4540e154ccfa9 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -95,6 +95,7 @@ def mps_ops_grad_modifier(ops): 'index_fill': [torch.float16, torch.float32], # missing `aten::_unique`. 'linalg.lu_factor': [torch.float16, torch.float32], # missing `aten::lu_unpack`. 'aminmax': [torch.float32, torch.float16], + 'i0': None, # missing `aten::i1`. # Correctness issues 'atanh': [torch.float32], @@ -199,6 +200,8 @@ def mps_ops_grad_modifier(ops): # Exception: Caused by sample input at index 3 on MPS 'nn.functional.conv3d': [torch.float32], + + } def addDecorator(op, d) -> None: @@ -273,7 +276,6 @@ def mps_ops_modifier(ops): 'empty', 'empty_permuted', 'empty_strided', - 'eye', 'exp', 'expand', 'expand_as', @@ -292,10 +294,6 @@ def mps_ops_modifier(ops): 'kron', 'linalg.diagonal', 'linalg.svd', - 'linspace', - 'logspace', - 'linspacetensor_overload', - 'logspacetensor_overload', 'mH', 'mT', 'masked_scatter', @@ -321,6 +319,7 @@ def mps_ops_modifier(ops): 'ones', 'outer', 'permute', + 'permute_copy', 'positive', 'randn', 'ravel', @@ -351,6 +350,7 @@ def mps_ops_modifier(ops): 'transpose_copy', 'T', 'unbind', + 'unbind_copy', 'unflatten', 'unfold', 'unfold_copy', @@ -407,6 +407,7 @@ def mps_ops_modifier(ops): 'equal', 'exp2', 'expm1', + 'eye', 'fft.fft', 'fft.fft2', 'fft.fftn', @@ -435,6 +436,8 @@ def mps_ops_modifier(ops): 'ldexp', 'linalg.multi_dot', 'linalg.pinv', + 'linspace', + 'linspacetensor_overload', 'log10', 'log1p', 'log2', @@ -667,6 +670,8 @@ def mps_ops_modifier(ops): UNIMPLEMENTED_XFAILLIST = { # Failures due to lack of op implementation on MPS backend 'login': None, + 'logspace': None, + 'logspacetensor_overload': None, 'linalg.eig': None, 'linalg.eigvals': None, 'put': None, @@ -689,7 +694,6 @@ def mps_ops_modifier(ops): 'geqrf': None, 'nn.functional.grid_sample': None, # Unsupported Border padding mode 'heaviside': None, - 'i0': None, 'igamma': None, 'igammac': None, 'index_copy': None, @@ -939,8 +943,10 @@ def mps_ops_modifier(ops): 'multinomial': [torch.float16, torch.float32, torch.bfloat16], # random results 'uniform': [torch.float16, torch.float32, torch.bfloat16], 'rand_like': [torch.float16, torch.float32, torch.bfloat16], + 'randint': None, 'randint_like': None, - 'randn_like': [torch.float16, torch.float32, torch.bfloat16], + 'randn': None, + 'randn_like': None, 'bernoulli': [torch.float16, torch.float32, torch.bfloat16], 'exponential': [torch.float16, torch.float32, torch.bfloat16], 'nn.functional.feature_alpha_dropoutwith_train': [torch.float16, torch.float32, torch.bfloat16], @@ -985,6 +991,9 @@ def mps_ops_modifier(ops): # Failures due to lack of implementation of downstream functions on MPS backend # TODO: remove these once downstream function 'aten::_linalg_svd.U' have been implemented 'linalg.matrix_rank': None, + + # Exception: Caused by `torch.arange(-8.001, -4.0, dtype=torch.uint8, device="mps")` + 'arange': [torch.uint8], } EMPTY_OPS_SKIPLIST = { @@ -7376,7 +7385,7 @@ def helper(value, dim, index, idx_dtype=torch.int32): def test_embedding_dense_backward(self): def helper(n, d, m, idx): embeddingMPS = nn.Embedding(n, d, max_norm=True, device='mps') - emedding_weight = embeddingMPS.weight.detach().cpu() + embedding_weight = embeddingMPS.weight.detach().cpu() W_MPS = torch.randn((m, d), requires_grad=True, device='mps') idx_MPS = torch.tensor(idx, device='mps') a_MPS = embeddingMPS.weight.clone() @ W_MPS.t() # weight must be cloned for this to be differentiable @@ -7387,7 +7396,7 @@ def helper(n, d, m, idx): loss_MPS = out_MPS.sigmoid().prod() loss_MPS.backward() - embeddingCPU = nn.Embedding(n, d, max_norm=True, _weight=emedding_weight) + embeddingCPU = nn.Embedding(n, d, max_norm=True, _weight=embedding_weight) W_CPU = W_MPS.to('cpu') idx_CPU = torch.tensor(idx) a_CPU = embeddingCPU.weight.clone() @ W_CPU.t() # weight must be cloned for this to be differentiable @@ -10992,6 +11001,12 @@ def test_nonzero_multi_threading(self): t1.start() t2.start() + def test_sliced_view_cast(self): + # This used to crash on MacOS Sequoia + # See https://github.com/pytorch/pytorch/issues/137800 + x = torch.rand(16, 16, device='mps', dtype=torch.float16) + y = x[:, 0:2].view(torch.float32) + 1 + def test_masked_select(self): x = torch.randn(3, 4) x_mps = x.to("mps") @@ -12026,6 +12041,16 @@ def test_serialization_map_location(self): MPS_GRAD_DTYPES = [torch.float32, torch.float16] +def transform_opinfo_sample_to_mps(sample): + """Transforms opinfo.core.SampleInput from CPU to MPS""" + mps_sample = sample.transform( + lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x) + + # Transform kwargs `device="cpu"` to `device="mps"` + if mps_sample.kwargs.get("device", "") == "cpu": + mps_sample.kwargs["device"] = "mps" + return mps_sample + class TestConsistency(TestCaseMPS): # TODO: This is only used while some ops are being added. # This list should contain all ops and dtypes eventually @@ -12053,6 +12078,10 @@ class TestConsistency(TestCaseMPS): 'nn.functional.triplet_margin_loss', 'nn.functional.triplet_margin_with_distance_loss', 'nn.functional.batch_norm', + # NOTE: nn.functional.group_norm is here because 1 ULP difference in the mean + # output from the forward pass (tolerable) blew up into 8 ULP difference from + # the backward pass, and MPS uses fp16 accumulation anyway. + 'nn.functional.group_norm', 'nn.functional.instance_norm', 'round', 'xlogy', 'addcmul', 'nn.functional.cross_entropy', @@ -12081,6 +12110,7 @@ class TestConsistency(TestCaseMPS): 'nn.functional.upsample_bilinear', 'nn.functional.upsample_nearest', 'norm', 'masked.normalize', + 'arange', 'linspace', } FP32_LOW_PRECISION_LIST = { @@ -12123,6 +12153,9 @@ def _compute_tolerances(self, op, dtype): # TODO: Investigate why this is needed # See https://github.com/pytorch/pytorch/issues/120237 return (3e-5, 3e-5) + # TODO: Rounding is broken for linspace, see https://github.com/pytorch/pytorch/issues/137635 + if op.name == 'linspace' and dtype in [torch.int8, torch.uint8, torch.int32, torch.int16, torch.int64]: + return (1.0, 0.0) return (None, None) # Used for accept mode only @@ -12147,8 +12180,7 @@ def get_samples(): # # Forward check # - mps_sample = cpu_sample.transform( - lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x) + mps_sample = transform_opinfo_sample_to_mps(cpu_sample) cpu_args = [cpu_sample.input] + list(cpu_sample.args) cpu_kwargs = cpu_sample.kwargs @@ -12189,8 +12221,7 @@ def get_samples(): # Forward check # forward_failed = False - mps_sample = cpu_sample.transform( - lambda x: x.detach().to("mps").requires_grad_(x.requires_grad) if isinstance(x, torch.Tensor) else x) + mps_sample = transform_opinfo_sample_to_mps(cpu_sample) cpu_args = [cpu_sample.input] + list(cpu_sample.args) cpu_kwargs = cpu_sample.kwargs diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index 1b4a84a484dce..4f724c93069c2 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -3888,6 +3888,21 @@ def grad_test_func(a, b, c): gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) + def test_binary_pointwise_with_nested_int_second_arg(self, device): + # See https://github.com/pytorch/pytorch/issues/138496 + nt = random_nt_from_dims( + [3, None, 5], + device=device, + dtype=torch.float32, + layout=torch.jagged, + ) + + with self.assertRaisesRegex(RuntimeError, "invalid argument"): + nt * nt.size(1) + + with self.assertRaisesRegex(RuntimeError, "invalid argument"): + nt + nt.size(1) + def test_split(self, device): a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) @@ -4231,6 +4246,45 @@ def test_threshold_backward(self, device): self.assertEqual(res_dense, res_nt.values()) + @onlyCUDA + @dtypes(torch.float32) + def test_record_stream(self, device, dtype): + def _create_nt(): + values = torch.ones(1024, 4 * 1024, device="cuda") + offsets = torch.tensor([0, 500, 1024], device="cuda", dtype=torch.int64) + lengths = offsets.diff() + nt = torch.nested.nested_tensor_from_jagged(values, offsets, lengths) + data_ptrs = { + nt._values.data_ptr(), + nt._offsets.data_ptr(), + nt._lengths.data_ptr(), + } + return nt, data_ptrs + + def fn(record_stream): + nt, data_ptrs = _create_nt() + s = torch.cuda.Stream() + + with torch.cuda.stream(s): + # emulate doing something long via sleep + per_ms = 2e7 + torch.cuda._sleep(int(per_ms * 100)) + if record_stream: + nt.record_stream(s) + return data_ptrs + + # expect memory reuse when record_stream() is not run + data_ptrs = fn(record_stream=False) + nt, nt_data_ptrs = _create_nt() + self.assertEqual(data_ptrs, nt_data_ptrs) + del nt + torch.cuda.synchronize() + + # expect memory to be preserved (no reuse) when record_stream() is run + data_ptrs = fn(record_stream=True) + nt, nt_data_ptrs = _create_nt() + self.assertEqual(len(data_ptrs.intersection(nt_data_ptrs)), 0) + @dtypes(torch.float32) @parametrize( "func", @@ -7055,6 +7109,36 @@ def test_noncontiguous_to(self, device, dtype, contiguity): if nt._lengths is not None: self.assertEqual(nt3._lengths.device, other_device) + @dtypes(torch.float32) + def test_autograd_function_with_None_grad(self, device, dtype): + class MyFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, inp): + ctx.save_for_backward(inp) + out1 = inp + 1 + out2 = inp * 2 + return out1, out2 + + @staticmethod + def backward(ctx, grad_out1, grad_out2): + (inp,) = ctx.saved_tensors + return grad_out1 + grad_out2 + + f = MyFunction.apply + nt = random_nt_from_dims( + [5, None, 10], + device=device, + dtype=dtype, + layout=torch.jagged, + requires_grad=True, + ) + + # Only use one of the autograd.Function outputs downstream so that the grad + # for the other output is None. We're testing that the engine can allocate + # correctly-shaped (NJT) zeros for the grad of the other output in this case. + (out1, _) = f(nt) + out1.backward(torch.ones_like(out1)) + @dtypes(torch.float64, torch.float32, torch.half) def test_jagged_padded_dense_conversion_kernels(self, device, dtype): values = torch.randn(10, 5, device=device, dtype=dtype) diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py index 0ad99a58ed8ad..00f3a75ddce7d 100644 --- a/test/test_numpy_interop.py +++ b/test/test_numpy_interop.py @@ -412,7 +412,7 @@ def test_numpy_array_interface(self, device): self.assertEqual(asarray.dtype, dtype) # Only concrete class can be given where "Type[number[_64Bit]]" is expected if np.dtype(dtype).kind == "u": # type: ignore[misc] - wrapped_x = np.array([1, -2, 3, -4], dtype=dtype) + wrapped_x = np.array([1, -2, 3, -4]).astype(dtype) for i in range(len(x)): self.assertEqual(asarray[i], wrapped_x[i]) else: diff --git a/test/test_openmp.py b/test/test_openmp.py index 473a687925762..95a2bd0fdc52c 100644 --- a/test/test_openmp.py +++ b/test/test_openmp.py @@ -4,7 +4,7 @@ import unittest import torch -from torch.testing._internal.common_utils import run_tests, TEST_WITH_ASAN, TestCase +from torch.testing._internal.common_utils import run_tests, TestCase try: @@ -27,7 +27,6 @@ def forward(self, x): @unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run") -@unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN") class TestOpenMP_ParallelFor(TestCase): batch = 20 channels = 1 diff --git a/test/test_optim.py b/test/test_optim.py index 30b489f02fa6f..046b8728e3c00 100644 --- a/test/test_optim.py +++ b/test/test_optim.py @@ -1341,8 +1341,12 @@ def test_optimizer_can_be_printed(self, device, dtype, optim_info): optimizer = optim_cls(params, **optim_input.kwargs) optimizer.__repr__() + @parametrize("is_named_optim0", [True, False]) + @parametrize("is_named_optim1", [True, False]) @optims(optim_db, dtypes=[torch.float32]) - def test_state_dict_deterministic(self, device, dtype, optim_info): + def test_state_dict_deterministic( + self, device, dtype, optim_info, is_named_optim0, is_named_optim1 + ): optim_cls = optim_info.optim_cls # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 @@ -1356,6 +1360,17 @@ def test_state_dict_deterministic(self, device, dtype, optim_info): input = torch.randn(3, requires_grad=True, device=device, dtype=dtype) params = [weight, bias] + def make_named_param(param, is_named): + if not is_named: + return param + return [(f"name{i}", p) for i, p in enumerate(param)] + + def without_param_names(state_dict): + new_state_dict = deepcopy(state_dict) + for pg in new_state_dict["param_groups"]: + pg.pop("param_names", None) + return new_state_dict + def fwd_bwd(optim, w, b, i): optim.zero_grad() loss = (w.mv(i) + b).pow(2).sum() @@ -1368,7 +1383,8 @@ def fwd_bwd(optim, w, b, i): return loss for optim_input in all_optim_inputs: - optimizer = optim_cls(params, **optim_input.kwargs) + params_in = make_named_param(params, is_named=is_named_optim0) + optimizer = optim_cls(params_in, **optim_input.kwargs) closure = functools.partial(fwd_bwd, optimizer, weight, bias, input) # Prime the optimizer @@ -1383,8 +1399,8 @@ def fwd_bwd(optim, w, b, i): with torch.no_grad(): weight_c = Parameter(weight.clone()) bias_c = Parameter(bias.clone()) - - optimizer_c = optim_cls([weight_c, bias_c], **optim_input.kwargs) + params_c = make_named_param([weight_c, bias_c], is_named=is_named_optim1) + optimizer_c = optim_cls(params_c, **optim_input.kwargs) closure_c = functools.partial(fwd_bwd, optimizer_c, weight_c, bias_c, input) # Load the state dict from the original optimizer into the new one @@ -1405,13 +1421,17 @@ def fwd_bwd(optim, w, b, i): self.assertEqual(bias, bias_c) # Make sure state dict is deterministic with equal (not identical) parameters - self.assertEqual(optimizer.state_dict(), optimizer_c.state_dict()) + # Param names are optional and not needed to be the consistent. + self.assertEqual( + without_param_names(optimizer.state_dict()), + without_param_names(optimizer_c.state_dict()), + ) # Make sure repeated parameters have identical representation (see #36831) optimizer_c.param_groups.extend(optimizer_c.param_groups) self.assertEqual( - optimizer.state_dict()["param_groups"][-1], - optimizer_c.state_dict()["param_groups"][-1], + without_param_names(optimizer.state_dict())["param_groups"][-1], + without_param_names(optimizer_c.state_dict())["param_groups"][-1], ) @optims(optim_db, dtypes=[torch.float32]) @@ -1462,8 +1482,77 @@ def fwd_bwd(optim, mod, i): fwd_bwd(optimizer, model, input) optimizer.step() + @parametrize("is_named_optim0", [True, False]) + @parametrize("is_named_optim1", [True, False]) + @optims( + [o for o in optim_db if not o.only_supports_sparse_grads], + dtypes=[torch.float32], + ) + def test_can_load_from_to_named_state_dict( + self, device, dtype, optim_info, is_named_optim0, is_named_optim1 + ): + optim_cls = optim_info.optim_cls + + # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 + all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs( + device, dtype, optim_info, skip=("differentiable",) + ) + for optim_input in all_optim_inputs: + torch.manual_seed(1) + model = torch.nn.Sequential( + torch.nn.Conv2d(4, 2, 1, stride=2), + torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1), + ) + model.to(dtype=dtype, device=device) + input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype) + + def fwd_bwd(optim, mod, i): + optim.zero_grad() + loss = mod(i).sum() + loss.backward() + return loss + + # test for parameters, named_parameters, and 2 groups: + params_to_optimizer = ( + model.named_parameters() if is_named_optim0 else model.parameters() + ) + optimizer = optim_cls(params_to_optimizer, **optim_input.kwargs) + + for _ in range(3): + if optim_info.step_requires_closure: + optimizer.step(functools.partial(fwd_bwd, optimizer, model, input)) + else: + fwd_bwd(optimizer, model, input) + optimizer.step() + + # old_state_dict has all new flags del'd + old_state_dict = deepcopy(optimizer.state_dict()) + + params_to_optimizer2 = ( + model.named_parameters() if is_named_optim1 else model.parameters() + ) + optimizer2 = optim_cls(params_to_optimizer2, **optim_input.kwargs) + optimizer2.load_state_dict(old_state_dict) + + # Make sure we can still step + if optim_info.step_requires_closure: + optimizer2.step(functools.partial(fwd_bwd, optimizer2, model, input)) + else: + fwd_bwd(optimizer2, model, input) + optimizer2.step() + + # Make sure that param_names are preserved when provided to at least one of the optimizers + if is_named_optim0 or is_named_optim1: + self.assertEqual( + optimizer2.state_dict()["param_groups"][0]["param_names"], + ["0.weight", "0.bias", "1.weight", "1.bias"], + ) + + @parametrize("is_named_optim", [True, False]) @optims(optim_db, dtypes=[torch.float32]) - def test_save_load_equality_with_weights_only(self, device, dtype, optim_info): + def test_save_load_equality_with_weights_only( + self, device, dtype, optim_info, is_named_optim + ): optim_cls = optim_info.optim_cls # Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490 @@ -1477,6 +1566,11 @@ def test_save_load_equality_with_weights_only(self, device, dtype, optim_info): input = torch.randn(3, requires_grad=True, device=device, dtype=dtype) params = [weight, bias] + def make_named_param(param, is_named): + if not is_named: + return param + return [(f"name{i}", p) for i, p in enumerate(param)] + def fwd_bwd(optim, w, b, i): optim.zero_grad() loss = (w.mv(i) + b).pow(2).sum() @@ -1487,7 +1581,8 @@ def fwd_bwd(optim, w, b, i): return loss for optim_input in all_optim_inputs: - optimizer = optim_cls(params, **optim_input.kwargs) + params_in = make_named_param(params, is_named=is_named_optim) + optimizer = optim_cls(params_in, **optim_input.kwargs) closure = functools.partial(fwd_bwd, optimizer, weight, bias, input) # Prime the optimizer diff --git a/test/test_overrides.py b/test/test_overrides.py index 6c4f9229e63cf..dc8597309a19d 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -1553,6 +1553,15 @@ class A(torch.Tensor): finally: del g + def test_disable_enable_torch_function_ctx(self): + class A(torch.Tensor): + pass + + x = A(torch.randn(5)) + with torch._C.DisableTorchFunction(): + with torch.overrides._enable_torch_function(): + self.assertIsInstance(torch.sum(x), A) + def test_torch_function_all_disabled_api(self): from torch._C import _is_torch_function_all_disabled @@ -1570,6 +1579,7 @@ def test_torch_function_all_disabled_api(self): state = _is_torch_function_all_disabled() self.assertFalse(state) + def test_subclass_hash(self): class DiagTensor(torch.Tensor): def __init__(self, diag): diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 496dfbccea35a..f31f85f12e219 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1519,11 +1519,6 @@ def f(x1, x2, x3, y): z3 = x3.item() torch._check(z1 == z2 + z3) return y * 2 - if z2 + z3 == z1: - return y * 2 - else: - return y + 3 - # NB: inputs are done as CUDA to ensure they aren't queried to be # backed diff --git a/test/test_reductions.py b/test/test_reductions.py index 323866c80153c..1e2625c4f606e 100644 --- a/test/test_reductions.py +++ b/test/test_reductions.py @@ -20,6 +20,7 @@ from torch.testing._internal.common_utils import ( TestCase, run_tests, skipIfNoSciPy, slowTest, torch_to_numpy_dtype_dict, parametrize, + skipIfTorchDynamo, IS_WINDOWS) from torch.testing._internal.common_device_type import ( OpDTypes, expectedFailureMeta, instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, dtypesIfCPU, @@ -2589,7 +2590,7 @@ def check(op, a, args, key): self.assertEqual(a[:, ::2, :].median(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device)) self.assertEqual(a[:, ::2, :].nanmedian(-1)[0], torch.tensor([[0, 4], [6, 10]], device=device)) - + @skipIfTorchDynamo("https://github.com/pytorch/pytorch/pull/138657 discovers a latent bug") @onlyNativeDeviceTypes @dtypes(torch.float, torch.double) def test_quantile(self, device, dtype): diff --git a/test/test_serialization.py b/test/test_serialization.py index d0202b73cc1a2..a58e47c083176 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -4308,6 +4308,32 @@ def _save_load(t): f.seek(0) torch.load(f, weights_only=True) + @parametrize("force_weights_only", (True, False)) + def test_weights_only_env_variables(self, force_weights_only): + env_var = "TORCH_FORCE_WEIGHTS_ONLY_LOAD" if force_weights_only else "TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD" + args = ( + (pickle.UnpicklingError, "Weights only load failed") + if force_weights_only + else (UserWarning, "forcing weights_only=False") + ) + ctx = self.assertRaisesRegex if force_weights_only else self.assertWarnsRegex + m = torch.nn.Linear(3, 5) + with TemporaryFileName() as f: + torch.save(m, f) + try: + old_value = os.environ[env_var] if env_var in os.environ else None + os.environ[env_var] = "1" + # if weights_only is explicitly set, TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD cannot override it + with self.assertRaisesRegex(pickle.UnpicklingError, "Weights only load failed"): + m = torch.load(f, weights_only=not force_weights_only) + with ctx(*args): + m = torch.load(f, weights_only=None) + finally: + if old_value is None: + del os.environ[env_var] + else: + os.environ[env_var] = old_value + def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super().run(*args, **kwargs) @@ -4500,6 +4526,14 @@ def test_safe_globals_context_manager_weights_only(self): finally: torch.serialization.clear_safe_globals() + def test_sets_are_loadable_with_weights_only(self): + s = {1, 2, 3} + with tempfile.NamedTemporaryFile() as f: + torch.save(s, f) + f.seek(0) + l_s = torch.load(f, weights_only=True) + self.assertEqual(l_s, s) + @unittest.skipIf(not torch.cuda.is_available(), "map_location loads to cuda") def test_tensor_subclass_map_location(self): t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3)) diff --git a/test/test_sparse.py b/test/test_sparse.py index 554172e3bcafc..e6d9c8285dafa 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -320,7 +320,6 @@ def test_shape(sparse_dims, nnz, with_size): @coalescedonoff @dtypes(torch.double, torch.cdouble, torch.bfloat16) @precisionOverride({torch.bfloat16: 1e-2}) - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") def test_coalesce(self, device, dtype, coalesced): def _test_coalesce(t): diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py index f897fd041889f..a63620dcdbee6 100644 --- a/test/test_sparse_csr.py +++ b/test/test_sparse_csr.py @@ -3567,7 +3567,6 @@ def _to_block_triangular_inplace(self, d, row_block, col_block): return d @onlyCUDA - @skipIfRocm(msg="test is too slow on ROCm stack") @dtypes(torch.half, torch.bfloat16, torch.float) @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") @@ -3774,7 +3773,6 @@ def broadcast_input(*ts): @parametrize("block_size", [16, 32, 64]) @onlyCUDA - @skipIfRocm @dtypes(torch.half, torch.bfloat16, torch.float) @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") @@ -3843,7 +3841,6 @@ def test_triton_sampled_addmm(self, device, dtype, block_size): self.assertEqual(res_tri, res_tri_grid) @onlyCUDA - @skipIfRocm @dtypes(torch.half, torch.bfloat16, torch.float) @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float) @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") @@ -4023,16 +4020,24 @@ def test_TensorAsKey(self, device): @suppress_warnings @parametrize("op", ['bsr_dense_addmm', 'bsr_dense_mm', 'bsr_dense_linear', '_int_bsr_dense_addmm']) @parametrize("blocksize", [16, '16x32', 32]) + @parametrize("out_dtype", ['unspecified', 'int32']) @onlyCUDA - @skipIfRocm @dtypes(torch.half, torch.bfloat16, torch.float, torch.int8) @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8) @precisionOverride({torch.float16: 6e-1}) @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") - def test_triton_kernel(self, op, device, dtype, blocksize): + def test_triton_kernel(self, op, device, dtype, blocksize, out_dtype): from torch.sparse._triton_ops import bsr_dense_addmm, bsr_dense_mm, _int_bsr_dense_addmm from torch.sparse._triton_ops_meta import (create_blocked_tensor, get_meta, optimize_bsr_dense_addmm, dump) + if out_dtype == "unspecified": + out_dtype = None + elif op == "bsr_dense_addmm": + out_dtype = getattr(torch, out_dtype) + if out_dtype.is_floating_point != dtype.is_floating_point: + self.skipTest("incompatible out dtype") + else: + self.skipTest("out dtype not implemented") def bsr_dense_linear(input, weights, bias=None): return torch.nn.functional.linear(input, weights, bias=bias).transpose(-1, -2) @@ -4048,7 +4053,10 @@ def reference(input, mat1, mat2, beta=1, alpha=1, left_alpha=None, right_alpha=N mat12 = torch._int_mm(mat1, mat2) else: # workaround RuntimeError: "addmm_cuda" not implemented for 'Char' - mat12 = torch._int_mm(mat1, mat2).to(torch.int8) + if out_dtype is not None: + mat12 = torch._int_mm(mat1, mat2).to(out_dtype) + else: + mat12 = torch._int_mm(mat1, mat2).to(torch.int8) else: mat12 = mat1 @ mat2 if alpha != 1: @@ -4144,7 +4152,12 @@ def nc_copy(t, axes=(-1,)): dump() # this will update torch/sparse/_triton_ops_meta.py expected = reference(input, mat1, mat2, beta=beta, alpha=alpha, left_alpha=left_alpha, right_alpha=right_alpha) - kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha, + if out_dtype is not None: + expected = expected.to(out_dtype) + out = expected.new_empty(input.shape, dtype=out_dtype) + else: + out = None + kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha, out=out, left_alpha=left_alpha, right_alpha=right_alpha), bsr_dense_mm={}, bsr_dense_linear=dict(bias=input.transpose(-1, -2)))[op] @@ -4175,21 +4188,30 @@ def nc_copy(t, axes=(-1,)): if op in {'bsr_dense_addmm', 'bsr_dense_linear'}: args = dict(bsr_dense_addmm=(nc_input, bsr, nc_mat2), bsr_dense_linear=(nc_mat2.transpose(-1, -2), bsr))[op] - kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha, left_alpha=left_alpha, right_alpha=right_alpha), + kwargs = dict(bsr_dense_addmm=dict(beta=beta, alpha=alpha, left_alpha=left_alpha, right_alpha=right_alpha, out=out), bsr_dense_linear=dict(bias=nc_input.transpose(-1, -2)))[op] result = operation(*args, **kwargs) self.assertEqual(result, expected) @parametrize("op", ['bsr_dense_addmm', '_int_bsr_dense_addmm']) @onlyCUDA - @skipIfRocm + @parametrize("out_dtype", ['unspecified', 'int32']) @dtypes(torch.half, torch.bfloat16, torch.float, torch.int8) @dtypesIfCUDA(torch.half, *[torch.bfloat16] if SM80OrLater else [], torch.float, torch.int8) @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") - def test_triton_tune(self, op, device, dtype): + def test_triton_tune(self, op, device, dtype, out_dtype): from torch.sparse._triton_ops import bsr_dense_addmm, _int_bsr_dense_addmm from torch.sparse._triton_ops_meta import (create_blocked_tensor, tune_bsr_dense_addmm, tune__int_bsr_dense_addmm, get_meta) + if out_dtype == "unspecified": + out_dtype = None + elif op == "bsr_dense_addmm": + out_dtype = getattr(torch, out_dtype) + if out_dtype.is_floating_point != dtype.is_floating_point: + self.skipTest("incompatible out dtype") + else: + self.skipTest("out dtype not implemented") + operation = dict(bsr_dense_addmm=bsr_dense_addmm, _int_bsr_dense_addmm=_int_bsr_dense_addmm)[op] tuner = dict(bsr_dense_addmm=tune_bsr_dense_addmm, _int_bsr_dense_addmm=tune__int_bsr_dense_addmm)[op] @@ -4205,12 +4227,19 @@ def test_triton_tune(self, op, device, dtype): sparsity = 1 - bsr._nnz() * blocksize[0] * blocksize[1] / (M * K) input = make_tensor(K, N, dtype=dtype, device=device) dense = make_tensor(K, N, dtype=dtype, device=device) + version_dtype = dtype + if out_dtype is None: + out = None + else: + out = input.new_empty(input.shape, dtype=out_dtype) + if dtype is not out_dtype: + version_dtype = (dtype, out_dtype) if op in {'bsr_dense_addmm', '_int_bsr_dense_addmm'}: args = (input, bsr, dense) def get_current_meta(): - version = (0, dtype, sparsity) + version = (0, version_dtype, sparsity) meta_key = (M, K, N, *blocksize, False, True, True) return get_meta(op, meta_key, version=version, exact=True) else: @@ -4218,15 +4247,14 @@ def get_current_meta(): self.assertEqual(get_current_meta(), None) - meta = tuner(*args, **dict(store=True, verbose=False)) + meta = tuner(*args, **dict(store=True, verbose=False, out=out)) self.assertEqual(get_current_meta(), meta) - expected = operation(*args) - result = operation(*args, **dict(meta=meta)) + expected = operation(*args, **dict(out=None if out_dtype is None else out.clone())) + result = operation(*args, **dict(meta=meta, out=out)) self.assertEqual(result, expected) @onlyCUDA - @skipIfRocm @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "Test requires Triton") def test_triton_bsr_dense_addmm_meta(self, device): from torch.sparse._triton_ops import bsr_dense_addmm_meta diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 1a2adc104e944..f5a7975c9e841 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -14,11 +14,27 @@ from torch.testing import make_tensor from torch.testing._internal.common_utils import ( - TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings, - torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest, - set_default_dtype, set_default_tensor_type, - TEST_SCIPY, IS_MACOS, IS_PPC, IS_JETSON, IS_WINDOWS, parametrize, skipIfTorchDynamo, - xfailIfTorchDynamo) + TestCase, + run_tests, + do_test_empty_full, + TEST_WITH_ROCM, + suppress_warnings, + torch_to_numpy_dtype_dict, + numpy_to_torch_dtype_dict, + slowTest, + set_default_dtype, + set_default_tensor_type, + TEST_SCIPY, + IS_MACOS, + IS_PPC, + IS_JETSON, + IS_WINDOWS, + IS_FBCODE, + IS_SANDCASTLE, + parametrize, + skipIfTorchDynamo, + xfailIfTorchDynamo, +) from torch.testing._internal.common_device_type import ( expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes, onlyCPU, largeTensorTest, precisionOverride, dtypes, @@ -148,7 +164,16 @@ def test_vander_types(self, device, dtype): exact_dtype=False) def test_cat_all_dtypes_and_devices(self, device): - for dt in all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.chalf): + for dt in all_types_and_complex_and( + torch.half, + torch.bool, + torch.bfloat16, + torch.chalf, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + ): x = torch.tensor([[1, 2], [3, 4]], dtype=dt, device=device) expected1 = torch.tensor([[1, 2], [3, 4], [1, 2], [3, 4]], dtype=dt, device=device) @@ -1046,6 +1071,9 @@ def test_float_to_int_conversion_finite(self, device, dtype): # Note: numpy -2.0 or -1.5 -> uint8 conversion is undefined # see https://github.com/pytorch/pytorch/issues/97794 refs = (0, 254, 255, 0, 0, 0, 1, 2) + elif dtype == torch.int16: + # CPU min and max float -> int16 conversion is divergent. + vals = (-2, -1.5, -.5, 0, .5, 1.5, 2) self._float_to_int_conversion_helper(vals, device, dtype, refs) @@ -2478,7 +2506,6 @@ def test_arange(self, device): self.assertEqual(d.shape[0], 800) # TODO: this test should be updated - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") @onlyCPU def test_arange_inference(self, device): # end only @@ -3556,6 +3583,7 @@ def test_randperm(self, device): # Test exceptions when device and generator types are incompatible @onlyCUDA + @unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, "Produces inconsistent errors when run in fbcode.") def test_randperm_device_compatibility(self, device): cuda_gen = torch.Generator(device='cuda') cpu_gen = torch.Generator(device='cpu') diff --git a/test/test_testing.py b/test/test_testing.py index b215e62a7ac44..56ce579374b9e 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -19,7 +19,7 @@ from torch.testing import make_tensor from torch.testing._internal.common_utils import \ (IS_FBCODE, IS_JETSON, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, slowTest, - parametrize, subtest, instantiate_parametrized_tests, dtype_name, TEST_WITH_ROCM, decorateIf) + parametrize, subtest, instantiate_parametrized_tests, dtype_name, TEST_WITH_ROCM, decorateIf, skipIfRocm) from torch.testing._internal.common_device_type import \ (PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes, get_device_type_test_bases, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyNativeDeviceTypes, @@ -30,6 +30,7 @@ from torch.testing._internal.common_modules import modules, module_db, ModuleInfo from torch.testing._internal.opinfo.core import SampleInput, DecorateInfo, OpInfo import operator +import string # For testing TestCase methods and torch.testing functions class TestTesting(TestCase): @@ -2221,6 +2222,9 @@ def _check_python_output(cls, program) -> str: # fail, so just set CWD to this script's directory cwd=os.path.dirname(os.path.realpath(__file__)),).decode("utf-8") + # The test is flaky on ROCm and has been open and close multiple times + # https://github.com/pytorch/pytorch/issues/110040 + @skipIfRocm def test_circular_dependencies(self) -> None: """ Checks that all modules inside torch can be imported Prevents regression reported in https://github.com/pytorch/pytorch/issues/77441 """ @@ -2296,7 +2300,7 @@ def test_no_mutate_global_logging_on_import(self, path) -> None: # Calling logging.basicConfig, among other things, modifies the global # logging state. It is not OK to modify the global logging state on # `import torch` (or other submodules we own) because users do not expect it. - expected = 'abcdefghijklmnopqrstuvwxyz' + expected = string.ascii_lowercase commands = [ 'import logging', f'import {path}', diff --git a/test/test_torch.py b/test/test_torch.py index 868aeeb7b5689..666a590b1db3c 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -237,7 +237,7 @@ def test_storage_setitem(self, device, dtype): s[2:7] = 1 self.assertEqual(s, storage_type(l)) - @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") + @xfailIfTorchDynamo @onlyNativeDeviceTypes @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16)) def test_tensor_storage_type(self, device, dtype): @@ -1337,7 +1337,8 @@ def test_deterministic_resize(self, device, dtype): # point tensors with NaN and integer tensors with MAX_INT @skipXLA @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") - @dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64)) + @dtypes(*all_types_and_complex_and( + torch.half, torch.bool, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64, torch.complex32)) def test_deterministic_empty(self, device, dtype): gen_fns = [ lambda: torch.empty(10, 9, device=device, dtype=dtype), @@ -1739,11 +1740,33 @@ def test_nondeterministic_alert_EmbeddingBag_max(self, device): 'embedding_bag_backward_cuda_max', torch.device(device).type == 'cuda') + @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") + @onlyCUDA + def test_deterministic_cumsum(self, device): + test_cases = [ + # size, dim + [(2, 3, 4), 0], + [(2, 3, 4), 1], + [(2, 3, 4), 2], + [(1000, 10, 2), 0], + ] + for size, dim in test_cases: + input = 100 * torch.randn(*size, device=device) + with DeterministicGuard(True): + res0 = input.cumsum(dim) + for _ in range(3): + res1 = input.cumsum(dim) + self.assertEqual(res0, res1, atol=0, rtol=0) + + res_cpu = input.cpu().cumsum(dim) + self.assertEqual(res0, res_cpu, atol=1e-3, rtol=1e-2) + + @dtypes(*all_types_and_complex_and(torch.bool)) @skipIfTorchInductor("https://github.com/pytorch/pytorch/issues/113707") def test_nondeterministic_alert_cumsum(self, device, dtype): input = make_tensor((10,), dtype=dtype, device=device, low=-9, high=9) - should_alert = torch.device(device).type == 'cuda' and (dtype.is_floating_point or dtype.is_complex) + should_alert = False for op_call in [torch.Tensor.cumsum, torch.cumsum]: self.check_nondeterministic_alert( @@ -2891,7 +2914,7 @@ def test_diff(self, device, dtype): # if the given input arg is not a list, it returns a list of single element: [arg] def _wrap_to_list(self, input_array): - return input_array if isinstance(input_array, list) else [input_array] + return list(input_array) if isinstance(input_array, (list, tuple)) else [input_array] # To ensure inf, -inf, and nan values do not cause divergence between Numpy and PyTorch. # There are two types of possible divergence: @@ -3029,7 +3052,7 @@ def test_gradient_type_promotion(self, device): # Result is given just as real number and all the imaginary parts to be equal to zero. self.assertEqual(expected[i].imag, torch.zeros(actual[i].shape), exact_dtype=False) else: - actual, expected = self._inf_nan_preprocess(list(actual), expected) + actual, expected = self._inf_nan_preprocess(list(actual), list(expected)) self.assertEqual(actual, expected, equal_nan=True, exact_dtype=False) @onlyNativeDeviceTypes @@ -5086,10 +5109,20 @@ def _get_tensors(**kwargs): @deviceCountAtLeast(1) @onlyCUDA - def test_storage_all_devices(self, devices): + @parametrize("non_blocking", (True, False)) + def test_storage_all_devices(self, devices, non_blocking): for device in devices: - t = torch.tensor((), device=device) + t = torch.randn(6, device=device) self.assertEqual(t.dtype, t.storage().dtype) + s = t.untyped_storage() + s_cpu = s.to(device='cpu', non_blocking=non_blocking) + if non_blocking: + torch.cuda.synchronize() + self.assertTrue(s_cpu.is_pinned()) + else: + self.assertFalse(s_cpu.is_pinned()) + t_cpu = torch.empty(()).set_(s_cpu) + self.assertEqual(t.cpu(), t_cpu) # Note [lazy_clone_ tests with inductor enabled] # These `lazy_clone_` tests are written in a way that makes them pass in @@ -7559,10 +7592,10 @@ def test_sobolengine_distribution(self, scramble=False): torch.mean(sample, dim=0), torch.full((d,), 0.5), atol=2, rtol=2 ) torch.testing.assert_close( - np.percentile(sample, 25, axis=0), np.repeat(0.25, d), atol=2, rtol=2 + np.percentile(sample, 25, axis=0).astype(np.float64), np.repeat(0.25, d), atol=2, rtol=2 ) torch.testing.assert_close( - np.percentile(sample, 75, axis=0), np.repeat(0.75, d), atol=2, rtol=2 + np.percentile(sample, 75, axis=0).astype(np.float64), np.repeat(0.75, d), atol=2, rtol=2 ) @skipIfTorchDynamo("np.float64 restored as float32 after graph break.") diff --git a/test/test_transformers.py b/test/test_transformers.py index e332342e77289..737f830fef329 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -377,7 +377,7 @@ def test_multiheadattention_fastpath_attn_mask(self, device, attn_mask_dim, key_ out_fp, _ = mha(X, X, X, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) # The FP kernel will return NaNs while the sdpa kernel which is ran when the fast path is turned off returns 0 instead # of NaNs for fully masked rows - torch.testing.assert_close(out, out_fp.nan_to_num()) + self.assertEqual(out, out_fp.nan_to_num()) @parametrize("nhead", [1, 4, 8]) def test_transformerencoderlayer_src_mask(self, device, nhead): @@ -2495,6 +2495,46 @@ def test_cudnn_attention_trivial_output_transpose(self, device): o.backward(o) torch.testing.assert_close(x.grad, x_cpu.grad.cuda(), atol=7e-3, rtol=7e-3) + @skipIfRocm # No cuDNN Attention + @unittest.skipIf(not PLATFORM_SUPPORTS_CUDNN_ATTENTION, "cudnn Attention is not supported on this system") + def test_cudnn_attention_nonmodulo64seqlen(self, device): + # see also: https://github.com/pytorch/pytorch/issues/137347 + mask = torch.randint(0, 2, (2, 1, 157, 6404)).to(device="cuda", dtype=torch.bool) + q = torch.randn(2, 32, 157, 128, device='cuda', dtype=torch.bfloat16, requires_grad=True) + k = torch.randn(2, 32, 6404, 128, device='cuda', dtype=torch.bfloat16, requires_grad=True) + v = torch.randn(2, 32, 6404, 128, device='cuda', dtype=torch.bfloat16, requires_grad=True) + q_cpu = q.detach().clone().cpu() + k_cpu = k.detach().clone().cpu() + v_cpu = v.detach().clone().cpu() + q_cpu.requires_grad = True + k_cpu.requires_grad = True + v_cpu.requires_grad = True + mask_cpu = mask.detach().clone().cpu() + with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION): + out = nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=0.0, + is_causal=False, + ) + out_cpu = nn.functional.scaled_dot_product_attention( + q_cpu, + k_cpu, + v_cpu, + attn_mask=mask_cpu, + dropout_p=0.0, + is_causal=False, + ) + + out.sum().backward() + out_cpu.sum().backward() + + torch.testing.assert_close(q.grad, q_cpu.grad.cuda(), atol=3e-3, rtol=2e-3) + torch.testing.assert_close(k.grad, k_cpu.grad.cuda(), atol=3e-3, rtol=2e-3) + torch.testing.assert_close(v.grad, v_cpu.grad.cuda(), atol=3e-3, rtol=2e-3) + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") @parametrize("mask_dim", [1, 2, 3, 4]) def test_mem_efficient_attention_mask_variants(self, device, mask_dim: List[int]): @@ -2689,7 +2729,7 @@ def rand_tensor(shape): math_ref_test = math_ref_test.to(dtype=torch.float32).contiguous() math_ref_lp_test = math_ref_lp_test.to(dtype=torch.float32).contiguous() - self.assertEqual(math_ref_test, math_ref_lp_test, atol=7e-3, rtol=7e-3) + self.assertEqual(math_ref_test, math_ref_lp_test, atol=8e-3, rtol=7e-3) self.assertEqual(actual_test, math_ref_test, atol=7e-3, rtol=7e-3) @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Efficient Attention was not built for this system") @@ -2808,12 +2848,18 @@ def test_fused_sdp_choice(self, device, type: str): value = value.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, num_heads, head_dim).transpose(1, 2) + # TODO we are currently disabling this by default, lets assert that this returns + # FlashAttention, we need to change when we make remove opt-in for cudnn if type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater: - self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) + self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) + with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): + self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) elif PLATFORM_SUPPORTS_FLASH_ATTENTION: self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.FLASH_ATTENTION.value) elif type != "nested" and PLATFORM_SUPPORTS_CUDNN_ATTENTION: # e.g., we're on Windows - self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) + self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.EFFICIENT_ATTENTION.value) + with sdpa_kernel(backends=[SDPBackend.CUDNN_ATTENTION]): + self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.CUDNN_ATTENTION.value) else: self.assertEqual(torch._fused_sdp_choice(query, key, value), SDPBackend.EFFICIENT_ATTENTION.value) @@ -3111,7 +3157,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, fudge_factors = { "out": 4, - "grad_query": 150.0, + "grad_query": 160.0, "grad_key": 25.0, "grad_value": 8.0, "grad_attn_mask": 45.0, diff --git a/test/test_utils_config_module.py b/test/test_utils_config_module.py new file mode 100644 index 0000000000000..f0539452961f3 --- /dev/null +++ b/test/test_utils_config_module.py @@ -0,0 +1,264 @@ +# Owner(s): ["module: unknown"] +import pickle + +from torch.testing._internal import fake_config_module as config +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestConfigModule(TestCase): + def test_base_value_loading(self): + self.assertTrue(config.e_bool) + self.assertTrue(config.nested.e_bool) + self.assertEqual(config.e_int, 1) + self.assertEqual(config.e_float, 1.0) + self.assertEqual(config.e_string, "string") + self.assertEqual(config.e_list, [1]) + self.assertEqual(config.e_set, {1}) + self.assertEqual(config.e_tuple, (1,)) + self.assertEqual(config.e_dict, {1: 2}) + self.assertEqual(config.e_none, None) + with self.assertRaises( + AttributeError, msg="fake_config_module.does_not_exist does not exist" + ): + config.does_not_exist + + def test_overrides(self): + config.e_bool = False + self.assertFalse(config.e_bool) + config.nested.e_bool = False + self.assertFalse(config.nested.e_bool) + config.e_int = 2 + self.assertEqual(config.e_int, 2) + config.e_float = 2.0 + self.assertEqual(config.e_float, 2.0) + config.e_string = "string2" + self.assertEqual(config.e_string, "string2") + config.e_list = [2] + self.assertEqual(config.e_list, [2]) + config.e_set = {2} + self.assertEqual(config.e_set, {2}) + config.e_tuple = (2,) + self.assertEqual(config.e_tuple, (2,)) + config.e_dict = {2: 3} + self.assertEqual(config.e_dict, {2: 3}) + config.e_none = "not none" + self.assertEqual(config.e_none, "not none") + config.e_none = None + self.assertEqual(config.e_none, None) + with self.assertRaises( + AttributeError, msg="fake_config_module.does_not_exist does not exist" + ): + config.does_not_exist = 0 + # Config changes get persisted between test cases + config.e_bool = True + config.nested.e_bool = True + config.e_int = 1 + config.e_float = 1.0 + config.e_string = "string" + config.e_list = [1] + config.e_set = {1} + config.e_tuple = (1,) + config.e_dict = {1: 2} + config.e_none = None + + def test_delete(self): + self.assertTrue(config.e_bool) + del config.e_bool + with self.assertRaises( + AttributeError, msg="fake_config_module.e_bool does not exist" + ): + print(config.e_bool) + # Config changes get persisted between test cases + config.e_bool = True + + def test_save_config(self): + p = config.save_config() + self.assertEqual( + pickle.loads(p), + { + "_cache_config_ignore_prefix": ["magic_cache_config"], + "e_bool": True, + "e_dict": {1: 2}, + "e_float": 1.0, + "e_int": 1, + "e_list": [1], + "e_none": None, + "e_set": {1}, + "e_string": "string", + "e_tuple": (1,), + "nested.e_bool": True, + "_e_ignored": True, + "e_compile_ignored": True, + "magic_cache_config_ignored": True, + "_save_config_ignore": ["e_ignored"], + }, + ) + config.e_bool = False + config.e_ignored = False + config.load_config(p) + self.assertTrue(config.e_bool) + self.assertFalse(config.e_ignored) + # Config changes get persisted between test cases + config.e_ignored = True + + def test_save_config_portable(self): + p = config.save_config_portable() + self.assertEqual( + p, + { + "e_bool": True, + "e_dict": {1: 2}, + "e_float": 1.0, + "e_int": 1, + "e_list": [1], + "e_none": None, + "e_set": {1}, + "e_string": "string", + "e_tuple": (1,), + "nested.e_bool": True, + "e_ignored": True, + "e_compile_ignored": True, + }, + ) + config.e_bool = False + config._e_ignored = False + config.load_config(p) + self.assertTrue(config.e_bool) + self.assertFalse(config._e_ignored) + # Config changes get persisted between test cases + config._e_ignored = True + + def test_codegen_config(self): + config.e_bool = False + config.e_ignored = False + code = config.codegen_config() + self.assertEqual( + code, "torch.testing._internal.fake_config_module.e_bool = False" + ) + # Config changes get persisted between test cases + config.e_bool = True + config.e_ignored = True + + def test_get_hash(self): + self.assertEqual( + config.get_hash(), b"\xcd\x96\x93\xf5(\xf8(\xa5\x1c+O\n\xd3_\x0b\xa6" + ) + # Test cached value + self.assertEqual( + config.get_hash(), b"\xcd\x96\x93\xf5(\xf8(\xa5\x1c+O\n\xd3_\x0b\xa6" + ) + self.assertEqual( + config._hash_digest, b"\xcd\x96\x93\xf5(\xf8(\xa5\x1c+O\n\xd3_\x0b\xa6" + ) + config._hash_digest = "fake" + self.assertEqual(config.get_hash(), "fake") + + # BUG + config.e_bool = False + self.assertNotEqual( + config.get_hash(), b"\xcd\x96\x93\xf5(\xf8(\xa5\x1c+O\n\xd3_\x0b\xa6" + ) + config.e_bool = True + + # Test ignored values + config.e_compile_ignored = False + self.assertEqual( + config.get_hash(), b"\xcd\x96\x93\xf5(\xf8(\xa5\x1c+O\n\xd3_\x0b\xa6" + ) + config.e_compile_ignored = True + + def test_dict_copy_semantics(self): + p = config.shallow_copy_dict() + self.assertEqual( + p, + { + "e_bool": True, + "e_dict": {1: 2}, + "e_float": 1.0, + "e_int": 1, + "e_list": [1], + "e_none": None, + "e_set": {1}, + "e_string": "string", + "e_tuple": (1,), + "nested.e_bool": True, + "e_ignored": True, + "_e_ignored": True, + "e_compile_ignored": True, + "_cache_config_ignore_prefix": ["magic_cache_config"], + "_save_config_ignore": ["e_ignored"], + "magic_cache_config_ignored": True, + }, + ) + p2 = config.to_dict() + self.assertEqual( + p2, + { + "e_bool": True, + "e_dict": {1: 2}, + "e_float": 1.0, + "e_int": 1, + "e_list": [1], + "e_none": None, + "e_set": {1}, + "e_string": "string", + "e_tuple": (1,), + "nested.e_bool": True, + "e_ignored": True, + "_e_ignored": True, + "e_compile_ignored": True, + "_cache_config_ignore_prefix": ["magic_cache_config"], + "_save_config_ignore": ["e_ignored"], + "magic_cache_config_ignored": True, + }, + ) + p3 = config.get_config_copy() + self.assertEqual( + p3, + { + "e_bool": True, + "e_dict": {1: 2}, + "e_float": 1.0, + "e_int": 1, + "e_list": [1], + "e_none": None, + "e_set": {1}, + "e_string": "string", + "e_tuple": (1,), + "nested.e_bool": True, + "e_ignored": True, + "_e_ignored": True, + "e_compile_ignored": True, + "_cache_config_ignore_prefix": ["magic_cache_config"], + "_save_config_ignore": ["e_ignored"], + "magic_cache_config_ignored": True, + }, + ) + + # Shallow + deep copy semantics + config.e_dict[2] = 3 + self.assertEqual(p["e_dict"], {1: 2}) + self.assertEqual(p2["e_dict"], {1: 2}) + self.assertEqual(p3["e_dict"], {1: 2}) + config.e_dict = {1: 2} + + def test_patch(self): + with config.patch("e_bool", False): + self.assertFalse(config.e_bool) + self.assertTrue(config.e_bool) + with config.patch(e_bool=False): + self.assertFalse(config.e_bool) + self.assertTrue(config.e_bool) + with self.assertRaises(AssertionError): + with config.patch("does_not_exist"): + pass + + def test_make_closur_patcher(self): + revert = config._make_closure_patcher(e_bool=False)() + self.assertFalse(config.e_bool) + revert() + self.assertTrue(config.e_bool) + + +if __name__ == "__main__": + run_tests() diff --git a/test/test_xpu.py b/test/test_xpu.py index 7ace821693089..933f1be9ac3dd 100644 --- a/test/test_xpu.py +++ b/test/test_xpu.py @@ -420,6 +420,14 @@ def test_device_memory_allocated(self): ) ) + def test_get_arch_list(self): + arch_list = torch.xpu.get_arch_list() + if not arch_list: + return + flags = torch.xpu.get_gencode_flags() + for arch in arch_list: + self.assertTrue(arch in flags) + instantiate_device_type_tests(TestXpu, globals(), only_for="xpu", allow_xpu=True) diff --git a/test/torch_np/numpy_tests/lib/test_function_base.py b/test/torch_np/numpy_tests/lib/test_function_base.py index 876bd553d0399..0664664d2ac68 100644 --- a/test/torch_np/numpy_tests/lib/test_function_base.py +++ b/test/torch_np/numpy_tests/lib/test_function_base.py @@ -33,6 +33,8 @@ IS_WASM = False IS_PYPY = False +import string + # FIXME: make from torch._numpy # These are commented, as if they are imported, some of the tests pass for the wrong reasons # from numpy lib import digitize, piecewise, trapz, select, trim_zeros, interp @@ -1528,7 +1530,7 @@ def test_execution_order_ticket_1487(self): def test_string_ticket_1892(self): # Test vectorization over strings: issue 1892. f = np.vectorize(lambda x: x) - s = "0123456789" * 10 + s = string.digits * 10 assert_equal(s, f(s)) def test_cache(self): diff --git a/test/torch_np/test_basic.py b/test/torch_np/test_basic.py index c5bd65369f6fb..4f7551bb471ec 100644 --- a/test/torch_np/test_basic.py +++ b/test/torch_np/test_basic.py @@ -561,14 +561,19 @@ def test_set_default_float(self, dt): @skip(_np.__version__ <= "1.23", reason="from_dlpack is new in NumPy 1.23") class TestExport(TestCase): def test_exported_objects(self): - exported_fns = ( + exported_fns = { x for x in dir(w) if inspect.isfunction(getattr(w, x)) and not x.startswith("_") and x != "set_default_dtype" - ) - diff = set(exported_fns).difference(set(dir(_np))) + } + if _np.__version__ > "2": + # The following methods are removed in NumPy 2. + # See https://numpy.org/devdocs/numpy_2_0_migration_guide.html#main-namespace + exported_fns -= {"product", "round_", "sometrue", "cumproduct", "alltrue"} + + diff = exported_fns.difference(set(dir(_np))) assert len(diff) == 0, str(diff) diff --git a/third_party/composable_kernel b/third_party/composable_kernel new file mode 160000 index 0000000000000..11b7a4db005dc --- /dev/null +++ b/third_party/composable_kernel @@ -0,0 +1 @@ +Subproject commit 11b7a4db005dc38e60b1ea045d03a92d2a8f9cd0 diff --git a/third_party/cpuinfo b/third_party/cpuinfo index a5ff6df40ce52..1e83a2fdd3102 160000 --- a/third_party/cpuinfo +++ b/third_party/cpuinfo @@ -1 +1 @@ -Subproject commit a5ff6df40ce528721cfc310c7ed43946d77404d5 +Subproject commit 1e83a2fdd3102f65c6f1fb602c1b320486218a99 diff --git a/third_party/cudnn_frontend b/third_party/cudnn_frontend index 2533f5e5c1877..936021bfed8c9 160000 --- a/third_party/cudnn_frontend +++ b/third_party/cudnn_frontend @@ -1 +1 @@ -Subproject commit 2533f5e5c1877fd76266133c1479ef1643ce3a8b +Subproject commit 936021bfed8c91dc416af1588b2c4eca631a9e45 diff --git a/third_party/kineto b/third_party/kineto index b5c85daac1ee1..ed052ea024b94 160000 --- a/third_party/kineto +++ b/third_party/kineto @@ -1 +1 @@ -Subproject commit b5c85daac1ee123aa7f04eb6f2bc71363f429e68 +Subproject commit ed052ea024b9468908d558b15cd3f7584fb0f492 diff --git a/third_party/xpu.txt b/third_party/xpu.txt index 69606f14a7af3..b79c4dba924e9 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -7e3d00acea9f0d3728048a5b2743de20d55c64ba +b3d5d78c72eadc5140aef1f8e06844385e9a2d45 diff --git a/tools/autograd/context.py b/tools/autograd/context.py index d838aa3c77bbb..146cf571d3041 100644 --- a/tools/autograd/context.py +++ b/tools/autograd/context.py @@ -9,7 +9,7 @@ # Like tools.api.context.with_native_function, but for # NativeFunctionWithDifferentiabilityInfo. def with_native_function_with_differentiability_info( - func: Callable[[NFWDI], T] + func: Callable[[NFWDI], T], ) -> Callable[[NFWDI], T]: @functools.wraps(func) def wrapper(f: NFWDI) -> T: @@ -21,7 +21,7 @@ def wrapper(f: NFWDI) -> T: # Like the above but with an additional dispatch key string argument def with_native_function_with_differentiability_info_and_key( - func: Callable[[NFWDI, str], T] + func: Callable[[NFWDI, str], T], ) -> Callable[[NFWDI, str], T]: @functools.wraps(func) def wrapper(f: NFWDI, key: str) -> T: diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 9df4d965d9f78..af829ff2bf77e 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -112,9 +112,9 @@ # # - `wrap_opt_if`, is a 2-argument function that accepts a tensor # variable and a boolean condition that dictates whether to save that -# variable in a graph. The result of this function is `c10::optional`, +# variable in a graph. The result of this function is `std::optional`, # and it is `::std::nullopt` when the condition evalutes to `false`, -# otherwise it is the variable wrapped in `c10::optional`. +# otherwise it is the variable wrapped in `std::optional`. # For example, wrap_opt_if(var_0, grad_input_mask[1] || grad_input_mask[2]) # would mean that `var_0` is saved as long as the second (grad_input_mask[1]) # or the third (grad_input_mask[2]) argument requires gradients. diff --git a/tools/autograd/gen_autograd.py b/tools/autograd/gen_autograd.py index f6e7be149ad6d..d93d3f4cab4a6 100644 --- a/tools/autograd/gen_autograd.py +++ b/tools/autograd/gen_autograd.py @@ -70,9 +70,9 @@ def gen_autograd( ), key=lambda f: cpp.name(f.func), ) - fns_with_diff_infos: list[ - NativeFunctionWithDifferentiabilityInfo - ] = match_differentiability_info(fns, differentiability_infos) + fns_with_diff_infos: list[NativeFunctionWithDifferentiabilityInfo] = ( + match_differentiability_info(fns, differentiability_infos) + ) # Generate VariableType.h/cpp if not disable_autograd: diff --git a/tools/autograd/gen_autograd_functions.py b/tools/autograd/gen_autograd_functions.py index 785ea68315b76..769334d2ee243 100644 --- a/tools/autograd/gen_autograd_functions.py +++ b/tools/autograd/gen_autograd_functions.py @@ -447,7 +447,7 @@ def get_infos_with_derivatives_list( - differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]] + differentiability_infos: dict[FunctionSchema, dict[str, DifferentiabilityInfo]], ) -> list[DifferentiabilityInfo]: diff_info_list = [ info diff --git a/tools/autograd/gen_inplace_or_view_type.py b/tools/autograd/gen_inplace_or_view_type.py index e8141658b0335..afc932606a519 100644 --- a/tools/autograd/gen_inplace_or_view_type.py +++ b/tools/autograd/gen_inplace_or_view_type.py @@ -590,8 +590,7 @@ def inplace_or_view_method_definition( # For functions that modify their inputs but don't return them, # we can't give them autograd support. # See https://github.com/pytorch/pytorch/issues/53796 - not modifies_arguments(f) - or len(f.func.returns) == 0 + not modifies_arguments(f) or len(f.func.returns) == 0 ): return None return METHOD_DEFINITION.substitute( diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 44453306a0ecb..5c736cf3f8b9e 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -386,9 +386,9 @@ def group_filter_overloads( pairs: Sequence[PythonSignatureNativeFunctionPair], pred: Callable[[NativeFunction], bool], ) -> dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]]: - grouped: dict[ - BaseOperatorName, list[PythonSignatureNativeFunctionPair] - ] = defaultdict(list) + grouped: dict[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] = ( + defaultdict(list) + ) for pair in pairs: if pred(pair.function): grouped[pair.function.func.name.name].append(pair) @@ -522,12 +522,12 @@ def create_python_bindings_sharded( grouped = group_filter_overloads(pairs, pred) def key_func( - kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] + kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]], ) -> str: return kv[0].base def env_func( - kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]] + kv: tuple[BaseOperatorName, list[PythonSignatureNativeFunctionPair]], ) -> dict[str, list[str]]: name, fn_pairs = kv return { @@ -679,9 +679,7 @@ def is_schema_compatible( function=pair.function, ) ) - assert ( - any_schema_found - ), f"No native function with name {aten_name} matched signature:\n {str(schema)}" + assert any_schema_found, f"No native function with name {aten_name} matched signature:\n {str(schema)}" return results diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index d26a83713a68f..08530d42acfa4 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -199,6 +199,7 @@ "transpose", "transpose_copy", "permute", + "permute_copy", "squeeze", "squeeze_copy", "unsqueeze", @@ -240,6 +241,7 @@ "slice", "constant_pad_nd", "unbind", + "unbind_copy", "split", "split_with_sizes", "unsafe_split", diff --git a/tools/autograd/load_derivatives.py b/tools/autograd/load_derivatives.py index 645a569c45e3d..e0223cf74351b 100644 --- a/tools/autograd/load_derivatives.py +++ b/tools/autograd/load_derivatives.py @@ -128,9 +128,9 @@ def load_derivatives( # function schema is the complete declaration including mutability annotation / default value and etc. # signature is the canonical schema for a group of functions (in-place/out/functional variants) # that are semantically related. - functions_by_signature: dict[ - FunctionSchema, list[NativeFunction] - ] = defaultdict(list) + functions_by_signature: dict[FunctionSchema, list[NativeFunction]] = ( + defaultdict(list) + ) functions_by_schema: dict[str, NativeFunction] = {} for function in native_functions: functions_by_signature[function.func.signature()].append(function) @@ -991,7 +991,7 @@ def _create_op_prefix(name: str) -> str: OP names correspond to classes, hence the change to title case. Example:: - >>> _create_op_prefix('add') + >>> _create_op_prefix("add") 'AddBackward' """ camel_case = "".join([p.title() for p in name.split("_")]) diff --git a/tools/autograd/templates/VariableType.cpp b/tools/autograd/templates/VariableType.cpp index 08f1f8b698e52..23976a48473a3 100644 --- a/tools/autograd/templates/VariableType.cpp +++ b/tools/autograd/templates/VariableType.cpp @@ -41,13 +41,13 @@ namespace torch::autograd { namespace VariableType { namespace{ - C10_UNUSED void reset_grad_accumulator(Variable & self) { - AutogradMeta* meta = torch::autograd::impl::get_autograd_meta(self); - if (meta != nullptr) { - meta->grad_accumulator_.reset(); - } +[[maybe_unused]] void reset_grad_accumulator(Variable& self) { + AutogradMeta* meta = torch::autograd::impl::get_autograd_meta(self); + if (meta != nullptr) { + meta->grad_accumulator_.reset(); } } +} namespace { diff --git a/tools/build_with_debinfo.py b/tools/build_with_debinfo.py index 066d6ce414d67..26c054bf2a0c4 100755 --- a/tools/build_with_debinfo.py +++ b/tools/build_with_debinfo.py @@ -78,8 +78,11 @@ def create_build_plan() -> list[tuple[str, str]]: if line.startswith(": &&") and line.endswith("&& :"): line = line[4:-4] line = line.replace("-O2", "-g").replace("-O3", "-g") - name = line.split("-o ", 1)[1].split(" ")[0] - rc.append((name, line)) + try: + name = line.split("-o ", 1)[1].split(" ")[0] + rc.append((name, line)) + except IndexError: + print(f"Skipping {line} as it does not specify output file") return rc diff --git a/tools/flight_recorder/components/builder.py b/tools/flight_recorder/components/builder.py index e89914eb459b2..8108114b80b8d 100644 --- a/tools/flight_recorder/components/builder.py +++ b/tools/flight_recorder/components/builder.py @@ -112,8 +112,8 @@ def build_groups_memberships( assert ( _groups[pg_guid].desc == desc ), f"mismatch in desc {_groups[pg_guid].desc} vs {desc} for group {pg_guid}" - assert _memberships[pg_guid] == set( - ranks + assert ( + _memberships[pg_guid] == set(ranks) ), f"mismatch in membership for group {pg_guid} {_memberships[pg_guid]} vs {set(ranks)}" return groups, _groups, memberships, _memberships, _pg_guids diff --git a/tools/flight_recorder/components/utils.py b/tools/flight_recorder/components/utils.py index 3b0631b3a018a..7f2af5eeb29ec 100644 --- a/tools/flight_recorder/components/utils.py +++ b/tools/flight_recorder/components/utils.py @@ -278,7 +278,7 @@ def get_version_detail(version: str) -> Tuple[int, int]: def align_trace_from_beginning( - entries: Dict[int, List[Dict[str, Any]]] + entries: Dict[int, List[Dict[str, Any]]], ) -> Dict[int, List[Dict[str, Any]]]: """ Align the trace entries by record ID for entries. diff --git a/tools/linter/adapters/constexpr_linter.py b/tools/linter/adapters/constexpr_linter.py deleted file mode 100644 index adb7fe001749a..0000000000000 --- a/tools/linter/adapters/constexpr_linter.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -CONSTEXPR: Ensures users don't use vanilla constexpr since it causes issues -""" - -from __future__ import annotations - -import argparse -import json -import logging -import sys -from enum import Enum -from typing import NamedTuple - - -CONSTEXPR = "constexpr char" -CONSTEXPR_MACRO = "CONSTEXPR_EXCEPT_WIN_CUDA char" - -LINTER_CODE = "CONSTEXPR" - - -class LintSeverity(str, Enum): - ERROR = "error" - - -class LintMessage(NamedTuple): - path: str | None - line: int | None - char: int | None - code: str - severity: LintSeverity - name: str - original: str | None - replacement: str | None - description: str | None - - -def check_file(filename: str) -> LintMessage | None: - logging.debug("Checking file %s", filename) - - with open(filename) as f: - lines = f.readlines() - - for idx, line in enumerate(lines): - if CONSTEXPR in line: - original = "".join(lines) - replacement = original.replace(CONSTEXPR, CONSTEXPR_MACRO) - logging.debug("replacement: %s", replacement) - return LintMessage( - path=filename, - line=idx, - char=None, - code=LINTER_CODE, - severity=LintSeverity.ERROR, - name="Vanilla constexpr used, prefer macros", - original=original, - replacement=replacement, - description="Vanilla constexpr used, prefer macros run `lintrunner --take CONSTEXPR -a` to apply changes.", - ) - return None - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="CONSTEXPR linter", - fromfile_prefix_chars="@", - ) - parser.add_argument( - "--verbose", - action="store_true", - ) - parser.add_argument( - "filenames", - nargs="+", - help="paths to lint", - ) - - args = parser.parse_args() - - logging.basicConfig( - format="<%(threadName)s:%(levelname)s> %(message)s", - level=logging.NOTSET - if args.verbose - else logging.DEBUG - if len(args.filenames) < 1000 - else logging.INFO, - stream=sys.stderr, - ) - - lint_messages = [] - for filename in args.filenames: - lint_message = check_file(filename) - if lint_message is not None: - lint_messages.append(lint_message) - - for lint_message in lint_messages: - print(json.dumps(lint_message._asdict()), flush=True) diff --git a/tools/linter/adapters/flake8_linter.py b/tools/linter/adapters/flake8_linter.py index df5ccb4934249..c046f18ac04fe 100644 --- a/tools/linter/adapters/flake8_linter.py +++ b/tools/linter/adapters/flake8_linter.py @@ -115,7 +115,8 @@ def as_posix(name: str) -> str: def _test_results_re() -> None: """ - >>> def t(s): return RESULTS_RE.search(s).groupdict() + >>> def t(s): + ... return RESULTS_RE.search(s).groupdict() >>> t(r"file.py:80:1: E302 expected 2 blank lines, found 1") ... # doctest: +NORMALIZE_WHITESPACE diff --git a/tools/linter/adapters/pyfmt_linter.py b/tools/linter/adapters/pyfmt_linter.py index 33a7d9fe4e959..ae292100a0631 100644 --- a/tools/linter/adapters/pyfmt_linter.py +++ b/tools/linter/adapters/pyfmt_linter.py @@ -31,17 +31,11 @@ [ # ** # .ci/** - ".ci/**", # .github/** - ".github/**", # benchmarks/** - "benchmarks/**", # functorch/** - "functorch/**", # tools/** - "tools/**", # torchgen/** - "torchgen/**", # test/** # test/[a-h]*/** "test/[a-h]*/**", diff --git a/tools/lite_interpreter/gen_selected_mobile_ops_header.py b/tools/lite_interpreter/gen_selected_mobile_ops_header.py index 09f0f4e80bbaf..24bc62cdab137 100644 --- a/tools/lite_interpreter/gen_selected_mobile_ops_header.py +++ b/tools/lite_interpreter/gen_selected_mobile_ops_header.py @@ -33,9 +33,9 @@ const char *kernel_tag_str, at::ScalarType scalar_type ) { - c10::string_view kernel_tag_sv C10_UNUSED = c10::string_view(kernel_tag_str); - $body - return false; + [[maybe_unused]] c10::string_view kernel_tag_sv = + c10::string_view(kernel_tag_str); + $body return false; } } """ diff --git a/tools/nightly.py b/tools/nightly.py index 6d1dc48604089..f09563a8c0334 100755 --- a/tools/nightly.py +++ b/tools/nightly.py @@ -57,6 +57,7 @@ REPO_ROOT = Path(__file__).absolute().parent.parent GITHUB_REMOTE_URL = "https://github.com/pytorch/pytorch.git" SPECS_TO_INSTALL = ("pytorch", "mypy", "pytest", "hypothesis", "ipython", "sphinx") +DEFAULT_ENV_NAME = "pytorch-deps" LOGGER: logging.Logger | None = None URL_FORMAT = "{base_url}/{platform}/{dist_name}.tar.bz2" @@ -212,6 +213,28 @@ def check_branch(subcommand: str, branch: str | None) -> str | None: return None +def check_conda_env_exists(name: str | None = None, prefix: str | None = None) -> bool: + """Checks that the conda environment exists.""" + if name is not None and prefix is not None: + raise ValueError("Cannot specify both --name and --prefix") + if name is None and prefix is None: + raise ValueError("Must specify either --name or --prefix") + + try: + cmd = ["conda", "info", "--envs"] + output = subprocess.check_output(cmd, text=True, encoding="utf-8") + except subprocess.CalledProcessError: + logger = cast(logging.Logger, LOGGER) + logger.warning("Failed to list conda environments", exc_info=True) + return False + + if name is not None: + return len(re.findall(rf"^{name}\s+", output, flags=re.MULTILINE)) > 0 + assert prefix is not None + prefix = Path(prefix).absolute() + return len(re.findall(rf"\s+{prefix}$", output, flags=re.MULTILINE)) > 0 + + @contextlib.contextmanager def timer(logger: logging.Logger, prefix: str) -> Iterator[None]: """Timed context manager""" @@ -271,7 +294,7 @@ def conda_solve( else: # create new environment existing_env = False - env_opts = ["--name", "pytorch-deps"] + env_opts = ["--name", DEFAULT_ENV_NAME] # run solve if existing_env: cmd = [ @@ -280,8 +303,8 @@ def conda_solve( "--yes", "--dry-run", "--json", + *env_opts, ] - cmd.extend(env_opts) else: cmd = [ "conda", @@ -321,8 +344,9 @@ def deps_install(deps: list[str], existing_env: bool, env_opts: list[str]) -> No """Install dependencies to deps environment""" if not existing_env: # first remove previous pytorch-deps env - cmd = ["conda", "env", "remove", "--yes", *env_opts] - subprocess.check_call(cmd) + if check_conda_env_exists(name=DEFAULT_ENV_NAME): + cmd = ["conda", "env", "remove", "--yes", *env_opts] + subprocess.check_output(cmd) # install new deps install_command = "install" if existing_env else "create" cmd = ["conda", install_command, "--yes", "--no-deps", *env_opts, *deps] diff --git a/tools/setup_helpers/cmake.py b/tools/setup_helpers/cmake.py index 4b605fe597505..e417f6d56a0e6 100644 --- a/tools/setup_helpers/cmake.py +++ b/tools/setup_helpers/cmake.py @@ -229,6 +229,7 @@ def generate( "STATIC_DISPATCH_BACKEND", "SELECTED_OP_LIST", "TORCH_CUDA_ARCH_LIST", + "TORCH_XPU_ARCH_LIST", "TRACING_BASED", "PYTHON_LIB_REL_PATH", ) diff --git a/tools/test/test_codegen.py b/tools/test/test_codegen.py index cefd8aeeded69..83356e9694622 100644 --- a/tools/test/test_codegen.py +++ b/tools/test/test_codegen.py @@ -383,9 +383,9 @@ def test_native_function_declaration_1_op_1_ns_valid(self) -> None: class TestNativeFunctionGeneratrion(unittest.TestCase): def setUp(self) -> None: self.native_functions: list[NativeFunction] = [] - self.backend_indices: dict[ - DispatchKey, dict[OperatorName, BackendMetadata] - ] = defaultdict(dict) + self.backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = ( + defaultdict(dict) + ) yaml_entry = """ - func: op(Tensor self) -> Tensor dispatch: @@ -442,9 +442,9 @@ def test_functional_variant_autogen_out_variant_two_returns(self) -> None: # Test for static_dispatch class TestStaticDispatchGeneratrion(unittest.TestCase): def setUp(self) -> None: - self.backend_indices: dict[ - DispatchKey, dict[OperatorName, BackendMetadata] - ] = defaultdict(dict) + self.backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] = ( + defaultdict(dict) + ) yaml_entry = """ - func: op.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) dispatch: diff --git a/tools/test/test_test_selections.py b/tools/test/test_test_selections.py index 1dbe6e1f60bd7..5e3e7a949fa38 100644 --- a/tools/test/test_test_selections.py +++ b/tools/test/test_test_selections.py @@ -474,7 +474,8 @@ def test_split_shards_random(self) -> None: else: # x.time is not None because of the above check self.assertAlmostEqual( - random_times[test], sum(x.time for x in sharded_tests) # type: ignore[misc] + random_times[test], + sum(x.time for x in sharded_tests), # type: ignore[misc] ) self.assertListEqual( list(range(sharded_tests[0].num_shards)), diff --git a/tools/testing/target_determination/determinator.py b/tools/testing/target_determination/determinator.py index ff65251945ed7..9207e62c28ba5 100644 --- a/tools/testing/target_determination/determinator.py +++ b/tools/testing/target_determination/determinator.py @@ -19,10 +19,15 @@ def get_test_prioritizations( print(f" {test}", file=file) for heuristic in HEURISTICS: - new_rankings: TestPrioritizations = heuristic.get_prediction_confidence(tests) - aggregated_results.add_heuristic_results(heuristic, new_rankings) + try: + new_rankings: TestPrioritizations = heuristic.get_prediction_confidence( + tests + ) + aggregated_results.add_heuristic_results(heuristic, new_rankings) - print(f"Results from {heuristic.__class__.__name__}") - print(new_rankings.get_info_str(verbose=False), file=file) + print(f"Results from {heuristic.__class__.__name__}") + print(new_rankings.get_info_str(verbose=False), file=file) + except Exception as e: + print(f"Error in {heuristic.__class__.__name__}: {e}", file=file) return aggregated_results diff --git a/tools/testing/target_determination/heuristics/interface.py b/tools/testing/target_determination/heuristics/interface.py index 5ce0ffe576450..e1e03eee7a4b1 100644 --- a/tools/testing/target_determination/heuristics/interface.py +++ b/tools/testing/target_determination/heuristics/interface.py @@ -52,13 +52,11 @@ def validate(self) -> None: files[test.test_file] |= test for test in files.values(): - assert ( - test.is_full_file() - ), f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that" + assert test.is_full_file(), f"All includes should have been excluded elsewhere, and vice versa. Test run `{test}` violates that" # noqa: B950 # Ensure that the set of tests in the TestPrioritizations is identical to the set of tests passed in - assert self._original_tests == set( - files.keys() + assert ( + self._original_tests == set(files.keys()) ), "The set of tests in the TestPrioritizations must be identical to the set of tests passed in" def _traverse_scores(self) -> Iterator[tuple[float, TestRun]]: @@ -279,9 +277,9 @@ def get_test_stats(self, test: TestRun) -> dict[str, Any]: stats["heuristics"] = heuristics - stats[ - "aggregated" - ] = self.get_aggregated_priorities().get_priority_info_for_test(test) + stats["aggregated"] = ( + self.get_aggregated_priorities().get_priority_info_for_test(test) + ) stats["aggregated_trial"] = self.get_aggregated_priorities( include_trial=True diff --git a/tools/testing/upload_artifacts.py b/tools/testing/upload_artifacts.py new file mode 100644 index 0000000000000..2a226b1896d29 --- /dev/null +++ b/tools/testing/upload_artifacts.py @@ -0,0 +1,110 @@ +import glob +import os +import time +import zipfile +from functools import lru_cache +from pathlib import Path +from typing import Any, List + + +REPO_ROOT = Path(__file__).resolve().parent.parent.parent +LAST_UPDATED = 0.0 + + +@lru_cache(maxsize=1) +def get_s3_resource() -> Any: + import boto3 # type: ignore[import] + + return boto3.client("s3") + + +def zip_artifact(file_name: str, paths: List[str]) -> None: + """Zip the files in the paths listed into file_name. The paths will be used + in a glob and should be relative to REPO_ROOT.""" + + with zipfile.ZipFile(file_name, "w") as f: + for path in paths: + for file in glob.glob(f"{REPO_ROOT}/{path}", recursive=True): + f.write(file, os.path.relpath(file, REPO_ROOT)) + + +def upload_to_s3_artifacts() -> None: + """Upload the file to S3.""" + workflow_id = os.environ.get("GITHUB_RUN_ID") + workflow_run_attempt = os.environ.get("GITHUB_RUN_ATTEMPT") + file_suffix = os.environ.get("ARTIFACTS_FILE_SUFFIX") + if not workflow_id or not workflow_run_attempt or not file_suffix: + print( + "GITHUB_RUN_ID, GITHUB_RUN_ATTEMPT, or ARTIFACTS_FILE_SUFFIX not set, not uploading" + ) + return + + test_reports_zip_path = f"{REPO_ROOT}/test-reports-{file_suffix}.zip" + zip_artifact( + test_reports_zip_path, + ["test/test-reports/**/*.xml", "test/test-reports/**/*.csv"], + ) + test_logs_zip_path = f"{REPO_ROOT}/logs-{file_suffix}.zip" + zip_artifact(test_logs_zip_path, ["test/test-reports/**/*.log"]) + jsons_zip_path = f"{REPO_ROOT}/test-jsons-{file_suffix}.zip" + zip_artifact(jsons_zip_path, ["test/test-reports/**/*.json"]) + + s3_prefix = f"pytorch/pytorch/{workflow_id}/{workflow_run_attempt}/artifact" + get_s3_resource().upload_file( + test_reports_zip_path, + "gha-artifacts", + f"{s3_prefix}/{Path(test_reports_zip_path).name}", + ) + get_s3_resource().upload_file( + test_logs_zip_path, + "gha-artifacts", + f"{s3_prefix}/{Path(test_logs_zip_path).name}", + ) + get_s3_resource().upload_file( + test_logs_zip_path, + "gha-artifacts", + f"{s3_prefix}/{Path(jsons_zip_path).name}", + ) + get_s3_resource().put_object( + Body=b"", + Bucket="gha-artifacts", + Key=f"workflows_failing_pending_upload/{workflow_id}.txt", + ) + + +def zip_and_upload_artifacts(failed: bool) -> None: + # not thread safe but correctness of the LAST_UPDATED var doesn't really + # matter for this + # Upload if a test failed or every 20 minutes + global LAST_UPDATED + + if failed or time.time() - LAST_UPDATED > 20 * 60: + start = time.time() + try: + upload_to_s3_artifacts() + LAST_UPDATED = time.time() + except Exception as e: + print(f"Failed to upload artifacts: {e}") + print(f"Uploading artifacts took {time.time() - start:.2f} seconds") + + +def trigger_upload_test_stats_intermediate_workflow() -> None: + import requests + + # The GITHUB_TOKEN cannot trigger workflow so this isn't used for now + print("Triggering upload_test_stats_intermediate workflow") + x = requests.post( + "https://api.github.com/repos/pytorch/pytorch/actions/workflows/upload_test_stats_intermediate.yml/dispatches", + headers={ + "Accept": "application/vnd.github.v3+json", + "Authorization": f"Bearer {os.environ.get('GITHUB_TOKEN')}", + }, + json={ + "ref": "main", + "inputs": { + "workflow_run_id": os.environ.get("GITHUB_RUN_ID"), + "workflow_run_attempt": os.environ.get("GITHUB_RUN_ATTEMPT"), + }, + }, + ) + print(x.text) diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index c74b45431c947..5e6ee6f9dab28 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -415,6 +415,12 @@ if(USE_ROCM) set_source_files_properties(${TORCH_SRC_DIR}/csrc/cuda/Module.cpp PROPERTIES COMPILE_FLAGS "-DCUDA_ARCH_FLAGS=\"${PYTORCH_ROCM_ARCH_readable}\"") endif() +# Preserve XPU arch flags +if(USE_XPU) + string(REPLACE "," " " _ARCH_FLAGS_readable "${TORCH_XPU_ARCH_LIST}") + set_source_files_properties(${TORCH_SRC_DIR}/csrc/xpu/Module.cpp PROPERTIES COMPILE_FLAGS "-DXPU_ARCH_FLAGS=\"${_ARCH_FLAGS_readable}\"") +endif() + target_compile_definitions(torch_python PRIVATE "-DTHP_BUILD_MAIN_LIB") target_link_libraries(torch_python PRIVATE ${TORCH_LIB} ${TORCH_PYTHON_LINK_LIBRARIES}) @@ -456,20 +462,15 @@ else() set(TORCH_VERSION_DEBUG 0) endif() -add_custom_command( - OUTPUT ${TORCH_SRC_DIR}/version.py - COMMAND "${CMAKE_COMMAND}" -E touch "${TOOLS_PATH}/generate_torch_version.py" - COMMAND - "${Python_EXECUTABLE}" "${TOOLS_PATH}/generate_torch_version.py" - --is-debug=${TORCH_VERSION_DEBUG} - --cuda-version=${CUDA_VERSION} - --hip-version=${HIP_VERSION} - DEPENDS ${TOOLS_PATH}/generate_torch_version.py - WORKING_DIRECTORY ${TORCH_ROOT} -) add_custom_target( gen_torch_version ALL - DEPENDS ${TORCH_SRC_DIR}/version.py + "${Python_EXECUTABLE}" "${TOOLS_PATH}/generate_torch_version.py" + --is-debug=${TORCH_VERSION_DEBUG} + --cuda-version=${CUDA_VERSION} + --hip-version=${HIP_VERSION} + BYPRODUCTS ${TORCH_SRC_DIR}/version.py + COMMENT "Regenerating version file..." + WORKING_DIRECTORY ${TORCH_ROOT} ) add_dependencies(torch_python gen_torch_version) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 96122e2848163..4fec36e8e6574 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1281,6 +1281,7 @@ def _set_blas_preferred_backend(arg: torch._C._BlasBackend): ... class _BlasBackend: Cublas: _BlasBackend Cublaslt: _BlasBackend + Ck: _BlasBackend class ConvBackend(Enum): ... @@ -1873,6 +1874,7 @@ def _tensors_data_ptrs_at_indices_equal(tensors: List[Union[Tensor, _int]], ptrs def _construct_CUDA_Tensor_From_Storage_And_Metadata(metadata: dict, storage: Storage) -> Tensor: ... def _storage_Use_Count(storage_ptr: _int) -> _int: ... def _set_storage_access_error_msg(t: Tensor, s: str) -> None: ... +def _set_storage_data_ptr_access_error_msg(storage_ptr: _int, s: str) -> None: ... def _free_And_Remove_DeleterFn(storage_ptr: _int) -> None: ... def _has_Standard_Deleter(storage_ptr: _int) -> _bool: ... @@ -2091,6 +2093,7 @@ class _MemPool: def id(self) -> Tuple[_int, _int]: ... @property def allocator(self) -> Optional[_cuda_CUDAAllocator]: ... + def use_count(self) -> _int: ... class _MemPoolContext: def __init__(self, pool: _MemPool) -> None: ... @@ -2106,6 +2109,7 @@ def _xpu_exchangeDevice(device: _int) -> _int: ... def _xpu_maybeExchangeDevice(device: _int) -> _int: ... def _xpu_getDevice() -> _int: ... def _xpu_getDeviceCount() -> _int: ... +def _xpu_getArchFlags() -> Optional[str]: ... def _xpu_init() -> None: ... def _xpu_setStream(stream_id: _int, device_index: _int, device_type: _int) -> None: ... def _xpu_getCurrentStream(device: _int) -> Tuple: ... diff --git a/torch/_C/_cpu.pyi b/torch/_C/_cpu.pyi index 6593222a119f4..ddd9c4a95ec0a 100644 --- a/torch/_C/_cpu.pyi +++ b/torch/_C/_cpu.pyi @@ -8,5 +8,6 @@ def _is_avx512_vnni_supported() -> _bool: ... def _is_avx512_bf16_supported() -> _bool: ... def _is_amx_tile_supported() -> _bool: ... def _init_amx() -> _bool: ... +def _is_arm_sve_supported() -> _bool: ... def _L1d_cache_size() -> _int: ... def _L2_cache_size() -> _int: ... diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index f1cbf47ea0f3f..fea0f54f53848 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -2,7 +2,7 @@ # mypy: disable-error-code="type-arg" from datetime import timedelta from enum import Enum -from typing import Any, Optional, overload +from typing import Any, overload import torch from torch import Tensor @@ -577,6 +577,8 @@ class ProcessGroupNCCL(Backend): def perform_nocolor_split(self, device: torch.device) -> None: ... def comm_split_count(self) -> int: ... def _add_ephemeral_timeout(self, timeout: timedelta) -> None: ... + def abort(self) -> None: ... + def _is_initialized(self) -> bool: ... @property def uid(self) -> int: ... @property @@ -661,30 +663,3 @@ class _SymmetricMemory: def barrier(self, channel: int = 0) -> None: ... def put_signal(self, dst_rank: int, channel: int = 0) -> None: ... def wait_signal(self, src_rank: int, channel: int = 0) -> None: ... - -class ProcessGroupCudaP2P(Backend): - class Options: - nccl_options: Optional[ProcessGroupNCCL.Options] - buffer_size: Optional[int] - - def __init__(self) -> None: ... - - def __init__( - self, - store: Store, - rank: int, - size: int, - options: ProcessGroupCudaP2P.Options, - ) -> None: ... - def is_p2p_available(self) -> bool: ... - def get_buffer_size(self) -> int: ... - def stream(self) -> torch.cuda.Stream: ... - def intra_node_barrier(self) -> Work: ... - def get_p2p_buffer( - self, - rank: int, - sizes: torch.Size, - dtype: torch.dtype, - storage_offset: Optional[int] = 0, - ) -> torch.Tensor: ... - def _shutdown(self) -> None: ... diff --git a/torch/_C/_dynamo/compiled_autograd.pyi b/torch/_C/_dynamo/compiled_autograd.pyi index 80144e3a77907..b308f63844ed6 100644 --- a/torch/_C/_dynamo/compiled_autograd.pyi +++ b/torch/_C/_dynamo/compiled_autograd.pyi @@ -1,10 +1,6 @@ from typing import Callable -from torch._dynamo.compiled_autograd import AutogradCompilerInstance - -def set_autograd_compiler( - autograd_compiler: Callable[[], AutogradCompilerInstance] | None, -) -> Callable[[], AutogradCompilerInstance] | None: ... +def notify_autograd_engine() -> None: ... def clear_cache() -> None: ... def is_cache_empty() -> bool: ... def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ... diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index 00e556764a4b8..d39e2cc292e22 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -16,6 +16,7 @@ skip_code_recursive_flag: SkipCodeRecursiveFlag cache_limit_hit_flag: CacheLimitHitFlag def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... +def get_eval_frame_callback() -> DynamoCallback: ... def reset_code(code: types.CodeType) -> None: ... def unsupported(obj1: object, obj2: object) -> object: ... def skip_code(code: types.CodeType) -> None: ... diff --git a/torch/__init__.py b/torch/__init__.py index c21619a56f46f..5ff3c610abff6 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -34,7 +34,7 @@ TypeVar as _TypeVar, Union as _Union, ) -from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard +from typing_extensions import ParamSpec as _ParamSpec, TypeIs as _TypeIs if TYPE_CHECKING: @@ -309,7 +309,6 @@ def _load_global_deps() -> None: "cuda_runtime": "libcudart.so.*[0-9]", "cuda_cupti": "libcupti.so.*[0-9]", "cufft": "libcufft.so.*[0-9]", - "cufile": "libcufile.so.*[0-9]", "curand": "libcurand.so.*[0-9]", "nvjitlink": "libnvJitLink.so.*[0-9]", "cusparse": "libcusparse.so.*[0-9]", @@ -524,6 +523,9 @@ def __neg__(self): def __sub__(self, other: "IntLikeType") -> "SymInt": raise TypeError("type stub not overridden") + def __rsub__(self, other: "IntLikeType") -> "SymInt": + raise TypeError("type stub not overridden") + def __repr__(self): return self.node._graph_repr() @@ -1031,7 +1033,7 @@ def typename(obj: _Any, /) -> str: return f"{module}.{qualname}" -def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]: +def is_tensor(obj: _Any, /) -> _TypeIs["torch.Tensor"]: r"""Returns True if `obj` is a PyTorch tensor. Note that this function is simply doing ``isinstance(obj, Tensor)``. @@ -1051,7 +1053,7 @@ def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]: return isinstance(obj, torch.Tensor) -def is_storage(obj: _Any, /) -> _TypeGuard[_Union["TypedStorage", "UntypedStorage"]]: +def is_storage(obj: _Any, /) -> _TypeIs[_Union["TypedStorage", "UntypedStorage"]]: r"""Returns True if `obj` is a PyTorch storage object. Args: @@ -1264,6 +1266,7 @@ def use_deterministic_algorithms( * :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU tensor * :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor + * :func:`torch.cumsum` when called on a CUDA tensor * :func:`torch.gather` when called on a CUDA tensor that requires grad * :func:`torch.index_add` when called on CUDA tensor * :func:`torch.index_select` when attempting to differentiate a CUDA tensor @@ -1310,7 +1313,6 @@ def use_deterministic_algorithms( * :func:`torch.kthvalue` with called on a CUDA tensor * :func:`torch.median` with indices output when called on a CUDA tensor * :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor - * :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex * :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor * :func:`torch.Tensor.resize_` when called with a quantized tensor @@ -2237,12 +2239,20 @@ def apply_mode(self, mode: _Optional[str]): ) def apply_options(self, options: _Optional[_Dict[str, _Any]]): + from torch._inductor.bisect_helper import BisectionManager + + if bisect_changes := BisectionManager.get_config_change("inductor"): + options = {} if options is None else options + options = ( + {**bisect_changes} if options is None else {**options, **bisect_changes} # type: ignore[dict-item] + ) + if not options: return from torch._inductor import config - current_config: _Dict[str, _Any] = config.shallow_copy_dict() + current_config: _Dict[str, _Any] = config.get_config_copy() for key, val in options.items(): attr_name = key.replace("-", "_") @@ -2469,6 +2479,12 @@ def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]: ) if mode is None and options is None: mode = "default" + + from torch._inductor.bisect_helper import BisectionManager + + if bisect_backend := BisectionManager.get_backend(): + backend = bisect_backend + if backend == "inductor": backend = _TorchCompileInductorWrapper(mode, options, dynamic) else: @@ -2523,6 +2539,7 @@ def _register_device_module(device_type, module): # Populate magic methods on SymInt and SymFloat import torch.fx.experimental.sym_node +from torch import fx as fx # Register MPS specific decomps diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index ec2883d7eff28..0541e2366e898 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -31,7 +31,6 @@ "register_decomposition", "get_decompositions", "core_aten_decompositions", - "_decomp_table_to_post_autograd_aten", "_special_op_to_preserve_cia", ] @@ -263,191 +262,29 @@ def remove_decompositions( import torch._refs -# Our strategy for deciding if we can preserve a op is following: -# 1. The op should be known statically that it is functional -# 2. If it is maybe aliasing, we decompose because we must know if an op -# is mutating or aliasing. -# TODO (tmanlaibaatar) make this utility function and share it with functional_tensor -# decomp part. (https://github.com/pytorch/pytorch/issues/129431) -def _check_valid_to_preserve(op_overload: "OperatorBase"): - if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops: - return False - if op_overload in FunctionalTensor.metadata_fns: - return False - - if not hasattr(op_overload, "_schema"): - return False - - alias_info = len( - [i for i in op_overload._schema.arguments if i.alias_info is not None] - ) - - is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable - - if is_mutating_or_aliasing: - return False - - if not torch._C._dispatch_has_kernel(op_overload.name()): - return False - - return True - - -def _is_cia_op(op: "OperatorBase") -> bool: - return ( - torch._C._dispatch_has_kernel_for_dispatch_key( - op.name(), torch._C.DispatchKey.CompositeImplicitAutograd - ) - or torch._C.DispatchKey.CompositeImplicitAutograd in op.py_kernels - ) - - -def _is_preservable_cia_op(op: "OperatorBase") -> bool: - return _check_valid_to_preserve(op) and _is_cia_op(op) - - -@lru_cache(maxsize=1) -def _collect_all_valid_cia_ops() -> Set["OperatorBase"]: - """ - This is an util function that gets the all CIA functional ops. - - The algorithm is in 2 steps: - 1. We first query C++ dispatcher to get the list of CIA ops - and then we call getattr on torch.ops.aten to lazily populate - them. - - 2. Sometimes, handful of ops have CIA registered in python dispatcher - but not on the C++ side, these can't be caught at the first step. - So we walk again to get the final list. - - Note that the output of this function should never be modified - """ - # First step to lazily populate torch.ops.aten - cia_ops = torch._C._dispatch_get_registrations_for_dispatch_key( - "CompositeImplicitAutograd" - ) - # Ignore quantized namespace ops - cia_ops = [name[6:] for name in cia_ops if name.startswith("aten::")] - # Materialize all CIA ops first - for op in cia_ops: - split_list = op.split(".") - # Sometime overload could be missing - assert len(split_list) == 1 or len(split_list) == 2 - op_name = split_list[0] - op_overload_name = "default" - if len(split_list) == 2: - op_overload_name = split_list[1] - - _ = getattr(getattr(torch.ops.aten, op_name), op_overload_name) - - # Second step to finally compile the list of all valid ops - cia_ops = set() - for op in torch.ops.aten: - op_packet = getattr(torch.ops.aten, op) - for overload in op_packet.overloads(): - op_overload = getattr(op_packet, overload) - if _is_preservable_cia_op(op_overload): - cia_ops.add(op_overload) - return cia_ops - - -def _get_decomp_for_cia(op): - # [NOTE] Seperating out func.decompose - # Ideally we should be able to just register func.decompose but - # we can't as this decomp is gonna be registered to the py_impl. - # As a result it will infinitely recurse. So we first check if the op - # has py_impl entry for CIA and if it is we use that first. If not, - # we register C++ query to py_impl. - dk = torch._C.DispatchKey.CompositeImplicitAutograd - if dk in op.py_kernels and not isinstance(op.py_kernels[dk], torch._C.DispatchKey): - return op.py_kernels[dk] - - def _special_op_to_decompose_cia(*args, **kwargs): - kernel = kwargs["kernel"] - del kwargs["kernel"] - # Can't call kernel.decompose due to infinite recursion as - # we register this kernel to py_impl directly - dk = torch._C.DispatchKey.CompositeImplicitAutograd - if torch._C._dispatch_has_kernel_for_dispatch_key( - kernel.name(), torch._C.DispatchKey.CompositeImplicitAutograd - ): - return kernel._op_dk(dk, *args, **kwargs) - else: - raise AssertionError( - f"Expected {kernel} to have CompositeImplicitAutograd kernel" - ) - - return partial(_special_op_to_decompose_cia, kernel=op) - - # See NOTE [Core ATen Ops] # # list was copied from torch/_inductor/decomposition.py # excluding decompositions that results in prim ops # Resulting opset of decomposition is core aten ops def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: - decomp_table = _core_aten_decompositions_post_autograd() - # If it is fbcode change, we return the old decomposition list - from torch._inductor import config - - if config.is_fbcode(): - return decomp_table - - aten = torch.ops.aten + from torch._export.utils import ( + _collect_all_valid_cia_ops_for_aten_namespace, + _get_decomp_for_cia, + ) - # We are deleting custom decomp in core_aten_decomp - # for CIA ops but it should be fine technically - # because this table is only meant to be used in export context - # in which we really carefully control the decomp behaviour - # In any case, C++ decomps should be preferred - cia_ops_that_should_be_removed = [ - aten.all.dimname, - aten.index_add.dimname, - aten.index_copy.dimname, - aten.index_fill.Dimname_Scalar, - aten.index_fill.Dimname_Tensor, - aten.norm.names_ScalarOpt_dim_dtype, - aten.norm.names_ScalarOpt_dim, - aten.silu_backward.default, - aten.std.default, - aten.std.dim, - aten.std.names_dim, - aten.std.correction_names, - aten.std_mean.default, - aten.std_mean.dim, - aten.std_mean.names_dim, - aten.std_mean.correction_names, - aten.upsample_bilinear2d.vec, - aten.upsample_trilinear3d.vec, - ] - - for k in list(decomp_table.keys()): - if k in cia_ops_that_should_be_removed: - del decomp_table[k] - - for op in _collect_all_valid_cia_ops(): + # Entry without functional CIA ops + decomp_table = _core_aten_decompositions_post_autograd() + for op in _collect_all_valid_cia_ops_for_aten_namespace(): decomp_table[op] = _get_decomp_for_cia(op) return decomp_table -# This table is a stop-gap table which replicates -# the old behaviour of post-dispatch IR. -# This table contains all functional CIA ops mapping -# to their default decomp. In old export, this will -# be decomposed implicitly. -def _decomp_table_to_post_autograd_aten(): - decomp_table = {} - for k in _collect_all_valid_cia_ops(): - decomp_table[k] = _get_decomp_for_cia(k) - return decomp_table - - def _core_aten_decompositions_post_autograd() -> ( Dict[torch._ops.OperatorBase, Callable] ): aten = torch.ops.aten - # TODO Delete all mutating or CIA ops from this list return get_decompositions( [ aten.addcdiv, @@ -483,6 +320,7 @@ def _core_aten_decompositions_post_autograd() -> ( aten.detach, aten.diag_embed, aten.diagonal_backward, + aten.diagonal_copy, aten.dot, aten.vdot, aten.elu, @@ -519,11 +357,16 @@ def _core_aten_decompositions_post_autograd() -> ( aten.huber_loss, aten.huber_loss_backward, aten.im2col, - aten.index_add, + aten.index_add.out, + aten.index_add.default, aten.index_add_, - aten.index_copy, + aten.index_copy.out, + aten.index_copy.default, aten.index_copy_, - aten.index_fill, + aten.index_fill.int_Scalar, + aten.index_fill.int_Tensor, + aten.index_fill.int_Scalar_out, + aten.index_fill.int_Tensor_out, aten.index_fill_, aten.isin, aten.isneginf, @@ -575,7 +418,16 @@ def _core_aten_decompositions_post_autograd() -> ( aten.nll_loss2d_backward, aten.nll_loss_backward, aten.nll_loss_forward, - aten.norm, + aten.norm.ScalarOpt_dtype, + aten.norm.Scalar, + aten.norm.ScalarOpt_dim_dtype, + aten.norm.ScalarOpt_dim, + aten.norm.dtype_out, + aten.norm.out, + aten.norm.names_dtype_out, + aten.norm.names_out, + aten.norm.ScalarOpt_dtype_out, + aten.norm.Scalar_out, aten.ones, aten.ones_like, aten.pixel_shuffle, @@ -612,7 +464,7 @@ def _core_aten_decompositions_post_autograd() -> ( aten.sigmoid_backward, aten.silu, aten.silu_, - aten.silu_backward, + aten.silu_backward.grad_input, aten.sinc, aten.sinc_, aten.slice_backward, @@ -632,8 +484,13 @@ def _core_aten_decompositions_post_autograd() -> ( aten.squeeze_copy, aten.squeeze.default, aten.squeeze.dim, - aten.std, - aten.std_mean, + aten.std.correction, + aten.std.out, + aten.std.correction_out, + aten.std.names_out, + aten.std.correction_names_out, + aten.std_mean.correction, + aten.std_mean.correction_out, aten.stack, aten.sum.default, aten.sum.out, @@ -663,8 +520,8 @@ def _core_aten_decompositions_post_autograd() -> ( aten.unsqueeze_copy, aten._unsafe_view, aten.upsample_linear1d, - aten.upsample_bilinear2d, - aten.upsample_trilinear3d, + aten.upsample_bilinear2d.out, + aten.upsample_trilinear3d.out, aten.upsample_nearest2d_backward, aten.view_as_complex, aten.xlogy, diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 4383519ffdeaa..5ec05bf0b6db2 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1689,7 +1689,9 @@ def native_layer_norm_backward( N = prod(inner_dims) # type: ignore[arg-type] M = prod(outer_dims) # type: ignore[arg-type] - if M <= 0 or N <= 0: + from torch.fx.experimental.symbolic_shapes import guard_size_oblivious + + if guard_size_oblivious(M <= 0) or guard_size_oblivious(N <= 0): return ( input.new_zeros(input_shape) if output_mask[0] else None, input.new_zeros(input_shape[axis:]) if output_mask[1] else None, diff --git a/torch/_dynamo/__init__.py b/torch/_dynamo/__init__.py index 7f58ba7f7bf7f..4ca47f3d776cb 100644 --- a/torch/_dynamo/__init__.py +++ b/torch/_dynamo/__init__.py @@ -17,6 +17,7 @@ mark_static_address, maybe_mark_dynamic, run, + set_stance, substitute_in_graph, ) from .eval_frame import ( @@ -57,6 +58,7 @@ "run", "replay", "disable", + "set_stance", "reset", "OptimizedModule", "is_compiling", diff --git a/torch/_dynamo/_trace_wrapped_higher_order_op.py b/torch/_dynamo/_trace_wrapped_higher_order_op.py index c698ded100943..bff15b33b8b02 100644 --- a/torch/_dynamo/_trace_wrapped_higher_order_op.py +++ b/torch/_dynamo/_trace_wrapped_higher_order_op.py @@ -1,15 +1,21 @@ -# mypy: allow-untyped-defs +from typing import Any, Dict, List, Optional, Tuple + import torch +import torch.utils._pytree as pytree from torch._C import DispatchKey from torch._higher_order_ops.utils import autograd_not_implemented -from torch._ops import HigherOrderOperator +from torch._ops import HigherOrderOperator, OpOverload from torch._subclasses import FakeTensorMode from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree +from torch.overrides import TorchFunctionMode from torch.utils._python_dispatch import _get_current_dispatch_mode from torch.utils._pytree import tree_map_only +Tensor = torch.Tensor + + __all__ = ["trace_wrapped"] @@ -43,16 +49,109 @@ # compiled autograd do we inline into the function. -def trace_wrapped(*args, **kwargs): +if not torch._running_with_deploy(): + # torch.library.custom_op does not work with torch.deploy/multipy + + @torch.library.custom_op("FlexAttentionLib::zeros_and_scatter", mutates_args=()) # type: ignore[misc] + def zeros_and_scatter( + shape: List[int], + indices: List[Tensor], + vals: Tensor, + ) -> Tensor: + """Custom Op so that we can register a custom lowering for the new_output + scatter in the backwards pass""" + grad = torch.zeros(shape, device=vals.device, dtype=vals.dtype) + return torch.ops.aten.index_put(grad, indices, vals, accumulate=True) + + @zeros_and_scatter.register_fake # type: ignore[misc] + def _( + shape: List[int], + indices: List[Tensor], + vals: Tensor, + ) -> Tensor: + return vals.new_empty(shape) + + @zeros_and_scatter.register_vmap # type: ignore[misc] + def _(info, indims, shape, indices, value): # type: ignore[no-untyped-def] + """The batching rule is special in that it returns a tensor that is not batched""" + indices_indims = indims[1] + expanded_indices = [] + for idx, idx_indim in zip(indices, indices_indims): + # The index is not a being batched, we should unsqueeze and expand to val + if idx_indim is None: + expanded_indices.append(idx.expand(value.shape)) + else: + # the index is being part of the vmap batch, it should be the same size as val + assert idx.shape == value.shape + expanded_indices.append(idx) + + out = torch.ops.FlexAttentionLib.zeros_and_scatter( + shape, + expanded_indices, + value, + ) + return out, None + + +class ModIndex(torch.autograd.Function): + generate_vmap_rule = True + + @staticmethod + def forward(x: Tensor, indices: List[Tensor]) -> Tensor: + return torch.ops.aten.index(x, indices) + + @staticmethod + def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None: + x, indices = inputs + ctx.save_for_backward(*indices) + ctx.input_shape = x.shape + + @staticmethod + def backward(ctx, gradOut): # type: ignore[no-untyped-def] + indices = ctx.saved_tensors + return ( + torch.ops.FlexAttentionLib.zeros_and_scatter( + ctx.input_shape, + indices, + gradOut, + ), + None, + ) + + +mod_index = ModIndex.apply + + +class TransformGetItemToIndex(TorchFunctionMode): + # This is needed since we want to support calling + # A[q_idx], where q_idx is a scalar tensor in score_mod. + # Today, when q_idx is a scalar tensor, we implicitly convert it to a python + # scalar and create a view. We do not want that behavior in this case, so we + # use this torchfunctionmode to override that behavior for score_mod + # wherever we're running it. + def __torch_function__( + self, + func: OpOverload, + types: Tuple[torch._C._TensorMeta, ...], + args: Tuple[object, ...] = (), + kwargs: Optional[Dict[str, object]] = None, + ) -> object: + if func == torch.Tensor.__getitem__: + index_args = pytree.tree_leaves(args[1]) + if all(isinstance(x, torch.Tensor) for x in index_args): + return mod_index(args[0], index_args) + return func(*args, **(kwargs or {})) + + +def trace_wrapped(*args: Any, **kwargs: Any) -> Any: with torch.no_grad(): return _trace_wrapped_op(*args, **kwargs) class TraceWrapped(HigherOrderOperator): - def __init__(self): + def __init__(self) -> None: super().__init__("trace_wrapped") - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: return super().__call__(*args, **kwargs) @@ -60,7 +159,12 @@ def __call__(self, *args, **kwargs): _trace_wrapped_op = TraceWrapped() -def _assert_meta(grad, size, stride, dtype): +def _assert_meta( + grad: torch.Tensor, + size: Tuple[int, ...], + stride: Tuple[int, ...], + dtype: torch.dtype, +) -> torch.Tensor: assert grad.size() == size, "size mismatch" assert grad.stride() == stride, "stride mismatch" assert grad.dtype == dtype, "dtype mismatch" @@ -68,14 +172,19 @@ def _assert_meta(grad, size, stride, dtype): @_trace_wrapped_op.py_impl(ProxyTorchDispatchMode) -def inner_trace(mode, *args, bw_state=None, **kwargs): - def self_invoke(*args, **dyn_kwargs): +def inner_trace( + mode: ProxyTorchDispatchMode, + *args: Any, + bw_state: Optional[BackwardState] = None, + **kwargs: Any, +) -> Any: + def self_invoke(*args: Any, **dyn_kwargs: Any) -> Any: with torch.no_grad(): return _trace_wrapped_op(*args, **dyn_kwargs, **kwargs) - def unwrap_proxies(x): + def unwrap_proxies(x: Any) -> Any: if isinstance(x, torch.Tensor): - return mode.tracer.unwrap_proxy(x) + return mode.tracer.unwrap_proxy(x) # type: ignore[union-attr] if isinstance(x, (list, tuple)): return type(x)(map(unwrap_proxies, x)) if x is None: @@ -104,12 +213,12 @@ def unwrap_proxies(x): @_trace_wrapped_op.py_impl(FakeTensorMode) -def inner_fake(*args, **kwargs): +def inner_fake(*args: Any, **kwargs: Any) -> None: raise RuntimeError("This op should never be invoked here") @_trace_wrapped_op.py_impl(DispatchKey.CompositeExplicitAutograd) -def _trace_wrapped_op_dense(*args, fn, **kwargs): +def _trace_wrapped_op_dense(*args: Any, fn: Any, **kwargs: Any) -> Any: mode = _get_current_dispatch_mode() assert mode is None, "Mode should never be enabled for CPU/CUDA key" return fn(*args, **kwargs) @@ -121,7 +230,7 @@ def _trace_wrapped_op_dense(*args, fn, **kwargs): @_trace_wrapped_op.py_functionalize_impl -def _trace_wrapped_functionalized(ctx, *args, **kwargs): +def _trace_wrapped_functionalized(ctx: Any, *args: Any, **kwargs: Any) -> Any: unwrapped_args = ctx.unwrap_tensors(args) with ctx.redispatch_to_next(): return ctx.wrap_tensors(_trace_wrapped_op(*unwrapped_args, **kwargs)) diff --git a/torch/_dynamo/backends/common.py b/torch/_dynamo/backends/common.py index 323ac9412a9fd..e5815ad266b3e 100644 --- a/torch/_dynamo/backends/common.py +++ b/torch/_dynamo/backends/common.py @@ -77,7 +77,7 @@ def _wrapped_bw_compiler(*args, **kwargs): raise -def aot_autograd(**kwargs): +def aot_autograd(**kwargs) -> AotAutograd: return AotAutograd(**kwargs) diff --git a/torch/_dynamo/backends/debugging.py b/torch/_dynamo/backends/debugging.py index 2121db54c264b..94ed9b0865091 100644 --- a/torch/_dynamo/backends/debugging.py +++ b/torch/_dynamo/backends/debugging.py @@ -118,13 +118,40 @@ def run(args): return run +def fake_crossref_boxed_nop(fx_g, example_inputs): + def run(args): + with torch._subclasses.CrossRefFakeMode(): + return torch.fx.Interpreter(fx_g).boxed_run(args) + + run._boxed_call = True + return run + + +def get_nop_func(): + return ( + boxed_nop + if not torch._functorch.config.fake_tensor_crossref + else fake_crossref_boxed_nop + ) + + # Useful for debugging purpose # aot_eager uses AOT Autograd backend with nop compiler. It is helpful in debugging. -aot_eager = aot_autograd( - fw_compiler=boxed_nop, - partition_fn=min_cut_rematerialization_partition, - keep_inference_input_mutations=True, -) +def aot_eager( + gm, + fake_tensor_inputs, + fw_compiler=None, + bw_compiler=None, + **kwargs, +): + return aot_autograd( + fw_compiler=fw_compiler or boxed_nop, + bw_compiler=bw_compiler or boxed_nop, + partition_fn=min_cut_rematerialization_partition, + keep_inference_input_mutations=True, + )(gm, fake_tensor_inputs, **kwargs) + + register_backend(name="aot_eager", compiler_fn=aot_eager) aot_eager_default_partitioner = aot_autograd( @@ -145,11 +172,19 @@ def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs): "aot_eager_decomp_partition backend ignoring extra kwargs %s", kwargs ) - with functorch_config.patch(unlift_effect_tokens=True): + from torch._inductor.bisect_helper import BisectionManager + + config_patches = {"unlift_effect_tokens": True} + if bisect_changes := BisectionManager.get_config_change( + "aot_eager_decomp_partition" + ): + config_patches.update(bisect_changes) + + with functorch_config.patch(config_patches): return aot_autograd( # these are taken from memory_efficient_fusion() - fw_compiler=boxed_nop, - bw_compiler=boxed_nop, + fw_compiler=get_nop_func(), + bw_compiler=get_nop_func(), # NB: lambda here is to delay import of inductor decompositions=lambda: import_module( "torch._inductor.compile_fx" @@ -165,6 +200,17 @@ def aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs): ) +def aot_eager_decomp_partition_crossref(gm, fake_tensor_inputs, **kwargs): + with functorch_config.patch(fake_tensor_crossref=True): + return aot_eager_decomp_partition(gm, fake_tensor_inputs, **kwargs) + + +register_backend( + name="aot_eager_decomp_partition_crossref", + compiler_fn=aot_eager_decomp_partition_crossref, +) + + # AOT Autograd with torchscript backend. Default partitioner. # aot_ts uses torchscript backend. We can use this with both nnc and nvfuser # by using the relevant fuser with torch.jit.fuser(...) diff --git a/torch/_dynamo/backends/distributed.py b/torch/_dynamo/backends/distributed.py index 3b79d1e68cf8a..bb35a9117daa6 100644 --- a/torch/_dynamo/backends/distributed.py +++ b/torch/_dynamo/backends/distributed.py @@ -413,23 +413,6 @@ def compile_fn(self, gm: fx.GraphModule, example_inputs: List[torch.Tensor]): to compile each subgraph. Finally, stiches compiled graphs into one graphmodule and returns its callable. """ - if has_higher_order_op(gm): - # This indicates presence of a higher order op. For now, we - # have no way to break the higher order op into two buckets. - # Allowing higher order ops in the graph also requires - # changes in the split_module, becuase graph splitter - # currently assumes that all the args of all ops are - # tensors, but in the case of higher order ops, it could be - # a graph module. As a workaround, we are shortcircuiting - raise NotImplementedError( - "DDPOptimizer backend: Found a higher order op in the graph. " - "This is not supported. Please turn off DDP optimizer using " - "torch._dynamo.config.optimize_ddp=False. Note that this can " - "cause performance degradation because there will be one bucket " - "for the entire Dynamo graph. Please refer to this issue - " - "https://github.com/pytorch/pytorch/issues/104674." - ) - # 1: compute the partition map according to DDP bucket logic buckets = [Bucket()] # (size, param_names) processed_modules = set() diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index 1812145fcf162..8202d32dcd1b4 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -1037,7 +1037,7 @@ def remove_fused_load_store(instructions: List[Instruction]) -> None: create_instruction(inst0, argval=argval0), create_instruction(inst1, argval=argval1), ] - new_insts.append(overwrite_instruction(inst, replace_insts)) + new_insts.extend(overwrite_instruction(inst, replace_insts)) else: new_insts.append(inst) instructions[:] = new_insts diff --git a/torch/_dynamo/cache_size.py b/torch/_dynamo/cache_size.py index 5c675ad052907..1d0c169345d2e 100644 --- a/torch/_dynamo/cache_size.py +++ b/torch/_dynamo/cache_size.py @@ -15,10 +15,10 @@ [Note on cache size limit] Background - TorchDynamo cache is a linked list. Each cache entry is a -(check_fn, out_code, next pointer). These are stored on the f_code's co_extra +(guard_manager, out_code, next pointer). These are stored on the f_code's co_extra scratch space. When a frame is invoked, we walk this linked list and run -check_fn in each cache_entry to decide if the frame needs recompilation. If none -of the check_fn's returns True, we recompile and add a new entry. To ensure we +guard_manager in each cache_entry to decide if the frame needs recompilation. If none +of the guard_manager's returns True, we recompile and add a new entry. To ensure we don't end up recompiling infinitely, we put limits on the cache size. There are two limits @@ -121,7 +121,7 @@ def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool: for ( local_name, weakref_from_cache_entry, - ) in cache_entry.check_fn.id_matched_objs.items(): + ) in cache_entry.guard_manager.id_matched_objs.items(): if weakref_from_cache_entry() is not None: weakref_from_frame = _get_weakref_from_f_locals(frame, local_name) if weakref_from_frame is not weakref_from_cache_entry: @@ -176,7 +176,7 @@ def exceeds_cache_size_limit( if cache_size.will_compilation_exceed_specific_limit(config.cache_size_limit): return True, "cache_size_limit" # NOTE this check is needed in the case that the frame's cache doesn't grow - # and we keep recompiling. This can happen if the guard check_fn becomes invalidated, + # and we keep recompiling. This can happen if the guard guard_manager becomes invalidated, # e.g. due to guarded objects being freed. This technically makes the # will_compilation_exceed_accumulated_limit check unnecessary, but we will keep the # check in case we have a better fix in the future. diff --git a/torch/_dynamo/codegen.py b/torch/_dynamo/codegen.py index 6b072cf5a9093..dd83f56c34615 100644 --- a/torch/_dynamo/codegen.py +++ b/torch/_dynamo/codegen.py @@ -51,6 +51,7 @@ def __init__( root: Optional[torch.nn.Module] = None, graph_output_var: Optional[str] = None, tempvars=None, + overridden_sources=None, ) -> None: self.root = root self.top_of_stack: Optional[VariableTracker] = None @@ -65,6 +66,10 @@ def __init__( self.new_var = self.tx.output.new_var self.mutable_side_effects_from_source = False self.value_from_source: bool = True + # This serves as a way for codegen to use a different source; we need + # this because sometimes we can't easily modify the original source + # without affecting other components, e.g., guards. + self.overridden_sources: Dict[Source, Source] = overridden_sources or {} def restore_stack(self, stack_values, *, value_from_source=True): prior = self.mutable_side_effects_from_source @@ -116,7 +121,9 @@ def add_push_null(self, gen_fn, call_function_ex=False): def __call__(self, value, allow_cache=True): """Generate code such that top-of-stack (TOS) is set to value""" if isinstance(value, Source): - self.call_reconstruct(value) + # If the source needs to be overridden, use the new one. + source = self.overridden_sources.get(value, value) + self.call_reconstruct(source) self.clear_tos() return @@ -130,27 +137,25 @@ def __call__(self, value, allow_cache=True): if self.mutable_side_effects_from_source: # this is needed to get aliasing relationships right - # value.mutable_local.source will get mutated to hold `value` + # value.source will get mutated to hold `value` # mutable_side_effects_from_source=False is used to codegen the mutation # mutable_side_effects_from_source=True is used to codegen a reference from .side_effects import MutableSideEffects if isinstance(value.mutable_local, MutableSideEffects): - self(value.mutable_local.source) + self(value.source) return if allow_cache: - if value.mutable_local and value.mutable_local in self.tempvars: - output.append(self.create_load(self.tempvars[value.mutable_local])) - self.top_of_stack = value - return if self.tempvars.get(value) is not None: output.append(self.create_load(self.tempvars[value])) self.top_of_stack = value return if value.source is not None and allow_cache and self.value_from_source: - self.call_reconstruct(value.source) + # If the source needs to be overridden, use the new one. + source = self.overridden_sources.get(value.source, value.source) + self.call_reconstruct(source) elif value.is_python_constant() and is_safe_constant( value.as_python_constant() ): @@ -180,7 +185,9 @@ def __call__(self, value, allow_cache=True): # NB: It works to add_graph_output on a computed expression # as_tensor here, because we memoize as_tensor calls on # SymNodeVariable! - graph_outputs_key = self.add_graph_output(value.as_tensor(self.tx)) + graph_outputs_key = self.add_graph_output( + value.as_tensor(self.tx, torch.float64) + ) def gen_fn(): self.load_graph_output(graph_outputs[graph_outputs_key].index) @@ -254,8 +261,6 @@ def load_graph_output(self, index): def add_cache(self, value): var = self.new_var() self.tempvars[value] = var - if value.mutable_local: - self.tempvars[value.mutable_local] = var self._output.append(self.create_store(var)) def foreach(self, items): diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index 088b0c3579ebb..ada3aa4eee0b4 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -1,7 +1,10 @@ # mypy: allow-untyped-defs import contextlib import functools -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +import threading +from dataclasses import dataclass +from logging import Logger +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch from torch._dynamo.external_utils import ( @@ -38,18 +41,95 @@ verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose") -def snapshot_verbose_logging_enabled(): - return torch._logging._internal.log_state.is_artifact_enabled( - "compiled_autograd_verbose" - ) +@dataclass +class CompiledAutogradTLS: + next_ctx_id: int = 0 + in_compiled_autograd_region: bool = False + compiler: Optional["AutogradCompilerInstance"] = None + vlogger: Optional[Logger] = None + + +class TLSWrapper: + tls_key = "compiled_autograd_state" + + def __init__(self): + self._local = threading.local() + + def _get_tls(self) -> CompiledAutogradTLS: + if hasattr(self._local, self.tls_key): + # first look in python + state = getattr(self._local, self.tls_key) + if torch._C._is_key_in_tls(self.tls_key): + # then look in cpp + state = torch._C._get_obj_in_tls(self.tls_key) + else: + # init new thread created outside of autograd + # TODO: what if context manager wrapped outside of thread? + setattr(self._local, self.tls_key, CompiledAutogradTLS()) + state = getattr(self._local, self.tls_key) + torch._C._stash_obj_in_tls(self.tls_key, state) + return state + + # queries on the object stored in TLS + def get(self, name): + return getattr(self._get_tls(), name) + + def set_tls(self, **kwargs) -> Callable[[], None]: + priors: Dict[str, Any] = {} + for k, v in kwargs.items(): + state = self._get_tls() + priors[k] = getattr(state, k) + setattr(state, k, v) + + torch._C._dynamo.compiled_autograd.notify_autograd_engine() + + def revert(): + self.set_tls(**priors) + + return revert + + def enter_ctx(self) -> Callable[[], None]: + state = self._get_tls() + state.next_ctx_id += 1 + id = state.next_ctx_id + + def exit(): + assert ( + state is self._get_tls() + ), "Runtime must begin and end on the same thread" + assert state.next_ctx_id == id, ( + "Error nesting compiled autograd context managers: " + "inner context managers must have shorter lifetime than the outer context manager" + ) + state.next_ctx_id -= 1 + + return exit + + def enter_compiled_region(self) -> Callable[[], None]: + state = self._get_tls() + prior = state.in_compiled_autograd_region + state.in_compiled_autograd_region = True + assert prior is False, "Nested compiled autograd regions are not supported" + + def exit(): + assert ( + state is self._get_tls() + ), "Runtime must begin and end on the same thread" + assert state.in_compiled_autograd_region is True + state.in_compiled_autograd_region = prior + return exit -def cpp_verbose_log_fn(msg: str) -> None: - verbose_log.debug(msg) +local = TLSWrapper() -def snapshot_cudagraph_enabled(): - return torch._inductor.config.triton.cudagraphs + +def enabled() -> bool: + return local.get("compiler") is not None + + +def in_compiled_autograd_region() -> bool: + return local.get("in_compiled_autograd_region") def maybe_clone(x): @@ -311,7 +391,7 @@ def end_capture(self, outputs): self.rename_aot_dispatcher_nodes() self.reorder_accumulate_grad_nodes() runtime_inputs_to_move: List[int] = [] - if snapshot_cudagraph_enabled(): + if torch._inductor.config.triton.cudagraphs: runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) graph = GraphModule( @@ -333,16 +413,15 @@ def end_capture(self, outputs): ) def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks): - global in_compiled_autograd_region try: - in_compiled_autograd_region = True + exit_compiled_region = local.enter_compiled_region() for i in runtime_inputs_to_move: inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True) with disable(): return compiled_fn(inputs, sizes, scalars, hooks) finally: - in_compiled_autograd_region = False + exit_compiled_region() return runtime_wrapper, self.compiler_fn(graph) @@ -514,51 +593,64 @@ def set_node_origin( set_stack_trace(new_stack_trace) -# state of the autograd engine dispatch, kept in sync by enable/disable context managers -compiled_autograd_enabled = False - -# global flag to check if we are processing graphs produced from a compiled autograd graph -in_compiled_autograd_region = False +# global flag to check if compiled autograd is enabled but Dynamo stance is "force_eager" +compiled_autograd_enabled_force_eager = False @contextlib.contextmanager def enable(compiler_fn): - # we need to import this, because user might not have imported it if they directly use this context manager - # we need to lazily import it, because of circular dependencies - import torch._inductor.cudagraph_trees - - prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler( - functools.partial(AutogradCompilerInstance, compiler_fn) - ) - if snapshot_verbose_logging_enabled(): - torch._C._dynamo.compiled_autograd.set_verbose_logger(cpp_verbose_log_fn) - global compiled_autograd_enabled - compiled_autograd_enabled = True - try: - with torch.autograd.set_multithreading_enabled(False): + from torch._dynamo import eval_frame + + if eval_frame._stance.stance == "force_eager": + # If user explicitly sets Dynamo stance to "force_eager", we want Compiled Autograd + # to fall back to eager as well. + global compiled_autograd_enabled_force_eager + compiled_autograd_enabled_force_eager = True + try: yield - finally: - if not prior: - compiled_autograd_enabled = False - torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) + finally: + compiled_autograd_enabled_force_eager = False + else: + # we need to import this, because user might not have imported it if they directly use this context manager + # we need to lazily import it, because of circular dependencies + import torch._inductor.cudagraph_trees + + exit_ctx = local.enter_ctx() + revert_tls = local.set_tls( + compiler=functools.partial(AutogradCompilerInstance, compiler_fn), + vlogger=verbose_log + if torch._logging._internal.log_state.is_artifact_enabled( + "compiled_autograd_verbose" + ) + else None, + ) + try: + with torch.autograd.set_multithreading_enabled(False): + yield + finally: + revert_tls() + exit_ctx() @contextlib.contextmanager def disable(): - prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) - global compiled_autograd_enabled - compiled_autograd_enabled = False + exit_ctx = local.enter_ctx() + revert_tls = local.set_tls( + compiler=None, + vlogger=None, + ) try: yield finally: - if prior: - compiled_autograd_enabled = True - torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) + revert_tls() + exit_ctx() # return to starting state of a new process def reset() -> None: - compiled_autograd_enable = False - assert not in_compiled_autograd_region - torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) - torch._C._dynamo.compiled_autograd.set_verbose_logger(None) + assert local.get("next_ctx_id") == 0 + assert local.get("in_compiled_autograd_region") is False + local.set_tls( + compiler=None, + vlogger=None, + ) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index c3db01dfa454b..e974e4ccb852b 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -252,10 +252,6 @@ # compile this code; however, this can be useful for export. force_unspec_int_unbacked_size_like_on_torchrec_kjt = False -# Should almost always be true in prod. This relaxes the requirement that cond's true_fn and -# false_fn produces code with identical guards. -enforce_cond_guards_match = True - # Specify how to optimize a compiled DDP module. The flag accepts a boolean # value or a string. There are 4 modes. # 1. "ddp_optimizer" (or True): with "ddp_ptimizer", Dynamo will automatically @@ -373,8 +369,8 @@ def _get_optimize_ddp_mode(): # use numpy's PRNG if True, pytorch otherwise use_numpy_random_stream = False -# Use C++ guard manager -enable_cpp_guard_manager = os.environ.get("TORCHDYNAMO_CPP_GUARD_MANAGER", "1") == "1" +# Use C++ guard manager (deprecated: always true) +enable_cpp_guard_manager = True # Inline inbuilt nn modules inline_inbuilt_nn_modules = not is_fbcode() @@ -472,6 +468,11 @@ def default_debug_dir_root(): # Overrides torch.compile() kwargs for Compiled Autograd: compiled_autograd_kwargs_override: Dict[str, Any] = {} +# Compiled Autograd will attempt to automatically wrap C++ autograd functions found in the autograd graph, +# and make them opaque to the compiler. This does not work when the C++ backward implementation involves +# other dispatcher subsystems e.g. custom subclasses, autocast, vmap. +compiled_autograd_opaque_cpp_node = False + # Enables use of collectives *during* compilation to synchronize behavior # across ranks. Today, this is used solely to modify automatic_dynamic_shapes # behavior, making it so that we infer that if an input is dynamic by diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 57ce3e972ba7b..a3aa8eb00e4b4 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -7,6 +7,7 @@ import dis import functools import itertools +import json import logging import os import pstats @@ -44,6 +45,7 @@ GuardOnDataDependentSymNode, ) from torch.fx.graph_module import _forward_from_src as original_forward_from_src +from torch.monitor import _WaitCounter from torch.nn.parallel.distributed import DistributedDataParallel from torch.utils._python_dispatch import ( _disable_current_modes, @@ -117,6 +119,7 @@ record_compilation_metrics, reset_graph_break_dup_checker, setup_compile_debug, + to_int_ms, troubleshooting_url, write_record_to_file, ) @@ -645,7 +648,7 @@ def transform( one_graph, export, export_constraints, - mutated_closure_cell_contents, + mutated_closure_cell_ids, frame_state=frame_state, speculation_log=speculation_log, distributed_state=distributed_state, @@ -689,9 +692,19 @@ def compile_inner( hooks: Hooks, transform: Callable[[List[Instruction], Dict[str, Any]], Any], ) -> Optional[GuardedCode]: - with dynamo_timed("_compile.compile_inner", phase_name="entire_frame_compile"): - with CompileTimeInstructionCounter.record(): - return _compile_inner(code, one_graph, hooks, transform) + with contextlib.ExitStack() as stack: + stack.enter_context( + dynamo_timed( + "_compile.compile_inner", phase_name="entire_frame_compile" + ) + ) + stack.enter_context( + _WaitCounter("pytorch.wait_counter.dynamo_compile").guard() + ) + stack.enter_context(CompileTimeInstructionCounter.record()) + return _compile_inner(code, one_graph, hooks, transform) + + return None # dead, but see https://github.com/python/mypy/issues/7577 @compile_time_strobelight_meta(phase_name="compile_inner") @maybe_cprofile @@ -829,7 +842,7 @@ def count_args(code: CodeType) -> int: compile_id_str = str(compile_id) if compile_id is not None else "Unknown" annotation_str = "Torch-Compiled Region: " + compile_id_str guarded_code = GuardedCode( - out_code, check_fn.check_fn, compile_id, annotation_str + out_code, check_fn.guard_manager, compile_id, annotation_str # type: ignore[arg-type] ) if not output.is_empty_graph() and hooks.guard_export_fn is not None: @@ -842,12 +855,17 @@ def count_args(code: CodeType) -> int: return guarded_code + chromium_event_log = get_chromium_event_logger() + + chromium_event_log.reset() + chromium_start_time = time.time_ns() + chromium_event_log.log_event_start("dynamo", chromium_start_time, {}) with _use_lazy_graph_module(config.use_lazy_graph_module), compile_context( CompileContext(compile_id) ): restart_reasons: set[str] = set() # This is shared across restarts - mutated_closure_cell_contents: Set[str] = set() + mutated_closure_cell_ids: Set[int] = set() speculation_log = SpeculationLog() if compile_pg := get_compile_pg(): distributed_state = DistributedState(compile_pg, LocalState()) @@ -925,8 +943,6 @@ def format_guard_failures() -> str: # torch/_dynamo/convert_frame.py:780 in convert_frame_intern = structured.intern_string(__file__) # Initialize the ChromiumEventLogger on start - chromium_event_log = get_chromium_event_logger() - chromium_event_log.reset() torch._logging.trace_structured( "dynamo_start", lambda: { @@ -1052,6 +1068,12 @@ def format_guard_failures() -> str: "auto_functionalize", {"missed_reinplacing_bytes": possibly_missed_reinplacing_bytes}, ) + remote_fx_graph_cache_get_time = frame_phase_timing[frame_key].get( + "remote_fx_graph_cache_get", None + ) + remote_fx_graph_cache_put_time = frame_phase_timing[frame_key].get( + "remote_fx_graph_cache_put", None + ) else: guard_count = None shape_env_guard_count = None @@ -1069,11 +1091,25 @@ def format_guard_failures() -> str: dynamo_time_before_restart = time.time() - start_time possibly_missed_reinplacing_opportunities = None remote_cache_time_saved = None + remote_fx_graph_cache_get_time = None + remote_fx_graph_cache_put_time = None structured_logging_overhead_s = ( torch._logging.get_structured_logging_overhead() ) + def handle_sets(d: Dict[str, Any]) -> Dict[str, Any]: + # Remove entries that have set values which are functions + del d["reorderable_logging_functions"] + # Remove entries that have set values which are _TensorMeta + del d["traceable_tensor_subclasses"] + + return { + key: list(value) if isinstance(value, set) else value + for key, value in d.items() + } + + config_dict = handle_sets(config.get_config_copy()) metrics = CompilationMetrics( str(compile_id), frame_key, @@ -1107,9 +1143,16 @@ def format_guard_failures() -> str: config.suppress_errors, config.inline_inbuilt_nn_modules, config.specialize_float, + json.dumps(config_dict), + True, # is_forward + to_int_ms(remote_fx_graph_cache_get_time), + to_int_ms(remote_fx_graph_cache_put_time), ) record_compilation_metrics(metrics) torch._dynamo.callback_handler.run_end_callbacks() + chromium_event_log.log_event_end( + "dynamo", time.time_ns(), {}, chromium_start_time + ) class ConvertFrame: @@ -1179,9 +1222,17 @@ def __call__( user_stack_formatted = "".join( traceback.format_list(user_stack) ) + user_stack_trace = f"Graph break: skip: from user code at:\n{user_stack_formatted}" + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "dynamo_graph_break_reason", + "encoding": "string", + }, + payload_fn=lambda: f"{user_stack_trace}\n{traceback.format_exc()}", + ) graph_break_log.debug( - "Graph break: skip: from user code at:\n%s", - user_stack_formatted, + user_stack_trace, exc_info=True, ) diff --git a/torch/_dynamo/decorators.py b/torch/_dynamo/decorators.py index 67d6c0f27a4c2..73a942c6fbab7 100644 --- a/torch/_dynamo/decorators.py +++ b/torch/_dynamo/decorators.py @@ -6,11 +6,18 @@ from typing import Any, Callable, Dict, Type, TYPE_CHECKING, TypeVar import torch +from torch.utils._contextlib import _DecoratorContextManager from torch.utils._python_dispatch import is_traceable_wrapper_subclass from . import trace_rules, variables from .comptime import comptime -from .eval_frame import DisableContext, innermost_fn, RunOnlyContext +from .eval_frame import ( + _set_stance, + DisableContext, + DynamoStance, + innermost_fn, + RunOnlyContext, +) from .exc import IncorrectUsage from .external_utils import is_compiling from .utils import is_function @@ -49,7 +56,7 @@ def run(fn=None): def disable(fn=None, recursive=True): """ - Decorator and context manager to disable TorchDynamo + Decorator to disable TorchDynamo If recursive=True, Dynamo is completely skipped on the decorated function frame as well as the recursively invoked functions. @@ -81,6 +88,39 @@ def skip(fn=None): return fn +class set_stance(_DecoratorContextManager): + """ + Decorator, context manager, function to set the current stance of the compiler. + + Stances documented in corresponding function in torch/compiler/__init__.py + """ + + _dynamo_forbidden = True + + def __init__(self, stance: str, force_backend=None) -> None: + if force_backend is not None and stance != "default": + raise RuntimeError("non-default stance cannot have force_backend set") + + self.stance = DynamoStance(stance, force_backend) + self.prev = _set_stance(self.stance) + + def __call__(self, fn): + _set_stance(self.prev) + wrapper = super().__call__(fn) + # forbid wrapper in graph + wrapper._dynamo_forbidden = True # type: ignore[attr-defined] + return wrapper + + def __enter__(self): + _set_stance(self.stance) + + def __exit__(self, exc_type, exc_val, exc_tb): + _set_stance(self.prev) + + def clone(self): + return self.__class__(self.stance.stance, force_backend=self.stance.backend) + + def assume_constant_result(fn): fn._dynamo_marked_constant = True return fn diff --git a/torch/_dynamo/device_interface.py b/torch/_dynamo/device_interface.py index baa26c6478988..141defe0210ae 100644 --- a/torch/_dynamo/device_interface.py +++ b/torch/_dynamo/device_interface.py @@ -346,13 +346,13 @@ def register_interface_for_device( device: Union[str, torch.device], device_interface: Type[DeviceInterface] ): if isinstance(device, torch.device): - device = str(device) + device = device.type device_interfaces[device] = device_interface def get_interface_for_device(device: Union[str, torch.device]) -> Type[DeviceInterface]: if isinstance(device, torch.device): - device = str(device) + device = device.type if not _device_initialized: init_device_reg() if device in device_interfaces: diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 6ca7368006032..a3fedd4cd2128 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -23,6 +23,7 @@ import types import warnings import weakref +from dataclasses import dataclass from enum import Enum from os.path import dirname, join from typing import ( @@ -57,7 +58,11 @@ from torch._dispatch.python import enable_python_dispatcher from torch._subclasses.fake_tensor import unset_fake_temporarily from torch._utils_internal import justknobs_check, log_export_usage -from torch.export.dynamic_shapes import _combine_args, _process_dynamic_shapes +from torch.export.dynamic_shapes import ( + _combine_args, + _process_dynamic_shapes, + _RelaxedConstraint, +) from torch.fx import GraphModule from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ( @@ -114,6 +119,68 @@ def _maybe_set_eval_frame(callback: DynamoCallback): return set_eval_frame(callback) +@dataclass +class DynamoStance: + stance: str = "default" + backend: Union[str, Callable[..., Any], None] = None + + +_stance = DynamoStance() + + +def _set_stance(stance: DynamoStance) -> DynamoStance: + global _stance + + from torch._C._dynamo.eval_frame import get_eval_frame_callback + + callback = get_eval_frame_callback() + + if callback is not False and callback is not None: + raise RuntimeError("attempted to set_stance in a torch.compile region") + + prior = _stance + _stance = stance + return prior + + +_set_stance._dynamo_forbidden = True # type: ignore[attr-defined] + + +def _callback_from_stance(callback): + if _stance.stance == "default": + # force_backend + if _stance.backend is not None: + hooks = Hooks() + callback = convert_frame.catch_errors_wrapper( + convert_frame.convert_frame( # type: ignore[arg-type] + get_compiler_fn(_stance.backend), + hooks, + ), + hooks, + ) + + return callback + elif _stance.stance == "force_eager": + # disable + return None + elif _stance.stance == "eager_on_recompile": + # run mode + return False + elif _stance.stance == "fail_on_recompile": + + def fail_callback(*args, **kwargs): + raise RuntimeError( + "Detected recompile when torch.compile stance is 'fail_on_recompile'" + ) + + # to prevent cache miss due to different callback + fail_callback._torchdynamo_orig_callable = callback # type: ignore[attr-defined] + + return fail_callback + else: + raise RuntimeError(f"invalid torch.compile stance '{_stance}'") + + def _reset_guarded_backend_cache(): global cached_backends for backend in cached_backends.values(): @@ -376,7 +443,7 @@ def __enter__(self): "to use torch._dynamo.optimize(...) as an annotation/decorator. " ) self.cleanup_fns = [enter() for enter in self.enter_exit_hooks] - self.prior = _maybe_set_eval_frame(self.callback) + self.prior = _maybe_set_eval_frame(_callback_from_stance(self.callback)) def __exit__(self, exc_type, exc_val, exc_tb): assert self.prior is not unset @@ -471,7 +538,7 @@ def _fn(*args, **kwargs): ) cleanups = [enter() for enter in self.enter_exit_hooks] - prior = _maybe_set_eval_frame(callback) + prior = _maybe_set_eval_frame(_callback_from_stance(callback)) # Ensure that if an assertion occurs after graph pushes # something onto the DynamicLayerStack then we pop it off (the @@ -645,11 +712,9 @@ def __call__(self, fn): assert callable(fn) - callback = self.callback - @functools.wraps(fn) def _fn(*args, **kwargs): - prior = _maybe_set_eval_frame(callback) + prior = _maybe_set_eval_frame(_callback_from_stance(self.callback)) try: return fn(*args, **kwargs) finally: @@ -817,9 +882,11 @@ def toy_example(a, b): hooks, backend_ctx_ctor, dynamic=dynamic, - compiler_config=backend.get_compiler_config() - if hasattr(backend, "get_compiler_config") - else None, + compiler_config=( + backend.get_compiler_config() + if hasattr(backend, "get_compiler_config") + else None + ), rebuild_ctx=rebuild_ctx, ) @@ -933,9 +1000,11 @@ def __init__( flat_args[i], symbolic_context=StatelessSymbolicContext( dynamic_sizes=[ - DimDynamic.DYNAMIC - if d in flat_args_dynamic_dims[i] - else DimDynamic.STATIC + ( + DimDynamic.DYNAMIC + if d in flat_args_dynamic_dims[i] + else DimDynamic.STATIC + ) for d in range(len(flat_args[i].shape)) ], constraint_sizes=[None] * len(flat_args[i].shape), @@ -1597,6 +1666,7 @@ def graph_with_interpreter(*args): for c in (constraints or ()) if ( c.t_id == id(x) + and not isinstance(c, _RelaxedConstraint) and c.constraint_range.vr.lower != c.constraint_range.vr.upper ) } diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 5012a30581a5b..fad5cb2f35170 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -144,6 +144,7 @@ class UserErrorType(Enum): DYNAMIC_DIM = auto() INVALID_INPUT = auto() INVALID_OUTPUT = auto() + UNSUPPORTED_ALIASED_MUTATED_DYNAMIC_INPUTS = auto() class UserError(Unsupported): @@ -287,6 +288,14 @@ def unimplemented_with_warning(e: Exception, code, msg: str) -> NoReturn: # exception, its ok to fallback to eager but not silently. Here, we can use # this function to log the message and the stack trace. graph_break_msg = format_error_msg_verbose(e, code) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "dynamo_graph_break_reason", + "encoding": "string", + }, + payload_fn=lambda: graph_break_msg, + ) graph_breaks_log.debug("%s", graph_break_msg) log.warning(msg) unimplemented(msg, from_exc=e) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index e0d1c853268f0..f00b96300e4c9 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -13,10 +13,10 @@ import itertools import logging import math -import os import re import sys import textwrap +import time import types import weakref from contextlib import contextmanager @@ -37,6 +37,7 @@ from weakref import ReferenceType import torch +import torch.overrides import torch.utils._device from torch._C._dynamo.guards import ( check_obj_id, @@ -46,7 +47,6 @@ install_no_tensor_aliasing_guard, install_object_aliasing_guard, RootGuardManager, - TensorGuards, ) from torch._dynamo.source import ( is_from_flatten_script_object_source, @@ -146,7 +146,7 @@ verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards") -class GuardManager: +class GuardManagerWrapper: """ A helper class that contains the root guard manager. An instance of this class is stored in the Dynamo cache entry, so that the cache entry can @@ -294,7 +294,10 @@ def visit(mgr): def from_numpy(a): # If not numpy array, piggy back on e.g. tensor guards to check type - return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a + # Re-enable torch function since we disable it on leaf guards + # we need it to properly construct the tensor if a default device is set + with torch.overrides._enable_torch_function(): + return torch.as_tensor(a) if isinstance(a, (np.generic, np.ndarray)) else a # For user stack printing @@ -524,7 +527,7 @@ def __init__( lookup_weakrefs: Callable[[object], ReferenceType[object]], local_scope: Dict[str, object], global_scope: Dict[str, object], - guard_manager: Optional[GuardManager], + guard_manager: GuardManagerWrapper, check_fn_manager: CheckFunctionManager, ): self.id_ref = id_ref @@ -568,7 +571,7 @@ def __init__( self.tensor_check_names: List[str] = [] self.tensor_check_examples: List[torch.Tensor] = [] self.tensor_check_guards: List[Guard] = [] - self.tensor_check_guard_managers: List[GuardManager] = [] + self.tensor_check_guard_managers: List[GuardManagerWrapper] = [] self.check_fn_manager: CheckFunctionManager = check_fn_manager @@ -581,7 +584,7 @@ def __init__( self.key_order_guarded_dict_ids.add(id(self.get(source_name))) # Keep track of weak references of objects with ID_MATCH guard. This - # info is stored alongside optimized_code and check_fn and is used to + # info is stored alongside optimized_code and guard_manager and is used to # limit the number of cache entries with same ID_MATCH'd object. self.id_matched_objs: Dict[str, ReferenceType[object]] = {} @@ -589,7 +592,6 @@ def __init__( self._cached_guard_managers: Dict[ str, torch._C._dynamo.guards.GuardManager ] = {} - self._cached_duplicate_input_guards: Set[Tuple[str, str]] = set() def guard_on_dict_keys_and_ignore_order(self, example_value, guard): @@ -824,7 +826,6 @@ def manager_guards_on_keys(self, mgr_enum): ) def get_global_guard_manager(self): - assert self.guard_manager # to make mypy happy return self.guard_manager.root.globals_dict_manager( f_globals=self.scope["G"], source="G", @@ -833,7 +834,6 @@ def get_global_guard_manager(self): ) def get_guard_manager_from_source(self, source): - assert self.guard_manager # to make mypy happy root_guard_manager = self.guard_manager.root example_value = None @@ -1158,7 +1158,6 @@ def add_python_lambda_leaf_guard_to_root( globals_for_guard_fn = {"G": self.scope["G"]} exec(pycode, globals_for_guard_fn, out) guard_fn = out["___make_guard_fn"](*closure_vars.values()) - assert self.guard_manager # to make mypy happy if is_epilogue: # Epilogue guards are run after all the other guards have finished. # If epilogue guards contain a getattr or getitem access, one of the @@ -1227,44 +1226,39 @@ def HASATTR(self, guard: Guard): guard, [code], provided_guarded_object=self.get(base) ) - if config.enable_cpp_guard_manager: - base_manager = self.get_guard_manager_from_source(base_source) - if val: - # Just install a getattr manager. GetAttrGuardAccessor itself - # acts as hasattr guard. - example_value = self.get(source.name()) - base_example_value = self.get(base) - guard_manager_enum = self.get_guard_manager_type(source, example_value) - - # if the base value is nn.Module, check if we can speedup the - # guard by going through __dict__ attrs. - if ( - isinstance(base_example_value, torch.nn.Module) - and get_custom_getattr(base_example_value) - is unpatched_nn_module_getattr - ): - return self.getattr_on_nn_module( - source, - base_manager, - base_example_value, - example_value, - base, - source.name(), - guard_manager_enum, - ) - else: - base_manager.getattr_manager( - attr=attr, - source=guard.name, - example_value=example_value, - guard_manager_enum=guard_manager_enum, - ) + base_manager = self.get_guard_manager_from_source(base_source) + if val: + # Just install a getattr manager. GetAttrGuardAccessor itself + # acts as hasattr guard. + example_value = self.get(source.name()) + base_example_value = self.get(base) + guard_manager_enum = self.get_guard_manager_type(source, example_value) + + # if the base value is nn.Module, check if we can speedup the + # guard by going through __dict__ attrs. + if ( + isinstance(base_example_value, torch.nn.Module) + and get_custom_getattr(base_example_value) + is unpatched_nn_module_getattr + ): + return self.getattr_on_nn_module( + source, + base_manager, + base_example_value, + example_value, + base, + source.name(), + guard_manager_enum, + ) else: - base_manager.add_no_hasattr_guard( - attr, get_verbose_code_parts(code, guard) + base_manager.getattr_manager( + attr=attr, + source=guard.name, + example_value=example_value, + guard_manager_enum=guard_manager_enum, ) else: - self._produce_guard_code(guard, [code]) + base_manager.add_no_hasattr_guard(attr, get_verbose_code_parts(code, guard)) def NOT_PRESENT_IN_GENERIC_DICT(self, guard: Guard, attr=None) -> None: assert attr is not None @@ -1293,12 +1287,9 @@ def TYPE_MATCH(self, guard: Guard) -> None: code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})" self._set_guard_export_info(guard, [code]) - if config.enable_cpp_guard_manager: - self.get_guard_manager(guard).add_type_match_guard( - obj_id, get_verbose_code_parts(code, guard) - ) - else: - self._produce_guard_code(guard, [code]) + self.get_guard_manager(guard).add_type_match_guard( + obj_id, get_verbose_code_parts(code, guard) + ) def DICT_VERSION(self, guard: Guard): # ___check_dict_version is same as `dict_version(x) == y` @@ -1308,14 +1299,11 @@ def DICT_VERSION(self, guard: Guard): code = f"___dict_version({ref}) == {version}" self._set_guard_export_info(guard, [code]) - if config.enable_cpp_guard_manager: - # TODO(anijain2305) - Delete this when DictGuardManager uses tags - # for dicts. - self.get_guard_manager(guard).add_dict_version_guard( - val, get_verbose_code_parts(code, guard) - ) - else: - self._produce_guard_code(guard, [code]) + # TODO(anijain2305) - Delete this when DictGuardManager uses tags + # for dicts. + self.get_guard_manager(guard).add_dict_version_guard( + val, get_verbose_code_parts(code, guard) + ) def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool): dict_ref = self.arg_ref(guard) @@ -1324,12 +1312,9 @@ def DICT_CONTAINS(self, guard: Guard, key: str, invert: bool): code = f"{maybe_not}___dict_contains({key!r}, {dict_ref})" self._set_guard_export_info(guard, [code]) - if config.enable_cpp_guard_manager: - self.get_guard_manager(guard).add_dict_contains_guard( - not invert, key, get_verbose_code_parts(code, guard) - ) - else: - self._produce_guard_code(guard, [code]) + self.get_guard_manager(guard).add_dict_contains_guard( + not invert, key, get_verbose_code_parts(code, guard) + ) def ID_MATCH(self, guard: Guard): # ___check_obj_id is same as `id(x) == y` @@ -1345,12 +1330,9 @@ def ID_MATCH(self, guard: Guard): code = f"___check_obj_id({ref}, {id_val})" self._set_guard_export_info(guard, [code]) - if config.enable_cpp_guard_manager: - self.get_guard_manager(guard).add_id_match_guard( - id_val, get_verbose_code_parts(code, guard) - ) - else: - self._produce_guard_code(guard, [code]) + self.get_guard_manager(guard).add_id_match_guard( + id_val, get_verbose_code_parts(code, guard) + ) # Keep track of ID_MATCH'd objects. This will be used to modify the # cache size logic @@ -1371,32 +1353,22 @@ def NOT_NONE_MATCH(self, guard: Guard, value=None): code = f"{ref} is not None" self._set_guard_export_info(guard, [code]) - if config.enable_cpp_guard_manager: - self.get_guard_manager(guard).add_not_none_guard( - get_verbose_code_parts(code, guard) - ) - else: - self._produce_guard_code(guard, [code]) + self.get_guard_manager(guard).add_not_none_guard( + get_verbose_code_parts(code, guard) + ) def NAME_MATCH(self, guard: Guard): self._guard_on_attribute(guard, "__name__", GuardBuilder.EQUALS_MATCH) def DATA_PTR_MATCH(self, guard: Guard): - # Add a type check. C++ guard has the type check internally, so only - # enable it for Python guards. - if not config.enable_cpp_guard_manager: - self.TYPE_MATCH(guard) - + # C++ guard has the type check internally obj = self.get(guard.name) code = f"{self.arg_ref(guard)}.data_ptr() == {obj.data_ptr()}" self._set_guard_export_info(guard, [code]) - if config.enable_cpp_guard_manager: - self.get_guard_manager(guard).add_data_ptr_guard( - obj, get_verbose_code_parts(code, guard) - ) - else: - self._produce_guard_code(guard, [code]) + self.get_guard_manager(guard).add_data_ptr_guard( + obj, get_verbose_code_parts(code, guard) + ) def DUAL_LEVEL(self, guard: Guard): # Invalidate dual level if current dual level is different than the one @@ -1404,19 +1376,15 @@ def DUAL_LEVEL(self, guard: Guard): dual_level = torch.autograd.forward_ad._current_level code = [f"torch.autograd.forward_ad._current_level == {dual_level}"] self._set_guard_export_info(guard, [code]) - if config.enable_cpp_guard_manager: - # TODO(anijain2305) - Consider this moving this guard to C++ - forward_ad = torch.autograd.forward_ad + # TODO(anijain2305) - Consider this moving this guard to C++ + forward_ad = torch.autograd.forward_ad - def fn(x): - return forward_ad._current_level == dual_level + def fn(x): + return forward_ad._current_level == dual_level - assert self.guard_manager # to make mypy happy - self.guard_manager.root.add_lambda_guard( - fn, get_verbose_code_parts(code, guard) - ) - else: - self._produce_guard_code(guard, code) + self.guard_manager.root.add_lambda_guard( + fn, get_verbose_code_parts(code, guard) + ) def FUNCTORCH_STACK_MATCH(self, guard: Guard): # Invalidate functorch code if current level is different than @@ -1426,19 +1394,15 @@ def FUNCTORCH_STACK_MATCH(self, guard: Guard): code = [f"torch._functorch.pyfunctorch.compare_functorch_state({states})"] self._set_guard_export_info(guard, code) - if config.enable_cpp_guard_manager: - # TODO(anijain2305) - Consider this moving this guard to C++ - compare_fn = torch._functorch.pyfunctorch.compare_functorch_state + # TODO(anijain2305) - Consider this moving this guard to C++ + compare_fn = torch._functorch.pyfunctorch.compare_functorch_state - def fn(x): - return compare_fn(states) + def fn(x): + return compare_fn(states) - assert self.guard_manager # to make mypy happy - self.guard_manager.root.add_lambda_guard( - fn, get_verbose_code_parts(code, guard) - ) - else: - self._produce_guard_code(guard, code) + self.guard_manager.root.add_lambda_guard( + fn, get_verbose_code_parts(code, guard) + ) def TENSOR_SUBCLASS_METADATA_MATCH(self, guard: Guard): value = self.get(guard.name) @@ -1457,15 +1421,9 @@ def metadata_checker(x): return x.__tensor_flatten__()[1] == original_metadata global_name = f"___check_metadata_{id(metadata_checker)}_c{CompileContext.current_compile_id()}" - if config.enable_cpp_guard_manager: - self.get_guard_manager(guard).add_lambda_guard( - metadata_checker, get_verbose_code_parts(global_name, guard) - ) - else: - global_scope = self.get("G") - global_scope[global_name] = metadata_checker - code = [f"{global_name}({self.get(guard.name)})"] - self._produce_guard_code(guard, code) + self.get_guard_manager(guard).add_lambda_guard( + metadata_checker, get_verbose_code_parts(global_name, guard) + ) def EQUALS_MATCH(self, guard: Guard): ref = self.arg_ref(guard) @@ -1536,13 +1494,10 @@ def EQUALS_MATCH(self, guard: Guard): code.append(f"__math_isnan({ref})") self._set_guard_export_info(guard, code) - if config.enable_cpp_guard_manager: - self.get_guard_manager(guard).add_lambda_guard( - _get_closure_vars()["__math_isnan"], - get_verbose_code_parts(code, guard), - ) - else: - self._produce_guard_code(guard, code) + self.get_guard_manager(guard).add_lambda_guard( + _get_closure_vars()["__math_isnan"], + get_verbose_code_parts(code, guard), + ) return # Python math library doesn't support complex nan, so we need to use numpy @@ -1552,58 +1507,24 @@ def EQUALS_MATCH(self, guard: Guard): code.append(f"__numpy_isnan({ref})") self._set_guard_export_info(guard, code) - if config.enable_cpp_guard_manager: - self.get_guard_manager(guard).add_lambda_guard( - _get_closure_vars()["__numpy_isnan"], - get_verbose_code_parts(code, guard), - ) - else: - self._produce_guard_code(guard, code) - return - - if config.enable_cpp_guard_manager: - # Construct a debug string to put into the c++ equals match guard. - code = [f"{ref} == {val!r}"] - if istype(val, ok_mutable_types): - # C++ guards perform a pointer equality check to speedup guards, but the assumption is that the object - # is mutable. For a few corner cases like sets and lists, we make a deepcopy to purposefully fail the - # pointer equality check. - val = deepcopy(val) - self.get_guard_manager(guard).add_equals_match_guard( - val, get_verbose_code_parts(code, guard) + self.get_guard_manager(guard).add_lambda_guard( + _get_closure_vars()["__numpy_isnan"], + get_verbose_code_parts(code, guard), ) - self._set_guard_export_info(guard, code) return - code = [] - - # If matching equality against list/tuple, we must also check that - # the internal types match. (TODO: what about nested lists?) - if istype(val, (list, tuple)): - # NB: SEQUENCE_LENGTH takes care of the outer __check_type_id test - self.SEQUENCE_LENGTH(guard) - - for idx, elem in enumerate(val): - code.append( - f"___check_type_id({ref}[{idx}], {self.id_ref(type(elem))})" - ) - else: - # Add type check to prevent equality check between tensor and non-tensor. - self.TYPE_MATCH(guard) - - if istype(val, torch.Size): - val = tuple(val) - - # Code object can not be compared against their string representation - # I.e `eval(f"{compile('2+2','','exec')!r}")` raises SyntaxError - assert not istype(val, types.CodeType) - - # TODO: It feels like it would be better to just implement our own - # equality test in C that handles all of the necessary type checking - # and NaN tests - code.append(f"{ref} == {val!r}") - self._produce_guard_code(guard, code) + # Construct a debug string to put into the c++ equals match guard. + code = [f"{ref} == {val!r}"] + if istype(val, ok_mutable_types): + # C++ guards perform a pointer equality check to speedup guards, but the assumption is that the object + # is mutable. For a few corner cases like sets and lists, we make a deepcopy to purposefully fail the + # pointer equality check. + val = deepcopy(val) + self.get_guard_manager(guard).add_equals_match_guard( + val, get_verbose_code_parts(code, guard) + ) self._set_guard_export_info(guard, code) + return def CONSTANT_MATCH(self, guard: Guard): val = self.get(guard.name) @@ -1648,7 +1569,7 @@ def SEQUENCE_LENGTH(self, guard): value = self.get(guard.name) t = type(value) - if not (config.enable_cpp_guard_manager and isinstance(value, dict)): + if not isinstance(value, dict): # C++ DICT_LENGTH checks for type self.TYPE_MATCH(guard) @@ -1659,40 +1580,30 @@ def SEQUENCE_LENGTH(self, guard): code.append(f"len({ref}) == {len(value)}") self._set_guard_export_info(guard, code) - if config.enable_cpp_guard_manager: - if isinstance(value, dict): - self.get_guard_manager(guard).add_dict_length_check_guard( - len(value), get_verbose_code_parts(code, guard) - ) - else: - self.get_guard_manager(guard).add_length_check_guard( - len(value), get_verbose_code_parts(code, guard) - ) + if isinstance(value, dict): + self.get_guard_manager(guard).add_dict_length_check_guard( + len(value), get_verbose_code_parts(code, guard) + ) else: - self._produce_guard_code(guard, code) + self.get_guard_manager(guard).add_length_check_guard( + len(value), get_verbose_code_parts(code, guard) + ) def TUPLE_ITERATOR_LEN(self, guard): ref = self.arg_ref(guard) value = self.get(guard.name) t = type(value) - if not config.enable_cpp_guard_manager: - # C++ guard already checks the type - self.TYPE_MATCH(guard) - code = [] code.append(f"___tuple_iterator_len({ref}) == {tuple_iterator_len(value)}") self._set_guard_export_info(guard, code) - if config.enable_cpp_guard_manager: - t = type(value) - obj_id = self.id_ref(t) + t = type(value) + obj_id = self.id_ref(t) - self.get_guard_manager(guard).add_tuple_iterator_length_guard( - tuple_iterator_len(value), obj_id, get_verbose_code_parts(code, guard) - ) - else: - self._produce_guard_code(guard, code) + self.get_guard_manager(guard).add_tuple_iterator_length_guard( + tuple_iterator_len(value), obj_id, get_verbose_code_parts(code, guard) + ) # TODO(voz): Deduplicate w/ AOTAutograd dupe input guards def DUPLICATE_INPUT(self, guard, source_b): @@ -1707,21 +1618,18 @@ def DUPLICATE_INPUT(self, guard, source_b): code = [f"{ref_b} is {ref_a}"] self._set_guard_export_info(guard, code) - if config.enable_cpp_guard_manager: - # Check that the guard has not been inserted already - key = (ref_a, ref_b) - if key in self._cached_duplicate_input_guards: - return - self._cached_duplicate_input_guards.add((ref_a, ref_b)) - self._cached_duplicate_input_guards.add((ref_b, ref_a)) - - install_object_aliasing_guard( - self.get_guard_manager(guard), - self.get_guard_manager_from_source(source_b), - get_verbose_code_parts(code, guard), - ) - else: - self._produce_guard_code(guard, code) + # Check that the guard has not been inserted already + key = (ref_a, ref_b) + if key in self._cached_duplicate_input_guards: + return + self._cached_duplicate_input_guards.add((ref_a, ref_b)) + self._cached_duplicate_input_guards.add((ref_b, ref_a)) + + install_object_aliasing_guard( + self.get_guard_manager(guard), + self.get_guard_manager_from_source(source_b), + get_verbose_code_parts(code, guard), + ) def DICT_KEYS(self, guard): # Guard on the keys and their order @@ -1742,24 +1650,18 @@ def DICT_KEYS(self, guard): code.append(f"list({ref}.keys()) == {const_keys_repr}") self._set_guard_export_info(guard, code) - if config.enable_cpp_guard_manager: - if self.requires_key_order_guarding(guard.originating_source): - self.guard_on_dict_keys_and_order(value, guard) - else: - self.guard_on_dict_keys_and_ignore_order(value, guard) + if self.requires_key_order_guarding(guard.originating_source): + self.guard_on_dict_keys_and_order(value, guard) else: - self._produce_guard_code(guard, code) + self.guard_on_dict_keys_and_ignore_order(value, guard) def WEAKREF_ALIVE(self, guard): code = [f"{self.arg_ref(guard)} is not None"] self._set_guard_export_info(guard, code) - if config.enable_cpp_guard_manager: - self.get_guard_manager(guard).add_not_none_guard( - get_verbose_code_parts(code, guard) - ) - else: - self._produce_guard_code(guard, code) + self.get_guard_manager(guard).add_not_none_guard( + get_verbose_code_parts(code, guard) + ) def DICT_CONST_KEYS(self, guard): """Constant keys match""" @@ -1767,21 +1669,14 @@ def DICT_CONST_KEYS(self, guard): value = self.get(guard.name) t = type(value) - if not config.enable_cpp_guard_manager: - # DictGuardManager supports TYPE_MATCH internally - self.TYPE_MATCH(guard) - code = [] code.append(f"list({ref}.keys()) == {list(value.keys())!r}") self._set_guard_export_info(guard, code) - if config.enable_cpp_guard_manager: - if self.requires_key_order_guarding(guard.originating_source): - self.guard_on_dict_keys_and_order(value, guard) - else: - self.guard_on_dict_keys_and_ignore_order(value, guard) + if self.requires_key_order_guarding(guard.originating_source): + self.guard_on_dict_keys_and_order(value, guard) else: - self._produce_guard_code(guard, code) + self.guard_on_dict_keys_and_ignore_order(value, guard) def EMPTY_NN_MODULE_HOOKS_DICT(self, guard): """Special guard to skip guards on empty hooks. This is controlled by skip_nnmodule_hook_guards""" @@ -1813,12 +1708,9 @@ def DEFAULT_DEVICE(self, guard: Guard): code = [f"utils_device.CURRENT_DEVICE == {m.CURRENT_DEVICE!r}"] self._set_guard_export_info(guard, code) - if config.enable_cpp_guard_manager: - self.get_guard_manager(guard).add_default_device_guard( - get_verbose_code_parts(code, guard) - ) - else: - self._produce_guard_code(guard, code) + self.get_guard_manager(guard).add_default_device_guard( + get_verbose_code_parts(code, guard) + ) def SHAPE_ENV(self, guard: Guard): # Let's handle ShapeEnv guards. To do this, we will resolve @@ -1845,6 +1737,7 @@ def get_sources(t_id, dim): Tuple[Source, Union[Source, Symbol], Callable] ] = [] phantom_symbols: Dict[str, Symbol] = {} + relaxed_sources: Set[Source] = set() for constraint in output_graph.export_constraints: if constraint.t_id in output_graph.tracked_fakes_id_to_source: torch.export.dynamic_shapes._process_equalities( @@ -1855,6 +1748,7 @@ def get_sources(t_id, dim): source_pairs, derived_equalities, phantom_symbols, + relaxed_sources, ) else: log.warning("Untracked tensor used in export constraints") @@ -1862,6 +1756,7 @@ def get_sources(t_id, dim): source_pairs=source_pairs, derived_equalities=derived_equalities, phantom_symbols=list(phantom_symbols.values()), + relaxed_sources=relaxed_sources, warn_only=False, ) else: @@ -1883,18 +1778,14 @@ def get_sources(t_id, dim): for code in code_parts: self._set_guard_export_info(guard, [code]) - if config.enable_cpp_guard_manager: - # Install all the symbolic guards in one lambda guard. These are run - # at the very end of the RootGuardManager via epilogue guards. - # TODO(anijain2305,williamwen42) - Consider moving this to C++. - self.add_python_lambda_leaf_guard_to_root( - code_parts, - verbose_code_parts, - closure_vars={**SYMPY_INTERP, **_get_closure_vars()}, - ) - else: - for code in code_parts: - self._produce_guard_code(guard, [code], shape_env=True) + # Install all the symbolic guards in one lambda guard. These are run + # at the very end of the RootGuardManager via epilogue guards. + # TODO(anijain2305,williamwen42) - Consider moving this to C++. + self.add_python_lambda_leaf_guard_to_root( + code_parts, + verbose_code_parts, + closure_vars={**SYMPY_INTERP, **_get_closure_vars()}, + ) def TENSOR_MATCH(self, guard: Guard, value=None): # For FSDP modules, we can skip guards on nn module tensors because FSDP @@ -1963,30 +1854,29 @@ def TENSOR_MATCH(self, guard: Guard, value=None): self.tensor_check_names.append(tensor_name) self.tensor_check_guards.append(guard) - if config.enable_cpp_guard_manager: - guard_manager = self.get_guard_manager(guard) - # Keep track of all the tensor guard managers to insert - # NoAliasing check at the end. - self.tensor_check_guard_managers.append(guard_manager) - - output_graph = self.check_fn_manager.output_graph - metadata = output_graph.input_source_to_sizes_strides[ - guard.originating_source - ] - size = convert_to_concrete_values(metadata["size"]) - stride = convert_to_concrete_values(metadata["stride"]) - - verbose_code_parts = get_verbose_code_parts( - get_tensor_guard_code_part(value, tensor_name, size, stride), - guard, - ) - guard_manager.add_tensor_match_guard( - value, - size, - stride, - tensor_name, - verbose_code_parts, - ) + guard_manager = self.get_guard_manager(guard) + # Keep track of all the tensor guard managers to insert + # NoAliasing check at the end. + self.tensor_check_guard_managers.append(guard_manager) + + output_graph = self.check_fn_manager.output_graph + metadata = output_graph.input_source_to_sizes_strides[ + guard.originating_source + ] + size = convert_to_concrete_values(metadata["size"]) + stride = convert_to_concrete_values(metadata["stride"]) + + verbose_code_parts = get_verbose_code_parts( + get_tensor_guard_code_part(value, tensor_name, size, stride), + guard, + ) + guard_manager.add_tensor_match_guard( + value, + size, + stride, + tensor_name, + verbose_code_parts, + ) # A frame is valid for reuse with dynamic dimensions if the new # (user-requested) dynamic dimensions are a subset of the old @@ -2026,10 +1916,9 @@ def TENSOR_MATCH(self, guard: Guard, value=None): dynamic_indices = value._dynamo_dynamic_indices code_part = f"(({tensor_name}._dynamo_dynamic_indices.issubset({dynamic_indices})) if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True)" # noqa: B950 code.append(code_part) - if config.enable_cpp_guard_manager: - self.get_guard_manager(guard).add_dynamic_indices_guard( - dynamic_indices, get_verbose_code_parts(code_part, guard) - ) + self.get_guard_manager(guard).add_dynamic_indices_guard( + dynamic_indices, get_verbose_code_parts(code_part, guard) + ) # In the case of us not having any dynamic dimension indices, we compiled the frame with no chance of # raising for this specific tensor - and any inputs with more dynamic user directives specified must be recompiled. else: @@ -2037,23 +1926,12 @@ def TENSOR_MATCH(self, guard: Guard, value=None): f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False" ) code.append(code_part) - if config.enable_cpp_guard_manager: - self.get_guard_manager(guard).add_no_hasattr_guard( - "_dynamo_dynamic_indices", - get_verbose_code_parts(code_part, guard), - ) + self.get_guard_manager(guard).add_no_hasattr_guard( + "_dynamo_dynamic_indices", + get_verbose_code_parts(code_part, guard), + ) if len(code) > 0: self._set_guard_export_info(guard, code) - if not config.enable_cpp_guard_manager: - self._produce_guard_code(guard, code) - - # A util that appends guarded code - def _produce_guard_code(self, guard, code_list, shape_env=False): - assert not config.enable_cpp_guard_manager - if shape_env: - self.shape_env_code.append(GuardCodeList(code_list, guard)) - else: - self.code.append(GuardCodeList(code_list, guard)) # A util that in the case of export, adds data onto guards def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None): @@ -2086,8 +1964,9 @@ def _set_guard_export_info(self, guard, code_list, provided_guarded_object=None) obj_ref = None # Not necessary to have weakref for Enum type, but there is a bug that # makes hasattr(guarded_object.__class__, "__weakref__") return True. + # See D64140537 for why we are checking for tuple. if hasattr(guarded_object.__class__, "__weakref__") and not isinstance( - guarded_object, enum.Enum + guarded_object, (enum.Enum, tuple) ): obj_ref = weakref.ref(guarded_object) @@ -2232,9 +2111,7 @@ def __init__( ): guards = output_graph.guards if output_graph else None self._weakrefs: Dict[int, ReferenceType[object]] = {} - self.guard_manager = None - if config.enable_cpp_guard_manager: - self.guard_manager = GuardManager() + self.guard_manager = GuardManagerWrapper() self.output_graph = output_graph w_builder = None @@ -2294,40 +2171,40 @@ def cleanup_builder(weak_b): guard.create(builder) - self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn) + self.compile_check_fn(builder, guards, guard_fail_fn) # Keep track of weak references of objects with ID_MATCH guard. This - # info is stored alongside optimized_code and check_fn and is used to + # info is stored alongside optimized_code and guard_manager and is used to # limit the number of cache entries with same ID_MATCH'd object. # TODO(anijain2305) - Currently this information is stored as an attr on - # the check_fn itself to avoid changing CacehEntry datastructure in - # eval_frame.c. In future, we should probably replace check_fn with a + # the guard_manager itself to avoid changing CacheEntry data structure in + # eval_frame.c. In future, we should probably replace guard_manager with a # queryable data structure such that this information is already present # in some form. - self.check_fn.id_matched_objs = builder.id_matched_objs + self.guard_manager.id_matched_objs = builder.id_matched_objs - if config.enable_cpp_guard_manager: - # TODO: don't do the string rep, do something more structured here - torch._logging.trace_structured( - "dynamo_cpp_guards_str", payload_fn=lambda: str(self.guard_manager) - ) - guards_log.debug("%s", self.guard_manager) - assert self.guard_manager # to make mypy happy - self.guard_manager.id_matched_objs = builder.id_matched_objs - self.check_fn = self.guard_manager - - # Check that the guard returns True. False means that we will always - # recompile. - # TODO(anijain2305, ydwu4) - Skipping export because of following test - # python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs - if not output_graph.export: - if not self.guard_manager.check(output_graph.local_scope): - reasons = get_guard_fail_reason_helper( - self.guard_manager, # type: ignore[arg-type] - output_graph.local_scope, - CompileContext.current_compile_id(), - ) - raise AssertionError(f"Guard check failed: {reasons}") + # TODO: don't do the string rep, do something more structured here + torch._logging.trace_structured( + "dynamo_cpp_guards_str", payload_fn=lambda: str(self.guard_manager) + ) + guards_log.debug("%s", self.guard_manager) + self.guard_manager.id_matched_objs = builder.id_matched_objs + + # Check that the guard returns True. False means that we will always + # recompile. + # TODO(anijain2305, ydwu4) - Skipping export because of following test + # python -s test/dynamo/test_export.py -k test_export_with_symbool_inputs + if not output_graph.export: + if not self.guard_manager.check(output_graph.local_scope): + reasons = get_guard_fail_reason_helper( + self.guard_manager, # type: ignore[arg-type] + output_graph.local_scope, + CompileContext.current_compile_id(), + ) + raise AssertionError(f"Guard check failed: {reasons}") + + if guards_log.isEnabledFor(logging.DEBUG): + self.profile_guard_eval(output_graph.local_scope) # NB - We have to very careful of cleaning up here. Because of the # invalidate function, we can create a weakref finalizer that keeps @@ -2340,6 +2217,18 @@ def cleanup_builder(weak_b): self._weakrefs.clear() self.output_graph = None + def profile_guard_eval(self, f_locals): + start_time = time.time() + iterations = 0 + profile_duration = 1 # unit is seconds + + while time.time() - start_time < profile_duration: + self.guard_manager.check(f_locals) + iterations += 1 + + guard_latency = 10**6 / iterations # us + guards_log.debug("Guard eval latency = %s us", f"{guard_latency:.2f}") + def compile_check_fn(self, builder, guards_out, guard_fail_fn): # see parallel handling of ".0" / "___implicit0" in _eval_frame.c largs = builder.argnames @@ -2355,26 +2244,15 @@ def compile_check_fn(self, builder, guards_out, guard_fail_fn): self.torch_function_mode_stack ) - if config.enable_cpp_guard_manager: - # Insert the global_state guard - assert self.guard_manager # to make mypy happy - self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) + # Insert the global_state guard + self.guard_manager.root.add_global_state_guard(["___check_global_state()"]) - self.guard_manager.root.add_torch_function_mode_stack_guard( - self.torch_function_mode_stack, - ["___check_torch_function_mode_stack()"], - ) - # Clear references to torch_function modes held in the list - self.torch_function_mode_stack = None - else: - # Don't report this guard, it's always the same, useless! - global_guard = "___check_global_state()" - code_parts.append(global_guard) - verbose_code_parts.append(global_guard) - - tf_mode_stack_guard = "___check_torch_function_mode_stack()" - code_parts.append(tf_mode_stack_guard) - verbose_code_parts.append(tf_mode_stack_guard) + self.guard_manager.root.add_torch_function_mode_stack_guard( + self.torch_function_mode_stack, + ["___check_torch_function_mode_stack()"], + ) + # Clear references to torch_function modes held in the list + self.torch_function_mode_stack = None def add_code_part(code_part, guard, log_only=False): verbose_code_part = get_verbose_code_part(code_part, guard) @@ -2419,54 +2297,14 @@ def add_code_part(code_part, guard, log_only=False): if code not in seen: # If Cpp guard manager is enabled, we don't need to add to # code_parts. - add_code_part(code, gcl.guard, config.enable_cpp_guard_manager) + add_code_part(code, gcl.guard, True) seen.add(code) tensor_check_names = builder.tensor_check_names check_tensors_fn = None check_tensors_verbose_fn = None - if tensor_check_names and not config.enable_cpp_guard_manager: - tensor_check_guards = builder.tensor_check_guards - assert ( - not self.output_graph.export - ), "Illegal to set tensor_check_names in export." - tensor_check_examples = builder.tensor_check_examples - - dynamic_dims_sizes = [] - dynamic_dims_strides = [] - for t, g in zip(tensor_check_examples, tensor_check_guards): - metadata = self.output_graph.input_source_to_sizes_strides[ - g.originating_source - ] - dynamic_dims_sizes.append(convert_to_concrete_values(metadata["size"])) - dynamic_dims_strides.append( - convert_to_concrete_values(metadata["stride"]) - ) - tensor_guards = TensorGuards( - *tensor_check_examples, - dynamic_dims_sizes=dynamic_dims_sizes, - dynamic_dims_strides=dynamic_dims_strides, - ) - check_tensors_fn = tensor_guards.check - check_tensors_verbose_fn = tensor_guards.check_verbose - tensor_check_args = ", ".join( - tensor_check_names + ["tensor_check_names=tensor_check_names"] - ) - # Do this manually, to un-stagger the guards in log message - code_parts.append(f"___check_tensors({tensor_check_args})") - verbose_code_parts.append(f"___check_tensors({tensor_check_args})") - - for i, name in enumerate(tensor_check_names): - # This is a copy of what guards.cpp checks against - # Keep this in sync with TensorCheck constructor - t = tensor_check_examples[i] - sizes = dynamic_dims_sizes[i] - strides = dynamic_dims_strides[i] - code_part = get_tensor_guard_code_part(t, name, sizes, strides) - add_code_part(code_part, tensor_check_guards[i], log_only=True) - - if len(tensor_check_names) > 1 and config.enable_cpp_guard_manager: + if len(tensor_check_names) > 1: # Install tensor aliasing guard. TENSOR_MATCH guards are already # installed for cpp guard manager. install_no_tensor_aliasing_guard( @@ -2489,13 +2327,12 @@ def add_code_part(code_part, guard, log_only=False): source_a = guard.input_source_a source_b = guard.input_source_b code_part = f"{source_a.name()} is {source_b.name()}" - if config.enable_cpp_guard_manager: - install_object_aliasing_guard( - builder.get_guard_manager_from_source(source_a), - builder.get_guard_manager_from_source(source_b), - [code_part], - ) - add_code_part(code_part, None, config.enable_cpp_guard_manager) + install_object_aliasing_guard( + builder.get_guard_manager_from_source(source_a), + builder.get_guard_manager_from_source(source_b), + [code_part], + ) + add_code_part(code_part, None, True) else: raise RuntimeError(f"Unknown GuardEnvExpr: {guard}") @@ -2505,7 +2342,7 @@ def add_code_part(code_part, guard, log_only=False): for code in gcl.code_list: # Shape env guards are already added for CPP guard manager in # SHAPE_ENV implementation. - add_code_part(code, gcl.guard, config.enable_cpp_guard_manager) + add_code_part(code, gcl.guard, True) # OK, all done generating guards if structured_guard_fns: @@ -2528,71 +2365,39 @@ def add_code_part(code_part, guard, log_only=False): } globals_for_guard_fn = {"G": builder.scope["G"]} - if config.enable_cpp_guard_manager: - # Guard manager construction is complete - assert self.guard_manager # to make mypy happy - # TODO (anijain2305) - When enable_cpp_guard_manager is ON by - # default, change the guard_fn name to be guard_manager everywhere - # to avoid confusion. - guard_fn = self.guard_manager - # Ensure we did not miss to insert a guard in cpp guard manager. - assert len(code_parts) == 0 - else: - unique_code_parts = list(unique(code_parts)) - make_guard_fn_args = ", ".join(closure_vars.keys()) - guard_body, pycode = build_guard_function( - unique_code_parts, make_guard_fn_args - ) - - if os.environ.get("TORCHDYNAMO_PRINT_GUARDS", None) == "1": - print("GUARDS\n", guard_body) - - out: Dict[str, Any] = {} - - # We don't put builder.scope as the globals in exec call because - # guard_fn.__globals__ becomes equal to builder.scope. This causes - # guard_fn to hold a referece to f_locals sitting in builder.scope["L"] - try: - exec(pycode, globals_for_guard_fn, out) - except SyntaxError as ex: - log.exception("Failed to exec guard at line %s.\n%s", ex.lineno, pycode) - raise - guard_fn = out["___make_guard_fn"](*closure_vars.values()) - - guard_fn.closure_vars = closure_vars - # TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both - guard_fn.args = largs - if config.enable_cpp_guard_manager: - guard_fn.populate_code_parts_for_debugging() - else: - guard_fn.code_parts = code_parts - guard_fn.verbose_code_parts = verbose_code_parts + # Guard manager construction is complete. Ensure we did not miss to + # insert a guard in cpp guard manager. + assert len(code_parts) == 0 + + self.guard_manager.closure_vars = closure_vars + self.guard_manager.args = largs + self.guard_manager.populate_code_parts_for_debugging() + self.guard_manager.verbose_code_parts = verbose_code_parts # Grab only G, but preserve "G" because guards access it as "G" - guard_fn.global_scope = globals_for_guard_fn - guard_fn.guard_fail_fn = guard_fail_fn + self.guard_manager.global_scope = globals_for_guard_fn + self.guard_manager.guard_fail_fn = guard_fail_fn # will be populated by a non-owning reference to CacheEntry/ExtraState # when the CacheEntry is constructed - guard_fn.cache_entry = None - guard_fn.extra_state = None - guard_fn.no_tensor_aliasing_sources = tensor_check_names - return guard_fn + self.guard_manager.cache_entry = None + self.guard_manager.extra_state = None + self.guard_manager.no_tensor_aliasing_sources = tensor_check_names def invalidate(self): # Some tests reveal that CheckFunctionManager has no attribute - # check_fn, but this case should not be of any concern. + # guard_manager, but this case should not be of any concern. # This case doesn't seem easy to repro. if ( - hasattr(self, "check_fn") - and self.check_fn is not DeletedGuardFn - and (cache_entry := self.check_fn.cache_entry) is not None - and (extra_state := self.check_fn.extra_state) is not None + hasattr(self, "guard_manager") + and self.guard_manager is not DeletedGuardFn + and (cache_entry := self.guard_manager.cache_entry) is not None + and (extra_state := self.guard_manager.extra_state) is not None ): assert isinstance(cache_entry, CacheEntry) assert isinstance(extra_state, ExtraState) extra_state.invalidate(cache_entry) - self.check_fn.cache_entry = None - self.check_fn.extra_state = None - self.check_fn = DeletedGuardFn + self.guard_manager.cache_entry = None + self.guard_manager.extra_state = None + self.guard_manager = DeletedGuardFn # type: ignore[assignment] def id_ref(self, obj): """add a weakref, return the id""" @@ -2702,54 +2507,49 @@ def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope): def get_guard_fail_reason_helper( - guard_fn: GuardFn, + guard_manager: GuardFn, f_locals: Dict[str, object], compile_id: CompileId, ) -> str: """ - Return the reason why `guard_fn` failed. + Return the reason why `guard_manager` failed. Updates `guard_failures` with the generated reason. - Only the first failed check of guard_fn is reported. + Only the first failed check of guard_manager is reported. """ - scope = {"L": f_locals, "G": guard_fn.global_scope["G"]} - scope.update(guard_fn.closure_vars) + scope = {"L": f_locals, "G": guard_manager.global_scope["G"]} + scope.update(guard_manager.closure_vars) reasons: List[str] = [] no_tensor_aliasing_check_failed = False verbose_code_parts: List[str] = [] - if config.enable_cpp_guard_manager: - guard_manager = guard_fn - guard_debug_info = guard_manager.check_verbose(f_locals) # type: ignore[attr-defined] - # For test_export_with_map_cond, the check_verbose fail even without the - # C++ guard manager. We need to fix the issue to remove the comment. - # assert not guard_debug_info.result - if not guard_debug_info.result: - verbose_code_parts = guard_debug_info.verbose_code_parts - # verbose_code_parts is either the actual reason (e.g. in case of - # TENSOR_MATCH) or it could be a list of verbose_code_part that we - # passed to the leaf guard at construction time. If its a list, we - # walk through this list and find the guard that failed. This is - # very important for symbolic shape guards which are currently - # installed as a lambda guard and can encompass a long list of code_parts. - - if len(verbose_code_parts) == 1: - if "Duplicate tensor found" in verbose_code_parts[0]: - no_tensor_aliasing_check_failed = True - else: - reasons = verbose_code_parts - verbose_code_parts = [] - else: - verbose_code_parts = guard_fn.verbose_code_parts - # This is not needed for CPP guard because the verbose check is already - # run in C++. - scope["___check_tensors"] = scope["___check_tensors_verbose"] + guard_debug_info = guard_manager.check_verbose(f_locals) # type: ignore[attr-defined] + # For test_export_with_map_cond, the check_verbose fail even without the + # C++ guard manager. We need to fix the issue to remove the comment. + # assert not guard_debug_info.result + if not guard_debug_info.result: + verbose_code_parts = guard_debug_info.verbose_code_parts + # verbose_code_parts is either the actual reason (e.g. in case of + # TENSOR_MATCH) or it could be a list of verbose_code_part that we + # passed to the leaf guard at construction time. If its a list, we + # walk through this list and find the guard that failed. This is + # very important for symbolic shape guards which are currently + # installed as a lambda guard and can encompass a long list of code_parts. + + if len(verbose_code_parts) == 1: + if "Duplicate tensor found" in verbose_code_parts[0]: + no_tensor_aliasing_check_failed = True + else: + reasons = verbose_code_parts + verbose_code_parts = [] if no_tensor_aliasing_check_failed: - reasons = recompilation_reason_for_no_tensor_aliasing_guard(guard_fn, scope) + reasons = recompilation_reason_for_no_tensor_aliasing_guard( + guard_manager, scope + ) else: for part in verbose_code_parts: - global_scope = dict(guard_fn.global_scope) + global_scope = dict(guard_manager.global_scope) global_scope["__compile_source__"] = part with report_compile_source_on_error(): try: @@ -2774,17 +2574,17 @@ def get_guard_fail_reason_helper( def get_guard_fail_reason( - guard_fn: GuardFn, + guard_manager: GuardFn, code: types.CodeType, f_locals: Dict[str, object], compile_id: CompileId, ) -> str: - reason_str = get_guard_fail_reason_helper(guard_fn, f_locals, compile_id) + reason_str = get_guard_fail_reason_helper(guard_manager, f_locals, compile_id) guard_failures[orig_code_map[code]].append(reason_str) try: - if guard_fn.guard_fail_fn is not None: - guard_fn.guard_fail_fn( + if guard_manager.guard_fail_fn is not None: + guard_manager.guard_fail_fn( GuardFail(reason_str or "unknown reason", orig_code_map[code]) ) except Exception as e: @@ -2806,7 +2606,7 @@ def get_and_maybe_log_recompilation_reason( reasons = [] while cache_entry is not None: reason = get_guard_fail_reason( - cache_entry.check_fn, + cache_entry.guard_manager, cache_entry.code, frame.f_locals, cache_entry.compile_id, @@ -2856,7 +2656,7 @@ def get_and_maybe_log_recompilation_reason( def guard_error_hook( - guard_fn: GuardFn, + guard_manager: GuardFn, code: types.CodeType, f_locals: Dict[str, object], index: int, @@ -2865,16 +2665,15 @@ def guard_error_hook( print( f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}" ) - print("lambda " + ", ".join(guard_fn.args) + ":") - print(" ", " and\n ".join(guard_fn.code_parts)) + print("lambda " + ", ".join(guard_manager.args) + ":") + print(" ", " and\n ".join(guard_manager.code_parts)) - if config.enable_cpp_guard_manager: - print(guard_fn) + print(guard_manager) - local_scope = {"L": f_locals, **guard_fn.closure_vars} - for guard in guard_fn.code_parts: + local_scope = {"L": f_locals, **guard_manager.closure_vars} + for guard in guard_manager.code_parts: try: - eval(guard, guard_fn.global_scope, local_scope) + eval(guard, guard_manager.global_scope, local_scope) except: # noqa: B001,E722 print(f"Malformed guard:\n{guard}") diff --git a/torch/_dynamo/mutation_guard.py b/torch/_dynamo/mutation_guard.py index a773ca334bfbb..bdc24c421dba4 100644 --- a/torch/_dynamo/mutation_guard.py +++ b/torch/_dynamo/mutation_guard.py @@ -6,7 +6,7 @@ from torch.nn import Module from . import config -from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks +from .utils import ExactWeakKeyDictionary, nn_module_has_global_hooks unpatched_nn_module_init = torch.nn.Module.__init__ @@ -99,8 +99,6 @@ def is_dynamic_nn_module(obj: Any, is_export: bool) -> bool: return True if hasattr(obj, "torchdynamo_force_dynamic"): return obj.torchdynamo_force_dynamic - if is_lazy_module(obj): - return False # For export, we will have to fix # 1) Input signature problem because params are lifted as inputs # 2) nn module stack info changes diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 6c052237473e0..cf7e8b5ebb5d3 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -91,7 +91,6 @@ BackwardStateGraphArg, GraphArg, TrackedFake, - VariableBuilder, wrap_fx_proxy, ) from .variables.lists import BaseListVariable @@ -185,7 +184,7 @@ def __init__(self, nn_modules: Dict[str, torch.nn.Module]): for k, v in nn_modules.items(): setattr(self, k, v) - def __repr__(self): + def __repr__(self) -> str: return "FakeRootModule(...)" @@ -498,7 +497,7 @@ def synthetic_graph_input(self, fn, args): cg.store(varname) self.pregraph_bytecode.extend(cg.get_instructions()) source = SyntheticLocalSource(varname) - result = VariableBuilder(self.root_tx, source)(example_value) + result = VariableTracker.build(self.root_tx, example_value, source) TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source( source ) @@ -767,8 +766,8 @@ def register_attr_or_module( ): if is_dynamic_nn_module(target, self.root_tx.export): # Instead of returning UnspecializedNNModuleVariable, call - # VariableBuilder so that it is tracked for mutation. - return VariableBuilder(self.current_tx, **options)(target) + # VariableTracker.build so that it is tracked for mutation. + return VariableTracker.build(self.current_tx, target, **options) options = dict(options) assert "source" in options @@ -860,8 +859,8 @@ def wrap_name(module_key): def wrap_name(module_key): self.output.update_co_names(module_key) self.global_scope[module_key] = target - return VariableBuilder(self, ConstantSource(source_name=module_key))( - target + return VariableTracker.build( + self, target, ConstantSource(source_name=module_key) ) for k, v in self.nn_modules.items(): @@ -873,7 +872,7 @@ def wrap_name(module_key): base = name for i in itertools.count(): - if name not in self.nn_modules: + if name not in self.nn_modules and name not in self.global_scope: self.nn_modules[name] = target if isinstance(target, torch.nn.Module): @@ -906,12 +905,10 @@ def handle_aliases_for_stolen_lists(self, tx): maybe_gm = self.local_scope.get("self") stolen_list_names = get_locals_to_steal(maybe_gm) if not stolen_list_names: - return [] + return [], {} alias_insts = [] - needs_alias: Dict[ - str, List[Union[VariableTracker, AttributeMutationExisting]] - ] = {} + needs_alias: Dict[str, List[VariableTracker]] = {} queue = [ *tx.stack, @@ -927,7 +924,10 @@ def handle_aliases_for_stolen_lists(self, tx): continue if not ( - isinstance(x, (VariableTracker, AttributeMutationExisting)) + ( + x not in self.side_effects.store_attr_mutations + or isinstance(x.mutable_local, AttributeMutationExisting) + ) and isinstance(x.source, GetItemSource) and isinstance(x.source.base, LocalSource) and x.source.base.local_name in stolen_list_names @@ -940,6 +940,7 @@ def handle_aliases_for_stolen_lists(self, tx): needs_alias[stolen_name].append(x) visited = {} + overridden_sources: Dict[Source, Source] = {} for arg in self.graphargs: if not ( isinstance(arg._example, list) @@ -952,6 +953,12 @@ def handle_aliases_for_stolen_lists(self, tx): list_name = arg.source.local_name assert list_name in self.code_options["co_varnames"] for x in needs_alias[list_name]: + # Skip if already handled. + if x.source in overridden_sources: + continue + + # A small codegen optimization because we might have different + # VariableTrackers that share the same source. list_idx = x.source.index if list_idx not in visited: alias_name = self.new_var( @@ -970,9 +977,14 @@ def handle_aliases_for_stolen_lists(self, tx): ) # operate on alias, handled by suffix codegen - x.source = LocalSource(visited[list_idx]) + old_source = x.source + overridden_sources[old_source] = LocalSource(visited[list_idx]) - return alias_insts + # NOTE: we need `overridden_sources` because (1) we want to codegen for + # these list items to use the new local source, but (2) we want to avoid + # updating `source` in place because that might break invariants in + # other parts of Dynamo like guards. + return alias_insts, overridden_sources def compile_subgraph( self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None @@ -1014,7 +1026,8 @@ def compile_subgraph( self.pregraph_bytecode and self.export ), "export does not support pregraph_bytecode" prefix_insts.extend(self.pregraph_bytecode) - prefix_insts.extend(self.handle_aliases_for_stolen_lists(tx)) + alias_insts, overridden_sources = self.handle_aliases_for_stolen_lists(tx) + prefix_insts.extend(alias_insts) def append_prefix_insts(): self.add_output_instructions(prefix_insts) @@ -1082,7 +1095,7 @@ def append_prefix_insts(): self.random_values_var = self.new_var("random_values") rand_fn = disable(_get_gen_rand_values_fn(self.random_calls)) rand_fn_name = self.install_global("__gen_rand_values", rand_fn) - codegen = PyCodegen(tx, root) + codegen = PyCodegen(tx, root, overridden_sources=overridden_sources) random_calls_instructions.extend( codegen.load_function_name(rand_fn_name, True) ) @@ -1120,11 +1133,18 @@ def append_prefix_insts(): ) # restore all the live local vars self.add_output_instructions( - [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)] + [ + PyCodegen(tx, overridden_sources=overridden_sources).create_store( + var + ) + for var in reversed(restore_vars) + ] ) else: graph_output_var = self.new_var("graph_out") - pass1 = PyCodegen(tx, root, graph_output_var) + pass1 = PyCodegen( + tx, root, graph_output_var, overridden_sources=overridden_sources + ) self.codegen_suffix(tx, stack_values, pass1) # one more time now that we have established tempvars @@ -1133,6 +1153,7 @@ def append_prefix_insts(): root, graph_output_var, tempvars={val: None for val, count in pass1.uses.items() if count > 1}, + overridden_sources=overridden_sources, ) self.codegen_suffix(tx, stack_values, pass2) @@ -1157,12 +1178,21 @@ def append_prefix_insts(): # restore all the live local vars self.add_output_instructions( - [PyCodegen(tx).create_store(var) for var in reversed(restore_vars)] + [ + PyCodegen(tx, overridden_sources=overridden_sources).create_store( + var + ) + for var in reversed(restore_vars) + ] ) if stored_graph_output_var: self.add_output_instructions( - [PyCodegen(tx).create_delete(graph_output_var)] + [ + PyCodegen( + tx, overridden_sources=overridden_sources + ).create_delete(graph_output_var) + ] ) def codegen_suffix(self, tx, stack_values, cg): @@ -1996,7 +2026,11 @@ def get_trace_call_log_str(): rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ ( rv.node.name, - rv.node.meta["nn_module_stack"][target][1], + next( + ty + for k, (_, ty) in rv.node.meta["nn_module_stack"].items() + if k.split("@")[0] == target + ), ) ] diff --git a/torch/_dynamo/repro/after_aot.py b/torch/_dynamo/repro/after_aot.py index 1f411b96acdbe..5d3995ee41666 100644 --- a/torch/_dynamo/repro/after_aot.py +++ b/torch/_dynamo/repro/after_aot.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs + import argparse import copy import functools @@ -12,7 +13,8 @@ import uuid from importlib import import_module from tempfile import TemporaryFile -from typing import Any, Callable, Dict, Union +from typing import Any, Callable, Dict, Sequence, TYPE_CHECKING, Union +from typing_extensions import Unpack import torch import torch.fx as fx @@ -45,6 +47,12 @@ from .. import config +if TYPE_CHECKING: + from torch._inductor.codecache import CompiledFxGraph + from torch._inductor.compile_fx import _CompileFxCallableEx, _CompileFxKwargsEx + from torch._inductor.utils import InputType + + log = logging.getLogger(__name__) @@ -56,7 +64,10 @@ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # -def wrap_compiler_debug(unconfigured_compiler_fn, compiler_name: str): +def wrap_compiler_debug( + unconfigured_compiler_fn: "_CompileFxCallableEx", + compiler_name: str, +) -> "_CompileFxCallableEx": """ Minifier for Fx Graph modules after Aot Autograd has finished. We wrap both forward and backward call separately with the backend compiler_fn - like @@ -66,7 +77,11 @@ def wrap_compiler_debug(unconfigured_compiler_fn, compiler_name: str): """ @functools.wraps(unconfigured_compiler_fn) - def debug_wrapper(gm, example_inputs, **kwargs): + def debug_wrapper( + gm: torch.fx.GraphModule, + example_inputs: Sequence["InputType"], + **kwargs: Unpack["_CompileFxKwargsEx"], + ) -> Union["CompiledFxGraph", str]: from torch._subclasses import FakeTensorMode compiler_fn = functools.partial(unconfigured_compiler_fn, **kwargs) @@ -104,11 +119,15 @@ def debug_wrapper(gm, example_inputs, **kwargs): # We may run regular PyTorch compute that may trigger Dynamo, do NOT # recursively attempt to accuracy minify in that case! - def deferred_for_real_inputs(real_inputs): + def deferred_for_real_inputs( + real_inputs: Sequence["InputType"], **_kwargs: object + ) -> Any: # This is a bit obscure: if we recursively try to accuracy minify # the SAME function, this would trigger. But most of the time # we should never hit this branch + assert not _kwargs if config.repro_after != "aot": + assert not isinstance(inner_compiled_fn, str) return inner_compiled_fn(real_inputs) with config.patch(repro_after=None): return inner_debug_fn(real_inputs) @@ -165,11 +184,11 @@ def inner_debug_fn(real_inputs): raise AccuracyError("Bad accuracy detected") else: # Call the compiled function with real inputs - return inner_compiled_fn(real_inputs) + return inner_compiled_fn(real_inputs) # type: ignore[operator] else: try: # Call the compiled function with real inputs - out = inner_compiled_fn(real_inputs) + out = inner_compiled_fn(real_inputs) # type: ignore[operator] # sync cuda kernels to ensure IMA detection for arg in example_inputs: if isinstance(arg, torch.Tensor) and arg.is_cuda: @@ -194,7 +213,7 @@ def inner_debug_fn(real_inputs): if config.repro_after == "aot": compiled_fn = deferred_for_real_inputs compiled_fn._boxed_call = True # type: ignore[attr-defined] - return compiled_fn + return compiled_fn # type: ignore[return-value] else: return inner_compiled_fn @@ -432,6 +451,7 @@ def sync(): try: compile_mod = compile_fx_inner(fx_g, args) + assert not isinstance(compile_mod, str) compile_mod(args) sync() except Exception as e: @@ -601,6 +621,7 @@ def save_hook(name, val): with intermediate_hook(save_hook), tqdm( desc="Saving inductor intermediates", total=total ) as pbar: + assert not isinstance(compiled, str) compiled(new_args) assert not new_args @@ -717,6 +738,7 @@ def repro_run(options, mod, load_args): from torch.cuda import synchronize compiled = compile_fx_inner(mod, args) + assert not isinstance(compiled, str) if options.accuracy != "": # We don't really respect --accuracy vs --strict-accuracy here, it @@ -731,14 +753,16 @@ def repro_run(options, mod, load_args): raise AccuracyError("Bad accuracy detected") else: need_sync = False + for arg in args: if isinstance(arg, torch.Tensor) and arg.is_cuda: need_sync = True break - ref = compiled(list(args)) + + compiled(list(args)) + if need_sync: synchronize() # ensure segfaults are surfaced - return lambda: compiled(list(args)) # TODO: lazily load the inputs or something, rather than cloning them diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 7fd3561e63aed..cdf0649a7cbc4 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -5,7 +5,7 @@ import warnings import weakref from collections.abc import MutableMapping -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Set, Type import torch.nn @@ -36,9 +36,8 @@ class MutableSideEffects(MutableLocalBase): the graph runs. """ - def __init__(self, source: Source, is_modified: bool = False): + def __init__(self, is_modified: bool = False): super().__init__(MutableLocalSource.Existing) - self.source = source self.is_modified = is_modified @@ -47,20 +46,18 @@ class AttributeMutation(MutableLocalBase): VariableTracker.mutable_local marker to track changes to attributes """ - def __init__(self, typ: MutableLocalSource, source: Optional[Source]): + def __init__(self, typ: MutableLocalSource): super().__init__(typ) - self.source = source class AttributeMutationExisting(AttributeMutation): - def __init__(self, source: Source): - super().__init__(MutableLocalSource.Existing, source) - self.source = source + def __init__(self): + super().__init__(MutableLocalSource.Existing) class AttributeMutationNew(AttributeMutation): - def __init__(self, source: Optional[Source], cls_source: Optional[Source]): - super().__init__(MutableLocalSource.Local, source) + def __init__(self, cls_source: Optional[Source] = None): + super().__init__(MutableLocalSource.Local) self.cls_source = cls_source @@ -76,7 +73,7 @@ class SideEffects: """ id_to_variable: Dict[int, VariableTracker] - store_attr_mutations: Dict[MutableLocalBase, Dict[str, VariableTracker]] + store_attr_mutations: Dict[VariableTracker, Dict[str, VariableTracker]] keepalive: List[Any] def __init__( @@ -175,13 +172,13 @@ def check_allowed_side_effect(self, item): def store_attr(self, item: VariableTracker, name: str, value: VariableTracker): assert self.is_attribute_mutation(item) self.check_allowed_side_effect(item) - if item.mutable_local not in self.store_attr_mutations: - self.store_attr_mutations[item.mutable_local] = {} - self.store_attr_mutations[item.mutable_local][name] = value + if item not in self.store_attr_mutations: + self.store_attr_mutations[item] = {} + self.store_attr_mutations[item][name] = value def load_attr(self, item, name, deleted_ok=False): assert self.is_attribute_mutation(item) - result = self.store_attr_mutations[item.mutable_local][name] + result = self.store_attr_mutations[item][name] if not deleted_ok and isinstance(result, variables.DeletedVariable): unimplemented("read deleted attribute") return result @@ -216,19 +213,19 @@ def is_attribute_mutation(self, item): def has_pending_mutation(self, item): return self.is_attribute_mutation(item) and bool( - self.store_attr_mutations.get(item.mutable_local) + self.store_attr_mutations.get(item) ) def has_pending_mutation_of_attr(self, item, name): return self.is_attribute_mutation( item - ) and name in self.store_attr_mutations.get(item.mutable_local, ()) + ) and name in self.store_attr_mutations.get(item, ()) def is_modified(self, item): if isinstance(item.mutable_local, AttributeMutationNew): return True if self.is_attribute_mutation(item): - return item.mutable_local in self.store_attr_mutations + return item in self.store_attr_mutations return item.mutable_local.is_modified def _track_obj( @@ -249,7 +246,7 @@ def _track_obj( f"Source of previously tracked object: {self.id_to_variable[id(item)].source}." ) - variable.mutable_local = mutable_cls(variable.source) + variable.mutable_local = mutable_cls() self.id_to_variable[id(item)] = variable self.keepalive.append(item) return variable @@ -285,7 +282,7 @@ def track_object_new( unimplemented(f"Unable to construct the object of type {user_cls}") variable = variable_cls( obj, - mutable_local=AttributeMutationNew(None, cls_source), + mutable_local=AttributeMutationNew(cls_source), **options, ) self.id_to_variable[id(obj)] = variable @@ -323,7 +320,7 @@ def track_cell_new( ): obj = object() variable = variables.NewCellVariable( - mutable_local=AttributeMutationNew(None, None), + mutable_local=AttributeMutationNew(), ) self.id_to_variable[id(obj)] = variable self.keepalive.append(obj) @@ -331,7 +328,8 @@ def track_cell_new( def track_cell_existing(self, source: Source, item: Any): variable = variables.NewCellVariable( - mutable_local=AttributeMutationExisting(source), + mutable_local=AttributeMutationExisting(), + source=source, ) self.id_to_variable[id(item)] = variable self.keepalive.append(item) @@ -339,7 +337,8 @@ def track_cell_existing(self, source: Source, item: Any): def track_global_existing(self, source: Source, item: Any): variable = variables.NewGlobalVariable( - mutable_local=AttributeMutationExisting(source), + mutable_local=AttributeMutationExisting(), + source=source, ) self.id_to_variable[id(item)] = variable self.keepalive.append(item) @@ -362,35 +361,30 @@ def track_tensor_variables_from_runahead_side_effects(self, other): self.track_object_existing(other_item, other_variable) def prune_dead_object_new(self, tx): - live_new_objects = set() + live_new_objects: Set[VariableTracker] = set() # use this to avoid cycles in mutable_local (though I'm not sure if that # can actually happen). - visited: Any = set({}) + visited: Set[VariableTracker] = set({}) def visit(var: VariableTracker): - mutable_local = var.mutable_local - if mutable_local is None: - return - if mutable_local in visited: + if var in visited: return - visited.add(mutable_local) + visited.add(var) # Object may have been mutated, store this mutation. - if isinstance(mutable_local, AttributeMutationNew): - live_new_objects.add(mutable_local) + if isinstance(var.mutable_local, AttributeMutationNew): + live_new_objects.add(var) # It's possible that we have mutated the value of this variable # to be another one. The new value is in store_attr_mutations. # Also recurse through the new value to detect alive AttributeMutationNew. - if var.mutable_local in self.store_attr_mutations: + if var in self.store_attr_mutations: VariableTracker.visit( - visit, self.store_attr_mutations[var.mutable_local] + visit, self.store_attr_mutations[var] # noqa: F821 ) - def is_live(var: Union[MutableLocalBase, VariableTracker]): - if isinstance(var, AttributeMutationNew): + def is_live(var: VariableTracker): + if isinstance(var.mutable_local, AttributeMutationNew): return var in live_new_objects - if isinstance(var, VariableTracker): - return is_live(var.mutable_local) return True pre_existing_vars = [ @@ -403,6 +397,10 @@ def is_live(var: Union[MutableLocalBase, VariableTracker]): # during a graph break (tx.symbolic_locals), and mutation on pre-existing variables. # Recursively visit Variables and see if any of them have been mutated. VariableTracker.visit(visit, (tx.stack, tx.symbolic_locals, pre_existing_vars)) + # Manually release the self-referential function, which indirectly + # captures certain `VariableTracker` and affects parts of PT test/logic + # that are sensitive to when certain objects get released. + del visit # NB: cell variable handling.is tricky. # cell variables must stay alive if any NestedUserFunctionVariable @@ -420,23 +418,22 @@ def is_live(var: Union[MutableLocalBase, VariableTracker]): def mutation(self, var): self.check_allowed_side_effect(var) if isinstance(var.mutable_local, MutableSideEffects): - var.mutable_local = MutableSideEffects(var.mutable_local.source, True) + var.mutable_local.is_modified = True def _get_modified_vars(self): return [var for var in self.id_to_variable.values() if self.is_modified(var)] def codegen_save_tempvars(self, cg: PyCodegen): for var in self._get_modified_vars(): - if isinstance( - var.mutable_local, (AttributeMutationExisting, AttributeMutationNew) - ) and isinstance(var, variables.NewCellVariable): + if isinstance(var.mutable_local, AttributeMutationNew) and isinstance( + var, variables.NewCellVariable + ): cg.add_push_null( lambda: cg.load_import_from(utils.__name__, "make_cell") ) cg.extend_output(create_call_function(0, False)) cg.add_cache(var) - if isinstance(var.mutable_local, AttributeMutationNew): - var.mutable_local.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] + var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] elif isinstance(var.mutable_local, AttributeMutationNew): if isinstance(var, variables.AutogradFunctionContextVariable): unimplemented("AutogradFunctionContextVariable escaped") @@ -446,11 +443,11 @@ def codegen_save_tempvars(self, cg: PyCodegen): cg(var.mutable_local.cls_source) cg.extend_output(create_call_function(1, False)) cg.add_cache(var) - var.mutable_local.source = LocalSource(cg.tempvars[var]) + var.source = LocalSource(cg.tempvars[var]) elif var in cg.tempvars: assert cg.tempvars.get(var) is None # subsequent usage should point to the original variable - cg(var.mutable_local.source) + cg(var.source) cg.add_cache(var) for ctx, args in self.save_for_backward: @@ -553,7 +550,7 @@ def codegen_update_mutated(self, cg: PyCodegen): if isinstance(var, variables.ListVariable): # old[:] = new cg(var, allow_cache=False) - cg(var.mutable_local.source) # type: ignore[attr-defined] + cg(var.source) # type: ignore[attr-defined] cg.extend_output( [ cg.create_load_const(None), @@ -568,7 +565,7 @@ def codegen_update_mutated(self, cg: PyCodegen): for name in _manual_update_dict.__code__.co_varnames: varname_map[name] = cg.tx.output.new_var() - cg(var.mutable_local.source) # type: ignore[attr-defined] + cg(var.source) # type: ignore[attr-defined] cg.extend_output( [create_instruction("STORE_FAST", argval=varname_map["dict_to"])] ) @@ -578,7 +575,7 @@ def codegen_update_mutated(self, cg: PyCodegen): [create_instruction("STORE_FAST", argval=varname_map["dict_from"])] ) - cg(var.mutable_local.source) # type: ignore[attr-defined] + cg(var.source) # type: ignore[attr-defined] cg.load_method("clear") # unfortunately can't just use DICT_MERGE due to possible custom behaviors @@ -603,12 +600,12 @@ def codegen_update_mutated(self, cg: PyCodegen): # + only if a key was removed from the input dict # (4) update the original dictionary with the dict created in (2) - cg(var.mutable_local.source) # type: ignore[attr-defined] + cg(var.source) # type: ignore[attr-defined] cg.load_method("update") cg(var, allow_cache=False) if var.should_reconstruct_all: - cg(var.mutable_local.source) # type: ignore[attr-defined] + cg(var.source) # type: ignore[attr-defined] cg.load_method("clear") suffixes.append( @@ -669,12 +666,12 @@ def codegen_update_mutated(self, cg: PyCodegen): # for this reversal, we iterate through the mutable attributes # in reverse order. for name, value in reversed( - self.store_attr_mutations.get(var.mutable_local, {}).items() + self.store_attr_mutations.get(var, {}).items() ): if isinstance(var, variables.NewGlobalVariable): cg.tx.output.update_co_names(name) cg(value) - assert isinstance(var.mutable_local.source, GlobalSource) # type: ignore[attr-defined] + assert isinstance(var.source, GlobalSource) # type: ignore[attr-defined] suffixes.append( [create_instruction("STORE_GLOBAL", argval=name)] ) @@ -683,7 +680,7 @@ def codegen_update_mutated(self, cg: PyCodegen): var.mutable_local, AttributeMutationExisting ) and hasattr(getattr(var, "value", None), name): cg.tx.output.update_co_names(name) - cg(var.mutable_local.source) + cg(var.source) suffixes.append( [create_instruction("DELETE_ATTR", argval=name)] ) @@ -694,7 +691,7 @@ def codegen_update_mutated(self, cg: PyCodegen): # __setattr__ is defined on this object, so call object.__setattr__ directly cg.load_import_from("builtins", "object") cg.load_method("__setattr__") - cg(var.mutable_local.source) # type: ignore[attr-defined] + cg(var.source) # type: ignore[attr-defined] cg(variables.ConstantVariable(name)) cg(value) suffixes.append( @@ -703,20 +700,20 @@ def codegen_update_mutated(self, cg: PyCodegen): else: cg.tx.output.update_co_names(name) cg(value) - cg(var.mutable_local.source) + cg(var.source) suffixes.append([create_instruction("STORE_ATTR", argval=name)]) elif isinstance(var, variables.TupleIteratorVariable): for _ in range(var.index): cg.add_push_null( lambda: cg.load_import_from(utils.__name__, "iter_next") ) - cg(var.mutable_local.source) # type: ignore[attr-defined] + cg(var.source) # type: ignore[attr-defined] cg.call_function(1, False) cg.pop_top() elif isinstance(var, variables.RandomVariable): # set correct random seed state def gen_fn(): - cg(var.mutable_local.source) # type: ignore[attr-defined] + cg(var.source) # type: ignore[attr-defined] cg.load_attr("setstate") cg.add_push_null(gen_fn) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 09b7cbdf5023a..2e7a1daa7e830 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -306,15 +306,17 @@ def __post_init__(self): assert self.idx is not None def reconstruct(self, codegen): - def gen_fn(): - self.base.reconstruct(codegen) - codegen.append_output(codegen.create_load_attr(self.prop.method_name())) + codegen.add_push_null( + lambda: codegen.load_import_from( + utils.__name__, f"call_{self.prop.method_name()}" + ) + ) + self.base.reconstruct(codegen) - codegen.add_push_null(gen_fn) if self.idx is not None: codegen.append_output(codegen.create_load_const(self.idx)) codegen.extend_output( - create_call_function(1 if self.idx is not None else 0, False) + create_call_function(2 if self.idx is not None else 1, False) ) def guard_source(self): diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 9d4b644543c44..09c2c59e60944 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -70,8 +70,8 @@ LazyString, proxy_args_kwargs, ) -from .variables.base import is_side_effect_safe, MutableLocal, typestr, VariableTracker -from .variables.builder import VariableBuilder, wrap_fx_proxy +from .variables.base import MutableLocal, typestr, VariableTracker +from .variables.builder import wrap_fx_proxy from .variables.builtin import BuiltinVariable from .variables.constant import ConstantVariable from .variables.ctx_manager import ( @@ -88,6 +88,7 @@ UserMethodVariable, ) from .variables.iter import MAX_ITERATOR_LIMIT +from .variables.lazy import LazyVariableTracker from .variables.lists import ( BaseListVariable, ListIteratorVariable, @@ -98,7 +99,6 @@ from .variables.misc import ( ClosureVariable, GetAttrVariable, - InlinedClosureVariable, NullVariable, PythonModuleVariable, UnknownVariable, @@ -377,6 +377,25 @@ def log_graph_break(code_options, reason="", exc_info=False, user_stack=None): code_options["co_firstlineno"], ) + user_stack_formatted = "".join(traceback.format_list(user_stack)) + user_stack_trace = ( + "Graph break in user code at %s:%s\nReason: %s\nUser code traceback:\n%s" # noqa: UP031 + % ( + frame_loc[0], + frame_loc[1], + reason, + user_stack_formatted, + ) + ) + torch._logging.trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "dynamo_graph_break_reason", + "encoding": "string", + }, + payload_fn=lambda: f"{user_stack_trace}\n{traceback.format_exc() if exc_info else ''}", + ) + # torch._dynamo.explain() formats this a little nicer, and presents a slightly # more actionable user code pointer if ( @@ -384,16 +403,11 @@ def log_graph_break(code_options, reason="", exc_info=False, user_stack=None): and not explain and graph_break_dup_warning_checker.add(frame_loc) ): - user_stack_formatted = "".join(traceback.format_list(user_stack)) # This log line MUST contain the string "Graph break in user code", # This log line is exercised from # python test/dynamo/test_exc.py -k test_graph_break_log graph_break_log.debug( - "Graph break in user code at %s:%s\nReason: %s\nUser code traceback:\n%s", - frame_loc[0], - frame_loc[1], - reason, - user_stack_formatted, + user_stack_trace, exc_info=exc_info, ) else: @@ -550,6 +564,11 @@ def inner(self: "InstructionTranslatorBase", inst: Instruction): if push: self.push(value) self.jump(inst) + elif isinstance(result, SymNodeVariable): + if result.evaluate_expr(): + if push: + self.push(value) + self.jump(inst) else: unimplemented( "generic_jump on UserDefined with __bool__ returning non-constant" @@ -798,16 +817,27 @@ def maybe_has_backedge(self): return True return False - def cell_and_freevars(self): - if not hasattr(self, "_cell_and_freevars"): - self._cell_and_freevars = tuple( - self.code_options["co_cellvars"] or [] - ) + tuple(self.code_options["co_freevars"] or []) + def cellvars(self): + if not hasattr(self, "_cellvars"): + self._cellvars = tuple(self.code_options["co_cellvars"] or []) + # An inlined function might depend on the cellvar of the parent + # function. So, recursively obtain parent cellvars. + if isinstance(self, InliningInstructionTranslator): + self._cellvars += self.parent.cellvars() + return self._cellvars + def freevars(self): + if not hasattr(self, "_freevars"): + self._freevars = tuple(self.code_options["co_freevars"] or []) # An inlined function might depend on the freevar of the parent - # function. So, recursively obtain parent cell and freevars. + # function. So, recursively obtain parent freevars. if isinstance(self, InliningInstructionTranslator): - self._cell_and_freevars += self.parent.cell_and_freevars() + self._freevars += self.parent.freevars() + return self._freevars + + def cell_and_freevars(self): + if not hasattr(self, "_cell_and_freevars"): + self._cell_and_freevars = self.cellvars() + self.freevars() return self._cell_and_freevars def prune_dead_locals(self): @@ -815,12 +845,99 @@ def prune_dead_locals(self): # implicit use by super() # reads = reads | {"__class__"} # output variables? - reads = reads | set(self.cell_and_freevars()) + reads = reads | set(self.freevars()) + + # First we prune the non-cell local vars, this allows us to prune more + # cell local vars later on (e.g., if we manage to prune a + # `NestedUserFunctionVariable` that makes use of some cell locals). + cellvars = set(self.cellvars()) self.symbolic_locals = { - k: v for k, v in self.symbolic_locals.items() if k in reads + k: v for k, v in self.symbolic_locals.items() if k in cellvars or k in reads } + + # Then we prune the side effects, which might enable us to prune more + # cellvars afterwards. self.output.side_effects.prune_dead_object_new(self) + # Then we prune the cell locals. + # + # Note that we keep certain cell locals, because the current mechanism + # for codegen closure initialization for nested function creation is: + # 1. `NestedUserFunctionVariable` codegen assumes its closure has been + # initialized properly by its creator, i.e., the tuple of cells will + # be populated with correct content before the function is used. + # 2. `OutputGraph::compile_subgraph`, we populate the tuple of cells + # _after_ emitting the `MAKE_FUNCTION` bytecode, via `STORE_DEREF`; + # these `STORE_DEREF` are generated partly based on the current + # `symbolic_locals`. + # As a result, we must be careful not to prune the cell locals that'll + # allow `OutputGraph` to generate the proper `STORE_DEREF`. + # + # On the other hand, we do want to prune away the truly dead ones, e.g., + # say after we invoke a nested function, and the function is never used + # again. So here we do some conservative pruning, by tracing from a + # series of must-live root variables -- for any reachable cell, it must + # be kept alive. + # + # TODO(#137123) there are extra complexities due to side-effects (e.g., + # the nested function leaking out into backward hook or globals). We + # could probably improve the variable tracing here to include the + # relevant variables in `output.side_effects`. + if self.output.side_effects.is_empty(): + cellvars_that_must_live = set() + visited = set() + + def visit(var: VariableTracker): + if var in visited: + return + visited.add(var) + + # Avoid realizing the lazy variable which could end up adding a + # graph input which isn't needed, this is sound because there's + # there doesn't seem to be a way to go from a + # `LazyVariableTracker` to `ClosureVariable`. TODO is this + # really true in general? + if isinstance(var, LazyVariableTracker): + return + + # We need to do this explicitly to walk the entire use chain, + # e.g., from a `ClosureVariable` to its underlying + # `NestedUserFunctionVariable`, rather than just stopping at the + # `ClosureVariable` with a name. + if isinstance(var, ClosureVariable): + cellvars_that_must_live.add(var.name) + + # We only recur if the closure variable has been initialized. + actual_var = self.symbolic_locals.get(var.name, None) + if actual_var is not None: + VariableTracker.visit(visit, actual_var) + + # Populate `cellvars_that_must_live` + # + # NOTE: Don't trace from the cell locals which aren't explicitly + # read anymore; if they are indirectly used, they will be reached by + # other roots. These initially excluded cells are the ones that will + # hopefully be pruned. + local_roots = [ + var + for name, var in self.symbolic_locals.items() + if name not in cellvars or name in reads + ] + VariableTracker.visit( + visit, (local_roots, self.stack, self.output.backward_state) + ) + # Manually release the self-referential nested function, which + # captures `self.symbolic_locals` and affects parts of PT test/logic + # that are sensitive to when certain objects get released. + del visit + + # Only keep locals that will be read, or are cellvars that must live. + self.symbolic_locals = { + k: v + for k, v in self.symbolic_locals.items() + if k in reads or k in cellvars_that_must_live + } + def call_function( self, fn: VariableTracker, @@ -1126,15 +1243,14 @@ def _load_global(self, inst): except KeyError: return self.load_builtin(inst) - source = GlobalSource(name) - self.push(VariableBuilder(self, source)(value)) + self.push(VariableTracker.build(self, value, GlobalSource(name))) @functools.cached_property def nn_modules_globals_vt(self): module_name = "torch.nn.modules.module" module_source = self.import_source(module_name) fglobals_value = importlib.import_module(module_name) # type: ignore[assignment] - return VariableBuilder(self, module_source)(fglobals_value) + return VariableTracker.build(self, fglobals_value, module_source) def LOAD_GLOBAL(self, inst): if sys.version_info >= (3, 11) and sys.version_info < (3, 13) and inst.arg % 2: @@ -1276,7 +1392,7 @@ def load_builtin_from_argval(self, argval): self.output.name_of_builtins_dict_key_in_fglobals ) var_source = GetItemSource(builtins_source, argval) - self.push(VariableBuilder(self, var_source)(val)) + self.push(VariableTracker.build(self, val, var_source)) else: assert is_builtin_constant(val) self.push(ConstantVariable.create(value=val)) @@ -2047,13 +2163,7 @@ def DUP_TOP_TWO(self, inst): self.push(b) self.push(a) - def FORMAT_VALUE(self, inst): - flags = inst.arg - if (flags & 0x04) == 0x04: - fmt_spec = self.pop() - else: - fmt_spec = ConstantVariable.create("") - + def _format_value(self, fmt_spec, flags): value = self.pop() if isinstance(value, SymNodeVariable): from torch._dynamo.variables.lazy import ( @@ -2077,6 +2187,15 @@ def FORMAT_VALUE(self, inst): self.call_function(BuiltinVariable(str.format), [fmt_var, value], {}) + def FORMAT_VALUE(self, inst): + flags = inst.arg + if (flags & 0x04) == 0x04: + fmt_spec = self.pop() + else: + fmt_spec = ConstantVariable.create("") + + return self._format_value(fmt_spec, flags) + def BUILD_STRING(self, inst): format_string_parts: List[str] = [] args: List[VariableTracker] = [] @@ -2473,20 +2592,11 @@ def SET_FUNCTION_ATTRIBUTE(self, inst): self.push(fn) - def _format_value_313(self, fmt_spec): - value = self.pop() - if isinstance(value, SymNodeVariable): - value = ConstantVariable.create(str(value.sym_num)) - - fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}") - - self.call_function(BuiltinVariable(str.format), [fmt_var, value], {}) - def FORMAT_SIMPLE(self, inst): - self._format_value_313(ConstantVariable.create("")) + self._format_value(ConstantVariable.create(""), 0) def FORMAT_WITH_SPEC(self, inst): - self._format_value_313(self.pop()) + self._format_value(self.pop(), 0) def is_non_empty_graph(self): if self.output.count_calls() > 1: @@ -2618,6 +2728,7 @@ def __init__( # The first field of tuple is the fully qualified name of current module # in original hierarchy. The second field is the type of current nn.module self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {} + self.num_calls: Dict[str, int] = {} # Flag to indicate whether tracing is used for export. self.export = export self.one_graph = False @@ -2648,7 +2759,8 @@ def __init__( class InstructionTranslator(InstructionTranslatorBase): - mutated_closure_cell_contents: Set[str] + mutated_closure_cell_ids: Set[int] + contents_var_to_mutated_cell: Dict[VariableTracker, Any] @staticmethod def current_tx() -> "InstructionTranslator": @@ -2676,7 +2788,7 @@ def __init__( one_graph, export, export_constraints, - mutated_closure_cell_contents: Set[str], + mutated_closure_cell_ids: Set[int], frame_state, speculation_log: SpeculationLog, distributed_state: Optional[DistributedState], @@ -2721,7 +2833,8 @@ def __init__( with tracing(self.output.tracing_context), self.set_current_tx(): self.one_graph: bool = one_graph self.export = export - self.mutated_closure_cell_contents = mutated_closure_cell_contents + self.mutated_closure_cell_ids = mutated_closure_cell_ids + self.contents_var_to_mutated_cell = {} if self.export: assert ( self.one_graph @@ -3239,19 +3352,15 @@ def STORE_DEREF(self, inst): # type: ignore[override] self.symbolic_locals[inst.argval], self.pop() ) else: + root_tx = self.output.root_tx if ( maybe_cell is not None - and maybe_cell.source.name() - not in self.output.root_tx.mutated_closure_cell_contents + and maybe_cell in root_tx.contents_var_to_mutated_cell + and id(root_tx.contents_var_to_mutated_cell[maybe_cell]) + not in root_tx.mutated_closure_cell_ids ): - # Why is the source name here unique? - # mutated_closure_cell_contents is a per-frame - # concept, and sources identify, e.g., particular - # locals from the frame. If you had two locals, - # they'll get different source names, and therefore - # differ here. - self.output.root_tx.mutated_closure_cell_contents.add( - maybe_cell.source.name() + self.output.root_tx.mutated_closure_cell_ids.add( + id(root_tx.contents_var_to_mutated_cell[maybe_cell]) ) raise exc.UnspecializeRestartAnalysis unimplemented("write to __closure__ while inlining") @@ -3275,13 +3384,10 @@ def _load_closure(self, name): if name in self.closure_cells: return self.closure_cells[name] else: - return InlinedClosureVariable(name=name) - - def check_replace_is_safe(self, oldvar): - if not is_side_effect_safe(oldvar.mutable_local): - unimplemented( - "HigherOrderOperator: Mutating a variable not in the current scope (replace_all)" - ) + # We model unmodified cells captured by `UserFunctionVariable` as + # their contents, in `self.symbolic_locals`. See + # `UserFunctionVariable::bind_args`. + return self.symbolic_locals[name] def should_compile_partial_graph(self): return False # inlining functions is all-or-nothing @@ -3307,7 +3413,7 @@ def get_globals_source_and_value(self, name): fglobals_value = torch.package.package_importer._package_imported_modules[module_name] # type: ignore[assignment] else: fglobals_value = importlib.import_module(module_name) # type: ignore[assignment] - fglobals_vt = VariableBuilder(self, module_source)(fglobals_value) + fglobals_vt = VariableTracker.build(self, fglobals_value, module_source) global_source = AttrSource(module_source, name) else: globals_name = self.output.install_global_by_id( @@ -3315,7 +3421,7 @@ def get_globals_source_and_value(self, name): ) globals_source = GlobalSource(globals_name) fglobals_value = self.f_globals # type: ignore[assignment] - fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value) + fglobals_vt = VariableTracker.build(self, fglobals_value, globals_source) global_source = GetItemSource(globals_source, name) # type: ignore[assignment] return fglobals_value, fglobals_vt, global_source @@ -3334,7 +3440,7 @@ def _load_global(self, inst): except KeyError: return self.load_builtin(inst) - self.push(VariableBuilder(self, global_source)(value)) + self.push(VariableTracker.build(self, value, global_source)) def STORE_GLOBAL(self, inst): if self.f_globals is self.parent.f_globals: diff --git a/torch/_dynamo/test_minifier_common.py b/torch/_dynamo/test_minifier_common.py index b05542d578f43..b8e31ee0f578d 100644 --- a/torch/_dynamo/test_minifier_common.py +++ b/torch/_dynamo/test_minifier_common.py @@ -100,8 +100,8 @@ def _maybe_subprocess_run(self, args, *, isolate, cwd=None): # NB: Can't use save_config because that will omit some fields, # but we must save and reset ALL fields - dynamo_config = torch._dynamo.config.shallow_copy_dict() - inductor_config = torch._inductor.config.shallow_copy_dict() + dynamo_config = torch._dynamo.config.get_config_copy() + inductor_config = torch._inductor.config.get_config_copy() try: stderr = io.StringIO() log_handler = logging.StreamHandler(stderr) diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 2e998873d0c41..9281c7c7e284e 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -24,6 +24,7 @@ import torch from torch import fx +from torch._dynamo.backends.debugging import aot_eager from torch._dynamo.output_graph import OutputGraph from . import config, eval_frame, optimize_assert, reset @@ -190,7 +191,7 @@ def insert_nops(instructions: List[Any], code_options: Any) -> None: torch_function_mode_stack=[], ) - return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) + return GuardedCode(code, CheckFunctionManager(graph).guard_manager, CompileId(0, 0)) # type: ignore[arg-type] class CompileCounter: @@ -245,6 +246,37 @@ def __call__( return gm.forward +class AotEagerAndRecordGraphs: + def __init__(self) -> None: + self.graphs: List[torch.fx.GraphModule] = [] + self.fw_graphs: List[torch.fx.GraphModule] = [] + self.bw_graphs: List[torch.fx.GraphModule] = [] + + def __call__( + self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor] + ) -> Callable[..., Any]: + self.graphs.append(gm) + + def fw_compiler( + gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor] + ) -> Callable[..., Any]: + self.fw_graphs.append(gm) + return gm.forward + + def bw_compiler( + gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor] + ) -> Callable[..., Any]: + self.bw_graphs.append(gm) + return gm.forward + + return aot_eager( + gm, + example_inputs, + fw_compiler=fw_compiler, + bw_compiler=bw_compiler, + ) + + def strip_comment(code: str) -> str: return re.sub(r"(?m)^ *#.*\n?", "", code) diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 39372d83a60d9..7a8d25d98c4d4 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -421,6 +421,7 @@ "torch._C._cpu._is_avx512_bf16_supported", "torch._C._cpu._is_amx_tile_supported", "torch._C._cpu._init_amx", + "torch._C._cpu._is_arm_sve_supported", "torch._C._crash_if_aten_asan", "torch._C._crash_if_csrc_asan", "torch._C._crash_if_csrc_ubsan", @@ -2445,6 +2446,7 @@ "torch._C._cpu._is_avx512_bf16_supported", "torch._C._cpu._is_amx_tile_supported", "torch.cpu._init_amx", + "torch._C._cpu._is_arm_sve_supported", "torch.cpu.current_device", "torch.cpu.current_stream", "torch.cpu.device_count", @@ -2912,6 +2914,9 @@ def get_tensor_method(): method, (types.MethodDescriptorType, types.WrapperDescriptorType) ): s.add(method) + + # mlazos: this is a function which we handle specially in TensorVariable + s.add(torch.Tensor.__contains__) # type: ignore[arg-type] return frozenset(s) @@ -3154,7 +3159,6 @@ def is_numpy_type_info(obj) -> bool: "hypothesis", "networkx", "numpy", - "omegaconf", "onnx", "onnxruntime", "onnx_tf", @@ -3246,6 +3250,7 @@ def _module_dir(m: types.ModuleType): "torch._functorch.functional_call", "torch._functorch.vmap", "torch._higher_order_ops.associative_scan", + "torch._higher_order_ops.invoke_subgraph", "torch._higher_order_ops.scan", "torch._higher_order_ops.strict_mode", "torch._higher_order_ops.while_loop", @@ -3269,6 +3274,7 @@ def _module_dir(m: types.ModuleType): "torch.nn", "torch.overrides", "torch.random", + "torch.return_types", "torch.sparse", "torch.testing", "torch.utils._content_store", diff --git a/torch/_dynamo/types.py b/torch/_dynamo/types.py index 16ef7b5821c2a..298741a4e9586 100644 --- a/torch/_dynamo/types.py +++ b/torch/_dynamo/types.py @@ -3,7 +3,7 @@ import types from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union -# CacheEntry has a `check_fn` field for the guard, and a `code` field for the code object. +# CacheEntry has a `guard_manager` field for the guard, and a `code` field for the code object. from torch._C._dynamo.eval_frame import ( _CacheEntry as CacheEntry, _ExtraState as ExtraState, @@ -46,7 +46,7 @@ def __call__(self, f_locals: Dict[str, object]) -> bool: @dataclasses.dataclass class GuardedCode: code: types.CodeType - check_fn: GuardFn + guard_manager: GuardFn compile_id: CompileId trace_annotation: str = "Unknown" @@ -67,7 +67,7 @@ def __call__( class DynamoGuardHook(Protocol): def __call__( self, - guard_fn: GuardFn, + guard_manager: GuardFn, code: types.CodeType, f_locals: Dict[str, object], index: int, diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index cf69bae46d792..775ec8b488dd9 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -21,8 +21,10 @@ import os import re import sys +import textwrap import threading import time +import traceback import types import typing import uuid @@ -54,7 +56,7 @@ Union, ValuesView, ) -from typing_extensions import Literal, TypeGuard +from typing_extensions import Literal, TypeIs import torch import torch._functorch.config @@ -118,6 +120,8 @@ T = TypeVar("T") unpatched_nn_module_getattr = torch.nn.Module.__getattr__ +unpatched_nn_module_call = torch.nn.Module.__call__ +unpatched_nn_module_call_impl = torch.nn.Module._call_impl counters: DefaultDict[str, Counter[str]] = collections.defaultdict(collections.Counter) optimus_scuba_log: Dict[str, Any] = {} @@ -234,16 +238,6 @@ def add_remote_cache_time_saved(time_saved_ns: int, is_backward: bool = False) - _add_time_spent(key, "remote_cache_time_saved", time_saved) -def get_cache_stats() -> Dict[str, Any]: - """Get a bunch of metadata about cache hits and misses to use in chromium events""" - cache_stats = { - "fxgraph_cache_hit": counters["inductor"]["fxgraph_cache_hit"], - "fxgraph_cache_miss": counters["inductor"]["fxgraph_cache_miss"], - "fxgraph_cache_bypass": counters["inductor"]["fxgraph_cache_bypass"], - } - return cache_stats - - # dynamo_timed is a context manager # By wrapping a function in dynamo_timed, we can store a record in compilation_time_metrics # where the key is the functions name. @@ -288,9 +282,10 @@ def dynamo_timed( try: with torch.profiler.record_function(f"{key} (dynamo_timed)"): t0 = time.time() - chromium_log.log_event_start(key, start, None) if phase_name: - chromium_log.log_event_start(phase_name, start) + chromium_log.log_event_start(phase_name, start, {"fn_name": key}) + else: + chromium_log.log_event_start(key, start, {}) yield time_spent = time.time() - t0 compilation_time_metrics[key].append(time_spent) @@ -304,16 +299,15 @@ def dynamo_timed( chromium_log.log_event_end( phase_name, time.time_ns(), - {"cache_stats": get_cache_stats()}, + {}, start, ) - chromium_log.log_event_end( - key, time.time_ns(), {"cache_stats": get_cache_stats()}, start - ) + else: + chromium_log.log_event_end(key, time.time_ns(), {}, start) # Only record backward compilation metrics if phase_name is not None! if phase_name: frame_key = str(curr_frame) - # fwd only compilation stages: entire_frame_compile, backend_compile. + # fwd only compilation stages: entire_frame_compile, backend_compile, aotdispatch. # use frame_key as time aggregation key. if fwd_only and fail_type is None: _add_time_spent(frame_key, phase_name, time_spent) @@ -349,10 +343,18 @@ def dynamo_timed( remote_cache_time_saved = frame_phase_timing[ compile_id ].get("remote_cache_time_saved", None) + remote_fx_graph_cache_get_time = frame_phase_timing[ + compile_id + ].get("remote_fx_graph_cache_get", None) + remote_fx_graph_cache_put_time = frame_phase_timing[ + compile_id + ].get("remote_fx_graph_cache_put", None) else: inductor_compile_time = None code_gen_time = None remote_cache_time_saved = None + remote_fx_graph_cache_get_time = None + remote_fx_graph_cache_put_time = None structured_logging_overhead_s = ( torch._logging.get_structured_logging_overhead() ) @@ -364,6 +366,9 @@ def dynamo_timed( fail_reason, remote_cache_time_saved, structured_logging_overhead_s, + False, # is_forward + to_int_ms(remote_fx_graph_cache_get_time), + to_int_ms(remote_fx_graph_cache_put_time), ) record_compilation_metrics(metrics) @@ -524,7 +529,7 @@ def count_calls(g: fx.Graph) -> int: return c -def identity(x): +def identity(x: T) -> T: return x @@ -577,14 +582,14 @@ def clear(self): @overload -def istype(obj: object, allowed_types: Type[T]) -> TypeGuard[T]: +def istype(obj: object, allowed_types: Type[T]) -> TypeIs[T]: ... @overload def istype( obj: object, allowed_types: Tuple[Type[List[T]], Type[Tuple[T, ...]]] -) -> TypeGuard[T]: +) -> TypeIs[T]: ... @@ -770,6 +775,10 @@ def proxy_args_kwargs(args, kwargs): ) +def to_int_ms(v: Optional[float]) -> Optional[int]: + return None if v is None else int(v * 1000) + + @dataclasses.dataclass class CompilationMetrics: compile_id: str @@ -807,6 +816,10 @@ class CompilationMetrics: config_suppress_errors: Optional[bool] config_inline_inbuilt_nn_modules: Optional[bool] specialize_float: Optional[bool] + dynamo_config: Optional[str] + is_forward: Optional[bool] + remote_fx_graph_cache_get_time_ms: Optional[int] + remote_fx_graph_cache_put_time_ms: Optional[int] @dataclasses.dataclass @@ -818,6 +831,9 @@ class BwdCompilationMetrics: fail_reason: Optional[str] remote_cache_time_saved_s: Optional[float] structured_logging_overhead_s: Optional[float] + is_forward: Optional[bool] + remote_fx_graph_cache_get_time_ms: Optional[int] + remote_fx_graph_cache_put_time_ms: Optional[int] DEFAULT_COMPILATION_METRICS_LIMIT = 64 @@ -828,6 +844,37 @@ class BwdCompilationMetrics: ] = collections.deque(maxlen=DEFAULT_COMPILATION_METRICS_LIMIT) +def add_compilation_metrics_to_chromium(c: CompilationMetrics): + event_logger = get_chromium_event_logger() + # The following compilation metrics are related to + # dynamo, so go with the "entire frame compile" event + event_logger.add_event_data( + event_name="dynamo", + frame_key=c.frame_key, + co_name=c.co_name, + co_filename=c.co_filename, + co_firstlineno=c.co_firstlineno, + cache_size=c.cache_size, + accumulated_cache_size=c.accumulated_cache_size, + guard_count=c.guard_count, + shape_env_guard_count=c.shape_env_guard_count, + graph_op_count=c.graph_op_count, + graph_node_count=c.graph_node_count, + graph_input_count=c.graph_input_count, + fail_type=c.fail_type, + fail_reason=c.fail_reason, + fail_user_frame_filename=c.fail_user_frame_filename, + fail_user_frame_lineno=c.fail_user_frame_lineno, + # Sets aren't JSON serializable + non_compliant_ops=list(c.non_compliant_ops), + compliant_custom_ops=list(c.compliant_custom_ops), + restart_reasons=list(c.restart_reasons), + dynamo_time_before_restart_s=c.dynamo_time_before_restart_s, + has_guarded_code=c.has_guarded_code, + dynamo_config=c.dynamo_config, + ) + + def record_compilation_metrics( compilation_metrics: Union[CompilationMetrics, BwdCompilationMetrics] ): @@ -835,6 +882,7 @@ def record_compilation_metrics( _compilation_metrics.append(compilation_metrics) if isinstance(compilation_metrics, CompilationMetrics): name = "compilation_metrics" + add_compilation_metrics_to_chromium(compilation_metrics) else: name = "bwd_compilation_metrics" torch._logging.trace_structured( @@ -884,6 +932,11 @@ def get_stack(self): self.tls.stack = ["__start__"] return self.tls.stack + def get_event_data(self) -> Dict[str, Any]: + if not hasattr(self.tls, "event_data"): + self.tls.event_data = {} + return self.tls.event_data + def __init__(self): self.tls = threading.local() # Generate a unique id for this logger, which we can use in scuba to filter down @@ -893,11 +946,30 @@ def __init__(self): # TODO: log to init/id tlparse after I add support for it log.info("ChromiumEventLogger initialized with id %s", self.id_) + def add_event_data( + self, + event_name: str, + **kwargs, + ) -> None: + """ + Adds additional metadata info to an in-progress event + This metadata is recorded in the END event + """ + if event_name not in self.get_stack(): + raise RuntimeError( + "Cannot add metadata to events that aren't in progress." + "Please make sure the event has started and hasn't ended." + ) + event_data = self.get_event_data() + if event_name not in event_data: + event_data[event_name] = {} + event_data[event_name].update(kwargs) + def log_event_start( self, event_name: str, time_ns: int, - metadata: Optional[Dict[str, Any]] = None, + metadata: Dict[str, Any], ) -> None: """ Logs the start of a single event. @@ -905,20 +977,14 @@ def log_event_start( :param time_ns Timestamp in nanoseconds :param metadata: Any extra metadata associated with this event """ - - # Add compile id to metadata - if metadata is None: - metadata = {} compile_id = str(torch._guards.CompileContext.current_compile_id()) metadata["compile_id"] = compile_id - - event = self._log_timed_event( + self._log_timed_event( event_name, time_ns, "B", metadata, ) - log_chromium_event_internal(event, self.get_stack(), compile_id, self.id_) self.get_stack().append(event_name) def reset(self) -> None: @@ -927,13 +993,15 @@ def reset(self) -> None: stack = self.get_stack() stack.clear() stack.append("__start__") + event_data = self.get_event_data() + event_data.clear() def log_event_end( self, event_name: str, time_ns: int, - metadata: Optional[Dict[str, Any]] = None, - start_time_ns: Optional[int] = None, + metadata: Dict[str, Any], + start_time_ns: int, ) -> None: """ Logs the end of a single event. This function should only be @@ -942,12 +1010,26 @@ def log_event_end( :param time_ns: Timestamp in nanoseconds :param metadata: Any extra metadata associated with this event """ - # Add compile id to metadata - if metadata is None: - metadata = {} compile_id = str(torch._guards.CompileContext.current_compile_id()) metadata["compile_id"] = compile_id + # Grab metadata collected during event span + all_event_data = self.get_event_data() + if event_name in all_event_data: + event_metadata = all_event_data[event_name] + del all_event_data[event_name] + else: + event_metadata = {} + # Add the passed in metadata + event_metadata.update(metadata) + + event = self._log_timed_event( + event_name, + time_ns, + "E", + event_metadata, + ) + # These stack health checks currently never happen, # but they're written this way to future proof any weird event # overlaps in the future. @@ -958,13 +1040,6 @@ def log_event_end( log.warning("ChromiumEventLogger: Start event not in stack, ignoring") return - event = self._log_timed_event( - event_name, - time_ns, - "E", - metadata, - ) - while event_name != stack[-1]: # If the event isn't the most recent one to end, pop # off the stack until it is. @@ -974,7 +1049,7 @@ def log_event_end( ) stack.pop() - log_chromium_event_internal(event, stack, compile_id, self.id_, start_time_ns) + log_chromium_event_internal(event, stack, self.id_, start_time_ns) # Finally pop the actual event off the stack stack.pop() @@ -1041,7 +1116,7 @@ def log_instant_event( expect_trace_id=True, ) # Log an instant event with the same start and end time - log_chromium_event_internal(event, self.get_stack(), compile_id, self.id_) + log_chromium_event_internal(event, self.get_stack(), self.id_, time_ns) CHROMIUM_EVENT_LOG: Optional[ChromiumEventLogger] = None @@ -1453,6 +1528,7 @@ def check_numpy_ndarray_args(args, kwargs): dict_values: Type[ValuesView[Any]] = type({}.values()) odict_values: Type[ValuesView[Any]] = type(collections.OrderedDict().values()) tuple_iterator: Type[Iterator[Any]] = type(iter(())) +range_iterator: Type[Iterator[Any]] = type(iter(range(0))) tuple_iterator_len = tuple_iterator.__length_hint__ # type: ignore[attr-defined] object_new = object.__new__ @@ -2493,7 +2569,7 @@ def __init__(self, f): self.f = f self.__name__ = "wrapped_" + self.f.__name__ - def __repr__(self): + def __repr__(self) -> str: return f">" def __call__(self, *args, **kwargs): @@ -2517,7 +2593,7 @@ def __init__(self, method: str): self.method = method self.__name__ = "wrapped_" + self.method - def __repr__(self): + def __repr__(self) -> str: return f">" def __call__(self, *args, **kwargs): @@ -2536,7 +2612,7 @@ def __init__(self, op: Callable[..., Any]): self.op = op self.__name__ = f"wrapped_{op.__name__}" - def __repr__(self): + def __repr__(self) -> str: return f">" def __call__(self, *args, **kwargs): @@ -2778,10 +2854,34 @@ def get_instruction_source_311(code: types.CodeType, inst: dis.Instruction) -> s h(x))) ^^^^^ - We need our own implementation since `format_frame_summary` in + We need our own implementation in < 3.13 since `format_frame_summary` in Python's `traceback` module doesn't handle multi-line expressions (and their anchor extraction code is not completely correct). """ + if sys.version_info >= (3, 13): + # multiline traceback implemented in 3.13+ + frame_summary = traceback.FrameSummary( + code.co_filename, + inst.positions.lineno, + code.co_name, + end_lineno=inst.positions.end_lineno, + colno=inst.positions.col_offset, + end_colno=inst.positions.end_col_offset, + ) + result = traceback.format_list([frame_summary])[0] + # remove first line containing filename info + result = "\n".join(result.splitlines()[1:]) + # indent lines with original indentation + orig_lines = [ + linecache.getline(code.co_filename, lineno).rstrip() + for lineno in range(inst.positions.lineno, inst.positions.end_lineno + 1) + ] + orig_lines_dedent = textwrap.dedent("\n".join(orig_lines)).splitlines() + indent_len = len(orig_lines[0]) - len(orig_lines_dedent[0]) + indent = orig_lines[0][:indent_len] + result = textwrap.indent(textwrap.dedent(result), indent) + return result + assert inst.positions is not None if inst.positions.lineno is None: return "" @@ -2912,18 +3012,28 @@ def is_torch_function_object(value): def has_torch_function(vt: torch._dynamo.variables.base.VariableTracker) -> bool: - from torch._dynamo.variables import LazyVariableTracker, UserDefinedObjectVariable + from torch._dynamo.variables import UserDefinedObjectVariable from torch._dynamo.variables.torch_function import TensorWithTFOverrideVariable - if isinstance(vt, TensorWithTFOverrideVariable): - return True + # Note on lazy vars: The value will either be realized or not throughout the course of execution + # if the value has a torch function, it will eventually be realized so we can realize it here + # if the value does not have a torch function, it may or may not be realized + # if it is realized it will be used and guards will be installed properly + # if it is not used, guards won't be installed, and it doesn't matter + # if the value has a torch function or not, so we should *not* realize it. + # NB: We technically know that if is_realized is False, LazyVariableTracker has the peek_value method + # but mypy does not unfortunately + if vt.is_realized() or ( + hasattr(vt, "peek_value") and hasattr(vt.peek_value(), "__torch_function__") + ): + if isinstance(vt, TensorWithTFOverrideVariable): + return True - if isinstance(vt, LazyVariableTracker): - LazyVariableTracker.realize(vt) + return isinstance(vt, UserDefinedObjectVariable) and hasattr( + vt.value, "__torch_function__" + ) - return isinstance(vt, UserDefinedObjectVariable) and hasattr( - vt.value, "__torch_function__" - ) + return False # see note [Tensor Fakification and Symbol Caching] @@ -3028,7 +3138,7 @@ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): if node.op == "placeholder" and node.meta.get("steal_arg", False) ] - if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): # fast path, avoid pytree overhead # compiled autograd inputs are always a list of tensors, maybe followed by symints assert inputs_idx_to_clear == [0] @@ -3077,7 +3187,7 @@ class Lit: def __init__(self, s): self.s = s - def __repr__(self): + def __repr__(self) -> str: return self.s @@ -3166,6 +3276,34 @@ def does_not_override_dict_iter_methods(user_cls): ) +# Helper functions below are to prevent __torch_function__ +# calls from happening in the middle of __torch_function__ +# compiled bytecode +# They will be skipped which is the desired result +def call_size(x, i): + @torch._dynamo.disable(recursive=True) + def fn(x, i): + return x.size(i) + + return fn(x, i) + + +def call_stride(x, i): + @torch._dynamo.disable(recursive=True) + def fn(x, i): + return x.stride(i) + + return fn(x, i) + + +def call_storage_offset(x): + @torch._dynamo.disable(recursive=True) + def fn(x): + return x.storage_offset() + + return fn(x) + + # Helper function to extract relevant parts of a tensor's __dict__ to store in node meta. # To avoid ref cycles, it's important that no tensors are present here, so leave those out. def _extract_tensor_dict(t): diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 5a8522e68c4c0..7d07978d7a0bf 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -30,10 +30,12 @@ ) from .distributed import BackwardHookVariable, DistributedVariable, PlacementVariable from .functions import ( + CreateTMADescriptorVariable, FunctoolsPartialVariable, NestedUserFunctionVariable, PolyfilledFunctionVariable, SkipFunctionVariable, + TMADescriptorVariable, UserFunctionVariable, UserMethodVariable, ) @@ -95,6 +97,7 @@ from .optimizer import OptimizerVariable from .sdpa import SDPAParamsVariable from .tensor import ( + DataPtrVariable, FakeItemVariable, NumpyNdarrayVariable, SymNodeVariable, @@ -124,9 +127,11 @@ "ConstDictVariable", "ContextWrappingVariable", "CountIteratorVariable", + "CreateTMADescriptorVariable", "CUDADeviceVariable", "CustomizedDictVariable", "CycleIteratorVariable", + "DataPtrVariable", "DefaultDictVariable", "DeletedVariable", "DeterministicAlgorithmsVariable", @@ -163,6 +168,7 @@ "StringFormatVariable", "SuperVariable", "TensorVariable", + "TMADescriptorVariable", "TorchCtxManagerClassVariable", "TorchInGraphFunctionVariable", "TorchVersionVariable", diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index 723c5a90c66ac..4572131553ea0 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: - from torch._dynamo.symbolic_convert import InstructionTranslator + from .symbolic_convert import InstructionTranslator, InstructionTranslatorBase class MutableLocalSource(Enum): @@ -121,6 +121,8 @@ class VariableTracker(metaclass=VariableTrackerMeta): VariableTracker instances are immutable and should be copied in order to change them. + + Prefer the factory function VariableTracker.build() over VariableTracker.__init__(). """ # fields to leave unmodified in apply() @@ -244,9 +246,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke value = self.const_getattr(tx, name) if not variables.ConstantVariable.is_literal(value): raise NotImplementedError - source = None - if self.source: - source = AttrSource(self.source, name) + source = self.source and AttrSource(self.source, name) return variables.ConstantVariable.create(value, source=source) def is_proxy(self): @@ -363,6 +363,20 @@ def next_variable(self, tx): def is_strict_mode(self, tx): return tx.strict_checks_fn and tx.strict_checks_fn(self) + @staticmethod + def build( + tx: "InstructionTranslatorBase", + value: Any, + source: Optional[Source] = None, + ) -> Any: + """Create a new VariableTracker from a value and optional Source""" + from . import builder + + if source is None: + return builder.SourcelessBuilder.create(tx, value) + else: + return builder.VariableBuilder(tx, source)(value) + def __init__( self, *, diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index dbae72c31693b..a47948dc541f0 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -3,6 +3,7 @@ import abc import collections import contextlib +import copy import dataclasses import enum import functools @@ -14,6 +15,7 @@ import random import re import sys +import time import types import warnings import weakref @@ -33,6 +35,7 @@ import torch from torch import SymInt +from torch._dynamo.utils import get_chromium_event_logger from torch._guards import GuardSource, TracingContext from torch._higher_order_ops.torchbind import call_torchbind from torch._ops import HigherOrderOperator @@ -104,6 +107,7 @@ istype, odict_values, proxy_args_kwargs, + range_iterator, set_example_value, tensor_always_has_static_shape, tuple_iterator, @@ -139,6 +143,7 @@ ) from .functions import ( CollectiveFunctionRewriteVariable, + CreateTMADescriptorVariable, FunctoolsPartialVariable, TritonKernelVariable, UserFunctionVariable, @@ -150,6 +155,7 @@ from .lazy import LazyVariableTracker from .lists import ( BaseListVariable, + ListIteratorVariable, ListVariable, NamedTupleVariable, RangeVariable, @@ -193,6 +199,7 @@ from .sdpa import SDPAParamsVariable from .tensor import ( NumpyNdarrayVariable, + supported_const_comparison_op_values, SymNodeVariable, TensorSubclassVariable, TensorVariable, @@ -426,8 +433,12 @@ def set_source_and_track_mutable(self, value, var): return self.tx.output.side_effects.track_mutable(value, var) @classmethod - @functools.lru_cache(None) def _type_dispatch(cls): + return cls._type_dispatch_impl(config.trace_numpy) + + @classmethod + @functools.lru_cache(None) + def _type_dispatch_impl(cls, trace_numpy): # NB: Careful not to close over self to avoid ref cycle from lru_cache entries = [ ( @@ -444,6 +455,7 @@ def _type_dispatch(cls): cls.wrap_listlike, ), (tuple_iterator, cls.wrap_tuple_iterator), + (range_iterator, cls.wrap_range_iterator), ((slice, range), cls.wrap_slice_range), (tuple(common_constant_types), cls.wrap_literal), (re.Pattern, cls.wrap_regex_pattern), @@ -452,7 +464,7 @@ def _type_dispatch(cls): (torch.jit.ScriptFunction, cls.wrap_jit_function), ] - if config.trace_numpy and np: + if trace_numpy and np: entries.append((np.ndarray, cls.wrap_numpy_ndarray)) result = {} @@ -523,7 +535,7 @@ def _id_dispatch( def _wrap(self, value): # import here to avoid circular dependencies - from torch.utils._triton import has_triton + from torch.utils._triton import has_triton, has_triton_tma if has_triton(): from triton.runtime.autotuner import Autotuner @@ -536,6 +548,19 @@ class JITFunction: class Autotuner: pass + if has_triton_tma(): + from triton.tools.experimental_descriptor import ( + create_1d_tma_descriptor, + create_2d_tma_descriptor, + ) + else: + + def create_1d_tma_descriptor(): + pass + + def create_2d_tma_descriptor(): + pass + # Handle exact type() match type_dispatch = self._type_dispatch().get(type(value)) if type_dispatch is not None: @@ -965,6 +990,10 @@ def build_key_value(i, k, v): None, # No grid provided source=self.source, ) + elif value is create_1d_tma_descriptor: + return CreateTMADescriptorVariable(rank=1) + elif value is create_2d_tma_descriptor: + return CreateTMADescriptorVariable(rank=2) elif isinstance(value, torch.amp.autocast_mode.autocast): self.install_guards(GuardBuilder.ID_MATCH) return AutocastModeVariable( @@ -1291,6 +1320,12 @@ def wrap_tuple_iterator(self, value: tuple_iterator): return self.set_source_and_track_mutable(value, result) + def wrap_range_iterator(self, value: range_iterator): + self.install_guards(GuardBuilder.TYPE_MATCH) + # Get all the values from the range iterator + items = [ConstantVariable.create(v) for v in copy.deepcopy(value)] + return ListIteratorVariable(items, mutable_local=MutableLocal()) + def wrap_slice_range(self, value: Union[slice, range]): items = [ VariableBuilder(self.tx, AttrSource(self.get_source(), k))( @@ -1760,6 +1795,17 @@ def update_frame_state(value): value, frame_state_entry.scalar, ) + get_chromium_event_logger().log_instant_event( + "automatic_dynamic", + time.time_ns(), + { + "name": name, + "dim_changed": "scalar", + "reason": "scalar change", + "cached": str(frame_state_entry.scalar), + "new": str(value), + }, + ) if self.source.guard_source().is_unspecialized_nn_module(): log.info( "%s", @@ -2323,12 +2369,16 @@ def _clone_input(value): set_example_value(proxy.node, example_value) return SDPAParamsVariable(proxy, **options) - elif isinstance(example_value, bool) and proxy.node.target in [ - torch._C._are_functorch_transforms_active, - torch.backends.cuda.is_flash_attention_available, - torch.backends.cuda.can_use_flash_attention, - torch.backends.cuda.can_use_efficient_attention, - ]: + elif isinstance(example_value, bool) and ( + proxy.node.target + in [ + torch._C._are_functorch_transforms_active, + torch.backends.cuda.is_flash_attention_available, + torch.backends.cuda.can_use_flash_attention, + torch.backends.cuda.can_use_efficient_attention, + ] + + list(supported_const_comparison_op_values.keys()) + ): set_example_value(proxy.node, example_value) return ConstantVariable.create(example_value, **options) elif ( @@ -2466,6 +2516,17 @@ def update_frame_state(size, stride): len(size), frame_state_entry.size, ) + get_chromium_event_logger().log_instant_event( + "automatic_dynamic", + time.time_ns(), + { + "name": name, + "dim_changed": "all", + "reason": "dimensionality change", + "cached": str(frame_state_entry.size), + "new": str(size), + }, + ) frame_state_entry.size = None frame_state_entry.stride = None else: @@ -2483,6 +2544,17 @@ def update_frame_state(size, stride): size[i], dim, ) + get_chromium_event_logger().log_instant_event( + "automatic_dynamic", + time.time_ns(), + { + "name": name, + "dim_changed": i, + "reason": "size change", + "cached": str(dim), + "new": str(size[i]), + }, + ) frame_state_entry.size[i] = None has_size_changed = ( has_size_changed or frame_state_entry.size[i] is None @@ -2513,6 +2585,17 @@ def update_frame_state(size, stride): stride[i], dim, ) + get_chromium_event_logger().log_instant_event( + "automatic_dynamic", + time.time_ns(), + { + "name": name, + "dim_changed": i, + "reason": "stride change", + "cached": str(dim), + "new": str(stride[i]), + }, + ) frame_state_entry.stride[i] = None tx.output.frame_state[name] = frame_state_entry @@ -2558,8 +2641,12 @@ def update_dim2constraint(dim, constraint_range, name): else: dim2constraint[dim] = constraint_range, name + from torch.export.dynamic_shapes import _RelaxedConstraint + if tx.output.export_constraints: for constraint in tx.output.export_constraints: + if isinstance(constraint, _RelaxedConstraint): + continue if constraint.t_id == t_id: update_dim2constraint( constraint.dim, constraint.constraint_range, constraint.name diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 66b5be01221a0..135b492e62485 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -629,7 +629,7 @@ def __init__(self, fn, **kwargs) -> None: super().__init__(**kwargs) self.fn = fn - def __str__(self) -> str: + def __repr__(self) -> str: if self.fn is None: name = "None" else: @@ -701,7 +701,6 @@ def has_constant_handler(self, args, kwargs): @staticmethod def _make_handler(fn, arg_types: List[type], has_kwargs: bool): - from .builder import SourcelessBuilder from .lazy import LazyVariableTracker obj = BuiltinVariable(fn) @@ -794,8 +793,6 @@ def call_self_handler(tx: "InstructionTranslator", args, kwargs): handlers.append(call_self_handler) if obj.can_constant_fold_through(): - builder = SourcelessBuilder.create - if ( all(issubclass(x, ConstantVariable) for x in arg_types) and not has_kwargs @@ -809,7 +806,7 @@ def constant_fold_handler(tx: "InstructionTranslator", args, kwargs): ) except Exception as exc: unimplemented(f"constant fold exception: {repr(exc)}") - return builder(tx, res) + return VariableTracker.build(tx, res) else: @@ -825,7 +822,7 @@ def constant_fold_handler(tx: "InstructionTranslator", args, kwargs): ) except Exception as exc: unimplemented(f"constant fold exception: {repr(exc)}") - return builder(tx, res) + return VariableTracker.build(tx, res) handlers.append(constant_fold_handler) @@ -1361,8 +1358,6 @@ def call_dict(self, tx: "InstructionTranslator", *args, **kwargs): @staticmethod def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs): - from .builder import SourcelessBuilder - if not kwargs: if not args: args = ({},) @@ -1399,7 +1394,7 @@ def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs): ) new_dict = dict(arg.value.items()) - return SourcelessBuilder.create(tx, new_dict) + return VariableTracker.build(tx, new_dict) else: func_var = arg.var_getattr(tx, "items") if not isinstance(func_var, variables.UserFunctionVariable): @@ -1631,7 +1626,6 @@ def call_getattr( TorchInGraphFunctionVariable, UserFunctionVariable, ) - from .builder import SourcelessBuilder, VariableBuilder name = name_var.as_python_constant() @@ -1666,34 +1660,21 @@ def call_getattr( if not hasattr_var.as_python_constant(): return default - options = {} - if obj.source: - source = AttrSource(obj.source, name) - options["source"] = source - else: - source = None - + source = obj.source and AttrSource(obj.source, name) if name in {"__bases__", "__base__", "__flags__"}: try: value = obj.as_python_constant() if isinstance(value, type): if name == "__bases__": - bases = value.__bases__ - if source is not None: - tuple_args = [ - VariableBuilder(tx, GetItemSource(source, i))(b) - for i, b in enumerate(bases) - ] - else: - tuple_args = [ - SourcelessBuilder.create(tx, b) for b in bases - ] - return variables.TupleVariable(tuple_args, **options) + tuple_args = [ + VariableTracker.build( + tx, b, source and GetItemSource(source, i) + ) + for i, b in enumerate(value.__bases__) + ] + return variables.TupleVariable(tuple_args, source=source) if name == "__base__": - base = value.__base__ - if source is not None: - return VariableBuilder(tx, source)(base) - return SourcelessBuilder.create(tx, base) + return VariableTracker.build(tx, value.__base__, source) if name == "__flags__": return ConstantVariable.create(value.__flags__) except NotImplementedError: @@ -1715,14 +1696,14 @@ def call_getattr( try: return obj.var_getattr(tx, name) except NotImplementedError: - return GetAttrVariable(obj, name, **options) + return GetAttrVariable(obj, name, source=source) elif isinstance(obj, TorchInGraphFunctionVariable): # Get OpOverload from an OpOverloadPacket, e.g., torch.ops.aten.add.default. member = getattr(obj.value, name) if isinstance( member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) ) and trace_rules.is_aten_op_or_tensor_method(member): - return TorchInGraphFunctionVariable(member, **options) + return TorchInGraphFunctionVariable(member, source=source) elif isinstance(obj, DummyModule): # TODO(mlazos) - Do we need this? if obj.is_torch or name not in obj.value.__dict__: @@ -1732,18 +1713,15 @@ def call_getattr( if config.replay_record_enabled: tx.exec_recorder.record_module_access(obj.value, name, member) + return VariableTracker.build(tx, member, source) - if source is not None: - return VariableBuilder(tx, source)(member) - else: - return SourcelessBuilder.create(tx, member) elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"): return ConstantVariable.create(getattr(obj.fn, name)) else: try: return obj.var_getattr(tx, name) except NotImplementedError: - return GetAttrVariable(obj, name, **options) + return GetAttrVariable(obj, name, source=source) def call_setattr( self, @@ -1882,8 +1860,6 @@ def call_delattr( return self.call_setattr(tx, obj, name_var, variables.DeletedVariable()) def call_type(self, tx: "InstructionTranslator", obj: VariableTracker): - from .builder import SourcelessBuilder, VariableBuilder - try: py_type = obj.python_type() except NotImplementedError as error: @@ -1893,10 +1869,8 @@ def call_type(self, tx: "InstructionTranslator", obj: VariableTracker): case_name="unknown_python_type", ) from None - if obj.source is None: - return SourcelessBuilder.create(tx, py_type) - else: - return VariableBuilder(tx, TypeSource(obj.source))(py_type) + source = obj.source and TypeSource(obj.source) + return VariableTracker.build(tx, py_type, source) def call_reversed(self, tx: "InstructionTranslator", obj: VariableTracker): if obj.has_unpack_var_sequence(tx): diff --git a/torch/_dynamo/variables/constant.py b/torch/_dynamo/variables/constant.py index de357cf8094f3..cdd977383996a 100644 --- a/torch/_dynamo/variables/constant.py +++ b/torch/_dynamo/variables/constant.py @@ -82,7 +82,7 @@ def __init__(self, value, **kwargs) -> None: def as_proxy(self): return self.value - def __str__(self) -> str: + def __repr__(self) -> str: return f"ConstantVariable({type(self.value).__name__}: {repr(self.value)})" def as_python_constant(self): @@ -226,7 +226,7 @@ def as_proxy(self): return int(self.value) # convert IntEnum to a normal int return self.value - def __str__(self) -> str: + def __repr__(self) -> str: return f"EnumVariable({type(self.value)})" def as_python_constant(self): diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index ac17713cf4d9a..b1688060db3ac 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -203,10 +203,13 @@ def len(self): ] ) - def _maybe_realize(self, item): - return item.realize() if item else item - def reconstruct(self, codegen): + def is_new_item(value, other): + # compare the id of the realized values if both values are not lazy VTs + if value and value.is_realized() and other.is_realized(): + return id(value.realize()) != id(other.realize()) + return id(value) != id(other) + # instructions to load collections.OrderedDict if necessary if self.user_cls is collections.OrderedDict: codegen.add_push_null( @@ -221,11 +224,8 @@ def reconstruct(self, codegen): num_args = 0 for key, value in self.items.items(): # We can safely call realize() here as it won't introduce any new guards - is_new_item = ( - self._maybe_realize(self.original_items.get(key.vt)) != value.realize() - ) - - if is_new_item or self.should_reconstruct_all: + item = self.original_items.get(key.vt) + if is_new_item(item, value) or self.should_reconstruct_all: codegen(key.vt) codegen(value) num_args += 1 @@ -984,12 +984,10 @@ def __init__(self, obj, **kwargs) -> None: assert self.is_matching_cls(type(obj)) def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": - from .builder import VariableBuilder - try: attr_value = getattr(self.obj, name) - attr_source = AttrSource(self.source, name) - return VariableBuilder(tx, attr_source)(attr_value) + source = self.source and AttrSource(self.source, name) + return VariableTracker.build(tx, attr_value, source) except AttributeError: unimplemented(f"getattr({self.value}, {name})") @@ -1053,15 +1051,11 @@ def call_get( key: VariableTracker, default: Optional[VariableTracker] = None, ): - from .builder import VariableBuilder - k, has_key = self._contains_helper(tx, key) if has_key: - return VariableBuilder( - tx, - GetItemSource(self.source, k), - )(sys.modules[k]) + source = self.source and GetItemSource(self.source, k) + return VariableTracker.build(tx, sys.modules[k], source) if default is not None: return default @@ -1069,10 +1063,6 @@ def call_get( return ConstantVariable.create(value=None) def call_getitem(self, tx: "InstructionTranslator", key: VariableTracker): - from .builder import VariableBuilder - k, has_key = self._contains_helper(tx, key) - return VariableBuilder( - tx, - GetItemSource(self.source, k), - )(sys.modules[k]) + source = self.source and GetItemSource(self.source, k) + return VariableTracker.build(tx, sys.modules[k], source) diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index c14b8794cba5f..6afffc15ad169 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -313,7 +313,7 @@ def create( user_hooks: VariableTracker, user_pre_hooks: VariableTracker, ): - if not compiled_autograd.compiled_autograd_enabled: + if not compiled_autograd.enabled(): unimplemented("module-level backwards hooks require compiled autograd") def _in_graph_bw_hooks(bw_state: BackwardState): diff --git a/torch/_dynamo/variables/functions.py b/torch/_dynamo/variables/functions.py index 14e4cf0c820a3..a178b0a5956c0 100644 --- a/torch/_dynamo/variables/functions.py +++ b/torch/_dynamo/variables/functions.py @@ -6,7 +6,18 @@ import inspect import itertools import types -from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, TypeVar, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import Never import torch @@ -37,6 +48,10 @@ if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator from torch._guards import Source + from torch._higher_order_ops.triton_kernel_wrap import ( + TritonGridType, + TritonKernelType, + ) _F = TypeVar("_F", bound=Callable) @@ -47,9 +62,7 @@ def wrap_bound_arg(tx: "InstructionTranslator", val, source=None): if isinstance(val, VariableTracker): return val elif not source: - from torch._dynamo.variables.builder import SourcelessBuilder - - return SourcelessBuilder.create(tx, val) + return VariableTracker.build(tx, val) else: # Create a lazy variable to avoid guarding on __defaults__ unless really # needed. @@ -208,9 +221,11 @@ def bind_args(self, parent, args, kwargs): ) if fn.__kwdefaults__: kwdefaults_sources = { - k: None - if self.source is None - else DefaultsSource(self.source, k, is_kw=True) + k: ( + None + if self.source is None + else DefaultsSource(self.source, k, is_kw=True) + ) for k in fn.__kwdefaults__ } fake_func.__kwdefaults__ = { @@ -241,8 +256,6 @@ def bind_args(self, parent, args, kwargs): # optimization for cleaner codegen result[name] = var elif self.source: - from .builder import VariableBuilder - side_effects = parent.output.side_effects if cell in side_effects: out = side_effects[cell] @@ -254,17 +267,14 @@ def bind_args(self, parent, args, kwargs): closure_cell, "cell_contents" ) try: - contents_var = VariableBuilder( - parent, closure_cell_contents - )(cell.cell_contents) + contents_var = VariableTracker.build( + parent, cell.cell_contents, closure_cell_contents + ) except ValueError: # Cell has not yet been assigned contents_var = variables.DeletedVariable() - if ( - closure_cell_contents.name() - not in tx.mutated_closure_cell_contents - ): + if id(cell) not in tx.mutated_closure_cell_ids: # Optimistically don't allocate the cell, to # reduce the number of side effects. This is # important for cond, as without it, any accesses @@ -274,6 +284,10 @@ def bind_args(self, parent, args, kwargs): # the analysis with this cell's name in the # mutated list here result[name] = contents_var + # Map the variable to the original cell so we can + # look it up later, see + # `InliningInstructionTranslator.STORE_DEREF`. + tx.contents_var_to_mutated_cell[contents_var] = cell continue # cells are written to with "cell_contents", @@ -287,9 +301,7 @@ def bind_args(self, parent, args, kwargs): result[name] = out else: - from .builder import SourcelessBuilder - - result[name] = SourcelessBuilder.create(tx, cell.cell_contents) + result[name] = VariableTracker.build(tx, cell.cell_contents) return result, closure_cells @@ -297,17 +309,14 @@ def export_freevars(self, parent, child): pass def var_getattr(self, tx: "InstructionTranslator", name: str): - source = AttrSource(self.source, name) if self.source else None + source = self.source and AttrSource(self.source, name) try: subobj = inspect.getattr_static(self.fn, name) except AttributeError: - options = {"source": source} - return variables.GetAttrVariable(self, name, **options) + return variables.GetAttrVariable(self, name, source=source) if source: return variables.LazyVariableTracker.create(subobj, source) - from .builder import SourcelessBuilder - - return SourcelessBuilder.create(tx, subobj) + return VariableTracker.build(tx, subobj) def call_hasattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: result = hasattr(self.fn, name) @@ -347,7 +356,7 @@ def __init__(self, fn, obj, **kwargs) -> None: super().__init__(fn=fn, **kwargs) self.obj = obj - def __str__(self) -> str: + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.fn}, {self.obj})" def self_args(self): @@ -539,7 +548,8 @@ def get_globals(self): return self.f_globals def bind_args(self, parent, args, kwargs): - from .misc import InlinedClosureVariable + # Avoid circular import + from .misc import ClosureVariable, NewCellVariable code = self.get_code() func = types.FunctionType( @@ -560,23 +570,15 @@ def bind_args(self, parent, args, kwargs): for idx, name in enumerate(code.co_freevars): cell = self.closure.items[idx] assert name not in result - if isinstance(cell, InlinedClosureVariable): - # InlinedClosureVariable's are created from LOAD_CLOSURE's from - # InliningInstructionTranslators when the variable name is not found in closure_cells. - # They should remain outside of closure_cells, so that our callee (the - # InliningInstructionTranslator that traces `func`) handles - # the cell correctly - that is, the cell's contents are treated as if they - # are local variables, like in UserFunctionVariable's bind_args for freevars. - cand = parent - while cand and name not in cand.symbolic_locals: - cand = cand.parent - if cand is None: - raise RuntimeError( - f"Couldn't find {name} in the symbolic_locals of the inline interpreter stack" - ) - result[name] = cand.symbolic_locals[name] + # In the regular case, a cell is either a `ClosureVariable` or + # `NewCellVariable`. + if isinstance(cell, (ClosureVariable, NewCellVariable)): + closure_cells[name] = cell else: - closure_cells[name] = self.closure.items[idx] + # We model unmodified cells captured by `UserFunctionVariable` as + # their contents, in tracer's `symbolic_locals`. See + # `UserFunctionVariable::bind_args`. + result[name] = cell return result, closure_cells @@ -739,6 +741,12 @@ def wraps(fn): ) # also warn on it because most users won't see the graph break message torch._dynamo.utils.warn_once(msg) + if self.value.__qualname__ == "allow_in_graph": + msg = ( + "Found an allow_in_graph decorator to a function which " + "is created inside the parent function that is getting " + "compiled. This is not supported for now." + ) msg += f"', {self.reason}'" if self.reason else "" unimplemented(msg) @@ -759,14 +767,8 @@ def __init__(self, wrapper_obj, attr_to_trace, **kwargs) -> None: def var_getattr(self, tx: "InstructionTranslator", name): if name == self.attr_to_trace: val = getattr(self.wrapper_obj, self.attr_to_trace) - if self.source: - from .builder import VariableBuilder - - return VariableBuilder(tx, AttrSource(self.source, name))(val) - else: - from .builder import SourcelessBuilder - - return SourcelessBuilder.create(tx, val) + source = self.source and AttrSource(self.source, name) + return VariableTracker.build(tx, val, source) return super().var_getattr(tx, name) @@ -1001,8 +1003,6 @@ def call_function( args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": - from torch._dynamo.variables.builder import SourcelessBuilder - if self.can_constant_fold_through() and check_unspec_or_constant_args( args, kwargs ): @@ -1012,7 +1012,7 @@ def call_function( **{k: v.as_python_constant() for k, v in kwargs.items()}, ) ) - return SourcelessBuilder.create(tx, result) + return VariableTracker.build(tx, result) # Special case for sum on tuple/list of ints if ( @@ -1036,15 +1036,17 @@ def call_function( ), sym_num=torch.sym_sum( [ - x.value - if isinstance(x, variables.ConstantVariable) - else x.sym_num + ( + x.value + if isinstance(x, variables.ConstantVariable) + else x.sym_num + ) for x in args[0].items ] ), ) - traceable_function_variable = SourcelessBuilder.create(tx, self.traceable_fn) + traceable_function_variable = VariableTracker.build(tx, self.traceable_fn) return traceable_function_variable.call_function(tx, args, kwargs) def call_method( @@ -1070,22 +1072,25 @@ def as_python_constant(self): return self.fn -from torch._higher_order_ops.triton_kernel_wrap import TritonHOPifier +from torch._higher_order_ops.triton_kernel_wrap import ( + TMADescriptorMetadata, + TritonHOPifier, +) class DynamoTritonHOPifier(TritonHOPifier): - def raise_unsupported(self, msg): + def raise_unsupported(self, msg: str) -> Never: raise Unsupported(msg) - def is_callable(self, maybe_callable): + def is_callable(self, maybe_callable: Any) -> bool: return isinstance( maybe_callable, (NestedUserFunctionVariable, UserFunctionVariable) ) - def get_value(self, val): + def get_value(self, val: Any) -> Any: return val.value - def check_grid(self, grid): + def check_grid(self, grid) -> Tuple[torch.fx.proxy.Proxy, ...]: from .lists import BaseListVariable if isinstance(grid, BaseListVariable): @@ -1098,10 +1103,22 @@ def call_grid(self, grid, meta, tx): grid = grid.call_function(tx, [meta], {}) return grid - def call_HOP(self, variable, grids, combined_args_raw, tx): + def call_HOP(self, variable, grids, combined_args_raw, tx) -> ConstantVariable: from .constant import ConstantVariable from .dicts import ConstDictVariable + # as we can only pass tensors as non-const args in fx graph, + # here we replace TMA descriptors (TMADescriptorVariable + # instances) with the underlying tensors, while moving the + # TMA descriptor-related metadata to a separate argument, + # so that we can reconstruct the TMA descriptors downstream + tma_descriptor_metadata: TMADescriptorMetadata = {} + for k in list(combined_args_raw.keys()): + v = combined_args_raw[k] + if isinstance(v, TMADescriptorVariable): + tma_descriptor_metadata[k] = v.to_metadata() + combined_args_raw[k] = v.data_ptr.from_tensor + combined_args = { variables.ConstantVariable.create(k): v for k, v in combined_args_raw.items() @@ -1126,6 +1143,13 @@ def call_HOP(self, variable, grids, combined_args_raw, tx): if not isinstance(v, ConstantVariable) } + for v in non_constant_args.values(): + v = v.realize() + if not isinstance(v, (variables.TensorVariable, variables.SymNodeVariable)): + self.raise_unsupported( + f"Unexpected argument type for a Triton kernel: {repr(v)}." + ) + constant_args_idx = kernel_side_table.add_constant_args(constant_args) meta = ConstDictVariable(non_constant_args, dict) tx.output.create_proxy( @@ -1136,6 +1160,7 @@ def call_HOP(self, variable, grids, combined_args_raw, tx): "kernel_idx": variable.kernel_idx, "constant_args_idx": constant_args_idx, "grid": grids, + "tma_descriptor_metadata": tma_descriptor_metadata, "kwargs": meta.as_proxy(), }, ) @@ -1149,6 +1174,10 @@ def call_HOP(self, variable, grids, combined_args_raw, tx): class TritonKernelVariable(VariableTracker): + grid: "TritonGridType" + kernel: "TritonKernelType" + kernel_idx: Optional[int] + def __init__(self, kernel, kernel_idx, grid, **kwargs) -> None: super().__init__(**kwargs) dynamo_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) @@ -1186,3 +1215,92 @@ def specialize_symbolic(self, arg: Any) -> Any: if isinstance(arg, SymNodeVariable): return ConstantVariable.create(arg.evaluate_expr()) return arg + + +class TMADescriptorVariable(VariableTracker): + def __init__( + self, + data_ptr: "variables.DataPtrVariable", + dims: "List[ConstantVariable]", + block_dims: "List[ConstantVariable]", + element_size: "ConstantVariable", + **kwargs, + ): + assert isinstance(data_ptr, variables.DataPtrVariable) + super().__init__(**kwargs) + self.data_ptr = data_ptr + self.dims = dims + self.block_dims = block_dims + self.element_size = element_size + + def to_metadata(self): + return ( + [dim.as_proxy() for dim in self.dims], + [dim.as_proxy() for dim in self.block_dims], + self.element_size.as_proxy(), + ) + + def reconstruct(self, codegen): + codegen.add_push_null( + lambda: codegen.load_import_from( + "triton.tools.experimental_descriptor", + f"create_{len(self.dims)}d_tma_descriptor", + ) + ) + self.data_ptr.reconstruct(codegen) + args = [*self.dims, *self.block_dims, self.element_size] + codegen.foreach(args) + codegen.call_function(len(args) + 1, False) + + +class CreateTMADescriptorVariable(VariableTracker): + def __init__( + self, + rank: int, + **kwargs, + ) -> None: + assert rank in (1, 2) + super().__init__(**kwargs) + self.rank = rank + + def call_function( + self, + tx: "InstructionTranslator", + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + ptr = kwargs["ptr"] if "ptr" in kwargs else args[0] + + if not isinstance(ptr, variables.DataPtrVariable): + raise Unsupported( + "Please ensure there were no graph breaks between " + f"create_{self.rank}d_tma_descriptor and the upstream " + ".data_ptr() call." + ) + + if self.rank == 1: + assert len(args) + len(kwargs) == 4 + dims = [ + kwargs["dim"] if "dim" in kwargs else args[1], + ] + block_dims = [ + kwargs["block_dim"] if "block_dim" in kwargs else args[2], + ] + else: + assert len(args) + len(kwargs) == 6 + dims = [ + kwargs["dim1"] if "dim1" in kwargs else args[1], + kwargs["dim0"] if "dim0" in kwargs else args[2], + ] + block_dims = [ + kwargs["block_dim1"] if "block_dim1" in kwargs else args[3], + kwargs["block_dim0"] if "block_dim0" in kwargs else args[4], + ] + element_size = kwargs["element_size"] if "element_size" in kwargs else args[-1] + + return TMADescriptorVariable( + data_ptr=ptr, + dims=dims, + block_dims=block_dims, + element_size=element_size, + ) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 007bfb2f97b0c..3af819db81dbd 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1,6 +1,7 @@ # mypy: ignore-errors import contextlib +import copy import functools import inspect import itertools @@ -13,12 +14,12 @@ import torch.nn from torch._dynamo.utils import get_fake_value from torch._dynamo.variables import ConstantVariable -from torch._dynamo.variables.base import VariableTracker from torch._dynamo.variables.builtin import BuiltinVariable from torch._dynamo.variables.functions import UserFunctionVariable from torch._dynamo.variables.tensor import SymNodeVariable from torch._guards import Source from torch._ops import HigherOrderOperator +from torch.fx.node import map_arg from torch.fx.passes.shape_prop import _extract_tensor_metadata from torch.utils import _pytree as pytree @@ -31,6 +32,7 @@ ) from ..source import AttrSource from ..utils import proxy_args_kwargs +from .base import VariableTracker from .dicts import ConstDictVariable from .lazy import LazyVariableTracker from .lists import ListVariable, TupleVariable @@ -623,11 +625,15 @@ def make(value, source=None, **kwargs): return CallTorchbindHigherOrderVariable(value, source, **kwargs) elif value.__name__ == "wrap_with_set_grad_enabled": return WrapWithSetGradEnabledHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "wrap_with_autocast": + return WrapWithAutocastHigherOrderVariable(value, source, **kwargs) elif ( value.__name__ == "auto_functionalized" or value.__name__ == "auto_functionalized_v2" ): return AutoFunctionalizeHigherOrderVariable(value, source, **kwargs) + elif value.__name__ == "invoke_subgraph": + return InvokeSubgraphHigherOrderVariable(value, source, **kwargs) else: unimplemented(f"HigherOrderOperator {value.__name__}") @@ -673,38 +679,40 @@ def call_function( ) # Specialize into one of the branches since pred is constant - if type(args[0]) is ConstantVariable: + pred, true_fn, false_fn, operands = args + if type(pred) is ConstantVariable: log.warning( "Pred is a Python constant. When used with torch.cond, it executes only one of the branches." " If you want torch.cond to perserve two branches, please make the predicate a boolean tensor or a SymBool." ) - if args[0].as_python_constant(): - return args[1].call_function(tx, args[3].unpack_var_sequence(tx), {}) + if pred.as_python_constant(): + return true_fn.call_function(tx, operands.unpack_var_sequence(tx), {}) else: - return args[2].call_function(tx, args[3].unpack_var_sequence(tx), {}) + return false_fn.call_function(tx, operands.unpack_var_sequence(tx), {}) # predicate - if type(args[0]) not in (ConstantVariable, TensorVariable, SymNodeVariable): + if type(pred) not in (ConstantVariable, TensorVariable, SymNodeVariable): unimplemented( f"Expected pred to be bool or a boolean tensor with single " - f"item but got {str(type(args[0]))} " - f"with original python type {str(args[0].python_type())}.", + f"item but got {str(type(pred))} " + f"with original python type {str(pred.python_type())}.", ) # operands - if not isinstance(args[3], (ListVariable, TupleVariable)): + if not isinstance(operands, (ListVariable, TupleVariable)): unimplemented( - f"Expected a tuple but got {args[3].python_type()}", + f"Expected operands to be a list/tuple but got " + f"{operands.python_type()}", ) - operands = args[3].unpack_var_sequence(tx) - if not only_consist_of(args[3], (TensorVariable,)): + operands_seq = operands.unpack_var_sequence(tx) + if not only_consist_of(operands, (TensorVariable,)): unimplemented( "Expect operands to be a tuple of pytrees that only consists of tensor leaves." ) # branches - _check_supported_callable_arg(tx, args[1], "true_fn") - _check_supported_callable_arg(tx, args[2], "false_fn") + _check_supported_callable_arg(tx, true_fn, "true_fn") + _check_supported_callable_arg(tx, false_fn, "false_fn") # Our strategy for tracing the true/false branches of cond # are to checkpoint our graphstate, run the true branch, @@ -730,7 +738,7 @@ def speculate_branch(branch): ) = speculate_subgraph( tx, args[ix], - operands, + operands_seq, {}, "cond", source_target=self.value, @@ -817,7 +825,7 @@ def diff_meta(tensor_vars1, tensor_vars2): false_node = make_attr(tx, false_name) p_args = ( - args[0].as_proxy(), + pred.as_proxy(), true_node, false_node, # We pick true_shared but it shouldn't matter @@ -903,26 +911,30 @@ def call_function( f"Usage: while_loop(cond_fn, body_fn, operands)", ) - _check_supported_callable_arg(tx, args[0], "cond_fn") - _check_supported_callable_arg(tx, args[1], "body_fn") + cond_fn, body_fn, operands, additional_inputs = args + _check_supported_callable_arg(tx, cond_fn, "cond_fn") + _check_supported_callable_arg(tx, body_fn, "body_fn") # operands - if not isinstance(args[2], (ListVariable, TupleVariable)): + if not isinstance(operands, (ListVariable, TupleVariable)): unimplemented( - f"Expected a tuple but got {args[2].python_type()}", + f"Expected operands to be a list/tuple but got " + f"{operands.python_type()}", ) - operands = args[2].unpack_var_sequence(tx) - if not only_consist_of(args[2], (TensorVariable,)): + operands_seq = operands.unpack_var_sequence(tx) + if not only_consist_of(operands, (TensorVariable,)): unimplemented( "Expect operands to be a tuple of pytrees that only consists of tensor leaves." ) # additional inputs check - if not isinstance(args[3], (ListVariable, TupleVariable)): + if not isinstance(additional_inputs, (ListVariable, TupleVariable)): unimplemented( - f"Expected a tuple but got {args[3].python_type()}", + f"Expected additional_inputs to be a list/tuple but got " + f"{additional_inputs.python_type()}. It seems to be an " + f"internal error, please report an issue to PyTorch." ) - additional_inputs = args[3].unpack_var_sequence(tx) + additional_inputs_seq = additional_inputs.unpack_var_sequence(tx) ( (cond_r, cond_treespec), @@ -930,8 +942,8 @@ def call_function( cond_lifted_freevars, ) = speculate_subgraph( tx, - args[0], - operands + additional_inputs, + cond_fn, + operands_seq + additional_inputs_seq, {}, "while_loop", source_target=self.value, @@ -959,8 +971,8 @@ def call_function( body_lifted_freevars, ) = speculate_subgraph( tx, - args[1], - operands + additional_inputs, + body_fn, + operands_seq + additional_inputs_seq, {}, "while_loop", source_target=self.value, @@ -1006,9 +1018,10 @@ def call_function( p_args = ( cond_node, body_node, - tuple([operand.as_proxy() for operand in operands]), + tuple([operand.as_proxy() for operand in operands_seq]), tuple( - [inp.as_proxy() for inp in additional_inputs] + additional_lifted_inputs + [inp.as_proxy() for inp in additional_inputs_seq] + + additional_lifted_inputs ), ) @@ -1038,7 +1051,7 @@ def call_function( args: List[VariableTracker], kwargs: Dict[str, VariableTracker], ) -> VariableTracker: - from .builder import SourcelessBuilder, wrap_fx_proxy + from .builder import wrap_fx_proxy args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) @@ -1060,18 +1073,22 @@ def arg_extractor(combine_fn, xs, dim): tx, "new_empty", args=( - SourcelessBuilder.create( + VariableTracker.build( tx, - leaf.size - if leaf.size is not None - else BuiltinVariable(getattr) - .call_function(tx, [leaf, ConstantVariable.create("shape")], {}) - .items, + ( + leaf.size + if leaf.size is not None + else BuiltinVariable(getattr) + .call_function( + tx, [leaf, ConstantVariable.create("shape")], {} + ) + .items + ), ), ), kwargs={ - "dtype": SourcelessBuilder.create(tx, leaf.dtype), - "requires_grad": SourcelessBuilder.create(tx, leaf.requires_grad), + "dtype": VariableTracker.build(tx, leaf.dtype), + "requires_grad": VariableTracker.build(tx, leaf.requires_grad), }, ) for leaf in itertools.chain(xs.items, xs.items) @@ -1148,27 +1165,34 @@ def call_function( args: List[VariableTracker], kwargs: Dict[str, VariableTracker], ) -> VariableTracker: - from torch._higher_order_ops.scan import make_expanded_output_shape + from torch._higher_order_ops.scan import ( + _extract_carry_and_out, + first_slice_copy, + stack_y, + ) - from .builder import SourcelessBuilder, wrap_fx_proxy + from .builder import wrap_fx_proxy args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) - def arg_extractor(combine_fn, init, xs, dim, reverse): - return combine_fn, init, xs, dim, reverse + def arg_extractor(combine_fn, init, xs, dim, reverse, additional_inputs): + return combine_fn, init, xs, dim, reverse, additional_inputs - combine_fn, init, xs, dim, reverse = arg_extractor(*args, **kwargs) + combine_fn, init, xs, dim, reverse, additional_inputs = arg_extractor( + *args, **kwargs + ) + assert isinstance(additional_inputs, variables.BaseListVariable) if xs.python_type() != list: unimplemented( f"Expected xs to be a list of tensors but got {xs.python_type()}", ) - assert isinstance(xs, torch._dynamo.variables.lists.BaseListVariable) + assert isinstance(xs, variables.BaseListVariable) if init.python_type() != list: unimplemented( f"Expected init to be a list of tensors but got {init.python_type()}", ) - assert isinstance(init, torch._dynamo.variables.lists.BaseListVariable) + assert isinstance(init, variables.BaseListVariable) dim_fake = ( dim.as_proxy() @@ -1189,58 +1213,18 @@ def arg_extractor(combine_fn, init, xs, dim, reverse): # TODO: Fix these pointless new_empty calls appearing in the dynamo output graph. # TODO: Unify handling of sub_args across control flow ops, such as cond, while_loop, etc. sub_args_init = [ - ini.call_method( - tx, - "new_empty", - args=( - SourcelessBuilder.create( - tx, - ini.size - if ini.size is not None - else tuple( - BuiltinVariable(getattr) - .call_function( - tx, [ini, ConstantVariable.create("shape")], {} - ) - .items - ), - ), - ), - kwargs={ - "dtype": SourcelessBuilder.create(tx, ini.dtype), - "device": SourcelessBuilder.create(tx, ini.device), - "requires_grad": SourcelessBuilder.create(tx, ini.requires_grad), - }, - ) - for ini in init.items + ini.call_method(tx, "clone", args=(), kwargs={}) for ini in init.items ] - sub_args_inp_shapes = make_expanded_output_shape( - dim_fake, - 1, - [ - tuple( - BuiltinVariable(getattr) - .call_function(tx, [inp, ConstantVariable.create("shape")], {}) - .items - ) - for inp in xs.items - ], - True, - ) + # The sub_args_inp is a slice of original input, e.g. if input.size is (3, 4), and scan dim=0 + # the sub_args_inp shape will be (4, ). sub_args_inp = [ - inp.call_method( - tx, - "new_empty", - args=(SourcelessBuilder.create(tx, inp_sh),), - kwargs={ - "dtype": SourcelessBuilder.create(tx, inp.dtype), - "device": SourcelessBuilder.create(tx, inp.device), - "requires_grad": SourcelessBuilder.create(tx, inp.requires_grad), - }, - ) - for inp, inp_sh in zip(xs.items, sub_args_inp_shapes) + _make_inlined(tx, first_slice_copy)(inp, dim) for inp in xs.items ] - sub_args = sub_args_init + sub_args_inp + sub_args_additional_inputs = [ + t.call_method(tx, "clone", args=(), kwargs={}) + for t in additional_inputs.items + ] + sub_args = sub_args_init + sub_args_inp + sub_args_additional_inputs ( (combine_result, combine_treespec), combine_graph, @@ -1255,22 +1239,42 @@ def arg_extractor(combine_fn, init, xs, dim, reverse): set_subgraph_inputs="flatten_manual", ) - if combine_lifted_freevars: - unimplemented( - f"Combine fn had unexpected freevars: {combine_lifted_freevars}" - ) + # key in the combine_lifted_freevars are proxies in the root tracer. + # We use root tracer's proxies to create scan op's inputs. + def _check_phs_position_match( + combine_graph: torch.fx.Graph, lifted_proxies: list[torch.fx.Proxy] + ): + lifted_phs = [ + node for node in combine_graph.nodes if node.op == "placeholder" + ][-len(lifted_proxies) :] + for ph, lifted_proxy in zip(lifted_phs, lifted_proxies): + if ph is not lifted_proxy.node: + unimplemented( + "The postion lifted freevars doesn't match the order of placeholders in subgraph." + ) + + _check_phs_position_match(combine_graph, list(combine_lifted_freevars.values())) + combine_freevars_proxy = list(combine_lifted_freevars.keys()) - if any(cr.python_type() != list for cr in combine_result.items): + if combine_result.python_type() != list: unimplemented( f"Expected combine_fn to return a list if tensor but got {combine_result.python_type()}", ) xs_proxy = xs.as_proxy() init_proxy = init.as_proxy() - combine_carry_proxy = combine_result.items[0].as_proxy() + additional_inputs_proxy = additional_inputs.as_proxy() + combine_freevars_proxy + num_init_leaves = len(init_proxy) + # combine_result is a flatten list concated by carry + y, len(carry) is len(init) since they have + # same pytree structure. + carry_vars, y_vars = _extract_carry_and_out( + combine_result.items, num_init_leaves + ) + carry_proxies = [carry_var.as_proxy() for carry_var in carry_vars] + y_proxies = [y_var.as_proxy() for y_var in y_vars] # Checks for carry and init - for ini_proxy, carry in zip(init_proxy, combine_carry_proxy): + for ini_proxy, carry in zip(init_proxy, carry_proxies): ini_meta = ini_proxy.node.meta["example_value"] carry_meta = carry.node.meta["example_value"] if ( @@ -1292,32 +1296,19 @@ def arg_extractor(combine_fn, init, xs, dim, reverse): xs_proxy, dim.as_proxy(), reverse.as_proxy(), + additional_inputs_proxy, ) with tx.fake_mode: + example_carry = [ + init_p.node.meta["example_value"].clone() for init_p in init_proxy + ] # For the fake mode, we need to duplicate the init tensor along the dim # to have the same size as the xs arguments - # We also do a clone with contiguous_format. This is to be consistent with - # eager semantic of map, which stacks the outputs. The result is contiguous - # as a result of the stack operation. - fake_out_shapes = make_expanded_output_shape( - dim_fake, - scan_length, - [ - get_fake_value(o.as_proxy().node, tx).size() - for o in combine_result.items[1].items - ], - ) - out_meta = ( - [init_p.node.meta["example_value"].clone() for init_p in init_proxy], - list( # noqa: C400 - t.as_proxy() - .node.meta["example_value"] - .expand(*sh) - .clone(memory_format=torch.contiguous_format) - for t, sh in zip(combine_result.items[1].items, fake_out_shapes) - ), - ) + example_stacked_out = [ + stack_y(y.node.meta["example_value"], scan_length) for y in y_proxies + ] + out_meta = [*example_carry, *example_stacked_out] return wrap_fx_proxy( tx=tx, @@ -1502,10 +1493,20 @@ def call_function( class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): + def install_subgraph_in_output_graph( + self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="wrap_body" + ): + return add_subgraph( + tx, + f"{attr_name}", + body_gmod, + ) + def create_wrapped_node( self, tx: "InstructionTranslator", - args, + fn_vt, + fn_args_vt, kwargs, description, under_activation_checkpoint=False, @@ -1518,8 +1519,8 @@ def create_wrapped_node( body_lifted_freevars, ) = speculate_subgraph( tx, - args[0], # function - [*args[1:]], + fn_vt, + fn_args_vt, kwargs, description, source_target=self.value, @@ -1528,12 +1529,9 @@ def create_wrapped_node( ) body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) - body_name = add_subgraph( - tx, - "wrap_body", - body_gmod, + body_name = self.install_subgraph_in_output_graph( + tx, fn_vt, fn_args_vt, kwargs, body_gmod ) - body_node = make_attr(tx, body_name) # Since, we call `speculate_subgraph` with `set_subgraph_inputs="automatic`, @@ -1547,7 +1545,7 @@ def create_wrapped_node( body_r.as_proxy(), ) - return proxy_args, {}, example_value, body_r, treespec, body_gmod + return proxy_args, {}, example_value, body_r, treespec, body_gmod, body_name def call_function( self, @@ -1556,9 +1554,15 @@ def call_function( kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": # This flattens the kwargs into lifted args - p_args, p_kwargs, example_value, body_r, treespec, _ = self.create_wrapped_node( - tx, args, kwargs, "wrap" - ) + ( + p_args, + p_kwargs, + example_value, + body_r, + treespec, + _, + _, + ) = self.create_wrapped_node(tx, args[0], args[1:], kwargs, "wrap") if len(p_kwargs) > 0: unimplemented("kwargs should have been flattened into lifted args") @@ -1647,6 +1651,88 @@ def call_function( ) +class WrapWithAutocastHigherOrderVariable(TorchHigherOrderOperatorVariable): + """ + This hop is not exposed to users but is inserted into the graph + after export as a post-processing step. + """ + + def call_function( + self, + tx: "InstructionTranslator", + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + args, kwargs = LazyVariableTracker.realize_all((args, kwargs)) + + if kwargs: + unimplemented( + f"wrap_with_autocast: Got unexpected kwargs: {list(kwargs.keys())}" + ) + + device_type, dtype, enabled, cache_enabled, fn_var, *rest_args = args + + for arg in [device_type, dtype, enabled, cache_enabled]: + if not isinstance(arg, ConstantVariable): + unimplemented( + "device_type, dtype, enabled, cache_enabled must be constants" + ) + + _check_supported_callable_arg(tx, fn_var, "autocast") + + python_constants = [ + arg.as_python_constant() + for arg in [device_type, dtype, enabled, cache_enabled] + ] + + with torch.autocast(*python_constants): + ( + (body_r, treespec), + body_graph, + body_lifted_freevars, + ) = speculate_subgraph( + tx, + fn_var, + [*rest_args], + {}, + "torch.ops.higher_order.wrap_with_autocast", + source_target=self.value, + set_subgraph_inputs="manual", + should_flatten_outputs=True, + ) + + if len(body_lifted_freevars) > 0: + unimplemented( + f"wrap_with_autocast: Got unexpected freevars {body_lifted_freevars}" + ) + + body_gmod = torch.fx.GraphModule(tx.output.nn_modules, body_graph) + body_name = add_subgraph( + tx, + "wrap_body", + body_gmod, + ) + + body_node = make_attr(tx, body_name) + + proxy_args = tuple( + [ + *python_constants, + body_node, + ] + + [operand.as_proxy() for operand in rest_args] + ) + example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + + return _call_function_and_unflatten_output( + tx, self.value, proxy_args, {}, example_value, treespec + ) + + class HintsWrapperHigherOrderVariable(TorchHigherOrderOperatorVariable): @raise_hard_error_if_graph_break( reason="Hints_wrapper doesn't work unless it is captured completely with torch.compile." @@ -1860,9 +1946,11 @@ def call_function( body_r, treespec, checkpointed_gmod, + _, ) = self.create_wrapped_node( tx, - args, + args[0], + args[1:], gmod_kwargs, "torch.utils.checkpoint.checkpoint", under_activation_checkpoint=True, @@ -1998,9 +2086,7 @@ def create_wrapped_node( fn: "VariableTracker", fn_name: str, ): - from torch._higher_order_ops.flex_attention import TransformGetItemToIndex - - from .builder import SourcelessBuilder + from .._trace_wrapped_higher_order_op import TransformGetItemToIndex tx: InstructionTranslator = tx @@ -2008,9 +2094,9 @@ def create_scalar(): return query.call_method( tx, "new_empty", - (SourcelessBuilder.create(tx, []),), + (VariableTracker.build(tx, []),), { - "dtype": SourcelessBuilder.create(tx, torch.int32), + "dtype": VariableTracker.build(tx, torch.int32), }, ) @@ -2020,8 +2106,8 @@ def create_scalar(): score = query.call_method( tx, "new_empty", - (SourcelessBuilder.create(tx, []),), - {"requires_grad": SourcelessBuilder.create(tx, scores_require_grad)}, + (VariableTracker.build(tx, []),), + {"requires_grad": VariableTracker.build(tx, scores_require_grad)}, ) new_args = [score, *bhmn] else: @@ -2197,7 +2283,6 @@ def bwd(ctx, grad, x): source_target="autograd.Function", ) - fwd_src = AttrSource(self.parent_source, member="forward") ctx = AutogradFunctionContextVariable.create(tx, args, kwargs) if isinstance(self.fwd_graph, types.FunctionType): fwd_fn = UserFunctionVariable(self.fwd_graph) @@ -2218,16 +2303,15 @@ def bwd(ctx, grad, x): fwd_args, kwargs, "autograd.Function", - enable_grad=False, set_subgraph_inputs="semi_automatic", restore_side_effects=False, tracer=fwd_tracer, ) - if ctx.mutable_local in tx.output.side_effects.store_attr_mutations: + if ctx in tx.output.side_effects.store_attr_mutations: if ( "_materialize_non_diff_grads" - in tx.output.side_effects.store_attr_mutations[ctx.mutable_local] + in tx.output.side_effects.store_attr_mutations[ctx] ): unimplemented("NYI") @@ -2457,3 +2541,152 @@ def maybe_positional_arg_names(func): else: result.append(name) return result + + +def canonicalize(gmod, root_gmod): + # autograd_cache_key is sensitive to the name of the placeholder and intermediate nodes. + # So, we first canonicalize it. + new_graph = torch.fx.Graph() + env = {} + + placeholder_counter = itertools.count(0) + + def next_placeholder_name(): + nonlocal placeholder_counter + return f"placeholder_{next(placeholder_counter)}" + + node_counter = itertools.count(0) + + def next_node_name(): + nonlocal node_counter + return f"node_{next(node_counter)}" + + for node in gmod.graph.nodes: + if node.op == "placeholder": + env[node] = new_graph.placeholder(next_placeholder_name()) + else: + # Can't use node_copy because node.name will not be unique. + args = map_arg(node.args, lambda x: env[x]) + kwargs = map_arg(node.kwargs, lambda x: env[x]) + env[node] = new_graph.create_node( + node.op, node.target, args, kwargs, next_node_name(), node.type + ) + env[node].meta = copy.copy(node.meta) + + new_graph.lint() + new_gmod = torch.fx.GraphModule(root_gmod, new_graph) + return new_gmod + + +@functools.lru_cache(None) +def get_dummy_aot_autograd_config(): + from torch._functorch._aot_autograd.schemas import AOTConfig + + return AOTConfig( + fw_compiler=None, + bw_compiler=None, + inference_compiler=None, + partition_fn=None, + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + dynamic_shapes=True, + aot_autograd_arg_pos_to_source=None, + is_export=False, + no_tangents=False, + enable_log=False, + ) + + +def hash_graph_and_inputs(tx, gmod, fake_inputs): + # Here, we use the existing autograd_cache_key infrastructure to hash the + # graph and fake inputs. + + # TODO(anijain2305) - Consider reorganizing autograd_cache_key such that the + # namespaces seem more intuitive. It seems somewhat confusing that we are + # calling an API from aot_autograd here. + from torch._functorch._aot_autograd.autograd_cache import autograd_cache_key + + # autograd_cache_key is sensitive to the name of the placeholder nodes. + # So, we first canonicalize it. + canonicalized_gmod = canonicalize(gmod, tx.output.nn_modules) + config = get_dummy_aot_autograd_config() + + key, _ = autograd_cache_key(canonicalized_gmod, fake_inputs, config, {}) + return key + + +class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable): + def install_subgraph_in_output_graph( + self, tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name="invoke_subgraph" + ): + # Check if the subgraph from speculate_subgraph (body_gmod) and the fake + # inputs have already been seen before. If yes, the subgraph is already + # installed in the output graph and we can just access the subgraph + # using the saved attr name. + from torch._higher_order_ops.utils import has_potential_input_alias_or_mutation + + fake_inputs = [arg.as_proxy().node.meta["example_value"] for arg in fn_args_vt] + + # TODO(anijain2305) - This might be too big of a limitation. Consider + # supporting mutation/aliasing in HOP itself to remove this restriction. + if has_potential_input_alias_or_mutation(body_gmod, fake_inputs): + unimplemented("NYI: invoke_subgraph with aliasing/mutation") + + key = hash_graph_and_inputs(tx, body_gmod, fake_inputs) + + invoke_subgraph_cache = ( + tx.output.tracing_context.hop_dispatch_set_cache.get_cache( + torch._higher_order_ops.invoke_subgraph + ) + ) + + if invoke_subgraph_cache: + if identifier := invoke_subgraph_cache.get_dynamo_identifier(key): + return identifier + + body_name = super().install_subgraph_in_output_graph( + tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name + ) + if invoke_subgraph_cache: + invoke_subgraph_cache.add_dynamo_identifier(key, body_name) + + return body_name + + def call_function( + self, + tx: "InstructionTranslator", + args: "List[VariableTracker]", + kwargs: "Dict[str, VariableTracker]", + ) -> "VariableTracker": + # This flattens the kwargs into lifted args + ( + p_args, + p_kwargs, + example_value, + body_r, + treespec, + body_gmod, + body_name, + ) = self.create_wrapped_node( + tx, args[0], args[2].items, kwargs, "invoke_subgraph" + ) + + if len(p_kwargs) > 0: + unimplemented("kwargs should have been flattened into lifted args") + + flat_example_value = pytree.tree_map_only( + torch.fx.Proxy, + lambda a: a.node.meta["example_value"], + body_r.as_proxy(), + ) + + p_args = ( + p_args[0], + body_name, + p_args[1:], + ) + return _call_function_and_unflatten_output( + tx, self.value, tuple(p_args), p_kwargs, flat_example_value, treespec + ) diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 471d72e90535a..aee2d89488bd7 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -172,10 +172,8 @@ def retrieve_const_key(key): *args, mutable_local=MutableLocal() ) - from .builder import SourcelessBuilder - return tx.inline_user_function_return( - SourcelessBuilder.create(tx, polyfills.repeat), args, kwargs + VariableTracker.build(tx, polyfills.repeat), args, kwargs ) elif self.value is itertools.count: return variables.CountIteratorVariable(*args, mutable_local=MutableLocal()) diff --git a/torch/_dynamo/variables/lazy.py b/torch/_dynamo/variables/lazy.py index 23866616249f9..f2f32bb15de2b 100644 --- a/torch/_dynamo/variables/lazy.py +++ b/torch/_dynamo/variables/lazy.py @@ -20,14 +20,15 @@ def __init__(self, value: Any, source: Any) -> None: def realize(self) -> None: assert self.vt is None from ..symbolic_convert import InstructionTranslator - from .builder import SourcelessBuilder, VariableBuilder tx = InstructionTranslator.current_tx() + if isinstance(self.value, LazySymNodeFormatString): - self.vt = SourcelessBuilder.create(tx, self.value) + source = None else: - self.vt = VariableBuilder(tx, self.source)(self.value) + source = self.source + self.vt = VariableTracker.build(tx, self.value, source) del self.value del self.source @@ -37,7 +38,7 @@ class LazyVariableTracker(VariableTracker): A structure that defers the creation of the actual VariableTracker for a given underlying value until it is accessed. - The `realize` function invokes VariableBuilder to produce the real object. + The `realize` function invokes VariableTracker.build() to produce the real object. Once a LazyVariableTracker has been realized, internal bookkeeping will prevent double realization. @@ -80,17 +81,25 @@ def clone(self, **kwargs: Any) -> VariableTracker: self.realize() return VariableTracker.clone(self.unwrap(), **kwargs) + def peek_type(self) -> type[Any]: + assert not self.is_realized() + return type(self._cache.value) + + def peek_value(self) -> Any: + assert not self.is_realized() + return self._cache.value + def __str__(self) -> str: if self.is_realized(): - return self.unwrap().__str__() - return VariableTracker.__str__(self.unwrap()) + return repr(self.unwrap()) + return super().__repr__() def __getattr__(self, item: str) -> Any: return getattr(self.realize(), item) # most methods are auto-generated below, these are the ones we want to exclude visit = VariableTracker.visit # type: ignore[assignment] - __repr__ = VariableTracker.__repr__ + __repr__ = __str__ @classmethod def realize_all( @@ -144,7 +153,7 @@ def __init__( "{:" + fmt_spec_var.as_python_constant() + "}" ) - def __str__(self) -> str: + def __repr__(self) -> str: return str.format( self.fmt_var.as_python_constant(), str(self.sym_node_var.evaluate_expr()), diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 7cbeff07953d3..44f053452ad81 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -33,6 +33,7 @@ if TYPE_CHECKING: + from torch._dynamo.codegen import PyCodegen from torch._dynamo.symbolic_convert import InstructionTranslator @@ -134,10 +135,8 @@ def call_method( assert not kwargs return iter_contains(self.unpack_var_sequence(tx), args[0], tx) elif name == "index": - from .builder import SourcelessBuilder - return tx.inline_user_function_return( - SourcelessBuilder.create(tx, polyfills.index), + VariableTracker.build(tx, polyfills.index), [self] + list(args), kwargs, ) @@ -296,7 +295,7 @@ def as_proxy(self): def unpack_var_sequence(self, tx=None): return [variables.ConstantVariable.create(x) for x in self.as_python_constant()] - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen") -> None: assert "range" not in codegen.tx.f_globals codegen.add_push_null( lambda: codegen.append_output(codegen.create_load_python_module(range)) @@ -402,7 +401,7 @@ def __repr__(self) -> str: def debug_repr(self): return self.debug_repr_helper("[", "]") - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach(self.items) codegen.append_output(create_instruction("BUILD_LIST", arg=len(self.items))) @@ -459,7 +458,7 @@ def python_type(self): def debug_repr(self): return self.debug_repr_helper("deque([", "])") - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen") -> None: assert "deque" not in codegen.tx.f_globals codegen.add_push_null( lambda: codegen.append_output( @@ -534,7 +533,7 @@ def __repr__(self) -> str: def debug_repr(self): return self.debug_repr_helper("(", ")") - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach(self.items) codegen.append_output(create_instruction("BUILD_TUPLE", arg=len(self.items))) @@ -633,7 +632,7 @@ def as_proxy(self): ) return proxy - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: codegen.load_import_from("torch", "Size")) codegen.foreach(self.items) build_torch_size = [ @@ -717,21 +716,45 @@ def __init__(self, items, tuple_cls, **kwargs) -> None: super().__init__(items, **kwargs) self.tuple_cls = tuple_cls + def is_namedtuple(self): + return hasattr(self.tuple_cls, "_fields") and callable( + getattr(self.tuple_cls, "_make", None) + ) + + def is_structseq(self): + return not self.is_namedtuple() + def debug_repr(self): + if self.is_structseq(): + # StructSequenceType(iterable) + return repr(self.tuple_cls([Lit(x.debug_repr()) for x in self.items])) + # NamedTupleType(*iterable) return repr(self.tuple_cls(*(Lit(x.debug_repr()) for x in self.items))) def python_type(self): return self.tuple_cls def as_python_constant(self): + if self.is_structseq(): + # StructSequenceType(iterable) + return self.python_type()([x.as_python_constant() for x in self.items]) + # NamedTupleType(*iterable) return self.python_type()(*[x.as_python_constant() for x in self.items]) def as_proxy(self): assert self.python_type() is not SizeVariable + if self.is_structseq(): + # StructSequenceType(iterable) + return self.python_type()(self._as_proxy()) + # NamedTupleType(*iterable) return self.python_type()(*self._as_proxy()) - def reconstruct(self, codegen): - create_fn = getattr(self.tuple_cls, "_make", self.tuple_cls) + def reconstruct(self, codegen: "PyCodegen") -> None: + # Constructors: + # StructSequenceType(iterable) + # NamedTupleType(*iterable) + # NamedTupleType._make(iterable) + create_fn = self.tuple_cls if self.is_structseq() else self.tuple_cls._make codegen.add_push_null( lambda: codegen.append_output(codegen._create_load_const(create_fn)) ) @@ -804,7 +827,7 @@ def python_type(self): def as_python_constant(self): return slice(*[guard_if_dyn(x) for x in self.items]) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach(self.items) codegen.append_output(create_instruction("BUILD_SLICE", arg=len(self.items))) @@ -872,7 +895,7 @@ def unpack_var_sequence(self, tx): def force_unpack_var_sequence(self, tx) -> List[VariableTracker]: return self.unpack_var_sequence(tx) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen") -> None: remaining_items = self.items[self.index :] codegen.foreach(remaining_items) codegen.extend_output( @@ -977,7 +1000,7 @@ def modified(self, items, **kwargs): **kwargs, ) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: codegen(self.user_cls_source)) super().reconstruct(codegen) codegen.extend_output(create_call_function(1, False)) diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 54b5734835199..32a09df5961c3 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -207,12 +207,10 @@ def call_method( and len(kwargs) == 0 and args[0].is_python_constant() ): - from .builder import VariableBuilder - key = args[0].as_python_constant() - return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))( - collections.OrderedDict.__getitem__(self.objvar.value, key) - ) + value = collections.OrderedDict.__getitem__(self.objvar.value, key) + source = ODictGetItemSource(self.objvar.source, key) + return VariableTracker.build(tx, value, source) elif inner_fn in ( collections.OrderedDict.__setitem__, object.__setattr__, @@ -349,21 +347,6 @@ def reconstruct(self, codegen): codegen.append_output(codegen.create_load_closure(self.name)) -# closure variable created by an inlined function -class InlinedClosureVariable(UnknownVariable): - _nonvar_fields = { - "name", - *UnknownVariable._nonvar_fields, - } - - def __init__(self, name, **kwargs) -> None: - super().__init__(**kwargs) - self.name = name - - def reconstruct(self, codegen): - codegen.append_output(codegen.create_load_closure(self.name)) - - class NewCellVariable(VariableTracker): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) @@ -482,15 +465,10 @@ def __init__(self, value, **kwargs) -> None: self.value = value def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": - from .builder import SourcelessBuilder, VariableBuilder - try: attr_value = getattr(self.value, name) - if self.source: - attr_source = AttrSource(self.source, name) - return VariableBuilder(tx, attr_source)(attr_value) - else: - return SourcelessBuilder.create(tx, attr_value) + source = self.source and AttrSource(self.source, name) + return VariableTracker.build(tx, attr_value, source) except AttributeError: unimplemented(f"getattr({self.value}, {name})") @@ -727,6 +705,9 @@ def visit(node): ctx = AutogradFunctionContextVariable.create(tx, args, kwargs) args = [ctx, *args] if isinstance(fn, types.FunctionType): + sig = inspect.signature(fn) + if len(args) - 1 == len(sig._parameters): + args = args[1:] # Don't use context return variables.UserFunctionVariable(fn, source=source).call_function( tx, args, kwargs ) @@ -921,11 +902,9 @@ def var_getattr(self, tx: "InstructionTranslator", name): if self.needs_input_grad is not None: return variables.ConstantVariable.create(self.needs_input_grad) if self.source: - from .builder import VariableBuilder + source = AttrSource(self.source, "needs_input_grad") + return VariableTracker.build(tx, self.value.needs_input_grad, source) - return VariableBuilder(tx, AttrSource(self.source, "needs_input_grad"))( - self.value.needs_input_grad - ) return super().var_getattr(tx, name) @@ -950,7 +929,7 @@ def call_method( kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": if name == "queue_callback": - if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): assert ( tx.one_graph ), "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" @@ -997,7 +976,7 @@ def __init__(self, obj, name, **kwargs) -> None: self.obj = obj self.name = name - def __str__(self) -> str: + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.obj}, {self.name})" @staticmethod @@ -1130,11 +1109,8 @@ def __init__(self, desc, **kwargs) -> None: def var_getattr(self, tx: "InstructionTranslator", name): if name == "__get__" and self.source: - from .builder import VariableBuilder - - return VariableBuilder(tx, AttrSource(self.source, "__get__"))( - self.desc.__get__ - ) + source = AttrSource(self.source, "__get__") + return VariableTracker.build(tx, self.desc.__get__, source) else: return super().var_getattr(tx, name) @@ -1174,18 +1150,13 @@ def var_getattr(self, tx: "InstructionTranslator", name): if tx.output.side_effects.has_pending_mutation_of_attr(self, name): return tx.output.side_effects.load_attr(self, name) - from .builder import SourcelessBuilder, VariableBuilder - if self.is_torch or name not in self.value.__dict__: attr_value = getattr(self.value, name) else: attr_value = self.value.__dict__[name] - if self.source: - new_source = AttrSource(self.source, name) - return VariableBuilder(tx, new_source)(attr_value) - else: - return SourcelessBuilder.create(tx, attr_value) + source = self.source and AttrSource(self.source, name) + return VariableTracker.build(tx, attr_value, source) class TypingVariable(VariableTracker): @@ -1206,6 +1177,19 @@ def call_method( ) unimplemented("typing") + def var_getattr(self, tx: "InstructionTranslator", name: str): + from .builder import SourcelessBuilder, VariableBuilder + + if tx.output.side_effects.has_pending_mutation_of_attr(self, name): + return tx.side_effects.load_attr(self, name) + + value = getattr(self.value, name) + if self.source: + attr_source = AttrSource(self.source, name) + return VariableBuilder(tx, attr_source)(value) + else: + return SourcelessBuilder(tx, value) + def as_python_constant(self): return self.value @@ -1341,7 +1325,7 @@ class NullVariable(VariableTracker): def __init__(self, **kwargs) -> None: super().__init__(**kwargs) - def __str__(self) -> str: + def __repr__(self) -> str: return "NullVariable" def reconstruct(self, codegen): diff --git a/torch/_dynamo/variables/nn_module.py b/torch/_dynamo/variables/nn_module.py index f3ddcd80cf551..08c036949a999 100644 --- a/torch/_dynamo/variables/nn_module.py +++ b/torch/_dynamo/variables/nn_module.py @@ -39,6 +39,8 @@ object_has_getattribute, proxy_args_kwargs, set_example_value, + unpatched_nn_module_call, + unpatched_nn_module_call_impl, ) from .base import MutableLocal, typestr, VariableTracker from .functions import invoke_and_store_as_constant @@ -82,8 +84,11 @@ def convert_to_fake(x): @contextmanager def record_nn_module_stack(module_key: str, source, tx, mod: torch.nn.Module): fully_qualified_name = source.name() + num_calls = tx.num_calls.get(fully_qualified_name, 0) + module_key = f"{module_key}@{num_calls}" if num_calls > 0 else module_key try: tx.nn_module_stack[module_key] = (fully_qualified_name, mod.__class__) + tx.num_calls[fully_qualified_name] = num_calls + 1 yield finally: del tx.nn_module_stack[module_key] @@ -241,12 +246,7 @@ def _custom_getattr_fallback(self, base, tx, name, options): ) def var_getattr(self, tx: "InstructionTranslator", name): - from .builder import VariableBuilder - - if self.source: - source = AttrSource(self.source, name) - else: - source = None + source = self.source and AttrSource(self.source, name) base = tx.output.get_submodule(self.module_key) base_dict = object.__getattribute__(base, "__dict__") @@ -294,7 +294,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): return variables.UserDefinedClassVariable(base.__class__, source=source) if object_member: - out = VariableBuilder(tx, NNModuleSource(source))(subobj) + out = VariableTracker.build(tx, subobj, NNModuleSource(source)) if isinstance(out, (NNModuleVariable, UnspecializedNNModuleVariable)): # nn_module_stack source is BC surface area. Ensure that @@ -330,7 +330,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): return variables.UserMethodVariable(subobj, self, source=source) elif is_safe_constant(subobj) or istensor(subobj): # Support possibly common cases of class members - return VariableBuilder(tx, NNModuleSource(source))(subobj) + return VariableTracker.build(tx, subobj, NNModuleSource(source)) else: unimplemented( f"class property {name} - {typestr(base)} {typestr(subobj)}" @@ -859,12 +859,26 @@ def call_function( if mod.cls_to_become is not None: self.value_type = mod.cls_to_become initialize_lazy_module(tx, mod, args, kwargs) - name = "_call_impl" - fn = getattr(self.value_type, name) + + if ( + not isinstance(mod, torch.fx.GraphModule) + and mod.__call__.__func__ is not unpatched_nn_module_call + ): + name = "__call__" + fn = getattr(self.value_type, name) + else: + name = "_call_impl" + fn = getattr(self.value_type, name) # Check if we can short circuit nn.Module._call_impl to the forward # method. NB - This is done to reduce the compile time of Dynamo. - if fn is torch.nn.Module._call_impl and "forward" not in mod.__dict__: + if ( + istype(mod.__call__, types.MethodType) + and istype(mod._call_impl, types.MethodType) + and mod.__call__.__func__ is unpatched_nn_module_call + and mod._call_impl.__func__ is unpatched_nn_module_call_impl + and "forward" not in mod.__dict__ + ): forward_method = inspect.getattr_static(mod, "forward") if isinstance(forward_method, types.FunctionType): globals_vt = tx.nn_modules_globals_vt @@ -1080,7 +1094,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): ) return variables.ConstDictVariable({}) - # For non-empty hook dicts, one way is to just fallback to VariableBuilder and create a ConstDictVariable. + # For non-empty hook dicts, one way is to just fallback to VariableTracker.build() and create a ConstDictVariable. # However, ConstDictVariable guards on keys. This can cause recompiles when the same hook is installed for # differnt nn module instances, because the key keeps changing (look more into RemovableHandle to understand why # key changes - also related https://github.com/pytorch/pytorch/issues/125836). Here, we carefully craft a diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index 8bed74fb21613..3d9432e2bd32a 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -1,9 +1,11 @@ # mypy: ignore-errors +import logging import weakref from typing import Dict, List, TYPE_CHECKING import torch +from torch._logging import getArtifactLogger from torch.utils._pytree import tree_map_only from ..guards import GuardBuilder, install_guard @@ -15,6 +17,7 @@ GradSource, ) from ..utils import GLOBAL_KEY_PREFIX +from .base import VariableTracker from .constant import ConstantVariable from .dicts import ConstDictVariable from .lists import ListVariable @@ -25,8 +28,6 @@ if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator - from .base import VariableTracker - class ArgMappingException(Exception): pass @@ -36,6 +37,27 @@ class GuardInstallException(Exception): pass +perf_hint_log = getArtifactLogger(__name__, "perf_hints") + + +def _is_static_for_cudagraphs(x): + from torch._inductor.cudagraph_trees import get_manager + + if x.is_cuda: + manager = get_manager(x.device.index, False) + is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None + if manager: + return ( + is_static_address + or manager.current_node._is_cuda_graph_recorded_tensor(x) + ) + else: + return is_static_address + else: + # Don't print a warning for non-cuda tensors + return True + + class OptimizerVariable(UserDefinedObjectVariable): _nonvar_fields = { "grad_to_source", @@ -124,7 +146,6 @@ def graph_break_if_pending_mutation(self, tx): def _set_capturable(self, tx): from . import LazyVariableTracker - from .builder import VariableBuilder # We only set capturable if params are on cuda # and the state is not initialized @@ -145,10 +166,9 @@ def safe_to_set_capturable(group): if safe_to_set_capturable(group): group["capturable"] = True + source = self.source and AttrSource(self.source, "param_groups") param_groups_vt = LazyVariableTracker.realize_all( - VariableBuilder(tx, AttrSource(self.source, "param_groups"))( - self.value.param_groups - ) + VariableTracker.build(tx, self.value.param_groups, source) ) for ind, param_group_vt in enumerate(param_groups_vt.items): key = ConstDictVariable._HashableTracker( @@ -191,7 +211,6 @@ def move_step_if_cpu(self): def map_sources_and_install_guards(self, tx): from ..decorators import mark_static_address - from .builder import VariableBuilder from .lazy import LazyVariableTracker self.grad_to_source = {} @@ -212,15 +231,13 @@ def mark_static(x): # Recursively realize the variable trackers for optim.state and # optim.param_groups, which recursively install the necessary guards. + params_groups_source = self.source and AttrSource(self.source, "param_groups") param_groups_vt = LazyVariableTracker.realize_all( - VariableBuilder(tx, AttrSource(self.source, "param_groups"))( - self.value.param_groups - ) + VariableTracker.build(tx, self.value.param_groups, params_groups_source) ) - state_vt = VariableBuilder(tx, AttrSource(self.source, "state"))( - self.value.state - ) + state_source = self.source and AttrSource(self.source, "state") + state_vt = VariableTracker.build(tx, self.value.state, state_source) # We need to realize the top level state dict to populate # the guard locals @@ -242,20 +259,22 @@ def mark_static(x): key_index = i break if key_index: - state_source = AttrSource(self.source, "state") LazyVariableTracker.realize_all( - VariableBuilder( + VariableTracker.build( tx, + self.value.state[param], GetItemSource( state_source, ConstDictKeySource(state_source, key_index), ), - )(self.value.state[param]) + ) ) break group_source = group_vt.source params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params")) + all_static = True + non_static_grads = [] for p_ind, (p, p_vt) in enumerate( zip(group["params"], params_vt.unpack_var_sequence(tx)) ): @@ -268,12 +287,25 @@ def mark_static(x): if p.grad is not None: self.grad_to_source[p.grad] = grad_source + if not _is_static_for_cudagraphs(p.grad): + all_static = False + non_static_grads.append(grad_source) else: install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH)) + if not all_static and perf_hint_log.isEnabledFor(logging.WARNING): + non_static_grads = [src.name() for src in non_static_grads] + perf_hint_log.warning( + ( + "Grad tensors %s will be copied during cudagraphs execution." + "If using cudagraphs and the grad tensor addresses will be the same across runs," + " use torch._dynamo.decorators.mark_static_address to elide this copy.", + ), + non_static_grads, + ) + # We have to again iterate over the state dict to collect the # tensor_to_source dict. This is used for the finalizer. - state_source = AttrSource(self.source, "state") for idx, (p, value) in enumerate(self.value.state.items()): p_state_source = GetItemSource( state_source, ConstDictKeySource(state_source, idx) @@ -289,7 +321,6 @@ def mark_static(x): def wrap_tensor(self, tx: "InstructionTranslator", tensor_value): """Wrap state tensor in a TensorVariable""" from ..decorators import mark_static_address - from .builder import VariableBuilder # If we have a source for a tensor already use it, # if we have not seen a tensor before, stash and use a @@ -299,20 +330,19 @@ def wrap_tensor(self, tx: "InstructionTranslator", tensor_value): if tensor_value in self.tensor_to_source: # mark these tensors as static for cudagraphs mark_static_address(tensor_value) - builder = VariableBuilder(tx, self.tensor_to_source[tensor_value]) - self.static_tensor_names.add(tx.output.module_key_name(builder.name)) + source = self.tensor_to_source[tensor_value] + self.static_tensor_names.add(tx.output.module_key_name(source.name)) elif tensor_value in self.grad_to_source: - builder = VariableBuilder(tx, self.grad_to_source[tensor_value]) + source = self.grad_to_source[tensor_value] else: # mark these tensors as static for cudagraphs mark_static_address(tensor_value) global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value) - builder = VariableBuilder(tx, GlobalWeakRefSource(global_name)) - self.static_tensor_names.add(tx.output.module_key_name(builder.name)) + source = GlobalWeakRefSource(global_name) + self.static_tensor_names.add(tx.output.module_key_name(source.name)) - result = builder(tensor_value) - return result + return VariableTracker.build(tx, tensor_value, source) def update_list_args( self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs @@ -328,14 +358,8 @@ def update_list_args( if isinstance(val, torch.Tensor): arg.items.append(self.wrap_tensor(tx, val)) else: - from .builder import SourcelessBuilder, VariableBuilder - - if arg.source: - arg.items.append( - VariableBuilder(tx, GetItemSource(arg.source, i))(val) - ) - else: - arg.items.append(SourcelessBuilder.create(tx, val)) + source = arg.source and GetItemSource(arg.source, i) + arg.items.append(VariableTracker.build(tx, val, source)) def create_finalizer(self, tx): names_to_delete = self.static_tensor_names diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 611450ae6cf9a..51c1ea6bf141d 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -5,12 +5,15 @@ from ..bytecode_transformation import create_call_function from ..exc import Unsupported +from ..source import AttrSource from .base import VariableTracker if TYPE_CHECKING: from torch._dynamo.symbolic_convert import InstructionTranslator +PARAM_NAMES = "query key value attn_mask dropout is_causal enable_gqa".split() + class SDPAParamsVariable(VariableTracker): """Represents the c++ params struct for scaled dot product attention. @@ -20,35 +23,13 @@ class SDPAParamsVariable(VariableTracker): def create(tx: "InstructionTranslator", value, source): from torch.backends.cuda import SDPAParams - from ..source import AttrSource - from .builder import VariableBuilder from .torch import TorchInGraphFunctionVariable - query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query) - key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key) - value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value) - attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))( - value.attn_mask - ) - dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout) - is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))( - value.is_causal - ) - enable_gqa_var = VariableBuilder(tx, AttrSource(source, "enable_gqa"))( - value.enable_gqa - ) - param_vars = [ - query_var, - key_var, - value_var, - attn_mask_var, - dropout_var, - is_causal_var, - enable_gqa_var, + params = [ + VariableTracker.build(tx, getattr(value, p), AttrSource(source, p)) + for p in PARAM_NAMES ] - return TorchInGraphFunctionVariable(SDPAParams).call_function( - tx, param_vars, {} - ) + return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {}) def __init__(self, proxy, param_vars, **kwargs) -> None: self.proxy = proxy @@ -70,7 +51,6 @@ def as_proxy(self): def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker: import torch._C - from ..source import AttrSource from .builder import wrap_fx_proxy from .misc import GetAttrVariable diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 514a712c89165..8c55e7d5a6cc6 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -103,6 +103,7 @@ class TensorVariable(VariableTracker): "requires_grad", "is_quantized", "is_contiguous", + "is_nested", "is_sparse", "class_type", "specialized_value", @@ -128,6 +129,7 @@ def __init__( layout, ndim, requires_grad, + is_nested, is_quantized, is_sparse, class_type, @@ -149,6 +151,7 @@ def __init__( self.requires_grad = requires_grad self.is_quantized = is_quantized self.is_contiguous = is_contiguous + self.is_nested = is_nested self.is_sparse = is_sparse self.class_type = class_type self.has_grad_fn = has_grad_fn @@ -175,6 +178,7 @@ def specialize(value: torch.Tensor): "layout": value.layout, "ndim": int(value.ndim), "requires_grad": value.requires_grad, + "is_nested": value.is_nested, "is_quantized": value.is_quantized, "is_sparse": value.is_sparse, "class_type": type(value), @@ -238,9 +242,7 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name): # any other attributes on the subclass (that are not methods) # are assumed to be constant metadata. elif not callable(example_value): - from .builder import SourcelessBuilder - - return SourcelessBuilder.create(tx, example_value) + return VariableTracker.build(tx, example_value) if not (self.source and self.source.subguards_allowed()): raise NotImplementedError @@ -277,12 +279,9 @@ def dynamic_getattr(self, tx: "InstructionTranslator", name): # Note - at a certain point we may want to handle raise NotImplementedError - from ..guards import GuardBuilder - from .builder import VariableBuilder - attr_source = AttrSource(self.source, name) install_guard(attr_source.make_guard(GuardBuilder.HASATTR)) - return VariableBuilder(tx, attr_source)(real_value) + return VariableTracker.build(tx, real_value, attr_source) def method_attr_ndim(self, tx): if self.ndim is not None: @@ -325,6 +324,10 @@ def method_attr_is_sparse(self, tx): if self.is_sparse is not None: return ConstantVariable.create(self.is_sparse) + def method_attr_is_nested(self, tx): + if self.is_nested is not None: + return ConstantVariable.create(self.is_nested) + def method_attr_data(self, tx): return variables.TorchInGraphFunctionVariable( torch._C._autograd._get_data_attr @@ -510,9 +513,37 @@ def call_method( args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": + from .builder import SourcelessBuilder, VariableBuilder + from .torch_function import can_dispatch_torch_function, dispatch_torch_function + if self.is_strict_mode(tx) and name in self._strict_mode_banned_ops(): unimplemented(f"Illegal method invocation {name} in strict mode") + # Only override builtin tensor methods + # The user can manually add override handling + # with a decorator for other methods (e.g. a dispatch subclass with other methods) + has_torch_function_override = False + try: + inspect.getattr_static(torch.Tensor, name) + has_torch_function_override = True + except AttributeError: + has_torch_function_override = False + + if ( + can_dispatch_torch_function(tx, tuple([self] + list(args)), kwargs) + and has_torch_function_override + ): + if self.source: + func_var = VariableBuilder( + tx, AttrSource(AttrSource(self.source, "__class__"), name) + )(inspect.getattr_static(torch.Tensor, name)) + else: + func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name)) + + return dispatch_torch_function( + tx, func_var, tuple([self] + list(args)), kwargs + ) + """ Dispatch to a method-specific handler defined below. If the handler returns None (or doesn't exist) we put the method call @@ -667,7 +698,6 @@ def method_type(self, dtype=None, non_blocking=False, **kwargs): def method_as_subclass(self, cls): if isinstance(cls, TensorSubclassVariable) and cls.source: from ..symbolic_convert import InstructionTranslator - from .builder import VariableBuilder from .torch_function import TensorWithTFOverrideVariable tx = InstructionTranslator.current_tx() @@ -677,10 +707,11 @@ def method_as_subclass(self, cls): # defines a constructor, but if only a __torch_function__ impl is defined, this is okay to call. # It is up to the user whether this is correct behavior or not. py_cls = cls.as_python_constant() - torch_fn = VariableBuilder( + torch_fn = VariableTracker.build( tx, + py_cls.__torch_function__.__func__, AttrSource(AttrSource(cls.source, "__torch_function__"), "__func__"), - )(py_cls.__torch_function__.__func__) + ) return TensorWithTFOverrideVariable.from_tensor_var( tx, self, py_cls, torch_fn @@ -722,7 +753,6 @@ def method_numpy(self, *, force=False): def method_tolist(self): from ..symbolic_convert import InstructionTranslator - from .builder import SourcelessBuilder tx = InstructionTranslator.current_tx() @@ -759,20 +789,20 @@ def wrap(i, sub_proxy): tensor = self.as_proxy().node.meta["example_value"] out = tolist(tensor, self.as_proxy()) - return SourcelessBuilder.create(tx, out) + return VariableTracker.build(tx, out) def method_backward(self, *args, **kwargs): unimplemented("Tensor.backward") def method_data_ptr(self, *args, **kwargs): - unimplemented("Tensor.data_ptr") + return DataPtrVariable(self) def method_item(self, *args, **kwargs): if not config.capture_scalar_outputs: self._warn_capture_scalar_outputs() unimplemented("Tensor.item") - def method_getitem(self, *args, **kwargs): + def method___getitem__(self, *args, **kwargs): from ..symbolic_convert import InstructionTranslator from .builder import wrap_fx_proxy @@ -829,10 +859,9 @@ def method_addcmul_(self, tensor1, tensor2, *, value=None): tx = InstructionTranslator.current_tx() if value is not None: from .. import polyfills - from .builder import SourcelessBuilder return tx.inline_user_function_return( - SourcelessBuilder.create(tx, polyfills.addcmul_inplace), + VariableTracker.build(tx, polyfills.addcmul_inplace), [self, tensor1, tensor2, value], {}, ) @@ -846,15 +875,6 @@ def has_bool_key(v): else: return False - if ( - has_bool_key(key) - and isinstance(value, TensorVariable) - and value.requires_grad - and torch.is_grad_enabled() - ): - unimplemented( - "boolean masking setitem backwards, see https://github.com/pytorch/pytorch/issues/114123" - ) from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() @@ -995,7 +1015,7 @@ def _method_register_hook(self, name: str, hook: VariableTracker): tx = InstructionTranslator.current_tx() if not self.source: - if not compiled_autograd.compiled_autograd_enabled: + if not compiled_autograd.enabled(): # TODO(voz): # We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary # python state. @@ -1134,13 +1154,11 @@ def python_type(self): def as_proxy(self): return self.proxy - def as_tensor(self, tx): + def as_tensor(self, tx, dtype): if self._tensor_var is None: - from .builder import SourcelessBuilder - - self._tensor_var = SourcelessBuilder.create( + self._tensor_var = VariableTracker.build( tx, torch.scalar_tensor - ).call_function(tx, [self], {}) + ).call_function(tx, [self], {"dtype": VariableTracker.build(tx, dtype)}) return self._tensor_var def evaluate_expr(self, output_graph=None): @@ -1343,12 +1361,10 @@ def call_function( kwargs: Dict[str, VariableTracker], ) -> VariableTracker: if len(args) == 1 and isinstance(args[0], TensorVariable): - from .builder import VariableBuilder from .torch_function import TensorWithTFOverrideVariable - torch_fn = VariableBuilder( - tx, AttrSource(self.source, "__torch_function__") - )(self.value.__torch_function__) + source = AttrSource(self.source, "__torch_function__") + torch_fn = VariableTracker.build(tx, self.value.__torch_function__, source) return TensorWithTFOverrideVariable.from_tensor_var( tx, args[0], self.value, torch_fn @@ -1420,3 +1436,18 @@ def reconstruct(self, codegen): codegen(self.from_tensor) codegen.load_method("untyped_storage") codegen.call_method(0) + + +class DataPtrVariable(VariableTracker): + def __init__( + self, + from_tensor: TensorVariable, + **kwargs, + ) -> None: + super().__init__(**kwargs), + self.from_tensor = from_tensor + + def reconstruct(self, codegen): + codegen(self.from_tensor, allow_cache=False) + codegen.load_method("data_ptr") + codegen.call_method(0) diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index 77d8d2fcf8c10..5ec7910b43cce 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -397,7 +397,7 @@ def _register(handler): TensorVariable, UserDefinedObjectVariable, ) - from .builder import SourcelessBuilder, wrap_fx_proxy, wrap_fx_proxy_cls + from .builder import wrap_fx_proxy, wrap_fx_proxy_cls @register(*tracing_state_functions) def handle_tracing_state_functions( @@ -422,14 +422,14 @@ def handle_get_default_nowrap_functions( # the set of functions that we trace __torch_function__ on to # functions outside of the actual set. Implementing this properly will require implementing # some variable types to track and compare tensor getset descriptors - return SourcelessBuilder.create( + return VariableTracker.build( tx, torch.overrides.get_default_nowrap_functions() ) @register(torch.ops.inductor.accumulate_grad_.default) def handle_accumulate_grad_(self, tx: "InstructionTranslator", *args, **kwargs): return tx.inline_user_function_return( - SourcelessBuilder.create(tx, polyfills.accumulate_grad), args, kwargs + VariableTracker.build(tx, polyfills.accumulate_grad), args, kwargs ) @register(math.radians) @@ -437,7 +437,7 @@ def handle_radians(self, tx: "InstructionTranslator", *args, **kwargs): if not check_unspec_or_constant_args(args, kwargs): # Use polyfill to convert math.radians(x) into math.pi * x / 180.0 return tx.inline_user_function_return( - SourcelessBuilder.create(tx, polyfills.radians), args, kwargs + VariableTracker.build(tx, polyfills.radians), args, kwargs ) @register(torch.is_tensor, torch.overrides.is_tensor_like) @@ -622,7 +622,7 @@ def handle_inplace_foreach_lerp_scalar( ): if len(args) == 3 and not isinstance(args[2], ListVariable) and not kwargs: return tx.inline_user_function_return( - SourcelessBuilder.create(tx, polyfills.foreach_lerp_inplace), + VariableTracker.build(tx, polyfills.foreach_lerp_inplace), args, kwargs, ) @@ -635,7 +635,7 @@ def handle_foreach_pow_scalar( # in compile, it's more performant to not graph break. if len(args) == 2 and isinstance(args[0], TensorVariable) and not kwargs: return tx.inline_user_function_return( - SourcelessBuilder.create(tx, polyfills.foreach_pow_scalar), + VariableTracker.build(tx, polyfills.foreach_pow_scalar), args, kwargs, ) @@ -704,7 +704,7 @@ def handle_constant_processgroup_functions( # Note - while we *could* cook up sources around invocations, like a FunctionSource # the space of invoking functions in the middle of the guard chain is very iffy. As such, # guard propagation via options is the best we can do. - return SourcelessBuilder.create(tx, invocation_result) + return VariableTracker.build(tx, invocation_result) @register(DTensor.from_local) def handle_from_local(self, tx: "InstructionTranslator", *args, **kwargs): @@ -871,10 +871,6 @@ def handle_set_default_device( return ConstantVariable.create(None) - @register(torch._C.TensorBase.__getitem__) - def handle_getitem(self, tx: "InstructionTranslator", *args, **kwargs): - return args[0].call_method(tx, "getitem", args[1:], kwargs) - return handlers def call_function( @@ -904,6 +900,9 @@ def call_function( ), ) + if self.is_tensor_method(): + return self.call_tensor_method(tx, args, kwargs) + special_handler = self._get_handlers().get(self.value) if special_handler: result = special_handler(self, tx, *args, **kwargs) @@ -1144,8 +1143,6 @@ def call_nn_parameter(cls, tx, data=None, requires_grad=True): @staticmethod def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad): # Alternate version if we have a .source - from .builder import VariableBuilder - varname = tx.output.new_var() # construct the nn.Parmeter before the graph save it to varname @@ -1168,7 +1165,7 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad example_value = torch.nn.Parameter( tx.output.example_value_from_input_node(data.as_proxy().node) ) - result = VariableBuilder(tx, source)(example_value) + result = VariableTracker.build(tx, example_value, source) # No need to guard on this since we already guarded on `data`. # These guards would fail since varname doesn't exist until after the function starts TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source( @@ -1176,6 +1173,16 @@ def _nn_param_via_prefix_insert(tx: "InstructionTranslator", data, requires_grad ) return result + def call_tensor_method(self, tx, args, kwargs): + return args[0].call_method(tx, self.get_function().__name__, args[1:], kwargs) + + def is_tensor_method(self): + return ( + inspect.ismethoddescriptor(self.get_function()) + and hasattr(self.get_function(), "__objclass__") + and self.get_function().__objclass__ == torch._C.TensorBase + ) or self.get_function() is torch.Tensor.__contains__ + def torch_function_override_enabled(self, tx, args, kwargs): return ( self.get_function() in get_overridable_functions() diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 3662f34804a01..b89fad0274799 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -442,7 +442,6 @@ def _flatten_vts(vts): from collections import deque from .dicts import ConstDictVariable - from .lazy import LazyVariableTracker from .lists import ListVariable vts = deque(vts) @@ -450,13 +449,17 @@ def _flatten_vts(vts): while vts: vt = vts.pop() - LazyVariableTracker.realize_all(vt) - if isinstance(vt, ListVariable): - vts.extend(vt.items) - elif isinstance(vt, ConstDictVariable): - vts.extend(vt.items.values()) - else: - output.append(vt) + + if not vt.is_realized() and vt.peek_type() in (dict, list, tuple): + vt.realize() + + if vt.is_realized(): + if isinstance(vt, ListVariable): + vts.extend(vt.items) + elif isinstance(vt, ConstDictVariable): + vts.extend(vt.items.values()) + + output.append(vt) return output @@ -471,12 +474,8 @@ def _get_subclass_type_var(tx: "InstructionTranslator", var): if isinstance(var, TensorWithTFOverrideVariable): return var.class_type_var(tx) elif isinstance(var, UserDefinedObjectVariable): - from .builder import SourcelessBuilder, VariableBuilder - - if var.source: - return VariableBuilder(tx, TypeSource(var.source))(var.python_type()) - else: - return SourcelessBuilder.create(tx, var.python_type()) + source = var.source and TypeSource(var.source) + return VariableTracker.build(tx, var.python_type(), source) def _is_attr_overidden(tx: "InstructionTranslator", var, name): @@ -495,16 +494,14 @@ def _is_attr_overidden(tx: "InstructionTranslator", var, name): def call_torch_function( tx, torch_function_type, torch_function_var, fn, types, args, kwargs ): - from .builder import SourcelessBuilder - # signature: # def __torch_function__(cls, func, types, args=(), kwargs=None): tf_args = ( torch_function_type, fn, types, - SourcelessBuilder.create(tx, tuple(args)), - SourcelessBuilder.create(tx, kwargs), + VariableTracker.build(tx, tuple(args)), + VariableTracker.build(tx, kwargs), ) return tx.inline_user_function_return(torch_function_var, tf_args, {}) @@ -512,20 +509,13 @@ def call_torch_function( def build_torch_function_fn(tx: "InstructionTranslator", value, source): from types import FunctionType - from .builder import SourcelessBuilder, VariableBuilder - func = value.__torch_function__.__func__ if not isinstance(func, FunctionType): unimplemented("Builtin/C++ torch function implementations NYI") - if source: - return VariableBuilder( - tx, - AttrSource(AttrSource(source, "__torch_function__"), "__func__"), - )(value.__torch_function__.__func__) - else: - return SourcelessBuilder.create(tx, value.__torch_function__.__func__) + source = source and AttrSource(AttrSource(source, "__torch_function__"), "__func__") + return VariableTracker.build(tx, func, source) def can_dispatch_torch_function(tx: "InstructionTranslator", args, kwargs): @@ -622,8 +612,6 @@ def var_getattr(self, tx: "InstructionTranslator", name): # base tensors, custom attribute accesses will graph break. import torch - from .builder import SourcelessBuilder - if name in banned_attrs: unimplemented( f"Accessing {name} on a tensor subclass with a __torch_function__ override is not supported" @@ -642,7 +630,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): GuardBuilder.FUNCTION_MATCH ) ) - get_fn = SourcelessBuilder.create(tx, getattr(torch.Tensor, name).__get__) + get_fn = VariableTracker.build(tx, getattr(torch.Tensor, name).__get__) return self.call_torch_function( tx, @@ -677,8 +665,6 @@ def call_method( if tx.output.torch_function_enabled: import torch - from .builder import SourcelessBuilder, VariableBuilder - if _is_attr_overidden(tx, self, name): unimplemented( f"Calling overridden method {name} on a tensor" @@ -690,11 +676,12 @@ def call_method( # We've established with the above check that the method is not overridden, so we guard that the method is the same # as the impl defined on tensor and retrieve it if self.source: - func_var = VariableBuilder( - tx, AttrSource(AttrSource(self.source, "__class__"), name) - )(inspect.getattr_static(self.python_type(), name)) + source = AttrSource(AttrSource(self.source, "__class__"), name) + value = inspect.getattr_static(self.python_type(), name) else: - func_var = SourcelessBuilder.create(tx, getattr(torch.Tensor, name)) + source = None + value = getattr(torch.Tensor, name) + func_var = VariableTracker.build(tx, value, source) return dispatch_torch_function(tx, func_var, [self] + args, kwargs) else: return super().call_method(tx, name, args, kwargs) diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 609c4872ea868..ff0c2bdf6a50c 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -13,6 +13,7 @@ import types import warnings from typing import Dict, Generic, List, TYPE_CHECKING +from typing_extensions import is_typeddict import torch._dynamo.config import torch.nn @@ -115,7 +116,7 @@ def as_python_constant(self): def as_proxy(self): return self.value - def __str__(self) -> str: + def __repr__(self) -> str: return f"UserDefinedClassVariable({self.value})" @staticmethod @@ -158,7 +159,6 @@ def has_key_in_generic_dict(self, tx: "InstructionTranslator", key): def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": from . import ConstantVariable, EnumVariable - from .builder import SourcelessBuilder, VariableBuilder source = AttrSource(self.source, name) if self.source is not None else None @@ -187,11 +187,7 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke obj = None if isinstance(obj, staticmethod): - func = obj.__get__(self.value) - if source is not None: - return VariableBuilder(tx, source)(func) - else: - return SourcelessBuilder.create(tx, func) + return VariableTracker.build(tx, obj.__get__(self.value), source) elif isinstance(obj, classmethod): if isinstance(obj.__func__, property): return variables.UserFunctionVariable(obj.__func__.fget).call_function( @@ -202,16 +198,13 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke # e.g.: inspect.getattr_static(dict, "fromkeys") # inspect.getattr_static(itertools.chain, "from_iterable") func = obj.__get__(None, self.value) - if source is not None: - return VariableBuilder(tx, source)(func) - else: - return SourcelessBuilder.create(tx, func) + return VariableTracker.build(tx, func, source) elif source: # __mro__ is a member in < 3.12, an attribute in >= 3.12 if inspect.ismemberdescriptor(obj) or ( sys.version_info >= (3, 12) and name == "__mro__" ): - return VariableBuilder(tx, source)(obj.__get__(self.value)) + return VariableTracker.build(tx, obj.__get__(self.value), source) if ConstantVariable.is_literal(obj): return ConstantVariable.create(obj) @@ -222,14 +215,15 @@ def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracke or self.value.__module__ == "torch" ): if source: - return VariableBuilder(tx, source)(obj) + return VariableTracker.build(tx, obj, source) if ( source and not inspect.ismethoddescriptor(obj) and not is_wrapper_or_member_descriptor(obj) ): - return VariableBuilder(tx, source)(obj) + return VariableTracker.build(tx, obj, source) + return super().var_getattr(tx, name) def _call_cross_entropy_loss(self, tx: "InstructionTranslator", args, kwargs): @@ -341,7 +335,7 @@ def call_function( kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": from ..side_effects import SideEffects - from .builder import SourcelessBuilder, wrap_fx_proxy + from .builder import wrap_fx_proxy from .builtin import BuiltinVariable constant_args = check_constant_args(args, kwargs) @@ -376,6 +370,10 @@ def call_function( args[0], mutable_local=MutableLocal(), ) + elif is_typeddict(self.value): + if self.value.__optional_keys__: + unimplemented("TypedDict with optional keys not supported") + return variables.BuiltinVariable(dict).call_dict(tx, *args, **kwargs) elif self.value is collections.deque and not kwargs: if len(args) == 0: items = [] @@ -436,35 +434,35 @@ def call_function( fields = namedtuple_fields(self.value) # check if this a quasi-namedtuple or a real one if self.value.__module__ == "torch.return_types": - # create pseudo-defaults from values of the quasi-namedtuple - field_defaults = dict(zip(fields, args[0].items)) + assert len(args) == 1 + assert not kwargs + items = args[0].force_unpack_var_sequence(tx) else: field_defaults = self.value._field_defaults - items = list(args) - items.extend([None] * (len(fields) - len(items))) + items = list(args) + items.extend([None] * (len(fields) - len(items))) - var_tracker_kwargs = {} - for field_name, var_tracker in zip(fields, items): - if var_tracker is None: - if field_name in kwargs: - field_var = kwargs[field_name] - else: - assert field_name in field_defaults - field_var = SourcelessBuilder.create( - tx, field_defaults[field_name] - ) - var_tracker_kwargs[field_name] = field_var + var_tracker_kwargs = {} + for field_name, var_tracker in zip(fields, items): + if var_tracker is None: + if field_name in kwargs: + field_var = kwargs[field_name] + else: + assert field_name in field_defaults + field_var = VariableTracker.build( + tx, field_defaults[field_name] + ) + var_tracker_kwargs[field_name] = field_var + + for name, value in var_tracker_kwargs.items(): + assert name in fields + items[fields.index(name)] = value - for name, value in var_tracker_kwargs.items(): - assert name in fields - items[fields.index(name)] = value + assert all(x is not None for x in items) - assert all(x is not None for x in items) return variables.NamedTupleVariable(items, self.value) elif is_frozen_dataclass(self.value) and self.is_standard_new(): - from .builder import SourcelessBuilder - fields = dataclasses.fields(self.value) items = list(args) items.extend([None] * (len(fields) - len(items))) @@ -479,9 +477,9 @@ def call_function( continue if field.default is not dataclasses.MISSING: - var_tracker = SourcelessBuilder.create(tx, field.default) + var_tracker = VariableTracker.build(tx, field.default) elif field.default_factory is not dataclasses.MISSING: - factory_fn = SourcelessBuilder.create( + factory_fn = VariableTracker.build( tx, field.default_factory ) var_tracker = factory_fn.call_function(tx, [], {}) @@ -571,7 +569,7 @@ def call_function( and self.source ): return tx.inline_user_function_return( - SourcelessBuilder.create( + VariableTracker.build( tx, polyfills.instantiate_user_defined_class_object ), [self, *args], @@ -855,7 +853,6 @@ def call_function( kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": from .. import trace_rules - from .builder import VariableBuilder if ( self.is_supported_random() @@ -892,9 +889,9 @@ def call_function( "Sourceless UserDefinedObjectVariable method not supported" ) func_src = AttrSource(self.source, "__func__") - func_var = VariableBuilder(tx, func_src)(func) + func_var = VariableTracker.build(tx, func, func_src) obj_src = AttrSource(self.source, "__self__") - obj_var = VariableBuilder(tx, obj_src)(obj) + obj_var = VariableTracker.build(tx, obj, obj_src) return func_var.call_function(tx, [obj_var] + args, kwargs) elif ( istype(self.value, functools.partial) @@ -996,7 +993,6 @@ def is_supported_nn_module_method(self, method): def var_getattr(self, tx: "InstructionTranslator", name): from .. import trace_rules from . import ConstantVariable - from .builder import SourcelessBuilder, VariableBuilder source = AttrSource(self.source, name) if self.source else None self._check_for_getattribute() @@ -1029,8 +1025,13 @@ def var_getattr(self, tx: "InstructionTranslator", name): if isinstance(getattr_fn, types.FunctionType): # Dynamo is going to trace the __getattr__ function with # args=name. Set the source accordingly. - if getattr_fn is unpatched_nn_module_getattr and isinstance( - self, variables.UnspecializedNNModuleVariable + if ( + getattr_fn is unpatched_nn_module_getattr + and isinstance(self, variables.UnspecializedNNModuleVariable) + # prevent against overwriting of params/buffers/submodules + and istype(self.value._parameters, dict) + and istype(self.value._buffers, dict) + and istype(self.value._modules, dict) ): # Manually trace out the nn module __getattr__ to avoid large compilation latency. out = self.manually_trace_nn_module_getattr(tx, name) @@ -1083,10 +1084,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): elif isinstance(subobj, types.ClassMethodDescriptorType): # e.g.: inspect.getattr_static({}, "fromkeys") func = subobj.__get__(self.value, None) - if source is not None: - return VariableBuilder(tx, source)(func) - else: - return SourcelessBuilder.create(tx, func) + return VariableTracker.build(tx, func, source) elif inspect.ismethoddescriptor(subobj) and not is_wrapper_or_member_descriptor( subobj.__get__ ): @@ -1181,7 +1179,7 @@ def var_getattr(self, tx: "InstructionTranslator", name): subobj_from_class, src_from_class ) - return SourcelessBuilder.create(tx, subobj) + return VariableTracker.build(tx, subobj) # Earlier we were returning GetAttrVariable but its incorrect. In absence of attr, Python raises AttributeError. raise_observed_exception(AttributeError, tx) @@ -1205,7 +1203,6 @@ def call_hasattr(self, tx: "InstructionTranslator", name: str) -> "VariableTrack return variables.ConstantVariable.create(False) def odict_getitem(self, tx: "InstructionTranslator", key): - from .builder import VariableBuilder from .dicts import is_hashable # TODO this should probably be merged with the dict handling @@ -1216,10 +1213,11 @@ def odict_getitem(self, tx: "InstructionTranslator", key): else key.as_python_constant() ) - return VariableBuilder( + return VariableTracker.build( tx, - ODictGetItemSource(self.source, index), - )(collections.OrderedDict.__getitem__(self.value, key.as_python_constant())) + collections.OrderedDict.__getitem__(self.value, key.as_python_constant()), + self.source and ODictGetItemSource(self.source, index), + ) class FrozenDataClassVariable(UserDefinedObjectVariable): @@ -1229,14 +1227,14 @@ def create(tx, value, source): assert is_frozen_dataclass(value) - from .builder import VariableBuilder - field_map = {} for field in fields(value): if hasattr(value, field.name): - field_map[field.name] = VariableBuilder( - tx, AttrSource(source, field.name) - )(getattr(value, field.name)) + field_map[field.name] = VariableTracker.build( + tx, + getattr(value, field.name), + source and AttrSource(source, field.name), + ) return FrozenDataClassVariable(value, fields=field_map, source=source) @@ -1308,16 +1306,8 @@ def call_function( ) -> "VariableTracker": call_source = None referent = self.value() - - if self.source: - from .builder import VariableBuilder - - call_source = WeakRefCallSource(self.source) - return VariableBuilder(tx, call_source)(referent) - else: - from .builder import SourcelessBuilder - - return SourcelessBuilder.create(tx, referent) + source = self.source and WeakRefCallSource(self.source) + return VariableTracker.build(tx, referent, source) class KeyedJaggedTensorVariable(UserDefinedObjectVariable): @@ -1393,10 +1383,27 @@ class MutableMappingVariable(UserDefinedObjectVariable): def __init__(self, value, **kwargs): super().__init__(value, **kwargs) + self.generic_dict_vt = variables.ConstDictVariable({}) def var_getattr(self, tx: "InstructionTranslator", name: str) -> "VariableTracker": + # A common pattern in the init code of MutableMapping objects is to + # update the __dict__ attribute. To prevent graph break, we directly + # return a ConstDictVariable for the __dict__attr. + # + # However, users can try to add a new attribute to the class using the + # __dict__ attribute. To catch this, we save the ConstDictVariable for + # the __dict__ and then lookup into this vt for each attr lookup. if name == "get" and type(self.value).get is collections.abc.Mapping.get: return variables.UserMethodVariable(polyfills.mapping_get, self) + elif name == "__dict__" and self.source: + self.generic_dict_vt = variables.LazyVariableTracker.create( + self.value.__dict__, AttrSource(self.source, "__dict__") + ) + return self.generic_dict_vt + elif out := self.generic_dict_vt.maybe_getitem_const( + variables.ConstantVariable(name) + ): + return out else: return super().var_getattr(tx, name) diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index e3515b486c4c5..a8bb964a0ff92 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -3,7 +3,7 @@ import inspect import logging from collections import defaultdict -from typing import Any, Callable, Dict, List, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Set, Tuple, TYPE_CHECKING, Union import torch import torch.utils._pytree as pytree @@ -26,6 +26,7 @@ _combine_args, _DimHint, _process_dynamic_shapes, + _RelaxedConstraint, _tree_map_with_path, ) from torch.export.graph_signature import CustomObjArgument @@ -115,9 +116,11 @@ def fakify( assert mode.shape_env is not None if t_id in t_constraints: for i, constraint in t_constraints[t_id].items(): - symbolic_context.constraint_sizes[i] = constraint.constraint_range src = TensorPropertySource(base=source, prop=TensorProperty.SIZE, idx=i) sources[(t_id, i)].append(src) + if isinstance(constraint, _RelaxedConstraint): + continue + symbolic_context.constraint_sizes[i] = constraint.constraint_range mode.shape_env.source_name_to_debug_name[src.name()] = constraint.name # type: ignore[assignment] fake = mode.from_tensor(t, source=source, symbolic_context=symbolic_context) mode.shape_env.tracked_fakes.append(TrackedFake(fake, source, symbolic_context)) # type: ignore[union-attr] @@ -209,6 +212,7 @@ def make_fake_inputs( source_pairs: List[Tuple[Source, Source]] = [] derived_equalities: List[Tuple[Source, Union[Source, Symbol], Callable]] = [] phantom_symbols: Dict[str, Symbol] = {} + relaxed_sources: Set[Source] = set() for constraint in constraints: torch.export.dynamic_shapes._process_equalities( constraint, @@ -218,12 +222,14 @@ def make_fake_inputs( source_pairs, derived_equalities, phantom_symbols, + relaxed_sources, ) equalities_inputs = EqualityConstraint( source_pairs=source_pairs, derived_equalities=derived_equalities, phantom_symbols=list(phantom_symbols.values()), + relaxed_sources=relaxed_sources, warn_only=False, ) return ( diff --git a/torch/_export/passes/collect_tracepoints_pass.py b/torch/_export/passes/collect_tracepoints_pass.py index c89d2216632fa..c84aec50b8369 100644 --- a/torch/_export/passes/collect_tracepoints_pass.py +++ b/torch/_export/passes/collect_tracepoints_pass.py @@ -66,6 +66,18 @@ def get_arg_spec(arg): node.meta["nn_module_stack"].popitem() else: nn_module_stack = None + + def copy_sig(sig): + from torch.export.exported_program import ModuleCallSignature + + return ModuleCallSignature( + inputs=[], + outputs=[], + in_spec=sig.in_spec, + out_spec=sig.out_spec, + forward_arg_names=None, + ) + for module in gm.modules(): if not isinstance(module, torch.fx.GraphModule): continue @@ -73,16 +85,19 @@ def get_arg_spec(arg): if node.op != "call_function": continue if node.target == torch.ops.higher_order._export_tracepoint: + path = node.kwargs["path"] + module_key = next(reversed(node.meta["nn_module_stack"])) + if "@" in module_key: + call_path = f"{path}@{module_key.split('@')[-1]}" + if call_path not in self.specs: + self.specs[call_path] = copy_sig(self.specs[path]) + path = call_path + kind = node.kwargs["kind"] for i, arg in enumerate(node.args): - kind = node.kwargs["kind"] if kind == "module_call_inputs": - self.specs[node.kwargs["path"]].inputs.append( - get_arg_spec(arg) - ) + self.specs[path].inputs.append(get_arg_spec(arg)) elif kind == "module_call_outputs": - self.specs[node.kwargs["path"]].outputs.append( - get_arg_spec(arg) - ) + self.specs[path].outputs.append(get_arg_spec(arg)) else: raise AssertionError(f"Unknown tracepoint kind: {kind}") if isinstance(arg, torch.fx.Node): diff --git a/torch/_export/passes/lift_constants_pass.py b/torch/_export/passes/lift_constants_pass.py index 60123be711b67..3085b62d53206 100644 --- a/torch/_export/passes/lift_constants_pass.py +++ b/torch/_export/passes/lift_constants_pass.py @@ -16,6 +16,7 @@ InputSpec, TensorArgument, ) +from torch.fx.graph_module import _get_attr class ConstantAttrMap(collections.abc.MutableMapping): @@ -154,9 +155,10 @@ def lift_constants_pass( first_user_input_loc += 1 lifted_objs = ConstantAttrMap() + renamed_targets = {} for node in gm.graph.nodes: if node.op == "get_attr": - constant_val = getattr(gm, node.target) + constant_val = _get_attr(gm, node.target) if constant_val in lifted_objs: # We already lifted this constant elsewhere. Just rewrite uses # of this get_attr to point to the already-existing placeholder @@ -164,6 +166,7 @@ def lift_constants_pass( const_placeholder_node = _get_first_fqn(lifted_objs, constant_val) node.replace_all_uses_with(const_placeholder_node) gm.graph.erase_node(node) + renamed_targets[node.name] = const_placeholder_node.name continue # For ScriptObject, Tensor and FakeScriptObject constants: @@ -262,6 +265,8 @@ def lift_constants_pass( node.replace_all_uses_with(const_placeholder_node) gm.graph.erase_node(node) + renamed_targets[node.name] = const_placeholder_node.name + # Add the constant as a buffer to the graph signature graph_signature.input_specs.insert( first_user_input_loc, @@ -278,6 +283,10 @@ def lift_constants_pass( all_constants[constant_fqn] = constant_val first_user_input_loc += 1 + for spec in graph_signature.output_specs: + if spec.arg.name in renamed_targets: + spec.arg.name = renamed_targets[spec.arg.name] + return all_constants diff --git a/torch/_export/utils.py b/torch/_export/utils.py index 47c8529d5c4fd..a34aea5519a0a 100644 --- a/torch/_export/utils.py +++ b/torch/_export/utils.py @@ -1,20 +1,37 @@ # mypy: allow-untyped-defs import ast import dataclasses +import functools import inspect import math import operator import re +from contextlib import contextmanager from inspect import Parameter -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TYPE_CHECKING +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Type, + TYPE_CHECKING, +) import torch from torch._guards import detect_fake_mode from torch._subclasses.fake_tensor import FakeTensor +from torch._subclasses.functional_tensor import FunctionalTensor +from torch.fx._utils import first_call_function_nn_module_stack +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts if TYPE_CHECKING: from torch._export.passes.lift_constants_pass import ConstantAttrMap + from torch._ops import OperatorBase from torch.export import ExportedProgram from torch.export.graph_signature import ExportGraphSignature @@ -518,6 +535,35 @@ def nodes_filter(nodes: List[torch.fx.Node], node_call_back) -> List[torch.fx.No return [node for node in nodes if node_call_back(node)] +def apply_runtime_assertion_pass(gm, graph_signature): + from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, + ) + from torch._functorch._aot_autograd.input_output_analysis import _graph_output_names + + if not torch._dynamo.config.do_not_emit_runtime_asserts: + stack_trace = ( + 'File "torch/fx/passes/runtime_assert.py", line 24, ' + "in insert_deferred_runtime_asserts" + ) + with _set_node_metadata_hook( + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + ): + shape_env = _get_shape_env_from_gm(gm) + if shape_env: + insert_deferred_runtime_asserts( + gm, + shape_env, + f"exported program: {first_call_function_nn_module_stack(gm.graph)}", + export=True, + ) + # update output specs + gm.recompile() + graph_signature.user_outputs = _graph_output_names(gm) + return gm, graph_signature + + def nodes_first( nodes: List[torch.fx.Node], node_call_back=None ) -> Optional[torch.fx.Node]: @@ -555,6 +601,15 @@ def node_replace_(old_node: torch.fx.Node, new_node: torch.fx.Node) -> None: old_node.graph.erase_node(old_node) +def _update_gm_meta_if_possible(gm: torch.fx.GraphModule, mod: torch.nn.Module) -> None: + if ( + isinstance(mod, torch.fx.GraphModule) + and hasattr(mod, "meta") + and "custom" in mod.meta + ): + gm.meta.update({"custom": mod.meta["custom"]}) + + def node_inline_(call_mod_node: torch.fx.Node) -> None: """ Inline the submodule of the given node into the parent module. @@ -579,6 +634,8 @@ def node_inline_(call_mod_node: torch.fx.Node) -> None: with gm.graph.inserting_before(call_mod_node): for node in body: new_node = gm.graph.node_copy(node) + if node.op == "get_attr": + setattr(gm, node.target, getattr(sub_gm, node.target)) node_replace_(node, new_node) if len(output) > 0: @@ -894,3 +951,187 @@ def _detect_fake_mode_from_gm( fake_vals.append(fake_val) return detect_fake_mode(fake_inps + fake_vals) + + +@contextmanager +def _disable_load_state_dict_hooks(mod: torch.nn.Module): + state_dict_hooks: Dict[int, Callable] = dict(mod._state_dict_hooks) + state_dict_pre_hooks: Dict[int, Callable] = dict(mod._state_dict_pre_hooks) + mod._state_dict_hooks.clear() + mod._state_dict_pre_hooks.clear() + try: + yield + finally: + mod._state_dict_hooks = state_dict_hooks + mod._state_dict_pre_hooks = state_dict_pre_hooks + + +def _is_cia_op(op: "OperatorBase") -> bool: + return ( + torch._C._dispatch_has_kernel_for_dispatch_key( + op.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ) + or torch._C.DispatchKey.CompositeImplicitAutograd in op.py_kernels + ) + + +def _is_preservable_cia_op(op: "OperatorBase") -> bool: + return _check_valid_to_preserve(op) and _is_cia_op(op) + + +def _is_aten_op(op: "OperatorBase") -> bool: + return op.name().split("::")[0] == "aten" + + +def _is_custom_op(op: "OperatorBase") -> bool: + return not _is_aten_op(op) + + +# We can't cache this because custom op registry API in python can still +# add entries to the C++ dispatcher. +def _materialize_cpp_cia_ops() -> None: + """ + Utility function to query C++ dispatcher to get the all + possible CIA ops and populate them into torch.ops namespace + """ + cia_ops = torch._C._dispatch_get_registrations_for_dispatch_key( + "CompositeImplicitAutograd" + ) + + # Materialize all CIA ops + for op in cia_ops: + namespace, op_name = tuple(op.split("::")) + split_list = op_name.split(".") + # Sometime overload could be missing + assert len(split_list) == 1 or len(split_list) == 2 + op_name = split_list[0] + op_overload_name = "default" + if len(split_list) == 2: + op_overload_name = split_list[1] + + _ = getattr(getattr(getattr(torch.ops, namespace), op_name), op_overload_name) + + +def _special_op_to_preserve_cia(*args, **kwargs): + """ + This is an special marker that tells our infra that we shouldn't decompose this op. + """ + return NotImplemented + + +# Our strategy for deciding if we can preserve a op is following: +# 1. The op should be known statically that it is functional +# 2. If it is maybe aliasing, we decompose because we must know if an op +# is mutating or aliasing. +# TODO (tmanlaibaatar) make this utility function and share it with functional_tensor +# decomp part. (https://github.com/pytorch/pytorch/issues/129431) +def _check_valid_to_preserve(op_overload: "OperatorBase"): + if op_overload in FunctionalTensor.maybe_aliasing_or_mutating_ops: + return False + if op_overload in FunctionalTensor.metadata_fns: + return False + + if not hasattr(op_overload, "_schema"): + return False + + alias_info = len( + [i for i in op_overload._schema.arguments if i.alias_info is not None] + ) + + is_mutating_or_aliasing = alias_info != 0 or op_overload._schema.is_mutable + + if is_mutating_or_aliasing: + return False + + if not torch._C._dispatch_has_kernel(op_overload.name()): + return False + + return True + + +@functools.lru_cache(maxsize=1) +def _collect_all_valid_cia_ops_for_aten_namespace() -> Set["OperatorBase"]: + return _collect_all_valid_cia_ops_for_namespace("aten") + + +def _collect_all_valid_cia_ops_for_namespace(namespace: str) -> Set["OperatorBase"]: + # Step 1: Materialize all ops from C++ dispatcher + _materialize_cpp_cia_ops() + + # Step 2: Query all ops from python dispatcher + assert hasattr(torch.ops, namespace) + op_namespace = getattr(torch.ops, namespace) + cia_ops = set() + for op in op_namespace: + op_packet = getattr(op_namespace, op) + for overload in op_packet.overloads(): + op_overload = getattr(op_packet, overload) + if _is_preservable_cia_op(op_overload): + cia_ops.add(op_overload) + return cia_ops + + +def _collect_all_valid_cia_ops() -> Set["OperatorBase"]: + """ + This is an util function that gets the all CIA functional ops. + + The algorithm is in 2 steps: + 1. We first query C++ dispatcher to get the list of CIA ops + and then we call getattr on torch.ops.aten to lazily populate + them. + + 2. Sometimes, handful of ops have CIA registered in python dispatcher + but not on the C++ side, these can't be caught at the first step. + So we walk again to get the final list. + + Note that the output of this function should never be modified + """ + cia_ops = set() + for op_namespace_name in torch.ops._dir: + # The reason we split here is because aten ops are safe to cache. + if op_namespace_name != "aten": + cia_ops |= _collect_all_valid_cia_ops_for_namespace(op_namespace_name) + else: + cia_ops |= _collect_all_valid_cia_ops_for_aten_namespace() + return cia_ops + + +def _get_decomp_for_cia(op: "OperatorBase"): + # [NOTE] Seperating out func.decompose + # Ideally we should be able to just register func.decompose but + # we can't as this decomp is gonna be registered to the py_impl. + # As a result it will infinitely recurse. So we first check if the op + # has py_impl entry for CIA and if it is we use that first. If not, + # we register C++ query to py_impl. + dk = torch._C.DispatchKey.CompositeImplicitAutograd + if dk in op.py_kernels and not isinstance(op.py_kernels[dk], torch._C.DispatchKey): + return op.py_kernels[dk] + + def _special_op_to_decompose_cia(*args, **kwargs): + kernel = kwargs["kernel"] + del kwargs["kernel"] + # Can't call kernel.decompose due to infinite recursion as + # we register this kernel to py_impl directly + dk = torch._C.DispatchKey.CompositeImplicitAutograd + if torch._C._dispatch_has_kernel_for_dispatch_key( + kernel.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ): + return kernel._op_dk(dk, *args, **kwargs) + else: + raise AssertionError( + f"Expected {kernel} to have CompositeImplicitAutograd kernel" + ) + + return functools.partial(_special_op_to_decompose_cia, kernel=op) + + +# This table is a stop-gap table which replicates +# the old behaviour of post-dispatch IR. +# This table contains all functional CIA ops mapping +# to their default decomp. In old export, this will +# be decomposed implicitly. +def _decomp_table_to_post_autograd_aten(): + decomp_table = {} + for k in _collect_all_valid_cia_ops_for_aten_namespace(): + decomp_table[k] = _get_decomp_for_cia(k) + return decomp_table diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 39a3c823a5b3f..e5bc8bc95d3c6 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -33,6 +33,7 @@ from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.utils import should_use_remote_fx_graph_cache from torch._logging import LazyString +from torch._utils_internal import log_cache_bypass from .runtime_wrappers import ( AOTDispatchAutograd, @@ -47,6 +48,7 @@ if TYPE_CHECKING: + from torch._inductor.compile_fx import _CompileFxKwargs from torch._inductor.remote_cache import JsonDataTy, RemoteCache from torch._inductor.utils import BoxedBool from torch.fx.node import Node @@ -175,7 +177,7 @@ def check_cacheable(gm: torch.fx.GraphModule): Checks that the graph module only uses supported operators """ nodes = gm.graph.nodes - if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): raise BypassAOTAutogradCache( "Cannot cache a graph with compiled autograd enabled" ) @@ -205,7 +207,7 @@ def __init__( gm: torch.fx.GraphModule, example_inputs, aot_config: AOTConfig, - fx_config: Dict[str, BoxedBool], + fx_config: _CompileFxKwargs, ): # FxGraphHashDetails contains all the keys related to inductor. Also includes some system info self.aot_config = aot_config @@ -269,7 +271,7 @@ def autograd_cache_key( gm: torch.fx.GraphModule, example_inputs, config: AOTConfig, - fx_config: Dict[str, BoxedBool], + fx_config: _CompileFxKwargs, # TODO: add args and parameters ) -> Tuple[str, List[str]]: """ @@ -295,7 +297,7 @@ class FXGraphCacheLoadable: def is_backward(self): return False - def load(self, example_inputs, fx_config: Dict[str, BoxedBool]) -> CompiledFxGraph: + def load(self, example_inputs, fx_config: _CompileFxKwargs) -> CompiledFxGraph: # [Note: AOTAutogradCache and FXGraphCache Guard interactions] # As mentioned, AOTAutograd takes in the symint inputs from dynamo's list of arguments. # FXGraphCache serializes guards that are needed in the shape_env based on these symint inputs to the graph. @@ -332,7 +334,7 @@ def load(self, example_inputs, fx_config: Dict[str, BoxedBool]) -> CompiledFxGra payload_fn=lambda: json.dumps(cache_info), ) - FxGraphCache.post_compile(result, example_inputs, fx_config["cudagraphs"]) + FxGraphCache.post_compile(result, example_inputs, fx_config["cudagraphs"]) # type: ignore[arg-type] result._boxed_call = True return result @@ -369,6 +371,12 @@ class AOTAutogradCacheEntry: compiled_fw: CompiledForward compiled_bw: Optional[CompiledBackward] + # Code of the joint graph using print_readable() + # Used for logging purposes + aot_joint_graph_str: Optional[str] + aot_forward_graph_str: Optional[str] + aot_backward_graph_str: Optional[str] + # Runtime_metadata saved right before compilation runtime_metadata: ViewAndMutationMeta @@ -393,7 +401,7 @@ def wrap_post_compile( self, args: List[torch.Tensor], aot_config: AOTConfig, - fx_config: Dict[str, BoxedBool], + fx_config: _CompileFxKwargs, ) -> Callable: """ This function takes a cache entry and carefully reconstructs the original callable @@ -411,13 +419,35 @@ def wrap_post_compile( Which we'll handle separately later on, if necessary. """ + + # Log the output of AOTAutogradCache + if aot_config.enable_log: + # TODO: maybe also log to aot_graphs_log + # Unfortunately aot_graphs_log uses + # slightly different formatting though + if self.aot_joint_graph_str is not None: + torch._logging.trace_structured( + "aot_joint_graph", payload_fn=lambda: self.aot_joint_graph_str + ) + if self.aot_forward_graph_str is not None: + torch._logging.trace_structured( + "aot_forward_graph", payload_fn=lambda: self.aot_forward_graph_str + ) + if self.aot_backward_graph_str is not None: + torch._logging.trace_structured( + "aot_backward_graph", payload_fn=lambda: self.aot_backward_graph_str + ) + compiled_fw_func = self.compiled_fw.load(args, fx_config) compiled_bw_func = None + chromium_log = get_chromium_event_logger() if self.compiled_bw is not None: compiled_bw_func = self.compiled_bw.load(args, fx_config) needs_autograd = True + chromium_log.add_event_data("backend_compile", dispatch_mode="autograd") else: needs_autograd = False + chromium_log.add_event_data("backend_compile", dispatch_mode="inference") # Wrap the forward function in post compile wrappers compiled_fw_func = AOTDispatchSubclassWrapper( @@ -429,6 +459,11 @@ def wrap_post_compile( compiled_fw_func, aot_config, runtime_metadata=self.runtime_metadata ) + req_subclass_dispatch = self.maybe_subclass_meta is not None + chromium_log.add_event_data( + "backend_compile", requires_subclass_dispatch=req_subclass_dispatch + ) + # In autograd case, functionalizedRngWrapper should not modify outs return_new_outs = not needs_autograd compiled_fw_func = FunctionalizedRngRuntimeWrapper( @@ -541,7 +576,7 @@ def load( debug_lines: List[str] = [] cache_event_time = time.time_ns() cache_state = None - fx_config = {"cudagraphs": cudagraphs} + fx_config: _CompileFxKwargs = {"cudagraphs": cudagraphs} try: cache_key, debug_lines = autograd_cache_key(gm, args, aot_config, fx_config) entry: Optional[AOTAutogradCacheEntry] = AOTAutogradCache._lookup( @@ -593,6 +628,9 @@ def load( counters["aot_autograd"]["autograd_cache_bypass"] += 1 cache_state = "bypass" cache_event_time = time.time_ns() + cache_info["cache_bypass_reason"] = str(e) + if remote: + log_cache_bypass("bypass_aot_autograd", str(e)) if config.strict_autograd_cache: raise e if compiled_fn is None: @@ -612,6 +650,18 @@ def load( chromium_log.log_instant_event( f"autograd_cache_{cache_state}", cache_event_time, metadata=cache_info ) + + chromium_log.add_event_data( + "backend_compile", + cache_state=cache_state, + cache_event_time=cache_event_time, + key=cache_info.get("key"), + components=cache_info.get("components"), + cache_bypass_reason=cache_info.get("cache_bypass_reason"), + remote_cache_enabled=remote, + local_cache_enabled=local, + ) + torch._logging.trace_structured( "artifact", metadata_fn=lambda: { diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index c16aadfac5232..7fa17ba2ff9dc 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -12,7 +12,7 @@ import contextlib import logging from functools import wraps -from typing import Callable, DefaultDict, Dict, List, Optional +from typing import Callable, DefaultDict, Dict, List, Optional, Set import torch import torch.utils._pytree as pytree @@ -34,7 +34,7 @@ from_fun, has_data_mutation, has_metadata_mutation, - has_same_metadata, + MetadataKey, to_fun, was_inductor_storage_resized, ) @@ -152,6 +152,9 @@ def run_functionalized_fw_and_collect_metadata( # Note: this is guaranteed to be set when running under dynamo static_input_indices: Optional[List[int]] = None, pre_dispatch: bool = False, + # is_export is technically only needed to avoid using functionalization V2 + # during analysis + is_export: bool = False, ) -> Callable[..., ViewAndMutationMeta]: memo: Dict[Tensor, Tensor] = {} @@ -183,7 +186,7 @@ def inner(*flat_args): # It doesn't matter if we run this under predispatch or not because it is # only for figuring out metadata - mode = FunctionalTensorMode(_allow_token_discovery=True) + mode = FunctionalTensorMode(_allow_token_discovery=True, export=is_export) suppress_pending = contextlib.nullcontext() fake_mode = detect_fake_mode() if fake_mode and (shape_env := fake_mode.shape_env): @@ -225,10 +228,6 @@ def inner(*flat_args): "tensor subclasses" ) - if not isinstance(arg, Tensor): - new_arg = arg - else: - new_arg = from_fun(f_arg) mutates_metadata = has_metadata_mutation( f_arg, arg, check_only_storage_mutation=False ) @@ -292,7 +291,11 @@ def inner(*flat_args): num_aliased_tensors_that_are_multi_output_views: DefaultDict = ( collections.defaultdict(int) ) - out_storage_to_tensors: DefaultDict = collections.defaultdict(set) + + out_storage_to_metadata_key_to_tensors: DefaultDict[ + Optional[StorageWeakRef], DefaultDict[MetadataKey, Set[torch.Tensor]] + ] = collections.defaultdict(lambda: collections.defaultdict(set)) + curr_storage = None for o in flat_f_outs: if isinstance(o, torch.Tensor): @@ -383,7 +386,10 @@ def inner(*flat_args): ) if is_cur_tensor_multi_out_view: num_aliased_tensors_that_are_multi_output_views[curr_storage] += 1 - out_storage_to_tensors[curr_storage].add(o) + if o.requires_grad: + out_storage_to_metadata_key_to_tensors[curr_storage][ + MetadataKey.make(o) + ].add(o) # maps the id of an intermediate base to its index in the output of the compiled forward intermediate_base_tensor_id_to_output_idx: Dict[int, int] = {} @@ -428,10 +434,10 @@ def inner(*flat_args): if not isinstance(o, Tensor) else [ curr - for curr in out_storage_to_tensors[curr_storage] - if has_same_metadata(o, curr) - and curr.requires_grad - and o is not curr + for curr in out_storage_to_metadata_key_to_tensors[curr_storage][ + MetadataKey.make(o) + ] + if o is not curr ] ) @@ -701,7 +707,7 @@ def view_avoid_dupes_with_primals(t): traced_tangent_memory_formats = [t[1] for t in tangents_and_memory_formats] nonlocal static_input_indices static_input_indices = static_input_indices or [] - if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): passed_indices = set(static_input_indices) static_input_indices = [ i diff --git a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py index 62cf7b68cd3fa..393331102ec20 100644 --- a/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py +++ b/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py @@ -16,6 +16,7 @@ from torch._logging import getArtifactLogger, trace_structured from torch._subclasses.functional_tensor import FunctionalTensorMode from torch.fx.experimental.proxy_tensor import make_fx +from torchgen.utils import dataclass_repr from .. import config from .functional_utils import ( @@ -192,8 +193,27 @@ def aot_dispatch_base_graph( colored=True, ), ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(fw_metadata), + ) + if maybe_subclass_meta is not None: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_subclass_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(maybe_subclass_meta), + ) + trace_structured( - "aot_forward_graph", + "aot_inference_graph", payload_fn=lambda: fw_module.print_readable( print_output=False, include_stride=True, include_device=True ), diff --git a/torch/_functorch/_aot_autograd/functional_utils.py b/torch/_functorch/_aot_autograd/functional_utils.py index 71862997ae071..ec647888c8612 100644 --- a/torch/_functorch/_aot_autograd/functional_utils.py +++ b/torch/_functorch/_aot_autograd/functional_utils.py @@ -8,7 +8,8 @@ """ from __future__ import annotations -from typing import Optional +from dataclasses import dataclass +from typing import Optional, Tuple import torch from torch import Tensor @@ -16,7 +17,11 @@ from torch._subclasses.fake_tensor import FakeTensor from torch._subclasses.functional_tensor import FunctionalTensor from torch._subclasses.meta_utils import is_sparse_any -from torch.fx.experimental.symbolic_shapes import definitely_true, sym_eq +from torch.fx.experimental.symbolic_shapes import ( + definitely_true, + sym_eq, + SymIntEqByExpr, +) from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._python_dispatch import ( is_traceable_wrapper_subclass, @@ -326,6 +331,35 @@ def has_same_metadata(t1, t2): ) +@dataclass(frozen=True) +class MetadataKey: + """ + This should be equal whenever has_same_metadata would return True + """ + + size: Tuple[SymIntEqByExpr, ...] + layout: torch.layout + is_sparse: bool + # these are empty when is_sparse + stride: Optional[Tuple[SymIntEqByExpr, ...]] + storage_offset: Optional[SymIntEqByExpr] + is_conj: bool + is_neg: bool + + @staticmethod + def make(t): + is_sparse = is_sparse_any(t) + return MetadataKey( + size=tuple(SymIntEqByExpr(s) for s in t.size()), + layout=t.layout, + is_sparse=is_sparse, + stride=None if is_sparse else tuple(SymIntEqByExpr(s) for s in t.stride()), + storage_offset=None if is_sparse else SymIntEqByExpr(t.storage_offset()), + is_conj=t.is_conj(), + is_neg=t.is_neg(), + ) + + # Wrapper around a FunctionalTensorWrapper for comparing only the resulting metadata # after applying all the ViewMeta operations. class FunctionalTensorMetadataEq: diff --git a/torch/_functorch/_aot_autograd/input_output_analysis.py b/torch/_functorch/_aot_autograd/input_output_analysis.py index 8f330a056b7ae..77e9e23faeb1a 100644 --- a/torch/_functorch/_aot_autograd/input_output_analysis.py +++ b/torch/_functorch/_aot_autograd/input_output_analysis.py @@ -16,6 +16,7 @@ import torch import torch.utils._pytree as pytree from torch import Tensor +from torch._dynamo.exc import Unsupported from torch._subclasses.functional_tensor import FunctionalTensor from torch.fx.experimental.symbolic_shapes import is_concrete_int @@ -375,9 +376,7 @@ def compute_overlapping_inputs(fwd_inputs, aliased_input_indices): ) ): dynamic_shape_indices.add(j_) - assert ( - len(dynamic_shape_indices) == 0 - ), f"""\ + err_message = f"""\ Encountered a graph where: - {num_aliases} graph inputs all share the same storage (input indices: {str(aliased_input_indices)}) - at least one of these aliased inputs was mutated @@ -397,6 +396,11 @@ def compute_overlapping_inputs(fwd_inputs, aliased_input_indices): If you are running into this issue in a situation where your parameters are static but some other inputs are aliased and mutated, and they should be dynamic, please file an issue. """ + if len(dynamic_shape_indices) != 0: + raise Unsupported( + err_message, + case_name="dynamic_shapes_validation", + ) for j in range(num_aliases): for i in range(j): j_ = aliased_input_indices[j] diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index 0d29e1b5594d7..fda2a22d202c2 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -19,14 +19,17 @@ import torch import torch.utils.dlpack from torch import Tensor -from torch._dynamo.utils import lazy_format_graph_code +from torch._dynamo.utils import detect_fake_mode, lazy_format_graph_code from torch._guards import CompileContext, TracingContext from torch._logging import getArtifactLogger, trace_structured from torch._subclasses import FakeTensor from torch.fx.experimental._backward_state import BackwardState from torch.fx.experimental.proxy_tensor import is_sym_node from torch.fx.experimental.symbolic_shapes import fx_placeholder_vals +from torch.fx.graph_module import GraphModule +from torch.fx.passes._tensorify_python_scalars import tensorify_python_scalars from torch.multiprocessing.reductions import StorageWeakRef +from torchgen.utils import dataclass_repr from .. import config from .autograd_cache import ( @@ -143,6 +146,13 @@ def aot_dispatch_base( fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph( # type: ignore[misc] flat_fn, flat_args, aot_config, fw_metadata=fw_metadata ) + # Save the forward_graph_str right after aot_dispatch_base_graph, + # to save in the cache + aot_forward_graph_str = None + if autograd_cache_enabled(): + aot_forward_graph_str = fw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ) fakified_out_wrapper = FakifiedOutWrapper() ( @@ -179,6 +189,10 @@ def aot_dispatch_base( ) with TracingContext.report_output_strides() as fwd_output_strides: + fake_mode = detect_fake_mode() + if fake_mode is not None: + assert isinstance(fw_module, GraphModule) + tensorify_python_scalars(fw_module, fake_mode.shape_env, fake_mode) compiled_fw = compiler(fw_module, updated_flat_args) if fakified_out_wrapper.needs_post_compile: @@ -203,6 +217,9 @@ def aot_dispatch_base( entry = AOTAutogradCacheEntry( compiled_fw=CompiledForward(fw_key), compiled_bw=None, + aot_joint_graph_str=None, + aot_forward_graph_str=aot_forward_graph_str, + aot_backward_graph_str=None, runtime_metadata=fw_metadata, dispatch_wrappers=wrappers, maybe_subclass_meta=maybe_subclass_meta, @@ -344,7 +361,7 @@ def aot_dispatch_autograd( # Copied from aot_dispatch_autograd_graph. disable_amp = torch._C._is_any_autocast_enabled() - + joint_graph_str = None if aot_config.enable_log: aot_joint_log.info( "%s", @@ -357,11 +374,12 @@ def aot_dispatch_autograd( colored=True, ), ) + joint_graph_str = fx_g.print_readable( + print_output=False, include_stride=True, include_device=True + ) trace_structured( "aot_joint_graph", - payload_fn=lambda: fx_g.print_readable( - print_output=False, include_stride=True, include_device=True - ), + payload_fn=lambda: joint_graph_str, ) with torch.no_grad(): @@ -387,6 +405,9 @@ def aot_dispatch_autograd( + inner_meta.num_outputs_rng_offset + num_tokens # See Note [Side-Effectful Tokens in AOTAutograd] ) + fake_mode = detect_fake_mode() + if fake_mode is not None: + tensorify_python_scalars(fx_g, fake_mode.shape_env, fake_mode) fw_module, bw_module = aot_config.partition_fn( fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs ) @@ -527,6 +548,8 @@ def aot_dispatch_autograd( if bw_out is None and not metadata_mutation_in_graph and is_non_leaf: _indices_of_inps_to_detach.append(i) + fw_module_str = None + bw_module_str = None if aot_config.enable_log: aot_graphs_log.info( "%s", @@ -550,20 +573,43 @@ def aot_dispatch_autograd( colored=True, ), ) + fw_module_str = fw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ) + bw_module_str = bw_module.print_readable( + print_output=False, include_stride=True, include_device=True + ) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(fw_metadata), + ) + if maybe_subclass_meta is not None: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_forward_graph_fw_subclass_metadata", + "encoding": "string", + }, + payload_fn=lambda: dataclass_repr(maybe_subclass_meta), + ) + trace_structured( "aot_forward_graph", - payload_fn=lambda: fw_module.print_readable( - print_output=False, include_stride=True, include_device=True - ), + payload_fn=lambda: fw_module_str, ) trace_structured( "aot_backward_graph", - payload_fn=lambda: bw_module.print_readable( - print_output=False, include_stride=True, include_device=True - ), + payload_fn=lambda: bw_module_str, ) - with track_graph_compiling(aot_config, "forward"): + # AMP is already traced out in joint graph. we do not wish to reapply it accidentally + # in the compiler. + with track_graph_compiling(aot_config, "forward"), torch._C._DisableAutocast(): # flat_args at this point might still be subclasses- # make sure to pass the unwrapped fake tensors into the compiler! adjusted_flat_args = joint_inputs[0] @@ -628,7 +674,7 @@ def aot_dispatch_autograd( # NB: It's important to compile backwards ahead of time, as this may # add extra guards which we need to apply to the Dynamo cache at # forwards - with track_graph_compiling(aot_config, "backward"): + with track_graph_compiling(aot_config, "backward"), torch._C._DisableAutocast(): placeholder_list = fx_placeholder_vals(bw_module) forward_saved_for_backwards_strides = None @@ -680,28 +726,24 @@ def aot_dispatch_autograd( compiled_bw_func = None if num_symints_saved_for_bw > 0: - context = torch._C._DisableAutocast if disable_amp else nullcontext - with context(): - try: - compiled_bw_func = aot_config.bw_compiler( - bw_module, placeholder_list - ) - except Exception as e: - exc = e - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "eager_compile_backwards_failure", - "encoding": "string", - }, - payload_fn=lambda: "\n".join( - traceback.format_exception(exc) - ), - ) - log.warning( - "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed", - exc_info=True, - ) + try: + compiled_bw_func = aot_config.bw_compiler( + bw_module, placeholder_list + ) + except Exception as e: + exc = e + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "eager_compile_backwards_failure", + "encoding": "string", + }, + payload_fn=lambda: "\n".join(traceback.format_exception(exc)), + ) + log.warning( + "failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed", + exc_info=True, + ) # Compiled autograd will run the bw_module in the backward pass, # so recompilation need happen anyway if the backward pass is ever # called. @@ -716,7 +758,7 @@ def aot_dispatch_autograd( # becomes the lazy version again. One example is when dynamic shape is enabled # upfront, the bw_compiler will be called above which can cause extra # graph module recompilation on bw_module. - if torch._dynamo.compiled_autograd.in_compiled_autograd_region: + if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): from torch.fx._lazy_graph_module import _LazyGraphModule _LazyGraphModule.force_recompile(bw_module) @@ -762,11 +804,20 @@ def try_save_cache_entry( # noqa: F811 # It's possible this changes in the future, in which case we should # update backward_time_taken_ns to be more inclusive backward_time_taken_ns = getattr(compiled_bw_func, "_time_taken_ns", 0) + + aot_forward_graph_str: Optional[str] = fw_module_str + aot_backward_graph_str: Optional[str] = bw_module_str + aot_joint_graph_str: Optional[str] = joint_graph_str entry = AOTAutogradCacheEntry( CompiledForward(fw_key), CompiledBackward( - bw_key, backward_state_indices, num_symints_saved_for_bw + bw_key, + backward_state_indices, + num_symints_saved_for_bw, ), + aot_joint_graph_str, + aot_forward_graph_str, + aot_backward_graph_str, _fw_metadata, wrappers, maybe_subclass_meta, diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index e42a1a9976199..9e8a21321ad77 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -1854,13 +1854,6 @@ def get_types_for_tangents(tangents): "The grad inputs should be same tensor subclass type as forward output" ) - # Get the number of tangents after unwrapping - len_tangents = len( - unwrap_tensor_subclasses( - tangents, - is_joint_structure=False, - ) - ) assert CompiledFunction.metadata.traced_tangent_metas is not None all_args = [ ( diff --git a/torch/_functorch/_aot_autograd/subclass_utils.py b/torch/_functorch/_aot_autograd/subclass_utils.py index ad8de0eac069f..62b6223440a33 100644 --- a/torch/_functorch/_aot_autograd/subclass_utils.py +++ b/torch/_functorch/_aot_autograd/subclass_utils.py @@ -113,12 +113,9 @@ def create_subclass_meta( # NOTE: this function is hot, since we unwrap tensor subclass inputs at runtime def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool): def concat_inner_tensors_from_subclasses(xs): - xs_inner = [] + xs_inner: List[Tensor] = [] for x in xs: - if is_traceable_wrapper_subclass(x): - xs_inner.extend(get_plain_tensors(typing.cast(Tensor, x))) - else: - xs_inner.append(x) + get_plain_tensors(x, out_append_list=xs_inner) return xs_inner if is_joint_structure: diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 25a85c9d4dd45..85120254b2e66 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -15,7 +15,11 @@ from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions from torch._dispatch.python import enable_python_dispatcher from torch._dynamo import compiled_autograd -from torch._dynamo.utils import dynamo_timed, preserve_rng_state +from torch._dynamo.utils import ( + dynamo_timed, + get_chromium_event_logger, + preserve_rng_state, +) from torch._guards import detect_fake_mode from torch._inductor.utils import BoxedBool from torch._subclasses import FakeTensor, FakeTensorMode @@ -581,6 +585,13 @@ def _create_aot_dispatcher_function( enable_python_dispatcher() if shape_env is not None else nullcontext() ) + def try_record_chromium_data(**kwargs): + # `backend_compile` only exists as an event if we are compiling with dynamo + # In some unit tests we don't use dynamo, so we ignore those cases + chromium_log = get_chromium_event_logger() + if "backend_compile" in chromium_log.get_stack(): + chromium_log.add_event_data("backend_compile", **kwargs) + # See NOTE: [Deferring tensor pack/unpack hooks until runtime] # If any saved tensor hooks are active, we **don't** want to trace them. # Instead, we'll let them run at runtime, around the custom autograd.Function @@ -628,11 +639,15 @@ def _dup_fake_script_obj(fake_flat_args): keep_input_mutations=aot_config.keep_inference_input_mutations, is_train=needs_autograd, pre_dispatch=aot_config.pre_dispatch, + is_export=aot_config.is_export, )(*_dup_fake_script_obj(fake_flat_args)) req_subclass_dispatch = requires_subclass_dispatch( fake_flat_args, fw_metadata ) + try_record_chromium_data( + requires_subclass_dispatch=req_subclass_dispatch + ) output_and_mutation_safe = not any( x.requires_grad @@ -751,10 +766,13 @@ def choose_dispatcher(needs_autograd, aot_config): if aot_config.is_export: # export uses just the "graph bits", whereas the other # two dispatchers include some extra work around handling a runtime epilogue + try_record_chromium_data(dispatch_mode="export") return partial(aot_dispatch_export, needs_autograd=needs_autograd) elif needs_autograd and not aot_config.pre_dispatch: + try_record_chromium_data(dispatch_mode="autograd") return aot_dispatch_autograd else: + try_record_chromium_data(dispatch_mode="inference") return aot_dispatch_base compiler_fn = choose_dispatcher(needs_autograd, aot_config) @@ -1454,6 +1472,7 @@ def _aot_export_function( flat_fn, out_spec = create_tree_flattened_fn(func, args, kwargs) flat_args, in_spec = pytree.tree_flatten((args, kwargs)) + fake_mode = None if dynamic_shapes is None: # Try to infer `dynamic_shapes from inputs and graph nodes fake_mode = detect_fake_mode(flat_args) @@ -1491,7 +1510,10 @@ def _aot_export_function( no_tangents=no_tangents, pre_dispatch=pre_dispatch, ) - fake_mode, shape_env = construct_fake_mode(flat_args, aot_config) + if fake_mode is None: + fake_mode, shape_env = construct_fake_mode(flat_args, aot_config) + else: + shape_env = fake_mode.shape_env fake_flat_args = process_inputs(flat_args, aot_config, fake_mode, shape_env) fx_g, meta = create_aot_dispatcher_function( @@ -1509,7 +1531,7 @@ def _detect_attribute_assignment(mod: torch.nn.Module): # Do not allow assignment of tensor attributes during export unless # the attribute is registered as a buffer. - STD_ATTRS = { + NN_MODULE_STD_ATTRS = [ "_backward_hooks", "_backward_pre_hooks", "_buffers", @@ -1527,6 +1549,14 @@ def _detect_attribute_assignment(mod: torch.nn.Module): "_state_dict_hooks", "_state_dict_pre_hooks", "training", + ] + NN_MODULE_LAZY_STD_ATTRS = [ + "_initialize_hook", + "_load_hook", + ] + STD_ATTRS = { + *NN_MODULE_STD_ATTRS, + *NN_MODULE_LAZY_STD_ATTRS, } def _get_attributes(mod): diff --git a/torch/_functorch/benchmark_utils.py b/torch/_functorch/benchmark_utils.py index e0bcae4c836e9..ac69e8bd4744c 100644 --- a/torch/_functorch/benchmark_utils.py +++ b/torch/_functorch/benchmark_utils.py @@ -222,7 +222,7 @@ def f(a): optimize_ctx, [ProfilerActivity.CUDA], num_runs=num_runs, - devices="cuda", + devices=["cuda"], ) utilization, mm_conv_utilization = compute_utilization( chrome_trace_file_name, total_length diff --git a/torch/_functorch/config.py b/torch/_functorch/config.py index 8c042ee7ed56a..9d148de1aa794 100644 --- a/torch/_functorch/config.py +++ b/torch/_functorch/config.py @@ -162,6 +162,10 @@ def remote_autograd_cache_default() -> Optional[bool]: # tokens. unlift_effect_tokens = False + +# Run aot eager decomp partition with CrossRefFakeMode +fake_tensor_crossref = False + # This mode specifies that we should also keep track of the real # tensor along with the fake tensor, and do real compute. While # seemingly this eliminates the whole point of fake tensors, there are diff --git a/torch/_functorch/partitioners.py b/torch/_functorch/partitioners.py index d5ba05eca0390..e36a02853c235 100644 --- a/torch/_functorch/partitioners.py +++ b/torch/_functorch/partitioners.py @@ -1506,9 +1506,12 @@ def dp_knapsack( def _optimize_runtime_with_given_memory( + joint_graph: fx.Graph, memory: List[float], runtimes: List[float], max_memory: float, + node_info: NodeInfo, + all_recomputable_banned_nodes: List[fx.Node], ) -> Tuple[float, List[int], List[int]]: SOLVER = config.activation_memory_budget_solver if SOLVER == "greedy": @@ -1517,6 +1520,11 @@ def _optimize_runtime_with_given_memory( return ilp_knapsack(memory, runtimes, max_memory) elif SOLVER == "dp": return dp_knapsack(memory, runtimes, max_memory) + elif callable(SOLVER): + saved_node_idx, recomp_node_idx = SOLVER( + memory, joint_graph, max_memory, node_info, all_recomputable_banned_nodes + ) + return (0.0, saved_node_idx, recomp_node_idx) else: raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}") @@ -1572,7 +1580,9 @@ def realize_symbol(d): def choose_saved_values_set( - joint_graph: fx.Graph, node_info: NodeInfo, memory_budget=1 + joint_graph: fx.Graph, + node_info: NodeInfo, + memory_budget=1, ) -> List[fx.Node]: if memory_budget > 1 or memory_budget < 0: raise RuntimeError( @@ -1680,18 +1690,28 @@ def get_recomputable_banned_nodes(banned_nodes: List[fx.Node]) -> List[fx.Node]: ] from torch.utils._mode_utils import no_dispatch - def get_saved_values_knapsack(memory_budget): + def get_saved_values_knapsack(memory_budget, node_info, joint_graph): with no_dispatch(): ( expected_runtime, saved_node_idxs, recomputable_node_idxs, ) = _optimize_runtime_with_given_memory( - memories_banned_nodes, runtimes_banned_nodes, max(memory_budget, 0) + joint_graph, + memories_banned_nodes, + runtimes_banned_nodes, + max(memory_budget, 0), + node_info, + all_recomputable_banned_nodes, ) dont_ban = set() for idx in recomputable_node_idxs: - dont_ban.add(all_recomputable_banned_nodes[idx]) + # if idx in all_recomputable_banned_nodes: + try: + dont_ban.add(all_recomputable_banned_nodes[idx]) + except BaseException: + pass + assert dont_ban.issubset(all_recomputable_banned_nodes) saved_values, _ = solve_min_cut( @@ -1706,7 +1726,7 @@ def get_saved_values_knapsack(memory_budget): options = [] for sweep_memory_budget in range(100, -1, -5): saved_values, expected_runtime = get_saved_values_knapsack( - sweep_memory_budget / 100 + sweep_memory_budget / 100, node_info=node_info, joint_graph=joint_graph ) options.append( ( @@ -1751,7 +1771,9 @@ def get_saved_values_knapsack(memory_budget): # tensors we actually banned from recompute, but there may be other # tensors that we choose to save. - return get_saved_values_knapsack(memory_budget=memory_budget)[0] + return get_saved_values_knapsack( + memory_budget=memory_budget, node_info=node_info, joint_graph=joint_graph + )[0] def min_cut_rematerialization_partition( @@ -1877,7 +1899,9 @@ def classify_nodes(joint_module): break # print("Memory Budget: ", memory_budget) saved_values = choose_saved_values_set( - joint_graph, node_info, memory_budget=memory_budget + joint_graph, + node_info, + memory_budget=memory_budget, ) # save_for_backward on tensors and stashes symints in autograd .ctx saved_sym_nodes = list(filter(is_sym_node, saved_values)) diff --git a/torch/_guards.py b/torch/_guards.py index f6bd852d26d47..c42d483750087 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -570,6 +570,66 @@ def restore_graphstate(self, state): self.dynamo_guards = GuardsSet(state.dynamo_guards) +class HopSubgraphCache: + @abstractmethod + def add_dynamo_identifier(self, cache_key: str, identifier: str): ... + + @abstractmethod + def get_dynamo_identifier(self, cache_key: str) -> Optional[str]: ... + + @abstractmethod + def add_autograd_key_entry(self, identifier: str, key: Callable): ... + + @abstractmethod + def get_autograd_key_entry(self, identifier: str): ... + + @abstractmethod + def add_proxy_dispatch_entry(self, identifier: str, key: Callable): ... + + @abstractmethod + def get_proxy_dispatch_entry(self, identifier: str): ... + + +class InvokeSubgraphCache(HopSubgraphCache): + def __init__(self) -> None: + self.autograd_cache: Dict[str, Callable] = {} + self.proxy_dispatch_cache: Dict[str, Callable] = {} + self.dynamo_identifiers: Dict[str, str] = {} + + def add_dynamo_identifier(self, cache_key: str, identifier: str): + self.dynamo_identifiers[cache_key] = identifier + + def get_dynamo_identifier(self, cache_key: str) -> Optional[str]: + return self.dynamo_identifiers.get(cache_key, None) + + def add_autograd_key_entry(self, identifier: str, key: Callable): + self.autograd_cache[identifier] = key + + def get_autograd_key_entry(self, identifier: str): + return self.autograd_cache.get(identifier, None) + + def add_proxy_dispatch_entry(self, identifier: str, key: Callable): + self.proxy_dispatch_cache[identifier] = key + + def get_proxy_dispatch_entry(self, identifier: str): + return self.proxy_dispatch_cache.get(identifier, None) + + +class HopDispatchSetCache: + def __init__(self) -> None: + # Delayed import to avoid circular dependency + from torch._higher_order_ops.invoke_subgraph import invoke_subgraph + + self.hop_cache_map = {invoke_subgraph: InvokeSubgraphCache()} + + def get_cache( + self, op: torch._ops.HigherOrderOperator + ) -> Optional[HopSubgraphCache]: + if op not in self.hop_cache_map: + return None + return self.hop_cache_map[op] # type: ignore[index] + + _TLS = threading.local() """ @@ -686,6 +746,7 @@ def __init__(self, fake_mode): # meta on the first invocation # see note: [Returning Fake Tensors on First AOT Autograd Call] self.fakify_first_call = False + self.hop_dispatch_set_cache = HopDispatchSetCache() def clear(self): # Look at the note in output_graph.py in function `save_global_state` diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py index 72800cae7fc98..8c78306699f5e 100644 --- a/torch/_higher_order_ops/__init__.py +++ b/torch/_higher_order_ops/__init__.py @@ -4,12 +4,16 @@ flex_attention_backward, ) from torch._higher_order_ops.hints_wrap import hints_wrapper +from torch._higher_order_ops.invoke_subgraph import invoke_subgraph +from torch._higher_order_ops.scan import scan from torch._higher_order_ops.while_loop import while_loop __all__ = [ "cond", "while_loop", + "invoke_subgraph", + "scan", "flex_attention", "flex_attention_backward", "hints_wrapper", diff --git a/torch/_higher_order_ops/auto_functionalize.py b/torch/_higher_order_ops/auto_functionalize.py index ac06b9c822942..7557deede66d5 100644 --- a/torch/_higher_order_ops/auto_functionalize.py +++ b/torch/_higher_order_ops/auto_functionalize.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import warnings +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Any, Dict, List, Optional, Sequence, Tuple, Union @@ -25,24 +26,30 @@ def get_base(tensor): return tensor._base -@dataclass -class ViewInfo: +class ViewInfo(ABC): base_index: int - size: Optional[Sequence[Union[int, torch.SymInt]]] = None - stride: Optional[Sequence[Union[int, torch.SymInt]]] = None - storage_offset: Optional[int] = None - # When is_view is false, the tensor is the base, and - # size, stride and storage_offset are all None. - is_view: bool = True + def __init__(self, base_index): + self.base_index = base_index + + @abstractmethod def regenerate_view(self, bases_list: List[Tensor]): - if not self.is_view: - return bases_list[self.base_index] + pass - assert self.stride is not None - assert self.size is not None - assert self.storage_offset is not None +@dataclass +class AsStridedViewInfo(ViewInfo): + size: Sequence[Union[int, torch.SymInt]] + stride: Sequence[Union[int, torch.SymInt]] + storage_offset: int + + def __init__(self, base_index, size, stride, storage_offset): + super().__init__(base_index) + self.size = size + self.stride = stride + self.storage_offset = storage_offset + + def regenerate_view(self, bases_list: List[Tensor]): return torch.as_strided( bases_list[self.base_index], self.size, @@ -51,6 +58,85 @@ def regenerate_view(self, bases_list: List[Tensor]): ) +@dataclass +class SliceViewInfo(ViewInfo): + dim: Union[int, torch.SymInt] + start: Union[int, torch.SymInt] + end: Union[int, torch.SymInt] + + def __init__(self, base_index, dim, start, end): + super().__init__(base_index) + self.dim = dim + self.start = start + self.end = end + + def regenerate_view(self, bases_list: List[Tensor]): + return torch.ops.aten.slice.Tensor( + bases_list[self.base_index], self.dim, self.start, self.end + ) + + +@dataclass +class AliasViewInfo(ViewInfo): + def __init__(self, base_index): + super().__init__(base_index) + + def regenerate_view(self, bases_list: List[Tensor]): + return torch.ops.aten.alias.default(bases_list[self.base_index]) + + +@dataclass +class NotView(ViewInfo): + def __init__(self, base_index): + super().__init__(base_index) + + def regenerate_view(self, bases_list: List[Tensor]): + return bases_list[self.base_index] + + +def is_alias(base, tensor): + from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq + + return all( + statically_known_true(a) + for a in [ + sym_eq(base.storage_offset(), tensor.storage_offset()), + sym_eq(base.stride(), tensor.stride()), + sym_eq(base.size(), tensor.size()), + ] + ) + + +# return None or (dim, start, end) +def try_use_slice(base, tensor): + from torch.fx.experimental.symbolic_shapes import statically_known_true, sym_eq + + # This condition should never be triggered. + if is_alias(base, tensor): + return (0, 0, base.size()[0]) + + # TODO is there cases can we use slice even if stride or len(sizes) are not equal? + if not statically_known_true(sym_eq(tensor.stride(), base.stride())): + return None + if not statically_known_true(sym_eq(len(tensor.size()), len(base.size()))): + return None + + dim = None + count = 0 + for i in range(len(tensor.size())): + if base.size()[i] != tensor.size()[i]: + dim = i + count = count + 1 + if count != 1: + return None + + if tensor.storage_offset() % tensor.stride()[dim] != 0: + return None + start = tensor.storage_offset() // tensor.stride()[dim] + end = start + tensor.size()[dim] + return (dim, start, end) + + def write_view_information_to_args( mutable_arg_names: List[str], mutable_arg_types: List[torch.Type], @@ -73,16 +159,38 @@ def write_single_view(prefix: str, tensor: Tensor, base_index: int): assert f"{prefix}_stride" not in kwargs assert f"{prefix}_storage_offset" not in kwargs + assert f"{prefix}_slice_dim" not in kwargs + assert f"{prefix}_slice_start" not in kwargs + assert f"{prefix}_slice_end" not in kwargs + + def use_as_strided(tensor): + kwargs[f"{prefix}_size"] = tensor.size() + kwargs[f"{prefix}_stride"] = tensor.stride() + kwargs[f"{prefix}_storage_offset"] = tensor.storage_offset() + + def use_slice(dim, start, end): + kwargs[f"{prefix}_slice_dim"] = dim + kwargs[f"{prefix}_slice_start"] = start + kwargs[f"{prefix}_slice_end"] = end + + def use_alias(): + kwargs[f"{prefix}_alias"] = True + + # The start if the function if tensor is None: kwargs[f"{prefix}_base_index"] = None - elif get_base(tensor) is None: - # if the tensor is the base (not view), for simplicity we do not serialize view meta. - kwargs[f"{prefix}_base_index"] = base_index else: + base = get_base(tensor) kwargs[f"{prefix}_base_index"] = base_index - kwargs[f"{prefix}_size"] = tensor.size() - kwargs[f"{prefix}_stride"] = tensor.stride() - kwargs[f"{prefix}_storage_offset"] = tensor.storage_offset() + if base is None: + # no need to add anything else other than _base_index + return + elif is_alias(base, tensor): + use_alias() + elif (slice_info := try_use_slice(base, tensor)) is not None: + use_slice(*slice_info) + else: + use_as_strided(tensor) for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types): arg = kwargs[arg_name] @@ -129,18 +237,23 @@ def read_single_view(prefix): base_index = get_arg(f"{prefix}_base_index") if base_index is None: return None - elif f"{prefix}_size" not in kwargs: - assert f"{prefix}_stride" not in kwargs - assert f"{prefix}_storage_offset" not in kwargs - - # This means that the argument is the base tensor - return ViewInfo(base_index, all_bases[base_index], is_view=False) - - else: + elif f"{prefix}_alias" in kwargs: + get_arg(f"{prefix}_alias") + return AliasViewInfo(base_index) + elif f"{prefix}_storage_offset" in kwargs: + # The view is regenerated using as_strided. size = get_arg(f"{prefix}_size") stride = get_arg(f"{prefix}_stride") storage_offset = get_arg(f"{prefix}_storage_offset") - return ViewInfo(base_index, size, stride, storage_offset, is_view=True) + return AsStridedViewInfo(base_index, size, stride, storage_offset) + elif f"{prefix}_slice_dim" in kwargs: + dim = get_arg(f"{prefix}_slice_dim") + start = get_arg(f"{prefix}_slice_start") + end = get_arg(f"{prefix}_slice_end") + return SliceViewInfo(base_index, dim, start, end) + else: + # This means that the argument is the base tensor + return NotView(base_index) args_view_info: Dict[str, Any] = {} for arg_name, arg_type in zip(mutable_arg_names, mutable_arg_types): @@ -566,9 +679,11 @@ def auto_functionalized_dense( new_kwargs[name] = ( [clone_preserve_strides(x) for x in kwargs[name]] if kwargs[name] is not None and isinstance(kwargs[name], list) - else clone_preserve_strides(kwargs[name]) - if kwargs[name] is not None - else None + else ( + clone_preserve_strides(kwargs[name]) + if kwargs[name] is not None + else None + ) ) result.append(new_kwargs[name]) out = _mutable_op(**new_kwargs) diff --git a/torch/_higher_order_ops/cond.py b/torch/_higher_order_ops/cond.py index d997044c7f020..a90bcd1bc9e03 100644 --- a/torch/_higher_order_ops/cond.py +++ b/torch/_higher_order_ops/cond.py @@ -1,6 +1,8 @@ +# mypy: allow-untyped-decorators # mypy: allow-untyped-defs import contextlib import logging +from typing import Any, Callable, List, Tuple, Union import torch import torch._subclasses.functional_tensor @@ -21,6 +23,8 @@ _maybe_run_with_interpreter, _set_compilation_env, reenter_make_fx, + save_tensors_and_symints_for_backward, + saved_tensors_and_symints, unique_graph_id, UnsupportedAliasMutationException, ) @@ -60,7 +64,12 @@ def __call__(self, pred, true_fn, false_fn, operands): @exposed_in("torch") -def cond(pred, true_fn, false_fn, operands): +def cond( + pred: Union[bool, int, float, torch.Tensor], + true_fn: Callable, + false_fn: Callable, + operands: Union[Tuple, List] = (), +) -> Any: r""" Conditionally applies `true_fn` or `false_fn`. @@ -93,7 +102,8 @@ def cond(pred, true_branch, false_branch, operands): have consistent input and outputs, meaning the inputs have to be the same, and the outputs have to be the same type and shape. - operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the true/false functions. + operands (Tuple of possibly nested dict/list/tuple of torch.Tensor): A tuple of inputs to the + true/false functions. It can be empty if true_fn/false_fn doesn't require input. Defaults to (). Example:: @@ -154,7 +164,7 @@ def _validate_input(pred, true_fn, false_fn, operands): ) if not callable(true_fn) or not callable(false_fn): - raise RuntimeError("Expect both branches to be callbale.") + raise RuntimeError("Expect both branches to be callable.") if not isinstance(operands, (tuple, list)) or pytree.tree_any( lambda t: not isinstance(t, torch.Tensor), operands @@ -229,10 +239,10 @@ def create_fw_bw_graph_branches(true_fn, false_fn, *operands): def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands): assert isinstance( operands, (list, tuple) - ), "Cond operands must be a list or tuple of tensors" + ), f"Cond operands must be a list or tuple of tensors and SymInts {operands}" assert all( - isinstance(o, torch.Tensor) for o in operands - ), "Cond operands must be a list of tensors" + isinstance(o, (torch.Tensor, torch.SymInt)) for o in operands + ), f"Cond operands must be a list of tensors and SymInts {operands}" true_graph = reenter_make_fx(true_fn)(*operands) false_graph = reenter_make_fx(false_fn)(*operands) @@ -372,14 +382,14 @@ def forward( ctx._pred = pred ctx._joint_true_graph = joint_true_graph ctx._joint_false_graph = joint_false_graph - ctx.save_for_backward(*operands) + save_tensors_and_symints_for_backward(ctx, operands) with torch._C._AutoDispatchBelowAutograd(): return cond_op(pred, fw_true_graph, fw_false_graph, operands) @staticmethod def backward(ctx, *flat_grads): - operands = ctx.saved_tensors + operands = saved_tensors_and_symints(ctx) grads = cond_op( ctx._pred, @@ -442,6 +452,14 @@ def cond_fake_tensor_mode(mode, pred, true_fn, false_fn, operands): raise RuntimeError("Unmatched number of outputs from cond() branches.") for true_out, false_out in zip(flat_true_outs, flat_false_outs): + if true_out is None or false_out is None: + if true_out is None and false_out is None: + continue + raise torch._dynamo.exc.CondOpArgsMismatchError( + f"Expected both branches to return None:" + f"\n {true_fn.__name__} returns {true_out}" + f"\n {false_fn.__name__} returns {false_out}" + ) true_meta = _extract_tensor_metadata(true_out) false_meta = _extract_tensor_metadata(false_out) if true_meta != false_meta: @@ -466,14 +484,17 @@ def cond_func(ctx, pred, true_fn, false_fn, inputs): branch, unwrapped_inputs, pre_dispatch=pre_dispatch ): raise UnsupportedAliasMutationException( - "One of torch.cond branch might be modifying the input!" + "One of torch.cond branch might be modifying the input! " + "Consider cloning the input before modifying it. " ) for branch in [true_fn, false_fn]: if _has_potential_branch_input_alias( branch, unwrapped_inputs, pre_dispatch=pre_dispatch ): raise UnsupportedAliasMutationException( - "One of torch.cond branch might be aliasing the input!" + "One of torch.cond branch might be aliasing the input! " + "If you are returning a view of the input, please make sure " + "to clone it. " ) cond_return = cond_op( diff --git a/torch/_higher_order_ops/flex_attention.py b/torch/_higher_order_ops/flex_attention.py index 94e39a96ca759..56794cc1b93e9 100644 --- a/torch/_higher_order_ops/flex_attention.py +++ b/torch/_higher_order_ops/flex_attention.py @@ -11,7 +11,7 @@ reenter_make_fx, UnsupportedAliasMutationException, ) -from torch._ops import HigherOrderOperator, OpOverload +from torch._ops import HigherOrderOperator from torch._subclasses import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( make_fx, @@ -19,7 +19,6 @@ track_tensor_tree, ) from torch.fx.graph_module import GraphModule -from torch.overrides import TorchFunctionMode # Duplicate of _inductor/kernel/flex_attention.py to avoid circular import @@ -58,10 +57,9 @@ def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch torch.Tensor: A new tensor with same shape and data as the input, but with strides permuted based on the query tensor's stride order. """ - from torch._inductor.ir import get_stride_order, stride_order2fill_order + from torch._inductor.ir import get_fill_order - stride_order = get_stride_order(query_strides) - fill_order = stride_order2fill_order(stride_order) + fill_order = get_fill_order(query_strides) assert out.storage_offset() == 0, "Only support storage_offset == 0" out_strides = _construct_strides(out.shape, fill_order) new_out = out.new_empty(out.shape).as_strided(out.shape, out_strides) @@ -69,27 +67,6 @@ def _permute_strides(out: torch.Tensor, query_strides: Tuple[int, ...]) -> torch return new_out -class TransformGetItemToIndex(TorchFunctionMode): - # This is needed since we want to support calling - # A[q_idx], where q_idx is a scalar tensor in score_mod. - # Today, when q_idx is a scalar tensor, we implicitly convert it to a python - # scalar and create a view. We do not want that behavior in this case, so we - # use this torchfunctionmode to override that behavior for score_mod - # wherever we're running it. - def __torch_function__( - self, - func: OpOverload, - types: Tuple[torch._C._TensorMeta, ...], - args: Tuple[object, ...] = (), - kwargs: Optional[Dict[str, object]] = None, - ) -> object: - if func == torch.Tensor.__getitem__: - index_args = pytree.tree_leaves(args[1]) - if all(isinstance(x, torch.Tensor) for x in index_args): - return torch.ops.aten.index(args[0], index_args) - return func(*args, **(kwargs or {})) - - class FlexAttentionHOP(HigherOrderOperator): def __init__(self) -> None: super().__init__("flex_attention", cacheable=True) @@ -147,7 +124,9 @@ def __call__( kernel_options: Dict[str, Any], score_mod_other_buffers: Tuple = (), mask_mod_other_buffers: Tuple = (), - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...] + ]: if not all( isinstance(buf, torch.Tensor) for buf in score_mod_other_buffers + mask_mod_other_buffers @@ -185,6 +164,8 @@ def _math_attention_inner( score_mod_other_buffers: Tuple = (), mask_mod_other_buffers: Tuple = (), ) -> Tuple[torch.Tensor, torch.Tensor]: + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + working_precision = torch.float64 if query.dtype == torch.float64 else torch.float32 scores = (query @ key.transpose(-2, -1)).to(dtype=working_precision) @@ -318,6 +299,8 @@ def trace_flex_attention( This will produce a GraphModule that will be stored on the root tracer as "sdpa_score". We access this graph module in inductor to inline the score_mod function to the triton template. """ + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + example_out = flex_attention( query, key, @@ -414,6 +397,8 @@ def flex_attention_functionalize( guard against any mutations in the score_mod function, to the other_buffers since those are free variables. """ + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + query_unwrapped = ctx.unwrap_tensors(query) key_unwrapped = ctx.unwrap_tensors(key) value_unwrapped = ctx.unwrap_tensors(value) @@ -594,16 +579,15 @@ def forward( block_mask: Tuple[Any, ...], scale: float, kernel_options: Dict[str, Any], - score_mod_other_buffers: Tuple[Any, ...], mask_mod_other_buffers: Tuple[Any, ...], + *score_mod_other_buffers: Tuple[Any, ...], ) -> Tuple[torch.Tensor, torch.Tensor]: any_buffer_requires_grad = any( - buffer.requires_grad - for buffer in score_mod_other_buffers + mask_mod_other_buffers + buffer.requires_grad for buffer in mask_mod_other_buffers ) assert ( not any_buffer_requires_grad - ), "Captured buffers that require grad are not yet supported." + ), "Captured buffers from mask mod that require grad are not yet supported." ctx._fw_graph = fw_graph ctx._joint_graph = joint_graph ctx._mask_graph = block_mask[-1] @@ -670,9 +654,15 @@ def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Option mask_mod_other_buffers = tuple( other_buffers[ctx._score_mod_other_buffers_len :] ) - # We have asserted that other_buffers do not require grad in the forward - none_grads = [None] * 7 - grad_query, grad_key, grad_value = flex_attention_backward( + # We have asserted that mask_mod_other_buffers do not require grad, + # but score_mod_other_buffers can require grad. + none_grads = [None] * 6 + ( + grad_query, + grad_key, + grad_value, + grad_score_mod_captured, + ) = flex_attention_backward( query, key, value, @@ -700,7 +690,7 @@ def backward(ctx: Any, grad_out: Tensor, grad_logsumexp: Tensor) -> Tuple[Option score_mod_other_buffers, mask_mod_other_buffers, ) - return grad_query, grad_key, grad_value, *none_grads + return grad_query, grad_key, grad_value, *none_grads, *grad_score_mod_captured @flex_attention.py_impl(DispatchKey.Autograd) @@ -715,6 +705,8 @@ def flex_attention_autograd( score_mod_other_buffers: Tuple[Tensor, ...] = (), mask_mod_other_buffers: Tuple[Tensor, ...] = (), ) -> Tuple[torch.Tensor, torch.Tensor]: + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + with TransformGetItemToIndex(): input_requires_grad = any(t.requires_grad for t in (query, key, value)) if torch.is_grad_enabled() and input_requires_grad: @@ -739,8 +731,8 @@ def flex_attention_autograd( block_mask, scale, kernel_options, - score_mod_other_buffers, mask_mod_other_buffers, + *score_mod_other_buffers, ) return out, logsumexp @@ -764,11 +756,19 @@ def sdpa_dense_backward( kernel_options: Dict[str, Any], score_mod_other_buffers: Tuple, mask_mod_other_buffers: Tuple, -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...] +]: + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + # Get outputs before calling repeat interleave actual_grad_query = torch.empty_like(query) actual_grad_key = torch.empty_like(key) actual_grad_value = torch.empty_like(value) + actual_grad_score_mod_captured = [ + torch.empty_like(buffer) if buffer.requires_grad else None + for buffer in score_mod_other_buffers + ] Bq, Bkv = query.size(0), key.size(0) if not ((Bq == Bkv) or (Bq > 1 and Bkv == 1)): @@ -829,7 +829,7 @@ def sdpa_dense_backward( out_dims=out_dims, ) with TransformGetItemToIndex(): - grad_scores, *_ = joint_score_mod( + grad_scores, _, _, _, _, *grad_score_mod_captured = joint_score_mod( scores, b, h, m, n, grad_score_mod, *score_mod_other_buffers ) grad_scores = grad_scores * scale @@ -870,8 +870,19 @@ def sdpa_dense_backward( actual_grad_query.copy_(grad_query) actual_grad_key.copy_(grad_key) actual_grad_value.copy_(grad_value) + score_mod_other_buffer_grads = [ + actual_grad.copy_(grad) if actual_grad is not None else actual_grad + for actual_grad, grad in zip( + actual_grad_score_mod_captured, grad_score_mod_captured + ) + ] - return actual_grad_query, actual_grad_key, actual_grad_value + return ( + actual_grad_query, + actual_grad_key, + actual_grad_value, + tuple(score_mod_other_buffer_grads), + ) def trace_flex_attention_backward( @@ -890,8 +901,12 @@ def trace_flex_attention_backward( kernel_options: Dict[str, Any], score_mod_other_buffers: Tuple = (), mask_mod_other_buffers: Tuple = (), -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...] +]: """We already have the forward graph and joint graph from the forward pass, so we create a proxy attach both graphs""" + from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex + example_out = flex_attention_backward( query, key, @@ -974,7 +989,9 @@ def flex_attention_backward_proxy_torch_dispatch_mode( kernel_options: Dict[str, Any], score_mod_other_buffers: Tuple = (), mask_mod_other_buffers: Tuple = (), -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...] +]: assert mode is not None, "Mode should always be enabled for python fallback key" return trace_flex_attention_backward( mode, @@ -1012,7 +1029,9 @@ def flex_attention_backward_functionalize( kernel_options: Dict[str, Any], score_mod_other_buffers: Tuple = (), mask_mod_other_buffers: Tuple = (), -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...] +]: """Defines the functionalization rules for the flex_attention operator. Write now we are unwrapping each tensor and then redispatching to the next, @@ -1050,7 +1069,12 @@ def flex_attention_backward_functionalize( functional_fw_graph = ctx.functionalize(fw_graph) functional_joint_graph = ctx.functionalize(joint_graph) - grad_query, grad_key, grad_value = flex_attention_backward( + ( + grad_query, + grad_key, + grad_value, + grad_score_mod_captured, + ) = flex_attention_backward( query_unwrapped, key_unwrapped, value_unwrapped, @@ -1067,7 +1091,7 @@ def flex_attention_backward_functionalize( mask_mod_other_buffers_unwrapped, ) - return ctx.wrap_tensors((grad_query, grad_key, grad_value)) # type: ignore[return-value,arg-type] + return ctx.wrap_tensors((grad_query, grad_key, grad_value, grad_score_mod_captured)) # type: ignore[return-value,arg-type] @flex_attention_backward.py_impl(FakeTensorMode) @@ -1087,12 +1111,20 @@ def flex_attention_backward_fake_tensor_mode( kernel_options: Dict[str, Any], score_mod_other_buffers: Tuple = (), mask_mod_other_buffers: Tuple = (), -) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +) -> Tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, Tuple[Optional[torch.Tensor], ...] +]: with mode: grad_query = torch.empty_like(query) grad_key = torch.empty_like(key) grad_value = torch.empty_like(value) - return grad_query, grad_key, grad_value + grad_score_mod_captured = tuple( + [ + torch.empty_like(buffer) if buffer.requires_grad else None + for buffer in score_mod_other_buffers + ] + ) + return grad_query, grad_key, grad_value, grad_score_mod_captured flex_attention_backward.py_impl(DispatchKey.Autograd)( diff --git a/torch/_higher_order_ops/invoke_subgraph.py b/torch/_higher_order_ops/invoke_subgraph.py new file mode 100644 index 0000000000000..f9b82d701b198 --- /dev/null +++ b/torch/_higher_order_ops/invoke_subgraph.py @@ -0,0 +1,266 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs + + +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils._pytree as pytree +from torch._C import DispatchKey +from torch._dispatch.python import suspend_functionalization +from torch._higher_order_ops.utils import ( + _from_fun, + _maybe_reenter_make_fx, + clone_outputs_aliasing_inputs, + get_dummy_aot_autograd_config, + prepare_fw_with_masks, + reenter_make_fx, +) +from torch._ops import HigherOrderOperator +from torch._subclasses import FakeTensorMode +from torch._subclasses.functional_tensor import disable_functional_mode +from torch.fx.experimental.proxy_tensor import ( + disable_proxy_modes_tracing, + ProxyTorchDispatchMode, + track_tensor_tree, +) +from torch.fx.graph_module import GraphModule + + +invoke_subgraph_counter = 0 + + +class InvokeSubgraphHOP(HigherOrderOperator): + def __init__(self) -> None: + super().__init__("invoke_subgraph") + + # identifier is setup by upper part of the stack. This helps us in + # identifying two invoke_subgraph calls have same subgraph. + def __call__( + self, + subgraph: GraphModule, + identifier: Optional[str], + operands: Union[ + List[Union[torch.Tensor, torch.SymInt]], + Tuple[Union[torch.Tensor, torch.SymInt]], + ], + ): + assert identifier is None or isinstance( + identifier, str + ), "identifier must be a None or a string" + + assert isinstance( + operands, (list, tuple) + ), f"invoke_subgraph operands must be a list or tuple of tensors and SymInts {operands}" + assert all( + isinstance(o, (torch.Tensor, torch.SymInt)) for o in operands + ), f"invoke_subgraph operands must be a list of tensors and SymInts {operands}" + + return super().__call__(subgraph, identifier, operands) + + +invoke_subgraph = InvokeSubgraphHOP() + + +def get_invoke_subgraph_cache(): + cache = None + if tracing_ctx := torch._guards.TracingContext.try_get(): + cache = tracing_ctx.hop_dispatch_set_cache.get_cache(invoke_subgraph) + return cache + + +def trace_joint_graph(fn, fw_inputs, fw_outputs): + """ + Naively trace out a joint graph. This simplifies the reconstruction of joint + graph in the min-cut partitioner later on. + """ + from torch._functorch.aot_autograd import create_joint + + dummy_aot_config = get_dummy_aot_autograd_config() + + def joint_fn(*primals_and_tangents): + primals = primals_and_tangents[: len(fw_inputs)] + tangents = primals_and_tangents[len(fw_inputs) :] + + fw_outs, grads = create_joint( + prepare_fw_with_masks(fn), aot_config=dummy_aot_config + )(primals, tangents) + + maybe_clone = clone_outputs_aliasing_inputs(primals_and_tangents) + + return pytree.tree_map(maybe_clone, list(fw_outs) + grads) + + primals = list(fw_inputs) + # This assumes that the tangent strides match fw_outputs strides. Check the + # InvokeSubgraphAutogradOp backward op for the contiguous call. + tangents = [_from_fun(out) for out in fw_outputs] + + joint_operands = primals + tangents + + return _maybe_reenter_make_fx(joint_fn)(*joint_operands) + + +def create_fw_bw_graph(subgraph, operands): + with suspend_functionalization(), disable_functional_mode(): + with disable_proxy_modes_tracing(): + # args are functional tensors, generate some example tensors + fw_inputs = pytree.tree_map(_from_fun, operands) + + fw_outputs = pytree.tree_map(_from_fun, subgraph(*fw_inputs)) + if any( + not isinstance(out, torch.Tensor) + for out in fw_outputs + if out is not None + ): + raise RuntimeError( + "Expect outputs of invoke_subgraph to only contains tensors or None. " + f"Got types {[type(out) for out in fw_outputs]}." + ) + + # Trace the forward subgraph + fw_graph = _maybe_reenter_make_fx(subgraph)(*fw_inputs) + + # Trace the joint graph and assign it to the bwd graph + bw_graph = trace_joint_graph( + subgraph, + fw_inputs, + fw_outputs, + ) + return fw_graph, bw_graph, len(fw_outputs) + + +class InvokeSubgraphAutogradOp(torch.autograd.Function): + """ + This autograd function op is to stash the backward graph in the ctx while + running forward. + """ + + @staticmethod + def forward(ctx, fw_graph, bw_graph, identifier, num_fw_outs, *operands): + ctx._fw_graph = fw_graph + ctx._bw_graph = bw_graph + ctx._identifier = identifier + ctx._num_fw_outs = num_fw_outs + + with torch._C._AutoDispatchBelowAutograd(): + out = invoke_subgraph( + fw_graph, + f"___forward_{identifier}", + operands, + ) + + ctx.save_for_backward(*operands) + return out + + @staticmethod + def backward(ctx, *grad_outs): + bw_graph = ctx._bw_graph + identifier = ctx._identifier + primals = ctx.saved_tensors + num_fw_outs = ctx._num_fw_outs + + # While tracing we made the assumption that tangents are contiguous. So, + # force the grad_outs to be contiguous. + contiguous_grad_outs = tuple([o.contiguous() for o in grad_outs]) + + # bw_graph is a joint graph with signature (*primals_and_tangents) and + # returns (*fw_outs_and_grads). To get the grads, we use the num_fw_outs + # to extract the grads. + primals_and_tangents = primals + contiguous_grad_outs + grads = invoke_subgraph( + bw_graph, f"___backward_{identifier}", primals_and_tangents + )[num_fw_outs:] + return None, None, None, None, *grads + + +@invoke_subgraph.py_impl(DispatchKey.CompositeExplicitAutograd) +def _(subgraph, identifier, operands): + from torch.utils._python_dispatch import _get_current_dispatch_mode + + mode = _get_current_dispatch_mode() + assert mode is None, "Mode should never be enabled for CPU/CUDA key" + return subgraph(*operands) + + +@invoke_subgraph.py_impl(DispatchKey.Autograd) +def _(subgraph, identifier, operands): + if not torch.is_grad_enabled(): + with torch._C._AutoDispatchBelowAutograd(): + return invoke_subgraph(subgraph, identifier, operands) + + # A shortcut for the case where all inputs don't require gradient, + # we skip tracing the forward and backward graph. + if pytree.tree_all_only( + torch.Tensor, + lambda t: not t.requires_grad, # type: ignore[union-attr] + operands, + ): + with torch._C._AutoDispatchBelowAutograd(): + return invoke_subgraph(subgraph, identifier, operands) + + # Check if we have already traced the subgraph. + invoke_subgraph_cache = get_invoke_subgraph_cache() + if invoke_subgraph_cache: + if saved_autograd_fn := invoke_subgraph_cache.get_autograd_key_entry( + identifier + ): + return saved_autograd_fn(*operands) + + fw_graph, bw_graph, num_fw_outs = create_fw_bw_graph(subgraph, operands) + + def autograd_fn_callable(*args): + return InvokeSubgraphAutogradOp.apply( + fw_graph, bw_graph, identifier, num_fw_outs, *args + ) + + # Save the autograd_fn_callable in the dispatch set cache. + if invoke_subgraph_cache: + invoke_subgraph_cache.add_autograd_key_entry(identifier, autograd_fn_callable) + + return autograd_fn_callable(*operands) + + +@invoke_subgraph.py_functionalize_impl +def _(ctx, subgraph, identifier, operands): + unwrapped_operands = ctx.unwrap_tensors(operands) + with ctx.redispatch_to_next() as m: + # NB: There is an assumption that subgraph does not mutate inputs and + # there is no aliasing. Its Dynamo responsibility to prevent formation + # of invoke_subgraph ops if input aliasing/mutation is detected. + functionalized_subgraph = ctx.functionalize(subgraph) + out = invoke_subgraph(functionalized_subgraph, identifier, unwrapped_operands) + return ctx.wrap_tensors(out) + + +@invoke_subgraph.py_impl(FakeTensorMode) +def _(mode, subgraph, identifier, operands): + # TODO(anijain2305) - Implement fake tensor caching. + return subgraph(*operands) + + +@invoke_subgraph.py_impl(ProxyTorchDispatchMode) +def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, operands): + # Check if we have already traced the subgraph. + graph = None + invoke_subgraph_cache = get_invoke_subgraph_cache() + if invoke_subgraph_cache: + graph = invoke_subgraph_cache.get_proxy_dispatch_entry(identifier) + + if graph is None: + graph = reenter_make_fx(subgraph)(*operands) + assert isinstance(proxy_mode.tracer, torch.fx.Tracer) + qualname = proxy_mode.tracer.get_fresh_qualname("repeated_subgraph") + proxy_mode.tracer.root.register_module(qualname, graph) + if invoke_subgraph_cache: + invoke_subgraph_cache.add_proxy_dispatch_entry(identifier, graph) + + node_args = (graph, identifier, operands) + proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) # type: ignore[union-attr] + out_proxy = proxy_mode.tracer.create_proxy( + "call_function", invoke_subgraph, proxy_args, {} + ) + + example_out = invoke_subgraph(graph, identifier, operands) + return track_tensor_tree( + example_out, out_proxy, constant=None, tracer=proxy_mode.tracer + ) diff --git a/torch/_higher_order_ops/scan.py b/torch/_higher_order_ops/scan.py index d66cff067f668..a5a08fea26a31 100644 --- a/torch/_higher_order_ops/scan.py +++ b/torch/_higher_order_ops/scan.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools import itertools -from typing import Callable, List, Tuple +from typing import Any, Callable, List, Tuple import torch import torch._prims_common as utils @@ -41,7 +41,11 @@ def wrap_combine_fn_flat( carry_flat = pytree.tree_leaves(carry) combined_flat = pytree.tree_leaves(combined) assert num_init_leaves == len(carry_flat) - return (carry_flat, combined_flat) + return [*carry_flat, *combined_flat] + + +def _extract_carry_and_out(flat_out: List[Any], num_carry: int): + return flat_out[:num_carry], flat_out[num_carry:] def scan( @@ -50,7 +54,6 @@ def scan( ], init: pytree.PyTree, xs: pytree.PyTree, - /, *, dim: int = 0, reverse: bool = False, @@ -86,7 +89,7 @@ def scan( final_carry (torch.Tensor or pytree with tensor leaves), the final carry of the scan operation with same pytree structure as init. out (torch.Tensor or pytree with tensor leaves), - each tensor leaf is a stacked output along dim, where each slice is the output of a scan iteration. + each tensor leaf is a stacked output along first dim, where each slice is the output of a scan iteration. Example:: @@ -95,8 +98,8 @@ def add(x: torch.Tensor, y: torch.Tensor): return next_carry, y i0 = torch.zeros(1) - xs = torch.arange(1, 5) - # returns torch.tensor([10]), torch.tensor([1., 3., 6., 10.]) + xs = torch.arange(5) + # returns torch.tensor([10.]), torch.tensor([[0], [1.], [3.], [6.], [10.]]) last_carry, cumsum = scan(add, init=i0, xs=xs) @@ -108,15 +111,85 @@ def add(x: torch.Tensor, y: torch.Tensor): if not isinstance(reverse, bool): raise RuntimeError("Reverse must be a bool, but got " + str(type(reverse))) + leaves_init, spec_init = pytree.tree_flatten(init) + leaves_xs, spec_xs = pytree.tree_flatten(xs) + + if len(leaves_init) == 0: + raise RuntimeError("Init tensors must be provided") + for x in leaves_init: + if not isinstance(x, torch.Tensor): + raise RuntimeError(f"All init leaves must be a Tensor but got {x}") + for x in leaves_xs: + if not isinstance(x, torch.Tensor): + raise RuntimeError(f"All xs leaves must be a Tensor but got {x}") + if x.shape[dim] == 0: + raise RuntimeError( + f"All xs leaves must have a scan dimension > 0 but got {x}" + ) + + if len(leaves_xs) == 0: + return pytree.tree_unflatten(leaves_init, spec_init), xs + + shape = leaves_xs[0].shape + ndim = len(shape) + dim = utils.canonicalize_dim(ndim, dim) + + out = combine_fn( + pytree.tree_unflatten(leaves_init, spec_init), + pytree.tree_unflatten([elem.select(dim, 0) for elem in leaves_xs], spec_xs), + ) + + # The first output needs to have the same pytree as init + carry_leaves = pytree.tree_leaves(out[0]) + if len(carry_leaves) != len(leaves_init): + raise RuntimeError( + f"The number of leaves of the pytree of the new carry produced by the operator is {len(carry_leaves)}\ +doesn't match the length of the pytree of the init {len(leaves_init)}" + ) + + def _check_new_carry_match_init(leaves_init, carry_leaves): + for i, (init, new_carry) in enumerate(zip(leaves_init, carry_leaves)): + if init.shape != new_carry.shape: + raise RuntimeError( + f"The shape of the new_carry[{i}] {new_carry.shape} doesn't match that of the init[{i}] {init.shape}." + ) + if init.stride() != new_carry.stride(): + raise RuntimeError( + f"The stride of the new_carry[{i}] {new_carry.stride()} doesn't match that of the init[{i}] {init.stride()}." + ) + if init.dtype != new_carry.dtype: + raise RuntimeError( + f"The dtype of the new_carry[{i}] {new_carry.dtype} doesn't match that of the init[{i}] {init.dtype}." + ) + if init.requires_grad != new_carry.requires_grad: + raise RuntimeError( + f"The requires_grad of the new_carry[{i}] {new_carry.requires_grad} doesn't match that of the init[{i}] {init.requires_grad}." # noqa: B950 + ) + + _check_new_carry_match_init(leaves_init, carry_leaves) + + # There are no pytree restrictions on the second output of the operator + out_leaves, tree_out = pytree.tree_flatten(out[1]) + # TODO: Support closures/nn_modules in order to be able represent RNNs with scan # TODO: Support _inductor lowering # TODO: Support Autograd # TODO: Unify handling of pytrees for control flow ops, such as cond, while_loop, etc. + # TODO: Unify the list inputs of control flow ops to tuple. + + combine_fn = functools.partial( + wrap_combine_fn_flat, + combine_fn=combine_fn, + spec_init=spec_init, + spec_xs=spec_xs, + num_init_leaves=len(leaves_init), + num_inp_leaves=len(leaves_xs), + ) - # Dynamo is expecting a callable with "__code__" attribute. - # We cannot directly pass cond_op to it. So we wrap it in a dummy function. - def _scan_op_wrapper(*args, **kwargs): - return scan(*args, **kwargs) + def run_flattened_scan(combine_fn, leaves_init, leaves_xs, dim, reverse): + return scan_op( + combine_fn, leaves_init, leaves_xs, dim, reverse, additional_inputs=[] + ) if not torch._dynamo.is_compiling(): from torch._dynamo.backends.debugging import ( @@ -129,84 +202,43 @@ def _scan_op_wrapper(*args, **kwargs): backend = make_eager_backend_with_torch_function_mode(metadata_mode) else: backend = "eager" - return torch.compile(_scan_op_wrapper, backend=backend, fullgraph=True)( - combine_fn, init, xs, dim=dim, reverse=reverse + result = torch.compile( + run_flattened_scan, backend=backend, fullgraph=True + )( + combine_fn, + leaves_init, + leaves_xs, + dim=dim, + reverse=reverse, ) + else: + result = run_flattened_scan(combine_fn, leaves_init, leaves_xs, dim, reverse) - leaves_init, spec_init = pytree.tree_flatten(init) - leaves_xs, spec_xs = pytree.tree_flatten(xs) - - if len(leaves_init) == 0: - raise RuntimeError("Init tensors must be provided") - if any(not isinstance(x, torch.Tensor) for x in leaves_init): - raise RuntimeError("All init leaves must be a Tensor") - if any(not isinstance(x, torch.Tensor) for x in leaves_xs): - raise RuntimeError("All xs leaves must be a Tensor") - if any(x.shape[dim] == 0 for x in leaves_xs): - raise RuntimeError("All xs leaves must have a scan dimension > 0") - - if len(leaves_xs) > 0: - shape = leaves_xs[0].shape - ndim = len(shape) - dim = utils.canonicalize_dim(ndim, dim) - - out = combine_fn( - pytree.tree_unflatten(leaves_init, spec_init), - pytree.tree_unflatten( - [aten.slice(elem, dim, 0, 1, 1) for elem in leaves_xs], spec_xs - ), - ) - - # The first output needs to have the same pytree as init - carry_leaves = pytree.tree_leaves(out[0]) - if len(carry_leaves) != len(leaves_init): - raise RuntimeError( - "The number of leaves of the pytree of the new carry produced by the operator\ - needs to match the length of the pytree of the init" - ) - if any( - in_l.shape != out_l.shape for in_l, out_l in zip(leaves_init, carry_leaves) - ): - raise RuntimeError( - "The pytree of the new carry produced by the operator needs to match the pytree of the init" - ) - - # There are no pytree restrictions on the second output of the operator - out_leaves, tree_out = pytree.tree_flatten(out[1]) - - combine_fn = functools.partial( - wrap_combine_fn_flat, - combine_fn=combine_fn, - spec_init=spec_init, - spec_xs=spec_xs, - num_init_leaves=len(leaves_init), - num_inp_leaves=len(leaves_xs), - ) - - result_carry, result_flat = scan_op( - combine_fn, leaves_init, leaves_xs, dim, reverse - ) - - return pytree.tree_unflatten(result_carry, spec_init), pytree.tree_unflatten( - result_flat, tree_out - ) + result_carry, result_flat = _extract_carry_and_out( + result, + len(leaves_init), + ) - else: - return pytree.tree_unflatten(leaves_init, spec_init), xs + return pytree.tree_unflatten(result_carry, spec_init), pytree.tree_unflatten( + result_flat, tree_out + ) class ScanOp(HigherOrderOperator): def __init__(self): super().__init__("scan") - def __call__(self, combine_fn, init, xs, dim, reverse): - return super().__call__(combine_fn, init, xs, dim, reverse) + def __call__(self, combine_fn, init, xs, dim, reverse, additional_inputs): + assert isinstance(additional_inputs, list), "additional_inputs must be a list." + return super().__call__(combine_fn, init, xs, dim, reverse, additional_inputs) scan_op = ScanOp() -def generic_scan(operator, init, xs, dim=0, reverse=False): +def generic_scan(operator, init, xs, dim=0, reverse=False, additional_inputs=None): + additional_inputs = additional_inputs if additional_inputs is not None else [] + def _scan(init, xs): """Perform scan on `elems` using `elems_init.""" carry = init @@ -220,85 +252,77 @@ def _scan(init, xs): ind = 0 # Compute dummy shapes for the pre-allocation - dummy_carry, dummy_out = operator( - *carry, *[aten.slice(elem, dim, 0, 1, 1) for elem in xs] + num_init_leaves = len(init) + dummy_carry, dummy_out = _extract_carry_and_out( + operator( + *carry, + *[first_slice_copy(elem, dim) for elem in xs], + *additional_inputs, + ), + num_init_leaves, ) - output_scanned_dim = dummy_out[0].shape[dim] # Pre-alocate # outs -> Output matrix # idxs -> Index matrix for scatter_ - outs, outs_idxs = zip( + # out: (num_elems, M, N, ...) + # idx: (1, M, N) + outs, idxs = zip( *[ [ torch.zeros( - list(e.size())[:dim] - + [list(e.size())[dim] * num_elems] - + list(e.size())[dim + 1 :], + [num_elems] + list(e.size()), dtype=e.dtype, device=e.device, ), - torch.cat( - [ - id * t - for id, t in zip( - range(output_scanned_dim), - torch.tensor_split( - torch.ones_like(e, dtype=torch.int64), - output_scanned_dim, - dim=dim, - ), - ) - ], - dim, - ), + torch.ones_like(e, dtype=torch.int64).unsqueeze(0), ] for i, e in enumerate(dummy_out) ] ) - def store_in_mat(mat, out, d, index, index_modifier): + def store_out_in_outs(out, ind): # Store the intermediate out in the outs matrix - for o, x, idx in zip(mat, out, index): - o.scatter_(d, idx + index_modifier, x) - - def cond(i, n, r): - if (r and i < 0) or (not r and i > (n - 1)): - return False - else: - return True - - def op(i): - if reverse: - return i - 1 - else: - return i + 1 - - while cond(ind, num_elems, reverse): - carry, out = operator( - *carry, - *[aten.slice(elem, dim, ind, ind + 1, 1) for elem in xs], + for o, x, idx in zip(outs, out, idxs): + # o: (num_elems, M, N ...) + # x: (M, N, ...) -> (1, M, N) + # ind * idx: (1, M, N,) with values to be ind + # essentially: o[ind][n][k] = x[0][n][k] + o.scatter_(0, ind * idx, x.unsqueeze(0)) + + for i in range(num_elems): + ind = i if not reverse else num_elems - i - 1 + carry, out = _extract_carry_and_out( + operator( + *carry, + *[elem.select(dim, ind) for elem in xs], + *additional_inputs, + ), + num_init_leaves, ) # Store the inits in the outs matrix. - store_in_mat(outs, out, dim, outs_idxs, ind * output_scanned_dim) - - ind = op(ind) + store_out_in_outs(out, ind) - return (carry, list(outs)) + return [*carry, *list(outs)] scans = _scan(init, xs) return scans -def make_expanded_output_shape(dim, scan_length, shapes, use_sh=False): - expanded_shapes = [ - tuple( - (s if use_sh else -1) if i != dim else scan_length for i, s in enumerate(sh) - ) - for sh in shapes - ] - return expanded_shapes +def first_slice_copy(t: torch.Tensor, dim: int) -> torch.Tensor: + return torch.select_copy(t, dim, 0) + + +# We also do a clone with contiguous_format. This is to be consistent with +# eager semantic of scan, which stacks the outputs. The result is contiguous +# as a result of the stack operation. +def stack_y(y: torch.Tensor, scan_length: int) -> torch.Tensor: + return ( + y.unsqueeze(0) + .repeat(*([scan_length] + [1] * y.ndim)) + .clone(memory_format=torch.contiguous_format) + ) def trace_scan( @@ -309,27 +333,15 @@ def trace_scan( xs: List[torch.Tensor], dim: int, reverse: bool, + additional_inputs: List[torch.Tensor], ): with disable_proxy_modes_tracing(): - sample_inits = [ - torch.empty_like( - x_init, - dtype=x_init.dtype, - device=x_init.device, - requires_grad=x_init.requires_grad, - ) - for x_init in init - ] - sample_xs = [ - torch.empty_like( - aten.slice(x, dim, 0, 1, 1), - dtype=x.dtype, - device=x.device, - requires_grad=x.requires_grad, - ) - for x in xs - ] - combine_graph = reenter_make_fx(combine_fn)(*sample_inits, *sample_xs) + sample_inits = [x_init.clone() for x_init in init] + sample_inputs = [first_slice_copy(x, dim) for x in xs] + sample_additional_inputs = [x.clone() for x in additional_inputs] + combine_graph = reenter_make_fx(combine_fn)( + *sample_inits, *sample_inputs, *sample_additional_inputs + ) outputs = None for node in combine_graph.graph.nodes: @@ -339,16 +351,13 @@ def trace_scan( outputs = node.args[0] assert outputs is not None - if len(outputs) != 2: - raise RuntimeError( - f"Expected to return 2 outputs: carry, out_matrix, but got:" - f"\n {len(outputs)} elements" - ) - for ini, carry in zip(init, outputs[0]): + carry, output = _extract_carry_and_out(outputs, len(init)) + + for ini, ca in zip(init, carry): ini_meta = ini - carry_meta = carry.meta["tensor_meta"] - carry_val = carry.meta["val"] + carry_meta = ca.meta["tensor_meta"] + carry_val = ca.meta["val"] if ( carry_val.device != ini_meta.device or carry_meta.dtype != ini_meta.dtype @@ -363,7 +372,7 @@ def trace_scan( proxy_mode.tracer.root.register_module(combine_graph_name, combine_graph) - args = (combine_graph, init, xs, dim, reverse) + args = (combine_graph, init, xs, dim, reverse, additional_inputs) proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, args) out_proxy = proxy_mode.tracer.create_proxy( "call_function", func_overload, proxy_args, {}, name="scan" @@ -371,29 +380,22 @@ def trace_scan( with disable_proxy_modes_tracing(): scan_length = xs[0].shape[dim] - fake_out_shapes = make_expanded_output_shape( - dim, scan_length, [o.meta["val"].size() for o in outputs[1]] + fake_carry, fake_outputs = _extract_carry_and_out( + [o.meta["val"] for o in outputs], len(init) + ) + out = ( + *fake_carry, + *(stack_y(t, scan_length) for t in fake_outputs), ) - - def expand_tensor(t, sh): - if isinstance(t, torch.Tensor): - return t.expand(*sh) - return t - - expanded_outs = [ - pytree.tree_map(expand_tensor, t.meta["val"], sh) - for t, sh in zip(outputs[1], fake_out_shapes) - ] - out = (init, expanded_outs) return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer) @scan_op.py_impl(DispatchKey.CompositeExplicitAutograd) -def scan_op_dense(combine_fn, init, xs, dim, reverse): +def scan_op_dense(combine_fn, init, xs, dim, reverse, additional_inputs): mode = _get_current_dispatch_mode() assert mode is None, "Mode should never be enabled for CPU/CUDA key" - return generic_scan(combine_fn, init, xs, dim, reverse) + return generic_scan(combine_fn, init, xs, dim, reverse, additional_inputs) scan_op.py_impl(DispatchKey.Autograd)( @@ -402,47 +404,108 @@ def scan_op_dense(combine_fn, init, xs, dim, reverse): @scan_op.py_impl(ProxyTorchDispatchMode) -def scan_proxy_mode(mode, combine_fn, init, xs, dim, reverse): - return trace_scan(mode, scan_op, combine_fn, init, xs, dim, reverse) +def scan_proxy_mode(mode, combine_fn, init, xs, dim, reverse, additional_inputs): + return trace_scan( + mode, scan_op, combine_fn, init, xs, dim, reverse, additional_inputs + ) @scan_op.py_impl(FakeTensorMode) -def scan_fake_tensor_mode(mode, combine_fn, init, xs, dim, reverse): +def scan_fake_tensor_mode(mode, combine_fn, init, xs, dim, reverse, additional_inputs): with mode: - dim_len = xs[0].shape[dim] - carry, outputs = combine_fn( - *init, *[aten.slice(inp, dim, 0, 1, 1) for inp in xs] + scan_length = xs[0].shape[dim] + carry, outputs = _extract_carry_and_out( + combine_fn( + *init, + *[first_slice_copy(inp, dim) for inp in xs], + *additional_inputs, + ), + len(init), ) - fake_out_shapes = [ - tuple(-1 if i != dim else dim_len for i, sh in enumerate(o.size())) - for o in outputs - ] out = ( - carry, - tuple(t.expand(*sh).clone() for t, sh in zip(outputs, fake_out_shapes)), + *carry, + *(stack_y(t, scan_length) for t in outputs), ) return out @scan_op.py_functionalize_impl -def scan_functionalize(ctx, combine_fn, init, xs, dim, reverse): +def scan_functionalize(ctx, combine_fn, init, xs, dim, reverse, additional_inputs): unwrapped_xs = ctx.unwrap_tensors(xs) unwrapped_init = ctx.unwrap_tensors(init) + unwrapped_additional_inputs = ctx.unwrap_tensors(additional_inputs) with ctx.redispatch_to_next() as m: functional_combine_fn = ctx.functionalize(combine_fn) pre_dispatch = hasattr(ctx, "mode") and ctx.mode.pre_dispatch - sample_xs = list(itertools.chain(unwrapped_init, unwrapped_init)) + sample_unwrapped_xs_sliced = [ + first_slice_copy(inp, dim) for inp in unwrapped_xs + ] + sample_inputs = list( + itertools.chain( + unwrapped_init, + sample_unwrapped_xs_sliced, + unwrapped_additional_inputs, + ) + ) if _has_potential_branch_input_mutation( - functional_combine_fn, sample_xs, pre_dispatch=pre_dispatch + functional_combine_fn, sample_inputs, pre_dispatch=pre_dispatch ): raise UnsupportedAliasMutationException( "Combine_fn might be modifying the input!" ) if _has_potential_branch_input_alias( - functional_combine_fn, sample_xs, pre_dispatch=pre_dispatch + functional_combine_fn, sample_inputs, pre_dispatch=pre_dispatch ): raise UnsupportedAliasMutationException( "Combine_fn might be aliasing the input!" ) - ret = scan_op(functional_combine_fn, unwrapped_init, unwrapped_xs, dim, reverse) + ret = scan_op( + functional_combine_fn, + unwrapped_init, + unwrapped_xs, + dim, + reverse, + unwrapped_additional_inputs, + ) return ctx.wrap_tensors(ret) + + +# dense implementation for scan. Used for testing only. +def _fake_scan(combine_fn, init, xs=None, dim=0, reverse=False): + carry_leaves, carry_spec = pytree.tree_flatten(init) + inp_leaves, inp_spec = pytree.tree_flatten(xs) + if xs is None or len(inp_leaves) == 0: + return init, [] + result_flat = [] + carry = carry_leaves + op = reversed if reverse else lambda x: x + + dummy_carry, dummy_out = combine_fn( + pytree.tree_unflatten(carry, carry_spec), + pytree.tree_unflatten( + [first_slice_copy(elem, dim) for elem in inp_leaves], + inp_spec, + ), + ) + dummy_out_leaves, dummy_out_spec = pytree.tree_flatten(dummy_out) + num_leaves = len(dummy_out_leaves) + + for ind in op(range(inp_leaves[0].size(dim))): + xs = [elem.select(dim, ind) for elem in inp_leaves] + + carry, y = combine_fn( + pytree.tree_unflatten(carry, carry_spec), + pytree.tree_unflatten(xs, inp_spec), + ) + carry, _ = pytree.tree_flatten(carry) + y, _ = pytree.tree_flatten(y) + result_flat.append(y) + + results = [ + torch.stack([e[leave_ind] for e in op(result_flat)]) + for leave_ind in range(num_leaves) + ] + return ( + pytree.tree_unflatten(carry, carry_spec), + pytree.tree_unflatten(results, dummy_out_spec), + ) diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 9e020760329c9..5780a2fb638c5 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import collections import copy import dataclasses @@ -6,11 +5,24 @@ import logging import threading from collections import defaultdict -from typing import Any, Dict, List, Optional, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) +from typing_extensions import Never + +import sympy import torch.fx as fx import torch.utils._pytree as pytree -from torch import Tensor +from torch import SymInt, Tensor from torch._C import DispatchKey from torch._ops import HigherOrderOperator from torch._prims_common import clone_preserve_strides @@ -23,8 +35,55 @@ from torch.fx.experimental.symbolic_shapes import guard_scalar +if TYPE_CHECKING: + from triton._C.libtriton.ir import ( + module as TritonIRModule, + operation as TritonIROperation, + ) + + from torch._dynamo.symbolic_convert import InstructionTranslator + from torch._dynamo.variables.constant import ConstantVariable + from torch._dynamo.variables.functions import TritonKernelVariable + from torch._subclasses.functional_tensor import BaseFunctionalizeAPI + from torch.fx.proxy import Proxy + from torch.utils._triton import has_triton + + TritonMetaParamsType = Dict[str, int] + TritonGridTupleType = Tuple[Union[int, sympy.Expr, SymInt], ...] + TritonGridCallableType = Callable[[TritonMetaParamsType], Tuple[int, ...]] + TritonGridType = Union[TritonGridTupleType, TritonGridCallableType] + + if has_triton(): + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + else: + + class Autotuner: # type: ignore[no-redef] + pass + + class JITFunction: # type: ignore[no-redef] + pass + + TritonKernelType = Union[Autotuner, JITFunction] + + log = logging.getLogger("torch._dynamo") +# TMADescriptorMetadata maps kernel parameter names to the metadata that allows +# reconstructing TMA descriptors from the underlying tensors (passed as kernel +# arguments in the fx graph, instead of the TMA descriptors). Namely: a tuple +# conisting of list of dims, list of block dims, and element size. E.g., for this +# call in host-side Triton TMA API ``create_2d_tma_descriptor(ptr, 50, 60, 32, 15, 4)``, +# the metadata will look like ``([50, 60], [32, 15], 4)``. All ints can be SymInts. +TMADescriptorMetadata = Dict[ + str, # kernel parameter name + Tuple[ + List[Union[int, SymInt]], # dims + List[Union[int, SymInt]], # block_dims + Union[int, SymInt], # element_size + ], +] + ############################################################################### # Kernel Side Table @@ -35,13 +94,13 @@ # Use a side table. # We use two dicts so that fetching both the kernel and id are O(1) class KernelSideTable: - id_to_kernel: Dict[int, Any] = {} - kernel_to_id: Dict[Any, int] = {} - constant_args: Dict[int, Any] = {} + id_to_kernel: Dict[int, "TritonKernelType"] = {} + kernel_to_id: Dict["TritonKernelType", int] = {} + constant_args: Dict[int, Dict[str, Any]] = {} lock = threading.Lock() # Returns index on the table - def add_kernel(self, kernel) -> int: + def add_kernel(self, kernel: "TritonKernelType") -> int: with self.lock: if kernel in self.kernel_to_id: return self.kernel_to_id[kernel] @@ -52,21 +111,21 @@ def add_kernel(self, kernel) -> int: return idx # Returns the triton kernel at the given index - def get_kernel(self, idx: int): + def get_kernel(self, idx: int) -> "TritonKernelType": # No need to lock here as fetching from dict is atomic assert idx in self.id_to_kernel return self.id_to_kernel[idx] # Not every constant arg can be added to the graph. Use this side table # for constant args. - def add_constant_args(self, args) -> int: + def add_constant_args(self, args: Dict[str, Any]) -> int: with self.lock: idx = len(self.constant_args) self.constant_args[idx] = args return idx # Returns the constant args - def get_constant_args(self, idx: int): + def get_constant_args(self, idx: int) -> Dict[str, Any]: # No need to lock here as fetching from dict is atomic assert idx in self.constant_args return self.constant_args[idx] @@ -95,7 +154,7 @@ class Param: class Intermediate: idx: int - def fake(self): + def fake(self) -> bool: return self.idx < 0 @@ -106,14 +165,16 @@ class Op: args: List[Union[Param, Intermediate]] ret: Intermediate = dataclasses.field(repr=False) - def __post_init__(self): + def __post_init__(self) -> None: if self.name == "tt.call": assert self.fn_call_name is not None else: assert self.fn_call_name is None -def generate_ttir(kernel, kwargs): +def generate_ttir( + kernel: "TritonKernelType", kwargs: Dict[str, Any] +) -> Tuple["TritonIRModule", List[str]]: """ Uses Triton's internal code generation to create TTIR """ @@ -155,7 +216,18 @@ def generate_ttir(kernel, kwargs): ordered_tensor_names = [ name for name, arg in ordered_args.items() if isinstance(arg, Tensor) ] - specialization = kernel._get_config(*ordered_args.values()) + + def _get_specialization(args): # type: ignore[no-untyped-def] + try: + from triton.backends.compiler import AttrsDescriptor # noqa: F401 + + target = triton.runtime.driver.active.get_current_target() + backend = triton.compiler.compiler.make_backend(target) + return backend.get_attrs_descriptor(args, kernel.params) + except ImportError: + return kernel._get_config(*args) + + specialization = _get_specialization(ordered_args.values()) constants = { name: arg for name, arg in ordered_args.items() if not isinstance(arg, Tensor) } @@ -194,7 +266,9 @@ def generate_ttir(kernel, kwargs): return ttir_module, ordered_tensor_names -def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]: +def ttir_to_functions( + ttir_module: "TritonIRModule", +) -> Dict[str, Dict[Intermediate, List[Op]]]: """ Walk the `ttir_module` bottom up to mine the `functions` from the structured MLIR entities representing the Triton kernel @@ -212,12 +286,12 @@ def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]: reindex_map: Dict[int, int] = {} next_fake_intermediate = 0 - def reindex(idx): + def reindex(idx: int) -> int: if idx not in reindex_map: reindex_map[idx] = len(reindex_map) return reindex_map[idx] - def mlir_to_functions(op) -> None: + def mlir_to_functions(op: "TritonIROperation") -> None: name: str = op.get_name() if name == "builtin.module": # this wraps all tt.func ops @@ -393,11 +467,19 @@ def mlir_to_functions(op) -> None: class MemoizeWithCycleCheck: - def __init__(self, fn): + fn: Callable[..., Any] + cache: Dict[Tuple[str, int], Any] + + def __init__(self, fn: Callable[..., Any]) -> None: self.fn = fn self.reset() - def __call__(self, functions, fn_name, num_args): + def __call__( + self, + functions: Dict[str, Dict[Intermediate, List[Op]]], + fn_name: str, + num_args: int, + ) -> List[bool]: key = (fn_name, num_args) if key not in self.cache: self.cache[key] = None @@ -406,12 +488,14 @@ def __call__(self, functions, fn_name, num_args): raise RuntimeError("Recursion is not supported") return self.cache[key] - def reset(self): + def reset(self) -> None: self.cache = {} @MemoizeWithCycleCheck -def analyze_kernel_mutations(functions, fn_name, num_args): +def analyze_kernel_mutations( + functions: Dict[str, Dict[Intermediate, List[Op]]], fn_name: str, num_args: int +) -> List[bool]: """ Analyzes the graph to detect all sinks from a predefined list of sinks by using triton's MemWrite trait list. NOTE: What if triton exposed this? @@ -422,7 +506,12 @@ def analyze_kernel_mutations(functions, fn_name, num_args): # List from Triton Github include/triton/Dialect/Triton/IR/TritonOps.td # All the OPs that have MemWrite trait. # What if Triton exposed this? - MUTATION_OPS = {"tt.store": [0], "tt.atomic_cas": [0], "tt.atomic_rmw": [0]} + MUTATION_OPS = { + "tt.store": [0], + "tt.atomic_cas": [0], + "tt.atomic_rmw": [0], + "tt.experimental_descriptor_store": [0], + } # Ops that we want to bail out on UNKNOWN_OPS = {"tt.elementwise_inline_asm"} @@ -468,7 +557,9 @@ def analyze_kernel_mutations(functions, fn_name, num_args): return mutated -def identify_mutated_tensors(kernel, kwargs): +def identify_mutated_tensors( + kernel: "TritonKernelType", kwargs: Dict[str, Any] +) -> List[str]: """ Given a triton kernel and the arguments for this kernel, this function 1) Retrieves the TTIR converted version of the kernel from Triton's API. @@ -524,11 +615,19 @@ class TritonKernelWrapperMutation(HigherOrderOperator): def __init__(self) -> None: super().__init__("triton_kernel_wrapper_mutation", cacheable=False) - def __call__(self, kernel_idx, constant_args_idx, grid, kwargs): + def __call__( + self, + kernel_idx: int, + constant_args_idx: int, + grid: List["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: Dict[str, Any], + ) -> Any: return super().__call__( kernel_idx=kernel_idx, constant_args_idx=constant_args_idx, grid=grid, + tma_descriptor_metadata=tma_descriptor_metadata, kwargs=kwargs, ) @@ -541,11 +640,20 @@ class TritonKernelWrapperFunctional(HigherOrderOperator): def __init__(self) -> None: super().__init__("triton_kernel_wrapper_functional", cacheable=False) - def __call__(self, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone): + def __call__( + self, + kernel_idx: int, + constant_args_idx: int, + grid: List["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: Dict[str, Any], + tensors_to_clone: List[str], + ) -> Dict[str, Any]: return super().__call__( kernel_idx=kernel_idx, constant_args_idx=constant_args_idx, grid=grid, + tma_descriptor_metadata=tma_descriptor_metadata, kwargs=kwargs, tensors_to_clone=tensors_to_clone, ) @@ -556,8 +664,13 @@ def __call__(self, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone @triton_kernel_wrapper_mutation.py_impl(DispatchKey.CompositeExplicitAutograd) def triton_kernel_wrapper_mutation_dense( - *, kernel_idx, constant_args_idx, grid, kwargs -): + *, + kernel_idx: int, + constant_args_idx: int, + grid: List["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: Dict[str, Any], +) -> None: from torch._inductor.codegen.wrapper import user_defined_kernel_grid_fn_code kernel = kernel_side_table.get_kernel(kernel_idx) @@ -573,27 +686,70 @@ def triton_kernel_wrapper_mutation_dense( exec(code, namespace) grid_fn = namespace[fn_name] + if tma_descriptor_metadata: + from triton.tools.experimental_descriptor import ( # noqa: F401 + create_1d_tma_descriptor, + create_2d_tma_descriptor, + ) + + # as we need to launch the kernel here, we "unwrap" the + # tma_descriptor_metadata, create the TMA descriptors + # from it, and replace the tensors in the kwargs by the + # correspoinding TMA descriptors before launching + kwargs = kwargs.copy() + for k, v in tma_descriptor_metadata.items(): + tensor = kwargs[k] + dims, block_dims, element_size = v + create_tma_descriptor = ( + create_1d_tma_descriptor if len(dims) == 1 else create_2d_tma_descriptor + ) + kwargs[k] = create_tma_descriptor( + tensor.data_ptr(), + *dims, + *block_dims, + element_size, + ) + kernel[grid_fn](**kwargs, **constant_args) @triton_kernel_wrapper_mutation.py_impl(FakeTensorMode) def triton_kernel_wrapper_mutation_fake_tensor_mode( - mode, *, kernel_idx, constant_args_idx, grid, kwargs -): + mode: FakeTensorMode, + *, + kernel_idx: int, + constant_args_idx: int, + grid: List["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: Dict[str, Any], +) -> None: with mode: return None @triton_kernel_wrapper_mutation.py_impl(DispatchKey.Meta) -def _(*, kernel_idx, constant_args_idx, grid, kwargs): +def _( + *, + kernel_idx: int, + constant_args_idx: int, + grid: List["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: Dict[str, Any], +) -> None: return None -def trace_triton_kernel_wrapper(proxy_mode, func_overload, node_args): +def trace_triton_kernel_wrapper( + proxy_mode: ProxyTorchDispatchMode, + func_overload: Callable[..., Any], + node_args: Dict[str, Any], +) -> Optional[Dict[str, Any]]: with disable_proxy_modes_tracing(): out = func_overload(**node_args) - proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) + proxy_args = pytree.tree_map( + proxy_mode.tracer.unwrap_proxy, node_args # type: ignore[union-attr] + ) out_proxy = proxy_mode.tracer.create_proxy( "call_function", func_overload, @@ -607,8 +763,14 @@ def trace_triton_kernel_wrapper(proxy_mode, func_overload, node_args): @triton_kernel_wrapper_mutation.py_impl(ProxyTorchDispatchMode) def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode( - mode, *, kernel_idx, constant_args_idx, grid, kwargs -): + mode: ProxyTorchDispatchMode, + *, + kernel_idx: int, + constant_args_idx: int, + grid: List["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: Dict[str, Any], +) -> None: trace_triton_kernel_wrapper( mode, triton_kernel_wrapper_mutation, @@ -616,6 +778,7 @@ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode( "kernel_idx": kernel_idx, "constant_args_idx": constant_args_idx, "grid": grid, + "tma_descriptor_metadata": tma_descriptor_metadata, "kwargs": kwargs, }, ) @@ -623,7 +786,9 @@ def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode( return None -def get_mutated_tensors(kernel_idx, constant_args_idx, kwargs): +def get_mutated_tensors( + kernel_idx: int, constant_args_idx: int, kwargs: Dict[str, Any] +) -> List[str]: kernel = kernel_side_table.get_kernel(kernel_idx) constant_args = kernel_side_table.get_constant_args(constant_args_idx) return identify_mutated_tensors(kernel, {**kwargs, **constant_args}) @@ -631,9 +796,14 @@ def get_mutated_tensors(kernel_idx, constant_args_idx, kwargs): @triton_kernel_wrapper_mutation.py_functionalize_impl def triton_kernel_wrapper_mutation_functionalize( - ctx, kernel_idx, constant_args_idx, grid, kwargs -): - unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + ctx: "BaseFunctionalizeAPI", + kernel_idx: int, + constant_args_idx: int, + grid: List["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: Dict[str, Any], +) -> None: + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type] # TODO(oulgen): Preexisting bug, if two kernel inputs are views of each # other, and one gets mutated in kernel, and later another gets mutated, # they are no longer equal. Fix this by graph breaking on this condition @@ -646,6 +816,7 @@ def triton_kernel_wrapper_mutation_functionalize( kernel_idx=kernel_idx, constant_args_idx=constant_args_idx, grid=grid, + tma_descriptor_metadata=tma_descriptor_metadata, kwargs=unwrapped_kwargs, tensors_to_clone=tensors_to_clone, ) @@ -667,8 +838,14 @@ def triton_kernel_wrapper_mutation_functionalize( @triton_kernel_wrapper_functional.py_impl(DispatchKey.CompositeExplicitAutograd) def triton_kernel_wrapper_functional_dense( - *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone -): + *, + kernel_idx: int, + constant_args_idx: int, + grid: List["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: Dict[str, Any], + tensors_to_clone: List[str], +) -> Dict[str, Any]: # TODO(oulgen): For performance reasons, we want to ensure that these # `clone_preserve_strides` calls are never executed at runtime # (inductor should always optimize them away). @@ -681,6 +858,7 @@ def triton_kernel_wrapper_functional_dense( kernel_idx=kernel_idx, constant_args_idx=constant_args_idx, grid=grid, + tma_descriptor_metadata=tma_descriptor_metadata, kwargs=kwargs, ) return {key: val for key, val in kwargs.items() if key in tensors_to_clone} @@ -688,8 +866,15 @@ def triton_kernel_wrapper_functional_dense( @triton_kernel_wrapper_functional.py_impl(FakeTensorMode) def triton_kernel_wrapper_functional_fake_tensor_mode( - mode, *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone -): + mode: FakeTensorMode, + *, + kernel_idx: int, + constant_args_idx: int, + grid: List["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: Dict[str, Any], + tensors_to_clone: List[str], +) -> Dict[str, Any]: # TODO(oulgen): For performance reasons, we want to ensure that these # `clone_preserve_strides` calls are never executed at runtime # (inductor should always optimize them away). @@ -704,35 +889,52 @@ def triton_kernel_wrapper_functional_fake_tensor_mode( @triton_kernel_wrapper_functional.py_impl(ProxyTorchDispatchMode) def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode( - mode, *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone -): - return trace_triton_kernel_wrapper( + mode: ProxyTorchDispatchMode, + *, + kernel_idx: int, + constant_args_idx: int, + grid: List["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: Dict[str, Any], + tensors_to_clone: List[str], +) -> Dict[str, Any]: + ret = trace_triton_kernel_wrapper( mode, triton_kernel_wrapper_functional, { "kernel_idx": kernel_idx, "constant_args_idx": constant_args_idx, "grid": grid, + "tma_descriptor_metadata": tma_descriptor_metadata, "kwargs": kwargs, "tensors_to_clone": tensors_to_clone, }, ) + assert ret is not None + return ret @triton_kernel_wrapper_functional.py_functionalize_impl def triton_kernel_wrapper_functional_functionalize( - ctx, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone -): - unwrapped_kwargs = ctx.unwrap_tensors(kwargs) + ctx: "BaseFunctionalizeAPI", + kernel_idx: int, + constant_args_idx: int, + grid: List["TritonGridType"], + tma_descriptor_metadata: TMADescriptorMetadata, + kwargs: Dict[str, Any], + tensors_to_clone: List[str], +) -> Dict[str, Any]: + unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type] with ctx.redispatch_to_next(): outputs = triton_kernel_wrapper_functional( kernel_idx=kernel_idx, constant_args_idx=constant_args_idx, grid=grid, + tma_descriptor_metadata=tma_descriptor_metadata, kwargs=unwrapped_kwargs, tensors_to_clone=tensors_to_clone, ) - return ctx.wrap_tensors(outputs) + return ctx.wrap_tensors(outputs) # type: ignore[return-value,arg-type] triton_kernel_wrapper_mutation.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined] @@ -780,25 +982,44 @@ class TritonHOPifier: TritonHOPifier is an abstract class that can be overriden by its subclasses. """ - def raise_unsupported(self, msg): + def raise_unsupported(self, msg: str) -> Never: raise NotImplementedError("abstract method") - def is_callable(self, maybe_callable): + def is_callable(self, maybe_callable: Any) -> bool: raise NotImplementedError("abstract method") - def get_value(self, val): + def get_value(self, val: Any) -> Any: raise NotImplementedError("abstract method") - def call_grid(self, grid, meta, tx): + def call_grid( # type: ignore[no-untyped-def] + self, + grid, + meta, + tx, + ) -> Union[Tuple[Union[int, sympy.Expr, SymInt], ...], Tuple["Proxy", ...]]: raise NotImplementedError("abstract method") - def call_HOP(self, variable, grids, combined_args, tx): + def call_HOP( # type: ignore[no-untyped-def] + self, + variable, + grids, + combined_args: Dict[str, Any], + tx, + ) -> Optional["ConstantVariable"]: raise NotImplementedError("abstract method") - def check_grid(self, grid): + def check_grid( # type: ignore[no-untyped-def] + self, grid + ) -> Union[Tuple[Union[int, sympy.Expr, SymInt], ...], Tuple["Proxy", ...]]: raise NotImplementedError("abstract method") - def init_variable(self, variable, kernel, kernel_idx, grid): + def init_variable( + self, + variable: Union["TraceableTritonKernelWrapper", "TritonKernelVariable"], + kernel: "TritonKernelType", + kernel_idx: Optional[int], + grid: Optional["TritonGridType"], + ) -> None: from triton.runtime.autotuner import Autotuner assert kernel is not None @@ -853,7 +1074,11 @@ def init_variable(self, variable, kernel, kernel_idx, grid): "Only configs and keys are supported for triton.autotune" ) - def call_getitem(self, variable, args): + def call_getitem( + self, + variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"], + args: Sequence[Any], + ) -> Union["TritonKernelVariable", "TraceableTritonKernelWrapper"]: # __getitem__ should only be called if we don't already have a grid # Only grid needs to be passed if variable.grid is not None or len(args) != 1: @@ -867,7 +1092,13 @@ def call_getitem(self, variable, args): grid=args[0], ) - def call_run(self, variable, args, kwargs, tx): + def call_run( + self, + variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"], + args: Sequence[Any], + kwargs: Dict[str, Any], + tx: Optional["InstructionTranslator"], + ) -> Optional["ConstantVariable"]: if "grid" not in kwargs: self.raise_unsupported("Triton kernel requires to be called with a grid") grid = kwargs.pop("grid") @@ -882,7 +1113,13 @@ def call_run(self, variable, args, kwargs, tx): tx, ) - def call_triton_kernel(self, variable, args, kwargs, tx): + def call_triton_kernel( + self, + variable: Union["TritonKernelVariable", "TraceableTritonKernelWrapper"], + args: Sequence[Any], + kwargs: Dict[str, Any], + tx: Optional["InstructionTranslator"], + ) -> Optional["ConstantVariable"]: from triton import JITFunction from triton.runtime.autotuner import autotune, Autotuner, Config @@ -969,10 +1206,11 @@ def call_triton_kernel(self, variable, args, kwargs, tx): # If the grid is a function, then lets execute it and convert it to # a list grid = variable.grid + assert grid is not None if self.is_callable(grid): # Populate the special "meta" argument to call the grid function meta = {**combined_args_raw, **config_args} - grid = self.call_grid(grid, meta, tx) + grid = self.call_grid(grid, meta, tx) # type: ignore[arg-type] grids.append(self.check_grid(grid)) for i in range(len(grids)): @@ -987,7 +1225,6 @@ def call_triton_kernel(self, variable, args, kwargs, tx): self.raise_unsupported("Grid can have at most rank 3") assert len(grids) != 0 - if isinstance(variable.kernel, JITFunction): constexprs = variable.kernel.constexprs else: @@ -1010,7 +1247,6 @@ def call_triton_kernel(self, variable, args, kwargs, tx): combined_args_raw[arg_name] = variable.specialize_symbolic( combined_args_raw[arg_name] ) - return self.call_HOP(variable, grids, combined_args_raw, tx) @@ -1020,20 +1256,30 @@ def call_triton_kernel(self, variable, args, kwargs, tx): class TracingTritonHOPifier(TritonHOPifier): - def raise_unsupported(self, msg): + def raise_unsupported(self, msg: str) -> Never: raise RuntimeError(msg) - def is_callable(self, maybe_callable): + def is_callable(self, maybe_callable: Any) -> bool: return callable(maybe_callable) - def get_value(self, val): + def get_value(self, val: Any) -> Any: return val - def call_grid(self, grid, meta, tx): + def call_grid( + self, + grid: "TritonGridCallableType", + meta: "TritonMetaParamsType", + tx: None, + ) -> Tuple[Union[int, sympy.Expr, SymInt], ...]: assert tx is None + assert isinstance(meta, dict) + assert callable(grid) return grid(meta) - def check_grid(self, grid): + def check_grid( + self, + grid: "TritonGridType", + ) -> Tuple[Union[int, sympy.Expr, SymInt], ...]: if not isinstance(grid, collections.abc.Sequence): raise RuntimeError( "capture_triton can only handle grids that resolve to Sequence[int]." @@ -1041,10 +1287,17 @@ def check_grid(self, grid): # normalize to tuple return tuple(grid) - def call_HOP(self, variable, grids, combined_args, tx): + def call_HOP( + self, + variable: "TraceableTritonKernelWrapper", + grids: List["TritonGridTupleType"], + combined_args: Dict[str, Any], + tx: None, + ) -> None: assert tx is None + assert isinstance(variable, TraceableTritonKernelWrapper) - def is_graphable(val): + def is_graphable(val: Any) -> bool: return isinstance(val, fx.node.base_types) non_graphable_args = { @@ -1053,10 +1306,14 @@ def is_graphable(val): graphable_args = {k: v for k, v in combined_args.items() if is_graphable(v)} constant_args_idx = kernel_side_table.add_constant_args(non_graphable_args) + assert isinstance(variable.kernel_idx, int) return triton_kernel_wrapper_mutation( kernel_idx=variable.kernel_idx, constant_args_idx=constant_args_idx, - grid=grids, + grid=grids, # type: ignore[arg-type] + # TMA descriptor capturing not yet + # supported in non-dynamo tracing + tma_descriptor_metadata={}, kwargs=graphable_args, ) @@ -1065,16 +1322,25 @@ def is_graphable(val): class TraceableTritonKernelWrapper: - def __init__(self, kernel, kernel_idx, grid): + kernel: "TritonKernelType" + kernel_idx: Optional[int] + grid: Optional["TritonGridType"] + + def __init__( + self, + kernel: "TritonKernelType", + kernel_idx: Optional[int], + grid: Optional["TritonGridType"], + ) -> None: self.kernel = None self.grid = None tracing_triton_hopifier_singleton.init_variable(self, kernel, kernel_idx, grid) assert self.kernel is not None - def __getitem__(self, *args): - return tracing_triton_hopifier_singleton.call_getitem(self, args) + def __getitem__(self, *args: Sequence[Any]) -> "TraceableTritonKernelWrapper": + return tracing_triton_hopifier_singleton.call_getitem(self, args) # type: ignore[return-value] - def run(self, *args, **kwargs): + def run(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> Any: from torch._library.triton import is_capture_triton_enabled if is_capture_triton_enabled(): @@ -1083,7 +1349,7 @@ def run(self, *args, **kwargs): assert self.kernel is not None return self.kernel.run(*args, **kwargs) - def __call__(self, *args, **kwargs): + def __call__(self, *args: Sequence[Any], **kwargs: Dict[str, Any]) -> Any: from torch._library.triton import is_capture_triton_enabled if is_capture_triton_enabled(): @@ -1094,7 +1360,7 @@ def __call__(self, *args, **kwargs): assert self.kernel is not None return self.kernel[self.grid](*args, **kwargs) - def specialize_symbolic(self, arg: Any) -> Any: + def specialize_symbolic(self, arg: Sequence[Any]) -> Any: import torch # See [Note: Specialize tl.constexpr args in user-defined triton kernels] diff --git a/torch/_higher_order_ops/utils.py b/torch/_higher_order_ops/utils.py index 139e9a160cbe2..549e1af54f9b6 100644 --- a/torch/_higher_order_ops/utils.py +++ b/torch/_higher_order_ops/utils.py @@ -2,7 +2,7 @@ import functools from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Callable +from typing import Any, Callable, List import torch import torch.fx.traceback as fx_traceback @@ -99,7 +99,23 @@ def _maybe_reenter_make_fx(fn): if _CURRENT_MAKE_FX_TRACER is not None: return reenter_make_fx(fn) else: - return make_fx(fn) + + def _maybe_make_fx_with_fake_mode(fn): + @functools.wraps(fn) + def wrapped(*args): + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(args) + if fake_mode is None: + # we creaeta a fake_mode here to make sure we could + # trace the graph with data-dependent calls e.g. .item() + return make_fx(fn, tracing_mode="fake")(*args) + # Tracing with real if all inputs have been fakfied + return make_fx(fn)(*args) + + return wrapped + + return _maybe_make_fx_with_fake_mode(fn) @contextmanager @@ -114,6 +130,73 @@ def _set_compilation_env(): torch.fx._symbolic_trace._is_fx_tracing_flag = _old_is_tracing +def _detect_input_mutation(gm): + input_nodes = set() + for node in gm.graph.nodes: + if node.op == "placeholder": + input_nodes.add(node) + if node.op == "call_function": + target = node.target + if isinstance(target, torch._ops.OpOverload) and target._schema.is_mutable: + for arg in node.args: + if arg in input_nodes: + return True + + for _, module in gm.named_children(): + if isinstance(module, torch.fx.GraphModule): + if _detect_input_mutation(module): + return True + + return False + + +def _detect_input_alias(gm): + input_storages = set() + for node in gm.graph.nodes: + # We need to check existence of "val" because we reuse the logic here + # for map operator, where num_mapped_args is a scalar + # and doesn't have a "val" meta. + if ( + node.op == "placeholder" + and "val" in node.meta + and isinstance(node.meta["val"], torch.Tensor) + ): + input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage())) + if node.op == "output": + + def check_alias(out): + if ( + out is not None + and "val" in out.meta + and isinstance(out.meta["val"], torch.Tensor) + ): + out_storage = StorageWeakRef(out.meta["val"]._typed_storage()) + return out_storage in input_storages + return False + + if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))): + return True + + for _, module in gm.named_children(): + if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module): + return True + + return False + + +def has_potential_input_alias_or_mutation(gm, inputs, pre_dispatch=False): + try: + gm = make_fx(gm, pre_dispatch=pre_dispatch)(*inputs) + except UnsupportedAliasMutationException: + # this can happen when nested cond_op is + # functionalized + return True + except Exception as e: + raise e + + return _detect_input_mutation(gm) or _detect_input_alias(gm) + + def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False): """ Dispatch-trace the branch with inputs and check if @@ -129,28 +212,6 @@ def _has_potential_branch_input_mutation(branch, inputs, pre_dispatch=False): except Exception as e: raise e - def _detect_input_mutation(gm): - input_nodes = set() - for node in gm.graph.nodes: - if node.op == "placeholder": - input_nodes.add(node) - if node.op == "call_function": - target = node.target - if ( - isinstance(target, torch._ops.OpOverload) - and target._schema.is_mutable - ): - for arg in node.args: - if arg in input_nodes: - return True - - for _, module in gm.named_children(): - if isinstance(module, torch.fx.GraphModule): - if _detect_input_mutation(module): - return True - - return False - return _detect_input_mutation(gm) @@ -169,31 +230,6 @@ def _has_potential_branch_input_alias(branch, inputs, pre_dispatch=False): except Exception as e: raise e - def _detect_input_alias(gm): - input_storages = set() - for node in gm.graph.nodes: - # We need to check existence of "val" because we reuse the logic here - # for map operator, where num_mapped_args is a scalar - # and doesn't have a "val" meta. - if node.op == "placeholder" and "val" in node.meta: - input_storages.add(StorageWeakRef(node.meta["val"]._typed_storage())) - if node.op == "output": - - def check_alias(out): - if out is not None and "val" in out.meta: - out_storage = StorageWeakRef(out.meta["val"]._typed_storage()) - return out_storage in input_storages - return False - - if any(pytree.tree_leaves(pytree.tree_map(check_alias, node.args))): - return True - - for _, module in gm.named_children(): - if isinstance(module, torch.fx.GraphModule) and _detect_input_alias(module): - return True - - return False - return _detect_input_alias(gm) @@ -378,3 +414,63 @@ def _stack_pytree(pytrees): else: raise RuntimeError(f"Cannot stack {leaves}.") return pytree.tree_unflatten(stacked_out, out_spec) + + +# We cannot call save_for_backward for symints. This helper function +# can be used to save symints as direct attributes of ctx in autograd.Function. +# +# For example, if args = (x, y, s0, z, s1), +# save_tensors_and_symints_for_backward will partition the args into two lists, and a bookkeeping list pos: +# partitioned_args[0] = (x, y, z) +# partitioned_args[1] = (s0, s1) +# pos = (0, 0, 1, 0, 1) +# pos list keeps track of which partition the args +# is partitioned into in order to recover it in saved_tensors_and_symints. +# +# In saved_tensors_and_symints, we can recover the original args by: +# iterating over the pos list and pop one item from the front of paritioned_args[pos[i]]. +# We use t_idx and s_idx to keep track of the next index of the item we are going to pop for the two lists. +def save_tensors_and_symints_for_backward(ctx, args): + assert all(isinstance(arg, (torch.Tensor, torch.SymInt, int)) for arg in args), args + partitioned_args: List[Any] = [[], []] + pos = [] + for i, arg in enumerate(args): + idx = 0 if isinstance(arg, torch.Tensor) else 1 + partitioned_args[idx].append(arg) + pos.append(idx) + + assert not hasattr(ctx, "sym_int_args"), "ctx already has sym_int_args attribute." + assert not hasattr(ctx, "pos"), "ctx already has pos attribute." + ctx.save_for_backward(*partitioned_args[0]) + ctx.sym_int_args = partitioned_args[1] + ctx.pos = pos + + +def saved_tensors_and_symints(ctx): + args = [] + t_idx = 0 + s_idx = 0 + saved_tensors = ctx.saved_tensors + for p in ctx.pos: + if p == 0: + args.append(saved_tensors[t_idx]) + t_idx += 1 + else: + args.append(ctx.sym_int_args[s_idx]) + s_idx += 1 + assert t_idx + s_idx == len(ctx.pos) + return tuple(args) + + +def get_dummy_aot_autograd_config(): + from torch._functorch.aot_autograd import AOTConfig + + return AOTConfig( + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] + decompositions={}, + num_params_buffers=0, + aot_id=0, + keep_inference_input_mutations=False, + ) diff --git a/torch/_higher_order_ops/while_loop.py b/torch/_higher_order_ops/while_loop.py index f14321842f40b..fe8f11a9a7a36 100644 --- a/torch/_higher_order_ops/while_loop.py +++ b/torch/_higher_order_ops/while_loop.py @@ -129,7 +129,7 @@ def body_fn(iter, x): def _validate_input(cond_fn, body_fn, carried_inputs): if not callable(cond_fn) or not callable(body_fn): - raise RuntimeError("Expect cond_fn and body_fn to be callbale.") + raise RuntimeError("Expect cond_fn and body_fn to be callable.") if not isinstance(carried_inputs, (tuple, list)) or pytree.tree_any( lambda t: not isinstance(t, torch.Tensor), carried_inputs diff --git a/torch/_inductor/__init__.py b/torch/_inductor/__init__.py index 666b332a4af61..397739147c13b 100644 --- a/torch/_inductor/__init__.py +++ b/torch/_inductor/__init__.py @@ -1,16 +1,21 @@ # mypy: allow-untyped-defs -from typing import Any, Dict, List, Optional, Tuple + +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING import torch.fx import torch.utils._pytree as pytree +if TYPE_CHECKING: + from torch._inductor.utils import InputType + + __all__ = ["compile", "list_mode_options", "list_options", "cudagraph_mark_step_begin"] def compile( gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], + example_inputs: List["InputType"], options: Optional[Dict[str, Any]] = None, ): """ @@ -74,7 +79,9 @@ def aoti_compile_and_package( if not isinstance(exported_program, ExportedProgram): raise ValueError("Only ExportedProgram is supported") - assert package_path is None or package_path.endswith(".pt2") + assert package_path is None or package_path.endswith( + ".pt2" + ), f"Expect package path to end with .pt2, got {package_path}" inductor_configs = inductor_configs or {} @@ -232,12 +239,14 @@ def list_mode_options( # enable max-autotune "max-autotune-no-cudagraphs": { "max_autotune": True, + "coordinate_descent_tuning": True, }, # enable max-autotune # enable cudagraphs "max-autotune": { "max_autotune": True, "triton.cudagraphs": True, + "coordinate_descent_tuning": True, }, } return mode_options[mode] if mode else mode_options # type: ignore[return-value] @@ -256,7 +265,7 @@ def list_options() -> List[str]: from torch._inductor import config - current_config: Dict[str, Any] = config.shallow_copy_dict() + current_config: Dict[str, Any] = config.get_config_copy() return list(current_config.keys()) diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index b965f3f129031..c803e7690f9fc 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -49,6 +49,8 @@ kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") +log = logging.getLogger(__name__) + def pre_fork_setup(): """ @@ -160,10 +162,14 @@ def process_pool() -> AnyPool: pool: AnyPool if get_worker_start_method() == "subprocess": # Wrapper around ProcessPoolExecutor forks in a new process we control + log.info("Creating subprocess pool with %d workers", get_compile_threads()) pool = SubprocPool(get_compile_threads()) else: pre_fork_setup() ctx = multiprocessing.get_context(get_worker_start_method()) + log.info( + "Creating forked subprocess pool with %d workers", get_compile_threads() + ) pool = ProcessPoolExecutor( get_compile_threads(), mp_context=ctx, diff --git a/torch/_inductor/autotune_process.py b/torch/_inductor/autotune_process.py index 47ec77883ecb5..4be35dceece17 100644 --- a/torch/_inductor/autotune_process.py +++ b/torch/_inductor/autotune_process.py @@ -433,7 +433,7 @@ def from_irnodes( node = irnodes if isinstance(node, ir.Layout): - node = ir.Buffer("fake", node) + node = ir.Buffer(name="fake", layout=node) dtype = node.get_dtype() assert dtype is not None diff --git a/torch/_inductor/bisect_helper.py b/torch/_inductor/bisect_helper.py new file mode 100644 index 0000000000000..5cb1dd5691804 --- /dev/null +++ b/torch/_inductor/bisect_helper.py @@ -0,0 +1,593 @@ +import collections +import dataclasses +import functools +import os +import shutil +import sys +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional, Tuple + +from torch._inductor.runtime.cache_dir_utils import cache_dir + + +# Set the subdirectory name +SUBDIR_NAME = "bisect" + + +@dataclass +class Subsystem: + name: str + + +@dataclass +class BisectSubsystem(Subsystem): + pass + + +@dataclass +class BinarySubsystem(Subsystem): + pass + + +@dataclass +class ConfigChange(BinarySubsystem): + name: str = field(init=False) + config_name: str + config_field: str + config_value: object + + def __post_init__(self) -> None: + self.name = f"{self.config_name}_{self.config_field}" + + +# Dictionary of backend -> subsystems +BACKENDS: Dict[str, List[Subsystem]] = { + # run dynamo without aot_autograd + "eager": [], + # run dynamo with aot_autograd, but no partitioner or decomps + "aot_eager": [], + # run dynamo with aot autograd, decompositions and partitioner + "aot_eager_decomp_partition": [ + ConfigChange("aot_eager_decomp_partition", "cse", False), + BisectSubsystem( + "decomposition" + ), # number of decompositions we apply in tracing + ], # TODO - add cse ? + # applies CrossRefFakeMode on invocation + "aot_eager_decomp_partition_crossref": [], + "inductor": [ + BisectSubsystem( + "post_grad_passes" + ), # passes applied individually on forward, and backward in inductor + ConfigChange("inductor", "emulate_precision_casts", True), + BisectSubsystem("lowerings"), # lowering aten operators to inductor + ], # TODO - add more - fusions, amp numeric mode ? +} + +subsystem_call_counter: Dict[str, int] = collections.Counter() +call_counter_debug_info: Dict[int, str] = {} + + +def reset_counters() -> None: + subsystem_call_counter.clear() + call_counter_debug_info.clear() + + +@functools.lru_cache(None) +def get_env_val(env_str: str) -> Optional[str]: + return os.environ.get(env_str, None) + + +@dataclasses.dataclass +class BisectionResult: + """ + backend: torch.compile backend responsible for failure + subsystem: optional, registered component identified for failure + bisect_number: optional, number of times the subsystem needed to be applied to trigger failure + debug_info: associated info of the triggering bisect application of subsystem + """ + + backend: str + subsystem: Optional[str] = None + bisect_number: Optional[int] = None + debug_info: Optional[str] = None + + +class BisectionManager: + bisection_enabled: bool = False + + @classmethod + def get_dir(cls) -> str: + return f"{cache_dir()}/{SUBDIR_NAME}" + + @classmethod + def write_lines_to_file(cls, file_path: str, lines: List[str]) -> None: + os.makedirs(os.path.dirname(file_path), exist_ok=True) + with open(file_path, "w") as file: + file.writelines(lines) + + @classmethod + def read_lines_from_file(cls, file_path: str) -> List[str]: + if os.path.exists(file_path): + with open(file_path) as file: + return file.readlines() + return [] + + @classmethod + def update_run_state( + cls, backend_name: str, subsystem: Subsystem, run_state: str + ) -> None: + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem.name}_run_state.txt" + ) + if isinstance(subsystem, ConfigChange): + assert run_state == "test_disable" + cls.set_config_values( + backend_name, + subsystem.name, + {subsystem.config_field: subsystem.config_value}, + ) + + cls.write_lines_to_file(file_path, [run_state]) + + @classmethod + def set_config_values( + cls, backend: str, subsystem: str, config_data: Dict[str, object] + ) -> None: + file_path = os.path.join(cls.get_dir(), backend, f"{subsystem}_config.txt") + lines = [f"{k}={v}\n" for k, v in config_data.items()] + cls.write_lines_to_file(file_path, lines) + + @classmethod + def update_bisect_status(cls, backend_name: str, subsystem_name: str) -> None: + assert isinstance(subsystem_name, str) + file_path = os.path.join(cls.get_dir(), "bisect_status.txt") + lines = [f"backend={backend_name}\n", f"subsystem={subsystem_name}\n"] + cls.write_lines_to_file(file_path, lines) + + @classmethod + def update_bisect_range( + cls, backend_name: str, subsystem_name: str, low: int, high: int + ) -> None: + assert isinstance(subsystem_name, str) + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem_name}_bisect_range.txt" + ) + lines = [f"low={low}\n", f"high={high}\n"] + cls.write_lines_to_file(file_path, lines) + + @classmethod + def get_backend(cls) -> Optional[str]: + """ + Returns the active backend, if any + """ + if val := get_env_val("TORCH_BISECT_BACKEND"): + return val + + file_path = os.path.join(cls.get_dir(), "bisect_status.txt") + lines = cls.read_lines_from_file(file_path) + for line in lines: + if line.startswith("backend="): + return line.strip().split("=")[1] + return None + + @classmethod + def get_subsystem(cls) -> Optional[str]: + """ + Returns the active subsystem, if any + """ + + if val := get_env_val("TORCH_BISECT_SUBSYSTEM"): + return val + + file_path = os.path.join(cls.get_dir(), "bisect_status.txt") + lines = cls.read_lines_from_file(file_path) + for line in lines: + if line.startswith("subsystem="): + out = line.strip().split("=")[1] + return out if out else None + return None + + @classmethod + def get_subsystem_object(cls, backend_name: str, subsystem_name: str) -> Subsystem: + return next(obj for obj in BACKENDS[backend_name] if obj.name == subsystem_name) + + @classmethod + def get_run_state(cls, backend_name: str, subsystem_name: str) -> Optional[str]: + """ + Returns the current stage of bisecting, if Any + """ + + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem_name}_run_state.txt" + ) + lines = cls.read_lines_from_file(file_path) + if lines: + out = lines[0].strip() + assert out in ("test_disable", "find_max_bounds", "bisect") + return out + return None + + @classmethod + def get_bisect_range( + cls, backend_name: str, subsystem_name: str + ) -> Tuple[int, int]: + file_path = os.path.join( + cls.get_dir(), backend_name, f"{subsystem_name}_bisect_range.txt" + ) + lines = cls.read_lines_from_file(file_path) + low = None + high = None + for line in reversed(lines): + if line.startswith("low="): + low = int(line.strip().split("=")[1]) + elif line.startswith("high="): + high = int(line.strip().split("=")[1]) + + if low is not None and high is not None: + break + + if low is None or high is None: + raise RuntimeError( + f"Trying to get bisect range when it is not set: subsystem {subsystem_name}" + ) + + return low, high + + @classmethod + def update_config_change(cls, backend: str, subsystem: ConfigChange) -> None: + file_path = os.path.join(cls.get_dir(), backend, f"{subsystem.name}_config.txt") + lines = [ + f"config_name={subsystem.config_name}\n", + f"config_field={subsystem.config_field}\n", + f"config_value={subsystem.config_value}\n", + ] + cls.write_lines_to_file(file_path, lines) + + @classmethod + def get_config_change(cls, config_name: str) -> Optional[Dict[str, object]]: + backend = cls.get_backend() + subsystem = cls.get_subsystem() + + if not backend or not subsystem: + return None + + file_path = os.path.join(cls.get_dir(), backend, f"{subsystem}_config.txt") + + if not os.path.exists(file_path): + return None + + lines = cls.read_lines_from_file(file_path) + config_data = {} + for line in lines: + key, value = line.strip().split("=", 1) + config_data[key] = eval(value) + + return config_data + + @classmethod + def delete_bisect_status(cls) -> None: + if os.path.exists(cls.get_dir()): + shutil.rmtree(cls.get_dir()) + print("Bisection status deleted.") + else: + print("No bisection status found.") + + @classmethod + def get_system_counter(cls, name: str, increment: bool = True) -> int: + global subsystem_call_counter + curr = subsystem_call_counter[name] + if increment: + subsystem_call_counter[name] += 1 + return curr + + @classmethod + def disable_subsystem( + cls, + backend: str, + subsystem: str, + debug_info: Optional[Callable[[], str]] = None, + ) -> bool: + if not cls.bisection_enabled: + return False + + if cls.get_backend() != backend: + return False + + if cls.get_subsystem() != subsystem: + return False + + if val := get_env_val("TORCH_BISECT_MAX"): + counter = cls.get_system_counter(subsystem, increment=True) + return counter > int(val) + + run_state = cls.get_run_state(backend, subsystem) + if run_state == "test_disable": + # First run, disable completely + return True + elif run_state == "find_max_bounds": + # Second run, update bisection range and return True to enable the subsystem + cls.update_bisect_range( + backend, + subsystem, + 0, + cls.get_system_counter(subsystem, increment=True), + ) + return False + else: + assert run_state == "bisect" + # If the environment variable is not set, use the bisection range midpoint + low, high = cls.get_bisect_range(backend, subsystem) + # if high - low <= 2: + midpoint = (low + high) // 2 + call_counter = cls.get_system_counter(subsystem) + + if ( + call_counter >= low + and call_counter <= high + and (low - high) <= 2 + and debug_info is not None + ): + call_counter_debug_info[call_counter] = debug_info() + + return call_counter > midpoint + + @classmethod + def advance_subsystem( + cls, curr_backend: str, curr_subsystem: Subsystem + ) -> Optional[Subsystem]: + """ + Tries to move to the next subsystem within the current system. + """ + print(f"Disabling {curr_subsystem.name} did not fix the issue.") + + current_subsystems = BACKENDS[curr_backend] + current_subsystem_index = next( + i + for i, subsystem in enumerate(current_subsystems) + if subsystem.name == curr_subsystem.name + ) + + if current_subsystem_index < len(current_subsystems) - 1: + next_subsystem = current_subsystems[current_subsystem_index + 1] + cls.update_bisect_status(curr_backend, next_subsystem.name) + cls.update_run_state(curr_backend, next_subsystem, "test_disable") + print( + f"Moving to the next subsystem: {curr_backend} - {next_subsystem.name}" + ) + return next_subsystem + else: + print( + f"All subsystems in {curr_backend} have been checked. The issue is not in this system." + ) + return None + + @classmethod + def advance_backend(cls, curr_backend: str) -> Optional[str]: + """ + Tries Move to the next backend. + """ + current_system_index = list(BACKENDS.keys()).index(curr_backend) + + if current_system_index < len(BACKENDS) - 1: + curr_backend = list(BACKENDS.keys())[current_system_index + 1] + cls.update_bisect_status(curr_backend, "") + print(f"Moving to the next system: {curr_backend}") + return curr_backend + else: + return None + + @classmethod + def process_subsystem( + cls, + curr_backend: str, + curr_subsystem: Subsystem, + fn: Callable[[], bool], + cli_interface: bool = True, + ) -> bool: + """ + Process the current subsystem. Returns True if the issue is found, False otherwise. + """ + assert isinstance(curr_subsystem, Subsystem) + while True: + run_state = cls.get_run_state(curr_backend, curr_subsystem.name) + reset_counters() + if run_state == "test_disable": + if not fn(): + next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem) + if not next_subsystem: + return False + curr_subsystem = next_subsystem + else: + if isinstance(curr_subsystem, ConfigChange): + print( + f"Setting config {curr_subsystem.config_name} field {curr_subsystem.config_field}" + f"to {curr_subsystem.config_value} fixed the issue" + ) + else: + print(f"Disabling {curr_subsystem.name} fixed the issue.") + if isinstance(curr_subsystem, BinarySubsystem): + return True + print("Starting bisect by getting upper bound.") + cls.update_run_state( + curr_backend, curr_subsystem, "find_max_bounds" + ) + elif run_state == "find_max_bounds": + if fn(): + raise RuntimeError( + f"Function succeeded with 'find_max_bounds' status for {curr_backend} - {curr_subsystem.name}." + ) + else: + _, high = cls.get_bisect_range(curr_backend, curr_subsystem.name) + print(f"Upper bound of {high} found for {curr_backend}.") + cls.update_run_state(curr_backend, curr_subsystem, "bisect") + elif run_state == "bisect": + low, high = cls.get_bisect_range(curr_backend, curr_subsystem.name) + midpoint = (low + high) // 2 + print( + f"Bisecting {curr_backend} - {curr_subsystem.name} (Range: [{low}, {high}], Midpoint: {midpoint})" + ) + if fn(): + cls.update_bisect_range( + curr_backend, curr_subsystem.name, midpoint + 1, high + ) + else: + cls.update_bisect_range( + curr_backend, curr_subsystem.name, low, midpoint + ) + low, high = cls.get_bisect_range(curr_backend, curr_subsystem.name) + if low == high: + print( + f"Binary search completed for {curr_backend} - {curr_subsystem.name}. The bisect number is {low}. " + f"Debug info: {call_counter_debug_info.get(low, 'not found')}" + ) + return True + else: + raise RuntimeError(f"Unexpected run_state {run_state}") + + if cli_interface: + sys.exit(0) + + @classmethod + def initialize_system(cls) -> None: + curr_backend = next(iter(BACKENDS.keys())) + curr_subsystem = "" + cls.update_bisect_status(curr_backend, curr_subsystem) + print(f"Starting bisection process with system: {curr_backend}") + + @classmethod + def do_bisect( + cls, fn: Callable[[], bool], cli_interface: bool = False + ) -> Optional[BisectionResult]: + if not cli_interface: + bisection_enabled_orig = cls.bisection_enabled + cls.delete_bisect_status() + cls.bisection_enabled = True + + # TODO - cli interface, and in-process different directories + class DisableBisect: + def __del__(self) -> None: + cls.bisection_enabled = bisection_enabled_orig + cls.delete_bisect_status() + + cleanup = DisableBisect() + + curr_backend = cls.get_backend() + curr_subsystem_name = cls.get_subsystem() + + if not curr_backend: + cls.initialize_system() + curr_backend = cls.get_backend() + assert curr_backend is not None + curr_subsystem_name = cls.get_subsystem() + + curr_subsystem = ( + cls.get_subsystem_object(curr_backend, curr_subsystem_name) + if curr_subsystem_name is not None + else None + ) + while True: + assert curr_backend is not None + reset_counters() + if curr_subsystem: + result = cls.process_subsystem( + curr_backend, curr_subsystem, fn, cli_interface=cli_interface + ) + if result: + curr_subsystem = cls.get_subsystem_object( + curr_backend, cls.get_subsystem() # type: ignore[arg-type] + ) + + if isinstance(curr_subsystem, BinarySubsystem): + return BisectionResult( + curr_backend, + curr_subsystem.name, + 0, + curr_subsystem.name, + ) + + low, _ = cls.get_bisect_range(curr_backend, curr_subsystem.name) + return BisectionResult( + curr_backend, + curr_subsystem.name, + low, + call_counter_debug_info.get(low, None), + ) + + next_subsystem = cls.advance_subsystem(curr_backend, curr_subsystem) + if not next_subsystem: + print( + f"The issue is in the {curr_backend} system, but could not identify subsystem." + ) + assert curr_backend is not None + return BisectionResult(curr_backend) + + curr_subsystem = next_subsystem + else: + if fn(): + next_backend = cls.advance_backend(curr_backend) + if not next_backend: + print("All systems have been checked.") + return None + + curr_backend = next_backend + else: + current_subsystems = BACKENDS[curr_backend] + if current_subsystems: + curr_subsystem = current_subsystems[0] + cls.update_bisect_status(curr_backend, curr_subsystem.name) + cls.update_run_state( + curr_backend, curr_subsystem, "test_disable" + ) + print( + f"The issue is in the {curr_backend} system. Moving to the first subsystem: {curr_subsystem}" + ) + else: + print(f"The issue is in the {curr_backend} system.") + return BisectionResult(curr_backend) + + if cli_interface: + sys.exit(0) + + +def command_line_usage() -> None: + if len(sys.argv) < 2: + print("Usage: python bisect_update.py ") + sys.exit(1) + + bisection_manager = BisectionManager() + command = sys.argv[1] + + if command == "end": + bisection_manager.delete_bisect_status() + sys.exit(0) + + if command == "start": + bisection_manager.delete_bisect_status() + bisection_manager.initialize_system() + sys.exit(0) + + if command not in ["good", "bad"]: + print("Invalid command. Must be 'good', 'bad', 'start', or 'end'.") + sys.exit(1) + + def test_function() -> bool: + return command == "good" + + if not bisection_manager.get_backend(): + raise ValueError("Must call start prior to good or bad") + + bisection_manager.do_bisect(test_function, cli_interface=True) + + +def get_is_bisection_enabled() -> bool: + return ( + BisectionManager.get_subsystem() is not None + or BisectionManager.get_backend() is not None + ) + + +BisectionManager.bisection_enabled = get_is_bisection_enabled() + +if __name__ == "__main__": + command_line_usage() diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index c902015b683dd..c914c6a7338bd 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -69,6 +69,8 @@ from torch._utils_internal import log_cache_bypass from .remote_cache import create_cache +from .runtime import autotune_cache +from .runtime.autotune_cache import AutotuneCacheBundler from .utils import _align @@ -78,7 +80,9 @@ if TYPE_CHECKING: from collections.abc import KeysView + from .compile_fx import _CompileFxKwargs from .remote_cache import JsonDataTy, RemoteCache + from .utils import InputType """ @@ -194,6 +198,15 @@ def get_cpp_wrapper_cubin_path_name() -> str: return "cubin_path" if torch.version.hip is None else "hsaco_path" +@functools.lru_cache(None) +def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]: + return ( + Path(os.path.join(global_cache_dir, CacheBase.get_system()["hash"])) + if global_cache_dir is not None + else None + ) + + class CacheBase: @staticmethod @functools.lru_cache(None) @@ -240,13 +253,8 @@ def get_local_cache_path() -> Path: return Path(os.path.join(cache_dir(), "cache", CacheBase.get_system()["hash"])) @staticmethod - @functools.lru_cache(None) def get_global_cache_path() -> Optional[Path]: - return ( - Path(os.path.join(config.global_cache_dir, CacheBase.get_system()["hash"])) - if config.global_cache_dir is not None - else None - ) + return get_global_cache_path_impl(config.global_cache_dir) def __init__(self) -> None: self.system = CacheBase.get_system() @@ -470,7 +478,17 @@ def write_atomic( write_mode = "w" if isinstance(content, str) else "wb" with tmp_path.open(write_mode, encoding="utf-8" if encode_utf_8 else None) as f: f.write(content) - tmp_path.rename(path) + try: + tmp_path.rename(target=path) + except FileExistsError as e_file_exist: + if not _IS_WINDOWS: + raise + # On Windows file exist is expected: https://docs.python.org/3/library/pathlib.html#pathlib.Path.rename + # Below two lines code is equal to `tmp_path.rename(path)` on non-Windows OS. + # 1. Copy tmp_file to Target(Dst) file. + shutil.copy2(src=tmp_path, dst=path) + # 2. Delete tmp_file. + os.remove(tmp_path) @dataclasses.dataclass @@ -676,34 +694,35 @@ def torch_key() -> bytes: """ Compute a key that contains relevant information about torch source files """ - if not config.is_fbcode(): - - def get_code_hash(root: str) -> bytes: - # This function isn't meant to be used outside of torch_key, just a - # helper for clarity. Instead, use torch_key() directly when you need - # a hash representing the state of the source code. - extra_files = ( - "codegen/aoti_runtime/interface.cpp", - "codegen/aoti_runtime/implementation.cpp", - "codegen/cpp_prefix.h", - "script.ld", - ) - inductor_root = os.path.dirname(__file__) - extra_files = [os.path.join(inductor_root, x) for x in extra_files] - hasher = hashlib.sha256() - hasher.update(torch.__version__.encode("utf-8")) - build_code_hash([root], "", hasher) - for path in extra_files: - if os.path.exists(path): - with open(path, "rb") as f: - hasher.update(f.read()) - return hasher.digest() - - return get_code_hash(_TORCH_PATH) + with dynamo_timed("inductor_codecache_torch_key"): + if not config.is_fbcode(): + + def get_code_hash(root: str) -> bytes: + # This function isn't meant to be used outside of torch_key, just a + # helper for clarity. Instead, use torch_key() directly when you need + # a hash representing the state of the source code. + extra_files = ( + "codegen/aoti_runtime/interface.cpp", + "codegen/aoti_runtime/implementation.cpp", + "codegen/cpp_prefix.h", + "script.ld", + ) + inductor_root = os.path.dirname(__file__) + extra_files = [os.path.join(inductor_root, x) for x in extra_files] + hasher = hashlib.sha256() + hasher.update(torch.__version__.encode("utf-8")) + build_code_hash([root], "", hasher) + for path in extra_files: + if os.path.exists(path): + with open(path, "rb") as f: + hasher.update(f.read()) + return hasher.digest() + + return get_code_hash(_TORCH_PATH) - from libfb.py import parutil + from libfb.py import parutil - return parutil.get_file_contents("torch/src_hash.txt").rstrip().encode("ascii") + return parutil.get_file_contents("torch/src_hash.txt").rstrip().encode("ascii") def get_inductor_root() -> str: @@ -738,23 +757,25 @@ class FxGraphHashDetails: def __init__( self, gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], - fx_kwargs: Dict[str, Any], + example_inputs: Sequence[InputType], + fx_kwargs: _CompileFxKwargs, inputs_to_check: Sequence[int], ) -> None: self.gm = gm self.example_inputs = example_inputs - # Order kwargs so hashing is stable to changes in kwarg order. - self.fx_kwargs = {} - for k in sorted(fx_kwargs): + # Order kwargs so hashing is stable to changes in kwarg order. Although + # it's technically a _CompileFxKwargs we don't actually need it typed as + # such since we're just using it to generate a hash. + self.fx_kwargs: Dict[str, object] = {} + for k, v in sorted(fx_kwargs.items()): if k not in self.EXCLUDED_KWARGS: - if type(fx_kwargs[k]) is set: + if type(v) is set: # Special case to handle set params. Python sets can't be # ordered, so sort the elements and store them in a proxy. - self.fx_kwargs[k] = OrderedSetHolder(sorted(fx_kwargs[k])) + self.fx_kwargs[k] = OrderedSetHolder(sorted(v)) else: - self.fx_kwargs[k] = fx_kwargs[k] + self.fx_kwargs[k] = v # Alignment checks self.inputs_to_check = inputs_to_check @@ -805,8 +826,8 @@ def debug_lines(self) -> List[str]: def compiled_fx_graph_hash( gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], - fx_kwargs: Dict[str, Any], + example_inputs: Sequence[InputType], + fx_kwargs: _CompileFxKwargs, inputs_to_check: Sequence[int], ) -> Tuple[str, List[str]]: """ @@ -823,7 +844,7 @@ def compiled_fx_graph_hash( def cudagraph_post_compile( - example_inputs: List[Any], + example_inputs: Sequence[InputType], compiled_graph: CompiledFxGraph, cudagraphs: BoxedBool, ) -> None: @@ -866,7 +887,7 @@ def cudagraph_post_compile( assert current_callable is not None compiled_graph.current_callable = cudagraphify( current_callable, - static_input_idxs=static_input_idxs, + static_input_idxs=static_input_idxs or (), device_index=next(iter(compiled_graph.device_idxs)), stack_traces=stack_traces, is_backward=is_backward, @@ -1005,7 +1026,7 @@ def _get_tmp_dir_for_key(key: str) -> str: return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key) @staticmethod - def _filter_backed_symints(inputs: List[Any]) -> List[torch.SymInt]: + def _filter_backed_symints(inputs: Sequence[InputType]) -> List[torch.SymInt]: """ Get the backed SymInt objects from the input list. Note that we can never have guards that depend on unbacked symint. @@ -1025,7 +1046,7 @@ def _get_shape_env() -> Optional[ShapeEnv]: @staticmethod def _lookup_graph( key: str, - example_inputs: List[torch.Tensor], + example_inputs: Sequence[InputType], local: bool, remote_cache: Optional[RemoteCache[JsonDataTy]], ) -> Optional[CompiledFxGraph]: @@ -1116,6 +1137,9 @@ def iterate_over_candidates() -> Generator[CompiledFxGraph, None, None]: write_atomic(artifact_path, code, make_dirs=True) + inductor_meta = autotune_cache.inductor_meta_from_config() + AutotuneCacheBundler.begin_compile(inductor_meta, code=code) + try: graph.current_callable = PyCodeCache.load_by_key_path( graph.cache_key, @@ -1161,7 +1185,7 @@ def iterate_over_candidates() -> Generator[CompiledFxGraph, None, None]: @staticmethod def post_compile( compiled_graph: CompiledFxGraph, - example_inputs: List[torch.Tensor], + example_inputs: Sequence[InputType], cudagraphs: BoxedBool, ) -> CompiledFxGraph: """ @@ -1208,7 +1232,7 @@ def post_compile( def _save_graph( key: str, compiled_graph: CompiledFxGraph, - example_inputs: List[torch.Tensor], + example_inputs: Sequence[InputType], local: bool, remote_cache: Optional[RemoteCache[JsonDataTy]], ) -> None: @@ -1285,6 +1309,12 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None: "Freezing may introduce constants that aren't static across runs" ) + from torch._inductor.bisect_helper import BisectionManager + + if BisectionManager.bisection_enabled: + log.debug("dont cache graph when bisect enabled") + raise BypassFxGraphCache + # The treatment of guards in the caching implementation requires that # we have a shape env. if FxGraphCache._get_shape_env() is None: @@ -1311,8 +1341,8 @@ def _check_can_cache(gm: torch.fx.GraphModule) -> None: @staticmethod def prepare_key( gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], - fx_kwargs: Dict[str, Any], + example_inputs: Sequence[InputType], + fx_kwargs: _CompileFxKwargs, inputs_to_check: Sequence[int], remote: bool, ) -> Tuple[Optional[Tuple[str, List[str]]], Dict[str, Any]]: @@ -1362,7 +1392,7 @@ def get_remote_cache() -> Optional[RemoteCache[JsonDataTy]]: def load_with_key( key: str, debug_lines: List[str], - example_inputs: List[torch.Tensor], + example_inputs: Sequence[InputType], local: bool, remote_cache: Optional[RemoteCache[JsonDataTy]], is_backward: bool, @@ -1381,7 +1411,7 @@ def load_with_key( "cache_event_time": time_ns(), } if compiled_graph is not None: - log.debug("fx graph cache hit for key %s", key) + log.info("fx graph cache hit for key %s", key) counters["inductor"]["fxgraph_cache_hit"] += 1 cache_info["cache_state"] = "hit" @@ -1395,7 +1425,7 @@ def load_with_key( ) != 0: cache_info["ephemeral_timeout_increase"] = ephemeral_increase else: - log.debug("fx graph cache miss for key %s", key) + log.info("fx graph cache miss for key %s", key) counters["inductor"]["fxgraph_cache_miss"] += 1 cache_info["cache_state"] = "miss" @@ -1405,8 +1435,8 @@ def load_with_key( def load( # type: ignore[no-untyped-def] compile_fx_fn: Callable[..., Any], gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], - fx_kwargs: Dict[str, Any], + example_inputs: Sequence[InputType], + fx_kwargs: _CompileFxKwargs, inputs_to_check: Sequence[int], local: bool, remote: bool, @@ -1482,6 +1512,18 @@ def load( # type: ignore[no-untyped-def] cache_info["cache_event_time"], metadata=cache_info, ) + # Add event data about cache hits/miss + # TODO: add remote cache get/put timings here too + chromium_log.add_event_data( + "inductor_compile", + cache_state=cache_state, + cache_event_time=cache_info["cache_event_time"], + key=cache_info.get("key"), + components=cache_info.get("components"), + cache_bypass_reason=cache_info.get("cache_bypass_reason"), + remote_cache_enabled=remote, + local_cache_enabled=local, + ) torch._logging.trace_structured( "artifact", metadata_fn=lambda: { @@ -1492,7 +1534,7 @@ def load( # type: ignore[no-untyped-def] ) # Use the passed in cudagraphs so that we mutate the BoxedBool correctly FxGraphCache.post_compile( - compiled_graph, example_inputs, fx_kwargs["cudagraphs"] + compiled_graph, example_inputs, fx_kwargs["cudagraphs"] # type: ignore[arg-type] ) return compiled_graph @@ -1539,7 +1581,7 @@ class CompiledFxGraph: guards_expr: Optional[str] cudagraph_info: Optional[CudagraphCachedInfo] - fx_kwargs: Dict[str, Any] + fx_kwargs: _CompileFxKwargs inputs_to_check: Sequence[int] boxed_forward_device_index: Optional[BoxedDeviceIndex] @@ -1579,9 +1621,12 @@ def __init__( self.inputs_to_check = () self.boxed_forward_device_index = None - def __call__(self, inputs: List[Any]) -> Any: + def __call__(self, inputs: Sequence[Any]) -> Any: assert self.current_callable is not None - return self.current_callable(inputs) + try: + return self.current_callable(inputs) + finally: + AutotuneCacheBundler.end_compile() def run_command_and_check(cmd_: str) -> None: @@ -1673,7 +1718,7 @@ def compile( else: objcopy_command = build_paths.objcopy else: - ld_command = "ld" + ld_command = "ld -z noexecstack" objcopy_command = "objcopy" ( diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 20c66fcf6c9a2..2329cc1aba9ab 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import contextlib import dataclasses +import enum import functools import itertools import logging @@ -38,6 +39,7 @@ DeferredLineBase, generate_assert, IndentedBuffer, + ir_dataclass, sympy_dot, sympy_subs, unique, @@ -53,7 +55,27 @@ def data_type_logger(msg): schedule_log.debug("Data type propagation: %s", msg) -@dataclasses.dataclass +class WorkspaceZeroMode(enum.Enum): + UNINITIALIZED = 0 + ZERO_ON_CALL = 1 # kernel may leave workspace dirty + ZERO_PER_GRAPH = 2 # must be re-zeroed by kernel + + @staticmethod + def combine(a, b): + if a == b or b == WorkspaceZeroMode.UNINITIALIZED: + return a + if a == WorkspaceZeroMode.UNINITIALIZED: + return b + raise NotImplementedError(f"WorkspaceZeroMode.combine({a!r}, {b!r})") + + @staticmethod + def from_bool(zero_fill): + if zero_fill: + return WorkspaceZeroMode.ZERO_ON_CALL + return WorkspaceZeroMode.UNINITIALIZED + + +@ir_dataclass(frozen=True) class WorkspaceArg: """A temporary buffer used for a single kernel, then discarded. @@ -61,8 +83,84 @@ class WorkspaceArg: so it would be dead code eliminated. """ - nbytes: sympy.Expr - zero_fill: bool + count: sympy.Expr + zero_mode: WorkspaceZeroMode + device: torch.device + outer_name: str + inner_name: str = "ws_ptr" + dtype: torch.dtype = torch.uint8 + + @staticmethod + def unique_name(prefix="workspace_"): + return f"{prefix}{next(V.graph.workspace_id)}" + + @staticmethod + def can_join(a, b) -> bool: + return ( + a.inner_name == b.inner_name and a.dtype == b.dtype and a.device == b.device + ) + + @staticmethod + def join(a, b): + return WorkspaceArg( + count=a.count + b.count, + zero_mode=WorkspaceZeroMode.combine(a.zero_mode, b.zero_mode), + dtype=a.dtype, + device=a.device, + inner_name=a.inner_name, + outer_name=a.outer_name, + ) + + @staticmethod + def maximum(a, b): + assert ( + a.zero_mode == b.zero_mode + and a.dtype == b.dtype + and a.device == b.device + and a.inner_name == b.inner_name + and a.outer_name == b.outer_name + ) + return WorkspaceArg( + count=sympy.Max(a.count, b.count), + zero_mode=a.zero_mode, + dtype=a.dtype, + device=a.device, + inner_name=a.inner_name, + outer_name=a.outer_name, + ) + + # These methods let WorkspaceArg pretend it is a buffer to reuse allocation code + def get_device(self): + return self.device + + def get_dtype(self): + return self.dtype + + def get_layout(self): + from ..ir import FixedLayout + + return FixedLayout( + device=self.device, + dtype=self.dtype, + size=[self.count], + stride=[1], + ) + + @property + def layout(self): + return self.get_layout() + + def get_size(self): + return [self.count] + + def get_stride(self): + return [1] + + def get_name(self): + return self.outer_name + + def get_inputs_that_alias_output(self): + return [] @dataclasses.dataclass @@ -84,6 +182,11 @@ def alias_of(self): return None +@dataclasses.dataclass +class TMADescriptorArg: + name: str + + @dataclasses.dataclass class DeviceCodegen: scheduling: Any @@ -91,7 +194,7 @@ class DeviceCodegen: cpp_wrapper_codegen: type = type(None) -KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg] +KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg, TMADescriptorArg] device_codegens: Dict[str, DeviceCodegen] = {} @@ -224,8 +327,7 @@ def get_wrapper_codegen_for_device(device: str, cpp_wrapper: bool = False): if cpp_wrapper else wrapper_codegen_obj.wrapper_codegen ) - else: - return None + return None @functools.lru_cache(None) @@ -575,8 +677,7 @@ def _print_Pow(self, expr): assert exp >= 0 if exp > 0: return "*".join([self.paren(base)] * exp) - else: # exp == 0 - return "1" + return "1" # Explicit NotImplemented functions are to prevent default sympy printing # behavior, which will just barf out ToFloat(...) to your IR. The error @@ -1255,7 +1356,7 @@ def __init__(self, sizevars=None): self.output_buffers = {} self.inplace_buffers = {} self.sizevars = sizevars or {} - self.workspace_arg = None + self.workspace_args = [] def __repr__(self): return "KernelArgs({})".format( @@ -1310,14 +1411,62 @@ def make_inplace(self, input_name, output_name): self.inplace_buffers[output_name] = buf def workspace(self, nbytes: sympy.Expr, zero_fill: bool): - if self.workspace_arg is None: - self.workspace_arg = WorkspaceArg(nbytes, zero_fill) - return "ws_ptr", 0 + """ + Allocate a new uint32 scratch space for use within this kernel. Multiple calls to this function will + extend the buffer returning a new region on each call. + + Args: + nbytes: size to add to the workspace. + zero_fill: True if the workspace should be zero-filled. - offset = self.workspace_arg.nbytes - zero_fill = zero_fill or self.workspace_arg.zero_fill - self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill) - return "ws_ptr", offset + Returns: + (buffer_name: str, offset: sympy.Expr) + """ + arg = WorkspaceArg( + count=nbytes, + zero_mode=WorkspaceZeroMode.from_bool(zero_fill), + device=V.graph.get_current_device_or_throw(), + outer_name=WorkspaceArg.unique_name(), + ) + for i, existing_arg in enumerate(self.workspace_args): + if WorkspaceArg.can_join(existing_arg, arg): + offset = existing_arg.count + self.workspace_args[i] = WorkspaceArg.join(existing_arg, arg) + return existing_arg.inner_name, offset + assert ( + existing_arg.inner_name != arg.inner_name + and existing_arg.outer_name != arg.outer_name + ) + self.workspace_args.append(arg) + return arg.inner_name, 0 + + def semaphores(self, min_size: sympy.Expr): + """ + Lazily allocate a graph-wide semaphores buffer with at least min_size. This is a single buffer shared by + all kernels and zero initialized once at graph start. Each kernel must leave the buffer zeroed on exit. + + Warning: multiple calls to this function will return the same buffer. + + Args: + min_size: the number of int32 semaphores required + + Returns: + name of the semaphores buffer + """ + current_device = V.graph.get_current_device_or_throw() + arg = WorkspaceArg( + count=min_size, + zero_mode=WorkspaceZeroMode.ZERO_PER_GRAPH, + dtype=torch.int32, + inner_name="sem_ptr", + outer_name=f"semaphores_{current_device.type}_{current_device.index}", + device=current_device, + ) + for existing_arg in self.workspace_args: + if existing_arg.inner_name == arg.inner_name: + assert arg == existing_arg + self.workspace_args.append(arg) + return arg.inner_name def seed_offset(self, name, value): if value in self.sizevars: @@ -1384,7 +1533,7 @@ def cpp_argdefs(self): arg_types.append(f"const {INDEX_TYPE}") if V.graph.wrapper_code: V.graph.wrapper_code.ensure_size_computed(outer) - assert self.workspace_arg is None, "Workspace not supported on CPU " + assert not self.workspace_args, "Workspace not supported on CPU " return arg_defs, call_args, arg_types def python_argdefs(self): @@ -1427,11 +1576,11 @@ def python_argdefs(self): precompile_args.append(SizeArg(inner, outer)) if V.graph.wrapper_code: V.graph.wrapper_code.ensure_size_computed(outer) - if self.workspace_arg is not None: - arg_defs.append("ws_ptr") - call_args.append("workspace") - precompile_args.append(self.workspace_arg) - arg_types.append(torch.uint8) + for arg in self.workspace_args: + arg_defs.append(arg.inner_name) + call_args.append(arg.outer_name) + precompile_args.append(arg) + arg_types.append(arg.dtype) return arg_defs, call_args, precompile_args, arg_types def aliases(self): @@ -1480,11 +1629,17 @@ class CSEVariable: See example of TritonCSEVariable in triton.py """ - def __init__(self, name, bounds: ValueRanges[Any]): + def __init__( + self, + name, + bounds: ValueRanges[Any], + dtype: Optional[torch.dtype] = None, + ): assert isinstance(bounds, ValueRanges) self.name = name self.bounds = bounds self.use_count = 1 # track how many tims this expression is used + self.dtype = dtype def __str__(self): return self.name @@ -1503,16 +1658,6 @@ def __repr__(self): class CppWrapperKernelArgs(KernelArgs): - def wrap_ptr_arg(self, buf, dtype): - from .cpp_utils import DTYPE_TO_CPP - - if config.abi_compatible: - # In the abi_compatible model, we just return the buf here. - # We will form correct call args later in wrapper.generate_kernel_all. - return buf - else: - return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())" - def wrap_size_arg(self, size): return f"{size}" @@ -1566,6 +1711,7 @@ def generate( bounds: ValueRanges[Any] = ValueRanges.unknown(), write=True, assignment=True, + dtype: Optional[torch.dtype] = None, ) -> CSEVariable: if isinstance(expr, OpsValue): expr = expr.value @@ -1582,7 +1728,7 @@ def generate( cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr var = self.cache.get(cache_key, None) if not var: - var = self.newvar(bounds) + var = self.newvar(bounds, dtype) self.cache[cache_key] = var if write: if V.kernel.current_node: @@ -1606,13 +1752,217 @@ def generate( return var - def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable: + def newvar( + self, + bounds: ValueRanges[Any] = ValueRanges.unknown(), + dtype: Optional[torch.dtype] = None, + ) -> CSEVariable: var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}" - var = V.kernel.create_cse_var(var_name, bounds) + var = V.kernel.create_cse_var(var_name, bounds, dtype) self.varname_map[var_name] = var return var +@functools.lru_cache(None) +def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND): + def construct_input(inp): + if isinstance(inp, torch._prims_common.Number): + return inp + else: + assert hasattr(inp, "dtype") + + # construct a tmp tensor to use dtype promotion util function + return torch.empty([1], dtype=inp.dtype) + + inps = [construct_input(arg) for arg in args] + _, dtype = torch._prims_common.elementwise_dtypes( + *inps, type_promotion_kind=type_promotion_kind + ) + return dtype + + +def promote_types(args): + dtype_prop_candidates = [] + + # CSEVariable and scalar will be included in dtype_prop_candidates + for arg in args: + if isinstance(arg, str): + continue + elif ( + isinstance(arg, OpsValue) + and isinstance(arg.value, CSEVariable) + and arg.value.dtype is not None + ): + dtype_prop_candidates.append(arg.value) + elif (isinstance(arg, CSEVariable) and arg.dtype is not None) or isinstance( + arg, torch._prims_common.Number + ): + dtype_prop_candidates.append(arg) # type: ignore[arg-type] + + dtype = get_promoted_dtype( + *dtype_prop_candidates, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT, + ) + + return dtype + + +class DtypePropagationOpsHandler: + """ + Propagate dtype from args to output + """ + + @staticmethod + def default_handler(*args): + # Fallback to FP32 dtype + return torch.float32 + + @staticmethod + def randint64(seed, offset, low, high): + return torch.int64 + + @staticmethod + def where(a, b, c): + return promote_types([b, c]) + + @staticmethod + def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype): + return dtype + + @staticmethod + def load_seed(name, offset): + return torch.float32 + + @staticmethod + def masked(mask, body, other): + # TODO: inspect body to propagate dtype + return torch.float32 + + @staticmethod + def index_expr(expr, dtype): + return dtype + + @staticmethod + def isnan(x): + return torch.bool + + @staticmethod + def lt(a, b): + return torch.bool + + @staticmethod + def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): + return dtype + + @staticmethod + def constant(value, dtype): + return dtype + + @staticmethod + def mul(a, b): + return promote_types([a, b]) + + @staticmethod + def sub(a, b): + return promote_types([a, b]) + + @staticmethod + def add(a, b): + return promote_types([a, b]) + + @staticmethod + def div(a, b): + return promote_types([a, b]) + + @staticmethod + def abs(x): + return promote_types([x]) + + @staticmethod + def exp(x): + return promote_types([x]) + + @staticmethod + def truediv(a, b): + return promote_types([a, b]) + + @staticmethod + def pow(a, b): + return promote_types([a, b]) + + @staticmethod + def sqrt(x): + return promote_types([x]) + + @staticmethod + def rsqrt(x): + return promote_types([x]) + + @staticmethod + def sigmoid(x): + return promote_types([x]) + + @staticmethod + def gelu(x): + return promote_types([x]) + + @staticmethod + def neg(x): + return promote_types([x]) + + @staticmethod + def minimum(a, b): + return promote_types([a, b]) + + @staticmethod + def maximum(a, b): + return promote_types([a, b]) + + @staticmethod + def log(x): + return promote_types([x]) + + @staticmethod + def log1p(x): + return promote_types([x]) + + @staticmethod + def gt(a, b): + return torch.bool + + @staticmethod + def ge(a, b): + return torch.bool + + @staticmethod + def reciprocal(x): + return promote_types([x]) + + @staticmethod + def and_(a, b): + return torch.bool + + @staticmethod + def bitwise_right_shift(a, b): + return a.dtype + + @staticmethod + def bitwise_left_shift(a, b): + return a.dtype + + @staticmethod + def sin(x): + return promote_types([x]) + + @staticmethod + def cos(x): + return promote_types([x]) + + @staticmethod + def mod(a, b): + return promote_types([a, b]) + + class CodeGen: def __init__(self) -> None: super().__init__() @@ -1848,8 +2198,17 @@ def inner(*args, **kwargs): value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type] def do_cse(v): + output_dtype = getattr( + DtypePropagationOpsHandler, + name, + DtypePropagationOpsHandler.default_handler, + )(*args) + csevar = V.kernel.cse.generate( - V.kernel.compute, v, bounds=bounds + V.kernel.compute, + v, + bounds=bounds, + dtype=output_dtype, ) csevar.update_on_args(name, args, kwargs) return csevar @@ -1898,8 +2257,7 @@ def arg_to_bound(x): arg_bounds = list(map(arg_to_bound, args)) return getattr(CSEProxy.vr_analysis, name)(*arg_bounds) - else: - return ValueRanges.unknown() + return ValueRanges.unknown() @staticmethod def indirect_indexing( @@ -1993,8 +2351,7 @@ def store( CSEProxy._update_store_cache(name, value) if name not in V.graph.removed_buffers: return self.store(name, index, value, mode=mode) - else: - return None # type: ignore[return-value] + return None # type: ignore[return-value] @staticmethod def store_reduction(name: str, index: sympy.Expr, value: CSEVariable): @@ -2198,47 +2555,46 @@ def indent_except_first(source: str, num_indents: int, indents_spacing=4): @staticmethod def _template_from_string(source): env = jinja2_env() - if env is not None: - env.filters["indent_except_first"] = KernelTemplate.indent_except_first - from jinja2 import TemplateSyntaxError - - class DetailedTemplateSyntaxError(TemplateSyntaxError): - def __init__(self, original_error): - super().__init__( - original_error.message, - original_error.lineno, - original_error.name, - original_error.filename, - ) - self.original_error = original_error - - def __str__(self): - error_info = f"Error in template at line {self.lineno}\n" - error_info += f"Error message: {self.message}\n" - if hasattr(self.original_error, "source"): - lines = self.original_error.source.split("\n") - error_info += "Context:\n" - start = max(0, self.lineno - 2) - end = min(len(lines), self.lineno + 2) - for i in range(start, end): - if i == self.lineno - 1: - error_info += f"{i + 1}: --> {lines[i]}\n" - if hasattr(self.original_error, "column"): - error_info += ( - " " - + " " * (self.original_error.column - 1) - + "^\n" - ) - else: - error_info += f"{i + 1}: {lines[i]}\n" - return error_info - - try: - return env.from_string(source) - except TemplateSyntaxError as e: - raise DetailedTemplateSyntaxError(e) from e + if env is None: + return None + env.filters["indent_except_first"] = KernelTemplate.indent_except_first + from jinja2 import TemplateSyntaxError + + class DetailedTemplateSyntaxError(TemplateSyntaxError): + def __init__(self, original_error): + super().__init__( + original_error.message, + original_error.lineno, + original_error.name, + original_error.filename, + ) + self.original_error = original_error + + def __str__(self): + error_info = f"Error in template at line {self.lineno}\n" + error_info += f"Error message: {self.message}\n" + if hasattr(self.original_error, "source"): + lines = self.original_error.source.split("\n") + error_info += "Context:\n" + start = max(0, self.lineno - 2) + end = min(len(lines), self.lineno + 2) + for i in range(start, end): + if i == self.lineno - 1: + error_info += f"{i + 1}: --> {lines[i]}\n" + if hasattr(self.original_error, "column"): + error_info += ( + " " + + " " * (self.original_error.column - 1) + + "^\n" + ) + else: + error_info += f"{i + 1}: {lines[i]}\n" + return error_info - return None + try: + return env.from_string(source) + except TemplateSyntaxError as e: + raise DetailedTemplateSyntaxError(e) from e @staticmethod def _fake_get_dtype(fake_out): diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 00da8c0282e47..ceaa9c8cdb1cf 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1487,18 +1487,21 @@ def masked(mask, body, other): dtype = result.dtype body_code = f"{var}()" - body_code_vec = ( - body_code - if result.is_vec - else f"{V.kernel._get_vec_type(dtype)}({body_code})" - ) + + def maskify_or_vecify(code): + return ( + f"{V.kernel._get_mask_type()}::from({code})" + if dtype == torch.bool + else f"{V.kernel._get_vec_type(dtype)}({code})" + ) + + if result.is_vec: + body_code_vec = body_code + else: + body_code_vec = maskify_or_vecify(body_code) other_code = value_to_cpp(other, DTYPE_TO_CPP[dtype]) # loading bool as VecMask - other_code_vec = ( - f"{V.kernel._get_mask_type()}::from({other_code})" - if dtype == torch.bool - else f"{V.kernel._get_vec_type(dtype)}({other_code})" - ) + other_code_vec = maskify_or_vecify(other_code) assert isinstance(new_mask, CppCSEVariable), new_mask if new_mask.is_vec: code = BracesBuffer() @@ -3088,7 +3091,12 @@ def gen_transposed_tile_load_store(self, name, var, index, is_store): tile_var = self.cse.cache[load_or_store] if need_define: - define_line = f"alignas({factor}) {DTYPE_TO_CPP[dtype]} {tile_var}[{factor}*{factor}];" + cpp_dtype = DTYPE_TO_CPP[dtype] + # tiling_factor might be smaller than the alignment of cpp_dtype, such as + # with a vector that only holds 4 elements due to NEON 128-bit vectors and + # cpp_dtype being a 64-bit integer. + alignas = f"alignas(std::max(std::size_t({factor}), alignof({cpp_dtype})))" + define_line = f"{alignas} {cpp_dtype} {tile_var}[{factor}*{factor}];" self.preloads.writeline(define_line) load_or_store = load_or_store.replace("__place_holder__", str(tile_var)) @@ -4393,8 +4401,8 @@ def try_share_local_buffer(local_buffer_layout, local_buffers): if not local_buffer_used: # Create new local buffer local_buffer_used = ir.Buffer( - f"{local_buf_prefix}_{len(local_buffers)}", - local_buffer_layout, + name=f"{local_buf_prefix}_{len(local_buffers)}", + layout=local_buffer_layout, ) local_buffers.append(local_buffer_used) local_to_global_buffers[local_buffer_used.name] = [] diff --git a/torch/_inductor/codegen/cpp_gemm_template.py b/torch/_inductor/codegen/cpp_gemm_template.py index 9ffc3da0578ee..2552b01bf1a0e 100644 --- a/torch/_inductor/codegen/cpp_gemm_template.py +++ b/torch/_inductor/codegen/cpp_gemm_template.py @@ -13,7 +13,13 @@ from .. import config, ir, lowering as L from ..kernel.mm_common import mm_args from ..select_algorithm import DataProcessorTemplateWrapper -from ..utils import cache_on_self, has_free_symbols, parallel_num_threads +from ..utils import ( + cache_on_self, + has_free_symbols, + is_same_mkldnn_tensor, + is_same_tensor, + parallel_num_threads, +) from ..virtualized import ops, V from .cpp import get_export_declaration from .cpp_micro_gemm import CppMicroGemmAMX, create_micro_gemm, LayoutType @@ -555,6 +561,19 @@ def reorder_and_filter(inputs, layout_or_out): assert len(input_indices) >= 2 return [inputs[idx] for idx in input_indices], layout_or_out + new_inputs, new_layout = reorder_and_filter(input_nodes, layout) + assert new_inputs[1].get_name() in V.graph.constants + is_mkldnn_wgt = V.graph.constants[new_inputs[1].get_name()].is_mkldnn + if is_mkldnn_wgt: + # It shouldn't happen as viewing an mkldnn tensor, we can extend the + # implementation if it does. + assert not isinstance(new_inputs[1], ir.BaseView) + assert isinstance(new_inputs[1].layout, ir.FixedLayout) + # Note that the layout of MKLDNN Tensor is with the wrong stride + view_size = new_inputs[1].layout.size + view_stride = new_inputs[1].layout.stride + view_offset = new_inputs[1].layout.offset + def maybe_to_dense(inputs, layout_or_out): new_inputs = list(inputs) if isinstance(inputs[1], torch.Tensor): @@ -563,12 +582,19 @@ def maybe_to_dense(inputs, layout_or_out): return new_inputs, layout_or_out def normalize_shapes(inputs, layout_or_out): - if not trans_w: - return inputs, layout_or_out new_inputs = list(inputs) - X = inputs[0] - W = inputs[1] - B = inputs[2] if has_bias else None + if not is_mkldnn_wgt and isinstance(new_inputs[1], torch.Tensor): + # With the assumptation that W is the storage of unwrap view + # thus view it back here + new_inputs[1] = new_inputs[1].as_strided( + view_size, view_stride, view_offset + ) + + if not trans_w: + return new_inputs, layout_or_out + X = new_inputs[0] + W = new_inputs[1] + B = new_inputs[2] if has_bias else None if isinstance(W, ir.IRNode): if trans_w: if not isinstance(W, ir.TensorBox): @@ -593,9 +619,7 @@ def normalize_shapes(inputs, layout_or_out): # TODO(jgong5): decide proper number of threads per problem size num_threads = parallel_num_threads() - new_inputs, _ = normalize_shapes( - *maybe_to_dense(*reorder_and_filter(input_nodes, layout)) - ) + new_inputs, _ = normalize_shapes(*maybe_to_dense(new_inputs, new_layout)) m, n, k, *_ = mm_args(new_inputs[0], new_inputs[1]) output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( new_inputs[0].get_dtype() @@ -623,8 +647,8 @@ def pack_weight(inputs, layout_or_out): if isinstance(W, ir.IRNode): new_size = [padded_n // block_n, k, block_n] blocked_w = ir.Buffer( - W.get_name(), # Borrow the registered buffer name - ir.FixedLayout( + name=W.get_name(), # Borrow the registered buffer name + layout=ir.FixedLayout( W.get_device(), W.get_dtype(), new_size, @@ -697,71 +721,89 @@ def preprocessor(inputs, layout): *normalize_shapes(*maybe_to_dense(*reorder_and_filter(inputs, layout))) ) - def postprocessor(output): - if isinstance(output, ir.TensorBox): - # prepack the weight as input to the template buffer - template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) - assert isinstance(template_buffer, ir.CppTemplateBuffer) - new_input_nodes, _ = reorder_and_filter(input_nodes, layout) - - W_node = new_input_nodes[1] - assert W_node.get_name() in V.graph.constants - W = V.graph.constants[W_node.get_name()] - new_input_nodes[1] = W - new_input_nodes, _ = pack_weight( - *normalize_shapes(*maybe_to_dense(new_input_nodes, layout)) + def prune_tensors(input_nodes, new_input_nodes): + def share_storage(base_tensor: torch.Tensor, comp_tensor: torch.Tensor): + return base_tensor.is_mkldnn == comp_tensor.is_mkldnn and ( + is_same_tensor(base_tensor, comp_tensor) + or is_same_mkldnn_tensor(base_tensor, comp_tensor) ) + def get_candidates(input_nodes, new_input_nodes): + # Only Constant Buffer like weight and bias might be changed in GEMM Template. + # The Inductor IR Node may changed, but still share the storage. For example: + # bias in bfloat16 case which only do the expand + return [ + node + for node in input_nodes + if ( + node not in new_input_nodes + and isinstance(node, (ir.TensorBox, ir.StorageBox)) + and node.get_name() in V.graph.constants + and not any( + ( + isinstance(new_node, (ir.TensorBox, ir.StorageBox)) + and new_node.get_name() in V.graph.constants + and share_storage( + V.graph.constants[node.get_name()], + V.graph.constants[new_node.get_name()], + ) + ) + for new_node in new_input_nodes + ) + ) + ] + + for candidate_node in get_candidates(input_nodes, new_input_nodes): # By using the new packed weight for the GEMM template, we can prune the # old weight if it has no other users. This saves memory but makes the FX graph # non-retraceable. To support retracing, we can add a repack node to the # FX graph. For example: # mkldnn._linear_pointwise <- repack_linear_wgt <- packed_wgt_for_template - W_tensor_users = 0 + candidate_tensor_users = 0 + candidate_tensor = V.graph.constants[candidate_node.get_name()] for node in reversed(V.graph.graph.nodes): - # Case may happen when the wgt tensor is used by more than 1 get_attr node + # Case may happen when the candidate tensor is used by more than 1 get_attr node # https://github.com/pytorch/pytorch/issues/134998 if node.op == "get_attr" and hasattr( V.graph.module, node.name - ): # wgt might already be deleted + ): # candidate tensor might already be deleted comp_tensor = getattr(V.graph.module, node.name) - if ( - W.is_mkldnn == comp_tensor.is_mkldnn - and W.dtype == comp_tensor.dtype - and W.device == comp_tensor.device - and ( - ( - not W.is_mkldnn - and ( - W.untyped_storage().data_ptr() - == comp_tensor.untyped_storage().data_ptr() - ) - ) - or ( - W.is_mkldnn - and ( - torch.ops.mkldnn.data_ptr(W) - == torch.ops.mkldnn.data_ptr(comp_tensor) - ) - ) - ) - ): - W_tensor_users += 1 + if share_storage(candidate_tensor, comp_tensor): + candidate_tensor_users += 1 for node in reversed(V.graph.graph.nodes): - # The wgt tensor has been used by only 1 get_attr node # The get_attr node has only 1 user fx node + # The candidate tensor has been used by only 1 get_attr node if ( - node.name == W_node.get_name() + node.name == candidate_node.get_name() and len(node.users) == 1 - and W_tensor_users == 1 + and candidate_tensor_users == 1 ): del V.graph.constants[node.name] delattr(V.graph.module, node.name) delattr(V.graph.graph.owning_module, node.name) + def postprocessor(output): + if isinstance(output, ir.TensorBox): + # prepack the weight as input to the template buffer + template_buffer = ir.InputsKernel.unwrap_storage_for_input(output) + assert isinstance(template_buffer, ir.CppTemplateBuffer) + new_input_nodes, _ = reorder_and_filter(input_nodes, layout) + + W_node = new_input_nodes[1] + assert W_node.get_name() in V.graph.constants + W = V.graph.constants[W_node.get_name()] + new_input_nodes[1] = W + new_input_nodes, _ = pack_weight( + *normalize_shapes(*maybe_to_dense(new_input_nodes, layout)) + ) W_packed = new_input_nodes[1] W_packed_constant = V.graph.add_tensor_constant(W_packed) + new_input_nodes[1] = W_packed_constant + + # Prune unused tensors + prune_tensors(input_nodes, new_input_nodes) + template_buffer.inputs[1] = ir.InputsKernel.unwrap_storage_for_input( W_packed_constant ) @@ -910,7 +952,9 @@ def copy_inner(index): # Y if epilogue_creators: gemm_output_name = f"{template_buffer.get_name()}_GemmOut" - gemm_output_buffer = ir.Buffer(gemm_output_name, template_buffer.layout) + gemm_output_buffer = ir.Buffer( + name=gemm_output_name, layout=template_buffer.layout + ) current_input_buffer = gemm_output_buffer for i, creator in enumerate(epilogue_creators): if i == len(epilogue_creators) - 1: @@ -929,7 +973,7 @@ def copy_inner(index): reindexers.append(None) if i < len(epilogue_creators) - 1: current_input_buffer = ir.Buffer( - buffer_name, template_buffer.layout + name=buffer_name, layout=template_buffer.layout ) Y_2d: Union[ir.Buffer, ir.ReinterpretView] = Y @@ -993,7 +1037,9 @@ def get_reindexer(epilogue_node): else: assert isinstance(Y, ir.Buffer) storage = ir.StorageBox(Y) - Y_2d = ir.ReinterpretView(storage, template_buffer.get_layout()) + Y_2d = ir.ReinterpretView( + data=storage, layout=template_buffer.get_layout() + ) output_dtype, compute_dtype = get_gemm_template_output_and_compute_dtype( X.get_dtype() diff --git a/torch/_inductor/codegen/cpp_prefix.h b/torch/_inductor/codegen/cpp_prefix.h index be7289b40d771..0ae57c7d4c649 100644 --- a/torch/_inductor/codegen/cpp_prefix.h +++ b/torch/_inductor/codegen/cpp_prefix.h @@ -28,7 +28,7 @@ #include #include -#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE256) #define INDUCTOR_USE_VECTOR_TYPES() 1 #else #define INDUCTOR_USE_VECTOR_TYPES() 0 @@ -637,8 +637,8 @@ void atomic_add_vec(T *addr, at::vec::VectorizedN index, at::vec::V static_assert(len <= at::vec::VectorizedN::size()); __at_align__ std::array tmpbuf; __at_align__ std::array tmpidx; - offset.store(tmpbuf.data()); - index.store(tmpidx.data()); + offset.store(tmpbuf.data(), len); + index.store(tmpidx.data(), len); for (int i = 0; i < len; i++){ atomic_add(addr + tmpidx[i], tmpbuf[i]); } diff --git a/torch/_inductor/codegen/cpp_template.py b/torch/_inductor/codegen/cpp_template.py index a237924b9182d..57da3f3dd4d82 100644 --- a/torch/_inductor/codegen/cpp_template.py +++ b/torch/_inductor/codegen/cpp_template.py @@ -33,7 +33,7 @@ def __init__( ) -> None: super().__init__(name) self.input_nodes = input_nodes - self.output_node: ir.Buffer = ir.Buffer("buf_out", layout) + self.output_node: ir.Buffer = ir.Buffer(name="buf_out", layout=layout) self.layout = layout self.num_threads = num_threads self.epilogue_creator = epilogue_creator @@ -113,8 +113,7 @@ def header(self) -> IndentedBuffer: res.writeline(codecache.cpp_prefix()) # TODO: add c10::ForcedUnroll test to test_aoti_abi_check res.splice("""#include """) - if config.abi_compatible: - res.splice("""#include """) + res.splice("""#include """) enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ "linux", "win32", diff --git a/torch/_inductor/codegen/cpp_template_kernel.py b/torch/_inductor/codegen/cpp_template_kernel.py index 7ac125c76cc6f..453e4b37375e9 100644 --- a/torch/_inductor/codegen/cpp_template_kernel.py +++ b/torch/_inductor/codegen/cpp_template_kernel.py @@ -182,7 +182,9 @@ def unroll_pragma(self, unroll): def define_buffer(self, name, sizes: List[Any], dtype=torch.float) -> str: """Define kernel local buffer""" sizes = parse_expr_with_index_symbols(sizes) - buf = ir.Buffer(name, ir.FixedLayout(torch.device("cpu"), dtype, sizes)) + buf = ir.Buffer( + name=name, layout=ir.FixedLayout(torch.device("cpu"), dtype, sizes) + ) self.local_buffers[name] = buf ctype = f"{DTYPE_TO_CPP[dtype]}" numel = f"{cexpr_index(buf.get_numel())}" @@ -346,7 +348,7 @@ def __init__( Dict[str, Union[ir.PrimitiveInfoType, List[ir.PrimitiveInfoType]]] ] = None, ): - super().__init__(name, input_nodes, layout) + super().__init__(name, input_nodes, layout, description="") self.category = category self.make_kernel_render = make_kernel_render self.bmreq = bmreq diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index edc9d8a4efae9..6c15e76253b94 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import contextlib -import copy +import dataclasses import functools import math import sys @@ -176,10 +176,14 @@ def deduce_dtype_for_cpp_cse_variable(name, *args, **kwargs): class CppCSEVariable(CSEVariable): - def __init__(self, name, bounds: ValueRanges[Any]) -> None: - super().__init__(name, bounds) + def __init__( + self, + name, + bounds: ValueRanges[Any], + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(name, bounds, dtype) self.is_vec = False - self.dtype: Optional[torch.dtype] = None self.dependent_itervars: Set[sympy.Symbol] = set() def __repr__(self) -> str: @@ -652,18 +656,19 @@ def localize_nodes( def wrap_inner_fn_for_node(node: ir.IRNode): loops = node.data if isinstance(node, ir.ComputedBuffer) else node assert isinstance(loops, ir.Loops) - new_loops = copy.copy(loops) + new_inner_fn = self.localize_function( + loops.inner_fn, + rewrite_index, + ) + + new_loops = dataclasses.replace(loops, inner_fn=new_inner_fn) if isinstance(node, ir.ComputedBuffer): new_node = ir.ComputedBuffer( - node.get_name(), node.get_layout(), new_loops + name=node.get_name(), layout=node.get_layout(), data=new_loops ) else: new_node = new_loops # type: ignore[assignment] - new_loops.inner_fn = self.localize_function( - new_loops.inner_fn, - rewrite_index, - ) return new_node return [wrap_inner_fn_for_node(node) for node in nodes] diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index 21c9e235e354d..891ed89ed8d66 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -15,17 +15,11 @@ from torch.fx.experimental.symbolic_shapes import ConvertIntKey, DivideByKey, SymTypes from .. import config, ir -from ..utils import _align, ALIGN_BYTES, cache_on_self, normalize_name, sympy_product +from ..utils import _align, ALIGN_BYTES, cache_on_self, normalize_name from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper from .common import IndentedBuffer, Kernel -from .cpp_utils import ( - cexpr, - DEVICE_TO_ATEN, - DTYPE_TO_ATEN, - DTYPE_TO_CPP, - LAYOUT_TO_ATEN, -) +from .cpp_utils import cexpr, DEVICE_TO_ATEN, DTYPE_TO_ATEN, DTYPE_TO_CPP from .wrapper import EnterSubgraphLine, ExitSubgraphLine, PythonWrapperCodegen @@ -45,7 +39,7 @@ def __init__(self): self.closed_bracket = "}" self.comment = "//" self.namespace = "at::" - self.none_str = "nullptr" if config.abi_compatible else "at::Tensor()" + self.none_str = "nullptr" self.extern_call_ops = set() self.size = "sizes()" self.stride = "strides()" @@ -70,6 +64,14 @@ def __init__(self): self.initialized_kernels: Dict[str, Kernel] = {} self.expr_printer = cexpr + @staticmethod + def create( + is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + ): + # TODO - support subgraph codegen by lifting functions. Check the + # comment at CppWrapperCpu `codegen_subgraph` function. + return CppWrapperCpu() + def generate_kernel_call( self, kernel_name: str, @@ -110,34 +112,29 @@ def generate_kernel_call( grid_extra_kwargs, ) else: - if config.abi_compatible: - assert arg_types is not None and len(call_args) == len( - arg_types - ), "Mismatch call_args and arg_types in generate_kernel_call" - new_args = [] - for idx, arg in enumerate(call_args): - if "*" in arg_types[idx]: - var_name = f"var_{next(self.arg_var_id)}" - self.writeline( - f"auto* {var_name} = get_data_ptr_wrapper({arg});" - ) - new_args.append(f"({arg_types[idx]})({var_name})") - else: - # arg is a scalar - new_args.append(arg) - # debug printer related logic for cpp kernel type. - debug_printer_manager = V.graph.wrapper_code.debug_printer - debug_printer_manager.set_printer_args( - call_args, - kernel_name, - None, - None, - "cpp", - ) - with debug_printer_manager: - self.writeline(self.wrap_kernel_call(kernel_name, new_args)) - else: - self.writeline(self.wrap_kernel_call(kernel_name, call_args)) + assert arg_types is not None and len(call_args) == len( + arg_types + ), "Mismatch call_args and arg_types in generate_kernel_call" + new_args = [] + for idx, arg in enumerate(call_args): + if "*" in arg_types[idx]: + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"auto* {var_name} = get_data_ptr_wrapper({arg});") + new_args.append(f"({arg_types[idx]})({var_name})") + else: + # arg is a scalar + new_args.append(arg) + # debug printer related logic for cpp kernel type. + debug_printer_manager = V.graph.wrapper_code.debug_printer + debug_printer_manager.set_printer_args( + call_args, + kernel_name, + None, + None, + "cpp", + ) + with debug_printer_manager: + self.writeline(self.wrap_kernel_call(kernel_name, new_args)) def write_constant(self, name, hashed): # include a hash so our code cache gives different constants different files @@ -167,40 +164,23 @@ def write_header(self): """ ) - if config.abi_compatible: - self.header.splice( - f"#include " - ) - self.header.splice( - """ - #include - #include - #include - """ - ) - if V.graph.aot_mode: - self.header.splice( - """ - #include - """ - ) - else: + self.header.splice( + f"#include " + ) + self.header.splice( + """ + #include + #include + #include + """ + ) + if V.graph.aot_mode: self.header.splice( """ - #include - #include - #include - #include - #include - #include - #include - #include - #include - - #define reinterpret_tensor torch::inductor::_reinterpret_tensor - #define alloc_from_pool torch::inductor::_alloc_from_pool + #include """ ) + enable_kernel_profile = config.cpp.enable_kernel_profile and sys.platform in [ "linux", "win32", @@ -292,18 +272,6 @@ def write_input_output_info( ): self.prefix.writeline(f"""{info_kind}[{idx}].name = "{name}";""") - @staticmethod - def get_input_cpp_type(input): - assert config.use_minimal_arrayref_interface - - if isinstance(input, sympy.Expr): - from ..graph import may_get_constant_buffer_dtype - - dtype = may_get_constant_buffer_dtype(input) - assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}" - return DTYPE_TO_CPP[dtype] - return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>" - def generate_input_output_runtime_checks(self): # In debug_compile mode, we generate checks to ensure the dtype/shape/stride of each # real input/output tensor match ones provided at compile time via sample @@ -401,23 +369,6 @@ def gen_check(handle_kind, idx, name, tensor): def write_wrapper_decl(self): inputs_len = len(V.graph.graph_inputs.keys()) if V.graph.aot_mode: - if config.use_minimal_arrayref_interface and not V.graph.is_const_graph: - input_cpp_types = ", ".join( - f"{CppWrapperCpu.get_input_cpp_type(x)}" - for x in V.graph.graph_inputs.values() - ) - output_arrayref_types = ", ".join( - f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>" - for x in V.graph.graph_outputs - ) - - self.prefix.splice( - f""" - using AOTInductorModelInputs = std::tuple<{input_cpp_types}>; - using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>; - """ - ) - if V.graph.const_module: self.header.splice(V.graph.const_module.wrapper_code.header) self.prefix.splice(V.graph.const_code) @@ -460,58 +411,13 @@ def write_wrapper_decl(self): AOTIProxyExecutorHandle proxy_executor ) { """ - # Since we are removing non-abi-compatible mode, let's generate - # runtime checks only for abi_compatible mode to avoid extra branches. - if config.aot_inductor.debug_compile and config.abi_compatible: + if config.aot_inductor.debug_compile: self.generate_input_output_runtime_checks() run_impl_proto += """ __check_inputs_outputs(input_handles, output_handles); """ - if config.use_minimal_arrayref_interface: - self.prefix.splice( - """ - template <> - AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface< - AOTInductorModelInputs, AOTInductorModelOutputs>( - const AOTInductorModelInputs& inputs, - DeviceStreamType stream, - AOTIProxyExecutorHandle proxy_executor - ) { - """ - ) - self.suffix.splice(run_impl_proto) - self.suffix.splice( - """ - AOTInductorModelInputs inputs; - convert_handles_to_inputs(input_handles, inputs); - auto outputs = run_impl_minimal_arrayref_interface( - inputs, stream, proxy_executor); - // NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this - // interface to perform well for a DSO using the minimal arrayref interface, all we need - // to do is provide ThreadLocalCachedTensor for each one! - convert_outputs_to_handles(outputs, output_handles); - } - """ - ) - self.suffix.splice( - """ - extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface( - AOTInductorModelHandle model_handle, - const AOTInductorModelInputs& inputs, - AOTInductorModelOutputs& outputs) { - auto model = reinterpret_cast(model_handle); - CONVERT_EXCEPTION_TO_ERROR_CODE({ - outputs = model->run_impl_minimal_arrayref_interface( - inputs, - (torch::aot_inductor::DeviceStreamType)nullptr, - nullptr); - }) - } - """ - ) - else: - self.prefix.splice(run_impl_proto) + self.prefix.splice(run_impl_proto) else: # cpp entry function for JIT with cpp wrapper self.prefix.splice( @@ -529,37 +435,23 @@ def write_wrapper_decl(self): ) with self.prefix.indent(): # assign inputs and outputs in both cases so the later codegen can be simplified - if not config.use_minimal_arrayref_interface: - if not V.graph.is_const_graph: - if V.graph.aot_mode: - num_args = len(V.graph.graph_inputs) - else: - # Weights are promoted in the JIT mode - num_args = len(V.graph.graph_inputs) + len(V.graph.constants) - # release GIL to support multiple instances inference (in different threads of the same process) - self.prefix.splice("py::gil_scoped_release release;") + if not V.graph.is_const_graph: + if V.graph.aot_mode: + num_args = len(V.graph.graph_inputs) + else: + # Weights are promoted in the JIT mode + num_args = len(V.graph.graph_inputs) + len(V.graph.constants) + # release GIL to support multiple instances inference (in different threads of the same process) + self.prefix.splice("py::gil_scoped_release release;") - if config.abi_compatible: - self.prefix.splice( - f""" - auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args}); - """ - ) - else: - # This looks dumb, but can avoid creating two versions of code in the AOTInductor runtime. - self.prefix.splice( - f""" - auto inputs = alloc_tensors_by_stealing_from_handles(input_handles, {num_args}); - """ - ) + self.prefix.splice( + f""" + auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args}); + """ + ) if inputs_len != 0: for idx, input_key in enumerate(V.graph.graph_inputs.keys()): - if config.use_minimal_arrayref_interface: - self.prefix.writeline( - f"auto {input_key} = std::get<{idx}>(inputs);" - ) - continue # unwrap input tensor back to scalar if isinstance(V.graph.graph_inputs[input_key], sympy.Expr): from ..graph import may_get_constant_buffer_dtype @@ -570,15 +462,9 @@ def write_wrapper_decl(self): assert ( dtype is not None ), "Fails to get the dtype of the sympy.Expr" - cpp_dtype = DTYPE_TO_CPP[dtype] - if config.abi_compatible: - self.codegen_tensor_item( - dtype, f"inputs[{idx}]", input_key, self.prefix - ) - else: - self.prefix.writeline( - f"{cpp_dtype} {input_key} = inputs[{idx}].item<{cpp_dtype}>();" - ) + self.codegen_tensor_item( + dtype, f"inputs[{idx}]", input_key, self.prefix + ) else: self.prefix.writeline( f"auto {input_key} = std::move(inputs[{idx}]);" @@ -591,81 +477,43 @@ def write_wrapper_decl(self): if V.graph.aot_mode: # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there. # Don't call std::move here because it will cause constants_ to lose the ownership. - if config.abi_compatible: - self.prefix.writeline( - f"""auto {constants_key} = constants_->at({idx});""" - ) - else: - self.prefix.writeline( - f"auto {constants_key} = *tensor_handle_to_tensor_pointer(" - + f"""constants_->at({idx}));""" - ) + self.prefix.writeline( + f"""auto {constants_key} = constants_->at({idx});""" + ) else: # Append constants as inputs to the graph constants_idx = inputs_len + idx - if config.abi_compatible: - self.prefix.writeline( - f"auto {constants_key} = std::move(inputs[{constants_idx}]);" - ) - else: - self.prefix.writeline( - f"auto {constants_key} = inputs[{constants_idx}];" - ) + self.prefix.writeline( + f"auto {constants_key} = std::move(inputs[{constants_idx}]);" + ) self.codegen_inputs(self.prefix, V.graph.graph_inputs) if V.graph.aot_mode: if not V.graph.is_const_graph: - if config.use_minimal_arrayref_interface: - # TODO: input shape checking for regular tensor interface as well? - self.codegen_input_numel_asserts() - else: - self.prefix.writeline("inputs.clear();") + self.prefix.writeline("inputs.clear();") self.prefix.writeline( "auto& kernels = static_cast(*this->kernels_.get());" ) - def codegen_input_numel_asserts(self): - for name, buf in V.graph.graph_inputs.items(): - if isinstance(buf, sympy.Expr): - continue - - # comparing strides for 0 size tensor is tricky. Ignore them for now. - if sympy_product(buf.get_size()) == 0: - continue - numel = buf.get_numel() - self.prefix.writeline(f"assert_numel({name}, {numel});") - def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name): - if config.abi_compatible: - code.writeline(f"int32_t {name}_dtype;") - code.writeline( - "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype" - f"({name}, &{name}_dtype));" - ) - else: - # Note that we don't have a corresponding class method from - # the PythonWrapperCodegen since this method is used for asserting AOTI - # cpp wrapper code. - code.writeline(f"auto {name}_dtype = {name}.dtype();") + code.writeline(f"int32_t {name}_dtype;") + code.writeline( + "AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype" + f"({name}, &{name}_dtype));" + ) def codegen_input_size_var_decl(self, code: IndentedBuffer, name): - if config.abi_compatible: - code.writeline(f"int64_t* {name}_size;") - code.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes({name}, &{name}_size));" - ) - else: - super().codegen_input_size_var_decl(code, name) + code.writeline(f"int64_t* {name}_size;") + code.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes({name}, &{name}_size));" + ) def codegen_input_stride_var_decl(self, code: IndentedBuffer, name): - if config.abi_compatible: - code.writeline(f"int64_t* {name}_stride;") - code.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides({name}, &{name}_stride));" - ) - else: - super().codegen_input_stride_var_decl(code, name) + code.writeline(f"int64_t* {name}_stride;") + code.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides({name}, &{name}_stride));" + ) def codegen_model_kernels(self): self.prefix.writeline("namespace {") @@ -958,13 +806,12 @@ def generate(self, is_inference): def finalize_prefix(self): cached_dtypes_buffer = IndentedBuffer() - if config.abi_compatible: - for dtype in self.used_cached_dtypes: - cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});") - for device in self.used_cached_devices: - cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});") - for layout in self.used_cached_layouts: - cached_dtypes_buffer.writeline(f"CACHE_TORCH_LAYOUT({layout});") + for dtype in self.used_cached_dtypes: + cached_dtypes_buffer.writeline(f"CACHE_TORCH_DTYPE({dtype});") + for device in self.used_cached_devices: + cached_dtypes_buffer.writeline(f"CACHE_TORCH_DEVICE({device});") + for layout in self.used_cached_layouts: + cached_dtypes_buffer.writeline(f"CACHE_TORCH_LAYOUT({layout});") cached_dtypes_buffer.splice(self.prefix) self.prefix = cached_dtypes_buffer @@ -983,9 +830,6 @@ def codegen_scalar_to_tensor(self, output: str): def codegen_tensor_item( self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None ): - assert ( - config.abi_compatible - ), "codegen_tensor_item is only used for the ABI-compatible mode" dtype_str = str(dtype).split(".")[-1] writer = indented_buffer or self @@ -1012,52 +856,28 @@ def codegen_tensor_item( @cache_on_self def get_output_refs(self): - return [ - f"torch::tensor({x.codegen_reference(self.wrapper_call)})" - if isinstance(x, ir.ShapeAsConstantBuffer) and not config.abi_compatible - else x.codegen_reference(self.wrapper_call) - for x in V.graph.graph_outputs - ] + return [x.codegen_reference(self.wrapper_call) for x in V.graph.graph_outputs] def generate_return(self, output_refs: List[str]): - cst_names = V.graph.constants.keys() - arr_iface = ( - not V.graph.is_const_graph and config.use_minimal_arrayref_interface - ) # For brevity. - def use_thread_local_cached_output_tensor(idx, output): cached_output_name = f"cached_output_{next(self.cached_output_id)}" - cache_type = "Array" if arr_iface else "Tensor" + cache_type = "Tensor" self.wrapper_call.writeline( f"thread_local ThreadLocalCachedOutput{cache_type}> " f"{cached_output_name}({output});" ) - if arr_iface: - self.wrapper_call.writeline( - f"{cached_output_name}.copy_data_from({output});" - ) - output_entry = f"std::get<{idx}>(output_arrayref_tensors)" - element_type = f"std::decay_t" - self.wrapper_call.writeline( - f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();" - ) - else: - self.wrapper_call.writeline( - f"{cached_output_name}.copy_data_from({output});" - ) - self.wrapper_call.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));" - ) - self.wrapper_call.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), " - f"output_handles[{idx}]));" - ) - - if arr_iface: self.wrapper_call.writeline( - "AOTInductorModelOutputs output_arrayref_tensors;" + f"{cached_output_name}.copy_data_from({output});" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), " + f"output_handles[{idx}]));" ) + cst_names = V.graph.constants.keys() output2idx: Dict[str, int] = {} for idx, output in enumerate(output_refs): if output == self.none_str: @@ -1070,99 +890,48 @@ def use_thread_local_cached_output_tensor(idx, output): if isinstance(output_storage.data, ir.ConstantBuffer): is_constant_buffer = True - if config.abi_compatible: - if isinstance(output_buffer, ir.ShapeAsConstantBuffer): - # Need to wrap scalar into tensor as the main function returns a vector of tensors - output_tensor = self.codegen_scalar_to_tensor(output) - self.wrapper_call.writeline( - f"output_handles[{idx}] = {output_tensor}.release();" - ) - continue - - output_is_tensor_handle_expr = ( - f"std::is_same_v," - "RAIIAtenTensorHandle> || " - f"std::is_same_v," - "AtenTensorHandle> || " - f"std::is_same_v," - "ConstantHandle>" - ) + if isinstance(output_buffer, ir.ShapeAsConstantBuffer): + # Need to wrap scalar into tensor as the main function returns a vector of tensors + output_tensor = self.codegen_scalar_to_tensor(output) self.wrapper_call.writeline( - f"if constexpr ({output_is_tensor_handle_expr}) {{" + f"output_handles[{idx}] = {output_tensor}.release();" ) - with self.wrapper_call.indent(): - if arr_iface: - cached_output_name = ( - f"cached_output_{next(self.cached_output_id)}" - ) - output_value_type = f"std::decay_t(output_arrayref_tensors).data()[0])>" - self.wrapper_call.writeline( - f"thread_local RAIIAtenTensorHandle {cached_output_name};" - ) - if is_constant_buffer: - # NOTE(return_constant): In some rare cases where we return - # a constant, we have to return a copy of this constant, - # because (1) constants are not owned by the Model instance - # (2) constants remain the same cross inference runs, - # assuming they are not updated at runtime Basically, we - # cannot release or transfer the ownership of any original - # constant to the user. - self.wrapper_call.writeline( - f"AtenTensorHandle {cached_output_name}_tmp;" - ) - self.wrapper_call.writeline( - f"aoti_torch_clone({output}, &{cached_output_name}_tmp);" - ) - self.wrapper_call.writeline( - f"{cached_output_name} = {cached_output_name}_tmp;" - ) - else: - self.wrapper_call.writeline( - f"{cached_output_name} = {output}.release();" - ) - self.wrapper_call.writeline( - f"convert_handle_to_arrayref_tensor({cached_output_name}, " - f"std::get<{idx}>(output_arrayref_tensors));" - ) - else: - if is_constant_buffer: - # See NOTE(return_constant) above. - self.wrapper_call.writeline( - f"aoti_torch_clone({output}, &output_handles[{idx}]);" - ) - else: - if output in output2idx: - src_idx = output2idx[output] - self.wrapper_call.writeline( - f"output_handles[{idx}] = output_handles[{src_idx}];" - ) - else: - self.wrapper_call.writeline( - f"output_handles[{idx}] = {output}.release();" - ) - self.wrapper_call.writeline("} else {") - with self.wrapper_call.indent(): - use_thread_local_cached_output_tensor(idx, output) - self.wrapper_call.writeline("}") + continue - else: - assert ( - not arr_iface - ), "minimal ArrayRef interface is only supported in ABI-compatible mode" + output_is_tensor_handle_expr = ( + f"std::is_same_v," + "RAIIAtenTensorHandle> || " + f"std::is_same_v," + "AtenTensorHandle> || " + f"std::is_same_v," + "ConstantHandle>" + ) + self.wrapper_call.writeline( + f"if constexpr ({output_is_tensor_handle_expr}) {{" + ) + with self.wrapper_call.indent(): if is_constant_buffer: - output_expr = f"{output}.clone()" # See NOTE(return_constant) above. + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &output_handles[{idx}]);" + ) else: - output_expr = output - self.wrapper_call.writeline( - f"output_handles[{idx}] = reinterpret_cast(" - + f"new at::Tensor({output_expr}));" - ) + if output in output2idx: + src_idx = output2idx[output] + self.wrapper_call.writeline( + f"output_handles[{idx}] = output_handles[{src_idx}];" + ) + else: + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output}.release();" + ) + self.wrapper_call.writeline("} else {") + with self.wrapper_call.indent(): + use_thread_local_cached_output_tensor(idx, output) + self.wrapper_call.writeline("}") if output not in output2idx: output2idx[output] = idx - if arr_iface: - self.wrapper_call.writeline("return output_arrayref_tensors;") def generate_before_suffix(self, result): if not V.graph.is_const_graph: @@ -1219,9 +988,11 @@ def generate_end(self, result): outputs_str = "output_tensors" else: outputs = [ - f"output_tensors[{i}]" - if self.output_is_tensor[i] - else f"output_tensors[{i}].item()" + ( + f"output_tensors[{i}]" + if self.output_is_tensor[i] + else f"output_tensors[{i}].item()" + ) for i in range(len(V.graph.graph_outputs)) ] outputs_str = f"[{', '.join(outputs)}]" @@ -1244,7 +1015,7 @@ def g(args): ) def get_c_shim_func_name(self, kernel): - if not config.abi_compatible or kernel.startswith("aoti_torch_"): + if kernel.startswith("aoti_torch_"): return kernel assert "::" in kernel, "Cpp kernel name: " + kernel + " does not contain '::'" @@ -1294,14 +1065,11 @@ def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args): self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});") def generate_extern_kernel_alloc(self, extern_kernel, args): - if config.abi_compatible: - if getattr(extern_kernel, "outputs", None): - # ir.ExternKernelAlloc may have outputs if it returns a tuple - self.generate_c_shim_fallback_kernel(extern_kernel, args) - else: - self.generate_c_shim_extern_kernel_alloc(extern_kernel, args) + if getattr(extern_kernel, "outputs", None): + # ir.ExternKernelAlloc may have outputs if it returns a tuple + self.generate_c_shim_fallback_kernel(extern_kernel, args) else: - super().generate_extern_kernel_alloc(extern_kernel, args) + self.generate_c_shim_extern_kernel_alloc(extern_kernel, args) def generate_c_shim_fallback_kernel(self, fallback_kernel, args): output_args = [] @@ -1339,10 +1107,7 @@ def generate_c_shim_fallback_kernel(self, fallback_kernel, args): self.writeline(raii_handle) def generate_fallback_kernel(self, fallback_kernel, args): - if config.abi_compatible: - self.generate_c_shim_fallback_kernel(fallback_kernel, args) - else: - super().generate_fallback_kernel(fallback_kernel, args) + self.generate_c_shim_fallback_kernel(fallback_kernel, args) def generate_extern_kernel_out( self, kernel: str, out: str, out_view: Optional[str], args: List[str] @@ -1354,11 +1119,7 @@ def generate_extern_kernel_out( else: args.insert(0, out) - if config.abi_compatible: - self.generate_c_shim_extern_kernel_call(kernel, args) - else: - # TODO: add debug printing info for non-abi compatible mode extern kernel call - self.writeline(self.wrap_kernel_call(kernel, args)) + self.generate_c_shim_extern_kernel_call(kernel, args) def generate_scatter_fallback( self, @@ -1370,20 +1131,19 @@ def generate_scatter_fallback( reduce, kwargs, ): - if config.abi_compatible: - # call the ABI shim function instead of the ATen one - cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name) - # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py - cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" - inputs_wrapped = [ + # call the ABI shim function instead of the ATen one + cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name) + # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py + cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" + inputs_wrapped = [ + ( f"convert_arrayref_tensor_to_tensor({x})" if isinstance(x, str) else str(x) - for x in inputs - ] - line = f"{cpp_kernel_name}(convert_arrayref_tensor_to_tensor({output}), {','.join(inputs_wrapped)}" - else: - line = f"{cpp_kernel_name}({','.join(map(str, inputs))}" + ) + for x in inputs + ] + line = f"{cpp_kernel_name}(convert_arrayref_tensor_to_tensor({output}), {','.join(inputs_wrapped)}" if python_kernel_name.startswith("aten.scatter_reduce"): line += f", {','.join(kwargs)}" @@ -1400,35 +1160,28 @@ def generate_scatter_fallback( def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version - if config.abi_compatible: - # See the comment in codegen_reinterpret_view about why having something like - # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding - # tensor prematurely deallocated, thus this std::vector().data() trick here. - indices_str = ( - "std::vector{" - + ( - ", ".join( - [f"convert_arrayref_tensor_to_tensor({ind})" for ind in indices] - ) + # See the comment in codegen_reinterpret_view about why having something like + # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding + # tensor prematurely deallocated, thus this std::vector().data() trick here. + indices_str = ( + "std::vector{" + + ( + ", ".join( + [f"convert_arrayref_tensor_to_tensor({ind})" for ind in indices] ) - + "}.data()" ) - args = [ - f"convert_arrayref_tensor_to_tensor({x})", - indices_str, - str(len(indices)), - f"convert_arrayref_tensor_to_tensor({values})", - accumulate, - ] - args.insert( - 0, f"convert_arrayref_tensor_to_tensor({x})" - ) # set x as the output tensor, this fallback mutates x. - else: - indices_str = ( - f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}" - ) - args = [x, indices_str, values, accumulate] - args.insert(0, x) # set x as the output tensor, this fallback mutates + + "}.data()" + ) + args = [ + f"convert_arrayref_tensor_to_tensor({x})", + indices_str, + str(len(indices)), + f"convert_arrayref_tensor_to_tensor({values})", + accumulate, + ] + args.insert( + 0, f"convert_arrayref_tensor_to_tensor({x})" + ) # set x as the output tensor, this fallback mutates x. self.writeline(self.wrap_kernel_call(kernel, args)) @@ -1441,11 +1194,8 @@ def codegen_sizevar(self, x: Expr) -> str: return self.expr_printer(V.graph.sizevars.simplify(x)) def codegen_tuple_access(self, basename: str, name: str, index: str) -> str: - if config.abi_compatible: - # in the abi_compatible mode, outputs are returned via arguments - return name - else: - return f"std::get<{index}>({basename})" + # in the abi_compatible mode, outputs are returned via arguments + return name def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: parts = list(map(self.codegen_sizevar, shape)) @@ -1457,21 +1207,13 @@ def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str: def codegen_dynamic_scalar(self, node): (data,) = (t.codegen_reference() for t in node.inputs) - if config.abi_compatible: - self.codegen_tensor_item( - node.inputs[0].get_dtype(), data, f"{node.sym}_raw" - ) - else: - convert_type = DTYPE_TO_ATEN[node.inputs[0].get_dtype()].replace( - "at::k", "to" - ) - self.writeline(f"auto {node.sym}_raw = {data}.item().{convert_type}();") + self.codegen_tensor_item(node.inputs[0].get_dtype(), data, f"{node.sym}_raw") if len(node.keypath) == 0: self.writeline(f"auto {node.sym} = {node.sym}_raw;") - elif len(node.keypath == 1) and isinstance(node.keypath[0], ConvertIntKey): + elif len(node.keypath) == 1 and isinstance(node.keypath[0], ConvertIntKey): self.writeline(f"int64_t {node.sym} = {node.sym}_raw ? 1 : 0;") - elif len(node.keypath == 1) and isinstance(node.keypath[0], DivideByKey): + elif len(node.keypath) == 1 and isinstance(node.keypath[0], DivideByKey): # TODO: assert divisibility here self.writeline( f"int64_t {node.sym} = {node.sym}_raw / {node.keypath[0].divisor};" @@ -1493,10 +1235,7 @@ def make_free_by_names(self, names_to_del: List[str]): return " ".join(f"{name}.reset();" for name in names_to_del) def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): - if config.abi_compatible: - return f"auto {new_name} = std::move({old_name}); // reuse" - else: - return super().codegen_exact_buffer_reuse(old_name, new_name, del_line) + return f"auto {new_name} = std::move({old_name}); // reuse" def generate_profiler_mark_wrapper_call(self, stack): self.wrapper_call.writeline( @@ -1521,35 +1260,22 @@ def generate_inf_and_nan_checker(self, nodes): ) def codegen_device(self, device): - if config.abi_compatible: - assert device.type in DEVICE_TO_ATEN, ( - device.type + " not found in DEVICE_TO_ATEN" - ) - device_str = DEVICE_TO_ATEN[device.type][5:].lower() # remove "at::k" - self.used_cached_devices.add(device_str) - return f"cached_torch_device_type_{device_str}, {device.index if device.index else 0}" - else: - return ( - f"c10::Device({DEVICE_TO_ATEN[device.type]}, {device.index})" - if device.index is not None - else f"{DEVICE_TO_ATEN[device.type]}" - ) + assert device.type in DEVICE_TO_ATEN, ( + device.type + " not found in DEVICE_TO_ATEN" + ) + device_str = DEVICE_TO_ATEN[device.type][5:].lower() # remove "at::k" + self.used_cached_devices.add(device_str) + return f"cached_torch_device_type_{device_str}, {device.index if device.index else 0}" def codegen_dtype(self, dtype): - if config.abi_compatible: - dtype_str = str(dtype).split(".")[-1] - self.used_cached_dtypes.add(dtype_str) - return f"cached_torch_dtype_{dtype_str}" - else: - return DTYPE_TO_ATEN[dtype] + dtype_str = str(dtype).split(".")[-1] + self.used_cached_dtypes.add(dtype_str) + return f"cached_torch_dtype_{dtype_str}" def codegen_layout(self, layout): - if config.abi_compatible: - layout_str = str(layout).split(".")[-1] - self.used_cached_layouts.add(layout_str) - return f"cached_torch_layout_{layout_str}" - else: - return LAYOUT_TO_ATEN[layout] + layout_str = str(layout).split(".")[-1] + self.used_cached_layouts.add(layout_str) + return f"cached_torch_layout_{layout_str}" @functools.lru_cache(None) # noqa: B019 def codegen_int_array_var( @@ -1594,96 +1320,60 @@ def make_allocation(self, name, device, dtype, shape, stride): dtype_code = self.codegen_dtype(dtype) size = self.codegen_shape_tuple(shape) stride = self.codegen_shape_tuple(orig_stride) - if config.abi_compatible: - size_array_var = self.codegen_int_array_var( - size, - self.wrapper_call, - known_statically=self.is_statically_known_list_of_ints(shape), - graph=self.get_codegened_graph(), - ) - stride_array_var = self.codegen_int_array_var( - stride, - self.wrapper_call, - known_statically=self.is_statically_known_list_of_ints(orig_stride), - graph=self.get_codegened_graph(), - ) - device_type, device_id = device_str.split(",") - device_idx = "this->device_idx_" if V.graph.aot_mode else device_id - - args = [ - str(len(shape)), - size_array_var, - stride_array_var, - dtype_code, - device_type, - device_idx, - f"&{name}_handle", - ] - - self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;") - self.wrapper_call.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" - ) - - return f"RAIIAtenTensorHandle {name}({name}_handle);" - - if V.graph.aot_mode and device_str.startswith("c10::Device("): - tensor_device = f"{device_str.split(',')[0]}, this->device_idx_)" - else: - tensor_device = device_str + size_array_var = self.codegen_int_array_var( + size, + self.wrapper_call, + known_statically=self.is_statically_known_list_of_ints(shape), + graph=self.get_codegened_graph(), + ) + stride_array_var = self.codegen_int_array_var( + stride, + self.wrapper_call, + known_statically=self.is_statically_known_list_of_ints(orig_stride), + graph=self.get_codegened_graph(), + ) + device_type, device_id = device_str.split(",") + device_idx = "this->device_idx_" if V.graph.aot_mode else device_id + + args = [ + str(len(shape)), + size_array_var, + stride_array_var, + dtype_code, + device_type, + device_idx, + f"&{name}_handle", + ] - if device.type == "cpu": - return f"at::Tensor {name} = at::detail::empty_strided_cpu({size}, {stride}, {dtype_code});" - if device.type == "cuda": - return ( - f"at::Tensor {name} = at::detail::empty_strided_cuda(" - f"{size}, {stride}, {dtype_code}, c10::DeviceType::CUDA);" - ) - if device.type == "xpu": - return ( - f"at::Tensor {name} = at::detail::empty_strided_xpu(" - f"{size}, {stride}, {dtype_code}, c10::DeviceType::XPU);" - ) - return ( - f"{self.declare}{name} = {self.namespace}empty_strided(" - f"{size}, {stride}, at::TensorOptions({tensor_device}).dtype({dtype_code})){self.ending}" + self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" ) + return f"RAIIAtenTensorHandle {name}({name}_handle);" + def codegen_alloc_from_pool(self, name, offset, dtype, shape, stride) -> str: - if config.abi_compatible: - size = self.codegen_shape_tuple(shape) - stride = self.codegen_shape_tuple(stride) - tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" - args = [ - name, - self.expr_printer(offset), # bytes not numel - self.codegen_dtype(dtype), - str(len(shape)), - self.codegen_int_array_var( - size, self.wrapper_call, graph=self.get_codegened_graph() - ), - self.codegen_int_array_var( - stride, self.wrapper_call, graph=self.get_codegened_graph() - ), - f"&{tmp_name}", - ] - self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};") - self.wrapper_call.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));" - ) - return f"RAIIAtenTensorHandle({tmp_name})" - - return "alloc_from_pool({})".format( - ", ".join( - [ - name, - self.expr_printer(offset), # bytes not numel - self.codegen_dtype(dtype), - self.codegen_shape_tuple(shape), - self.codegen_shape_tuple(stride), - ] - ) + size = self.codegen_shape_tuple(shape) + stride = self.codegen_shape_tuple(stride) + tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" + args = [ + name, + self.expr_printer(offset), # bytes not numel + self.codegen_dtype(dtype), + str(len(shape)), + self.codegen_int_array_var( + size, self.wrapper_call, graph=self.get_codegened_graph() + ), + self.codegen_int_array_var( + stride, self.wrapper_call, graph=self.get_codegened_graph() + ), + f"&{tmp_name}", + ] + self.wrapper_call.writeline(f"AtenTensorHandle {tmp_name};") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__alloc_from_pool({', '.join(args)}));" ) + return f"RAIIAtenTensorHandle({tmp_name})" def codegen_reinterpret_view( self, data, size_list, stride_list, offset, writer, dtype=None @@ -1698,60 +1388,47 @@ def codegen_reinterpret_view( final_tmp_name_is_RAIIAtenTensorHandle = False def create_reinterpret_call() -> Tuple[str, str]: - if config.abi_compatible: - tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" - args = [ - f"{data.get_name()}", - dim, - self.codegen_int_array_var( - size, - writer, - known_statically=self.is_statically_known_list_of_ints( - size_list - ), - graph=self.get_codegened_graph(), - ), - self.codegen_int_array_var( - stride, - writer, - known_statically=self.is_statically_known_list_of_ints( - stride_list - ), - graph=self.get_codegened_graph(), - ), - offset, - ] - call_str = ( - f"auto {tmp_name} = reinterpret_tensor_wrapper({', '.join(args)});" - ) - return tmp_name, call_str - else: - args = [data.get_name(), size, stride, offset] - return f"reinterpret_tensor({', '.join(args)})", "" + tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" + args = [ + f"{data.get_name()}", + dim, + self.codegen_int_array_var( + size, + writer, + known_statically=self.is_statically_known_list_of_ints(size_list), + graph=self.get_codegened_graph(), + ), + self.codegen_int_array_var( + stride, + writer, + known_statically=self.is_statically_known_list_of_ints(stride_list), + graph=self.get_codegened_graph(), + ), + offset, + ] + call_str = ( + f"auto {tmp_name} = reinterpret_tensor_wrapper({', '.join(args)});" + ) + return tmp_name, call_str def create_dtypeview_call(reinterpret_call: str) -> Tuple[str, List[str]]: - if config.abi_compatible: - tmp_AtenTensorHandle = ( - f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" - ) - call_strs = [f"AtenTensorHandle {tmp_AtenTensorHandle};"] - dtype_name = str(dtype).split(".")[-1] - device_name = data.layout.device.type - get_dtype_function = f"aoti_torch_dtype_{dtype_name}" - dtypeview_function = f"aoti_torch_{device_name}_view_dtype" - call_strs.append( - f"AOTI_TORCH_ERROR_CODE_CHECK({dtypeview_function}" - f"({reinterpret_call}, {get_dtype_function}(), &{tmp_AtenTensorHandle}));" - ) - tmp_RAIIAtenTensorHandle = ( - f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}_handle" - ) - call_strs.append( - f"RAIIAtenTensorHandle {tmp_RAIIAtenTensorHandle}({tmp_AtenTensorHandle});" - ) - return tmp_RAIIAtenTensorHandle, call_strs - else: - return f"{reinterpret_call}.view({self.codegen_dtype(dtype)})", [] + tmp_AtenTensorHandle = f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" + call_strs = [f"AtenTensorHandle {tmp_AtenTensorHandle};"] + dtype_name = str(dtype).split(".")[-1] + device_name = data.layout.device.type + get_dtype_function = f"aoti_torch_dtype_{dtype_name}" + dtypeview_function = f"aoti_torch_{device_name}_view_dtype" + call_strs.append( + f"AOTI_TORCH_ERROR_CODE_CHECK({dtypeview_function}" + f"({reinterpret_call}, {get_dtype_function}(), &{tmp_AtenTensorHandle}));" + ) + tmp_RAIIAtenTensorHandle = ( + f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}_handle" + ) + call_strs.append( + f"RAIIAtenTensorHandle {tmp_RAIIAtenTensorHandle}({tmp_AtenTensorHandle});" + ) + return tmp_RAIIAtenTensorHandle, call_strs if ( size_list == data.layout.size @@ -1783,124 +1460,111 @@ def create_dtypeview_call(reinterpret_call: str) -> Tuple[str, List[str]]: # of self.generate), the writeline behavior is different in the two passes. writer.writelines(call_strs) - if config.abi_compatible: - # NB, the return handle here represents a temporary tensor, which will be automatically - # released. - # Here's a sample usage in the cpp wrapper code: - # ``` - # aoti_torch_addmm_out( - # buf1, - # arg1_1, - # RAIIAtenTensorHandle(tmp_tensor_handle_0), - # buf0, - # 1L, - # 1L)); - # ``` - # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out. - # This could be problematic when it's used in a different pattern, for example: - # ```` - # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}; - # aoti_torch_proxy_executor_call_function(..., tensor_args); - # ```` - # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter - # kernel call. - # - # This is solved by updating the proxy_executor invocation to - # ``` - # aoti_torch_proxy_executor_call_function(..., - # std::vector{ - # RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6 - # }.data() - # ); - # ``` - if not final_tmp_name_is_RAIIAtenTensorHandle: - return f"wrap_with_raii_handle_if_needed({final_tmp_name})" - else: - return final_tmp_name + # NB, the return handle here represents a temporary tensor, which will be automatically + # released. + # Here's a sample usage in the cpp wrapper code: + # ``` + # aoti_torch_addmm_out( + # buf1, + # arg1_1, + # RAIIAtenTensorHandle(tmp_tensor_handle_0), + # buf0, + # 1L, + # 1L)); + # ``` + # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out. + # This could be problematic when it's used in a different pattern, for example: + # ```` + # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}; + # aoti_torch_proxy_executor_call_function(..., tensor_args); + # ```` + # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter + # kernel call. + # + # This is solved by updating the proxy_executor invocation to + # ``` + # aoti_torch_proxy_executor_call_function(..., + # std::vector{ + # RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6 + # }.data() + # ); + # ``` + if not final_tmp_name_is_RAIIAtenTensorHandle: + return f"wrap_with_raii_handle_if_needed({final_tmp_name})" else: return final_tmp_name def codegen_device_copy(self, src, dst, non_blocking: bool): - if config.abi_compatible: - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));" - ) - else: - self.writeline(f"{dst}.copy_({src}, {non_blocking});") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));" + ) def codegen_multi_output(self, name, value): # in the abi_compatible mode, outputs are retrieved by passing # output pointers, so we skip its codegen here. - if not config.abi_compatible: - super().codegen_multi_output(name, value) + pass def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): - for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs): - if config.abi_compatible: - # in ABI-compatible mode, we copy the underlying at::Tensor of the conditional - # input (outer_input) into another at::Tensor to be used as a subgraph input - # (inner_input) in the nested scope. we can't std::move here, as the codegened - # outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we - # can't necessarily std::move it back to the origin (x). - self.writeline(f"AtenTensorHandle {inner_input}_handle;") - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));" - ) - self.writeline( - f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);" - ) - else: - self.writeline( - f"{self.declare}{inner_input} = {outer_input}{self.ending}" - ) + assert len(subgraph.graph.graph_inputs) == len(outer_inputs) + + for (inner_input, inner_input_val), outer_input in zip( + subgraph.graph.graph_inputs.items(), outer_inputs + ): + if not isinstance(inner_input_val, ir.TensorBox): + continue + + # in ABI-compatible mode, we copy the underlying at::Tensor of the conditional + # input (outer_input) into another at::Tensor to be used as a subgraph input + # (inner_input) in the nested scope. we can't std::move here, as the codegened + # outer input may be an expression / rvalue (e.g., reinterpret_view(x)), so we + # can't necessarily std::move it back to the origin (x). + self.writeline(f"AtenTensorHandle {inner_input}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({outer_input}, &{inner_input}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {inner_input}({inner_input}_handle);") def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): for inner_output, outer_output in zip( subgraph.graph.graph_outputs, outer_outputs ): src = inner_output.codegen_reference() - if config.abi_compatible: - # in ABI-compatible mode, we need to std::move subgraph output (inner_output) - # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy - # constructor is deleted. - src = f"std::move({src})" - # in case the outer_output carried a value - # before (e.g., in the while_loop codegen) - self.writeline(f"{outer_output}.reset();") + # in ABI-compatible mode, we need to std::move subgraph output (inner_output) + # to the conditional output (outer_output), as RAIIAtenTensorHandle's copy + # constructor is deleted. + src = f"std::move({src})" + # in case the outer_output carried a value + # before (e.g., in the while_loop codegen) + self.writeline(f"{outer_output}.reset();") self.writeline(f"{outer_output} = {src}{self.ending}") + def codegen_invoke_subgraph(self, invoke_subgraph): + raise NotImplementedError( + "codegen invoke_subgraph is not implemented for cpp wrapper" + ) + def codegen_conditional(self, conditional): name = conditional.get_name() outer_inputs = [f"{buf.codegen_reference()}" for buf in conditional.operands] - if config.abi_compatible: - outer_outputs = [] - for out in conditional.outputs: - # in ABI-compatible mode, ir.MultiOutput is not codegened, - # hence pre-declare output variables directly and separately - self.writeline(f"RAIIAtenTensorHandle {out.get_name()};") - outer_outputs.append(out.get_name()) - - if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): - # in ABI-compatible mode, we need to use the ABI shim function - # to extract a C++ bool from the unrelying scalar bool Tensor - predicate = f"{conditional.predicate.get_name()}_scalar" - self.codegen_tensor_item( - torch.bool, - conditional.predicate.codegen_reference(), - predicate, - ) - else: - # the predicate is not a Tensor: SymBool or Python bool - predicate = conditional.predicate.codegen_reference() + outer_outputs = [] + for out in conditional.outputs: + # in ABI-compatible mode, ir.MultiOutput is not codegened, + # hence pre-declare output variables directly and separately + self.writeline(f"RAIIAtenTensorHandle {out.get_name()};") + outer_outputs.append(out.get_name()) + + if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): + # in ABI-compatible mode, we need to use the ABI shim function + # to extract a C++ bool from the unrelying scalar bool Tensor + predicate = f"{conditional.predicate.get_name()}_scalar" + self.codegen_tensor_item( + torch.bool, + conditional.predicate.codegen_reference(), + predicate, + ) else: - # in non-ABI-compatible mode, we can codegen the conditional outputs - # as array of at::Tensor instances, as the ir.MultiOutput is codegened - outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] - self.writeline(f"at::Tensor {name}[{len(conditional.outputs)}];") - predicate = f"{conditional.predicate.codegen_reference()}" - if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer): - # move the Tensor predicate to host - predicate = f"{predicate}.item()" + # the predicate is not a Tensor: SymBool or Python bool + predicate = conditional.predicate.codegen_reference() self.writeline(f"if ({predicate}) {{") self.writeline(EnterSubgraphLine(self, conditional.true_subgraph.graph)) @@ -1912,6 +1576,25 @@ def codegen_conditional(self, conditional): self.writeline(ExitSubgraphLine(self)) self.writeline("}") + def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs): + # TODO (desertfire) - This function is the old way of supporting + # subgraph codegen by inlining subgraphs in the output code. For python + # wrapper, we have moved to lifting subgraphs as functions, supported by + # PythonWrapperCode `codegen_subgraph` function. We should perhaps + # support lifting of subgraphs as functions for cpp wrapper as well. + try: + self.push_codegened_graph(subgraph.graph) + self.writeline(f"{self.comment} subgraph: {subgraph.name}") + self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs) + parent_graph = V.graph + with V.set_graph_handler(subgraph.graph): + subgraph.graph.codegen_subgraph( + parent_graph=parent_graph, + ) + self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs) + finally: + self.pop_codegened_graph() + def codegen_while_loop(self, while_loop): name = while_loop.get_name() outer_carried_inputs = [ @@ -1921,38 +1604,25 @@ def codegen_while_loop(self, while_loop): buf.codegen_reference() for buf in while_loop.additional_inputs ] cond_result_name = f"{name}_cond_result" + self.writeline(f"RAIIAtenTensorHandle {cond_result_name};") + + cond_outer_inputs = [] + for inp, out in zip(outer_carried_inputs, while_loop.outputs): + # in ABI-compatible mode, the carried inputs are codegened + # as buffers outside the while loop and set to the initial + # values. at the end of each while_loop iteration, they + # will be assined the carried values. + out_name = out.get_name() + self.writeline(f"AtenTensorHandle {out_name}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({inp}, &{out_name}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {out_name}({out_name}_handle);") + cond_outer_inputs.append(out_name) - if config.abi_compatible: - self.writeline(f"RAIIAtenTensorHandle {cond_result_name};") - - cond_outer_inputs = [] - for inp, out in zip(outer_carried_inputs, while_loop.outputs): - # in ABI-compatible mode, the carried inputs are codegened - # as buffers outside the while loop and set to the initial - # values. at the end of each while_loop iteration, they - # will be assined the carried values. - out_name = out.get_name() - self.writeline(f"AtenTensorHandle {out_name}_handle;") - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors_out({inp}, &{out_name}_handle));" - ) - self.writeline(f"RAIIAtenTensorHandle {out_name}({out_name}_handle);") - cond_outer_inputs.append(out_name) - - # additional inputs will be assinged within the while_loop - # iteration directly from the corresponding outer graph buffers - cond_outer_inputs.extend(outer_additional_inputs) - else: - self.writeline(f"at::Tensor {cond_result_name};") - self.writeline(f"at::Tensor {name}[{len(outer_carried_inputs)}];") - for i, inp in enumerate(outer_carried_inputs): - # set the initial state before the loop - self.writeline(f"{name}[{i}] = {inp};") - - cond_outer_inputs = [ - *[f"{name}[{i}]" for i in range(len(outer_carried_inputs))], - *outer_additional_inputs, - ] + # additional inputs will be assinged within the while_loop + # iteration directly from the corresponding outer graph buffers + cond_outer_inputs.extend(outer_additional_inputs) cond_outer_outputs = [cond_result_name] body_outer_inputs = list(cond_outer_inputs) @@ -1964,11 +1634,8 @@ def codegen_while_loop(self, while_loop): while_loop.cond_subgraph, cond_outer_inputs, cond_outer_outputs ) - if config.abi_compatible: - cond_result = f"{cond_result_name}_scalar" - self.codegen_tensor_item(torch.bool, cond_result_name, cond_result) - else: - cond_result = f"{cond_result_name}.item()" + cond_result = f"{cond_result_name}_scalar" + self.codegen_tensor_item(torch.bool, cond_result_name, cond_result) self.writeline(f"if (!{cond_result}) break;") self.writeline(ExitSubgraphLine(self)) @@ -1980,7 +1647,11 @@ def codegen_while_loop(self, while_loop): self.writeline("}") def generate_extern_kernel_args_decl_if_needed( - self, op_overload, raw_args, output_args + self, + op_overload, + raw_args, + output_args: Optional[List[str]] = None, + raw_outputs: Optional[List[ir.Buffer]] = None, ): arg_types = [x.real_type for x in op_overload._schema.arguments] return_types = [x.type for x in op_overload._schema.returns] @@ -2068,13 +1739,14 @@ def fill_args(arg, arg_type): else: fill_args(arg, arg_type) - def fill_output_arg(arg, return_type): + def fill_output_arg(arg, return_type, is_mutated_output: bool): if isinstance(return_type, torch.TensorType): - self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer") - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));" - ) - self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);") + if not is_mutated_output: + self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);") new_tensor_args.append(f"{arg}") elif isinstance(return_type, torch.SymIntType): raise NotImplementedError("NYI support for return type: SymInt") @@ -2098,13 +1770,21 @@ def fill_output_arg(arg, return_type): f"return type {return_type} is not yet supported." ) - for output_arg in output_args: + for output_arg, raw_output_arg in zip(output_args, raw_outputs): # type: ignore[arg-type] assert output_arg is not None, "Optional return types are not yet supported" if isinstance(output_arg, (list, tuple)): for out in output_arg: - fill_output_arg(out, torch.TensorType.get()) + fill_output_arg( + out, + torch.TensorType.get(), + isinstance(raw_output_arg, ir.MutationOutput), + ) else: - fill_output_arg(output_arg, torch.TensorType.get()) + fill_output_arg( + output_arg, + torch.TensorType.get(), + isinstance(raw_output_arg, ir.MutationOutput), + ) return new_tensor_args, new_int_args @@ -2126,6 +1806,12 @@ def extract_output_name(out): return None elif isinstance(out, (ir.MultiOutput, ir._CollectiveKernel)): return out.get_name() + elif isinstance(out, ir.MutationOutput): + mutated_buf_names = out.get_mutation_names() + assert ( + isinstance(mutated_buf_names, list) and len(mutated_buf_names) == 1 + ), "Expect only one mutated buffer in MutationOutput" + return mutated_buf_names[0] elif isinstance(out, (list, tuple)): return type(out)(extract_output_name(o) for o in out) else: @@ -2140,7 +1826,7 @@ def extract_output_name(out): if isinstance(output_args, str): output_args = [output_args] - if V.graph.aot_mode and config.abi_compatible: + if V.graph.aot_mode: assert op_overload is not None assert raw_args is not None assert output_args is not None @@ -2150,6 +1836,7 @@ def extract_output_name(out): op_overload, raw_args, output_args, + outputs, ) else: return self.generate_extern_kernel_alloc_and_find_schema_if_needed_jit( @@ -2163,6 +1850,7 @@ def extract_output_name(out): op_overload, raw_args, output_args, + outputs, ) def generate_scoped_gil_acquire(self, declarations_before_scope, lines_in_scope): @@ -2306,78 +1994,71 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_jit( op_overload: Optional[torch._ops.OpOverload] = None, raw_args=None, output_args: Optional[List[str]] = None, + raw_outputs: Optional[List[ir.Buffer]] = None, ): - if not config.abi_compatible: - # Will update this to use an OSS version ProxyExecutor - if cpp_kernel_key not in self.extern_call_ops: - self.writeline( - f"static auto op_{cpp_kernel_key} = c10::Dispatcher::singleton()" - ) - self.writeline( - f'\t.findSchemaOrThrow("{cpp_kernel_name}", "{cpp_kernel_overload_name}")' - ) - self.writeline(f"\t.typed<{cpp_op_schema}>();") - self.extern_call_ops.add(cpp_kernel_key) - - self.writeline( - f"auto {buf_name} = op_{cpp_kernel_key}.call({', '.join(codegen_args)});" - ) - else: - # In the JIT mode, because of the ABI-compatible requirement, we can't directly call - # c10::Dispatcher to find the custom op and call it. Instead, we go back to Python - # to invoke this custom op. - self.load_custom_op_wrapper() - - assert output_args is not None, "output_args should not be None" - num_args = len(raw_args) - py_args_var = f"py_args_{next(self.arg_var_id)}" - # First arg is always the python op name - lines = f""" + # In the JIT mode, because of the ABI-compatible requirement, we can't directly call + # c10::Dispatcher to find the custom op and call it. Instead, we go back to Python + # to invoke this custom op. + self.load_custom_op_wrapper() + + assert output_args is not None, "output_args should not be None" + num_args = len(raw_args) + py_args_var = f"py_args_{next(self.arg_var_id)}" + # First arg is always the python op name + lines = f""" RAIIPyObject {py_args_var}(PyTuple_New({num_args+1})); if ({py_args_var}.get() == NULL) {{ - throw std::runtime_error("PyTuple_New {py_args_var} failed"); +throw std::runtime_error("PyTuple_New {py_args_var} failed"); }} PyTuple_SetItem({py_args_var}, 0, PyUnicode_FromString("{python_kernel_name}")); """ - assert op_overload is not None, "op_overload should not be None" + assert op_overload is not None, "op_overload should not be None" - for idx, (raw_arg, schema_arg) in enumerate( - zip(raw_args, op_overload._schema.arguments) - ): - lines += self.generate_py_arg( - py_args_var, idx + 1, raw_arg, schema_arg.real_type - ) + for idx, (raw_arg, schema_arg) in enumerate( + zip(raw_args, op_overload._schema.arguments) + ): + lines += self.generate_py_arg( + py_args_var, idx + 1, raw_arg, schema_arg.real_type + ) - lines += f""" + lines += f""" // Call the custom op in Python RAIIPyObject py_{buf_name}(PyObject_CallObject(custom_op_wrapper, {py_args_var})); if (py_{buf_name}.get() == NULL) {{ - throw std::runtime_error("PyObject_CallObject {python_kernel_name} failed"); +throw std::runtime_error("PyObject_CallObject {python_kernel_name} failed"); }}""" - if len(output_args) == 1: - # result is a single tensor - lines += f""" + if len(output_args) == 1: + # result is a single tensor + lines += f""" {output_args[0]} = reinterpret_cast(PyCapsule_GetPointer(py_{buf_name}.get(), NULL));""" - else: - # result is a tuple of tensors - for idx, output_arg in enumerate(output_args): - if output_arg is None: - continue - lines += f""" + else: + # result is a tuple of tensors + for idx, output_arg in enumerate(output_args): + if output_arg is None: + continue + lines += f""" {output_arg} = - reinterpret_cast(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));""" +reinterpret_cast(PyCapsule_GetPointer(PyList_GET_ITEM(py_{buf_name}.get(), {idx}), NULL));""" + if raw_outputs: declarations_before_scope = [ f"RAIIAtenTensorHandle {output_arg};" - for output_arg in output_args + for output_arg, raw_output_arg in zip(output_args, raw_outputs) # type: ignore[arg-type] if output_arg is not None + and not isinstance(raw_output_arg, ir.MutationOutput) ] - scope_gil_acquire = self.generate_scoped_gil_acquire( - declarations_before_scope, lines - ) - self.writelines(scope_gil_acquire) + else: + declarations_before_scope = [ + f"RAIIAtenTensorHandle {output_arg};" + for output_arg in output_args # type: ignore[arg-type] + if output_arg is not None + ] + scope_gil_acquire = self.generate_scoped_gil_acquire( + declarations_before_scope, lines + ) + self.writelines(scope_gil_acquire) def generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor( self, @@ -2385,12 +2066,16 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_with_proxy_executor( op_overload, raw_args, # contains both args and flatten kwargs output_args: Optional[List[str]] = None, + raw_outputs: Optional[List[ir.Buffer]] = None, ): ( tensor_call_args, int_call_args, ) = self.generate_extern_kernel_args_decl_if_needed( - op_overload, raw_args, output_args + op_overload, + raw_args, + output_args, + raw_outputs, ) tensor_call_args_str = ", ".join(tensor_call_args) @@ -2416,9 +2101,6 @@ def generate_save_uncompiled_kernels(self): pass def c_type_for_prim_type(self, val, type_) -> str: - assert ( - config.abi_compatible - ), "c_type_for_prim_type is only used in ABI compatible mode" if isinstance(type_, torch.OptionalType): return f"{self.c_type_for_prim_type(val, type_.getElementType())}*" elif isinstance(type_, torch.TensorType): @@ -2451,10 +2133,7 @@ def c_type_for_prim_type(self, val, type_) -> str: def val_to_arg_str_for_prim_type(self, val, type_) -> str: # TODO: not using type_ as the first step of refactoring. Will update this later. if isinstance(val, bool): - if config.abi_compatible: - return "1" if val else "0" - else: - return "true" if val else "false" + return "1" if val else "0" elif isinstance(val, int): # uint64_t is long on Linux, but long long on MacOS and Windows return f"{val}LL" if sys.platform in ["darwin", "win32"] else f"{val}L" @@ -2483,101 +2162,84 @@ def val_to_arg_str_for_prim_type(self, val, type_) -> str: def val_to_arg_str(self, val, type_=None) -> str: if val is None: # None needs special care. It either represent nullopt or an empty tensor - if config.abi_compatible: - if type_ is None or isinstance(type_, torch.OptionalType): - if type_ is not None and isinstance( - type_.getElementType(), - ( - torch.ListType, - torch.TupleType, - torch.DeviceObjType, - ), - ): - return "0, 0" - else: - return "0" # nullptr is not available in C - elif isinstance(type_, torch.TensorType): - # create an empty tensor, the equivalent of at::Tensor() - var_name = f"var_{next(self.arg_var_id)}" - self.writeline(f"AtenTensorHandle {var_name}_handle;") - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));" - ) - self.writeline( - f"RAIIAtenTensorHandle {var_name}({var_name}_handle);" - ) - return var_name + if type_ is None or isinstance(type_, torch.OptionalType): + if type_ is not None and isinstance( + type_.getElementType(), + ( + torch.ListType, + torch.TupleType, + torch.DeviceObjType, + ), + ): + return "0, 0" else: - raise AssertionError("Can not map None to a known data type") + return "0" # nullptr is not available in C + elif isinstance(type_, torch.TensorType): + # create an empty tensor, the equivalent of at::Tensor() + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"AtenTensorHandle {var_name}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {var_name}({var_name}_handle);") + return var_name else: - return "std::nullopt" + raise AssertionError("Can not map None to a known data type") if isinstance(type_, torch.OptionalType): element_type = type_.getElementType() - if config.abi_compatible: - if not isinstance(element_type, torch.TensorType): - var_name = f"var_{next(self.arg_var_id)}" - if isinstance( - element_type, - (torch.ListType, torch.TupleType, torch.DeviceObjType), - ): - # type_ is something like Optional[List] or Optional[Device] - arg_str = self.val_to_arg_str(val, element_type) - # For datatypes with auxiliary info, we need to hoist out the extra arguments. - # NOTE: This only works if there is one additional argument, though it can easily be generalized. - main_value, aux = arg_str.rsplit(", ") - self.writeline(f"auto {var_name} = {main_value};") - return f"&{var_name}, {aux}" - else: - self.writeline( - f"{self.c_type_for_prim_type(val, element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};" - ) - return f"&{var_name}" + if not isinstance(element_type, torch.TensorType): + var_name = f"var_{next(self.arg_var_id)}" + if isinstance( + element_type, + (torch.ListType, torch.TupleType, torch.DeviceObjType), + ): + # type_ is something like Optional[List] or Optional[Device] + arg_str = self.val_to_arg_str(val, element_type) + # For datatypes with auxiliary info, we need to hoist out the extra arguments. + # NOTE: This only works if there is one additional argument, though it can easily be generalized. + main_value, aux = arg_str.rsplit(", ") + self.writeline(f"auto {var_name} = {main_value};") + return f"&{var_name}, {aux}" else: - # type_ is Optional[Tensor] - # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim - base_handle = self.val_to_arg_str(val, element_type) - if config.use_minimal_arrayref_interface: - base_handle = ( - f"convert_arrayref_tensor_to_tensor({base_handle})" - ) - ( - tmp_raii_handle_var, - tmp_raii_handle_var_decl, - ) = self.create_tmp_raii_handle_var(base_handle) - if tmp_raii_handle_var: - self.writeline(tmp_raii_handle_var_decl) - base_handle = tmp_raii_handle_var - var_name = f"var_{next(self.arg_var_id)}" self.writeline( - f"AtenTensorHandle {var_name} = {base_handle}.get();" + f"{self.c_type_for_prim_type(val, element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};" ) return f"&{var_name}" else: - return self.val_to_arg_str(val, element_type) + # type_ is Optional[Tensor] + # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim + base_handle = self.val_to_arg_str(val, element_type) + ( + tmp_raii_handle_var, + tmp_raii_handle_var_decl, + ) = self.create_tmp_raii_handle_var(base_handle) + if tmp_raii_handle_var: + self.writeline(tmp_raii_handle_var_decl) + base_handle = tmp_raii_handle_var + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"AtenTensorHandle {var_name} = {base_handle}.get();") + return f"&{var_name}" elif isinstance(type_, torch.ListType): assert isinstance( val, (list, tuple) ), f"{val} does not match with arg type {type_}" element_type = type_.getElementType() - if config.abi_compatible: - var_name = f"var_array_{next(self.var_array_id)}" - if len(val) == 0: - # Zero-size array is not supported in the C or C++ standard, so - # we declare a null pointer for it. - self.writeline( - f"const {self.c_type_for_prim_type(None, element_type)}* {var_name} = nullptr;" - ) - else: - result = f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" - self.writeline( - f"const {self.c_type_for_prim_type(val[0], element_type)} {var_name}[] = {result};" - ) - # Need to pass the array length because we can't use std::vector - return f"{var_name}, {len(val)}" + var_name = f"var_array_{next(self.var_array_id)}" + if len(val) == 0: + # Zero-size array is not supported in the C or C++ standard, so + # we declare a null pointer for it. + self.writeline( + f"const {self.c_type_for_prim_type(None, element_type)}* {var_name} = nullptr;" + ) else: - return f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" + result = f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" + self.writeline( + f"const {self.c_type_for_prim_type(val[0], element_type)} {var_name}[] = {result};" + ) + # Need to pass the array length because we can't use std::vector + return f"{var_name}, {len(val)}" return self.val_to_arg_str_for_prim_type(val, type_) diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py index 40dacf4dba085..4cdff622dd646 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu_array_ref.py @@ -2,19 +2,24 @@ from itertools import count from typing import Dict, List, Optional, Tuple +import sympy + import torch import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools import torch._ops from .. import config, ir +from ..utils import sympy_product from ..virtualized import V from .cpp_utils import cexpr, DTYPE_TO_CPP from .cpp_wrapper_cpu import CppWrapperCpu from .wrapper import ( + BufferLike, EnterSubgraphLine, ExitSubgraphLine, MemoryPlanningLine, MemoryPlanningState, + PythonWrapperCodegen, ) @@ -47,7 +52,7 @@ def __init__(self): self.closed_bracket = "}" self.comment = "//" self.namespace = "at::" - self.none_str = "nullptr" if config.abi_compatible else "at::Tensor()" + self.none_str = "nullptr" self.extern_call_ops = set() self.size = "sizes()" self.stride = "strides()" @@ -70,7 +75,366 @@ def __init__(self): self.custom_op_wrapper_loaded = False self.expr_printer = cexpr self.allow_stack_allocation: Optional[bool] = None - self.stack_allocated_buffers: Dict[BufferName, ir.Buffer] = {} + self.stack_allocated_buffers: Dict[BufferName, BufferLike] = {} + + @staticmethod + def create( + is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + ): + # TODO - support subgraph codegen by lifting functions. Check the + # comment at CppWrapperCpu `codegen_subgraph` function. + return CppWrapperCpuArrayRef() + + @staticmethod + def get_input_cpp_type(input): + assert config.use_minimal_arrayref_interface + + if isinstance(input, sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype(input) + assert dtype is not None, f"Failed to get the dtype of sympy.Expr: {input}" + return DTYPE_TO_CPP[dtype] + return f"ArrayRefTensor<{DTYPE_TO_CPP[input.get_dtype()]}>" + + def codegen_input_numel_asserts(self): + for name, buf in V.graph.graph_inputs.items(): + if isinstance(buf, sympy.Expr): + continue + + # comparing strides for 0 size tensor is tricky. Ignore them for now. + if sympy_product(buf.get_size()) == 0: + continue + numel = buf.get_numel() + self.prefix.writeline(f"assert_numel({name}, {numel});") + + def write_wrapper_decl(self): + inputs_len = len(V.graph.graph_inputs.keys()) + if V.graph.aot_mode: + if config.use_minimal_arrayref_interface and not V.graph.is_const_graph: + input_cpp_types = ", ".join( + f"{CppWrapperCpuArrayRef.get_input_cpp_type(x)}" + for x in V.graph.graph_inputs.values() + ) + output_arrayref_types = ", ".join( + f"ArrayRefTensor<{DTYPE_TO_CPP[x.get_dtype()]}>" + for x in V.graph.graph_outputs + ) + + self.prefix.splice( + f""" + using AOTInductorModelInputs = std::tuple<{input_cpp_types}>; + using AOTInductorModelOutputs = std::tuple<{output_arrayref_types}>; + """ + ) + + if V.graph.const_module: + self.header.splice(V.graph.const_module.wrapper_code.header) + self.prefix.splice(V.graph.const_code) + + if V.graph.is_const_graph: + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + else: + if not config.aot_inductor.use_runtime_constant_folding: + # If we do not split the constant graph, we'll just create + # an empty implementation when wrapping the main module. + self.prefix.splice( + """ + void AOTInductorModel::_const_run_impl( + std::vector& output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) {} + + """ + ) + + run_impl_proto = """ + void AOTInductorModel::run_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + if config.aot_inductor.debug_compile: + self.generate_input_output_runtime_checks() + run_impl_proto += """ + __check_inputs_outputs(input_handles, output_handles); + """ + if config.use_minimal_arrayref_interface: + self.prefix.splice( + """ + template <> + AOTInductorModelOutputs AOTInductorModel::run_impl_minimal_arrayref_interface< + AOTInductorModelInputs, AOTInductorModelOutputs>( + const AOTInductorModelInputs& inputs, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor + ) { + """ + ) + self.suffix.splice(run_impl_proto) + self.suffix.splice( + """ + AOTInductorModelInputs inputs; + convert_handles_to_inputs(input_handles, inputs); + auto outputs = run_impl_minimal_arrayref_interface( + inputs, stream, proxy_executor); + // NOTE: outputs is full of ArrayRef to thread_local storage. If in the future we need this + // interface to perform well for a DSO using the minimal arrayref interface, all we need + // to do is provide ThreadLocalCachedTensor for each one! + convert_outputs_to_handles(outputs, output_handles); + } + """ + ) + + self.suffix.splice( + """ + extern "C" AOTIRuntimeError AOTInductorModelRunMinimalArrayrefInterface( + AOTInductorModelHandle model_handle, + const AOTInductorModelInputs& inputs, + AOTInductorModelOutputs& outputs) { + auto model = reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + outputs = model->run_impl_minimal_arrayref_interface( + inputs, + (torch::aot_inductor::DeviceStreamType)nullptr, + nullptr); + }) + } + """ + ) + else: + self.prefix.splice(run_impl_proto) + else: + # cpp entry function for JIT with cpp wrapper + self.prefix.splice( + """ + void inductor_entry_impl( + AtenTensorHandle* + input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle* + output_handles // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed) + ) { + """ + ) + with self.prefix.indent(): + # assign inputs and outputs in both cases so the later codegen can be simplified + if not config.use_minimal_arrayref_interface: + if not V.graph.is_const_graph: + if V.graph.aot_mode: + num_args = len(V.graph.graph_inputs) + else: + # Weights are promoted in the JIT mode + num_args = len(V.graph.graph_inputs) + len(V.graph.constants) + # release GIL to support multiple instances inference (in different threads of the same process) + self.prefix.splice("py::gil_scoped_release release;") + + self.prefix.splice( + f""" + auto inputs = steal_from_raw_handles_to_raii_handles(input_handles, {num_args}); + """ + ) + + if inputs_len != 0: + for idx, input_key in enumerate(V.graph.graph_inputs.keys()): + if config.use_minimal_arrayref_interface: + self.prefix.writeline( + f"auto {input_key} = std::get<{idx}>(inputs);" + ) + continue + # unwrap input tensor back to scalar + if isinstance(V.graph.graph_inputs[input_key], sympy.Expr): + from ..graph import may_get_constant_buffer_dtype + + dtype = may_get_constant_buffer_dtype( + V.graph.graph_inputs[input_key] # type: ignore[arg-type] + ) + assert ( + dtype is not None + ), "Fails to get the dtype of the sympy.Expr" + self.codegen_tensor_item( + dtype, f"inputs[{idx}]", input_key, self.prefix + ) + else: + self.prefix.writeline( + f"auto {input_key} = std::move(inputs[{idx}]);" + ) + + assert all( + isinstance(v, torch.Tensor) for v in list(V.graph.constants.values()) + ), "Expect all constants to be Tensor" + for idx, constants_key in enumerate(V.graph.constants.keys()): + if V.graph.aot_mode: + # Weights are stored in constants_ and owned by RAIIAtenTensorHandle there. + # Don't call std::move here because it will cause constants_ to lose the ownership. + self.prefix.writeline( + f"""auto {constants_key} = constants_->at({idx});""" + ) + else: + # Append constants as inputs to the graph + constants_idx = inputs_len + idx + self.prefix.writeline( + f"auto {constants_key} = std::move(inputs[{constants_idx}]);" + ) + + self.codegen_inputs(self.prefix, V.graph.graph_inputs) + + if V.graph.aot_mode: + if not V.graph.is_const_graph: + if config.use_minimal_arrayref_interface: + # TODO: input shape checking for regular tensor interface as well? + self.codegen_input_numel_asserts() + else: + self.prefix.writeline("inputs.clear();") + self.prefix.writeline( + "auto& kernels = static_cast(*this->kernels_.get());" + ) + + def generate_return(self, output_refs: List[str]): + cst_names = V.graph.constants.keys() + arr_iface = ( + not V.graph.is_const_graph and config.use_minimal_arrayref_interface + ) # For brevity. + + def use_thread_local_cached_output_tensor(idx, output): + cached_output_name = f"cached_output_{next(self.cached_output_id)}" + cache_type = "Array" if arr_iface else "Tensor" + self.wrapper_call.writeline( + f"thread_local ThreadLocalCachedOutput{cache_type}> " + f"{cached_output_name}({output});" + ) + if arr_iface: + self.wrapper_call.writeline( + f"{cached_output_name}.copy_data_from({output});" + ) + output_entry = f"std::get<{idx}>(output_arrayref_tensors)" + element_type = f"std::decay_t" + self.wrapper_call.writeline( + f"{output_entry} = {cached_output_name}.arrayref_tensor<{element_type}>();" + ) + else: + self.wrapper_call.writeline( + f"{cached_output_name}.copy_data_from({output});" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]));" + ) + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_assign_tensors({cached_output_name}.tensor(), " + f"output_handles[{idx}]));" + ) + + if arr_iface: + self.wrapper_call.writeline( + "AOTInductorModelOutputs output_arrayref_tensors;" + ) + + output2idx: Dict[str, int] = {} + for idx, output in enumerate(output_refs): + if output == self.none_str: + continue + + is_constant_buffer = output in cst_names + output_buffer = V.graph.graph_outputs[idx] + if isinstance(output_buffer, ir.BaseView): + output_storage = output_buffer.unwrap_view() + if isinstance(output_storage.data, ir.ConstantBuffer): + is_constant_buffer = True + + if isinstance(output_buffer, ir.ShapeAsConstantBuffer): + # Need to wrap scalar into tensor as the main function returns a vector of tensors + output_tensor = self.codegen_scalar_to_tensor(output) + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output_tensor}.release();" + ) + continue + + output_is_tensor_handle_expr = ( + f"std::is_same_v," + "RAIIAtenTensorHandle> || " + f"std::is_same_v," + "AtenTensorHandle> || " + f"std::is_same_v," + "ConstantHandle>" + ) + self.wrapper_call.writeline( + f"if constexpr ({output_is_tensor_handle_expr}) {{" + ) + with self.wrapper_call.indent(): + if arr_iface: + cached_output_name = f"cached_output_{next(self.cached_output_id)}" + output_value_type = f"std::decay_t(output_arrayref_tensors).data()[0])>" + self.wrapper_call.writeline( + f"thread_local RAIIAtenTensorHandle {cached_output_name};" + ) + if is_constant_buffer: + # NOTE(return_constant): In some rare cases where we return + # a constant, we have to return a copy of this constant, + # because (1) constants are not owned by the Model instance + # (2) constants remain the same cross inference runs, + # assuming they are not updated at runtime Basically, we + # cannot release or transfer the ownership of any original + # constant to the user. + self.wrapper_call.writeline( + f"AtenTensorHandle {cached_output_name}_tmp;" + ) + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &{cached_output_name}_tmp);" + ) + self.wrapper_call.writeline( + f"{cached_output_name} = {cached_output_name}_tmp;" + ) + else: + self.wrapper_call.writeline( + f"{cached_output_name} = {output}.release();" + ) + self.wrapper_call.writeline( + f"convert_handle_to_arrayref_tensor({cached_output_name}, " + f"std::get<{idx}>(output_arrayref_tensors));" + ) + else: + if is_constant_buffer: + # See NOTE(return_constant) above. + self.wrapper_call.writeline( + f"aoti_torch_clone({output}, &output_handles[{idx}]);" + ) + else: + if output in output2idx: + src_idx = output2idx[output] + self.wrapper_call.writeline( + f"output_handles[{idx}] = output_handles[{src_idx}];" + ) + else: + self.wrapper_call.writeline( + f"output_handles[{idx}] = {output}.release();" + ) + self.wrapper_call.writeline("} else {") + with self.wrapper_call.indent(): + use_thread_local_cached_output_tensor(idx, output) + self.wrapper_call.writeline("}") + + if output not in output2idx: + output2idx[output] = idx + if arr_iface: + self.wrapper_call.writeline("return output_arrayref_tensors;") def memory_plan(self): from .memory_planning import MemoryPlanner @@ -158,76 +522,53 @@ def make_allocation( dtype_code = self.codegen_dtype(dtype) size = self.codegen_shape_tuple(shape) stride = self.codegen_shape_tuple(orig_stride) - if config.abi_compatible: - size_array_var = self.codegen_int_array_var( - size, - self.wrapper_call, - known_statically=self.is_statically_known_list_of_ints(shape), - graph=self.get_codegened_graph(), - ) - stride_array_var = self.codegen_int_array_var( - stride, - self.wrapper_call, - known_statically=self.is_statically_known_list_of_ints(orig_stride), - graph=self.get_codegened_graph(), - ) - device_type, device_id = device_str.split(",") - device_idx = "this->device_idx_" if V.graph.aot_mode else device_id - if buffer_if_can_stack_allocate is not None: - self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate - cpp_type = DTYPE_TO_CPP[dtype] - numel = buffer_if_can_stack_allocate.get_numel() - # Note: we don't zero storage because empty_strided doesn't zero either. - self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];") - args = [ - f"{name}_storage", - size_array_var, - stride_array_var, - device_type, - device_idx, - ] - return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});" - + size_array_var = self.codegen_int_array_var( + size, + self.wrapper_call, + known_statically=self.is_statically_known_list_of_ints(shape), + graph=self.get_codegened_graph(), + ) + stride_array_var = self.codegen_int_array_var( + stride, + self.wrapper_call, + known_statically=self.is_statically_known_list_of_ints(orig_stride), + graph=self.get_codegened_graph(), + ) + device_type, device_id = device_str.split(",") + device_idx = "this->device_idx_" if V.graph.aot_mode else device_id + if buffer_if_can_stack_allocate is not None: + self.stack_allocated_buffers[name] = buffer_if_can_stack_allocate + cpp_type = DTYPE_TO_CPP[dtype] + numel = buffer_if_can_stack_allocate.get_numel() + # Note: we don't zero storage because empty_strided doesn't zero either. + self.wrapper_call.writeline(f"{cpp_type} {name}_storage[{numel}];") args = [ - str(len(shape)), + f"{name}_storage", size_array_var, stride_array_var, - dtype_code, device_type, device_idx, - f"&{name}_handle", ] - - self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;") - self.wrapper_call.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" - ) - - return f"RAIIAtenTensorHandle {name}({name}_handle);" - - if V.graph.aot_mode and device_str.startswith("c10::Device("): - tensor_device = f"{device_str.split(',')[0]}, this->device_idx_)" - else: - tensor_device = device_str - - if device.type == "cpu": - return f"at::Tensor {name} = at::detail::empty_strided_cpu({size}, {stride}, {dtype_code});" - if device.type == "cuda": - return ( - f"at::Tensor {name} = at::detail::empty_strided_cuda(" - f"{size}, {stride}, {dtype_code}, c10::DeviceType::CUDA);" - ) - if device.type == "xpu": - return ( - f"at::Tensor {name} = at::detail::empty_strided_xpu(" - f"{size}, {stride}, {dtype_code}, c10::DeviceType::XPU);" - ) - return ( - f"{self.declare}{name} = {self.namespace}empty_strided(" - f"{size}, {stride}, at::TensorOptions({tensor_device}).dtype({dtype_code})){self.ending}" + return f"ArrayRefTensor<{cpp_type}> {name}({', '.join(args)});" + + args = [ + str(len(shape)), + size_array_var, + stride_array_var, + dtype_code, + device_type, + device_idx, + f"&{name}_handle", + ] + + self.wrapper_call.writeline(f"AtenTensorHandle {name}_handle;") + self.wrapper_call.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_empty_strided({', '.join(args)}));" ) - def make_buffer_reuse(self, old: ir.Buffer, new: ir.Buffer, delete_old: bool): + return f"RAIIAtenTensorHandle {name}({name}_handle);" + + def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool): assert old.get_dtype() == new.get_dtype() old_name = old.get_name() new_name = new.get_name() @@ -291,22 +632,19 @@ def generate_scatter_fallback( # No stack allocation when there is a fallback op self.allow_stack_allocation = False - if config.abi_compatible: - # call the ABI shim function instead of the ATen one - cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name) - # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py - cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" - inputs_wrapped = [ - ( - f"convert_arrayref_tensor_to_tensor({x})" - if isinstance(x, str) - else str(x) - ) - for x in inputs - ] - line = f"{cpp_kernel_name}(convert_arrayref_tensor_to_tensor({output}), {','.join(inputs_wrapped)}" - else: - line = f"{cpp_kernel_name}({','.join(map(str, inputs))}" + # call the ABI shim function instead of the ATen one + cpp_kernel_name = self.get_c_shim_func_name(cpp_kernel_name) + # TODO: consider remove "_out" and add missing inplace variants to fallback_ops.py + cpp_kernel_name = cpp_kernel_name.replace("__", "_") + "_out" + inputs_wrapped = [ + ( + f"convert_arrayref_tensor_to_tensor({x})" + if isinstance(x, str) + else str(x) + ) + for x in inputs + ] + line = f"{cpp_kernel_name}(convert_arrayref_tensor_to_tensor({output}), {','.join(inputs_wrapped)}" if python_kernel_name.startswith("aten.scatter_reduce"): line += f", {','.join(kwargs)}" @@ -326,36 +664,28 @@ def generate_index_put_fallback(self, kernel, x, indices, values, accumulate): self.allow_stack_allocation = False # TODO: update aoti_torch_index_put_out in ir.py to use autogen out version - if config.abi_compatible: - # See the comment in codegen_reinterpret_view about why having something like - # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding - # tensor prematurely deallocated, thus this std::vector().data() trick here. - indices_str = ( - "std::vector{" - + ( - ", ".join( - [f"convert_arrayref_tensor_to_tensor({ind})" for ind in indices] - ) + # See the comment in codegen_reinterpret_view about why having something like + # RAIIAtenTensorHandle(tmp_tensor_handle_2) in a tmp array can cause the correponding + # tensor prematurely deallocated, thus this std::vector().data() trick here. + indices_str = ( + "std::vector{" + + ( + ", ".join( + [f"convert_arrayref_tensor_to_tensor({ind})" for ind in indices] ) - + "}.data()" ) - args = [ - f"convert_arrayref_tensor_to_tensor({x})", - indices_str, - str(len(indices)), - f"convert_arrayref_tensor_to_tensor({values})", - accumulate, - ] - args.insert( - 0, f"convert_arrayref_tensor_to_tensor({x})" - ) # set x as the output tensor, this fallback mutates x. - else: - indices_str = ( - f"{self.open_bracket}{', '.join(indices)}{self.closed_bracket}" - ) - args = [x, indices_str, values, accumulate] - args.insert(0, x) # set x as the output tensor, this fallback mutates - + + "}.data()" + ) + args = [ + f"convert_arrayref_tensor_to_tensor({x})", + indices_str, + str(len(indices)), + f"convert_arrayref_tensor_to_tensor({values})", + accumulate, + ] + args.insert( + 0, f"convert_arrayref_tensor_to_tensor({x})" + ) # set x as the output tensor, this fallback mutates x. self.writeline(self.wrap_kernel_call(kernel, args)) def generate_extern_kernel_alloc_and_find_schema_if_needed( @@ -386,16 +716,15 @@ def extract_output_name(out): # output_args has the same pytree structure as outputs output_args = None - if config.abi_compatible: - if outputs is None: - # outputs is not specified, the default is to write to buf_name - output_args = [buf_name] - else: - output_args = extract_output_name(outputs) - if isinstance(output_args, str): - output_args = [output_args] + if outputs is None: + # outputs is not specified, the default is to write to buf_name + output_args = [buf_name] + else: + output_args = extract_output_name(outputs) + if isinstance(output_args, str): + output_args = [output_args] - if V.graph.aot_mode and config.abi_compatible: + if V.graph.aot_mode: assert op_overload is not None assert raw_args is not None assert outputs is not None @@ -405,6 +734,7 @@ def extract_output_name(out): op_overload, raw_args, output_args, + outputs, ) else: return self.generate_extern_kernel_alloc_and_find_schema_if_needed_jit( @@ -418,19 +748,17 @@ def extract_output_name(out): op_overload, raw_args, output_args, + outputs, ) def codegen_device_copy(self, src, dst, non_blocking: bool): - if config.abi_compatible: - # aoti_torch_tensor_copy_ takes AtenTensorHandle as input, - # while stack-allocation results in ArrayRefTensor - # so disable stack allocation here - self.allow_stack_allocation = False - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));" - ) - else: - self.writeline(f"{dst}.copy_({src}, {non_blocking});") + # aoti_torch_tensor_copy_ takes AtenTensorHandle as input, + # while stack-allocation results in ArrayRefTensor + # so disable stack allocation here + self.allow_stack_allocation = False + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_copy_(expensive_copy_to_tensor_if_needed({dst}), {src}, {non_blocking}));" + ) def codegen_reinterpret_view( self, data, size_list, stride_list, offset, writer, dtype=None @@ -445,60 +773,47 @@ def codegen_reinterpret_view( final_tmp_name_is_RAIIAtenTensorHandle = False def create_reinterpret_call() -> Tuple[str, str]: - if config.abi_compatible: - tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" - args = [ - f"{data.get_name()}", - dim, - self.codegen_int_array_var( - size, - writer, - known_statically=self.is_statically_known_list_of_ints( - size_list - ), - graph=self.get_codegened_graph(), - ), - self.codegen_int_array_var( - stride, - writer, - known_statically=self.is_statically_known_list_of_ints( - stride_list - ), - graph=self.get_codegened_graph(), - ), - offset, - ] - call_str = ( - f"auto {tmp_name} = reinterpret_tensor_wrapper({', '.join(args)});" - ) - return tmp_name, call_str - else: - args = [data.get_name(), size, stride, offset] - return f"reinterpret_tensor({', '.join(args)})", "" + tmp_name = f"tmp_tensor_handle_{next(self.tmp_tensor_id)}" + args = [ + f"{data.get_name()}", + dim, + self.codegen_int_array_var( + size, + writer, + known_statically=self.is_statically_known_list_of_ints(size_list), + graph=self.get_codegened_graph(), + ), + self.codegen_int_array_var( + stride, + writer, + known_statically=self.is_statically_known_list_of_ints(stride_list), + graph=self.get_codegened_graph(), + ), + offset, + ] + call_str = ( + f"auto {tmp_name} = reinterpret_tensor_wrapper({', '.join(args)});" + ) + return tmp_name, call_str def create_dtypeview_call(reinterpret_call: str) -> Tuple[str, List[str]]: - if config.abi_compatible: - tmp_AtenTensorHandle = ( - f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" - ) - call_strs = [f"AtenTensorHandle {tmp_AtenTensorHandle};"] - dtype_name = str(dtype).split(".")[-1] - device_name = data.layout.device.type - get_dtype_function = f"aoti_torch_dtype_{dtype_name}" - dtypeview_function = f"aoti_torch_{device_name}_view_dtype" - call_strs.append( - f"AOTI_TORCH_ERROR_CODE_CHECK({dtypeview_function}" - f"({reinterpret_call}, {get_dtype_function}(), &{tmp_AtenTensorHandle}));" - ) - tmp_RAIIAtenTensorHandle = ( - f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}_handle" - ) - call_strs.append( - f"RAIIAtenTensorHandle {tmp_RAIIAtenTensorHandle}({tmp_AtenTensorHandle});" - ) - return tmp_RAIIAtenTensorHandle, call_strs - else: - return f"{reinterpret_call}.view({self.codegen_dtype(dtype)})", [] + tmp_AtenTensorHandle = f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}" + call_strs = [f"AtenTensorHandle {tmp_AtenTensorHandle};"] + dtype_name = str(dtype).split(".")[-1] + device_name = data.layout.device.type + get_dtype_function = f"aoti_torch_dtype_{dtype_name}" + dtypeview_function = f"aoti_torch_{device_name}_view_dtype" + call_strs.append( + f"AOTI_TORCH_ERROR_CODE_CHECK({dtypeview_function}" + f"({reinterpret_call}, {get_dtype_function}(), &{tmp_AtenTensorHandle}));" + ) + tmp_RAIIAtenTensorHandle = ( + f"tmp_{data.get_name()}_{next(self.tmp_tensor_id)}_handle" + ) + call_strs.append( + f"RAIIAtenTensorHandle {tmp_RAIIAtenTensorHandle}({tmp_AtenTensorHandle});" + ) + return tmp_RAIIAtenTensorHandle, call_strs if ( size_list == data.layout.size @@ -530,47 +845,130 @@ def create_dtypeview_call(reinterpret_call: str) -> Tuple[str, List[str]]: # of self.generate), the writeline behavior is different in the two passes. writer.writelines(call_strs) - if config.abi_compatible: - if ( - self.can_stack_allocate_buffer(data) - and self.is_statically_known_list_of_ints(size_list) - and self.is_statically_known_list_of_ints(stride_list) - and ir.is_contiguous_strides_for_shape(stride_list, size_list) - ): - return final_tmp_name - - # NB, the return handle here represents a temporary tensor, which will be automatically - # released. - # Here's a sample usage in the cpp wrapper code: - # ``` - # aoti_torch_addmm_out( - # buf1, - # arg1_1, - # RAIIAtenTensorHandle(tmp_tensor_handle_0), - # buf0, - # 1L, - # 1L)); - # ``` - # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out. - # This could be problematic when it's used in a different pattern, for example: - # ```` - # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}; - # aoti_torch_proxy_executor_call_function(..., tensor_args); - # ```` - # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter - # kernel call. - # - # This is solved by updating the proxy_executor invocation to - # ``` - # aoti_torch_proxy_executor_call_function(..., - # std::vector{ - # RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6 - # }.data() - # ); - # ``` - if not final_tmp_name_is_RAIIAtenTensorHandle: - return f"wrap_with_raii_handle_if_needed({final_tmp_name})" - else: - return final_tmp_name + if ( + self.can_stack_allocate_buffer(data) + and self.is_statically_known_list_of_ints(size_list) + and self.is_statically_known_list_of_ints(stride_list) + and ir.is_contiguous_strides_for_shape(stride_list, size_list) + ): + return final_tmp_name + + # NB, the return handle here represents a temporary tensor, which will be automatically + # released. + # Here's a sample usage in the cpp wrapper code: + # ``` + # aoti_torch_addmm_out( + # buf1, + # arg1_1, + # RAIIAtenTensorHandle(tmp_tensor_handle_0), + # buf0, + # 1L, + # 1L)); + # ``` + # RAIIAtenTensorHandle(tmp_tensor_handle_0) will be released after the call to addmm_out. + # This could be problematic when it's used in a different pattern, for example: + # ```` + # AtenTensorHandle tensor_args[] = {RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6}; + # aoti_torch_proxy_executor_call_function(..., tensor_args); + # ```` + # RAIIAtenTensorHandle(tmp_tensor_handle_2) will be invalid when it's used in the latter + # kernel call. + # + # This is solved by updating the proxy_executor invocation to + # ``` + # aoti_torch_proxy_executor_call_function(..., + # std::vector{ + # RAIIAtenTensorHandle(tmp_tensor_handle_2), buf5, buf6 + # }.data() + # ); + # ``` + if not final_tmp_name_is_RAIIAtenTensorHandle: + return f"wrap_with_raii_handle_if_needed({final_tmp_name})" else: return final_tmp_name + + def val_to_arg_str(self, val, type_=None) -> str: + if val is None: + # None needs special care. It either represent nullopt or an empty tensor + if type_ is None or isinstance(type_, torch.OptionalType): + if type_ is not None and isinstance( + type_.getElementType(), + ( + torch.ListType, + torch.TupleType, + torch.DeviceObjType, + ), + ): + return "0, 0" + else: + return "0" # nullptr is not available in C + elif isinstance(type_, torch.TensorType): + # create an empty tensor, the equivalent of at::Tensor() + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"AtenTensorHandle {var_name}_handle;") + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{var_name}_handle));" + ) + self.writeline(f"RAIIAtenTensorHandle {var_name}({var_name}_handle);") + return var_name + else: + raise AssertionError("Can not map None to a known data type") + + if isinstance(type_, torch.OptionalType): + element_type = type_.getElementType() + if not isinstance(element_type, torch.TensorType): + var_name = f"var_{next(self.arg_var_id)}" + if isinstance( + element_type, + (torch.ListType, torch.TupleType, torch.DeviceObjType), + ): + # type_ is something like Optional[List] or Optional[Device] + arg_str = self.val_to_arg_str(val, element_type) + # For datatypes with auxiliary info, we need to hoist out the extra arguments. + # NOTE: This only works if there is one additional argument, though it can easily be generalized. + main_value, aux = arg_str.rsplit(", ") + self.writeline(f"auto {var_name} = {main_value};") + return f"&{var_name}, {aux}" + else: + self.writeline( + f"{self.c_type_for_prim_type(val, element_type)} {var_name} = {self.val_to_arg_str(val, element_type)};" + ) + return f"&{var_name}" + else: + # type_ is Optional[Tensor] + # Similar to other data type, use pointer to denote optional tensor arg in v2 C shim + base_handle = self.val_to_arg_str(val, element_type) + if config.use_minimal_arrayref_interface: + base_handle = f"convert_arrayref_tensor_to_tensor({base_handle})" + ( + tmp_raii_handle_var, + tmp_raii_handle_var_decl, + ) = self.create_tmp_raii_handle_var(base_handle) + if tmp_raii_handle_var: + self.writeline(tmp_raii_handle_var_decl) + base_handle = tmp_raii_handle_var + var_name = f"var_{next(self.arg_var_id)}" + self.writeline(f"AtenTensorHandle {var_name} = {base_handle}.get();") + return f"&{var_name}" + + elif isinstance(type_, torch.ListType): + assert isinstance( + val, (list, tuple) + ), f"{val} does not match with arg type {type_}" + element_type = type_.getElementType() + var_name = f"var_array_{next(self.var_array_id)}" + if len(val) == 0: + # Zero-size array is not supported in the C or C++ standard, so + # we declare a null pointer for it. + self.writeline( + f"const {self.c_type_for_prim_type(None, element_type)}* {var_name} = nullptr;" + ) + else: + result = f"{{{', '.join(self.val_to_arg_str(x, element_type) for x in val)}}}" + self.writeline( + f"const {self.c_type_for_prim_type(val[0], element_type)} {var_name}[] = {result};" + ) + # Need to pass the array length because we can't use std::vector + return f"{var_name}, {len(val)}" + + return self.val_to_arg_str_for_prim_type(val, type_) diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 59e9055d8e636..c8dae73be75ef 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -1,24 +1,23 @@ # mypy: allow-untyped-defs import functools import os -from itertools import chain, count +from itertools import chain, count, zip_longest from typing import Any, Callable, List, Optional, Tuple, TYPE_CHECKING, Union import sympy -from torch import dtype as torch_dtype, uint8 +from torch import dtype as torch_dtype from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name from torch._inductor.runtime.triton_heuristics import grid as default_grid_fn -from .. import config from ..codecache import CudaKernelParamCache from ..utils import DeferredLineBase, get_gpu_type from ..virtualized import V from .aoti_hipify_utils import maybe_hipify_code_wrapper from .common import get_device_op_overrides -from .cpp_utils import cexpr, DTYPE_TO_CPP +from .cpp_utils import cexpr from .cpp_wrapper_cpu import CppWrapperCpu -from .wrapper import SymbolicCallArg +from .wrapper import PythonWrapperCodegen, SymbolicCallArg if TYPE_CHECKING: @@ -98,13 +97,7 @@ def __call__(self): assert ( params is not None ), f"{self.kernel_name} not found in CudaKernelParamCache" - block_cfg = { - "XBLOCK": params["x_block"], - "YBLOCK": params["y_block"], - "ZBLOCK": params["z_block"], - "RBLOCK": params["r_block"], - } - return grid_fn(block_cfg) + return grid_fn(params["meta"]) class DeferredGpuGridLine(DeferredLineBase): @@ -171,6 +164,14 @@ def __init__(self) -> None: super().__init__() self.grid_id = count() + @staticmethod + def create( + is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + ): + # TODO - support subgraph codegen by lifting functions. Check the + # comment at CppWrapperCpu `codegen_subgraph` function. + return CppWrapperGpu() + def write_header(self): if V.graph.is_const_graph: # We do not write header for constant graph, it will be written by main module. @@ -179,12 +180,7 @@ def write_header(self): super().write_header() self.header.splice("#include ") - if config.abi_compatible: - self.header.splice(self.device_codegen.abi_compatible_header()) - else: - self.header.splice( - maybe_hipify_code_wrapper(self.device_codegen.kernel_header()) - ) + self.header.splice(self.device_codegen.abi_compatible_header()) self.header.splice( maybe_hipify_code_wrapper(self.device_codegen.kernel_driver()) ) @@ -254,6 +250,9 @@ def generate_user_defined_triton_kernel( autotune_configs=configs, ) + def generate_tma_descriptor(self, desc): + raise NotImplementedError("Host-side TMA descriptors NYI in C++ wrapper.") + @functools.lru_cache(None) # noqa: B019 def generate_load_kernel_once( self, @@ -266,72 +265,75 @@ def generate_load_kernel_once( self.writeline( DeferredGpuKernelLine( kernel_name, - """ """ - + kernel_var_name - + """ = loadKernel("%s", "%s", %s, this->cubin_dir_);""" - if V.graph.aot_mode - else """ """ - + kernel_var_name - + """ = loadKernel("%s", "%s", %s);""", + ( + """ """ + + kernel_var_name + + """ = loadKernel("%s", "%s", %s, this->cubin_dir_);""" + if V.graph.aot_mode + else """ """ + + kernel_var_name + + """ = loadKernel("%s", "%s", %s);""" + ), keys, ) ) self.writeline("}") return kernel_var_name - def generate_args_decl(self, call_args, arg_types): + def generate_args_decl(self, call_args, arg_types, arg_signatures): new_args = [] - for arg, arg_type in zip(call_args, arg_types): + + # Add more cases for other types as needed + signature2dtype = { + "i32": "int32_t", + "i64": "int64_t", + "fp32": "float", + } + + def process_args(arg, arg_type, arg_signature=None): var_name = f"var_{next(self.arg_var_id)}" if isinstance(arg_type, torch_dtype): if arg.endswith(".item()"): # Need to declare a scalar in this case - ctype = DTYPE_TO_CPP[arg_type] arg = arg[:-7] - if config.abi_compatible: - self.codegen_tensor_item( - arg_type, - arg, - var_name, - ) - else: - from torch import bfloat16, float16 - - if arg_type in (float16, bfloat16): - var_name_tmp = f"{var_name}_tmp" - self.writeline( - f"{ctype} {var_name_tmp} = {arg}.item<{ctype}>();" - ) - self.writeline(f"float {var_name} = float({var_name_tmp});") - else: - self.writeline( - f"{ctype} {var_name} = {arg}.item<{ctype}>();" - ) + self.codegen_tensor_item( + arg_type, + arg, + var_name, + ) else: - if config.abi_compatible: - self.writeline( - maybe_hipify_code_wrapper( - f"{self.device_codegen.cpp_device_ptr()} {var_name};" - ) - ) - self.writeline( - f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast(&{var_name})));" - ) - else: - self.writeline( - maybe_hipify_code_wrapper( - f"{self.device_codegen.cpp_device_ptr()} {var_name} = \ - reinterpret_cast<{self.device_codegen.cpp_device_ptr()}>({arg}.data_ptr());" - ) + self.writeline( + maybe_hipify_code_wrapper( + f"{self.device_codegen.cpp_device_ptr()} {var_name};" ) + ) + self.writeline( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr({arg}, reinterpret_cast(&{var_name})));" + ) elif arg_type in (sympy.Integer, int): self.writeline(f"int {var_name} = {self.expr_printer(arg)};") elif arg_type in (sympy.Float, float): self.writeline(f"float {var_name} = {self.expr_printer(arg)};") + # For symbolic call arguments, examine the arg signatures from triton meta + # to explicitly cast to the right type + # Reason: `auto` can infer unexpected type against kernel input signature. + elif ( + isinstance(arg_type, type(SymbolicCallArg)) + and arg_signature is not None + and arg_signature in signature2dtype.keys() + ): + self.writeline( + f"{signature2dtype[arg_signature]} {var_name} = {self.expr_printer(arg)};" + ) else: self.writeline(f"auto {var_name} = {self.expr_printer(arg)};") new_args.append(f"&{var_name}") + for arg, arg_type, arg_signature in zip_longest( + call_args, arg_types, arg_signatures + ): + process_args(arg, arg_type, arg_signature) + return ", ".join(new_args) def generate_default_grid( @@ -391,7 +393,7 @@ def generate_kernel_call( ) if device_index is None: - current_device = V.graph.scheduler.get_current_device_or_throw() + current_device = V.graph.get_current_device_or_throw() device_index = current_device.index stream = ( "stream" @@ -408,18 +410,26 @@ def generate_kernel_call( # args with value 1 are added into equal_to_1 and constants # in triton_meta (in the Python codegen) which makes them # inlined in the PTX and compiled CUBIN + arg_signatures = [] if ( triton_meta is not None - and "configs" in triton_meta - and triton_meta["configs"] + and triton_meta.get("configs") + and triton_meta.get("signature") ): equal_to_1 = triton_meta["configs"][0].equal_to_1 call_args = [ arg for i, arg in enumerate(call_args) if i not in equal_to_1 ] arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1] + # extract the arg signatures from triton_meta + arg_signatures = triton_meta["signature"].values() + arg_signatures = [ + v for i, v in enumerate(arg_signatures) if i not in equal_to_1 + ] - call_args_str = self.generate_args_decl(call_args, arg_types) + call_args_str = self.generate_args_decl( + call_args, arg_types, arg_signatures + ) kernel_args_var = f"kernel_args_var_{next(self.kernel_callsite_id)}" self.writeline(f"void* {kernel_args_var}[] = {{{call_args_str}}};") @@ -458,31 +468,13 @@ def generate_kernel_call( for arg_type, arg in zip(arg_types, call_args): new_arg = arg if arg_type.endswith("*") and arg != "nullptr": - if config.abi_compatible: - new_arg = f"var_{next(self.arg_var_id)}" - self.writeline( - f"auto* {new_arg} = get_data_ptr_wrapper({arg});" - ) - else: - new_arg = f"{arg}.data_ptr()" + new_arg = f"var_{next(self.arg_var_id)}" + self.writeline(f"auto* {new_arg} = get_data_ptr_wrapper({arg});") casted.append(f"({arg_type}){new_arg}") call_args_str = ", ".join(casted) self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});") - def generate_workspace_allocation(self, nbytes, device, zero_fill): - line = self.make_allocation( - "workspace", device, uint8, shape=(nbytes,), stride=(1,) + def make_zero_buffer(self, name): + return ( + f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get())){self.ending}" ) - self.writeline(line) - if config.triton.autotune_at_compile_time: - self.kernel_autotune_calls.writeline(line) - if zero_fill: - if config.abi_compatible: - # TODO: remove this function to use the default WrapperCodegen behavior after service platform has zero_() symbol - # default behavior is f"workspace.zero_(){self.ending}" - # or add f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_(workspace.get())){self.ending}" - pass - else: - self.writeline(f"workspace.zero_(){self.ending}") - if config.triton.autotune_at_compile_time: - self.kernel_autotune_calls.writeline(f"workspace.zero_(){self.ending}") diff --git a/torch/_inductor/codegen/cuda/cuda_kernel.py b/torch/_inductor/codegen/cuda/cuda_kernel.py index c6748e6cde619..ad288244012b1 100644 --- a/torch/_inductor/codegen/cuda/cuda_kernel.py +++ b/torch/_inductor/codegen/cuda/cuda_kernel.py @@ -16,7 +16,13 @@ ) from ...utils import sympy_product from ...virtualized import V -from ..common import IndentedBuffer, Kernel, OpOverrides +from ..common import ( + IndentedBuffer, + Kernel, + OpOverrides, + WorkspaceArg, + WorkspaceZeroMode, +) from ..cpp_utils import CppPrinter, DTYPE_TO_CPP @@ -197,14 +203,19 @@ def call_kernel( arg_types.append("size_t*") if node.get_workspace_size() > 0: - wrapper.generate_workspace_allocation( - node.get_workspace_size(), V.graph.scheduler.current_device, False + ws = WorkspaceArg( + count=node.get_workspace_size(), + device=V.graph.get_current_device_or_throw(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + outer_name=WorkspaceArg.unique_name(), ) - data_ptr = "workspace.data_ptr()" + wrapper.generate_workspace_allocation(ws) + data_ptr = f"{ws.outer_name}.data_ptr()" call_args.append( data_ptr if V.graph.cpp_wrapper else f"c_void_p({data_ptr})" ) else: + ws = None call_args.append("nullptr" if V.graph.cpp_wrapper else "None") if V.graph.cpp_wrapper: arg_types.append("uint8_t*") @@ -216,8 +227,8 @@ def call_kernel( triton=False, arg_types=arg_types, ) - if node.get_workspace_size() > 0: - wrapper.writeline(wrapper.make_free_by_names(["workspace"])) + if ws: + wrapper.generate_workspace_deallocation(ws) def dtype(self, node: IRNode) -> Optional[str]: """ @@ -363,8 +374,9 @@ def __init__( bmreq: CUDABenchmarkRequest, template: "CUDATemplate", # type: ignore[name-defined] info_kwargs: Optional[Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]], # type: ignore[type-arg] + description: str, ) -> None: - super().__init__(name, input_nodes, layout) + super().__init__(name, input_nodes, layout, description) self.category = category self.make_kernel_render = make_kernel_render self.bmreq = bmreq diff --git a/torch/_inductor/codegen/cuda/cuda_template.py b/torch/_inductor/codegen/cuda/cuda_template.py index 1f5e59b3b8cc2..2902c25cfcf60 100644 --- a/torch/_inductor/codegen/cuda/cuda_template.py +++ b/torch/_inductor/codegen/cuda/cuda_template.py @@ -43,12 +43,13 @@ def __init__( """ super().__init__(name) self.input_nodes = input_nodes - self.output_node: Buffer = Buffer("buf_out", layout) + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) self.input_reorder = input_reorder self.layout = layout def generate( # type: ignore[override] self, + description, **kwargs, ) -> CUDATemplateCaller: """ @@ -129,6 +130,7 @@ def make_kernel_render( bmreq, self, kwargs, + description, ) def header(self) -> IndentedBuffer: diff --git a/torch/_inductor/codegen/cuda/device_op_overrides.py b/torch/_inductor/codegen/cuda/device_op_overrides.py index 011e503a7b889..ecd89a7f3eb97 100644 --- a/torch/_inductor/codegen/cuda/device_op_overrides.py +++ b/torch/_inductor/codegen/cuda/device_op_overrides.py @@ -46,7 +46,12 @@ def kernel_driver(self): do { \\ CUresult code = EXPR; \\ const char *msg; \\ - cuGetErrorString(code, &msg); \\ + CUresult code_get_error = cuGetErrorString(code, &msg); \\ + if (code_get_error != CUDA_SUCCESS) { \\ + throw std::runtime_error( \\ + std::string("CUDA driver error: ") + \\ + std::string("invalid error code!")); \\ + } \\ if (code != CUDA_SUCCESS) { \\ throw std::runtime_error( \\ std::string("CUDA driver error: ") + \\ diff --git a/torch/_inductor/codegen/cuda/gemm_template.py b/torch/_inductor/codegen/cuda/gemm_template.py index b9a38f65df1e1..ee2c51bd779cb 100644 --- a/torch/_inductor/codegen/cuda/gemm_template.py +++ b/torch/_inductor/codegen/cuda/gemm_template.py @@ -505,9 +505,10 @@ def _add_cutlass_gemm_choices( """ ops = self.gen_ops() - for op in ops: + for name, op in ops: self.maybe_append_choice( choices, + description=name, op=op, ) if len(ops) == 0: @@ -809,7 +810,7 @@ def filter_op( return op - def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name-defined] # noqa: F821 + def gen_ops(self) -> "List[Tuple[str, cutlass_gemm_op.GemmOperation]]": # type: ignore[name-defined] # noqa: F821 """ Creates a list of Cutlass GemmOperation instances that match the operation this template is designed to represent. The matching is carried out with respect to the input and output specifications of the operation. @@ -817,8 +818,8 @@ def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name No function arguments. Returns: - List[cutlass_gemm_op.GemmOperation]: A list of GemmOperation instances that are compatible with the - operation requirements of this template. + List[Tuple[str, cutlass_gemm_op.GemmOperation]]: A list of (cutlass_name, GemmOperation) + tuples that are compatible with the operation requirements of this template. """ assert cutlass_utils.try_import_cutlass() import cutlass_library.gemm_operation as cutlass_gemm_op @@ -837,7 +838,7 @@ def gen_ops(self) -> "List[cutlass_gemm_op.GemmOperation]": # type: ignore[name ): res[filter_res.configuration_name()] = filter_res log.debug("Got cutlass configs: total number of ops: %d, ", len(res)) - return list(res.values())[: inductor_cuda_config.cutlass_max_profiling_configs] + return list(res.items())[: inductor_cuda_config.cutlass_max_profiling_configs] def gemm_mode(self) -> str: """ @@ -1277,7 +1278,7 @@ def clone_with_transposed_stride(node: IRNode) -> IRNode: new_stride, old_layout.offset, ) - return Buffer(node.get_name(), new_layout) + return Buffer(name=node.get_name(), layout=new_layout) new_X = clone_with_transposed_stride(X) new_W = clone_with_transposed_stride(W) diff --git a/torch/_inductor/codegen/debug_utils.py b/torch/_inductor/codegen/debug_utils.py index 25fab760b9a9b..791cbc69dd6da 100644 --- a/torch/_inductor/codegen/debug_utils.py +++ b/torch/_inductor/codegen/debug_utils.py @@ -184,13 +184,9 @@ def codegen_intermediate_tensor_value_save( continue launch_prefix = "before_launch" if before_launch else "after_launch" if V.graph.cpp_wrapper: - if config.abi_compatible: - V.graph.wrapper_code.writeline( - f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");' - ) - else: - # TODO: add non-abi compatible mode debug printing info - pass + V.graph.wrapper_code.writeline( + f'aoti_torch_save_tensor_handle({arg}, "{arg}", "{launch_prefix}", "{kernel_name}");' + ) else: cwd = os.getcwd() saved_dir = cwd + "/tmp/jit_inductor/" @@ -226,11 +222,10 @@ def codegen_intermediate_tensor_value_print( == IntermediateValueDebuggingLevel.PRINT_KERNEL_NAMES_ONLY ): if V.graph.cpp_wrapper: - if config.abi_compatible: - V.graph.wrapper_code.writeline( - f'printf("[ {launch_prefix}: {kernel_name} ]");' - ) - V.graph.wrapper_code.writeline('printf("\\n");') + V.graph.wrapper_code.writeline( + f'printf("[ {launch_prefix}: {kernel_name} ]");' + ) + V.graph.wrapper_code.writeline('printf("\\n");') return if self.debug_printer_level != IntermediateValueDebuggingLevel.PRINT_ONLY: @@ -244,38 +239,30 @@ def codegen_intermediate_tensor_value_print( ): continue if V.graph.cpp_wrapper: - if config.abi_compatible: - if arg_signatures is not None and isinstance( - arg_signatures[i], (torch_dtype) - ): - # infer from the arg data type (has torch.dtype) to see if it is a tensor type + if arg_signatures is not None and isinstance( + arg_signatures[i], (torch_dtype) + ): + # infer from the arg data type (has torch.dtype) to see if it is a tensor type + V.graph.wrapper_code.writeline( + f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");' + ) + elif arg_signatures is not None and isinstance( + arg_signatures[i], + ( + type(torch._inductor.codegen.wrapper.SymbolicCallArg), + type(int), + type(float), + type(bool), + ), + ): + V.graph.wrapper_code.writeline( + f'printf("[ {launch_prefix} - {kernel_name} - {arg}: %ld ]", {arg}); printf("\\n");' + ) + else: + if arg_signatures is None and self.kernel_type == "cpp" or "extern": V.graph.wrapper_code.writeline( f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");' ) - elif arg_signatures is not None and isinstance( - arg_signatures[i], - ( - type(torch._inductor.codegen.wrapper.SymbolicCallArg), - type(int), - type(float), - type(bool), - ), - ): - V.graph.wrapper_code.writeline( - f'printf("[ {launch_prefix} - {kernel_name} - {arg}: %ld ]", {arg}); printf("\\n");' - ) - else: - if ( - arg_signatures is None - and self.kernel_type == "cpp" - or "extern" - ): - V.graph.wrapper_code.writeline( - f'aoti_torch_print_tensor_handle({arg}, "{launch_prefix} - {kernel_name} - {arg}");' - ) - else: - # TODO: add non-abi compatible mode debug printing info - pass else: V.graph.wrapper_code.writeline( f'_print_debugging_tensor_value_info("inductor: {launch_prefix} - {kernel_name} - {arg}", {arg})' diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index bbf6866fdc381..27a043d785443 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -572,8 +572,13 @@ def _typecheck_HalideOverrides(h: HalideOverrides) -> OpsHandler[str]: class HalideCSEVariable(CSEVariable): undefined_re = re.compile(r"\b(tmp\d+)\[\?\]") - def __init__(self, name, bounds: ValueRanges[Any]) -> None: - super().__init__(name, bounds) + def __init__( + self, + name, + bounds: ValueRanges[Any], + dtype: Optional[torch.dtype] = None, + ) -> None: + super().__init__(name, bounds, dtype) self.used_dims: Optional[List[sympy.Symbol]] = None def update_on_args(self, name, args, kwargs): @@ -706,9 +711,9 @@ def __init__( self.buffer_aliases: Dict[str, List[str]] = defaultdict(list) self.has_indirect_indexing = False - def create_cse_var(self, name, bounds=None): + def create_cse_var(self, name, bounds=None, dtype=None): self.body.writeline(f"{name} = hl.Func({name!r})") - return HalideCSEVariable(name, bounds) + return HalideCSEVariable(name, bounds, dtype) def finalize_indexing(self, indices: Sequence[sympy.Expr]): """ @@ -1444,7 +1449,7 @@ def halide_kernel_meta(self) -> HalideMeta: ) ) - current_device = V.graph.scheduler.get_current_device_or_throw() + current_device = V.graph.get_current_device_or_throw() if current_device.type == "cpu": target = [config.halide.cpu_target] schduler = config.halide.scheduler_cpu @@ -1621,7 +1626,7 @@ def _autoscheduler_workarounds(n, dims): if ( len(dims) == 1 and config.halide.scheduler_cuda == "Anderson2021" - and V.graph.scheduler.get_current_device_or_throw().type == "cuda" + and V.graph.get_current_device_or_throw().type == "cuda" ): # workaround https://github.com/halide/Halide/issues/8246 n = max(2, n) @@ -1631,7 +1636,7 @@ def call_kernel(self, name: str, node=None): """Codegen a call to this kernel""" wrapper = V.graph.wrapper_code call_args = [f"{n}" for n, arg in self.halide_argdefs() if arg.alias_of is None] - current_device = V.graph.scheduler.get_current_device_or_throw() + current_device = V.graph.get_current_device_or_throw() if current_device.type == "cuda": stream_name = wrapper.write_get_raw_stream(current_device.index, V.graph) call_args.append(stream_name) diff --git a/torch/_inductor/codegen/memory_planning.py b/torch/_inductor/codegen/memory_planning.py index 60360597ec1cb..b1841da6a5f48 100644 --- a/torch/_inductor/codegen/memory_planning.py +++ b/torch/_inductor/codegen/memory_planning.py @@ -11,11 +11,12 @@ import torch -from .. import config, ir +from .. import config from ..utils import _align, align, cache_on_self, CachedMethod, IndentedBuffer from ..virtualized import V from .wrapper import ( AllocateLine, + BufferLike, FreeIfNotReusedLine, MemoryPlanningLine, NullLine, @@ -129,7 +130,7 @@ class Allocation(AllocationTreeNode): Represents memory allocated to a given node in the allocation pool. """ - node: ir.Buffer + node: BufferLike live_range: LiveRange size_hint: int symbolic_size: sympy.Expr @@ -506,7 +507,7 @@ class BufferGroup: This tracks these collections of buffers sharing underlying memory. """ - def __init__(self, node: ir.Buffer): + def __init__(self, node: BufferLike): self.node = node self.names = [node.get_name()] self.is_output = False diff --git a/torch/_inductor/codegen/rocm/ck_conv_template.py b/torch/_inductor/codegen/rocm/ck_conv_template.py new file mode 100644 index 0000000000000..02ad5a404808a --- /dev/null +++ b/torch/_inductor/codegen/rocm/ck_conv_template.py @@ -0,0 +1,558 @@ +# mypy: allow-untyped-defs +import copy +import logging +import random +from typing import Tuple + +from torch._inductor.virtualized import V + + +try: + import ck4inductor # type: ignore[import] +except ImportError: + ck4inductor = None + +if ck4inductor is not None: + from ck4inductor.grouped_conv_fwd.gen_instances import ( # type: ignore[import] + gen_conv_ops_library, + ) + from ck4inductor.grouped_conv_fwd.op import ( # type: ignore[import] # noqa: TCH002 + CKGroupedConvFwdOp, + ) +else: + + def gen_conv_ops_library(): + return [] + + +from torch._inductor import config +from torch._inductor.codegen.rocm.ck_template import CKTemplate +from torch._inductor.codegen.rocm.rocm_kernel import ROCmTemplateKernel +from torch._inductor.utils import IndentedBuffer + + +log = logging.getLogger(__name__) + + +def torch_layout_to_ck_layouts(torch_layout): + # logically, torch tensors are always NCHW, + # and channels-last memory layout is visible in the strides + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + # when input or output is NCHW + # NB: torch.conv2d result is always NCHW + return ["NGCHW", "GKCYX", "NGKHW"] + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + # when input or output or weight is channels-last + return ["NHWGC", "GKYXC", "NHWGK"] + else: + return None + + +def torch_layout_to_ck_input_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "NGCHW" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "NHWGC" + else: + return None + + +def torch_layout_to_ck_weight_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "GKCYX" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "GKYXC" + else: + return None + + +def torch_layout_to_ck_output_layout(torch_layout): + if V.graph.sizevars.statically_known_equals(torch_layout.stride[-1], 1): + return "NGKHW" + elif V.graph.sizevars.statically_known_equals(torch_layout.stride[-3], 1): + return "NHWGK" + else: + return None + + +class CKGroupedConvFwdTemplate(CKTemplate): + conv_template = r""" + {{headers}} + {{globals}} + {{instance_definition}} + extern "C" { + PT_EXPORT {{kernel_definition}} { + auto conv = {{instance_type}} {}; + auto invoker = conv.MakeInvoker(); + + using ck::index_t; + + constexpr index_t NumDTensor = {{n_d_tensors}}; + constexpr index_t NDimSpatial = {{n_dim_spatial}}; + constexpr index_t GroupCount = {{group_count}}; + constexpr index_t NBatch = {{batch_size}}; + constexpr index_t NOutChannels = {{n_output_channels}}; + constexpr index_t NInChannels = {{n_input_channels}}; + const std::vector FilterSize = { {{filter_size}} }; + const std::vector InputSize = { {{input_size}} }; + const std::vector ConvolutionStrides = { {{convolution_strides}} }; + const std::vector Dilations = { {{dilations}} }; + const std::vector LeftPads = { {{left_pads}} }; + const std::vector RightPads = { {{right_pads}} }; + + auto conv_param = ck::utils::conv::ConvParam { + NDimSpatial, + GroupCount, + NBatch, + NOutChannels, + NInChannels, + FilterSize, + InputSize, + ConvolutionStrides, + Dilations, + LeftPads, + RightPads, + }; + + using InLayout = ck::tensor_layout::convolution::{{input_layout}}; + using WeiLayout = ck::tensor_layout::convolution::{{weight_layout}}; + using OutLayout = ck::tensor_layout::convolution::{{output_layout}}; + + const auto in_g_n_c_wis_desc = + ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed(conv_param); + const auto wei_g_k_c_xs_desc = + ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed(conv_param); + const auto out_g_n_k_wos_desc = + ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed(conv_param); + + const void* p_a = input; + const void* p_b = weight; + const std::array p_ds; + void* p_e = output; + std::array a_g_n_c_wis_lengths; + std::array a_g_n_c_wis_strides; + std::array b_g_k_c_xs_lengths; + std::array b_g_k_c_xs_strides; + std::array, NumDTensor> ds_g_n_k_wos_lengths; + std::array, NumDTensor> ds_g_n_k_wos_strides; + std::array e_g_n_k_wos_lengths; + std::array e_g_n_k_wos_strides; + std::array conv_filter_strides; + std::array conv_filter_dilations; + std::array input_left_pads; + std::array input_right_pads; + const auto a_element_op = PassThrough {}; + const auto b_element_op = PassThrough {}; + const auto cde_element_op = PassThrough {}; + + auto copy = [](auto& x, auto& y) { ck::ranges::copy(x, y.begin()); }; + + copy(in_g_n_c_wis_desc.GetLengths(), a_g_n_c_wis_lengths); + copy(in_g_n_c_wis_desc.GetStrides(), a_g_n_c_wis_strides); + copy(wei_g_k_c_xs_desc.GetLengths(), b_g_k_c_xs_lengths); + copy(wei_g_k_c_xs_desc.GetStrides(), b_g_k_c_xs_strides); + copy(out_g_n_k_wos_desc.GetLengths(), e_g_n_k_wos_lengths); + copy(out_g_n_k_wos_desc.GetStrides(), e_g_n_k_wos_strides); + copy(conv_param.conv_filter_strides_, conv_filter_strides); + copy(conv_param.conv_filter_dilations_, conv_filter_dilations); + copy(conv_param.input_left_pads_, input_left_pads); + copy(conv_param.input_right_pads_, input_right_pads); + + auto argument = conv.MakeArgument( + p_a, + p_b, + p_ds, + p_e, + a_g_n_c_wis_lengths, + a_g_n_c_wis_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_k_wos_lengths, + ds_g_n_k_wos_strides, + e_g_n_k_wos_lengths, + e_g_n_k_wos_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op + ); + if (!conv.IsSupportedArgument(argument)) { + // we do our best to statically avoid this case in `filter_op` + std::cerr << "invalid argument for conv instance " << conv.GetTypeString() << std::endl; + argument.Print(); + return -23; + } + if (workspace_size) { + *workspace_size = conv.GetWorkSpaceSize(&argument); + return 0; + } + + if (p_a == nullptr) { + std::cerr << "p_a is nullptr" << std::endl; + return -1; + } + if (p_b == nullptr) { + std::cerr << "p_b is nullptr" << std::endl; + return -1; + } + if (p_e == nullptr) { + std::cerr << "p_e is nullptr" << std::endl; + return -1; + } + + // when debugging, do time kernel to serialize launches + auto stream_config = StreamConfig{stream, /* time kernel */ false, /* log level */ 0}; + + if (workspace != nullptr) { + conv.SetWorkSpacePointer(&argument, workspace, stream_config); + } + + // run the kernel + float elapsed_time = invoker.Run(argument, stream_config); + return 0; + } // kernel definition + } // extern C +""" + + def globals(self) -> IndentedBuffer: + res = super().globals() + res.splice( + """ + // CK conv globals + + using NWC = ck::tensor_layout::convolution::NWC; + using NHWC = ck::tensor_layout::convolution::NHWC; + using NDHWC = ck::tensor_layout::convolution::NDHWC; + + using KXC = ck::tensor_layout::convolution::KXC; + using KYXC = ck::tensor_layout::convolution::KYXC; + using KZYXC = ck::tensor_layout::convolution::KZYXC; + + using NWK = ck::tensor_layout::convolution::NWK; + using NHWK = ck::tensor_layout::convolution::NHWK; + using NDHWK = ck::tensor_layout::convolution::NDHWK; + + using GNWC = ck::tensor_layout::convolution::GNWC; + using GNHWC = ck::tensor_layout::convolution::GNHWC; + using GNDHWC = ck::tensor_layout::convolution::GNDHWC; + + using GKXC = ck::tensor_layout::convolution::GKXC; + using GKYXC = ck::tensor_layout::convolution::GKYXC; + using GKZYXC = ck::tensor_layout::convolution::GKZYXC; + + using GKCX = ck::tensor_layout::convolution::GKCX; + using GKCYX = ck::tensor_layout::convolution::GKCYX; + using GKCZYX = ck::tensor_layout::convolution::GKCZYX; + + using GNWK = ck::tensor_layout::convolution::GNWK; + using GNHWK = ck::tensor_layout::convolution::GNHWK; + using GNDHWK = ck::tensor_layout::convolution::GNDHWK; + + using NGKW = ck::tensor_layout::convolution::NGKW; + using NGKHW = ck::tensor_layout::convolution::NGKHW; + using NGKDHW = ck::tensor_layout::convolution::NGKDHW; + + using NWGC = ck::tensor_layout::convolution::NWGC; + using NHWGC = ck::tensor_layout::convolution::NHWGC; + using NDHWGC = ck::tensor_layout::convolution::NDHWGC; + + using KXGC = ck::tensor_layout::convolution::KXGC; + using KYXGC = ck::tensor_layout::convolution::KYXGC; + using KZYXGC = ck::tensor_layout::convolution::KZYXGC; + + using NWGK = ck::tensor_layout::convolution::NWGK; + using NHWGK = ck::tensor_layout::convolution::NHWGK; + using NDHWGK = ck::tensor_layout::convolution::NDHWGK; + + using NGCW = ck::tensor_layout::convolution::NGCW; + using NGCHW = ck::tensor_layout::convolution::NGCHW; + using NGCDHW = ck::tensor_layout::convolution::NGCDHW; + + using G_K = ck::tensor_layout::convolution::G_K; + + using BlockGemmPipelineScheduler = ck::BlockGemmPipelineScheduler; + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + using BlockGemmPipelineVersion = ck::BlockGemmPipelineVersion; + + using ConvolutionForwardSpecialization = ck::tensor_operation::device::ConvolutionForwardSpecialization; + + namespace ck { + namespace utils { + namespace conv { + + ConvParam::ConvParam(ck::index_t n_dim, + ck::index_t group_count, + ck::index_t n_batch, + ck::index_t n_out_channels, + ck::index_t n_in_channels, + const std::vector& filters_len, + const std::vector& input_len, + const std::vector& strides, + const std::vector& dilations, + const std::vector& left_pads, + const std::vector& right_pads) + : num_dim_spatial_(static_cast(n_dim)), + G_(static_cast(group_count)), + N_(static_cast(n_batch)), + K_(static_cast(n_out_channels)), + C_(static_cast(n_in_channels)), + filter_spatial_lengths_(num_dim_spatial_), + input_spatial_lengths_(num_dim_spatial_), + output_spatial_lengths_(num_dim_spatial_), + conv_filter_strides_(num_dim_spatial_), + conv_filter_dilations_(num_dim_spatial_), + input_left_pads_(num_dim_spatial_), + input_right_pads_(num_dim_spatial_) + { + if(static_cast(filter_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(input_spatial_lengths_.size()) != num_dim_spatial_ || + static_cast(conv_filter_strides_.size()) != num_dim_spatial_ || + static_cast(conv_filter_dilations_.size()) != num_dim_spatial_ || + static_cast(input_left_pads_.size()) != num_dim_spatial_ || + static_cast(input_right_pads_.size()) != num_dim_spatial_) + { + throw( + std::runtime_error("ConvParam::ConvParam: " + "parameter size is different from number of declared dimensions!")); + } + + for(ck::index_t i = 0; i < num_dim_spatial_; ++i) + { + filter_spatial_lengths_[i] = static_cast(filters_len[i]); + input_spatial_lengths_[i] = static_cast(input_len[i]); + conv_filter_strides_[i] = static_cast(strides[i]); + conv_filter_dilations_[i] = static_cast(dilations[i]); + input_left_pads_[i] = static_cast(left_pads[i]); + input_right_pads_[i] = static_cast(right_pads[i]); + + // XEff = (X - 1) * conv_dilation_w + 1; + // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + const ck::long_index_t x_eff = + (filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1; + + output_spatial_lengths_[i] = + (input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) / + conv_filter_strides_[i] + + 1; + } + } + + } // namespace conv + } // namespace utils + } // namespace ck + + const std::vector& HostTensorDescriptor::GetLengths() const { return mLens; } + const std::vector& HostTensorDescriptor::GetStrides() const { return mStrides; } + std::size_t HostTensorDescriptor::GetNumOfDimension() const { return mLens.size(); } + void HostTensorDescriptor::CalculateStrides() { + mStrides.clear(); + mStrides.resize(mLens.size(), 0); + if(mStrides.empty()) + return; + + mStrides.back() = 1; + std::partial_sum( + mLens.rbegin(), mLens.rend() - 1, mStrides.rbegin() + 1, std::multiplies()); + } + """ + ) + return res + + def header(self) -> IndentedBuffer: + res = super().header() + res.splice( + """ + // CK conv headers + + #include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp" + #include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp" + #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" + + #include "ck/library/utility/convolution_parameter.hpp" + #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + """ + ) + return res + + @staticmethod + def add_ck_conv_choices( + choices, + layout, + input_nodes, + *, + stride, + padding, + dilation, + groups, + n_spatial_dimensions, + ): + template = CKGroupedConvFwdTemplate( + input_nodes, + layout, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + n_spatial_dimensions=n_spatial_dimensions, + ) + ops = template.gen_ops() + for op in ops: + template.maybe_append_choice( + choices, + op=op, + ) + + def __init__( + self, + input_nodes, + layout, + *, + stride, + padding, + dilation, + groups, + n_spatial_dimensions, + ): + super().__init__( + "ck_conv_template", + input_nodes, + layout, + ) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.n_spatial_dimensions = n_spatial_dimensions + + def filter_op(self, op: "CKGroupedConvFwdOp"): # type: ignore[name-defined] + metas = [ + T.get_layout() + for T in [*self.input_nodes, self.output_node] + if T is not None + ] + X_meta = metas[0] + W_meta = metas[1] + Y_meta = metas[-1] + # disable the instance if dtypes don't match + if op.a_element_dtype != self._TORCH_DTYPE_TO_CK[X_meta.dtype]: + return None + if op.b_element_dtype != self._TORCH_DTYPE_TO_CK[W_meta.dtype]: + return None + if op.e_element_dtype != self._TORCH_DTYPE_TO_CK[Y_meta.dtype]: + return None + # disable the instance if layouts don't match + if op.a_layout != torch_layout_to_ck_input_layout(X_meta): + return None + if op.b_layout != torch_layout_to_ck_weight_layout(W_meta): + return None + if op.e_layout != torch_layout_to_ck_output_layout(Y_meta): + return None + # disable the instance if number of spatial dimensions doesn't match + if op.n_dim_spatial != self.n_spatial_dimensions: + return None + # disable 1x1 and odd-channels conv specializations for now + if "Default" not in op.conv_forward_specialization: + return None + return op + + def gen_ops(self): + unfiltered_instances = gen_conv_ops_library() + + filtered_instances = list( + filter(lambda op: self.filter_op(op), unfiltered_instances) + ) + # NB: when using a fixed list order, most likely we will pick the subset of instances + # which are very similar to each other. Randomizing the choice seems to solve this. + random.seed(-11) + chosen_instances = ( + random.sample( + filtered_instances, + min(len(filtered_instances), config.rocm.n_max_profiling_configs), + ) + if config.rocm.n_max_profiling_configs + else filtered_instances + ) + log.debug( + "generated %d ck instances after filter: %s", + len(chosen_instances), + chosen_instances, + ) + return chosen_instances + + def emit_ck_instance(self, op: "CKGroupedConvFwdOp") -> Tuple[str, str]: # type: ignore[name-defined] + # The Jinja template for generating a C++ type alias *definition* for a Universal GEMM instance + template_definition = r""" + // Gemm operator {{operation_name}} + using Operation_{{operation_name}} = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< + {{template_params}}>; + +""" + # The Jinja template for generating a C++ type alias *usage* for a Universal GEMM instance + template_type = r""" + Operation_{{operation_name}} +""" + template_params = [] + for field_name, field_value in op.dict_items(): + if isinstance(field_value, tuple): + tuple_elements = ", ".join(map(str, iter(field_value))) + if "ds" in field_name: # element type and layout for bias + arg = f"/* {field_name} */ Tuple<{tuple_elements}>" + else: # tile shape + arg = f"/* {field_name} */ S<{tuple_elements}>" + template_params.append(arg) + else: + if field_value is not None: + template_params.append(f"/* {field_name} */ {field_value}") + return self._template_from_string(template_definition).render( + operation_name=op.name(), + template_params=(",\n" + 12 * " ").join(template_params), + ), self._template_from_string(template_type).render(operation_name=op.name()) + + def render(self, kernel: ROCmTemplateKernel, op: "CKGroupedConvFwdOp", **kwargs) -> str: # type: ignore[override, name-defined] + template_buffer_node = kwargs.get("template_buffer_node", None) + if template_buffer_node is not None: + self.output_node = template_buffer_node + X, W = self.input_nodes[0], self.input_nodes[1] + Y = self.output_node + Bias = self.input_nodes[2] if 3 == len(self.input_nodes) else None + + op = copy.deepcopy(op) + + instance_definition, instance_type = self.emit_ck_instance(op) + + return self._template_from_string(self.conv_template).render( + headers=self.header().getvalue(), + globals=self.globals().getvalue(), + instance_definition=instance_definition, + instance_type=instance_type, + kernel_definition=kernel.def_kernel( + inputs=[X, W, Bias] if Bias is not None else [X, W], + outputs=[Y], + names_str="input, weight, bias, output" + if Bias is not None + else "input, weight, output", + size_args=[], + ), + n_d_tensors=1 if Bias is not None else 0, + n_dim_spatial=self.n_spatial_dimensions, + group_count=self.groups, + batch_size=X.shape[0], + n_output_channels=Y.shape[1], + n_input_channels=X.shape[1], + filter_size=", ".join(map(str, W.shape[2:])), + input_size=", ".join(map(str, X.shape[2:])), + convolution_strides=", ".join(map(str, self.stride)), + dilations=", ".join(map(str, self.dilation)), + left_pads=", ".join(map(str, self.padding)), + right_pads=", ".join(map(str, self.padding)), + input_layout=op.a_layout, + weight_layout=op.b_layout, + output_layout=op.e_layout, + ) + + def size_args(self): + return [] diff --git a/torch/_inductor/codegen/rocm/compile_command.py b/torch/_inductor/codegen/rocm/compile_command.py index c765d98bd9431..228a250f1b8c3 100644 --- a/torch/_inductor/codegen/rocm/compile_command.py +++ b/torch/_inductor/codegen/rocm/compile_command.py @@ -25,14 +25,19 @@ def _rocm_include_paths() -> List[str]: from libfb.py import parutil ck_path = parutil.get_dir_path("composable-kernel-headers") - ck_include = os.path.join(ck_path, "include") else: - ck_include = os.path.join( - config.rocm.ck_dir or cpp_extension._join_rocm_home("composable_kernel"), - "include", + ck_path = config.rocm.ck_dir or cpp_extension._join_rocm_home( + "composable_kernel" ) - paths = [os.path.realpath(rocm_include), os.path.realpath(ck_include)] + ck_include = os.path.join(ck_path, "include") + ck_library_include = os.path.join(ck_path, "library", "include") + + # CK has to take priority over ROCm include paths + # Since CK is potentially more up-to-date + paths = [ + os.path.realpath(p) for p in (ck_include, ck_library_include, rocm_include) + ] return paths diff --git a/torch/_inductor/codegen/rocm/rocm_kernel.py b/torch/_inductor/codegen/rocm/rocm_kernel.py index e954de5af7e0c..40fb5b8a7011d 100644 --- a/torch/_inductor/codegen/rocm/rocm_kernel.py +++ b/torch/_inductor/codegen/rocm/rocm_kernel.py @@ -6,7 +6,7 @@ from ...ir import Buffer, ChoiceCaller, IRNode, Layout, PrimitiveInfoType, TensorBox from ...virtualized import V -from ..common import Kernel, OpOverrides +from ..common import Kernel, OpOverrides, WorkspaceArg, WorkspaceZeroMode from ..cpp_utils import CppPrinter from .rocm_benchmark_request import ROCmBenchmarkRequest from .rocm_template_buffer import ROCmTemplateBuffer @@ -85,7 +85,6 @@ def def_kernel( and the actual input passed into this template could be [Bias, X, W]. In this case, the `input_reorder` would be [2, 0, 1]. """ - names = [x.strip() for x in names_str.strip().split(",")] if len(inputs) + len(outputs) != len(names): raise RuntimeError( @@ -111,7 +110,7 @@ def def_kernel( arg_defs, *_ = self.args.cpp_argdefs() - signature = f"int {self.kernel_name}({', '.join(arg_defs)}, {', '.join(size_args)}, {self._EXTRA_CPP_ARGS})" + signature = f"int {self.kernel_name}({', '.join(arg_defs + size_args)}, {self._EXTRA_CPP_ARGS})" self.signature = signature return signature @@ -170,19 +169,24 @@ def call_kernel( arg_types.append("size_t*") if node.get_workspace_size() > 0: - wrapper.generate_workspace_allocation( - node.get_workspace_size(), V.graph.scheduler.current_device, False + ws = WorkspaceArg( + count=node.get_workspace_size(), + device=V.graph.get_current_device_or_throw(), + zero_mode=WorkspaceZeroMode.UNINITIALIZED, + outer_name=WorkspaceArg.unique_name(), ) - data_ptr = "workspace.data_ptr()" + wrapper.generate_workspace_allocation(ws) + data_ptr = f"{ws.outer_name}.data_ptr()" kernel_args.append( data_ptr if V.graph.cpp_wrapper else f"c_void_p({data_ptr})" ) else: + ws = None kernel_args.append("nullptr" if V.graph.cpp_wrapper else "None") if V.graph.cpp_wrapper: arg_types.append("uint8_t*") - current_device = V.graph.scheduler.get_current_device_or_throw() + current_device = V.graph.get_current_device_or_throw() wrapper.generate_kernel_call( name, kernel_args, @@ -191,8 +195,8 @@ def call_kernel( triton=False, arg_types=arg_types, ) - if node.get_workspace_size() > 0: - wrapper.writeline(wrapper.make_free_by_names(["workspace"])) + if ws: + wrapper.generate_workspace_deallocation(ws) class ROCmTemplateCaller(ChoiceCaller): @@ -218,7 +222,7 @@ def __init__( template: "ROCmTemplate", # type: ignore[name-defined] info_kwargs: Optional[Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]], # type: ignore[type-arg] ) -> None: - super().__init__(name, input_nodes, layout) + super().__init__(name, input_nodes, layout, description="") self.category = category self.make_kernel_render = make_kernel_render self.bmreq = bmreq diff --git a/torch/_inductor/codegen/rocm/rocm_template.py b/torch/_inductor/codegen/rocm/rocm_template.py index ea18d7ee2c83c..069606d226913 100644 --- a/torch/_inductor/codegen/rocm/rocm_template.py +++ b/torch/_inductor/codegen/rocm/rocm_template.py @@ -41,7 +41,7 @@ def __init__( """ super().__init__(name) self.input_nodes = input_nodes - self.output_node: Buffer = Buffer("buf_out", layout) + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) self.input_reorder = input_reorder self.layout = layout diff --git a/torch/_inductor/codegen/simd.py b/torch/_inductor/codegen/simd.py index eedf55a41ec81..b952a2d4529aa 100644 --- a/torch/_inductor/codegen/simd.py +++ b/torch/_inductor/codegen/simd.py @@ -1349,19 +1349,9 @@ def codegen_node_schedule( index_dtype=index_dtype, ) - def _node_has_sort(node): - if node in (EnableReduction, DisableReduction): - return False - - sort_nodes = node._body.root_block.graph.find_nodes( - op="call_method", target="sort" - ) - return bool(sort_nodes) - # ops.sort only works with persistent reduction, and is not bandwidth bound anyway # so taking the hit of non-coalesced loads is okay - has_sort = any(_node_has_sort(node) for node in node_schedule) - if has_sort: + if has_sort := schedule_contains_op(node_schedule, "sort"): kernel_kwargs["override_persistent_reduction"] = True kernel = kernel_type( @@ -1806,8 +1796,6 @@ class LastUsageHolder: def __del__(self) -> None: self.n.last_usage = self.last_usage - last_usage_holders = [LastUsageHolder(n, n.last_usage) for n in nodes] - # empty last_usage. May cause more aggressive 'evict_last'. Should be fine. for n in nodes: n.last_usage = OrderedSet() @@ -1897,3 +1885,12 @@ def filter(node_schedule): class CantSplit(Exception): pass + + +def schedule_contains_op(node_schedule, op_name: str) -> bool: + """True if V.ops.{op_name} is used in node_schedule""" + for node in node_schedule: + if node not in (EnableReduction, DisableReduction): + if node._body.has_op(op_name): + return True + return False diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 51b95f03ef882..bae35bf3c2cbe 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -6,6 +6,7 @@ import itertools import logging import os +import re import textwrap from functools import lru_cache from typing import ( @@ -25,6 +26,7 @@ import sympy import torch +import torch._inductor.metrics as metrics import torch._logging from torch._dynamo.utils import preserve_rng_state from torch._inductor.runtime.hints import AutotuneHint, DeviceProperties @@ -38,10 +40,10 @@ from ...utils._sympy.value_ranges import ValueRanges from .. import config, ir from ..codecache import code_hash, get_path, PyCodeCache -from ..metrics import is_metric_table_enabled, log_kernel_metadata from ..runtime.benchmarking import benchmarker from ..runtime.hints import ReductionHint, TRITON_MAX_BLOCK from ..runtime.runtime_utils import get_max_y_grid, next_power_of_2 +from ..scheduler import BaseSchedulerNode, FusedSchedulerNode, Scheduler, SchedulerNode from ..utils import ( cache_on_self, get_bounds_index_expr, @@ -65,6 +67,7 @@ SizeArg, TensorArg, WorkspaceArg, + WorkspaceZeroMode, ) from .simd import ( constant_repr, @@ -103,6 +106,8 @@ class defined. import triton.compiler.compiler + # Note: this works because triton.compiler.compiler imports AttrsDescriptor from triton.backends.compiler + # When support for the legacy AttrsDescriptor is removed then this import path should be changed. if hasattr(triton.compiler.compiler, "AttrsDescriptor"): return "from triton.compiler.compiler import AttrsDescriptor" else: @@ -125,7 +130,7 @@ def gen_common_triton_imports(): """ from torch._inductor.runtime import triton_helpers, triton_heuristics from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math - from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, instance_descriptor, DeviceProperties + from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties """ ) return imports.getvalue() @@ -364,9 +369,11 @@ def format(self, name: str, roffset=True) -> str: self.replace_roffset(offset, sympy.Integer(0)) for offset in offsets ] args = [ - f"{name} + ({f(self.constant_offset)})" - if self.constant_offset != 0 - else name, + ( + f"{name} + ({f(self.constant_offset)})" + if self.constant_offset != 0 + else name + ), f"shape={f(self.shape)}", f"strides={f(self.strides)}", f"block_shape={f(self.block_shape)}", @@ -448,12 +455,8 @@ def triton_reshape( """Workaround https://github.com/openai/triton/issues/2836""" assert isinstance(old_shape, list) and isinstance(new_shape, list) - def shape_to_str(shape: List[sympy.Expr]) -> List[str]: - return [str(dim) for dim in shape] - - old_shape_str, new_shape_str = tuple( - shape_to_str(shape) for shape in (old_shape, new_shape) - ) + old_shape_str = [V.kernel.index_to_str(shape) for shape in old_shape] + new_shape_str = [V.kernel.index_to_str(shape) for shape in new_shape] if old_shape_str == new_shape_str: return value @@ -640,67 +643,72 @@ def _print_RoundDecimal(self, expr): texpr = TritonPrinter().doprint +# correct cases where Triton types names don't match PyTorch +_triton_type_mapping = { + "tl.bool": "tl.int1", + "tl.float8_e4m3fn": "tl.float8e4nv", + "tl.float8_e5m2": "tl.float8e5", + "tl.float8_e4m3fnuz": "tl.float8e4b8", + "tl.float8_e5m2fnuz": "tl.float8e5b16", +} +_triton_type_re = re.compile(r"^.*[.]") + + +def triton_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type""" + triton_type_name = _triton_type_re.sub("tl.", str(dtype)) + return _triton_type_mapping.get(triton_type_name, triton_type_name) -def triton_compute_type(dtype): - triton_type_name = str(dtype).split(".")[-1] - if triton_type_name == "bool": - triton_type_name = "int1" - elif ( - triton_type_name in ("float16", "bfloat16") - and config.triton.codegen_upcast_to_fp32 + +def triton_compute_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type and upcast [b]float16 to float32""" + return triton_type(upcast_compute_type(dtype)) + + +def upcast_compute_type(dtype: torch.dtype) -> torch.dtype: + """Maybe upcast [b]float16 to float32""" + if config.triton.codegen_upcast_to_fp32 and ( + dtype == torch.float16 or dtype == torch.bfloat16 ): - # float16 math is done in float32 inside the kernel - triton_type_name = "float32" - elif triton_type_name == "float8_e4m3fn": - triton_type_name = "float8e4nv" - elif triton_type_name == "float8_e5m2": - triton_type_name = "float8e5" - elif triton_type_name == "float8_e4m3fnuz": - triton_type_name = "float8e4b8" - elif triton_type_name == "float8_e5m2fnuz": - triton_type_name = "float8e5b16" - return f"tl.{triton_type_name}" - - -def _get_primitive_bitwidth(dtype): - if hasattr(dtype, "is_floating_point"): - if dtype.is_floating_point: - # triton_compute_type changes the bitwidth - if ( - dtype in [torch.bfloat16, torch.float16] - and config.triton.codegen_upcast_to_fp32 - ): - return 32 - return torch.finfo(dtype).bits - else: - return torch.iinfo(dtype).bits + return torch.float32 + return dtype + + +def _get_primitive_bitwidth(dtype: torch.dtype) -> int: + """Number of bits of triton_compute_type()""" + dtype = upcast_compute_type(dtype) + itemsize = getattr(dtype, "itemsize", None) + if itemsize: + return itemsize * 8 else: return -1 -def triton_store_type(dtype): - triton_type_name = str(dtype).split(".")[-1] - if triton_type_name == "bool": - triton_type_name = "int8" - elif triton_type_name == "float8_e4m3fn": - triton_type_name = "float8e4nv" - elif triton_type_name == "float8_e5m2": - triton_type_name = "float8e5" - return f"tl.{triton_type_name}" +def triton_store_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type, with fix for storing tl.bool""" + if dtype == torch.bool: + dtype = torch.int8 + return triton_type(dtype) + +def upcast_acc_dtype(dtype: torch.dtype) -> torch.dtype: + """Implicit upcasts used for Triton reduction types""" + if is_integer_dtype(dtype) and dtype.is_signed and dtype.itemsize <= 4: + return torch.int32 + return upcast_compute_type(dtype) -def triton_acc_type(dtype): - if is_integer_dtype(dtype) and dtype.is_signed: - nbits = 64 if dtype == torch.int64 else 32 - return f"tl.int{nbits}" - return triton_compute_type(dtype) + +def triton_acc_type(dtype: torch.dtype) -> str: + """Convert torch.dtype to triton type, with reduction upcasts""" + return triton_compute_type(upcast_acc_dtype(dtype)) class TritonCSEVariable(CSEVariable): - def __init__(self, name, bounds: ValueRanges[Any]) -> None: - super().__init__(name, bounds) + def __init__(self, name, bounds: ValueRanges[Any], dtype: torch.dtype) -> None: + super().__init__(name, bounds, dtype) # We'll use this to track which masks the variable needs when used for indirect indexing self.mask_vars: OrderedSet[str] = OrderedSet() + assert dtype is not None, "TritonCSEVariable must have dtype" def update_on_args(self, name, args, kwargs): for arg in args: @@ -842,7 +850,14 @@ def expm1(x): @staticmethod def sqrt(x): - return f"libdevice.sqrt({x})" + if config.triton.codegen_upcast_to_fp32: + return f"libdevice.sqrt({x})" + else: + needs_upcast = x.dtype in (torch.float16, torch.bfloat16) + orig_dtype = triton_type(x.dtype) + upcast_string = ".to(tl.float32)" if needs_upcast else "" + downcast_string = f".to({orig_dtype})" if needs_upcast else "" + return f"libdevice.sqrt({x}{upcast_string}){downcast_string}" @staticmethod def libdevice_sqrt(x): @@ -1162,11 +1177,18 @@ def index_expr(cls, expr, dtype): indexing = V.kernel.indexing(expr, block_ptr=False) assert isinstance(indexing, IndexingOptions) var = V.kernel.cse.generate( - V.kernel.compute, indexing.index_str, bounds=get_bounds_index_expr(expr) + V.kernel.compute, + indexing.index_str, + bounds=get_bounds_index_expr(expr), + dtype=dtype, ) if dtype not in (torch.int32, torch.int64): - var = V.kernel.cse.generate(V.kernel.compute, cls.to_dtype(var, dtype)) + var = V.kernel.cse.generate( + V.kernel.compute, + cls.to_dtype(var, dtype), + dtype=dtype, + ) var.mask_vars = indexing.mask_vars return var @@ -1176,6 +1198,7 @@ def masked(mask, body, other): mask = V.kernel.cse.generate( V.kernel.compute, f"{mask}.to(tl.int1)", + dtype=torch.bool, ) nodes = body.graph.find_nodes(op="output") @@ -1200,6 +1223,7 @@ def masked(mask, body, other): V.kernel.compute, f"tl.full({result}.shape, {constant_repr(other)}, {result}.dtype)", bounds=ValueRanges.wrap(other), + dtype=result.dtype, ) ret = ops.where(new_mask, result, other) else: @@ -1221,8 +1245,8 @@ def frexp(x): if cache_key in V.kernel.cse.cache: return V.kernel.cse.cache[cache_key] - mantissa = V.kernel.cse.newvar() - exponent = V.kernel.cse.newvar() + mantissa = V.kernel.cse.newvar(dtype=x.dtype) + exponent = V.kernel.cse.newvar(dtype=x.dtype) V.kernel.compute.writeline( f"{mantissa}, {exponent} = triton_helpers.frexp({x})" ) @@ -1794,7 +1818,7 @@ def check_bounds( isinstance(m, TritonCSEVariable) for m in indexing.mask_vars ) buffer = self.get_load_buffer(indexing) - self.cse.generate(buffer, line, assignment=False) + self.cse.generate(buffer, line, assignment=False, dtype=torch.int32) def get_load_buffer(self, indexing): if indexing.has_indirect() or indexing.has_tmpmask(): @@ -1862,6 +1886,8 @@ def load(self, name: str, index: sympy.Expr): advance_block_ptr = None append_broadcast = None + dtype = V.graph.get_dtype(name) + if should_unwrap_unspec_arg(name): line = var else: @@ -1880,26 +1906,27 @@ def load(self, name: str, index: sympy.Expr): else: line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other})" - dtype = V.graph.get_dtype(name) if ( dtype in (torch.float16, torch.bfloat16) and config.triton.codegen_upcast_to_fp32 ): line += ".to(tl.float32)" + dtype = torch.float32 if dtype == torch.bool and torch.version.hip is None: # Workaround for https://github.com/openai/triton/issues/2151 # tl.load returns int8 when loading from pointer to int1 # NOTE: Currently causes hangs on bool UTs for ROCm line += ".to(tl.int1)" + dtype = torch.bool load_buffer = self.get_load_buffer(indexing) - result_var = self.cse.generate(load_buffer, line) + result_var = self.cse.generate(load_buffer, line, dtype=dtype) assert isinstance(result_var, TritonCSEVariable) result_var.mask_vars = indexing.mask_vars # type: ignore[assignment] if append_broadcast: line = f"tl.broadcast_to({result_var}, {append_broadcast})" - result_var = self.cse.generate(load_buffer, line) + result_var = self.cse.generate(load_buffer, line, dtype=dtype) if advance_block_ptr: load_buffer.writeline(advance_block_ptr) @@ -1997,6 +2024,7 @@ def bucketize( f"{sorter_indices}, " f"{block_size}, " ")", + dtype=values.dtype, # type: ignore[attr-defined] ) return result @@ -2034,7 +2062,9 @@ def reduction( dense_size_str = self.dense_size_str() value = self._map_tuple_or_scalar( lambda v: self.cse.generate( - self.compute, f"tl.broadcast_to({v}, {dense_size_str})" + self.compute, + f"tl.broadcast_to({v}, {dense_size_str})", + dtype=v.dtype, ), value, ) @@ -2065,7 +2095,7 @@ def final_argreduce(buffer, result_var, value, index): dim = self.triton_tensor_ndim() - 1 acc_type = triton_acc_type(src_dtype) - result_var: Any = self.cse.newvar() + result_var: Any = self.cse.newvar(dtype=dtype) result_var.mask_vars = OrderedSet(var for var in masks if var[0] != "r") cond = " & ".join(masks) @@ -2079,7 +2109,9 @@ def where_cond(tval, fval): default = self._map_tuple_or_scalar(constant_repr, default) def _mask_value(value, default): - return self.cse.generate(self.compute, where_cond(value, default)) + return self.cse.generate( + self.compute, where_cond(value, default), dtype=value.dtype + ) if isinstance(value, tuple): masked_value = [_mask_value(v, d) for v, d in zip(value, default)] @@ -2091,6 +2123,7 @@ def _mask_value(value, default): self.cse.generate( self.compute, f"tl.broadcast_to({reduction_range_prefix}index, {masked_value}.shape)", + dtype=torch.int64, ) ) root_op = {"argmax": "max", "argmin": "min"}[reduction_type] @@ -2105,16 +2138,18 @@ def _mask_value(value, default): elif reduction_type == "welford_combine": mean, m2, weight = masked_value welford = f"triton_helpers.welford({mean}, {m2}, {weight}, {dim})" - mean, m2, weight = (self.cse.newvar() for _ in range(3)) + mean, m2, weight = (self.cse.newvar(dtype=dtype) for _ in range(3)) self.compute.writeline(f"{mean}, {m2}, {weight} = {welford}") result_var = tuple( - self.cse.generate(self.compute, self.reduction_resize(var_name)) + self.cse.generate( + self.compute, self.reduction_resize(var_name), dtype=dtype + ) for var_name in (mean, m2, weight) ) else: result_var = self.cse.generate( - self.compute, final_reduction(masked_value) + self.compute, final_reduction(masked_value), dtype=dtype ) else: accumulator = f"_{result_var}" @@ -2186,8 +2221,8 @@ def _mask_value(value, default): ) result_mean = result_var - result_m2 = self.cse.newvar() - result_weight = self.cse.newvar() + result_m2 = self.cse.newvar(dtype=dtype) + result_weight = self.cse.newvar(dtype=dtype) self.suffix.splice( f"""\ {result_mean}_tmp, {result_m2}_tmp, {result_weight}_tmp = triton_helpers.welford( @@ -2289,6 +2324,7 @@ def inner(*args, **kwargs): return cse.generate( helper, getattr(overrides, name)(*args, **kwargs), + dtype=torch.float32, ) return inner @@ -2329,10 +2365,12 @@ def scan( value_dtype = self.cse.generate( self.compute, f"{value}.to({triton_compute_type(dtype)})", + dtype=dtype, ) value = self.cse.generate( self.compute, f"tl.broadcast_to({value_dtype}, {self.dense_size_str()})", + dtype=dtype, ) broadcasted_values.append(value) @@ -2340,7 +2378,7 @@ def scan( cond = " & ".join(masks) if not self.persistent_reduction: - accumulator = self.cse.newvar() + accumulator = self.cse.newvar(dtype=dtype) reduced_size = self.dense_size_list() reduced_size[-1] = "1" reduced_size = f"[{', '.join(reduced_size)}]" @@ -2355,11 +2393,12 @@ def scan( def csv(values): return " ".join(f"{value}," for value in values) - def cse_multiple(line, n, masks): + def cse_multiple(line, values, masks, dtypes): + n = len(values) cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] if all(cache_key in self.cse.cache for cache_key in cache_keys): return [self.cse.cache[cache_key] for cache_key in cache_keys] - result_vars = [self.cse.newvar() for _ in range(n)] + result_vars = [self.cse.newvar(dtype=_dtype) for _dtype in dtypes] self.compute.writeline( f"{csv(result_vars)} = {line}", ) @@ -2371,8 +2410,9 @@ def cse_multiple(line, n, masks): partial_scan_vars = cse_multiple( f"tl.associative_scan(({csv(broadcasted_values)}), {dim}, {combine_helper_fn})", - len(values), + values, masks, + dtypes, ) if not self.persistent_reduction: @@ -2381,14 +2421,18 @@ def cse_multiple(line, n, masks): # last scan value partial_reduce_vars = [ cse_compute( - f"triton_helpers.select_one(({partial_scan_var}), rbase == (RBLOCK - 1), dim=-1, keep_dims=True)" + f"triton_helpers.select_one(({partial_scan_var}), rbase == (RBLOCK - 1), dim=-1, keep_dims=True)", + dtype=partial_scan_var.dtype, ) for partial_scan_var in partial_scan_vars ] accs_next = combine_fn(tuple(accumulators), tuple(partial_reduce_vars)) full_scan_vars = combine_fn(tuple(accumulators), partial_scan_vars) result_vars = [ - cse_compute(f"tl.where(roffset > 0, {full_scan}, {partial_scan})") + cse_compute( + f"tl.where(roffset > 0, {full_scan}, {partial_scan})", + dtype=partial_scan.dtype, + ) for full_scan, partial_scan in zip(full_scan_vars, partial_scan_vars) ] for acc_next, accumulator, partial_reduce in zip( @@ -2425,19 +2469,22 @@ def sort( cse_compute = functools.partial(self.cse.generate, self.compute) dim = self.triton_tensor_ndim() - 1 + assert len(dtypes) == len(values) broadcasted_values = [ - cse_compute(f"tl.broadcast_to({value}, {self.dense_size_str()})") - for value in values + cse_compute( + f"tl.broadcast_to({value}, {self.dense_size_str()})", dtype=dtypes[i] + ) + for i, value in enumerate(values) ] def csv(values): return " ".join(f"{value}," for value in values) - def cse_multiple(line, n, masks): + def cse_multiple(line, n, masks, dtypes): cache_keys = [f"{line}, {i}, {masks}" for i in range(n)] if all(cache_key in self.cse.cache for cache_key in cache_keys): return [self.cse.cache[cache_key] for cache_key in cache_keys] - result_vars = [self.cse.newvar() for _ in range(n)] + result_vars = [self.cse.newvar(dtype=dtypes[i]) for i in range(n)] # type: ignore[attr-defined] self.compute.writeline( f"{csv(result_vars)} = {line}", ) @@ -2455,7 +2502,7 @@ def cse_multiple(line, n, masks): f"triton_helpers.sort_with_index({broadcasted_values[0]}, {broadcasted_values[1]}," f" {rnumel}, {dim}, stable={stable}, descending={descending})" ) - result_vars = cse_multiple(line, len(values), masks) + result_vars = cse_multiple(line, len(values), masks, dtypes) else: raise AssertionError("Unhandled sort") @@ -2540,10 +2587,10 @@ def codegen_kernel_benchmark(self, num_gb, grid=None): symval_hint = 0 result.writeline(f"{var_name} = {symval_hint}") elif isinstance(arg_sig, WorkspaceArg): - device = V.graph.scheduler.get_current_device_or_throw() - nbytes = V.graph.sizevars.size_hint(arg_sig.nbytes) + device = V.graph.get_current_device_or_throw() + count = V.graph.sizevars.size_hint(arg_sig.count) result.writeline( - f"{var_name} = torch.zeros({nbytes}, device='{device}', dtype=torch.uint8)" + f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})" ) else: raise KeyError( @@ -2569,7 +2616,7 @@ def codegen_kernel_benchmark(self, num_gb, grid=None): grid_arg = f"{extra_args_str}grid=grid({', '.join(grid)})" else: grid_arg = f"grid={grid}" - current_device = V.graph.scheduler.get_current_device_or_throw() + current_device = V.graph.get_current_device_or_throw() index = current_device.index with result.indent(): result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") @@ -2704,7 +2751,7 @@ def codegen_kernel(self, name=None): if name is None: code.splice(gen_common_triton_imports()) - device_type = V.graph.scheduler.get_current_device_or_throw().type + device_type = V.graph.get_current_device_or_throw().type if device_type == "cpu": code.splice("triton_helpers.set_driver_to_cpu()") else: @@ -2746,9 +2793,13 @@ def codegen_kernel(self, name=None): # zero_fill: that's because, if we don't expect the buffer to be pre-filled with # zeros, then, although we still mutate the data, we don't care about those # mutations because we don't make any assumptions about the contents of the - # workspace buffer. + # workspace buffer. Similarly, ZERO_PER_GRAPH requires the kernel to return + # the buffer back to its original state. for argname, arg in zip(argdefs, signature): - if isinstance(arg, WorkspaceArg) and arg.zero_fill: + if ( + isinstance(arg, WorkspaceArg) + and arg.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL + ): mutated_args.add(argname) mutated_args = sorted(mutated_args) @@ -2758,9 +2809,7 @@ def codegen_kernel(self, name=None): ) triton_meta = { "signature": triton_meta_signature, - "device": DeviceProperties.create( - V.graph.scheduler.get_current_device_or_throw() - ), + "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), "constants": {}, } @@ -2936,13 +2985,10 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None): _, call_args, _, arg_types = self.args.python_argdefs() grid: List[Any] = [] self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid) - current_device = V.graph.scheduler.get_current_device_or_throw() + current_device = V.graph.get_current_device_or_throw() - if self.args.workspace_arg is not None: - ws = self.args.workspace_arg - wrapper.generate_workspace_allocation( - ws.nbytes, current_device, ws.zero_fill - ) + for ws in self.args.workspace_args: + wrapper.generate_workspace_allocation(ws) grid = wrapper.generate_default_grid( name, grid, grid_callable=self._get_grid_fn() @@ -2959,8 +3005,8 @@ def call_kernel(self, name: str, node: Optional[IRNode] = None): triton_meta=self.triton_meta, ) - if self.args.workspace_arg is not None: - wrapper.writeline(wrapper.make_free_by_names(["workspace"])) + for ws in reversed(self.args.workspace_args): + wrapper.generate_workspace_deallocation(ws) def codegen_nan_check(self): wrapper = V.graph.wrapper_code @@ -2968,12 +3014,9 @@ def codegen_nan_check(self): for arg, arg_signature in zip(call_args, arg_signatures): if isinstance(arg_signature, TensorArg): if V.graph.cpp_wrapper: - if config.abi_compatible: - wrapper.writeline( - f'AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan("{arg}", {arg}));' - ) - else: - wrapper.writeline(f'assert_inf_and_nan("{arg}", {arg});') + wrapper.writeline( + f'AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_check_inf_and_nan("{arg}", {arg}));' + ) else: line = f"assert not {arg}.isnan().any().item()" wrapper.writeline(line) @@ -3102,6 +3145,14 @@ class TritonScheduling(SIMDScheduling): ) ) + def __init__(self, scheduler: Scheduler) -> None: + super().__init__(scheduler) + if scheduler is None or not hasattr(scheduler, "nodes"): + return + for node in scheduler.nodes: + if isinstance(node, (SchedulerNode, FusedSchedulerNode)): + node.debug_device_str = debug_triton_code + @classmethod def get_backend_features(cls, device: torch.device): return cls.backend_features @@ -3165,7 +3216,7 @@ def define_kernel(self, src_code, node_schedule, kernel): compile_wrapper = IndentedBuffer() compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''") compile_wrapper.splice(src_code, strip=True) - current_device = V.graph.scheduler.get_current_device_or_throw() + current_device = V.graph.get_current_device_or_throw() compile_wrapper.writeline(f"''', device_str='{current_device.type}')") metadata_comment = f"# kernel path: {kernel_path}" @@ -3178,14 +3229,14 @@ def define_kernel(self, src_code, node_schedule, kernel): # log kernel metadata for offline analysis. # E.g. one can find all unaligned inner reduction and check if # padding helps with the perf kernel by kernel. - if is_metric_table_enabled("kernel_metadata"): - log_kernel_metadata(kernel_name, kernel_path, src_code) + if metrics.is_metric_table_enabled("kernel_metadata"): + metrics.log_kernel_metadata(kernel_name, kernel_path, src_code) return kernel_name def benchmark_fused_nodes(self, nodes): with preserve_rng_state(), torch.cuda.device( - self.scheduler.get_current_device_or_throw() + V.graph.get_current_device_or_throw() ): src_code = self.generate_kernel_code_from_nodes( nodes, benchmark_kernel=True @@ -3350,3 +3401,35 @@ def store_cache(): V.graph.removed_buffers = removed_buffers_orig V.graph.inplaced_to_remove = inplaced_to_remove_orig return total_ms, total_clone_ms, file_list + + +def debug_triton_code(node: BaseSchedulerNode) -> List[str]: + lines = [] + multi_template = node.get_template_node() + assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer) + if multi_template and multi_template.make_kernel_render is None: + lines.append(f"{node.get_name()} Unfinalized multi template buffer") + else: + from torch._inductor.codegen.cuda_combined_scheduling import ( + CUDACombinedScheduling, + ) + + device = node.get_device() + backend = node.scheduler.get_backend(device) + assert isinstance( + backend, (SIMDScheduling, CUDACombinedScheduling) + ), f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}" + + with V.graph.set_current_device(device): + # Don't increment kernel count when generating debug string. + # This will confuse some unit tests that check the number of + # generated kernels. + old_generated_kernel_count = metrics.generated_kernel_count + triton_code = backend.generate_kernel_code_from_nodes( + node.get_nodes() + ).strip() + metrics.generated_kernel_count = old_generated_kernel_count + + lines.append(f"{node.get_name()} Triton code:") + lines.append(textwrap.indent(triton_code, " ")) + return lines diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index a89bcbfd60270..2b6c185170dd6 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -298,7 +298,7 @@ def codegen_pid_range( else: code.splice(f"elif pid < num_xblocks_{num}:") with code.indent(): - code.splice(f"pid_offset = pid - num_xblocks_{num-1}") + code.splice(f"pid_offset = pid - num_xblocks_{num - 1}") @classmethod def _calculate_xblocks( @@ -322,7 +322,7 @@ def _calculate_xblocks( if i == 0: code.splice(f"num_xblocks_{i} = {xblock_str}") else: - code.splice(f"num_xblocks_{i} = num_xblocks_{i-1} + {xblock_str}") + code.splice(f"num_xblocks_{i} = num_xblocks_{i - 1} + {xblock_str}") @classmethod def grid( @@ -673,9 +673,7 @@ def jit_line( "signature": signature_to_meta( signature, size_dtype=size_dtype, argdefs=argdefs ), - "device": DeviceProperties.create( - V.graph.scheduler.get_current_device_or_throw() - ), + "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), "constants": {}, } triton_meta["configs"] = [config_of(signature)] @@ -918,10 +916,11 @@ def codegen_kernel_benchmark( symval_hint = 0 result.writeline(f"{var_name} = {symval_hint}") elif isinstance(arg_sig, WorkspaceArg): - device = V.graph.scheduler.get_current_device_or_throw() - nbytes = V.graph.sizevars.size_hint(arg_sig.nbytes) + device = V.graph.get_current_device_or_throw() + count = V.graph.sizevars.size_hint(arg_sig.count) + # for benchmark harness, we ignore arg_sig.zero_mode and always zero it result.writeline( - f"{var_name} = torch.zeros({nbytes}, device='{device}', dtype=torch.uint8)" + f"{var_name} = torch.zeros({count}, device='{device}', dtype={arg_sig.dtype})" ) else: raise KeyError( @@ -960,7 +959,7 @@ def codegen_kernel_benchmark( grid_arg = f"{extra_args_str}grid=grid_combo_kernels({grid_str})" else: grid_arg = f"grid={grid}" - index = V.graph.scheduler.get_current_device_or_throw().index + index = V.graph.get_current_device_or_throw().index with result.indent(): result.writeline(f"with {V.graph.device_ops.device_guard(index)}:") with result.indent(): @@ -1088,7 +1087,7 @@ def call_kernel(self, code: IndentedBuffer, name: str) -> None: name, call_args, grid, - V.graph.scheduler.get_current_device_or_throw().index, + V.graph.get_current_device_or_throw().index, gpu=True, triton=True, arg_types=arg_types, diff --git a/torch/_inductor/codegen/triton_split_scan.py b/torch/_inductor/codegen/triton_split_scan.py index 31dab39923649..3ffe313aec4da 100644 --- a/torch/_inductor/codegen/triton_split_scan.py +++ b/torch/_inductor/codegen/triton_split_scan.py @@ -97,7 +97,7 @@ def scan(self, dtypes, combine_fn, values): scratch_type_triton.primitive_bitwidth // 8 ) - cse_load = functools.partial(self.cse.generate, self.loads) + cse_load = functools.partial(self.cse.generate, self.loads, dtype=dtype) cse_compute = functools.partial(self.cse.generate, self.compute) assert len(self.numels) == 2, "Unexpected tiling" @@ -115,18 +115,28 @@ def scan(self, dtypes, combine_fn, values): masks = {f"{tree.prefix}mask" for tree in self.range_trees} self.filter_masks(masks) - masks = sorted(masks) assert not self._load_mask, "ops.scan not supported inside ops.masked" - value = cse_compute(f"{value}.to({compute_type})") - value = cse_compute(f"tl.broadcast_to({value}, {self.dense_size_str()})") + value = cse_compute( + f"{value}.to({compute_type})", + dtype=dtype, + ) + value = cse_compute( + f"tl.broadcast_to({value}, {self.dense_size_str()})", + dtype=dtype, + ) combine_helper_fn = self._lift_helper(combine_fn, 1) dim = self.triton_tensor_ndim() - 1 assert dim == 0, "" - block_sum = cse_compute(f"tl.reduce({value}, {dim}, {combine_helper_fn})") - exclusive_prefix = self.cse.newvar() + block_sum = cse_compute( + f"tl.reduce({value}, {dim}, {combine_helper_fn})", + dtype=dtype, + ) + exclusive_prefix = self.cse.newvar( + dtype=dtype, + ) if element_nbits == 64: self.compute.splice( f""" @@ -159,13 +169,18 @@ def scan(self, dtypes, combine_fn, values): ) # Compute final cumsum block_scan = cse_compute( - f"tl.associative_scan({value}, {dim}, {combine_helper_fn})" + f"tl.associative_scan({value}, {dim}, {combine_helper_fn})", + dtype=dtype, ) combined_result = cse_compute( - f"{combine_helper_fn}({exclusive_prefix}, {block_scan})" + f"{combine_helper_fn}({exclusive_prefix}, {block_scan})", + dtype=dtype, ) return ( - cse_compute(f"tl.where(roffset == 0, {block_scan}, {combined_result})"), + cse_compute( + f"tl.where(roffset == 0, {block_scan}, {combined_result})", + dtype=dtype, + ), ) def _get_heuristic(self): diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 6c6c296d5fd0d..8b8c29bbb1524 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -6,16 +6,16 @@ import torch from .. import config -from ..runtime.hints import instance_descriptor +from ..runtime.hints import AttrsDescriptorWrapper from ..utils import _type_of, expr_fits_within_32bit from ..virtualized import V -from .common import KernelArgType, SizeArg, TensorArg, WorkspaceArg +from .common import KernelArgType, SizeArg, TensorArg, TMADescriptorArg, WorkspaceArg def should_unwrap_unspec_arg(name: str): if V.graph.is_unspec_arg(name): # Unwrap on all devices except CPU - if V.graph.scheduler.get_current_device_or_throw().type != "cpu": + if V.graph.get_current_device_or_throw().type != "cpu": return True # Only unwrap on CPU if the input is not used as an output if name not in V.graph.mutated_buffers: @@ -70,7 +70,9 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: else: raise NotImplementedError(f"unhandled size_dtype {size_dtype}") if isinstance(arg, WorkspaceArg): - return "*i8" + return _type_of(arg.dtype) + if isinstance(arg, TMADescriptorArg): + return "nvTmaDesc" raise NotImplementedError(f"unhandled {type(arg)}: {arg}") @@ -150,6 +152,8 @@ def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool: if isinstance(x, WorkspaceArg): # We allocate the workspace ourselves, so it is always aligned return True + if isinstance(x, TMADescriptorArg): + return False raise NotImplementedError(f"unhandled {type(x)}: {x}") if config.triton.divisible_by_16: @@ -160,11 +164,6 @@ def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool: ) else: divisible_by_16 = () - divisible_by_8 = tuple( - i - for i, arg in zip(indices, args) - if is_aligned(arg, alignment=8, include_tensor=False) - ) equal_to_1 = tuple( i @@ -173,10 +172,5 @@ def is_aligned(x: KernelArgType, alignment: int, include_tensor: bool) -> bool: and isinstance(arg.expr, (int, sympy.Integer)) and V.graph.sizevars.statically_known_equals(arg.expr, 1) # type: ignore[arg-type] ) - # ids_of_folded_args is set from equal_to_1 - # and None args by the Triton compiler - ids_of_folded_args = tuple(equal_to_1) - return instance_descriptor( - divisible_by_16, equal_to_1, ids_of_folded_args, divisible_by_8 - ) + return AttrsDescriptorWrapper(divisible_by_16, equal_to_1) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 55ed206d26a95..ec6d72b93bfe6 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -53,8 +53,14 @@ sympy_str, ) from ..virtualized import V -from .aoti_hipify_utils import maybe_hipify_code_wrapper -from .common import CodeGen, DeferredLine, IndentedBuffer, PythonPrinter +from .common import ( + CodeGen, + DeferredLine, + IndentedBuffer, + PythonPrinter, + WorkspaceArg, + WorkspaceZeroMode, +) from .triton_utils import config_of, should_unwrap_unspec_arg, signature_to_meta @@ -68,9 +74,10 @@ ReuseKey = Tuple[torch.device, torch.dtype, str] +BufferLike = Union[ir.Buffer, WorkspaceArg] -def buffer_reuse_key(node: ir.Buffer) -> ReuseKey: +def buffer_reuse_key(node: BufferLike) -> ReuseKey: return ( node.get_device(), node.get_dtype(), @@ -181,13 +188,16 @@ def determine_grid( sympy_grid = tuple(_convert_to_sympy_expr(g) for g in grid) return ( wrapper.codegen_shape_tuple(sympy_grid), - wrapper.codegen_shape_tuple( - tuple( - wrapper.generate_example_arg_value(g, type(g)) for g in sympy_grid + ( + wrapper.codegen_shape_tuple( + tuple( + wrapper.generate_example_arg_value(g, type(g)) + for g in sympy_grid + ) ) - ) - if config.triton.autotune_at_compile_time - else None, + if config.triton.autotune_at_compile_time + else None + ), ) def writeline(line: str, example_grid: Optional[str] = None): @@ -300,17 +310,9 @@ def codegen(self, code: IndentedBuffer) -> None: # associated with a device, so we never expect the device to change. # CUDAStreamGuard sets the stream and the device. if self.last_seen_device_guard_index is None: - if config.abi_compatible: - code.writeline( - f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);" - ) - else: - code.writeline( - maybe_hipify_code_wrapper( - f"{V.graph.device_ops.cpp_stream_guard()} stream_guard(" - + f"{V.graph.device_ops.cpp_getStreamFromExternal()}(stream, this->device_idx_));" - ) - ) + code.writeline( + f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);" + ) else: assert ( self.last_seen_device_guard_index == self.device_idx @@ -319,10 +321,6 @@ def codegen(self, code: IndentedBuffer) -> None: if self.last_seen_device_guard_index is None: code.writeline( f"{V.graph.device_ops.cpp_aoti_device_guard()} device_guard({self.device_idx});" - if config.abi_compatible - else maybe_hipify_code_wrapper( - f"{V.graph.device_ops.cpp_device_guard()} device_guard({self.device_idx});" - ) ) else: code.writeline(f"device_guard.set_index({self.device_idx});") @@ -368,7 +366,7 @@ def __str__(self) -> str: @dataclasses.dataclass class AllocateLine(MemoryPlanningLine): - node: ir.Buffer + node: BufferLike def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: if self.node.get_name() in V.graph.removed_buffers: @@ -398,7 +396,7 @@ def codegen(self, code: IndentedBuffer) -> None: @dataclasses.dataclass class FreeIfNotReusedLine(MemoryPlanningLine): - node: ir.Buffer + node: BufferLike is_reused: bool = False def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: @@ -421,8 +419,8 @@ def codegen(self, code: IndentedBuffer) -> None: @dataclasses.dataclass class ReuseLine(MemoryPlanningLine): - node: ir.Buffer - reused_as: ir.Buffer + node: BufferLike + reused_as: BufferLike delete_old: bool = True def plan(self, state: MemoryPlanningState) -> MemoryPlanningLine: @@ -462,6 +460,7 @@ def __init__(self): self.wrapper_call = IndentedBuffer() self.kernel_autotune_defs = IndentedBuffer() self.kernel_autotune_calls = IndentedBuffer() + self.subgraph_definitions = IndentedBuffer() self.kernel_autotune_names: Set[str] = set() # If the generated source code is exactly the same, reuse the # pre-existing kernel for it @@ -524,12 +523,21 @@ def add_import_once(line: str) -> None: self._metas: Dict[str, str] = {} self._meta_vars: Set[str] = set() self.multi_kernel_state = MultiKernelState() + self.already_codegened_subgraphs: Set[str] = set() # intermediate tensor value printing utility self.debug_printer = DebugPrinterManager( debug_printer_level=config.aot_inductor.debug_intermediate_value_printer ) + @staticmethod + def create( + is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + ): + if is_subgraph: + return SubgraphPythonWrapperCodegen(subgraph_name, parent_wrapper) + return PythonWrapperCodegen() + def set_launcher_fn_name(self) -> None: self.launcher_fn_name = "call" @@ -831,6 +839,18 @@ def generate_user_defined_triton_kernel( kernel_name, args, grid_fn=grid_fn, arg_types=arg_types, raw_args=raw_args ) + def generate_tma_descriptor(self, desc): + ptr = f"{desc.tensor.codegen_reference()}.data_ptr()" + dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.dims) + block_dims = ", ".join(self.val_to_arg_str(dim) for dim in desc.block_dims) + element_size = self.val_to_arg_str(desc.element_size) + prefix = "triton.tools.experimental_descriptor" + fn_name = f"create_{desc.rank}d_tma_descriptor" + call = f"{prefix}.{fn_name}" + args = f"{ptr}, {dims}, {block_dims}, {element_size}" + line = f"{desc.name} = {call}({args})" + self.writeline(line) + def generate_scatter_fallback( self, output, @@ -886,6 +906,9 @@ def _generate(self, is_inference): if V.graph.aot_mode and V.graph.cpp_wrapper and V.graph.is_const_graph: result = IndentedBuffer() + # Add subgraph definitions to the result + result.splice(self.subgraph_definitions) + with contextlib.ExitStack() as stack: stack.enter_context(self.wrapper_call.indent()) if config.profiler_mark_wrapper_call: @@ -1267,6 +1290,9 @@ def define_kernel( if config.triton.autotune_at_compile_time: self.kernel_autotune_defs.splice(body) + def define_subgraph_launcher_fn(self, fn_code: str): + self.subgraph_definitions.splice(fn_code) + def define_user_defined_triton_kernel(self, kernel, configs, kwargs): from torch.utils._triton import patch_triton_dtype_repr @@ -1274,7 +1300,7 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs): original_name = kernel.__name__ - from .common import KernelArgType, SizeArg, TensorArg + from .common import KernelArgType, SizeArg, TensorArg, TMADescriptorArg signature: List[KernelArgType] = [] constants: Dict[str, Any] = {} @@ -1286,9 +1312,17 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs): arg = kwargs[key] if idx in kernel.constexprs: constants[key] = arg + elif kwargs[key] is None: + constants[key] = None else: non_constant_indices.append(idx) - if isinstance(arg, ir.Buffer): + if isinstance(arg, ir.TMADescriptor): + signature.append( + TMADescriptorArg( + name=key, + ) + ) + elif isinstance(arg, ir.Buffer): signature.append( TensorArg( name=key, @@ -1323,9 +1357,7 @@ def define_user_defined_triton_kernel(self, kernel, configs, kwargs): indices=non_constant_indices, argdefs=kernel.arg_names, ), - "device": DeviceProperties.create( - V.graph.scheduler.get_current_device_or_throw() - ), + "device": DeviceProperties.create(V.graph.get_current_device_or_throw()), # Triton compiler includes equal_to_1 args into constants even # when they are not constexpr. otherwise there may be a segfault # during launching the Inductor-compiled Triton kernel. @@ -1444,7 +1476,7 @@ def traverse(cur_kernel): f"{symbol_name}{annotation_code} = {symbol_str}" ) else: - compile_wrapper.writeline(f"{symbol_name} = {symbol!r}") + compile_wrapper.writeline(f"{symbol_name} = {symbol_str}") symbols_included.add(symbol_name) elif ( symbol_name in unqualified_loads @@ -1465,7 +1497,7 @@ def traverse(cur_kernel): traverse(kernel) - current_device = V.graph.scheduler.get_current_device_or_throw() + current_device = V.graph.get_current_device_or_throw() compile_wrapper.writeline(f"''', device_str='{current_device.type}')") _, lineno = inspect.getsourcelines(kernel.fn) srcfile = inspect.getsourcefile(kernel.fn) @@ -1498,19 +1530,46 @@ def generate_numel_expr(self, kernel_name: str, tree, suffix: Optional[str] = No # it suffices as a type hint for the purposes of producing the correct code for this type. return SymbolicCallArg(expr, tree.numel) - def generate_workspace_allocation(self, nbytes, device, zero_fill): - if isinstance(nbytes, sympy.Expr): - nbytes = V.graph.sizevars.size_hint(nbytes) - line = self.make_allocation( - "workspace", device, torch.uint8, shape=(nbytes,), stride=(1,) - ) - self.writeline(line) + def generate_workspace_allocation(self, ws: WorkspaceArg): + name = ws.get_name() + line = AllocateLine(self, ws) + if ws.zero_mode == WorkspaceZeroMode.UNINITIALIZED: + self.writeline(line) + elif ws.zero_mode == WorkspaceZeroMode.ZERO_ON_CALL: + self.writeline(line) + self.writeline(self.make_zero_buffer(name)) + elif ws.zero_mode == WorkspaceZeroMode.ZERO_PER_GRAPH: + prior = V.graph.allocated_workspaces.get(name) + if prior: + assert isinstance(prior, AllocateLine) + # expand existing allocation + prior.node = WorkspaceArg.maximum(prior.node, ws) + else: + self.writeline(line) + self.writeline(self.make_zero_buffer(name)) + V.graph.allocated_workspaces[name] = line + else: + raise AssertionError(ws.zero_mode) + if config.triton.autotune_at_compile_time: - self.kernel_autotune_calls.writeline(line) - if zero_fill: - self.writeline(f"workspace.zero_(){self.ending}") - if config.triton.autotune_at_compile_time: - self.kernel_autotune_calls.writeline(f"workspace.zero_(){self.ending}") + self.kernel_autotune_calls.writeline( + self.make_allocation( + name, + ws.device, + ws.dtype, + shape=(V.graph.sizevars.size_hint(ws.count),), + stride=(1,), + ) + ) + if ws.zero_mode != WorkspaceZeroMode.UNINITIALIZED: + self.kernel_autotune_calls.writeline(self.make_zero_buffer(name)) + + def generate_workspace_deallocation(self, ws: WorkspaceArg): + if ws.zero_mode != WorkspaceZeroMode.ZERO_PER_GRAPH: + self.writeline(FreeIfNotReusedLine(self, ws)) + + def make_zero_buffer(self, name): + return f"{name}.zero_(){self.ending}" def wrap_kernel_call(self, name, call_args): return f"{name}({', '.join(call_args)}){self.ending}" @@ -1586,7 +1645,7 @@ def wrap_arg(arg): call_args = [wrap_arg(arg) for arg in call_args] if device_index is None: - current_device = V.graph.scheduler.get_current_device_or_throw() + current_device = V.graph.get_current_device_or_throw() device_index = current_device.index return device_index, call_args @@ -1603,13 +1662,19 @@ def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None): buf_name = f"tmp_arg_{index}" buf = raw_arg - size = V.graph.sizevars.size_hints( - buf.get_size(), - fallback=config.unbacked_symint_fallback, + size = tuple( + V.graph.sizevars.atomically_apply_size_hint( + e, + fallback=config.unbacked_symint_fallback, + ) + for e in buf.get_size() ) - stride = V.graph.sizevars.size_hints( - buf.get_stride(), - fallback=config.unbacked_symint_fallback, + stride = tuple( + V.graph.sizevars.atomically_apply_size_hint( + e, + fallback=config.unbacked_symint_fallback, + ) + for e in buf.get_stride() ) device = buf.get_device() dtype = buf.get_dtype() @@ -1633,18 +1698,11 @@ def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None): if arg in V.graph.sizevars.inv_precomputed_replacements: arg = V.graph.sizevars.inv_precomputed_replacements[arg] - # For multiple expressions that depend on an unbacked symint, - # we want to compute them consistently for a size hint we have chosen. - # So, recursively compute expressions via size hints of contained symbols. - free_symbols = arg.free_symbols - size_dict = { - symbol: V.graph.sizevars.size_hint( - symbol, - fallback=config.unbacked_symint_fallback, + return str( + V.graph.sizevars.atomically_apply_size_hint( + arg, fallback=config.unbacked_symint_fallback ) - for symbol in free_symbols - } - return str(arg.subs(size_dict)) + ) elif isinstance(arg, (str, int, float, bool)): return str(arg) @@ -1746,8 +1804,8 @@ def generate_kernel_call( if isinstance(arg_type, torch_dtype): # workspace allocation is already generated by `generate_workspace_allocation()` # in `TritonKernel.call_kernel()`. - if arg == "workspace": - arg_str = "workspace" + if re.match(r"^(workspace|semaphore)", arg): + arg_str = arg tensor_args[arg] = arg_str elif arg not in tensor_args: arg_str = self.generate_example_arg_value( @@ -1818,7 +1876,7 @@ def __repr__(self): return repr(s) # The following methods are for memory management - def make_buffer_allocation(self, buffer): + def make_buffer_allocation(self, buffer: BufferLike): device = buffer.get_device() dtype = buffer.get_dtype() shape = tuple(buffer.get_size()) @@ -1845,7 +1903,7 @@ def make_allocation(self, name, device, dtype, shape, stride): def make_tensor_alias(self, new_name, old_name, comment=""): return f"{self.declare}{new_name} = {old_name}{self.ending} {self.comment} {comment}" - def make_buffer_free(self, buffer): + def make_buffer_free(self, buffer: BufferLike): return f"del {buffer.get_name()}" def make_free_by_names(self, names_to_del: List[str]): @@ -1854,7 +1912,7 @@ def make_free_by_names(self, names_to_del: List[str]): def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str): return f"{self.declare_maybe_reference}{new_name} = {old_name}{del_line}{self.ending} {self.comment} reuse" - def make_buffer_reuse(self, old: ir.Buffer, new: ir.Buffer, delete_old: bool): + def make_buffer_reuse(self, old: BufferLike, new: BufferLike, delete_old: bool): assert old.get_dtype() == new.get_dtype() old_name = old.get_name() new_name = new.get_name() @@ -1958,37 +2016,126 @@ def codegen_unbacked_symbol_decl(self, symbol): self.unbacked_symbol_decls.add(name) return self.declare + name - def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): - for inner_input, outer_input in zip(subgraph.graph.graph_inputs, outer_inputs): - self.writeline(f"{self.declare}{inner_input} = {outer_input}{self.ending}") + def codegen_subgraph_by_inlining(self, subgraph, outer_inputs, outer_outputs): + # TODO (desertfire) - This function is the old way of supporting + # subgraph codegen by inlining subgraphs in the output code. For python + # wrapper, we have moved to lifting subgraphs as functions, supported by + # `codegen_subgraph` function. + # + # However this does not work with cpp wrapper. With cpp wrapper, we make + # two passes and the kernels are shared from the first pass to the next. + # Therefore, both the Python and CppWrapper need to share the some + # codegen infra. For now, CppWrapperCpu has not been updated to lift the + # subgraph as functions. Therefore for cpp_wrapper first pass with + # PythonWrapper, we still fallback to the old way of inlining subgraphs + # in the output code. Once we update CppWrapperCpu, we can remove this + # function. + def _codegen_subgraph_prefix(): + assert len(subgraph.graph.graph_inputs) == len(outer_inputs) + for inner_input, outer_input in zip( + subgraph.graph.graph_inputs, outer_inputs + ): + self.writeline( + f"{self.declare}{inner_input} = {outer_input}{self.ending}" + ) - def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): - for inner_output, outer_output in zip( - subgraph.graph.graph_outputs, outer_outputs - ): - self.writeline( - f"{outer_output} = {inner_output.codegen_reference()}{self.ending}" - ) + def _codegen_subgraph_suffix(): + assert len(subgraph.graph.graph_outputs) == len(outer_outputs) + for inner_output, outer_output in zip( + subgraph.graph.graph_outputs, outer_outputs + ): + self.writeline( + f"{outer_output} = {inner_output.codegen_reference()}{self.ending}" + ) - def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs): try: self.push_codegened_graph(subgraph.graph) self.writeline(f"{self.comment} subgraph: {subgraph.name}") - self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs) + _codegen_subgraph_prefix() parent_graph = V.graph with V.set_graph_handler(subgraph.graph): subgraph.graph.codegen_subgraph( parent_graph=parent_graph, ) - self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs) + _codegen_subgraph_suffix() finally: self.pop_codegened_graph() + def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): + subgraph.graph.add_symbol_graph_inputs() + # NB: Because of symints, the len of graph_inputs might be larger than + # outer_inputs + explicit_graph_inputs = subgraph.graph.graph_input_names[: len(outer_inputs)] + for inner_input, outer_input in zip(explicit_graph_inputs, outer_inputs): + self.writeline(f"{self.declare}{inner_input} = {outer_input}{self.ending}") + + def codegen_subgraph_suffix(self, subgraph, outer_inputs, outer_outputs): + assert len(subgraph.graph.graph_outputs) == len(outer_outputs) + for inner_output, outer_output in zip( + subgraph.graph.get_output_names(), outer_outputs + ): + self.writeline(f"{outer_output} = {inner_output}{self.ending}") + + def codegen_subgraph_call(self, subgraph, outer_inputs, outer_outputs): + # Get the input and output names of the subgraph + input_names = subgraph.graph.graph_input_names + inner_inputs = ", ".join(input_names) + if len(input_names) == 1: + inner_inputs += "," + + output_names = subgraph.graph.get_output_names() + inner_outputs = ", ".join(output_names) + if len(output_names) == 1: + inner_outputs += "," + + # Create a list of inputs for the subgraph call + self.writeline(f"{subgraph.graph.name}_args = [{inner_inputs}]") + for inner_input in input_names[: len(outer_inputs)]: + self.writeline(f"del {inner_input}") + + # Call the subgraph launcher function + self.writeline( + f"({inner_outputs}) = {subgraph.graph.name}({subgraph.graph.name}_args)" + ) + + def codegen_subgraph(self, subgraph, outer_inputs, outer_outputs): + # Codegen subgraph by recursively calling the codegen for the subgraph. + # This lifts the subgraph as a function in the output code. + if V.graph.aot_mode: + self.codegen_subgraph_by_inlining(subgraph, outer_inputs, outer_outputs) + return + + self.push_codegened_graph(subgraph.graph) + self.writeline(f"{self.comment} subgraph: {subgraph.name}") + self.codegen_subgraph_prefix(subgraph, outer_inputs, outer_outputs) + + parent_graph = V.graph + subgraph.graph.cpp_wrapper = parent_graph.cpp_wrapper + + if subgraph.graph.name not in self.already_codegened_subgraphs: + # If it is already codegened, the parent wrapper already has + # subgraph fn by name subgraph.graph.name + with V.set_graph_handler(subgraph.graph): + # Call the codegen of subgraph recursively + subgraph_code, _ = subgraph.graph.codegen() + self.already_codegened_subgraphs.add(subgraph.graph.name) + self.define_subgraph_launcher_fn(subgraph_code) + + self.codegen_subgraph_call(subgraph, outer_inputs, outer_outputs) + + self.codegen_subgraph_suffix(subgraph, outer_inputs, outer_outputs) + + def codegen_invoke_subgraph(self, invoke_subgraph): + name = invoke_subgraph.get_name() + + self.writeline(f"{name} = [None] * {len(invoke_subgraph.outputs)}") + outer_inputs = [buf.codegen_reference() for buf in invoke_subgraph.inputs] + outer_outputs = [f"{name}[{i}]" for i in range(len(invoke_subgraph.outputs))] + self.codegen_subgraph(invoke_subgraph.subgraph, outer_inputs, outer_outputs) + def codegen_conditional(self, conditional): name = conditional.get_name() - self.writeline(f"{name} = [None] * {len(conditional.outputs)}") - outer_inputs = [buf.codegen_reference() for buf in conditional.operands] outer_outputs = [f"{name}[{i}]" for i in range(len(conditional.outputs))] @@ -2090,3 +2237,59 @@ def static_shape_for_buffer_or_none(buffer): @staticmethod def can_prove_buffer_has_static_shape(buffer): return PythonWrapperCodegen.static_shape_for_buffer_or_none(buffer) is not None + + +class SubgraphPythonWrapperCodegen(PythonWrapperCodegen): + """ + A wrapper codegen that generates code for a subgraph. For most of the + methods, we rely on the implementation in the PythonWrapperCodegen. But we + override a few functions to produce cleaner code (like avoiding writing + imports twice in the output code) + """ + + def __init__(self, subgraph_name, parent_wrapper): + # It is necessary to set the subgraph_name before calling super __init__ + # because __init__ calls set_launcher_fn_name + self.subgraph_name = subgraph_name + self.parent_wrapper = parent_wrapper + super().__init__() + + def set_launcher_fn_name(self) -> None: + # This sets up the name of the function containing the launcher code of + # the subgraph. + self.launcher_fn_name = self.subgraph_name + + def write_header(self) -> None: + pass + + def add_benchmark_harness(self, output): + pass + + def benchmark_compiled_module(self, output): + pass + + def write_async_compile_wait(self): + pass + + def next_kernel_suffix(self) -> str: + # Ensures that subgraphs kernels do not clash with each other + return self.parent_wrapper.next_kernel_suffix() + + @cache_on_self + def write_triton_header_once(self) -> None: + # TODO: Uncomment in future. This will be needed to support subgraph + # codegen for cpp wrapper. + # if config.triton.autotune_at_compile_time: + # import_str = self.triton_header_str() + # self.kernel_autotune_calls.splice(import_str) + self.parent_wrapper.write_triton_header_once() + + @cache_on_self + def write_get_raw_stream_header_once(self) -> None: + # TODO: Uncomment in future. This will be needed to support subgraph + # codegen for cpp wrapper. + # if config.triton.autotune_at_compile_time: + # self.kernel_autotune_calls.writeline( + # V.graph.device_ops.import_get_raw_stream_as("get_raw_stream") + # ) + self.parent_wrapper.write_get_raw_stream_header_once() diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index d8b6d4c34d630..4af7b796091ea 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -1,5 +1,5 @@ -# mypy: allow-untyped-decorators -# mypy: allow-untyped-defs +from __future__ import annotations + import contextlib import functools import io @@ -10,13 +10,28 @@ import time import warnings from itertools import count -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import ( + Any, + Callable, + ContextManager, + Dict, + Generator, + List, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import Never, ParamSpec, Protocol, TypedDict, Unpack from unittest import mock import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools import torch.fx import torch.utils._pytree as pytree from functorch.compile import min_cut_rematerialization_partition +from torch import fx from torch._dispatch.python import enable_python_dispatcher from torch._dynamo import ( compiled_autograd, @@ -29,6 +44,7 @@ from torch._dynamo.utils import ( counters, detect_fake_mode, + dynamo_timed, flatten_graph_inputs, lazy_format_graph_code, ) @@ -60,14 +76,15 @@ tensor_is_aligned, ) from torch._logging import trace_structured -from torch._ops import OpOverload +from torch._utils_internal import compile_time_strobelight_meta +from torch.fx import GraphModule from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols, SymExprPrinter from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.monitor import _WaitCounter from torch.utils._ordered_set import OrderedSet from .._dynamo.backends.common import aot_autograd -from ..fx._lazy_graph_module import _use_lazy_graph_module # type: ignore[attr-defined] +from ..fx._lazy_graph_module import _use_lazy_graph_module from ..fx.graph import _PyTreeCodeGen from . import config, metrics from .debug import DebugContext @@ -76,7 +93,6 @@ from .fx_passes.post_grad import post_grad_passes, view_to_reshape from .fx_passes.pre_grad import pre_grad_passes from .graph import GraphLowering -from .ir import ExternKernelNode from .utils import ( align_inputs_from_check_idxs, clone_preserve_strides, @@ -91,13 +107,33 @@ from .virtualized import V -if config.is_fbcode(): - from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log -else: +if TYPE_CHECKING: + from torch._ops import OpOverload + + from .ir import ExternKernelNode + + +_P = ParamSpec("_P") +_T = TypeVar("_T") + +if TYPE_CHECKING or not config.is_fbcode(): # no-op decorator - def time_and_log(attr: str): + def time_and_log(attr: str) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: return dynamo_utils.identity + def log_optimus_to_scuba(*args: object, **kwargs: object) -> None: + pass + +else: + from torch._inductor.fb.utils import log_optimus_to_scuba, time_and_log + +if TYPE_CHECKING: + from torch._functorch._aot_autograd.schemas import ( + FQN, + GraphInputName, + GraphSignature, + ) + log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") @@ -111,7 +147,7 @@ def time_and_log(attr: str): # for expanded dimensions (a dimension which used to have size 1 -> ?) # we can select one element from that dimension and write to it # to achieve writing to all values of that dimension of the input tensor -def get_expanded_dims(t): +def get_expanded_dims(t: torch.Tensor) -> List[int]: if not isinstance(t, torch.Tensor): return None return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1] @@ -142,7 +178,7 @@ def complex_memory_overlap(t: torch.Tensor) -> bool: return False -def get_static_input_idxs(num_fixed): +def get_static_input_idxs(num_fixed: int) -> List[int]: # If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes # of cudagraphs. Rather than copying these into cudagraph-owned memory # like we do for normal inputs on each run, we will re-record a cudagraph if these @@ -156,12 +192,12 @@ def get_static_input_idxs(num_fixed): @functools.lru_cache(None) -def _step_logger(): +def _step_logger() -> Callable[..., None]: return dynamo_logging.get_step_logger(log) @functools.lru_cache(None) -def _warn_tf32_disabled(): +def _warn_tf32_disabled() -> None: if ( torch.cuda.is_available() and not torch.backends.cuda.matmul.allow_tf32 @@ -173,10 +209,12 @@ def _warn_tf32_disabled(): ) -def _unlift_graph(mod, gm, graph_signature): +def _unlift_graph( + mod: GraphModule, gm: GraphModule, graph_signature: GraphSignature +) -> GraphModule: from torch.export.unflatten import _assign_attr, _AttrKind - state_dict = {} + state_dict: Dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {} for name, param in mod.named_parameters(remove_duplicate=False): state_dict[name] = param _assign_attr( @@ -195,7 +233,7 @@ def _unlift_graph(mod, gm, graph_signature): ) placeholder_nodes = gm.graph.find_nodes(op="placeholder") - lifted_inputs = [] + lifted_inputs: List[Optional[FQN]] = [] # In AOTI, module parameters and buffers are not lifted as graph inputs. # As a result, mutation to buffers has side effect which makes their initial @@ -225,7 +263,7 @@ def _unlift_graph(mod, gm, graph_signature): user_input_mutations = graph_signature.user_inputs_to_mutate output_tokens = graph_signature.output_tokens for idx, out in enumerate(outputs): - value = None + value: Optional[Union[FQN, GraphInputName]] = None if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): if out.name in buffer_mutations: @@ -247,7 +285,7 @@ def _unlift_graph(mod, gm, graph_signature): return unlifted_gm -def _get_subgraph_names(gm): +def _get_subgraph_names(gm: GraphModule) -> Generator[str, None, None]: for node in sorted( itertools.chain( gm.graph.find_nodes(op="call_function", target=torch.ops.higher_order.cond), @@ -268,34 +306,38 @@ def _get_subgraph_names(gm): yield body_subgraph_name -def _recursive_pre_grad_passes(gm, example_inputs): - for subgraph_name in _get_subgraph_names(gm): - subgraph = getattr(gm, subgraph_name) - # as we don't have recursive example inputs, passing None here - new_subgraph = _recursive_pre_grad_passes(subgraph, example_inputs=None) - setattr(gm, subgraph_name, new_subgraph) - return pre_grad_passes(gm, example_inputs) +def _recursive_pre_grad_passes( + gm: GraphModule, example_inputs: Sequence[InputType] +) -> GraphModule: + with dynamo_timed("_recursive_pre_grad_passes"): + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + # as we don't have recursive example inputs, passing empty set here + new_subgraph = _recursive_pre_grad_passes(subgraph, ()) + setattr(gm, subgraph_name, new_subgraph) + return pre_grad_passes(gm, example_inputs) -def _recursive_joint_graph_passes(gm): +def _recursive_joint_graph_passes(gm: GraphModule) -> None: for subgraph_name in _get_subgraph_names(gm): subgraph = getattr(gm, subgraph_name) _recursive_joint_graph_passes(subgraph) joint_graph_passes(gm) -def _recursive_post_grad_passes(gm, is_inference: bool = False): - for subgraph_name in _get_subgraph_names(gm): - subgraph = getattr(gm, subgraph_name) - _recursive_post_grad_passes(subgraph, is_inference) - post_grad_passes(gm, is_inference) +def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) -> None: + with dynamo_timed("_recursive_post_grad_passes"): + for subgraph_name in _get_subgraph_names(gm): + subgraph = getattr(gm, subgraph_name) + _recursive_post_grad_passes(subgraph, is_inference) + post_grad_passes(gm, is_inference) def split_const_gm( - gm: torch.fx.GraphModule, + gm: GraphModule, lifted_constants: Optional[Dict[str, Any]] = None, skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, -) -> Tuple[torch.fx.GraphModule, Dict[str, int]]: +) -> Tuple[GraphModule, Dict[str, int]]: """ This function takes an GraphModule input "gm". The gm will be split into 2 components, @@ -357,7 +399,7 @@ def split_const_gm( return const_gm, const_output_index -def is_tf32_warning_applicable(gm: torch.fx.GraphModule): +def is_tf32_warning_applicable(gm: GraphModule) -> bool: aten = torch.ops.aten tf32_ops = { aten.mm.default, @@ -376,7 +418,9 @@ def is_tf32_warning_applicable(gm: torch.fx.GraphModule): return False -def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]): +def maybe_disable_comprehensive_padding( + example_inputs: Sequence[InputType], +) -> contextlib.AbstractContextManager[None, None]: """ For CPU backend, enable comprehensive padding causes some unit tests fail due to changing number of generated kernels. Skip for now. @@ -393,10 +437,10 @@ def maybe_disable_comprehensive_padding(example_inputs: List[torch.Tensor]): def fake_tensor_prop( - gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], + gm: GraphModule, + example_inputs: Sequence[InputType], force_allow_non_fake_inputs: bool = False, -): +) -> torch._subclasses.FakeTensorMode: """ If we can not detect fake mode from the context of inputs, create one. @@ -423,13 +467,15 @@ def fake_tensor_prop( # pass config dict back to user -def get_patched_config_dict(config_patches=None) -> Dict[str, Any]: +def get_patched_config_dict( + config_patches: Optional[Union[str, Dict[str, Any]]] = None +) -> Dict[str, Any]: with config.patch(config_patches): return config.get_config_copy() @contextlib.contextmanager -def with_fresh_cache_if_config(): +def with_fresh_cache_if_config() -> Generator[None, None, None]: if config.force_disable_caches: # Don't delete the cache dir because it has to survive beyond the # compile_fx call. Let's put the temp dirs under the default cache @@ -440,7 +486,50 @@ def with_fresh_cache_if_config(): yield -def compile_fx_inner(*args, **kwargs): +class _CompileFxKwargs(TypedDict, total=False): + cudagraphs: Optional[BoxedBool] + static_input_idxs: Sequence[int] + is_backward: bool + graph_id: Optional[int] + cpp_wrapper: bool + aot_mode: bool + is_inference: bool + user_visible_outputs: Optional[Dict[str, None]] + layout_opt: Optional[bool] + extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] + + +class _CompileFxKwargsEx(_CompileFxKwargs, total=False): + boxed_forward_device_index: Optional[BoxedDeviceIndex] + + +class _CompileFxCallableEx(Protocol): + def __call__( + self, + gm: GraphModule, + example_inputs: Sequence[InputType], + **kwargs: Unpack[_CompileFxKwargsEx], + ) -> Union[CompiledFxGraph, str]: + ... + + +def compile_fx_inner( + gm: GraphModule, + example_inputs: Sequence[InputType], + **kwargs: Unpack[_CompileFxKwargsEx], +) -> Union[CompiledFxGraph, str]: + kwargs.setdefault("cudagraphs", None) + kwargs.setdefault("static_input_idxs", ()) + kwargs.setdefault("is_backward", False) + kwargs.setdefault("graph_id", None) + kwargs.setdefault("cpp_wrapper", False) + kwargs.setdefault("aot_mode", False) + kwargs.setdefault("is_inference", False) + kwargs.setdefault("boxed_forward_device_index", None) + kwargs.setdefault("user_visible_outputs", None) + kwargs.setdefault("layout_opt", None) + kwargs.setdefault("extern_node_serializer", None) + # Need with_fresh_cache_if_config for compile_fx_inner even if we already have one for # compile_fx. The reason is the compilation for backward graph may happen after # compile_fx return and we may want to use the _LazyGraphModule for compiling @@ -453,29 +542,28 @@ def compile_fx_inner(*args, **kwargs): "compile_fx_inner", phase_name="inductor_compile", fwd_only=False ) ) + # NB: Why is this the dynamo_compile counter? The rule here is that + # if it gets an entry in the dynamo_compile table, we also want to + # tick up the wait counter. We have to displeasingly manually trigger + # the counter here because we may dropped into compile_fx directly + # from lazy backwards compilation. + stack.enter_context(_WaitCounter("pytorch.wait_counter.dynamo_compile").guard()) stack.enter_context(with_fresh_cache_if_config()) stack.enter_context(DebugContext()) return wrap_compiler_debug(_compile_fx_inner, compiler_name="inductor")( - *args, **kwargs + gm, + example_inputs, + **kwargs, ) @time_and_log(attr="compilation time (in seconds)") def _compile_fx_inner( - gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], - cudagraphs: Optional[BoxedBool] = None, - static_input_idxs: Optional[List[int]] = None, - is_backward: bool = False, - graph_id: Optional[int] = None, - cpp_wrapper: bool = False, - aot_mode: bool = False, - is_inference: bool = False, + gm: GraphModule, + example_inputs: Sequence[InputType], boxed_forward_device_index: Optional[BoxedDeviceIndex] = None, - user_visible_outputs: Optional[Dict[str, None]] = None, - layout_opt: Optional[bool] = None, - extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] = None, + **graph_kwargs: Unpack[_CompileFxKwargs], ) -> Union[CompiledFxGraph, str]: """ Inductor API that compiles a single graph. @@ -483,6 +571,8 @@ def _compile_fx_inner( If you change the argument list for this function, make sure you also update the call to save_args_for_compile_fx_inner below accordingly. """ + aot_mode: bool = graph_kwargs.setdefault("aot_mode", False) + if dynamo_utils.count_calls(gm.graph) == 0 and not aot_mode: # trigger the real recompilation for _LazyGraphModule before returning # the forward method. @@ -491,62 +581,35 @@ def _compile_fx_inner( _LazyGraphModule.force_recompile(gm) return make_boxed_func(gm.forward) - if static_input_idxs is None: - static_input_idxs = [] - + static_input_idxs: Sequence[int] = graph_kwargs.setdefault("static_input_idxs", ()) static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs) assert isinstance( next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list) ), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}" + if (cudagraphs := graph_kwargs.get("cudagraphs")) is None: + graph_kwargs["cudagraphs"] = cudagraphs = BoxedBool(config.triton.cudagraphs) if config.save_args: save_args_for_compile_fx_inner( gm, example_inputs, - cudagraphs=cudagraphs, - static_input_idxs=static_input_idxs, - is_backward=is_backward, - graph_id=graph_id, - cpp_wrapper=cpp_wrapper, - aot_mode=aot_mode, - is_inference=is_inference, boxed_forward_device_index=boxed_forward_device_index, - user_visible_outputs=user_visible_outputs, - layout_opt=layout_opt, + **graph_kwargs, ) - if cudagraphs is None: - cudagraphs = BoxedBool(config.triton.cudagraphs) - - # Inputs to fx_codegen_and_compile - # Anything that affects codegen should go here, so if the signature - # of fx_codegen_and_compile changes, the dict should be updated accordingly - graph_kwargs = { - "cudagraphs": cudagraphs, - "static_input_idxs": static_input_idxs, - "is_backward": is_backward, - "graph_id": graph_id, - "cpp_wrapper": cpp_wrapper, - "aot_mode": aot_mode, - "is_inference": is_inference, - "user_visible_outputs": user_visible_outputs, - "layout_opt": layout_opt, - "extern_node_serializer": extern_node_serializer, - } - start = time.time() fx_graph_remote_cache = should_use_remote_fx_graph_cache() - inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) # type: ignore[arg-type] + inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) def codegen_and_compile( - gm, - example_inputs, - inputs_to_check, - fx_kwargs, - ): + gm: GraphModule, + example_inputs: Sequence[InputType], + inputs_to_check: Sequence[int], + fx_kwargs: _CompileFxKwargs, + ) -> Union[CompiledFxGraph, str]: """ This function calls fx_codegen_and_compile and also adds some extra metadata to the resulting compiled fx graph. The metadata is saved to FXGraphCache. @@ -585,7 +648,7 @@ def codegen_and_compile( check_for_mutation_ignore_cuda_graph_managed_tensor( gm, compiled_graph, - static_input_idxs, # type:ignore[arg-type] + static_input_idxs, ) ) has_mutation = has_mutation_str is not None @@ -651,7 +714,7 @@ def codegen_and_compile( ) else: compiled_graph = codegen_and_compile( - gm, example_inputs, inputs_to_check, graph_kwargs # type: ignore[arg-type] + gm, example_inputs, inputs_to_check, graph_kwargs ) if aot_mode: # AOT mode is special because codegen_and_compile returns a string. @@ -667,8 +730,8 @@ def codegen_and_compile( _step_logger()( logging.INFO, "torchinductor done compiling " - f"{'BACKWARDS' if is_backward else 'FORWARDS'} " - f"graph {graph_id}", + f"{'BACKWARDS' if graph_kwargs['is_backward'] else 'FORWARDS'} " + f"graph {graph_kwargs['graph_id']}", ) # aot autograd needs to know to pass in inputs as a list compiled_graph._boxed_call = True @@ -676,10 +739,10 @@ def codegen_and_compile( def fx_codegen_and_compile( - gm: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], + gm: GraphModule, + example_inputs: Sequence[InputType], cudagraphs: Optional[BoxedBool] = None, - static_input_idxs: Optional[List[int]] = None, + static_input_idxs: Optional[Sequence[int]] = None, is_backward: bool = False, graph_id: Optional[int] = None, cpp_wrapper: bool = False, @@ -714,7 +777,7 @@ def fx_codegen_and_compile( f"graph {graph_id}", ) - def log_graph_runnable(): + def log_graph_runnable() -> str: fd = io.StringIO() torch._dynamo.repro.after_aot.save_graph_repro( fd, gm, example_inputs, "inductor", save_dir=None @@ -768,7 +831,9 @@ def log_graph_runnable(): with V.set_fake_mode(fake_mode): # has some issues with memory in training - _recursive_post_grad_passes(gm, is_inference=is_inference) + cuda_context = get_cuda_device_context(gm) + with cuda_context: + _recursive_post_grad_passes(gm, is_inference=is_inference) V.debug.fx_graph_transformed(gm, example_inputs) post_grad_graphs_log.debug( "%s", @@ -925,7 +990,7 @@ def log_graph_runnable(): def get_input_idxs_to_check( - inputs: List[InputType], + inputs: Sequence[InputType], static_input_idxs: Sequence[int], ) -> Sequence[int]: """ @@ -992,7 +1057,7 @@ def cudagraphify( compiled_fn = None - def run(new_inputs): + def run(new_inputs: Sequence[InputType]) -> Any: nonlocal compiled_fn if compiled_fn is None: with dynamo_utils.dynamo_timed( @@ -1015,7 +1080,7 @@ def index_expanded_dims_and_copy_( dst: torch.Tensor, src: torch.Tensor, expanded_dims: List[int], -): +) -> None: "Index into expanded dimensions of both dst and src then copy_" dst = index_expanded_dims(dst, expanded_dims) src = index_expanded_dims(src, expanded_dims) @@ -1026,7 +1091,7 @@ def cudagraphify_impl( model: Callable[..., Any], inputs: List[torch.Tensor], static_input_idxs: Sequence[int] = (), -): +) -> Callable[[List[InputType]], Any]: """ Assumes inputs[static_input_idxs[i]] are always the same memory address """ @@ -1078,14 +1143,15 @@ def cudagraphify_impl( if config.size_asserts: - def run(new_inputs): + def run(new_inputs: List[InputType]) -> Callable[[List[InputType]], Any]: assert len(static_inputs) == len(new_inputs) for idx, (dst, src, expanded_dims) in enumerate( zip(static_inputs, new_inputs, inps_expanded_dims) ): if not isinstance(dst, torch.Tensor): - pass - elif idx in static_input_idxs: + continue + assert isinstance(src, torch.Tensor) + if idx in static_input_idxs: assert dst.data_ptr() == src.data_ptr() else: # TODO - could make one single op of multiple slices @@ -1101,12 +1167,12 @@ def run(new_inputs): idx for idx in range(len(static_inputs)) if idx not in static_input_idxs ] - def run(new_inputs): + def run(new_inputs: List[InputType]) -> Callable[[List[InputType]], Any]: for idx in copy_indices: expanded_dims = inps_expanded_dims[idx] - index_expanded_dims_and_copy_( - static_inputs[idx], new_inputs[idx], expanded_dims - ) + src = new_inputs[idx] + assert isinstance(src, torch.Tensor) + index_expanded_dims_and_copy_(static_inputs[idx], src, expanded_dims) new_inputs.clear() graph.replay() return static_outputs @@ -1115,11 +1181,11 @@ def run(new_inputs): def compile_fx_aot( - model_: torch.fx.GraphModule, - example_inputs_: List[torch.Tensor], - inner_compile: Callable[..., Any] = compile_fx_inner, - config_patches: Optional[Dict[str, Any]] = None, -): + model_: GraphModule, + example_inputs_: List[InputType], + inner_compile: _CompileFxCallableEx = compile_fx_inner, + config_patches: Optional[Dict[str, str]] = None, +) -> str: config_patches: Dict[str, Any] = ( {"cpp_wrapper": True} if config_patches is None @@ -1147,6 +1213,7 @@ def compile_fx_aot( ), config_patches=config_patches, ) + assert isinstance(compiled_lib_path, str) assert os.path.exists( compiled_lib_path ), f"AOTInductor compiled library does not exist at {compiled_lib_path}" @@ -1157,15 +1224,15 @@ def compile_fx_aot( def fw_compiler_freezing( - aot_autograd_model: torch.fx.GraphModule, - aot_example_inputs: List[torch.Tensor], - dynamo_model: torch.fx.GraphModule, + aot_autograd_model: GraphModule, + aot_example_inputs: Sequence[InputType], + dynamo_model: GraphModule, num_example_inputs: int, inner_compile: Callable[..., Any], cudagraphs: BoxedBool, graph_id: int, forward_device: BoxedDeviceIndex, -): +) -> Callable[[List[object]], Sequence[torch.Tensor]]: from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze # partition_fn won't be called @@ -1248,7 +1315,7 @@ def fw_compiler_freezing( if V.aot_compilation is True: return optimized_function - def wrapper(args): + def wrapper(args: List[object]) -> Sequence[torch.Tensor]: args_new = [ args[i - unwrapped_args_offsets[min(i, max_offset_idx)]] for i in preserved_arg_indices @@ -1261,7 +1328,7 @@ def wrapper(args): return wrapper -def get_cpp_wrapper_config(): +def get_cpp_wrapper_config() -> Dict[str, object]: return { # Set autotune_at_compile_time to True as default if the option is not explicitly set "triton.autotune_at_compile_time": config.triton.autotune_at_compile_time @@ -1273,13 +1340,43 @@ def get_cpp_wrapper_config(): } +def get_cuda_device_context(gm: torch.fx.GraphModule) -> ContextManager[None]: + """ + Returns a cuda device context manager if there is a single device in the graph + """ + if not torch.cuda.is_available(): + return contextlib.nullcontext() + + placeholder_nodes = gm.graph.find_nodes(op="placeholder") + input_devices: OrderedSet[torch.device] = OrderedSet( + node.meta["val"].device + for node in placeholder_nodes + if isinstance(node.meta.get("val"), torch.Tensor) + ) + + out_devices: OrderedSet[torch.device] = OrderedSet( + arg.meta["val"].device + for arg in output_node(gm).args[0] + if isinstance(arg, fx.Node) and isinstance(arg.meta.get("val"), torch.Tensor) + ) + cuda_devices: OrderedSet[torch.device] = OrderedSet( + device for device in (input_devices | out_devices) if device.type == "cuda" + ) + + return ( + torch.cuda.device(next(iter(cuda_devices))) # type: ignore[return-value] + if len(cuda_devices) == 1 + else contextlib.nullcontext() + ) + + def compile_fx( - model_: torch.fx.GraphModule, - example_inputs_: List[torch.Tensor], + model_: GraphModule, + example_inputs_: Sequence[InputType], inner_compile: Callable[..., Any] = compile_fx_inner, config_patches: Optional[Dict[str, Any]] = None, decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None, -): +) -> Union[Callable[[List[object]], Sequence[torch.Tensor]], str]: with _use_lazy_graph_module(dynamo_config.use_lazy_graph_module): """Main entrypoint to a compile given FX graph""" if config_patches: @@ -1299,8 +1396,9 @@ def compile_fx( **get_cpp_wrapper_config(), } ), V.set_real_inputs(example_inputs_): - inputs_ = example_inputs_ - if isinstance(model_, torch.fx.GraphModule): + inputs_: Sequence[InputType] = example_inputs_ + + if isinstance(model_, GraphModule): fake_inputs = [ node.meta.get("val") for node in model_.graph.nodes @@ -1313,15 +1411,17 @@ def compile_fx( for inp in fake_inputs ] - if all(v is not None for v in fake_inputs): + if any(v is not None for v in fake_inputs): # Validate devices before switching to fake tensors. for idx, fi, i in zip(count(), fake_inputs, inputs_): - if fi is not None and fi.device != i.device: - raise ValueError( - f"Device mismatch between fake input and example input at position #{idx}: " - f"{fi.device} vs {i.device}. If the model was exported via torch.export(), " - "make sure torch.export() and torch.aot_compile() run on the same device." - ) + if fi is not None: + assert isinstance(i, torch.Tensor) + if fi.device != i.device: + raise ValueError( + f"Device mismatch between fake input and example input at position #{idx}: " + f"{fi.device} vs {i.device}. If the model was exported via torch.export(), " + "make sure torch.export() and torch.aot_compile() run on the same device." + ) inputs_ = fake_inputs # type: ignore[assignment] return compile_fx( model_, @@ -1343,7 +1443,7 @@ def compile_fx( recursive_compile_fx, ) - if isinstance(model_, torch.fx.GraphModule): + if isinstance(model_, GraphModule): if isinstance(model_.graph._codegen, _PyTreeCodeGen): # this graph is the result of dynamo.export() return handle_dynamo_export_graph( @@ -1373,18 +1473,18 @@ def compile_fx( ) def fw_compiler_base( - model: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], + model: GraphModule, + example_inputs: List[InputType], is_inference: bool, - ): + ) -> CompiledFxGraph: with dynamo_utils.dynamo_timed("compile_fx..fw_compiler_base"): return _fw_compiler_base(model, example_inputs, is_inference) def _fw_compiler_base( - model: torch.fx.GraphModule, - example_inputs: List[torch.Tensor], + model: GraphModule, + example_inputs: List[InputType], is_inference: bool, - ): + ) -> CompiledFxGraph: if is_inference: # partition_fn won't be called _recursive_joint_graph_passes(model) @@ -1409,7 +1509,7 @@ def _fw_compiler_base( else: original_output_start_index = 0 - if isinstance(model_, torch.fx.GraphModule): + if isinstance(model_, GraphModule): *_, orig_model_outputs_node = model_.graph.nodes assert orig_model_outputs_node.op == "output" orig_model_outputs, _ = pytree.tree_flatten( @@ -1463,7 +1563,7 @@ def _fw_compiler_base( fw_compiler = functools.partial(fw_compiler_base, is_inference=False) if config.freezing and not torch.is_grad_enabled(): - inference_compiler = functools.partial( + inference_compiler: Callable[..., Any] = functools.partial( fw_compiler_freezing, dynamo_model=model_, num_example_inputs=num_example_inputs, @@ -1475,15 +1575,22 @@ def _fw_compiler_base( else: inference_compiler = functools.partial(fw_compiler_base, is_inference=True) - def partition_fn(graph, joint_inputs, **kwargs): - _recursive_joint_graph_passes(graph) + def partition_fn( + gm: GraphModule, + joint_inputs: Sequence[object], + **kwargs: object, + ) -> Tuple[GraphModule, GraphModule]: + cuda_context = get_cuda_device_context(gm) + with cuda_context: + _recursive_joint_graph_passes(gm) return min_cut_rematerialization_partition( - graph, joint_inputs, **kwargs, compiler="inductor" + gm, joint_inputs, **kwargs, compiler="inductor" ) + @compile_time_strobelight_meta(phase_name="backward") def bw_compiler( - model: torch.fx.GraphModule, example_inputs: List[torch.Tensor] - ): + model: GraphModule, example_inputs: List[InputType] + ) -> Union[CompiledFxGraph, str]: with dynamo_utils.dynamo_timed("compile_fx..bw_compiler"): user_visible_outputs = {} @@ -1563,9 +1670,9 @@ def bw_compiler( )(model_, example_inputs_) -def graph_returns_tuple(gm: torch.fx.GraphModule): +def graph_returns_tuple(gm: GraphModule) -> bool: """True if a FX graph returns a tuple""" - if not isinstance(gm, torch.fx.GraphModule): + if not isinstance(gm, GraphModule): return True # can't check this, assume true (rv,) = output_node(gm).args if isinstance(rv, (list, tuple)): @@ -1582,10 +1689,10 @@ def graph_returns_tuple(gm: torch.fx.GraphModule): def make_graph_return_tuple( - gm: torch.fx.GraphModule, - inputs: List[torch.Tensor], + gm: GraphModule, + inputs: Sequence[InputType], compile_gm: Callable[..., Any], -): +) -> Callable[..., Any]: """ Mutate gm so it returns a tuple. This is only needed for graphs not created by torchdynamo that return non-tuples. @@ -1601,17 +1708,17 @@ def make_graph_return_tuple( compiled_fn = compile_gm(gm, inputs) @functools.wraps(compiled_fn) - def wrapper(*args, **kwargs): + def wrapper(*args: Any, **kwargs: Any) -> Any: return pytree.tree_unflatten(compiled_fn(*args, **kwargs), spec) return wrapper def handle_dynamo_export_graph( - gm: torch.fx.GraphModule, - inputs: List[torch.Tensor], + gm: GraphModule, + inputs: Sequence[InputType], compile_gm: Callable[..., Any], -): +) -> Callable[..., Any]: """ `torch._dynamo.export` embeds pytrees in the FX graph codegen object, convert that to a normal FX graph so inductor can compile it. @@ -1623,14 +1730,14 @@ def handle_dynamo_export_graph( compiled_fn = compile_gm(gm, codegen.process_inputs(*inputs)) @functools.wraps(compiled_fn) - def wrapper(*args): + def wrapper(*args: Any) -> Any: return codegen.process_outputs(compiled_fn(*codegen.process_inputs(*args))) return wrapper def _check_triton_bf16_support(graph: GraphLowering) -> None: - def warn_and_skip(device) -> None: + def warn_and_skip(device: torch.device) -> Never: from torch._dynamo.exc import SkipFrame device_interface = get_interface_for_device(device.type) diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index a46bb772ad90d..82281e1b38acb 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -1,4 +1,3 @@ -# mypy: allow-untyped-defs import functools import itertools import logging @@ -13,7 +12,8 @@ import typing from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures.process import BrokenProcessPool -from typing import Any, Callable, Dict +from typing import Any, BinaryIO, Callable, Dict, Tuple, TypeVar +from typing_extensions import Never, ParamSpec # _thread_safe_fork is needed because the subprocesses in the pool can read # justknobs, e.g., in the Triton compiler. For internal, the import installs @@ -25,12 +25,15 @@ log = logging.getLogger(__name__) +_P = ParamSpec("_P") +_T = TypeVar("_T") -def _pack_msg(job_id, length): + +def _pack_msg(job_id: int, length: int) -> bytes: return struct.pack("nn", job_id, length) -def _unpack_msg(data): +def _unpack_msg(data: bytes) -> Tuple[int, int]: if not data: return -1, -1 return struct.unpack("nn", data) @@ -39,7 +42,7 @@ def _unpack_msg(data): msg_bytes = len(_pack_msg(0, 0)) -def _send_msg(write_pipe, job_id, job_data=b""): +def _send_msg(write_pipe: BinaryIO, job_id: int, job_data: bytes = b"") -> None: length = len(job_data) write_pipe.write(_pack_msg(job_id, length)) if length > 0: @@ -47,13 +50,13 @@ def _send_msg(write_pipe, job_id, job_data=b""): write_pipe.flush() -def _recv_msg(read_pipe): +def _recv_msg(read_pipe: BinaryIO) -> Tuple[int, bytes]: job_id, length = _unpack_msg(read_pipe.read(msg_bytes)) data = read_pipe.read(length) if length > 0 else b"" return job_id, data -def _get_ld_library_path(): +def _get_ld_library_path() -> str: path = os.environ.get("LD_LIBRARY_PATH", "") if config.is_fbcode(): from libfb.py.parutil import get_runtime_path @@ -73,7 +76,7 @@ class _SubprocExceptionInfo: use it for the message in the exception thrown in the main process. """ - def __init__(self, details) -> None: + def __init__(self, details: str) -> None: self.details = details @@ -82,7 +85,7 @@ class SubprocException(Exception): Thrown when a job in a subprocess raises an Exception. """ - def __init__(self, details) -> None: + def __init__(self, details: str) -> None: super().__init__(f"An exception occurred in a subprocess:\n\n{details}") @@ -136,11 +139,13 @@ def __init__(self, nprocs: int) -> None: # before any access. self.read_thread.start() - def submit(self, job_fn: Callable[..., Any], *args): - if args: - job_fn = functools.partial(job_fn, *args) + def submit( + self, job_fn: Callable[_P, _T], *args: _P.args, **kwargs: _P.kwargs + ) -> Future[_T]: + if args or kwargs: + job_fn = functools.partial(job_fn, *args, **kwargs) job_data = pickle.dumps(job_fn, pickle.HIGHEST_PROTOCOL) - future: Future[Any] + future: Future[_T] with self.futures_lock: job_id = next(self.job_id_count) self.pending_futures[job_id] = future = Future() @@ -151,7 +156,7 @@ def submit(self, job_fn: Callable[..., Any], *args): _send_msg(self.write_pipe, job_id, job_data) return future - def _read_thread(self): + def _read_thread(self) -> None: try: while True: job_id, data = _recv_msg(self.read_pipe) @@ -178,7 +183,7 @@ def _read_thread(self): except Exception: log.exception("failure in SubprocPool._read_thread") - def shutdown(self): + def shutdown(self) -> None: try: with self.write_lock: if not self.running: @@ -200,7 +205,7 @@ def shutdown(self): class SubprocMain: """Communicates with a SubprocPool in the parent process, called by __main__.py""" - def __init__(self, nprocs, read_pipe, write_pipe) -> None: + def __init__(self, nprocs: int, read_pipe: BinaryIO, write_pipe: BinaryIO) -> None: self.read_pipe = read_pipe self.write_pipe = write_pipe self.write_lock = threading.Lock() @@ -208,7 +213,7 @@ def __init__(self, nprocs, read_pipe, write_pipe) -> None: self.pool = self._new_pool(nprocs, True) self.running = True - def _new_pool(self, nprocs, warm): + def _new_pool(self, nprocs: int, warm: bool) -> ProcessPoolExecutor: pool = ProcessPoolExecutor( nprocs, mp_context=multiprocessing.get_context("fork"), @@ -219,14 +224,14 @@ def _new_pool(self, nprocs, warm): _warm_process_pool(pool, nprocs) return pool - def main(self): + def main(self) -> None: while True: job_id, data = _recv_msg(self.read_pipe) if job_id < 0: return self._shutdown() self.submit(job_id, data) - def _shutdown(self): + def _shutdown(self) -> None: with self.write_lock: self.running = False try: @@ -237,7 +242,7 @@ def _shutdown(self): self.read_pipe.close() self.pool.shutdown() - def submit(self, job_id, data): + def submit(self, job_id: int, data: bytes) -> None: while self.running: try: self._submit_inner(job_id, data) @@ -248,10 +253,10 @@ def submit(self, job_id, data): # recreating the pool and resubmitting. self.pool = self._new_pool(self.nprocs, False) - def _submit_inner(self, job_id, data): + def _submit_inner(self, job_id: int, data: bytes) -> None: future = self.pool.submit(functools.partial(SubprocMain.do_job, data)) - def callback(_): + def callback(_: Future[Any]) -> None: if not self.running: return try: @@ -263,11 +268,12 @@ def callback(_): with self.write_lock: if self.running: _send_msg(self.write_pipe, job_id, result) + return future.add_done_callback(callback) @staticmethod - def do_job(data): + def do_job(data: bytes) -> bytes: # do the pickle/unpickle in the sub-subproc job = pickle.loads(data) try: @@ -280,7 +286,7 @@ def do_job(data): AnyPool = typing.Union[ProcessPoolExecutor, SubprocPool] -def _warm_process_pool(pool: AnyPool, n: int): +def _warm_process_pool(pool: AnyPool, n: int) -> None: if isinstance(pool, SubprocPool): return # no need assert isinstance(pool, ProcessPoolExecutor) @@ -314,5 +320,5 @@ class TestException(RuntimeError): pass -def raise_testexc(): +def raise_testexc() -> Never: raise TestException diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index cb4a1f47c01f6..f39051db75ec7 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -24,9 +24,13 @@ def autotune_remote_cache_default() -> Optional[bool]: return _get_tristate_env("TORCHINDUCTOR_AUTOTUNE_REMOTE_CACHE") +def bundled_autotune_remote_cache_default() -> Optional[bool]: + return _get_tristate_env("TORCHINDUCTOR_BUNDLED_AUTOTUNE_REMOTE_CACHE") + + # Enable auto_functionalized_v2 (enabled by default) enable_auto_functionalized_v2 = ( - os.environ.get("TORCHDYNAMO_AUTO_FUNCTIONALIZED_V2", "0") == "1" + os.environ.get("TORCHDYNAMO_AUTO_FUNCTIONALIZED_V2", "1") == "1" ) # add some debug printouts @@ -49,15 +53,36 @@ def autotune_remote_cache_default() -> Optional[bool]: # None: Not set -- Off for OSS, JustKnobs based for internal fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default() -# enable autotune local cache -autotune_local_cache = True +# Enable autotune local cache. +# +# See bundled_autotune_remote_cache for the effect this flag has on the bundled +# remote cache. +autotune_local_cache: bool = True -# enable autotune remote cache +# Enable autotune remote cache. +# +# Enables/disables the autotune remote cache regardless of the state of +# autotune_local_cache. If both local and remote are enabled then on write both +# are written and on read local is checked first and only on a cache miss is +# remote read. +# # False: Disables the cache # True: Enables the cache # None: Not set -- Off for OSS, JustKnobs based for internal autotune_remote_cache: Optional[bool] = autotune_remote_cache_default() +# Enable bundled autotune cache. +# +# Enables/disables the bundled autotune cache regardless of the state of +# autotune_remote_cache. However it does depend on the local cache for local +# state management - as a result if the local cache is disabled this will also +# disable the bundled autotune cache. +# +# False: Disables the cache +# True: Enables the cache (requires autotune_local_cache) +# None: Not set -- Off for OSS, JustKnobs based for internal +bundled_autotune_remote_cache: Optional[bool] = bundled_autotune_remote_cache_default() + # Force disabled all inductor level caching -- This will override any other caching flag force_disable_caches = os.environ.get("TORCHINDUCTOR_FORCE_DISABLE_CACHES") == "1" @@ -78,11 +103,6 @@ def autotune_remote_cache_default() -> Optional[bool]: # use cpp wrapper instead of python wrapper cpp_wrapper = os.environ.get("TORCHINDUCTOR_CPP_WRAPPER", "0") == "1" -# codegen cpp wrapper code in an ABI compatible mode -abi_compatible = ( - os.environ.get("TORCHINDUCTOR_ABI_COMPATIBLE", "1" if is_fbcode() else "0") == "1" -) - c_shim_version = os.environ.get("TORCHINDUCTOR_C_SHIM_VERSION", "2") # dead code elimination @@ -251,7 +271,7 @@ def autotune_remote_cache_default() -> Optional[bool]: ] # enable operator reordering for peak memory optimization -reorder_for_peak_memory = os.environ.get("TORCHINDUCTOR_REORDER_FOR_PEAK_MEMORY") == "1" +reorder_for_peak_memory = True # runtime estimation function for ops # for built-in estimation function, pass in "default"; for user-defined estimation function, pass in the function handle @@ -418,6 +438,17 @@ def use_autoheuristic(name: str) -> bool: os.environ.get("TORCHINDUCTOR_LOOP_ORDERING_AFTER_FUSION", "0") == "1" ) +# If fusing two nodes only save less then score_fusion_memory_threshold memory, +# we should not bother fusing the nodes. +# +# This is especially helpful to resolve https://github.com/pytorch/pytorch/issues/133242 +# Previously we fuse two nodes because of common read of a scalar tensor. +# If we skip it, the loop ordering after fusion mechanism kicks in and can +# brings more savings. +# +# For the cases loop ordering after fusion does not help, we don't lose much. +score_fusion_memory_threshold = 10 + # For Triton Templates, select fastest of best template + epilogue vs best template + separate epilogue kernel benchmark_epilogue_fusion = ( os.environ.get("TORCHINDUCTOR_BENCHMARK_EPILOGUE_FUSION", "1") == "1" @@ -1033,6 +1064,11 @@ class aot_inductor: # TODO: Move this somewhere else, since it's no longer really a config metadata: Dict[str, str] = {} + # fbcode only. Whether to raise error if C++ codegen is too big to optimize + raise_error_on_ignored_optimization: bool = ( + os.environ.get("AOTINDUCTOR_RAISE_ERROR_ON_IGNORED_OPTIMIZATION", "1") == "1" + ) + class cuda: # CUDA arch to use for CUDA template kernel compilation. @@ -1266,6 +1302,11 @@ class trace: # External callable for matmul tuning candidates external_matmul: List[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = [] + +class test_configs: + force_extern_kernel_in_multi_template = False + + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/_inductor/cpp_builder.py b/torch/_inductor/cpp_builder.py index adfd04b054012..866be87904515 100644 --- a/torch/_inductor/cpp_builder.py +++ b/torch/_inductor/cpp_builder.py @@ -501,7 +501,12 @@ def _get_os_related_cpp_cflags(cpp_compiler: str) -> List[str]: else: cflags = ["Wno-unused-variable", "Wno-unknown-pragmas"] if _is_clang(cpp_compiler): - cflags.append("Werror=ignored-optimization-argument") + ignored_optimization_argument = ( + "Werror=ignored-optimization-argument" + if config.aot_inductor.raise_error_on_ignored_optimization + else "Wno-ignored-optimization-argument" + ) + cflags.append(ignored_optimization_argument) return cflags @@ -707,9 +712,7 @@ def _setup_standard_sys_libs( cflags.append("nostdinc") # Note that the order of include paths do matter, as a result # we need to have several branches interleaved here - if torch.version.hip is None: - # TODO(T203136598): Is there any harm in including sleef_include in the hip path? - include_dirs.append(build_paths.sleef_include) + include_dirs.append(build_paths.sleef_include) include_dirs.append(build_paths.openmp_include) include_dirs.append(build_paths.python_include) include_dirs.append(build_paths.cc_include) @@ -776,14 +779,9 @@ def _get_torch_related_args( if not aot_mode: libraries.append("torch_python") - if _IS_WINDOWS: + if _IS_WINDOWS and platform.machine().lower() != "arm64": libraries.append("sleef") - # Unconditionally import c10 for non-abi-compatible mode to use TORCH_CHECK - See PyTorch #108690 - if not config.abi_compatible: - libraries.append("c10") - libraries_dirs.append(TORCH_LIB_PATH) - return include_dirs, libraries_dirs, libraries diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 344f2bc58f56a..c249c6311b753 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -6,7 +6,8 @@ import re import subprocess import sys -from typing import Any, Callable, Dict, List +import warnings +from typing import Any, Callable, Dict, List, Union import torch from torch._inductor import config @@ -52,7 +53,7 @@ class VecISA: # In fbcode however, we are using the same compiler for pytorch and for inductor codegen, # making the runtime check unnecessary. _avx_code = """ -#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) +#if defined(CPU_CAPABILITY_AVX512) || defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_ZVECTOR) || defined(CPU_CAPABILITY_NEON) || defined(CPU_CAPABILITY_VSX) || defined(CPU_CAPABILITY_SVE) #include #include #endif @@ -137,10 +138,13 @@ def check_build(self, code: str) -> bool: return True - @functools.lru_cache(None) # noqa: B019 def __bool__(self) -> bool: - if config.cpp.vec_isa_ok is not None: - return config.cpp.vec_isa_ok + return self.__bool__impl(config.cpp.vec_isa_ok) + + @functools.lru_cache(None) # noqa: B019 + def __bool__impl(self, vec_isa_ok) -> bool: + if vec_isa_ok is not None: + return vec_isa_ok if config.is_fbcode(): return True @@ -150,10 +154,10 @@ def __bool__(self) -> bool: @dataclasses.dataclass class VecNEON(VecISA): - _bit_width = 256 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec256/vec256_float_neon.h + _bit_width = 128 # This is required to leverage the compute implemented in aten/src/ATen/cpu/vec/vec128/vec128_float_neon.h _macro = ["CPU_CAPABILITY_NEON", "AT_BUILD_ARM_VEC256_WITH_SLEEF"] _arch_flags = "" # Unused - _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + _dtype_nelements = {torch.float: 4, torch.bfloat16: 8, torch.float16: 8} def __str__(self) -> str: return "asimd" # detects the presence of advanced SIMD on armv8-a kernels @@ -161,6 +165,24 @@ def __str__(self) -> str: __hash__: Callable[[VecISA], Any] = VecISA.__hash__ +@dataclasses.dataclass +class VecSVE(VecISA): + # this function can be repurposed for SVE with variable vec length + _bit_width = 256 + _macro = [ + "CPU_CAPABILITY_SVE", + "CPU_CAPABILITY_SVE256", + "AT_BUILD_ARM_VEC256_WITH_SLEEF", + ] + _arch_flags = "-march=armv8-a+sve -msve-vector-bits=256" + _dtype_nelements = {torch.float: 8, torch.bfloat16: 16, torch.float16: 16} + + def __str__(self) -> str: + return "asimd" + + __hash__: Callable[[VecISA], Any] = VecISA.__hash__ + + @dataclasses.dataclass class VecAVX512(VecISA): _bit_width = 512 @@ -306,7 +328,36 @@ def _check_and_append_supported_isa( invalid_vec_isa = InvalidVecISA() -supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()] +supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE()] + + +def get_isa_from_cpu_capability( + capability: Union[str, None], + vec_isa_list: List[VecISA], + invalid_vec_isa: InvalidVecISA, +): + # AMX setting is not supported in eager + # VecAMX will be prioritized for selection when setting ATEN_CPU_CAPABILITY to avx512 + # TODO add sve256 support + capability_to_isa_str = { + "default": "INVALID_VEC_ISA", + "zvector": "zvector", + "vsx": "vsx", + "avx2": "avx2", + "avx512": "avx512", + } + if capability in capability_to_isa_str.keys(): + isa_str = capability_to_isa_str[capability] + if isa_str == "INVALID_VEC_ISA": + return invalid_vec_isa + for vec_isa in vec_isa_list: + if isa_str in str(vec_isa): + return vec_isa + + if capability: + warnings.warn(f"ignoring invalid value for ATEN_CPU_CAPABILITY {capability}") + + return vec_isa_list[0] # Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content @@ -338,7 +389,10 @@ def valid_vec_isa_list() -> List[VecISA]: elif arch == "ppc64le": isa_list.append(VecVSX()) elif arch == "aarch64": - isa_list.append(VecNEON()) + if torch.cpu._is_arm_sve_supported(): + isa_list.append(VecSVE()) + else: + isa_list.append(VecNEON()) elif arch in ["x86_64", "AMD64"]: """ arch value is x86_64 on Linux, and the value is AMD64 on Windows. @@ -359,10 +413,12 @@ def pick_vec_isa() -> VecISA: if not _valid_vec_isa_list: return invalid_vec_isa - # If the simdlen is None, it indicates determine the vectorization length automatically + # If the simdlen is None, set simdlen based on the environment ATEN_CPU_CAPABILITY + # to control CPU vec ISA if config.cpp.simdlen is None: - assert _valid_vec_isa_list - return _valid_vec_isa_list[0] + return get_isa_from_cpu_capability( + os.getenv("ATEN_CPU_CAPABILITY"), _valid_vec_isa_list, invalid_vec_isa + ) for isa in _valid_vec_isa_list: if config.cpp.simdlen == isa.bit_width(): diff --git a/torch/_inductor/cudagraph_trees.py b/torch/_inductor/cudagraph_trees.py index 5a33de0e36689..1ccfc6e65055f 100644 --- a/torch/_inductor/cudagraph_trees.py +++ b/torch/_inductor/cudagraph_trees.py @@ -853,7 +853,7 @@ def __init__( def maybe_get_static_data_ptr( idx: int, - inputs: List[Union[torch.Tensor, int]], + inputs: List[InputType], static_input_idxs: List[int], ) -> Optional[int]: inp = inputs[idx] @@ -1576,7 +1576,7 @@ def create_storage(self, metadata: Dict[str, Any]) -> torch.types.Storage: def _allocate_and_copy_recording_inputs( self, inputs: List[InputType] - ) -> List[Union[torch.Tensor, int]]: + ) -> List[InputType]: """ Allocate inputs for non static, non cudagraph managed tensors in the memory pool and copy over the tensor values. @@ -1913,22 +1913,32 @@ def __init__(self, device_index: int) -> None: # mod2(mod1(x)).sum().backward() self.running_forwards_with_pending_backwards = False + self.mode: Optional[CompilationMode] = None + + self.disable_invalidate_aliases = ( + False + if not torch._environment.is_fbcode() + else torch._utils_internal.justknobs_check( + "pytorch/inductor:disable_cudagraph_alias_invalidation" + ) + ) def run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType: assert self.graph is not None, "Running CUDAGraph after shutdown" + self.mode = self.id_to_mode[function_id] out = self._run(new_inputs, function_id) # The forwards are only pending following invocation, not before - mode = self.id_to_mode[function_id] - if mode == CompilationMode.FORWARD: + if self.mode == CompilationMode.FORWARD: self.running_forwards_with_pending_backwards = True - elif mode == CompilationMode.BACKWARD: + elif self.mode == CompilationMode.BACKWARD: self.running_forwards_with_pending_backwards = False return out def set_to_running_backward(self) -> None: self.running_forwards_with_pending_backwards = False + self.mode = CompilationMode.BACKWARD def _get_cuda_graph_recorded_tensor_checker(self) -> Callable[[Tensor], bool]: return ( @@ -2348,10 +2358,24 @@ def check_warn_on_unable_to_start_executing(self, function_id: FunctionID) -> No "before each model invocation" ) + @staticmethod + def format_dealloc_msg(stack_trace: Optional[str]) -> str: + stack_trace = ( + stack_trace.strip() if stack_trace else "[Could not find stack trace]" + ) + return ( + "Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. " + f"Stack trace: {stack_trace}. " + "To prevent overwriting, clone the tensor outside of torch.compile() " + "or call torch.compiler.cudagraph_mark_step_begin() before each model invocation." + ) + def dealloc_current_path_weakrefs(self) -> None: assert self.current_node is not None # TODO: we could also allow the these weak refs to continue to be allocated, # but that adds some complications. + + stor_stack_trace: Dict[int, Optional[str]] = {} for node in self.current_node._path_from_root: assert node.stack_traces is not None assert len(node.tensor_weakrefs) == len(node.stack_traces) @@ -2360,26 +2384,41 @@ def dealloc_current_path_weakrefs(self) -> None: if ten is None: continue - stack_trace = ( - stack_trace.strip() - if stack_trace - else "[Could not find stack trace]" + torch._C._set_storage_access_error_msg( + ten, self.format_dealloc_msg(stack_trace) ) - msg = ( - "Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run. " - f"Stack trace: {stack_trace}. " - "To prevent overwriting, clone the tensor outside of torch.compile() " - "or call torch.compiler.cudagraph_mark_step_begin() before each model invocation." - ) - torch._C._set_storage_access_error_msg(ten, msg) + + # we would to enable the following assertion, but an internal model failed with a command + # that does not repro. len(node.outputs_weakrefs) == len(node.stack_traces) + # so, pessimistically assume that they might differ by doing the debug info + # loop separately from the dealloc loop + if self.disable_invalidate_aliases: + continue + + for storage_ref, stack_trace in zip( + node.outputs_weakrefs, node.stack_traces + ): + if not storage_ref: + continue + + stor_stack_trace[storage_ref.data_ptr()] = stack_trace deleted = set() for storage_ref in self.current_node.path_live_weakrefs(): _storage_deref = storage_ref() if _storage_deref and storage_ref.data_ptr() not in deleted: deleted.add(storage_ref.data_ptr()) + + msg = self.format_dealloc_msg( + stor_stack_trace.get(storage_ref.data_ptr()) + ) torch._C._free_And_Remove_DeleterFn(_storage_deref) + if self.disable_invalidate_aliases: + continue + + torch._C._set_storage_data_ptr_access_error_msg(_storage_deref, msg) + def clear_current_path_state_and_set_to_none(self) -> None: assert isinstance(self.current_node, CUDAGraphNode) self.current_node.clear_path_state() diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 0375ce5f75dc6..50024d25160e0 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -75,6 +75,7 @@ aten.native_group_norm, aten.native_layer_norm, aten.nll_loss2d_backward, + aten.permute_copy, aten._softmax, aten.sin_, aten.sqrt_, @@ -82,6 +83,7 @@ aten._to_copy, aten.tril_indices, aten.triu_indices, + aten.unbind_copy.int, aten.upsample_bilinear2d.vec, quantized.linear_dynamic_fp16_unpacked_weight, _quantized.wrapped_quantized_linear, diff --git a/torch/_inductor/dependencies.py b/torch/_inductor/dependencies.py index 912acb55b6aeb..9459138162375 100644 --- a/torch/_inductor/dependencies.py +++ b/torch/_inductor/dependencies.py @@ -230,6 +230,8 @@ def has_unbacked_symbols(self): return len(free_unbacked_symbols(self.get_numel())) > 0 def is_contiguous(self) -> bool: + if isinstance(self.index, sympy.Integer): + return True return isinstance(self.index, sympy.Symbol) and self.index in self.var_names def stride1_for_last_dim(self, result_for_complex_expression=True) -> bool: diff --git a/torch/_inductor/exc.py b/torch/_inductor/exc.py index 50c977e53e6f5..7b9f206955ee6 100644 --- a/torch/_inductor/exc.py +++ b/torch/_inductor/exc.py @@ -1,28 +1,28 @@ -# mypy: allow-untyped-defs from __future__ import annotations import os import tempfile import textwrap from functools import lru_cache +from typing import Any, List if os.environ.get("TORCHINDUCTOR_WRITE_MISSING_OPS") == "1": @lru_cache(None) - def _record_missing_op(target): + def _record_missing_op(target: Any) -> None: with open(f"{tempfile.gettempdir()}/missing_ops.txt", "a") as fd: fd.write(str(target) + "\n") else: - def _record_missing_op(target): # type: ignore[misc] + def _record_missing_op(target: Any) -> None: # type: ignore[misc] pass class OperatorIssue(RuntimeError): @staticmethod - def operator_str(target, args, kwargs): + def operator_str(target: Any, args: List[Any], kwargs: dict[str, Any]) -> str: lines = [f"target: {target}"] + [ f"args[{i}]: {arg}" for i, arg in enumerate(args) ] @@ -32,13 +32,13 @@ def operator_str(target, args, kwargs): class MissingOperatorWithoutDecomp(OperatorIssue): - def __init__(self, target, args, kwargs) -> None: + def __init__(self, target: Any, args: List[Any], kwargs: dict[str, Any]) -> None: _record_missing_op(target) super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}") class MissingOperatorWithDecomp(OperatorIssue): - def __init__(self, target, args, kwargs) -> None: + def __init__(self, target: Any, args: List[Any], kwargs: dict[str, Any]) -> None: _record_missing_op(target) super().__init__( f"missing decomposition\n{self.operator_str(target, args, kwargs)}" @@ -54,7 +54,9 @@ def __init__(self, target, args, kwargs) -> None: class LoweringException(OperatorIssue): - def __init__(self, exc: Exception, target, args, kwargs) -> None: + def __init__( + self, exc: Exception, target: Any, args: List[Any], kwargs: dict[str, Any] + ) -> None: super().__init__( f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}" ) diff --git a/torch/_inductor/fx_passes/b2b_gemm.py b/torch/_inductor/fx_passes/b2b_gemm.py index 5a854b5b9d994..64cb597188eed 100644 --- a/torch/_inductor/fx_passes/b2b_gemm.py +++ b/torch/_inductor/fx_passes/b2b_gemm.py @@ -369,12 +369,20 @@ def is_b2b_gemm_good_on( # basic checks if not all(["val" in A_node.meta, "val" in B_node.meta, "val" in C_node.meta]): return False - A, B, C = ( + fake_tensors = ( A_node.meta["val"], B_node.meta["val"], C_node.meta["val"], ) # torch._subclasses.fake_tensor.FakeTensor - if not all([A.is_cuda, B.is_cuda, C.is_cuda]): + + A, B, C = fake_tensors + + def check_all_attr_true(objects, attr): + return all(hasattr(obj, attr) and getattr(obj, attr) for obj in objects) + + if not check_all_attr_true(fake_tensors, "is_cuda") and not check_all_attr_true( + fake_tensors, "is_xpu" + ): return False if not all([len(A.shape) == 2, len(B.shape) == 2, len(C.shape) == 2]): return False @@ -506,7 +514,7 @@ def create_placeholder( """ Creates a placeholder input buffers for producing subgraph_output """ - input_buffer = InputBuffer(name, FixedLayout(device, dtype, [], [])) + input_buffer = InputBuffer(name=name, layout=FixedLayout(device, dtype, [], [])) return TensorBox.create(input_buffer) diff --git a/torch/_inductor/fx_passes/fuse_attention.py b/torch/_inductor/fx_passes/fuse_attention.py index 59e353ceb8f97..5c3811db27a07 100644 --- a/torch/_inductor/fx_passes/fuse_attention.py +++ b/torch/_inductor/fx_passes/fuse_attention.py @@ -42,9 +42,9 @@ def _sfdp_pattern_1(query, key, value, inv_scale): def _sfdp_replacement_1(query, key, value, inv_scale): counters["inductor"]["fuse_attention"] += 1 return _scaled_dot_product_attention( - query.contiguous(), - key.contiguous(), - value.contiguous(), + query, + key, + value, attn_mask=None, dropout_p=0.0, is_causal=False, @@ -64,9 +64,9 @@ def _sfdp_pattern_2(query, key, value, scale_factor): def _sfdp_replacement_2(query, key, value, scale_factor): counters["inductor"]["fuse_attention"] += 1 return _scaled_dot_product_attention( - query.contiguous(), - key.contiguous(), - value.contiguous(), + query, + key, + value, attn_mask=None, dropout_p=0.0, is_causal=False, @@ -86,9 +86,9 @@ def _sfdp_pattern_3(query, key, value, inv_scale_factor, dropout_p): def _sfdp_replacement_3(query, key, value, inv_scale_factor, dropout_p): counters["inductor"]["fuse_attention"] += 1 return _scaled_dot_product_attention( - query.contiguous(), - key.contiguous(), - value.contiguous(), + query, + key, + value, attn_mask=None, dropout_p=dropout_p, is_causal=False, @@ -106,9 +106,9 @@ def _sfdp_pattern_4(query, key, value, scale_factor, dropout_p): def _sfdp_replacement_4(query, key, value, scale_factor, dropout_p): counters["inductor"]["fuse_attention"] += 1 return _scaled_dot_product_attention( - query.contiguous(), - key.contiguous(), - value.contiguous(), + query, + key, + value, attn_mask=None, dropout_p=dropout_p, is_causal=False, @@ -127,9 +127,9 @@ def _sfdp_pattern_5(query, key, value, attn_mask): def _sfdp_replacement_5(query, key, value, attn_mask): counters["inductor"]["fuse_attention"] += 1 return _scaled_dot_product_attention( - query.contiguous(), - key.contiguous(), - value.contiguous(), + query, + key, + value, attn_mask=attn_mask.to(dtype=query.dtype), dropout_p=0.0, is_causal=False, @@ -147,9 +147,9 @@ def _sfdp_pattern_6(query, key, value, attn_mask, dropout_p): def _sfdp_replacement_6(query, key, value, attn_mask, dropout_p): counters["inductor"]["fuse_attention"] += 1 return _scaled_dot_product_attention( - query.contiguous(), - key.contiguous(), - value.contiguous(), + query, + key, + value, attn_mask=attn_mask.to(dtype=query.dtype), dropout_p=dropout_p, is_causal=False, diff --git a/torch/_inductor/fx_passes/group_batch_fusion.py b/torch/_inductor/fx_passes/group_batch_fusion.py index 35aaac5bb2672..d68f3c8c0e156 100644 --- a/torch/_inductor/fx_passes/group_batch_fusion.py +++ b/torch/_inductor/fx_passes/group_batch_fusion.py @@ -1042,6 +1042,71 @@ def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): ] += 1 +class BatchMathOpsPreGradFusion(BatchPointwiseOpsFusionFactory): + """ + Batch simple match related ops such as nan_to_num in pre grad pass. + """ + + def __init__(self, op, **kwargs): + super().__init__(op, **kwargs) + self.op = op + + def match(self, node: torch.fx.Node): + input = get_arg_value(node, 0, "input") + if CallFunctionVarArgs(self.op).match(node) and is_node_meta_valid(node): + # check the input has the same shape and its uers have the same target + # check all clamp operators have the same min and max values, and + # nan_to_num operators use the same default value. + child = next(iter(node.users.keys())) + group_key = ( + str(input.meta["example_value"].shape) + + str(node.kwargs) + + str(child.target) + ) + else: + group_key = None + return group_key + + def fuse(self, graph: torch.fx.GraphModule, subset: List[torch.fx.Node]): + batch_nodes = [] + batch_inputs = [] + batch_inputs_metadata = [] + kwargs = subset[0].kwargs + + for node in subset: + batch_nodes.append(node) + input = get_arg_value(node, 0, "input") + batch_inputs.append(input) + batch_inputs_metadata.append(input.meta["example_value"]) + + with graph.inserting_before(subset[0]): + stack_inputs = graph.call_function( + torch.stack, args=(batch_inputs,), kwargs={"dim": 0} + ) + update_stack_example_value(stack_inputs, batch_inputs_metadata) + batch_op = graph.call_function( + self.op, + args=(stack_inputs,), + kwargs=kwargs, + ) + batch_op.meta["example_value"] = self.op( + stack_inputs.meta["example_value"], **kwargs + ) + unbind_op = graph.call_function( + torch.unbind, args=(batch_op,), kwargs={"dim": 0} + ) + unbind_op.meta["example_value"] = torch.unbind( + batch_op.meta["example_value"], dim=0 + ) + for i, node in enumerate(batch_nodes): + with graph.inserting_after(unbind_op): + getitem = graph.call_function(operator.getitem, args=(unbind_op, i)) + node.replace_all_uses_with(getitem) + getitem.meta.update(node.meta) + graph.erase_node(node) + counters["inductor"]["batch_" + self.op.__name__.lower().split(".")[0]] += 1 + + @register_fusion("batch_tanh") class BatchTanhPreGradFusion(BatchPointwiseOpsPreGradFusion): def __init__(self, **kwargs) -> None: @@ -1060,6 +1125,24 @@ def __init__(self, **kwargs) -> None: super().__init__(torch.nn.functional.relu, **kwargs) +@register_fusion("batch_detach") +class BatchDetachPreGradFusion(BatchMathOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.detach, **kwargs) + + +@register_fusion("batch_nan_to_num") +class BatchNanToNumPreGradFusion(BatchMathOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.nan_to_num, **kwargs) + + +@register_fusion("batch_clamp") +class BatchClampPreGradFusion(BatchMathOpsPreGradFusion): + def __init__(self, **kwargs): + super().__init__(torch.clamp, **kwargs) + + @register_fusion("batch_aten_tanh", pre_grad=False) class BatchTanhPostGradFusion(BatchPointwiseOpsPostGradFusion): def __init__(self, **kwargs) -> None: diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index a00bad3974791..65246d18d2e73 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -706,35 +706,25 @@ def should_pad_mm(match: Match) -> bool: def pad_mat1(mat1, *, m_padded_length, k_padded_length, is_bmm=False): - if m_padded_length == 0 and k_padded_length == 0: - return mat1 - elif k_padded_length != 0 and m_padded_length != 0: + if k_padded_length != 0 or m_padded_length != 0: # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding pad_arg = [0, k_padded_length, 0, m_padded_length] if is_bmm: pad_arg.extend((0, 0)) return aten.constant_pad_nd(mat1, pad_arg) - elif m_padded_length != 0: - return pad_dim(mat1, m_padded_length, 0 if not is_bmm else 1) else: - assert k_padded_length != 0 - return pad_dim(mat1, k_padded_length, 1 if not is_bmm else 2) + return mat1 def pad_mat2(mat2, *, k_padded_length, n_padded_length, is_bmm=False): - if k_padded_length == 0 and n_padded_length == 0: - return mat2 - elif k_padded_length != 0 and n_padded_length != 0: + if k_padded_length != 0 or n_padded_length != 0: # dim order is reversed for constant_pad_nd, for every dim we specify right and left padding pad_arg = [0, n_padded_length, 0, k_padded_length] if is_bmm: pad_arg.extend((0, 0)) return aten.constant_pad_nd(mat2, pad_arg) - elif k_padded_length != 0: - return pad_dim(mat2, k_padded_length, 0 if not is_bmm else 1) else: - assert n_padded_length != 0 - return pad_dim(mat2, n_padded_length, 1 if not is_bmm else 2) + return mat2 def pad_mm( diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 08573a3ffc96c..2490a0bb6f4e5 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -1,11 +1,10 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs -import functools import itertools import logging import operator from collections import Counter, defaultdict -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Set import torch import torch._inductor as inductor @@ -54,10 +53,6 @@ from .split_cat import POST_GRAD_PATTERNS -if TYPE_CHECKING: - from sympy import Expr - - log = logging.getLogger(__name__) aten = torch.ops.aten prims = torch.ops.prims @@ -70,6 +65,19 @@ ] +def apply_pass(pass_fn: Callable[[], object], name: Optional[str] = None) -> None: + # TODO - we should just make this part of GraphTransformObserver + from torch._inductor.bisect_helper import BisectionManager + + debug_info: Optional[Callable[[], str]] = None + if name is not None: + debug_info = lambda: name # noqa: E731 + + if BisectionManager.disable_subsystem("inductor", "post_grad_passes", debug_info): + return + pass_fn() + + def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): """ Passes that run on after grad. This is called once on the forwards @@ -85,23 +93,28 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): gm.graph.eliminate_dead_code() if is_inference and config.reorder_for_locality: - reorder_for_locality(gm.graph) + apply_pass(lambda: reorder_for_locality(gm.graph), "reorder_for_locality") fake_tensor_updater = FakeTensorUpdater(gm.graph) - if config.post_grad_custom_pre_pass is not None: + if post_grad_custom_pre_pass := config.post_grad_custom_pre_pass: with GraphTransformObserver( gm, "post_grad_custom_pre_pass", config.trace.log_url_for_graph_xform ): - config.post_grad_custom_pre_pass(gm.graph) + apply_pass( + lambda: post_grad_custom_pre_pass(gm.graph), "post_grad_custom_pre_pass" + ) if config.pattern_matcher: lazy_init() optimus_scuba_log["before_recompile_post_grad"] = upload_graph(gm.graph) - group_batch_fusion_passes(gm.graph, pre_grad=False) - remove_noop_ops(gm.graph) - for patterns in pass_patterns: - patterns.apply(gm.graph) # type: ignore[arg-type] + apply_pass( + lambda: group_batch_fusion_passes(gm.graph, pre_grad=False), + "group_batch_fusion_passes", + ) + apply_pass(lambda: remove_noop_ops(gm.graph), "remove_noop_ops") + for i, patterns in enumerate(pass_patterns): + apply_pass(lambda: patterns.apply(gm.graph), f"pass_pattern_{i}") # type: ignore[arg-type] for pass_name in config.post_grad_fusion_options: # skip all patterns for group batch fusions if pass_name in POST_GRAD_FUSIONS: @@ -110,7 +123,7 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): inductor_before_change = save_inductor_dict( [pattern_matcher_pass.pass_name] ) - pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] + apply_pass(lambda: pattern_matcher_pass.apply(gm.graph), pass_name) # type: ignore[arg-type] if not is_same_dict(counters["inductor"], inductor_before_change): optimus_scuba_log[ f"{pattern_matcher_pass.pass_name}_post_grad" @@ -122,30 +135,40 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool): micro_pipeline_tp_pass(gm.graph) if config._fuse_ddp_communication: - fuse_ddp_communication( - gm.graph, - config._fuse_ddp_communication_passes, - config._fuse_ddp_bucket_size, + apply_pass( + lambda: fuse_ddp_communication( + gm.graph, + config._fuse_ddp_communication_passes, + config._fuse_ddp_bucket_size, + ), + "fuse_ddp_communication", ) - if config.post_grad_custom_post_pass is not None: + if post_grad_custom_post_pass := config.post_grad_custom_post_pass: with GraphTransformObserver( gm, "post_grad_custom_post_pass", config.trace.log_url_for_graph_xform ): - config.post_grad_custom_post_pass(gm.graph) + apply_pass( + lambda: post_grad_custom_post_pass(gm.graph), + "post_grad_custom_post_pass", + ) - stable_topological_sort(gm.graph) + apply_pass(lambda: stable_topological_sort(gm.graph), "stable_sort") - move_constructors_to_gpu(gm.graph) + apply_pass(lambda: move_constructors_to_gpu(gm.graph), "move_constructors_to_cuda") fake_tensor_updater.incremental_update() # Keep these last, since they introduces mutation. Look at # ./fx_passes/README.md for a discussion of mutation invariants. - reinplace_inplaceable_ops(gm.graph) - decompose_auto_functionalized(gm.graph) + apply_pass(lambda: reinplace_inplaceable_ops(gm.graph), "reinplace_inplaceable_ops") + apply_pass( + lambda: decompose_auto_functionalized(gm.graph), "decompose_auto_functionalized" + ) - comms.reinplace_fsdp_all_gather(gm.graph) + apply_pass( + lambda: comms.reinplace_fsdp_all_gather(gm.graph), "reinplace_fsdp_all_gather" + ) gm.recompile() optimus_scuba_log["after_recompile_post_grad"] = upload_graph(gm.graph) @@ -465,90 +488,6 @@ def repl(*shape): match.replace_by_example(repl, list(shape)) -def shape_of_mm(a, b): - m, _ = a.get_size() - _, n = b.get_size() - return [m, n] - - -@register_lowering_pattern( - CallFunction(aten.cat, ListOf(CallFunction(aten.mm, Arg(), Arg())), Arg()), -) -def cat_mm(match, inputs, dim): - return cat_tuned_op(match, inputs, dim, op=L[aten.mm], shape_of=shape_of_mm) - - -@register_lowering_pattern( - CallFunction( - aten.cat, ListOf(CallFunction(aten.addmm, Arg(), Arg(), Arg())), Arg() - ), -) -def cat_addmm(match, inputs, dim): - def shape_of(bias, a, b): - m, _ = a.get_size() - _, n = b.get_size() - return [m, n] - - return cat_tuned_op(match, inputs, dim, op=L[aten.addmm], shape_of=shape_of) - - -def cat_tuned_op(match, inputs, dim, *, op, shape_of): - """ - Memory planning to remove cat. We can't use the stock memory - planner since autotuning matmuls needs to know the output layout. - """ - if len(inputs) == 1: - return op(*inputs[0]) - - # TODO(jansel): rewrite this as a bmm? - if dim < 0: - dim += len(shape_of(*inputs[0])) - assert dim in (0, 1) - notdim = 1 - dim - - new_size: Optional[Union[List[Expr], List[int]]] = None - offsets_start = [] - offsets_end = [] - - # compute output sizes - for i in range(len(inputs)): - shape = shape_of(*inputs[i]) - if new_size is None: - new_size = shape - else: - new_size[notdim] = V.graph.sizevars.guard_equals( # type: ignore[call-overload] - shape[notdim], new_size[notdim] - ) - new_size[dim] += shape[dim] - offsets_start.append(new_size[dim] - shape[dim]) - offsets_end.append(new_size[dim]) - - assert new_size is not None - dtype = functools.reduce( - torch.promote_types, - [x.get_dtype() for x in itertools.chain.from_iterable(inputs)], - ) - device = inputs[0][0].get_device() - kernel = ir.ConcatKernel( - name=None, - layout=ir.FixedLayout(device, dtype, new_size), - inputs=[], - ) - kernel_tensor = ir.TensorBox.create(kernel) - - for i in range(len(inputs)): - dst = ir.SliceView.create(kernel_tensor, dim, offsets_start[i], offsets_end[i]) - src = op(*inputs[i], layout=dst.get_layout()).data.data - assert isinstance(src, (ir.ExternKernelOut, ir.TemplateBuffer)) - src.layout = ir.NonOwningLayout(dst) - kernel.inputs.append(src) - - kernel.name = V.graph.register_buffer(kernel) - kernel.inputs = ir.ConcatKernel.unwrap_storage(kernel.inputs) - V.graph.register_operation(kernel) - return kernel_tensor - - _cat_1 = CallFunction(aten.cat, Arg(), 1, _users=2) diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index bca3361962b07..16a6a74aea146 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -2,7 +2,7 @@ import copy import itertools import logging -from typing import Dict, Optional +from typing import Dict, Optional, Sequence import torch import torch.nn as nn @@ -112,7 +112,9 @@ def lazy_init(): from . import fb # type: ignore[attr-defined] # noqa: F401 -def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs=None): +def pre_grad_passes( + gm: torch.fx.GraphModule, example_inputs: Sequence[object] = () +) -> torch.fx.GraphModule: """ Apply passes on the input FX graph using Torch IR. @@ -138,7 +140,7 @@ def shape_prop(mod) -> None: gm=mod, # pyre-fixme[16]: Module `torch._dynamo.utils` has no attribute `detect_fake_mode` fake_mode=detect_fake_mode(example_inputs), - ).propagate(*example_inputs) + ).propagate(*tuple(example_inputs)) # normalization pass pass_execution_and_save( @@ -243,10 +245,14 @@ def shape_prop(mod) -> None: gm = fuse_fx(gm, example_inputs) numpy_compat_normalization(gm.graph) optimus_scuba_log["before_recompile_pre_grad"] = upload_graph(gm.graph) + # We should always do the normalization_pass first + if "normalization_pass" in config.pre_grad_fusion_options: + pattern_matcher_pass = PRE_GRAD_PATTERNS["normalization_pass"] + pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] group_batch_fusion_passes(gm.graph, pre_grad=True) for pass_name in config.pre_grad_fusion_options: # skip all patterns for group batch fusions - if pass_name in PRE_GRAD_FUSIONS: + if pass_name in PRE_GRAD_FUSIONS or pass_name == "normalization_pass": continue pattern_matcher_pass = PRE_GRAD_PATTERNS[pass_name] inductor_before_change = save_inductor_dict( diff --git a/torch/_inductor/fx_passes/quantization.py b/torch/_inductor/fx_passes/quantization.py index 3c918d480704e..d257536939fc6 100644 --- a/torch/_inductor/fx_passes/quantization.py +++ b/torch/_inductor/fx_passes/quantization.py @@ -912,31 +912,38 @@ def __init__( for int8_mixed_bf16_with_inplace_add in [False, True]: # Priority 1 to match: QConv2d Binary or Binary-Unary pattern with int8 output - binary_replace_patterns = { - BinaryUnaryAttr( - "sum", 1.0, "none", [], "" - ): generate_pattern_with_output_quant( - generate_pattern_with_binary( - aten.add.Tensor, - get_dequantize_qconv_pt2e_pattern(1), - dequantize_accum_pattern, - int8_mixed_bf16_with_inplace_add, - ), - ), - BinaryUnaryAttr( - "sum", 1.0, "relu", [], "" - ): generate_pattern_with_output_quant( - generate_pattern_with_unary( - generate_pattern_with_binary( - aten.add.Tensor, - get_dequantize_qconv_pt2e_pattern(1), - dequantize_accum_pattern, - int8_mixed_bf16_with_inplace_add, + swap_binary_inputs_list = [False, True] + binary_replace_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_patterns.update( + { + BinaryUnaryAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), ), - aten.relu.default, - ), - ), - } + BinaryUnaryAttr( + "sum", 1.0, "relu", [], "" + ): generate_pattern_with_output_quant( + generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + dequantize_accum_pattern, + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + aten.relu.default, + ), + ), + } + ) for binary_unary_attr, patterns in binary_replace_patterns.items(): _register_quantized_conv_binary_lowering( @@ -947,17 +954,24 @@ def __init__( ) # Priority 2 to match: QConv2d Binary-Unary pattern with fp32/bfloat16 output - binary_replace_float_out_patterns = { - BinaryUnaryAttr("sum", 1.0, "relu", [], ""): generate_pattern_with_unary( - generate_pattern_with_binary( - aten.add.Tensor, - get_dequantize_qconv_pt2e_pattern(1), - KeywordArg("accum_after_dequant"), - int8_mixed_bf16_with_inplace_add, - ), - aten.relu.default, - ), - } + binary_replace_float_out_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + BinaryUnaryAttr( + "sum", 1.0, "relu", [], "" + ): generate_pattern_with_unary( + generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + aten.relu.default, + ) + } + ) for ( binary_unary_attr, @@ -979,14 +993,21 @@ def __init__( ) # Priority 3: QConv2d Binary pattern with fp32/bfloat16 output - binary_replace_float_out_patterns = { - BinaryUnaryAttr("sum", 1.0, "none", [], ""): generate_pattern_with_binary( - aten.add.Tensor, - get_dequantize_qconv_pt2e_pattern(1), - KeywordArg("accum_after_dequant"), - int8_mixed_bf16_with_inplace_add, - ), - } + binary_replace_float_out_patterns = {} + for swap_inputs in swap_binary_inputs_list: + binary_replace_float_out_patterns.update( + { + BinaryUnaryAttr( + "sum", 1.0, "none", [], "" + ): generate_pattern_with_binary( + aten.add.Tensor, + get_dequantize_qconv_pt2e_pattern(1), + KeywordArg("accum_after_dequant"), + int8_mixed_bf16_with_inplace_add, + swap_inputs=swap_inputs, + ), + } + ) for ( binary_unary_attr, diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index a1bee18a615c6..8a7f06ed2a4b7 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -7,6 +7,7 @@ from typing import Any, Callable, Dict, List, Tuple import torch +from torch._dispatch.python import enable_python_dispatcher from torch._higher_order_ops.triton_kernel_wrap import ( kernel_side_table, triton_kernel_wrapper_functional, @@ -708,6 +709,7 @@ def tensor_with_same_storage_already_reinplaced(arg): def reinplace_inplaceable_ops(graph: torch.fx.Graph) -> None: - canonicalize_view_scatter_ops(graph) - reinplace_inplaceable_ops_core(graph) - decompose_generalized_scatter(graph) + with enable_python_dispatcher(): + canonicalize_view_scatter_ops(graph) + reinplace_inplaceable_ops_core(graph) + decompose_generalized_scatter(graph) diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 194d1d6dbaa79..46f990f7d9af0 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -7,6 +7,7 @@ import torch from torch._dynamo.utils import counters +from torch.fx.experimental.symbolic_shapes import free_symbols from ..pattern_matcher import ( Arg, @@ -450,8 +451,6 @@ def normalize_reshape_default(match: Match, *args, **kwargs): return reshape_input = get_arg_value(reshape_node, 0) - from torch.fx.experimental.symbolic_shapes import free_symbols - if free_symbols(reshape_node.meta["example_value"].shape): log.debug("dynamic shape not supported: %s", reshape_node) return @@ -466,6 +465,67 @@ def normalize_reshape_default(match: Match, *args, **kwargs): match.graph.erase_node(reshape_node) +@register_graph_pattern( + CallMethodVarArgs("clamp", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +@register_graph_pattern( + CallFunctionVarArgs(torch.clamp, users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_clamp_default(match: Match, *args, **kwargs): + clamp_node = match.nodes[0] + if not is_node_meta_valid(clamp_node): + log.debug("example value absent for node: %s", clamp_node) + return + + if free_symbols(clamp_node.meta["example_value"].shape): + log.debug("dynamic shape not supported: %s", clamp_node) + return + if len(clamp_node.args) > 1: + args = (get_arg_value(clamp_node, 0),) + kwargs = { + "min": get_arg_value(clamp_node, 1, kwarg_name="min"), + "max": get_arg_value(clamp_node, 2, kwarg_name="max"), + } + else: + args = clamp_node.args + kwargs = clamp_node.kwargs + with match.graph.inserting_after(clamp_node): + new_clamp_node = match.graph.call_function( + torch.clamp, + args=args, + kwargs=kwargs, + ) + clamp_node.replace_all_uses_with(new_clamp_node) + new_clamp_node.meta.update(clamp_node.meta) + match.graph.erase_node(clamp_node) + + +@register_graph_pattern( + CallMethodVarArgs("detach", users=MULTIPLE), + pass_dict=construct_pattern_matcher_pass("normalization_pass"), +) +def normalize_detach_default(match: Match, *args, **kwargs): + detach_node = match.nodes[0] + if not is_node_meta_valid(detach_node): + log.debug("example value absent for node: %s", detach_node) + return + + if free_symbols(detach_node.meta["example_value"].shape): + log.debug("dynamic shape not supported: %s", detach_node) + return + + with match.graph.inserting_after(detach_node): + new_detach_node = match.graph.call_function( + torch.detach, + args=detach_node.args, + ) + detach_node.replace_all_uses_with(new_detach_node) + new_detach_node.meta.update(detach_node.meta) + match.graph.erase_node(detach_node) + + class TorchSplit(CallFunction): """ Matches a call to torch.split if it is in a normalized form. Ensures that all users of diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 7859e8126cc9d..5f2848a98d1ce 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1,3 +1,4 @@ +import contextlib import functools import itertools import logging @@ -15,6 +16,7 @@ DefaultDict, Dict, Iterable, + Iterator, List, NoReturn, Optional, @@ -63,6 +65,7 @@ get_wrapper_codegen_for_device, init_backend_registration, ) +from .codegen.wrapper import PythonWrapperCodegen from .exc import ( CppWrapperCodegenError, LoweringException, @@ -90,6 +93,8 @@ needs_realized_inputs, unsupported_output_tensor, ) +from .runtime import autotune_cache +from .runtime.autotune_cache import AutotuneCacheBundler from .scheduler import BaseSchedulerNode from .sizevars import SizeVarAllocator from .utils import ( @@ -97,6 +102,7 @@ gather_origins, get_cloned_parameter_buffer_name, get_sympy_Expr_dtype, + is_same_tensor, maybe_get_suppress_shape_guards_ctx, normalize_name, should_assume_input_aligned, @@ -106,7 +112,6 @@ if TYPE_CHECKING: from torch._higher_order_ops.effects import _EffectType - from .codegen.wrapper import PythonWrapperCodegen from torch._inductor.codecache import output_code_log @@ -310,7 +315,7 @@ def static_sizes_strides( def __init__( self, gm: torch.fx.GraphModule, - example_inputs: Optional[List[torch.Tensor]] = None, + example_inputs: Optional[Sequence[object]] = None, shape_env: Optional[ShapeEnv] = None, graph_id: Optional[int] = None, cpp_wrapper: bool = False, @@ -426,6 +431,11 @@ def __init__( self.graph_id = graph_id self.post_grad_graph_id = next(_post_grad_graph_counter) self.scheduler: torch._inductor.scheduler.Scheduler = None # type: ignore[assignment] + + # current_device is set only during codegen of a device-specific kernel + # a graph can have many devices + self.current_device: Optional[torch.device] = None + self.nodes_prefer_channels_last = ( self.find_nodes_prefer_channels_last() if self.layout_opt else OrderedSet() ) @@ -463,12 +473,31 @@ def __init__( # Below field is related to printing debug intermediate tensor values info for debugging self.all_codegen_kernel_names: OrderedSet[str] = OrderedSet() + # state used by wrapper.generate_workspace_allocation() + self.allocated_workspaces: Dict[str, Any] = {} + self.workspace_id = itertools.count() + def has_feature( self, device: Union[torch._inductor.ir.IRNode, device], feature: BackendFeature ) -> bool: assert isinstance(feature, BackendFeature), feature return feature in self.get_backend_features(get_device_type(device)) + def get_current_device_or_throw(self) -> torch.device: + if device := self.current_device: + return device + else: + raise RuntimeError("No current device") + + @contextlib.contextmanager + def set_current_device(self, device: torch.device) -> Iterator[None]: + prior = self.current_device + self.current_device = device + try: + yield + finally: + self.current_device = prior + @staticmethod def decide_layout_opt(gm: GraphModule, *, is_inference: bool) -> bool: """ @@ -644,16 +673,17 @@ def make_subgraph( gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor], subgraph_name: str, - ) -> "GraphLowering": + ) -> "SubgraphLowering": """ - Make a subgraph of the current graph with all inherited - parts, except the graph module (`gm`) and `example_inputs`. - The subgraphs are lowered separately, but intended to be - inlined in the parent graph's codegening. Hence the need - for maintaining the same `shape_env` and other properties. - The subgraph name is qualified by the parent graph's name. + Make a subgraph of the current graph with all inherited parts, except + the graph module (`gm`) and `example_inputs`. The subgraphs are lowered + separately and lifted into a separate function in the parent output + wrapper code. The subgraph name is qualified by the parent graph's + name. Note that the lifting of subgraph is supported for python wrapper + only. For cpp wrapper, we inline the subgraphs in the parent wrapper. """ - return GraphLowering( + return SubgraphLowering( + parent=self, gm=gm, example_inputs=example_inputs, shape_env=self._shape_env, @@ -741,14 +771,17 @@ def try_get_buffer( if buffer_name in self.constants: data = V.graph.constants[buffer_name] return ir.ConstantBuffer( - buffer_name, - ir.FixedLayout( + name=buffer_name, + layout=ir.FixedLayout( data.device, data.dtype, *V.graph.static_sizes_strides(data) ), ) return None + def add_symbol_graph_input(self, symbol: sympy.Expr) -> None: + raise RuntimeError("Should not be called for the main graph") + def get_buffer(self, buffer_name: str) -> Union[ir.TensorBox, ir.Buffer]: buf = self.try_get_buffer(buffer_name) if buf is not None: @@ -862,16 +895,7 @@ def allocate_non_dup_const_name( ) -> str: if not config.aot_inductor.use_runtime_constant_folding: for constant_name, value in self.constants.items(): - if ( - not data.is_mkldnn - and data.size() == value.size() - and data.stride() == value.stride() - and data.dtype == value.dtype - and data.device == value.device - and data.untyped_storage().data_ptr() - == value.untyped_storage().data_ptr() - and data.storage_offset() == value.storage_offset() - ): + if is_same_tensor(data, value): return constant_name if name is None: @@ -903,8 +927,10 @@ def add_tensor_constant( new_name = self.allocate_non_dup_const_name(name, data) return TensorBox.create( ir.ConstantBuffer( - new_name, - FixedLayout(data.device, data.dtype, *self.static_sizes_strides(data)), + name=new_name, + layout=FixedLayout( + data.device, data.dtype, *self.static_sizes_strides(data) + ), ) ) @@ -928,6 +954,7 @@ def placeholder( self, target: str, args: Tuple[object], kwargs: Dict[str, object] # type: ignore[override] ) -> Union[Expr, TensorBox, None]: example = super().placeholder(target, args, kwargs) # type: ignore[arg-type] + target = self.qualify_name(target) if isinstance(example, SymTypes): expr = example.node.expr self.graph_inputs[target] = expr @@ -957,11 +984,10 @@ def placeholder( else: sizes, strides = self.symbolic_sizes_strides(example) # type: ignore[assignment] # TODO(jansel): handle input aliasing - target = self.qualify_name(target) tensor = TensorBox.create( InputBuffer( - target, - FixedLayout(example.device, example.dtype, sizes, strides), + name=target, + layout=FixedLayout(example.device, example.dtype, sizes, strides), ) ) self.graph_inputs[target] = tensor @@ -1059,7 +1085,7 @@ def get_attr( if isinstance(value, torch._C.ScriptObject): self.torchbind_constants[target] = value self.constant_reprs[target] = "" - return TorchBindObject(target, value) + return TorchBindObject(name=target, value=value) assert isinstance(value, torch.Tensor) if ( @@ -1071,7 +1097,9 @@ def get_attr( with no_dispatch(): if value.shape == (): - return Constant(value.item(), value.dtype, value.device) + return Constant( + value=value.item(), dtype=value.dtype, device=value.device + ) if self.can_inline_constant(value): log.debug("Inlining constant: %s ", str(target)) # tensor lowering has constant inlining logic @@ -1235,7 +1263,9 @@ def significant_strides_equal( new_stride, old_layout.offset, ) - return ir.TensorBox(torch._inductor.ir.ReinterpretView(storage, new_layout)) + return ir.TensorBox( + torch._inductor.ir.ReinterpretView(data=storage, layout=new_layout) + ) def propagate_mutation( self, @@ -1304,6 +1334,8 @@ def run_node(self, n: torch.fx.Node) -> object: def debug(msg: str) -> None: log.debug("lowering %s %s", LazyString(n.format_node), msg) + from torch._inductor.bisect_helper import BisectionManager + buffer_watermark = len(self.buffers) operation_watermark = len(self.operations) @@ -1320,7 +1352,12 @@ def debug(msg: str) -> None: if ( n.op == "call_function" and n.target is not operator.getitem - and fallback_node_due_to_unsupported_type(n) + and ( + fallback_node_due_to_unsupported_type(n) + or BisectionManager.disable_subsystem( + "inductor", "lowerings", lambda: repr(n) + ) + ) ): debug("fallback_handler") result = fallback_handler(n.target, add_to_fallback_set=False)( @@ -1539,7 +1576,7 @@ def debug(msg: str) -> None: curr = result.data.data if isinstance(curr, Pointwise): # Use inner fn as a rough proxy. Good enough. - if curr.has_large_inner_fn(): + if curr.has_large_inner_fn(threshold=100): result.realize() # This is not complete, but it doesn't have to be: origin_node @@ -1552,20 +1589,20 @@ def debug(msg: str) -> None: # the origin_node here. if isinstance(result, TensorBox) and isinstance(result.data, ir.StorageBox): if isinstance(result.data.data, ir.Loops): - result.data.data.origin_node = n + result.data.data._post_init_setattr("origin_node", n) elif isinstance(result.data.data, ir.Buffer): - result.data.data.origin_node = n + result.data.data._post_init_setattr("origin_node", n) if isinstance(result.data.data, ir.ComputedBuffer) and isinstance( result.data.data.data, ir.Loops ): - result.data.data.data.origin_node = n + result.data.data.data._post_init_setattr("origin_node", n) # Not really multi-output, can straightforwardly recurse in elif ( isinstance(result.data.data, ir.MultiOutput) and not result.data.data.indices ): if isinstance(result.data.data.inputs[0], ir.Buffer): - result.data.data.inputs[0].origin_node = n + result.data.data.inputs[0]._post_init_setattr("origin_node", n) self.register_users_of(result) @@ -1695,7 +1732,12 @@ def validate_can_generate_cpp_wrapper(self) -> None: if not supported_dtype_of_cpp_wrapper(dtype, self.device_type): raise CppWrapperCodegenError(f"Unsupported input dtype {dtype}") - def init_wrapper_code(self) -> None: + def init_wrapper_code( + self, + is_subgraph: bool = False, + subgraph_name: Optional[str] = None, + parent_wrapper_code: Optional[PythonWrapperCodegen] = None, + ) -> None: device_types = self.device_types.copy() device_types.discard("cpu") device_types.discard("meta") @@ -1716,7 +1758,9 @@ def init_wrapper_code(self) -> None: assert ( wrapper_code_gen_cls is not None ), f"Device {self.device_type} not supported" - self.wrapper_code = wrapper_code_gen_cls() + self.wrapper_code = wrapper_code_gen_cls.create( + is_subgraph, subgraph_name, parent_wrapper_code + ) if self.const_module: # If we have const module, we could reuse the kernels @@ -1899,6 +1943,10 @@ def _compile_to_module(self) -> ModuleType: GraphLowering.save_output_code(code) output_code_log.debug("Output code: \n%s", code) + + inductor_meta = autotune_cache.inductor_meta_from_config() + AutotuneCacheBundler.begin_compile(inductor_meta, code=code) + try: linemap = [(line_no, node.stack_trace) for line_no, node in linemap] # type: ignore[misc] key, path = PyCodeCache.write(code) @@ -1978,5 +2026,79 @@ def is_unspec_arg(self, name: str) -> bool: return ( name in self.graph_inputs.keys() and self.graph_inputs[name].get_numel() == 1 + and len(self.graph_inputs[name].get_size()) == 0 and self.graph_inputs[name].get_device().type == "cpu" ) or name in self.zero_dim_cpu_tensor_list + + +class SubgraphLowering(GraphLowering): + """ + Mostly a helper class for the subgraph lowering. The main goal is to call + init_wrapper_code with the subgraph related arguments. + """ + + def __init__(self, parent: GraphLowering, *args: Any, **kwargs: Any) -> None: + self.parent = parent + super().__init__(*args, **kwargs) + + def init_wrapper_code( + self, + is_subgraph: bool = False, + subgraph_name: Optional[str] = None, + parent_wrapper_code: Optional[PythonWrapperCodegen] = None, + ) -> None: + super().init_wrapper_code( + is_subgraph=True, + subgraph_name=self.name, + parent_wrapper_code=self.parent.wrapper_code, + ) + + def add_symbol_graph_inputs(self) -> None: + """ + For subgraphs, it is possible that the aten graph does not have a symint + associated with the shape of the input tensors. To ensure that the + shape/stride symbol is available for the subgraph code (e.g. for + allocating intermediate tensor), we collect all the symbols from input + tensors of this subgraph (passed as inputs from the parent graph) and + add them as extra inputs to the subgraph. + + The parent wrapper `codegen_subgraph` then ensures to pass on the + corresponding symints from the parent function to the lifted subgraph + function. + """ + + def get_free_symbols(expr: sympy.Expr) -> OrderedSet[sympy.Symbol]: + # expr can be s0 + s1, recurse to get s0 and s1 + symbols: OrderedSet[ + sympy.Symbol + ] = OrderedSet() # Use a set to avoid duplicates + if isinstance(expr, sympy.Symbol): + symbols.add(expr) + elif isinstance(expr, sympy.Expr): + symbols.update(expr.free_symbols) + return symbols + + subgraph_symbols: OrderedSet[sympy.Symbol] = OrderedSet() + + graph_inputs_tensors = list( + filter( + lambda x: not isinstance(x[1], sympy.Expr), self.graph_inputs.items() + ) + ) + + for name_value in graph_inputs_tensors: + _, value = name_value + shapes = value.get_size() + for dim, shape in enumerate(shapes): + subgraph_symbols.update(get_free_symbols(shape)) + + strides = value.get_stride() + for dim, shape in enumerate(strides): + subgraph_symbols.update(get_free_symbols(shape)) + + # Add the extra symints in the subgraph + for symbol in subgraph_symbols: + if symbol.name in self.graph_input_names: + continue + self.graph_inputs[symbol.name] = symbol + self.graph_input_names.append(symbol.name) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 0d44e245a53a9..695b94b27a737 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -83,6 +83,7 @@ convert_shape_to_symint, developer_warning, get_kernel_metadata, + ir_dataclass, is_dynamic, is_gpu, sympy_dot, @@ -233,6 +234,14 @@ def reindex(index: Sequence[_T]) -> Sequence[_V]: NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1] +def get_fill_order(seq: Sequence[Union[int, torch.SymInt, Expr]]) -> Sequence[int]: + """ + Convert strides to fill order (argsort) + """ + sorted_idx: Sequence[int] = argsort(seq) + return sorted_idx + + def stride_order2fill_order(order: Sequence[Union[int, Integer]]) -> Sequence[int]: """ Convert stride order to fill order @@ -249,7 +258,7 @@ def get_stride_order(seq: Sequence[Union[int, torch.SymInt, Expr]]) -> Sequence[ """ Convert strides to stride order """ - sorted_idx: List[int] = argsort(seq) + sorted_idx: Sequence[int] = get_fill_order(seq) out = [0 for _ in range(len(seq))] for i, elem in enumerate(sorted_idx): out[elem] = i @@ -324,6 +333,11 @@ def is_cpu(x: object) -> bool: class IRNode: _current_origins: ClassVar[OrderedSet[Any]] = OrderedSet() + # NB: These are kinda weird, + origins: OrderedSet[Any] = dataclasses.field(init=False) + traceback: str = dataclasses.field(init=False) + origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False) + @staticmethod @contextlib.contextmanager def current_origins(origins: OrderedSet[torch.fx.Node]): @@ -334,9 +348,18 @@ def current_origins(origins: OrderedSet[torch.fx.Node]): finally: IRNode._current_origins = old + def _post_init_setattr(self, attr, value): + # Intended for use in __post_init__ for enforcing an invariant on a dataclass + # If you must, can also be used for setting provenance info + # We would like to try and minimize these usages though + object.__setattr__(self, attr, value) + def __post_init__(self): - self.origins = OrderedSet(self._current_origins) - self.traceback = traceback.format_stack() if config.debug_ir_traceback else None + self._post_init_setattr("origins", OrderedSet(self._current_origins)) + self._post_init_setattr( + "traceback", traceback.format_stack() if config.debug_ir_traceback else None + ) + self._post_init_setattr("origin_node", None) def get_read_names(self) -> OrderedSet[str]: raise NotImplementedError(f"NYI on {type(self)}") @@ -344,6 +367,9 @@ def get_read_names(self) -> OrderedSet[str]: def get_traceback(self): return self.traceback + def get_origin_node(self): + return self.origin_node + def get_defining_op(self): raise NotImplementedError @@ -421,7 +447,7 @@ def codegen_reference(self, writer=None): get_unbacked_symbol_uses: Callable[[], OrderedSet[sympy.Symbol]] -@dataclasses.dataclass +@ir_dataclass(frozen=False) class Operation: def __post_init__(self): self.operation_name: Optional[str] = None @@ -490,7 +516,7 @@ def get_workspace_size(self): return 0 -@dataclasses.dataclass +@ir_dataclass class Loops(IRNode): device: torch.device dtype: torch.dtype @@ -516,7 +542,6 @@ def __str__(self, names=("ranges",)): def __post_init__(self): super().__post_init__() - self.origin_node = None __repr__ = __str__ @@ -539,11 +564,14 @@ def is_extern(self): def create(cls, *args, **kwargs): origin_node = kwargs.pop("origin_node", None) tb = kwargs.pop("traceback", None) + # if "origin_node" in kwargs: + # breakpoint() r = cls(*args, **kwargs) - r.origin_node = origin_node - r.traceback = ( - tb or traceback.format_stack() if config.debug_ir_traceback else None - ) + # Need to explicitly set origin_node here to propagate it down. + # todo(chilli): I think it would be better for IRNode to directly set + # origin_node + r._post_init_setattr("origin_node", origin_node) + r._post_init_setattr("traceback", tb or r.traceback) return TensorBox.create(r) @staticmethod @@ -571,8 +599,11 @@ def inner_fn_str(self): self.inner_fn, *self.inner_fn_args() ) - def has_large_inner_fn(self): - return self.inner_fn_opcount().num_ops > config.realize_opcount_threshold + def has_large_inner_fn(self, threshold=None): + if threshold is None: + threshold = 0 + threshold = max(threshold, config.realize_opcount_threshold) + return self.inner_fn_opcount().num_ops > threshold def inner_fn_free_unbacked_symbols(self): index = self._index(self.ranges) @@ -621,6 +652,7 @@ def nop_loader_fn(idx: Union[Expr, Sequence[Expr]], *, dtype: torch.dtype) -> Op return ops.constant(0, dtype) +@ir_dataclass class Pointwise(Loops): def make_loader(self): # Make zero-element loops into a no-op @@ -643,10 +675,12 @@ def constant_to_device(self, device): """Move this to a given device. Requires that all reads are to constants.""" loader = self.make_loader() loader = patch.object(ConstantBuffer, "override_device", device)(loader) - return Pointwise(device, self.dtype, loader, self.ranges) + return Pointwise( + device=device, dtype=self.dtype, inner_fn=loader, ranges=self.ranges + ) -@dataclasses.dataclass +@ir_dataclass class Scatter(Pointwise): output_indexer: Callable[[List[Expr]], Expr] scatter_mode: Optional[str] = None @@ -656,12 +690,12 @@ def constant_to_device(self, device): loader = self.make_loader() loader = patch.object(ConstantBuffer, "override_device", device)(loader) return Scatter( - device, - self.dtype, - loader, - self.ranges, - self.output_indexer, - self.scatter_mode, + device=device, + dtype=self.dtype, + inner_fn=loader, + ranges=self.ranges, + output_indexer=self.output_indexer, + scatter_mode=self.scatter_mode, ) def store_output(self, output_name, indexer, vars): @@ -763,7 +797,7 @@ def significant_strides_equal( return strides1 == strides2 -@dataclasses.dataclass +@ir_dataclass class Reduction(Loops): reduction_ranges: List[Expr] reduction_type: str @@ -817,14 +851,14 @@ def constant_to_device(self, device): loader = self.make_loader() loader = patch.object(ConstantBuffer, "override_device", device)(loader) return Reduction( - device, - self.dtype, - loader, - self.ranges, - self.reduction_ranges, - self.reduction_type, - self.src_dtype, - ReductionHint.DEFAULT, + device=device, + dtype=self.dtype, + inner_fn=loader, + ranges=self.ranges, + reduction_ranges=self.reduction_ranges, + reduction_type=self.reduction_type, + src_dtype=self.src_dtype, + reduction_hint=ReductionHint.DEFAULT, ) @staticmethod @@ -981,14 +1015,14 @@ def outer_reduction_splits(reduction_numel_hint, numel_hint): return ReductionHint.DEFAULT, 1 r = Reduction( - device, - dst_dtype, - inner_fn, - ranges, - reduction_ranges, - reduction_type, - src_dtype, - ReductionHint.DEFAULT, + device=device, + dtype=dst_dtype, + inner_fn=inner_fn, + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type=reduction_type, + src_dtype=src_dtype, + reduction_hint=ReductionHint.DEFAULT, ) def get_read_indices(r): @@ -1156,7 +1190,9 @@ def fn(index): reduction_index = [sympy.Integer(0) for _ in reduction_ranges] return inner_fn(index, reduction_index) - return Pointwise.create(device, dst_dtype, fn, ranges) + return Pointwise.create( + device=device, dtype=dst_dtype, inner_fn=fn, ranges=ranges + ) if ( isinstance(reduction_numel, sympy.Integer) @@ -1165,12 +1201,12 @@ def fn(index): and sympy_product(ranges) != 1 ): return Pointwise.create( - device, - dst_dtype, - cls._unroll_reduction_fn( + device=device, + dtype=dst_dtype, + inner_fn=cls._unroll_reduction_fn( inner_fn, reduction_ranges, reduction_type, src_dtype ), - ranges, + ranges=ranges, ) # triton doesn't support reduce to single element well, so break it up @@ -1225,14 +1261,14 @@ def fn(index): return TensorBox.create( Reduction( - device, - dst_dtype, - inner_fn, - ranges, - reduction_ranges, - reduction_type, - src_dtype, - reduction_hint, + device=device, + dtype=dst_dtype, + inner_fn=inner_fn, + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type=reduction_type, + src_dtype=src_dtype, + reduction_hint=reduction_hint, ) ) @@ -1397,14 +1433,14 @@ def intermediate_fn(index, reduction_index): assert original_ranges == new_ranges[: len(original_ranges)] return TensorBox.create( Reduction( - device, - dst_dtype, - intermediate_fn, - original_ranges, - new_ranges[len(original_ranges) :], - reduction_type, - src_dtype, - reduction_hint, + device=device, + dtype=dst_dtype, + inner_fn=intermediate_fn, + ranges=original_ranges, + reduction_ranges=new_ranges[len(original_ranges) :], + reduction_type=reduction_type, + src_dtype=src_dtype, + reduction_hint=reduction_hint, ) ) @@ -1511,14 +1547,14 @@ def loader(idx, reduction_idx): return tuple(fn(idx, reduction_idx) for fn in inner_fns) super().__init__( - device, - dtype, - loader, - ranges, - reduction_ranges, - reduction_type, - dtype, - reduction_hint, + device=device, + dtype=dtype, + inner_fn=loader, + ranges=ranges, + reduction_ranges=reduction_ranges, + reduction_type=reduction_type, + src_dtype=dtype, + reduction_hint=reduction_hint, ) self.output_index = output_index @@ -1744,7 +1780,7 @@ def intermediate_loader_fn(index, reduction_index, loader): ) -@dataclasses.dataclass +@ir_dataclass class Scan(Loops): scan_ranges: List[Expr] size: List[Expr] @@ -1931,12 +1967,12 @@ def wrapper_fn(idx, reduction_idx): # This signifies a scan op that should go through TritonSplitScanKernel codegen on CUDA. -@dataclasses.dataclass +@ir_dataclass class SplitScan(Scan): pass -@dataclasses.dataclass +@ir_dataclass class Sort(Loops): # Sorts a tuple of key, value pairs sort_ranges: List[Expr] @@ -2162,7 +2198,7 @@ def is_stride_order_storage_and_layout( return False -@dataclasses.dataclass +@ir_dataclass class BaseView(IRNode): data: IRNode @@ -2250,10 +2286,15 @@ def constant_to_device(self, device): """Move this to a given device. Requires that all reads are to constants.""" loader = self.make_loader() loader = patch.object(ConstantBuffer, "override_device", device)(loader) - return Pointwise(device, self.get_dtype(), loader, self.get_size()) + return Pointwise( + device=device, + dtype=self.get_dtype(), + inner_fn=loader, + ranges=self.get_size(), + ) -@dataclasses.dataclass +@ir_dataclass class ExpandView(BaseView): size: List[Expr] @@ -2269,7 +2310,9 @@ def _normalize_size(x, new_size): if new_size[i] == -1: assert old_size[i] is not None new_size[i] = old_size[i] - elif old_size[i] is None or old_size[i] == 1: + elif old_size[i] is None or V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(old_size[i], 1), size_oblivious=True + ): pass else: # Sanity check: Expect broadcast compatibility @@ -2292,7 +2335,13 @@ def create(cls, x, new_size): assert skip >= 0 new_stride = [sympy.Integer(0)] * skip for stride, size in zip(old_layout.stride, old_layout.size): - new_stride.append(stride if size != 1 else sympy.Integer(0)) + new_stride.append( + stride + if not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(size, 1), size_oblivious=True + ) + else sympy.Integer(0) + ) new_layout = FixedLayout( old_layout.device, old_layout.dtype, @@ -2300,9 +2349,9 @@ def create(cls, x, new_size): new_stride, old_layout.offset, ) - return ReinterpretView(storage, new_layout) + return ReinterpretView(data=storage, layout=new_layout) - return ExpandView(x, new_size) + return ExpandView(data=x, size=new_size) def get_size(self): return self.size @@ -2324,7 +2373,7 @@ def reindex(index): return reindex -@dataclasses.dataclass +@ir_dataclass class PermuteView(BaseView): dims: List[Expr] @@ -2342,9 +2391,9 @@ def create(cls, x, dims): [old_layout.stride[i] for i in dims], old_layout.offset, ) - return ReinterpretView(storage, new_layout) + return ReinterpretView(data=storage, layout=new_layout) - return PermuteView(x, dims) + return PermuteView(data=x, dims=dims) @classmethod def _map_neg_dims(cls, dims): @@ -2368,6 +2417,7 @@ def reindex(index): return reindex +@ir_dataclass class SqueezeView(BaseView): @classmethod def create(cls, x, *, dim=None): @@ -2398,7 +2448,7 @@ def create(cls, x, *, dim=None): new_stride, old_layout.offset, ) - return ReinterpretView(storage, new_layout) + return ReinterpretView(data=storage, layout=new_layout) if dim is None: # redirect to a generic view @@ -2426,7 +2476,7 @@ def __init__(self, data): raise AssertionError("use SqueezeView.create()") -@dataclasses.dataclass +@ir_dataclass class GenericView(BaseView): size: List[Expr] reindex: Callable[..., Any] @@ -2450,13 +2500,13 @@ def __str__(self) -> str: @classmethod def create(cls, x, new_size, reindex): - return cls(x, list(new_size), reindex) + return cls(data=x, size=list(new_size), reindex=reindex) def get_size(self): return self.size -@dataclasses.dataclass +@ir_dataclass class View(GenericView): @staticmethod def handle_negative_index(idx, size): @@ -2488,7 +2538,7 @@ def create(cls, x, new_size): def fake_reindex(index): return tuple([0] * len(old_size)) - return cls(x, list(new_size), fake_reindex) + return cls(data=x, size=list(new_size), reindex=fake_reindex) # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout elif is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes: if unbacked_symbols_in_sizes and (not is_contiguous_storage_and_layout(x)): @@ -2504,10 +2554,10 @@ def fake_reindex(index): FlexibleLayout.contiguous_strides(new_size), old_layout.offset, ) - return ReinterpretView(storage, new_layout) + return ReinterpretView(data=storage, layout=new_layout) reindex = cls.dynamic_reshape_indexer(old_size, new_size) - return cls(x, list(new_size), reindex) + return cls(data=x, size=list(new_size), reindex=reindex) @staticmethod def resolve_negative_size(old_size, new_size): @@ -2604,7 +2654,7 @@ def reindex(index): return reindex -@dataclasses.dataclass +@ir_dataclass class ReinterpretView(BaseView): """Pretend our storage has a different layout""" @@ -2613,7 +2663,7 @@ class ReinterpretView(BaseView): def __post_init__(self): super().__post_init__() if isinstance(self.data, BaseView): - self.data = self.data.unwrap_view() + object.__setattr__(self, "data", self.data.unwrap_view()) def __str__(self) -> str: return self.str_helper( @@ -2688,7 +2738,7 @@ def num_reads(self): return 1 -@dataclasses.dataclass +@ir_dataclass class DtypeView(BaseView): """Pretend our storage has a different type""" @@ -2705,8 +2755,8 @@ def create(cls, x, new_dtype): old_layout.stride, old_layout.offset, ) - return ReinterpretView(storage, new_layout) - return DtypeView(x, new_dtype) + return ReinterpretView(data=storage, layout=new_layout) + return DtypeView(data=x, target_dtype=new_dtype) def __str__(self) -> str: return self.str_helper([self.data, self.target_dtype]) @@ -2792,7 +2842,7 @@ def create(cls, x, dim, start, end, step=1, clamp=True): new_stride, old_layout.offset + old_layout.stride[dim] * start, ) - return ReinterpretView(storage, new_layout) + return ReinterpretView(data=storage, layout=new_layout) def reindex(index): assert len(index) == len(new_size), f"wrong ndim {index} {new_size}" @@ -2801,9 +2851,10 @@ def reindex(index): return index # redirect to a generic view - return SliceView(x, size=new_size, reindex=reindex) + return SliceView(data=x, size=new_size, reindex=reindex) +@ir_dataclass class BaseConstant(IRNode): dtype: torch.dtype device: torch.device @@ -2830,7 +2881,7 @@ def is_extern(self): return False -@dataclasses.dataclass +@ir_dataclass class Constant(BaseConstant): value: Any dtype: torch.dtype @@ -2846,10 +2897,10 @@ def realize(self): pass def constant_to_device(self, device): - return Constant(self.value, self.dtype, device) + return Constant(value=self.value, dtype=self.dtype, device=device) -@dataclasses.dataclass +@ir_dataclass class IndexingConstant(BaseConstant): index: Any dtype: torch.dtype @@ -2862,7 +2913,7 @@ def loader(index): return loader def constant_to_device(self, device): - return IndexingConstant(self.index, self.dtype, device) + return IndexingConstant(index=self.index, dtype=self.dtype, device=device) def is_contiguous_strides_for_shape( @@ -2880,7 +2931,7 @@ def get_align_for_dtype(dtype: torch.dtype) -> int: return config.padding_alignment_bytes // dtype.itemsize -@dataclasses.dataclass +@ir_dataclass class Layout(IRNode): def __init__( self, @@ -3090,11 +3141,11 @@ def __init__( if stride is None: stride = FlexibleLayout.contiguous_strides(size) super().__init__( - device, - dtype, - size, # type: ignore[arg-type] - stride, - offset, # type: ignore[arg-type] + device=device, + dtype=dtype, + size=size, # type: ignore[arg-type] + stride=stride, + offset=offset, # type: ignore[arg-type] ) def make_indexer(self): @@ -3277,6 +3328,7 @@ def maybe_guard_aligned(self): return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) # type: ignore[arg-type] +@ir_dataclass class NoneLayout(IRNode): # This is janky, I figured out what fields to populate by just running # the model I was interested in and adding properties/methods as needed. @@ -3286,10 +3338,9 @@ class NoneLayout(IRNode): # If you have an ir.Node with NoneLayout, you probably need to setup # dependencies manually in scheduler - def __init__(self, device): - self.device = device - self.size = [0] - self.stride = [0] + device: torch.device + size: List[int] = dataclasses.field(default_factory=lambda: [0]) + stride: List[int] = dataclasses.field(default_factory=lambda: [0]) def storage_size(self): return 0 @@ -3378,7 +3429,7 @@ def make_indexer(self): return self.target.make_indexer() -@dataclasses.dataclass +@ir_dataclass(frozen=False) class Buffer(IRNode): # Name is sometimes None; e.g., ForceInPlace, where there isn't # a meaningful name @@ -3390,7 +3441,7 @@ class Buffer(IRNode): def __post_init__(self): super().__post_init__() - self.origin_node = None + self._post_init_setattr("origin_node", None) def make_indexer(self): return self.layout.make_indexer() @@ -3402,9 +3453,6 @@ def get_name(self) -> str: def get_device(self): return self.layout.device - def get_origin_node(self): - return self.origin_node - def get_defining_op(self) -> Optional[Operation]: return None @@ -3499,7 +3547,7 @@ def should_allocate(self): return False -@dataclasses.dataclass +@ir_dataclass(frozen=False) class OperationBuffer(Buffer, Operation): # An operation that produces a single output buffer def get_outputs(self) -> List[Buffer]: @@ -3533,10 +3581,11 @@ def loader(index): def constant_to_device(self, device): return ConstantBuffer( - V.graph.constant_name(self.get_name(), device), self.layout + name=V.graph.constant_name(self.get_name(), device), layout=self.layout ) +@ir_dataclass class NoneAsConstantBuffer(IRNode): def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() @@ -3545,23 +3594,18 @@ def codegen_reference(self, writer=None): return V.graph.wrapper_code.none_str +@ir_dataclass class ShapeAsConstantBuffer(IRNode): - def __init__(self, shape): - super().__init__() - self._shape = shape - - @property - def shape(self): - return self._shape + expr: Expr def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: - return free_unbacked_symbols(self.shape) + return free_unbacked_symbols(self.expr) def codegen_reference(self, writer=None): - return V.graph.wrapper_code.expr_printer(V.graph.sizevars.simplify(self.shape)) + return V.graph.wrapper_code.expr_printer(V.graph.sizevars.simplify(self.expr)) -@dataclasses.dataclass +@ir_dataclass(frozen=False) class ComputedBuffer(OperationBuffer): data: Loops @@ -3799,7 +3843,7 @@ def simplify_and_reorder(x_vars, support_vars, sizes, simplify_loops): support_vars = index_vars + reduce_vars should_merge_loops = ( - self.get_device().type != "cuda" or not config.loop_ordering_after_fusion + not is_gpu(self.get_device().type) or not config.loop_ordering_after_fusion ) iter_ranges, iter_reindex, _ = simplify_and_reorder( index_vars, @@ -3946,7 +3990,6 @@ def __init__( layout, inputs, make_kernel_render, - debug_extra=None, mutated_inputs: Optional[Iterable[IRNode]] = None, ): """ @@ -3959,7 +4002,6 @@ def __init__( and we mark them as mutated inputs. """ super().__init__(layout, inputs, make_kernel_render) - self.debug_extra = debug_extra self.mutated_inputs = mutated_inputs self.outputs: List[Buffer] = [self] if mutated_inputs is not None: @@ -3974,14 +4016,15 @@ def __init__( ), f"Mutated inputs are only allowed for {allowed_set} but got {current_node}" device = self.inputs[0].get_device() self.outputs += [ - MutationOutput(NoneLayout(device), buf, self) for buf in mutated_inputs + MutationOutput(NoneLayout(device=device), buf, self) + for buf in mutated_inputs ] def get_outputs(self) -> List[Buffer]: return self.outputs def __str__(self) -> str: - out = f"TritonTemplateBuffer(layout={self.layout}, {self.debug_extra})" + out = f"TritonTemplateBuffer(layout={self.layout})" return out @@ -3997,11 +4040,20 @@ class ChoiceCaller: Children classes: TritonTemplateCaller, CUDATemplateCaller. """ - def __init__(self, name, input_nodes, layout): + def __init__( + self, + name: str, + input_nodes: List[Buffer], + layout: Layout, + description: str, + ): super().__init__() self.name = name self.layout = layout self.input_nodes = input_nodes + # An additional description used to describe the choice (useful for + # knowing what autotuning is choosing) + self.description = description def benchmark(self, *args, out) -> float: algo = self.to_callable() @@ -4046,11 +4098,27 @@ def __init__( layout: Layout, inputs: List[IRNode], choice_timings: Callable[[], Dict[ChoiceCaller, float]], + unfiltered_choices: List[ChoiceCaller], ): super().__init__(layout=layout, inputs=inputs, make_kernel_render=None) self._choice_timings_fn = choice_timings self._choice_timings: Optional[Dict[ChoiceCaller, float]] = None self.original_inputs = inputs + self._output_plannable = all( + isinstance(choice, TritonTemplateCallerBase) + or ( + isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller) + and choice.has_out_variant + ) + for choice in unfiltered_choices + ) + + @property + def output_plannable(self) -> bool: + """ + Are all possible choices TritonTemplates or Extern Kernels with out variants + """ + return self._output_plannable @property def choice_timings(self) -> Dict[ChoiceCaller, float]: @@ -4106,7 +4174,7 @@ def __init__(self, layout, inputs, make_kernel_render, template, choice): self.choice = choice -@dataclasses.dataclass +@ir_dataclass(frozen=False) class InputsKernel(OperationBuffer): inputs: List[Buffer] @@ -4116,6 +4184,9 @@ def get_read_writes(self): for input in self.inputs: if isinstance(input, list): reads.update(StarDep(x.get_name()) for x in input) + elif isinstance(input, ShapeAsConstantBuffer): + # Skip creating dependncy for symbolics as they're visible globally + continue else: reads.add(StarDep(input.get_name())) @@ -4269,10 +4340,31 @@ def create(cls, inputs, dim): return kernel @classmethod - def can_realize_into_without_copy(cls, src): + def can_realize_into_without_copy(cls, src, dst=None): if isinstance(src, TensorBox): # unwrap a TensorBox - return cls.can_realize_into_without_copy(src.data) + return cls.can_realize_into_without_copy(src.data, dst) + + if isinstance(src.data, MultiTemplateBuffer): + if ( + not isinstance(src.data.layout, FixedLayout) + or not src.data.output_plannable + ): + return False + + # we call can_realize_into_without_copy in cat lowering before we've decided + # on output format, optimistically assume layout matches + if dst is None: + return True + + # otherwise, check equality of layouts + if not len(src.get_stride()) == len(dst.get_stride()): + return False + + return all( + V.graph.sizevars.statically_known_equals(s1, s2) + for s1, s2 in zip(src.get_stride(), dst.get_stride()) + ) return isinstance(src.data.layout, FlexibleLayout) and not isinstance( src.data, ExternKernelAlloc @@ -4286,16 +4378,17 @@ def realize_into(cls, src, dst): if not isinstance(dst, ReinterpretView): if is_storage_and_layout(dst): storage, layout = as_storage_and_layout(dst) - dst = ReinterpretView(storage, layout) + dst = ReinterpretView(data=storage, layout=layout) assert isinstance(dst, ReinterpretView), dst if isinstance(src, TensorBox): # unwrap a TensorBox return cls.realize_into(src.data, dst) + if isinstance(src, StorageBox): src.realize() # ExternKernelAlloc has specific requirements for output layout, should create a copy assert hasattr(src.data, "layout") - if cls.can_realize_into_without_copy(src): + if cls.can_realize_into_without_copy(src, dst): src.data.layout = NonOwningLayout(dst) return src.data # introduce a copy @@ -4314,7 +4407,7 @@ def should_allocate(self): return True -@dataclasses.dataclass +@ir_dataclass(frozen=False) class ExternKernel(InputsKernel): constant_args: Tuple[Any, ...] = () kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) @@ -4350,9 +4443,9 @@ def __init__( op_overload=None, ): super().__init__( - name, - layout, - inputs, + name=name, + layout=layout, + inputs=inputs, ) self.constant_args = constant_args self.kwargs = kwargs if kwargs else {} @@ -4435,9 +4528,9 @@ def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None): # Try to construct cpp_kernel_name from op_overload if kernel.namespace == "aten": # Calling with the default kernel name can lead to ambiguous behavior like the following example. - # repeat_interleave(const at::Tensor & repeats, c10::optional output_size=std::nullopt) + # repeat_interleave(const at::Tensor & repeats, std::optional output_size=std::nullopt) # repeat_interleave(const at::Tensor & self, int64_t repeats, - # c10::optional dim=std::nullopt, c10::optional output_size=std::nullopt) + # std::optional dim=std::nullopt, std::optional output_size=std::nullopt) opname = ( kernel.__name__.split(".")[0] if kernel._overloadname == "default" @@ -4475,11 +4568,7 @@ def set_python_kernel_name(self, python_kernel_name: Optional[str]): def get_kernel_name(self): return ( - ( - V.graph.wrapper_code.get_c_shim_func_name(self.cpp_kernel_name) # type: ignore[attr-defined] - if config.abi_compatible - else self.cpp_kernel_name - ) + V.graph.wrapper_code.get_c_shim_func_name(self.cpp_kernel_name) # type: ignore[attr-defined] if V.graph.cpp_wrapper else self.python_kernel_name ) @@ -4667,7 +4756,7 @@ def realize_input(cls, x): if x is None: return NoneAsConstantBuffer() if isinstance(x, (sympy.Expr, sympy.logic.boolalg.Boolean, int)): - return ShapeAsConstantBuffer(x) + return ShapeAsConstantBuffer(expr=x) if isinstance(x, Constant): return V.graph.add_tensor_constant( torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device()) @@ -4677,7 +4766,9 @@ def realize_input(cls, x): if isinstance(x, TensorBox): return cls.realize_input(x.data) if isinstance(x, ReinterpretView): - return ReinterpretView(cls.realize_input(x.data), x.get_layout()) + return ReinterpretView( + data=cls.realize_input(x.data), layout=x.get_layout() + ) if isinstance(x, BaseView): x.realize() if is_storage_and_layout(x.unwrap_view()): @@ -4734,11 +4825,13 @@ def require_strides( x, freeze=True, want_contiguous=False, - stride_order=get_stride_order( - V.graph.sizevars.size_hints(x.get_layout().stride) - ) - if is_stride_order_storage_and_layout(x, order) - else order, + stride_order=( + get_stride_order( + V.graph.sizevars.size_hints(x.get_layout().stride) + ) + if is_stride_order_storage_and_layout(x, order) + else order + ), allow_padding=allow_padding, ) return x @@ -5066,7 +5159,7 @@ def __str__(self) -> str: __repr__ = __str__ -@dataclasses.dataclass +@ir_dataclass(frozen=False) class ExternKernelOut(ExternKernel): def codegen(self, wrapper): self.codegen_comment(wrapper) @@ -5077,11 +5170,7 @@ def codegen(self, wrapper): and self.cpp_kernel_name == "torch::inductor::_mm_plus_mm" ): # For https://github.com/pytorch/pytorch/issues/128474 - kernel_name = ( - "aoti_torch__mm_plus_mm_out" - if config.abi_compatible - else "torch::inductor::_mm_plus_mm_out" - ) + kernel_name = "aoti_torch__mm_plus_mm_out" else: kernel_name = self.get_kernel_name() wrapper.generate_extern_kernel_out( @@ -5137,9 +5226,7 @@ def __init__(self, count: int, device: torch.device): # FIXME: Ideally we should only use at::_ops::randint_low_out::call here, # but the signature is different from is at::randint_out. Again, # we can simplify the code when only keeping an ABI-compatible version. - cpp_kernel_name="at::_ops::randint_low_out::call" - if config.abi_compatible - else "at::randint_out", + cpp_kernel_name="at::_ops::randint_low_out::call", op_overload=aten.randint.low_out, ) @@ -5212,6 +5299,80 @@ def should_allocate(self): return False +class TMADescriptor(ExternKernel): + """ + An IR node representing a host-side TMA descriptor in the Triton API + (the ones obtained via create_{1d,2d}_tma_descriptor calls). Mostly + useful for user-defined Triton kernels relying on host-side TMA; but + can, in principle, be used for Inductor's Triton templates, too. + """ + + # as TMA descriptors are immutable, + # we can dedup them by the input args + _CACHE: Dict[Any, TMADescriptor] = {} + + @classmethod + def create( + cls, + tensor: TensorBox, + dims: List[Union[int, torch.SymInt]], + block_dims: List[Union[int, torch.SymInt]], + element_size: Optional[int] = None, + ): + key = (id(tensor), dims, block_dims, element_size) + if key not in cls._CACHE: + cls._CACHE[key] = TMADescriptor(tensor, dims, block_dims, element_size) + return cls._CACHE[key] + + def __init__( + self, + tensor: TensorBox, + dims: List[Union[int, torch.SymInt]], + block_dims: List[Union[int, torch.SymInt]], + element_size: Optional[int] = None, + ): + assert len(dims) in (1, 2) + assert len(dims) == len(block_dims) + + if element_size is None: + element_size = tensor.get_dtype().itemsize + + self.tensor = tensor + self.dims = dims + self.block_dims = block_dims + self.element_size = element_size + self.rank = len(self.dims) + + inputs = [tensor] + constant_args = [ + *self.dims, + *self.block_dims, + self.element_size, + ] + + super().__init__( + None, + # link back to the underlying tensor in terms of ownership + # to avoid getting the underlying tensor deleted *before* + # the TMADescriptor node can be deleted. + NonOwningLayout( + ReinterpretView( + data=tensor, + layout=tensor.get_layout(), + ) + ), + inputs, + tuple(constant_args), + None, + ) + + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + def codegen(self, wrapper): + wrapper.generate_tma_descriptor(self) + + class UserDefinedTritonKernel(ExternKernel): def get_kernel_and_configs(self): from triton.runtime.autotuner import Autotuner @@ -5243,9 +5404,25 @@ def codegen(self, wrapper): for idx, kwarg in enumerate(self.ordered_kwargs_for_cpp_kernel): if kernel.arg_names.index(kwarg) in kernel.constexprs: constexpr_indices.append(idx) + """ + Filter out None args. + + see https://github.com/pytorch/pytorch/issues/115344 + + Two cases for a None arg: + 1. The arg is already tl.constexpr, so leave it in + 2. The arg is not tl.constexpr so we have to remove it + """ + constexpr_indices_set = set(constexpr_indices) + raw_args = [ + arg + for idx, arg in enumerate(raw_args) + if (arg is not None) or (arg is None and idx in constexpr_indices_set) + ] # Call to kernel self.codegen_comment(wrapper) + wrapper.generate_user_defined_triton_kernel( new_name, raw_args, self.grid, configs, triton_meta, constexpr_indices ) @@ -5258,13 +5435,15 @@ def get_unbacked_symbol_uses(self) -> OrderedSet[sympy.Symbol]: def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() - def __init__(self, *, kernel_idx, grid, kernel_args): + def __init__(self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args): inputs = [] kwargs = {} constant_args = [] for k, v in kernel_args.items(): if isinstance(v, TensorBox): t = InputsKernel.unwrap_storage_for_input(self.realize_input(v)) + if k in tma_descriptor_metadata: + t = TMADescriptor.create(t, *tma_descriptor_metadata[k]) inputs.append(t) kwargs[k] = t else: @@ -5276,7 +5455,7 @@ def __init__(self, *, kernel_idx, grid, kernel_args): super().__init__( None, - NoneLayout(self.device), # type: ignore[arg-type] + NoneLayout(device=self.device), # type: ignore[arg-type] inputs, tuple(constant_args), kwargs, @@ -5285,6 +5464,7 @@ def __init__(self, *, kernel_idx, grid, kernel_args): self.grid = grid kernel, configs = self.get_kernel_and_configs() + # If we are autotuning, not all arguments will be passed self.ordered_kwargs_for_cpp_kernel = [ arg for arg in kernel.arg_names if arg in kernel_args @@ -5301,7 +5481,7 @@ def __init__(self, *, kernel_idx, grid, kernel_args): ] self.mutation_outputs = [ - MutationOutput(NoneLayout(self.device), buf, self) + MutationOutput(NoneLayout(device=self.device), buf, self) for buf in self.mutable_args ] V.graph.register_operation(self) @@ -5321,7 +5501,7 @@ class InplaceBernoulliFallback(ExternKernel): def codegen(self, wrapper): (x,) = (t.codegen_reference() for t in self.inputs) - if V.graph.cpp_wrapper and config.abi_compatible: + if V.graph.cpp_wrapper: # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here, # which needs to be explicitly generated for cpp wrapper wrapper.writeline( @@ -5344,7 +5524,7 @@ def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: def __init__(self, op_overload, x, *constant_args): super().__init__( None, - NoneLayout(x.get_device()), # type: ignore[arg-type] + NoneLayout(device=x.get_device()), # type: ignore[arg-type] self.unwrap_storage([x]), constant_args, op_overload=op_overload, @@ -5352,9 +5532,6 @@ def __init__(self, op_overload, x, *constant_args): V.graph.mark_buffer_mutated(x.get_name()) self.name = V.graph.register_buffer(self) V.graph.register_operation(self) - if not config.abi_compatible: - # TODO: this should be simplified once we switch to ABI-compatible only - self.cpp_kernel_name = "at::native::bernoulli_" # Used to deal with torch.complex types @@ -5388,9 +5565,7 @@ def __init__( inputs, constant_args, python_kernel_name="aten.copy_", - cpp_kernel_name=( - "aoti_torch_copy_" if config.abi_compatible else "at::_ops::copy_::call" - ), + cpp_kernel_name="aoti_torch_copy_", ) V.graph.mark_buffer_mutated(inputs[0].get_name()) self.name = V.graph.register_buffer(self) @@ -5401,7 +5576,7 @@ def create(cls, dst, src, non_blocking: bool = False): inputs = [cls.realize_input(t) for t in [dst, src]] constant_args = (non_blocking,) result = InplaceCopyFallback( - NoneLayout(dst.get_device()), # type: ignore[arg-type] + NoneLayout(device=dst.get_device()), # type: ignore[arg-type] inputs, constant_args, ) @@ -5440,7 +5615,7 @@ def __init__(self, variable, new_size): assert isinstance(new_size, int), "TODO: dynamic shapes" super().__init__( None, - NoneLayout(variable.get_device()), # type: ignore[arg-type] + NoneLayout(device=variable.get_device()), # type: ignore[arg-type] self.unwrap_storage([variable]), constant_args=(new_size,), ) @@ -5466,8 +5641,8 @@ def __init__(self, self_tensor, storage_tensor): V.graph.never_reuse_buffers.add(self.get_name()) device = storage_tensor.get_device() self.mutation_outputs = [ - MutationOutput(NoneLayout(device), self_tensor, self), - MutationOutput(NoneLayout(device), storage_tensor, self), + MutationOutput(NoneLayout(device=device), self_tensor, self), + MutationOutput(NoneLayout(device=device), storage_tensor, self), ] def get_inputs_that_alias_output(self): @@ -5536,7 +5711,7 @@ def __init__( super().__init__( None, - NoneLayout(x.get_device()), # type: ignore[arg-type] + NoneLayout(device=x.get_device()), # type: ignore[arg-type] self.unwrap_storage(tensors), constant_args, {"reduce": reduce, "include_self": include_self}, @@ -5581,12 +5756,10 @@ def __init__(self, op_overload, x, indices, values, accumulate): self.indices = indices valid_indices = [i for i in indices if i is not None] tensors = [self.realize_input(x) for x in [x, values, *valid_indices]] - cpp_kernel_name = ( - "aoti_torch_index_put_out" if config.abi_compatible else "at::index_put_out" - ) + cpp_kernel_name = "aoti_torch_index_put_out" super().__init__( None, - NoneLayout(x.get_device()), # type: ignore[arg-type] + NoneLayout(device=x.get_device()), # type: ignore[arg-type] self.unwrap_storage(tensors), (accumulate,), python_kernel_name="aten.index_put_", @@ -5647,7 +5820,7 @@ def should_allocate(self): def __init__(self, sym, keypath, data): data.realize() - super().__init__(None, NoneLayout(torch.device("cpu")), self.unwrap_storage([data])) # type: ignore[arg-type] + super().__init__(None, NoneLayout(device=torch.device("cpu")), self.unwrap_storage([data])) # type: ignore[arg-type] self.sym = sym self.keypath = keypath @@ -5673,7 +5846,7 @@ def __init__(self, scalar, msg): super().__init__( # Buffer(name, layotu) None, - NoneLayout(torch.device("cpu")), # type: ignore[arg-type] + NoneLayout(device=torch.device("cpu")), # type: ignore[arg-type] # InputsKernel(inputs) [], ) # type: ignore[arg-type] @@ -5706,32 +5879,12 @@ def codegen(self, wrapper): wrapper.writeline(f"{self.get_name()} = None") -@dataclasses.dataclass +@ir_dataclass(frozen=False) class ExternKernelNode: name: str node: export_schema.Node -has_c_shim = OrderedSet( - [ - aten._embedding_bag.default, - aten._fft_c2c.default, - aten._scaled_dot_product_efficient_attention.default, - aten._scaled_dot_product_flash_attention.default, - aten._scaled_dot_product_cudnn_attention.default, - aten._scaled_mm.default, - aten.addmm.out, - aten.bmm.out, - aten.copy_.default, - aten.mm.out, - aten.repeat_interleave.Tensor, - aten.nonzero.default, - aten.view.dtype, - aten.view_as_real.default, - ] -) - - class FallbackKernel(ExternKernelAlloc): def __init__( self, @@ -5844,7 +5997,7 @@ def add_alias(t): self.alias_names.append(t.get_name()) if info.alias_info.is_write: self.mutation_outputs.append( - MutationOutput(NoneLayout(t.get_device()), t, self) + MutationOutput(NoneLayout(device=t.get_device()), t, self) ) if is_list_tensor: @@ -5898,7 +6051,7 @@ def go(expr, keypath): raise AssertionError(f"unrecognized keypath {keypath}") def go_outer(): - if V.graph.cpp_wrapper and config.abi_compatible: + if V.graph.cpp_wrapper: # Special handling for the top level buffer access, # because self.get_name() is actually never bound; the # individual output arguments are bound by @@ -6028,15 +6181,23 @@ def handle_single_output(return_type, output): target = self.op_overload returns = target._schema.returns # type: ignore[union-attr] if len(returns) == 1: + # FIXME: there is a corner case here, i.e. all_reduce_coalesced_'s return value + # is a list of tensors, but self.mutation_outputs is already flatterned. A proper + # fix would require changing all the uses of self.mutation_outputs. return_type = returns[0].real_type - output_arguments = [handle_single_output(return_type, self.outputs)] + output_arguments = [ + handle_single_output( + return_type, [*self.outputs, *self.mutation_outputs] + ) + ] else: # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])" - assert isinstance(self.outputs, tuple) - assert len(returns) == len(self.outputs) + # Not generating output args for self.mutation_outputs output_arguments = [ handle_single_output(return_schema.real_type, output) - for return_schema, output in zip(returns, self.outputs) + for return_schema, output in zip( + returns, [*self.outputs, *self.mutation_outputs] + ) ] node = ExternKernelNode( @@ -6061,7 +6222,7 @@ def codegen(self, wrapper): if V.graph.cpp_wrapper: from torchgen.aoti.fallback_ops import inductor_fallback_ops - if config.abi_compatible and str(kernel) not in inductor_fallback_ops: + if str(kernel) not in inductor_fallback_ops: # C shim v2 is torchgen-ed, which should cover all aten ops. # If you do hit a missed op, please update fallback_ops.py. log.warning( @@ -6072,9 +6233,6 @@ def codegen(self, wrapper): elif kernel.namespace == "_quantized": # type: ignore[union-attr] # Internal Quantized Fallback Ops assert isinstance(kernel, torch._ops.OpOverload) - if V.graph.cpp_wrapper: - if not config.abi_compatible: - self.use_runtime_dispatch = True else: # For non-aten OpOverload, i.e. custom ops if V.graph.cpp_wrapper: @@ -6085,10 +6243,7 @@ def codegen(self, wrapper): exported_args = None args = None - if config.abi_compatible: - exported_args = self.export_extern_kernel_node() - else: - args = [*self.codegen_args(), *self.codegen_kwargs()] + exported_args = self.export_extern_kernel_node() wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( self.get_name(), @@ -6100,7 +6255,7 @@ def codegen(self, wrapper): self.cpp_kernel_overload_name, self.op_overload, exported_args, - self.outputs, + [*self.outputs, *self.mutation_outputs], ) else: self.codegen_comment(wrapper) @@ -6138,7 +6293,7 @@ def create(cls, kernel, *args, **kwargs): device = cls.find_device(tensor_args, example_output) if example_output is None: packed = cls( - NoneLayout(device), + NoneLayout(device=device), kernel, tensor_args, non_tensor_args, @@ -6149,7 +6304,7 @@ def create(cls, kernel, *args, **kwargs): else: assert device, "Not sure where to find device info" packed = cls( - MultiOutputLayout(device), + MultiOutputLayout(device=device), kernel, tensor_args, non_tensor_args, @@ -6195,7 +6350,7 @@ def apply_constraint(self): return super().apply_constraint() -@dataclasses.dataclass +@ir_dataclass(frozen=False) class ComplexView(FallbackKernel): """View a complex number as two dtyped numbers or vice versa""" @@ -6226,7 +6381,7 @@ def __init__( ) -@dataclasses.dataclass +@ir_dataclass class MultiOutputLayout(IRNode): device: torch.device @@ -6280,6 +6435,8 @@ def get_inputs_that_alias_output(self): ] +# We just use a normal dataclass for MutableBox/TensorBox/StorageBox since +# they're mainly lowering-time constructs that we expect to mutate and such. @dataclasses.dataclass class MutableBox(IRNode): """ @@ -6436,7 +6593,7 @@ def num_reads(self): return self.data.num_reads() -@dataclasses.dataclass +@ir_dataclass(frozen=False) class Subgraph(IRNode): name: str graph_module: torch.fx.GraphModule @@ -6452,7 +6609,86 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool: return len(OrderedSet(id(buffer) for buffer in buffers)) < len(buffers) -@dataclasses.dataclass +@ir_dataclass(frozen=False) +class InvokeSubgraph(ExternKernel): + subgraph: Optional[Subgraph] = None + operands: Optional[List[TensorBox]] = None + outputs: Optional[List[MultiOutput]] = None + + def __init__( + self, subgraph: Subgraph, operands: List[TensorBox], layout: MultiOutputLayout + ): + super().__init__( + name=None, + layout=layout, + inputs=operands, + ) + self.subgraph = subgraph + self.name = V.graph.register_buffer(self) + V.graph.register_operation(self) + + @classmethod + def create(cls, subgraph: Subgraph, operands): + # TODO(anijain2305) - Support sym expr as operands in future. + fx_operands = V.graph.current_node.args[-1] + fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr] + + # Realize the inputs. Also intermediates can have different strides than + # the inputs of the subgraph. So, force the intermediates to have same + # strides as that of subgraph inputs. + operands = [cls.realize_input(x) for x in operands] + + def handle_sym_expr(stride): + return [s.node.expr if isinstance(s, torch.SymInt) else s for s in stride] + + fake_strides = [fake_operand.stride() for fake_operand in fake_operands] + fake_strides = [handle_sym_expr(stride) for stride in fake_strides] + operands = [ + cls.require_exact_strides(x, fake_strides[idx]) + for idx, x in enumerate(operands) + ] + + if subgraph.graph is None: + # create and lower subgraphs + subgraph.graph = V.graph.make_subgraph( + gm=subgraph.graph_module, + example_inputs=fake_operands, + subgraph_name=subgraph.name, + ) + with V.set_graph_handler(subgraph.graph): + subgraph.graph.run(*fake_operands) + + outputs = subgraph.graph.graph_outputs # type: ignore[union-attr] + device = operands[0].get_device() + invoke_subgraph = InvokeSubgraph( + subgraph=subgraph, + operands=operands, + layout=MultiOutputLayout(device=device), + ) + + outputs = [ + MultiOutput( + FixedLayout( + device=output.get_device(), + dtype=output.get_dtype(), + size=output.get_size(), + stride=output.get_stride(), + offset=output.get_layout().offset, + ), + invoke_subgraph, + [(list, i)], + ) + for i, output in enumerate(outputs) + ] + + invoke_subgraph.outputs = outputs + return outputs + + def codegen(self, wrapper): + wrapper.codegen_invoke_subgraph(self) + + +@ir_dataclass(frozen=False) class Conditional(ExternKernel): predicate: Optional[IRNode] = None operands: Optional[List[TensorBox]] = None @@ -6546,7 +6782,7 @@ def create( operands=operands, true_subgraph=true_fn, false_subgraph=false_fn, - layout=MultiOutputLayout(device), + layout=MultiOutputLayout(device=device), ) outputs = [ @@ -6573,7 +6809,7 @@ def codegen(self, wrapper): wrapper.codegen_conditional(self) -@dataclasses.dataclass +@ir_dataclass(frozen=False) class WhileLoop(ExternKernel): carried_inputs: Optional[List[TensorBox]] = None additional_inputs: Optional[List[TensorBox]] = None @@ -6666,7 +6902,7 @@ def create( cond_subgraph=cond_fn, body_subgraph=body_fn, # asserted above that there is at least one operand - layout=MultiOutputLayout(device), + layout=MultiOutputLayout(device=device), ) outputs = [ @@ -6744,7 +6980,7 @@ def has_side_effects(self): return True -@dataclasses.dataclass +@ir_dataclass class TorchBindObject(IRNode): name: str value: torch._C.ScriptObject @@ -6808,7 +7044,7 @@ def create_inplace( device = tensor_args[0].get_device() packed = cls( - NoneLayout(device), + NoneLayout(device=device), kernel, tensor_args, non_tensor_args, @@ -6817,14 +7053,14 @@ def create_inplace( inps = pytree.tree_leaves(inputs) packed.mutation_outputs.extend( - [MutationOutput(NoneLayout(device), buf, packed) for buf in inps] + [MutationOutput(NoneLayout(device=device), buf, packed) for buf in inps] ) # For inplace collective ops, the input is guaranteed to be alias of the returned value of op. packed.alias_names.extend([inp.get_name() for inp in inps]) if "out" in kwargs: packed.mutation_outputs.append( - MutationOutput(NoneLayout(device), kwargs["out"], packed) + MutationOutput(NoneLayout(device=device), kwargs["out"], packed) ) # For out-variant collective ops, the `out=` arg is guaranteed to be alias of the returned value of op. packed.alias_names.append(kwargs["out"].get_name()) @@ -6870,7 +7106,7 @@ def create_out_of_place( if isinstance(example_output, list): device = cls.find_device(tensor_args, example_output) packed = cls( - MultiOutputLayout(device), + MultiOutputLayout(device=device), kernel, tensor_args, non_tensor_args, @@ -6931,14 +7167,14 @@ def create_wait(cls, kernel, inp: TensorBox) -> None: ) = cls.process_kernel(kernel, inp) assert not unbacked_bindings, f"{kernel} {unbacked_bindings}" packed = cls( - NoneLayout(inp.get_device()), + NoneLayout(device=inp.get_device()), kernel, tensor_args, non_tensor_args, unflatten_args, ) packed.mutation_outputs.append( - MutationOutput(NoneLayout(inp.get_device()), inp, packed) + MutationOutput(NoneLayout(device=inp.get_device()), inp, packed) ) def get_read_writes(self): diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index 71e3a21b005ef..94e9b86ea9405 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -6,6 +6,7 @@ from typing import cast, List, Optional, Sequence, Tuple, TYPE_CHECKING, TypedDict import torch +from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate from .. import config, ir from ..lowering import ( @@ -25,6 +26,7 @@ is_zeros, pad_listlike, sympy_product, + use_ck_conv_template, use_triton_template, ) from ..virtualized import V @@ -659,7 +661,17 @@ def channels_last_conv(): num_warps=cfg.num_warps, **cfg.kwargs, ) - + if use_ck_conv_template(layout): + CKGroupedConvFwdTemplate.add_ck_conv_choices( + choices, + layout, + input_nodes=(x, weight) + ((bias,) if bias is not None else tuple()), + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + n_spatial_dimensions=ndim, + ) return autotune_select_algorithm("convolution", choices, args, layout) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index c28f90dbf2620..2e8935cfed6b8 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -17,11 +17,10 @@ ExternKernel, FixedLayout, FlexibleLayout, - get_stride_order, + get_fill_order, InputBuffer, IRNode, StorageBox, - stride_order2fill_order, Subgraph, TensorBox, ) @@ -71,7 +70,7 @@ def create_placeholder( name: str, dtype: torch.dtype, device: torch.device ) -> TensorBox: """Creates a placeholder input buffers for producing subgraph_output.""" - input_buffer = InputBuffer(name, FixedLayout(device, dtype, [], [])) + input_buffer = InputBuffer(name=name, layout=FixedLayout(device, dtype, [], [])) return TensorBox.create(input_buffer) @@ -793,8 +792,7 @@ def flex_attention( # Construct output layout with strides matching the query. out_size = [B, Hq, seq_len_q, v_head_dim] - stride_order = get_stride_order(query.get_stride()) - fill_order = stride_order2fill_order(stride_order) + fill_order = get_fill_order(query.get_stride()) out_strides = construct_strides(out_size, fill_order) layout = FixedLayout( @@ -853,7 +851,7 @@ def flex_attention( # Note, we don't need to pass in the captured buffers explicitly # because they're implicitly added by the score_mod function # We do need to explicitly pass it in for autotuning though. - + original_kernel_options = kernel_options.copy() for BLOCK_M, BLOCK_N, num_warps, num_stages in configs: if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0 or SPARSE_Q_BLOCK_SIZE % BLOCK_M != 0: continue @@ -861,12 +859,13 @@ def flex_attention( if num_stages == 2: continue + cur_kernel_options = original_kernel_options.copy() # Performance tuning - kernel_options.setdefault("BLOCK_M", BLOCK_M) - kernel_options.setdefault("BLOCK_N", BLOCK_N) + cur_kernel_options.setdefault("BLOCK_M", BLOCK_M) + cur_kernel_options.setdefault("BLOCK_N", BLOCK_N) # Blocksparse options - kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) - kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) flex_attention_template.maybe_append_choice( choices=choices, @@ -891,7 +890,7 @@ def flex_attention( num_stages=num_stages, num_warps=num_warps, call_sizes=query.get_size(), - **kernel_options, + **cur_kernel_options, ) inputs_for_autotuning = ( [ @@ -1784,7 +1783,7 @@ def flex_attention_backward(*args, **kwargs): if BLOCK2 % BLOCK1 == 0 ] ) - + original_kernel_options = kernel_options.copy() for BLOCK1, BLOCK2, num_warps, num_stages in configs: if ( SPARSE_KV_BLOCK_SIZE % BLOCK1 != 0 @@ -1795,13 +1794,14 @@ def flex_attention_backward(*args, **kwargs): continue # Performance tuning - kernel_options.setdefault("BLOCK_M1", BLOCK1) - kernel_options.setdefault("BLOCK_N1", BLOCK2) - kernel_options.setdefault("BLOCK_M2", BLOCK2) - kernel_options.setdefault("BLOCK_N2", BLOCK1) + cur_kernel_options = original_kernel_options.copy() + cur_kernel_options.setdefault("BLOCK_M1", BLOCK1) + cur_kernel_options.setdefault("BLOCK_N1", BLOCK2) + cur_kernel_options.setdefault("BLOCK_M2", BLOCK2) + cur_kernel_options.setdefault("BLOCK_N2", BLOCK1) # Blocksparse options - kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) - kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + cur_kernel_options.setdefault("SPARSE_Q_BLOCK_SIZE", SPARSE_Q_BLOCK_SIZE) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) flex_attention_backward_template.maybe_append_choice( choices=choices, @@ -1829,7 +1829,7 @@ def flex_attention_backward(*args, **kwargs): call_sizes=query.get_size() + key.get_size()[1:3], num_stages=num_stages, num_warps=num_warps, - **kernel_options, + **cur_kernel_options, ) inputs_for_autotuning = ( [ diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index 291a78eeb4c27..7b1c3466b1290 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -479,6 +479,7 @@ def create_flex_decoding_kernel(*args, **kwargs): # Mark SPARSE_KV_BLOCK_SIZE as static shapes and add guards. SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) + original_kernel_options = kernel_options.copy() # Note, we don't need to pass in the captured buffers explicitly # because they're implicitly added by the score_mod function # We do need to explicitly pass it in for autotuning though. @@ -486,9 +487,10 @@ def create_flex_decoding_kernel(*args, **kwargs): if SPARSE_KV_BLOCK_SIZE % BLOCK_N != 0: continue + cur_kernel_options = original_kernel_options.copy() # Performance tuning - kernel_options.setdefault("BLOCK_N", BLOCK_N) - kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) + cur_kernel_options.setdefault("BLOCK_N", BLOCK_N) + cur_kernel_options.setdefault("SPARSE_KV_BLOCK_SIZE", SPARSE_KV_BLOCK_SIZE) # Work around https://github.com/pytorch/pytorch/issues/129625 if num_stages == 2: @@ -515,7 +517,7 @@ def create_flex_decoding_kernel(*args, **kwargs): num_stages=num_stages, num_warps=num_warps, call_sizes=query.get_size(), - **kernel_options, + **cur_kernel_options, ) inputs_for_flex_decoding = ( diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 6be50b9b6004c..d7aed0214e951 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -32,7 +32,7 @@ from ..utils import ( get_gpu_shared_memory, use_aten_gemm_kernels, - use_ck_template, + use_ck_gemm_template, use_cpp_packed_gemm_template, use_cutlass_template, use_max_autotune, @@ -204,7 +204,7 @@ def tuned_mm(mat1, mat2, *, layout=None): if static_shape and is_nonzero and use_cutlass_template(layout, m, n, k): CUTLASS3xGemmTemplate.add_cutlass_gemm_choices(choices, layout, [mat1, mat2]) - if is_nonzero and use_ck_template(layout, m, n, k): + if is_nonzero and use_ck_gemm_template(layout, m, n, k): CKGemmTemplate.add_ck_gemm_choices(choices, layout, [mat1, mat2]) if use_cpp_packed_gemm_template(layout, mat1, mat2): @@ -411,7 +411,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): beta=beta, ) - if is_nonzero and use_ck_template(layout, m, n, k): + if is_nonzero and use_ck_gemm_template(layout, m, n, k): CKGemmTemplate.add_ck_gemm_choices( choices, layout, diff --git a/torch/_inductor/kernel/mm_scaled.py b/torch/_inductor/kernel/mm_scaled.py index 98dc7d007f948..32f2362172652 100644 --- a/torch/_inductor/kernel/mm_scaled.py +++ b/torch/_inductor/kernel/mm_scaled.py @@ -16,7 +16,7 @@ realize_inputs, TritonTemplate, ) -from ..utils import use_aten_gemm_kernels, use_ck_template, use_triton_template +from ..utils import use_aten_gemm_kernels, use_ck_gemm_template, use_triton_template from .mm_common import _is_static_problem, mm_args, mm_grid, scaled_mm_configs @@ -294,7 +294,7 @@ def tuned_scaled_mm( **kwargs, ) - if is_nonzero and use_ck_template(layout, m, n, k): + if is_nonzero and use_ck_gemm_template(layout, m, n, k): CKGemmTemplate.add_ck_gemm_choices(choices, layout, input_nodes) if ( diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index da8b8465a2cfc..0c30ec780b52f 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs from __future__ import annotations +import collections import functools import itertools import re @@ -100,6 +101,7 @@ class LoopBody: indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] root_block: LoopBodyBlock memory_usage: Dict[MemoryUsageType, List[MemoryEntry]] + op_counts: collections.Counter[str] def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars): super().__init__() @@ -130,6 +132,7 @@ def _init_with_tracing(self, fn, args): self.indirect_vars = [] self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {} self.memory_usage = {t: [] for t in MemoryUsageType} + self.op_counts = collections.Counter() self.root_block = LoopBodyBlock(self, fn, args) # traces del self.indexing_exprs_name # not used after _init_with_tracing @@ -148,6 +151,7 @@ def _init_with_copy(self, other: LoopBody, args): self.indirect_vars = other.indirect_vars self.indirect_var_ranges = other.indirect_var_ranges self.memory_usage = other.memory_usage + self.op_counts = other.op_counts self.root_block = other.root_block.clone(self) submodules = {**other.submodules} @@ -157,6 +161,9 @@ def _init_with_copy(self, other: LoopBody, args): **{k: v.clone(self) for k, v in submodules.items()}, # type: ignore[attr-defined] } + def has_op(self, name: str): + return self.op_counts.get(name, 0) > 0 + def merge_loops(self) -> LoopBody: """ Merge both iteration and reduction loops and return a new LoopBody. @@ -609,8 +616,9 @@ def output(result): from .index_propagation import IndexPropagation from .sizevars import SimplifyIndexing - handler: Any = SimplifyIndexing( - CaptureIndexing(proxy_ops), self.body.var_ranges + handler: Any = CountOps( + SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges), + body.op_counts, ) if config.constant_and_index_propagation: handler = IndexPropagation( @@ -649,3 +657,13 @@ def clone(self, body: LoopBody): copy = LoopBodyBlock.__new__(LoopBodyBlock) copy.__dict__.update({**self.__dict__, "body": body}) return copy + + +class CountOps: + def __init__(self, inner: Any, counts: collections.Counter[str]): + self._inner = inner + self._counts = counts + + def __getattr__(self, name): + self._counts[name] += 1 + return getattr(self._inner, name) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index ba524bb18091c..6062cbcbb9fd5 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1,5 +1,6 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs +import dataclasses import functools import itertools import logging @@ -296,7 +297,7 @@ def promote(arg): if isinstance(arg, TensorBox): return to_dtype(arg, dtype) elif isinstance(arg, ir.Constant): - return ir.Constant(arg.value, dtype, device) + return ir.Constant(value=arg.value, dtype=dtype, device=device) else: return arg @@ -439,9 +440,13 @@ def broadcast_symbolic_shapes(a, b): for x, y in itertools.zip_longest( reversed(a), reversed(b), fillvalue=sympy.Integer(1) ): - if y == 1: + if V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(y, 1), size_oblivious=True + ): output.append(x) - elif x == 1: + elif V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(x, 1), size_oblivious=True + ): output.append(y) else: V.graph.sizevars.guard_equals(x, y) @@ -469,9 +474,11 @@ def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=No def const_func(x): if isinstance(x, sympy.Basic): - return ir.IndexingConstant(x, dtype, decode_device(None)) + return ir.IndexingConstant( + index=x, dtype=dtype, device=decode_device(None) + ) else: - return ir.Constant(x, dtype, decode_device(None)) + return ir.Constant(value=x, dtype=dtype, device=decode_device(None)) return [const_func(x) for x in inputs] ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView, ir.Constant))) @@ -480,13 +487,16 @@ def const_func(x): if isinstance(x, (int, float)): out.append( ExpandView.create( - ir.Constant(x, ex.get_dtype(), ex.get_device()), list(ex.get_size()) + ir.Constant(value=x, dtype=ex.get_dtype(), device=ex.get_device()), + list(ex.get_size()), ) ) elif isinstance(x, sympy.Basic): out.append( ExpandView.create( - IndexingConstant(x, ex.get_dtype(), ex.get_device()), + IndexingConstant( + index=x, dtype=ex.get_dtype(), device=ex.get_device() + ), list(ex.get_size()), ) ) @@ -858,7 +868,25 @@ def broadcast_tensors(*inputs): for x in inputs: sizes = x.get_size() if len(sizes) != len(target) or any( - ((a == 1 and b != 1) or (a != 1 and b == 1)) for a, b in zip(sizes, target) + ( + ( + V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(a, 1), size_oblivious=True + ) + and not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(b, 1), size_oblivious=True + ) + ) + or ( + not V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(a, 1), size_oblivious=True + ) + and V.graph.sizevars.shape_env.evaluate_expr( + sympy.Eq(b, 1), size_oblivious=True + ) + ) + ) + for a, b in zip(sizes, target) ): x = expand(x, target) outputs.append(x) @@ -1094,7 +1122,7 @@ def as_strided(x, size, stride, storage_offset=None): [sympy.expand(s) for s in stride], sympy.expand(storage_offset or 0), ) - return TensorBox(ir.ReinterpretView(storage, new_layout)) + return TensorBox(ir.ReinterpretView(data=storage, layout=new_layout)) @register_lowering(aten.as_strided_, type_promotion_kind=None) @@ -2573,7 +2601,7 @@ def clone_preserve_reinterpret_view(x): if reinterpret_view_layouts: x = x.data # unwrap TensorBox for layout in reinterpret_view_layouts[::-1]: - x = ir.ReinterpretView(x, layout) + x = ir.ReinterpretView(data=x, layout=layout) x = TensorBox(x) return x @@ -2991,7 +3019,7 @@ def empty_strided( pointwise.realize() buffer = pointwise.data.data # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode - buffer.data.ranges = [0] * len(size) + buffer.data = dataclasses.replace(buffer.data, ranges=[0] * len(size)) assert isinstance(buffer, ir.ComputedBuffer) size = [sympy.expand(s) for s in size] stride = ( @@ -3134,6 +3162,7 @@ def index_output_size_and_inner_fn( indexed_size, x_loader, check, + wrap_neg=True, ): # Note that behavior of indexing differs when there are non consecutive # tensors. In this case, the tensor index is pulled to the beginning. @@ -3185,6 +3214,7 @@ def fn(idx): loader(idx[start_offset : start_offset + rank]), size, check=check, + wrap_neg=wrap_neg, ) ) new_index = [ @@ -3207,7 +3237,7 @@ def index_impl(x, indices, check): ) -def index_impl_helper(x, indices, check): +def index_impl_helper(x, indices, check, wrap_neg=True): assert isinstance(indices, (list, tuple)) x_loader = x.make_loader() indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device()) @@ -3235,6 +3265,7 @@ def index_impl_helper(x, indices, check): indexed_size, None, check=check, + wrap_neg=wrap_neg, ) def inner_fn(idx): @@ -3391,9 +3422,9 @@ def index_put_impl_(self, indices, values, accumulate, check): scatter_mode="atomic_add" if accumulate else None, ) buffer = ir.ComputedBuffer( - None, - ir.MutationLayoutSHOULDREMOVE(self), - scatter, + name=None, + layout=ir.MutationLayoutSHOULDREMOVE(self), + data=scatter, ) buffer.name = V.graph.register_buffer(buffer) V.graph.register_operation(buffer) @@ -3414,7 +3445,9 @@ def index_put_impl_(self, indices, values, accumulate, check): @register_lowering(aten._unsafe_masked_index, type_promotion_kind=None) def _unsafe_masked_index(self, mask, indices, fill): - ranges, _, _unsafe_index_fn = index_impl_helper(self, indices, check=False) + ranges, _, _unsafe_index_fn = index_impl_helper( + self, indices, check=False, wrap_neg=False + ) mask_loader = mask.make_loader() self_loader = self.make_loader() @@ -3612,9 +3645,9 @@ def backend_reduce_str(reduce): scatter_mode=None, ) buffer = ir.ComputedBuffer( - None, - ir.MutationLayoutSHOULDREMOVE(self), - zero_out, + name=None, + layout=ir.MutationLayoutSHOULDREMOVE(self), + data=zero_out, ) buffer.name = V.graph.register_buffer(buffer) V.graph.register_operation(buffer) @@ -3631,9 +3664,9 @@ def backend_reduce_str(reduce): scatter_mode=backend_reduce_str(reduce), ) buffer = ir.ComputedBuffer( - None, - ir.MutationLayoutSHOULDREMOVE(self), - scatter, + name=None, + layout=ir.MutationLayoutSHOULDREMOVE(self), + data=scatter, ) buffer.name = V.graph.register_buffer(buffer) V.graph.register_operation(buffer) @@ -5276,7 +5309,7 @@ def mean(x, axis=None, keepdim=False, *, dtype=None): x = to_dtype(x, torch.float) sum_result = sum_(x, axis, keepdim) denom = sympy_product(size[i] for i in axis) - denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device()) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) denom = ExpandView.create(denom, list(sum_result.get_size())) return to_dtype(div(sum_result, denom), output_dtype) @@ -5297,7 +5330,7 @@ def var_mean_sum_(x, axis, correction, keepdim, return_mean): denom = sympy_product(size[i] for i in axis) if correction: denom = sympy.Max(denom - correction, 0) - denom = ir.IndexingConstant(denom, x.get_dtype(), x.get_device()) + denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device()) denom = ExpandView.create(denom, list(sum_result.get_size())) x_var = div(sum_result, denom) if not return_mean: @@ -6065,14 +6098,17 @@ def make_triton_fallback(op): ) register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True) foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul) +register_foreach_pointwise(aten._foreach_mul.Tensor, mul) foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul) register_foreach_pointwise(aten._foreach_sub.List, sub) register_foreach_pointwise(aten._foreach_sub.Scalar, sub) register_foreach_pointwise(aten._foreach_neg.default, neg) register_foreach_pointwise(aten._foreach_abs.default, abs) register_foreach_pointwise(aten._foreach_pow.Scalar, pow) +register_foreach_pointwise(aten._foreach_pow.List, pow) register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow) foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div) +register_foreach_pointwise(aten._foreach_div.Tensor, div) foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div) register_foreach_pointwise(aten._foreach_sqrt, sqrt) register_foreach_pointwise(aten._foreach_maximum.List, maximum) @@ -6322,13 +6358,21 @@ def inner_fn(idx): @register_lowering(triton_kernel_wrapper_mutation) -def triton_kernel_wrap_(*, kernel_idx, constant_args_idx, grid, kwargs): +def triton_kernel_wrap_( + *, + kernel_idx, + constant_args_idx, + grid, + tma_descriptor_metadata, + kwargs, +): from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table constant_args = kernel_side_table.get_constant_args(constant_args_idx) ir.UserDefinedTritonKernel( kernel_idx=kernel_idx, grid=grid, + tma_descriptor_metadata=tma_descriptor_metadata, kernel_args={**kwargs, **constant_args}, ) return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)} @@ -6358,6 +6402,12 @@ def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs): return list(map(TensorBox.create, result)) +@register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None) +def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, operands): + result = ir.InvokeSubgraph.create(subgraph_fn, operands) + return list(map(TensorBox.create, result)) + + @register_lowering(associative_scan_op, type_promotion_kind=None) def associative_scan(combine_fn: ir.Subgraph, xs, dim: int): from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph @@ -6436,6 +6486,7 @@ def _all_reduce(inp, reduce_op, group_name): @register_lowering(_c10d_functional.all_reduce_) def _all_reduce_(inp, reduce_op, group_name): + inp = ir.ExternKernel.require_contiguous(inp) ir._CollectiveKernel.create_inplace( _c10d_functional.all_reduce_.default, inp, reduce_op, group_name ) diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index e0849791b3485..ce260177c2d10 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -1,9 +1,10 @@ from __future__ import annotations +import collections import dataclasses import heapq import logging -from typing import Callable, Dict, List, Set, Tuple, TYPE_CHECKING, Union +from typing import Callable, Dict, List, Set, Tuple, TYPE_CHECKING, TypedDict, Union from torch._utils_internal import signpost_event from torch.utils._ordered_set import OrderedSet @@ -28,48 +29,35 @@ class MemoryPlanningInfoForBuffer: succ_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field( default_factory=OrderedSet ) - outdegree: int = 0 # this is used only in topological_sort_lpmf @dataclasses.dataclass class MemoryPlanningInfoForNode: - pred_buffers: List[Union[SchedulerBuffer, FreeableInputBuffer]] = dataclasses.field( - default_factory=list - ) - pred_nodes: List[BaseSchedulerNode] = dataclasses.field(default_factory=list) - succ_nodes: List[BaseSchedulerNode] = dataclasses.field(default_factory=list) - indegree: int = 0 index: int = 0 size: int = 0 - memory_to_free: int = 0 # this is used only in topological_sort_lpmf - size_with_reads: int = 0 # this is used only in topological_sort_dfs + pred_buffers: OrderedSet[ + Union[SchedulerBuffer, FreeableInputBuffer] + ] = dataclasses.field(default_factory=OrderedSet) + pred_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field( + default_factory=OrderedSet + ) + succ_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field( + default_factory=OrderedSet + ) @dataclasses.dataclass class FreeableInputBuffer: - dep: Dep + name: str mpi_buffer: MemoryPlanningInfoForBuffer = dataclasses.field( default_factory=MemoryPlanningInfoForBuffer ) def get_name(self) -> str: - return self.dep.name + return self.name def __hash__(self) -> int: - return hash(self.dep.name) - - -def dep_size_hint(dep: Dep) -> int: - res = 0 - try: - if not dep.has_unbacked_symbols(): - res = dep.numbytes_hint() - except KeyError: - # In at least one test (test/inductor/test_torchbind.py) we - # create a StarDep that doesn't exist in the graph and calling - # `has_unbacked_symbols()` throws an error. - pass - return res + return hash(self.name) def get_freeable_input_buf( @@ -78,22 +66,52 @@ def get_freeable_input_buf( ) -> Dict[str, FreeableInputBuffer]: """ Create and keep track of all input buffers that can be freed during the program + + Returns: + A dictionary containing all freeble input buffers, keyed by their names. """ - name_to_input_buf: Dict[str, FreeableInputBuffer] = {} + + # this function is copied from torch/_inductor/scheduler.py + # TODO: would be nice to remove the try/except block for both places + def _dep_size_hint(dep: Dep) -> int: + res = 0 + try: + if not dep.has_unbacked_symbols(): + res = dep.numbytes_hint() + except KeyError: + # In at least one test (test/inductor/test_torchbind.py) we + # create a StarDep that doesn't exist in the graph and calling + # `has_unbacked_symbols()` throws an error. + pass + return res + + # get freeable input buffers' successor nodes and their sizes + # note that different deps can have the same name, so we use name as keys + dep_name_to_succ_nodes: Dict[ + str, OrderedSet[BaseSchedulerNode] + ] = collections.defaultdict(OrderedSet) + dep_name_to_size: Dict[str, int] = dict() for node in nodes: for dep in node.read_writes.reads: - if ( - dep.name in graph_inputs - and not dep.name.startswith("primals_") - and dep.name not in name_to_input_buf - ): - name_to_input_buf[dep.name] = FreeableInputBuffer(dep) - name_to_input_buf[dep.name].mpi_buffer.size_free = dep_size_hint(dep) - - return name_to_input_buf + if dep.name in graph_inputs and not dep.name.startswith("primals_"): + dep_name_to_succ_nodes[dep.name].add(node) + dep_name_to_size[dep.name] = _dep_size_hint(dep) + + # create FreeableInputBuffer objects and add them to the returned dictionary + name_to_freeable_input_buf: Dict[str, FreeableInputBuffer] = dict() + for dep_name, succ_nodes in dep_name_to_succ_nodes.items(): + name_to_freeable_input_buf[dep_name] = FreeableInputBuffer( + dep_name, + MemoryPlanningInfoForBuffer( + size_free=dep_name_to_size[dep_name], succ_nodes=succ_nodes + ), + ) + return name_to_freeable_input_buf -def compute_size_for_scheduler_buffer(name_to_buf: Dict[str, SchedulerBuffer]) -> None: +def compute_size_for_scheduler_buffer( + name_to_buf: Dict[str, SchedulerBuffer] +) -> Dict[str, Tuple[int, int]]: """ Compute the size of each scheduler buffer, including (1) memory allocated when it is created and (2) memory deallocated when it is freed. @@ -107,63 +125,127 @@ def compute_size_for_scheduler_buffer(name_to_buf: Dict[str, SchedulerBuffer]) - buf0: at creation, 30 bytes allocated, when deleted, 0 bytes freed buf1: at creation, 0 bytes allocated, when deleted, 10 bytes freed buf2: at creation, 0 bytes allocated, when deleted, 20 bytes freed + + Returns: + A dictionary mapping a scheduler buffer to a tuple of (size_alloc, size_free). """ - from .scheduler import BaseSchedulerNode, OutputNode + from .ir import MultiOutput + from .scheduler import OutputNode - # compute the size of SchedulerBuffer without MultiOutputLayout layout - for sched_buf in name_to_buf.values(): - if not isinstance(sched_buf.node.layout, MultiOutputLayout): - sched_buf.mpi_buffer = MemoryPlanningInfoForBuffer() - sched_buf.mpi_buffer.size_alloc = V.graph.sizevars.size_hint( - sched_buf.node.get_numel(), fallback=0 - ) * get_dtype_size(sched_buf.node.get_dtype()) - sched_buf.mpi_buffer.size_free = sched_buf.mpi_buffer.size_alloc + sched_buf_to_size: Dict[str, Tuple[int, int]] = dict() - # compute the size of SchedulerBuffer with MultiOutputLayout layout - for sched_buf in name_to_buf.values(): + def _compute_and_update_buf_size( + sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False + ) -> int: if isinstance(sched_buf.node.layout, MultiOutputLayout): - sched_buf.mpi_buffer = MemoryPlanningInfoForBuffer() + size_alloc = 0 for user in sched_buf.users: if isinstance(user.node, OutputNode): continue - assert isinstance(user.node, BaseSchedulerNode) for buf in user.node.get_outputs(): - sched_buf.mpi_buffer.size_alloc += buf.mpi_buffer.size_alloc - buf.mpi_buffer.size_alloc = 0 + if isinstance(buf.node, MultiOutput): + size_alloc += _compute_and_update_buf_size(buf, True) + sched_buf_to_size[sched_buf.get_name()] = ( + 0 if user_of_MultiOutputLayout else size_alloc, + 0, + ) + return size_alloc + else: + buf_size = V.graph.sizevars.size_hint( + sched_buf.node.get_numel(), fallback=0 + ) * get_dtype_size(sched_buf.node.get_dtype()) + sched_buf_to_size[sched_buf.get_name()] = ( + 0 if user_of_MultiOutputLayout else buf_size, + buf_size, + ) + return buf_size + for sched_buf in name_to_buf.values(): + # skip if sched_buf is already processed as an user of another SchedulerBuffer + # whose layout is of the type MultiOutputLayout + if sched_buf.get_name() not in sched_buf_to_size: + _compute_and_update_buf_size(sched_buf) + + return sched_buf_to_size -def map_successor_nodes_with_predecessor_buffers( + +def assign_memory_planning_info_for_scheduler_buffers( nodes: List[BaseSchedulerNode], - name_to_input_buf: Dict[str, FreeableInputBuffer], name_to_buf: Dict[str, SchedulerBuffer], ) -> None: """ - For scheduling and memory estimation, for each scheduler node, we maintain - a list of its dependency buffers (SchedulerBuffer and FreeableInputBuffer). - This is similar to node.read_writes.reads, which is a list of Dep. - Reversely, for each SchedulerBuffer / FreeableInputBuffer, assign its successor nodes. + For each SchedulerBuffer, assign its size info and successor nodes. A buffer's successor nodes determines when a buffer can be freed. """ + # get buffer sizes + sched_buf_to_size = compute_size_for_scheduler_buffer(name_to_buf) + + # get buffer's successor nodes + # note that different deps can have the same name, so we use name as keys + dep_name_to_succ_nodes: Dict[ + str, OrderedSet[BaseSchedulerNode] + ] = collections.defaultdict(OrderedSet) for node in nodes: - node.mpi_node = MemoryPlanningInfoForNode() - node.mpi_node.pred_buffers = [] - for dep_name in {dep.name for dep in node.unmet_dependencies}: - sched_buf = name_to_buf.get(dep_name) - if sched_buf: - node.mpi_node.pred_buffers.append(sched_buf) - sched_buf.mpi_buffer.succ_nodes.add(node) - for dep_name in { - dep.name for dep in node.read_writes.reads - node.unmet_dependencies - }: - input_buf = name_to_input_buf.get(dep_name) - if input_buf: - node.mpi_node.pred_buffers.append(input_buf) - input_buf.mpi_buffer.succ_nodes.add(node) + for dep in node.unmet_dependencies: + dep_name_to_succ_nodes[dep.name].add(node) + + # populate the MemoryPlanningInfoForBuffer attribute to each scheduler buffer + # note: there are scheduler buffers not in dep_name_to_succ_nodes (e.g., graph outputs) + for buf_name in name_to_buf.keys(): + name_to_buf[buf_name].mpi_buffer = MemoryPlanningInfoForBuffer( + size_alloc=sched_buf_to_size[buf_name][0], + size_free=sched_buf_to_size[buf_name][1], + succ_nodes=dep_name_to_succ_nodes[buf_name], + ) + + +def assign_memory_planning_info_for_scheduler_nodes( + nodes: List[BaseSchedulerNode], + name_to_fused_node: Dict[str, BaseSchedulerNode], + name_to_buf: Dict[str, SchedulerBuffer], + name_to_freeable_input_buf: Dict[str, FreeableInputBuffer], +) -> None: + """ + Assign to each scheduler node its predecessor and successor nodes. + """ + from .scheduler import SchedulerBuffer + + for index, node in enumerate(nodes): + size_alloc = sum(buffer.mpi_buffer.size_alloc for buffer in node.get_outputs()) + pred_buffers: OrderedSet[ + Union[SchedulerBuffer, FreeableInputBuffer] + ] = OrderedSet() + for dep in node.read_writes.reads: + if dep.name in name_to_buf and dep in node.unmet_dependencies: + pred_buffers.add(name_to_buf[dep.name]) + elif dep.name in name_to_freeable_input_buf: + pred_buffers.add(name_to_freeable_input_buf[dep.name]) + pred_nodes = OrderedSet( + { + name_to_fused_node[pred_buffer.defining_op.get_name()] + for pred_buffer in pred_buffers + if (isinstance(pred_buffer, SchedulerBuffer)) + } + ) + succ_nodes = OrderedSet( + { + succ_node + for buffer in node.get_outputs() + for succ_node in buffer.mpi_buffer.succ_nodes + } + ) + node.mpi_node = MemoryPlanningInfoForNode( + index=index, + size=size_alloc, + pred_buffers=pred_buffers, + pred_nodes=pred_nodes, + succ_nodes=succ_nodes, + ) def estimate_peak_memory( nodes: List[BaseSchedulerNode], - name_to_input_buf: Dict[str, FreeableInputBuffer], + name_to_freeable_input_buf: Dict[str, FreeableInputBuffer], graph_outputs: Set[str], ) -> Tuple[int, List[int]]: """ @@ -172,63 +254,79 @@ def estimate_peak_memory( Returns: int: peak memory - List[int]: memory usage at each node. + List[int]: memory usage at each node (or each step). """ - # map each scheduler buffer to its size, start time, and end time + # map each scheduler buffer to its size, start step, and end step @dataclasses.dataclass class BufferInfo: buffer: Union[SchedulerBuffer, FreeableInputBuffer] size_alloc: int size_free: int - start_time: int - end_time: int - - name_to_buf_info: Dict[str, BufferInfo] = {} - node_name_to_time: Dict[str, int] = {} - - # assign start_time - for buf_name, input_buf in name_to_input_buf.items(): - name_to_buf_info[buf_name] = BufferInfo( - input_buf, - input_buf.mpi_buffer.size_free, - input_buf.mpi_buffer.size_free, - 0, - 0, + start_step: int + end_step: int + + # get the execution step of each node, this will be used to determine + # the end_step of buffers + node_to_step: Dict[BaseSchedulerNode, int] = dict() + for step, node in enumerate(nodes): + node_to_step[node] = step + + # get buffers' size and liveliness information + buf_info_list: List[BufferInfo] = [] + # 1. for freeable input buffers + for buf_name, input_buf in name_to_freeable_input_buf.items(): + end_step = ( + len(nodes) - 1 + if buf_name in graph_outputs + else max( + node_to_step[succ_node] for succ_node in input_buf.mpi_buffer.succ_nodes + ) ) - for t, node in enumerate(nodes): - node_name_to_time[node.get_name()] = t - for sched_buf in node.get_outputs(): - name_to_buf_info[sched_buf.get_name()] = BufferInfo( - sched_buf, - sched_buf.mpi_buffer.size_alloc, - sched_buf.mpi_buffer.size_free, - t, - t, + buf_info_list.append( + BufferInfo( + input_buf, + input_buf.mpi_buffer.size_free, + input_buf.mpi_buffer.size_free, + 0, + end_step, ) + ) - # assign end_time - for buf_name, buf_info in name_to_buf_info.items(): - succ_node_time = [ - node_name_to_time[succ_node.get_name()] - for succ_node in buf_info.buffer.mpi_buffer.succ_nodes - if succ_node.get_name() in node_name_to_time - ] - if succ_node_time: - buf_info.end_time = max(succ_node_time) - - # the end time of output buffers should be at the end of the horizon - for buf_name in graph_outputs: - if buf_name in name_to_buf_info: - name_to_buf_info[buf_name].end_time = len(nodes) - 1 + # 2. for scheduler buffers + for step, node in enumerate(nodes): + for sched_buf in node.get_outputs(): + # note: it is possible for a non-graph-output sched_buf to have no succ_nodes and + # to be only used by its defining op (e.g., due to fusion when all consumers of + # the buffer are fused with its defining op). In such cases, end_step is step. + end_step = ( + len(nodes) - 1 + if sched_buf.get_name() in graph_outputs + else max( + [ + node_to_step[succ_node] + for succ_node in sched_buf.mpi_buffer.succ_nodes + ], + default=step, + ) + ) + buf_info_list.append( + BufferInfo( + sched_buf, + sched_buf.mpi_buffer.size_alloc, + sched_buf.mpi_buffer.size_free, + step, + end_step, + ) + ) - # incremental memory changes at each time period + # incremental memory changes at each step memory = [0 for _ in range(len(nodes) + 1)] # for each buffer, update memory when created and when freed - for buf_name, buf_info in name_to_buf_info.items(): - memory[buf_info.start_time] += buf_info.size_alloc - memory[buf_info.end_time + 1] -= buf_info.size_free + for buf_info in buf_info_list: + memory[buf_info.start_step] += buf_info.size_alloc + memory[buf_info.end_step + 1] -= buf_info.size_free # get peak memory by compute the cumulative memories max_memory = 0 @@ -242,42 +340,19 @@ class BufferInfo: return (max_memory, memories_at_nodes) -def assign_predcessor_and_successor_nodes_to_nodes( - nodes: List[BaseSchedulerNode], name_to_fused_node: Dict[str, BaseSchedulerNode] -) -> None: - """ - Assign to each scheduler node its predecessor and successor nodes. - """ - from .scheduler import SchedulerBuffer - - for node in nodes: - node.mpi_node.pred_nodes = list( - { - name_to_fused_node[pred_buffer.defining_op.get_name()] - for pred_buffer in node.mpi_node.pred_buffers - if ( - isinstance(pred_buffer, SchedulerBuffer) - and pred_buffer.defining_op.get_name() in name_to_fused_node - ) - } - ) - node.mpi_node.succ_nodes = list( - { - succ_node - for buffer in node.get_outputs() - for succ_node in buffer.mpi_buffer.succ_nodes - } - ) - - def topological_sort_lpmf( nodes: List[BaseSchedulerNode], - name_to_input_buf: Dict[str, FreeableInputBuffer], + name_to_freeable_input_buf: Dict[str, FreeableInputBuffer], name_to_buf: Dict[str, SchedulerBuffer], graph_outputs: Set[str], ) -> List[BaseSchedulerNode]: """ A bfs-based greedy topological order. LPMF stands for "Least Peak Memory First". + + The idea is from this paper: + Buffer memory optimization for video codec application modeled in Simulink + https://www.cs.york.ac.uk/rts/docs/DAC-1964-2006/PAPERS/2006/DAC06/PDFFILES/P0689.PDF + The algorithm maintain the max memory so far. At every iteration, for each scheduleable node, it computes: - how much memory needs to be allocated for the output buffers of this node; @@ -291,53 +366,61 @@ def topological_sort_lpmf( (ii) otherwise, pick the one with the lowest mem1 value. """ + class NodeInfo(TypedDict): + indegree: int + memory_to_free: int + + class BufferInfo(TypedDict): + outdegree: int + + node_info: Dict[BaseSchedulerNode, NodeInfo] = dict() + buf_info: Dict[Union[SchedulerBuffer, FreeableInputBuffer], BufferInfo] = dict() + # compute nodes' number of unmet dependencies (for schedulability) # initialize the list of nodes ready to be scheduled - nodes_to_schedule: Set[BaseSchedulerNode] = set() + nodes_to_schedule: OrderedSet[BaseSchedulerNode] = OrderedSet() for node in nodes: - # note that .unmet_dependencies could have deps with the same name - # and in that case, it should only be counted once - node.mpi_node.indegree = len(node.mpi_node.pred_nodes) - if node.mpi_node.indegree == 0: + node_info[node] = { + "indegree": len(node.mpi_node.pred_nodes), + "memory_to_free": 0, + } + if node_info[node]["indegree"] == 0: nodes_to_schedule.add(node) # compute buffers' number of unmet successors (used to decide when to free) - for buf in list(name_to_buf.values()) + list(name_to_input_buf.values()): - buf.mpi_buffer.outdegree = len(buf.mpi_buffer.succ_nodes) - if buf.get_name() in graph_outputs: - buf.mpi_buffer.outdegree += 1 + for buf in list(name_to_buf.values()) + list(name_to_freeable_input_buf.values()): + buf_info[buf] = { + "outdegree": len(buf.mpi_buffer.succ_nodes) + + (1 if buf.get_name() in graph_outputs else 0) + } # initialize memory estimations live_memory = sum( - input_buf.mpi_buffer.size_free for input_buf in name_to_input_buf.values() + input_buf.mpi_buffer.size_free + for input_buf in name_to_freeable_input_buf.values() ) # this is the total output memory, which is a lower bound for peak memory - output_memory = sum( - name_to_buf[buf_name].mpi_buffer.size_free - for buf_name in graph_outputs - if buf_name in name_to_buf - ) + # we do not include the memory of non freeable input buffers + output_memory = 0 + for buf_name in graph_outputs: + if buf_name in name_to_buf: + output_memory += name_to_buf[buf_name].mpi_buffer.size_free + elif buf_name in name_to_freeable_input_buf: + output_memory += name_to_freeable_input_buf[buf_name].mpi_buffer.size_free max_memory = max(live_memory, output_memory) # compute the amount of memory that is allocated when a node is scheduled # and the amount of memory that can be freed when a node is scheduled for i, node in enumerate(nodes): - node.mpi_node.index = i # keep track of the original order - node.mpi_node.size = sum( - buffer.mpi_buffer.size_alloc for buffer in node.get_outputs() - ) - node.mpi_node.memory_to_free = 0 # 1. if a buffer read by this node is last used by this node - # then the buffer can be freed for buf in node.mpi_node.pred_buffers: - if buf.mpi_buffer.outdegree == 1: - node.mpi_node.memory_to_free += buf.mpi_buffer.size_free - # 2. if a buffer written by this node is used internally and - # not needed afterwards, it can be freed + if buf_info[buf]["outdegree"] == 1: + node_info[node]["memory_to_free"] += buf.mpi_buffer.size_free + # 2. if a buffer written by this node is used internally and not used later for buf in node.get_outputs(): - if buf.mpi_buffer.outdegree == 0: - node.mpi_node.memory_to_free += buf.mpi_buffer.size_free + if buf_info[buf]["outdegree"] == 0: + node_info[node]["memory_to_free"] += buf.mpi_buffer.size_free # schedule nodes one at a time schedule: List[BaseSchedulerNode] = [] @@ -348,7 +431,7 @@ def topological_sort_lpmf( nodes_to_schedule, key=lambda node: ( max(live_memory + node.mpi_node.size, max_memory), - node.mpi_node.size - node.mpi_node.memory_to_free, + node.mpi_node.size - node_info[node]["memory_to_free"], node.mpi_node.index, ), ) @@ -359,22 +442,22 @@ def topological_sort_lpmf( # update memory usage live_memory += selected_node.mpi_node.size max_memory = max(max_memory, live_memory) - live_memory -= selected_node.mpi_node.memory_to_free + live_memory -= node_info[node]["memory_to_free"] # update successor nodes and nodes_to_schedule for succ_node in selected_node.mpi_node.succ_nodes: - assert succ_node.mpi_node.indegree > 0 - succ_node.mpi_node.indegree -= 1 - if succ_node.mpi_node.indegree == 0: + assert node_info[succ_node]["indegree"] > 0 + node_info[succ_node]["indegree"] -= 1 + if node_info[succ_node]["indegree"] == 0: nodes_to_schedule.add(succ_node) # update predecessor nodes for buf in selected_node.mpi_node.pred_buffers: - assert buf.mpi_buffer.outdegree > 0 - buf.mpi_buffer.outdegree -= 1 - if buf.mpi_buffer.outdegree == 1: + assert buf_info[buf]["outdegree"] > 0 + buf_info[buf]["outdegree"] -= 1 + if buf_info[buf]["outdegree"] == 1: for succ_node in buf.mpi_buffer.succ_nodes: - succ_node.mpi_node.memory_to_free += buf.mpi_buffer.size_free + node_info[succ_node]["memory_to_free"] += buf.mpi_buffer.size_free if num_iters > len(nodes): raise RuntimeError("Failed to schedule, while loop ran too long for lpmf") @@ -392,34 +475,39 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo idea aims to reduce the liveness duration of buffers created. """ + class NodeInfo(TypedDict): + indegree: int + order: int + + node_info: Dict[BaseSchedulerNode, NodeInfo] = dict() + @dataclasses.dataclass - class HeapElement: + class NodeWithPriority: priority: List[int] node: BaseSchedulerNode - def __lt__(self, other: HeapElement) -> bool: + def __lt__(self, other: NodeWithPriority) -> bool: if self.priority == other.priority: return self.node.mpi_node.index < other.node.mpi_node.index return self.priority < other.priority def _node_priority(node: BaseSchedulerNode) -> List[int]: - assert node.mpi_node.indegree == 0 - ids = sorted( - {pred_node.mpi_node.index for pred_node in node.mpi_node.pred_nodes} + # priority is the order in which predecessor nodes are executed + assert node_info[node]["indegree"] == 0 + exec_orders = sorted( + {node_info[pred_node]["order"] for pred_node in node.mpi_node.pred_nodes} ) - ids.append(node.mpi_node.index) - return ids + return exec_orders # compute nodes' number of unmet dependencies (for schedulability) # initialize the list of nodes ready to be scheduled - nodes_to_schedule: List[HeapElement] = [] - for t, node in enumerate(nodes): - node.mpi_node.index = t - # note that .unmet_dependencies could have deps with the same name - # and in that case, it should only be counted once - node.mpi_node.indegree = len(node.mpi_node.pred_nodes) - if node.mpi_node.indegree == 0: - heapq.heappush(nodes_to_schedule, HeapElement(_node_priority(node), node)) + nodes_to_schedule: List[NodeWithPriority] = [] + for node in nodes: + node_info[node] = {"indegree": len(node.mpi_node.pred_nodes), "order": -1} + if node_info[node]["indegree"] == 0: + heapq.heappush( + nodes_to_schedule, NodeWithPriority(_node_priority(node), node) + ) # schedule nodes one at a time schedule: List[BaseSchedulerNode] = [] @@ -427,22 +515,23 @@ def _node_priority(node: BaseSchedulerNode) -> List[int]: while num_iters < len(nodes) and nodes_to_schedule: # select a node to schedule selected_node = heapq.heappop(nodes_to_schedule).node - selected_node.mpi_node.index = len(schedule) + node_info[selected_node]["order"] = len(schedule) schedule.append(selected_node) num_iters += 1 # update successor nodes and nodes_to_schedule for succ_node in selected_node.mpi_node.succ_nodes: - assert succ_node.mpi_node.indegree > 0 - succ_node.mpi_node.indegree -= 1 - if succ_node.mpi_node.indegree == 0: + assert node_info[succ_node]["indegree"] > 0 + node_info[succ_node]["indegree"] -= 1 + if node_info[succ_node]["indegree"] == 0: heapq.heappush( nodes_to_schedule, - HeapElement(_node_priority(succ_node), succ_node), + NodeWithPriority(_node_priority(succ_node), succ_node), ) if num_iters > len(nodes): raise RuntimeError("Failed to schedule, while loop ran too long for bfs") + return schedule @@ -458,6 +547,7 @@ def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo seen: OrderedSet[BaseSchedulerNode] = OrderedSet() name_to_node: Dict[str, BaseSchedulerNode] = dict() result: List[BaseSchedulerNode] = [] + size_with_reads: Dict[BaseSchedulerNode, int] = dict() def visit(n: BaseSchedulerNode) -> None: if n not in seen: @@ -468,7 +558,7 @@ def visit(n: BaseSchedulerNode) -> None: if dep.name in name_to_node ] for node in sorted( - dep_nodes, key=lambda x: (x.mpi_node.size_with_reads, x.mpi_node.index) + dep_nodes, key=lambda n: (size_with_reads[n], n.mpi_node.index) ): visit(node) result.append(n) @@ -477,18 +567,13 @@ def visit(n: BaseSchedulerNode) -> None: for name in node.get_buffer_names(): name_to_node[name] = node - for t, node in enumerate(nodes): - node.mpi_node.index = t - node.mpi_node.size = sum( - buffer.mpi_buffer.size_alloc for buffer in node.get_outputs() - ) - node.mpi_node.size_with_reads = node.mpi_node.size + sum( + for node in nodes: + size_with_reads[node] = node.mpi_node.size + sum( pred_buf.mpi_buffer.size_free for pred_buf in node.mpi_node.pred_buffers ) - for node in sorted( - nodes, key=lambda x: (x.mpi_node.size_with_reads, x.mpi_node.index) - ): + for node in sorted(nodes, key=lambda n: (size_with_reads[n], n.mpi_node.index)): visit(node) + return result @@ -509,7 +594,7 @@ def reorder_for_peak_memory( resulting topological order has the lowest peak memory estimation. """ - torch_log.warning("Reordering for peak memory -- %d nodes", len(nodes)) + torch_log.info("Reordering for peak memory -- %d nodes", len(nodes)) @dataclasses.dataclass class PeakMemoryResult: @@ -519,19 +604,20 @@ class PeakMemoryResult: # preparation -- as nodes are scheduled one at a time, these help # keep track of when a buffer can be freed, and when a node can be scheduled - name_to_input_buf: Dict[str, FreeableInputBuffer] = get_freeable_input_buf( + name_to_freeable_input_buf: Dict[str, FreeableInputBuffer] = get_freeable_input_buf( nodes, graph_inputs ) - compute_size_for_scheduler_buffer(name_to_buf) - map_successor_nodes_with_predecessor_buffers(nodes, name_to_input_buf, name_to_buf) - assign_predcessor_and_successor_nodes_to_nodes(nodes, name_to_fused_node) + assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf) + assign_memory_planning_info_for_scheduler_nodes( + nodes, name_to_fused_node, name_to_buf, name_to_freeable_input_buf + ) # keep track of the peak memory estimates of different methods peak_memory_diff_methods: List[PeakMemoryResult] = [] # the default estimated_peak_memory, _ = estimate_peak_memory( - nodes, name_to_input_buf, graph_outputs + nodes, name_to_freeable_input_buf, graph_outputs ) peak_memory_diff_methods.append( PeakMemoryResult(nodes, estimated_peak_memory, "baseline") @@ -542,12 +628,14 @@ class PeakMemoryResult: for method in methods: try: if method == topological_sort_lpmf: - order = method(nodes, name_to_input_buf, name_to_buf, graph_outputs) + order = method( + nodes, name_to_freeable_input_buf, name_to_buf, graph_outputs + ) else: order = method(nodes) assert len(order) == len(nodes) peak_memory, _ = estimate_peak_memory( - order, name_to_input_buf, graph_outputs + order, name_to_freeable_input_buf, graph_outputs ) peak_memory_diff_methods.append( PeakMemoryResult(order, peak_memory, method.__name__) @@ -566,4 +654,5 @@ class PeakMemoryResult: # get the optimal one best_result = min(peak_memory_diff_methods, key=lambda x: x.peak_memory) + return best_result.order diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index fe77279800e3d..bc374729dc656 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -411,10 +411,12 @@ def purge_old_log_files(): table.write_header() -@lru_cache def enabled_metric_tables() -> Set[str]: - config_str = config.enabled_metric_tables + return enabled_metric_tables_impl(config.enabled_metric_tables) + +@lru_cache +def enabled_metric_tables_impl(config_str: str) -> Set[str]: enabled = set() for name in config_str.split(","): name = name.strip() diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index 890689e7865c3..f0aef30437015 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -7,7 +7,6 @@ from torch._prims_common import make_channels_last_strides_for from torch.utils._ordered_set import OrderedSet -from . import config from .ir import ( ExternKernelAlloc, FixedLayout, @@ -188,6 +187,9 @@ def _prepare_linear_fusion_create( x: "TensorBox", weight: "TensorBox", bias: "TensorBox", + quantize_args: Optional[List["TensorBox"]] = None, + other: Optional["TensorBox"] = None, + binary_sum: bool = False, ): """ This function is a helper function to prepare inputs, layout and constant args @@ -209,7 +211,22 @@ def _prepare_linear_fusion_create( x = cls.require_stride_order(x, req_stride_order) assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" - inputs = [x, weight] + inputs = [x] + + if quantize_args is not None: + x_scale, x_zero_point, w_scale, w_zero_point = quantize_args + x_scale.realize() + x_zero_point.realize() + w_scale.realize() + w_zero_point.realize() + inputs = inputs + [x_scale, x_zero_point] + [weight] + [w_scale, w_zero_point] + else: + inputs += [weight] + + if other is not None: + if binary_sum: + other = cls.require_stride_order(other, req_stride_order) + inputs = inputs + [other] output_stride = FlexibleLayout.contiguous_strides(output_size) kernel_layout = FixedLayout( @@ -224,19 +241,16 @@ def _prepare_linear_fusion_create( inputs.append(bias) else: constant_args.insert(0, bias) - return inputs, constant_args, kernel_layout, req_stride_order + return inputs, constant_args, kernel_layout, req_stride_order, other def _create_output_node(packed): - if not config.abi_compatible: - return packed - output_ir = MultiOutput( packed.get_layout(), packed, [], ) - packed.layout = MultiOutputLayout(packed.get_device()) + packed.layout = MultiOutputLayout(device=packed.get_device()) packed.outputs = [output_ir] return output_ir @@ -254,9 +268,7 @@ def __init__( constant_args, None, op_overload=torch.ops.mkldnn._convolution_pointwise.default, - cpp_kernel_name="aoti_torch_cpu_mkldnn__convolution_pointwise" - if config.abi_compatible - else None, + cpp_kernel_name="aoti_torch_cpu_mkldnn__convolution_pointwise", ) self.cpp_op_schema = """ at::Tensor( @@ -272,24 +284,8 @@ def __init__( std::optional algorithm)""" def codegen(self, wrapper): - if config.abi_compatible: - wrapper.include_extra_header( - "torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h" - ) - super().codegen(wrapper) - else: - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - op_overload=self.op_overload, - raw_args=[*self.inputs, *self.constant_args], - ) - if isinstance(self.layout, Layout): - self.codegen_size_asserts(wrapper) + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h") + super().codegen(wrapper) @classmethod def create( @@ -335,9 +331,7 @@ def __init__( constant_args, None, op_overload=torch.ops.mkldnn._convolution_pointwise.binary, - cpp_kernel_name="aoti_torch_cpu_mkldnn__convolution_pointwise_binary" - if config.abi_compatible - else None, + cpp_kernel_name="aoti_torch_cpu_mkldnn__convolution_pointwise_binary", ) self.cpp_op_schema = """ at::Tensor( @@ -357,25 +351,8 @@ def __init__( self.cpp_constant_args = cpp_constant_args def codegen(self, wrapper): - if config.abi_compatible: - wrapper.include_extra_header( - "torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h" - ) - super().codegen(wrapper) - else: - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - self.op_overload, - [*self.inputs, *self.constant_args], - ) - if isinstance(self.layout, Layout): - self.codegen_size_asserts(wrapper) + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h") + super().codegen(wrapper) @classmethod def create( @@ -435,9 +412,7 @@ def __init__( constant_args, None, op_overload=torch.ops.mkldnn._convolution_pointwise_.binary, - cpp_kernel_name="aoti_torch_cpu_mkldnn__convolution_pointwise_binary_" - if config.abi_compatible - else None, + cpp_kernel_name="aoti_torch_cpu_mkldnn__convolution_pointwise_binary_", ) # TODO: op.call: input[0] should be at::Tensor& self.cpp_op_schema = """ @@ -457,28 +432,13 @@ def __init__( std::optional unary_algorithm)""" self.mutation_outputs = [ - MutationOutput(NoneLayout(inputs[0].get_device()), inputs[0], self), - MutationOutput(NoneLayout(inputs[1].get_device()), inputs[1], self), + MutationOutput(NoneLayout(device=inputs[0].get_device()), inputs[0], self), + MutationOutput(NoneLayout(device=inputs[1].get_device()), inputs[1], self), ] def codegen(self, wrapper): - if config.abi_compatible: - wrapper.include_extra_header( - "torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h" - ) - super().codegen(wrapper) - else: - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - self.op_overload, - [*self.inputs, *self.constant_args], - ) + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h") + super().codegen(wrapper) def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: return OrderedSet() @@ -518,7 +478,7 @@ def create( unary_algorithm, ] packed = ConvolutionBinaryInplace( - kernel_layout=NoneLayout(inputs[1].get_device()), # type: ignore[arg-type] + kernel_layout=NoneLayout(device=inputs[1].get_device()), # type: ignore[arg-type] inputs=inputs, constant_args=constant_args, ) @@ -541,9 +501,7 @@ def __init__( constant_args, None, op_overload=torch.ops.mkldnn._convolution_transpose_pointwise.default, - cpp_kernel_name="aoti_torch_cpu_mkldnn__convolution_transpose_pointwise" - if config.abi_compatible - else None, + cpp_kernel_name="aoti_torch_cpu_mkldnn__convolution_transpose_pointwise", ) self.cpp_op_schema = """ at::Tensor( @@ -560,20 +518,8 @@ def __init__( std::optional algorithm)""" def codegen(self, wrapper): - if config.abi_compatible: - wrapper.include_extra_header( - "torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h" - ) - super().codegen(wrapper) - else: - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - ) + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h") + super().codegen(wrapper) @classmethod def create( @@ -1135,7 +1081,7 @@ def create( V.graph.mark_buffer_mutated(qaccum.get_name()) packed = QConvPointWiseBinaryPT2E( - layout=NoneLayout(qaccum.get_device()), + layout=NoneLayout(device=qaccum.get_device()), inputs=inputs, constant_args=constant_args, ) @@ -1167,16 +1113,8 @@ def __init__( const int64_t prepack_batch_size)""" def codegen(self, wrapper): - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - op_overload=self.op_overload, - raw_args=[*self.inputs, *self.constant_args], - ) + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h") + super().codegen(wrapper) @classmethod def create(cls, x, packed_w, orig_w, B, batch_size): @@ -1215,9 +1153,7 @@ def __init__( constant_args, None, op_overload=torch.ops.mkldnn._linear_pointwise.default, - cpp_kernel_name="aoti_torch_cpu__linear_pointwise" - if config.abi_compatible - else None, + cpp_kernel_name="aoti_torch_cpu__linear_pointwise", ) self.cpp_kernel_key = "linear_pointwise" self.cpp_op_schema = """ @@ -1230,23 +1166,8 @@ def __init__( std::optional algorithm)""" def codegen(self, wrapper): - if config.abi_compatible: - wrapper.include_extra_header( - "torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h" - ) - super().codegen(wrapper) - else: - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - op_overload=self.op_overload, - raw_args=[*self.inputs, *self.constant_args], - outputs=self.outputs, - ) + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h") + super().codegen(wrapper) @classmethod def create(cls, x, w, B, attr, scalars, algorithm): @@ -1266,7 +1187,7 @@ def create(cls, x, w, B, attr, scalars, algorithm): constant_args.insert(0, None) packed = LinearUnary( - layout=FlexibleLayout( + layout=FixedLayout( device=x.get_device(), dtype=x.get_dtype(), size=output_size, @@ -1295,9 +1216,7 @@ def __init__( constant_args, None, op_overload=torch.ops.mkldnn._linear_pointwise.binary, - cpp_kernel_name="aoti_torch_cpu__linear_pointwise_binary" - if config.abi_compatible - else None, + cpp_kernel_name="aoti_torch_cpu__linear_pointwise_binary", ) self.cpp_op_schema = """ at::Tensor( @@ -1309,24 +1228,8 @@ def __init__( """ def codegen(self, wrapper): - if config.abi_compatible: - wrapper.include_extra_header( - "torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h" - ) - super().codegen(wrapper) - else: - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - self.codegen_args(), - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - op_overload=self.op_overload, - raw_args=[*self.inputs, *self.constant_args], - outputs=self.outputs, - ) + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h") + super().codegen(wrapper) @classmethod def create(cls, x, y, w, B, attr): @@ -1347,7 +1250,7 @@ def create(cls, x, y, w, B, attr): constant_args.insert(0, B) packed = LinearBinary( - layout=FlexibleLayout( + layout=FixedLayout( device=x.get_device(), dtype=x.get_dtype(), size=output_size, @@ -1368,7 +1271,6 @@ def __init__( inputs, constant_args=(), has_bias=True, - x_scale_zp_are_tensors=False, ) -> None: """ if bias is not None @@ -1381,21 +1283,15 @@ def __init__( fp32_output, unary_attr, unary_scalars, unary_algorithm] """ self.has_bias = has_bias - self.x_scale_zp_are_tensors = x_scale_zp_are_tensors super().__init__( layout, inputs, constant_args, None, - op_overload=torch.ops.onednn.qlinear_pointwise.tensor - if x_scale_zp_are_tensors - else torch.ops.onednn.qlinear_pointwise.default, - ) - x_scale_type_str, x_zp_type_str = ( - ("at::Tensor", "at::Tensor") - if x_scale_zp_are_tensors - else ("double", "int64_t") + op_overload=(torch.ops.onednn.qlinear_pointwise.tensor), + cpp_kernel_name=("aoti_torch_cpu__qlinear_pointwise_tensor"), ) + x_scale_type_str, x_zp_type_str = ("at::Tensor", "at::Tensor") self.cpp_op_schema = f""" at::Tensor( at::Tensor act, @@ -1413,104 +1309,9 @@ def __init__( c10::string_view post_op_algorithm)""" def codegen(self, wrapper): - # Parser the inputs and constant - # The raw_args setup can be skipped if there is a C shim implementation - args = [x.codegen_reference() for x in self.inputs] - const_args = [] - const_args.extend(self.codegen_const_args()) - - x = args[0] - x_raw = self.inputs[0] - packed_weight = args[1] - packed_weight_raw = self.inputs[1] - bias = args[2] if self.has_bias else const_args[0] - bias_raw = self.inputs[2] if self.has_bias else self.constant_args[0] - w_scale, w_zp = args[-2], args[-1] - w_scale_raw, w_zp_raw = self.inputs[-2], self.inputs[-1] - if self.x_scale_zp_are_tensors: - assert len(args) >= 4 - x_scale, x_zp = args[-4], args[-3] - x_scale_raw, x_zp_raw = self.inputs[-4], self.inputs[-3] - ( - o_scale, - o_zp, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-6:] - ( - o_scale_raw, - o_zp_raw, - output_dtype_raw, - unary_attr_raw, - unary_scalars_raw, - unary_algorithm_raw, - ) = self.constant_args[-6:] - else: - assert len(const_args) >= 8 - ( - x_scale, - x_zp, - o_scale, - o_zp, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-8:] - ( - x_scale_raw, - x_zp_raw, - o_scale_raw, - o_zp_raw, - output_dtype_raw, - unary_attr_raw, - unary_scalars_raw, - unary_algorithm_raw, - ) = self.constant_args[-8:] + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h") + super().codegen(wrapper) - codegen_args = ( - x, - x_scale, - x_zp, - packed_weight, - w_scale, - w_zp, - bias, - o_scale, - o_zp, - output_dtype, - unary_attr, - unary_scalars, - unary_algorithm, - ) - raw_args = ( - x_raw, - x_scale_raw, - x_zp_raw, - packed_weight_raw, - w_scale_raw, - w_zp_raw, - bias_raw, - o_scale_raw, - o_zp_raw, - output_dtype_raw, - unary_attr_raw, - unary_scalars_raw, - unary_algorithm_raw, - ) - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - codegen_args, - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - self.op_overload, - raw_args, - ) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) @@ -1518,8 +1319,8 @@ def codegen(self, wrapper): def create( cls, qx: "TensorBox", - x_scale: float, - x_zero_point: int, + x_scale: "TensorBox", + x_zero_point: "TensorBox", qw: "TensorBox", # packed_weight w_scale: "TensorBox", w_zero_point: "TensorBox", @@ -1531,25 +1332,14 @@ def create( post_op_args, post_op_algorithm, ): - (inputs, constant_args, kernel_layout, _) = _prepare_linear_fusion_create( + (inputs, constant_args, kernel_layout, _, _) = _prepare_linear_fusion_create( cls, qx, qw, bias, + [x_scale, x_zero_point, w_scale, w_zero_point], ) - if isinstance(x_scale, TensorBox) and isinstance(x_zero_point, TensorBox): - x_scale.realize() - x_zero_point.realize() - inputs = inputs + [x_scale, x_zero_point] - x_scale_zp_are_tensors = True - else: - assert isinstance(x_scale, float) and isinstance(x_zero_point, int) - constant_args = constant_args + [x_scale, x_zero_point] - x_scale_zp_are_tensors = False - w_scale.realize() - w_zero_point.realize() - inputs = inputs + [w_scale, w_zero_point] constant_args = constant_args + [ output_scale, output_zero_point, @@ -1570,7 +1360,6 @@ def create( inputs=inputs, constant_args=constant_args, has_bias=(bias is not None), - x_scale_zp_are_tensors=x_scale_zp_are_tensors, ) @@ -1581,34 +1370,28 @@ def __init__( inputs, constant_args=(), has_bias=True, - x_scale_zp_are_tensors=False, ) -> None: """ if bias is not None - - inputs = [x, w, b, weight_scale, weight_zp, x2] - - const_args is: [x_scale, x_zp, o_scale, o_zp, + - inputs = [x, w, x_scale, x_zp, weight_scale, weight_zp, x2, bias] + - const_args is: [o_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] else - - inputs = [x, w, weight_scale, weight_zp, x2] - - const_args is: [bias, x_scale, x_zp, o_scale, o_zp, + - inputs = [x, w, x_scale, x_zp, weight_scale, weight_zp, x2] + - const_args is: [bias, o_scale, o_zp, fp32_output, binary_attr, aplha, unary_attr, unary_scalars, unary_algorithm] """ self.has_bias = has_bias - self.x_scale_zp_are_tensors = x_scale_zp_are_tensors + self.idx_for_inplace_sum = 6 super().__init__( layout, inputs, constant_args, None, - op_overload=torch.ops.onednn.qlinear_pointwise.binary_tensor - if x_scale_zp_are_tensors - else torch.ops.onednn.qlinear_pointwise.binary, - ) - x_scale_type_str, x_zp_type_str = ( - ("at::Tensor", "at::Tensor") - if x_scale_zp_are_tensors - else ("double", "int64_t") + op_overload=(torch.ops.onednn.qlinear_pointwise.binary_tensor), + cpp_kernel_name="aoti_torch_cpu__qlinear_pointwise_binary_tensor", ) + x_scale_type_str, x_zp_type_str = ("at::Tensor", "at::Tensor") self.cpp_op_schema = f""" at::Tensor( at::Tensor act, @@ -1631,141 +1414,15 @@ def __init__( c10::string_view unary_post_op_algorithm)""" def codegen(self, wrapper): - # Parser the inputs and constant - # The raw_args setup can be skipped if there is a C shim implementation - args = [x.codegen_reference() for x in self.inputs] - const_args = [] - const_args.extend(self.codegen_const_args()) - - x = args[0] - x_raw = self.inputs[0] - packed_weight = args[1] - packed_weight_raw = self.inputs[1] - bias = args[2] if self.has_bias else const_args[0] - bias_raw = self.inputs[2] if self.has_bias else self.constant_args[0] - w_scale, w_zp, other = args[-3], args[-2], args[-1] - w_scale_raw, w_zp_raw, other_raw = ( - self.inputs[-3], - self.inputs[-2], - self.inputs[-1], - ) - if self.x_scale_zp_are_tensors: - assert len(args) >= 5 - x_scale, x_zp = args[-5], args[-4] - x_scale_raw, x_zp_raw = self.inputs[-5], self.inputs[-4] - ( - o_scale, - o_zp, - output_dtype, - other_scale, - other_zp, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-10:] - ( - o_scale_raw, - o_zp_raw, - output_dtype_raw, - other_scale_raw, - other_zp_raw, - binary_attr_raw, - alpha_raw, - unary_attr_raw, - unary_scalars_raw, - unary_algorithm_raw, - ) = self.constant_args[-10:] - else: - assert len(const_args) >= 8 - ( - x_scale, - x_zp, - o_scale, - o_zp, - output_dtype, - other_scale, - other_zp, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ) = const_args[-12:] - ( - x_scale_raw, - x_zp_raw, - o_scale_raw, - o_zp_raw, - output_dtype_raw, - other_scale_raw, - other_zp_raw, - binary_attr_raw, - alpha_raw, - unary_attr_raw, - unary_scalars_raw, - unary_algorithm_raw, - ) = self.constant_args[-12:] - - codegen_args = ( - x, - x_scale, - x_zp, - packed_weight, - w_scale, - w_zp, - other, - bias, - o_scale, - o_zp, - output_dtype, - other_scale, - other_zp, - binary_attr, - alpha, - unary_attr, - unary_scalars, - unary_algorithm, - ) - raw_args = ( - x_raw, - x_scale_raw, - x_zp_raw, - packed_weight_raw, - w_scale_raw, - w_zp_raw, - other_raw, - bias_raw, - o_scale_raw, - o_zp_raw, - output_dtype_raw, - other_scale_raw, - other_zp_raw, - binary_attr_raw, - alpha_raw, - unary_attr_raw, - unary_scalars_raw, - unary_algorithm_raw, - ) - wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed( - self.get_name(), - self.python_kernel_name, - self.cpp_kernel_name, - codegen_args, - self.cpp_op_schema, - self.cpp_kernel_key, - self.cpp_kernel_overload_name, - self.op_overload, - raw_args, - ) + wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h") + super().codegen(wrapper) if isinstance(self.layout, Layout): self.codegen_size_asserts(wrapper) def get_mutation_names(self): binary_post_op = self.constant_args[-5] if binary_post_op == "sum": - return [self.inputs[-1].get_name()] + return [self.inputs[self.idx_for_inplace_sum].get_name()] else: return [] @@ -1773,8 +1430,8 @@ def get_mutation_names(self): def create( cls, qx: "TensorBox", - x_scale: float, - x_zero_point: int, + x_scale: "TensorBox", + x_zero_point: "TensorBox", qw: "TensorBox", # packed_weight w_scale: "TensorBox", w_zero_point: "TensorBox", @@ -1796,28 +1453,17 @@ def create( constant_args, kernel_layout, req_stride_order, + other, ) = _prepare_linear_fusion_create( cls, qx, qw, bias, + [x_scale, x_zero_point, w_scale, w_zero_point], + other, + binary_post_op == "sum", ) - if isinstance(x_scale, TensorBox) and isinstance(x_zero_point, TensorBox): - x_scale.realize() - x_zero_point.realize() - inputs = inputs + [x_scale, x_zero_point] - x_scale_zp_are_tensors = True - else: - assert isinstance(x_scale, float) and isinstance(x_zero_point, int) - constant_args = constant_args + [x_scale, x_zero_point] - x_scale_zp_are_tensors = False - w_scale.realize() - w_zero_point.realize() - inputs = inputs + [w_scale, w_zero_point] - if binary_post_op == "sum": - other = cls.require_stride_order(other, req_stride_order) - inputs.append(other) constant_args = constant_args + [ output_scale, output_zero_point, @@ -1834,14 +1480,13 @@ def create( if binary_post_op == "sum": V.graph.mark_buffer_mutated(other.get_name()) packed = QLinearPointwiseBinaryPT2E( - layout=NoneLayout(other.get_device()), + layout=NoneLayout(device=other.get_device()), inputs=inputs, constant_args=constant_args, has_bias=(bias is not None), - x_scale_zp_are_tensors=x_scale_zp_are_tensors, ) # Return other since it has been inplace changed. - return packed.inputs[-1] + return packed.inputs[packed.idx_for_inplace_sum] assert output_dtype is not None if output_dtype in [torch.float32, torch.bfloat16]: @@ -1854,7 +1499,6 @@ def create( inputs=inputs, constant_args=constant_args, has_bias=(bias is not None), - x_scale_zp_are_tensors=x_scale_zp_are_tensors, ) @@ -1932,7 +1576,7 @@ def create( ] packed = MkldnnRnnLayer( - MultiOutputLayout(x.get_device()), + MultiOutputLayout(device=x.get_device()), inputs=inputs, constant_args=constant_args, ) diff --git a/torch/_inductor/optimize_indexing.py b/torch/_inductor/optimize_indexing.py index 96bf8641f3c9a..cd7ac7207dd42 100644 --- a/torch/_inductor/optimize_indexing.py +++ b/torch/_inductor/optimize_indexing.py @@ -1,5 +1,5 @@ -# mypy: allow-untyped-defs import math +from typing import Any, Dict, List import sympy @@ -10,7 +10,7 @@ from .utils import dominated_nodes -def val_expressable_in_32_bits(val): +def val_expressable_in_32_bits(val: Any) -> bool: if getattr(val, "is_Boolean", False): return True @@ -32,17 +32,23 @@ def val_expressable_in_32_bits(val): raise TypeError(f"Unexpected value {val}") -def range_expressable_in_32_bits(range): +def range_expressable_in_32_bits(range: ValueRanges[sympy.Expr]) -> bool: return val_expressable_in_32_bits(range.lower) and val_expressable_in_32_bits( range.upper ) -def try_to_reduce_precision(node, bounds, indirect_vars, indices, replacement_vals): +def try_to_reduce_precision( + node: Any, + bounds: Dict[Any, Any], + indirect_vars: List[Any], + indices: Dict[Any, sympy.Expr], + replacement_vals: Dict[Any, ValueRanges[sympy.Expr]], +) -> None: # if a downstream use of a node explicitly converts to int32, or float16/float32/float64, # then it's precision is set for that chain of uses, and we don't need to consider those # dominated values - def skip_filter(node): + def skip_filter(node: Any) -> bool: return node.target == "to_dtype" and node.args[2] in ( torch.int32, torch.float32, @@ -87,7 +93,7 @@ def skip_filter(node): node.args = tuple(args) -def indexing_dtype_strength_reduction(loop_body: LoopBody): +def indexing_dtype_strength_reduction(loop_body: LoopBody) -> None: """ Performs Value Range Analysis on LoopBody's fx graph to reduce precision of intermediaries from int64 to int32 diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index ada0fcb32141d..a45a7505b6cd1 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -70,7 +70,7 @@ TypeVar, Union, ) -from typing_extensions import Self, TypeGuard +from typing_extensions import Self, TypeIs import torch import torch._guards @@ -139,6 +139,15 @@ def __init__(self) -> None: MULTIPLE = Multiple() +def _transfer_meta(new_meta: Dict[str, Any], old_meta: Dict[str, Any]) -> None: + # transfer metadata after pattern matching occurs. + # skip "val" and "tensor_meta" because this info is too specific; it's unlikely + # to remain accurate after pattern matching has occurred. + new_meta.update( + (k, v) for k, v in old_meta.items() if k in torch.fx.proxy._COPY_META_FIELDS + ) + + class Match: """ Represents a successfully matched pattern. @@ -157,7 +166,7 @@ class Match: nodes: List[torch.fx.Node] targets: Dict[_TargetExpr, torch.fx.node.Target] ctx: MatchContext - replacement_graph: Optional[torch.fx.Graph] + replacement_graph: Optional[torch.fx.GraphModule] def __init__( self, @@ -253,6 +262,10 @@ def replace_by_example( replacement = trace_fn( replacement_fn, torch.fx.map_arg(args, lambda arg: arg.meta["val"]) # type: ignore[arg-type] ) + if len(self.nodes) == 1: + for n in replacement.graph.nodes: + _transfer_meta(new_meta=n.meta, old_meta=self.nodes[0].meta) + ReplacementPatternEntry.replace_with_graph( self, self.ctx.graph, @@ -292,10 +305,10 @@ def __bool__(self) -> bool: MatchResult = Union[Match, FailedMatch] -def is_match(m: MatchResult) -> TypeGuard[Match]: +def is_match(m: MatchResult) -> TypeIs[Match]: """ - TypeGuards cannot act on `self`. Thus this function exists to let mypy - recognize FailedMatch.__bool__ as a TypeGuard. + TypeIs cannot act on `self`. Thus this function exists to let mypy + recognize FailedMatch.__bool__ as a TypeIs. """ return bool(m) @@ -569,18 +582,25 @@ def simple_flatten( def pytree_flatten( args: Sequence[Any], kwargs: Mapping[Any, Any] ) -> Tuple[Sequence[Any], Union[_SimpleSpec, pytree.TreeSpec]]: - def norm_spec(s: pytree.TreeSpec) -> pytree.TreeSpec: - if s.type is None: - return s - mapping = {immutable_list: list, tuple: list, immutable_dict: dict} - return pytree.TreeSpec( - mapping.get(s.type, s.type), - s.context, - list(map(norm_spec, s.children_specs)), - ) + type_mapping = {immutable_list: tuple, list: tuple, immutable_dict: dict} + + def convert_type(x: Any) -> Any: + cls = type(x) + convert_fn = type_mapping.get(cls) + if convert_fn is not None: + return pytree.tree_map( + convert_type, + convert_fn(x), + is_leaf=lambda x: type(x) in type_mapping, + ) + return x - flat, spec = pytree.tree_flatten([args, kwargs]) - spec = norm_spec(spec) + normalized_args_tree = pytree.tree_map( + convert_type, + (args, kwargs), + is_leaf=lambda x: type(x) in type_mapping, + ) + flat, spec = pytree.tree_flatten(normalized_args_tree) return flat, spec def __repr__(self) -> str: @@ -1049,6 +1069,7 @@ def run_node(self, node: torch.fx.Node) -> Any: target = node.target args, kwargs = self.fetch_args_kwargs_from_env(node) result = graph.call_function(target, args, kwargs) # type: ignore[arg-type] + _transfer_meta(new_meta=result.meta, old_meta=node.meta) if "val" in node.meta and "val" not in result.meta: result.meta["val"] = node.meta["val"] if isinstance(node.meta["val"], torch.Tensor): @@ -1330,7 +1351,13 @@ def search_fn_new(*args_new: Any) -> Any: if is_match(specific_pattern_match) and extra_check(specific_pattern_match): # trace the pattern using the shapes from the user program - match.replacement_graph = trace_fn(replace_fn, args) # type: ignore[assignment] + match.replacement_graph = trace_fn(replace_fn, args) + if len(match.nodes) == 1: + for n in match.replacement_graph.graph.nodes: + _transfer_meta( + new_meta=n.meta, + old_meta=match.nodes[0].meta, + ) return True return False diff --git a/torch/_inductor/quantized_lowerings.py b/torch/_inductor/quantized_lowerings.py index 80910e67d3a61..ea81048b41e15 100644 --- a/torch/_inductor/quantized_lowerings.py +++ b/torch/_inductor/quantized_lowerings.py @@ -1,5 +1,5 @@ -# mypy: allow-untyped-defs import logging +from typing import Any import torch from torch._inductor.kernel.mm_common import mm_args @@ -22,13 +22,12 @@ torch._weight_int8pack_mm, "at::_weight_int8pack_mm", has_out_variant=False ) - quantized = torch.ops.quantized _quantized = torch.ops._quantized aten = torch.ops.aten -def register_quantized_ops(): +def register_quantized_ops() -> None: lowering.add_needs_realized_inputs( [ quantized.max_pool2d, @@ -36,15 +35,20 @@ def register_quantized_ops(): _quantized.wrapped_fbgemm_linear_fp16_weight, ] ) - lowering.make_fallback(quantized.max_pool2d) lowering.make_fallback(_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16) lowering.make_fallback(_quantized.wrapped_fbgemm_linear_fp16_weight) -def register_woq_mm_ops(): - @register_lowering(aten._weight_int8pack_mm, type_promotion_kind=None) - def int8pack_mm(input, weight, scale, *, layout=None): +def register_woq_mm_ops() -> None: + @register_lowering(aten._weight_int8pack_mm, type_promotion_kind=None) # type: ignore[misc] + def int8pack_mm( + input: torch.Tensor, + weight: torch.Tensor, + scale: torch.Tensor, + *, + layout: Any = None, + ) -> Any: _, _, _, layout, mat1, mat2 = mm_args( input, weight, layout=layout, mat2_transposed=True ) @@ -63,7 +67,7 @@ def int8pack_mm(input, weight, scale, *, layout=None): # scale is applied as an epilogue, and the scale tensor is expanded (with a view op) # for broadcasting, as it's 1D. - def _mul_epilogue(buf): + def _mul_epilogue(buf: torch.Tensor) -> Any: return create_epilogue_with_attr( buf, "mul", other=realize_inputs(expand(scale, layout.size)) ) @@ -74,7 +78,7 @@ def _mul_epilogue(buf): aten_layout, [mat1, mat2, scale], trans_w=True, - epilogue_creator=_mul_epilogue, + epilogue_creator=_mul_epilogue, # type: ignore[arg-type] ) if ( diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index 5ca4c79849de8..43adfb34165cf 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -283,6 +283,10 @@ class RemoteAutotuneCache(RedisRemoteCache): pass +class RemoteBundledAutotuneCache(RedisRemoteCache): + pass + + class RemoteFxGraphCache(RedisRemoteCache): pass diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index c674d0e40132f..ba3f55a23781b 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -5,11 +5,12 @@ import logging import os import os.path -from typing import Dict, List, Optional, Tuple +import re +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from typing_extensions import override import torch -from torch.utils._triton import has_triton_package +from torch.utils._triton import has_triton, has_triton_package from ..remote_cache import ( create_cache, @@ -20,6 +21,9 @@ ) +if TYPE_CHECKING: + from ..remote_cache import Sample + if has_triton_package(): from triton import Config @@ -29,6 +33,33 @@ _InductorMetaTy = Dict[str, object] +def inductor_meta_from_config() -> _InductorMetaTy: + from torch._inductor import config + + backend_hash = None + if has_triton(): + try: + backend_hash = torch.utils._triton.triton_hash_with_backend() + except RuntimeError: + # This can get the error: + # RuntimeError: 0 active drivers ([]). There should only be one. + pass + + is_hip = None + if torch.version.hip is not None: + is_hip = True + + return { + "autotune_local_cache": config.autotune_local_cache, + "autotune_remote_cache": config.autotune_remote_cache, + "backend_hash": backend_hash, + "bundled_autotune_remote_cache": config.bundled_autotune_remote_cache, + "coordinate_descent_tuning": config.coordinate_descent_tuning, + "is_fbcode": config.is_fbcode(), + "is_hip": is_hip, + } + + @dataclasses.dataclass class AutotuneCache: configs_hash: str @@ -50,7 +81,7 @@ def create( return None # Read the best config options from the most local cache and return it. - def _read(self, inductor_meta: _InductorMetaTy) -> Optional[Dict[str, JsonDataTy]]: + def _read(self) -> Optional[Dict[str, JsonDataTy]]: if local_cache := self.local_cache: cache, key = local_cache if best_config := cache.get(key): @@ -70,7 +101,7 @@ def _read(self, inductor_meta: _InductorMetaTy) -> Optional[Dict[str, JsonDataTy def read_best( self, inductor_meta: _InductorMetaTy, configs: List[Config] ) -> Optional[Config]: - if best := self._read(inductor_meta): + if best := self._read(): return _load_cached_autotuning( best, self.configs_hash, configs, inductor_meta ) @@ -134,6 +165,7 @@ def save( if local_cache := self.local_cache: cache, key = local_cache cache.put(key, data) + AutotuneCacheBundler.put(key, data) if log.isEnabledFor(logging.DEBUG): type_str = "coordesc" if found_by_coordesc else "heuristic" @@ -144,6 +176,212 @@ def save( cache.put(key, data) +class _AutotuneCacheBundlerImpl: + """ + Caches a set of LocalAutotuneCacheBackend entries together in a single + cache. + """ + + _key: str + _cache: RemoteCache[JsonDataTy] + + # All known entries from LocalAutotuneCache.put() + _entries: Dict[str, JsonDataTy] + + def end_compile(self) -> None: + # TODO: Do we need to compute time_taken_ms and encode that somehow? + if self._entries: + self._cache.put(self._key, self._entries) + + def put(self, basename: str, data: JsonDataTy) -> None: + # Do we need to worry about duplicates? We only have a single local fs + # entry - so probably not. + self._entries[basename] = data + + def __init__(self, key: str, cache: RemoteCache[JsonDataTy]) -> None: + self._key = key + self._cache = cache + self._entries = {} + + def sync(self) -> None: + # We don't currently use this - but we could async load starting at + # `begin_compile` and wait for the load to be finished here. + pass + + @classmethod + def _should_use_bundled_autotune_remote_cache( + cls, inductor_meta: _InductorMetaTy + ) -> bool: + # The bundled autotune cache is only available if you've also got local + # caching enabled (because we feed the bundled data to the local cache). + if not inductor_meta.get("autotune_local_cache", True): + return False + + # Check if the we're enabled via config + if ( + bundled_autotune_remote_cache := inductor_meta.get( + "bundled_autotune_remote_cache" + ) + ) is not None: + return bool(bundled_autotune_remote_cache) + + if not cls._get_is_fbcode(inductor_meta): + return False + if torch._utils_internal.is_fb_unit_test(): + return False + if inductor_meta.get("is_hip"): + return False + + try: + from torch._inductor.fb.remote_cache import REMOTE_CACHE_VERSION + except ModuleNotFoundError: + return False + + jk = torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:bundled_autotune_remote_cache_version" + ) + return REMOTE_CACHE_VERSION >= jk + + def _load_cache(self) -> bool: + from torch._inductor import codecache + + # The single key is defined on construction of the cache. + entries = self._cache.get(self._key) + if entries is None or not isinstance(entries, dict): + # We couldn't load the cache - so mark _entries as non-None so we + # store local cache values. + return False + + cache_dir = torch._inductor.runtime.runtime_utils.cache_dir() + + # Go through the entries we got from the cache and save them locally. + time_saved_ns = 0 + for basename, data in entries.items(): + # Reconstruct the final filename (see put()) + root, ext = _splitext_nodot(basename) + _, _, filename = codecache.get_path(root, ext) + if isinstance(data, dict) and (tsns := data.get("time_saved_ns")): + time_saved_ns += int(tsns) # type: ignore[arg-type] + local_cache = LocalAutotuneCache() + local_cache.put(filename, data) + + codecache.add_ephemeral_timeout_increase_for_distributed(time_saved_ns) + + return True + + @staticmethod + def _get_is_fbcode(inductor_meta: _InductorMetaTy) -> bool: + return bool(inductor_meta.get("is_fbcode", False)) + + @staticmethod + def _get_backend_hash(inductor_meta: _InductorMetaTy) -> str: + backend_hash = inductor_meta["backend_hash"] + assert isinstance(backend_hash, str) + return backend_hash + + +class AutotuneCacheBundler: + _bundler: Optional[_AutotuneCacheBundlerImpl] = None + + def __init__(self) -> None: + pass + + # Call this before we start any autotune computation for an inductor python + # file. On a cache hit it copies the individual results into the local + # autotune caches. + @classmethod + def begin_compile( + cls, + inductor_meta: _InductorMetaTy, + *, + code: Optional[str] = None, + code_hash: Optional[str] = None, + ) -> None: + assert cls._bundler is None + + if code is not None: + assert code_hash is None, "Cannot specify both code and code_hash" + code_hash = _comment_stripped_hash(code) + assert code_hash is not None + + if not _AutotuneCacheBundlerImpl._should_use_bundled_autotune_remote_cache( + inductor_meta + ): + return + + cache = create_cache( + "bundled-autotune-v1", + _AutotuneCacheBundlerImpl._get_is_fbcode(inductor_meta), + "FbRemoteBundledAutotuneCache", + "RemoteBundledAutotuneCache", + ) + if not cache: + return + + # We're starting a compilation phase. We have a cache key for the code + # we're compiling. We'll get the individual autotune bundles later (via + # self.put()). For now create the AutotuneCacheBundler and try to load + # from the cache. + + salt = "bundled-autotune-best-configs-v1" + backend_hash = _AutotuneCacheBundlerImpl._get_backend_hash(inductor_meta) + # TODO: The autotune cache includes configs_hash in the key. The problem + # is that the configs_hash includes info from the individual pointwise() + # calls (size_hints, for example) which we can't know yet. I *think* + # that info is basically present in the `code_hash` (since it's a + # parameter to the pointwise decorator) - but is there other info we + # need to include from inductor_meta? + key = code_hash + backend_hash + salt + key = hashlib.sha256(key.encode("utf-8")).hexdigest() + + bundler = _AutotuneCacheBundlerImpl(key, cache) + if not bundler._load_cache(): + # We couldn't load from the cache - so save the data so we can store + # the saved autotunes. + cls._bundler = bundler + + # If we get a cache hit don't bother saving any of the individual + # autotune results. + + # Call this after all individual autotune results are finished for a + # inductor python file. If we gathered any individual results then we bundle + # those and put it into the cache. + @classmethod + def end_compile(cls) -> None: + if bundler := cls._bundler: + cls._bundler = None + bundler.end_compile() + + @classmethod + def sync(cls) -> None: + if bundler := cls._bundler: + bundler.sync() + + @classmethod + def put(cls, filename: str, data: JsonDataTy) -> None: + if bundler := cls._bundler: + # The filename comes in as something like + # "/tmp/tmp{random}/{aa}/{basename}.py" (where aa is + # basename[1:3]). Strip it down and make sure that it looks like a path + # we could reconstruct (because it's possible for the caller to + # customize the path). + basename = os.path.basename(filename) + root, ext = _splitext_nodot(basename) + _, _, expected = torch._inductor.codecache.get_path(root, ext) + if filename != expected: + return + + # TODO: check cache_dir() vs filename, then strip dirname + bundler.put(basename, data) + + +# Remove the comments from the code (which include things like run ids and file +# paths) and then hash the result. +def _comment_stripped_hash(code: str) -> str: + code = re.sub(r"#.*$", "", code, count=0, flags=re.MULTILINE) + return torch._inductor.codecache.code_hash(code) + + def _should_use_remote_autotune_cache(inductor_meta: _InductorMetaTy) -> bool: if (config := inductor_meta.get("autotune_remote_cache")) is not None: return bool(config) @@ -221,3 +459,28 @@ def __init__(self) -> None: backend = _LocalAutotuneCacheBackend() serde = RemoteCacheJsonSerde() super().__init__(backend, serde) + + @override + def _get(self, key: str, sample: Optional[Sample]) -> Optional[JsonDataTy]: + AutotuneCacheBundler.sync() + result = super()._get(key, sample) + if result is not None: + # What? Why are we doing a put() here? Imagine we have a new model + # that reuses some existing kernels that have already been + # compiled. If we didn't do a `put` here (on cache hit) then the new + # model would only bundle *newly* compiled kernels, not existing + # kernels that were already compiled and cached. + AutotuneCacheBundler.put(key, result) + return result + + @override + def _put(self, key: str, value: JsonDataTy, sample: Optional[Sample]) -> None: + AutotuneCacheBundler.put(key, value) + super()._put(key, value, sample) + + +def _splitext_nodot(basename: str) -> Tuple[str, str]: + root, ext = os.path.splitext(basename) + if ext: + ext = ext[1:] + return root, ext diff --git a/torch/_inductor/runtime/cache_dir_utils.py b/torch/_inductor/runtime/cache_dir_utils.py new file mode 100644 index 0000000000000..1a2aabc572cfb --- /dev/null +++ b/torch/_inductor/runtime/cache_dir_utils.py @@ -0,0 +1,23 @@ +import getpass +import os +import re +import tempfile + + +# Factoring out to file without torch dependencies + + +def cache_dir() -> str: + cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") + if cache_dir is None: + os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir() + os.makedirs(cache_dir, exist_ok=True) + return cache_dir + + +def default_cache_dir() -> str: + sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) + return os.path.join( + tempfile.gettempdir(), + "torchinductor_" + sanitized_username, + ) diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index ad2dcb97d80b0..cd62fde545084 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import collections import typing -from dataclasses import fields from enum import auto, Enum from typing import Dict, List, Optional, Union @@ -27,48 +26,58 @@ class TileHint(Enum): DEFAULT = 1 -# Attempt to import AttrsDescriptor from Triton -try: - from triton.compiler.compiler import AttrsDescriptor - - attrs_descriptor_available = True - # Determine if 'ids_of_folded_args' is a valid field for AttrsDescriptor - attr_desc_fields = {f.name for f in fields(AttrsDescriptor)} - ids_of_folded_args_available = "ids_of_folded_args" in attr_desc_fields - divisible_by_8_available = "divisible_by_8" in attr_desc_fields -except ImportError: - attrs_descriptor_available = False - -# Define `instance_descriptor` function with clear conditional handling -if attrs_descriptor_available: - - def instance_descriptor( - divisible_by_16=None, - equal_to_1=None, - ids_of_folded_args=None, - divisible_by_8=None, - ): - # Prepare the arguments for AttrsDescriptor - kwargs = { - "divisible_by_16": divisible_by_16, - "equal_to_1": equal_to_1, - } - - # Conditionally add 'ids_of_folded_args' if it's available in AttrsDescriptor - if ids_of_folded_args_available: - kwargs["ids_of_folded_args"] = ids_of_folded_args - if divisible_by_8_available: - kwargs["divisible_by_8"] = divisible_by_8 - - # Instantiate AttrsDescriptor with the prepared arguments - return AttrsDescriptor(**kwargs) +def _is_triton_available(): + try: + import triton # noqa: F401 + + return True + except ImportError: + return False + + +# Define `AttrsDescriptorWrapper` function with clear conditional handling +if _is_triton_available(): + try: + from triton.backends.compiler import AttrsDescriptor + + def AttrsDescriptorWrapper( + divisible_by_16=None, + equal_to_1=None, + ): + # Prepare the arguments for AttrsDescriptor + kwargs = { + "tt.divisibility": divisible_by_16, + "tt.equal_to": equal_to_1, + } + + # Instantiate AttrsDescriptor with the prepared arguments + res = AttrsDescriptor.from_dict(kwargs) + assert res.property_values["tt.divisibility"] == 16 + assert res.property_values["tt.equal_to"] == 1 + return res + + except ImportError: + from triton.compiler.compiler import AttrsDescriptor + + def AttrsDescriptorWrapper( + divisible_by_16=None, + equal_to_1=None, + ): + # Prepare the arguments for AttrsDescriptor + kwargs = { + "divisible_by_16": divisible_by_16, + "equal_to_1": equal_to_1, + } + + # Instantiate AttrsDescriptor with the prepared arguments + return AttrsDescriptor(**kwargs) else: # Define a namedtuple as a fallback when AttrsDescriptor is not available - instance_descriptor = collections.namedtuple( # type: ignore[no-redef] - "instance_descriptor", - ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], - defaults=[(), (), (), ()], + AttrsDescriptorWrapper = collections.namedtuple( # type: ignore[no-redef, name-match] + "AttrsDescriptor", + ["divisible_by_16", "equal_to_1"], + defaults=[(), ()], ) diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 446dbc71c61d1..e7e25876632cd 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -3,13 +3,13 @@ import contextlib import functools -import getpass import operator -import os -import re -import tempfile import torch +from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401 + cache_dir, + default_cache_dir, +) def conditional_product(*args): @@ -86,22 +86,6 @@ def get_max_y_grid(): return 65535 -def cache_dir() -> str: - cache_dir = os.environ.get("TORCHINDUCTOR_CACHE_DIR") - if cache_dir is None: - os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir = default_cache_dir() - os.makedirs(cache_dir, exist_ok=True) - return cache_dir - - -def default_cache_dir(): - sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) - return os.path.join( - tempfile.gettempdir(), - "torchinductor_" + sanitized_username, - ) - - try: import colorama diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 211495443388e..6a55df5fe3944 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -128,9 +128,9 @@ def autotune_hints_to_configs( triton_config( size_hints, *xyz, - num_elements_per_warp=device_props.warp_size - if device_props.warp_size - else 32, + num_elements_per_warp=( + device_props.warp_size if device_props.warp_size else 32 + ), ) ) @@ -246,6 +246,10 @@ def __init__( self.precompile_time_taken_ns = 0 self.autotune_time_taken_ns = 0 + # Dumps the launch configs after autotuning. + self.dump_launch_params = ( + os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", "0") == "1" + ) def precompile(self, warm_cache_only=False): with self.lock: @@ -381,6 +385,9 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool): if k == "waves_per_eu": compile_meta["waves_per_eu"] = v continue + if k == "kpack": + compile_meta["kpack"] = v + continue compile_meta["constants"][k] = v compile_meta["num_warps"] = cfg.num_warps compile_meta["num_stages"] = cfg.num_stages @@ -478,13 +485,38 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool): raise binary._init_handles() + """ + https://github.com/pytorch/pytorch/issues/115344 + + self.fn.constexprs doesn't properly deal with None args, so when we filter out + an arg in UserDefinedTritonKernel.codegen, we need to filter it here as well. + We also don't want to modify self.fn. + + We know that we removed something from the signature if: + 1. It's in compile_meta["constants"] + 2. It isn't a constant we already know about + Note: The value of interest has already been added to compile_meta['constants'], + so we use self.fn.constexprs instead. + 3. It isn't in the compile_meta signature + """ + none_args = set(compile_meta["constants"].keys()) + known_constants = { + arg for i, arg in enumerate(self.fn.arg_names) if i in self.fn.constexprs + } + none_args = none_args.difference(known_constants) + none_args = none_args.difference(set(compile_meta["signature"].keys())) + call_args = [ arg for i, arg in enumerate(self.fn.arg_names) - if i not in self.fn.constexprs + if i not in self.fn.constexprs and arg not in none_args ] - def_args = [name for name in self.fn.arg_names if name not in cfg.kwargs] + def_args = [ + name + for name in self.fn.arg_names + if name not in cfg.kwargs and name not in none_args + ] binary_shared = ( binary.shared if hasattr(binary, "shared") else binary.metadata.shared ) @@ -494,9 +526,11 @@ def _precompile_config(self, cfg: Config, warm_cache_only: bool): "bin": binary, "launch_enter_hook": CompiledKernel.launch_enter_hook, "launch_exit_hook": CompiledKernel.launch_exit_hook, - "metadata": binary.packed_metadata - if hasattr(binary, "packed_metadata") - else binary.metadata, + "metadata": ( + binary.packed_metadata + if hasattr(binary, "packed_metadata") + else binary.metadata + ), "shared": binary_shared, } @@ -724,10 +758,9 @@ def copy_args_to_cpu_if_needed(self, *args, **kwargs): budget = torch.cuda.max_memory_allocated() - torch.cuda.memory_allocated() def maybe_copy(name, arg): - if name in self.mutated_arg_names: + if name in self.mutated_arg_names and arg.is_cuda: nonlocal budget assert isinstance(arg, torch.Tensor) - assert not arg.is_cpu size = arg.numel() * arg.element_size() if size > budget: cpu_arg = torch.empty_strided( @@ -828,9 +861,11 @@ def save_gpu_kernel(self, grid, stream, launcher): key = self.inductor_meta.get("kernel_name", None) # unique kernel name assert key is not None, "kernel_name can not be None" params = { - "mangled_name": launcher.bin.metadata.name - if hasattr(launcher.bin.metadata, "name") - else launcher.bin.metadata["name"], + "mangled_name": ( + launcher.bin.metadata.name + if hasattr(launcher.bin.metadata, "name") + else launcher.bin.metadata["name"] + ), "grid_x": grid_x, "grid_y": grid_y, "grid_z": grid_z, @@ -838,12 +873,16 @@ def save_gpu_kernel(self, grid, stream, launcher): "y_block": launcher.config.kwargs.get("YBLOCK", None), "z_block": launcher.config.kwargs.get("ZBLOCK", None), "r_block": launcher.config.kwargs.get("RBLOCK", None), - "num_warps": launcher.bin.num_warps - if hasattr(launcher.bin, "num_warps") - else launcher.bin.metadata.num_warps, - "shared_mem": launcher.bin.shared - if hasattr(launcher.bin, "shared") - else launcher.bin.metadata.shared, + "num_warps": ( + launcher.bin.num_warps + if hasattr(launcher.bin, "num_warps") + else launcher.bin.metadata.num_warps + ), + "shared_mem": ( + launcher.bin.shared + if hasattr(launcher.bin, "shared") + else launcher.bin.metadata.shared + ), "stream": stream, # User defined triton kernels will have arbitrary kwarg names "meta": launcher.config.kwargs, @@ -935,7 +974,7 @@ def run( if launcher.store_cubin and (not benchmark_run or not self.cuda_kernel_saved): self.save_gpu_kernel(grid, stream, launcher) - if os.environ.get("TORCHINDUCTOR_DUMP_LAUNCH_PARAMS", 0) == "1": + if self.dump_launch_params: _dump_launch_params(args, kwargs, launcher, self.fn.__name__) # it is faster than entering and exiting a context manager, even if the context @@ -1003,7 +1042,7 @@ def end_graph(output_file): cur_file = inspect.stack()[1].filename summary_str = ( f"SUMMARY ({cur_file})\n" - f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb/(overall_time/1e3):.2f}GB/s" + f"{overall_time:.2f}ms \t {overall_gb:.2f} GB\t {overall_gb / (overall_time / 1e3):.2f}GB/s" ) print(summary_str) print() @@ -1018,7 +1057,7 @@ def end_graph(output_file): file.write(f"TRITON KERNELS BANDWIDTH INFO ({cur_file})\n") for ms, num_gb, gb_per_s, kernel_name in sorted_calls: # also display the runtime percentage for each kernel - percentage = f"{ms/overall_time*100:.2f}%" + percentage = f"{ms / overall_time * 100:.2f}%" suffix = f" \t {percentage} \t {kernel_name}" bw_info_str = create_bandwidth_info_str( ms, @@ -1325,6 +1364,7 @@ def triton_config( x *= math.ceil(block_size / conditional_product(x, y, z)) x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps) + x = min(x, size_hints[0]) cfg = {"XBLOCK": x} if y: @@ -1579,7 +1619,6 @@ def reduction( size_hints = [1, *size_hints[1:]] assert triton_meta is not None - rnumel = size_hints[-1] if len(size_hints) != 2: raise NotImplementedError(f"size_hints: {size_hints}") @@ -1716,7 +1755,6 @@ def user_autotune( ) for c in configs ] - return cached_autotune( None, configs, diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 2f5bc89c81e69..ce70857b9d13c 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -175,6 +175,9 @@ class BaseSchedulerNode: def __init__(self, scheduler: Scheduler) -> None: self.scheduler: Scheduler = scheduler + self.debug_device_str: Callable[ + [BaseSchedulerNode], List[str] + ] = lambda *args, **kwargs: [] def _init_from_node(self, node: ir.Operation) -> None: self.node: Optional[ir.Operation] = node @@ -226,6 +229,9 @@ def debug_str(self) -> str: def debug_str_extra(self) -> str: return "" + def _debug_str_for_device(self) -> List[str]: + return self.debug_device_str(self) + def debug_str_short(self) -> str: maybe_data = getattr(self.node, "data", None) data_str = "" @@ -953,8 +959,7 @@ def debug_str_extra(self) -> str: lines.append(textwrap.indent(self._body.debug_str(), " ")) assert self.node is not None - if ir.is_triton(self.node.get_device()): - lines.extend(debug_triton_code(self)) + lines.extend(self._debug_str_for_device()) return "\n".join(lines) @@ -1178,9 +1183,7 @@ def debug_str_extra(self) -> str: ] node = self.snodes[0].node if node is not None: - device = node.get_device() - if ir.is_triton(device): - lines.extend(debug_triton_code(self)) + lines.extend(self._debug_str_for_device()) return textwrap.indent("\n".join(lines).rstrip(), " ") @@ -1841,7 +1844,6 @@ def _init(self, nodes: List[ir.Operation]) -> None: self.debug_draw_graph() # used during codegen: - self.current_device: Optional[torch.device] = None self.buffer_names_to_free: OrderedSet[str] = OrderedSet() # fx graph node to the position it appears in the graph @@ -1856,11 +1858,13 @@ def _init(self, nodes: List[ir.Operation]) -> None: } ) - def get_current_device_or_throw(self) -> torch.device: - if device := self.current_device: - return device - else: - raise RuntimeError("No current device") + @property + def current_device(self) -> Optional[torch.device]: + return V.graph.current_device + + @current_device.setter + def current_device(self, device: Optional[torch.device]) -> None: + V.graph.current_device = device def debug_draw_graph(self) -> None: """Generate an image of the graph for debugging""" @@ -2373,7 +2377,21 @@ def replace_operation_buffer( node.node, ir.MultiTemplateBuffer ): multi_node = node.node - min_node_unfused, _ = multi_node.get_min_choice() + if not config.test_configs.force_extern_kernel_in_multi_template: + min_node_unfused, _ = multi_node.get_min_choice() + else: + min_node_unfused = next( + ( + timing + for timing in multi_node.choice_timings + if isinstance( + timing, + torch._inductor.select_algorithm.ExternKernelCaller, + ) + ), + None, # type: ignore[arg-type] + ) + assert min_node_unfused is not None if isinstance( min_node_unfused, @@ -2904,7 +2922,10 @@ def has_shared_data_after_reordering_loop( node2.get_name(), ) - return self.score_fusion_memory(node1, node2) > 0 + return ( + self.score_fusion_memory(node1, node2) + >= config.score_fusion_memory_threshold + ) def unfusable_node(self, node: BaseSchedulerNode) -> bool: """ @@ -2972,7 +2993,10 @@ def can_fuse(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: return False del device2 - no_shared_data = self.score_fusion_memory(node1, node2) == 0 + no_shared_data = ( + self.score_fusion_memory(node1, node2) + < config.score_fusion_memory_threshold + ) if no_shared_data: no_shared_data = not self.has_shared_data_after_reordering_loop( node1, node2 @@ -3452,6 +3476,7 @@ def _codegen(self) -> None: ) seen.add(key) + self.current_device = None for node in self.nodes: if log.isEnabledFor(logging.DEBUG): try: @@ -3743,34 +3768,3 @@ def benchmark_combo_kernel( and memory copy time in milliseconds on randomly generated inputs. """ raise NotImplementedError - - -def debug_triton_code(node: Union[SchedulerNode, FusedSchedulerNode]) -> List[str]: - lines = [] - multi_template = node.get_template_node() - assert multi_template is None or isinstance(multi_template, ir.MultiTemplateBuffer) - if multi_template and multi_template.make_kernel_render is None: - lines.append(f"{node.get_name()} Unfinalized multi template buffer") - else: - from torch._inductor.codegen.cuda_combined_scheduling import ( - CUDACombinedScheduling, - ) - - from .codegen.simd import SIMDScheduling - - snodes = (node,) if isinstance(node, SchedulerNode) else node.snodes - device = snodes[0].get_device() - backend = node.scheduler.get_backend(device) - assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)) - V.graph.scheduler.current_device = device - - # Don't increment kernel count when generating debug string. - # This will confuse some unit tests that check the number of - # generated kernels. - old_generated_kernel_count = metrics.generated_kernel_count - triton_code = backend.generate_kernel_code_from_nodes(snodes).strip() - metrics.generated_kernel_count = old_generated_kernel_count - - lines.append(f"{node.get_name()} Triton code:") - lines.append(textwrap.indent(triton_code, " ")) - return lines diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 75e994f8120d6..62c0133ddfb66 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -347,8 +347,7 @@ def stride(self, name, index=None): if isinstance(index, int): return texpr(self.rename_indexing(val[index])) - else: - return ", ".join([texpr(self.rename_indexing(i)) for i in val]) + return ", ".join([texpr(self.rename_indexing(i)) for i in val]) def modification( self, subgraph_number: int, output_name: str, **fixed_inputs @@ -376,8 +375,14 @@ def modification( subgraph = self.subgraphs[subgraph_number] def add_input(name): + # This also implicitly adds name as an input to the kernel return self.args.input(name) + def print_and_rename_indexing(index): + # This also implicitly adds the indexing symbols as an input to + # the kernel + return self.kexpr(self.rename_indexing(index)) + name = f"PlaceholderSubstitution_{subgraph_number}" class PlaceholderSubstitution(V.WrapperHandler): # type: ignore[name-defined] @@ -387,7 +392,7 @@ def load(self, name: str, index: sympy.Expr): if name not in fixed_inputs: # If it's not a fixed input, it's a load from a captured # tensor - index_str = outer_self.kexpr(index) + index_str = print_and_rename_indexing(index) var = add_input(name) return f"tl.load({var} + {index_str})" @@ -646,7 +651,7 @@ def generate( # type: ignore[override] defines.write(f"{name} : tl.constexpr = {val}\n") defines = defines.getvalue() - fake_out = ir.Buffer("buf_out", layout) + fake_out = ir.Buffer(name="buf_out", layout=layout) kernel_name = f"triton_{self.name}" numel = sympy_product(layout.size) @@ -676,7 +681,7 @@ def generate( # type: ignore[override] with patch.object( V.graph, "get_dtype", self._fake_get_dtype(fake_out) - ), TritonTemplateKernel( + ), V.graph.set_current_device(layout.device), TritonTemplateKernel( kernel_name=kernel_name, output_node=fake_out, use_jit=False, @@ -858,16 +863,15 @@ def __init__( input_nodes, layout, make_kernel_render, - debug_extra, + description, bmreq, log_info: Optional[ Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]] ] = None, mutated_inputs=None, ) -> None: - super().__init__(name, input_nodes, layout) + super().__init__(name, input_nodes, layout, description) self.make_kernel_render = make_kernel_render - self.debug_extra = debug_extra self.bmreq: TritonBenchmarkRequest = bmreq if log_info is None: log_info = {} @@ -891,7 +895,7 @@ def precompile(self): self.bmreq.precompile() def __str__(self) -> str: - return f"TritonTemplateCaller({self.bmreq.module_path}, {self.debug_extra})" + return f"TritonTemplateCaller({self.bmreq.module_path}, {self.description})" def call_name(self): return f"template_kernels.{self.name}" @@ -910,7 +914,6 @@ def output_node(self): layout=self.layout, inputs=self.input_nodes, make_kernel_render=self.make_kernel_render, - debug_extra=self.debug_extra, mutated_inputs=self.mutated_inputs, ) ) @@ -946,7 +949,7 @@ def __init__( *, has_out_variant=True, ) -> None: - super().__init__(choice.name, input_nodes, layout) + super().__init__(choice.name, input_nodes, layout, description="") self.choice = choice self.kwargs = kwargs or {} self.has_out_variant = has_out_variant @@ -973,8 +976,7 @@ def to_callable(self): fn = self.choice.to_callable() if self.kwargs: return functools.partial(fn, **self.kwargs) - else: - return fn + return fn def hash_key(self): return "-".join( @@ -989,7 +991,7 @@ def hash_key(self): ) def output_node(self): - if config.abi_compatible and self.choice.use_fallback_kernel: + if self.choice.use_fallback_kernel: assert ( self.choice.op_overload is not None ), "Please provide an op_overload to use ir.FallbackKernel" @@ -1407,6 +1409,7 @@ def get_timings(): layout, input_nodes, get_timings, + choices, ) ) @@ -1648,11 +1651,9 @@ def get_choice_info(choice): for choice in top_k: result = timings[choice] if result: - kernel_info = ( - choice.debug_extra if hasattr(choice, "debug_extra") else "" - ) + kernel_description = choice.description sys.stderr.write( - f" {choice.name} {result:.4f} ms {best_time / result:.1%} {kernel_info}\n" + f" {choice.name} {result:.4f} ms {best_time / result:.1%} {kernel_description}\n" ) else: sys.stderr.write( @@ -1664,7 +1665,7 @@ def get_choice_info(choice): ) sys.stderr.write( f"{autotune_type_str} AUTOTUNE benchmarking takes {elapse:.4f} seconds and {precompile_elapse:.4f}" - " seconds precompiling\n" + f" seconds precompiling for {len(timings)} choices\n" ) @staticmethod @@ -1674,7 +1675,7 @@ def benchmark_example_value(node): benchmarking. """ if isinstance(node, ir.Layout): - node = ir.Buffer("fake", node) + node = ir.Buffer(name="fake", layout=node) # triton templates want the base tensor. if isinstance(node, ir.BaseView): node = node.unwrap_view() diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 8d3c6d411d278..8775036cf1059 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -247,9 +247,11 @@ def _simplify_loops_impl( # for which "strides" don't make sense so we ignore them here. # NOTE: These expressions may still block merging dims in the sound # substitution test performed in can_merge_dims. - self.stride_vars(x, index_vars) - if isinstance(x, sympy.Expr) - else [0] * len(index_vars) + ( + self.stride_vars(x, index_vars) + if isinstance(x, sympy.Expr) + else [0] * len(index_vars) + ) for x in index_formulas ] assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0])) @@ -415,14 +417,29 @@ def guard_equals(self, left: Expr, right: Expr) -> Expr: left = sympy_subs(left, self.inv_precomputed_replacements) # type: ignore[arg-type] if isinstance(right, Expr): right = sympy_subs(right, self.inv_precomputed_replacements) # type: ignore[arg-type] - assert self.shape_env.evaluate_expr(sympy.Eq(left, right)) + + expr = sympy.Eq(left, right) + static_expr = self.shape_env._maybe_evaluate_static(expr) + + if static_expr is not None: + assert bool(static_expr) + return left + + assert self.shape_env.defer_runtime_assert(expr, "guard_equals") return left def guard_leq(self, left: Expr, right: Expr) -> None: return self.guard_lt(left, right + 1) def guard_lt(self, left: Expr, right: Expr) -> None: - assert self.shape_env.evaluate_expr(sympy.Lt(left, right)) + expr = sympy.Lt(left, right) + static_expr = self.shape_env._maybe_evaluate_static(expr) + + if static_expr is not None: + assert bool(static_expr) + return + + assert self.shape_env.defer_runtime_assert(expr, "guard_lt") def guarded_order(self, seq): """ @@ -623,6 +640,24 @@ def _stride_vars( ) return strides + def atomically_apply_size_hint( + self, expr: Union[Expr, int], *, fallback: Optional[int] = None + ) -> Union[Expr, int]: + if isinstance(expr, int): + return int(expr) + + # For multiple expressions that depend on an unbacked symint, + # we want to compute them consistently for a size hint we have chosen. + # So, recursively compute expressions via size hints of contained symbols. + # For example: u1 * u2 - 10 ==> fallback * fallback - 10 + assert isinstance(expr, Expr), type(expr) + free_symbols = expr.free_symbols + size_dict = { + symbol: V.graph.sizevars.size_hint(symbol, fallback=fallback) + for symbol in free_symbols + } + return expr.subs(size_dict) + def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr: """Extract offset part of an indexing expression""" index = self.simplify(index) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 571463c01e592..383291c56ee21 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -40,7 +40,7 @@ Union, ValuesView, ) -from typing_extensions import Concatenate, ParamSpec +from typing_extensions import Concatenate, dataclass_transform, ParamSpec from unittest import mock import sympy @@ -88,7 +88,7 @@ def get_gpu_type(): _T = TypeVar("_T") VarRanges = Dict[sympy.Expr, sympy.Expr] -InputType = Union[torch.Tensor, int] +InputType = Optional[Union[torch.Tensor, int, torch.SymInt]] GPU_ALIGN_BYTES = 16 @@ -477,7 +477,8 @@ def {name}_cache_on_self(self): try: return self.{key} except AttributeError: - self.{key} = rv = fn(self) + rv = fn(self) + object.__setattr__(self, "{key}", rv) return rv """.lstrip(), ctx, @@ -1174,13 +1175,10 @@ class CKGemmOperation: # type: ignore[no-redef] return package_dirname, gen_ops_library, gen_ops_preselected, CKGemmOperation -def use_ck_template(layout, m, n, k): +def use_ck_template(layout): # config knobs check 1 if not use_max_autotune(): return False - # config knobs check 2 - if not _use_autotune_backend("CK"): - return False # platform check if not torch.version.hip: return False @@ -1200,16 +1198,8 @@ def use_ck_template(layout, m, n, k): if not requested_supported_archs: return False # supported input dtypes - if layout.dtype not in [torch.float16, torch.bfloat16]: + if layout.dtype not in [torch.float16, torch.bfloat16, torch.float32]: return False - # TBD: investigate if we need to disable backend based on number of available CUs similar to `is_big_gpu` - # check if shape is static and gemm size is not 0 - from .virtualized import V - - gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) - if gemm_size <= 0: - return False - # TBD: investigate if backend needs to be disabled for small gemms similar to CUTLASS ck_package_dirname, _, _, _ = try_import_ck_lib() @@ -1231,6 +1221,20 @@ def use_ck_template(layout, m, n, k): return True +def use_ck_gemm_template(layout, m, n, k): + from .virtualized import V + + return ( + use_ck_template(layout) + and _use_autotune_backend("CK") + and V.graph.sizevars.size_hint(m * n * k, fallback=-1) > 0 + ) + + +def use_ck_conv_template(layout): + return use_ck_template(layout) and _use_conv_autotune_backend("CK") + + def _use_template_for_cpu(layout): return use_max_autotune() and layout.device.type == "cpu" @@ -1823,7 +1827,7 @@ def get_cloned_parameter_buffer_name(name: str): def is_gpu(device: str): assert isinstance(device, str) or device is None, device - return device in ["cuda", "xpu"] + return device in GPU_TYPES def device_need_guard(device: str): @@ -1969,7 +1973,7 @@ def run_and_get_cpp_code(fn, *args, **kwargs): return result, s -def shape_env_from_inputs(inputs: List[torch.Tensor]): +def shape_env_from_inputs(inputs: Sequence[InputType]): shape_env = None fake_mode = detect_fake_mode(inputs) @@ -2023,9 +2027,9 @@ def copy_misaligned_inputs( def remove_unaligned_input_idxs( - inputs: List[InputType], + inputs: Sequence[InputType], static_input_idxs: Sequence[int], -): +) -> Sequence[int]: """ We require all inputs to be aligned, so introduce a copy for any that aren't. @@ -2091,12 +2095,47 @@ def should_use_remote_fx_graph_cache(): except ModuleNotFoundError: return False - jk_name = "pytorch/remote_cache:fx_graph_memcache_version" - if torch.version.hip is not None: - jk_name = "pytorch/remote_cache:fx_graph_memcache_version_amd" - - return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int(jk_name) + return REMOTE_CACHE_VERSION >= torch._utils_internal.justknobs_getval_int( + "pytorch/remote_cache:fx_graph_memcache_version" + ) def normalize_name(name: str) -> str: return re.sub(r"[^a-zA-Z0-9_]", "_", name) + + +def is_same_tensor(data: torch.Tensor, value: torch.Tensor): + return ( + not data.is_mkldnn + and data.size() == value.size() + and data.stride() == value.stride() + and data.dtype == value.dtype + and data.device == value.device + and data.untyped_storage().data_ptr() == value.untyped_storage().data_ptr() + and data.storage_offset() == value.storage_offset() + ) + + +def is_same_mkldnn_tensor(data: torch.Tensor, value: torch.Tensor): + return ( + data.is_mkldnn + and data.size() == value.size() + and data.dtype == value.dtype + and data.device == value.device + and torch.ops.mkldnn.data_ptr(data) == torch.ops.mkldnn.data_ptr(value) + ) + + +@dataclass_transform(frozen_default=True) +def ir_dataclass(cls=None, /, *, frozen: bool = True): + def wrap(cls: _T) -> _T: + if sys.version_info >= (3, 10): + return dataclasses.dataclass(cls, kw_only=True, frozen=frozen) # type: ignore[call-overload] + else: + # Polyfill for python=3.9. kw_only simply introduces an extra check + # that only kwargs are used (and is not available on 3.9) + return dataclasses.dataclass(cls, frozen=frozen) + + if cls is None: + return wrap + return wrap(cls) diff --git a/torch/_logging/_internal.py b/torch/_logging/_internal.py index 7e3d6093ad229..a2e55e262ced3 100644 --- a/torch/_logging/_internal.py +++ b/torch/_logging/_internal.py @@ -43,6 +43,8 @@ LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT" TRACE_ENV_VAR = "TORCH_TRACE" +LOG_TRACE_HANDLER: Optional["LazyTraceHandler"] = None + @dataclass class LogRegistry: @@ -972,12 +974,14 @@ def _init_logs(log_file_name=None): # initializing it until we actually need to log anything. This is # important because JK initializes a C++ singleton, which will pork our # process if we subsequently fork. - handler = LazyTraceHandler(trace_dir_name) + global LOG_TRACE_HANDLER + if LOG_TRACE_HANDLER is None: + LOG_TRACE_HANDLER = LazyTraceHandler(trace_dir_name) # This log is ALWAYS at debug level. We will additionally test if there # are any handlers before deciding to actually call logging on this. Do # not manually call trace_log.setLevel(logging.DEBUG) - trace_log_handler = _track_handler(handler) + trace_log_handler = _track_handler(LOG_TRACE_HANDLER) trace_log_handler.setFormatter(TorchLogsFormatter(trace=True)) trace_log.addHandler(trace_log_handler) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index b5cc43ffdccf2..0da6b58bdb413 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -2,6 +2,7 @@ # mypy: allow-untyped-defs import math from enum import Enum +from functools import wraps from typing import List, Optional, Sequence, Tuple, Union import torch @@ -718,6 +719,10 @@ def sym_constrain_range_for_size(size, min=None, max=None): # Avoid importing sympy at a module level from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size + if min is None and max is None: + torch._check_is_size(size) + return + if isinstance(size, (SymFloat, SymBool)): raise ValueError("Constraining SymFloat or Symbool is nyi") if type(size) is int: @@ -3394,10 +3399,6 @@ def meta_embedding_bag( mode == MODE_SUM, lambda: "embedding_bag: per_sample_weights only supported with mode='sum'", ) - torch._check( - per_sample_weights.dtype == weight.dtype, - lambda: f"expected weight ({weight.dtype}) and per_sample_weights ({per_sample_weights.dtype}) to have same dtype", - ) torch._check( per_sample_weights.ndim == 1, lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D", @@ -4793,7 +4794,7 @@ def gather_shape_check(self, dim, index): torch._check( ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i), lambda: f"Size does not match at dimension {i} expected index {index.shape}" - + f" to be smaller than self {self.shape} apart from dimension {dim}", + + f" to be no larger than self {self.shape} apart from dimension {dim}", ) @@ -4898,13 +4899,13 @@ def scatter_shape_check(self, dim, index, src_opt=None): ) torch._check( not is_wrong_shape, - lambda: f"Expected index {index.shape} to be smaller than self {self.shape}" - + f" apart from dimension {dim} and to be smaller than src {src_opt.shape}", + lambda: f"Expected index {index.shape} to be no larger than self {self.shape}" + + f" apart from dimension {dim} and to be no larger than src {src_opt.shape}", ) else: torch._check( not is_wrong_shape, - lambda: f"Expected index {index.shape} to be smaller than self {self.shape}" + lambda: f"Expected index {index.shape} to be no larger than self {self.shape}" + f" apart from dimension {dim}", ) @@ -5984,6 +5985,24 @@ def topk_meta(self, k, dim=-1, largest=True, sorted=True): return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64) +@register_meta(aten._segment_reduce_backward) +@out_wrapper() +def meta__segment_reduce_backward( + grad, output, data, reduce, lengths=None, offsets=None, axis=0, initial=None +): + assert ( + lengths is not None or offsets is not None + ), "segment_reduce(): Either lengths or offsets must be defined" + data_contig = data.contiguous() + grad_contig = grad.contiguous() + return torch.empty_like( + data_contig, + dtype=grad_contig.dtype, + device=grad_contig.device, + layout=grad_contig.layout, + ) + + @register_meta([aten.kthvalue.default, aten.kthvalue.values]) @out_wrapper("values", "indices") def kthvalue_meta(self, k, dim=-1, keepdim=False): @@ -6487,6 +6506,76 @@ def _f(x, y): _create_binary_float_meta_func(aten.special_legendre_polynomial_p) +def _register_inplace_meta(fn): + @wraps(fn) + def _fn(self, *args, **kwargs): + out = fn(self, *args, **kwargs) + check_inplace_broadcast(self.shape, out.shape) + return self + + inplace_name = f"{fn.__name__}_" + _fn.__name__ = inplace_name + _fn = register_meta(getattr(aten, inplace_name))(_fn) # type: ignore[assignment] + + return _fn + + +@register_meta(aten.lerp) +@out_wrapper() +def lerp(start, end, weight): + torch._check( + start.dtype == end.dtype, + lambda: f"expected dtype {start.dtype} for `end`, but got dtype {end.dtype}", + ) + args = [start, end] + if isinstance(weight, TensorLike): + torch._check( + start.dtype == weight.dtype, + lambda: f"expected dtype {start.dtype} for `weight`, but got dtype {weight.dtype}", + ) + args.append(weight) + return elementwise_meta( + *args, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +@register_meta(aten.addcmul) +@out_wrapper() +def addcmul(input, tensor1, tensor2, *, value=1): + return elementwise_meta( + input, tensor1, tensor2, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +@register_meta(aten.addcdiv) +@out_wrapper() +def addcdiv(input, tensor1, tensor2, *, value=1): + torch._check( + not ( + utils.is_integer_dtype(tensor1.dtype) + and utils.is_integer_dtype(tensor2.dtype) + ), + lambda: ( + "Integer division with addcdiv is no longer supported, and in a future ", + "release addcdiv will perform a true division of tensor1 and tensor2. ", + "The historic addcdiv behavior can be implemented as ", + "(input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) ", + "for integer inputs and as ", + "(input + value * tensor1 / tensor2) for float inputs. ", + "The future addcdiv behavior is just the latter implementation: ", + "(input + value * tensor1 / tensor2), for all dtypes.", + ), + ) + return elementwise_meta( + input, tensor1, tensor2, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + + +lerp_ = _register_inplace_meta(aten.lerp) +addcmul_ = _register_inplace_meta(aten.addcmul) +addcdiv_ = _register_inplace_meta(aten.addcdiv) + + # We must also trigger meta registrations from PrimTorch ref # decompositions import torch._refs diff --git a/torch/_prims/context.py b/torch/_prims/context.py index 7026d02790715..4756924691367 100644 --- a/torch/_prims/context.py +++ b/torch/_prims/context.py @@ -129,6 +129,8 @@ def __torch_function__( func = torch._decomp.decomposition_table.get(orig_func, None) elif func is None and isinstance(orig_func, torch._ops.OpOverloadPacket): default = getattr(orig_func, "default", None) + if default is None and orig_func._dir: + default = getattr(orig_func, orig_func._dir[0], None) if default is not None: func = torch._decomp.decomposition_table.get(default, None) diff --git a/torch/_prims_common/wrappers.py b/torch/_prims_common/wrappers.py index 19c7cfc3bb16f..865925e7dadd6 100644 --- a/torch/_prims_common/wrappers.py +++ b/torch/_prims_common/wrappers.py @@ -2,7 +2,16 @@ import inspect import warnings from functools import wraps -from typing import Callable, NamedTuple, Optional, overload, Sequence, Tuple, TypeVar +from typing import ( + Callable, + List, + NamedTuple, + Optional, + overload, + Sequence, + Tuple, + TypeVar, +) from typing_extensions import ParamSpec import torch @@ -59,7 +68,9 @@ def _maybe_convert_to_dtype(a, dtype): if a is None: return None - raise ValueError(f"Received type {type(a)} that is neither a tensor or a number!") + raise ValueError( + f"Received unsupported type {type(a)}. Expected TensorLike, Number, or Sequence." + ) def _maybe_convert_to_type(a: NumberType, typ: type) -> NumberType: @@ -288,11 +299,17 @@ def maybe_check_copy_devices(out): else: result = fn(*args, **kwargs) assert ( - isinstance(result, TensorLike) - and is_tensor - or isinstance(result, Tuple) # type: ignore[arg-type] - and len(result) == len(out_names) # type: ignore[arg-type] + (isinstance(result, TensorLike) and is_tensor) + or ( + isinstance(result, Tuple) # type: ignore[arg-type] + and len(result) == len(out_names) # type: ignore[arg-type] + ) + or ( + fn.__name__ == "unbind" + and isinstance(result, (List, Tuple)) # type: ignore[arg-type] + ) ) + # unbind_copy is a special case: see https://github.com/pytorch/pytorch/issues/130829 if out is not None: # Naively you might expect this assert to be true, but # it's not: @@ -310,7 +327,7 @@ def maybe_check_copy_devices(out): # the output tensor, but not the result--which will # be a normal meta tensor, but this is perfectly # harmless. - if is_tensor: + if is_tensor and fn.__name__ != "unbind": assert isinstance(out, TensorLike) # These two operations are done in-place _maybe_resize_out( @@ -318,7 +335,10 @@ def maybe_check_copy_devices(out): ) _safe_copy_out(copy_from=result, copy_to=out, exact_dtype=exact_dtype) # type: ignore[arg-type] else: - assert isinstance(out, Tuple) # type: ignore[arg-type] + if fn.__name__ != "unbind": + assert isinstance(out, Tuple) # type: ignore[arg-type] + else: + assert isinstance(out, (List, Tuple)) # type: ignore[arg-type] torch._check_type( len(out) == len(result), # type: ignore[arg-type] lambda: f"expected tuple of {len(result)} elements but got {len(out)}", # type: ignore[arg-type] diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 8ac7df248a1eb..08fdf5098eb13 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -285,6 +285,7 @@ "native_group_norm", "native_layer_norm", "permute", + "permute_copy", "ravel", "repeat", "reshape", @@ -304,6 +305,7 @@ "tensor_split", "transpose", "transpose_copy", + "unbind_copy", "unfold", "unfold_copy", "unsqueeze", @@ -421,12 +423,12 @@ def _broadcast_shapes(*_shapes): ) common_shape[idx] = shape[idx] elif guard_size_oblivious(shape[idx] != 1): - if common_shape[idx] != shape[idx]: - raise RuntimeError( - f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! " - f"Mismatching argument at index {arg_idx} had {shape}; but expected shape " - f"should be broadcastable to {common_shape}" - ) + torch._check( + common_shape[idx] == shape[idx], + lambda: f"Attempting to broadcast a dimension of length {shape[idx]} at {idx}! " + f"Mismatching argument at index {arg_idx} had {shape}; but expected shape " + f"should be broadcastable to {common_shape}", + ) return common_shape @@ -6378,8 +6380,10 @@ def select_scatter(x: TensorLikeType, src: TensorLikeType, dim: int, index: int) # no sparse support. See narrow_copy_sparse in core. narrow_copy = _make_copy_from_view(aten.narrow) squeeze_copy = _make_copy_from_view(aten.squeeze) +permute_copy = _make_copy_from_view(aten.permute) t_copy = _make_copy_from_view(aten.t) transpose_copy = _make_copy_from_view(aten.transpose) +unbind_copy = _make_copy_from_view(aten.unbind) unsqueeze_copy = _make_copy_from_view(aten.unsqueeze) view_copy = _make_copy_from_view(aten.view) diff --git a/torch/_strobelight/compile_time_profiler.py b/torch/_strobelight/compile_time_profiler.py index 13132188a1930..81ebef2df6b13 100644 --- a/torch/_strobelight/compile_time_profiler.py +++ b/torch/_strobelight/compile_time_profiler.py @@ -93,7 +93,7 @@ class StrobelightCompileTimeProfiler: profiler: Optional[Any] = None max_stack_length: int = int( - os.environ.get("COMPILE_STROBELIGHT_MAX_STACK_LENGTH", 127) + os.environ.get("COMPILE_STROBELIGHT_MAX_STACK_LENGTH", 500) ) max_profile_time: int = int( os.environ.get("COMPILE_STROBELIGHT_MAX_PROFILE_TIME", 60 * 30) @@ -125,6 +125,8 @@ def enable(cls, profiler_class: Any = StrobelightCLIFunctionProfiler) -> None: cls._cls_init() # profiler_class should have public API similar to that of StrobelightCLIFunctionProfiler. # we have pass different functionProfilerClass for meta-internal fbcode targets. + # NB: the actual implementation in Meta is at + # fbcode/caffe2/fb/strobelight/function_profiler.py cls.profiler = profiler_class( sample_each=cls.sample_each, max_profile_duration_sec=cls.max_profile_time, diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index b80b200a3c52b..14cb7b7362f97 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -630,7 +630,7 @@ def multi_device_op_default(fake_mode, func, *args, **kwargs): @register_op_impl(aten.slice_scatter.out) def multi_device_op_out(fake_mode, func, *args, **kwargs): with in_kernel_invocation_manager(fake_mode): - out = func(*args, **kwargs) + func(*args, **kwargs) _, new_kwargs = normalize_function( func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 37f337c0362c5..565b3a2fc66ae 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -32,7 +32,7 @@ TypeVar, Union, ) -from typing_extensions import Self, TypeGuard +from typing_extensions import Self, TypeIs from weakref import ReferenceType import torch @@ -151,14 +151,15 @@ def unset_fake_temporarily() -> Generator[Optional[TorchDispatchMode], None, Non torch._C._set_dispatch_mode(old) -def get_plain_tensors(subclass: Tensor) -> List[Tensor]: - assert is_traceable_wrapper_subclass(subclass) - plain_tensors: List[Tensor] = [] +def get_plain_tensors( + subclass: Tensor, out_append_list: Optional[List[Tensor]] = None +) -> List[Tensor]: + # This function is used in Runtime, do not add redundant asserts + plain_tensors: List[Tensor] = [] if out_append_list is None else out_append_list todo = [subclass] while todo: curr = todo.pop() if not is_traceable_wrapper_subclass(curr): - assert isinstance(curr, Tensor) plain_tensors.append(curr) continue @@ -169,7 +170,7 @@ def get_plain_tensors(subclass: Tensor) -> List[Tensor]: return plain_tensors -def is_fake(x: object) -> TypeGuard[Tensor]: +def is_fake(x: object) -> TypeIs[Tensor]: if isinstance(x, FakeTensor): return True if is_traceable_wrapper_subclass(x): @@ -475,18 +476,15 @@ def from_meta_and_device( @functools.lru_cache(None) -def init_gpu_context() -> None: +def init_gpu_context(device: torch.device) -> None: # Backward will error with cuda Fake Tensors if no cuda tensors have been initialized first - if torch.cuda.is_available(): + if torch.cuda.is_available() or torch.xpu.is_available(): ( - torch.empty(1, device="cuda") + torch.empty(1, device=device) if torch.version.hip is None - else torch.zeros(1, device="cuda") + else torch.zeros(1, device=device) ) - if torch.xpu.is_available(): - (torch.empty(1, device="xpu")) - @contextlib.contextmanager def in_kernel_invocation_manager( @@ -691,7 +689,7 @@ def __new__( assert device.type != "meta" # normalize device. if device.type in ["cuda", "xpu"]: - init_gpu_context() + init_gpu_context(device) if ( device.type @@ -1020,9 +1018,10 @@ def strip_shape_env(self) -> None: @dataclass_slots @dataclass(frozen=True) -class _DispatchCacheEntry: +class _DispatchCacheEntryOutputInfo: """ - Entry type for the FakeTensor dispatch cache. Accounts for two possibilities: + Entry type for the FakeTensor dispatch cache for an output. Accounts for two + possibilities: 1) The op is inplace, and a hit means we need to alias the argument at a given index. 2) We need to synthesize a new FakeTensor given tensor metadata. For view @@ -1034,6 +1033,21 @@ class _DispatchCacheEntry: view_idx: Optional[int] +@dataclass_slots +@dataclass(frozen=True) +class _DispatchCacheEntry: + """ + Entry type for the FakeTensor dispatch cache. It supports two types of outputs + 1) tensor + 2) tuple of tensors + + is_output_tuple flag helps in differentiating the return type + """ + + output_infos: Tuple[_DispatchCacheEntryOutputInfo] + is_output_tuple: bool = False + + @dataclass_slots @dataclass(frozen=True) class _BypassDispatchCache(Exception): @@ -1200,7 +1214,7 @@ def reset_nt_tensor_id_counter(self) -> None: # In this case, it's insufficient to test only one FakeTensor: you need # to distinguish between our fake tensor and other fake tensors. That's # what this function does. - def is_our_fake(self, t: object) -> TypeGuard[FakeTensor]: + def is_our_fake(self, t: object) -> TypeIs[FakeTensor]: return isinstance(t, FakeTensor) and t.fake_mode is self # If we should avoid device init. This changes the behavior of various APIs: @@ -1474,7 +1488,7 @@ def _prep_args_for_hash( result.append(type(arg)) result.append(arg) - def _make_cache_entry( + def _validate_output_for_cache_entry( self, state: _CacheKeyState, key: _DispatchCacheKey, @@ -1482,15 +1496,7 @@ def _make_cache_entry( args: Sequence[object], kwargs: Mapping[str, object], output: Optional[FakeTensor], - ) -> _DispatchCacheEntry: - """ - Make a cache entry object for the given 'output' Tensor. Raises - _BypassDispatchCache if the output tensor has characteristics that - prevent caching it. - """ - if output is None: - return _DispatchCacheEntry(inplace_idx=None, metadata=None, view_idx=None) - + ) -> None: # Some ops return tuples of Tensors, but it's rare, so avoid # the complexity of caching other types. if not isinstance(output, FakeTensor): @@ -1514,10 +1520,19 @@ def _make_cache_entry( if id(kval) == id(output): raise _BypassDispatchCache("kwarg aliases output") + def _get_output_info_for_cache_entry( + self, + state: _CacheKeyState, + key: _DispatchCacheKey, + func: OpOverload, + args: Sequence[object], + kwargs: Mapping[str, object], + output: FakeTensor, + ) -> _DispatchCacheEntryOutputInfo: # If this is an in-place op, the entry records which input arg is aliased. for idx in range(len(args)): if id(args[idx]) == id(output): - return _DispatchCacheEntry( + return _DispatchCacheEntryOutputInfo( inplace_idx=idx, metadata=None, view_idx=None ) @@ -1538,7 +1553,7 @@ def _make_cache_entry( else state.convert_output(metadata.storage_bytes) ) - entry = _DispatchCacheEntry( + entry = _DispatchCacheEntryOutputInfo( inplace_idx=None, metadata=metadata, view_idx=view_idx, @@ -1549,7 +1564,12 @@ def _make_cache_entry( # we can synthesize a tensor here and do the checks on that instance. # This approach keeps the (more frequent) cache-hit path as lightweight # as possible. - synth_output = self._output_from_cache_entry(state, entry, key, func, args) + entry_for_synth_output = _DispatchCacheEntry( + output_infos=(entry,), is_output_tuple=False + ) + synth_output = self._output_from_cache_entry( + state, entry_for_synth_output, key, func, args + ) # Make sure the dispatch_key_set from the synthesized output tensor will # be the same. @@ -1560,17 +1580,66 @@ def _make_cache_entry( return entry - def _output_from_cache_entry( + def _make_cache_entry( self, state: _CacheKeyState, - entry: _DispatchCacheEntry, key: _DispatchCacheKey, func: OpOverload, args: Sequence[object], - ) -> Optional[FakeTensor]: + kwargs: Mapping[str, object], + output: Optional[FakeTensor], + ) -> _DispatchCacheEntry: """ - Create a new FakeTensor from the cache entry. + Make a cache entry object for the given 'output' Tensor. Raises + _BypassDispatchCache if the output tensor has characteristics that + prevent caching it. """ + if output is None: + output_info = _DispatchCacheEntryOutputInfo( + inplace_idx=None, metadata=None, view_idx=None + ) + return _DispatchCacheEntry( + output_infos=(output_info,), is_output_tuple=False + ) + + if isinstance(output, tuple): + for out_element in output: + self._validate_output_for_cache_entry( + state, key, func, args, kwargs, out_element + ) + else: + self._validate_output_for_cache_entry( + state, key, func, args, kwargs, output + ) + + if isinstance(output, tuple): + output_infos = [] + for out_elem in output: + output_infos.append( + self._get_output_info_for_cache_entry( + state, key, func, args, kwargs, out_elem + ) + ) + return _DispatchCacheEntry( + output_infos=tuple(output_infos), is_output_tuple=True + ) + + else: + output_info = self._get_output_info_for_cache_entry( + state, key, func, args, kwargs, output + ) + return _DispatchCacheEntry( + output_infos=(output_info,), is_output_tuple=False + ) + + def _get_output_tensor_from_cache_entry( + self, + state: _CacheKeyState, + entry: _DispatchCacheEntryOutputInfo, + key: _DispatchCacheKey, + func: OpOverload, + args: Sequence[object], + ) -> Optional[FakeTensor]: if entry.inplace_idx is not None: # This is an in-place op; return the aliased arg. inplace_arg = args[entry.inplace_idx] @@ -1597,11 +1666,8 @@ def check_value( shape = tuple(check_value(v, state) for v in metadata.shape) stride = tuple(check_value(v, state) for v in metadata.stride) storage_offset = check_value(metadata.storage_offset, state) - storage_bytes = ( - None - if metadata.storage_bytes is None - else check_value(metadata.storage_bytes, state) - ) + if metadata.storage_bytes is not None: + check_value(metadata.storage_bytes, state) maybe_suppress: Callable[[], typing.ContextManager] = contextlib.nullcontext if self.shape_env is not None: @@ -1632,9 +1698,39 @@ def check_value( return FakeTensor(self, empty, metadata.device) + def _output_from_cache_entry( + self, + state: _CacheKeyState, + entry: _DispatchCacheEntry, + key: _DispatchCacheKey, + func: OpOverload, + args: Sequence[object], + ) -> Union[Optional[FakeTensor], Tuple[Optional[FakeTensor], ...]]: + """ + Create a new FakeTensor from the cache entry. + """ + + if entry.is_output_tuple: + outputs = [] + for output_info in entry.output_infos: + outputs.append( + self._get_output_tensor_from_cache_entry( + state, + output_info, + key, + func, + args, + ) + ) + return tuple(outputs) + else: + return self._get_output_tensor_from_cache_entry( + state, entry.output_infos[0], key, func, args + ) + def _crosscheck_cache_output( self, - output: Optional[FakeTensor], + output: Union[Optional[FakeTensor], Tuple[Optional[FakeTensor], ...]], func: OpOverload, types: Sequence[Type], args: Sequence[object], @@ -1653,7 +1749,13 @@ def _crosscheck_cache_output( ) from e try: if (true_output is not None) and (output is not None): - assert_metadata_eq(assert_eq, true_output, output) + if isinstance(true_output, tuple): + assert len(true_output) == len(output) + for a, b in zip(true_output, output): + assert_metadata_eq(assert_eq, a, b) + else: + assert not isinstance(output, tuple) + assert_metadata_eq(assert_eq, true_output, output) else: assert true_output is None assert output is None @@ -1920,6 +2022,13 @@ def go(t: object, real_t: Tensor) -> None: if isinstance(t.node.expr, sympy.Symbol): assert self.shape_env is not None self.shape_env.set_unbacked_var_to_val(t.node.expr, real_t) + elif ( + isinstance(s := t.node.expr, sympy.Eq) + and isinstance(s.lhs, sympy.Symbol) + and s.rhs == 1 + ): + assert self.shape_env is not None + self.shape_env.set_unbacked_var_to_val(s, int(real_t)) if real_out is not nil: if ( diff --git a/torch/_subclasses/fake_utils.py b/torch/_subclasses/fake_utils.py index 28fc7a4028917..c610ee9dbab40 100644 --- a/torch/_subclasses/fake_utils.py +++ b/torch/_subclasses/fake_utils.py @@ -82,6 +82,7 @@ def __init__( *, check_strides=True, check_aliasing=True, + only_check_ops_with_meta=True, ): super().__init__() self.ignore_op_fn = ( @@ -89,6 +90,7 @@ def __init__( ) self.check_strides = check_strides self.check_aliasing = check_aliasing + self.only_check_ops_with_meta = only_check_ops_with_meta def __torch_dispatch__(self, func, types, args=(), kwargs=None): kwargs = kwargs or {} @@ -105,6 +107,10 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): aten.set_.source_Storage_storage_offset, ) and not self.ignore_op_fn(func) + and ( + not self.only_check_ops_with_meta + or torch._subclasses.fake_impls.has_meta(func) + ) and torch.Tag.dynamic_output_shape not in func.tags and torch.Tag.inplace_view not in func.tags and torch.Tag.data_dependent_output not in func.tags diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 763dea06e5028..7a3cab1b09571 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -163,7 +163,8 @@ def __new__(cls, elem, mode): out.elem = elem if ( - torch.is_inference_mode_enabled() + not mode.export + and torch.is_inference_mode_enabled() and torch._inductor.config.enable_auto_functionalized_v2 ): if out.is_base_tensor(): @@ -309,6 +310,9 @@ def to_dense(self): # type: ignore[override] def layout(self): return self.elem.layout + def __bool__(self): + return bool(self.item()) + class FunctionalTensorMode(TorchDispatchMode): def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False): diff --git a/torch/_tensor.py b/torch/_tensor.py index d5215550ed093..f06eb7dffe433 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -300,6 +300,20 @@ def _reduce_ex_internal(self, proto): torch.serialization._serialization_tls.materialize_fake_tensors ) + if self.device.type == "xla" or ( + not torch._C._has_storage(self) + and self.device.type == torch._C._get_privateuse1_backend_name() + ): + if skip_data: + raise RuntimeError( + "Cannot serialize tensors on backends with no storage under skip_data context manager" + ) + cpu_tensor = self.cpu() + return ( + torch._utils._rebuild_device_tensor_from_cpu_tensor, + (cpu_tensor, self.dtype, str(self.device), self.requires_grad), + ) + # Legacy comment that does not hold anymore. # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors. # We considered a few options: # 1. CPU tensor can't be used here. @@ -310,10 +324,7 @@ def _reduce_ex_internal(self, proto): # 2. Python list is not a good fit due to performance reason. # `tolist()` converts every single element in the tensor into python objects # and serialize them one by one. - if self.device.type in ["xla", "mtia", "maia"] or ( - not torch._C._has_storage(self) - and self.device.type == torch._C._get_privateuse1_backend_name() - ): + if self.device.type in ["mtia", "maia"]: # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, # this would reconstruct the BFloat16 tensor from numpy. diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 41f9b2246f38b..427d32ed9f9e6 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -193,7 +193,7 @@ def merge_dicts(*dicts): add_docstr( torch.abs, r""" -abs(input, *, out=None) -> Tensor +abs(input: Tensor, *, out: Optional[Tensor]) -> Tensor Computes the absolute value of each element in :attr:`input`. @@ -217,7 +217,7 @@ def merge_dicts(*dicts): add_docstr( torch.absolute, r""" -absolute(input, *, out=None) -> Tensor +absolute(input: Tensor, *, out: Optional[Tensor]) -> Tensor Alias for :func:`torch.abs` """, @@ -226,7 +226,7 @@ def merge_dicts(*dicts): add_docstr( torch.acos, r""" -acos(input, *, out=None) -> Tensor +acos(input: Tensor, *, out: Optional[Tensor]) -> Tensor Computes the inverse cosine of each element in :attr:`input`. @@ -253,7 +253,7 @@ def merge_dicts(*dicts): add_docstr( torch.arccos, r""" -arccos(input, *, out=None) -> Tensor +arccos(input: Tensor, *, out: Optional[Tensor]) -> Tensor Alias for :func:`torch.acos`. """, @@ -262,7 +262,7 @@ def merge_dicts(*dicts): add_docstr( torch.acosh, r""" -acosh(input, *, out=None) -> Tensor +acosh(input: Tensor, *, out: Optional[Tensor]) -> Tensor Returns a new tensor with the inverse hyperbolic cosine of the elements of :attr:`input`. @@ -293,7 +293,7 @@ def merge_dicts(*dicts): add_docstr( torch.arccosh, r""" -arccosh(input, *, out=None) -> Tensor +arccosh(input: Tensor, *, out: Optional[Tensor]) -> Tensor Alias for :func:`torch.acosh`. """, @@ -302,7 +302,7 @@ def merge_dicts(*dicts): add_docstr( torch.index_add, r""" -index_add(input, dim, index, source, *, alpha=1, out=None) -> Tensor +index_add(input: Tensor, dim: int, index: Tensor, source: Tensor, *, alpha: Union[Number, _complex] = 1, out: Optional[Tensor]) -> Tensor # noqa: B950 See :meth:`~Tensor.index_add_` for function description. """, @@ -311,7 +311,7 @@ def merge_dicts(*dicts): add_docstr( torch.index_copy, r""" -index_copy(input, dim, index, source, *, out=None) -> Tensor +index_copy(input: Tensor, dim: int, index: Tensor, source: Tensor, *, out: Optional[Tensor]) -> Tensor See :meth:`~Tensor.index_add_` for function description. """, @@ -320,7 +320,7 @@ def merge_dicts(*dicts): add_docstr( torch.index_reduce, r""" -index_reduce(input, dim, index, source, reduce, *, include_self=True, out=None) -> Tensor +index_reduce(input: Tensor, dim: int, index: Tensor, source: Tensor, reduce: str, *, include_self: bool = True, out: Optional[Tensor]) -> Tensor # noqa: B950 See :meth:`~Tensor.index_reduce_` for function description. """, @@ -578,12 +578,15 @@ def merge_dicts(*dicts): add_docstr( torch.adjoint, r""" -adjoint(Tensor) -> Tensor +adjoint(input: Tensor) -> Tensor Returns a view of the tensor conjugated and with the last two dimensions transposed. ``x.adjoint()`` is equivalent to ``x.transpose(-2, -1).conj()`` for complex tensors and to ``x.transpose(-2, -1)`` for real tensors. +Args: + {input} + Example:: >>> x = torch.arange(4, dtype=torch.float) >>> A = torch.complex(x, x).reshape(2, 2) @@ -732,7 +735,7 @@ def merge_dicts(*dicts): add_docstr( torch.allclose, r""" -allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False) -> bool +allclose(input: Tensor, other: Tensor, rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> bool This function checks if :attr:`input` and :attr:`other` satisfy the condition: @@ -766,7 +769,7 @@ def merge_dicts(*dicts): add_docstr( torch.all, r""" -all(input) -> Tensor +all(input: Tensor) -> Tensor Tests if all elements in :attr:`input` evaluate to `True`. @@ -821,7 +824,7 @@ def merge_dicts(*dicts): add_docstr( torch.any, r""" -any(input) -> Tensor +any(input: Tensor, *, out: Optional[Tensor]) -> Tensor Tests if any element in :attr:`input` evaluates to `True`. @@ -876,7 +879,7 @@ def merge_dicts(*dicts): add_docstr( torch.angle, r""" -angle(input, *, out=None) -> Tensor +angle(input: Tensor, *, out: Optional[Tensor]) -> Tensor Computes the element-wise angle (in radians) of the given :attr:`input` tensor. @@ -946,7 +949,7 @@ def merge_dicts(*dicts): add_docstr( torch.as_tensor, r""" -as_tensor(data, dtype=None, device=None) -> Tensor +as_tensor(data: Any, dtype: Optional[dtype] = None, device: Optional[DeviceLikeType]) -> Tensor Converts :attr:`data` into a tensor, sharing data and preserving autograd history if possible. @@ -998,7 +1001,7 @@ def merge_dicts(*dicts): add_docstr( torch.asin, r""" -asin(input, *, out=None) -> Tensor +asin(input: Tensor, *, out: Optional[Tensor]) -> Tensor Returns a new tensor with the arcsine of the elements of :attr:`input`. @@ -1025,7 +1028,7 @@ def merge_dicts(*dicts): add_docstr( torch.arcsin, r""" -arcsin(input, *, out=None) -> Tensor +arcsin(input: Tensor, *, out: Optional[Tensor]) -> Tensor Alias for :func:`torch.asin`. """, @@ -1034,7 +1037,7 @@ def merge_dicts(*dicts): add_docstr( torch.asinh, r""" -asinh(input, *, out=None) -> Tensor +asinh(input: Tensor, *, out: Optional[Tensor]) -> Tensor Returns a new tensor with the inverse hyperbolic sine of the elements of :attr:`input`. @@ -1061,7 +1064,7 @@ def merge_dicts(*dicts): add_docstr( torch.arcsinh, r""" -arcsinh(input, *, out=None) -> Tensor +arcsinh(input: Tensor, *, out: Optional[Tensor]) -> Tensor Alias for :func:`torch.asinh`. """, @@ -1070,7 +1073,7 @@ def merge_dicts(*dicts): add_docstr( torch.atan, r""" -atan(input, *, out=None) -> Tensor +atan(input: Tensor, *, out: Optional[Tensor]) -> Tensor Returns a new tensor with the arctangent of the elements of :attr:`input`. @@ -1097,7 +1100,7 @@ def merge_dicts(*dicts): add_docstr( torch.arctan, r""" -arctan(input, *, out=None) -> Tensor +arctan(input: Tensor, *, out: Optional[Tensor]) -> Tensor Alias for :func:`torch.atan`. """, @@ -1106,7 +1109,7 @@ def merge_dicts(*dicts): add_docstr( torch.atan2, r""" -atan2(input, other, *, out=None) -> Tensor +atan2(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor Element-wise arctangent of :math:`\text{{input}}_{{i}} / \text{{other}}_{{i}}` with consideration of the quadrant. Returns a new tensor with the signed angles @@ -1138,7 +1141,7 @@ def merge_dicts(*dicts): add_docstr( torch.arctan2, r""" -arctan2(input, other, *, out=None) -> Tensor +arctan2(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor Alias for :func:`torch.atan2`. """, ) @@ -1146,7 +1149,7 @@ def merge_dicts(*dicts): add_docstr( torch.atanh, r""" -atanh(input, *, out=None) -> Tensor +atanh(input: Tensor, *, out: Optional[Tensor]) -> Tensor Returns a new tensor with the inverse hyperbolic tangent of the elements of :attr:`input`. @@ -1178,7 +1181,7 @@ def merge_dicts(*dicts): add_docstr( torch.arctanh, r""" -arctanh(input, *, out=None) -> Tensor +arctanh(input: Tensor, *, out: Optional[Tensor]) -> Tensor Alias for :func:`torch.atanh`. """, @@ -1187,7 +1190,7 @@ def merge_dicts(*dicts): add_docstr( torch.asarray, r""" -asarray(obj, *, dtype=None, device=None, copy=None, requires_grad=False) -> Tensor +asarray(obj: Any, *, dtype: Optional[dtype], device: Optional[DeviceLikeType], copy: Optional[bool] = None, requires_grad: bool = False) -> Tensor # noqa: B950 Converts :attr:`obj` to a tensor. @@ -1352,7 +1355,7 @@ def merge_dicts(*dicts): add_docstr( torch.bernoulli, r""" -bernoulli(input, *, generator=None, out=None) -> Tensor +bernoulli(input: Tensor, *, generator: Optional[Generator], out: Optional[Tensor]) -> Tensor Draws binary random numbers (0 or 1) from a Bernoulli distribution. @@ -1542,7 +1545,7 @@ def merge_dicts(*dicts): add_docstr( torch.bitwise_or, r""" -bitwise_or(input, other, *, out=None) -> Tensor +bitwise_or(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor Computes the bitwise OR of :attr:`input` and :attr:`other`. The input tensor must be of integral or Boolean types. For bool tensors, it computes the logical OR. @@ -5885,7 +5888,7 @@ def merge_dicts(*dicts): add_docstr( torch.log10, r""" -log10(input, *, out=None) -> Tensor +log10(input: Tensor, *, out: Optional[Tensor]) -> Tensor Returns a new tensor with the logarithm to the base 10 of the elements of :attr:`input`. @@ -5947,7 +5950,7 @@ def merge_dicts(*dicts): add_docstr( torch.log2, r""" -log2(input, *, out=None) -> Tensor +log2(input: Tensor, *, out: Optional[Tensor]) -> Tensor Returns a new tensor with the logarithm to the base 2 of the elements of :attr:`input`. @@ -6132,7 +6135,7 @@ def merge_dicts(*dicts): add_docstr( torch.logical_xor, r""" -logical_xor(input, other, *, out=None) -> Tensor +logical_xor(input: Tensor, other: Tensor, *, out: Optional[Tensor]) -> Tensor Computes the element-wise logical XOR of the given input tensors. Zeros are treated as ``False`` and nonzeros are treated as ``True``. @@ -8190,7 +8193,7 @@ def merge_dicts(*dicts): add_docstr( torch.numel, r""" -numel(input) -> int +numel(input: Tensor) -> int Returns the total number of elements in the :attr:`input` tensor. @@ -8529,7 +8532,7 @@ def merge_dicts(*dicts): add_docstr( torch.prod, r""" -prod(input, *, dtype=None) -> Tensor +prod(input: Tensor, *, dtype: Optional[_dtype]) -> Tensor Returns the product of all elements in the :attr:`input` tensor. @@ -8602,7 +8605,7 @@ def merge_dicts(*dicts): add_docstr( torch.qr, r""" -qr(input, some=True, *, out=None) -> (Tensor, Tensor) +qr(input: Tensor, some: bool = True, *, out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None]) -> (Tensor, Tensor) Computes the QR decomposition of a matrix or a batch of matrices :attr:`input`, and returns a namedtuple (Q, R) of tensors such that :math:`\text{input} = Q R` @@ -8686,7 +8689,7 @@ def merge_dicts(*dicts): add_docstr( torch.rad2deg, r""" -rad2deg(input, *, out=None) -> Tensor +rad2deg(input: Tensor, *, out: Optional[Tensor]) -> Tensor Returns a new tensor with each of the elements of :attr:`input` converted from angles in radians to degrees. @@ -9872,7 +9875,7 @@ def merge_dicts(*dicts): add_docstr( torch.msort, r""" -msort(input, *, out=None) -> Tensor +msort(input: Tensor, *, out: Optional[Tensor]) -> Tensor Sorts the elements of the :attr:`input` tensor along its first dimension in ascending order by value. @@ -10355,7 +10358,7 @@ def merge_dicts(*dicts): add_docstr( torch.square, r""" -square(input, *, out=None) -> Tensor +square(input: Tensor, *, out: Optional[Tensor]) -> Tensor Returns a new tensor with the square of the elements of :attr:`input`. @@ -10378,7 +10381,7 @@ def merge_dicts(*dicts): add_docstr( torch.squeeze, r""" -squeeze(input, dim=None) -> Tensor +squeeze(input: Tensor, dim: Optional[Union[int, List[int]]]) -> Tensor Returns a tensor with all specified dimensions of :attr:`input` of size `1` removed. @@ -12608,7 +12611,7 @@ def merge_dicts(*dicts): add_docstr( torch.combinations, r""" -combinations(input, r=2, with_replacement=False) -> seq +combinations(input: Tensor, r: int = 2, with_replacement: bool = False) -> seq Compute combinations of length :math:`r` of the given tensor. The behavior is similar to python's `itertools.combinations` when `with_replacement` is set to `False`, and diff --git a/torch/_utils.py b/torch/_utils.py index f0d38daa81149..e5c3a14ca81d7 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -67,6 +67,17 @@ def _to(self, device, non_blocking=False): if self.device == device: return self + if device.type == "cpu": + pin_memory = non_blocking and self.device.type in ( + "cuda", + torch._C._get_privateuse1_backend_name(), + ) + untyped_storage = torch.empty( + self.nbytes(), dtype=torch.uint8, device=device, pin_memory=pin_memory + ).untyped_storage() + untyped_storage.copy_(self, non_blocking) + return untyped_storage + device_module = getattr(torch, device.type, None) assert ( device_module is not None @@ -330,6 +341,13 @@ def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets): return torch._nested_view_from_buffer(buffer, sizes, strides, storage_offsets) +def _rebuild_device_tensor_from_cpu_tensor(data, dtype, device, requires_grad): + device = _get_restore_location(device) + tensor = data.to(dtype=dtype, device=device) + tensor.requires_grad = requires_grad + return tensor + + def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad): device = _get_restore_location(device) tensor = torch.from_numpy(data).to(dtype=dtype, device=device) diff --git a/torch/_utils_internal.py b/torch/_utils_internal.py index 1c7eaa955be23..9286a67563a5e 100644 --- a/torch/_utils_internal.py +++ b/torch/_utils_internal.py @@ -4,12 +4,16 @@ import os import sys import tempfile -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, List, Optional, TypeVar +from typing_extensions import ParamSpec import torch from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler +_T = TypeVar("_T") +_P = ParamSpec("_P") + log = logging.getLogger(__name__) if os.environ.get("TORCH_COMPILE_STROBELIGHT", False): @@ -76,12 +80,16 @@ def throw_abstract_impl_not_imported_error(opname, module, context): # NB! This treats "skip" kwarg specially!! -def compile_time_strobelight_meta(phase_name): - def compile_time_strobelight_meta_inner(function): +def compile_time_strobelight_meta( + phase_name: str, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def compile_time_strobelight_meta_inner( + function: Callable[_P, _T], + ) -> Callable[_P, _T]: @functools.wraps(function) - def wrapper_function(*args, **kwargs): - if "skip" in kwargs: - kwargs["skip"] = kwargs["skip"] + 1 + def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _T: + if "skip" in kwargs and isinstance(skip := kwargs["skip"], int): + kwargs["skip"] = skip + 1 if not StrobelightCompileTimeProfiler.enabled: return function(*args, **kwargs) @@ -357,6 +365,9 @@ def maybe_upload_prof_stats_to_manifold(profile_path: str) -> Optional[str]: def log_chromium_event_internal( - event, stack, compile_id, logger_uuid, start_timestamp=None + event: Dict[str, Any], + stack: List[str], + logger_uuid: str, + start_time_ns: int, ): return None diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 063b57d859e75..918db8ba0be9b 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -150,6 +150,9 @@ def _tensor_rebuild_functions(): # Reasoning is that we don't have control over the numpy functions, but # this utility is provided by pytorch torch._utils._rebuild_device_tensor_from_numpy, + # In 2.6, we should no longer have a dependency on numpy and the above + # _rebuild_device_tensor_from_numpy function. + torch._utils._rebuild_device_tensor_from_cpu_tensor, } @@ -166,6 +169,7 @@ def _get_allowed_globals(): "torch.device": torch.device, "_codecs.encode": encode, # for bytes "builtins.bytearray": bytearray, # for bytearray + "builtins.set": set, # for set } # dtype for t in torch.storage._dtype_to_storage_type_map().keys(): diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 6aba6bbad42ef..d940d04f8b7bc 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -355,6 +355,24 @@ def __enter__(self): torch.autocast_increment_nesting() torch.set_autocast_cache_enabled(self._cache_enabled) + # only dispatch to PreDispatchTorchFunctionMode to avoid exposing this + # API to other functional modes. We only expose to PreDispatchTorchFunctionMode + # for preserving autocast in torch.export.export. + if torch._C._is_torch_function_mode_enabled(): + stacks = torch.overrides._get_current_function_mode_stack() + for mode in stacks: + if isinstance( + mode, + torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode, + ): + args = ( + self.device, + self.fast_dtype, + self._enabled, + self._cache_enabled, + ) + return mode.__torch_function__(torch.amp._enter_autocast, (), args) + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[override] if torch._jit_internal.is_scripting(): return @@ -365,6 +383,18 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): # type: ignore[ov torch.set_autocast_enabled(self.device, self.prev) torch.set_autocast_dtype(self.device, self.prev_fastdtype) torch.set_autocast_cache_enabled(self.prev_cache_enabled) + + # only dispatch to PreDispatchTorchFunctionMode to avoid exposing this + # API to other functional modes. We only expose to PreDispatchTorchFunctionMode + # for preserving autocast in torch.export.export. + if torch._C._is_torch_function_mode_enabled(): + stacks = torch.overrides._get_current_function_mode_stack() + for mode in stacks: + if isinstance( + mode, + torch.fx.experimental.proxy_tensor.PreDispatchTorchFunctionMode, + ): + return mode.__torch_function__(torch.amp._exit_autocast, (), ()) return False def __call__(self, func): diff --git a/torch/ao/quantization/pt2e/utils.py b/torch/ao/quantization/pt2e/utils.py index 1cf91fbf3c8c0..8e966ffbff6c5 100644 --- a/torch/ao/quantization/pt2e/utils.py +++ b/torch/ao/quantization/pt2e/utils.py @@ -363,8 +363,7 @@ def _get_aten_graph_module_for_pattern( pattern: Callable, example_inputs: Tuple[Any, ...], is_cuda: bool = False, - *, - using_training_ir: bool, + using_training_ir: bool = True, **kwargs, ) -> GraphModule: """ diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py index ee497e38c39a4..89735523c0b6c 100644 --- a/torch/ao/quantization/utils.py +++ b/torch/ao/quantization/utils.py @@ -649,6 +649,7 @@ def determine_qparams( device = min_val_neg.device scale = torch.ones(min_val_neg.size(), dtype=torch.double, device=device) zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device) + eps = eps.to(device) if qscheme == torch.per_tensor_symmetric or qscheme == torch.per_channel_symmetric: max_val_pos = torch.max(-min_val_neg, max_val_pos) diff --git a/torch/backends/_nnapi/serializer.py b/torch/backends/_nnapi/serializer.py index 34bcc42f89278..0087fc47344cf 100644 --- a/torch/backends/_nnapi/serializer.py +++ b/torch/backends/_nnapi/serializer.py @@ -266,7 +266,7 @@ def broadcast_shapes(shape1, shape2): def get_conv_pool_shape(image_shape, args, out_ch, transpose): - batch, in_c, in_h, in_w = image_shape + batch, _in_c, in_h, in_w = image_shape # TODO: Handle dilation if args.dilation_h != 1 or args.dilation_w != 1: @@ -443,7 +443,6 @@ def add_tensor_operand_for_weight( operand_id = len(self.operands) self.operands.append(toper) tsize = tensor_size(toper.op_type, toper.shape) - psize = ((tsize - 1) | 0x3) + 1 self.values.append((operand_id, OperandValueSourceType.NUMBERED_BUFFER)) buf_num = len(self.used_weights) offset = 0 @@ -917,7 +916,7 @@ def add_node(self, node): adder(self, node) def _identity(self, node): - in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) + in_id, _in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) jitval = node.outputsAt(0) self.jitval_operand_map[jitval] = in_id @@ -1039,8 +1038,8 @@ def add_flatten(self, node): in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0)) - start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType") - end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType") + _start_ctype, start_dim = self.get_constant_value(node.inputsAt(1), "IntType") + _end_ctype, end_dim = self.get_constant_value(node.inputsAt(2), "IntType") # channels last with channels == 1 or (height & width both 1) is_trivial_flatten = len(in_oper.shape) == 4 and ( @@ -1526,7 +1525,7 @@ def add_prelu_op(self, node): def add_pool2d_node(self, node, opcode): assert node.inputsSize() == 6 assert node.outputsSize() == 1 - image, kernel, stride, padding, dilation, ceil_mode = node.inputs() + image, kernel, stride, padding, dilation, _ceil_mode = node.inputs() stride = stride or kernel @@ -1574,7 +1573,7 @@ def add_avg_pool2d(self, node): kernel, stride, padding, - ceil_mode, + _ceil_mode, count_include_pad, divisor_override, ) = node.inputs() @@ -1673,7 +1672,7 @@ def add_upsample_nearest2d(self, node): scale_ctype, scale_arg = self.get_constant_value(scale_jit) # type: ignore[possibly-undefined] else: scale_h_ctype, scale_h_arg = self.get_constant_value(scale_h_jit) # type: ignore[possibly-undefined] - scale_w_ctype, scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined] + scale_w_ctype, _scale_w_arg = self.get_constant_value(scale_w_jit) # type: ignore[possibly-undefined] # The only way for the 4-argument overload of upsample_nearest2d to # have been added to the graph without error is if the scale_h and @@ -1892,7 +1891,7 @@ def add_qlinear(self, node): self.add_operation(NNAPI_OperationCode.FULLY_CONNECTED, inputs, outputs) def get_optional_bias(self, jit_bias, weight_tensor, transpose=False): - ctype, value = self.get_constant_value(jit_bias) + ctype, _value = self.get_constant_value(jit_bias) if ctype.kind() == "NoneType": bias_idx = 1 if transpose else 0 nnapi_bias_tensor = torch.zeros( @@ -1919,7 +1918,7 @@ def add_conv2d(self, node): ) = node.inputs() _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") - bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor) + bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor) args = self.get_conv_pool_args_2d_from_jit( weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups ) @@ -1958,7 +1957,7 @@ def add_conv_underscore(self, node): _, weight_tensor = self.get_constant_value(jit_weight, "TensorType") _, transpose = self.get_constant_value(jit_transpose) - bias_id, bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose) + bias_id, _bias_oper = self.get_optional_bias(jit_bias, weight_tensor, transpose) args = self.get_conv_pool_args_2d_from_jit( weight_tensor.shape[2:4], jit_stride, jit_pad, jit_dilation, jit_groups ) @@ -1979,7 +1978,7 @@ def add_log_softmax(self, node): assert node.inputsSize() == 3 assert node.outputsSize() == 1 - (jit_input, jit_dim, jit_half_to_float) = node.inputs() + jit_input, jit_dim, _jit_half_to_float = node.inputs() input_id, input_oper = self.get_tensor_operand_by_jitval_fixed_size(jit_input) _, dim = self.get_constant_value(jit_dim, "IntType") @@ -2117,7 +2116,7 @@ def add_conv2d_common( if depthwise: # Depthwise convolution - one, kern_h, kern_w, out_c = weight_oper.shape + one, _kern_h, _kern_w, out_c = weight_oper.shape assert one == 1 assert out_c % in_c == 0 channel_multiplier = out_c // in_c @@ -2125,7 +2124,7 @@ def add_conv2d_common( assert out_c == in_c else: # Full convolution - out_c, kern_h, kern_w, kern_d = weight_oper.shape + out_c, _kern_h, _kern_w, kern_d = weight_oper.shape assert kern_d == in_c assert out_c == bias_oper.shape[0] diff --git a/torch/backends/cuda/__init__.py b/torch/backends/cuda/__init__.py index 3d49f22b2415a..2b7aa44946671 100644 --- a/torch/backends/cuda/__init__.py +++ b/torch/backends/cuda/__init__.py @@ -216,6 +216,7 @@ def preferred_linalg_library( "cublas": torch._C._BlasBackend.Cublas, "cublaslt": torch._C._BlasBackend.Cublaslt, "hipblaslt": torch._C._BlasBackend.Cublaslt, # alias + "ck": torch._C._BlasBackend.Ck, } _BlasBackends_str = ", ".join(_BlasBackends.keys()) @@ -224,16 +225,17 @@ def preferred_blas_library( backend: Union[None, str, torch._C._BlasBackend] = None ) -> torch._C._BlasBackend: r""" - Override the library PyTorch uses for BLAS operations. Choose between cuBLAS and cuBLASLt. + Override the library PyTorch uses for BLAS operations. Choose between cuBLAS, cuBLASLt, and CK [ROCm-only]. .. warning:: This flag is experimental and subject to change. When PyTorch runs a CUDA BLAS operation it defaults to cuBLAS even if both cuBLAS and cuBLASLt are available. - For PyTorch built for ROCm, hipBLAS and hipBLASLt may offer different performance. + For PyTorch built for ROCm, hipBLAS, hipBLASLt, and CK may offer different performance. This flag (a :class:`str`) allows overriding which BLAS library to use. * If `"cublas"` is set then cuBLAS will be used wherever possible. * If `"cublaslt"` is set then cuBLASLt will be used wherever possible. + * If `"ck"` is set then CK will be used wherever possible. * When no input is given, this function returns the currently preferred library. * User may use the environment variable TORCH_BLAS_PREFER_CUBLASLT=1 to set the preferred library to cuBLASLt globally. diff --git a/torch/backends/opt_einsum/__init__.py b/torch/backends/opt_einsum/__init__.py index ac63fa4bcf440..73c107cc1e448 100644 --- a/torch/backends/opt_einsum/__init__.py +++ b/torch/backends/opt_einsum/__init__.py @@ -16,7 +16,14 @@ @_lru_cache def is_available() -> bool: - r"""Return a bool indicating if opt_einsum is currently available.""" + r"""Return a bool indicating if opt_einsum is currently available. + + You must install opt-einsum in order for torch to automatically optimize einsum. To + make opt-einsum available, you can install it along with torch: ``pip install torch[opt-einsum]`` + or by itself: ``pip install opt-einsum``. If the package is installed, torch will import + it automatically and use it accordingly. Use this function to check whether opt-einsum + was installed and properly imported by torch. + """ return _opt_einsum is not None diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 7da8e911b83b2..60ed04aac946f 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -12,6 +12,7 @@ "substitute_in_graph", "list_backends", "disable", + "set_stance", "cudagraph_mark_step_begin", "wrap_numpy", "is_compiling", @@ -48,7 +49,8 @@ def allow_in_graph(fn): If you are using :func:`torch.compile` (with backend="inductor" (the default)), or :func:`torch.export.export`, and trying to black-box a Python function throughout all tracing, do not use this API. - Instead, please create a custom operator (see :ref:`custom-ops-landing-page`) + Instead, please create a custom operator (see `PyTorch Custom Operators Landing Page + `_) .. warning:: @@ -104,7 +106,7 @@ def allow_in_graph(fn): torch.compiler.allow_in_graph(my_custom_function) @torch.compile(...) - def fn(a): + def fn(x): x = torch.add(x, 1) x = my_custom_function(x) x = torch.add(x, 1) @@ -216,7 +218,7 @@ def assume_constant_result(fn): def disable(fn=None, recursive=True): """ - This function provides both a decorator and a context manager to disable compilation on a function + This function provides a decorator to disable compilation on a function It also provides the option of recursively disabling called functions Args: @@ -228,6 +230,58 @@ def disable(fn=None, recursive=True): return torch._dynamo.disable(fn, recursive) +def set_stance(stance: str, force_backend=None): + """ + Set the current stance of the compiler. + Can be used as a function, context manager, or decorator. + Do not use this function inside a `torch.compile` region - an error will be raised otherwise. + + .. code-block:: python + + @torch.compile + def foo(x): + ... + + @torch.compiler.set_stance("force_eager") + def bar(): + # will not be compiled + foo(...) + + bar() + + with torch.compiler.set_stance("force_eager"): + # will also not be compiled + foo(...) + + torch.compiler.set_stance("force_eager") + # will also not be compiled + foo(...) + torch.compiler.set_stance("default") + + # will be compiled + foo(...) + + Args: + stance: The stance to set the compiler to. Valid values are: + + - "default": The default stance, used for normal compilation. + - "force_eager": Ignore all `torch.compile` directives. + - "eager_on_recompile": Run code eagerly when a recompile is necessary. + If there is cached compiled code valid for the input, it will still be used. + - "fail_on_recompile": Raise an error when recompiling a function. + + force_backend: If `stance` is "default", this argument can be used to force `torch.compile` + to use a specific backend. Otherwise, an error is raised. + """ + import torch._dynamo + + return torch._dynamo.set_stance(stance, force_backend=force_backend) + + +# forbid in graph +set_stance._dynamo_forbidden = True # type: ignore[attr-defined] + + def cudagraph_mark_step_begin(): """ Indicates that a new iteration of inference or training is about to begin. diff --git a/torch/contrib/_tensorboard_vis.py b/torch/contrib/_tensorboard_vis.py index ed1445dd7bce6..2a1f88c36996f 100644 --- a/torch/contrib/_tensorboard_vis.py +++ b/torch/contrib/_tensorboard_vis.py @@ -37,7 +37,7 @@ def visualize(graph, name_prefix='', pb_graph=None, executors_it=None): return pb_graph # Set up an input node - input_node = pb_graph.node.add(op='input', name=name_prefix + 'input') + pb_graph.node.add(op='input', name=name_prefix + 'input') for i, value in enumerate(graph.param_node().outputs()): value_map[value.unique()] = name_prefix + 'input:' + str(i) diff --git a/torch/cpu/__init__.py b/torch/cpu/__init__.py index 8443e0447aa25..f62ddda893b3b 100644 --- a/torch/cpu/__init__.py +++ b/torch/cpu/__init__.py @@ -60,6 +60,11 @@ def _init_amx() -> bool: return torch._C._cpu._init_amx() +def _is_arm_sve_supported() -> bool: + r"""Returns a bool indicating if CPU supports Arm SVE.""" + return torch._C._cpu._is_arm_sve_supported() + + def is_available() -> bool: r"""Returns a bool indicating if CPU is currently available. diff --git a/torch/csrc/Exceptions.h b/torch/csrc/Exceptions.h index 36e67dcde39cf..8c1f37a10c264 100644 --- a/torch/csrc/Exceptions.h +++ b/torch/csrc/Exceptions.h @@ -353,7 +353,6 @@ using Arg = typename invoke_traits::template arg::type; template auto wrap_pybind_function_impl_( - // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) Func&& f, std::index_sequence, std::bool_constant) { @@ -363,7 +362,7 @@ auto wrap_pybind_function_impl_( return [f = std::forward(f)](Arg... args) { HANDLE_TH_ERRORS conditional_gil_scoped_release no_gil; - return c10::guts::invoke(f, std::forward>(args)...); + return std::invoke(f, std::forward>(args)...); END_HANDLE_TH_ERRORS_PYBIND }; } diff --git a/torch/csrc/Generator.cpp b/torch/csrc/Generator.cpp index 666a05b6af09d..1da0a3229db0a 100644 --- a/torch/csrc/Generator.cpp +++ b/torch/csrc/Generator.cpp @@ -79,7 +79,8 @@ static PyObject* THPGenerator_pynew( } else if (device.type() == at::kPrivateUse1) { self->cdata = at::GetGeneratorForPrivateuse1(device.index()); } else { - AT_ERROR( + TORCH_CHECK( + false, "Device type ", c10::DeviceTypeName(device.type()), " is not supported for torch.Generator() api."); diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 6a3f9ddc48310..416e5b5d72b4c 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -2079,7 +2079,8 @@ Call this whenever a new thread is created in order to propagate values from py::enum_(py_module, "_BlasBackend") .value("Cublas", at::BlasBackend::Cublas) - .value("Cublaslt", at::BlasBackend::Cublaslt); + .value("Cublaslt", at::BlasBackend::Cublaslt) + .value("Ck", at::BlasBackend::Ck); py_module.def("_set_blas_preferred_backend", [](at::BlasBackend b) { at::globalContext().setBlasPreferredBackend(b); @@ -2366,6 +2367,23 @@ Call this whenever a new thread is created in order to propagate values from "DisableTorchFunction", (PyObject*)THPModule_DisableTorchFunctionType(), /* incref= */ false)); + py::enum_( + py_module, "_TorchFunctionState") + .value("ENABLED", at::impl::TorchFunctionDisabledState::ENABLED) + .value( + "SUBCLASSES_DISABLED", + at::impl::TorchFunctionDisabledState::SUBCLASSES_DISABLED) + .value( + "ALL_DISABLED", at::impl::TorchFunctionDisabledState::ALL_DISABLED); + + py_module.def( + "_set_torch_function_state", + [](at::impl::TorchFunctionDisabledState state) { + at::impl::PythonTorchFunctionTLS::set_disabled_state(state); + }); + py_module.def("_get_torch_function_state", []() { + return at::impl::PythonTorchFunctionTLS::get_disabled_state(); + }); torch::set_disabled_torch_function_impl( PyObject_GetAttrString(module, "_disabled_torch_function_impl")); ASSERT_TRUE(torch::disabled_torch_function_impl() != nullptr); diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index effe761d49a8e..688d1d6ef5752 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -482,8 +482,7 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) { return THPByteUtils_newReal(value); /* Slice index */ } else if (PySlice_Check(index)) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Py_ssize_t start, stop, slicelength, step; + Py_ssize_t start = 0, stop = 0, slicelength = 0, step = 0; if (PySlice_Unpack(index, &start, &stop, &step) < 0) { return nullptr; } @@ -554,8 +553,7 @@ static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) { storage_set(storage, nindex, rvalue); return 0; } else if (PySlice_Check(index)) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - Py_ssize_t start, stop, step; + Py_ssize_t start = 0, stop = 0, step = 0; Py_ssize_t len = static_cast(storage.nbytes()); if (PySlice_Unpack(index, &start, &stop, &step) < 0) { return -1; diff --git a/torch/csrc/StorageSharing.cpp b/torch/csrc/StorageSharing.cpp index bba836dc916bc..9f1f71ae7fe99 100644 --- a/torch/csrc/StorageSharing.cpp +++ b/torch/csrc/StorageSharing.cpp @@ -294,7 +294,8 @@ static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) { c10::StorageImpl* storage_impl = storage.unsafeGetStorageImpl(); if (storage_impl->received_cuda()) { - AT_ERROR( + TORCH_CHECK( + false, "Attempted to send CUDA tensor received from another process; this is not currently supported. Consider cloning before sending."); } @@ -313,7 +314,6 @@ static PyObject* THPStorage_shareCuda(PyObject* self, PyObject* noargs) { THPObjectPtr _event_sync_required(Py_None); Py_INCREF(Py_None); if (storage.data()) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) auto shandle = c10::cuda::CUDACachingAllocator::shareIpcHandle(storage.mutable_data()); _handle = PyBytes_FromStringAndSize( @@ -470,8 +470,7 @@ static PyObject* THPStorage_newSharedCuda(PyObject* _unused, PyObject* args) { } auto ipc_event_handle = reinterpret_cast( s_ipc_event_handle.c_str()); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - cudaEvent_t event; + cudaEvent_t event = nullptr; cudaIpcOpenEventHandle(&event, *ipc_event_handle); C10_CUDA_CHECK( cudaStreamWaitEvent(c10::cuda::getCurrentCUDAStream(device), event, 0)); diff --git a/torch/csrc/api/include/torch/all.h b/torch/csrc/api/include/torch/all.h index 56ed75c833117..026f4f9f579e9 100644 --- a/torch/csrc/api/include/torch/all.h +++ b/torch/csrc/api/include/torch/all.h @@ -10,7 +10,6 @@ #include #include #include -#include #include #include #include diff --git a/torch/csrc/api/include/torch/cuda.h b/torch/csrc/api/include/torch/cuda.h index 537ddf02479c2..31ad826214d2c 100644 --- a/torch/csrc/api/include/torch/cuda.h +++ b/torch/csrc/api/include/torch/cuda.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace cuda { +namespace torch::cuda { /// Returns the number of CUDA devices available. size_t TORCH_API device_count(); @@ -26,5 +25,4 @@ void TORCH_API manual_seed_all(uint64_t seed); /// Waits for all kernels in all streams on a CUDA device to complete. void TORCH_API synchronize(int64_t device_index = -1); -} // namespace cuda -} // namespace torch +} // namespace torch::cuda diff --git a/torch/csrc/api/include/torch/data.h b/torch/csrc/api/include/torch/data.h index ac718acd4fa31..78aae1d25c27c 100644 --- a/torch/csrc/api/include/torch/data.h +++ b/torch/csrc/api/include/torch/data.h @@ -6,9 +6,8 @@ #include // Some "exports". -namespace torch { -namespace data { -using datasets::BatchDataset; -using datasets::Dataset; -} // namespace data -} // namespace torch + +namespace torch::data { +using datasets::BatchDataset; // NOLINT +using datasets::Dataset; // NOLINT +} // namespace torch::data diff --git a/torch/csrc/api/include/torch/data/dataloader.h b/torch/csrc/api/include/torch/data/dataloader.h index 158813043af61..c60abc79c847e 100644 --- a/torch/csrc/api/include/torch/data/dataloader.h +++ b/torch/csrc/api/include/torch/data/dataloader.h @@ -12,8 +12,7 @@ #include #include -namespace torch { -namespace data { +namespace torch::data { /// Creates a `DataLoader` instance for a stateless `dataset`, a `sampler` and /// some `options`. @@ -23,7 +22,7 @@ std::enable_if_t< std::unique_ptr>> make_data_loader(Dataset dataset, Sampler sampler, DataLoaderOptions options) { return std::make_unique>( - std::move(dataset), std::move(sampler), std::move(options)); + std::move(dataset), std::move(sampler), options); } /// Creates a `DataLoader` instance for a stateless `dataset` and some @@ -41,8 +40,7 @@ make_data_loader( size.has_value(), "Expected the dataset to be sized in " "order to construct the Sampler"); - return make_data_loader( - std::move(dataset), Sampler(*size), std::move(options)); + return make_data_loader(std::move(dataset), Sampler(*size), options); } /// Creates a `DataLoader` for a stateful `dataset` and some `options`. @@ -51,7 +49,6 @@ std::unique_ptr> make_data_loader( Dataset dataset, DataLoaderOptions options = DataLoaderOptions()) { return std::make_unique>( - std::move(dataset), std::move(options)); + std::move(dataset), options); } -} // namespace data -} // namespace torch +} // namespace torch::data diff --git a/torch/csrc/api/include/torch/data/dataloader/base.h b/torch/csrc/api/include/torch/data/dataloader/base.h index cb17843ba0b33..aff15f34fec0c 100644 --- a/torch/csrc/api/include/torch/data/dataloader/base.h +++ b/torch/csrc/api/include/torch/data/dataloader/base.h @@ -17,12 +17,10 @@ #include #include #include -#include #include #include -namespace torch { -namespace data { +namespace torch::data { template class DataLoaderBase { public: @@ -35,7 +33,7 @@ class DataLoaderBase { DataLoaderBase( DataLoaderOptions options, std::unique_ptr main_thread_dataset = nullptr) - : options_(std::move(options)), + : options_(options), main_thread_dataset_(std::move(main_thread_dataset)), sequencer_(new_sequencer()) {} @@ -82,8 +80,7 @@ class DataLoaderBase { // Send one 'quit' message per worker. Since a worker dies (exits its // thread) after receiving this message, each `QuitWorker()` message will be // read by exactly one worker. - for (const auto w : c10::irange(options_.workers)) { - (void)w; // Suppress unused variable warning + for ([[maybe_unused]] const auto w : c10::irange(options_.workers)) { push_job(QuitWorker()); } for (auto& worker : workers_) { @@ -146,8 +143,7 @@ class DataLoaderBase { /// Schedules `requested_jobs` many new batches to be fetched. The actual /// number of jobs scheduled may be less if the DataLoader exhausts. void prefetch(size_t requested_jobs) { - for (const auto r : c10::irange(requested_jobs)) { - (void)r; // Suppress unused variable + for ([[maybe_unused]] const auto r : c10::irange(requested_jobs)) { if (auto batch_request = get_batch_request()) { this->push_job(std::move(*batch_request)); } else { @@ -220,7 +216,7 @@ class DataLoaderBase { } /// The options the DataLoader was configured with. - // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const FullDataLoaderOptions options_; /// The dataset for the main thread, only has a value if the number of @@ -251,5 +247,4 @@ class DataLoaderBase { // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) bool joined_ = false; }; -} // namespace data -} // namespace torch +} // namespace torch::data diff --git a/torch/csrc/api/include/torch/data/dataloader/stateful.h b/torch/csrc/api/include/torch/data/dataloader/stateful.h index 6ae027119a0c9..964a1ffcc7f6c 100644 --- a/torch/csrc/api/include/torch/data/dataloader/stateful.h +++ b/torch/csrc/api/include/torch/data/dataloader/stateful.h @@ -7,8 +7,7 @@ #include #include -namespace torch { -namespace data { +namespace torch::data { /// A dataloader for stateful datasets. /// @@ -59,5 +58,4 @@ class StatefulDataLoader : public DataLoaderBase< return this->options_.batch_size; } }; -} // namespace data -} // namespace torch +} // namespace torch::data diff --git a/torch/csrc/api/include/torch/data/dataloader/stateless.h b/torch/csrc/api/include/torch/data/dataloader/stateless.h index 422b1097ee71b..d8f94d471ce03 100644 --- a/torch/csrc/api/include/torch/data/dataloader/stateless.h +++ b/torch/csrc/api/include/torch/data/dataloader/stateless.h @@ -10,8 +10,7 @@ #include #include -namespace torch { -namespace data { +namespace torch::data { /// A dataloader for stateless datasets. /// @@ -38,7 +37,7 @@ class StatelessDataLoader : public DataLoaderBase< Dataset dataset, Sampler sampler, DataLoaderOptions options) - : super(std::move(options)), sampler_(std::move(sampler)) { + : super(options), sampler_(std::move(sampler)) { for (const auto w : c10::irange(this->options_.workers)) { // Here we copy the dataset into the worker thread closure. Each worker // has its own copy of the dataset. This means the dataset must be @@ -78,5 +77,4 @@ class StatelessDataLoader : public DataLoaderBase< /// The `Sampler` used to produce batch requests. Sampler sampler_; }; -} // namespace data -} // namespace torch +} // namespace torch::data diff --git a/torch/csrc/api/include/torch/data/dataloader_options.h b/torch/csrc/api/include/torch/data/dataloader_options.h index a0c96aee07713..34dd3a00dc47a 100644 --- a/torch/csrc/api/include/torch/data/dataloader_options.h +++ b/torch/csrc/api/include/torch/data/dataloader_options.h @@ -6,8 +6,7 @@ #include #include -namespace torch { -namespace data { +namespace torch::data { /// Options to configure a `DataLoader`. struct DataLoaderOptions { @@ -61,5 +60,4 @@ struct FullDataLoaderOptions { bool enforce_ordering; bool drop_last; }; -} // namespace data -} // namespace torch +} // namespace torch::data diff --git a/torch/csrc/api/include/torch/data/datasets/base.h b/torch/csrc/api/include/torch/data/datasets/base.h index f17b3fe8af475..e5232ab0d7a3c 100644 --- a/torch/csrc/api/include/torch/data/datasets/base.h +++ b/torch/csrc/api/include/torch/data/datasets/base.h @@ -11,20 +11,14 @@ #include #include -namespace torch { -namespace data { -namespace datasets { +namespace torch::data::datasets { template class MapDataset; template MapDataset map(D, T); // NOLINT -} // namespace datasets -} // namespace data -} // namespace torch +} // namespace torch::data::datasets -namespace torch { -namespace data { -namespace datasets { +namespace torch::data::datasets { namespace detail { template struct is_optional : std::false_type {}; @@ -99,6 +93,4 @@ class Dataset : public BatchDataset> { /// yields that many elements from the stream. template >> using StreamDataset = BatchDataset; -} // namespace datasets -} // namespace data -} // namespace torch +} // namespace torch::data::datasets diff --git a/torch/csrc/api/include/torch/data/datasets/chunk.h b/torch/csrc/api/include/torch/data/datasets/chunk.h index 01d940aa3e488..c51822536d9e6 100644 --- a/torch/csrc/api/include/torch/data/datasets/chunk.h +++ b/torch/csrc/api/include/torch/data/datasets/chunk.h @@ -6,12 +6,11 @@ #include #include #include +#include #include -namespace torch { -namespace data { -namespace datasets { +namespace torch::data::datasets { /// Interface for chunk reader, which performs data chunking and reading of /// entire chunks. @@ -138,7 +137,6 @@ class BatchDataBuffer { // If we still have data remaining after filling the last pushed batch, add // them to the queue too. - // NOLINTNEXTLINE(bugprone-infinite-loop) while (remaining_size > 0) { UnwrappedBatchType current_batch; @@ -213,8 +211,8 @@ class BatchDataBuffer { explicit UnwrappedBatchData(UnwrappedBatchType data) : batch_data(std::move(data)) {} - // NOLINTNEXTLINE(modernize-pass-by-value) - explicit UnwrappedBatchData(std::exception_ptr e) : exception(e) {} + explicit UnwrappedBatchData(std::exception_ptr e) + : exception(std::move(e)) {} /// batch data to return UnwrappedBatchType batch_data; @@ -233,6 +231,7 @@ class BatchDataBuffer { std::condition_variable cv_read_; std::condition_variable cv_write_; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) ExampleSampler& example_sampler_; // configurable maximun number of elements the queue can hold at one time. @@ -333,11 +332,10 @@ class ChunkDataset final : chunk_reader_(std::move(chunk_reader)), chunk_sampler_(std::move(chunk_sampler)), example_sampler_(std::move(example_sampler)), - options_(std::move(options)), + options_(options), preprocessing_policy_(std::move(preprocessing_policy)), quit_worker_(false), - running_preloaders_(0), - load_checkpoint_(false) {} + running_preloaders_(0) {} ~ChunkDataset() override { // stop batch buffer first. @@ -496,6 +494,7 @@ class ChunkDataset final std::vector preload_threads_; /// The options the Dataset was configured with. + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const ChunkDatasetOptions options_; // function pointer wrapper to apply custom processing over chunk data. This @@ -522,8 +521,6 @@ class ChunkDataset final // boolean value to indicate whether we need to load the checkpoint for // chunk_sampler_. - bool load_checkpoint_; + bool load_checkpoint_{false}; }; -} // namespace datasets -} // namespace data -} // namespace torch +} // namespace torch::data::datasets diff --git a/torch/csrc/api/include/torch/data/datasets/map.h b/torch/csrc/api/include/torch/data/datasets/map.h index ebd4374cca8f3..b23e881391ab6 100644 --- a/torch/csrc/api/include/torch/data/datasets/map.h +++ b/torch/csrc/api/include/torch/data/datasets/map.h @@ -9,12 +9,10 @@ #include #include -namespace torch { -namespace data { -namespace datasets { +namespace torch::data::datasets { namespace detail { template -using optional_if_t = typename std::conditional, T>::type; +using optional_if_t = std::conditional_t, T>; } // namespace detail /// A `MapDataset` is a dataset that applies a transform to a source dataset. @@ -103,16 +101,14 @@ MapDataset map( DatasetType dataset, TransformType transform) { static_assert( - std::is_same< - typename std::conditional< + std::is_same_v< + std::conditional_t< DatasetType::is_stateful, typename DatasetType::BatchType::value_type, - typename DatasetType::BatchType>::type, - typename TransformType::InputBatchType>::value, + typename DatasetType::BatchType>, + typename TransformType::InputBatchType>, "BatchType type of dataset does not match input type of transform"); return {std::move(dataset), std::move(transform)}; } -} // namespace datasets -} // namespace data -} // namespace torch +} // namespace torch::data::datasets diff --git a/torch/csrc/api/include/torch/data/datasets/mnist.h b/torch/csrc/api/include/torch/data/datasets/mnist.h index 5d9e352f36d07..c19a862ba99f7 100644 --- a/torch/csrc/api/include/torch/data/datasets/mnist.h +++ b/torch/csrc/api/include/torch/data/datasets/mnist.h @@ -9,9 +9,7 @@ #include #include -namespace torch { -namespace data { -namespace datasets { +namespace torch::data::datasets { /// The MNIST dataset. class TORCH_API MNIST : public Dataset { public: @@ -43,6 +41,4 @@ class TORCH_API MNIST : public Dataset { private: Tensor images_, targets_; }; -} // namespace datasets -} // namespace data -} // namespace torch +} // namespace torch::data::datasets diff --git a/torch/csrc/api/include/torch/data/datasets/shared.h b/torch/csrc/api/include/torch/data/datasets/shared.h index aff84b586c89c..725cfb5ffdf4a 100644 --- a/torch/csrc/api/include/torch/data/datasets/shared.h +++ b/torch/csrc/api/include/torch/data/datasets/shared.h @@ -5,9 +5,7 @@ #include #include -namespace torch { -namespace data { -namespace datasets { +namespace torch::data::datasets { /// A dataset that wraps another dataset in a shared pointer and implements the /// `BatchDataset` API, delegating all calls to the shared instance. This is @@ -78,6 +76,4 @@ template SharedBatchDataset make_shared_dataset(Args&&... args) { return std::make_shared(std::forward(args)...); } -} // namespace datasets -} // namespace data -} // namespace torch +} // namespace torch::data::datasets diff --git a/torch/csrc/api/include/torch/data/datasets/stateful.h b/torch/csrc/api/include/torch/data/datasets/stateful.h index fb2379c673340..adc210fcf3d5e 100644 --- a/torch/csrc/api/include/torch/data/datasets/stateful.h +++ b/torch/csrc/api/include/torch/data/datasets/stateful.h @@ -6,16 +6,12 @@ #include #include -namespace torch { -namespace serialize { +namespace torch::serialize { class OutputArchive; class InputArchive; -} // namespace serialize -} // namespace torch +} // namespace torch::serialize -namespace torch { -namespace data { -namespace datasets { +namespace torch::data::datasets { /// A stateful dataset is a dataset that maintains some internal state, which /// will be `reset()` at the beginning of each epoch. Subclasses can override @@ -65,6 +61,4 @@ serialize::InputArchive& operator>>( return archive; } -} // namespace datasets -} // namespace data -} // namespace torch +} // namespace torch::data::datasets diff --git a/torch/csrc/api/include/torch/data/datasets/tensor.h b/torch/csrc/api/include/torch/data/datasets/tensor.h index 4968e263009f3..1c9fd2130fe64 100644 --- a/torch/csrc/api/include/torch/data/datasets/tensor.h +++ b/torch/csrc/api/include/torch/data/datasets/tensor.h @@ -7,9 +7,7 @@ #include #include -namespace torch { -namespace data { -namespace datasets { +namespace torch::data::datasets { /// A dataset of tensors. /// Stores a single tensor internally, which is then indexed inside `get()`. @@ -22,7 +20,7 @@ struct TensorDataset : public Dataset { /// Returns a single `TensorExample`. TensorExample get(size_t index) override { - return tensor[index]; + return tensor[static_cast(index)]; } /// Returns the number of tensors in the dataset. @@ -33,6 +31,4 @@ struct TensorDataset : public Dataset { Tensor tensor; }; -} // namespace datasets -} // namespace data -} // namespace torch +} // namespace torch::data::datasets diff --git a/torch/csrc/api/include/torch/data/detail/data_shuttle.h b/torch/csrc/api/include/torch/data/detail/data_shuttle.h index 9c3ef12116012..b1e7ac8768688 100644 --- a/torch/csrc/api/include/torch/data/detail/data_shuttle.h +++ b/torch/csrc/api/include/torch/data/detail/data_shuttle.h @@ -9,9 +9,7 @@ #include #include -namespace torch { -namespace data { -namespace detail { +namespace torch::data::detail { /// Encapsulates the full life cycle of DataLoader jobs. /// @@ -82,6 +80,4 @@ class DataShuttle { Queue results_; }; -} // namespace detail -} // namespace data -} // namespace torch +} // namespace torch::data::detail diff --git a/torch/csrc/api/include/torch/data/detail/queue.h b/torch/csrc/api/include/torch/data/detail/queue.h index 60236ab3f520c..71752d1af3f78 100644 --- a/torch/csrc/api/include/torch/data/detail/queue.h +++ b/torch/csrc/api/include/torch/data/detail/queue.h @@ -10,9 +10,7 @@ #include #include -namespace torch { -namespace data { -namespace detail { +namespace torch::data::detail { /// A basic locked, blocking MPMC queue. /// @@ -46,7 +44,7 @@ class Queue { if (!cv_.wait_for( lock, *timeout, [this] { return !this->queue_.empty(); })) { // clang-format off - AT_ERROR( + TORCH_CHECK(false, "Timeout in DataLoader queue while waiting for next batch" " (timeout was ", timeout->count(), " ms)"); // clang-format on @@ -79,6 +77,4 @@ class Queue { std::mutex mutex_; std::condition_variable cv_; }; -} // namespace detail -} // namespace data -} // namespace torch +} // namespace torch::data::detail diff --git a/torch/csrc/api/include/torch/data/detail/sequencers.h b/torch/csrc/api/include/torch/data/detail/sequencers.h index c59f4cd7e290d..779e21f3a4b68 100644 --- a/torch/csrc/api/include/torch/data/detail/sequencers.h +++ b/torch/csrc/api/include/torch/data/detail/sequencers.h @@ -6,10 +6,7 @@ #include #include -namespace torch { -namespace data { -namespace detail { -namespace sequencers { +namespace torch::data::detail::sequencers { namespace detail { template bool buffer_contains_result(const std::vector>& buffer) { @@ -107,7 +104,4 @@ struct OrderedSequencer : public Sequencer { /// A fixed-size buffer (after construction). std::vector> buffer_; }; -} // namespace sequencers -} // namespace detail -} // namespace data -} // namespace torch +} // namespace torch::data::detail::sequencers diff --git a/torch/csrc/api/include/torch/data/example.h b/torch/csrc/api/include/torch/data/example.h index 57219a24cd0b0..af4b08371a82b 100644 --- a/torch/csrc/api/include/torch/data/example.h +++ b/torch/csrc/api/include/torch/data/example.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace data { +namespace torch::data { /// An `Example` from a dataset. /// @@ -51,5 +50,4 @@ struct Example { }; using TensorExample = Example; -} // namespace data -} // namespace torch +} // namespace torch::data diff --git a/torch/csrc/api/include/torch/data/iterator.h b/torch/csrc/api/include/torch/data/iterator.h index 94293c452d53c..a0ee28a73e018 100644 --- a/torch/csrc/api/include/torch/data/iterator.h +++ b/torch/csrc/api/include/torch/data/iterator.h @@ -11,8 +11,7 @@ #include #include -namespace torch { -namespace data { +namespace torch::data { namespace detail { // For increased safety and more separated logic, this implementation of // `Iterator` consists of a `ValidIterator` and a `SentinelIterator`. A @@ -101,12 +100,14 @@ struct ValidIterator : public IteratorImpl { template struct SentinelIterator : public IteratorImpl { void next() override { - AT_ERROR( + TORCH_CHECK( + false, "Incrementing the DataLoader's past-the-end iterator is not allowed"); } Batch& get() override { - AT_ERROR( + TORCH_CHECK( + false, "Dereferencing the DataLoader's past-the-end iterator is not allowed"); } @@ -174,5 +175,4 @@ class Iterator { /// Points either to a `ValidIterator` or to a `SentinelIterator`. std::shared_ptr> impl_; }; -} // namespace data -} // namespace torch +} // namespace torch::data diff --git a/torch/csrc/api/include/torch/data/samplers/base.h b/torch/csrc/api/include/torch/data/samplers/base.h index 8ab48d9d5931f..67c1ad5ea7cbe 100644 --- a/torch/csrc/api/include/torch/data/samplers/base.h +++ b/torch/csrc/api/include/torch/data/samplers/base.h @@ -7,16 +7,12 @@ #include #include -namespace torch { -namespace serialize { +namespace torch::serialize { class OutputArchive; class InputArchive; -} // namespace serialize -} // namespace torch +} // namespace torch::serialize -namespace torch { -namespace data { -namespace samplers { +namespace torch::data::samplers { /// A `Sampler` is an object that yields an index with which to access a /// dataset. template > @@ -42,6 +38,4 @@ class Sampler { virtual void load(serialize::InputArchive& archive) = 0; }; -} // namespace samplers -} // namespace data -} // namespace torch +} // namespace torch::data::samplers diff --git a/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h b/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h index a5247b008d750..7132856fe2359 100644 --- a/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h +++ b/torch/csrc/api/include/torch/data/samplers/custom_batch_request.h @@ -3,9 +3,7 @@ #include #include -namespace torch { -namespace data { -namespace samplers { +namespace torch::data::samplers { /// A base class for custom index types. struct TORCH_API CustomBatchRequest { CustomBatchRequest() = default; @@ -16,6 +14,4 @@ struct TORCH_API CustomBatchRequest { /// The number of elements accessed by this index. virtual size_t size() const = 0; }; -} // namespace samplers -} // namespace data -} // namespace torch +} // namespace torch::data::samplers diff --git a/torch/csrc/api/include/torch/data/samplers/distributed.h b/torch/csrc/api/include/torch/data/samplers/distributed.h index bce36aaa4df71..64be81645dcc6 100644 --- a/torch/csrc/api/include/torch/data/samplers/distributed.h +++ b/torch/csrc/api/include/torch/data/samplers/distributed.h @@ -6,16 +6,12 @@ #include #include -namespace torch { -namespace serialize { +namespace torch::serialize { class OutputArchive; class InputArchive; -} // namespace serialize -} // namespace torch +} // namespace torch::serialize -namespace torch { -namespace data { -namespace samplers { +namespace torch::data::samplers { /// A `Sampler` that selects a subset of indices to sample from and defines a /// sampling behavior. In a distributed setting, this selects a subset of the @@ -33,7 +29,7 @@ class DistributedSampler : public Sampler { : size_(size), num_replicas_(num_replicas), rank_(rank), - epoch_(0), + allow_duplicates_(allow_duplicates) {} /// Set the epoch for the current enumeration. This can be used to alter the @@ -62,7 +58,7 @@ class DistributedSampler : public Sampler { // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) size_t rank_; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) - size_t epoch_; + size_t epoch_{0}; // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) bool allow_duplicates_; }; @@ -134,6 +130,4 @@ class TORCH_API DistributedSequentialSampler : public DistributedSampler<> { std::vector all_indices_; }; -} // namespace samplers -} // namespace data -} // namespace torch +} // namespace torch::data::samplers diff --git a/torch/csrc/api/include/torch/data/samplers/random.h b/torch/csrc/api/include/torch/data/samplers/random.h index 4b023b6c703af..fc81aae7c3b52 100644 --- a/torch/csrc/api/include/torch/data/samplers/random.h +++ b/torch/csrc/api/include/torch/data/samplers/random.h @@ -7,16 +7,12 @@ #include #include -namespace torch { -namespace serialize { +namespace torch::serialize { class OutputArchive; class InputArchive; -} // namespace serialize -} // namespace torch +} // namespace torch::serialize -namespace torch { -namespace data { -namespace samplers { +namespace torch::data::samplers { /// A `Sampler` that returns random indices. class TORCH_API RandomSampler : public Sampler<> { @@ -49,6 +45,4 @@ class TORCH_API RandomSampler : public Sampler<> { at::Tensor indices_; int64_t index_ = 0; }; -} // namespace samplers -} // namespace data -} // namespace torch +} // namespace torch::data::samplers diff --git a/torch/csrc/api/include/torch/data/samplers/sequential.h b/torch/csrc/api/include/torch/data/samplers/sequential.h index 252ecc3ad3d75..2b57f90d116f5 100644 --- a/torch/csrc/api/include/torch/data/samplers/sequential.h +++ b/torch/csrc/api/include/torch/data/samplers/sequential.h @@ -7,16 +7,12 @@ #include #include -namespace torch { -namespace serialize { +namespace torch::serialize { class OutputArchive; class InputArchive; -} // namespace serialize -} // namespace torch +} // namespace torch::serialize -namespace torch { -namespace data { -namespace samplers { +namespace torch::data::samplers { /// A `Sampler` that returns indices sequentially. class TORCH_API SequentialSampler : public Sampler<> { @@ -45,6 +41,4 @@ class TORCH_API SequentialSampler : public Sampler<> { size_t index_{0}; }; -} // namespace samplers -} // namespace data -} // namespace torch +} // namespace torch::data::samplers diff --git a/torch/csrc/api/include/torch/data/samplers/serialize.h b/torch/csrc/api/include/torch/data/samplers/serialize.h index 7585217a9cf26..8c87a9b3d00e2 100644 --- a/torch/csrc/api/include/torch/data/samplers/serialize.h +++ b/torch/csrc/api/include/torch/data/samplers/serialize.h @@ -3,9 +3,7 @@ #include #include -namespace torch { -namespace data { -namespace samplers { +namespace torch::data::samplers { /// Serializes a `Sampler` into an `OutputArchive`. template serialize::OutputArchive& operator<<( @@ -23,6 +21,4 @@ serialize::InputArchive& operator>>( sampler.load(archive); return archive; } -} // namespace samplers -} // namespace data -} // namespace torch +} // namespace torch::data::samplers diff --git a/torch/csrc/api/include/torch/data/samplers/stream.h b/torch/csrc/api/include/torch/data/samplers/stream.h index 201c914e49e5c..c5eb8214cdf64 100644 --- a/torch/csrc/api/include/torch/data/samplers/stream.h +++ b/torch/csrc/api/include/torch/data/samplers/stream.h @@ -7,16 +7,12 @@ #include -namespace torch { -namespace serialize { +namespace torch::serialize { class InputArchive; class OutputArchive; -} // namespace serialize -} // namespace torch +} // namespace torch::serialize -namespace torch { -namespace data { -namespace samplers { +namespace torch::data::samplers { /// A wrapper around a batch size value, which implements the /// `CustomBatchRequest` interface. @@ -58,6 +54,4 @@ class TORCH_API StreamSampler : public Sampler { size_t epoch_size_; }; -} // namespace samplers -} // namespace data -} // namespace torch +} // namespace torch::data::samplers diff --git a/torch/csrc/api/include/torch/data/transforms/base.h b/torch/csrc/api/include/torch/data/transforms/base.h index 0bc1f2ea7b141..b2ee9ed81f6b5 100644 --- a/torch/csrc/api/include/torch/data/transforms/base.h +++ b/torch/csrc/api/include/torch/data/transforms/base.h @@ -5,9 +5,7 @@ #include #include -namespace torch { -namespace data { -namespace transforms { +namespace torch::data::transforms { /// A transformation of a batch to a new batch. template @@ -48,6 +46,4 @@ class Transform return output_batch; } }; -} // namespace transforms -} // namespace data -} // namespace torch +} // namespace torch::data::transforms diff --git a/torch/csrc/api/include/torch/data/transforms/collate.h b/torch/csrc/api/include/torch/data/transforms/collate.h index 181bcae0031b6..8905fc7f7c936 100644 --- a/torch/csrc/api/include/torch/data/transforms/collate.h +++ b/torch/csrc/api/include/torch/data/transforms/collate.h @@ -5,9 +5,7 @@ #include -namespace torch { -namespace data { -namespace transforms { +namespace torch::data::transforms { /// A `Collation` is a transform that reduces a batch into a single value. /// The result is a `BatchDataset` that has the type of the single value as its @@ -30,6 +28,4 @@ using Collation = BatchTransform; /// \endrst template > using Collate = BatchLambda; -} // namespace transforms -} // namespace data -} // namespace torch +} // namespace torch::data::transforms diff --git a/torch/csrc/api/include/torch/data/transforms/lambda.h b/torch/csrc/api/include/torch/data/transforms/lambda.h index 252b29807a8ef..c9cfa15431b26 100644 --- a/torch/csrc/api/include/torch/data/transforms/lambda.h +++ b/torch/csrc/api/include/torch/data/transforms/lambda.h @@ -6,9 +6,7 @@ #include #include -namespace torch { -namespace data { -namespace transforms { +namespace torch::data::transforms { /// A `BatchTransform` that applies a user-provided functor to a batch. template @@ -51,6 +49,4 @@ class Lambda : public Transform { FunctionType function_; }; -} // namespace transforms -} // namespace data -} // namespace torch +} // namespace torch::data::transforms diff --git a/torch/csrc/api/include/torch/data/transforms/stack.h b/torch/csrc/api/include/torch/data/transforms/stack.h index 4be1bd920b715..26063db4ea853 100644 --- a/torch/csrc/api/include/torch/data/transforms/stack.h +++ b/torch/csrc/api/include/torch/data/transforms/stack.h @@ -7,9 +7,7 @@ #include #include -namespace torch { -namespace data { -namespace transforms { +namespace torch::data::transforms { template > struct Stack; @@ -44,6 +42,4 @@ struct Stack return torch::stack(data); } }; -} // namespace transforms -} // namespace data -} // namespace torch +} // namespace torch::data::transforms diff --git a/torch/csrc/api/include/torch/data/transforms/tensor.h b/torch/csrc/api/include/torch/data/transforms/tensor.h index 2e135c5281315..7b6280bd96859 100644 --- a/torch/csrc/api/include/torch/data/transforms/tensor.h +++ b/torch/csrc/api/include/torch/data/transforms/tensor.h @@ -7,9 +7,7 @@ #include #include -namespace torch { -namespace data { -namespace transforms { +namespace torch::data::transforms { /// A `Transform` that is specialized for the typical `Example` /// combination. It exposes a single `operator()` interface hook (for @@ -72,6 +70,4 @@ struct Normalize : public TensorTransform { torch::Tensor mean, stddev; }; -} // namespace transforms -} // namespace data -} // namespace torch +} // namespace torch::data::transforms diff --git a/torch/csrc/api/include/torch/data/worker_exception.h b/torch/csrc/api/include/torch/data/worker_exception.h index 40680b8330c45..afaf369e55376 100644 --- a/torch/csrc/api/include/torch/data/worker_exception.h +++ b/torch/csrc/api/include/torch/data/worker_exception.h @@ -4,8 +4,7 @@ #include #include -namespace torch { -namespace data { +namespace torch::data { /// An exception thrown when a DataLoader's worker thread throws an exception, /// which is caught. A `WorkerException` stores an `exception_ptr` to the @@ -13,6 +12,7 @@ namespace data { struct WorkerException : public std::exception { /// Constructs a `WorkerException` from an `exception_ptr`. explicit WorkerException(std::exception_ptr original) + // NOLINTNEXTLINE(bugprone-throw-keyword-missing) : original_exception(std::move(original)), message("Caught exception in DataLoader worker thread.") { try { @@ -34,5 +34,4 @@ struct WorkerException : public std::exception { std::string message; }; -} // namespace data -} // namespace torch +} // namespace torch::data diff --git a/torch/csrc/api/include/torch/detail/TensorDataContainer.h b/torch/csrc/api/include/torch/detail/TensorDataContainer.h index 4da7cb1f4460f..d5e8f0f9234b4 100644 --- a/torch/csrc/api/include/torch/detail/TensorDataContainer.h +++ b/torch/csrc/api/include/torch/detail/TensorDataContainer.h @@ -16,9 +16,7 @@ #include -namespace torch { - -namespace detail { +namespace torch::detail { enum class TensorDataContainerType { Scalar, InitList, Tensor }; @@ -110,7 +108,6 @@ struct TensorDataContainer { // NOTE: For tensors with zero-size dimensions (e.g. `torch::tensor({{}, // {}})`), the innermost empty braced-init-list `{}` matches the default // constructor of the innermost `TensorDataContainer`. - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TensorDataContainer() : sizes_({0}), // NOTE: In Python, the dtype of tensors with zero-size dimensions (e.g. @@ -125,12 +122,9 @@ struct TensorDataContainer { scalar_type_(at::k##S), \ type_(TensorDataContainerType::Scalar), \ scalar_(value) {} - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) AT_FORALL_COMPLEX_TYPES(TENSOR) #undef TENSOR - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) TensorDataContainer(std::initializer_list init_list) : sizes_(), scalar_type_(init_list.begin()->scalar_type()), @@ -157,7 +151,7 @@ struct TensorDataContainer { elem.scalar_type()); } sizes_.reserve(first_elem.sizes().size() + 1); - sizes_.push_back(init_list.size()); + sizes_.push_back(static_cast(init_list.size())); sizes_.insert( sizes_.end(), first_elem.sizes().begin(), first_elem.sizes().end()); } @@ -174,9 +168,7 @@ struct TensorDataContainer { tensor_ = at::tensor(values, at::dtype(scalar_type_).device(at::kCPU)); \ } \ } - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, TENSOR) - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) AT_FORALL_COMPLEX_TYPES(TENSOR) #undef TENSOR @@ -194,9 +186,7 @@ struct TensorDataContainer { #define TENSOR(T, S) \ TensorDataContainer(const std::vector& values) \ : TensorDataContainer(at::ArrayRef(values)) {} - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) AT_FORALL_SCALAR_TYPES_AND2(Half, BFloat16, TENSOR) - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) AT_FORALL_COMPLEX_TYPES(TENSOR) #undef TENSOR @@ -328,7 +318,7 @@ struct TensorDataContainer { " in its first dimension, but got Tensor with size ", tensor.sizes()[0], " in its first dimension"); - size_t index = 0; + int64_t index = 0; for (const auto& elem : init_list_) { at::Tensor slice = tensor[index]; elem.fill_tensor(slice); @@ -358,6 +348,4 @@ inline std::ostream& operator<<( return stream; } -} // namespace detail - -} // namespace torch +} // namespace torch::detail diff --git a/torch/csrc/api/include/torch/detail/static.h b/torch/csrc/api/include/torch/detail/static.h index c85fc7fff4b4d..d855f0007498c 100644 --- a/torch/csrc/api/include/torch/detail/static.h +++ b/torch/csrc/api/include/torch/detail/static.h @@ -6,14 +6,11 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { class Module; -} // namespace nn -} // namespace torch +} // namespace torch::nn -namespace torch { -namespace detail { +namespace torch::detail { /// Detects if a type T has a forward() method. template struct has_forward { @@ -43,9 +40,10 @@ struct has_forward { template constexpr bool check_not_lvalue_references() { - return (!std::is_lvalue_reference::value || - std::is_const::type>::value) && - check_not_lvalue_references(); + return ( + !std::is_lvalue_reference_v || + std::is_const_v>)&&check_not_lvalue_references(); } template <> @@ -55,11 +53,8 @@ inline constexpr bool check_not_lvalue_references() { /// A type trait whose `value` member is true if `M` derives from `Module`. template -using is_module = - std::is_base_of::type>; +using is_module = std::is_base_of>; template -using enable_if_module_t = - typename std::enable_if::value, T>::type; -} // namespace detail -} // namespace torch +using enable_if_module_t = std::enable_if_t::value, T>; +} // namespace torch::detail diff --git a/torch/csrc/api/include/torch/enum.h b/torch/csrc/api/include/torch/enum.h index 02d409a3d64c1..195b776b672d8 100644 --- a/torch/csrc/api/include/torch/enum.h +++ b/torch/csrc/api/include/torch/enum.h @@ -140,8 +140,7 @@ TORCH_ENUM_DECLARE(GRU) TORCH_ENUM_DECLARE(Valid) TORCH_ENUM_DECLARE(Same) -namespace torch { -namespace enumtype { +namespace torch::enumtype { struct _compute_enum_name { TORCH_ENUM_PRETTY_PRINT(Linear) @@ -208,5 +207,4 @@ at::Reduction::Reduction reduction_get_enum(V variant_enum) { } } -} // namespace enumtype -} // namespace torch +} // namespace torch::enumtype diff --git a/torch/csrc/api/include/torch/expanding_array.h b/torch/csrc/api/include/torch/expanding_array.h index 62c12d2e0ac8b..e7c834626dd7f 100644 --- a/torch/csrc/api/include/torch/expanding_array.h +++ b/torch/csrc/api/include/torch/expanding_array.h @@ -27,18 +27,18 @@ class ExpandingArray { /// the length is checked against the `ExpandingArray`'s extent parameter `D` /// at runtime. /*implicit*/ ExpandingArray(std::initializer_list list) - : ExpandingArray(at::ArrayRef(list)) {} + : ExpandingArray(c10::ArrayRef(list)) {} /// Constructs an `ExpandingArray` from an `std::vector`. The extent of /// the length is checked against the `ExpandingArray`'s extent parameter `D` /// at runtime. /*implicit*/ ExpandingArray(std::vector vec) - : ExpandingArray(at::ArrayRef(vec)) {} + : ExpandingArray(c10::ArrayRef(vec)) {} - /// Constructs an `ExpandingArray` from an `at::ArrayRef`. The extent of + /// Constructs an `ExpandingArray` from an `c10::ArrayRef`. The extent of /// the length is checked against the `ExpandingArray`'s extent parameter `D` /// at runtime. - /*implicit*/ ExpandingArray(at::ArrayRef values) { + /*implicit*/ ExpandingArray(c10::ArrayRef values) { // clang-format off TORCH_CHECK( values.size() == D, @@ -78,7 +78,7 @@ class ExpandingArray { } /// Returns an `ArrayRef` to the underlying `std::array`. - operator at::ArrayRef() const { + operator c10::ArrayRef() const { return values_; } @@ -100,7 +100,7 @@ std::ostream& operator<<( if (expanding_array.size() == 1) { return stream << expanding_array->at(0); } - return stream << static_cast>(expanding_array); + return stream << static_cast>(expanding_array); } /// A utility class that accepts either a container of `D`-many @@ -118,18 +118,18 @@ class ExpandingArrayWithOptionalElem /// of the underlying type `T`. The extent of the length is checked against /// the `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime. /*implicit*/ ExpandingArrayWithOptionalElem(std::initializer_list list) - : ExpandingArrayWithOptionalElem(at::ArrayRef(list)) {} + : ExpandingArrayWithOptionalElem(c10::ArrayRef(list)) {} /// Constructs an `ExpandingArrayWithOptionalElem` from an `std::vector` of /// the underlying type `T`. The extent of the length is checked against the /// `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime. /*implicit*/ ExpandingArrayWithOptionalElem(std::vector vec) - : ExpandingArrayWithOptionalElem(at::ArrayRef(vec)) {} + : ExpandingArrayWithOptionalElem(c10::ArrayRef(vec)) {} - /// Constructs an `ExpandingArrayWithOptionalElem` from an `at::ArrayRef` of + /// Constructs an `ExpandingArrayWithOptionalElem` from an `c10::ArrayRef` of /// the underlying type `T`. The extent of the length is checked against the /// `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime. - /*implicit*/ ExpandingArrayWithOptionalElem(at::ArrayRef values) + /*implicit*/ ExpandingArrayWithOptionalElem(c10::ArrayRef values) : ExpandingArray>(0) { // clang-format off TORCH_CHECK( @@ -174,7 +174,7 @@ std::ostream& operator<<( str_array.emplace_back( elem.has_value() ? c10::str(elem.value()) : "None"); } - stream << at::ArrayRef(str_array); + stream << c10::ArrayRef(str_array); } return stream; } diff --git a/torch/csrc/api/include/torch/fft.h b/torch/csrc/api/include/torch/fft.h index ef6d9b1bc2362..00db0df9428a6 100644 --- a/torch/csrc/api/include/torch/fft.h +++ b/torch/csrc/api/include/torch/fft.h @@ -1,9 +1,11 @@ #pragma once #include +#include -namespace torch { -namespace fft { +#include + +namespace torch::fft { /// Computes the 1 dimensional fast Fourier transform over a given dimension. /// See https://pytorch.org/docs/main/fft.html#torch.fft.fft. @@ -18,7 +20,7 @@ inline Tensor fft( std::optional n = std::nullopt, int64_t dim = -1, std::optional norm = std::nullopt) { - return torch::fft_fft_symint(self, n, dim, norm); + return torch::fft_fft_symint(self, std::move(n), dim, norm); } /// Computes the 1 dimensional inverse Fourier transform over a given dimension. @@ -34,7 +36,7 @@ inline Tensor ifft( std::optional n = std::nullopt, int64_t dim = -1, std::optional norm = std::nullopt) { - return torch::fft_ifft_symint(self, n, dim, norm); + return torch::fft_ifft_symint(self, std::move(n), dim, norm); } /// Computes the 2-dimensional fast Fourier transform over the given dimensions. @@ -115,7 +117,7 @@ inline Tensor rfft( std::optional n = std::nullopt, int64_t dim = -1, std::optional norm = std::nullopt) { - return torch::fft_rfft_symint(self, n, dim, norm); + return torch::fft_rfft_symint(self, std::move(n), dim, norm); } /// Computes the inverse of torch.fft.rfft @@ -134,7 +136,7 @@ inline Tensor irfft( std::optional n = std::nullopt, int64_t dim = -1, std::optional norm = std::nullopt) { - return torch::fft_irfft_symint(self, n, dim, norm); + return torch::fft_irfft_symint(self, std::move(n), dim, norm); } /// Computes the 2-dimensional FFT of real input. Returns a onesided Hermitian @@ -218,7 +220,7 @@ inline Tensor hfft( std::optional n = std::nullopt, int64_t dim = -1, std::optional norm = std::nullopt) { - return torch::fft_hfft_symint(self, n, dim, norm); + return torch::fft_hfft_symint(self, std::move(n), dim, norm); } /// Computes the inverse FFT of a real-valued Fourier domain signal. @@ -237,7 +239,7 @@ inline Tensor ihfft( std::optional n = std::nullopt, int64_t dim = -1, std::optional norm = std::nullopt) { - return torch::fft_ihfft_symint(self, n, dim, norm); + return torch::fft_ihfft_symint(self, std::move(n), dim, norm); } /// Computes the 2-dimensional FFT of a Hermitian symmetric input signal. @@ -385,5 +387,4 @@ inline Tensor ifftshift( return torch::fft_ifftshift(x, dim); } -} // namespace fft -} // namespace torch +} // namespace torch::fft diff --git a/torch/csrc/api/include/torch/jit.h b/torch/csrc/api/include/torch/jit.h index 703eed0d04248..19651f23ba381 100644 --- a/torch/csrc/api/include/torch/jit.h +++ b/torch/csrc/api/include/torch/jit.h @@ -6,8 +6,7 @@ #include #include -namespace torch { -namespace jit { +namespace torch::jit { /// Compiles script code into an executable graph. /// @@ -32,5 +31,4 @@ namespace jit { /// \endrst TORCH_API std::shared_ptr compile(const std::string& source); -} // namespace jit -} // namespace torch +} // namespace torch::jit diff --git a/torch/csrc/api/include/torch/linalg.h b/torch/csrc/api/include/torch/linalg.h deleted file mode 100644 index 60cf06f6fedbf..0000000000000 --- a/torch/csrc/api/include/torch/linalg.h +++ /dev/null @@ -1,1065 +0,0 @@ -#pragma once - -#include - -namespace torch { -namespace linalg { - -#ifndef DOXYGEN_SHOULD_SKIP_THIS -namespace detail { - -inline Tensor cholesky(const Tensor& self) { - return torch::linalg_cholesky(self); -} - -inline Tensor cholesky_out(Tensor& result, const Tensor& self) { - return torch::linalg_cholesky_out(result, self); -} - -inline Tensor det(const Tensor& self) { - return torch::linalg_det(self); -} - -inline std::tuple slogdet(const Tensor& input) { - return torch::linalg_slogdet(input); -} - -inline std::tuple slogdet_out( - Tensor& sign, - Tensor& logabsdet, - const Tensor& input) { - return torch::linalg_slogdet_out(sign, logabsdet, input); -} - -inline std::tuple eig(const Tensor& self) { - return torch::linalg_eig(self); -} - -inline std::tuple eig_out( - Tensor& eigvals, - Tensor& eigvecs, - const Tensor& self) { - return torch::linalg_eig_out(eigvals, eigvecs, self); -} - -inline Tensor eigvals(const Tensor& self) { - return torch::linalg_eigvals(self); -} - -inline Tensor& eigvals_out(Tensor& result, const Tensor& self) { - return torch::linalg_eigvals_out(result, self); -} - -inline std::tuple eigh( - const Tensor& self, - c10::string_view uplo) { - return torch::linalg_eigh(self, uplo); -} - -inline std::tuple eigh_out( - Tensor& eigvals, - Tensor& eigvecs, - const Tensor& self, - c10::string_view uplo) { - return torch::linalg_eigh_out(eigvals, eigvecs, self, uplo); -} - -inline Tensor eigvalsh(const Tensor& self, c10::string_view uplo) { - return torch::linalg_eigvalsh(self, uplo); -} - -inline Tensor& eigvalsh_out( - Tensor& result, - const Tensor& self, - c10::string_view uplo) { - return torch::linalg_eigvalsh_out(result, self, uplo); -} - -inline Tensor householder_product(const Tensor& input, const Tensor& tau) { - return torch::linalg_householder_product(input, tau); -} - -inline Tensor& householder_product_out( - Tensor& result, - const Tensor& input, - const Tensor& tau) { - return torch::linalg_householder_product_out(result, input, tau); -} - -inline std::tuple lu_factor( - const Tensor& self, - const bool pivot) { - return torch::linalg_lu_factor(self, pivot); -} - -inline std::tuple lu_factor_out( - Tensor& LU, - Tensor& pivots, - const Tensor& self, - const bool pivot) { - return torch::linalg_lu_factor_out(LU, pivots, self, pivot); -} - -inline std::tuple lu( - const Tensor& self, - const bool pivot) { - return torch::linalg_lu(self, pivot); -} - -inline std::tuple lu_out( - Tensor& P, - Tensor& L, - Tensor& U, - const Tensor& self, - const bool pivot) { - return torch::linalg_lu_out(P, L, U, self, pivot); -} - -inline std::tuple lstsq( - const Tensor& self, - const Tensor& b, - std::optional cond, - std::optional driver) { - return torch::linalg_lstsq(self, b, cond, driver); -} - -inline Tensor matrix_exp(const Tensor& self) { - return torch::linalg_matrix_exp(self); -} - -inline Tensor norm( - const Tensor& self, - const std::optional& opt_ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return torch::linalg_norm(self, opt_ord, opt_dim, keepdim, opt_dtype); -} - -inline Tensor norm( - const Tensor& self, - c10::string_view ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return torch::linalg_norm(self, ord, opt_dim, keepdim, opt_dtype); -} - -inline Tensor& norm_out( - Tensor& result, - const Tensor& self, - const std::optional& opt_ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return torch::linalg_norm_out( - result, self, opt_ord, opt_dim, keepdim, opt_dtype); -} - -inline Tensor& norm_out( - Tensor& result, - const Tensor& self, - c10::string_view ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return torch::linalg_norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); -} - -inline Tensor vector_norm( - const Tensor& self, - Scalar ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return torch::linalg_vector_norm(self, ord, opt_dim, keepdim, opt_dtype); -} - -inline Tensor& vector_norm_out( - Tensor& result, - const Tensor& self, - Scalar ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return torch::linalg_vector_norm_out( - result, self, ord, opt_dim, keepdim, opt_dtype); -} - -inline Tensor matrix_norm( - const Tensor& self, - const Scalar& ord, - IntArrayRef dim, - bool keepdim, - std::optional dtype) { - return torch::linalg_matrix_norm(self, ord, dim, keepdim, dtype); -} - -inline Tensor& matrix_norm_out( - const Tensor& self, - const Scalar& ord, - IntArrayRef dim, - bool keepdim, - std::optional dtype, - Tensor& result) { - return torch::linalg_matrix_norm_out(result, self, ord, dim, keepdim, dtype); -} - -inline Tensor matrix_norm( - const Tensor& self, - std::string ord, - IntArrayRef dim, - bool keepdim, - std::optional dtype) { - return torch::linalg_matrix_norm(self, ord, dim, keepdim, dtype); -} - -inline Tensor& matrix_norm_out( - const Tensor& self, - std::string ord, - IntArrayRef dim, - bool keepdim, - std::optional dtype, - Tensor& result) { - return torch::linalg_matrix_norm_out(result, self, ord, dim, keepdim, dtype); -} - -inline Tensor matrix_power(const Tensor& self, int64_t n) { - return torch::linalg_matrix_power(self, n); -} - -inline Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) { - return torch::linalg_matrix_power_out(result, self, n); -} - -inline Tensor matrix_rank(const Tensor& input, double tol, bool hermitian) { - return torch::linalg_matrix_rank(input, tol, hermitian); -} - -inline Tensor matrix_rank( - const Tensor& input, - const Tensor& tol, - bool hermitian) { - return torch::linalg_matrix_rank(input, tol, hermitian); -} - -inline Tensor matrix_rank( - const Tensor& input, - std::optional atol, - std::optional rtol, - bool hermitian) { - return torch::linalg_matrix_rank(input, atol, rtol, hermitian); -} - -inline Tensor matrix_rank( - const Tensor& input, - const std::optional& atol, - const std::optional& rtol, - bool hermitian) { - return torch::linalg_matrix_rank(input, atol, rtol, hermitian); -} - -inline Tensor& matrix_rank_out( - Tensor& result, - const Tensor& input, - double tol, - bool hermitian) { - return torch::linalg_matrix_rank_out(result, input, tol, hermitian); -} - -inline Tensor& matrix_rank_out( - Tensor& result, - const Tensor& input, - const Tensor& tol, - bool hermitian) { - return torch::linalg_matrix_rank_out(result, input, tol, hermitian); -} - -inline Tensor& matrix_rank_out( - Tensor& result, - const Tensor& input, - std::optional atol, - std::optional rtol, - bool hermitian) { - return torch::linalg_matrix_rank_out(result, input, atol, rtol, hermitian); -} - -inline Tensor& matrix_rank_out( - Tensor& result, - const Tensor& input, - const std::optional& atol, - const std::optional& rtol, - bool hermitian) { - return torch::linalg_matrix_rank_out(result, input, atol, rtol, hermitian); -} - -inline Tensor multi_dot(TensorList tensors) { - return torch::linalg_multi_dot(tensors); -} - -inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) { - return torch::linalg_multi_dot_out(result, tensors); -} - -inline Tensor pinv(const Tensor& input, double rcond, bool hermitian) { - return torch::linalg_pinv(input, rcond, hermitian); -} - -inline Tensor& pinv_out( - Tensor& result, - const Tensor& input, - double rcond, - bool hermitian) { - return torch::linalg_pinv_out(result, input, rcond, hermitian); -} - -inline std::tuple qr( - const Tensor& input, - c10::string_view mode) { - return torch::linalg_qr(input, mode); -} - -inline std::tuple qr_out( - Tensor& Q, - Tensor& R, - const Tensor& input, - c10::string_view mode) { - return torch::linalg_qr_out(Q, R, input, mode); -} - -inline std::tuple solve_ex( - const Tensor& input, - const Tensor& other, - bool left, - bool check_errors) { - return torch::linalg_solve_ex(input, other, left, check_errors); -} - -inline std::tuple solve_ex_out( - Tensor& result, - Tensor& info, - const Tensor& input, - const Tensor& other, - bool left, - bool check_errors) { - return torch::linalg_solve_ex_out( - result, info, input, other, left, check_errors); -} - -inline Tensor solve(const Tensor& input, const Tensor& other, bool left) { - return torch::linalg_solve(input, other, left); -} - -inline Tensor& solve_out( - Tensor& result, - const Tensor& input, - const Tensor& other, - bool left) { - return torch::linalg_solve_out(result, input, other, left); -} - -inline Tensor solve_triangular( - const Tensor& input, - const Tensor& other, - bool upper, - bool left, - bool unitriangular) { - return torch::linalg_solve_triangular( - input, other, upper, left, unitriangular); -} - -inline Tensor& solve_triangular_out( - Tensor& result, - const Tensor& input, - const Tensor& other, - bool upper, - bool left, - bool unitriangular) { - return torch::linalg_solve_triangular_out( - result, input, other, upper, left, unitriangular); -} - -inline std::tuple svd( - const Tensor& input, - bool full_matrices, - std::optional driver) { - return torch::linalg_svd(input, full_matrices, driver); -} - -inline std::tuple svd_out( - Tensor& U, - Tensor& S, - Tensor& Vh, - const Tensor& input, - bool full_matrices, - std::optional driver) { - return torch::linalg_svd_out(U, S, Vh, input, full_matrices, driver); -} - -inline Tensor svdvals( - const Tensor& input, - std::optional driver) { - return torch::linalg_svdvals(input, driver); -} - -inline Tensor& svdvals_out( - Tensor& result, - const Tensor& input, - std::optional driver) { - return torch::linalg_svdvals_out(result, input, driver); -} - -inline Tensor tensorinv(const Tensor& self, int64_t ind) { - return torch::linalg_tensorinv(self, ind); -} - -inline Tensor& tensorinv_out(Tensor& result, const Tensor& self, int64_t ind) { - return torch::linalg_tensorinv_out(result, self, ind); -} - -inline Tensor tensorsolve( - const Tensor& self, - const Tensor& other, - OptionalIntArrayRef dims) { - return torch::linalg_tensorsolve(self, other, dims); -} - -inline Tensor& tensorsolve_out( - Tensor& result, - const Tensor& self, - const Tensor& other, - OptionalIntArrayRef dims) { - return torch::linalg_tensorsolve_out(result, self, other, dims); -} - -inline Tensor inv(const Tensor& input) { - return torch::linalg_inv(input); -} - -inline Tensor& inv_out(Tensor& result, const Tensor& input) { - return torch::linalg_inv_out(result, input); -} - -} // namespace detail -#endif /* DOXYGEN_SHOULD_SKIP_THIS */ - -/// Cholesky decomposition -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.cholesky -/// -/// Example: -/// ``` -/// auto A = torch::randn({4, 4}); -/// auto A = torch::matmul(A, A.t()); -/// auto L = torch::linalg::cholesky(A); -/// assert(torch::allclose(torch::matmul(L, L.t()), A)); -/// ``` -inline Tensor cholesky(const Tensor& self) { - return detail::cholesky(self); -} - -inline Tensor cholesky_out(Tensor& result, const Tensor& self) { - return detail::cholesky_out(result, self); -} - -// C10_DEPRECATED_MESSAGE("linalg_det is deprecated, use det instead.") -inline Tensor linalg_det(const Tensor& self) { - return detail::det(self); -} - -/// See the documentation of torch.linalg.det -inline Tensor det(const Tensor& self) { - return detail::det(self); -} - -/// Computes the sign and (natural) logarithm of the determinant -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.slogdet -inline std::tuple slogdet(const Tensor& input) { - return detail::slogdet(input); -} - -inline std::tuple slogdet_out( - Tensor& sign, - Tensor& logabsdet, - const Tensor& input) { - return detail::slogdet_out(sign, logabsdet, input); -} - -/// Computes eigenvalues and eigenvectors of non-symmetric/non-hermitian -/// matrices -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.eig -inline std::tuple eig(const Tensor& self) { - return detail::eig(self); -} - -inline std::tuple eig_out( - Tensor& eigvals, - Tensor& eigvecs, - const Tensor& self) { - return detail::eig_out(eigvals, eigvecs, self); -} - -/// Computes eigenvalues of non-symmetric/non-hermitian matrices -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.eigvals -inline Tensor eigvals(const Tensor& self) { - return detail::eigvals(self); -} - -inline Tensor& eigvals_out(Tensor& result, const Tensor& self) { - return detail::eigvals_out(result, self); -} - -/// Computes eigenvalues and eigenvectors -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.eigh -inline std::tuple eigh( - const Tensor& self, - c10::string_view uplo) { - return detail::eigh(self, uplo); -} - -inline std::tuple eigh_out( - Tensor& eigvals, - Tensor& eigvecs, - const Tensor& self, - c10::string_view uplo) { - return detail::eigh_out(eigvals, eigvecs, self, uplo); -} - -/// Computes eigenvalues -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.eigvalsh -inline Tensor eigvalsh(const Tensor& self, c10::string_view uplo) { - return detail::eigvalsh(self, uplo); -} - -inline Tensor& eigvalsh_out( - Tensor& result, - const Tensor& self, - c10::string_view uplo) { - return detail::eigvalsh_out(result, self, uplo); -} - -/// Computes the product of Householder matrices -/// -/// See -/// https://pytorch.org/docs/main/linalg.html#torch.linalg.householder_product -inline Tensor householder_product(const Tensor& input, const Tensor& tau) { - return detail::householder_product(input, tau); -} - -inline Tensor& householder_product_out( - Tensor& result, - const Tensor& input, - const Tensor& tau) { - return detail::householder_product_out(result, input, tau); -} - -inline std::tuple lstsq( - const Tensor& self, - const Tensor& b, - std::optional cond, - std::optional driver) { - return detail::lstsq(self, b, cond, driver); -} - -/// Computes the matrix exponential -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.matrix_exp -inline Tensor matrix_exp(const Tensor& input) { - return detail::matrix_exp(input); -} - -// C10_DEPRECATED_MESSAGE("linalg_norm is deprecated, use norm instead.") -inline Tensor linalg_norm( - const Tensor& self, - const std::optional& opt_ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype); -} - -// C10_DEPRECATED_MESSAGE("linalg_norm is deprecated, use norm instead.") -inline Tensor linalg_norm( - const Tensor& self, - c10::string_view ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return detail::norm(self, ord, opt_dim, keepdim, opt_dtype); -} - -// C10_DEPRECATED_MESSAGE("linalg_norm_out is deprecated, use norm_out -// instead.") -inline Tensor& linalg_norm_out( - Tensor& result, - const Tensor& self, - const std::optional& opt_ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return detail::norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); -} - -// C10_DEPRECATED_MESSAGE("linalg_norm_out is deprecated, use norm_out -// instead.") -inline Tensor& linalg_norm_out( - Tensor& result, - const Tensor& self, - c10::string_view ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); -} - -/// Computes the LU factorization with partial pivoting -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.lu_factor -inline std::tuple lu_factor( - const Tensor& input, - const bool pivot = true) { - return detail::lu_factor(input, pivot); -} - -inline std::tuple lu_factor_out( - Tensor& LU, - Tensor& pivots, - const Tensor& self, - const bool pivot = true) { - return detail::lu_factor_out(LU, pivots, self, pivot); -} - -/// Computes the LU factorization with partial pivoting -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.lu -inline std::tuple lu( - const Tensor& input, - const bool pivot = true) { - return detail::lu(input, pivot); -} - -inline std::tuple lu_out( - Tensor& P, - Tensor& L, - Tensor& U, - const Tensor& self, - const bool pivot = true) { - return detail::lu_out(P, L, U, self, pivot); -} - -inline Tensor norm( - const Tensor& self, - const std::optional& opt_ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return detail::norm(self, opt_ord, opt_dim, keepdim, opt_dtype); -} - -inline Tensor norm( - const Tensor& self, - std::string ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return detail::norm(self, ord, opt_dim, keepdim, opt_dtype); -} - -inline Tensor& norm_out( - Tensor& result, - const Tensor& self, - const std::optional& opt_ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return detail::norm_out(result, self, opt_ord, opt_dim, keepdim, opt_dtype); -} - -inline Tensor& norm_out( - Tensor& result, - const Tensor& self, - std::string ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return detail::norm_out(result, self, ord, opt_dim, keepdim, opt_dtype); -} - -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.vector_norm -inline Tensor vector_norm( - const Tensor& self, - Scalar ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return detail::vector_norm(self, ord, opt_dim, keepdim, opt_dtype); -} - -inline Tensor& vector_norm_out( - Tensor& result, - const Tensor& self, - Scalar ord, - OptionalIntArrayRef opt_dim, - bool keepdim, - std::optional opt_dtype) { - return detail::vector_norm_out( - result, self, ord, opt_dim, keepdim, opt_dtype); -} - -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.matrix_norm -inline Tensor matrix_norm( - const Tensor& self, - const Scalar& ord, - IntArrayRef dim, - bool keepdim, - std::optional dtype) { - return detail::matrix_norm(self, ord, dim, keepdim, dtype); -} - -inline Tensor& matrix_norm_out( - const Tensor& self, - const Scalar& ord, - IntArrayRef dim, - bool keepdim, - std::optional dtype, - Tensor& result) { - return detail::matrix_norm_out(self, ord, dim, keepdim, dtype, result); -} - -inline Tensor matrix_norm( - const Tensor& self, - std::string ord, - IntArrayRef dim, - bool keepdim, - std::optional dtype) { - return detail::matrix_norm(self, ord, dim, keepdim, dtype); -} - -inline Tensor& matrix_norm_out( - const Tensor& self, - std::string ord, - IntArrayRef dim, - bool keepdim, - std::optional dtype, - Tensor& result) { - return detail::matrix_norm_out(self, ord, dim, keepdim, dtype, result); -} - -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.matrix_power -inline Tensor matrix_power(const Tensor& self, int64_t n) { - return detail::matrix_power(self, n); -} - -inline Tensor& matrix_power_out(const Tensor& self, int64_t n, Tensor& result) { - return detail::matrix_power_out(self, n, result); -} - -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.matrix_rank -inline Tensor matrix_rank(const Tensor& input, double tol, bool hermitian) { - return detail::matrix_rank(input, tol, hermitian); -} - -inline Tensor matrix_rank( - const Tensor& input, - const Tensor& tol, - bool hermitian) { - return detail::matrix_rank(input, tol, hermitian); -} - -inline Tensor matrix_rank( - const Tensor& input, - std::optional atol, - std::optional rtol, - bool hermitian) { - return detail::matrix_rank(input, atol, rtol, hermitian); -} - -inline Tensor matrix_rank( - const Tensor& input, - const std::optional& atol, - const std::optional& rtol, - bool hermitian) { - return detail::matrix_rank(input, atol, rtol, hermitian); -} - -inline Tensor& matrix_rank_out( - Tensor& result, - const Tensor& input, - double tol, - bool hermitian) { - return detail::matrix_rank_out(result, input, tol, hermitian); -} - -inline Tensor& matrix_rank_out( - Tensor& result, - const Tensor& input, - const Tensor& tol, - bool hermitian) { - return detail::matrix_rank_out(result, input, tol, hermitian); -} - -inline Tensor& matrix_rank_out( - Tensor& result, - const Tensor& input, - std::optional atol, - std::optional rtol, - bool hermitian) { - return detail::matrix_rank_out(result, input, atol, rtol, hermitian); -} - -inline Tensor& matrix_rank_out( - Tensor& result, - const Tensor& input, - const std::optional& atol, - const std::optional& rtol, - bool hermitian) { - return detail::matrix_rank_out(result, input, atol, rtol, hermitian); -} - -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.multi_dot -inline Tensor multi_dot(TensorList tensors) { - return detail::multi_dot(tensors); -} - -inline Tensor& multi_dot_out(TensorList tensors, Tensor& result) { - return detail::multi_dot_out(tensors, result); -} - -/// Computes the pseudo-inverse -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.pinv -inline Tensor pinv( - const Tensor& input, - double rcond = 1e-15, - bool hermitian = false) { - return detail::pinv(input, rcond, hermitian); -} - -inline Tensor& pinv_out( - Tensor& result, - const Tensor& input, - double rcond = 1e-15, - bool hermitian = false) { - return detail::pinv_out(result, input, rcond, hermitian); -} - -/// Computes the QR decomposition -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.qr -inline std::tuple qr( - const Tensor& input, - c10::string_view mode = "reduced") { - // C++17 Change the initialisation to "reduced"sv - // Same for qr_out - return detail::qr(input, mode); -} - -inline std::tuple qr_out( - Tensor& Q, - Tensor& R, - const Tensor& input, - c10::string_view mode = "reduced") { - return detail::qr_out(Q, R, input, mode); -} - -/// Computes the LDL decomposition -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.ldl_factor_ex -inline std::tuple ldl_factor_ex( - const Tensor& input, - bool hermitian, - bool check_errors) { - return torch::linalg_ldl_factor_ex(input, hermitian, check_errors); -} - -inline std::tuple ldl_factor_ex_out( - Tensor& LD, - Tensor& pivots, - Tensor& info, - const Tensor& input, - bool hermitian, - bool check_errors) { - return torch::linalg_ldl_factor_ex_out( - LD, pivots, info, input, hermitian, check_errors); -} - -/// Solve a system of linear equations using the LDL decomposition -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.ldl_solve -inline Tensor ldl_solve( - const Tensor& LD, - const Tensor& pivots, - const Tensor& B, - bool hermitian) { - return torch::linalg_ldl_solve(LD, pivots, B, hermitian); -} - -inline Tensor& ldl_solve_out( - Tensor& result, - const Tensor& LD, - const Tensor& pivots, - const Tensor& B, - bool hermitian) { - return torch::linalg_ldl_solve_out(result, LD, pivots, B, hermitian); -} - -/// Solves a system linear system AX = B -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.solve_ex -inline std::tuple solve_ex( - const Tensor& input, - const Tensor& other, - bool left, - bool check_errors) { - return detail::solve_ex(input, other, left, check_errors); -} - -inline std::tuple solve_ex_out( - Tensor& result, - Tensor& info, - const Tensor& input, - const Tensor& other, - bool left, - bool check_errors) { - return detail::solve_ex_out(result, info, input, other, left, check_errors); -} - -/// Computes a tensor `x` such that `matmul(input, x) = other`. -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.solve -inline Tensor solve(const Tensor& input, const Tensor& other, bool left) { - return detail::solve(input, other, left); -} - -inline Tensor& solve_out( - Tensor& result, - const Tensor& input, - const Tensor& other, - bool left) { - return detail::solve_out(result, input, other, left); -} - -/// Computes a solution of a linear system AX = B for input = A and other = B -/// whenever A is square upper or lower triangular and does not have zeros in -/// the diagonal -/// -/// See -/// https://pytorch.org/docs/main/linalg.html#torch.linalg.solve_triangular -inline Tensor solve_triangular( - const Tensor& input, - const Tensor& other, - bool upper, - bool left, - bool unitriangular) { - return detail::solve_triangular(input, other, upper, left, unitriangular); -} - -inline Tensor& solve_triangular_out( - Tensor& result, - const Tensor& input, - const Tensor& other, - bool upper, - bool left, - bool unitriangular) { - return detail::solve_triangular_out( - result, input, other, upper, left, unitriangular); -} - -/// Computes the singular values and singular vectors -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.svd -inline std::tuple svd( - const Tensor& input, - bool full_matrices, - std::optional driver) { - return detail::svd(input, full_matrices, driver); -} - -inline std::tuple svd_out( - Tensor& U, - Tensor& S, - Tensor& Vh, - const Tensor& input, - bool full_matrices, - std::optional driver) { - return detail::svd_out(U, S, Vh, input, full_matrices, driver); -} - -/// Computes the singular values -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.svdvals -inline Tensor svdvals( - const Tensor& input, - std::optional driver) { - return detail::svdvals(input, driver); -} - -inline Tensor& svdvals_out( - Tensor& result, - const Tensor& input, - std::optional driver) { - return detail::svdvals_out(result, input, driver); -} - -/// Computes the inverse of a tensor -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.tensorinv -/// -/// Example: -/// ``` -/// auto a = torch::eye(4*6).reshape({4, 6, 8, 3}); -/// int64_t ind = 2; -/// auto ainv = torch::linalg::tensorinv(a, ind); -/// ``` -inline Tensor tensorinv(const Tensor& self, int64_t ind) { - return detail::tensorinv(self, ind); -} - -inline Tensor& tensorinv_out(Tensor& result, const Tensor& self, int64_t ind) { - return detail::tensorinv_out(result, self, ind); -} - -/// Computes a tensor `x` such that `tensordot(input, x, dims=x.dim()) = other`. -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.tensorsolve -/// -/// Example: -/// ``` -/// auto a = torch::eye(2*3*4).reshape({2*3, 4, 2, 3, 4}); -/// auto b = torch::randn(2*3, 4); -/// auto x = torch::linalg::tensorsolve(a, b); -/// ``` -inline Tensor tensorsolve( - const Tensor& input, - const Tensor& other, - OptionalIntArrayRef dims) { - return detail::tensorsolve(input, other, dims); -} - -inline Tensor& tensorsolve_out( - Tensor& result, - const Tensor& input, - const Tensor& other, - OptionalIntArrayRef dims) { - return detail::tensorsolve_out(result, input, other, dims); -} - -/// Computes a tensor `inverse_input` such that `dot(input, inverse_input) = -/// eye(input.size(0))`. -/// -/// See https://pytorch.org/docs/main/linalg.html#torch.linalg.inv -inline Tensor inv(const Tensor& input) { - return detail::inv(input); -} - -inline Tensor& inv_out(Tensor& result, const Tensor& input) { - return detail::inv_out(result, input); -} - -} // namespace linalg -} // namespace torch diff --git a/torch/csrc/api/include/torch/mps.h b/torch/csrc/api/include/torch/mps.h index 1b2eabd6832ba..576b8835a413e 100644 --- a/torch/csrc/api/include/torch/mps.h +++ b/torch/csrc/api/include/torch/mps.h @@ -15,8 +15,7 @@ using MTLCommandBuffer_t = void*; using DispatchQueue_t = void*; #endif -namespace torch { -namespace mps { +namespace torch::mps { /// Returns true if MPS device is available. bool TORCH_API is_available(); @@ -40,5 +39,4 @@ MTLCommandBuffer_t TORCH_API get_command_buffer(); /// with the PyTorch MPS backend. DispatchQueue_t TORCH_API get_dispatch_queue(); -} // namespace mps -} // namespace torch +} // namespace torch::mps diff --git a/torch/csrc/api/include/torch/nested.h b/torch/csrc/api/include/torch/nested.h index 2e4365e0031cc..0340d1f2b34f4 100644 --- a/torch/csrc/api/include/torch/nested.h +++ b/torch/csrc/api/include/torch/nested.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace nested { +namespace torch::nested { /// Nested tensor /// @@ -91,5 +90,4 @@ inline at::Tensor to_padded_tensor( return at::nested_to_padded_tensor(self, padding, output_size); } -} // namespace nested -} // namespace torch +} // namespace torch::nested diff --git a/torch/csrc/api/include/torch/nn/functional/activation.h b/torch/csrc/api/include/torch/nn/functional/activation.h index 5ae6fcc317602..5073c62c52e78 100644 --- a/torch/csrc/api/include/torch/nn/functional/activation.h +++ b/torch/csrc/api/include/torch/nn/functional/activation.h @@ -10,9 +10,7 @@ #include #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { @@ -372,7 +370,7 @@ inline Tensor glu(const Tensor& input, const GLUFuncOptions& options = {}) { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { -inline Tensor gelu(const Tensor& input, string approximate) { +inline Tensor gelu(const Tensor& input, const string& approximate) { return torch::gelu(input, approximate); } } // namespace detail @@ -693,7 +691,7 @@ inline std::tuple multi_head_attention_forward( // encoder-decoder attention // This is inline in_proj function with in_proj_weight and in_proj_bias auto _b = in_proj_bias; - auto _start = 0; + int64_t _start = 0; auto _end = embed_dim; auto _w = in_proj_weight.slice(/*dim=*/0, _start, _end); if (_b.defined()) { @@ -720,7 +718,7 @@ inline std::tuple multi_head_attention_forward( } else { // This is inline in_proj function with in_proj_weight and in_proj_bias auto _b = in_proj_bias; - auto _start = 0; + int64_t _start = 0; auto _end = embed_dim; auto _w = in_proj_weight.slice(/*dim=*/0, _start, _end); if (_b.defined()) { @@ -903,8 +901,7 @@ inline std::tuple multi_head_attention_forward( attn_output_weights = attn_output_weights.view({bsz * num_heads, tgt_len, src_len}); } - // NOLINTNEXTLINE(bugprone-argument-comment) - attn_output_weights = F::softmax(attn_output_weights, /*dim=*/-1); + attn_output_weights = F::softmax(attn_output_weights, /*options=*/-1); attn_output_weights = F::dropout( attn_output_weights, F::DropoutFuncOptions().p(dropout_p).training(training)); @@ -961,6 +958,4 @@ inline std::tuple multi_head_attention_forward( options.average_attn_weights()); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/batchnorm.h b/torch/csrc/api/include/torch/nn/functional/batchnorm.h index bc6f141281b39..66d5a6bd69d0a 100644 --- a/torch/csrc/api/include/torch/nn/functional/batchnorm.h +++ b/torch/csrc/api/include/torch/nn/functional/batchnorm.h @@ -4,9 +4,7 @@ #include #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { @@ -78,6 +76,4 @@ inline Tensor batch_norm( options.eps()); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/conv.h b/torch/csrc/api/include/torch/nn/functional/conv.h index 8f85fb286731a..1c2b5b73c48dc 100644 --- a/torch/csrc/api/include/torch/nn/functional/conv.h +++ b/torch/csrc/api/include/torch/nn/functional/conv.h @@ -3,9 +3,7 @@ #include #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { @@ -296,6 +294,4 @@ inline Tensor conv_transpose3d( options.dilation()); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/distance.h b/torch/csrc/api/include/torch/nn/functional/distance.h index 84f6009fae9d7..c5cb133aa609b 100644 --- a/torch/csrc/api/include/torch/nn/functional/distance.h +++ b/torch/csrc/api/include/torch/nn/functional/distance.h @@ -2,9 +2,7 @@ #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { @@ -83,6 +81,4 @@ inline Tensor pdist(const Tensor& input, double p = 2.0) { return torch::pdist(input, p); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/dropout.h b/torch/csrc/api/include/torch/nn/functional/dropout.h index 6b7953a266c4d..d365ff8400477 100644 --- a/torch/csrc/api/include/torch/nn/functional/dropout.h +++ b/torch/csrc/api/include/torch/nn/functional/dropout.h @@ -4,9 +4,7 @@ #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { @@ -229,6 +227,4 @@ inline Tensor feature_alpha_dropout( std::move(input), options.p(), options.training(), options.inplace()); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/embedding.h b/torch/csrc/api/include/torch/nn/functional/embedding.h index 602268ab2eba3..fb8aa8d45b2b9 100644 --- a/torch/csrc/api/include/torch/nn/functional/embedding.h +++ b/torch/csrc/api/include/torch/nn/functional/embedding.h @@ -2,9 +2,7 @@ #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { inline Tensor one_hot(const Tensor& tensor, int64_t num_classes = -1) { return torch::one_hot(tensor, num_classes); @@ -133,8 +131,7 @@ inline Tensor embedding_bag( input_.dim()); } - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int mode_enum; + int mode_enum = 0; if (std::holds_alternative(mode)) { mode_enum = 0; } else if (std::holds_alternative(mode)) { @@ -206,6 +203,4 @@ inline Tensor embedding_bag( options.padding_idx()); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/fold.h b/torch/csrc/api/include/torch/nn/functional/fold.h index 4f1716b2881bc..23b19d0bb8d58 100644 --- a/torch/csrc/api/include/torch/nn/functional/fold.h +++ b/torch/csrc/api/include/torch/nn/functional/fold.h @@ -2,9 +2,7 @@ #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { @@ -97,6 +95,4 @@ inline Tensor unfold(const Tensor& input, const UnfoldFuncOptions& options) { options.stride()); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/instancenorm.h b/torch/csrc/api/include/torch/nn/functional/instancenorm.h index 17efaea7a5e55..92f9694650319 100644 --- a/torch/csrc/api/include/torch/nn/functional/instancenorm.h +++ b/torch/csrc/api/include/torch/nn/functional/instancenorm.h @@ -2,9 +2,7 @@ #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { @@ -58,6 +56,4 @@ inline Tensor instance_norm( options.eps()); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/linear.h b/torch/csrc/api/include/torch/nn/functional/linear.h index ffeafcd712af0..4d9e7fe6d4b7a 100644 --- a/torch/csrc/api/include/torch/nn/functional/linear.h +++ b/torch/csrc/api/include/torch/nn/functional/linear.h @@ -2,9 +2,7 @@ #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { inline Tensor bilinear( const Tensor& input1, @@ -32,6 +30,4 @@ inline Tensor linear( } } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/loss.h b/torch/csrc/api/include/torch/nn/functional/loss.h index 6a425e606caf2..405e224a14648 100644 --- a/torch/csrc/api/include/torch/nn/functional/loss.h +++ b/torch/csrc/api/include/torch/nn/functional/loss.h @@ -4,9 +4,7 @@ #include #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { @@ -47,8 +45,7 @@ inline Tensor kl_div( const Tensor& target, KLDivFuncOptions::reduction_t reduction, bool log_target = false) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - torch::Reduction::Reduction reduction_enum; + torch::Reduction::Reduction reduction_enum{}; if (std::holds_alternative(reduction)) { TORCH_WARN( @@ -1039,6 +1036,4 @@ inline Tensor binary_cross_entropy_with_logits( options.pos_weight()); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/normalization.h b/torch/csrc/api/include/torch/nn/functional/normalization.h index 965cfcd9ac83f..3df0189890864 100644 --- a/torch/csrc/api/include/torch/nn/functional/normalization.h +++ b/torch/csrc/api/include/torch/nn/functional/normalization.h @@ -5,9 +5,7 @@ #include #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { @@ -206,6 +204,4 @@ inline Tensor group_norm( options.eps()); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/padding.h b/torch/csrc/api/include/torch/nn/functional/padding.h index 1bb6f95382904..5ef8b6ff34492 100644 --- a/torch/csrc/api/include/torch/nn/functional/padding.h +++ b/torch/csrc/api/include/torch/nn/functional/padding.h @@ -3,9 +3,7 @@ #include #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { @@ -53,6 +51,4 @@ inline Tensor pad(const Tensor& input, const PadFuncOptions& options) { return detail::pad(input, options.pad(), options.mode(), options.value()); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h b/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h index a245002428e2d..4d005f3568969 100644 --- a/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h +++ b/torch/csrc/api/include/torch/nn/functional/pixelshuffle.h @@ -2,9 +2,7 @@ #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { @@ -42,6 +40,4 @@ inline Tensor pixel_unshuffle( return detail::pixel_unshuffle(input, options.downscale_factor()); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/pooling.h b/torch/csrc/api/include/torch/nn/functional/pooling.h index 798467c0e0a68..72aaca76f6f4d 100644 --- a/torch/csrc/api/include/torch/nn/functional/pooling.h +++ b/torch/csrc/api/include/torch/nn/functional/pooling.h @@ -5,9 +5,7 @@ #include #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace detail { @@ -1057,8 +1055,8 @@ inline Tensor lp_pool2d( ExpandingArray<2> kernel_size, ExpandingArray<2> stride, bool ceil_mode) { - int kw = (*kernel_size)[0]; - int kh = (*kernel_size)[1]; + auto kw = (*kernel_size)[0]; + auto kh = (*kernel_size)[1]; Tensor out = detail::avg_pool2d( input.pow(norm_type), kernel_size, @@ -1106,9 +1104,9 @@ inline Tensor lp_pool3d( ExpandingArray<3> kernel_size, ExpandingArray<3> stride, bool ceil_mode) { - int kd = (*kernel_size)[0]; - int kw = (*kernel_size)[1]; - int kh = (*kernel_size)[2]; + auto kd = (*kernel_size)[0]; + auto kw = (*kernel_size)[1]; + auto kh = (*kernel_size)[2]; Tensor out = detail::avg_pool3d( input.pow(norm_type), kernel_size, @@ -1148,6 +1146,4 @@ inline Tensor lp_pool3d( options.ceil_mode()); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/upsampling.h b/torch/csrc/api/include/torch/nn/functional/upsampling.h index 75707ef091a78..ace73152d88ca 100644 --- a/torch/csrc/api/include/torch/nn/functional/upsampling.h +++ b/torch/csrc/api/include/torch/nn/functional/upsampling.h @@ -7,9 +7,7 @@ #include #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { inline std::vector _interp_output_size( int64_t dim, @@ -18,7 +16,8 @@ inline std::vector _interp_output_size( std::optional>, std::optional>, std::optional> closed_over_args) { - auto [input, size, scale_factor, recompute_scale_factor] = closed_over_args; + auto [input, size, scale_factor, recompute_scale_factor] = + std::move(closed_over_args); if (size == std::nullopt && scale_factor == std::nullopt) { TORCH_CHECK(false, "either size or scale_factor should be defined"); } @@ -284,6 +283,4 @@ inline Tensor interpolate( options.antialias()); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/functional/vision.h b/torch/csrc/api/include/torch/nn/functional/vision.h index a6c53e0c0a9ad..78a015dcff856 100644 --- a/torch/csrc/api/include/torch/nn/functional/vision.h +++ b/torch/csrc/api/include/torch/nn/functional/vision.h @@ -3,9 +3,7 @@ #include #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { inline Tensor affine_grid( const Tensor& theta, @@ -60,8 +58,7 @@ inline Tensor grid_sample( GridSampleFuncOptions::mode_t mode, GridSampleFuncOptions::padding_mode_t padding_mode, std::optional align_corners) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t mode_enum, padding_mode_enum; + int64_t mode_enum = 0, padding_mode_enum = 0; if (std::holds_alternative(mode)) { mode_enum = 0; @@ -119,6 +116,4 @@ inline Tensor grid_sample( options.align_corners()); } -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/modules/_functions.h b/torch/csrc/api/include/torch/nn/modules/_functions.h index 5bf1ce2dcb285..f7cc8d0eb9354 100644 --- a/torch/csrc/api/include/torch/nn/modules/_functions.h +++ b/torch/csrc/api/include/torch/nn/modules/_functions.h @@ -5,9 +5,7 @@ #include #include -namespace torch { -namespace nn { -namespace functions { +namespace torch::nn::functions { class CrossMapLRN2d : public torch::autograd::Function { public: @@ -21,6 +19,4 @@ class CrossMapLRN2d : public torch::autograd::Function { torch::autograd::variable_list grad_output); }; -} // namespace functions -} // namespace nn -} // namespace torch +} // namespace torch::nn::functions diff --git a/torch/csrc/api/include/torch/nn/modules/activation.h b/torch/csrc/api/include/torch/nn/modules/activation.h index 08e1039610745..806fbd2f0f876 100644 --- a/torch/csrc/api/include/torch/nn/modules/activation.h +++ b/torch/csrc/api/include/torch/nn/modules/activation.h @@ -8,8 +8,7 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -871,5 +870,4 @@ class TORCH_API MultiheadAttentionImpl /// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(MultiheadAttention); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/adaptive.h b/torch/csrc/api/include/torch/nn/modules/adaptive.h index 609e690d4c7de..7833b01297d2d 100644 --- a/torch/csrc/api/include/torch/nn/modules/adaptive.h +++ b/torch/csrc/api/include/torch/nn/modules/adaptive.h @@ -8,8 +8,9 @@ #include #include -namespace torch { -namespace nn { +#include + +namespace torch::nn { /// The output of a single invocation of an AdaptiveLogSoftmaxWithLoss /// module's `forward()` method. @@ -51,7 +52,7 @@ class TORCH_API AdaptiveLogSoftmaxWithLossImpl : AdaptiveLogSoftmaxWithLossImpl(AdaptiveLogSoftmaxWithLossOptions( in_features, n_classes, - cutoffs)) {} + std::move(cutoffs))) {} explicit AdaptiveLogSoftmaxWithLossImpl( AdaptiveLogSoftmaxWithLossOptions options_); @@ -105,5 +106,4 @@ class TORCH_API AdaptiveLogSoftmaxWithLossImpl /// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(AdaptiveLogSoftmaxWithLoss); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/torch/csrc/api/include/torch/nn/modules/batchnorm.h index 0f5e32746936e..cf6e824189618 100644 --- a/torch/csrc/api/include/torch/nn/modules/batchnorm.h +++ b/torch/csrc/api/include/torch/nn/modules/batchnorm.h @@ -7,10 +7,7 @@ #include #include -#include - -namespace torch { -namespace nn { +namespace torch::nn { /// Base class for all (dimension-specialized) batchnorm and instancenorm /// modules. @@ -104,11 +101,8 @@ class BatchNormImplBase : public NormImplBase { Tensor forward(const Tensor& input) { this->_check_input_dim(input); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double exponential_average_factor; - if (this->options.momentum() == std::nullopt) { - exponential_average_factor = 0.0; - } else { + double exponential_average_factor = 0.0; + if (this->options.momentum().has_value()) { exponential_average_factor = this->options.momentum().value(); } @@ -246,5 +240,4 @@ class TORCH_API BatchNorm3dImpl : public BatchNormImplBase<3, BatchNorm3dImpl> { /// learn about PyTorch's module storage semantics. TORCH_MODULE(BatchNorm3d); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/common.h b/torch/csrc/api/include/torch/nn/modules/common.h index f172c82e7e632..e967e23171872 100644 --- a/torch/csrc/api/include/torch/nn/modules/common.h +++ b/torch/csrc/api/include/torch/nn/modules/common.h @@ -70,28 +70,30 @@ /// seq->forward(1); // This correctly populates the default arguments for /// `MImpl::forward` /// ``` -#define FORWARD_HAS_DEFAULT_ARGS(...) \ - template \ - friend struct torch::nn::AnyModuleHolder; \ - bool _forward_has_default_args() override { \ - return true; \ - } \ - unsigned int _forward_num_required_args() override { \ - std::pair args_info[] = {__VA_ARGS__}; \ - return args_info[0].first; \ - } \ - std::vector _forward_populate_default_args( \ - std::vector&& arguments) override { \ - std::pair args_info[] = {__VA_ARGS__}; \ - unsigned int num_all_args = std::rbegin(args_info)->first + 1; \ - TORCH_INTERNAL_ASSERT( \ - arguments.size() >= _forward_num_required_args() && \ - arguments.size() <= num_all_args); \ - std::vector ret = std::move(arguments); \ - ret.reserve(num_all_args); \ - for (auto& arg_info : args_info) { \ - if (arg_info.first > ret.size() - 1) \ - ret.emplace_back(std::move(arg_info.second)); \ - } \ - return ret; \ +#define FORWARD_HAS_DEFAULT_ARGS(...) \ + template \ + friend struct torch::nn::AnyModuleHolder; \ + bool _forward_has_default_args() override { \ + return true; \ + } \ + unsigned int _forward_num_required_args() override { \ + std::vector> args_info{ \ + __VA_ARGS__}; \ + return std::begin(args_info)->first; \ + } \ + std::vector _forward_populate_default_args( \ + std::vector&& arguments) override { \ + std::vector> args_info{ \ + __VA_ARGS__}; \ + unsigned int num_all_args = std::rbegin(args_info)->first + 1; \ + TORCH_INTERNAL_ASSERT( \ + arguments.size() >= _forward_num_required_args() && \ + arguments.size() <= num_all_args); \ + std::vector ret = std::move(arguments); \ + ret.reserve(num_all_args); \ + for (auto& arg_info : args_info) { \ + if (arg_info.first > ret.size() - 1) \ + ret.emplace_back(std::move(arg_info.second)); \ + } \ + return ret; \ } diff --git a/torch/csrc/api/include/torch/nn/modules/container/any.h b/torch/csrc/api/include/torch/nn/modules/container/any.h index ab4a589aeded1..89b14f0d4e893 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/any.h +++ b/torch/csrc/api/include/torch/nn/modules/container/any.h @@ -1,25 +1,15 @@ #pragma once -#include #include #include -#include -#include #include -#include -#include - -#include - #include #include -#include #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Stores a type erased `Module`. /// @@ -261,8 +251,8 @@ inline AnyModule AnyModule::clone(std::optional device) const { template AnyModule& AnyModule::operator=(std::shared_ptr module) { - // NOLINTNEXTLINE(cppcoreguidelines-c-copy-assignment-signature) - return (*this = AnyModule(std::move(module))); + *this = AnyModule(std::move(module)); + return *this; } template @@ -336,7 +326,7 @@ std::unique_ptr AnyModule::make_holder( "Modules stored inside AnyModule must not take references. " "Use pointers instead."); static_assert( - !std::is_void::value, + !std::is_void_v, "AnyModule cannot store modules that return void " "(you can return a dummy value)."); return std::make_unique< @@ -346,7 +336,7 @@ std::unique_ptr AnyModule::make_holder( template ModuleType& AnyModule::get_() const { - using M = typename std::remove_reference::type; + using M = std::remove_reference_t; static_assert( torch::detail::has_forward::value, "Can only call AnyModule::get with a type T that has a forward method"); @@ -361,12 +351,12 @@ ModuleType& AnyModule::get_( *content_) .module; } - AT_ERROR( + TORCH_CHECK( + false, "Attempted to cast module of type ", c10::demangle(type_info().name()), " to type ", c10::demangle(typeid(ModuleType).name())); } -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h b/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h index edeb8e6b764c5..7482ef3b452d9 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h +++ b/torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h @@ -1,9 +1,9 @@ #pragma once +#include #include -namespace torch { -namespace nn { +namespace torch::nn { class Module; @@ -46,7 +46,8 @@ struct AnyModuleHolder : public AnyModulePlaceholder { if (auto* maybe_value = value.template try_get>()) { return std::move(*maybe_value); } - AT_ERROR( + TORCH_CHECK( + false, "Expected argument #", index, " to be of type ", @@ -54,6 +55,7 @@ struct AnyModuleHolder : public AnyModulePlaceholder { ", but received value of type ", c10::demangle(value.type_info().name())); } + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) std::vector& arguments_; }; @@ -63,6 +65,7 @@ struct AnyModuleHolder : public AnyModulePlaceholder { AnyValue operator()(Ts&&... ts) { return AnyValue(module_->forward(std::forward(ts)...)); } + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) std::shared_ptr& module_; }; @@ -129,5 +132,4 @@ struct AnyModuleHolder : public AnyModulePlaceholder { std::shared_ptr module; }; -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/container/any_value.h b/torch/csrc/api/include/torch/nn/modules/container/any_value.h index d154130618f2d..92f6a5d7789eb 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/any_value.h +++ b/torch/csrc/api/include/torch/nn/modules/container/any_value.h @@ -1,20 +1,13 @@ #pragma once -#include -#include -#include #include -#include -#include - #include #include #include #include -namespace torch { -namespace nn { +namespace torch::nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyValue ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -37,8 +30,9 @@ class AnyValue { } /// Constructs the `AnyValue` from value type. - template - // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) + template < + typename T, + typename = std::enable_if_t>> explicit AnyValue(T&& value) : content_( std::make_unique>>(std::forward(value))) { @@ -50,10 +44,10 @@ class AnyValue { template T* try_get() { static_assert( - !std::is_reference::value, + !std::is_reference_v, "AnyValue stores decayed types, you cannot cast it to a reference type"); static_assert( - !std::is_array::value, + !std::is_array_v, "AnyValue stores decayed types, you must cast it to T* instead of T[]"); if (typeid(T).hash_code() == type_info().hash_code()) { return &static_cast&>(*content_).value; @@ -69,7 +63,8 @@ class AnyValue { if (auto* maybe_value = try_get()) { return *maybe_value; } - AT_ERROR( + TORCH_CHECK( + false, "Attempted to cast AnyValue to ", c10::demangle(typeid(T).name()), ", but its actual type is ", @@ -98,6 +93,7 @@ class AnyValue { virtual std::unique_ptr clone() const { TORCH_CHECK(false, "clone() should only be called on `AnyValue::Holder`"); } + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const std::type_info& type_info; }; @@ -107,8 +103,9 @@ class AnyValue { template struct Holder : public Placeholder { /// A template because T&& would not be universal reference here. - template - // NOLINTNEXTLINE(bugprone-forwarding-reference-overload) + template < + typename U, + typename = std::enable_if_t>> explicit Holder(U&& value_) noexcept : Placeholder(typeid(T)), value(std::forward(value_)) {} std::unique_ptr clone() const override { @@ -121,5 +118,4 @@ class AnyValue { std::unique_ptr content_; }; -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/container/functional.h b/torch/csrc/api/include/torch/nn/modules/container/functional.h index 3f381a63944f5..fac31d204f5ae 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/functional.h +++ b/torch/csrc/api/include/torch/nn/modules/container/functional.h @@ -1,16 +1,13 @@ #pragma once #include -#include #include -#include #include #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Wraps a function in a `Module`. /// @@ -101,5 +98,4 @@ class TORCH_API FunctionalImpl : public torch::nn::Cloneable { /// module storage semantics. TORCH_MODULE(Functional); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/container/moduledict.h b/torch/csrc/api/include/torch/nn/modules/container/moduledict.h index b96b7611936f1..16c9c94489b0d 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/moduledict.h +++ b/torch/csrc/api/include/torch/nn/modules/container/moduledict.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// An OrderedDict of `Module`s that registers its elements by their `key`s. /// @@ -258,5 +257,4 @@ class ModuleDictImpl : public Cloneable { /// module storage semantics. TORCH_MODULE(ModuleDict); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/container/modulelist.h b/torch/csrc/api/include/torch/nn/modules/container/modulelist.h index b115abe1e9551..6147a73db4b4b 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/modulelist.h +++ b/torch/csrc/api/include/torch/nn/modules/container/modulelist.h @@ -7,8 +7,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// A list of `Module`s that registers its elements. /// @@ -99,7 +98,7 @@ class ModuleListImpl : public Cloneable { /// and letting the container deal with the boxing. template > void push_back(M&& module) { - using Type = typename std::remove_reference::type; + using Type = std::remove_reference_t; push_back(std::make_shared(std::forward(module))); } @@ -242,7 +241,7 @@ class ModuleListImpl : public Cloneable { /// and letting the container deal with the boxing. template > void insert(size_t index, M&& module) { - using Type = typename std::remove_reference::type; + using Type = std::remove_reference_t; insert(index, std::make_shared(std::forward(module))); } @@ -270,5 +269,4 @@ class ModuleListImpl : public Cloneable { /// module storage semantics. TORCH_MODULE(ModuleList); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/container/named_any.h b/torch/csrc/api/include/torch/nn/modules/container/named_any.h index 00d39de17f401..542471f61f2df 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/named_any.h +++ b/torch/csrc/api/include/torch/nn/modules/container/named_any.h @@ -1,25 +1,13 @@ #pragma once -#include -#include #include -#include #include -#include -#include - -#include - -#include #include #include -#include #include -#include -namespace torch { -namespace nn { +namespace torch::nn { /// Stores a type erased `Module` with name. /// @@ -57,7 +45,7 @@ class NamedAnyModule { NamedAnyModule(std::string name, M&& module) : NamedAnyModule( std::move(name), - std::make_shared::type>( + std::make_shared>( std::forward(module))) {} /// Creates a `NamedAnyModule` from a `Module` that is unwrapped from @@ -90,5 +78,4 @@ class NamedAnyModule { AnyModule module_; }; -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h b/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h index f201825deb5ba..df6d003750ab9 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h +++ b/torch/csrc/api/include/torch/nn/modules/container/parameterdict.h @@ -6,8 +6,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { class ParameterDictImpl : public Cloneable { public: @@ -27,22 +26,22 @@ class ParameterDictImpl : public Cloneable { /// Pretty prints the `ParameterDict` module into the given `stream`. void pretty_print(std::ostream& stream) const override { - stream << "torch::nn::ParameterDict(" << std::endl; + stream << "torch::nn::ParameterDict(" << '\n'; for (const auto& pair : parameters_) { stream << "(" << pair.key() << ")" << ": Parameter containing: [" << pair.value().scalar_type() << " of size " << pair.value().sizes() << "]"; ; - stream << std::endl; + stream << '\n'; } stream << ")"; } /// Insert the parameter along with the key into ParameterDict /// The parameter is set to be require grad by default - Tensor& insert(std::string key, Tensor param) { + Tensor& insert(const std::string& key, const Tensor& param) { bool requires_grad = param.requires_grad(); - return register_parameter(std::move(key), std::move(param), requires_grad); + return register_parameter(key, param, requires_grad); } /// Remove key from the ParameterDict and return its value, throw exception @@ -144,5 +143,4 @@ class ParameterDictImpl : public Cloneable { TORCH_MODULE(ParameterDict); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h b/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h index cb816d1bb2a1e..2ea2b52fa0fb9 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h +++ b/torch/csrc/api/include/torch/nn/modules/container/parameterlist.h @@ -5,8 +5,7 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { class ParameterListImpl : public Cloneable { public: using Iterator = typename std::vector< @@ -35,13 +34,13 @@ class ParameterListImpl : public Cloneable { /// Pretty prints the `ParameterList` module into the given `stream`. void pretty_print(std::ostream& stream) const override { - stream << "torch::nn::ParameterList(" << std::endl; + stream << "torch::nn::ParameterList(" << '\n'; for (const auto& pair : parameters_) { stream << "(" << pair.key() << ")" << ": Parameter containing: [" << pair.value().scalar_type() << " of size " << pair.value().sizes() << "]"; ; - stream << std::endl; + stream << '\n'; } stream << ")"; } @@ -165,5 +164,4 @@ class ParameterListImpl : public Cloneable { void push_back_var() {} }; TORCH_MODULE(ParameterList); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/container/sequential.h b/torch/csrc/api/include/torch/nn/modules/container/sequential.h index 6ee12bc477d82..f5ddb4e370f61 100644 --- a/torch/csrc/api/include/torch/nn/modules/container/sequential.h +++ b/torch/csrc/api/include/torch/nn/modules/container/sequential.h @@ -18,8 +18,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// A list of `Module`s that acts as a `Module` itself. /// @@ -185,7 +184,8 @@ class SequentialImpl : public Cloneable { if (auto* return_value = input.template try_get()) { return std::move(*return_value); } - AT_ERROR( + TORCH_CHECK( + false, "The type of the return value is ", c10::demangle(input.type_info().name()), ", but you asked for type ", @@ -384,5 +384,4 @@ class Sequential : public torch::nn::ModuleHolder { Sequential(std::initializer_list named_modules) : ModuleHolder(std::make_shared(named_modules)) {} }; -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/conv.h b/torch/csrc/api/include/torch/nn/modules/conv.h index e44fd44b954ab..20dc17e8e6fc4 100644 --- a/torch/csrc/api/include/torch/nn/modules/conv.h +++ b/torch/csrc/api/include/torch/nn/modules/conv.h @@ -17,8 +17,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Base class for all (dimension-specialized) convolution modules. template @@ -447,5 +446,4 @@ class TORCH_API ConvTranspose3dImpl /// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(ConvTranspose3d); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/distance.h b/torch/csrc/api/include/torch/nn/modules/distance.h index 774b01d7e447c..7166ba15d1821 100644 --- a/torch/csrc/api/include/torch/nn/modules/distance.h +++ b/torch/csrc/api/include/torch/nn/modules/distance.h @@ -8,8 +8,7 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { /// Returns the cosine similarity between :math:`x_1` and :math:`x_2`, computed /// along `dim`. @@ -82,5 +81,4 @@ class TORCH_API PairwiseDistanceImpl : public Cloneable { /// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(PairwiseDistance); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/dropout.h b/torch/csrc/api/include/torch/nn/modules/dropout.h index a2ebabded6fab..c23f0501dc3b5 100644 --- a/torch/csrc/api/include/torch/nn/modules/dropout.h +++ b/torch/csrc/api/include/torch/nn/modules/dropout.h @@ -7,11 +7,7 @@ #include -#include -#include - -namespace torch { -namespace nn { +namespace torch::nn { namespace detail { @@ -186,5 +182,4 @@ class TORCH_API FeatureAlphaDropoutImpl /// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(FeatureAlphaDropout); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/embedding.h b/torch/csrc/api/include/torch/nn/modules/embedding.h index ff61941d3a35b..f8af433bcc4c1 100644 --- a/torch/csrc/api/include/torch/nn/modules/embedding.h +++ b/torch/csrc/api/include/torch/nn/modules/embedding.h @@ -9,8 +9,7 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Embedding // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -70,10 +69,8 @@ class Embedding : public torch::nn::ModuleHolder { embeddings.dim() == 2, "Embeddings parameter is expected to be 2-dimensional"); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t rows, cols; - rows = embeddings.size(0); - cols = embeddings.size(1); + auto rows = embeddings.size(0); + auto cols = embeddings.size(1); Embedding embedding(EmbeddingOptions(rows, cols) ._weight(embeddings) @@ -149,10 +146,8 @@ class EmbeddingBag : public torch::nn::ModuleHolder { embeddings.dim() == 2, "Embeddings parameter is expected to be 2-dimensional"); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t rows, cols; - rows = embeddings.size(0); - cols = embeddings.size(1); + auto rows = embeddings.size(0); + auto cols = embeddings.size(1); EmbeddingBag embeddingbag( EmbeddingBagOptions(rows, cols) @@ -167,5 +162,4 @@ class EmbeddingBag : public torch::nn::ModuleHolder { return embeddingbag; } }; -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/fold.h b/torch/csrc/api/include/torch/nn/modules/fold.h index 6b415a99b5ea8..4ad49f191fbba 100644 --- a/torch/csrc/api/include/torch/nn/modules/fold.h +++ b/torch/csrc/api/include/torch/nn/modules/fold.h @@ -7,8 +7,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Applies fold over a 3-D input. /// See https://pytorch.org/docs/main/nn.html#torch.nn.Fold to learn about @@ -83,5 +82,4 @@ class TORCH_API UnfoldImpl : public Cloneable { /// learn about PyTorch's module storage semantics. TORCH_MODULE(Unfold); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/instancenorm.h b/torch/csrc/api/include/torch/nn/modules/instancenorm.h index cee9142da6b6c..228f181715fc7 100644 --- a/torch/csrc/api/include/torch/nn/modules/instancenorm.h +++ b/torch/csrc/api/include/torch/nn/modules/instancenorm.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include @@ -7,6 +8,7 @@ namespace torch::nn { /// Base class for all (dimension-specialized) instance norm modules template +// NOLINTNEXTLINE(bugprone-crtp-constructor-accessibility) class InstanceNormImpl : public torch::nn::NormImplBase { private: diff --git a/torch/csrc/api/include/torch/nn/modules/linear.h b/torch/csrc/api/include/torch/nn/modules/linear.h index 4a88ea80afe63..cb54396837840 100644 --- a/torch/csrc/api/include/torch/nn/modules/linear.h +++ b/torch/csrc/api/include/torch/nn/modules/linear.h @@ -8,10 +8,10 @@ #include #include +#include #include -namespace torch { -namespace nn { +namespace torch::nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Identity ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -136,9 +136,10 @@ TORCH_MODULE(Flatten); class TORCH_API UnflattenImpl : public Cloneable { public: UnflattenImpl(int64_t dim, std::vector sizes) - : UnflattenImpl(UnflattenOptions(dim, sizes)) {} + : UnflattenImpl(UnflattenOptions(dim, std::move(sizes))) {} UnflattenImpl(std::string dimname, UnflattenOptions::namedshape_t namedshape) - : UnflattenImpl(UnflattenOptions(dimname, namedshape)) {} + : UnflattenImpl( + UnflattenOptions(std::move(dimname), std::move(namedshape))) {} explicit UnflattenImpl(UnflattenOptions options_); void reset() override; @@ -210,5 +211,4 @@ class TORCH_API BilinearImpl : public Cloneable { /// learn about PyTorch's module storage semantics. TORCH_MODULE(Bilinear); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/loss.h b/torch/csrc/api/include/torch/nn/modules/loss.h index 747b548b75844..52be4f612b59f 100644 --- a/torch/csrc/api/include/torch/nn/modules/loss.h +++ b/torch/csrc/api/include/torch/nn/modules/loss.h @@ -12,8 +12,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ L1Loss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -801,5 +800,4 @@ struct TORCH_API BCEWithLogitsLossImpl /// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(BCEWithLogitsLoss); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/normalization.h b/torch/csrc/api/include/torch/nn/modules/normalization.h index 9bc0b7f9e7fc4..7fe0396319d7b 100644 --- a/torch/csrc/api/include/torch/nn/modules/normalization.h +++ b/torch/csrc/api/include/torch/nn/modules/normalization.h @@ -8,10 +8,10 @@ #include #include +#include #include -namespace torch { -namespace nn { +namespace torch::nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LayerNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -31,7 +31,7 @@ namespace nn { class TORCH_API LayerNormImpl : public torch::nn::Cloneable { public: LayerNormImpl(std::vector normalized_shape) - : LayerNormImpl(LayerNormOptions(normalized_shape)) {} + : LayerNormImpl(LayerNormOptions(std::move(normalized_shape))) {} explicit LayerNormImpl(LayerNormOptions options_); void reset() override; @@ -194,5 +194,4 @@ class TORCH_API GroupNormImpl : public torch::nn::Cloneable { /// learn about PyTorch's module storage semantics. TORCH_MODULE(GroupNorm); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/padding.h b/torch/csrc/api/include/torch/nn/modules/padding.h index f051e9a19305c..855608438ce0b 100644 --- a/torch/csrc/api/include/torch/nn/modules/padding.h +++ b/torch/csrc/api/include/torch/nn/modules/padding.h @@ -6,8 +6,7 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { /// Base class for all (dimension-specialized) ReflectionPad modules. template @@ -374,5 +373,4 @@ class TORCH_API ConstantPad3dImpl /// to learn about PyTorch's module storage semantics. TORCH_MODULE(ConstantPad3d); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h b/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h index 7ad916d332f45..ce981c3a1c341 100644 --- a/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h +++ b/torch/csrc/api/include/torch/nn/modules/pixelshuffle.h @@ -6,8 +6,7 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PixelShuffle // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -84,5 +83,4 @@ struct TORCH_API PixelUnshuffleImpl /// to learn about PyTorch's module storage semantics. TORCH_MODULE(PixelUnshuffle); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/pooling.h b/torch/csrc/api/include/torch/nn/modules/pooling.h index 0fac60edbcde4..c9482adb702bb 100644 --- a/torch/csrc/api/include/torch/nn/modules/pooling.h +++ b/torch/csrc/api/include/torch/nn/modules/pooling.h @@ -8,8 +8,7 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { /// Base class for all (dimension-specialized) avgpool modules. template @@ -775,5 +774,4 @@ class TORCH_API LPPool3dImpl : public LPPoolImpl<3, LPPool3dImpl> { /// learn about PyTorch's module storage semantics. TORCH_MODULE(LPPool3d); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/rnn.h b/torch/csrc/api/include/torch/nn/modules/rnn.h index bc33ecac834c0..a2a251b413675 100644 --- a/torch/csrc/api/include/torch/nn/modules/rnn.h +++ b/torch/csrc/api/include/torch/nn/modules/rnn.h @@ -16,8 +16,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { namespace detail { /// Base class for all RNN implementations (intended for code sharing). @@ -397,5 +396,4 @@ class TORCH_API GRUCellImpl : public detail::RNNCellImplBase { /// learn about PyTorch's module storage semantics. TORCH_MODULE(GRUCell); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/transformer.h b/torch/csrc/api/include/torch/nn/modules/transformer.h index c8c417c7564b3..2f22f087bf518 100644 --- a/torch/csrc/api/include/torch/nn/modules/transformer.h +++ b/torch/csrc/api/include/torch/nn/modules/transformer.h @@ -10,8 +10,7 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Transformer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -139,5 +138,4 @@ class TORCH_API TransformerImpl : public Cloneable { /// module storage semantics. TORCH_MODULE(Transformer); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/transformercoder.h b/torch/csrc/api/include/torch/nn/modules/transformercoder.h index 5ca4ddea64b8d..e06dd81b9234c 100644 --- a/torch/csrc/api/include/torch/nn/modules/transformercoder.h +++ b/torch/csrc/api/include/torch/nn/modules/transformercoder.h @@ -10,10 +10,9 @@ #include -#include +#include -namespace torch { -namespace nn { +namespace torch::nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoder // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -40,7 +39,7 @@ class TORCH_API TransformerEncoderImpl TransformerEncoderLayer encoder_layer, int64_t num_layers) : TransformerEncoderImpl( - TransformerEncoderOptions(encoder_layer, num_layers)) {} + TransformerEncoderOptions(std::move(encoder_layer), num_layers)) {} explicit TransformerEncoderImpl(TransformerEncoderOptions options_); Tensor forward( @@ -101,7 +100,7 @@ class TORCH_API TransformerDecoderImpl TransformerDecoderLayer decoder_layer, int64_t num_layers) : TransformerDecoderImpl( - TransformerDecoderOptions(decoder_layer, num_layers)) {} + TransformerDecoderOptions(std::move(decoder_layer), num_layers)) {} explicit TransformerDecoderImpl(TransformerDecoderOptions options_); void reset() override; @@ -150,5 +149,4 @@ class TORCH_API TransformerDecoderImpl /// module storage semantics. TORCH_MODULE(TransformerDecoder); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/transformerlayer.h b/torch/csrc/api/include/torch/nn/modules/transformerlayer.h index b2d8131870161..74f1143e5c163 100644 --- a/torch/csrc/api/include/torch/nn/modules/transformerlayer.h +++ b/torch/csrc/api/include/torch/nn/modules/transformerlayer.h @@ -14,8 +14,7 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoderLayer // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -191,5 +190,4 @@ class TORCH_API TransformerDecoderLayerImpl /// `ModuleHolder` to learn about PyTorch's module storage semantics. TORCH_MODULE(TransformerDecoderLayer); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/upsampling.h b/torch/csrc/api/include/torch/nn/modules/upsampling.h index 8520bf632f83e..6651357913080 100644 --- a/torch/csrc/api/include/torch/nn/modules/upsampling.h +++ b/torch/csrc/api/include/torch/nn/modules/upsampling.h @@ -11,8 +11,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Upsample ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -51,5 +50,4 @@ class TORCH_API UpsampleImpl : public Cloneable { /// learn about PyTorch's module storage semantics. TORCH_MODULE(Upsample); -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/modules/utils.h b/torch/csrc/api/include/torch/nn/modules/utils.h index 6eaa0c1fb2c73..b89abb748c898 100644 --- a/torch/csrc/api/include/torch/nn/modules/utils.h +++ b/torch/csrc/api/include/torch/nn/modules/utils.h @@ -6,10 +6,7 @@ #include -namespace torch { -namespace nn { -namespace modules { -namespace utils { +namespace torch::nn::modules::utils { // Reverse the order of `t` and repeat each element for `n` times. // This can be used to translate padding arg used by Conv and Pooling modules @@ -17,14 +14,13 @@ namespace utils { // // This mirrors `_reverse_repeat_tuple` in `torch/nn/modules/utils.py`. inline std::vector _reverse_repeat_vector( - at::ArrayRef t, + c10::ArrayRef t, int64_t n) { TORCH_INTERNAL_ASSERT(n >= 0); std::vector ret; ret.reserve(t.size() * n); for (auto rit = t.rbegin(); rit != t.rend(); ++rit) { - for (const auto i : c10::irange(n)) { - (void)i; // Suppress unused variable + for ([[maybe_unused]] const auto i : c10::irange(n)) { ret.emplace_back(*rit); } } @@ -32,14 +28,14 @@ inline std::vector _reverse_repeat_vector( } inline std::vector _list_with_default( - torch::ArrayRef> out_size, - torch::IntArrayRef defaults) { + c10::ArrayRef> out_size, + c10::IntArrayRef defaults) { TORCH_CHECK( defaults.size() > out_size.size(), "Input dimension should be at least ", out_size.size() + 1); std::vector ret; - torch::IntArrayRef defaults_slice = + c10::IntArrayRef defaults_slice = defaults.slice(defaults.size() - out_size.size(), out_size.size()); for (const auto i : c10::irange(out_size.size())) { auto v = out_size.at(i); @@ -49,7 +45,4 @@ inline std::vector _list_with_default( return ret; } -} // namespace utils -} // namespace modules -} // namespace nn -} // namespace torch +} // namespace torch::nn::modules::utils diff --git a/torch/csrc/api/include/torch/nn/options/activation.h b/torch/csrc/api/include/torch/nn/options/activation.h index ac6cbc4ea4dea..480e09ad4de2b 100644 --- a/torch/csrc/api/include/torch/nn/options/activation.h +++ b/torch/csrc/api/include/torch/nn/options/activation.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `ELU` module. /// @@ -710,5 +709,4 @@ struct TORCH_API MultiheadAttentionForwardFuncOptions { } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/adaptive.h b/torch/csrc/api/include/torch/nn/options/adaptive.h index d4754747a1d29..4335fb725c6f4 100644 --- a/torch/csrc/api/include/torch/nn/options/adaptive.h +++ b/torch/csrc/api/include/torch/nn/options/adaptive.h @@ -4,8 +4,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `AdaptiveLogSoftmaxWithLoss` module. /// @@ -37,5 +36,4 @@ struct TORCH_API AdaptiveLogSoftmaxWithLossOptions { TORCH_ARG(bool, head_bias) = false; }; -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/batchnorm.h b/torch/csrc/api/include/torch/nn/options/batchnorm.h index 943673e2aae74..a870ba3767c5a 100644 --- a/torch/csrc/api/include/torch/nn/options/batchnorm.h +++ b/torch/csrc/api/include/torch/nn/options/batchnorm.h @@ -4,8 +4,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `BatchNorm` module. struct TORCH_API BatchNormOptions { @@ -91,5 +90,4 @@ struct TORCH_API BatchNormFuncOptions { } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/conv.h b/torch/csrc/api/include/torch/nn/options/conv.h index 0b5b5b1b3f955..f10d5e9a31061 100644 --- a/torch/csrc/api/include/torch/nn/options/conv.h +++ b/torch/csrc/api/include/torch/nn/options/conv.h @@ -6,8 +6,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { namespace detail { @@ -411,5 +410,4 @@ using ConvTranspose3dFuncOptions = ConvTransposeFuncOptions<3>; } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/distance.h b/torch/csrc/api/include/torch/nn/options/distance.h index 654cd6626498d..c9cfc2e0aae2f 100644 --- a/torch/csrc/api/include/torch/nn/options/distance.h +++ b/torch/csrc/api/include/torch/nn/options/distance.h @@ -4,8 +4,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `CosineSimilarity` module. /// @@ -67,5 +66,4 @@ namespace functional { using PairwiseDistanceFuncOptions = PairwiseDistanceOptions; } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/dropout.h b/torch/csrc/api/include/torch/nn/options/dropout.h index 7f41f5672382c..865920c599cc3 100644 --- a/torch/csrc/api/include/torch/nn/options/dropout.h +++ b/torch/csrc/api/include/torch/nn/options/dropout.h @@ -4,8 +4,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `Dropout` module. /// @@ -126,5 +125,4 @@ struct TORCH_API FeatureAlphaDropoutFuncOptions { } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/embedding.h b/torch/csrc/api/include/torch/nn/options/embedding.h index a3d2fdb72f54d..be689f12b3bd9 100644 --- a/torch/csrc/api/include/torch/nn/options/embedding.h +++ b/torch/csrc/api/include/torch/nn/options/embedding.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `Embedding` module. /// @@ -238,5 +237,4 @@ struct TORCH_API EmbeddingBagFuncOptions { } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/fold.h b/torch/csrc/api/include/torch/nn/options/fold.h index 21c24bff845ac..958105e159bb6 100644 --- a/torch/csrc/api/include/torch/nn/options/fold.h +++ b/torch/csrc/api/include/torch/nn/options/fold.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `Fold` module. /// @@ -17,8 +16,7 @@ namespace nn { /// ``` struct TORCH_API FoldOptions { FoldOptions(ExpandingArray<2> output_size, ExpandingArray<2> kernel_size) - : output_size_(std::move(output_size)), - kernel_size_(std::move(kernel_size)) {} + : output_size_(output_size), kernel_size_(kernel_size) {} /// describes the spatial shape of the large containing tensor of the sliding /// local blocks. It is useful to resolve the ambiguity when multiple input @@ -63,8 +61,7 @@ using FoldFuncOptions = FoldOptions; /// Unfold model(UnfoldOptions({2, 4}).dilation(2).padding({2, 1}).stride(2)); /// ``` struct TORCH_API UnfoldOptions { - UnfoldOptions(ExpandingArray<2> kernel_size) - : kernel_size_(std::move(kernel_size)) {} + UnfoldOptions(ExpandingArray<2> kernel_size) : kernel_size_(kernel_size) {} /// the size of the sliding blocks TORCH_ARG(ExpandingArray<2>, kernel_size); @@ -95,5 +92,4 @@ namespace functional { using UnfoldFuncOptions = UnfoldOptions; } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/instancenorm.h b/torch/csrc/api/include/torch/nn/options/instancenorm.h index d93e10d0c95a2..2c90a060340b7 100644 --- a/torch/csrc/api/include/torch/nn/options/instancenorm.h +++ b/torch/csrc/api/include/torch/nn/options/instancenorm.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `InstanceNorm` module. struct TORCH_API InstanceNormOptions { @@ -85,5 +84,4 @@ struct TORCH_API InstanceNormFuncOptions { } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/linear.h b/torch/csrc/api/include/torch/nn/options/linear.h index 5952d97806b37..6c045910b848c 100644 --- a/torch/csrc/api/include/torch/nn/options/linear.h +++ b/torch/csrc/api/include/torch/nn/options/linear.h @@ -4,8 +4,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `Linear` module. /// @@ -91,5 +90,4 @@ struct TORCH_API BilinearOptions { TORCH_ARG(bool, bias) = true; }; -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/loss.h b/torch/csrc/api/include/torch/nn/options/loss.h index 5a6e7aa3ab20b..88d954c5e18b5 100644 --- a/torch/csrc/api/include/torch/nn/options/loss.h +++ b/torch/csrc/api/include/torch/nn/options/loss.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `L1Loss` module. /// @@ -798,5 +797,4 @@ namespace functional { using BinaryCrossEntropyWithLogitsFuncOptions = BCEWithLogitsLossOptions; } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/normalization.h b/torch/csrc/api/include/torch/nn/options/normalization.h index 4b6dcd6ffe0c2..6097a2923af2f 100644 --- a/torch/csrc/api/include/torch/nn/options/normalization.h +++ b/torch/csrc/api/include/torch/nn/options/normalization.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `LayerNorm` module. /// @@ -188,5 +187,4 @@ struct TORCH_API GroupNormFuncOptions { } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/padding.h b/torch/csrc/api/include/torch/nn/options/padding.h index 8b8312f78ee64..efe71cff29005 100644 --- a/torch/csrc/api/include/torch/nn/options/padding.h +++ b/torch/csrc/api/include/torch/nn/options/padding.h @@ -6,8 +6,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for a `D`-dimensional ReflectionPad module. template @@ -215,5 +214,4 @@ struct TORCH_API PadFuncOptions { } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/pixelshuffle.h b/torch/csrc/api/include/torch/nn/options/pixelshuffle.h index 859da98616db1..8de36fb614861 100644 --- a/torch/csrc/api/include/torch/nn/options/pixelshuffle.h +++ b/torch/csrc/api/include/torch/nn/options/pixelshuffle.h @@ -4,8 +4,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `PixelShuffle` module. /// @@ -61,5 +60,4 @@ using PixelShuffleFuncOptions = PixelShuffleOptions; using PixelUnshuffleFuncOptions = PixelUnshuffleOptions; } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/pooling.h b/torch/csrc/api/include/torch/nn/options/pooling.h index 75408890e7cd1..3934f326c8a5d 100644 --- a/torch/csrc/api/include/torch/nn/options/pooling.h +++ b/torch/csrc/api/include/torch/nn/options/pooling.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for a `D`-dimensional avgpool module. template @@ -592,5 +591,4 @@ namespace functional { using LPPool3dFuncOptions = LPPool3dOptions; } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/rnn.h b/torch/csrc/api/include/torch/nn/options/rnn.h index 133acc500276d..44d9b5ab6b617 100644 --- a/torch/csrc/api/include/torch/nn/options/rnn.h +++ b/torch/csrc/api/include/torch/nn/options/rnn.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { namespace detail { @@ -232,5 +231,4 @@ struct TORCH_API GRUCellOptions { TORCH_ARG(bool, bias) = true; }; -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/transformer.h b/torch/csrc/api/include/torch/nn/options/transformer.h index 41db38fe0757a..a5ecba9d22637 100644 --- a/torch/csrc/api/include/torch/nn/options/transformer.h +++ b/torch/csrc/api/include/torch/nn/options/transformer.h @@ -8,8 +8,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `Transformer` module /// @@ -60,5 +59,4 @@ struct TORCH_API TransformerOptions { TORCH_ARG(AnyModule, custom_decoder); }; -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/transformercoder.h b/torch/csrc/api/include/torch/nn/options/transformercoder.h index 64f6b998f4c65..343cce605b60f 100644 --- a/torch/csrc/api/include/torch/nn/options/transformercoder.h +++ b/torch/csrc/api/include/torch/nn/options/transformercoder.h @@ -8,8 +8,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `TransformerEncoder` /// @@ -72,5 +71,4 @@ struct TORCH_API TransformerDecoderOptions { TORCH_ARG(AnyModule, norm); }; -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/transformerlayer.h b/torch/csrc/api/include/torch/nn/options/transformerlayer.h index cbd6af26a1da6..d20f60567b9e2 100644 --- a/torch/csrc/api/include/torch/nn/options/transformerlayer.h +++ b/torch/csrc/api/include/torch/nn/options/transformerlayer.h @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { using activation_t = std::variant< enumtype::kReLU, @@ -68,5 +67,4 @@ struct TORCH_API TransformerDecoderLayerOptions { TORCH_ARG(activation_t, activation) = torch::kReLU; }; -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/upsampling.h b/torch/csrc/api/include/torch/nn/options/upsampling.h index df8eb194180ac..a0d6bb57182c4 100644 --- a/torch/csrc/api/include/torch/nn/options/upsampling.h +++ b/torch/csrc/api/include/torch/nn/options/upsampling.h @@ -8,8 +8,7 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { /// Options for the `Upsample` module. /// @@ -106,5 +105,4 @@ struct TORCH_API InterpolateFuncOptions { } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/options/vision.h b/torch/csrc/api/include/torch/nn/options/vision.h index a5204f0dffb62..bbbcbee92ff30 100644 --- a/torch/csrc/api/include/torch/nn/options/vision.h +++ b/torch/csrc/api/include/torch/nn/options/vision.h @@ -5,9 +5,7 @@ #include #include -namespace torch { -namespace nn { -namespace functional { +namespace torch::nn::functional { /// Options for `torch::nn::functional::grid_sample`. /// @@ -31,6 +29,4 @@ struct TORCH_API GridSampleFuncOptions { TORCH_ARG(std::optional, align_corners) = std::nullopt; }; -} // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn::functional diff --git a/torch/csrc/api/include/torch/nn/parallel/data_parallel.h b/torch/csrc/api/include/torch/nn/parallel/data_parallel.h index 22f8f678a8e74..c5144497c7576 100644 --- a/torch/csrc/api/include/torch/nn/parallel/data_parallel.h +++ b/torch/csrc/api/include/torch/nn/parallel/data_parallel.h @@ -15,14 +15,12 @@ #include #include -#include #include #include #include #include -namespace torch { -namespace nn { +namespace torch::nn { namespace { @@ -62,8 +60,9 @@ namespace { struct ReduceAdd : public autograd::Node { explicit ReduceAdd(const at::Device& destination_device) : destination_device_(destination_device){}; - ~ReduceAdd() override {} + ~ReduceAdd() override = default; + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) autograd::variable_list apply(autograd::variable_list&& inputs) override { TORCH_CHECK( !torch::autograd::compute_requires_grad(inputs), @@ -293,5 +292,4 @@ Tensor data_parallel( } } // namespace parallel -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/include/torch/nn/pimpl.h b/torch/csrc/api/include/torch/nn/pimpl.h index 9f0fe629baae9..3c1206e4edb82 100644 --- a/torch/csrc/api/include/torch/nn/pimpl.h +++ b/torch/csrc/api/include/torch/nn/pimpl.h @@ -42,7 +42,7 @@ class ModuleHolder : torch::detail::ModuleHolderIndicator { /// actually used. ModuleHolder() : impl_(default_construct()) { static_assert( - std::is_default_constructible::value, + std::is_default_constructible_v, "You are trying to default construct a module which has " "no default constructor. Use = nullptr to give it the empty state " "(e.g. `Linear linear = nullptr;` instead of `Linear linear;`)."); @@ -182,7 +182,7 @@ serialize::InputArchive& operator>>( #ifdef __CUDACC__ #define TORCH_UNUSED_EXCEPT_CUDA #else -#define TORCH_UNUSED_EXCEPT_CUDA C10_UNUSED +#define TORCH_UNUSED_EXCEPT_CUDA [[maybe_unused]] #endif /// Defines a class `Name` which inherits from `nn::ModuleHolder` to provide a diff --git a/torch/csrc/api/include/torch/nn/utils/clip_grad.h b/torch/csrc/api/include/torch/nn/utils/clip_grad.h index 8a2a569c03335..a5fbbcbd854cd 100644 --- a/torch/csrc/api/include/torch/nn/utils/clip_grad.h +++ b/torch/csrc/api/include/torch/nn/utils/clip_grad.h @@ -2,11 +2,11 @@ #include +#include #include +#include -namespace torch { -namespace nn { -namespace utils { +namespace torch::nn::utils { // Clips gradient norm of a vector of Tensors. // See @@ -109,8 +109,7 @@ inline double clip_grad_norm_( double norm_type = 2.0, bool error_if_nonfinite = false) { std::vector params = {std::move(parameter)}; - return clip_grad_norm_( - std::move(params), max_norm, norm_type, error_if_nonfinite); + return clip_grad_norm_(params, max_norm, norm_type, error_if_nonfinite); } // Clips gradient of an iterable of parameters at specified value. @@ -139,9 +138,7 @@ inline void clip_grad_value_( // single Tensor. inline void clip_grad_value_(Tensor parameter, double clip_value) { std::vector params = {std::move(parameter)}; - clip_grad_value_(std::move(params), clip_value); + clip_grad_value_(params, clip_value); } -} // namespace utils -} // namespace nn -} // namespace torch +} // namespace torch::nn::utils diff --git a/torch/csrc/api/include/torch/nn/utils/convert_parameters.h b/torch/csrc/api/include/torch/nn/utils/convert_parameters.h index b8bfee33473f2..bb79a743902af 100644 --- a/torch/csrc/api/include/torch/nn/utils/convert_parameters.h +++ b/torch/csrc/api/include/torch/nn/utils/convert_parameters.h @@ -3,9 +3,7 @@ #include #include -namespace torch { -namespace nn { -namespace utils { +namespace torch::nn::utils { // This helper function is to check if the parameters are located // in the same device. Currently, the conversion between model parameters @@ -77,6 +75,4 @@ inline void vector_to_parameters( } } -} // namespace utils -} // namespace nn -} // namespace torch +} // namespace torch::nn::utils diff --git a/torch/csrc/api/include/torch/nn/utils/rnn.h b/torch/csrc/api/include/torch/nn/utils/rnn.h index 6f2a68984c80a..53c378c028972 100644 --- a/torch/csrc/api/include/torch/nn/utils/rnn.h +++ b/torch/csrc/api/include/torch/nn/utils/rnn.h @@ -5,10 +5,7 @@ #include -namespace torch { -namespace nn { -namespace utils { -namespace rnn { +namespace torch::nn::utils::rnn { inline Tensor invert_permutation(const Tensor& permutation) { if (!permutation.defined()) { @@ -244,7 +241,7 @@ inline PackedSequence pack_padded_sequence( /// Tuple of Tensor containing the padded sequence, and a Tensor /// containing the list of lengths of each sequence in the batch. inline std::tuple pad_packed_sequence( - PackedSequence sequence, + const PackedSequence& sequence, bool batch_first = false, double padding_value = 0.0, std::optional total_length = torch::nullopt) { @@ -339,7 +336,7 @@ inline PackedSequence pack_sequence( bool enforce_sorted = true) { Tensor lengths = torch::empty({(int64_t)sequences.size()}, kInt64); for (const auto i : c10::irange(sequences.size())) { - lengths[i] = sequences[i].size(0); + lengths[static_cast(i)] = sequences[i].size(0); } return pack_padded_sequence( at::pad_sequence(sequences), @@ -348,7 +345,4 @@ inline PackedSequence pack_sequence( /*enforce_sorted=*/enforce_sorted); } -} // namespace rnn -} // namespace utils -} // namespace nn -} // namespace torch +} // namespace torch::nn::utils::rnn diff --git a/torch/csrc/api/include/torch/optim/adagrad.h b/torch/csrc/api/include/torch/optim/adagrad.h index 4b2ff3c676b3d..80e85dc0dfcd1 100644 --- a/torch/csrc/api/include/torch/optim/adagrad.h +++ b/torch/csrc/api/include/torch/optim/adagrad.h @@ -9,15 +9,12 @@ #include #include -namespace torch { -namespace serialize { +namespace torch::serialize { class OutputArchive; class InputArchive; -} // namespace serialize -} // namespace torch +} // namespace torch::serialize -namespace torch { -namespace optim { +namespace torch::optim { struct TORCH_API AdagradOptions : public OptimizerCloneableOptions { @@ -59,11 +56,9 @@ struct TORCH_API AdagradParamState class TORCH_API Adagrad : public Optimizer { public: explicit Adagrad( - std::vector param_groups, + const std::vector& param_groups, AdagradOptions defaults = {}) - : Optimizer( - std::move(param_groups), - std::make_unique(defaults)) { + : Optimizer(param_groups, std::make_unique(defaults)) { TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); TORCH_CHECK( defaults.lr_decay() >= 0, @@ -93,7 +88,8 @@ class TORCH_API Adagrad : public Optimizer { } explicit Adagrad(std::vector params, AdagradOptions defaults = {}) - : Adagrad({OptimizerParamGroup(std::move(params))}, defaults) {} + : Adagrad({OptimizerParamGroup(std::move(params))}, std::move(defaults)) { + } torch::Tensor step(LossClosure closure = nullptr) override; void save(serialize::OutputArchive& archive) const override; @@ -105,5 +101,4 @@ class TORCH_API Adagrad : public Optimizer { _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(Adagrad); } }; -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/include/torch/optim/adam.h b/torch/csrc/api/include/torch/optim/adam.h index 6e5e02d82c544..6c06e4030cf4c 100644 --- a/torch/csrc/api/include/torch/optim/adam.h +++ b/torch/csrc/api/include/torch/optim/adam.h @@ -7,15 +7,12 @@ #include #include -namespace torch { -namespace serialize { +namespace torch::serialize { class OutputArchive; class InputArchive; -} // namespace serialize -} // namespace torch +} // namespace torch::serialize -namespace torch { -namespace optim { +namespace torch::optim { struct TORCH_API AdamOptions : public OptimizerCloneableOptions { AdamOptions(double lr = 1e-3); @@ -54,11 +51,9 @@ struct TORCH_API AdamParamState class TORCH_API Adam : public Optimizer { public: explicit Adam( - std::vector param_groups, + const std::vector& param_groups, AdamOptions defaults = {}) - : Optimizer( - std::move(param_groups), - std::make_unique(defaults)) { + : Optimizer(param_groups, std::make_unique(defaults)) { TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); auto betas = defaults.betas(); @@ -76,7 +71,7 @@ class TORCH_API Adam : public Optimizer { defaults.weight_decay()); } explicit Adam(std::vector params, AdamOptions defaults = {}) - : Adam({OptimizerParamGroup(std::move(params))}, defaults) {} + : Adam({OptimizerParamGroup(std::move(params))}, std::move(defaults)) {} torch::Tensor step(LossClosure closure = nullptr) override; void save(serialize::OutputArchive& archive) const override; @@ -88,5 +83,4 @@ class TORCH_API Adam : public Optimizer { _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(Adam); } }; -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/include/torch/optim/adamw.h b/torch/csrc/api/include/torch/optim/adamw.h index a63d7fc32d455..d656921a719d0 100644 --- a/torch/csrc/api/include/torch/optim/adamw.h +++ b/torch/csrc/api/include/torch/optim/adamw.h @@ -7,15 +7,12 @@ #include #include -namespace torch { -namespace serialize { +namespace torch::serialize { class OutputArchive; class InputArchive; -} // namespace serialize -} // namespace torch +} // namespace torch::serialize -namespace torch { -namespace optim { +namespace torch::optim { struct TORCH_API AdamWOptions : public OptimizerCloneableOptions { AdamWOptions(double lr = 1e-3); @@ -54,11 +51,9 @@ struct TORCH_API AdamWParamState class TORCH_API AdamW : public Optimizer { public: explicit AdamW( - std::vector param_groups, + const std::vector& param_groups, AdamWOptions defaults = {}) - : Optimizer( - std::move(param_groups), - std::make_unique(defaults)) { + : Optimizer(param_groups, std::make_unique(defaults)) { TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); auto betas = defaults.betas(); @@ -76,7 +71,7 @@ class TORCH_API AdamW : public Optimizer { defaults.weight_decay()); } explicit AdamW(std::vector params, AdamWOptions defaults = {}) - : AdamW({OptimizerParamGroup(std::move(params))}, defaults) {} + : AdamW({OptimizerParamGroup(std::move(params))}, std::move(defaults)) {} torch::Tensor step(LossClosure closure = nullptr) override; void save(serialize::OutputArchive& archive) const override; @@ -88,5 +83,4 @@ class TORCH_API AdamW : public Optimizer { _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(AdamW); } }; -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/include/torch/optim/lbfgs.h b/torch/csrc/api/include/torch/optim/lbfgs.h index 0832afff5f8f2..3d5f1832cf600 100644 --- a/torch/csrc/api/include/torch/optim/lbfgs.h +++ b/torch/csrc/api/include/torch/optim/lbfgs.h @@ -8,10 +8,10 @@ #include #include #include +#include #include -namespace torch { -namespace optim { +namespace torch::optim { struct TORCH_API LBFGSOptions : public OptimizerCloneableOptions { LBFGSOptions(double lr = 1); @@ -58,11 +58,9 @@ struct TORCH_API LBFGSParamState class TORCH_API LBFGS : public Optimizer { public: explicit LBFGS( - std::vector param_groups, + const std::vector& param_groups, LBFGSOptions defaults = {}) - : Optimizer( - std::move(param_groups), - std::make_unique(defaults)) { + : Optimizer(param_groups, std::make_unique(defaults)) { TORCH_CHECK( param_groups_.size() == 1, "LBFGS doesn't support per-parameter options (parameter groups)"); @@ -70,12 +68,12 @@ class TORCH_API LBFGS : public Optimizer { auto max_eval_val = (defaults.max_iter() * 5) / 4; static_cast(param_groups_[0].options()) .max_eval(max_eval_val); - static_cast(*defaults_.get()).max_eval(max_eval_val); + static_cast(*defaults_).max_eval(max_eval_val); } _numel_cache = std::nullopt; } explicit LBFGS(std::vector params, LBFGSOptions defaults = {}) - : LBFGS({OptimizerParamGroup(std::move(params))}, defaults) {} + : LBFGS({OptimizerParamGroup(std::move(params))}, std::move(defaults)) {} Tensor step(LossClosure closure) override; void save(serialize::OutputArchive& archive) const override; @@ -99,5 +97,4 @@ class TORCH_API LBFGS : public Optimizer { _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(LBFGS); } }; -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/include/torch/optim/optimizer.h b/torch/csrc/api/include/torch/optim/optimizer.h index f6599248244a2..e6115cb2f6d78 100644 --- a/torch/csrc/api/include/torch/optim/optimizer.h +++ b/torch/csrc/api/include/torch/optim/optimizer.h @@ -29,8 +29,7 @@ class InputArchive; } // namespace torch #endif // DOXYGEN_SHOULD_SKIP_THIS -namespace torch { -namespace optim { +namespace torch::optim { class TORCH_API OptimizerParamState { public: @@ -115,7 +114,7 @@ class TORCH_API Optimizer { Optimizer(Optimizer&& optimizer) = default; explicit Optimizer( - std::vector param_groups, + const std::vector& param_groups, std::unique_ptr defaults) : defaults_(std::move(defaults)) { for (const auto& param_group : param_groups) { @@ -215,5 +214,4 @@ TORCH_API serialize::InputArchive& operator>>( serialize::InputArchive& archive, Optimizer& optimizer); -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/include/torch/optim/rmsprop.h b/torch/csrc/api/include/torch/optim/rmsprop.h index 69a2e27993d5b..7b6b9dea5649f 100644 --- a/torch/csrc/api/include/torch/optim/rmsprop.h +++ b/torch/csrc/api/include/torch/optim/rmsprop.h @@ -9,17 +9,15 @@ #include #include #include +#include #include -namespace torch { -namespace serialize { +namespace torch::serialize { class OutputArchive; class InputArchive; -} // namespace serialize -} // namespace torch +} // namespace torch::serialize -namespace torch { -namespace optim { +namespace torch::optim { struct TORCH_API RMSpropOptions : public OptimizerCloneableOptions { @@ -59,11 +57,9 @@ struct TORCH_API RMSpropParamState class TORCH_API RMSprop : public Optimizer { public: explicit RMSprop( - std::vector param_groups, + const std::vector& param_groups, RMSpropOptions defaults = {}) - : Optimizer( - std::move(param_groups), - std::make_unique(defaults)) { + : Optimizer(param_groups, std::make_unique(defaults)) { TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); TORCH_CHECK(defaults.eps() >= 0, "Invalid epsilon value: ", defaults.eps()); TORCH_CHECK( @@ -79,7 +75,8 @@ class TORCH_API RMSprop : public Optimizer { } explicit RMSprop(std::vector params, RMSpropOptions defaults = {}) - : RMSprop({OptimizerParamGroup(std::move(params))}, defaults) {} + : RMSprop({OptimizerParamGroup(std::move(params))}, std::move(defaults)) { + } torch::Tensor step(LossClosure closure = nullptr) override; void save(serialize::OutputArchive& archive) const override; @@ -91,5 +88,4 @@ class TORCH_API RMSprop : public Optimizer { _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(RMSprop); } }; -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h b/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h index 26d324fbecce1..fdab69d3615c4 100644 --- a/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h +++ b/torch/csrc/api/include/torch/optim/schedulers/lr_scheduler.h @@ -4,8 +4,7 @@ #include -namespace torch { -namespace optim { +namespace torch::optim { class TORCH_API LRScheduler { public: @@ -35,5 +34,4 @@ class TORCH_API LRScheduler { torch::optim::Optimizer& optimizer_; }; -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/include/torch/optim/schedulers/reduce_on_plateau_scheduler.h b/torch/csrc/api/include/torch/optim/schedulers/reduce_on_plateau_scheduler.h index ae8892ff4fda6..17c89816d79d3 100644 --- a/torch/csrc/api/include/torch/optim/schedulers/reduce_on_plateau_scheduler.h +++ b/torch/csrc/api/include/torch/optim/schedulers/reduce_on_plateau_scheduler.h @@ -5,14 +5,9 @@ #include -#include - #include -#include - -namespace torch { -namespace optim { +namespace torch::optim { class TORCH_API ReduceLROnPlateauScheduler { public: @@ -37,28 +32,28 @@ class TORCH_API ReduceLROnPlateauScheduler { private: void reset(); void reduce_lr(int epoch); - bool in_cooldown(); + bool in_cooldown() const; bool is_better(float a); void init_is_better( SchedulerMode mode, double threshold, ThresholdMode threshold_mode); + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) Optimizer& optimizer; - SchedulerMode mode; - float mode_worse; + SchedulerMode mode{}; + float mode_worse{}; float factor; int patience; - double threshold; - ThresholdMode threshold_mode; - int cooldown; - int cooldown_counter; + double threshold{}; + ThresholdMode threshold_mode{}; + int cooldown{}; + int cooldown_counter{}; std::vector min_lrs; double eps; - float best; + float best{}; bool verbose; - int last_epoch; - int num_bad_epochs; + int last_epoch{}; + int num_bad_epochs{}; }; -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/include/torch/optim/schedulers/step_lr.h b/torch/csrc/api/include/torch/optim/schedulers/step_lr.h index 289bb4bd84e54..f46b274f518bd 100644 --- a/torch/csrc/api/include/torch/optim/schedulers/step_lr.h +++ b/torch/csrc/api/include/torch/optim/schedulers/step_lr.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace optim { +namespace torch::optim { class TORCH_API StepLR : public LRScheduler { public: @@ -18,5 +17,4 @@ class TORCH_API StepLR : public LRScheduler { const unsigned step_size_; const double gamma_; }; -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/include/torch/optim/serialize.h b/torch/csrc/api/include/torch/optim/serialize.h index 7c34450999b62..50f66782f2763 100644 --- a/torch/csrc/api/include/torch/optim/serialize.h +++ b/torch/csrc/api/include/torch/optim/serialize.h @@ -10,8 +10,7 @@ #include #include -namespace torch { -namespace optim { +namespace torch::optim { namespace detail { // Utility function to save state template @@ -24,7 +23,7 @@ void serialize( std::string tensorimpl_key = std::to_string(reinterpret_cast(item.first)); const DerivedOptimizerParamState& curr_state = - static_cast(*(item.second.get())); + static_cast(*(item.second)); curr_state.serialize(param_state_archive); archive.write(tensorimpl_key, param_state_archive); } @@ -41,6 +40,7 @@ void serialize( archive.read(tensorimpl_key, param_state_archive); DerivedOptimizerParamState param_state; param_state.serialize(param_state_archive); + // NOLINTNEXTLINE(performance-no-int-to-ptr) state[reinterpret_cast(std::stoull(tensorimpl_key))] = std::make_unique(param_state); } @@ -193,6 +193,7 @@ void serialize(serialize::InputArchive& archive, Optimizer& optimizer) { for (const auto idx : c10::irange(params.size())) { auto param_group_old_key = + // NOLINTNEXTLINE(performance-no-int-to-ptr) reinterpret_cast(std::stoull(param_group_old_keys[idx])); if (saved_state.find(param_group_old_key) != saved_state.end()) { optimizer.state()[params[idx].unsafeGetTensorImpl()] = @@ -282,16 +283,16 @@ std::deque list_to_deque(const c10::List& list) { archive.write(#name, ivalue); \ } -#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(T, name) \ - { \ - c10::IValue ivalue; \ - bool exists = archive.try_read(#name, ivalue); \ - if (exists) { \ - name(ivalue.to()); \ - } else { \ - bool is_tensor_type = std::is_base_of::value; \ - TORCH_INTERNAL_ASSERT(is_tensor_type); \ - } \ +#define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG(T, name) \ + { \ + c10::IValue ivalue; \ + bool exists = archive.try_read(#name, ivalue); \ + if (exists) { \ + name(ivalue.to()); \ + } else { \ + constexpr bool is_tensor_type = std::is_base_of_v; \ + TORCH_INTERNAL_ASSERT(is_tensor_type); \ + } \ } #define _TORCH_OPTIM_DESERIALIZE_TORCH_ARG_OPTIONAL(T, name) \ @@ -311,5 +312,4 @@ std::deque list_to_deque(const c10::List& list) { name(list_to_deque(list)); \ } -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/include/torch/optim/sgd.h b/torch/csrc/api/include/torch/optim/sgd.h index 85e9aba7ba48f..34896fb15653d 100644 --- a/torch/csrc/api/include/torch/optim/sgd.h +++ b/torch/csrc/api/include/torch/optim/sgd.h @@ -10,15 +10,12 @@ #include #include -namespace torch { -namespace serialize { +namespace torch::serialize { class OutputArchive; class InputArchive; -} // namespace serialize -} // namespace torch +} // namespace torch::serialize -namespace torch { -namespace optim { +namespace torch::optim { struct TORCH_API SGDOptions : public OptimizerCloneableOptions { SGDOptions(double lr); @@ -53,11 +50,9 @@ struct TORCH_API SGDParamState class TORCH_API SGD : public Optimizer { public: explicit SGD( - std::vector param_groups, + const std::vector& param_groups, SGDOptions defaults) - : Optimizer( - std::move(param_groups), - std::make_unique(defaults)) { + : Optimizer(param_groups, std::make_unique(defaults)) { TORCH_CHECK(defaults.lr() >= 0, "Invalid learning rate: ", defaults.lr()); TORCH_CHECK( defaults.momentum() >= 0, @@ -74,7 +69,7 @@ class TORCH_API SGD : public Optimizer { } explicit SGD(std::vector params, SGDOptions defaults) - : SGD({OptimizerParamGroup(std::move(params))}, defaults) {} + : SGD({OptimizerParamGroup(std::move(params))}, std::move(defaults)) {} torch::Tensor step(LossClosure closure = nullptr) override; @@ -87,5 +82,4 @@ class TORCH_API SGD : public Optimizer { _TORCH_OPTIM_SERIALIZE_WITH_TEMPLATE_ARG(SGD); } }; -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/include/torch/ordered_dict.h b/torch/csrc/api/include/torch/ordered_dict.h index 31a2ab65131c1..ab8bf851263a1 100644 --- a/torch/csrc/api/include/torch/ordered_dict.h +++ b/torch/csrc/api/include/torch/ordered_dict.h @@ -349,7 +349,7 @@ Value& OrderedDict::operator[](const Key& key) { if (auto* value = find(key)) { return *value; } - AT_ERROR(key_description_, " '", key, "' is not defined"); + TORCH_CHECK(false, key_description_, " '", key, "' is not defined"); } template @@ -357,7 +357,7 @@ const Value& OrderedDict::operator[](const Key& key) const { if (auto* value = find(key)) { return *value; } - AT_ERROR(key_description_, " '", key, "' is not defined"); + TORCH_CHECK(false, key_description_, " '", key, "' is not defined"); } template diff --git a/torch/csrc/api/include/torch/python.h b/torch/csrc/api/include/torch/python.h index cc9d6a51a6de4..1d5ec77df13a3 100644 --- a/torch/csrc/api/include/torch/python.h +++ b/torch/csrc/api/include/torch/python.h @@ -21,8 +21,7 @@ #include #include -namespace torch { -namespace python { +namespace torch::python { namespace detail { inline Device py_object_to_device(py::object object) { PyObject* obj = object.ptr(); @@ -83,7 +82,9 @@ void bind_cpp_module_wrapper( // which replaces its methods with those of the C++ module. wrapper_class.attr("__init__") = py::cpp_function( [cpp_module, cpp_class]( - py::object self, py::args args, py::kwargs kwargs) { + const py::object& self, + const py::args& args, + const py::kwargs& kwargs) { cpp_module.attr("__init__")(self, cpp_class(*args, **kwargs)); }, py::is_method(wrapper_class)); @@ -141,7 +142,7 @@ py::class_ add_module_bindings( "_modules", [](ModuleType& module) { return module.named_children(); }) .def("modules", [](ModuleType& module) { return module.modules(); }) .def("named_modules", - [](ModuleType& module, py::object /* unused */, std::string prefix, bool remove_duplicate /* unused */) { + [](ModuleType& module, const py::object& /* unused */, std::string prefix, bool remove_duplicate /* unused */) { return module.named_modules(std::move(prefix)); }, py::arg("memo") = py::none(), @@ -163,8 +164,8 @@ py::class_ add_module_bindings( py::arg("non_blocking") = false) .def("to", [](ModuleType& module, - py::object device, - py::object dtype, + const py::object& device, + const py::object& dtype, bool non_blocking) { if (device.is_none()) { module.to(detail::py_object_to_dtype(dtype), non_blocking); @@ -257,5 +258,4 @@ detail::PyModuleClass bind_module( .def("forward", &ModuleType::forward) .def("__call__", &ModuleType::forward); } -} // namespace python -} // namespace torch +} // namespace torch::python diff --git a/torch/csrc/api/include/torch/python/init.h b/torch/csrc/api/include/torch/python/init.h index a52857985af3a..03edca27f4705 100644 --- a/torch/csrc/api/include/torch/python/init.h +++ b/torch/csrc/api/include/torch/python/init.h @@ -2,9 +2,7 @@ #include -namespace torch { -namespace python { +namespace torch::python { /// Initializes Python bindings for the C++ frontend. void init_bindings(PyObject* module); -} // namespace python -} // namespace torch +} // namespace torch::python diff --git a/torch/csrc/api/include/torch/serialize/input-archive.h b/torch/csrc/api/include/torch/serialize/input-archive.h index 3650cfcfea23f..f399ac63d5e7e 100644 --- a/torch/csrc/api/include/torch/serialize/input-archive.h +++ b/torch/csrc/api/include/torch/serialize/input-archive.h @@ -22,8 +22,7 @@ struct Module; } // namespace jit } // namespace torch -namespace torch { -namespace serialize { +namespace torch::serialize { /// A recursive representation of tensors that can be deserialized from a file /// or stream. In most cases, users should not have to interact with this class, @@ -113,5 +112,4 @@ class TORCH_API InputArchive final { jit::Module module_; std::string hierarchy_prefix_; }; -} // namespace serialize -} // namespace torch +} // namespace torch::serialize diff --git a/torch/csrc/api/include/torch/serialize/output-archive.h b/torch/csrc/api/include/torch/serialize/output-archive.h index 12e0f54971cb3..29052bfe6c687 100644 --- a/torch/csrc/api/include/torch/serialize/output-archive.h +++ b/torch/csrc/api/include/torch/serialize/output-archive.h @@ -19,8 +19,7 @@ struct Module; } // namespace jit } // namespace torch -namespace torch { -namespace serialize { +namespace torch::serialize { class TORCH_API OutputArchive final { public: explicit OutputArchive(std::shared_ptr cu); @@ -78,5 +77,4 @@ class TORCH_API OutputArchive final { std::shared_ptr cu_; jit::Module module_; }; -} // namespace serialize -} // namespace torch +} // namespace torch::serialize diff --git a/torch/csrc/api/include/torch/sparse.h b/torch/csrc/api/include/torch/sparse.h index a30e74477e365..753a07de8a6f0 100644 --- a/torch/csrc/api/include/torch/sparse.h +++ b/torch/csrc/api/include/torch/sparse.h @@ -1,7 +1,3 @@ #pragma once #include - -namespace torch { -namespace sparse {} -} // namespace torch diff --git a/torch/csrc/api/include/torch/special.h b/torch/csrc/api/include/torch/special.h index d8346e1aa1d8c..7ab96c123f4a2 100644 --- a/torch/csrc/api/include/torch/special.h +++ b/torch/csrc/api/include/torch/special.h @@ -3,8 +3,7 @@ #include #include -namespace torch { -namespace special { +namespace torch::special { /// Computes the natural logarithm of the absolute value of the gamma function /// See https://pytorch.org/docs/main/special.html#torch.special.gammaln. @@ -1401,5 +1400,4 @@ inline Tensor spherical_bessel_j0(const Tensor& x) { inline Tensor& spherical_bessel_j0_out(Tensor& y, const Tensor& x) { return torch::special_spherical_bessel_j0_out(y, x); } -} // namespace special -} // namespace torch +} // namespace torch::special diff --git a/torch/csrc/api/include/torch/types.h b/torch/csrc/api/include/torch/types.h index 850100ea69a06..3e9d0166071b0 100644 --- a/torch/csrc/api/include/torch/types.h +++ b/torch/csrc/api/include/torch/types.h @@ -38,8 +38,8 @@ namespace torch { // the `func()` function defined in `at::` namespace is always hidden. using namespace at; // NOLINT -using std::nullopt; -using std::optional; +using std::nullopt; // NOLINT +using std::optional; // NOLINT using Dtype = at::ScalarType; diff --git a/torch/csrc/api/include/torch/utils.h b/torch/csrc/api/include/torch/utils.h index 004a0064636ef..a517043fa3ff8 100644 --- a/torch/csrc/api/include/torch/utils.h +++ b/torch/csrc/api/include/torch/utils.h @@ -5,8 +5,8 @@ #include #include #include -#include +// NOLINTBEGIN(misc-unused-using-decls) namespace torch { /// A RAII, thread-local guard that disabled gradient calculation. @@ -89,7 +89,7 @@ using at::get_num_interop_threads; using at::set_num_interop_threads; // Returns true if both t1, t2 are undefined or both are defined and equal -inline bool equal_if_defined(Tensor t1, Tensor t2) { +inline bool equal_if_defined(const Tensor& t1, const Tensor& t2) { return ( (!t1.defined() && !t2.defined()) || (t1.defined() && t2.defined() && torch::equal(t1, t2))); @@ -114,3 +114,4 @@ using at::RecordFunctionGuard; using at::removeCallback; } // namespace torch +// NOLINTEND(misc-unused-using-decls) diff --git a/torch/csrc/api/src/cuda.cpp b/torch/csrc/api/src/cuda.cpp index eafbc15eed6d2..5d7624a997641 100644 --- a/torch/csrc/api/src/cuda.cpp +++ b/torch/csrc/api/src/cuda.cpp @@ -41,7 +41,8 @@ void manual_seed(uint64_t seed) { void manual_seed_all(uint64_t seed) { auto num_gpu = device_count(); for (const auto i : c10::irange(num_gpu)) { - auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(i); + auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator( + static_cast(i)); { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex()); @@ -52,12 +53,13 @@ void manual_seed_all(uint64_t seed) { void synchronize(int64_t device_index) { TORCH_CHECK(is_available(), "No CUDA GPUs are available"); - int64_t num_gpus = cuda::device_count(); + auto num_gpus = cuda::device_count(); TORCH_CHECK( - device_index == -1 || device_index < num_gpus, + device_index < 0 || static_cast(device_index) < num_gpus, "Device index out of range: ", device_index); - at::detail::getCUDAHooks().deviceSynchronize(device_index); + at::detail::getCUDAHooks().deviceSynchronize( + static_cast(device_index)); } } // namespace torch::cuda diff --git a/torch/csrc/api/src/data/datasets/mnist.cpp b/torch/csrc/api/src/data/datasets/mnist.cpp index ff9f5c351e854..3a862257b3639 100644 --- a/torch/csrc/api/src/data/datasets/mnist.cpp +++ b/torch/csrc/api/src/data/datasets/mnist.cpp @@ -9,9 +9,7 @@ #include #include -namespace torch { -namespace data { -namespace datasets { +namespace torch::data::datasets { namespace { constexpr uint32_t kTrainSize = 60000; constexpr uint32_t kTestSize = 10000; @@ -36,18 +34,20 @@ constexpr uint32_t flip_endianness(uint32_t value) { uint32_t read_int32(std::ifstream& stream) { static const bool is_little_endian = check_is_little_endian(); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t value; + uint32_t value = 0; AT_ASSERT(stream.read(reinterpret_cast(&value), sizeof value)); return is_little_endian ? flip_endianness(value) : value; } uint32_t expect_int32(std::ifstream& stream, uint32_t expected) { const auto value = read_int32(stream); - // clang-format off - TORCH_CHECK(value == expected, - "Expected to read number ", expected, " but found ", value, " instead"); - // clang-format on + TORCH_CHECK( + value == expected, + "Expected to read number ", + expected, + " but found ", + value, + " instead"); return value; } @@ -101,14 +101,15 @@ MNIST::MNIST(const std::string& root, Mode mode) targets_(read_targets(root, mode == Mode::kTrain)) {} Example<> MNIST::get(size_t index) { - return {images_[index], targets_[index]}; + return { + images_[static_cast(index)], + targets_[static_cast(index)]}; } std::optional MNIST::size() const { return images_.size(0); } -// NOLINTNEXTLINE(bugprone-exception-escape) bool MNIST::is_train() const noexcept { return images_.size(0) == kTrainSize; } @@ -121,6 +122,4 @@ const Tensor& MNIST::targets() const { return targets_; } -} // namespace datasets -} // namespace data -} // namespace torch +} // namespace torch::data::datasets diff --git a/torch/csrc/api/src/data/samplers/distributed.cpp b/torch/csrc/api/src/data/samplers/distributed.cpp index eaae80bf06954..9f240570f75ed 100644 --- a/torch/csrc/api/src/data/samplers/distributed.cpp +++ b/torch/csrc/api/src/data/samplers/distributed.cpp @@ -8,9 +8,7 @@ #include #include -namespace torch { -namespace data { -namespace samplers { +namespace torch::data::samplers { DistributedRandomSampler::DistributedRandomSampler( size_t size, @@ -22,7 +20,7 @@ DistributedRandomSampler::DistributedRandomSampler( end_index_(0), sample_index_(0) { // shuffle first time. - reset(size_); + DistributedRandomSampler::reset(size_); } std::optional> DistributedRandomSampler::next( @@ -37,7 +35,9 @@ std::optional> DistributedRandomSampler::next( } auto iter = all_indices_.begin(); - std::vector res(iter + sample_index_, iter + end); + std::vector res( + iter + static_cast(sample_index_), + iter + static_cast(end)); sample_index_ = end; return res; } @@ -162,6 +162,4 @@ size_t DistributedSequentialSampler::index() const noexcept { return sample_index_; } -} // namespace samplers -} // namespace data -} // namespace torch +} // namespace torch::data::samplers diff --git a/torch/csrc/api/src/data/samplers/random.cpp b/torch/csrc/api/src/data/samplers/random.cpp index 10c478aa38da5..dba9af5c49ec4 100644 --- a/torch/csrc/api/src/data/samplers/random.cpp +++ b/torch/csrc/api/src/data/samplers/random.cpp @@ -6,9 +6,7 @@ #include #include -namespace torch { -namespace data { -namespace samplers { +namespace torch::data::samplers { RandomSampler::RandomSampler(int64_t size, Dtype index_dtype) : indices_(torch::randperm(size, index_dtype)) {} @@ -18,7 +16,7 @@ void RandomSampler::reset(std::optional new_size) { // This allocates a new chunk of memory every time (just FYI). It should be // amortized over the entire epoch hopefully. const auto size = new_size.value_or(static_cast(indices_.numel())); - indices_ = torch::randperm(size, indices_.options()); + indices_ = torch::randperm(static_cast(size), indices_.options()); index_ = 0; } @@ -38,14 +36,14 @@ optional> RandomSampler::next(size_t batch_size) { slice = slice.to(torch::kInt64); const auto* data = slice.const_data_ptr(); std::copy(data, data + index_batch.size(), index_batch.begin()); - index_ += index_batch.size(); + index_ += static_cast(index_batch.size()); return index_batch; } void RandomSampler::save(serialize::OutputArchive& archive) const { archive.write( "index", - torch::tensor(static_cast(index_), torch::kInt64), + torch::tensor(index_, torch::kInt64), /*is_buffer=*/true); archive.write( "indices", @@ -70,6 +68,4 @@ size_t RandomSampler::index() const noexcept { return index_; } -} // namespace samplers -} // namespace data -} // namespace torch +} // namespace torch::data::samplers diff --git a/torch/csrc/api/src/data/samplers/sequential.cpp b/torch/csrc/api/src/data/samplers/sequential.cpp index 64cf0f5e0a6ba..cd906e9c866bc 100644 --- a/torch/csrc/api/src/data/samplers/sequential.cpp +++ b/torch/csrc/api/src/data/samplers/sequential.cpp @@ -6,9 +6,7 @@ #include #include -namespace torch { -namespace data { -namespace samplers { +namespace torch::data::samplers { SequentialSampler::SequentialSampler(size_t size) : size_(size) {} void SequentialSampler::reset(std::optional new_size) { @@ -50,6 +48,4 @@ size_t SequentialSampler::index() const noexcept { return index_; } -} // namespace samplers -} // namespace data -} // namespace torch +} // namespace torch::data::samplers diff --git a/torch/csrc/api/src/data/samplers/stream.cpp b/torch/csrc/api/src/data/samplers/stream.cpp index bce63f13eae56..2281e8b9329d8 100644 --- a/torch/csrc/api/src/data/samplers/stream.cpp +++ b/torch/csrc/api/src/data/samplers/stream.cpp @@ -6,9 +6,7 @@ #include -namespace torch { -namespace data { -namespace samplers { +namespace torch::data::samplers { BatchSize::BatchSize(size_t size) : size_(size) {} size_t BatchSize::size() const noexcept { @@ -56,6 +54,4 @@ void StreamSampler::load(serialize::InputArchive& archive) { examples_retrieved_so_far_ = tensor.item(); } -} // namespace samplers -} // namespace data -} // namespace torch +} // namespace torch::data::samplers diff --git a/torch/csrc/api/src/nn/init.cpp b/torch/csrc/api/src/nn/init.cpp index 8919f0fdfcf62..d4e4025f62e54 100644 --- a/torch/csrc/api/src/nn/init.cpp +++ b/torch/csrc/api/src/nn/init.cpp @@ -1,6 +1,5 @@ #include -#include #include #include @@ -55,11 +54,11 @@ double calculate_kaiming_std( double calculate_gain(NonlinearityType nonlinearity, double param) { if (std::holds_alternative(nonlinearity)) { - return 5.0 / 3.0; // NOLINT + return 5.0 / 3.0; } else if (std::holds_alternative(nonlinearity)) { - return std::sqrt(2.0); // NOLINT + return std::sqrt(2.0); } else if (std::holds_alternative(nonlinearity)) { - return std::sqrt(2.0 / (1 + pow(param, 2))); // NOLINT + return std::sqrt(2.0 / (1 + pow(param, 2))); } return 1.0; @@ -83,6 +82,7 @@ Tensor dirac_(Tensor tensor) { tensor.zero_(); for (const auto d : c10::irange(min_dim)) { + // NOLINTNEXTLINE(bugprone-switch-missing-default-case) switch (tensor.ndimension()) { case 3: // Temporal convolution tensor[d][d][sizes[2] / 2] = 1; @@ -134,7 +134,7 @@ Tensor orthogonal_(Tensor tensor, double gain) { } // Compute the qr factorization - auto [q, r] = torch::linalg::qr(flattened); + auto [q, r] = torch::linalg_qr(flattened); // Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf auto d = torch::diag(r, 0); auto ph = d.sign(); @@ -158,7 +158,7 @@ Tensor sparse_(Tensor tensor, double sparsity, double std) { const auto rows = tensor.size(0); const auto columns = tensor.size(1); - const int64_t num_zeros = std::ceil(sparsity * rows); + const int64_t num_zeros = std::ceil(sparsity * static_cast(rows)); tensor.normal_(0, std); for (const auto column : c10::irange(columns)) { auto row_indices = torch::randperm(rows, tensor.options().dtype(kLong)); @@ -207,16 +207,16 @@ Tensor xavier_normal_(Tensor tensor, double gain) { NoGradGuard guard; Fan fan(tensor); - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - const auto std = gain * std::sqrt(2.0 / (fan.in + fan.out)); + const auto std = + gain * std::sqrt(2.0 / static_cast(fan.in + fan.out)); return tensor.normal_(0, std); } Tensor xavier_uniform_(Tensor tensor, double gain) { NoGradGuard guard; Fan fan(tensor); - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - const auto std = gain * std::sqrt(2.0 / (fan.in + fan.out)); + const auto std = + gain * std::sqrt(2.0 / static_cast(fan.in + fan.out)); // Calculate uniform bounds from standard deviation with const auto a = std::sqrt(3.0) * std; return tensor.uniform_(-a, a); @@ -243,7 +243,7 @@ std::tuple _calculate_fan_in_and_fan_out( } else { const auto num_input_fmaps = tensor.size(1); const auto num_output_fmaps = tensor.size(0); - auto receptive_field_size = 1; + int64_t receptive_field_size = 1; if (tensor.dim() > 2) { receptive_field_size = tensor[0][0].numel(); } diff --git a/torch/csrc/api/src/nn/module.cpp b/torch/csrc/api/src/nn/module.cpp index 78a31eea36ac0..563ed4789cb12 100644 --- a/torch/csrc/api/src/nn/module.cpp +++ b/torch/csrc/api/src/nn/module.cpp @@ -34,6 +34,7 @@ Module::Module() : parameters_("Parameter"), buffers_("Buffer"), children_("Submodule") {} Module::Module(std::string name) : Module() { + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) name_ = std::move(name); } @@ -64,7 +65,8 @@ const std::string& Module::name() const noexcept { std::shared_ptr Module::clone( const std::optional& device) const { - AT_ERROR( + TORCH_CHECK( + false, "clone() has not been implemented for ", name(), ". Subclass torch::nn::Cloneable<", @@ -378,7 +380,8 @@ std::shared_ptr Module::shared_from_this_checked() const { try { ptr = shared_from_this(); } catch (const std::bad_weak_ptr&) { - AT_ERROR( + TORCH_CHECK( + false, "It looks like you attempted to retrieve your top-level module " "as a shared_ptr, but it is not stored in a shared_ptr. " "Use std::make_shared<", diff --git a/torch/csrc/api/src/nn/modules/_functions.cpp b/torch/csrc/api/src/nn/modules/_functions.cpp index 10dba0f0907a9..3bd956098f2ce 100644 --- a/torch/csrc/api/src/nn/modules/_functions.cpp +++ b/torch/csrc/api/src/nn/modules/_functions.cpp @@ -67,7 +67,8 @@ Variable CrossMapLRN2d::forward( ctx->saved_data["scale"] .toTensor() .mul_( - ctx->saved_data["alpha"].toDouble() / ctx->saved_data["size"].toInt()) + ctx->saved_data["alpha"].toDouble() / + static_cast(ctx->saved_data["size"].toInt())) .add_(ctx->saved_data["k"].toInt()); torch::pow_out( @@ -83,7 +84,7 @@ Variable CrossMapLRN2d::forward( variable_list CrossMapLRN2d::backward( AutogradContext* ctx, variable_list grad_outputs) { - auto grad_output = grad_outputs[0]; + auto const& grad_output = grad_outputs[0]; auto input = ctx->get_saved_variables()[0]; auto output = ctx->get_saved_variables()[1]; auto grad_input = torch::empty({0}, grad_output.options()); @@ -100,7 +101,8 @@ variable_list CrossMapLRN2d::backward( input.options()); auto accum_ratio = torch::empty({input_height, input_width}, input.options()); double cache_ratio_value = 2 * ctx->saved_data["alpha"].toDouble() * - ctx->saved_data["beta"].toDouble() / ctx->saved_data["size"].toInt(); + ctx->saved_data["beta"].toDouble() / + static_cast(ctx->saved_data["size"].toInt()); int64_t inversePrePad = static_cast( ctx->saved_data["size"].toInt() - (ctx->saved_data["size"].toInt() - 1) / 2); diff --git a/torch/csrc/api/src/nn/modules/adaptive.cpp b/torch/csrc/api/src/nn/modules/adaptive.cpp index 491d6269ad261..55f004e71b1b9 100644 --- a/torch/csrc/api/src/nn/modules/adaptive.cpp +++ b/torch/csrc/api/src/nn/modules/adaptive.cpp @@ -18,8 +18,7 @@ AdaptiveLogSoftmaxWithLossImpl::AdaptiveLogSoftmaxWithLossImpl( shortlist_size(0), n_clusters(0), head_size(0) { - // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) - reset(); + AdaptiveLogSoftmaxWithLossImpl::reset(); } void AdaptiveLogSoftmaxWithLossImpl::reset() { @@ -43,7 +42,7 @@ void AdaptiveLogSoftmaxWithLossImpl::reset() { cutoffs.push_back(options.n_classes()); shortlist_size = cutoffs[0]; - n_clusters = cutoffs.size() - 1; + n_clusters = static_cast(cutoffs.size() - 1); head_size = shortlist_size + n_clusters; head = this->register_module( @@ -54,7 +53,8 @@ void AdaptiveLogSoftmaxWithLossImpl::reset() { for (const auto i : c10::irange(n_clusters)) { int64_t hsz = static_cast(std::floor( - options.in_features() / std::pow(options.div_value(), (i + 1)))); + static_cast(options.in_features()) / + std::pow(options.div_value(), (i + 1)))); int64_t osz = cutoffs[i + 1] - cutoffs[i]; Sequential projection( @@ -129,7 +129,7 @@ ASMoutput AdaptiveLogSoftmaxWithLossImpl::forward( const Tensor cluster_output = tail[i - 1]->as()->forward(input_subset); - int64_t cluster_index = shortlist_size + i - 1; + int64_t cluster_index = shortlist_size + static_cast(i) - 1; gather_inds.index_fill_(0, row_indices, cluster_index); diff --git a/torch/csrc/api/src/nn/modules/container/functional.cpp b/torch/csrc/api/src/nn/modules/container/functional.cpp index 215ba8739b943..e615592e3f4f3 100644 --- a/torch/csrc/api/src/nn/modules/container/functional.cpp +++ b/torch/csrc/api/src/nn/modules/container/functional.cpp @@ -5,8 +5,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { FunctionalImpl::FunctionalImpl(Function function) : function_(std::move(function)) {} @@ -27,5 +26,4 @@ Tensor FunctionalImpl::operator()(Tensor input) { bool FunctionalImpl::is_serializable() const { return false; } -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/src/nn/modules/dropout.cpp b/torch/csrc/api/src/nn/modules/dropout.cpp index 2bbd2073f5fb0..2b7c5aa3a289e 100644 --- a/torch/csrc/api/src/nn/modules/dropout.cpp +++ b/torch/csrc/api/src/nn/modules/dropout.cpp @@ -5,10 +5,8 @@ #include -#include #include #include -#include namespace F = torch::nn::functional; diff --git a/torch/csrc/api/src/nn/modules/embedding.cpp b/torch/csrc/api/src/nn/modules/embedding.cpp index 150f9ded397ec..f8659b527629f 100644 --- a/torch/csrc/api/src/nn/modules/embedding.cpp +++ b/torch/csrc/api/src/nn/modules/embedding.cpp @@ -4,18 +4,15 @@ #include #include -#include #include #include -#include namespace F = torch::nn::functional; namespace torch::nn { EmbeddingImpl::EmbeddingImpl(EmbeddingOptions options_) : options(std::move(options_)) { - // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) - reset(); + EmbeddingImpl::reset(); } void EmbeddingImpl::reset() { diff --git a/torch/csrc/api/src/nn/modules/linear.cpp b/torch/csrc/api/src/nn/modules/linear.cpp index 56933cd468e4a..60a63076925f9 100644 --- a/torch/csrc/api/src/nn/modules/linear.cpp +++ b/torch/csrc/api/src/nn/modules/linear.cpp @@ -25,8 +25,7 @@ Tensor IdentityImpl::forward(const Tensor& input) { // ============================================================================ LinearImpl::LinearImpl(const LinearOptions& options_) : options(options_) { - // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) - reset(); + LinearImpl::reset(); } void LinearImpl::reset() { diff --git a/torch/csrc/api/src/nn/modules/normalization.cpp b/torch/csrc/api/src/nn/modules/normalization.cpp index 4bc332395799b..f2e10e7facd52 100644 --- a/torch/csrc/api/src/nn/modules/normalization.cpp +++ b/torch/csrc/api/src/nn/modules/normalization.cpp @@ -13,8 +13,7 @@ namespace torch::nn { LayerNormImpl::LayerNormImpl(LayerNormOptions options_) : options(std::move(options_)) { - // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) - reset(); + LayerNormImpl::reset(); } void LayerNormImpl::reset() { diff --git a/torch/csrc/api/src/nn/modules/transformer.cpp b/torch/csrc/api/src/nn/modules/transformer.cpp index 53f7b83cdc35b..455b81b91ae9b 100644 --- a/torch/csrc/api/src/nn/modules/transformer.cpp +++ b/torch/csrc/api/src/nn/modules/transformer.cpp @@ -222,8 +222,7 @@ TransformerEncoderImpl::TransformerEncoderImpl( void TransformerEncoderImpl::reset() { layers = this->register_module("layers", ModuleList()); - for (const auto i : c10::irange(options.num_layers())) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(options.num_layers())) { layers->push_back(options.encoder_layer()->clone()); } @@ -289,8 +288,7 @@ TransformerDecoderImpl::TransformerDecoderImpl( void TransformerDecoderImpl::reset() { layers = this->register_module("layers", ModuleList()); - for (const auto i : c10::irange(options.num_layers())) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(options.num_layers())) { layers->push_back(options.decoder_layer()->clone()); } diff --git a/torch/csrc/api/src/nn/options/activation.cpp b/torch/csrc/api/src/nn/options/activation.cpp index 8476a4ff61a27..e6d1f9376ff98 100644 --- a/torch/csrc/api/src/nn/options/activation.cpp +++ b/torch/csrc/api/src/nn/options/activation.cpp @@ -1,7 +1,6 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { SELUOptions::SELUOptions(bool inplace) : inplace_(inplace) {} @@ -60,5 +59,4 @@ MultiheadAttentionForwardFuncOptions::MultiheadAttentionForwardFuncOptions( out_proj_bias_(std::move(out_proj_bias)) {} } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/src/nn/options/adaptive.cpp b/torch/csrc/api/src/nn/options/adaptive.cpp index 1a8fcc4dc61ed..82d3e3b50de6b 100644 --- a/torch/csrc/api/src/nn/options/adaptive.cpp +++ b/torch/csrc/api/src/nn/options/adaptive.cpp @@ -1,7 +1,6 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { AdaptiveLogSoftmaxWithLossOptions::AdaptiveLogSoftmaxWithLossOptions( int64_t in_features, @@ -11,5 +10,4 @@ AdaptiveLogSoftmaxWithLossOptions::AdaptiveLogSoftmaxWithLossOptions( n_classes_(n_classes), cutoffs_(std::move(cutoffs)) {} -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/src/nn/options/batchnorm.cpp b/torch/csrc/api/src/nn/options/batchnorm.cpp index a0f7f22638985..3d608742bc618 100644 --- a/torch/csrc/api/src/nn/options/batchnorm.cpp +++ b/torch/csrc/api/src/nn/options/batchnorm.cpp @@ -1,10 +1,8 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { BatchNormOptions::BatchNormOptions(int64_t num_features) : num_features_(num_features) {} -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/src/nn/options/conv.cpp b/torch/csrc/api/src/nn/options/conv.cpp index cda9480369a0f..fccb6240cfe90 100644 --- a/torch/csrc/api/src/nn/options/conv.cpp +++ b/torch/csrc/api/src/nn/options/conv.cpp @@ -1,7 +1,6 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { template struct ConvOptions<1>; template struct ConvOptions<2>; @@ -19,5 +18,4 @@ template struct ConvTransposeFuncOptions<3>; } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/src/nn/options/dropout.cpp b/torch/csrc/api/src/nn/options/dropout.cpp index a12ea3bfcf4e4..bb7443373820a 100644 --- a/torch/csrc/api/src/nn/options/dropout.cpp +++ b/torch/csrc/api/src/nn/options/dropout.cpp @@ -1,9 +1,7 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { DropoutOptions::DropoutOptions(double p) : p_(p) {} -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/src/nn/options/embedding.cpp b/torch/csrc/api/src/nn/options/embedding.cpp index 3b9509d19a026..d5c2fc0b2b6fb 100644 --- a/torch/csrc/api/src/nn/options/embedding.cpp +++ b/torch/csrc/api/src/nn/options/embedding.cpp @@ -1,7 +1,6 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { EmbeddingOptions::EmbeddingOptions( int64_t num_embeddings, int64_t embedding_dim) @@ -11,5 +10,4 @@ EmbeddingBagOptions::EmbeddingBagOptions( int64_t num_embeddings, int64_t embedding_dim) : num_embeddings_(num_embeddings), embedding_dim_(embedding_dim) {} -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/src/nn/options/instancenorm.cpp b/torch/csrc/api/src/nn/options/instancenorm.cpp index 405c264195545..4d878282fc777 100644 --- a/torch/csrc/api/src/nn/options/instancenorm.cpp +++ b/torch/csrc/api/src/nn/options/instancenorm.cpp @@ -1,10 +1,8 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { InstanceNormOptions::InstanceNormOptions(int64_t num_features) : num_features_(num_features) {} -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/src/nn/options/linear.cpp b/torch/csrc/api/src/nn/options/linear.cpp index 67e167ee11710..3087974141d2e 100644 --- a/torch/csrc/api/src/nn/options/linear.cpp +++ b/torch/csrc/api/src/nn/options/linear.cpp @@ -1,7 +1,6 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { LinearOptions::LinearOptions(int64_t in_features, int64_t out_features) : in_features_(in_features), out_features_(out_features) {} @@ -27,5 +26,4 @@ UnflattenOptions::UnflattenOptions(std::string dimname, namedshape_t namedshape) dimname_(std::move(dimname)), namedshape_(std::move(namedshape)) {} -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/src/nn/options/normalization.cpp b/torch/csrc/api/src/nn/options/normalization.cpp index 3b1600c6a69b7..6131ae8dcd08c 100644 --- a/torch/csrc/api/src/nn/options/normalization.cpp +++ b/torch/csrc/api/src/nn/options/normalization.cpp @@ -1,7 +1,6 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { LayerNormOptions::LayerNormOptions(std::vector normalized_shape) : normalized_shape_(std::move(normalized_shape)) {} @@ -22,5 +21,4 @@ GroupNormFuncOptions::GroupNormFuncOptions(int64_t num_groups) } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/src/nn/options/padding.cpp b/torch/csrc/api/src/nn/options/padding.cpp index 30b62adddd273..8f4777b00d10a 100644 --- a/torch/csrc/api/src/nn/options/padding.cpp +++ b/torch/csrc/api/src/nn/options/padding.cpp @@ -1,7 +1,6 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { template struct ReflectionPadOptions<1>; template struct ReflectionPadOptions<2>; @@ -21,5 +20,4 @@ PadFuncOptions::PadFuncOptions(std::vector pad) } // namespace functional -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/src/nn/options/pooling.cpp b/torch/csrc/api/src/nn/options/pooling.cpp index bbe27592a53c4..97ff5a03e6979 100644 --- a/torch/csrc/api/src/nn/options/pooling.cpp +++ b/torch/csrc/api/src/nn/options/pooling.cpp @@ -1,7 +1,6 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { template struct AvgPoolOptions<1>; template struct AvgPoolOptions<2>; @@ -27,5 +26,4 @@ template struct LPPoolOptions<1>; template struct LPPoolOptions<2>; template struct LPPoolOptions<3>; -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/src/nn/options/rnn.cpp b/torch/csrc/api/src/nn/options/rnn.cpp index b948c0afac1d1..3674bc525dedf 100644 --- a/torch/csrc/api/src/nn/options/rnn.cpp +++ b/torch/csrc/api/src/nn/options/rnn.cpp @@ -1,7 +1,6 @@ #include -namespace torch { -namespace nn { +namespace torch::nn { namespace detail { @@ -45,5 +44,4 @@ LSTMCellOptions::LSTMCellOptions(int64_t input_size, int64_t hidden_size) GRUCellOptions::GRUCellOptions(int64_t input_size, int64_t hidden_size) : input_size_(input_size), hidden_size_(hidden_size) {} -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/src/nn/options/transformer.cpp b/torch/csrc/api/src/nn/options/transformer.cpp index 2afb9bda543c4..7a3d53a18d0eb 100644 --- a/torch/csrc/api/src/nn/options/transformer.cpp +++ b/torch/csrc/api/src/nn/options/transformer.cpp @@ -2,8 +2,7 @@ #include #include -namespace torch { -namespace nn { +namespace torch::nn { TransformerEncoderLayerOptions::TransformerEncoderLayerOptions( int64_t d_model, @@ -48,5 +47,4 @@ TransformerOptions::TransformerOptions( num_encoder_layers_(num_encoder_layers), num_decoder_layers_(num_decoder_layers) {} -} // namespace nn -} // namespace torch +} // namespace torch::nn diff --git a/torch/csrc/api/src/optim/adagrad.cpp b/torch/csrc/api/src/optim/adagrad.cpp index 45b9da08b2c57..2279af7898b19 100644 --- a/torch/csrc/api/src/optim/adagrad.cpp +++ b/torch/csrc/api/src/optim/adagrad.cpp @@ -10,8 +10,7 @@ #include -namespace torch { -namespace optim { +namespace torch::optim { AdagradOptions::AdagradOptions(double lr) : lr_(lr) {} @@ -151,5 +150,4 @@ void Adagrad::load(serialize::InputArchive& archive) { } } } -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/src/optim/adam.cpp b/torch/csrc/api/src/optim/adam.cpp index 10a9a258a600a..924ba504d8f31 100644 --- a/torch/csrc/api/src/optim/adam.cpp +++ b/torch/csrc/api/src/optim/adam.cpp @@ -11,8 +11,7 @@ #include #include -namespace torch { -namespace optim { +namespace torch::optim { AdamOptions::AdamOptions(double lr) : lr_(lr) {} @@ -181,5 +180,4 @@ void Adam::load(serialize::InputArchive& archive) { } } } -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/src/optim/adamw.cpp b/torch/csrc/api/src/optim/adamw.cpp index 7ba7b50877cd7..b6928ae168ce9 100644 --- a/torch/csrc/api/src/optim/adamw.cpp +++ b/torch/csrc/api/src/optim/adamw.cpp @@ -11,8 +11,7 @@ #include #include -namespace torch { -namespace optim { +namespace torch::optim { AdamWOptions::AdamWOptions(double lr) : lr_(lr) {} @@ -182,5 +181,4 @@ void AdamW::load(serialize::InputArchive& archive) { } } } -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/src/optim/lbfgs.cpp b/torch/csrc/api/src/optim/lbfgs.cpp index dbf17f718614a..db81239552dc6 100644 --- a/torch/csrc/api/src/optim/lbfgs.cpp +++ b/torch/csrc/api/src/optim/lbfgs.cpp @@ -13,8 +13,7 @@ #include #include -namespace torch { -namespace optim { +namespace torch::optim { LBFGSOptions::LBFGSOptions(double lr) : lr_(lr) {} @@ -56,7 +55,7 @@ void LBFGSOptions::set_lr(const double lr) { } template -bool if_container_equal(T lhs, T rhs) { +static bool if_container_equal(T lhs, T rhs) { if (!(lhs.size() == rhs.size())) return false; for (const auto i : c10::irange(lhs.size())) { @@ -132,7 +131,7 @@ Tensor LBFGS::_gather_flat_grad() { int64_t LBFGS::_numel() { if (_numel_cache == std::nullopt) { - auto res = 0; + int64_t res = 0; for (const auto& p : param_groups_.at(0).params()) { res += p.numel(); } @@ -142,7 +141,7 @@ int64_t LBFGS::_numel() { } void LBFGS::_add_grad(const double step_size, const Tensor& update) { - auto offset = 0; + int64_t offset = 0; for (auto& p : param_groups_.at(0).params()) { auto numel = p.numel(); // view as to avoid deprecated pointwise semantics @@ -176,8 +175,7 @@ std::tuple LBFGS::_directional_evaluate( double t, const Tensor& d) { _add_grad(t, d); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double loss; + double loss = 0; { torch::AutoGradMode enable_grad(true); loss = closure().item(); @@ -194,17 +192,11 @@ static double _cubic_interpolate( double x2, double f2, double g2, - std::optional> bounds = std::nullopt) { + std::optional> bounds = std::nullopt) { // ported from https://github.com/torch/optim/blob/master/polyinterp.lua // Compute bounds of interpolation area - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double xmin_bound, xmax_bound; - if (bounds != std::nullopt) { - std::tie(xmin_bound, xmax_bound) = *bounds; - } else { - std::tie(xmin_bound, xmax_bound) = - (x1 <= x2) ? std::make_tuple(x1, x2) : std::make_tuple(x2, x1); - } + auto [xmin_bound, xmax_bound] = + (bounds != std::nullopt) ? (*bounds) : std::minmax({x1, x2}); // Code for most common case: cubic interpolation of 2 points // w/ function and derivative values for both // Solution in this case (where x2 is the farthest point): @@ -215,12 +207,9 @@ static double _cubic_interpolate( auto d1 = (g1 + g2) - (3 * (f1 - f2) / (x1 - x2)); auto d2_square = std::pow(d1, 2) - g1 * g2; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double d2; if (d2_square >= 0) { - d2 = std::sqrt(d2_square); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - double min_pos; + auto d2 = std::sqrt(d2_square); + double min_pos = 0; if (x1 <= x2) { min_pos = x2 - ((x2 - x1) * ((g2 + d2 - d1) / (g2 - g1 + 2 * d2))); } else { @@ -304,7 +293,7 @@ static std::tuple _strong_wolfe( t, f_new, val(gtd_new), - std::make_tuple(min_step, max_step)); + std::make_pair(min_step, max_step)); // next step t_prev = tmp; f_prev = f_new; @@ -653,5 +642,4 @@ void LBFGS::load(serialize::InputArchive& archive) { std::move(state); } } -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/src/optim/optimizer.cpp b/torch/csrc/api/src/optim/optimizer.cpp index b5288dea5cff0..c5cac1243284a 100644 --- a/torch/csrc/api/src/optim/optimizer.cpp +++ b/torch/csrc/api/src/optim/optimizer.cpp @@ -7,8 +7,7 @@ #include #include -namespace torch { -namespace optim { +namespace torch::optim { bool OptimizerParamGroup::has_options() const { return options_ != nullptr; @@ -16,12 +15,12 @@ bool OptimizerParamGroup::has_options() const { OptimizerOptions& OptimizerParamGroup::options() { TORCH_CHECK(has_options()); - return *options_.get(); + return *options_; } const OptimizerOptions& OptimizerParamGroup::options() const { TORCH_CHECK(has_options()); - return *options_.get(); + return *options_; } void OptimizerParamGroup::set_options( @@ -154,11 +153,11 @@ size_t Optimizer::size() const noexcept { } OptimizerOptions& Optimizer::defaults() noexcept { - return *defaults_.get(); + return *defaults_; } const OptimizerOptions& Optimizer::defaults() const noexcept { - return *defaults_.get(); + return *defaults_; } std::vector& Optimizer::param_groups() noexcept { @@ -199,5 +198,4 @@ serialize::InputArchive& operator>>( return archive; } -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/src/optim/rmsprop.cpp b/torch/csrc/api/src/optim/rmsprop.cpp index 4a55bdf00abce..b6a12dafb3f24 100644 --- a/torch/csrc/api/src/optim/rmsprop.cpp +++ b/torch/csrc/api/src/optim/rmsprop.cpp @@ -9,8 +9,7 @@ #include -namespace torch { -namespace optim { +namespace torch::optim { RMSpropOptions::RMSpropOptions(double lr) : lr_(lr) {} @@ -178,5 +177,4 @@ void RMSprop::load(serialize::InputArchive& archive) { } } } -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/src/optim/schedulers/lr_scheduler.cpp b/torch/csrc/api/src/optim/schedulers/lr_scheduler.cpp index 1c2aa1b91eef6..b29f4ce6e5826 100644 --- a/torch/csrc/api/src/optim/schedulers/lr_scheduler.cpp +++ b/torch/csrc/api/src/optim/schedulers/lr_scheduler.cpp @@ -1,8 +1,7 @@ #include #include -namespace torch { -namespace optim { +namespace torch::optim { LRScheduler::LRScheduler(torch::optim::Optimizer& optimizer) : optimizer_(optimizer) {} @@ -39,5 +38,4 @@ std::vector LRScheduler::get_current_lrs() const { return learnings_rates; } -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/src/optim/schedulers/reduce_on_plateau_scheduler.cpp b/torch/csrc/api/src/optim/schedulers/reduce_on_plateau_scheduler.cpp index 53734b2eb99b9..3bbd65bccfa7e 100644 --- a/torch/csrc/api/src/optim/schedulers/reduce_on_plateau_scheduler.cpp +++ b/torch/csrc/api/src/optim/schedulers/reduce_on_plateau_scheduler.cpp @@ -2,8 +2,7 @@ #include -namespace torch { -namespace optim { +namespace torch::optim { ReduceLROnPlateauScheduler::ReduceLROnPlateauScheduler( Optimizer& optimizer, @@ -74,7 +73,7 @@ void ReduceLROnPlateauScheduler::reduce_lr(int epoch) { if (verbose) { std::cout << std::setprecision(4) << "Epoch " << epoch << ": reducing learning rate of group " << i << " to " - << new_lr << std::endl; + << new_lr << '\n'; } } } @@ -87,7 +86,7 @@ void ReduceLROnPlateauScheduler::reset() { this->best = mode_worse; } -bool ReduceLROnPlateauScheduler::in_cooldown() { +bool ReduceLROnPlateauScheduler::in_cooldown() const { return cooldown_counter > 0; } @@ -119,5 +118,4 @@ void ReduceLROnPlateauScheduler::init_is_better( this->threshold_mode = threshold_mode; this->threshold = threshold; } -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/src/optim/schedulers/step_lr.cpp b/torch/csrc/api/src/optim/schedulers/step_lr.cpp index 497ebe08fed3b..dd5975c2adb27 100644 --- a/torch/csrc/api/src/optim/schedulers/step_lr.cpp +++ b/torch/csrc/api/src/optim/schedulers/step_lr.cpp @@ -1,7 +1,6 @@ #include -namespace torch { -namespace optim { +namespace torch::optim { StepLR::StepLR( torch::optim::Optimizer& optimizer, @@ -22,5 +21,4 @@ std::vector StepLR::get_lrs() { } } -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/src/optim/serialize.cpp b/torch/csrc/api/src/optim/serialize.cpp index 6473127d96f7a..ca9f3142a591c 100644 --- a/torch/csrc/api/src/optim/serialize.cpp +++ b/torch/csrc/api/src/optim/serialize.cpp @@ -9,8 +9,7 @@ #include #include -namespace torch { -namespace optim { +namespace torch::optim { void serialize( serialize::OutputArchive& archive, const std::string& key, @@ -50,5 +49,4 @@ void serialize( steps.push_back(step.item()); } } -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/src/optim/sgd.cpp b/torch/csrc/api/src/optim/sgd.cpp index 337bfeb2fa214..dc3e5002790b4 100644 --- a/torch/csrc/api/src/optim/sgd.cpp +++ b/torch/csrc/api/src/optim/sgd.cpp @@ -11,8 +11,7 @@ #include -namespace torch { -namespace optim { +namespace torch::optim { SGDOptions::SGDOptions(double lr) : lr_(lr) {} @@ -131,5 +130,4 @@ void SGD::load(serialize::InputArchive& archive) { } } } -} // namespace optim -} // namespace torch +} // namespace torch::optim diff --git a/torch/csrc/api/src/serialize/input-archive.cpp b/torch/csrc/api/src/serialize/input-archive.cpp index 2d3cfe32c1134..691ac98e42a3f 100644 --- a/torch/csrc/api/src/serialize/input-archive.cpp +++ b/torch/csrc/api/src/serialize/input-archive.cpp @@ -93,13 +93,13 @@ void InputArchive::read(const std::string& key, InputArchive& archive) { void InputArchive::load_from( const std::string& filename, std::optional device /*= std::nullopt*/) { - module_ = torch::jit::load(filename, std::move(device)); + module_ = torch::jit::load(filename, device); } void InputArchive::load_from( std::istream& stream, std::optional device /*= std::nullopt*/) { - module_ = torch::jit::load(stream, std::move(device)); + module_ = torch::jit::load(stream, device); } void InputArchive::load_from( diff --git a/torch/csrc/api/src/xpu.cpp b/torch/csrc/api/src/xpu.cpp index adbfa79c3fd70..75837b831d9c8 100644 --- a/torch/csrc/api/src/xpu.cpp +++ b/torch/csrc/api/src/xpu.cpp @@ -27,7 +27,8 @@ void manual_seed(uint64_t seed) { void manual_seed_all(uint64_t seed) { auto num_gpu = device_count(); for (const auto i : c10::irange(num_gpu)) { - auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator(i); + auto gen = at::detail::getXPUHooks().getDefaultXPUGenerator( + static_cast(i)); { // See Note [Acquire lock when using random generators] std::lock_guard lock(gen.mutex()); diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 3f24c6ecb4095..00a856925db04 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1693,8 +1693,7 @@ Tensor repeat_backward( } const auto input_dims = input_shape.size(); auto num_unsqueezed = grad.dim() - input_dims; - for (const auto i : c10::irange(num_unsqueezed)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(num_unsqueezed)) { grad = grad.sum(0, false); } @@ -6882,7 +6881,8 @@ std::tuple scatter_reduce_backward( grad_self = (self == result) * grad_distributed; grad_src = (src == value) * grad_distributed.gather(dim, index); } else { - AT_ERROR( + TORCH_CHECK( + false, "Expected 'reduce' to be one of 'sum', 'prod', 'mean', 'amax', 'amin' but got ", reduce, "."); @@ -6977,7 +6977,8 @@ std::tuple index_reduce_backward( grad_self = self_is_result * grad_distributed; grad_src = source_is_result * grad_distributed.index_select(dim, index); } else { - AT_ERROR( + TORCH_CHECK( + false, "Expected 'reduce' to be one of 'prod', 'amax', 'amin' or 'mean' but got ", reduce, "."); @@ -7045,12 +7046,9 @@ mkldnn_rnn_layer_differentiable_backward( at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor& workspace) { - const Tensor& grad_output_r = - c10::value_or_else(grad_output_r_opt, [] { return Tensor(); }); - const Tensor& grad_hy_r = - c10::value_or_else(grad_hy_r_opt, [] { return Tensor(); }); - const Tensor& grad_cy_r = - c10::value_or_else(grad_cy_r_opt, [] { return Tensor(); }); + const Tensor& grad_output_r = grad_output_r_opt.value_or(Tensor()); + const Tensor& grad_hy_r = grad_hy_r_opt.value_or(Tensor()); + const Tensor& grad_cy_r = grad_cy_r_opt.value_or(Tensor()); if (!grad_output_r.defined() && !grad_hy_r.defined() && !grad_cy_r.defined()) { return std::make_tuple( diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index cbda6552fe7a6..e270df51221bf 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -43,7 +43,7 @@ std::vector allCPUTypes() { } std::vector allCUDATypes() { - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); return allTypesForBackends({Backend::CUDA, Backend::SparseCUDA}); } @@ -52,7 +52,7 @@ std::vector allXPUTypes() { } std::vector allPrivateUser1Types() { - at::globalContext().lazyInitPrivateUse1(); + at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1); return allTypesForBackends( {Backend::PrivateUse1, Backend::SparsePrivateUse1}); } @@ -63,7 +63,8 @@ const Variable& checked_cast_variable( const char* name, int pos) { if (!t.defined()) { - AT_ERROR( + TORCH_CHECK( + false, "Expected a proper Tensor but got None (or an undefined Tensor in C++) ", "for argument #", pos, @@ -76,7 +77,8 @@ const Variable& checked_cast_variable( Variable& checked_cast_variable(Tensor& t, const char* name, int pos) { if (!t.defined()) { - AT_ERROR( + TORCH_CHECK( + false, "Expected a proper Tensor but got None (or an undefined Tensor in C++) ", "for argument #", pos, @@ -243,7 +245,7 @@ const Tensor& resize_( std::optional optional_memory_format) { auto& self_ = unpack(self, "self", 0); if (self.requires_grad()) { - AT_ERROR("cannot resize variables that require grad"); + TORCH_CHECK(false, "cannot resize variables that require grad"); } { at::AutoDispatchBelowAutograd mode; @@ -252,7 +254,7 @@ const Tensor& resize_( } if (self._fw_grad(/* level */ 0).defined()) { - AT_ERROR("cannot resize variables that has a forward grad"); + TORCH_CHECK(false, "cannot resize variables that has a forward grad"); } return self; @@ -266,7 +268,7 @@ const Tensor& resize_as_( auto& self_ = unpack(self, "self", 0); auto& the_template_ = unpack(the_template, "the_template", 1); if (self.requires_grad()) { - AT_ERROR("cannot resize variables that require grad"); + TORCH_CHECK(false, "cannot resize variables that require grad"); } { at::AutoDispatchBelowAutograd mode; @@ -279,7 +281,7 @@ const Tensor& resize_as_( // Handle fw grad if (self._fw_grad(/* level */ 0).defined()) { - AT_ERROR("cannot resize variables that has a forward grad"); + TORCH_CHECK(false, "cannot resize variables that has a forward grad"); } return self; @@ -303,7 +305,8 @@ Tensor& detach_(c10::DispatchKeySet ks, Tensor& self) { RECORD_FUNCTION("detach_", std::vector({self})); if (self.is_view()) { // See NOTE [ View + Inplace detection ] - AT_ERROR( + TORCH_CHECK( + false, "Can't detach views in-place. Use detach() instead. " "If you are using DistributedDataParallel (DDP) for training, " "and gradient_as_bucket_view is set as True, gradients are " diff --git a/torch/csrc/autograd/VariableTypeUtils.h b/torch/csrc/autograd/VariableTypeUtils.h index aec108b0126c2..e6aebfafb1adc 100644 --- a/torch/csrc/autograd/VariableTypeUtils.h +++ b/torch/csrc/autograd/VariableTypeUtils.h @@ -93,7 +93,8 @@ inline void check_inplace(at::ITensorListRef tensors, bool requires_grad) { } inline void throw_error_out_requires_grad(const char* name) { - AT_ERROR( + TORCH_CHECK( + false, name, "(): functions with out=... arguments don't support automatic differentiation, " "but one of the arguments requires grad."); diff --git a/torch/csrc/autograd/functions/basic_ops.h b/torch/csrc/autograd/functions/basic_ops.h index 87530756a3b44..d9e11b1f45fc4 100644 --- a/torch/csrc/autograd/functions/basic_ops.h +++ b/torch/csrc/autograd/functions/basic_ops.h @@ -44,7 +44,8 @@ struct TORCH_API NotImplemented : public Error { // @once_differentiable struct TORCH_API DelayedError : public Node { DelayedError(std::string msg, int64_t num_inputs) : msg(std::move(msg)) { - for (const auto _ [[maybe_unused]] : c10::irange(num_inputs)) { + for ([[maybe_unused]] const auto _ [[maybe_unused]] : + c10::irange(num_inputs)) { add_input_metadata(Node::undefined_input()); } } diff --git a/torch/csrc/autograd/functions/pybind.h b/torch/csrc/autograd/functions/pybind.h index 94b3c9c679969..4e1262271de01 100644 --- a/torch/csrc/autograd/functions/pybind.h +++ b/torch/csrc/autograd/functions/pybind.h @@ -8,8 +8,7 @@ #include #include +// NOLINTNEXTLINE(misc-unused-alias-decls) namespace py = pybind11; -namespace pybind11 { -namespace detail {} -} // namespace pybind11 +namespace pybind11::detail {} // namespace pybind11::detail diff --git a/torch/csrc/autograd/graph_task.h b/torch/csrc/autograd/graph_task.h index e4a7ae4dad18e..018beaffdaaff 100644 --- a/torch/csrc/autograd/graph_task.h +++ b/torch/csrc/autograd/graph_task.h @@ -48,6 +48,9 @@ struct GraphTask : std::enable_shared_from_this { struct Capture { Capture(const Capture&) = delete; Capture(Capture&&) = default; + Capture& operator=(const Capture&) = delete; + Capture& operator=(Capture&&) = default; + ~Capture() = default; Capture(int input_idx, int output_idx) : input_idx_(input_idx), output_idx_(output_idx) {} diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index 481cff3cc5c0c..d3f3b1bd21436 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -1017,7 +1017,7 @@ class PostProcess { ska::flat_hash_map> tid_map; auto it = out.rbegin(); - for (C10_UNUSED auto _ : c10::irange(initial_size, out.size())) { + for ([[maybe_unused]] auto _ : c10::irange(initial_size, out.size())) { const auto python_tid = std::get>((*it)->extra_fields_).python_tid_; if ((*it)->start_tid_ == NoTID && SOFT_ASSERT(E == EventType::PyCall)) { diff --git a/torch/csrc/autograd/python_function.cpp b/torch/csrc/autograd/python_function.cpp index 415de56a49095..0e83ffcc09e02 100644 --- a/torch/csrc/autograd/python_function.cpp +++ b/torch/csrc/autograd/python_function.cpp @@ -297,7 +297,7 @@ auto PyNode::compiled_autograd_should_lift() const -> bool { void PyNode::compiled_args(CompiledNodeArgs& args) { static PyObject* method_name = PyUnicode_InternFromString("_compiled_autograd_key"); - THPObjectPtr pykey(PyObject_CallMethodNoArgs(obj, method_name)); + THPObjectPtr pykey(PyObject_CallMethodObjArgs(obj, method_name, nullptr)); if (!pykey) throw_python_error(); TORCH_CHECK( @@ -733,8 +733,18 @@ static void _wrap_outputs( PyTuple_SetItem(outputs, i, obj); } else { if (is_executable) { + // If one of the grad outputs is undefined, a correctly-shaped zeros + // should be used instead. To construct these for NJT, zeros_like() must + // be used until we have factory function support. // NOLINTNEXTLINE(bugprone-unchecked-optional-access) - self->output_info.emplace_back(*wrapped_outputs[i]); + bool is_differentiable = + (non_differentiable.count( + wrapped_outputs[i]->unsafeGetTensorImpl()) == 0 && + isDifferentiableType(wrapped_outputs[i]->scalar_type())); + bool use_zeros_like = is_differentiable && num_outputs > 1 && + wrapped_outputs[i]->is_nested(); + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + self->output_info.emplace_back(*wrapped_outputs[i], use_zeros_like); } // NOLINTNEXTLINE(bugprone-unchecked-optional-access) PyTuple_SetItem(outputs, i, THPVariable_Wrap(*wrapped_outputs[i])); diff --git a/torch/csrc/autograd/python_function.h b/torch/csrc/autograd/python_function.h index b1c3eb25ee584..8c4f2f68dc57a 100644 --- a/torch/csrc/autograd/python_function.h +++ b/torch/csrc/autograd/python_function.h @@ -13,7 +13,6 @@ #include #include -#include #include namespace torch::jit { diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 5a5d10ac2670a..8f113a6a70286 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -207,7 +207,7 @@ PyObject* ParameterClass = nullptr; static PyObject* THPVariable_NewWithVar( PyTypeObject* type, - Variable _var, + const at::TensorBase& _var, c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj = false); @@ -254,8 +254,7 @@ void activateGPUTrace() { c10::impl::GPUTrace::set_trace(getPyInterpreter()); } -// TODO: Make this take Variable by const reference -PyObject* THPVariable_Wrap(at::TensorBase var) { +PyObject* THPVariable_Wrap(const at::TensorBase& var) { if (!var.defined()) { Py_RETURN_NONE; } @@ -263,7 +262,7 @@ PyObject* THPVariable_Wrap(at::TensorBase var) { if (c10::impl::HermeticPyObjectTLS::get_state()) { return THPVariable_NewWithVar( (PyTypeObject*)THPVariableClass, - std::move(var), + var, c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); } @@ -282,7 +281,7 @@ PyObject* THPVariable_Wrap(at::TensorBase var) { // object if all C++ references go to zero var.unsafeGetTensorImpl()->pyobj_slot()->set_owns_pyobj(false); reinterpret_cast(obj)->cdata = - MaybeOwned::owned(std::move(var)); + MaybeOwned::owned(Variable(var)); // NB: incref is not necessary, because we are "stealing" the previous // ownership from the Variable to return it here for the wrap return obj; @@ -308,16 +307,14 @@ PyObject* THPVariable_Wrap(at::TensorBase var) { } if (C10_LIKELY(var.device().type() != c10::kXLA)) { - return THPVariable_NewWithVar( - (PyTypeObject*)THPVariableClass, std::move(var), status); + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status); } if (auto clazz = getPythonTensorClass(var.device())) { - return THPVariable_NewWithVar((PyTypeObject*)clazz, std::move(var), status); + return THPVariable_NewWithVar((PyTypeObject*)clazz, var, status); } - return THPVariable_NewWithVar( - (PyTypeObject*)THPVariableClass, std::move(var), status); + return THPVariable_NewWithVar((PyTypeObject*)THPVariableClass, var, status); } bool isResurrectable(THPVariable* self) { @@ -619,7 +616,7 @@ static PyObject* view_func_impl( } } } - return THPVariable_Wrap(std::move(out)); + return THPVariable_Wrap(out); END_HANDLE_TH_ERRORS } @@ -655,7 +652,7 @@ static PyObject* rev_view_func_impl(PyObject* self_, PyObject* arg) { TORCH_CHECK(view_info.has_view_fn(), "No _rev_view_func() found"); out = view_info.rev_view_fn()(new_view); } - return THPVariable_Wrap(std::move(out)); + return THPVariable_Wrap(out); END_HANDLE_TH_ERRORS } @@ -683,6 +680,10 @@ static PyObject* THPVariable_as_subclass( "cls must be a type (got ", Py_TYPE(cls)->tp_name, ")"); + // guard completely turns off torch dispatch modes, doesn't just pop off the + // stack + torch_dispatch_mode::StashTorchDispatchStackGuard td_g; + c10::impl::DisablePythonDispatcher dpd_g; return THPVariable_NewWithVar( (PyTypeObject*)cls, self.alias(), @@ -1894,7 +1895,7 @@ PyObject* THPVariable_pynew( // these to be passed on directly. return THPVariable_NewWithVar( type, - std::move(tensor), + tensor, c10::impl::PyInterpreterStatus::MAYBE_UNINITIALIZED, /*allow_preexisting_pyobj=*/true); END_HANDLE_TH_ERRORS @@ -2008,7 +2009,7 @@ void THPVariable_subclass_dealloc(PyObject* self) { // It's ALWAYS safe (albeit slower) to call this with MAYBE_UNINITIALIZED. static PyObject* THPVariable_NewWithVar( PyTypeObject* type, - Variable _var, + const at::TensorBase& _var, c10::impl::PyInterpreterStatus status, bool allow_preexisting_pyobj) { // Make sure that the reinterpret into a THPVariable* will be valid @@ -2078,7 +2079,7 @@ static PyObject* THPVariable_NewWithVar( " which is not a subclass of the " "requested type"); // We may (in fact, we typically will) need to resurrect this - return THPVariable_Wrap(std::move(_var)); + return THPVariable_Wrap(_var); } PyObject* obj = type->tp_alloc(type, 0); @@ -2088,7 +2089,7 @@ static PyObject* THPVariable_NewWithVar( new (&v->cdata) MaybeOwned(); if (c10::impl::HermeticPyObjectTLS::get_state()) { // Do NOT initialize pyobj field on the tensor, you own the C++ - v->cdata = MaybeOwned::owned(std::move(_var)); + v->cdata = MaybeOwned::owned(Variable(_var)); TORCH_INTERNAL_ASSERT( !check_has_torch_dispatch(obj), "While HermeticPyObject was enabled, we attempted to create a tensor " @@ -2100,7 +2101,7 @@ static PyObject* THPVariable_NewWithVar( "Python op registration."); } else { // Normal codepath - v->cdata = MaybeOwned::owned(std::move(_var)); + v->cdata = MaybeOwned::owned(Variable(_var)); const auto& var = THPVariable_Unpack(v); var.unsafeGetTensorImpl()->pyobj_slot()->init_pyobj( getPyInterpreter(), obj, status); diff --git a/torch/csrc/autograd/python_variable.h b/torch/csrc/autograd/python_variable.h index 51ade77f03ece..32cc5c930ca0a 100644 --- a/torch/csrc/autograd/python_variable.h +++ b/torch/csrc/autograd/python_variable.h @@ -37,7 +37,7 @@ TORCH_PYTHON_API extern PyObject* THPVariableClass; TORCH_PYTHON_API extern PyObject* ParameterClass; bool THPVariable_initModule(PyObject* module); -TORCH_PYTHON_API PyObject* THPVariable_Wrap(at::TensorBase var); +TORCH_PYTHON_API PyObject* THPVariable_Wrap(const at::TensorBase& var); inline bool THPVariable_CheckTypeExact(PyTypeObject* tp) { // Check that a python object is a `Tensor`, but not a `Tensor` subclass. diff --git a/torch/csrc/autograd/saved_variable.h b/torch/csrc/autograd/saved_variable.h index 2866b56715609..0d28c95e19a26 100644 --- a/torch/csrc/autograd/saved_variable.h +++ b/torch/csrc/autograd/saved_variable.h @@ -29,7 +29,9 @@ class TORCH_API SavedVariable { const std::optional& variable, bool is_output, bool is_inplace_on_view = false); + SavedVariable(const SavedVariable&) = delete; SavedVariable(SavedVariable&&) = default; + SavedVariable& operator=(const SavedVariable&) = delete; SavedVariable& operator=(SavedVariable&&) = default; ~SavedVariable() { if (fw_grad_) { diff --git a/torch/csrc/autograd/variable_info.cpp b/torch/csrc/autograd/variable_info.cpp index bffd3250fb088..5bde41544910f 100644 --- a/torch/csrc/autograd/variable_info.cpp +++ b/torch/csrc/autograd/variable_info.cpp @@ -2,6 +2,7 @@ #include #else #include +#include #endif #include @@ -9,13 +10,16 @@ namespace torch::autograd { -VariableInfo::VariableInfo(const Variable& var) +VariableInfo::VariableInfo(const Variable& var, bool use_zeros_like) : layout(var.layout()), device(var.device()), scalar_type(var.scalar_type()), size(var.sym_sizes().vec()), requires_grad(var.requires_grad()), - is_empty(false) {} + is_empty(false), + the_var( + use_zeros_like ? std::optional(var.detach()) + : std::nullopt) {} VariableInfo::VariableInfo() : requires_grad(false), is_empty(true) {} @@ -23,6 +27,8 @@ Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const { if (is_empty) { // Return undefined tensor. return at::Tensor(); + } else if (the_var.has_value()) { + return at::zeros_like(*the_var); } else { return at::zeros_symint( size, at::TensorOptions(scalar_type).device(device).layout(layout)); diff --git a/torch/csrc/autograd/variable_info.h b/torch/csrc/autograd/variable_info.h index 63e88deb0d547..e26804e7e55fc 100644 --- a/torch/csrc/autograd/variable_info.h +++ b/torch/csrc/autograd/variable_info.h @@ -6,7 +6,7 @@ namespace torch::autograd { struct TORCH_API VariableInfo { explicit VariableInfo(); - explicit VariableInfo(const Variable& var); + explicit VariableInfo(const Variable& var, bool use_zeros_like = false); Variable zeros(at::OptionalDeviceGuard& device_guard) const; @@ -16,6 +16,8 @@ struct TORCH_API VariableInfo { std::vector size; bool requires_grad; bool is_empty; + // needed for e.g. NJTs since they only support zeros_like() + std::optional the_var; }; } // namespace torch::autograd diff --git a/torch/csrc/cpu/Module.cpp b/torch/csrc/cpu/Module.cpp index 84eb864d2ceca..23abb3abae946 100644 --- a/torch/csrc/cpu/Module.cpp +++ b/torch/csrc/cpu/Module.cpp @@ -14,6 +14,7 @@ void initModule(PyObject* module) { cpu.def("_is_avx512_bf16_supported", at::cpu::is_avx512_bf16_supported); cpu.def("_is_amx_tile_supported", at::cpu::is_amx_tile_supported); cpu.def("_init_amx", at::cpu::init_amx); + cpu.def("_is_arm_sve_supported", at::cpu::is_arm_sve_supported); cpu.def("_L1d_cache_size", at::cpu::L1d_cache_size); cpu.def("_L2_cache_size", at::cpu::L2_cache_size); } diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp index 5220e86233bd6..faa5692b058df 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -14,7 +14,7 @@ CUDAPluggableAllocatorDeleterContext::CUDAPluggableAllocatorDeleterContext( size_t size, int device, cudaStream_t stream) - : free_fn_(free_fn), + : free_fn_(std::move(free_fn)), data_(data), size_(size), device_(device), diff --git a/torch/csrc/cuda/GdsFile.cpp b/torch/csrc/cuda/GdsFile.cpp index b95b86b3374f9..945da3be65102 100644 --- a/torch/csrc/cuda/GdsFile.cpp +++ b/torch/csrc/cuda/GdsFile.cpp @@ -12,8 +12,7 @@ namespace { // filesystem error and a negative CUfileOpError enum value otherwise). template < class T, - typename std::enable_if::value, std::nullptr_t>::type = - nullptr> + std::enable_if_t, std::nullptr_t> = nullptr> std::string cuGDSFileGetErrorString(T status) { status = std::abs(status); return IS_CUFILE_ERR(status) ? std::string(CUFILE_ERRSTR(status)) @@ -24,8 +23,7 @@ std::string cuGDSFileGetErrorString(T status) { // CUfileError_t template < class T, - typename std::enable_if::value, std::nullptr_t>::type = - nullptr> + std::enable_if_t, std::nullptr_t> = nullptr> std::string cuGDSFileGetErrorString(T status) { std::string errStr = cuGDSFileGetErrorString(static_cast(status.err)); if (IS_CUDA_ERR(status)) diff --git a/torch/csrc/cuda/MemPool.cpp b/torch/csrc/cuda/MemPool.cpp index 83c9b9c1c1bf5..d5e0030ee7b7f 100644 --- a/torch/csrc/cuda/MemPool.cpp +++ b/torch/csrc/cuda/MemPool.cpp @@ -1,6 +1,7 @@ #include #include +#include #include #include @@ -11,9 +12,16 @@ using shared_ptr_class_ = py::class_>; void THCPMemPool_init(PyObject* module) { auto torch_C_m = py::handle(module).cast(); shared_ptr_class_<::c10::cuda::MemPool>(torch_C_m, "_MemPool") - .def(py::init()) + .def( + py::init([](c10::cuda::CUDACachingAllocator::CUDAAllocator* allocator, + bool is_user_created) { + torch::utils::device_lazy_init(at::kCUDA); + return std::make_shared<::c10::cuda::MemPool>( + allocator, is_user_created); + })) .def_property_readonly("id", &::c10::cuda::MemPool::id) - .def_property_readonly("allocator", &::c10::cuda::MemPool::allocator); + .def_property_readonly("allocator", &::c10::cuda::MemPool::allocator) + .def("use_count", &::c10::cuda::MemPool::use_count); shared_ptr_class_<::c10::cuda::MemPoolContext>(torch_C_m, "_MemPoolContext") .def(py::init()) .def_static( diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 23a244a60f08c..ae1f20fac118d 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -890,7 +890,7 @@ PyObject* THCPModule_attachOutOfMemoryObserver( } Py_XDECREF(result); }; - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); c10::cuda::CUDACachingAllocator::attachOutOfMemoryObserver(std::move(obs)); Py_RETURN_NONE; END_HANDLE_TH_ERRORS @@ -1250,6 +1250,13 @@ static void registerCudaPluggableAllocator(PyObject* module) { ->release_storage_and_set_meta_custom_data_ptr_error_msg_(s); }); + m.def( + "_set_storage_data_ptr_access_error_msg", + [](size_t storage_impl_ptr, std::string s) { + c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr; + storage_impl->release_data_and_set_meta_custom_data_ptr_error_msg_(s); + }); + m.def("_has_Standard_Deleter", [](size_t storage_impl_ptr) { // NOLINTNEXTLINE(performance-no-int-to-ptr) c10::StorageImpl* storage_impl = (c10::StorageImpl*)storage_impl_ptr; @@ -1266,8 +1273,7 @@ static void registerCudaPluggableAllocator(PyObject* module) { m.def( "_tensors_data_ptrs_at_indices_equal", [](py::list& tensors, py::list& data_ptrs, py::list& indices) { - for (size_t i = 0, end = indices.size(); i < end; ++i) { - auto index = indices[i].cast(); + for (auto index : indices) { auto t = tensors[index].cast(); auto data_ptr = data_ptrs[index].cast(); if (reinterpret_cast(t.data_ptr()) != data_ptr) { @@ -1419,7 +1425,7 @@ static PyObject* THCPModule_initExtension(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level poison_fork(); - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda")); if (!m) @@ -1451,7 +1457,6 @@ PyObject* THCPModule_getCurrentBlasHandle_wrap( PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); return PyLong_FromVoidPtr(handle); END_HANDLE_TH_ERRORS @@ -1531,6 +1536,32 @@ PyObject* THCPModule_cuda_tunableop_tuning_is_enabled( END_HANDLE_TH_ERRORS } +PyObject* THCPModule_cuda_record_untuned_enable( + PyObject* _unused, + PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkBool(arg), + "cuda_record_untuned_enable expects a bool, but got ", + THPUtils_typename(arg)); + at::cuda::tunable::getTuningContext()->EnableRecordUntuned( + THPUtils_unpackBool(arg)); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THCPModule_cuda_record_untuned_is_enabled( + PyObject* _unused, + PyObject* noarg) { + HANDLE_TH_ERRORS + if (at::cuda::tunable::getTuningContext()->IsRecordUntunedEnabled()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + PyObject* THCPModule_cuda_tunableop_write_file_on_exit( PyObject* _unused, PyObject* arg) { @@ -1945,6 +1976,14 @@ static struct PyMethodDef _THCPModule_methods[] = { THCPModule_cuda_tunableop_tuning_is_enabled, METH_NOARGS, nullptr}, + {"_cuda_record_untuned_enable", + THCPModule_cuda_record_untuned_enable, + METH_O, + nullptr}, + {"_cuda_record_untuned_is_enabled", + THCPModule_cuda_record_untuned_is_enabled, + METH_NOARGS, + nullptr}, {"_cuda_tunableop_write_file_on_exit", THCPModule_cuda_tunableop_write_file_on_exit, METH_O, diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp index 52331909fe1dc..4dce2a710f3b7 100644 --- a/torch/csrc/cuda/comm.cpp +++ b/torch/csrc/cuda/comm.cpp @@ -94,7 +94,6 @@ std::vector& broadcast_out( } std::vector broadcast(const Tensor& tensor, IntArrayRef devices) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector diff_device_dst_tensors; diff_device_dst_tensors.reserve(devices.size()); for (auto device : devices) { @@ -109,7 +108,6 @@ std::vector broadcast(const Tensor& tensor, IntArrayRef devices) { } } _broadcast_out_impl(tensor, diff_device_dst_tensors); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector dst_tensors; dst_tensors.reserve(devices.size()); auto it = diff_device_dst_tensors.begin(); @@ -172,7 +170,6 @@ tensor_list2d broadcast_coalesced( buffer_size = std::min(torch::cuda::nccl::get_max_count(), buffer_size); #endif - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) tensor_list2d outputs(devices.size()); outputs[0] = tensors.vec(); for (auto& o : outputs) @@ -239,7 +236,6 @@ std::vector& scatter_out( "Expected at least one output tensor to scatter to"); dim = at::maybe_wrap_dim(dim, tensor); int64_t total_size = 0; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector chunk_sizes; chunk_sizes.reserve(out_tensors.size()); for (const auto i : c10::irange(out_tensors.size())) { @@ -374,7 +370,6 @@ static inline at::Tensor& _gather_out_impl( at::TensorList tensors, at::Tensor& out_tensor, int64_t dim) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector chunk_sizes; chunk_sizes.reserve(tensors.size()); for (auto& tensor : tensors) { @@ -397,7 +392,6 @@ at::Tensor& gather_out( auto& first = tensors.front(); const auto first_size = first.sizes(); dim = at::maybe_wrap_dim(dim, first); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector expected_size(first_size.begin(), first_size.end()); for (const auto i : c10::irange(tensors.size())) { const auto& tensor = tensors[i]; @@ -452,7 +446,6 @@ at::Tensor gather( auto& first = tensors.front(); const auto first_size = first.sizes(); dim = at::maybe_wrap_dim(dim, first); - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) std::vector expected_size(first_size.begin(), first_size.end()); auto memory_format = first.suggest_memory_format(); for (const auto i : c10::irange(tensors.size())) { diff --git a/torch/csrc/cuda/memory_snapshot.cpp b/torch/csrc/cuda/memory_snapshot.cpp index 76ff111936edf..05da63b5bbbc9 100644 --- a/torch/csrc/cuda/memory_snapshot.cpp +++ b/torch/csrc/cuda/memory_snapshot.cpp @@ -138,7 +138,7 @@ void _record_memory_history( } else if (record_context) { when = c10::cuda::CUDACachingAllocator::RecordContext::STATE; } - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); _initRecordAnnotations(); c10::cuda::CUDACachingAllocator::recordHistory( enabled, recorder, trace_alloc_max_entries, when); @@ -189,7 +189,7 @@ void _record_memory_history( when = c10::cuda::CUDACachingAllocator::RecordContext::STATE; } } - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); _initRecordAnnotations(); c10::cuda::CUDACachingAllocator::recordHistory( enabled.has_value(), recorder, max_entries, when); diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index 34b2896f7eaa5..7be7b08efc6a6 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -11,6 +11,7 @@ #include +#include #include #include #include @@ -112,6 +113,18 @@ ncclDataType_t to_nccl_data_type(c10::ScalarType type) { return ncclDataType_t::ncclUint8; case at::kBool: return ncclDataType_t::ncclUint8; +#if defined(USE_ROCM) + case at::kFloat8_e4m3fnuz: + return ncclDataType_t::ncclUint8; + case at::kFloat8_e5m2fnuz: + return ncclDataType_t::ncclUint8; +#else + case at::kFloat8_e4m3fn: + return ncclDataType_t::ncclUint8; + case at::kFloat8_e5m2: + return ncclDataType_t::ncclUint8; +#endif + #if HAS_NCCL_BF16_DATATYPE case at::kBFloat16: return ncclDataType_t::ncclBfloat16; @@ -155,40 +168,35 @@ bool nccl_use_nonblocking() { return nccl_use_nonblocking_; } -static int _parse_nccl_nonblocking_timeout() { - const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT"); - int timeout = -1; - if (val) { - const std::string config(val); - timeout = std::stoi(config); - if (!nccl_use_nonblocking() && timeout > 0) { - TORCH_WARN( - "TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false."); - timeout = -1; +// Default value: 30 minutes +static int nccl_nonblocking_timeout() { + static int timeout = -2; // -2 means not initialized + if (timeout == -2) { + const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT"); + if (val && strlen(val) > 0) { + timeout = strtol(val, nullptr, 0); + } else { + // Default value consistent with kBackendDefaultTimeout + timeout = 30 * 60; } } return timeout; } -static int nccl_nonblocking_timeout() { - static int timeout = _parse_nccl_nonblocking_timeout(); - return timeout; -} - static inline void NCCL_CHECK_TIMEOUT(ncclResult status, ncclComm_t comm) { #ifdef NCCL_HAS_COMM_NONBLOCKING ncclResult_t result = to_nccl_result(status); auto startTimepoint = std::chrono::steady_clock::now(); while (result == ncclInProgress) { - if (nccl_nonblocking_timeout() > 0) { - auto currentTimepoint = std::chrono::steady_clock::now(); - auto timeElapsed = std::chrono::duration_cast( - currentTimepoint - startTimepoint) - .count(); - if (timeElapsed > nccl_nonblocking_timeout()) { - throw std::runtime_error("NCCL timeout."); - } + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - startTimepoint) + .count(); + if (timeElapsed > nccl_nonblocking_timeout()) { + throw std::runtime_error( + "NCCL timeout when waiting for nonblocking call to become successful."); } + sched_yield(); // yield to other threads ncclCommGetAsyncError(to_nccl_comm(comm), &result); } if (result != ncclSuccess) { @@ -213,15 +221,15 @@ static inline void NCCL_CHECK_TIMEOUT( if (result == ncclInProgress) { for (const auto i : c10::irange(comms.size())) { do { - if (nccl_nonblocking_timeout() > 0) { - auto currentTimepoint = std::chrono::steady_clock::now(); - auto timeElapsed = std::chrono::duration_cast( - currentTimepoint - startTimepoint) - .count(); - if (timeElapsed > nccl_nonblocking_timeout()) { - throw std::runtime_error("NCCL timeout."); - } + auto currentTimepoint = std::chrono::steady_clock::now(); + auto timeElapsed = std::chrono::duration_cast( + currentTimepoint - startTimepoint) + .count(); + if (timeElapsed > nccl_nonblocking_timeout()) { + throw std::runtime_error( + "NCCL timeout when waiting for nonblocking call to become successful."); } + sched_yield(); // yield to other threads ncclCommGetAsyncError(to_nccl_comm(comms[i]), &result); } while (result == ncclInProgress); if (result != ncclSuccess) { @@ -263,7 +271,7 @@ struct NcclCommList { ~NcclCommList() { if (comms) { for (const auto i : c10::irange(ndevices)) { - int dummy_var; + int dummy_var = 0; if (C10_CUDA_ERROR_HANDLED(cudaGetDevice(&dummy_var)) != cudaSuccess) { /* there are cases when this destructor is called after the CUDA driver is already unloaded from the process. @@ -366,7 +374,7 @@ void check_inputs( auto dtype = inputs[0].scalar_type(); for (const auto i : c10::irange(len)) { - auto input = inputs[i]; + const auto& input = inputs[i]; auto output = outputs[i]; check_tensor( @@ -398,7 +406,7 @@ void check_inputs( auto dtype = inputs[0].scalar_type(); for (const auto i : c10::irange(len)) { - auto input = inputs[i]; + const auto& input = inputs[i]; check_tensor( input, @@ -421,25 +429,24 @@ void check_inputs( } // namespace detail -AutoNcclGroup::AutoNcclGroup() { +AutoNcclGroup::AutoNcclGroup() : comm_(nullptr), comm_nonblocking_(false) { #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2) // nccl < 2.0 cannot be called concurrently with cudaFree (c10::cuda::getFreeMutex())->lock(); #endif - comm_nonblocking_ = false; - comm_ = nullptr; + #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) detail::NCCL_CHECK(ncclGroupStart()); #endif } -AutoNcclGroup::AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking) { +AutoNcclGroup::AutoNcclGroup(ncclComm_t comm, bool comm_nonblocking) + : comm_(comm), comm_nonblocking_(comm_nonblocking) { #if defined(NCCL_MAJOR) && (NCCL_MAJOR < 2) // nccl < 2.0 cannot be called concurrently with cudaFree (c10::cuda::getFreeMutex())->lock(); #endif - comm_ = comm; - comm_nonblocking_ = comm_nonblocking; + #if defined(NCCL_MAJOR) && (NCCL_MAJOR >= 2) detail::NCCL_CHECK(ncclGroupStart()); #endif @@ -503,14 +510,14 @@ void get_unique_id(ncclUniqueId& id) { using namespace torch::cuda::nccl::detail; NCCL_CHECK(ncclGetUniqueId(to_nccl_unique_id(&id))); #else - AT_ERROR("PyTorch built without NCCL support"); + TORCH_CHECK(false, "PyTorch built without NCCL support"); #endif } ncclComm_t comm_init_rank(int nranks, const ncclUniqueId& comm_id, int rank) { #ifdef USE_NCCL using namespace torch::cuda::nccl::detail; - ncclComm_t comm; + ncclComm_t comm = nullptr; ncclUniqueId id = comm_id; NCCL_CHECK(ncclCommInitRank( to_nccl_comm(&comm), nranks, *(to_nccl_unique_id(&id)), rank)); @@ -548,7 +555,7 @@ struct GetSecondArgType; template struct GetSecondArgType { - typedef typename std::decay::type type; + typedef std::decay_t type; }; constexpr auto count_max = @@ -558,12 +565,10 @@ constexpr auto count_max = // https://github.com/NVIDIA/nccl/issues/696. The issue of skipping send/recv // is that it can cause deadlock when a rank send and recv 0 bytes so it's // completely skipping the collective, causing mismatch across ranks -// Note: on AMD GPU, we're running into hang w/ 0 byte send/recv, so we're -// skipping this for ROCm for now -#if !defined(USE_ROCM) && defined(NCCL_MAJOR) && \ +#if defined(NCCL_MAJOR) && \ ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR > 13))) template -constexpr bool _nccl_should_send_recv(C10_UNUSED T _unused_) { +constexpr bool _nccl_should_send_recv([[maybe_unused]] T _unused_) { return true; } #else @@ -620,7 +625,7 @@ void broadcast( stream)); } #else - AT_ERROR("PyTorch built without NCCL support"); + TORCH_CHECK(false, "PyTorch built without NCCL support"); #endif } @@ -669,7 +674,7 @@ void reduce( stream)); } #else - AT_ERROR("PyTorch built without NCCL support"); + TORCH_CHECK(false, "PyTorch built without NCCL support"); #endif } @@ -720,7 +725,7 @@ void all_reduce( stream)); } #else - AT_ERROR("PyTorch built without NCCL support"); + TORCH_CHECK(false, "PyTorch built without NCCL support"); #endif } @@ -762,7 +767,7 @@ void reduce_scatter( stream)); } #else - AT_ERROR("PyTorch built without NCCL support"); + TORCH_CHECK(false, "PyTorch built without NCCL support"); #endif } @@ -812,7 +817,7 @@ void all_gather( #endif } #else - AT_ERROR("PyTorch built without NCCL support"); + TORCH_CHECK(false, "PyTorch built without NCCL support"); #endif } @@ -827,7 +832,7 @@ void all2all_single_equal_split( ((NCCL_MAJOR > 2) || ((NCCL_MAJOR == 2) && (NCCL_MINOR >= 7))) using namespace torch::cuda::nccl::detail; - int numranks; + int numranks = 0; auto type = to_nccl_data_type(input); size_t count = input.numel() / size; size_t rankdiff = input.nbytes() / size; @@ -857,10 +862,10 @@ void all2all_single_equal_split( #endif #endif #else - AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); + TORCH_CHECK(false, "all2all is only supported for NCCL lib version >= 2.7.0"); #endif #else - AT_ERROR("PyTorch built without NCCL support"); + TORCH_CHECK(false, "PyTorch built without NCCL support"); #endif } @@ -882,7 +887,7 @@ void all2all_single_unequal_split( auto type = to_nccl_data_type(_type); auto comm = to_nccl_comm(_comm); -#ifdef NCCL_ALLTOALLV_SUPPORTED +#if defined(USE_ROCM) || defined(NCCL_ALLTOALLV_SUPPORTED) // NCCL_ALLTOALLV_SUPPORTED is used so NCCL can differentiate send/recv // operations issued as a part of the collective (e.g. alltoallv) vs those // inside traditional p2p operations. @@ -897,7 +902,7 @@ void all2all_single_unequal_split( comm, stream.stream())); #else - int numranks; + int numranks = 0; NCCL_CHECK(ncclCommCount(comm, &numranks)); NCCL_CHECK(ncclGroupStart()); for (const auto r : c10::irange(numranks)) { @@ -927,10 +932,10 @@ void all2all_single_unequal_split( #endif #endif #else - AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); + TORCH_CHECK(false, "all2all is only supported for NCCL lib version >= 2.7.0"); #endif #else - AT_ERROR("PyTorch built without NCCL support"); + TORCH_CHECK(false, "PyTorch built without NCCL support"); #endif } @@ -1018,10 +1023,10 @@ void all2all( #endif #endif #else - AT_ERROR("all2all is only supported for NCCL lib version >= 2.7.0"); + TORCH_CHECK(false, "all2all is only supported for NCCL lib version >= 2.7.0"); #endif #else - AT_ERROR("PyTorch built without NCCL support"); + TORCH_CHECK(false, "PyTorch built without NCCL support"); #endif } @@ -1054,10 +1059,10 @@ void send( comm); #endif #else - AT_ERROR("Send is only supported for NCCL lib version >= 2.7.0"); + TORCH_CHECK(false, "Send is only supported for NCCL lib version >= 2.7.0"); #endif #else - AT_ERROR("PyTorch built without NCCL support"); + TORCH_CHECK(false, "PyTorch built without NCCL support"); #endif } @@ -1090,10 +1095,10 @@ void recv( comm); #endif #else - AT_ERROR("Recv is only supported for NCCL lib version >= 2.7.0"); + TORCH_CHECK(false, "Recv is only supported for NCCL lib version >= 2.7.0"); #endif #else - AT_ERROR("PyTorch built without NCCL support"); + TORCH_CHECK(false, "PyTorch built without NCCL support"); #endif } @@ -1109,7 +1114,7 @@ void gather( using namespace torch::cuda::nccl::detail; auto comm = to_nccl_comm(_comm); - int numranks, cur_rank; + int numranks = 0, cur_rank = 0; NCCL_CHECK(ncclCommCount(comm, &numranks)); NCCL_CHECK(ncclCommUserRank(comm, &cur_rank)); @@ -1139,10 +1144,10 @@ void gather( #endif #else - AT_ERROR("gather is only supported for NCCL lib version >= 2.7.0"); + TORCH_CHECK(false, "gather is only supported for NCCL lib version >= 2.7.0"); #endif #else - AT_ERROR("PyTorch built without NCCL support"); + TORCH_CHECK(false, "PyTorch built without NCCL support"); #endif } @@ -1158,7 +1163,7 @@ void scatter( using namespace torch::cuda::nccl::detail; auto comm = to_nccl_comm(_comm); - int numranks, cur_rank; + int numranks = 0, cur_rank = 0; #ifndef NCCL_HAS_COMM_NONBLOCKING NCCL_CHECK(ncclCommCount(comm, &numranks)); NCCL_CHECK(ncclCommUserRank(comm, &cur_rank)); @@ -1192,10 +1197,10 @@ void scatter( NCCL_CHECK_TIMEOUT(ncclGroupEnd(), _comm); #endif #else - AT_ERROR("scatter is only supported for NCCL lib version >= 2.7.0"); + TORCH_CHECK(false, "scatter is only supported for NCCL lib version >= 2.7.0"); #endif #else - AT_ERROR("PyTorch built without NCCL support"); + TORCH_CHECK(false, "PyTorch built without NCCL support"); #endif } diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp index fe8d48c266fad..4f8dfd6456df8 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp +++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp @@ -366,7 +366,8 @@ void DistEngine::execute_graph_task_until_ready_queue_empty( // block and can be deallocated (release any references to grad tensors // as part of inputs_) NodeTask task = cpu_ready_queue->pop(); - if (!(local_graph_task = task.base_.lock())) { + local_graph_task = task.base_.lock(); + if (!local_graph_task) { continue; } if (task.fn_ && !local_graph_task->has_error_.load()) { @@ -629,11 +630,11 @@ size_t DistEngine::numBackwardPasses() const { return initializedContextIds_.size(); } -std::unordered_map DistEngine::getDebugInfo() const { - std::unordered_map debugInfo; - debugInfo[kNumBackwardPasses] = numBackwardPasses(); - debugInfo[kNumAutogradContexts] = - DistAutogradContainer::getInstance().numAutogradContexts(); +std::unordered_map DistEngine::getDebugInfo() const { + std::unordered_map debugInfo; + debugInfo[kNumBackwardPasses] = static_cast(numBackwardPasses()); + debugInfo[kNumAutogradContexts] = static_cast( + DistAutogradContainer::getInstance().numAutogradContexts()); return debugInfo; } diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.h b/torch/csrc/distributed/autograd/engine/dist_engine.h index 928a4cde282d1..362c78fa07b1f 100644 --- a/torch/csrc/distributed/autograd/engine/dist_engine.h +++ b/torch/csrc/distributed/autograd/engine/dist_engine.h @@ -52,7 +52,7 @@ class TORCH_API DistEngine { // Returns key-value pairs consisting of useful debugging information related // to distributed autograd. - std::unordered_map getDebugInfo() const; + std::unordered_map getDebugInfo() const; DistEngine(const DistEngine&) = delete; DistEngine& operator=(const DistEngine&) = delete; diff --git a/torch/csrc/distributed/autograd/functions/sendrpc_backward.cpp b/torch/csrc/distributed/autograd/functions/sendrpc_backward.cpp index a4da69153f26e..263bdf9eeb662 100644 --- a/torch/csrc/distributed/autograd/functions/sendrpc_backward.cpp +++ b/torch/csrc/distributed/autograd/functions/sendrpc_backward.cpp @@ -3,6 +3,7 @@ namespace torch::distributed::autograd { torch::autograd::variable_list SendRpcBackward::apply( + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) torch::autograd::variable_list&& inputs) { TORCH_INTERNAL_ASSERT( inputs.empty(), "SendRpcBackward should receive no inputs"); diff --git a/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp index ee5b0029f2fff..df1d88cde4886 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp +++ b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp @@ -63,10 +63,8 @@ std::unique_ptr PropagateGradientsReq::fromMessage( bool retainGraph = tupleElements.back().toBool(); // Build AutogradMetadata. - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int64_t autogradContextId, autogradMessageId; - autogradMessageId = tupleElements[tupleElements.size() - 2].toInt(); - autogradContextId = tupleElements[tupleElements.size() - 3].toInt(); + int64_t autogradMessageId = tupleElements[tupleElements.size() - 2].toInt(); + int64_t autogradContextId = tupleElements[tupleElements.size() - 3].toInt(); AutogradMetadata autogradMetadata(autogradContextId, autogradMessageId); diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp index e41cb036ada9d..fd5ab54e58cfa 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp @@ -106,7 +106,7 @@ std::unique_ptr RpcWithAutograd::fromMessage( static_cast(tupleElements[0].toInt()); AutogradMetadata autogradMetadata( tupleElements[1].toInt(), tupleElements[2].toInt()); - worker_id_t workerId = tupleElements[3].toInt(); + worker_id_t workerId = static_cast(tupleElements[3].toInt()); auto c10DeviceMap = tupleElements[4].to>(); diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.cpp b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.cpp index 90e15c3de612f..1fb1756d1b606 100644 --- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.cpp +++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_profiling_resp.cpp @@ -116,7 +116,7 @@ std::unique_ptr RpcWithProfilingResp::fromMessage( rpc::MessageType wrappedMsgType = static_cast(tupleElements[0].toInt()); rpc::ProfilingId profilingId = rpc::ProfilingId::fromIValue(tupleElements[1]); - int profiledEventsSize = tupleElements[2].toInt(); + auto profiledEventsSize = tupleElements[2].toInt(); std::vector remoteEvents; remoteEvents.reserve(profiledEventsSize); for (const auto i : c10::irange( diff --git a/torch/csrc/distributed/c10d/Backoff.cpp b/torch/csrc/distributed/c10d/Backoff.cpp index a0ef2ba0b8b34..850cb45181b91 100644 --- a/torch/csrc/distributed/c10d/Backoff.cpp +++ b/torch/csrc/distributed/c10d/Backoff.cpp @@ -1,13 +1,12 @@ #include -#include #include namespace c10d { namespace { constexpr std::chrono::milliseconds kZeroInterval{0}; -int32_t randSeed() { +std::random_device::result_type randSeed() { std::random_device rd; return rd(); } @@ -47,7 +46,7 @@ std::chrono::milliseconds ExponentialBackoffWithJitter::nextBackoff() { std::chrono::milliseconds maxSampleInterval = currentInterval_ + randomization; - std::uniform_int_distribution<> dist( + std::uniform_int_distribution dist( minSampleInterval.count(), maxSampleInterval.count()); std::chrono::milliseconds backoffInterval{dist(gen_)}; diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h b/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h index f1cc296874fcf..ef2b61eb8ec34 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory-inl.h @@ -5,12 +5,11 @@ #endif #include - +#if !defined(USE_ROCM) +#include +#endif namespace c10d::symmetric_memory { -constexpr size_t max_num_threads_per_block = 1024; -constexpr size_t max_num_blocks = 8; - template __inline__ size_t get_alignment(T ptr_or_size) { auto val = reinterpret_cast(ptr_or_size); @@ -32,114 +31,164 @@ __inline__ size_t get_alignment(size_t size) { return get_alignment(reinterpret_cast(size)); } +template +inline constexpr bool dependent_bool_value = Value; + +template +inline constexpr bool dependent_false = dependent_bool_value; + +template +inline constexpr bool dependent_false_nt = + dependent_bool_value; + +enum class MemOpSem { + Relaxed, + Acquire, + Release, + AcqRel, +}; + +#define CAS_ASM(addr, compare, val, old_val, sem) \ + asm volatile("atom.global" sem ".sys.cas.b32 %0, [%1], %2, %3;" \ + : "=r"(old_val) \ + : "l"(addr), "r"(compare), "r"(val) \ + : "memory"); + +template __device__ __forceinline__ uint32_t -cas_sys(uint32_t* addr, uint32_t compare, uint32_t val) { +cas(uint32_t* addr, uint32_t compare, uint32_t val) { #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) CUDA_KERNEL_ASSERT(false); #else uint32_t old_val; - asm volatile("atom.global.sys.cas.b32 %0, [%1], %2, %3;" - : "=r"(old_val) - : "l"(addr), "r"(compare), "r"(val) - : "memory"); + if constexpr (Sem == MemOpSem::Relaxed) { + CAS_ASM(addr, compare, val, old_val, ".relaxed"); + } else if constexpr (Sem == MemOpSem::Acquire) { + CAS_ASM(addr, compare, val, old_val, ".acquire"); + } else if constexpr (Sem == MemOpSem::Release) { + CAS_ASM(addr, compare, val, old_val, ".release"); + } else { + static_assert(dependent_false_nt); + } return old_val; #endif } -__device__ __forceinline__ uint32_t -cas_release_sys(uint32_t* addr, uint32_t compare, uint32_t val) { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) +__device__ __forceinline__ void trap() { +#if defined(USE_ROCM) + assert(0); +#else + __trap(); +#endif +} + +__device__ __forceinline__ size_t global_timer_ns() { +#if defined(USE_ROCM) CUDA_KERNEL_ASSERT(false); + return 0; #else - uint32_t old_val; - asm volatile("atom.global.release.sys.cas.b32 %0, [%1], %2, %3;" - : "=r"(old_val) - : "l"(addr), "r"(compare), "r"(val) - : "memory"); - return old_val; + size_t val; + asm volatile("mov.u64 %0, %globaltimer;" : "=l"(val) : : "memory"); + return val; #endif } -__device__ __forceinline__ void release_signal(uint32_t* addr) { - while (cas_release_sys(addr, 0, 1) != 0) +constexpr size_t ns_per_ms = 1e6; + +template +__device__ __forceinline__ bool try_put_signal( + uint32_t* addr, + size_t timeout_ms) { + size_t deadline = global_timer_ns() + timeout_ms * ns_per_ms; + while (cas(addr, 0, 1) != 0) { + if (timeout_ms != 0 && global_timer_ns() > deadline) { + return false; + } + } + return true; +} + +template +__device__ __forceinline__ bool try_wait_signal( + uint32_t* addr, + size_t timeout_ms) { + size_t deadline = global_timer_ns() + timeout_ms * ns_per_ms; + while (cas(addr, 1, 0) != 1) { + if (timeout_ms != 0 && global_timer_ns() > deadline) { + return false; + } + } + return true; +} + +template +__device__ __forceinline__ void put_signal(uint32_t* addr) { + while (cas(addr, 0, 1) != 0) ; } +template __device__ __forceinline__ void wait_signal(uint32_t* addr) { - while (cas_sys(addr, 1, 0) != 1) + while (cas(addr, 1, 0) != 1) ; } -__device__ __forceinline__ uint32_t acquire_signal(uint32_t* addr) { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) - CUDA_KERNEL_ASSERT(false); -#else - uint32_t val; - asm volatile("ld.acquire.sys.global.u32 %0, [%1];" - : "=r"(val) - : "l"(addr) - : "memory"); - return val; -#endif -} +// Synchronizes blocks with matching blockIdx across participating devices. +// Note: sync_remote_block itself is not a system level barrier/fence. It is a +// building block for expressing different synchronization patterns. +// +// Pattern 0: Ensures that all writes to symm_mem buffers from previous +// kernels across all devices are visible to the current kernel: +// +// sync_remote_blocks(...); +// __syncthreads(); +// +// Pattern 1: Ensures that all writes to symm_mem buffers from the current +// block are visible to all remote blocks with matching blockIdx: +// +// __syncthreads(); +// sync_remote_blocks(...); +// __syncthreads(); +// +// Pattern 2: Ensures that symm_mem buffers read by the current kernel are safe +// for writing by subsequent kernels across all devices. +// +// __syncthreads(); +// sync_remote_blocks(...); +template +__device__ __forceinline__ void sync_remote_blocks( + uint32_t** signal_pads, + size_t rank, + size_t world_size); -// Perform a barrier to establish observation order between memory operations -// issued before and after the barrier. -__device__ __forceinline__ void barrier( +template <> +__device__ __forceinline__ void sync_remote_blocks( uint32_t** signal_pads, size_t rank, size_t world_size) { if (threadIdx.x < world_size) { auto target_rank = threadIdx.x; - release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank); - wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank); + put_signal( + signal_pads[target_rank] + blockIdx.x * world_size + rank); + wait_signal( + signal_pads[rank] + blockIdx.x * world_size + target_rank); } - __syncthreads(); } -// Perform a barrier and establish causality order between memory operations -// issued before the calling kernel on all devices and memory operations -// issued after this function by all thread in the calling kernel. -// -// NOTE: this function does NOT ensure that memory operations issues in the -// current kernel are visible to all threads in the current kernel. -// -// | mem ops (guaranteed to be visible by all threads at point T) -// | kernel K -// | +- mem ops (not guaranteed to be visible all threads at point T) -// | +- barrier_and_acquire_previous_kernel_writes() -// | +- point T -// v -__device__ __forceinline__ void barrier_and_acquire_previous_kernel_writes( +template <> +__device__ __forceinline__ void sync_remote_blocks( uint32_t** signal_pads, size_t rank, size_t world_size) { if (threadIdx.x < world_size) { auto target_rank = threadIdx.x; - release_signal(signal_pads[target_rank] + blockIdx.x * world_size + rank); - wait_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank); - } - __syncthreads(); - // At this point, we established observation order between memory operations - // issued before and after the barrier. Now we convert the observation order - // into causality order by having every thread acquire the signals released - // by threads on peer devices. Due to the implicit synchronizes-with - // relationships at task/kernel boundaries, acquiring the signal released by - // thread T in kernel K transitively acquires memory operations issued prior - // to kernel K. - // - // https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#memory-fence-interference - for (size_t target_rank = 0; target_rank < world_size; ++target_rank) { - acquire_signal(signal_pads[rank] + blockIdx.x * world_size + target_rank); + put_signal( + signal_pads[target_rank] + blockIdx.x * world_size + rank); + wait_signal( + signal_pads[rank] + blockIdx.x * world_size + target_rank); } } -template -inline constexpr bool dependent_bool_value = Value; - -template -inline constexpr bool dependent_false = dependent_bool_value; - template union Vec; @@ -147,6 +196,7 @@ template <> union Vec<4> { uint16_t u16[2]; uint32_t u32, as_scalar; + float f32; }; template <> @@ -154,6 +204,7 @@ union Vec<8> { uint16_t u16[4]; uint32_t u32[2]; uint64_t u64, as_scalar; + float f32[2]; }; template <> @@ -162,6 +213,7 @@ union alignas(16) Vec<16> { uint32_t u32[4]; uint64_t u64[2]; uint4 u128, as_scalar; + float f32[4]; }; template @@ -179,49 +231,50 @@ __device__ __inline__ Vec multimem_ld_reduce_add(T* mc_ptr) { } #if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST) -#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \ - template <> \ - struct MultimemLdReduce { \ - template \ - __device__ __inline__ Vec operator()(type* mc_ptr) { \ - CUDA_KERNEL_ASSERT(false); \ - } \ +#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type, acc_prec) \ + template <> \ + struct MultimemLdReduce { \ + template \ + __device__ __inline__ Vec operator()(type* mc_ptr) { \ + CUDA_KERNEL_ASSERT(false); \ + } \ }; #else -#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type) \ - template <> \ - struct MultimemLdReduce { \ - template \ - __device__ __inline__ Vec operator()(type* mc_ptr) { \ - Vec vec; \ - if constexpr (Alignment == 16) { \ - asm("multimem.ld_reduce.relaxed.sys.global.add.v4." asm_type \ - " {%0,%1,%2,%3}, [%4];" \ - : "=r"(vec.u32[0]), \ - "=r"(vec.u32[1]), \ - "=r"(vec.u32[2]), \ - "=r"(vec.u32[3]) \ - : "l"(mc_ptr) \ - : "memory"); \ - } else if constexpr (Alignment == 8) { \ - asm("multimem.ld_reduce.relaxed.sys.global.add.v2." asm_type \ - " {%0,%1}, [%2];" \ - : "=r"(vec.u32[0]), "=r"(vec.u32[1]) \ - : "l"(mc_ptr) \ - : "memory"); \ - } else if constexpr (Alignment == 4) { \ - asm("multimem.ld_reduce.relaxed.sys.global.add." asm_type " %0, [%1];" \ - : "=r"(vec.u32) \ - : "l"(mc_ptr) \ - : "memory"); \ - } \ - return vec; \ - } \ +#define SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(type, asm_type, acc_prec) \ + template <> \ + struct MultimemLdReduce { \ + template \ + __device__ __inline__ Vec operator()(type* mc_ptr) { \ + Vec vec; \ + if constexpr (Alignment == 16) { \ + asm("multimem.ld_reduce.relaxed.sys.global.add" acc_prec \ + ".v4" asm_type " {%0,%1,%2,%3}, [%4];" \ + : "=r"(vec.u32[0]), \ + "=r"(vec.u32[1]), \ + "=r"(vec.u32[2]), \ + "=r"(vec.u32[3]) \ + : "l"(mc_ptr) \ + : "memory"); \ + } else if constexpr (Alignment == 8) { \ + asm("multimem.ld_reduce.relaxed.sys.global.add" acc_prec \ + ".v2" asm_type " {%0,%1}, [%2];" \ + : "=r"(vec.u32[0]), "=r"(vec.u32[1]) \ + : "l"(mc_ptr) \ + : "memory"); \ + } else if constexpr (Alignment == 4) { \ + asm("multimem.ld_reduce.relaxed.sys.global.add" acc_prec asm_type \ + " %0, [%1];" \ + : "=r"(vec.u32) \ + : "l"(mc_ptr) \ + : "memory"); \ + } \ + return vec; \ + } \ }; #endif -SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(at::BFloat16, "bf16x2"); -SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(float, "f32"); +SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(at::BFloat16, ".bf16x2", ".acc::f32"); +SPECIALIZE_MULTIMEM_LD_REDUCE_VEC_32(float, ".f32", ""); template __device__ __inline__ void multimem_st(T* mc_ptr, Vec& vec) { @@ -253,4 +306,145 @@ __device__ __inline__ void multimem_st(T* mc_ptr, Vec& vec) { #endif } +template +__device__ __inline__ Vec ld_vec(const T* addr) { +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); +#else + Vec vec; + if constexpr (Alignment == 16) { + asm("ld.global.v4.u32 {%0,%1,%2,%3}, [%4];" + : "=r"(vec.u32[0]), "=r"(vec.u32[1]), "=r"(vec.u32[2]), "=r"(vec.u32[3]) + : "l"(addr) + : "memory"); + } else if constexpr (Alignment == 8) { + asm("ld.global.v2.u32 {%0,%1}, [%2];" + : "=r"(vec.u32[0]), "=r"(vec.u32[1]) + : "l"(addr) + : "memory"); + } else if constexpr (Alignment == 4) { + asm("ld.global.u32 %0, [%1];" : "=r"(vec.u32) : "l"(addr) : "memory"); + } else { + static_assert(dependent_false); + } + return vec; +#endif +} + +template +__device__ __inline__ void st_vec(T* addr, const Vec& vec) { +#if defined(USE_ROCM) || !defined(NVCC_SUPPORTS_MULTICAST) + CUDA_KERNEL_ASSERT(false); +#else + if constexpr (Alignment == 16) { + asm("st.global.v4.u32 [%0], {%1,%2,%3,%4};" + : + : "l"(addr), + "r"(vec.u32[0]), + "r"(vec.u32[1]), + "r"(vec.u32[2]), + "r"(vec.u32[3]) + : "memory"); + } else if constexpr (Alignment == 8) { + asm("st.global.v2.u32 [%0], {%1,%2};" + : + : "l"(addr), "r"(vec.u32[0]), "r"(vec.u32[1]) + : "memory"); + } else if constexpr (Alignment == 4) { + asm("st.global.u32 [%0], %1;" : : "l"(addr), "r"(vec.u32) : "memory"); + } else { + static_assert(dependent_false); + } +#endif +} + +#if defined(USE_ROCM) +using __nv_bfloat162 = uint32_t; +#endif + +template +__device__ __inline__ T add_bf16x2(T a, T b) { + static_assert(sizeof(T) == 4); +#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) + CUDA_KERNEL_ASSERT(false); + return T{}; +#else + auto res = __hadd2( + *reinterpret_cast<__nv_bfloat162*>(&a), + *reinterpret_cast<__nv_bfloat162*>(&b)); + return *reinterpret_cast(&res); +#endif +} + +template +__device__ __inline__ Vec add_vec( + const Vec& a, + const Vec& b) { + Vec c{}; + if constexpr (std::is_same_v) { + if constexpr (Alignment == 16) { + c.f32[0] = a.f32[0] + b.f32[0]; + c.f32[1] = a.f32[1] + b.f32[1]; + c.f32[2] = a.f32[2] + b.f32[2]; + c.f32[3] = a.f32[3] + b.f32[3]; + } else if constexpr (Alignment == 8) { + c.f32[0] = a.f32[0] + b.f32[0]; + c.f32[1] = a.f32[1] + b.f32[1]; + } else if constexpr (Alignment == 4) { + c.f32 = a.f32 + b.f32; + } else { + static_assert(dependent_false); + } + } else if constexpr (std::is_same_v) { + if constexpr (Alignment == 16) { + c.u32[0] = add_bf16x2(a.u32[0], b.u32[0]); + c.u32[1] = add_bf16x2(a.u32[1], b.u32[1]); + c.u32[2] = add_bf16x2(a.u32[2], b.u32[2]); + c.u32[3] = add_bf16x2(a.u32[3], b.u32[3]); + } else if constexpr (Alignment == 8) { + c.u32[0] = add_bf16x2(a.u32[0], b.u32[0]); + c.u32[1] = add_bf16x2(a.u32[1], b.u32[1]); + } else if constexpr (Alignment == 4) { + c.u32 = add_bf16x2(a.u32, b.u32); + } else { + static_assert(dependent_false); + } + } else { + static_assert(dependent_false); + } + return c; +} + +// With world_size specialization: perform balanced load from all peers before +// performing reduction. +template +__device__ inline std::enable_if_t<(k_world_size > 0), Vec> +load_and_reduce(T** ptrs, size_t rank, size_t world_size, size_t offset) { + Vec vecs[k_world_size]; +#pragma unroll k_world_size + for (size_t step = 0; step < k_world_size; ++step) { + size_t remote_rank = (rank + step) % k_world_size; + vecs[remote_rank] = ld_vec(ptrs[remote_rank] + offset); + } + auto acc = vecs[0]; +#pragma unroll k_world_size - 1 + for (size_t r = 1; r < world_size; ++r) { + acc = add_vec(acc, vecs[r]); + } + return acc; +} + +// Without world_size specialization: perform ordered (unbalanced) load and +// accumulate on each load. +template +__device__ inline std::enable_if_t<(k_world_size <= 0), Vec> +load_and_reduce(T** ptrs, size_t rank, size_t world_size, size_t offset) { + Vec acc{}; + for (size_t step = 0; step < world_size; ++step) { + auto vec = ld_vec(ptrs[step] + offset); + acc = add_vec(acc, vec); + } + return acc; +} + } // namespace c10d::symmetric_memory diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu index 3ad6f5b6e6619..b535d51bfe12e 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory.cu @@ -1,5 +1,7 @@ #include +#include + #include #include #include @@ -360,8 +362,11 @@ at::Tensor CUDASymmetricMemory::get_buffer( c10::IntArrayRef sizes, c10::ScalarType dtype, int64_t storage_offset) { - const auto numel = - std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies()); + const size_t numel = std::accumulate( + sizes.begin(), + sizes.end(), + static_cast(1), + std::multiplies()); const auto element_size = c10::elementSize(dtype); const auto req_size = (numel + storage_offset) * element_size; TORCH_CHECK( @@ -371,10 +376,11 @@ at::Tensor CUDASymmetricMemory::get_buffer( " bytes) exceeds the allocated size (", buffer_size_, " bytes)"); + auto data_ptr = reinterpret_cast(buffers_[rank]) + + storage_offset * element_size; auto device = c10::Device(c10::DeviceType::CUDA, local_device_idx_); auto options = at::TensorOptions().dtype(dtype).device(device); - return at::for_blob(buffers_[rank], sizes) - .storage_offset(storage_offset) + return at::for_blob(data_ptr, sizes) .options(options) .target_device(device) .make_tensor(); @@ -397,50 +403,53 @@ void check_channel(int channel, int world_size) { ")"); } -__device__ __forceinline__ void release_signal(uint32_t* addr) { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) - CUDA_KERNEL_ASSERT(false); -#else - volatile uint32_t* signal = addr; - uint32_t val; - do { - val = *signal; - } while (val != 0 || atomicCAS_system(addr, 0, 1) != 0); -#endif -} - -__device__ __forceinline__ void acquire_signal(uint32_t* addr) { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) - CUDA_KERNEL_ASSERT(false); -#else - volatile uint32_t* signal = addr; - uint32_t val; - do { - val = *signal; - } while (val != 1 || atomicCAS_system(addr, 1, 0) != 1); -#endif -} - static __global__ void barrier_kernel( uint32_t** signal_pads, int channel, int rank, - int world_size) { + int world_size, + size_t timeout_ms) { if (threadIdx.x < world_size) { auto target_rank = threadIdx.x; - release_signal(signal_pads[target_rank] + world_size * channel + rank); - acquire_signal(signal_pads[rank] + world_size * channel + target_rank); + if (target_rank == rank) { + return; + } + auto put_success = try_put_signal( + signal_pads[target_rank] + world_size * channel + rank, timeout_ms); + if (!put_success) { + printf( + "[FATAL] CUDASymmetricMemory::barrier: rank %d failed to send signal " + "to rank %d on channel %d after %lu microseconds\n", + rank, + target_rank, + channel, + timeout_ms); + trap(); + } + auto wait_success = try_wait_signal( + signal_pads[rank] + world_size * channel + target_rank, timeout_ms); + if (!wait_success) { + printf( + "[FATAL] CUDASymmetricMemory::barrier: rank %d failed to receive signal " + "from rank %d on channel %d after %lu microseconds\n", + rank, + target_rank, + channel, + timeout_ms); + trap(); + } } } -void CUDASymmetricMemory::barrier(int channel) { +void CUDASymmetricMemory::barrier(int channel, size_t timeout_ms) { check_channel(channel, world_size_); c10::cuda::CUDAGuard guard(local_device_idx_); barrier_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( reinterpret_cast(signal_pads_dev_), channel, rank_, - world_size_); + world_size_, + timeout_ms); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -449,13 +458,28 @@ static __global__ void put_signal_kernel( int dst_rank, int channel, int rank, - int world_size) { + int world_size, + size_t timeout_ms) { if (threadIdx.x == 0) { - release_signal(signal_pads[dst_rank] + world_size * channel + rank); + bool success = try_put_signal( + signal_pads[dst_rank] + world_size * channel + rank, timeout_ms); + if (!success) { + printf( + "[FATAL] CUDASymmetricMemory::put_signal: rank %d failed to send signal " + "to rank %d on channel %d after %lu microseconds\n", + rank, + dst_rank, + channel, + timeout_ms); + trap(); + } } } -void CUDASymmetricMemory::put_signal(int dst_rank, int channel) { +void CUDASymmetricMemory::put_signal( + int dst_rank, + int channel, + size_t timeout_ms) { check_channel(channel, world_size_); c10::cuda::CUDAGuard guard(local_device_idx_); put_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( @@ -463,7 +487,8 @@ void CUDASymmetricMemory::put_signal(int dst_rank, int channel) { dst_rank, channel, rank_, - world_size_); + world_size_, + timeout_ms); C10_CUDA_KERNEL_LAUNCH_CHECK(); } @@ -472,14 +497,33 @@ static __global__ void wait_signal_kernel( int src_rank, int channel, int rank, - int world_size) { + int world_size, + size_t timeout_ms) { if (threadIdx.x == 0) { - acquire_signal(signal_pads[rank] + world_size * channel + src_rank); + bool success = try_wait_signal( + signal_pads[rank] + world_size * channel + src_rank, timeout_ms); + if (!success) { + printf( + "[FATAL] CUDASymmetricMemory::wait_signal rank %d failed to receive signal " + "from rank %d on channel %d after %lu microseconds\n", + rank, + src_rank, + channel, + timeout_ms); +#if !defined(USE_ROCM) + __trap(); +#else + assert(0); +#endif + } } __threadfence_system(); } -void CUDASymmetricMemory::wait_signal(int src_rank, int channel) { +void CUDASymmetricMemory::wait_signal( + int src_rank, + int channel, + size_t timeout_ms) { check_channel(channel, world_size_); c10::cuda::CUDAGuard guard(local_device_idx_); wait_signal_kernel<<<1, C10_WARP_SIZE, 0, at::cuda::getCurrentCUDAStream()>>>( @@ -487,7 +531,8 @@ void CUDASymmetricMemory::wait_signal(int src_rank, int channel) { src_rank, channel, rank_, - world_size_); + world_size_, + timeout_ms); C10_CUDA_KERNEL_LAUNCH_CHECK(); } diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp b/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp index 019f4532716d0..4fa907a952881 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemory.hpp @@ -45,9 +45,9 @@ class CUDASymmetricMemory : public SymmetricMemory { c10::ScalarType dtype, int64_t storage_offset) override; - void barrier(int channel) override; - void put_signal(int dst_rank, int channel) override; - void wait_signal(int src_rank, int channel) override; + void barrier(int channel, size_t timeout_ms) override; + void put_signal(int dst_rank, int channel, size_t timeout_ms) override; + void wait_signal(int src_rank, int channel, size_t timeout_ms) override; int get_rank() override; int get_world_size() override; diff --git a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu index cedcca2c97612..8a77632723822 100644 --- a/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu +++ b/torch/csrc/distributed/c10d/CUDASymmetricMemoryOps.cu @@ -16,6 +16,39 @@ #include #include +#define INT_SWITCH_CASE(name, val, ...) \ + case val: { \ + constexpr int name = val; \ + __VA_ARGS__(); \ + break; \ + } + +#define DISPATCH_WORLD_SIZES(world_size, ...) \ + switch (world_size) { \ + INT_SWITCH_CASE(k_world_size, 8, __VA_ARGS__); \ + INT_SWITCH_CASE(k_world_size, 4, __VA_ARGS__); \ + INT_SWITCH_CASE(k_world_size, 2, __VA_ARGS__); \ + default: { \ + constexpr int k_world_size = -1; \ + __VA_ARGS__(); \ + } \ + } + +#define DISPATCH_ALIGNMENTS_16_8_4(alignment, ...) \ + switch (alignment) { \ + INT_SWITCH_CASE(k_alignment, 16, __VA_ARGS__); \ + INT_SWITCH_CASE(k_alignment, 8, __VA_ARGS__); \ + INT_SWITCH_CASE(k_alignment, 4, __VA_ARGS__); \ + default: { \ + TORCH_CHECK(false, "Not implemented for aligment=", alignment); \ + } \ + } + +#define AT_DISPATCH_FLOAT_AND_BFLOAT16(scalar_type, name, ...) \ + AT_DISPATCH_SWITCH( \ + scalar_type, name, AT_DISPATCH_CASE(at::kBFloat16, __VA_ARGS__); \ + AT_DISPATCH_CASE(at::kFloat, __VA_ARGS__)); + namespace { using namespace c10d::symmetric_memory; @@ -53,6 +86,8 @@ void init_elementwise_launch_config( size_t element_size, size_t alignment, size_t splits, + size_t max_num_blocks, + size_t max_num_threads, int& num_blocks, int& num_threads) { // Align to preserve alignment in each split @@ -60,17 +95,16 @@ void init_elementwise_launch_config( const size_t numel_per_split = aligned_numel / splits; const size_t numel_per_thread = alignment / element_size; - if (numel_per_split <= max_num_threads_per_block * numel_per_thread) { + if (numel_per_split <= max_num_threads * numel_per_thread) { num_blocks = 1; num_threads = at::round_up( at::ceil_div(numel_per_split, numel_per_thread), static_cast(C10_WARP_SIZE)); } else { num_blocks = std::min( - at::ceil_div( - numel_per_split, max_num_threads_per_block * numel_per_thread), + at::ceil_div(numel_per_split, max_num_threads * numel_per_thread), max_num_blocks); - num_threads = max_num_threads_per_block; + num_threads = max_num_threads; } } @@ -84,7 +118,8 @@ static __global__ void multimem_all_reduce_kernel( static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); - barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); + __syncthreads(); const size_t numel_per_rank = at::round_up(numel, alignment * world_size) / world_size; @@ -99,11 +134,9 @@ static __global__ void multimem_all_reduce_kernel( auto vec = multimem_ld_reduce_add(input_mc_ptr + start + i); multimem_st(input_mc_ptr + start + i, vec); } - // Establish observation order - all writes are in-flight beyond this point. - barrier(signal_pads, rank, world_size); - // Establish causality order - all writes are visible to all devices beyond - // this point. - __threadfence_system(); + + __syncthreads(); + sync_remote_blocks(signal_pads, rank, world_size); } at::Tensor multimem_all_reduce_( @@ -133,36 +166,29 @@ at::Tensor multimem_all_reduce_( input.element_size(), alignment, symm_mem->get_world_size(), + 8, + 1024, num_blocks, num_threads); -#define DISPATCH(scalar_t, kernel_alignment) \ - if (alignment == kernel_alignment) { \ - multimem_all_reduce_kernel \ - <<>>( \ - reinterpret_cast(symm_mem->get_multicast_ptr()) + \ - input.storage_offset(), \ - input.numel(), \ - reinterpret_cast(symm_mem->get_signal_pad_ptrs_dev()), \ - symm_mem->get_rank(), \ - symm_mem->get_world_size()); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } - - AT_DISPATCH_SWITCH( - input.scalar_type(), - "multimem_all_reduce", - AT_DISPATCH_CASE(at::kBFloat16, [&] { - DISPATCH(scalar_t, 16); - DISPATCH(scalar_t, 8); - DISPATCH(scalar_t, 4); - }) AT_DISPATCH_CASE(at::kFloat, [&] { - DISPATCH(scalar_t, 16); - DISPATCH(scalar_t, 8); - DISPATCH(scalar_t, 4); - })); - -#undef DISPATCH + AT_DISPATCH_FLOAT_AND_BFLOAT16( + input.scalar_type(), "multimem_all_reduce_", [&]() { + DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { + multimem_all_reduce_kernel + <<>>( + reinterpret_cast(symm_mem->get_multicast_ptr()) + + input.storage_offset(), + input.numel(), + reinterpret_cast( + symm_mem->get_signal_pad_ptrs_dev()), + symm_mem->get_rank(), + symm_mem->get_world_size()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); return input; } @@ -177,23 +203,34 @@ static __global__ void multimem_one_shot_all_reduce_kernel( static_assert(alignment % sizeof(T) == 0); constexpr size_t numel_per_thread = alignment / sizeof(T); - barrier_and_acquire_previous_kernel_writes(signal_pads, rank, world_size); + sync_remote_blocks(signal_pads, rank, world_size); + __syncthreads(); auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread; auto stride = blockDim.x * gridDim.x * numel_per_thread; for (size_t i = offset; i < numel; i += stride) { auto vec = multimem_ld_reduce_add(input_mc_ptr + i); - *reinterpret_cast(output_ptr + i) = vec.as_scalar; + st_vec(output_ptr + i, vec); } + + __syncthreads(); + sync_remote_blocks(signal_pads, rank, world_size); } -at::Tensor multimem_one_shot_all_reduce( +at::Tensor multimem_one_shot_all_reduce_out( const at::Tensor& input, std::string reduce_op, - std::string group_name) { + std::string group_name, + at::Tensor out) { TORCH_CHECK( input.is_contiguous(), "multimem_one_shot_all_reduce: input must be contiguous."); + TORCH_CHECK( + out.is_contiguous(), + "multimem_one_shot_all_reduce: output must be contiguous."); + TORCH_CHECK( + out.sizes() == input.sizes(), + "multimem_one_shot_all_reduce: input/output size mismatch."); TORCH_CHECK( reduce_op == "sum", "multimem_one_shot_all_reduce: only sum is supported for now."); @@ -206,8 +243,6 @@ at::Tensor multimem_one_shot_all_reduce( symm_mem->has_multicast_support(), "multimem_one_shot_all_reduce: requires multicast support."); - auto output = at::empty_like(input); - const size_t alignment = get_and_verify_alignment(input, "multimem_one_shot_all_reduce"); @@ -217,49 +252,276 @@ at::Tensor multimem_one_shot_all_reduce( input.element_size(), alignment, 1, + 8, + 1024, num_blocks, num_threads); -#define DISPATCH(scalar_t, kernel_alignment) \ - if (alignment == kernel_alignment) { \ - multimem_one_shot_all_reduce_kernel \ - <<>>( \ - reinterpret_cast(symm_mem->get_multicast_ptr()) + \ - input.storage_offset(), \ - output.data_ptr(), \ - input.numel(), \ - reinterpret_cast(symm_mem->get_signal_pad_ptrs_dev()), \ - symm_mem->get_rank(), \ - symm_mem->get_world_size()); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ + AT_DISPATCH_FLOAT_AND_BFLOAT16( + input.scalar_type(), "multimem_one_shot_all_reduce", [&]() { + DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { + multimem_one_shot_all_reduce_kernel + <<>>( + reinterpret_cast(symm_mem->get_multicast_ptr()) + + input.storage_offset(), + out.data_ptr(), + input.numel(), + reinterpret_cast( + symm_mem->get_signal_pad_ptrs_dev()), + symm_mem->get_rank(), + symm_mem->get_world_size()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + return out; +} + +at::Tensor multimem_one_shot_all_reduce( + const at::Tensor& input, + std::string reduce_op, + std::string group_name) { + auto out = at::empty_like(input); + return multimem_one_shot_all_reduce_out(input, reduce_op, group_name, out); +} + +// One-shot all-reduce is register-intensive because it stages values loaded +// from peers in registers before performing reduction. Setting the thread +// count to 512 to prevent/alleviate register spill. +constexpr size_t one_shot_all_reduce_max_num_blocks = 8; +constexpr size_t one_shot_all_reduce_max_num_threads = 512; + +template +static __launch_bounds__(one_shot_all_reduce_max_num_threads) __global__ + void one_shot_all_reduce_kernel( + T** input_ptrs, + T* output_ptr, + size_t input_offset, + size_t numel, + uint32_t** signal_pads, + size_t rank, + size_t world_size) { + static_assert(alignment % sizeof(T) == 0); + constexpr size_t numel_per_thread = alignment / sizeof(T); + + sync_remote_blocks(signal_pads, rank, world_size); + __syncthreads(); + + auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread; + auto stride = blockDim.x * gridDim.x * numel_per_thread; + + for (size_t i = offset; i < numel; i += stride) { + auto vec = load_and_reduce( + input_ptrs, rank, world_size, input_offset + i); + st_vec(output_ptr + i, vec); } - AT_DISPATCH_SWITCH( - input.scalar_type(), - "multimem_all_reduce", - AT_DISPATCH_CASE(at::kBFloat16, [&] { - DISPATCH(scalar_t, 16); - DISPATCH(scalar_t, 8); - DISPATCH(scalar_t, 4); - }) AT_DISPATCH_CASE(at::kFloat, [&] { - DISPATCH(scalar_t, 16); - DISPATCH(scalar_t, 8); - DISPATCH(scalar_t, 4); - })); - - return output; + __syncthreads(); + sync_remote_blocks(signal_pads, rank, world_size); +} + +at::Tensor one_shot_all_reduce_out( + const at::Tensor& input, + std::string reduce_op, + std::string group_name, + at::Tensor out) { + TORCH_CHECK( + input.is_contiguous(), "one_shot_all_reduce: input must be contiguous."); + TORCH_CHECK( + out.is_contiguous(), "one_shot_all_reduce: output must be contiguous."); + TORCH_CHECK( + out.sizes() == input.sizes(), + "one_shot_all_reduce: input/output size mismatch."); + TORCH_CHECK( + reduce_op == "sum", + "one_shot_all_reduce: only sum is supported for now."); + + auto symm_mem = c10d::symmetric_memory::rendezvous(input); + TORCH_CHECK( + symm_mem != nullptr, + "one_shot_all_reduce: input must be allocated with empty_strided_p2p()."); + + const size_t alignment = + get_and_verify_alignment(input, "one_shot_all_reduce"); + + int num_blocks = 0, num_threads = 0; + init_elementwise_launch_config( + input.numel(), + input.element_size(), + alignment, + 1, + one_shot_all_reduce_max_num_blocks, + one_shot_all_reduce_max_num_threads, + num_blocks, + num_threads); + + AT_DISPATCH_FLOAT_AND_BFLOAT16( + input.scalar_type(), "one_shot_all_reduce", [&]() { + DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { + DISPATCH_WORLD_SIZES(symm_mem->get_world_size(), [&]() { + one_shot_all_reduce_kernel + <<>>( + reinterpret_cast( + symm_mem->get_buffer_ptrs_dev()), + out.data_ptr(), + input.storage_offset(), + input.numel(), + reinterpret_cast( + symm_mem->get_signal_pad_ptrs_dev()), + symm_mem->get_rank(), + symm_mem->get_world_size()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + return out; +} + +at::Tensor one_shot_all_reduce( + const at::Tensor& input, + std::string reduce_op, + std::string group_name) { + auto out = at::empty_like(input); + return one_shot_all_reduce_out(input, reduce_op, group_name, out); +} + +constexpr size_t two_shot_all_reduce_max_num_blocks = 24; +constexpr size_t two_shot_all_reduce_max_num_threads = 512; + +template +static __launch_bounds__(two_shot_all_reduce_max_num_threads) __global__ + void two_shot_all_reduce_kernel( + T** input_ptrs, + size_t input_offset, + size_t numel, + uint32_t** signal_pads, + size_t rank, + size_t world_size) { + static_assert(alignment % sizeof(T) == 0); + constexpr size_t numel_per_thread = alignment / sizeof(T); + + sync_remote_blocks(signal_pads, rank, world_size); + __syncthreads(); + + const size_t numel_per_rank = + at::round_up(numel, alignment * world_size) / world_size; + const size_t start = numel_per_rank * rank; + + auto offset = (blockDim.x * blockIdx.x + threadIdx.x) * numel_per_thread; + auto stride = blockDim.x * gridDim.x * numel_per_thread; + for (size_t i = offset; i < numel_per_rank; i += stride) { + if (start + i >= numel) { + continue; + } + auto vec = load_and_reduce( + input_ptrs, rank, world_size, input_offset + start + i); + for (size_t step = 0; step < world_size; ++step) { + size_t remote_rank = (rank + step) % world_size; + st_vec( + input_ptrs[remote_rank] + input_offset + start + i, vec); + } + } + + __syncthreads(); + sync_remote_blocks(signal_pads, rank, world_size); +} + +at::Tensor two_shot_all_reduce_( + at::Tensor input, + std::string reduce_op, + std::string group_name) { + TORCH_CHECK( + input.is_contiguous(), "two_shot_all_reduce: input must be contiguous."); + TORCH_CHECK( + reduce_op == "sum", + "two_shot_all_reduce: only sum is supported for now."); + + auto symm_mem = c10d::symmetric_memory::rendezvous(input); + TORCH_CHECK( + symm_mem != nullptr, + "two_shot_all_reduce: input must be allocated with empty_strided_p2p()."); + + const size_t alignment = + get_and_verify_alignment(input, "two_shot_all_reduce"); + + int num_blocks = 0, num_threads = 0; + init_elementwise_launch_config( + input.numel(), + input.element_size(), + alignment, + symm_mem->get_world_size(), + two_shot_all_reduce_max_num_blocks, + two_shot_all_reduce_max_num_threads, + num_blocks, + num_threads); + + AT_DISPATCH_FLOAT_AND_BFLOAT16( + input.scalar_type(), "two_shot_all_reduce", [&]() { + DISPATCH_ALIGNMENTS_16_8_4(alignment, [&]() { + DISPATCH_WORLD_SIZES(symm_mem->get_world_size(), [&]() { + two_shot_all_reduce_kernel + <<>>( + reinterpret_cast( + symm_mem->get_buffer_ptrs_dev()), + input.storage_offset(), + input.numel(), + reinterpret_cast( + symm_mem->get_signal_pad_ptrs_dev()), + symm_mem->get_rank(), + symm_mem->get_world_size()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + return input; } TORCH_LIBRARY_FRAGMENT(symm_mem, m) { m.def( - "multimem_all_reduce_(Tensor input, str reduce_op, str group_name) -> Tensor", + "multimem_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)", torch::dispatch(c10::DispatchKey::CUDA, ::multimem_all_reduce_), {at::Tag::pt2_compliant_tag}); + // NOTE: [multimem_one_shot_all_reduce] + // multimem.ld_reduce does not guarantee a fixed accumulation order. This + // means that while multimem_one_shot_all_reduce is faster and has higher + // numerical accuracy than one_shot_all_reduce, it doesn't guarantee + // identical results across ranks. There may be use cases that can take + // advantage of this property, but it should not be used without + // understanding the caveats. m.def( "multimem_one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor", torch::dispatch(c10::DispatchKey::CUDA, ::multimem_one_shot_all_reduce), {at::Tag::pt2_compliant_tag}); + + m.def( + "multimem_one_shot_all_reduce_out(Tensor input, str reduce_op, str group_name, Tensor(a!) out) -> Tensor(a!)", + torch::dispatch( + c10::DispatchKey::CUDA, ::multimem_one_shot_all_reduce_out), + {at::Tag::pt2_compliant_tag}); + + m.def( + "one_shot_all_reduce(Tensor input, str reduce_op, str group_name) -> Tensor", + torch::dispatch(c10::DispatchKey::CUDA, ::one_shot_all_reduce), + {at::Tag::pt2_compliant_tag}); + + m.def( + "one_shot_all_reduce_out(Tensor input, str reduce_op, str group_name, Tensor(a!) out) -> Tensor(a!)", + torch::dispatch(c10::DispatchKey::CUDA, ::one_shot_all_reduce_out), + {at::Tag::pt2_compliant_tag}); + + m.def( + "two_shot_all_reduce_(Tensor(a!) input, str reduce_op, str group_name) -> Tensor(a!)", + torch::dispatch(c10::DispatchKey::CUDA, ::two_shot_all_reduce_), + {at::Tag::pt2_compliant_tag}); } } // namespace diff --git a/torch/csrc/distributed/c10d/CudaDMAConnectivity.cpp b/torch/csrc/distributed/c10d/CudaDMAConnectivity.cpp index afb39bdff92e8..1ed72a9aa116a 100644 --- a/torch/csrc/distributed/c10d/CudaDMAConnectivity.cpp +++ b/torch/csrc/distributed/c10d/CudaDMAConnectivity.cpp @@ -12,6 +12,7 @@ namespace { constexpr int max_nvlinks = 64; std::string get_bus_id(int device_idx) { + // NOLINTNEXTLINE(*array*) char bus_id[80]; cudaDeviceProp prop{}; C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device_idx)); @@ -27,7 +28,7 @@ std::string get_bus_id(int device_idx) { struct C10_EXPORT NVLinkDetector : public c10d::DMAConnectivityDetector { c10::intrusive_ptr detect() override { - int num_devices; + int num_devices = 0; C10_CUDA_CHECK(cudaGetDeviceCount(&num_devices)); std::vector> matrix; @@ -46,22 +47,36 @@ struct C10_EXPORT NVLinkDetector : public c10d::DMAConnectivityDetector { bus_ids.push_back(std::move(bus_id)); } - // Obtain the nvml device for all bus_ids + static const char* warning_msg = + "PyTorch features that use NVLinkDetector may assume no NVLink presence."; + auto driver_api = c10::cuda::DriverAPI::get(); + if (driver_api->nvmlInit_v2_() != NVML_SUCCESS) { + LOG(WARNING) + << "NVLinkDetector: Failed to initialize NVML via nvmlInit_v2. " + << warning_msg; + return c10::make_intrusive( + c10::DeviceType::CUDA, "nvlink", std::move(matrix)); + } + + // Obtain the nvml device for all bus_ids std::vector nvml_devices(num_devices, nullptr); for (int i = 0; i < num_devices; ++i) { - TORCH_CHECK_EQ( - driver_api->nvmlDeviceGetHandleByPciBusId_v2_( - bus_ids[i].c_str(), &nvml_devices[i]), - NVML_SUCCESS); + auto res = driver_api->nvmlDeviceGetHandleByPciBusId_v2_( + bus_ids[i].c_str(), &nvml_devices[i]); + if (res != NVML_SUCCESS) { + LOG(WARNING) << "NVLinkDetector: Failed to obtain NVML device via " + << "nvmlDeviceGetHandleByPciBusId_v2. " << warning_msg; + return c10::make_intrusive( + c10::DeviceType::CUDA, "nvlink", std::move(matrix)); + } } std::vector switch_link_count(num_devices, 0); for (int i = 0; i < num_devices; ++i) { for (int link = 0; link < max_nvlinks; ++link) { - nvmlReturn_t ret; - nvmlIntNvLinkDeviceType_t deviceType; - ret = driver_api->nvmlDeviceGetNvLinkRemoteDeviceType_( + nvmlIntNvLinkDeviceType_t deviceType{}; + auto ret = driver_api->nvmlDeviceGetNvLinkRemoteDeviceType_( nvml_devices[i], link, &deviceType); if (ret != NVML_SUCCESS) { // We've exhausted the NVLinks connected to this device. This error @@ -74,10 +89,14 @@ struct C10_EXPORT NVLinkDetector : public c10d::DMAConnectivityDetector { // Remote device is GPU if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) { nvmlPciInfo_t pciInfo; - TORCH_CHECK_EQ( - driver_api->nvmlDeviceGetNvLinkRemotePciInfo_v2_( - nvml_devices[i], link, &pciInfo), - NVML_SUCCESS); + auto res = driver_api->nvmlDeviceGetNvLinkRemotePciInfo_v2_( + nvml_devices[i], link, &pciInfo); + if (res != NVML_SUCCESS) { + LOG(WARNING) << "NVLinkDetector: Failed to obtain NVML device via " + << "nvmlDeviceGetHandleByPciBusId_v2. " << warning_msg; + return c10::make_intrusive( + c10::DeviceType::CUDA, "nvlink", std::move(matrix)); + } auto it = bus_id_to_device_idx.find(pciInfo.busId); if (it != bus_id_to_device_idx.end()) { if (i != it->second) { diff --git a/torch/csrc/distributed/c10d/DMAConnectivity.cpp b/torch/csrc/distributed/c10d/DMAConnectivity.cpp index d920eb567197f..50c34f62426eb 100644 --- a/torch/csrc/distributed/c10d/DMAConnectivity.cpp +++ b/torch/csrc/distributed/c10d/DMAConnectivity.cpp @@ -1,10 +1,11 @@ #include +#include namespace { std::string get_detector_key( c10::DeviceType device_type, - std::string connection_type) { + const std::string& connection_type) { std::ostringstream oss; oss << device_type << "/" << connection_type; return oss.str(); @@ -12,6 +13,8 @@ std::string get_detector_key( class DetectorMap { public: + DetectorMap(const DetectorMap&) = delete; + DetectorMap& operator=(const DetectorMap&) = delete; static DetectorMap& get() { static DetectorMap instance; return instance; @@ -52,8 +55,6 @@ class DetectorMap { private: DetectorMap() = default; - DetectorMap(const DetectorMap&) = delete; - DetectorMap& operator=(const DetectorMap&) = delete; std::unordered_map< std::string, @@ -73,7 +74,7 @@ DMAConnectivity::DMAConnectivity( std::string connection_type, std::vector> matrix) : device_type(device_type), - connection_type(connection_type), + connection_type(std::move(connection_type)), matrix(std::move(matrix)) {} void register_dma_connectivity_detector( diff --git a/torch/csrc/distributed/c10d/Functional.cpp b/torch/csrc/distributed/c10d/Functional.cpp index 5c62849f841e4..1117718ee5093 100644 --- a/torch/csrc/distributed/c10d/Functional.cpp +++ b/torch/csrc/distributed/c10d/Functional.cpp @@ -418,7 +418,7 @@ class AllToAllSingle : public torch::autograd::Function { static torch::autograd::variable_list backward( torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_out_list) { + const torch::autograd::variable_list& grad_out_list) { const std::vector& output_split_sizes = ctx->saved_data["output_split_sizes"].toIntVector(); const std::vector& input_split_sizes = @@ -476,12 +476,12 @@ class ReduceScatterTensor static torch::autograd::variable_list backward( torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_out_list) { + const torch::autograd::variable_list& grad_out_list) { const int64_t group_size = ctx->saved_data["group_size"].toInt(); const std::string& group_name = ctx->saved_data["group_name"].toStringRef(); DCHECK(grad_out_list.size() == 1); - auto grad_out = grad_out_list[0]; + const auto& grad_out = grad_out_list[0]; auto out = c10::Dispatcher::singleton() @@ -532,12 +532,12 @@ class AllGatherIntoTensor static torch::autograd::variable_list backward( torch::autograd::AutogradContext* ctx, - torch::autograd::variable_list grad_out_list) { + const torch::autograd::variable_list& grad_out_list) { const int64_t group_size = ctx->saved_data["group_size"].toInt(); const std::string& group_name = ctx->saved_data["group_name"].toStringRef(); DCHECK(grad_out_list.size() == 1); - auto grad_out = grad_out_list[0]; + const auto& grad_out = grad_out_list[0]; auto out = c10::Dispatcher::singleton() diff --git a/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp b/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp index 3441c38be32ab..8ef195ca5ecbd 100644 --- a/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp +++ b/torch/csrc/distributed/c10d/GlooDeviceFactory.cpp @@ -5,6 +5,7 @@ #include #include +#include #if GLOO_HAVE_TRANSPORT_TCP #include @@ -66,10 +67,6 @@ C10_REGISTER_CREATOR(GlooDeviceRegistry, TCP, makeTCPDevice); #endif #if GLOO_HAVE_TRANSPORT_TCP_TLS -static std::string cstr_to_std_string(const char* chars) { - return std::string(chars != nullptr ? chars : ""); -} - static std::shared_ptr<::gloo::transport::Device> makeTCPTLSDevice( const std::string& interface, const std::string& hostname) { @@ -84,14 +81,20 @@ static std::shared_ptr<::gloo::transport::Device> makeTCPTLSDevice( } else { attr.hostname = hostname; } - const auto pkey = - cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY")); - const auto cert = - cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT")); + const auto pkey_env = + c10::utils::get_env("GLOO_DEVICE_TRANSPORT_TCP_TLS_PKEY"); + const auto pkey = pkey_env.has_value() ? pkey_env.value() : std::string(); + const auto cert_env = + c10::utils::get_env("GLOO_DEVICE_TRANSPORT_TCP_TLS_CERT"); + const auto cert = cert_env.has_value() ? cert_env.value() : std::string(); + const auto caFile_env = + c10::utils::get_env("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE"); const auto caFile = - cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_FILE")); + caFile_env.has_value() ? caFile_env.value() : std::string(); + const auto caPath_env = + c10::utils::get_env("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_PATH"); const auto caPath = - cstr_to_std_string(std::getenv("GLOO_DEVICE_TRANSPORT_TCP_TLS_CA_PATH")); + caPath_env.has_value() ? caPath_env.value() : std::string(); return ::gloo::transport::tcp::tls::CreateDevice( attr, pkey, cert, caFile, caPath); } @@ -129,9 +132,10 @@ namespace { std::shared_ptr<::gloo::transport::Device> makeGlooDevice( const std::string& interfaceName, const std::string& hostName) { - static auto transportName = getenv("GLOO_DEVICE_TRANSPORT"); - if (transportName) { - return GlooDeviceRegistry()->Create(transportName, interfaceName, hostName); + static auto transportName = c10::utils::get_env("GLOO_DEVICE_TRANSPORT"); + if (transportName.has_value()) { + return GlooDeviceRegistry()->Create( + transportName.value().c_str(), interfaceName, hostName); } #ifdef __linux__ diff --git a/torch/csrc/distributed/c10d/GroupRegistry.cpp b/torch/csrc/distributed/c10d/GroupRegistry.cpp index b13b4fa07c28e..2a735a4c99592 100644 --- a/torch/csrc/distributed/c10d/GroupRegistry.cpp +++ b/torch/csrc/distributed/c10d/GroupRegistry.cpp @@ -10,10 +10,11 @@ namespace { class GroupRegistry { public: void register_group( - const std::string& group_name, + std::string group_name, c10::intrusive_ptr group) { std::unique_lock write_lock(lock_); - auto [_, inserted] = registry_.try_emplace(group_name, std::move(group)); + auto [_, inserted] = + registry_.try_emplace(std::move(group_name), std::move(group)); TORCH_CHECK( inserted, "A process group is already registered under the name", @@ -70,12 +71,11 @@ bool get_thread_isolation_mode() { void register_process_group( const std::string& group_name, - c10::intrusive_ptr group) { + const c10::intrusive_ptr& group) { if (thread_isolation_mode) { - RankLocal<::GroupRegistry>::get().register_group( - group_name, std::move(group)); + RankLocal<::GroupRegistry>::get().register_group(group_name, group); } else { - process_registry.register_group(group_name, std::move(group)); + process_registry.register_group(group_name, group); } } diff --git a/torch/csrc/distributed/c10d/GroupRegistry.hpp b/torch/csrc/distributed/c10d/GroupRegistry.hpp index b22fb1ae8faf3..dc64adeaf6618 100644 --- a/torch/csrc/distributed/c10d/GroupRegistry.hpp +++ b/torch/csrc/distributed/c10d/GroupRegistry.hpp @@ -10,7 +10,7 @@ bool get_thread_isolation_mode(); C10_EXPORT void register_process_group( const std::string& group_name, - c10::intrusive_ptr group); + const c10::intrusive_ptr& group); C10_EXPORT c10::intrusive_ptr resolve_process_group( const std::string& group_name); diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 47ace12db6c3f..a86039c6ef4d4 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -2,9 +2,8 @@ #include #include -#include #include -#include +#include #ifdef USE_C10D_NCCL #include @@ -14,10 +13,6 @@ #include -namespace { -constexpr int64_t kCommInitBusyWaitMillis = 10; -} // namespace - namespace c10d { ncclComm_t NCCLComm::getNcclComm() { @@ -35,39 +30,23 @@ ncclComm_t NCCLComm::getNcclComm() { ". ", commFailureMsg)); } - // only wait for initialization if nonblocking mode is enabled - if (!initialized_ && nccl_use_nonblocking()) { - waitUntilInitialized(nccl_nonblocking_timeout()); + // In non-blocking mode, ensure comm is ready. + if (nonBlocking_) { + // If timeout is reached, throw an exception. + C10D_NCCL_CHECK_TIMEOUT_SLEEP(ncclInProgress, ncclComm_, std::nullopt); + // ncclComm_ should be initialized by now } - - return ncclComm_; -} - -void NCCLComm::waitUntilInitialized(int timeoutSecs) { - auto startTimepoint = std::chrono::steady_clock::now(); - while (!initialized_) { - if (ncclComm_) { - ncclResult_t result; - ncclCommGetAsyncError(ncclComm_, &result); - if (result == ncclSuccess) { - LOG(INFO) << "Rank " << rank_ << ": NCCL communicator is initialized."; - initialized_ = true; - break; - } - } - auto currentTimepoint = std::chrono::steady_clock::now(); - auto timeElapsed = std::chrono::duration_cast( - currentTimepoint - startTimepoint) - .count(); - if (timeElapsed > timeoutSecs) { - std::string err = "NCCL timeout in communicator initialization."; - TORCH_CHECK_WITH(DistBackendError, false, err); - } - std::this_thread::sleep_for( - std::chrono::milliseconds(kCommInitBusyWaitMillis)); + if (!initialized_) { + // TODO: see if we can consolidate other `initialized_` flipping here. + // Maintaining it elsewhere is some work. + initialized_ = true; + LOG(INFO) << "Rank " << rank_ << ": NCCL communicator " << repr() + << " is initialized."; } + return ncclComm_; } +// TODO: why do we have `!defined(FBCODE_CAFFE2)` here? #if defined(NCCL_HAS_COMM_SPLIT) && !defined(FBCODE_CAFFE2) // last argument to split() API is not used to support // multiple implementations @@ -77,16 +56,54 @@ std::shared_ptr NCCLComm::split( int rank, ncclConfig_t& config, std::vector& ranks_ull) { + TORCH_CHECK( + color_id >= NCCL_SPLIT_NOCOLOR, + "Color must be a non-negative value or NCCL_SPLIT_NOCOLOR (-1)" + ", but got ", + color_id); + LOG(INFO) << "Rank " << source->rank_ << ": split from parent comm " + << source->repr() << " with color_id " << color_id << " and rank " + << rank; auto comm = std::make_shared(); + // This call will block until the source communicator is initialized + auto sourceComm = source->getNcclComm(); +#ifndef NCCL_HAS_COMM_NONBLOCKING C10D_NCCL_CHECK( - ncclCommSplit( - source->ncclComm_, color_id, rank, &(comm->ncclComm_), &config), + ncclCommSplit(sourceComm, color_id, rank, &(comm->ncclComm_), &config), + std::nullopt); +#else + // After calling ncclCommSplit in non-blocking mode, we should wait for the + // source communicator to be out of ncclInProgress state. + // Reason 1: + // it's unsafe to call new operations on the parent comm while it's in + // ncclInProgress state. + // Reason 2: + // as of NCCL 2.23, the ptr value of child comm will not be filled until the + // state of parent comm is ncclSuccess. This may change in the future. See: + // https://github.com/NVIDIA/nccl/issues/1472 + C10D_NCCL_CHECK_TIMEOUT_SLEEP( + ncclCommSplit(sourceComm, color_id, rank, &(comm->ncclComm_), &config), + sourceComm, // wait on parent comm std::nullopt); + if (color_id >= 0) { + // Waiting for parent comm above still does not seem to guarantee the child + // comm ptr is valid. Therefore we add a manual wait here for safety. + // TODO: remove this wait after NCCL fix the semantics. + auto startTime = std::chrono::steady_clock::now(); + auto timeout = nccl_nonblocking_timeout(); + while (!comm->ncclComm_) { + C10D_CHECK_TIMEOUT(startTime, timeout); + C10D_SCHED_SLEEP(); + } + } + // comm->ncclComm_ should have valid ptr by now, but not necessarily + // initialized. Rely on getNcclComm() to wait for its initialization. +#endif ++source->ncclCommSplitCounter_; comm->rank_ = rank; - if (!nccl_use_nonblocking()) { - comm->initialized_ = true; - } + comm->nonBlocking_ = config.blocking == 0; + LOG(INFO) << "Rank " << source->rank_ << ": created child comm " + << comm->repr() << " with color_id " << color_id; return comm; } #endif @@ -96,7 +113,7 @@ std::string getNcclVersion() { static std::string versionString; c10::call_once(ncclGetVersionFlag, []() { - int version; + int version = 0; ncclResult_t status = ncclGetVersion(&version); // can't compute the version if call did not return successfully or version // code < 100 (corresponding to 0.1.0) @@ -114,7 +131,7 @@ std::string getNcclVersion() { std::to_string(ncclMinor) + "." + std::to_string(ncclPatch); #ifdef NCCL_SUFFIX const auto ncclSuffix = std::string(NCCL_SUFFIX); - if (ncclSuffix.length()) { + if (!ncclSuffix.empty()) { versionString += "." + ncclSuffix; } #endif @@ -132,16 +149,14 @@ size_t hashTensors(const std::vector& tensors) { size_t data_size = tensor.storage().nbytes(); if (data_size > 0 && tensor.storage().data_ptr()) { auto src = static_cast(tensor.storage().data_ptr().get()); - char* dst = (char*)std::calloc(data_size, sizeof(char)); + std::vector dst(data_size); // This is needed so that we trigger a device synchronization so we can // get the collective finished if launched on GPU and hash its output. - cudaMemcpy(dst, src, data_size, cudaMemcpyDeviceToHost); + cudaMemcpy(dst.data(), src, data_size, cudaMemcpyDeviceToHost); for (size_t i = 0; i < data_size; ++i) { // Update the hash for each byte in the tensor - hash = c10::hash_combine( - hash, c10::get_hash(((char*)dst)[i], data_size)); + hash = c10::hash_combine(hash, c10::get_hash(dst[i], data_size)); } - free(dst); } } } @@ -149,35 +164,21 @@ size_t hashTensors(const std::vector& tensors) { } #endif -bool nccl_use_nonblocking() { - static bool nccl_use_nonblocking_ = - c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING") == true; - if (nccl_use_nonblocking_) { - TORCH_WARN_ONCE("Using experimental non-blocking NCCL communicator."); - } - return nccl_use_nonblocking_; -} - -int _parse_nccl_nonblocking_timeout() { - const char* val = getenv("TORCH_NCCL_NONBLOCKING_TIMEOUT"); - int timeout = -1; - if (val) { - const std::string config(val); - timeout = std::stoi(config); - if (!nccl_use_nonblocking() && timeout > 0) { - TORCH_WARN( - "TORCH_NCCL_NONBLOCKING_TIMEOUT has no effect when TORCH_NCCL_USE_COMM_NONBLOCKING is false."); - timeout = -1; +// Default value: 30 minutes +int nccl_nonblocking_timeout() { + static int timeout = -2; // -2 means not initialized + if (timeout == -2) { + const auto val = c10::utils::get_env("TORCH_NCCL_NONBLOCKING_TIMEOUT"); + if (val.has_value() && !val.value().empty()) { + timeout = stoi(val.value()); + } else { + // Default value consistent with kBackendDefaultTimeout + timeout = 30 * 60; } } return timeout; } -int nccl_nonblocking_timeout() { - static int timeout = _parse_nccl_nonblocking_timeout(); - return timeout; -} - std::string ncclGetErrorWithVersion(ncclResult_t error) { return std::string(ncclGetErrorString(error)) + ", NCCL version " + getNcclVersion(); @@ -197,7 +198,7 @@ std::string getNcclErrorDetailStr( std::string interpret; std::string err; #ifdef ENABLE_NCCL_GET_LAST_ERROR - auto ret = ncclGetLastError(NULL); + auto ret = ncclGetLastError(nullptr); if (ret) { err = "\nLast error:\n" + std::string(ret); } else { @@ -242,7 +243,7 @@ std::string getNcclErrorDetailStr( control_plane::RegisterHandler dumpHandler{ "dump_nccl_trace_pickle", [](const control_plane::Request& req, control_plane::Response& res) { - const auto params = req.params(); + const auto& params = req.params(); size_t validParamCount = 0; // valid params @@ -290,7 +291,7 @@ control_plane::RegisterHandler dumpHandler{ control_plane::RegisterHandler jsonDumpHandler{ "dump_nccl_trace_json", [](const control_plane::Request& req, control_plane::Response& res) { - const auto params = req.params(); + const auto& params = req.params(); size_t validParamCount = 0; // valid params @@ -344,7 +345,12 @@ void DebugInfoWriter::write(const std::string& ncclTrace) { return; } - file.write(ncclTrace.data(), ncclTrace.size()); + file.write(ncclTrace.data(), static_cast(ncclTrace.size())); + if (!file) { + LOG(ERROR) << "Error opening file for writing NCCLPG debug info: " + << filename_; + return; + } LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_; } @@ -389,7 +395,7 @@ std::optional NCCLTraceBuffer::record( } if (all_pg_status_.find(pg_id) == all_pg_status_.end()) { // Current pg_status is not in FR. - all_pg_status_[pg_id] = pg_status; + all_pg_status_[pg_id] = std::move(pg_status); } auto traceback = torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); @@ -404,8 +410,8 @@ std::optional NCCLTraceBuffer::record( op_id, std::move(profiling_name), std::move(traceback), - std::move(start), - std::move(end), + start, + end, c10::getTime(), timeout_ms.count(), isP2P, @@ -422,14 +428,14 @@ std::optional NCCLTraceBuffer::record( for (const auto& input : inputs) { c10::IntArrayRef sizes = input.sizes(); te.input_dtypes_.push_back(input.dtype().toScalarType()); - te.input_dims_.push_back(sizes.size()); + te.input_dims_.push_back(static_cast(sizes.size())); te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); } for (const auto& output : outputs) { c10::IntArrayRef sizes = output.sizes(); te.output_dtypes_.push_back(output.dtype().toScalarType()); - te.output_dims_.push_back(sizes.size()); + te.output_dims_.push_back(static_cast(sizes.size())); te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); } @@ -451,7 +457,7 @@ void NCCLTraceBuffer::record_pg_ranks( return; } std::lock_guard guard(mutex_); - pg_name_to_ranks_[pg_name] = ranks; + pg_name_to_ranks_[pg_name] = std::move(ranks); } void NCCLTraceBuffer::update_state(Entry& r) { @@ -473,8 +479,14 @@ std::vector NCCLTraceBuffer::dump_entries() { std::lock_guard guard(mutex_); std::vector result; result.reserve(entries_.size()); - result.insert(result.end(), entries_.begin() + next_, entries_.end()); - result.insert(result.end(), entries_.begin(), entries_.begin() + next_); + result.insert( + result.end(), + entries_.begin() + static_cast(next_), + entries_.end()); + result.insert( + result.end(), + entries_.begin(), + entries_.begin() + static_cast(next_)); // query any remaining events for (auto& r : result) { update_state(r); @@ -527,7 +539,7 @@ void NCCLTraceBuffer::retire_id( return; } if (duration.has_value()) { - entry->duration_ = duration.value(); + entry->duration_ = duration; } } } @@ -564,7 +576,7 @@ const c10::List NCCLTraceBuffer::getCollectiveTrace( if (includeStacktraces) { auto& tb = stracebacks.tracebacks.at(i); auto frames = new_list(); - for (int64_t frame : tb) { + for (auto frame : tb) { frames.push_back(all_frames.at(frame)); } dict.insert(frames_key, frames); @@ -583,11 +595,11 @@ const c10::List NCCLTraceBuffer::getCollectiveTrace( } auto it = e.sizes_.begin(); - auto read_sizes = [&](const c10::SmallVector& dims) { + auto read_sizes = [&](const c10::SmallVector& dims) { auto sizes = new_list(); for (auto dim : dims) { auto arg_sizes = new_list(); - for (C10_UNUSED auto i : c10::irange(dim)) { + for ([[maybe_unused]] auto i : c10::irange(dim)) { arg_sizes.push_back(*it++); } sizes.push_back(arg_sizes); @@ -599,14 +611,14 @@ const c10::List NCCLTraceBuffer::getCollectiveTrace( std::vector input_dtypes_strs; input_dtypes_strs.reserve(e.input_dtypes_.size()); for (const auto& input_dtype : e.input_dtypes_) { - input_dtypes_strs.push_back(c10::toString(input_dtype)); + input_dtypes_strs.emplace_back(c10::toString(input_dtype)); } dict.insert(input_dtypes_key, input_dtypes_strs); dict.insert(output_sizes_key, read_sizes(e.output_dims_)); std::vector output_dtypes_strs; output_dtypes_strs.reserve(e.output_dtypes_.size()); for (const auto& output_dtype : e.output_dtypes_) { - output_dtypes_strs.push_back(c10::toString(output_dtype)); + output_dtypes_strs.emplace_back(c10::toString(output_dtype)); } dict.insert(output_dtypes_key, output_dtypes_strs); if (e.time_discovered_completed_.has_value()) { @@ -721,10 +733,10 @@ std::string NCCLTraceBuffer::dump_json( j[duration_key_str] = *e.duration_; } auto it = e.sizes_.begin(); - auto read_sizes = [&](const c10::SmallVector& dims) { - auto sizes = std::list>(); + auto read_sizes = [&](const c10::SmallVector& dims) { + auto sizes = std::list>(); for (auto dim : dims) { - auto arg_sizes = std::list(); + auto arg_sizes = std::list(); for (auto i : c10::irange(dim)) { (void)i; arg_sizes.push_back(*it++); @@ -737,14 +749,14 @@ std::string NCCLTraceBuffer::dump_json( std::vector input_dtypes_strs; input_dtypes_strs.reserve(e.input_dtypes_.size()); for (const auto& input_dtype : e.input_dtypes_) { - input_dtypes_strs.push_back(c10::toString(input_dtype)); + input_dtypes_strs.emplace_back(c10::toString(input_dtype)); } j[input_dtypes_key_str] = input_dtypes_strs; j[output_sizes_key_str] = read_sizes(e.output_dims_); std::vector output_dtypes_strs; output_dtypes_strs.reserve(e.output_dtypes_.size()); for (const auto& output_dtype : e.output_dtypes_) { - output_dtypes_strs.push_back(c10::toString(output_dtype)); + output_dtypes_strs.emplace_back(c10::toString(output_dtype)); } j[output_dtypes_key_str] = output_dtypes_strs; if (e.time_discovered_completed_.has_value()) { @@ -768,7 +780,7 @@ std::string NCCLTraceBuffer::dump_json( entries.emplace_back(j); } - if (entries.size() > 0) { + if (!entries.empty()) { result[entries_key_str] = entries; } } @@ -809,7 +821,7 @@ std::string NCCLTraceBuffer::dump( per_comm_dict.insert(ncclId, inner_dict); } } - if (per_comm_dict.size() > 0) { + if (!per_comm_dict.empty()) { result.insert(nccl_comm_key, per_comm_dict); } return pickle_str(result); diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index 070cbd34b3797..af32ab83ef57c 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -2,12 +2,12 @@ #ifdef USE_C10D_NCCL -#include -#include +#include +#include +#include #include #include -#include #include #include @@ -16,6 +16,8 @@ #include #include +constexpr int64_t kCommInitBusyWaitMillis = 2; + #if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && \ (NCCL_MINOR >= 14) #define NCCL_HAS_COMM_NONBLOCKING @@ -101,60 +103,77 @@ } \ } while (0) +// Error out if (current time - startTime) is greater than timeout (sec). +#define C10D_CHECK_TIMEOUT(startTime, timeout) \ + do { \ + auto currentTime = std::chrono::steady_clock::now(); \ + auto timeElapsed = std::chrono::duration_cast( \ + currentTime - startTime) \ + .count(); \ + if (timeElapsed > timeout) { \ + std::string err = "NCCL timeout in: " + std::string(__FILE__) + ":" + \ + std::to_string(__LINE__); \ + TORCH_CHECK_WITH(DistBackendError, false, err); \ + } \ + } while (0) + // Macro to throw on a non-successful NCCL return value, non-blocking. -#define C10D_NCCL_CHECK_TIMEOUT(cmd, comm, failureReason) \ - ncclResult_t result = cmd; \ - auto startTimepoint = std::chrono::steady_clock::now(); \ - while (result == ncclInProgress) { \ - if (nccl_nonblocking_timeout() > 0) { \ - auto currentTimepoint = std::chrono::steady_clock::now(); \ - auto timeElapsed = std::chrono::duration_cast( \ - currentTimepoint - startTimepoint) \ - .count(); \ - if (timeElapsed > nccl_nonblocking_timeout()) { \ - std::string err = "NCCL timeout in: " + std::string(__FILE__) + ":" + \ - std::to_string(__LINE__) + ", " + \ - ncclGetErrorWithVersion(result) + "\n" + \ - getNcclErrorDetailStr(result, failureReason); \ - TORCH_CHECK_WITH(DistBackendError, false, err); \ - } \ +#define C10D_NCCL_CHECK_TIMEOUT_BASE(cmd, comm, failureReason, yield_fn) \ + do { \ + ncclResult_t result = cmd; \ + auto startTimepoint = std::chrono::steady_clock::now(); \ + auto timeout = nccl_nonblocking_timeout(); \ + while (result == ncclInProgress) { \ + C10D_CHECK_TIMEOUT(startTimepoint, timeout); \ + yield_fn; \ + ncclCommGetAsyncError(comm, &result); \ } \ - ncclCommGetAsyncError(comm, &result); \ - } \ - if (result != ncclSuccess) { \ - std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ - std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \ - "\n" + getNcclErrorDetailStr(result, failureReason); \ - TORCH_CHECK_WITH(DistBackendError, false, err); \ - } + if (result != ncclSuccess) { \ + std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ + std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(result) + \ + "\n" + getNcclErrorDetailStr(result, failureReason); \ + TORCH_CHECK_WITH(DistBackendError, false, err); \ + } \ + } while (0) + +// Sleep for kCommInitBusyWaitMillis milliseconds. +#define C10D_SCHED_SLEEP() \ + std::this_thread::sleep_for( \ + std::chrono::milliseconds(kCommInitBusyWaitMillis)) + +// Macro to throw exception on a non-successful NCCL return value or timeout. +// This macro uses sched_yield() to yield the CPU. +// Thus suitable for NCCL calls that would quickly turn ncclSuccess, e.g. +// collectives. +#define C10D_NCCL_CHECK_TIMEOUT(cmd, comm, failureReason) \ + C10D_NCCL_CHECK_TIMEOUT_BASE(cmd, comm, failureReason, sched_yield()) + +// Macro to throw exception on a non-successful NCCL return value or timeout. +// This macro uses sleep to yield the CPU. +// Thus suitable for NCCL calls that would take longer to turn ncclSuccess, e.g. +// ncclCommInitRankConfig, ncclCommFinalize, etc. +#define C10D_NCCL_CHECK_TIMEOUT_SLEEP(cmd, comm, failureReason) \ + C10D_NCCL_CHECK_TIMEOUT_BASE(cmd, comm, failureReason, C10D_SCHED_SLEEP()) #define C10D_NCCL_CHECK_TIMEOUT_GROUPEND(cmd, comm, failureReason) \ - ncclResult_t state = cmd; \ - auto startTimepoint = std::chrono::steady_clock::now(); \ - if (state == ncclInProgress) { \ - do { \ - if (nccl_nonblocking_timeout() > 0) { \ - auto currentTimepoint = std::chrono::steady_clock::now(); \ - auto timeElapsed = std::chrono::duration_cast( \ - currentTimepoint - startTimepoint) \ - .count(); \ - if (timeElapsed > nccl_nonblocking_timeout()) { \ - std::string err = "NCCL timeout in: " + std::string(__FILE__) + \ - ":" + std::to_string(__LINE__) + ", " + \ - ncclGetErrorWithVersion(state) + "\n" + \ - getNcclErrorDetailStr(state, failureReason); \ - TORCH_CHECK_WITH(DistBackendError, false, err); \ - } \ - } \ - ncclCommGetAsyncError(comm->getNcclComm(), &state); \ - } while (state == ncclInProgress); \ - } \ - if (state != ncclSuccess) { \ - std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ - std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \ - "\n" + getNcclErrorDetailStr(state, failureReason); \ - TORCH_CHECK_WITH(DistBackendError, false, err); \ - } + do { \ + ncclResult_t state = cmd; \ + auto startTimepoint = std::chrono::steady_clock::now(); \ + auto timeout = nccl_nonblocking_timeout(); \ + if (state == ncclInProgress) { \ + do { \ + C10D_CHECK_TIMEOUT(startTimepoint, timeout); \ + sched_yield(); \ + ncclCommGetAsyncError(comm->getNcclComm(), &state); \ + } while (state == ncclInProgress); \ + } \ + if (state != ncclSuccess) { \ + std::string err = "NCCL error in: " + std::string(__FILE__) + ":" + \ + std::to_string(__LINE__) + ", " + ncclGetErrorWithVersion(state) + \ + "\n" + getNcclErrorDetailStr(state, failureReason); \ + TORCH_CHECK_WITH(DistBackendError, false, err); \ + } \ + } while (0) // Macro to print and abort on a non-successful NCCL return value. #define C10D_NCCL_ASSERT(cmd) \ @@ -217,7 +236,6 @@ DEFINE_CONSTANT(started_state, "started"); TORCH_API size_t hashTensors(const std::vector& tensors); TORCH_API std::string getNcclVersion(); TORCH_API std::string ncclGetErrorWithVersion(ncclResult_t error); -bool nccl_use_nonblocking(); int nccl_nonblocking_timeout(); // Provides additional detail into NCCL error codes based on when these are @@ -245,7 +263,7 @@ class TORCH_API DebugInfoWriter { } protected: - DebugInfoWriter(std::string namePrefix, int rank) { + DebugInfoWriter(const std::string& namePrefix, int rank) { filename_ = c10::str(namePrefix, rank); } std::string filename_; @@ -258,14 +276,9 @@ class TORCH_API DebugInfoWriter { // RAII wrapper for NCCL communicator class NCCLComm { public: - explicit NCCLComm(ncclComm_t ncclComm) - : ncclComm_(ncclComm), - aborted_(false), - ncclAsyncErr_(ncclSuccess), - commFailureReason_(std::nullopt), - initialized_(false) {} + explicit NCCLComm(ncclComm_t ncclComm) : ncclComm_(ncclComm) {} - NCCLComm() : NCCLComm(nullptr) {} + NCCLComm() = default; ~NCCLComm() noexcept { // Add lock in this destructor, as aborted_ needs to be read after memory @@ -294,6 +307,8 @@ class NCCLComm { comm->ncclId_ = commId; comm->rank_ = rank; comm->initialized_ = true; + // Old style comm is always blocking. + comm->nonBlocking_ = false; return comm; } @@ -304,26 +319,19 @@ class NCCLComm { ncclUniqueId commId, ncclConfig_t& config) { auto comm = std::make_shared(); - bool isInitialized = false; - if (nccl_use_nonblocking()) { - config.blocking = 0; - LOG(INFO) << "Rank " << rank - << ": creating NCCL communicator in nonblocking mode"; - C10D_NCCL_CHECK_NONBLOCKING( - ncclCommInitRankConfig( - &(comm->ncclComm_), numRanks, commId, rank, &config), - std::nullopt); - } else { - C10D_NCCL_CHECK( - ncclCommInitRankConfig( - &(comm->ncclComm_), numRanks, commId, rank, &config), - std::nullopt); - // under blocking mode, comm is initialized after NCCL CHECK - isInitialized = true; - } + comm->nonBlocking_ = config.blocking == 0; + LOG(INFO) << "Rank " << rank << ": creating NCCL communicator with mode: " + << (comm->nonBlocking_ ? "nonblocking" : "blocking"); + C10D_NCCL_CHECK_NONBLOCKING( + ncclCommInitRankConfig( + &(comm->ncclComm_), numRanks, commId, rank, &config), + std::nullopt); comm->ncclId_ = commId; comm->rank_ = rank; - comm->initialized_ = isInitialized; + // Under blocking mode, comm is initialized immediately after NCCL init + // returns; Under nonblocking mode, we check whether comm is initialized the + // *next* time ncclComm_ is accessed. + comm->initialized_ = !comm->nonBlocking_; return comm; } @@ -359,6 +367,7 @@ class NCCLComm { NCCLComm& operator=(NCCLComm&& other) = delete; // Move constructable + // NOLINTNEXTLINE(.*-noexcept-move-.*) NCCLComm(NCCLComm&& other) { // Using other's lock, as it reads other's states // Can not use this.mutex_, as this object is being constructed. @@ -367,6 +376,7 @@ class NCCLComm { std::swap(aborted_, other.aborted_); std::swap(ncclAsyncErr_, other.ncclAsyncErr_); std::swap(initialized_, other.initialized_); + std::swap(nonBlocking_, other.nonBlocking_); } ncclComm_t getNcclComm(); @@ -425,6 +435,11 @@ class NCCLComm { #endif } + bool isInitialized() const { + std::unique_lock lock(mutex_); + return initialized_; + } + bool isAborted() const { std::unique_lock lock(mutex_); return aborted_; @@ -463,16 +478,18 @@ class NCCLComm { " has already been registered on ncclComm_ ", ncclComm_); - void* handle; + void* handle = nullptr; + // Use getNcclComm to make sure comm is ready before calling nccl APIs + auto comm = getNcclComm(); C10D_NCCL_CHECK( - ncclCommRegister(ncclComm_, ptr, size, &handle), + ncclCommRegister(comm, ptr, size, &handle), c10::str( "Failed to register segment with ptr ", ptr, ", size ", size, " on ncclComm_ ", - ncclComm_)); + comm)); registeredSegmentHandles_[ptr] = handle; return ncclSuccess; #else @@ -491,15 +508,17 @@ class NCCLComm { ncclComm_); void* handle = registeredSegmentHandles_[ptr]; + // Use getNcclComm to make sure comm is ready before calling nccl APIs + auto comm = getNcclComm(); C10D_NCCL_CHECK( - ncclCommDeregister(ncclComm_, handle), + ncclCommDeregister(comm, handle), c10::str( "Failed to deregister segment handle ", handle, ", with ptr ", ptr, " on ncclComm_ ", - ncclComm_)); + comm)); registeredSegmentHandles_.erase(ptr); return ncclSuccess; #else @@ -507,28 +526,36 @@ class NCCLComm { #endif } + std::string repr() const { + return c10::str((void*)ncclComm_); + } + friend class ProcessGroupNCCL; protected: - // a helper function to wait until the communicator is initialized; - void waitUntilInitialized(int timeoutSecs); - ncclComm_t ncclComm_; // Unique nccl_id for this communicator. - ncclUniqueId ncclId_; - bool aborted_; + ncclUniqueId ncclId_{}; + bool aborted_{false}; uint64_t ncclCommSplitCounter_{0}; - ncclResult_t ncclAsyncErr_; + ncclResult_t ncclAsyncErr_{ncclSuccess}; mutable std::mutex mutex_; // Rank that this communicator corresponds to. - int rank_; + int rank_{}; // Optional reason for communicator failure, provided by ProcessGroupNCCL for // better error messaging. - std::optional commFailureReason_; + std::optional commFailureReason_{}; bool initialized_{false}; + // Whether this communicator is using nonblocking mode. Recorded during comm + // creation or split. For safety, we give a default value of true (more + // protection). + bool nonBlocking_{true}; #ifdef NCCL_HAS_COMM_REGISTER // Stores handlers for tensors registered by NCCL std::unordered_map registeredSegmentHandles_; #endif + + private: + ncclComm_t ncclComm_{nullptr}; }; // Helper that automatically cleans up premul sums. @@ -539,7 +566,7 @@ struct ncclRedOpRAII { : op_(op), comm_(comm), premul_sum_(true) {} ncclRedOpRAII(const ncclRedOpRAII&) = delete; ncclRedOpRAII& operator=(const ncclRedOpRAII&) = delete; - ncclRedOpRAII(ncclRedOpRAII&& tmp) : ncclRedOpRAII() { + ncclRedOpRAII(ncclRedOpRAII&& tmp) noexcept : ncclRedOpRAII() { std::swap(tmp.op_, this->op_); std::swap(tmp.comm_, this->comm_); std::swap(tmp.premul_sum_, this->premul_sum_); @@ -554,8 +581,8 @@ struct ncclRedOpRAII { operator ncclRedOp_t() const { return op_; } - ncclRedOp_t op_; - ncclComm_t comm_; + ncclRedOp_t op_{}; + ncclComm_t comm_{}; bool premul_sum_ = false; }; @@ -626,9 +653,9 @@ struct NCCLTraceBuffer { std::optional time_discovered_completed_; // size information for input/output tensors - c10::SmallVector input_dims_; + c10::SmallVector input_dims_; std::vector input_dtypes_; - c10::SmallVector output_dims_; + c10::SmallVector output_dims_; std::vector output_dtypes_; c10::SmallVector sizes_; // flattened from inputs, outputs bool retired_ = false; // is this work entry no longer in the workMetaList_? diff --git a/torch/csrc/distributed/c10d/Ops.cpp b/torch/csrc/distributed/c10d/Ops.cpp index ae822ad397504..6251bfa1817dd 100644 --- a/torch/csrc/distributed/c10d/Ops.cpp +++ b/torch/csrc/distributed/c10d/Ops.cpp @@ -427,6 +427,7 @@ IMPL_ALLTOALL_BASE(CPU) IMPL_ALLTOALL_BASE(CUDA) IMPL_ALLTOALL_BASE(PrivateUse1) +// NOLINTBEGIN(performance-unnecessary-value-param) #define IMPL_BARRIER(DEV) \ c10::intrusive_ptr barrier##DEV( \ at::Tensor /* unused */, \ @@ -441,9 +442,11 @@ IMPL_ALLTOALL_BASE(PrivateUse1) IMPL_BARRIER(CPU) IMPL_BARRIER(CUDA) IMPL_BARRIER(PrivateUse1) +// NOLINTEND(performance-unnecessary-value-param) // NOLINTEND(cppcoreguidelines-pro-type-const-cast) void monitored_barrier_CPU( + // NOLINTNEXTLINE(performance-unnecessary-value-param) at::Tensor /* unused */, const c10::intrusive_ptr<::c10d::ProcessGroup>& process_group, const std::vector& device_ids, diff --git a/torch/csrc/distributed/c10d/ParamCommsUtils.hpp b/torch/csrc/distributed/c10d/ParamCommsUtils.hpp index 027b13c73ae9c..d011b0e42ed10 100644 --- a/torch/csrc/distributed/c10d/ParamCommsUtils.hpp +++ b/torch/csrc/distributed/c10d/ParamCommsUtils.hpp @@ -121,7 +121,7 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { worldSize); \ c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \ std::initializer_list paramList = { \ - c10::IValue(seq), \ + seq, \ pgName, \ rank, \ collName, \ @@ -163,7 +163,7 @@ class TORCH_API ParamCommsDebugInfo : public c10::DebugInfoBase { c10::DebugInfoGuard g(c10::DebugInfoKind::PARAM_COMMS_INFO, paramCommsInfo); \ std::initializer_list paramList = { \ c10::IValue(InputTensors), \ - c10::IValue(seq), \ + seq, \ pgName, \ rank, \ collName, \ diff --git a/torch/csrc/distributed/c10d/ProcessGroup.cpp b/torch/csrc/distributed/c10d/ProcessGroup.cpp index f565de2013260..63d64447dfdb9 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.cpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace c10d { @@ -101,10 +102,10 @@ c10::intrusive_ptr ProcessGroup::getBackend( } ProcessGroup::ProcessGroup( - const c10::intrusive_ptr<::c10d::Store>& store, + c10::intrusive_ptr<::c10d::Store> store, int rank, int size) - : store_(store), + : store_(std::move(store)), rank_(rank), size_(size), backendType_(BackendType::UNDEFINED), diff --git a/torch/csrc/distributed/c10d/ProcessGroup.hpp b/torch/csrc/distributed/c10d/ProcessGroup.hpp index 92b655f016eff..463d1f046db52 100644 --- a/torch/csrc/distributed/c10d/ProcessGroup.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroup.hpp @@ -94,7 +94,7 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { explicit ProcessGroup(int rank, int size); explicit ProcessGroup( - const c10::intrusive_ptr<::c10d::Store>& store, + c10::intrusive_ptr<::c10d::Store> store, int rank, int size); ~ProcessGroup() override; diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 51fa248ec403b..3cb765a658912 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -647,7 +647,6 @@ void socketInitialize() { bool doesHostnameResolveToUsableAddress(const std::string& hostname) { socketInitialize(); struct addrinfo hints {}; - memset(&hints, 0, sizeof(hints)); hints.ai_family = AF_UNSPEC; hints.ai_socktype = SOCK_STREAM; struct addrinfo* result = nullptr; @@ -869,7 +868,7 @@ namespace { class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork { public: AsyncBroadcastWork( - const std::shared_ptr& context, + std::shared_ptr context, std::vector& inputs, int rootRank, int rootTensor, @@ -881,7 +880,7 @@ class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork { seq, "gloo:broadcast", inputs), - context(context), + context(std::move(context)), inputs(inputs), rootRank(rootRank), rootTensor(rootTensor), @@ -1018,7 +1017,7 @@ namespace { class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork { public: AsyncAllreduceWork( - const std::shared_ptr& context, + std::shared_ptr context, std::vector& inputs, ReduceOp reduceOp, uint32_t tag, @@ -1029,7 +1028,7 @@ class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork { seq, "gloo:all_reduce", inputs), - context(context), + context(std::move(context)), inputs(inputs), reduceOp(std::move(reduceOp)), tag(tag) {} @@ -1102,7 +1101,7 @@ class AsyncAllreduceCoalescedWork : public AsyncAllreduceWork { class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { public: AsyncSparseAllreduceWork( - const std::shared_ptr& context, + std::shared_ptr context, std::vector& inputs, uint32_t tag, uint64_t seq) @@ -1112,7 +1111,7 @@ class AsyncSparseAllreduceWork : public ProcessGroupGloo::AsyncWork { seq, "gloo:sparse_all_reduce", inputs), - context(context), + context(std::move(context)), inputs(inputs), tag(tag) {} @@ -1619,7 +1618,7 @@ namespace { class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { public: AsyncReduceWork( - const std::shared_ptr& context, + std::shared_ptr context, std::vector& inputs, int rootRank, int rootTensor, @@ -1632,7 +1631,7 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { seq, "gloo:reduce", inputs), - context(context), + context(std::move(context)), inputs(inputs), rootRank(rootRank), rootTensor(rootTensor), @@ -1797,7 +1796,7 @@ namespace { class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork { public: AsyncAllgatherWork( - const std::shared_ptr& context, + std::shared_ptr context, std::vector>& outputs, std::vector& inputs, uint32_t tag, @@ -1808,7 +1807,7 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork { seq, "gloo:all_gather", inputs), - context(context), + context(std::move(context)), outputs(outputs), inputs(inputs), tag(tag) {} @@ -2069,7 +2068,7 @@ namespace { class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { public: AsyncAllgatherCoalescedWork( - const std::shared_ptr& context, + std::shared_ptr context, std::vector>& output_lists, std::vector& input_list, uint32_t tag, @@ -2080,7 +2079,7 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork { seq, "gloo:all_gather", input_list), - context(context), + context(std::move(context)), output_lists(output_lists), input_list(input_list), tag(tag) {} @@ -2211,7 +2210,7 @@ namespace { class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { public: AsyncGatherWork( - const std::shared_ptr& context, + std::shared_ptr context, std::vector>& outputs, std::vector& inputs, int root, @@ -2223,7 +2222,7 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { seq, "gloo:gather", inputs), - context(context), + context(std::move(context)), outputs(outputs), inputs(inputs), root(root), @@ -2416,7 +2415,7 @@ namespace { class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { public: AsyncScatterWork( - const std::shared_ptr& context, + std::shared_ptr context, std::vector& outputs, std::vector>& inputs, int root, @@ -2429,7 +2428,7 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork { "gloo:scatter", !inputs.empty() ? std::optional>(inputs[0]) : std::nullopt), - context(context), + context(std::move(context)), outputs(outputs), inputs(inputs), root(root), @@ -2611,7 +2610,7 @@ namespace { class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork { public: AsyncAlltoallWork( - const std::shared_ptr& context, + std::shared_ptr context, at::Tensor& outputTensor, at::Tensor& inputTensor, std::vector& outputCounts, @@ -2624,7 +2623,7 @@ class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork { seq, "gloo:all_to_all", std::optional>({inputTensor})), - context(context), + context(std::move(context)), outputTensor(outputTensor), inputTensor(inputTensor), outputCounts(std::move(outputCounts)), @@ -2882,7 +2881,7 @@ namespace { class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { public: AsyncBarrierWork( - const std::shared_ptr& context, + std::shared_ptr context, std::vector> priorWork, uint32_t tag, uint64_t seq) @@ -2892,7 +2891,7 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork { seq, "gloo:barrier", std::nullopt), - context(context), + context(std::move(context)), priorWork(std::move(priorWork)), tag(tag) {} diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index 9f1e63d58adf2..1ebead6b598e7 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -30,24 +30,9 @@ constexpr const char* GLOO_BACKEND_NAME = "gloo"; // All functions on this class are expected to be called in the same // order across processes in the group. This is the only way that we // can guarantee to match up the same calls across processes. For -// multi-threaded usage of process groups, you can use consider using +// multi-threaded usage of process groups, you can consider using // multiple process group instances. // -// The Gloo algorithms that this class calls into are cached by their -// signature (see description of AlgorithmKey above). This cache works -// as follows: every function call instantiates an AlgorithmKey and -// looks in the cache for existing entries. If there is one, it is -// removed from the cache and returned to the caller. If there are -// none, a new entry is created and returned. If an entry was created -// before, but is still in use, the call will block and wait until the -// entry is returned to the cache. -// -// In the future, we hope to extend this to allow multiple entries per -// key, to enable parallelism for a single key. The number of entries -// per key must always be identical for all processes. This maximum -// number can be automatically tuned, but only if we let a single -// process take charge, and have it broadcast the limits. -// class TORCH_API ProcessGroupGloo : public Backend { public: // AsyncWork is the Gloo specific superclass for asynchronous work items. diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 24d53e1ac7433..c9564a31f057c 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -1,13 +1,12 @@ #ifdef USE_C10D_NCCL +#include #include -#include #include #include #include #include #include -#include #include #include @@ -86,7 +85,7 @@ ncclDataType_t getNcclDataType(at::ScalarType type) { return it->second; } -bool complexViewAsRealAllowed(const ReduceOp reduceOp) { +bool complexViewAsRealAllowed(const ReduceOp& reduceOp) { switch (reduceOp) { case ReduceOp::SUM: return true; @@ -109,7 +108,7 @@ ncclRedOpRAII unpackPreMulSum( const ncclComm_t& comm) { const auto* preMulSupplement = reinterpret_cast(reduceOp.supplement_.get()); - ncclRedOp_t preMulSum; + ncclRedOp_t preMulSum{}; bool has_tensor = preMulSupplement->tensor_factor.defined(); auto residence = has_tensor ? ncclScalarDevice : ncclScalarHostImmediate; const T* ptr_factor = has_tensor @@ -160,8 +159,7 @@ ncclRedOpRAII getNcclReduceOp( default: C10_THROW_ERROR( TypeError, "PreMulSum Data type must be half, float, or double"); - ncclRedOp_t unused; - return unused; + return ncclRedOp_t{}; } #else C10_THROW_ERROR(ValueError, "PreMulSum requires NCCL>=2.11.1"); @@ -259,7 +257,7 @@ std::string buildNcclUniqueIdStr(const ncclUniqueId& ncclID) { return oss.str(); } -std::string getNcclAbortedCommStoreKey(const std::string ncclIdStr) { +std::string getNcclAbortedCommStoreKey(const std::string& ncclIdStr) { return std::string(kNCCLAbortedCommStoreKey) + ":" + ncclIdStr; } @@ -307,7 +305,7 @@ static bool allocatorHooksAttached = false; std::atomic ProcessGroupNCCL::shouldDump_(false); -void cacheAllocatorRegisterHook( +static void cacheAllocatorRegisterHook( const c10::cuda::CUDACachingAllocator::TraceEntry& te) { // Register after SEGMENT_ALLOC if (te.action_ != @@ -325,7 +323,7 @@ void cacheAllocatorRegisterHook( } } -void cacheAllocatorDeregisterHook( +static void cacheAllocatorDeregisterHook( const c10::cuda::CUDACachingAllocator::TraceEntry& te) { // deregister before SEGMENT_FREE if (te.action_ != @@ -343,8 +341,9 @@ void cacheAllocatorDeregisterHook( } } -std::unordered_map> -getNCCLCommDumpMap() { +static std:: + unordered_map> + getNCCLCommDumpMap() { #if defined(IS_NCCLX) && defined(NCCL_COMM_DUMP) std::unordered_map< std::string /* ncclUniqueID */, @@ -367,9 +366,74 @@ getNCCLCommDumpMap() { } return ncclDumpMap; #else - return std::unordered_map< - std::string, - std::unordered_map>(); + /* + The following code is designed to work with NCCL versions above 2.23.4, which + support the profiler plugin. + For information on the NCCL profiler plugin, please refer to + https://github.com/NVIDIA/nccl/tree/v2.23.4-1/ext-profiler/example. + The plugin is a shared library (.so file) that is loaded by NCCL and PyTorch. + Users must define the dump function in the plugin, which should dump the + internal buffers of the profiler plugin. + + env variables: + 1. TORCH_NCCL_ENABLE_PROFILER_PLUGIN is a boolean flag to enable the plugin. + 2. NCCL_PROFILER_PLUGIN is the path to the plugin. + 3. NCCL_PROFILER_PLUGIN_FUN is the name of the dump function in the plugin. + + Hint: + 1. The function name would be mangled in C++. Use readelf -s -W .so to + find the mangled name. + */ + std::unordered_map> + ncclDumpMap; + + const bool isProfilerPluginEnabled = + getCvarBool({"TORCH_NCCL_ENABLE_PROFILER_PLUGIN"}, false); + if (!isProfilerPluginEnabled) { + return ncclDumpMap; + } + + const std::string profilerPluginPath = getCvarString( + {"NCCL_PROFILER_PLUGIN"}, + "/packages/training_platform/libnccl_profiler_plugin.so"); + LOG(INFO) << "NCCL_PROFILER_PLUGIN: " << profilerPluginPath; + if (profilerPluginPath.empty()) { + return ncclDumpMap; + } + + void* handle = dlopen(profilerPluginPath.c_str(), RTLD_LAZY | RTLD_LOCAL); + if (handle == nullptr) { + LOG(WARNING) << "Failed to open handle to process: "; + LOG(WARNING) << "dlopen failed:" << dlerror(); + return ncclDumpMap; + } + + const std::string profilerPluginFun = getCvarString( + {"NCCL_PROFILER_PLUGIN_FUN"}, "_Z22ncclProfilerPluginDumpB5cxx11v"); + if (profilerPluginFun.empty()) { + LOG(WARNING) << "NCCL_PROFILER_PLUGIN_FUN is empty"; + return ncclDumpMap; + } + std:: + unordered_map> ( + *dumpFn)() = + (std::unordered_map< + std::string, + std::unordered_map>(*)()) + dlsym(handle, profilerPluginFun.c_str()); + if (dumpFn == nullptr) { + LOG(WARNING) << "Failed to find " << profilerPluginFun; + return ncclDumpMap; + } + + try { + // nonblocking call + ncclDumpMap = (*dumpFn)(); + } catch (const std::exception& e) { + LOG(WARNING) << "Failed to call " << profilerPluginFun << ": " << e.what(); + } + + return ncclDumpMap; #endif } @@ -401,7 +465,7 @@ gil_checker_t& get_gil_checker() { return gil_checker; } -std::future launchAsyncGilCheck() { +static std::future launchAsyncGilCheck() { std::promise resultPromise; std::future resultFuture = resultPromise.get_future(); TORCH_CHECK(get_gil_checker(), "Can't check GIL with null GIL checker"); @@ -423,7 +487,7 @@ std::future launchAsyncGilCheck() { } const int64_t ProcessGroupNCCL::kWatchdogThreadSleepMillis = 100; -constexpr int64_t kSynchronizeBusyWaitMillis = 10; +constexpr int64_t kSynchronizeBusyWaitMillis = 1; thread_local uint64_t ProcessGroupNCCL::ncclActiveGroupCounter_ = 0; std::ostream& operator<<( @@ -447,12 +511,13 @@ std::ostream& operator<<( } ProcessGroupNCCL::WorkNCCL::WorkNCCL( - const std::string& pgUID, - const std::string& pgDesc, + std::string pgUID, + std::string pgDesc, at::Device& device, int rank, OpType opType, uint64_t seq, + bool isP2P, const char* profilingTitle, const std::optional>& inputs, bool desyncDebug, @@ -460,11 +525,12 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL( bool cudaEventCacheEnabled, DebugLevel distDebugLevel) : Work(rank, opType, profilingTitle, inputs), - pgUID_(pgUID), - pgDesc_(pgDesc), + pgUID_(std::move(pgUID)), + pgDesc_(std::move(pgDesc)), device_(device), workStartTime_(std::chrono::steady_clock::now()), seq_(seq), + isP2P_(isP2P), timingEnabled_(enableTiming), distDebugLevel_(distDebugLevel) { // Creates the CUDA event wrappers @@ -483,6 +549,8 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL( ncclEndEvent_ = std::make_shared( enableTiming ? cudaEventDefault : cudaEventDisableTiming); } + futureWorkResult_ = + c10::make_intrusive(c10::AnyEnumType::get()); } ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) @@ -499,10 +567,12 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) ownedEphermeralTimeout_(w.ownedEphermeralTimeout_), workStartTime_(w.workStartTime_), seq_(w.seq_), + isP2P_(w.isP2P_), startTraceUpdated_(w.startTraceUpdated_), numelIn_(w.numelIn_), numelOut_(w.numelOut_), store_(w.store_), + futureWorkResult_(w.futureWorkResult_), timingEnabled_(w.timingEnabled_), trace_id_(w.trace_id_), distDebugLevel_(w.distDebugLevel_) { @@ -553,7 +623,7 @@ const std::string& ProcessGroupNCCL::WorkNCCL::logPrefix() const { void ProcessGroupNCCL::WorkNCCL::setException( std::exception_ptr exception_ptr) { std::unique_lock lock(mutex_); - exception_ = exception_ptr; + exception_ = std::move(exception_ptr); } // Helper that checks if the NCCL kernels are completed on the GPUs @@ -632,6 +702,14 @@ void ProcessGroupNCCL::WorkNCCL::handleException( LOG(ERROR) << logPrefix() << exceptionMsg; C10_LOG_API_USAGE_ONCE("ProcessGroupNCCL.WorkNCCL.handleException"); + auto logger = c10d::C10dLogger::getLogger(); + if (logger) { + ::c10d::C10dLoggingData data; + data.strings["work_nccl_exception"] = + getExceptionMsgFromExceptionPtr(exception_); + logger->log(data); + } + if (SHOULD_TEAR_DOWN(errorHandling)) { auto tearDownMsg = c10::str( "To avoid data inconsistency, we are taking the entire process down."); @@ -643,27 +721,6 @@ void ProcessGroupNCCL::WorkNCCL::handleException( void ProcessGroupNCCL::WorkNCCL::synchronize() { synchronizeStream(); - - // Device synchronize only after we've completed timeout checks. - // TODO: Is this necessary for barrier if we block the cpu thread till - // the completion of the work? - if (barrierTensor_.defined()) { - // If we use the work to do barrier, we should block here - // `dist.barrier()` only requires all CPU processes to enter this - // function, hence we only need to make sure the dummy all-reduce has - // completed. So we would only need to sync the **current stream** back to - // host, and do not need to synchronize the entire device (which may have - // kernels running on other streams). - // Using `cudaStreamSynchronize` instead of `cudaDeviceSynchronize` can: - // - lower chance of hang; - // - CurrentCUDAStream is usually the context of the next operation in - // Python, thus blocking current stream would already block the next - // compute kernel; - // - achieve better barrier performance. - auto currentStream = at::cuda::getCurrentCUDAStream(device_.index()); - // CUDAStream wrapper will correctly use a DeviceGuard here - currentStream.synchronize(); - } } void ProcessGroupNCCL::WorkNCCL::synchronizeStream() { @@ -679,7 +736,7 @@ void ProcessGroupNCCL::WorkNCCL::synchronizeStream() { // Same as calling synchronize() when blockingWait_ is false bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) { RECORD_PARAM_COMMS( - static_cast(this->seq_), // seq + std::make_tuple(static_cast(this->seq_), this->isP2P_), // seq std::make_tuple(pgUID_, pgDesc_), // PG name tuple rank_, // rank "wait", // collective name @@ -692,6 +749,9 @@ bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) { -1, static_cast(1)); // number of device? + // synchronize() will block the current stream on the NCCL stream + synchronize(); + // In case of blockingWait or a timeout value is specified by the user, we // block the CPU thread until the work is completed or timed out. if (blockingWait_ || timeout != kNoTimeout) { @@ -713,17 +773,22 @@ bool ProcessGroupNCCL::WorkNCCL::wait(std::chrono::milliseconds timeout) { std::this_thread::sleep_for( std::chrono::milliseconds(kSynchronizeBusyWaitMillis)); } - - if (exception()) { - // Abort NCCL communicators - abort(); - // Throw exception (from main thread here) - handleException(TearDown); - } + } else if (isBarrierOp_ && !isCompleted()) { + // For barrier wait when timeout is unspecified, we block the CPU thread on + // current stream. This is to minimize the CPU barrier wait time in healthy + // path + auto currentStream = at::cuda::getCurrentCUDAStream(device_.index()); + // CUDAStream wrapper will correctly use a DeviceGuard here + currentStream.synchronize(); } - // syncrhoize() will block the current stream on the NCCL stream - synchronize(); + // If exception is detected, throw it from the main CPU thread + if (exception()) { + // Abort NCCL communicators + abort(); + // Throw exception (from main thread here) + handleException(TearDown); + } // TODO(kwen2501): this should be moved to c10d tests, to qualify a NCCL // upgrade. Once a NCCL version is qualified, this code should not be needed @@ -749,7 +814,7 @@ void ProcessGroupNCCL::WorkNCCL::abort() { ncclCommDevIdxMapMutex.unlock(); } -ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {} +ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() = default; // CUDA event is used to record the start/end of one Work. // Instead of let the CUDA event gets destroyed, we now reuse it after the Work @@ -757,27 +822,32 @@ ProcessGroupNCCL::CUDAEventCache::CUDAEventCache() {} // This is to avoid the potential deadlock caused by CudaEventDestroy. std::shared_ptr ProcessGroupNCCL::CUDAEventCache::create( bool timing) { + // register the deleter as a callback when the WorkNCCL object is destroyed. auto deleter = [this, timing](at::cuda::CUDAEvent* event) { std::lock_guard lock(this->cacheMutex_); + // We put the event back to the cache deque once the WorkNCCL object is + // destroyed. this->eventsArray_[timing ? 1 : 0].push_back(event); }; at::cuda::CUDAEvent* event = nullptr; { std::lock_guard lock(cacheMutex_); - auto events = eventsArray_[timing ? 1 : 0]; + auto& events = eventsArray_[timing ? 1 : 0]; + // If we still have events in the cache, we reuse it. Otherwise, we create a + // new one. if (!events.empty()) { - event = events.back(); - events.pop_back(); + event = events.front(); + events.pop_front(); + } else { + event = new at::cuda::CUDAEvent( + timing ? cudaEventDefault : cudaEventDisableTiming); } } - if (!event) { - event = new at::cuda::CUDAEvent( - timing ? cudaEventDefault : cudaEventDisableTiming); - } return std::shared_ptr(event, std::move(deleter)); } ProcessGroupNCCL::CUDAEventCache& ProcessGroupNCCL::CUDAEventCache::get() { + // Return a singleton instance of CUDAEventCache. static ProcessGroupNCCL::CUDAEventCache cache; return cache; } @@ -792,14 +862,14 @@ constexpr const char* MULTI_DEVICE_ERROR_MSG = "ProcessGroupNCCL continues supporting multi-process and multi-thread modes."; ProcessGroupNCCL::ProcessGroupNCCL( - const c10::intrusive_ptr& store, + c10::intrusive_ptr store, int rank, int size, c10::intrusive_ptr options) : Backend(rank, size), - store_(store), - options_(options), - ncclCommCounter_(0), + store_(std::move(store)), + options_(std::move(options)), + traceKeyStart_(getTraceStartKey("NCCL", rank)), traceKeyEnd_(getTraceEndKey("NCCL", rank)), terminateProcessGroup_(false), @@ -831,7 +901,6 @@ ProcessGroupNCCL::ProcessGroupNCCL( // both timeout and other errors. dumpOnTimeoutOrEx_ = getCvarBool(TORCH_NCCL_DUMP_ON_TIMEOUT, false) || (dist_debug_level_ >= DebugLevel::Detail); - sleepAfterException_ = getCvarBool(TORCH_NCCL_SLEEP_AFTER_EXCEPTION, false); // logging C++ stack isn't safe. Introduce a variable to control it. logCppStackOnUncleanShutdown_ = getCvarBool(TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN, true); @@ -871,15 +940,9 @@ ProcessGroupNCCL::ProcessGroupNCCL( #endif if (blockingWait_) { - if (asyncErrorHandling_ != NoHandling || desyncDebug_) { - LOG(INFO) - << logPrefix() << "TORCH_NCCL_BLOCKING_WAIT and " - << "TORCH_NCCL_ASYNC_ERROR_HANDLING|TORCH_NCCL_DESYNC_DEBUG" - << "should not both be enabled. " - << "Only TORCH_NCCL_BLOCKING_WAIT is being used in this process."; - asyncErrorHandling_ = NoHandling; - desyncDebug_ = false; - } + LOG(INFO) + << logPrefix() + << "TORCH_NCCL_BLOCKING_WAIT is enabled, NO watchdog thread is created."; } else { if (desyncDebug_ && asyncErrorHandling_ == NoHandling) { LOG(INFO) @@ -892,8 +955,13 @@ ProcessGroupNCCL::ProcessGroupNCCL( } #ifdef ENABLE_NCCL_ERROR_CHECKING - ncclCommWatchdogThread_ = - std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); + // in blockingWait mode, we don't need to enable the watchdog thread to check + // the timeout or nccl error because the main thread would throw an exception + // and it is the user's responsibility to handle the exception. + if (!blockingWait_) { + ncclCommWatchdogThread_ = + std::thread(&ProcessGroupNCCL::ncclCommWatchdog, this); + } #endif init(); @@ -971,9 +1039,10 @@ ProcessGroupNCCL::ProcessGroupNCCL( // SEGMENT_FREE action occurs. // We attach hooks only once at the first PG creation. // Attaching hooks fails if CUDACachingAllocator is not initialized, so - // lazyInitCUDA is called (and is a no-op if CUDA is already initialized). + // Init for CUDA is called (and is a no-op if CUDA is already + // initialized). if (useTensorRegisterAllocatorHook_ && !allocatorHooksAttached) { - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( &cacheAllocatorRegisterHook); c10::cuda::CUDACachingAllocator::attachAllocatorTraceTracker( @@ -989,6 +1058,39 @@ void ProcessGroupNCCL::eagerConnectSingleDevice(at::Device device) { getNCCLComm(key, device, OpType::ALLREDUCE); } +bool ProcessGroupNCCL::useNonblocking() { +#ifndef NCCL_HAS_COMM_NONBLOCKING + return false; +#endif + // Already parsed, return the cached value + if (useNonblocking_.has_value()) { + return useNonblocking_.value(); + } + // Get environment variable. + auto nbEnv = c10::utils::check_env("TORCH_NCCL_USE_COMM_NONBLOCKING"); + + // 1st priority: Respect the user's setting + if (options_->config.blocking != NCCL_CONFIG_UNDEF_INT) { + useNonblocking_ = options_->config.blocking == 0; + } + // 2nd priority: Respect the environment variable + else if (nbEnv.has_value()) { + useNonblocking_ = nbEnv.value(); + } + // 3rd priority: automatically use nonblocking if we are in eager init mode + else if (getBoundDeviceId()) { + useNonblocking_ = true; + } + // 4th priority: otherwise, nonblocking = false to preserve old behavior + else { + useNonblocking_ = false; + } + + LOG(INFO) << logPrefix() + << "Using non-blocking mode: " << useNonblocking_.value(); + return useNonblocking_.value(); +} + void ProcessGroupNCCL::performNocolorSplit(at::Device device) { // If our backend doesn't support splitting, this is a no-op for // ranks not in the new subgroup (and ranks that would be in it will @@ -997,6 +1099,8 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) { const auto key = getKeyFromDevice(device); LOG(INFO) << logPrefix() << "Performing nocolor split on backend device " << device << ", key " << key << ", i am " << this; + bool useNb = useNonblocking(); + options_->config.blocking = useNb ? 0 : 1; auto comm = getNCCLComm(key, device, OpType::ALLREDUCE); NCCLComm::split( comm.get(), @@ -1007,6 +1111,21 @@ void ProcessGroupNCCL::performNocolorSplit(at::Device device) { #endif } +bool ProcessGroupNCCL::isInitialized() { + if (devNCCLCommMap_.empty()) { + return false; + } + std::lock_guard lock(mutex_); + bool initialized = true; + for (const auto& [_, comm] : devNCCLCommMap_) { + if (!comm->isInitialized()) { + initialized = false; + break; + } + } + return initialized; +} + c10::intrusive_ptr ProcessGroupNCCL:: initIntraNodeComm() { using IntraNodeComm = intra_node_comm::IntraNodeComm; @@ -1097,9 +1216,10 @@ void ProcessGroupNCCL::waitForFutureOrTimeout( ::c10d::C10dLoggingData data; if (log) { - data.integers["pg_id"] = local_id_; + data.integers["pg_id"] = static_cast(local_id_); data.integers["rank"] = rank_; data.integers["global_rank"] = globalRank(); + data.integers["world_size"] = getSize(); data.strings["flight_recorder_version"] = c10d::version_val_str; } @@ -1164,7 +1284,7 @@ void ProcessGroupNCCL::waitForFutureOrTimeout( void ProcessGroupNCCL::abortCommsFromMap( std::unordered_map>& ncclCommsMap, - std::optional abortReason) { + const std::optional& abortReason) { // The process may control multiple devices, loop through the communicators on // each device for (auto& it : ncclCommsMap) { @@ -1180,7 +1300,7 @@ void ProcessGroupNCCL::abortCommsFromMap( gpuGuard.set_index(deviceIndex); } LOG(INFO) << logPrefix() << "ProcessGroupNCCL destroying ncclComm_ " - << ncclComm->ncclComm_ << " on CUDA device: " << devName; + << ncclComm->repr() << " on CUDA device: " << devName; ncclComm->ncclCommAbort(abortReason); // Note that we don't remove the aborted communicators from the // cache. The reason is that if we do remove the communicator @@ -1198,9 +1318,11 @@ void ProcessGroupNCCL::abortCommsFromMap( } // Abort all communicators on this rank -bool ProcessGroupNCCL::abort(std::optional abortReason) { - // This will log counter for how long the abort actually takes. - STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort); +// Note: original name of this method is `abort`. It was renamed to +// `abortComms` to distinguish from the `abort` method below. The `abort` +// method calls `abortComms` but does more destruction than the latter. +bool ProcessGroupNCCL::abortComms( + const std::optional& abortReason) { // Remove record from global ncclCommDevIdxMapMutex before aboarting, // so that a new cache segment would not register to already aborded // communicators. Note that ncclCommDevIdxMap is a global container which may @@ -1219,7 +1341,11 @@ bool ProcessGroupNCCL::abort(std::optional abortReason) { return true; } -void ProcessGroupNCCL::shutdown(std::optional reason) { +// Abort this backend. +void ProcessGroupNCCL::abort() { + // This will log counter for how long the abort actually takes. + STATIC_SCOPED_WAIT_COUNTER(pytorch.ProcessGroupNCCL__abort); + // Don't join threads here since the purpose of this method is to abort all // communicators and signal the threads to exit. Joining on the threads could // potentially block and hence avoid it in this method. @@ -1229,8 +1355,8 @@ void ProcessGroupNCCL::shutdown(std::optional reason) { // lauch abort asynchrounously and wait for it to complete or timeout LOG(INFO) << logPrefix() << "Launching ProcessGroupNCCL abort asynchrounously."; - std::future fut = std::async( - std::launch::async, [this, &reason]() { return this->abort(reason); }); + std::future fut = + std::async(std::launch::async, [this]() { return this->abortComms(); }); waitForFutureOrTimeout( fut, options_->timeout, "ProcessGroup abort", true, false); @@ -1242,6 +1368,15 @@ void ProcessGroupNCCL::shutdown(std::optional reason) { monitorWakeUpCV_.notify_one(); } +// Destroy (shutdown) this backend -- normal exit. +void ProcessGroupNCCL::shutdown() { + // kwen2501 (Aug 2024): moved code of `shutdown()` to `abort()` because it + // actually implemented an abort behavior. + // TODO: implementation of `shutdown` should use ncclCommDestroy() instead + // of ncclCommAbort(). Ideally non-blocking API mode should be used. + this->abort(); +} + ProcessGroupNCCL::~ProcessGroupNCCL() { LOG(INFO) << logPrefix() << "ProcessGroupNCCL destructor entered."; @@ -1262,14 +1397,16 @@ ProcessGroupNCCL::~ProcessGroupNCCL() { // Wait for all threads to finish before returning #ifdef ENABLE_NCCL_ERROR_CHECKING - if (ncclCommWatchdogThread_.joinable()) { - ncclCommWatchdogThread_.join(); - LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined."; - } - if (ncclHeartbeatMonitorThread_.joinable()) { - ncclHeartbeatMonitorThread_.join(); - LOG(INFO) << logPrefix() - << "ProcessGroupNCCL heart beat monitor thread joined."; + if (!blockingWait_) { + if (ncclCommWatchdogThread_.joinable()) { + ncclCommWatchdogThread_.join(); + LOG(INFO) << logPrefix() << "ProcessGroupNCCL watchdog thread joined."; + } + if (ncclHeartbeatMonitorThread_.joinable()) { + ncclHeartbeatMonitorThread_.join(); + LOG(INFO) << logPrefix() + << "ProcessGroupNCCL heart beat monitor thread joined."; + } } #endif if (onCompletionHookThread_.joinable()) { @@ -1300,13 +1437,13 @@ bool ProcessGroupNCCL::dumpDebuggingInfo() { return false; } -void ProcessGroupNCCL::terminateProcess(std::string errMsg) { +void ProcessGroupNCCL::terminateProcess(const std::string& errMsg) { // Logging with `FATAL`, after errMsg printed, it calls `std::abort()` // to terminate the program execution. LOG(FATAL) << logPrefix() << errMsg; } -int computeDeltaMS( +static long computeDeltaMS( std::chrono::time_point start, std::chrono::time_point end) { return std::chrono::duration_cast(end - start) @@ -1483,12 +1620,15 @@ void ProcessGroupNCCL::heartbeatMonitor() { } LOG(ERROR) << errorMsg; - auto& cpp_dumper = get_cpp_trace_dumper(); - if (logCppStackOnUncleanShutdown_ && cpp_dumper.has_value()) { - LOG(INFO) << "Dumping c++ stacktraces:"; - cpp_dumper.value()([](const std::string& line) { LOG(INFO) << line; }); - } + // We perform some checks to help users debug the timeout/hang issue: + // 1. Dump the nccl trace (flight recorder) to help debug the issue + // (timeout after waitTimeoutDumpInMilSec_, which is one minute). + // 2. Check if there is a GIL deadlock (timeout after 300ms). + // 3. Try to dump the c++ stacktraces (blocking and would hang, + // users can turn this off by set + // TORCH_NCCL_LOG_CPP_STACK_ON_UNCLEAN_SHUTDOWN=0). + // Dump the nccl trace (flight recorder). if (checkDumpSignal && shouldDump_.load()) { // Store debug info to storage if no other thread does it. (By default to // local disk) @@ -1502,8 +1642,11 @@ void ProcessGroupNCCL::heartbeatMonitor() { "Flight recorder dump in heartbeatMonitor", false, true); + // Indicate to watchdog thread that we have finished dumping. + promiseFlightRecorderDump_.set_value(); } + // GIL deadlock check. if (get_gil_checker() != nullptr) { auto fut = launchAsyncGilCheck(); auto kGilCheckTimeout = std::chrono::milliseconds(300); @@ -1513,13 +1656,24 @@ void ProcessGroupNCCL::heartbeatMonitor() { futStatus != std::future_status::deferred, "Expected the future to have been launched eagerly."); LOG(ERROR) + << logPrefix() << "Could not acquire GIL within 300 ms on exit, possible GIL induced hang"; } } else { LOG(INFO) + << logPrefix() << "GIL checker was not registered, perhaps this is a no-python build?"; } + // Dump the c++ stacktraces. + auto& cpp_dumper = get_cpp_trace_dumper(); + if (logCppStackOnUncleanShutdown_ && cpp_dumper.has_value()) { + LOG(INFO) << logPrefix() << "Dumping c++ stacktraces:"; + cpp_dumper.value()( + [&](const std::string& line) { LOG(INFO) << logPrefix() << line; }); + LOG(INFO) << logPrefix() << "Finished c++ stacktraces dump."; + } + // There are two possible cases for the watchdog thread exit: // Case one: desync report runs quickly, and it follows the step: // collective timeout -> desync -> exception handling -> destructors @@ -1695,7 +1849,7 @@ void ProcessGroupNCCL::addEphemeralTimeout( } bool ProcessGroupNCCL::verifyWorkTimeoutForTest( - const c10::intrusive_ptr work, + const c10::intrusive_ptr& work, const std::chrono::milliseconds& timeout) { // Since collective returns a c10d::Work, we need to cast it to WorkNCCL. if (auto workNCCL = c10::dynamic_intrusive_pointer_cast(work)) { @@ -1774,23 +1928,60 @@ void ProcessGroupNCCL::watchdogHandler() { // aborted, So cannot check exception based on them. But watchdog needs to // finish the check for the works that have already been enqueued to // workMetaList_ + + // check NCCL errors first if (!terminateProcessGroup_.load()) { work.checkAndSetException(); } - bool timedOut = work.checkTimeout(); - - // If work hits an exception (either an error or timeout) if (work.exception()) { // log as soon as exception is detected LOG(ERROR) << c10::str( logPrefix(), - "Exception (either an error or timeout) detected by watchdog at work: ", + "NCCL error is detected by watchdog at work: ", work.seq_, ", last enqueued NCCL work: ", pgStatus_->lastEnqueuedSeq, ", last completed NCCL work: ", pgStatus_->lastCompletedSeq, "."); + if (work.futureWorkResult_ && !work.futureWorkResult_->completed()) { + work.futureWorkResult_->markCompleted( + at::IValue(static_cast(WorkResult::COMM_ERROR))); + } + } else if (work.checkTimeout()) { + LOG(ERROR) << c10::str( + logPrefix(), + "Work timeout is detected by watchdog at work: ", + work.seq_, + ", last enqueued NCCL work: ", + pgStatus_->lastEnqueuedSeq, + ", last completed NCCL work: ", + pgStatus_->lastCompletedSeq, + "."); + if (work.futureWorkResult_ && !work.futureWorkResult_->completed()) { + work.futureWorkResult_->markCompleted( + at::IValue(static_cast(WorkResult::TIMEOUT))); + } + // Report desync state in case of timeout + if (desyncDebug_) { + try { + collectiveDebugInfoMode_.store(true); + auto desyncMsg = getNCCLWatchdogDebugInfo(); + LOG(ERROR) << logPrefix() << desyncMsg; + } catch (const std::exception& e) { + LOG(ERROR) << logPrefix() + << "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. " + << " Please file an issue. Error: " << e.what(); + } catch (...) { + LOG(ERROR) + << logPrefix() + << "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error." + << " Please file an issue."; + } + } + } + // If work hits an exception (either an error or timeout) + if (work.exception()) { // try to notify other ranks via global TCPStore to dump the flight // recorder when a collective timeout or exception happens. Flight // recorder behavior is independent of desync Debug. @@ -1808,12 +1999,18 @@ void ProcessGroupNCCL::watchdogHandler() { } // signal the monitor thread on PG0 to start dumping shouldDump_.store(true); - if (sleepAfterException_) { - // This sleep is used to give time for dumping before throwing - // exception - std::this_thread::sleep_for( - std::chrono::seconds(heartbeatTimeoutInSec_)); - LOG(INFO) << logPrefix() << "slept for " << heartbeatTimeoutInSec_ + // Give time for dumping before throwing exception + auto start = std::chrono::steady_clock::now(); + auto status = promiseFlightRecorderDump_.get_future().wait_for( + std::chrono::milliseconds(waitTimeoutDumpInMilSec_)); + if (status == std::future_status::timeout) { + LOG(WARNING) << logPrefix() << "timed out after waiting for " + << waitTimeoutDumpInMilSec_ << "ms" + << " flight recorder dumps to finish."; + } else if (status == std::future_status::ready) { + auto end = std::chrono::steady_clock::now(); + LOG(INFO) << logPrefix() << "slept for " + << computeDeltaMS(start, end) << "ms" << " giving time for flight recorder dumps to finish."; } } catch (const std::exception& e) { @@ -1828,37 +2025,7 @@ void ProcessGroupNCCL::watchdogHandler() { work.abort(); // PG level abort, which would abort all other communicators on this // rank - abort(); - } - - // Report desync state in case of timeout - if (timedOut) { - LOG(ERROR) << c10::str( - logPrefix(), - "Timeout at NCCL work: ", - work.seq_, - ", last enqueued NCCL work: ", - pgStatus_->lastEnqueuedSeq, - ", last completed NCCL work: ", - pgStatus_->lastCompletedSeq, - "."); - if (desyncDebug_) { - try { - collectiveDebugInfoMode_.store(true); - auto desyncMsg = getNCCLWatchdogDebugInfo(); - LOG(ERROR) << logPrefix() << desyncMsg; - } catch (const std::exception& e) { - LOG(ERROR) - << logPrefix() - << "Failed to retrieve TORCH_NCCL_DESYNC_DEBUG report. " - << " Please file an issue. Error: " << e.what(); - } catch (...) { - LOG(ERROR) - << logPrefix() - << "Failed to rerieve TORCH_NCCL_DESYNC_DEBUG report with unknown error." - << " Please file an issue."; - } - } + abortComms(); } // Throw exception work.handleException(asyncErrorHandling_); @@ -1879,12 +2046,17 @@ void ProcessGroupNCCL::watchdogHandler() { // multiple times after the start if (pgStatus_->lastStartedSeq < static_cast(work.seq_) && work.isStarted()) { - pgStatus_->lastStartedSeq = work.seq_; + pgStatus_->lastStartedSeq = static_cast(work.seq_); pgStatus_->lastStartedWorkName = opTypeToString(work.opType_); } // Clean up completed work if (work.isCompleted()) { + if (work.futureWorkResult_ && work.finishedGPUExecutionInternal() && + !work.futureWorkResult_->completed()) { + work.futureWorkResult_->markCompleted( + at::IValue(static_cast(WorkResult::SUCCESS))); + } { // Reset the timeout and first work if the work is completed. std::lock_guard timeoutLock(mtxTimeoutExtension_); @@ -1893,7 +2065,7 @@ void ProcessGroupNCCL::watchdogHandler() { ephemeralTimeoutInflight_ -= work.ownedEphermeralTimeout_; } } - pgStatus_->lastCompletedSeq = work.seq_; + pgStatus_->lastCompletedSeq = static_cast(work.seq_); pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_); pgStatus_->lastCompletedNumelIn = work.numelIn_; pgStatus_->lastCompletedNumelOut = work.numelOut_; @@ -1986,7 +2158,7 @@ void ProcessGroupNCCL::runHookLoop() { // already finished successfully at this point. We just need to abort // the process Abort all NCCL Communicators on this ProcessGroupNCCL // instance. - abort(errorStr); + abortComms(errorStr); } } @@ -2202,7 +2374,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( } // GPU world size and GPU rank - int numRanks, rank; + int numRanks = -1, rank = -1; if (!singleP2POp) { // Collective, all-to-all, or batch P2P @@ -2219,11 +2391,13 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( rank = p2pRank; } +#ifdef NCCL_HAS_COMM_NONBLOCKING + bool useNb = useNonblocking(); + options_->config.blocking = useNb ? 0 : 1; +#endif + #ifdef NCCL_HAS_COMM_SPLIT if (options_->split_from) { - TORCH_CHECK( - options_->split_color != 0, - "Must specify a non-zero color when splitting"); // Find a valid, healthy communicator to split from if possible. std::lock_guard lock(options_->split_from->mutex_); auto& other_comms = options_->split_from->devNCCLCommMap_; @@ -2231,6 +2405,8 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( if (dit != other_comms.end()) { auto& parentComm = dit->second; if (parentComm != nullptr && !parentComm->isAborted()) { + LOG(INFO) << logPrefix() << "Splitting NCCL communicator from " + << parentComm->repr(); ncclComm = NCCLComm::split( parentComm.get(), options_->split_color, @@ -2287,7 +2463,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( std::make_tuple(pg_uid_, pg_desc_), groupRanks()); RECORD_PARAM_COMMS( - 0, // seq + std::make_tuple(0, false), // seq std::make_tuple(pg_uid_, pg_desc_), // PG name tuple rank, // rank "init", // collective name @@ -2301,7 +2477,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( size_); // worldSize LOG(INFO) << logPrefix() << "ProcessGroupNCCL created ncclComm_ " - << ncclComm->ncclComm_ + << ncclComm->repr() << " on CUDA device: " << static_cast(deviceIndex); // At this point NCCL should have been initialized, hence we can accurately @@ -2315,7 +2491,7 @@ std::shared_ptr ProcessGroupNCCL::getNCCLComm( C10D_NCCL_CHECK(ncclGroupStart(), std::nullopt); } - ncclStreams_.emplace(deviceKey, std::move(streamVal)); + ncclStreams_.emplace(deviceKey, streamVal); // Note: these events are created with the (default) cudaEventDisableTiming // flag This flag provides the best performance when used with @@ -2404,7 +2580,7 @@ void check_gpu_single_tensor( // condition may be a challenge because the test would need to pass tensors on // different devices in the same process. int64_t check_gpu_tensors_same_device(const std::vector& tensors) { - if (tensors.size() == 0) { + if (tensors.empty()) { C10_THROW_ERROR(ValueError, "Tensor list must be nonempty"); } @@ -2450,6 +2626,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( at::Device& device, int rank, OpType opType, + bool isP2P, const char* profilingTitle, const std::vector& inputs, const std::vector& outputs, // TODO(kwen2501): necessary? @@ -2460,7 +2637,8 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( device, rank, opType, - seqCollective_, + isP2P ? seqP2P_ : seqCollective_, + isP2P, profilingTitle, profilingTitle != nullptr ? std::optional>(inputs) : std::nullopt, @@ -2511,6 +2689,11 @@ c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: return future_; } +c10::intrusive_ptr ProcessGroupNCCL::WorkNCCL:: + getFutureResult() { + return futureWorkResult_; +} + float ProcessGroupNCCL::WorkNCCL::getDuration() const { TORCH_CHECK(timingEnabled_, "getDuration only works if timing was enabled"); TORCH_CHECK( @@ -2541,8 +2724,10 @@ void ProcessGroupNCCL::assignTimeoutToWork( } void ProcessGroupNCCL::workEnqueue( - c10::intrusive_ptr work) { - if (!terminateProcessGroup_.load()) { + const c10::intrusive_ptr& work) { + // in blockingWait_ mode, we don't need watchdog thread, so no need to enqueue + // the work + if (!terminateProcessGroup_.load() && !blockingWait_) { std::lock_guard lock(workMetaListMutex_); // Avoid view tensors to be processed in cleanup thread. // View tensors' destruction invokes autograd_meta, which @@ -2574,14 +2759,6 @@ void ProcessGroupNCCL::startCoalescing() { // start, which has one minor downside- we burn a seq_ if someone ever does a // 'start' and 'end' coalescing region without doing an operation inbetween. - // Don't bump op_id_ here, because startCoalescing isn't a logical operation. - // Bump it for each logical op inside the coalescing group. - if (coalescing_state_ & CoalP2P) { - seqP2P_++; - } else { - seqCollective_++; - } - coalescedDevice_.set_index(-1); coalescedComm_ = nullptr; coalescing_state_ |= CoalActive; @@ -2615,8 +2792,15 @@ c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { c10::cuda::currentStreamCaptureStatusMayInitCtx(); bool enqueue = (coalescing_state_) && capture_status == c10::cuda::CaptureStatus::None; - auto work = - initWork(device, rank_, optype, "nccl:coalesced", {}, {}, enqueue); + auto work = initWork( + device, + rank_, + optype, + coalescing_state_ & CoalP2P, + "nccl:coalesced", + {}, + {}, + enqueue); work->ncclComm_ = comm; work->blockingWait_ = blockingWait_; work->avoidRecordStreams_ = avoidRecordStreams_; @@ -2628,7 +2812,7 @@ c10::intrusive_ptr ProcessGroupNCCL::endCoalescing(OpType optype) { work->ncclStartEvent_->record(ncclStream); } - if (nccl_use_nonblocking()) { + if (useNonblocking()) { groupEndNonblocking(comm); } else { groupEnd(); @@ -2678,19 +2862,29 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( avoidRecordStreams |= avoidRecordStreams_; nanCheck &= enableNanCheck_; + auto device = getDevice(inputs[0]); + // Guard must be created before `currentStreamCaptureStatusMayInitCtx`; + // otherwise, extra CUDA context could be created on device 0. + at::cuda::OptionalCUDAGuard gpuGuard(device); + c10::cuda::CaptureStatus capture_status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); errorIfCapturingNonCapturableNCCL(capture_status); // Bump collective counter - seqCollective_++; + if (!coalescing_state_) { + seqCollective_++; + } op_id_++; - auto device = getDevice(inputs[0]); const auto key = getKeyFromDevice(device); auto ncclComm = getNCCLComm(key, device, opType); if (coalescing_state_ & CoalActive) { + if ((coalescing_state_ & CoalColl) == 0) { + // First op in coalesced operations + seqCollective_++; + } coalescing_state_ |= CoalColl; if (coalescedDevice_.index() < 0) { coalescedDevice_ = device; @@ -2713,8 +2907,8 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( bool enqueue = !coalescing_state_ && capture_status == c10::cuda::CaptureStatus::None; - auto work = - initWork(device, rank_, opType, profilingTitle, inputs, outputs, enqueue); + auto work = initWork( + device, rank_, opType, false, profilingTitle, inputs, outputs, enqueue); // Store references to outputs to be used by WorkNCCL::result and operator<<. work->outputs_ = std::make_shared>(outputs); @@ -2724,8 +2918,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( std::make_shared>(inputs); } - at::cuda::OptionalCUDAGuard gpuGuard(device); - if (nanCheck) { for (const auto& input : inputs) { checkForNan(input, ncclStream); @@ -2850,6 +3042,19 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( bool avoidRecordStreams) { // Environment setting by the user may add onto collective call's option avoidRecordStreams |= avoidRecordStreams_; + + // Currently, the API permits one scenario where inputs.size() and + // outputs.size() are > 0. + // 1. If the call was a _coalesced call, all inputs must be on the same + // device. + // The group of nccl calls applies the collective separately to each input, + // but the group as a whole should be efficient, and might even execute as + // a single fused kernel. + auto device = getDevice(inputs[0]); + // Guard must be created before `currentStreamCaptureStatusMayInitCtx`; + // otherwise, extra CUDA context could be created on device 0. + at::cuda::OptionalCUDAGuard gpuGuard(device); + c10::cuda::CaptureStatus capture_status = c10::cuda::currentStreamCaptureStatusMayInitCtx(); errorIfCapturingNonCapturableNCCL(capture_status); @@ -2858,20 +3063,12 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( seqCollective_++; // For coalescingManager collectives, there is no individual c++ call per - // collective so there is no flight record and we increment seq*_ and op_id_ - // together. Compare this to startCoalesing/endCoalescing flow where we - // increment seq_ once per group and increment op_id_ once per indvidual - // operation within the group + // collective so there is no flight record and we increment seqCollective_ and + // op_id_ together. Compare this to startCoalescing/endCoalescing flow where + // we increment either seqP2P_ or seqCollective_ once per group and increment + // op_id_ once per indvidual operation within the group op_id_++; - // Currently, the API permits one scenario where inputs.size() and - // outputs.size() are > 0. - // 1. If the call was a _coalesced call, all inputs must be on the same - // device. - // The group of nccl calls applies the collective separately to each input, - // but the group as a whole should be efficient, and might even execute as - // a single fused kernel. - auto device = getDevice(inputs[0]); const auto key = getKeyFromDevice(device); auto ncclComm = getNCCLComm(key, device, opType); @@ -2897,7 +3094,14 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( syncStream(device, ncclEvents_[key], ncclStream); auto work = initWork( - device, rank_, opType, profilingTitle, inputs, outputs, /*record=*/true); + device, + rank_, + opType, + false, + profilingTitle, + inputs, + outputs, + /*record=*/true); // Store references to outputs to be used by WorkNCCL::result and operator<<. work->outputs_ = std::make_shared>(outputs); @@ -2907,8 +3111,6 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( std::make_shared>(inputs); } - at::cuda::OptionalCUDAGuard gpuGuard(device); - // Start event should only be recorded before the ncclGroupStart() (which // happens inside AutoNcclGroup guard below) if (work->timingEnabled_) { @@ -2930,8 +3132,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collectiveCoalesced( #endif { - torch::cuda::nccl::AutoNcclGroup nccl_group_guard( - comm, nccl_use_nonblocking()); + torch::cuda::nccl::AutoNcclGroup nccl_group_guard(comm, useNonblocking()); for (const auto i : c10::irange(inputs.size())) { // Both `inputs' and `outputs' are created on a worker stream and used in // different ncclStreams. Hence, both must record the ncclStream to @@ -3063,6 +3264,8 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( } auto device = getDevice(tensor); + at::cuda::OptionalCUDAGuard gpuGuard(device); + std::string key; int p2pRank = 0, p2pTargetRank = 0; bool isSendRecvSelf = false; @@ -3083,8 +3286,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( p2pTargetRank = isSendRecvSelf ? 0 : 1 - p2pRank; if (!coalescing_state_) { - // Bump P2P sequence number. Don't do so if it's a batch P2P, it will be - // bumped in `startCoalescing`. + // Bump P2P sequence number. seqP2P_++; } } @@ -3096,6 +3298,10 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( auto ncclComm = getNCCLComm(key, device, opType, p2pRank, isSendRecvSelf); if (coalescing_state_ & CoalActive) { + // Bump seqP2P_ once per coalesced group, not once per individual op. + if ((coalescing_state_ & CoalP2P) == 0) { + seqP2P_++; + } coalescing_state_ |= CoalP2P; if (coalescedDevice_.index() < 0) { coalescedDevice_ = device; @@ -3150,7 +3356,14 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // cases such as profiling. work = initWork( - device, rank_, opType, profilingTitle, {tensor}, {}, /*record=*/false); + device, + rank_, + opType, + true, + profilingTitle, + {tensor}, + {}, + /*record=*/false); // This bypasses something in Work() that crashes if {tensor} is given as // output, not sure what work->outputs_ = std::make_shared>(); @@ -3174,9 +3387,6 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( /*isP2P=*/true); } - // is gpuGuard needed for the if block below, or can i swap them - at::cuda::OptionalCUDAGuard gpuGuard(device); - // Only check for NaN for send ops, for recv ops `tensor` can be a random // placeholder if (enableNanCheck_ && opType == OpType::SEND) { @@ -3208,10 +3418,14 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( fn(tensor, comm_, ncclStream, p2pTargetRank), ncclComm->getNcclCommFailureReason()); #else - C10D_NCCL_CHECK_TIMEOUT( - fn(tensor, comm_, ncclStream, p2pTargetRank), - ncclComm->getNcclComm(), - ncclComm->getNcclCommFailureReason()); + // In non-blocking mode, we need to use ncclGroup semantics to ensure that the + // kernel is enqueued for single-P2P ops. Otherwise, the event record below + // may not capture the kernel, leading to data corruption. + ncclGroupStart(); + C10D_NCCL_CHECK_NONBLOCKING( + fn(tensor, comm_, ncclStream, p2pTargetRank), std::nullopt); + C10D_NCCL_CHECK_TIMEOUT_GROUPEND( + ncclGroupEnd(), ncclComm, ncclComm->getNcclCommFailureReason()); #endif if (!coalescing_state_) { @@ -3464,8 +3678,9 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce( "Float8 dtypes are not currenlty supported for NCCL reductions"); // @lint-ignore CLANGTIDY RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective std::make_tuple(pg_uid_, pg_desc_), // PG name tuple tensors, // inputTensors tensors, // outputTensors @@ -3494,8 +3709,10 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce_coalesced( // @lint-ignore CLANGTIDY RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective and assume only one collective + // in coalesed range std::make_tuple(pg_uid_, pg_desc_), // PG name tuple tensors, // inputTensors tensors, // outputTensors @@ -3547,8 +3764,9 @@ c10::intrusive_ptr ProcessGroupNCCL::broadcast( // @lint-ignore CLANGTIDY RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective std::make_tuple(pg_uid_, pg_desc_), // PG name tuple tensors, // inputTensors tensors, // outputTensors @@ -3646,8 +3864,9 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce( } check_gpu_single_tensor(tensor); RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective std::make_tuple(pg_uid_, pg_desc_), // PG name tuple tensors, // inputTensors tensors, // outputTensors @@ -3741,8 +3960,9 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather( auto outputTensors_ = outputTensors.back(); RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective std::make_tuple(pg_uid_, pg_desc_), // PG name tuple inputTensors, // inputTensors outputTensors, // outputTensors @@ -3838,6 +4058,26 @@ c10::intrusive_ptr ProcessGroupNCCL::allgather_into_tensor_coalesced( std::vector& outputs, std::vector& inputs, const AllgatherOptions& opts) { + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective and assume only one collective + // in coalesed range + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputs, // inputTensors + outputs, // outputTensors + rank_, // rank + "allgather_into_tensor_coalesced", // collective name + getTensorsNumel(inputs), // inNelems + getTensorsNumel(outputs), // outNelems + inputs[0].scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + return collectiveCoalesced( inputs, outputs, @@ -3872,8 +4112,9 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter( "Float8 dtypes are not currenlty supported for NCCL reductions"); RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective std::make_tuple(pg_uid_, pg_desc_), // PG name tuple inputTensors, // inputTensors outputTensors, // outputTensors @@ -3985,8 +4226,9 @@ c10::intrusive_ptr ProcessGroupNCCL::_reduce_scatter_base( !isFloat8Type(tensor.scalar_type()), "Float8 dtypes are not currenlty supported for NCCL reductions"); RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective std::make_tuple(pg_uid_, pg_desc_), // PG name tuple inputTensor, // inputTensor outputTensor, // outputTensor @@ -4047,6 +4289,27 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( TORCH_CHECK( !isFloat8Type(inputs.back().scalar_type()), "Float8 dtypes are not currenlty supported for NCCL reductions"); + + // @lint-ignore CLANGTIDY + RECORD_PARAM_COMMS_DATA( + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective and assume only one collective + // in coalesed range + std::make_tuple(pg_uid_, pg_desc_), // PG name tuple + inputs, // inputTensors + outputs, // outputTensors + rank_, // rank + "reduce_scatter_tensor_coalesced", // collective name + getTensorsNumel(inputs), // inNelems + getTensorsNumel(outputs), // outNelems + inputs[0].scalar_type(), // dType + std::vector(), // inSplitSizes + std::vector(), // outSplitSizes + globalRankStart, // globalRankStart + globalRankStride, // globalRankStride + this->getSize()); // worldSize + return collectiveCoalesced( inputs, outputs, @@ -4076,8 +4339,9 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce_scatter_tensor_coalesced( c10::intrusive_ptr ProcessGroupNCCL::barrier(const BarrierOptions& opts) { RECORD_PARAM_COMMS( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective std::make_tuple(pg_uid_, pg_desc_), // PG name tuple rank_, // rank "barrier", // collective name @@ -4121,8 +4385,8 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier(const BarrierOptions& opts) { " using GPU ", barDevIdx, " to perform barrier as devices used by this process are currently unknown. ", - "This can potentially cause a hang if this rank to GPU mapping is incorrect.", - "Specify device_ids in barrier() to force use of a particular device,", + "This can potentially cause a hang if this rank to GPU mapping is incorrect. ", + "Specify device_ids in barrier() to force use of a particular device, ", "or call init_process_group() with a device_id."); } @@ -4130,7 +4394,8 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier(const BarrierOptions& opts) { ValueError, barDevIdx >= 0, "Failed to infer a GPU device id to perform barrier. "); - auto barDevice = at::Device(at::DeviceType::CUDA, barDevIdx); + auto barDevice = at::Device( + at::DeviceType::CUDA, static_cast(barDevIdx)); // Create a dummy tensor on the device // Note: we use zeros() instead of empty() to prevent barrier from triggering @@ -4144,7 +4409,7 @@ c10::intrusive_ptr ProcessGroupNCCL::barrier(const BarrierOptions& opts) { // Work will take over barrierTensors auto ncclWork = dynamic_cast(work.get()); TORCH_CHECK(ncclWork); - ncclWork->barrierTensor_ = std::move(barrierTensor); + ncclWork->isBarrierOp_ = true; return work; } @@ -4156,11 +4421,11 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( const AllToAllOptions& /* unused */) { check_gpu_single_tensor(outputTensor, true); check_gpu_single_tensor(inputTensor, true); - if (outputSplitSizes.size() == 0 && inputSplitSizes.size() == 0) { + if (outputSplitSizes.empty() && inputSplitSizes.empty()) { RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective std::make_tuple(pg_uid_, pg_desc_), // PG name tuple inputTensor, // inputTensor outputTensor, // outputTensor @@ -4200,9 +4465,9 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall_base( c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_); RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + - 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective std::make_tuple(pg_uid_, pg_desc_), // PG name tuple inputTensor, // inputTensor outputTensor, // outputTensor @@ -4279,8 +4544,9 @@ c10::intrusive_ptr ProcessGroupNCCL::alltoall( } RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective std::make_tuple(pg_uid_, pg_desc_), // PG name tuple inputTensors, // inputTensors outputTensors, // outputTensors @@ -4331,8 +4597,10 @@ c10::intrusive_ptr ProcessGroupNCCL::send( check_gpu_single_tensor(tensor, true); RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqP2P_) + (coalescing_state_ & CoalP2P ? 0 : 1), + true), // the 1st p2p in coalesced range sets coalescing_state_ and + // bumps seqP2P_ std::make_tuple(pg_uid_, pg_desc_), // PG name tuple tensors, // inputTensors tensors, // outputTensors @@ -4353,8 +4621,14 @@ c10::intrusive_ptr ProcessGroupNCCL::send( ncclComm_t comm, at::cuda::CUDAStream& stream, int dst) { - torch::cuda::nccl::send(input, comm, stream, dst); - return ncclSuccess; + auto ncclDataType = getNcclDataType(input.scalar_type()); + return ncclSend( + input.data_ptr(), + input.numel(), + ncclDataType, + dst, + comm, + stream.stream()); }, dstRank, OpType::SEND, @@ -4372,8 +4646,10 @@ c10::intrusive_ptr ProcessGroupNCCL::recv( check_gpu_single_tensor(tensor, true); RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqP2P_) + (coalescing_state_ & CoalP2P ? 0 : 1), + true), // the 1st p2p in coalesced range sets coalescing_state_ and + // bumps seqP2P_ std::make_tuple(pg_uid_, pg_desc_), // PG name tuple tensors, // inputTensors tensors, // outputTensors @@ -4394,8 +4670,14 @@ c10::intrusive_ptr ProcessGroupNCCL::recv( ncclComm_t comm, at::cuda::CUDAStream& stream, int src) { - torch::cuda::nccl::recv(output, comm, stream, src); - return ncclSuccess; + auto ncclDataType = getNcclDataType(output.scalar_type()); + return ncclRecv( + output.data_ptr(), + output.numel(), + ncclDataType, + src, + comm, + stream.stream()); }, srcRank, OpType::RECV, @@ -4413,11 +4695,12 @@ void ProcessGroupNCCL::groupEnd() { --ncclActiveGroupCounter_; } -void ProcessGroupNCCL::groupEndNonblocking(std::shared_ptr comm) { +void ProcessGroupNCCL::groupEndNonblocking( + const std::shared_ptr& comm) { #ifndef NCCL_HAS_COMM_NONBLOCKING C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); #else - if (!nccl_use_nonblocking()) { + if (!useNonblocking()) { C10D_NCCL_CHECK(ncclGroupEnd(), std::nullopt); } else { C10D_NCCL_CHECK_TIMEOUT_GROUPEND(ncclGroupEnd(), comm, std::nullopt); @@ -4462,7 +4745,7 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( outputs = outputTensors[0]; } else { // if not in the root rank, initialize outputs as empty list - if (outputTensors.size() != 0) { + if (!outputTensors.empty()) { invalidArgument("requires empty output on non-root"); } outputs = {}; @@ -4472,8 +4755,9 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( } RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective std::make_tuple(pg_uid_, pg_desc_), // PG name tuple inputTensors, // inputTensors outputTensors, // outputTensors @@ -4502,13 +4786,14 @@ c10::intrusive_ptr ProcessGroupNCCL::gather( const auto root = opts.rootRank; if (getRank() == root) { if (!avoidRecordStreams_) { - for (auto output : outputs) { + for (auto const& output : outputs) { c10::cuda::CUDACachingAllocator::recordStream( output.storage().data_ptr(), stream); } } } - torch::cuda::nccl::gather(inputTensor, outputs, comm, stream, root); + torch::cuda::nccl::gather( + inputTensor, outputs, comm, stream, static_cast(root)); return ncclSuccess; }, [](at::cuda::CUDAStream&, @@ -4555,7 +4840,7 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( } else { // if not in the root rank, initialize inputTensors as empty place holder // with an empty list - if (inputTensors.size() != 0) { + if (!inputTensors.empty()) { invalidArgument("requires empty input on non-root"); } inputs = {}; @@ -4565,8 +4850,9 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( } RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective std::make_tuple(pg_uid_, pg_desc_), // PG name tuple inputTensors, // inputTensors outputTensors, // outputTensors @@ -4598,13 +4884,14 @@ c10::intrusive_ptr ProcessGroupNCCL::scatter( at::cuda::CUDAStream& stream) { if (getRank() == root) { if (!avoidRecordStreams) { - for (auto input : inputs) { + for (auto const& input : inputs) { c10::cuda::CUDACachingAllocator::recordStream( input.storage().data_ptr(), stream); } } } - torch::cuda::nccl::scatter(inputs, outputTensor, comm, stream, root); + torch::cuda::nccl::scatter( + inputs, outputTensor, comm, stream, static_cast(root)); return ncclSuccess; }, [](at::cuda::CUDAStream&, @@ -4643,8 +4930,9 @@ c10::intrusive_ptr ProcessGroupNCCL::_allgather_base( } RECORD_PARAM_COMMS_DATA( - static_cast( - this->getSequenceNumberForGroup() + 1), // seq + 1 to match collective + std::make_tuple( + static_cast(seqCollective_) + 1, + false), // seq + 1 to match collective std::make_tuple(pg_uid_, pg_desc_), // PG name tuple input_tensor, // inputTensors output_tensor, // outputTensors diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 84c8de4fc9484..839463a9d8be1 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -63,13 +64,6 @@ static std::vector TORCH_NCCL_ASYNC_ERROR_HANDLING = { static std::vector TORCH_NCCL_DUMP_ON_TIMEOUT = { "TORCH_NCCL_DUMP_ON_TIMEOUT"}; -// TODO: remove this change after a safe rollout. -// Control whether we sleep after an exception is thrown. -// This change is temporary and is used to safely remove the current sleep that -// exists after an exception is thrown. -static std::vector TORCH_NCCL_SLEEP_AFTER_EXCEPTION = { - "TORCH_NCCL_SLEEP_AFTER_EXCEPTION"}; - // Control whether Desync Debug is enabled. This variable must be set // together with TORCH_NCCL_ASYNC_ERROR_HANDLING. static std::vector TORCH_NCCL_DESYNC_DEBUG = { @@ -272,12 +266,13 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Constructor takes a list of CUDA devices WorkNCCL( - const std::string& pgUID, - const std::string& pgDesc, + std::string pgUID, + std::string pgDesc, at::Device& device, int rank, OpType opType, uint64_t seq, + bool isP2P = false, const char* profilingTitle = nullptr, const std::optional>& inputs = std::nullopt, bool desyncDebug = false, @@ -325,6 +320,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Get a Future object that will be marked as completed internally. c10::intrusive_ptr getFuture() override; + // Get a Future result of each work (e.g. success, different error types). + // instead of the tensor output. + c10::intrusive_ptr getFutureResult() override; + float getDuration() const override; uint64_t getSequencenumber() const override; @@ -362,17 +361,17 @@ class TORCH_API ProcessGroupNCCL : public Backend { // The NCCL communicator used for this work item. std::shared_ptr ncclComm_; - // Tensors used for barrier op - at::Tensor barrierTensor_; + // whether this work is a barrier op + bool isBarrierOp_{false}; // Clone of blockingWait_ from ProcessGroupNCCL. - bool blockingWait_ = false; + bool blockingWait_{false}; // Clone of avoidRecordStreams_ from ProcessGroupNCCL. - bool avoidRecordStreams_ = false; + bool avoidRecordStreams_{false}; // Clone of opTimeout_ from ProcessGroupNCCL. - std::chrono::milliseconds opTimeout_; + std::chrono::milliseconds opTimeout_{}; // Ephemeral timeouts are owned by exactly one work, // and reset after that work completes. @@ -384,8 +383,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Time point representing when the work started. std::chrono::time_point workStartTime_; - // Record the collective sequential number. + // Record the sequential number of collective or p2p. uint64_t seq_; + bool isP2P_; // Indicates if the nccl start event has been updated to the store trace. // This will be used by desync debug. @@ -439,6 +439,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // The future returned by getFuture. c10::intrusive_ptr future_; + // the future result (e.g., success or failure) of the work + c10::intrusive_ptr futureWorkResult_; + bool timingEnabled_; // unique id used to tell the trace buffer that this // work has completed @@ -455,10 +458,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { private: std::mutex cacheMutex_; - // NOTE: We intentionaly store raw pointers so that + // NOTE: We intentionally store raw pointers so that // we do not attempt to destroy the event objects on process exit, // because cuda may be gone. - std::vector + std::deque eventsArray_[2]; // 0 for timing=false, 1 for timing=true }; @@ -484,7 +487,25 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Optional "parent" backend and color to create communicators from // via `ncclCommSplit` std::shared_ptr split_from; - int64_t split_color{0}; + // Color to use for `ncclCommSplit`, values: + // * Non-negative value: in group; + // * NCCL_SPLIT_NOCOLOR (-1): not in group; + // * NCCL_SPLIT_NOCOLOR - 1: uninitialized. + // [Note 1]: the type must be `int` instead of `int64_t` because NCCL API + // accepts int. Otherwise, an implicit conversion may happen at the API call + // and the value may become negative. + // [Note 2]: this member is pybinded to Python, the value passed from Python + // must be within the numerical range of C++ int. Otherwise, Python will + // raise a RuntimeError saying type is incompatible. See also + // `_process_group_color` in `distributed_c10d.py`. +#ifdef NCCL_HAS_COMM_SPLIT + int split_color{NCCL_SPLIT_NOCOLOR - 1}; +#else + // [Note 3]: for older NCCL versions, NCCL_SPLIT_NOCOLOR is not defined. But + // `split_color` is pybinded to Python, so we need to define it. So we use + // the int value of `NCCL_SPLIT_NOCOLOR` (-1) instead. + int split_color{-2}; +#endif std::vector global_ranks_in_group; std::string group_name; }; @@ -504,7 +525,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // communicator. These NCCL communicators are cached and reused if possible. // ProcessGroupNCCL( - const c10::intrusive_ptr& store, + c10::intrusive_ptr store, int rank, int size, c10::intrusive_ptr options = Options::create()); @@ -518,7 +539,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { int size, const std::string& groupName, c10::intrusive_ptr options = Options::create()) - : ProcessGroupNCCL(store, rank, size, options) {} + : ProcessGroupNCCL(store, rank, size, std::move(options)) {} ~ProcessGroupNCCL() override; @@ -641,7 +662,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { void groupEnd(); - void groupEndNonblocking(std::shared_ptr comm); + void groupEndNonblocking(const std::shared_ptr& comm); c10::intrusive_ptr gather( std::vector>& outputTensors, @@ -680,21 +701,24 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Helper function for iteratively aborting communicators in the provided map void abortCommsFromMap( std::unordered_map>& ncclCommsMap, - std::optional abortReason); + const std::optional& abortReason); c10::intrusive_ptr initIntraNodeComm(); + // Destroy (shutdown) this backend -- normal exit. + void shutdown(); + // Provides an API to abort the ProcessGroup (similar to ncclCommAbort) // instead of relying on ProcessGroupNCCL destructor. - // return true if abort is successful, otherwise false - bool abort(std::optional abortReason = std::nullopt); - - void shutdown(std::optional reason = std::nullopt); + void abort(); void eagerConnectSingleDevice(at::Device device) override; void performNocolorSplit(at::Device device); + // If all comms on this PG are fully initialized, return true. + bool isInitialized(); + // This method adds a temporary extension for the timeout period, // applying to all collectives between the calling of this API and // the completion of the first collective on the GPU. While this feature @@ -710,7 +734,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // `opTimeout_` of the provided WorkNCCL instance is the same as the specified // timeout. bool verifyWorkTimeoutForTest( - const c10::intrusive_ptr work, + const c10::intrusive_ptr& work, const std::chrono::milliseconds& timeout); protected: @@ -740,6 +764,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { at::Device& device, int rank, OpType opType, + bool isP2P, const char* profilingTitle = nullptr, const std::vector& inputs = {}, const std::vector& outputs = {}, @@ -750,6 +775,13 @@ class TORCH_API ProcessGroupNCCL : public Backend { // operations, we might need to use a side thread to do it. bool dumpDebuggingInfo(); + // Abort all communicators on this rank. + bool abortComms(const std::optional& abortReason = std::nullopt); + + // A helper function to check if nonblocking API mode should be used. + // Use this helper instead of directly checking `useNonblocking_` variable. + bool useNonblocking(); + private: int globalRankStart; int globalRankStride; @@ -900,7 +932,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Function that directly trigger std::abort so that the whole process // gets terminated. - virtual void terminateProcess(std::string errMsg); + virtual void terminateProcess(const std::string& errMsg); // A helper function to wait for a future to complete or timeout. void waitForFutureOrTimeout( @@ -1001,7 +1033,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::mutex mutex_; // Heartbeat of watchdog thread. - std::atomic_uint64_t heartbeat_; + std::atomic_uint64_t heartbeat_{}; // The time interval used for deciding whether there is no watchdog heartbeat. int heartbeatTimeoutInSec_; @@ -1009,6 +1041,9 @@ class TORCH_API ProcessGroupNCCL : public Backend { // timeout for the dump to finish. int waitTimeoutDumpInMilSec_; + // promise to coordinate flight recorder dump. + std::promise promiseFlightRecorderDump_; + // Interval of check coordinated signals in ProcessGroupNCCL from other ranks // e.g., trigger the dump of the debugging info for timeout when notified. int coordCheckIntervalMilSec_; @@ -1017,10 +1052,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { int ncclTraceBufferSize_; // We gate the heartbeat monitor thread so that we can roll it out gradually. - std::atomic monitorThreadEnabled_; + std::atomic monitorThreadEnabled_{}; // We gate the cudaEventCache so that we can roll it out gradually. - std::atomic cudaEventCacheEnabled_; + std::atomic cudaEventCacheEnabled_{}; // Monitor thread which checks the heartbeat of Watchdog thread. // If the monitor thread finds there is no heartbeat, it will dump debug info @@ -1043,7 +1078,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::atomic collectiveDebugInfoMode_; // Whether there are hooks pending to be fired - std::atomic hasPendingHooks_; + std::atomic hasPendingHooks_{}; // This is the signal from watchdog threads to indicate whether the monitor // thread should dump. Making it static so that it is accessiable from all the @@ -1086,7 +1121,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::list completedWorkList_; // Add Work Pointer to workVector - void workEnqueue(c10::intrusive_ptr); + void workEnqueue(const c10::intrusive_ptr&); // The CUDA streams used by NCCL kernels std::unordered_map ncclStreams_; @@ -1157,11 +1192,11 @@ class TORCH_API ProcessGroupNCCL : public Backend { // Whether or not to create start CUDAEvent and enable timing for start // and end events. Note that enableTiming_ is always true if desyncDebug_ // is set to true. - std::atomic enableTiming_; + std::atomic enableTiming_{}; // Flag to enable the print of hash value of input/output of collectives for // verification. - std::atomic enableCollecticeHashDebug_; + std::atomic enableCollecticeHashDebug_{}; // Whether or not TORCH_NCCL_AVOID_RECORD_STREAMS was set bool avoidRecordStreams_ = false; @@ -1206,6 +1241,10 @@ class TORCH_API ProcessGroupNCCL : public Backend { std::shared_ptr pgStatus_ = std::make_shared(); + + // Internal cached value: use NCCL non-blocking API mode or not. + // Use `useNonblocking()` method instead of accessing this variable directly. + std::optional useNonblocking_{std::nullopt}; }; // Dumps the NCCL comm traces and additional information about the Process diff --git a/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp b/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp index d52adada45868..dab6aa6d26ece 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp @@ -1,6 +1,7 @@ #ifdef USE_C10D_UCC #include +#include #include #include #include @@ -157,11 +158,10 @@ void read_config() { torch_ucc_config.enable_comms_logger = false; // read all torch_ucc env. variables and update the map - char* env; - for (auto& torch_ucc_env : torch_ucc_envs_map) { - env = std::getenv(torch_ucc_env.first.c_str()); - if (env) { - torch_ucc_envs_map[torch_ucc_env.first] = std::string(env); + for (auto& [env_name, value] : torch_ucc_envs_map) { + auto env = c10::utils::get_env(env_name.c_str()); + if (env.has_value()) { + value = std::move(env.value()); } } diff --git a/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp b/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp index 6107261e16725..a0d2738ab6928 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp @@ -154,8 +154,7 @@ struct CollectiveFingerPrint { // tensor>] std::vector outputs; outputs.reserve(backend->getSize()); - for (const auto i : c10::irange(backend->getSize())) { - std::ignore = i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(backend->getSize())) { outputs.emplace_back(at::zeros_like(tensor_shape)); } output_tensors.emplace_back(outputs); diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.cpp b/torch/csrc/distributed/c10d/SymmetricMemory.cpp index fbfefdc10cb6a..7911a9d875b3a 100644 --- a/torch/csrc/distributed/c10d/SymmetricMemory.cpp +++ b/torch/csrc/distributed/c10d/SymmetricMemory.cpp @@ -8,6 +8,8 @@ static bool is_finalizing_ = false; class AllocatorMap { public: + AllocatorMap(const AllocatorMap&) = delete; + AllocatorMap& operator=(const AllocatorMap&) = delete; static AllocatorMap& get() { static AllocatorMap instance; return instance; @@ -35,8 +37,6 @@ class AllocatorMap { private: AllocatorMap() = default; - AllocatorMap(const AllocatorMap&) = delete; - AllocatorMap& operator=(const AllocatorMap&) = delete; std::unordered_map< c10::DeviceType, @@ -71,8 +71,12 @@ static at::Tensor empty_strided_p2p_persistent( "is still active."); } - const size_t numel = - std::accumulate(size.begin(), size.end(), 1, std::multiplies()); + const size_t numel = std::accumulate( + size.begin(), + size.end(), + size_t(1), + // NOLINTNEXTLINE(modernize-use-transparent-functors) + std::multiplies()); const size_t element_size = c10::elementSize(dtype); const size_t alloc_size = numel * element_size; @@ -105,8 +109,7 @@ static at::Tensor empty_strided_p2p_persistent( } // namespace -namespace c10d { -namespace symmetric_memory { +namespace c10d::symmetric_memory { bool is_finalizing() { return is_finalizing_; @@ -156,8 +159,12 @@ at::Tensor empty_strided_p2p( return empty_strided_p2p_persistent( size, stride, dtype, device, group_name, *alloc_id); } - const size_t numel = - std::accumulate(size.begin(), size.end(), 1, std::multiplies()); + const size_t numel = std::accumulate( + size.begin(), + size.end(), + size_t(1), + // NOLINTNEXTLINE(modernize-use-transparent-functors) + std::multiplies()); const size_t element_size = c10::elementSize(dtype); const size_t alloc_size = numel * element_size; @@ -195,5 +202,4 @@ TORCH_API bool has_multicast_support( auto allocator = get_allocator(device_type); return allocator->has_multicast_support(device_idx); } -} // namespace symmetric_memory -} // namespace c10d +} // namespace c10d::symmetric_memory diff --git a/torch/csrc/distributed/c10d/SymmetricMemory.hpp b/torch/csrc/distributed/c10d/SymmetricMemory.hpp index 30dc457518c63..55b212ef90154 100644 --- a/torch/csrc/distributed/c10d/SymmetricMemory.hpp +++ b/torch/csrc/distributed/c10d/SymmetricMemory.hpp @@ -60,9 +60,9 @@ class TORCH_API SymmetricMemory : public c10::intrusive_ptr_target { c10::ScalarType dtype, int64_t storage_offset) = 0; - virtual void barrier(int channel) = 0; - virtual void put_signal(int dst_rank, int channel) = 0; - virtual void wait_signal(int src_rank, int channel) = 0; + virtual void barrier(int channel, size_t timeout_ms) = 0; + virtual void put_signal(int dst_rank, int channel, size_t timeout_ms) = 0; + virtual void wait_signal(int src_rank, int channel, size_t timeout_ms) = 0; virtual int get_rank() = 0; virtual int get_world_size() = 0; diff --git a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp index c3fa09ab38bef..b5f4a8e547e22 100644 --- a/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp +++ b/torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp @@ -780,7 +780,7 @@ class UvClient : public UvTcpSocket { } bool parse_ping_command() { - uint32_t nonce; + uint32_t nonce = 0; if (!stream.read_value(nonce)) { return false; } diff --git a/torch/csrc/distributed/c10d/TraceUtils.h b/torch/csrc/distributed/c10d/TraceUtils.h index 9684ebe468a87..b211fc83564a8 100644 --- a/torch/csrc/distributed/c10d/TraceUtils.h +++ b/torch/csrc/distributed/c10d/TraceUtils.h @@ -8,13 +8,10 @@ #include #include #include -#include #include #include -#include #include -#include #include namespace c10d { @@ -27,7 +24,7 @@ struct ProcessGroupStatus { int64_t lastEnqueuedSeq{-1}; // the sequential number of the last collective started as the kernel int64_t lastStartedSeq{-1}; - // the sequential number of the last colletive completed marked by + // the sequential number of the last collective completed marked by // the watchdog thread // initialized to be -1 to indicate no collective has been completed int64_t lastCompletedSeq{-1}; @@ -129,7 +126,7 @@ inline std::string analyzeLaggingRanks(const TraceMap& traceMap) { std::string report = "\n\t - To our best knowledge, the lagging/dead/mismatched ranks " "that caused the desync are:"; - if (startRanks.size()) { + if (!startRanks.empty()) { report += c10::str( "\n\t - [", ranksToString(startRanks), @@ -137,7 +134,7 @@ inline std::string analyzeLaggingRanks(const TraceMap& traceMap) { lagSeq, " (count from 1)"); } - if (endRanks.size()) { + if (!endRanks.empty()) { report += c10::str( "\n\t [", ranksToString(endRanks), @@ -169,7 +166,7 @@ inline std::string dumpSnapshot(TraceMap& traceMap) { } } - if (collectivesStart.size()) { + if (!collectivesStart.empty()) { report += c10::str("\n\t #", seq, " started ranks:"); for (auto& mapPair : collectivesStart) { report += c10::str( @@ -179,7 +176,7 @@ inline std::string dumpSnapshot(TraceMap& traceMap) { mapPair.first); } } - if (collectivesEnd.size()) { + if (!collectivesEnd.empty()) { report += c10::str("\n\t #", seq, " finished ranks:"); for (auto& mapPair : collectivesEnd) { report += c10::str( @@ -218,7 +215,7 @@ inline std::string retrieveDesyncReport( int worldSize) { std::string report; - uint64_t thisSeq; + uint64_t thisSeq = 0; std::string thisCol; std::vector missingRanks; @@ -226,7 +223,7 @@ inline std::string retrieveDesyncReport( for (const auto rank : c10::irange(worldSize)) { // Build traceMapStart. - uint64_t seqStart; + uint64_t seqStart = 0; { std::string traceKeyStart = getTraceStartKey(pgName, rank); if (!store->check({traceKeyStart})) { @@ -250,7 +247,7 @@ inline std::string retrieveDesyncReport( if (!store->check({traceKeyEnd})) { continue; } - uint64_t seq; + uint64_t seq = 0; std::string col; if (!parseTraceValue(store, traceKeyEnd, seq, col)) { return report; @@ -323,7 +320,7 @@ inline std::string get_python_cpp_trace() { auto frame_id = s_tb[idx]; const auto& frame = s_tbs.all_frames.at(frame_id); oss << "#" << idx << " " << frame.funcname << " from " << frame.filename - << ":" << frame.lineno << std::endl; + << ":" << frame.lineno << '\n'; } return oss.str(); } diff --git a/torch/csrc/distributed/c10d/UCCTracing.cpp b/torch/csrc/distributed/c10d/UCCTracing.cpp index 5558f1a929267..c61acdf824daf 100644 --- a/torch/csrc/distributed/c10d/UCCTracing.cpp +++ b/torch/csrc/distributed/c10d/UCCTracing.cpp @@ -1,5 +1,6 @@ #ifdef USE_C10D_UCC +#include #include #include @@ -32,9 +33,9 @@ void ProcessGroupUCCLogger::flushComms(int rank, int world_size) { } std::string fullpath = "/tmp/" + dirname; - char* user_path = std::getenv("TORCH_UCC_COMMS_TRACE_OUTPUT_DIR"); - if (user_path) { - fullpath = user_path; + auto user_path = c10::utils::get_env("TORCH_UCC_COMMS_TRACE_OUTPUT_DIR"); + if (user_path.has_value()) { + fullpath = std::move(user_path.value()); } std::string trace_filename = c10::str(fullpath, "/rank", rank, ".json"); std::ofstream _outfile; @@ -149,7 +150,7 @@ void CommTraceLogger::recordComms( // record the trace to kineto trace if applicable RECORD_PARAM_COMMS( - static_cast(seqnum), // seq + std::make_tuple(static_cast(seqnum), false), // (seq, isP2P) std::make_tuple("0", ""), // pg_name tuple rank, commName.c_str(), diff --git a/torch/csrc/distributed/c10d/Utils.hpp b/torch/csrc/distributed/c10d/Utils.hpp index ea4a4653bc35f..08fc975c5c48b 100644 --- a/torch/csrc/distributed/c10d/Utils.hpp +++ b/torch/csrc/distributed/c10d/Utils.hpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include @@ -92,7 +93,7 @@ inline std::vector split( inline std::string getCvarString( const std::vector& env, const char* def) { - const char* ret = def; + std::string ret(def); if (env.empty()) { TORCH_CHECK(false, "No environment variables passed"); @@ -103,14 +104,14 @@ inline std::string getCvarString( * versions of a variable get higher priority than the latter * versions of the same variable */ for (ssize_t i = static_cast(env.size()) - 1; i >= 0; i--) { - const char* val = std::getenv(env[i].c_str()); - if (val == nullptr) { + auto val = c10::utils::get_env(env[i].c_str()); + if (!val) { continue; } else if (i) { WARN_ENV_VAR_ONCE(env[i], env[0]); } - ret = val; + ret = val.value(); } return ret; @@ -157,15 +158,14 @@ inline bool getCvarBool(const std::vector& env, bool def) { * versions of a variable get higher priority than the latter * versions of the same variable */ for (ssize_t i = static_cast(env.size()) - 1; i >= 0; i--) { - char* val_ = std::getenv(env[i].c_str()); - if (val_ == nullptr) { + auto val = c10::utils::get_env(env[i].c_str()); + if (!val.has_value()) { continue; } else if (i) { WARN_ENV_VAR_ONCE(env[i], env[0]); } - std::string val = std::string(val_); - for (auto& x : val) { + for (auto& x : val.value()) { // NOLINTNEXTLINE(*-narrowing-conversions) x = std::tolower(x); } diff --git a/torch/csrc/distributed/c10d/Work.cpp b/torch/csrc/distributed/c10d/Work.cpp index 8beb8f2936208..d7890566acbb3 100644 --- a/torch/csrc/distributed/c10d/Work.cpp +++ b/torch/csrc/distributed/c10d/Work.cpp @@ -98,8 +98,11 @@ void Work::abort() { TORCH_CHECK(false, "Work::abort not implemented."); } -c10::intrusive_ptr Work::getFuture() { - TORCH_CHECK(false, "Work::getFuture not implemented.") +c10::intrusive_ptr Work::getFuture(){ + TORCH_CHECK(false, "Work::getFuture not implemented.")} + +c10::intrusive_ptr Work::getFutureResult() { + TORCH_CHECK(false, "Work::getFutureResult not implemented.") } void Work::finish(std::exception_ptr exception) { diff --git a/torch/csrc/distributed/c10d/Work.hpp b/torch/csrc/distributed/c10d/Work.hpp index c10e5007b9f54..5fd6c6c737885 100644 --- a/torch/csrc/distributed/c10d/Work.hpp +++ b/torch/csrc/distributed/c10d/Work.hpp @@ -34,6 +34,14 @@ enum class OpType : std::uint8_t { UNKNOWN = 100, }; +// TODO: support different types of failures/errors +enum class WorkResult : std::uint8_t { + SUCCESS = 0, + TIMEOUT = 1, + COMM_ERROR = 2, + UNKNOWN = 100, +}; + // Converts OpType to human readable string. TORCH_API std::string opTypeToString(OpType opType); @@ -108,6 +116,11 @@ class TORCH_API Work : public torch::CustomClassHolder { // work. Only NCCL backend is currently supported. virtual c10::intrusive_ptr getFuture(); + // Get a Future object that would be marked as either success or failure + // This API can be used by the user to track the completion of the work + // and hanlde the exception if any. + virtual c10::intrusive_ptr getFutureResult(); + virtual float getDuration() const; virtual uint64_t getSequencenumber() const; diff --git a/torch/csrc/distributed/c10d/c10d.h b/torch/csrc/distributed/c10d/c10d.h index 5151a33f7ee35..4f1f92af9976b 100644 --- a/torch/csrc/distributed/c10d/c10d.h +++ b/torch/csrc/distributed/c10d/c10d.h @@ -2,12 +2,8 @@ #include -namespace torch { -namespace distributed { -namespace c10d { +namespace torch::distributed::c10d { PyMethodDef* python_functions(); -} // namespace c10d -} // namespace distributed -} // namespace torch +} // namespace torch::distributed::c10d diff --git a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp index 047459b965589..e4a2d301a5661 100644 --- a/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp +++ b/torch/csrc/distributed/c10d/control_plane/WorkerServer.cpp @@ -83,7 +83,7 @@ bool file_exists(const std::string& path) { #ifdef _WIN32 return std::filesystem::exists(path); #else - struct stat rc; + struct stat rc {}; return lstat(path.c_str(), &rc) == 0; #endif } diff --git a/torch/csrc/distributed/c10d/debug.cpp b/torch/csrc/distributed/c10d/debug.cpp index a4b2fa6180aaf..d5d77094e1718 100644 --- a/torch/csrc/distributed/c10d/debug.cpp +++ b/torch/csrc/distributed/c10d/debug.cpp @@ -4,6 +4,7 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. +#include #include #include @@ -19,15 +20,15 @@ namespace detail { namespace { DebugLevel loadDebugLevelFromEnvironment() { - char* env_value = std::getenv("TORCH_DISTRIBUTED_DEBUG"); + auto env_value = c10::utils::get_env("TORCH_DISTRIBUTED_DEBUG"); - if (env_value == nullptr) { + if (!env_value.has_value()) { return DebugLevel::Off; } DebugLevel level{}; - std::string level_str{env_value}; + std::string level_str = std::move(env_value.value()); std::transform( level_str.begin(), diff --git a/torch/csrc/distributed/c10d/error.h b/torch/csrc/distributed/c10d/error.h index fff2b45c4c952..fef7a630410f4 100644 --- a/torch/csrc/distributed/c10d/error.h +++ b/torch/csrc/distributed/c10d/error.h @@ -45,12 +45,10 @@ struct formatter { } // namespace fmt -namespace c10d { -namespace detail { +namespace c10d::detail { inline std::error_code lastError() noexcept { return std::error_code{errno, std::generic_category()}; } -} // namespace detail -} // namespace c10d +} // namespace c10d::detail diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 2b70f0edbb185..b1cebfe0502be 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #ifndef _WIN32 #include @@ -96,16 +97,17 @@ class IntrusivePtrNoGilDestructor { public: IntrusivePtrNoGilDestructor() = default; IntrusivePtrNoGilDestructor(const IntrusivePtrNoGilDestructor&) = default; - IntrusivePtrNoGilDestructor(IntrusivePtrNoGilDestructor&&) = default; + IntrusivePtrNoGilDestructor(IntrusivePtrNoGilDestructor&&) noexcept = default; IntrusivePtrNoGilDestructor& operator=(const IntrusivePtrNoGilDestructor&) = default; - IntrusivePtrNoGilDestructor& operator=(IntrusivePtrNoGilDestructor&&) = - default; + IntrusivePtrNoGilDestructor& operator=( + IntrusivePtrNoGilDestructor&&) noexcept = default; /* implicit */ IntrusivePtrNoGilDestructor(c10::intrusive_ptr impl) : impl_(std::move(impl)) {} // This ctor is very important; see // https://github.com/pybind/pybind11/issues/2957 explicit IntrusivePtrNoGilDestructor(T* impl) + // NOLINTNEXTLINE(bugprone-exception-escape) : impl_(c10::intrusive_ptr::unsafe_steal_from_new(impl)) {} ~IntrusivePtrNoGilDestructor() { if (impl_) { @@ -123,7 +125,7 @@ class IntrusivePtrNoGilDestructor { T* operator->() const noexcept { return impl_.get(); } - C10_NODISCARD T* get() const noexcept { + [[nodiscard]] T* get() const noexcept { return impl_.get(); } void reset() noexcept { @@ -908,8 +910,8 @@ This class does not support ``__members__`` property.)"); module.def( "_register_process_group", [](const std::string& group_name, - c10::intrusive_ptr<::c10d::ProcessGroup> group) { - ::c10d::register_process_group(group_name, std::move(group)); + const c10::intrusive_ptr<::c10d::ProcessGroup>& group) { + ::c10d::register_process_group(group_name, group); }, py::arg("group_name"), py::arg("group")); @@ -928,7 +930,7 @@ This class does not support ``__members__`` property.)"); const c10::intrusive_ptr<::c10d::Work>& work) { dynamic_cast<::c10d::PyProcessGroup::PyWork*>(work.get()) ->ref_py_object(); - ::c10d::register_work(tensor, std::move(work)); + ::c10d::register_work(tensor, work); }, py::arg("tensor"), py::arg("work")); @@ -1075,17 +1077,23 @@ This class does not support ``__members__`` property.)"); py::arg("sizes"), py::arg("dtype"), py::arg("storage_offset") = 0) - .def("barrier", &SymmetricMemory::barrier, py::arg("channel") = 0) + .def( + "barrier", + &SymmetricMemory::barrier, + py::arg("channel") = 0, + py::arg("timeout_ms") = 0) .def( "put_signal", &SymmetricMemory::put_signal, py::arg("dst_rank"), - py::arg("channel") = 0) + py::arg("channel") = 0, + py::arg("timeout_ms") = 0) .def( "wait_signal", &SymmetricMemory::wait_signal, py::arg("src_rank"), - py::arg("channel") = 0) + py::arg("channel") = 0, + py::arg("timeout_ms") = 0) .def( "stream_write_value32", &SymmetricMemory::stream_write_value32, @@ -1534,6 +1542,9 @@ Example:: bool useLibUV) { std::optional numWorkers = std::nullopt; if (worldSize.has_value() && worldSize.value() > -1) { + if (worldSize.value() == 0) { + throw py::value_error("TCPStore world size cannot be 0"); + } numWorkers = static_cast(worldSize.value()); } @@ -2166,7 +2177,7 @@ communication mechanism. // python-related libs. self->registerOnCompletionHook( [hookWrapper = ::c10d::PythonOnCompletionHook(std::move( - hook))](std::shared_ptr<::c10d::WorkInfo> workInfo) { + hook))](const std::shared_ptr<::c10d::WorkInfo>& workInfo) { hookWrapper(workInfo); }); }, @@ -2759,33 +2770,20 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`). &::c10d::ProcessGroupNCCL::setBoundDeviceId) .def( "perform_nocolor_split", - &::c10d::ProcessGroupNCCL::performNocolorSplit); + &::c10d::ProcessGroupNCCL::performNocolorSplit) + .def( + "abort", + &::c10d::ProcessGroupNCCL::abort, + py::call_guard()) + .def( + "_is_initialized", + &::c10d::ProcessGroupNCCL::isInitialized, + py::call_guard()); module.def( "_get_intra_node_comm_usage_counter", &::c10d::intra_node_comm::getIntraNodeCommUsageCounter); - using IntraNodeComm = ::c10d::intra_node_comm::IntraNodeComm; - py::class_>( - module, "_IntraNodeComm") - .def( - py::init([](const c10::intrusive_ptr<::c10d::Store>& store, - size_t rank, - size_t world_size, - std::optional buffer_size) { - auto comm = c10::make_intrusive( - store, rank, world_size, buffer_size); - if (!comm->rendezvous()) { - throw std::runtime_error("IntraNodeComm::rendezvous failed"); - } - return comm; - }), - py::arg("store"), - py::arg("rank"), - py::arg("world_size"), - py::arg("buffer_size") = std::nullopt) - .def("barrier", &IntraNodeComm::barrier, py::arg("ranks") = py::none()); - #ifdef NCCL_HAS_COMM_CTA_CGA py::class_( processGroupNCCL, @@ -2922,6 +2920,12 @@ Example:: .value("_ALLREDUCE_SPARSE", ::c10d::OpType::_ALLREDUCE_SPARSE) .value("UNKNOWN", ::c10d::OpType::UNKNOWN); + py::enum_<::c10d::WorkResult>(module, "WorkResult") + .value("SUCCESS", ::c10d::WorkResult::SUCCESS) + .value("TIMEOUT", ::c10d::WorkResult::TIMEOUT) + .value("COMM_ERROR", ::c10d::WorkResult::COMM_ERROR) + .value("UNKNOWN", ::c10d::WorkResult::UNKNOWN); + py::class_<::c10d::WorkInfo, std::shared_ptr<::c10d::WorkInfo>>( module, "WorkInfo") .def_readonly("op_type", &::c10d::WorkInfo::opType) @@ -3006,6 +3010,27 @@ such as `dist.all_reduce(tensor, async_op=True)`. However, if timeout is set, it will block the CPU thread until the NCCL work is completed or timed out. If timeout, exception will be thrown. )") + .def( + "get_future_result", + [](::c10d::Work& work) -> std::shared_ptr { + return std::make_shared( + work.getFutureResult()); + }, + R"( + Returns: + A ``torch.futures.Future`` object of int type which maps to the enum type of WorkResult + As an example, a future object can be retrieved + by ``fut = process_group.allreduce(tensor).get_future_result()``. + + Example:: + users can use ``fut.wait()`` to blocking wait for the completion of the work and + get the WorkResult by ``fut.value()``. + Also, users can use ``fut.then(call_back_func)`` to register a callback function to be called + when the work is completed, without blocking the current thread. + + .. warning :: + ``get_future_result`` API supports NCCL + )") .def( "get_future", [](::c10d::Work& work) -> std::shared_ptr { diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cpp b/torch/csrc/distributed/c10d/intra_node_comm.cpp index 05bb50313e846..6a61f16e9aea5 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.cpp @@ -1,29 +1,14 @@ #include -#include -#include -#include +#include #include -#include -#include - -#include -#include -#include -#include -#include -#include - -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) -#include -#include -#endif - -#include +// #include namespace c10d::intra_node_comm { +bool isIntraNodeCommSupported(); + static std::vector ENABLE_INTRA_NODE_COMM = { "ENABLE_INTRA_NODE_COMM"}; // Forces detectedTopology() to return Topology::FULLY_CONNECTED, so @@ -33,145 +18,23 @@ static std::vector TEST_INTRA_NODE_COMM = {"TEST_INTRA_NODE_COMM"}; static int intraNodeCommIdx = 0; -//////////////////////////////////////////////////////////////////////////////// -// CUDA Functions -//////////////////////////////////////////////////////////////////////////////// - -bool isIntraNodeCommSupported(); - -std::optional getHybridCubeMesh(NvlMesh nvlMesh); - -void* initP2pState(); - -void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank); - -//////////////////////////////////////////////////////////////////////////////// -// Topology Detection -//////////////////////////////////////////////////////////////////////////////// - -static std::ostream& operator<<(std::ostream& os, const NvlMesh& nvlMesh) { - std::ostringstream oss; - for (size_t i = 0; i < kMaxDevices; ++i) { - for (size_t j = 0; j < kMaxDevices; ++j) { - oss << nvlMesh[i][j] << " "; - } - oss << '\n'; - } - os << oss.str(); - return os; -} - -static bool isSame(NvlMesh lhs, NvlMesh rhs) { - for (size_t i = 0; i < kMaxDevices; ++i) { - for (size_t j = 0; j < kMaxDevices; ++j) { - if (lhs[i][j] != rhs[i][j]) { - return false; - } - } - } - return true; -} - /** * Query the nvlink connection among devices. */ -static NvlMesh getNvlMesh(const std::vector& rankToBusId) { -#if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED) - using namespace c10::cuda; - +static NvlMesh getNvlMesh(const std::vector& rankToDeviceIdx) { + auto connectivity = detect_dma_connectivity(c10::DeviceType::CUDA, "nvlink"); NvlMesh nvlMesh = {}; - auto driverApi = DriverAPI::get(); - if (driverApi == nullptr) { - return nvlMesh; - } - - const auto worldSize = rankToBusId.size(); - std::vector devices(worldSize, nullptr); - std::unordered_map busIdToRank; - std::vector switchLinkCount(worldSize, 0); - - for (size_t r = 0; r < worldSize; ++r) { - busIdToRank.emplace(rankToBusId[r], r); - TORCH_CHECK( - driverApi->nvmlDeviceGetHandleByPciBusId_v2_( - rankToBusId[r].c_str(), &devices[r]) == NVML_SUCCESS); - } - - // TODO: find a better way to determine this - constexpr size_t kMaxNvLinks = 20; - - // For each device, loop over devices connected to it via NVLink - for (size_t idx = 0; idx < worldSize; ++idx) { - for (size_t link = 0; link < kMaxNvLinks; ++link) { - nvmlReturn_t ret; - nvmlIntNvLinkDeviceType_t deviceType; - ret = driverApi->nvmlDeviceGetNvLinkRemoteDeviceType_( - devices[idx], link, &deviceType); - if (ret != NVML_SUCCESS) { - // We've exhausted the NVLinks connected to this device. - // This error is benign. There doesn't seem to be a reliable - // way to obtain the maximum link value that can be passed to - // the API, so we simply increment the link value until the - // API fails or we hit a predefined maximum value. - break; - } - // Remote device is GPU - if (deviceType == NVML_NVLINK_DEVICE_TYPE_GPU) { - nvmlPciInfo_t pciInfo; - ret = driverApi->nvmlDeviceGetNvLinkRemotePciInfo_v2_( - devices[idx], link, &pciInfo); - if (ret != NVML_SUCCESS) { - // Unexpected error. Return an empty NvlMesh - return {}; - } - auto it = busIdToRank.find(pciInfo.busId); - if (it != busIdToRank.end()) { - if (idx != it->second) { - nvlMesh[idx][it->second] += 1; - } - } - // Remote device is NVSwitch - } else if (deviceType == NVML_NVLINK_DEVICE_TYPE_SWITCH) { - switchLinkCount[idx] += 1; - } - } - } - // Process NVSwitch connections. For simplicity, we assume - // all NVSwitches are interconnected. - for (size_t i = 0; i < worldSize; ++i) { - for (size_t j = 0; j < worldSize; ++j) { - if (i == j) { - continue; + for (size_t srcRank = 0; srcRank < kMaxDevices; ++srcRank) { + for (size_t dstRank = 0; dstRank < kMaxDevices; ++dstRank) { + if (srcRank < rankToDeviceIdx.size() && + dstRank < rankToDeviceIdx.size()) { + nvlMesh[srcRank][dstRank] = + connectivity + ->matrix[rankToDeviceIdx[srcRank]][rankToDeviceIdx[dstRank]]; } - nvlMesh[i][j] += std::min(switchLinkCount[i], switchLinkCount[j]); } } return nvlMesh; -#else - return {}; -#endif -} - -/** - * Determine if the devices form a hybrid cube mesh - * topology given a NvlMesh. - */ -static bool isHybridCubeMesh(const NvlMesh nvlMesh) { - std::array numNeighbors = {}; - for (size_t i = 0; i < kMaxDevices; ++i) { - for (size_t j = 0; j < kMaxDevices; ++j) { - if (nvlMesh[i][j] > 0) { - numNeighbors[i] += 1; - } - } - } - for (size_t i = 0; i < kMaxDevices; ++i) { - // TODO: this is insufficent and needs revisit - if (numNeighbors[i] != 4) { - return false; - } - } - return true; } /** @@ -193,18 +56,10 @@ static Topology detectTopology(const NvlMesh nvlMesh, size_t worldSize) { LOG(INFO) << "IntraNodeComm: Topology::FULLY_CONNECTED"; return Topology::FULLY_CONNECTED; } - if (worldSize == kMaxDevices && getHybridCubeMesh(nvlMesh) != std::nullopt) { - LOG(INFO) << "IntraNodeComm: Topology::HYBRID_CUBE_MESH"; - return Topology::HYBRID_CUBE_MESH; - } LOG(INFO) << "IntraNodeComm: Topology::UNKNOWN"; return Topology::UNKNOWN; }; -//////////////////////////////////////////////////////////////////////////////// -// Rendezvous and Initialization -//////////////////////////////////////////////////////////////////////////////// - IntraNodeComm::IntraNodeComm( c10::intrusive_ptr store, size_t rank, @@ -213,8 +68,7 @@ IntraNodeComm::IntraNodeComm( : store_(std::move(store)), rank_(rank), worldSize_(worldSize), - bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize), - barrierReady_(at::cuda::CUDAEvent()) {} + bufferSize_(bufferSize.has_value() ? *bufferSize : kDefaultBufferSize) {} IntraNodeComm::~IntraNodeComm() { if (!isInitialized_) { @@ -281,30 +135,21 @@ bool IntraNodeComm::rendezvous() { } deviceIdx_ = at::cuda::current_device(); - c10::cuda::CUDAGuard guard(deviceIdx_); - // First hand shake: exchange hostname and device bus ID + // Exchange hostname and device bus ID struct DevInfo { char hostname[HOST_NAME_MAX + 1]; - char busId[80]; + int deviceIdx; }; DevInfo devInfo{}; gethostname(devInfo.hostname, sizeof(devInfo.hostname)); - cudaDeviceProp prop{}; - AT_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceIdx_)); - snprintf( - devInfo.busId, - sizeof(devInfo.busId), - NVML_DEVICE_PCI_BUS_ID_FMT, - prop.pciDomainID, - prop.pciBusID, - prop.pciDeviceID); + devInfo.deviceIdx = deviceIdx_; auto peerDevInfos = storeAllGather(store_, "handshake-0", rank_, worldSize_, devInfo); - std::vector rankToBusId; + std::vector rankToDeviceIdx; for (const auto& info : peerDevInfos) { if (strcmp(info.hostname, peerDevInfos.front().hostname) != 0) { LOG(WARNING) << "Aborting IntraNodeComm::rendezvous because some " @@ -312,39 +157,24 @@ bool IntraNodeComm::rendezvous() { << info.hostname << ", " << devInfo.hostname << ")"; return false; } - rankToBusId.emplace_back(info.busId); - } - - // Verify unique devices - { - std::unordered_set uniqueBusIds(rankToBusId.begin(), rankToBusId.end()); - TORCH_CHECK( - uniqueBusIds.size() == worldSize_, - "IntraNodeComm::rendezvous: detected overlapping devices across ranks. " - "Please properly set device via torch.cuda.set_device() before " - "initiating rendezvous."); + rankToDeviceIdx.emplace_back(info.deviceIdx); } // Query nvlink connection - auto nvlMesh = getNvlMesh(rankToBusId); + auto nvlMesh = getNvlMesh(rankToDeviceIdx); // Detect topology - Topology topology = detectTopology(nvlMesh, worldSize_); + topology_ = detectTopology(nvlMesh, worldSize_); + if (topology_ != Topology::FULLY_CONNECTED) { + return false; + } auto groupName = "IntraNodeComm" + std::to_string(intraNodeCommIdx++); set_group_info(groupName, rank_, worldSize_, store_); auto allocator = get_allocator(c10::DeviceType::CUDA); symmetricMemoryPtr_ = allocator->alloc(bufferSize_, deviceIdx_, groupName); symmetricMemory_ = allocator->rendezvous(symmetricMemoryPtr_); - TORCH_CHECK(symmetricMemory_->get_signal_pad_size() >= kP2pStateSize); - - void* topoInfo = initTopoInfo(topology, nvlMesh, rank_); - isInitialized_ = true; - topology_ = topology; - p2pStatesDev_ = symmetricMemory_->get_signal_pad_ptrs_dev(); - buffersDev_ = symmetricMemory_->get_buffer_ptrs_dev(); - topoInfo_ = topoInfo; return true; #endif return false; diff --git a/torch/csrc/distributed/c10d/intra_node_comm.cu b/torch/csrc/distributed/c10d/intra_node_comm.cu index fa40ea7cdb5f3..a32c64281512c 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.cu +++ b/torch/csrc/distributed/c10d/intra_node_comm.cu @@ -1,393 +1,13 @@ #include -#include -#include -#include - #include namespace c10d { namespace intra_node_comm { -static constexpr size_t kBytesPerThread = 16; -static constexpr size_t kMaxAllReduceBlocks = 24; -static constexpr size_t kThreadsPerBlock = 1024; -static constexpr size_t kWarpSize = 32; - -static constexpr size_t kHcmThreshBytes = 256 * 1024; static constexpr size_t kOneShotThreshBytes = 256 * 1024; static constexpr size_t kTwoShotThreshBytes = 10 * 1024 * 1024; -#if defined(USE_ROCM) -using __nv_bfloat162 = uint32_t; -#endif - -struct __align__(16) bf16x8 { - __nv_bfloat162 vals[4]; -}; - -#define DEVICE_INLINE __device__ inline __attribute__((always_inline)) - -DEVICE_INLINE __nv_bfloat162 -bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { -#if defined(USE_ROCM) - CUDA_KERNEL_ASSERT(false); - return 0; -#elif (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) - CUDA_KERNEL_ASSERT(false); - __nv_bfloat162 res; - return res; -#else - return __hadd2(x, y); -#endif -} - -DEVICE_INLINE bf16x8 add_bf16x8(bf16x8 a, bf16x8 b) { - bf16x8 c; - c.vals[0] = bf16hadd2(a.vals[0], b.vals[0]); - c.vals[1] = bf16hadd2(a.vals[1], b.vals[1]); - c.vals[2] = bf16hadd2(a.vals[2], b.vals[2]); - c.vals[3] = bf16hadd2(a.vals[3], b.vals[3]); - return c; -} - -template -DEVICE_INLINE void streamLoad128(bf16x8& val, const T* addr) { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) - CUDA_KERNEL_ASSERT(false); -#else - unsigned long long int low, high; - asm("ld.global.v2.u64 {%0, %1}, [%2];" : "=l"(low), "=l"(high) : "l"(addr)); - reinterpret_cast(&val)[0] = low; - reinterpret_cast(&val)[1] = high; -#endif -} - -__device__ inline void streamStore128(at::BFloat16* addr, const bf16x8& val) { -#if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) - CUDA_KERNEL_ASSERT(false); -#else - unsigned long long int low, high; - low = reinterpret_cast(&val)[0]; - high = reinterpret_cast(&val)[1]; - asm("st.global.v2.u64 [%0], {%1, %2};" : : "l"(addr), "l"(low), "l"(high)); -#endif -} - -template -DEVICE_INLINE void load128(bf16x8& val, const T* addr) { - *reinterpret_cast(&val) = reinterpret_cast(addr)[0]; -} - -template -DEVICE_INLINE void store128(T* addr, const bf16x8& val) { - *reinterpret_cast(addr) = reinterpret_cast(&val)[0]; -} - -//////////////////////////////////////////////////////////////////////////////// -// Fully Connected Algos -//////////////////////////////////////////////////////////////////////////////// - -struct P2pState { - uint32_t signals0[kMaxAllReduceBlocks][kMaxDevices]; - uint32_t signals1[kMaxAllReduceBlocks][kMaxDevices]; -}; - -static_assert(sizeof(P2pState) <= kP2pStateSize); - -template -static __global__ void oneShotAllReduceKernel( - at::BFloat16* input, - size_t N, - size_t N_aligned, - P2pState** p2pStates, - at::BFloat16** buffers, - size_t rank) { - const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16); - const size_t offset = - (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread; - const size_t stride = blockDim.x * gridDim.x * numelPerThread; - - barrier_and_acquire_previous_kernel_writes( - reinterpret_cast(p2pStates), rank, kWorldSize); - - // The source pointers. Distributed round-robin for the different warps - const at::BFloat16* srcs[kWorldSize]; - size_t srcRanks[kWorldSize]; -#pragma unroll kWorldSize - for (int ii = 0; ii < kWorldSize; ++ii) { - int srcRank = (rank + ii) % kWorldSize; - srcs[ii] = buffers[srcRank]; - srcRanks[ii] = srcRank; - } - - for (size_t i = offset; i < N_aligned; i += stride) { - bf16x8 vals[kWorldSize]; -#pragma unroll kWorldSize - for (size_t ii = 0; ii < kWorldSize; ++ii) { - // Make sure the values in `vals` are ordered by rank so that the - // reduction results are consistent across ranks. - streamLoad128(vals[srcRanks[ii]], &srcs[ii][i]); - } - - bf16x8 sums; - memset(reinterpret_cast(&sums), 0, sizeof(sums)); - -#pragma unroll kWorldSize - for (size_t ii = 0; ii < kWorldSize; ++ii) { - sums = add_bf16x8(sums, vals[ii]); - } - if constexpr (kAligned) { - streamStore128(&input[i], sums); - } else { - for (size_t ii = 0; ii < numelPerThread; ++ii) { - if (i + ii < N) { - input[i + ii] = reinterpret_cast(&sums)[ii]; - } - } - } - } - - barrier(reinterpret_cast(p2pStates), rank, kWorldSize); -} - -template -static __launch_bounds__(1024) __global__ void twoShotAllReduceKernel( - at::BFloat16* input, - size_t N_aligned, - P2pState** p2pStates, - at::BFloat16** buffers, - size_t rank) { - const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16); - const size_t offset = - (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread; - const size_t stride = blockDim.x * gridDim.x * numelPerThread; - const size_t N_per_rank = N_aligned / kWorldSize; - const size_t N_start = N_per_rank * rank; - - // Wait for all other ranks to enter the kernel - barrier_and_acquire_previous_kernel_writes( - reinterpret_cast(p2pStates), rank, kWorldSize); - - // The source pointers. Distributed round-robin for the different warps - at::BFloat16* srcs[kWorldSize]; - size_t srcRanks[kWorldSize]; -#pragma unroll kWorldSize - for (int ii = 0; ii < kWorldSize; ++ii) { - int srcRank = (rank + ii) % kWorldSize; - srcs[ii] = buffers[srcRank]; - srcRanks[ii] = srcRank; - } - - for (size_t i = offset; i < N_per_rank; i += stride) { - bf16x8 vals[kWorldSize]; -#pragma unroll kWorldSize - for (size_t ii = 0; ii < kWorldSize; ++ii) { - // Make sure the values in `vals` are ordered by rank so that the - // reduction results are consistent across ranks. - int srcRank = (ii + kWorldSize - rank) % kWorldSize; - streamLoad128(vals[srcRank], &srcs[ii][N_start + i]); - } - - bf16x8 sums; - memset(reinterpret_cast(&sums), 0, sizeof(sums)); - -#pragma unroll kWorldSize - for (size_t ii = 0; ii < kWorldSize; ++ii) { - sums = add_bf16x8(sums, vals[ii]); - } - streamStore128(&srcs[0][N_start + i], sums); - // Store local sums into input now so we can avoid - // a global memory access later for it. - streamStore128(&input[N_start + i], sums); - } - __syncthreads(); - - barrier_and_acquire_previous_kernel_writes( - reinterpret_cast(p2pStates), rank, kWorldSize); - - for (size_t i = offset; i < N_per_rank; i += stride) { -#pragma unroll kWorldSize - 1 - for (size_t ii = 1; ii < kWorldSize; ++ii) { - size_t k = N_start + i + (srcRanks[ii] - rank) * N_per_rank; - bf16x8 val; - streamLoad128(val, &srcs[ii][k]); - streamStore128(&input[k], val); - } - } - - barrier(reinterpret_cast(p2pStates), rank, kWorldSize); -} - -//////////////////////////////////////////////////////////////////////////////// -// Hybrid Cube Mesh Algos -//////////////////////////////////////////////////////////////////////////////// - -/** - * NOTE [hybrid cube mesh] - * - * In a hybrid cube mesh topology, every device has exactly 4 neighbors - * (directly connected via NVLink). For every device X, it has exactly 1 - * neighbor Y that is a neighbor of the 3 non-neighbor of X. We call Y the - * relay neighbor of X. This property is symmetrical: X is also guaranteed to - * be the relay neighbor of Y. - * - * With this property, we can perform a variant of one-shot allreduce algo that - * only moves data across NVLinks: - * - * - Each device one-shot allreduce among itself and 3 non-relay neighbors. - * - Each device exchange data with its relay neighbor. - * - * HybridCubeMesh is a data structure for describing the topology: - * - * - hcm[X][0:3] are the 3 neighbors of X. - * - hcm[X][3] is the relay neighbor of X. - * - For load balancing purpose, we also ensure that if hcm[X][k] = Y, - * hcm[Y][k] = X. - */ -std::optional getHybridCubeMesh(NvlMesh nvlMesh) { - std::array, kMaxDevices> neighbors = {}; - std::array neighborMasks = {}; - for (size_t i = 0; i < kMaxDevices; ++i) { - for (size_t j = 0; j < kMaxDevices; ++j) { - if (nvlMesh[i][j] > 0) { - neighbors[i].insert(j); - neighborMasks[i] |= (1ul << j); - } - } - } - HybridCubeMesh hcm = {}; - for (auto& row : hcm) { - row.fill(-1); - } - // A topology is an HCM if: - // - Every device has exactly 4 neighbors. - // - For every device, it has exactly 1 relay neighbor that is - // a neighbor of the 3 non-neighbor of the device. - for (size_t i = 0; i < kMaxDevices; ++i) { - if (neighbors[i].size() != 4) { - return std::nullopt; - } - // Condition 1: check the number of neighbors - std::vector relayNeighbors; - for (size_t j = 0; j < kMaxDevices; ++j) { - if ((neighborMasks[i] & neighborMasks[j]) == 0) { - relayNeighbors.push_back(j); - } - } - // Condition 2: check the number of relay neighbors - if (relayNeighbors.size() != 1) { - return std::nullopt; - } - neighbors[i].erase(relayNeighbors[0]); - hcm[i][3] = relayNeighbors[0]; - } - - for (size_t i = 0; i < kMaxDevices; ++i) { - for (size_t k = 0; k < 3; ++k) { - // We can only fill hcm[i][k] with j if hcm[j][k] is not filled - for (size_t j : neighbors[i]) { - if (hcm[j][k] == -1) { - hcm[i][k] = j; - hcm[j][k] = i; - break; - } - } - TORCH_CHECK(hcm[i][k] != -1); - neighbors[i].erase(hcm[i][k]); - } - } - return hcm; -} - -template -static __global__ void hybridCubeMeshAllReduceKernel( - at::BFloat16* input, - size_t N, - size_t N_aligned, - P2pState** p2pStates, - at::BFloat16** buffers, - int hcmInfo[4], - size_t bufferSize, - size_t rank) { - const size_t numelPerThread = kBytesPerThread / sizeof(at::BFloat16); - const size_t offset = - (blockDim.x * blockIdx.x + threadIdx.x) * numelPerThread; - const size_t stride = blockDim.x * gridDim.x * numelPerThread; - const int relayRank = hcmInfo[3]; - - // Wait for HCM neigbors to enter the kernel - if (threadIdx.x < 3) { - auto targetRank = hcmInfo[threadIdx.x]; - release_signal(&p2pStates[targetRank]->signals0[blockIdx.x][rank]); - acquire_signal(&p2pStates[rank]->signals0[blockIdx.x][targetRank]); - } - __syncthreads(); - - const at::BFloat16* srcs[4] = { - buffers[rank], - buffers[hcmInfo[0]], - buffers[hcmInfo[1]], - buffers[hcmInfo[2]], - }; - // Use the half second half of the buffer as relay - at::BFloat16* localRelay = - buffers[rank] + (bufferSize / sizeof(at::BFloat16) / 2); - at::BFloat16* remoteRelay = - buffers[relayRank] + (bufferSize / sizeof(at::BFloat16) / 2); - - for (size_t i = offset; i < N_aligned; i += stride) { - bf16x8 vals[4]; - -#pragma unroll 4 - for (size_t ii = 0; ii < 4; ++ii) { - streamLoad128(vals[ii], &srcs[ii][i]); - } - - bf16x8 sums; - memset(reinterpret_cast(&sums), 0, sizeof(sums)); - -#pragma unroll 4 - for (size_t ii = 0; ii < 4; ++ii) { - sums = add_bf16x8(sums, vals[ii]); - } - // Cached store for local sums - store128(&localRelay[i], sums); - } - __syncthreads(); - - if (threadIdx.x == 0) { - release_signal(&p2pStates[relayRank]->signals0[blockIdx.x][rank]); - acquire_signal(&p2pStates[rank]->signals0[blockIdx.x][relayRank]); - } - __syncthreads(); - - for (size_t i = offset; i < N_aligned; i += stride) { - bf16x8 localSum, remoteSum; - // Cached load for local sums - load128(localSum, &localRelay[i]); - streamLoad128(remoteSum, &remoteRelay[i]); - localSum = add_bf16x8(localSum, remoteSum); - if constexpr (kAligned) { - streamStore128(&input[i], localSum); - } else { - for (size_t ii = 0; ii < numelPerThread; ++ii) { - if (i + ii < N) { - input[i + ii] = reinterpret_cast(&localSum)[ii]; - } - } - } - } -} - -static inline size_t divUp(uint32_t a, uint32_t b) { - return (a + b - 1) / b; -} - -static inline size_t alignUp(uint32_t a, uint32_t b) { - return divUp(a, b) * b; -} - static void checkInput(const at::Tensor& input, int deviceIdx) { TORCH_CHECK( input.dtype() == at::kBFloat16, @@ -402,31 +22,6 @@ static void checkInput(const at::Tensor& input, int deviceIdx) { input.get_device()); } -static void getLaunchConfig( - size_t N_aligned, - size_t elemSize, - dim3& blocks, - dim3& threads) { - blocks = dim3(0, 1, 1); - threads = dim3(0, 1, 1); - - const auto numelPerThread = kBytesPerThread / elemSize; - const auto numelPerWarp = numelPerThread * kWarpSize; - TORCH_CHECK(N_aligned % numelPerThread == 0); - TORCH_CHECK(N_aligned % numelPerWarp == 0); - if (N_aligned < numelPerThread * kThreadsPerBlock) { - threads.x = N_aligned / numelPerWarp * kWarpSize; - blocks.x = 1; - } else { - auto warpsRequired = N_aligned / numelPerWarp; - auto threadsRequired = N_aligned / numelPerThread; - blocks.x = - std::min(divUp(threadsRequired, kThreadsPerBlock), kMaxAllReduceBlocks); - auto warpsPerBlock = divUp(warpsRequired, blocks.x); - threads.x = std::min(kThreadsPerBlock, warpsPerBlock * kWarpSize); - } -} - bool isIntraNodeCommSupported() { #if defined(USE_ROCM) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) return false; @@ -435,80 +30,23 @@ bool isIntraNodeCommSupported() { #endif } -void* initP2pState() { - void* state = nullptr; - AT_CUDA_CHECK(cudaMalloc(&state, sizeof(P2pState))); - AT_CUDA_CHECK(cudaMemset(state, 0, sizeof(P2pState))); - return state; -} - -void* initTopoInfo(Topology topology, NvlMesh nvlMesh, size_t rank) { - void* topoInfo = nullptr; - if (topology != Topology::HYBRID_CUBE_MESH) { - return topoInfo; - } - auto hcm = getHybridCubeMesh(nvlMesh); - int hcmInfo[4]; - std::copy((*hcm)[rank].begin(), (*hcm)[rank].begin() + 4, hcmInfo); - AT_CUDA_CHECK(cudaMalloc(&topoInfo, sizeof(hcmInfo))); - AT_CUDA_CHECK( - cudaMemcpy(topoInfo, hcmInfo, sizeof(hcmInfo), cudaMemcpyHostToDevice)); - return topoInfo; -} - at::Tensor IntraNodeComm::oneShotAllReduce( const at::Tensor& input, at::cuda::CUDAStream& stream) { checkInput(input, deviceIdx_); - const size_t numelPerWarp = - kBytesPerThread / input.element_size() * kWarpSize; - const size_t N_aligned = alignUp(input.numel(), numelPerWarp); - const bool isAligned = (N_aligned == static_cast(input.numel())); - TORCH_CHECK(N_aligned <= bufferSize_ / input.element_size()); - - dim3 blocks, threads; - getLaunchConfig(N_aligned, input.element_size(), blocks, threads); - - at::cuda::OptionalCUDAGuard guard(input.get_device()); + auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("symm_mem::one_shot_all_reduce_out", "") + .typed(); - AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs()[rank_], - input.data_ptr(), - input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, - stream)); + auto symmMemTensor = at::from_blob( + symmetricMemoryPtr_, + input.sizes(), + at::TensorOptions().dtype(input.dtype()).device(input.device())); -#define X(kWorldSize, kAligned) \ - if (worldSize_ == kWorldSize) { \ - oneShotAllReduceKernel \ - <<>>( \ - input.data_ptr(), \ - input.numel(), \ - N_aligned, \ - reinterpret_cast(p2pStatesDev_), \ - reinterpret_cast(buffersDev_), \ - rank_); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } - -#define DISPATCH_ALL_WORLD_SIZES(kAligned) \ - X(2, kAligned); \ - X(3, kAligned); \ - X(4, kAligned); \ - X(5, kAligned); \ - X(6, kAligned); \ - X(7, kAligned); \ - X(8, kAligned); - - if (isAligned) { - DISPATCH_ALL_WORLD_SIZES(true); - } else { - DISPATCH_ALL_WORLD_SIZES(false); - } - -#undef DISPATCH_ALL_WORLD_SIZES -#undef X + symmMemTensor.copy_(input); + op.call(symmMemTensor, "sum", "", input); return input; } @@ -517,126 +55,42 @@ at::Tensor IntraNodeComm::twoShotAllReduce( at::cuda::CUDAStream& stream) { checkInput(input, deviceIdx_); - size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize; - size_t N_aligned = alignUp(input.numel(), worldSize_ * numelPerWarp); - size_t N_per_rank = N_aligned / worldSize_; - TORCH_CHECK(N_aligned <= bufferSize_ / input.element_size()); - - dim3 blocks, threads; - getLaunchConfig(N_per_rank, input.element_size(), blocks, threads); - - auto output = N_aligned == static_cast(input.numel()) - ? input - : input.new_empty(N_aligned); - - at::cuda::OptionalCUDAGuard guard(input.get_device()); - AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs()[rank_], - input.data_ptr(), - input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, - stream)); - -#define X(kWorldSize) \ - if (worldSize_ == kWorldSize) { \ - twoShotAllReduceKernel<<>>( \ - output.data_ptr(), \ - N_aligned, \ - reinterpret_cast(p2pStatesDev_), \ - reinterpret_cast(buffersDev_), \ - rank_); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); \ - } - X(2); - X(3); - X(4); - X(5); - X(6); - X(7); - X(8); -#undef X - - if (output.data_ptr() != input.data_ptr()) { - AT_CUDA_CHECK(cudaMemcpyAsync( - input.data_ptr(), - output.data_ptr(), - input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, - stream)); - } - return input; -} - -at::Tensor IntraNodeComm::hybridCubeMeshAllReduce( - const at::Tensor& input, - at::cuda::CUDAStream& stream) { - checkInput(input, deviceIdx_); - - size_t numelPerWarp = kBytesPerThread / input.element_size() * kWarpSize; - size_t N_aligned = alignUp(input.numel(), numelPerWarp); - TORCH_CHECK(N_aligned * 2 <= bufferSize_ / input.element_size()); - - dim3 blocks, threads; - getLaunchConfig(N_aligned, input.element_size(), blocks, threads); + auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("symm_mem::two_shot_all_reduce_", "") + .typed(); - at::cuda::OptionalCUDAGuard guard(input.get_device()); - AT_CUDA_CHECK(cudaMemcpyAsync( - symmetricMemory_->get_buffer_ptrs()[rank_], - input.data_ptr(), - input.numel() * input.element_size(), - cudaMemcpyDeviceToDevice, - stream)); + auto symmMemTensor = at::from_blob( + symmetricMemoryPtr_, + input.sizes(), + at::TensorOptions().dtype(input.dtype()).device(input.device())); -#define X(kAligned) \ - hybridCubeMeshAllReduceKernel<<>>( \ - input.data_ptr(), \ - input.numel(), \ - N_aligned, \ - reinterpret_cast(p2pStatesDev_), \ - reinterpret_cast(buffersDev_), \ - static_cast(topoInfo_), \ - bufferSize_, \ - rank_); \ - C10_CUDA_KERNEL_LAUNCH_CHECK(); - - if (N_aligned == static_cast(input.numel())) { - X(true); - } else { - X(false); - } -#undef X + symmMemTensor.copy_(input); + op.call(symmMemTensor, "sum", ""); + input.copy_(symmMemTensor); return input; } AllReduceAlgo IntraNodeComm::selectAllReduceAlgo(const at::Tensor& input) { - // Only support bf16 for now - if (input.dtype() != at::kBFloat16) { + // Only support float and bf16 for now + if (input.dtype() != at::kBFloat16 && input.dtype() != at::kFloat) { return AllReduceAlgo::NONE; } - const auto inputSize = input.numel() * input.element_size(); - const auto bytesPerWarp = kBytesPerThread * kWarpSize; + const auto inputSize = + static_cast(input.numel() * input.element_size()); + const size_t ptrAlignment = get_alignment( + static_cast(input.storage_offset() * input.element_size())); + const size_t sizeAlignment = get_alignment(inputSize); + const size_t alignment = std::min(ptrAlignment, sizeAlignment); - if (topology_ == Topology::HYBRID_CUBE_MESH) { - TORCH_CHECK( - worldSize_ == 8, "hyperCubeAllReduce only supports exactly 8 GPUs"); - const auto hcmInputSize = alignUp(inputSize, bytesPerWarp); - const auto hcmBufferSizeReq = hcmInputSize * 2; - if (hcmInputSize <= kHcmThreshBytes && hcmBufferSizeReq <= bufferSize_) { - return AllReduceAlgo::HCM; - } - } if (topology_ == Topology::FULLY_CONNECTED) { - const auto oneShotInputSize = alignUp(inputSize, bytesPerWarp); - const auto oneShotBufferSizeReq = oneShotInputSize; - if (oneShotInputSize <= kOneShotThreshBytes && - oneShotBufferSizeReq <= bufferSize_) { + // Both symm_mem::one_shot_all_reduce and symm_mem::two_shot_all_reduce_ + // currently requires the input to be at least 4-bytes aligned. + if (alignment >= 4 && inputSize <= kOneShotThreshBytes && + inputSize <= bufferSize_) { return AllReduceAlgo::ONE_SHOT; } - - const auto twoShotInputSize = alignUp(inputSize, bytesPerWarp * worldSize_); - const auto twoShotBufferSizeReq = twoShotInputSize; - if (twoShotInputSize <= kTwoShotThreshBytes && - twoShotBufferSizeReq <= bufferSize_) { + if (alignment >= 4 && inputSize <= kTwoShotThreshBytes && + inputSize <= bufferSize_) { return AllReduceAlgo::TWO_SHOT; } } @@ -652,15 +106,11 @@ at::Tensor IntraNodeComm::allReduce( // We don't care about overflowing. ++usageCounter; auto stream = at::cuda::getCurrentCUDAStream(); - c10::cuda::CUDACachingAllocator::recordStream( - input.storage().data_ptr(), stream); switch (algo) { case AllReduceAlgo::ONE_SHOT: return oneShotAllReduce(input, stream); case AllReduceAlgo::TWO_SHOT: return twoShotAllReduce(input, stream); - case AllReduceAlgo::HCM: - return hybridCubeMeshAllReduce(input, stream); default: C10_THROW_ERROR(ValueError, "IntraNodeComm: invalid algo"); } @@ -670,42 +120,5 @@ int64_t getIntraNodeCommUsageCounter() { return usageCounter; } -static __global__ void barrierKernel( - P2pState** p2pStates, - uint64_t mask, - size_t rank, - size_t worldSize) { - if (threadIdx.x < worldSize && (mask & (1ULL << threadIdx.x))) { - auto targetRank = threadIdx.x; - release_signal(&p2pStates[targetRank]->signals0[0][rank]); - acquire_signal(&p2pStates[rank]->signals0[0][targetRank]); - } -} - -void IntraNodeComm::barrier(std::optional> ranks) { - barrierReady_.block(at::cuda::getCurrentCUDAStream()); - if (!ranks.has_value()) { - ranks = std::vector(worldSize_); - std::iota(ranks->begin(), ranks->end(), 0); - } - uint64_t mask = 0; - for (const auto& r : ranks.value()) { - TORCH_CHECK(r >= 0 && r < static_cast(worldSize_)); - mask |= (1ULL << r); - } - barrierKernel<<<1, kWarpSize, 0, at::cuda::getCurrentCUDAStream()>>>( - reinterpret_cast(p2pStatesDev_), mask, rank_, worldSize_); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - barrierReady_.record(); -} - -at::Tensor IntraNodeComm::getBuffer( - size_t rank, - const std::vector& sizes, - c10::ScalarType dtype, - int64_t storageOffset) { - return symmetricMemory_->get_buffer(rank, sizes, dtype, storageOffset); -} - } // namespace intra_node_comm } // namespace c10d diff --git a/torch/csrc/distributed/c10d/intra_node_comm.hpp b/torch/csrc/distributed/c10d/intra_node_comm.hpp index 37fe285cb929e..4c31149de44c1 100644 --- a/torch/csrc/distributed/c10d/intra_node_comm.hpp +++ b/torch/csrc/distributed/c10d/intra_node_comm.hpp @@ -1,7 +1,6 @@ #pragma once #include -#include #include #include #include @@ -13,22 +12,18 @@ using namespace c10d::symmetric_memory; constexpr size_t kMaxDevices = 8; constexpr size_t kDefaultBufferSize = 10ull * 1024 * 1024; -constexpr size_t kP2pStateSize = 2048; using NvlMesh = std::array, kMaxDevices>; -using HybridCubeMesh = std::array, kMaxDevices>; enum class Topology : uint8_t { UNKNOWN = 0, FULLY_CONNECTED = 1, - HYBRID_CUBE_MESH = 2 }; enum class AllReduceAlgo : uint8_t { NONE = 0, ONE_SHOT = 1, TWO_SHOT = 2, - HCM = 3 }; // NOTE: this class will be be removed soon in favor of SymmetricMemory @@ -51,14 +46,6 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { */ bool rendezvous(); - Topology getTopology() { - return topology_; - } - - size_t getBufferSize() { - return bufferSize_; - } - /** * Selects a AllReduceAlgo that we think will outperform nccl. * Returns AllReduceAlgo::NONE if we don't think we can outperform nccl. @@ -67,17 +54,6 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { at::Tensor allReduce(const at::Tensor& input, AllReduceAlgo algo); - /** - * Perform a barrier among the specified ranks. - */ - void barrier(std::optional> ranks = std::nullopt); - - at::Tensor getBuffer( - size_t rank, - const std::vector& sizes, - c10::ScalarType dtype, - int64_t storageOffset); - private: at::Tensor oneShotAllReduce( const at::Tensor& input, @@ -87,64 +63,26 @@ class TORCH_API IntraNodeComm : public c10::intrusive_ptr_target { const at::Tensor& input, at::cuda::CUDAStream& stream); - at::Tensor hybridCubeMeshAllReduce( - const at::Tensor& input, - at::cuda::CUDAStream& stream); - c10::intrusive_ptr store_; size_t rank_; size_t worldSize_; size_t bufferSize_; - at::cuda::CUDAEvent barrierReady_; /** * Members initialized after rendezvous */ bool isInitialized_ = false; - int deviceIdx_; + int deviceIdx_{0}; Topology topology_ = Topology::UNKNOWN; void* symmetricMemoryPtr_ = nullptr; c10::intrusive_ptr symmetricMemory_ = nullptr; - void* p2pStatesDev_{}; - void* buffersDev_{}; - void* topoInfo_{}; }; -/** - * NOTE [IntraNodeComm Stream Semantics] - * - * ProcessGroupNCCL launches kernels differently from the conventional PyTorch - * CUDA semantics: it always launches collective kernels onto a dedicated - * communication stream. Therefore, it needs to: - * - * - Synchronize the calling stream and the comm stream. - * - Ensure the memory safety of the operands (via record_stream or stashing). - * - Synchronize the waiting stream with the comm stream. - * - * Unconditionally performing these tasks makes sense when we expect most of the - * communication to benefit from compute/comm overlap. However, IntraNodeComm - * primarily aims to optimize small, latency-sensitive, blocking communication, - * in which the overhead incurred by the above steps can be quite pronounced. - * - * Thus, IntraNodeComm follows the conventional PyTorch CUDA semantics and - * launches kernels onto the stream specified by the user. Although the user - * can perform neccessary synchronization via wait_stream, to provide a UX - * consistent to that of ProcessGroupNCCL, the neccessary stream - * synchronization can also be performed via IntraNodeWork::wait(). - */ class IntraNodeCommWork : public c10d::Work { public: - IntraNodeCommWork() : c10d::Work() { - event_.record(); - } - bool wait(std::chrono::milliseconds timeout = kNoTimeout) override { - event_.block(at::cuda::getCurrentCUDAStream()); return true; } - - private: - at::cuda::CUDAEvent event_; }; TORCH_API int64_t getIntraNodeCommUsageCounter(); diff --git a/torch/csrc/distributed/c10d/logger.cpp b/torch/csrc/distributed/c10d/logger.cpp index 48f8786842f01..a43e428e899e0 100644 --- a/torch/csrc/distributed/c10d/logger.cpp +++ b/torch/csrc/distributed/c10d/logger.cpp @@ -61,7 +61,7 @@ Logger::Logger(std::shared_ptr reducer) ddp_logging_data_ = std::make_unique(); } -c10::once_flag log_graph_static_flag; +static c10::once_flag log_graph_static_flag; void Logger::log_if_graph_static(bool is_static) { c10::call_once(log_graph_static_flag, [this, is_static]() { @@ -116,7 +116,7 @@ void Logger::set_env_variables() { void Logger::set_parameter_stats() { // The number of parameter tensors ddp_logging_data_->ints_map["num_parameter_tensors"] = - reducer_->params_.size(); + static_cast(reducer_->params_.size()); // Total parameters size (Bytes) ddp_logging_data_->ints_map["total_parameter_size_bytes"] = 0; // Parameters' data types, there may be multiple data diff --git a/torch/csrc/distributed/c10d/logging.h b/torch/csrc/distributed/c10d/logging.h index a7cc82f702eea..6b15aa358f261 100644 --- a/torch/csrc/distributed/c10d/logging.h +++ b/torch/csrc/distributed/c10d/logging.h @@ -12,8 +12,7 @@ #include #include -namespace c10d { -namespace detail { +namespace c10d::detail { enum class LogLevel { Trace, Debug, Info, Warning, Error }; @@ -24,8 +23,7 @@ std::string formatLogMessage(fmt::string_view fmt, T&&... args) { return fmt::vformat(fmt, fmt::make_format_args(args...)); } -} // namespace detail -} // namespace c10d +} // namespace c10d::detail #define C10D_ERROR(...) \ if (c10d::detail::isLogLevelEnabled(c10d::detail::LogLevel::Error)) \ diff --git a/torch/csrc/distributed/c10d/python_comm_hook.cpp b/torch/csrc/distributed/c10d/python_comm_hook.cpp index c5b24e01fb515..adf73452bd7b4 100644 --- a/torch/csrc/distributed/c10d/python_comm_hook.cpp +++ b/torch/csrc/distributed/c10d/python_comm_hook.cpp @@ -7,6 +7,7 @@ namespace c10d { +// NOLINTNEXTLINE(bugprone-exception-escape) PythonCommHook::~PythonCommHook() { py::gil_scoped_acquire ag; state_.dec_ref(); diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp index 6c5f7a79ff9fb..bf21bab37ce3f 100644 --- a/torch/csrc/distributed/c10d/reducer.cpp +++ b/torch/csrc/distributed/c10d/reducer.cpp @@ -1044,11 +1044,11 @@ void Reducer::mark_bucket_ready(size_t bucket_index) { } void Reducer::install_futures( - c10::List> futs) { + const c10::List>& futs) { // Append instead of overwrite so that this method can be called multiple // times in one iteration. if (!installed_futures_) { - installed_futures_ = std::move(futs); + installed_futures_ = futs; } else { installed_futures_->append(futs); } @@ -1698,7 +1698,7 @@ void Reducer::runGradCallbackForVariable( cb(variable.mutable_grad()); } else { // Under distributed autograd - context_ptr->runGradCallbackForVariable(variable, std::move(cb)); + context_ptr->runGradCallbackForVariable(variable, cb); } #endif } @@ -1759,15 +1759,17 @@ void Reducer::sync_bucket_indices( num_buckets = indices_accessor[indices_accessor_Index]; // Broadcast bucket_sizes - auto bucket_sizes_tensor = at::empty({(int64_t)num_buckets}, at::kInt); + auto bucket_sizes_tensor = + at::empty({static_cast(num_buckets)}, at::kInt); auto bucket_sizes_accessor = bucket_sizes_tensor.accessor(); for (const auto i : c10::irange(num_buckets)) { // For rank != 0, it is possible that local num buckets bucket_sizes.size() // is smaller than broadcasted num_buckets - bucket_sizes_accessor[i] = - bucket_sizes.at(std::min(i, (bucket_sizes.size() - 1))); + bucket_sizes_accessor[static_cast(i)] = static_cast( + bucket_sizes.at(std::min(i, (bucket_sizes.size() - 1)))); } - auto bucket_sizes_tensor_device = at::empty({(int64_t)num_buckets}, options); + auto bucket_sizes_tensor_device = + at::empty({static_cast(num_buckets)}, options); bucket_sizes_tensor_device.copy_(bucket_sizes_tensor, /*non_blocking=*/true); std::vector bucket_sizes_tensor_list = { bucket_sizes_tensor_device}; @@ -2238,7 +2240,7 @@ void verify_params_across_processes( std::vector> param_size_output_tensors; param_size_output_tensors.emplace_back(); auto world_size = process_group->getSize(); - for (C10_UNUSED const auto i : c10::irange(world_size)) { + for ([[maybe_unused]] const auto i : c10::irange(world_size)) { param_size_output_tensors.front().emplace_back( at::empty_like(param_size_tensor)); } diff --git a/torch/csrc/distributed/c10d/reducer.hpp b/torch/csrc/distributed/c10d/reducer.hpp index aa3c40ae95bbf..e0f6b4570fa31 100644 --- a/torch/csrc/distributed/c10d/reducer.hpp +++ b/torch/csrc/distributed/c10d/reducer.hpp @@ -137,7 +137,8 @@ class TORCH_API Reducer { // Install futures that should be awaited at end of backwards. Currently these // are only used by user-defined custom buffer reduction hooks, but can be // generalized to any user-originating futures that need to be awaited. - void install_futures(c10::List> futs); + void install_futures( + const c10::List>& futs); // Returns true if we should rebuild buckets, else false. We only rebuild // buckets once after the first iteration and never rebuild them if diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp index db4519d7b2ad3..cad9630345cf5 100644 --- a/torch/csrc/distributed/c10d/socket.cpp +++ b/torch/csrc/distributed/c10d/socket.cpp @@ -206,25 +206,23 @@ std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len) { // if we can't resolve the hostname, display the IP address if (addr->sa_family == AF_INET) { struct sockaddr_in* psai = (struct sockaddr_in*)&addr; + // NOLINTNEXTLINE(*array*) char ip[INET_ADDRSTRLEN]; if (inet_ntop(addr->sa_family, &(psai->sin_addr), ip, INET_ADDRSTRLEN) != - NULL) { + nullptr) { return fmt::format("{}:{}", ip, psai->sin_port); } } else if (addr->sa_family == AF_INET6) { struct sockaddr_in6* psai = (struct sockaddr_in6*)&addr; + // NOLINTNEXTLINE(*array*) char ip[INET6_ADDRSTRLEN]; if (inet_ntop( addr->sa_family, &(psai->sin6_addr), ip, INET6_ADDRSTRLEN) != - NULL) { + nullptr) { return fmt::format("[{}]:{}", ip, psai->sin6_port); } } - - C10_THROW_ERROR( - DistNetworkError, - fmt::format( - "failed to format addr, unknown family={}", addr->sa_family)); + return "?UNKNOWN?"; } if (addr->sa_family == AF_INET) { return fmt::format("{}:{}", host, port); @@ -279,7 +277,7 @@ struct formatter { addr.ai_addr = addr_ptr; addr.ai_addrlen = addr_len; - auto remote = socket.remote(); + auto const& remote = socket.remote(); std::string remoteStr = remote ? *remote : "none"; return fmt::format_to( @@ -591,6 +589,11 @@ bool SocketListenOp::tryListen(int family) { } } + recordError( + "The server could not be initialized on any address for port={}, family={}", + port_, + family); + return false; } @@ -598,7 +601,7 @@ bool SocketListenOp::tryListen(const ::addrinfo& addr) { SocketImpl::Handle hnd = ::socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); if (hnd == SocketImpl::invalid_socket) { - recordError( + C10D_DEBUG( "The server socket cannot be initialized on {} {}.", addr, getSocketError()); @@ -820,7 +823,7 @@ bool SocketConnectOp::tryConnect(int family) { deadline_ = Clock::now() + opts_->connect_timeout(); - bool retry; // NOLINT(cppcoreguidelines-init-variables) + bool retry = false; do { retry = false; @@ -924,6 +927,11 @@ SocketConnectOp::ConnectResult SocketConnectOp::tryConnect( addr, err); + return ConnectResult::Retry; + } else if (err == std::errc::timed_out) { + C10D_WARNING( + "The server socket on {} has timed out, will retry.", addr, err); + return ConnectResult::Retry; } else { recordError( diff --git a/torch/csrc/distributed/c10d/socket.h b/torch/csrc/distributed/c10d/socket.h index de9bd6989c290..81659f11f049f 100644 --- a/torch/csrc/distributed/c10d/socket.h +++ b/torch/csrc/distributed/c10d/socket.h @@ -16,8 +16,7 @@ #include #include -namespace c10d { -namespace detail { +namespace c10d::detail { class SocketOptions { public: @@ -103,5 +102,4 @@ class Socket { std::unique_ptr impl_; }; -} // namespace detail -} // namespace c10d +} // namespace c10d::detail diff --git a/torch/csrc/distributed/c10d/socket_fmt.h b/torch/csrc/distributed/c10d/socket_fmt.h index 8c7832ebf933c..491d9241eaf97 100644 --- a/torch/csrc/distributed/c10d/socket_fmt.h +++ b/torch/csrc/distributed/c10d/socket_fmt.h @@ -22,11 +22,9 @@ as it exposes the underlying platform specific socket headers. #include #endif -namespace c10d { -namespace detail { +namespace c10d::detail { // Returns a human-readable representation of the given socket address. std::string formatSockAddr(const struct ::sockaddr* addr, socklen_t len); -} // namespace detail -} // namespace c10d +} // namespace c10d::detail diff --git a/torch/csrc/distributed/rpc/agent_utils.cpp b/torch/csrc/distributed/rpc/agent_utils.cpp index 8fc83dc54bd91..ab4ef317d6b6a 100644 --- a/torch/csrc/distributed/rpc/agent_utils.cpp +++ b/torch/csrc/distributed/rpc/agent_utils.cpp @@ -16,7 +16,7 @@ std::unordered_map collectNames( std::unordered_map nameToId; nameToId.reserve(worldSize); nameToId.emplace(selfName, selfId); - // NOLINTNEXTLINE(bugprone-too-small-loop-variable) + // NOLINTNEXTLINE(*loop*) for (worker_id_t workerId = 0; workerId < worldSize; ++workerId) { if (workerId == selfId) { continue; @@ -45,8 +45,7 @@ static std::vector splitString( const std::string& delim) { std::vector tokens; size_t start = 0; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - size_t end; + size_t end = 0; // Iterate through each delimiter while ((end = s.find(delim, start)) != std::string::npos) { tokens.emplace_back(s.substr(start, end - start)); diff --git a/torch/csrc/distributed/rpc/python_rpc_handler.cpp b/torch/csrc/distributed/rpc/python_rpc_handler.cpp index 0d737378ace8a..99dce71358329 100644 --- a/torch/csrc/distributed/rpc/python_rpc_handler.cpp +++ b/torch/csrc/distributed/rpc/python_rpc_handler.cpp @@ -23,7 +23,7 @@ constexpr auto kInternalModule = "torch.distributed.rpc.internal"; auto dur = std::chrono::duration_cast( \ std::chrono::high_resolution_clock::now() - startTime); \ RpcAgent::getCurrentRpcAgent()->addGilWaitTime(dur); \ - } // NOLINT + } // PythonTypeResolver that inherits from Script::Resolver to // support resolving types together with ScriptTypeParser. diff --git a/torch/csrc/distributed/rpc/script_call.h b/torch/csrc/distributed/rpc/script_call.h index 6fc247b3804e2..19e1871ead871 100644 --- a/torch/csrc/distributed/rpc/script_call.h +++ b/torch/csrc/distributed/rpc/script_call.h @@ -61,6 +61,7 @@ class TORCH_API ScriptCall : public RpcCommandBase { // an annotated torchscript function defined by users. std::optional qualifiedName_; std::vector stack_; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) const bool isAsyncExecution_; }; diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp index c624aa2c8a60e..2e1e54e312057 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp @@ -161,14 +161,14 @@ C10_DEFINE_REGISTRY_WITHOUT_WARNING( const std::string& TensorPipeAgent::guessAddress() { static const std::string uvAddress = []() { - char* ifnameEnv = std::getenv(kSocketIfnameEnvVar.c_str()); - if (ifnameEnv != nullptr) { + auto ifnameEnv = c10::utils::get_env(kSocketIfnameEnvVar.c_str()); + if (ifnameEnv.has_value()) { auto [error, result] = - tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv); + tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv.value()); if (error) { LOG(WARNING) << "Failed to look up the IP address for interface " - << ifnameEnv << " (" << error.what() << "), defaulting to " - << kDefaultUvAddress; + << ifnameEnv.value() << " (" << error.what() + << "), defaulting to " << kDefaultUvAddress; return kDefaultUvAddress; } return result; @@ -263,7 +263,7 @@ constexpr static int kNumUvThreads = 16; std::unique_ptr makeMultiplexedUvChannel() { std::vector> contexts; std::vector> listeners; - for (const auto laneIdx C10_UNUSED : c10::irange(kNumUvThreads)) { + for ([[maybe_unused]] const auto laneIdx : c10::irange(kNumUvThreads)) { auto context = tensorpipe::transport::uv::create(); std::string address = TensorPipeAgent::guessAddress(); contexts.push_back(std::move(context)); @@ -444,8 +444,8 @@ void TensorPipeAgent::startImpl() { } // Assign priorities in reverse order of occurrence in the vector, so that // a transport that comes before another receives a higher priority. - priority = - opts_.transports->size() - 1 - (iter - opts_.transports->begin()); + priority = static_cast(opts_.transports->size()) - 1 - + (iter - opts_.transports->begin()); } std::unique_ptr reg = TensorPipeTransportRegistry()->Create(key); @@ -474,7 +474,8 @@ void TensorPipeAgent::startImpl() { } // Assign priorities in reverse order of occurrence in the vector, so // that a channel that comes before another receives a higher priority. - priority = opts_.channels->size() - 1 - (iter - opts_.channels->begin()); + priority = static_cast(opts_.channels->size()) - 1 - + (iter - opts_.channels->begin()); } std::unique_ptr reg = TensorPipeChannelRegistry()->Create(key); diff --git a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp index 549b3c2706780..99c3b9a5963b5 100644 --- a/torch/csrc/distributed/rpc/tensorpipe_utils.cpp +++ b/torch/csrc/distributed/rpc/tensorpipe_utils.cpp @@ -140,9 +140,11 @@ std::tuple tensorpipeSerialize( buffers.type = std::make_unique(rpcMessage->type()); buffers.id = std::make_unique(rpcMessage->id()); // kTpMessageTypeIdx = 0 + // NOLINTNEXTLINE(modernize-use-emplace) tpMessage.payloads.push_back( tensorpipe::Message::Payload{buffers.type.get(), sizeof(MessageType)}); // kTpMessageIdIdx = 1 + // NOLINTNEXTLINE(modernize-use-emplace) tpMessage.payloads.push_back( tensorpipe::Message::Payload{buffers.id.get(), sizeof(int64_t)}); @@ -152,6 +154,7 @@ std::tuple tensorpipeSerialize( // it uses non-const pointers even though it doesn't modify them when writing. char* payloadPtr = buffers.payload.data(); // kTpMessagePayloadIdx = 2 + // NOLINTNEXTLINE(modernize-use-emplace) tpMessage.payloads.push_back( tensorpipe::Message::Payload{payloadPtr, buffers.payload.size()}); @@ -175,6 +178,7 @@ std::tuple tensorpipeSerialize( pickler.pushIValue(buffers.tensors); pickler.stop(); // kTpMessagePickleIdx = 3 + // NOLINTNEXTLINE(modernize-use-emplace) tpMessage.payloads.push_back(tensorpipe::Message::Payload{ buffers.pickle.data(), buffers.pickle.size()}); const std::vector& tensorDataVec = pickler.tensorData(); diff --git a/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.cpp index f545bf078468b..75b55bc801a06 100644 --- a/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.cpp +++ b/torch/csrc/distributed/rpc/testing/faulty_tensorpipe_agent.cpp @@ -27,8 +27,11 @@ FaultyTensorPipeAgent::FaultyTensorPipeAgent( std::move(reverseDeviceMaps), std::move(devices), std::move(callback)), + // NOLINTNEXTLINE(bugprone-use-after-move) numFailSends_(opts.numFailSends), + // NOLINTNEXTLINE(bugprone-use-after-move) messageTypesToFail_(parseMessagesToFailInput(opts.messagesToFail)), + // NOLINTNEXTLINE(bugprone-use-after-move) messageTypesToDelay_(parseMessagesToDelay(opts.messagesToDelay)) {} std::vector FaultyTensorPipeAgent::parseMessagesToFailInput( diff --git a/torch/csrc/distributed/rpc/testing/init.cpp b/torch/csrc/distributed/rpc/testing/init.cpp index 366e0f6312ec8..bc9541e56a49b 100644 --- a/torch/csrc/distributed/rpc/testing/init.cpp +++ b/torch/csrc/distributed/rpc/testing/init.cpp @@ -68,7 +68,7 @@ PyObject* faulty_agent_init(PyObject* _unused, PyObject* noargs) { module, "FaultyTensorPipeAgent", rpc_module.attr("TensorPipeAgent")) .def( py::init( - [](const c10::intrusive_ptr<::c10d::Store> store, + [](const c10::intrusive_ptr<::c10d::Store>& store, std::string name, worker_id_t rank, int world_size, diff --git a/torch/csrc/distributed/rpc/unpickled_python_call.cpp b/torch/csrc/distributed/rpc/unpickled_python_call.cpp index 1388891e8288b..733ad5cd51121 100644 --- a/torch/csrc/distributed/rpc/unpickled_python_call.cpp +++ b/torch/csrc/distributed/rpc/unpickled_python_call.cpp @@ -13,7 +13,7 @@ UnpickledPythonCall::UnpickledPythonCall( pythonUdf_ = pythonRpcHandler.deserialize(serializedPyObj); } -// NOTLINTNEXTLINE(bugprone-exception-escape) +// NOLINTNEXTLINE(bugprone-exception-escape) UnpickledPythonCall::~UnpickledPythonCall() { // explicitly setting PyObject* to nullptr to prevent py::object's dtor to // decref on the PyObject again. diff --git a/torch/csrc/dynamo/cache_entry.cpp b/torch/csrc/dynamo/cache_entry.cpp index bf89decf51930..2dc4bbece04b6 100644 --- a/torch/csrc/dynamo/cache_entry.cpp +++ b/torch/csrc/dynamo/cache_entry.cpp @@ -6,7 +6,7 @@ CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) : backend{backend} { - this->check_fn = guarded_code.attr("check_fn"); + this->guard_manager = guarded_code.attr("guard_manager"); this->code = guarded_code.attr("code"); this->compile_id = guarded_code.attr("compile_id"); py::object trace_annotation = guarded_code.attr("trace_annotation"); @@ -16,11 +16,8 @@ CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) } else { this->trace_annotation = "Unknown"; } - // TODO - clean this up when enable_cpp_guard_manager is True by default - if (py::hasattr(this->check_fn, "root")) { - this->root_mgr = torch::dynamo::convert_to_root_guard_manager( - this->check_fn.attr("root")); - } + this->root_mgr = torch::dynamo::convert_to_root_guard_manager( + this->guard_manager.attr("root")); } C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED( @@ -28,9 +25,9 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED( C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor") // NOLINTNEXTLINE(bugprone-exception-escape) CacheEntry::~CacheEntry() { - // prevent check_fn from use-after-free when invalidating - this->check_fn.attr("cache_entry") = py::none(); - this->check_fn.attr("extra_state") = py::none(); + // prevent guard_manager from use-after-free when invalidating + this->guard_manager.attr("cache_entry") = py::none(); + this->guard_manager.attr("extra_state") = py::none(); } C10_DIAGNOSTIC_POP() C10_DIAGNOSTIC_POP() diff --git a/torch/csrc/dynamo/cache_entry.h b/torch/csrc/dynamo/cache_entry.h index 7d1d92084444c..9747c0baa421a 100644 --- a/torch/csrc/dynamo/cache_entry.h +++ b/torch/csrc/dynamo/cache_entry.h @@ -18,11 +18,12 @@ of the cache is as follows: -> ExtraState -> CacheEntry (list) - -> check_fn + -> guard_manager (a wrapper that contains the actual guard manager at its +attr named root) -> code -> FrameState -CacheEntry is a linked list node containing the check_fn for guards +CacheEntry is a linked list node containing the guard_manager for guards and the optimized code. The FrameState is a PyDict that enables sharing between different frames. This @@ -41,8 +42,8 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED( C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor") typedef struct VISIBILITY_HIDDEN CacheEntry { // check the guards: lambda: : bool - py::object check_fn; - // modified user bytecode (protected by check_fn's guards) + py::object guard_manager; + // modified user bytecode (protected by guard_manager's guards) py::object code; // CompileId corresponding to this compilation py::object compile_id; diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index 69837edad59b3..2f7f364300105 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -224,10 +224,29 @@ struct LiftedIValueArgs { const std::optional& active_node_call_idx; }; +// Hold GIL while using +struct PyTLSWrapper { + PyTLSWrapper(PyObject* state) : state(state) {} + PyTLSWrapper(const PyTLSWrapper&) = delete; + PyTLSWrapper& operator=(const PyTLSWrapper&) = delete; + PyTLSWrapper(PyTLSWrapper&&) = default; + PyTLSWrapper& operator=(PyTLSWrapper&&) = default; + + static PyTLSWrapper create(); + + PyObject* get(std::string_view key) const; + + private: + PyObject* state; +}; + struct AutogradCompilerCall { - AutogradCompilerCall() - : tensor_args(active_node_call_idx), - lifted_ivalue_args(active_node_call_idx) {} + AutogradCompilerCall() = delete; + AutogradCompilerCall(PyTLSWrapper&& state) + : active_node_call_idx(std::nullopt), + tensor_args(active_node_call_idx), + lifted_ivalue_args(active_node_call_idx), + state(std::move(state)) {} void add_size_input(const c10::SymInt& s) { all_size_inputs.emplace_back( default_dyn_type, s.guard_int(__FILE__, __LINE__)); @@ -245,6 +264,7 @@ struct AutogradCompilerCall { active_node_call_idx = node_call_idx; } + std::optional active_node_call_idx; TensorArgs tensor_args; std::vector all_size_inputs; LiftedIValueArgs lifted_ivalue_args; @@ -252,9 +272,11 @@ struct AutogradCompilerCall { std::vector hooks; NodeCalls node_calls; SizeInput::DynType default_dyn_type = SizeInput::STATIC; + // NodeCall id of each size, only when verbose logging is enabled std::vector size_input_origins; - std::optional active_node_call_idx; + + const PyTLSWrapper state; }; class CompiledNodeArgs { diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index f8decb7d197b2..f253342845ad4 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -819,6 +819,10 @@ static PyObject* set_eval_frame_py(PyObject* dummy, PyObject* callback) { return set_eval_frame(callback, PyThreadState_GET()); } +static PyObject* get_eval_frame_callback_py(PyObject* dummy, PyObject* args) { + return eval_frame_callback_get(); +} + static PyObject* reset_code(PyObject* dummy, PyObject* code) { if (!PyCode_Check(code)) { DEBUG_TRACE0("arg error"); @@ -863,6 +867,7 @@ static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) { static PyMethodDef _methods[] = { {"set_eval_frame", set_eval_frame_py, METH_O, NULL}, + {"get_eval_frame_callback", get_eval_frame_callback_py, METH_NOARGS, NULL}, {"reset_code", reset_code, METH_O, NULL}, {"unsupported", unsupported, METH_VARARGS, NULL}, {"skip_code", skip_code, METH_O, NULL}, diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index 73e665e221b63..7ee7961096556 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -126,19 +126,13 @@ void lookup( if (valid) { try { - // TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is - // True by default - if (cache_entry.root_mgr != nullptr) { - valid = torch::dynamo::run_root_guard_manager( - cache_entry.root_mgr, f_locals); - } else { - valid = cache_entry.check_fn(locals).cast(); - } + valid = torch::dynamo::run_root_guard_manager( + cache_entry.root_mgr, f_locals); } catch (py::error_already_set& e) { if (guard_error_hook) { py::handle guard_error_hook_handle(guard_error_hook); guard_error_hook_handle( - cache_entry.check_fn, + cache_entry.guard_manager, cache_entry.code, locals, index, @@ -174,12 +168,12 @@ CacheEntry* create_cache_entry( auto new_iter = extra_state->cache_entry_list.begin(); new_iter->_owner = extra_state; new_iter->_owner_loc = new_iter; - // Set check_fn references to extra_state and CacheEntry + // Set guard_manager references to extra_state and CacheEntry // Warning: lifetime is controlled by C++! - py::handle check_fn = py::handle(guarded_code).attr("check_fn"); - check_fn.attr("cache_entry") = + py::handle guard_manager = py::handle(guarded_code).attr("guard_manager"); + guard_manager.attr("cache_entry") = py::cast(*new_iter, py::return_value_policy::reference); - check_fn.attr("extra_state") = + guard_manager.attr("extra_state") = py::cast(extra_state, py::return_value_policy::reference); return &*new_iter; } diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp index 9312a549da3c0..610bd99054bf4 100644 --- a/torch/csrc/dynamo/guards.cpp +++ b/torch/csrc/dynamo/guards.cpp @@ -1466,8 +1466,8 @@ class DYNAMIC_INDICES : public LeafGuard { } static PyObject* issubset_str = PyUnicode_InternFromString("issubset"); - PyObject* call_result = PyObject_CallMethodOneArg( - indices, issubset_str, _dynamic_indices.ptr()); // new ref + PyObject* call_result = PyObject_CallMethodObjArgs( + indices, issubset_str, _dynamic_indices.ptr(), nullptr); // new ref bool result = PyObject_IsTrue(call_result); Py_DECREF(call_result); Py_DECREF(indices); @@ -1693,6 +1693,15 @@ class GuardManager { // guards and does not change the fail count. For simplicity, we duplicate // the code here. virtual bool check_nopybind(PyObject* value) { // borrowed ref + + if (!this->check_leaf_guards_nopybind(value)) { + return false; + } + + return this->check_accessors_nopybind(value); + } + + bool check_leaf_guards_nopybind(PyObject* value) { // Iterate over leaf guards for (const auto& guard : _leaf_guards) { if (!guard->check_nopybind(value)) { // early exit @@ -1702,6 +1711,10 @@ class GuardManager { } } + return true; + } + + bool check_accessors_nopybind(PyObject* value) { bool matches_dict_tag = false; uint64_t new_tag = 0; if (_is_dict) { @@ -1754,6 +1767,7 @@ class GuardManager { // swapping). _dict_tag = new_tag; } + return result; } @@ -1762,6 +1776,19 @@ class GuardManager { virtual GuardDebugInfo check_verbose_nopybind( PyObject* value) { // borrowed ref int num_guards_executed = 0; + + const GuardDebugInfo& debug_info = + check_leaf_guards_verbose_nopybind(value, num_guards_executed); + if (!debug_info.result) { + return debug_info; + } + + return check_accessors_verbose_nopybind(value, num_guards_executed); + } + + GuardDebugInfo check_leaf_guards_verbose_nopybind( + PyObject* value, + int& num_guards_executed) { // Iterate over leaf guards for (const auto& guard : _leaf_guards) { const GuardDebugInfo& debug_info = guard->check_verbose_nopybind(value); @@ -1772,6 +1799,12 @@ class GuardManager { } } + return GuardDebugInfo(true, num_guards_executed); + } + + GuardDebugInfo check_accessors_verbose_nopybind( + PyObject* value, + int& num_guards_executed) { // Iterate over accessors for (const auto& accessor : _accessors) { const GuardDebugInfo& debug_info = @@ -1921,7 +1954,22 @@ class RootGuardManager : public GuardManager { _local_state = state; } - if (!GuardManager::check_nopybind(value)) { + if (!GuardManager::check_leaf_guards_nopybind(value)) { + _reset_relational_guard_state(); + return false; + } + + // Run accessor guards without TorchFunction enabled + // Dynamo should only be adding guards on values without + // torch function at this point, because if there + // was a torch function, we should've traced through it + const at::impl::TorchFunctionDisabledState old_state = + at::impl::PythonTorchFunctionTLS::get_disabled_state(); + at::impl::PythonTorchFunctionTLS::set_disabled_state( + at::impl::TorchFunctionDisabledState::ALL_DISABLED); + + if (!GuardManager::check_accessors_nopybind(value)) { + at::impl::PythonTorchFunctionTLS::set_disabled_state(old_state); _reset_relational_guard_state(); return false; } @@ -1929,10 +1977,13 @@ class RootGuardManager : public GuardManager { // Iterate over epilogue leaf guards. for (const auto& guard : _epilogue_lambda_guards) { if (!guard->check_nopybind(value)) { // early exit + at::impl::PythonTorchFunctionTLS::set_disabled_state(old_state); _reset_relational_guard_state(); return false; } } + + at::impl::PythonTorchFunctionTLS::set_disabled_state(old_state); _reset_relational_guard_state(); return true; } @@ -1953,13 +2004,33 @@ class RootGuardManager : public GuardManager { _local_state = state; } - GuardDebugInfo debug_info = GuardManager::check_verbose_nopybind(value); - if (!debug_info.result) { + int num_guards_executed = 0; + + // Run leaf guards + // This includes the GlobalStateGuard and the Torch Function Mode stack + // guard, which require Torch Function to be in its unmodified state + const GuardDebugInfo& debug_info_leaf = + GuardManager::check_leaf_guards_verbose_nopybind( + value, num_guards_executed); + + if (!debug_info_leaf.result) { _reset_relational_guard_state(); - return debug_info; + return debug_info_leaf; } - int num_guards_executed = debug_info.num_guards_executed; + const at::impl::TorchFunctionDisabledState old_state = + at::impl::PythonTorchFunctionTLS::get_disabled_state(); + at::impl::PythonTorchFunctionTLS::set_disabled_state( + at::impl::TorchFunctionDisabledState::ALL_DISABLED); + const GuardDebugInfo& debug_info_accessors = + GuardManager::check_accessors_verbose_nopybind( + value, num_guards_executed); + + if (!debug_info_accessors.result) { + at::impl::PythonTorchFunctionTLS::set_disabled_state(old_state); + _reset_relational_guard_state(); + return debug_info_accessors; + } // Iterate over epilogue leaf guards for (const auto& guard : _epilogue_lambda_guards) { @@ -1967,11 +2038,13 @@ class RootGuardManager : public GuardManager { guard->check_verbose_nopybind(value); num_guards_executed++; if (!tmp_debug_info.result) { + at::impl::PythonTorchFunctionTLS::set_disabled_state(old_state); _reset_relational_guard_state(); return GuardDebugInfo( false, tmp_debug_info.verbose_code_parts, num_guards_executed); } } + at::impl::PythonTorchFunctionTLS::set_disabled_state(old_state); _reset_relational_guard_state(); return GuardDebugInfo(true, num_guards_executed); } diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index 5993c25caace1..16a3f1e2c9736 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -67,7 +67,7 @@ void initDynamoBindings(PyObject* torch) { auto m = py::handle(eval_frame).cast(); py::class_(m, "_CacheEntry") - .def_readonly("check_fn", &CacheEntry::check_fn) + .def_readonly("guard_manager", &CacheEntry::guard_manager) .def_readonly("code", &CacheEntry::code) .def_readonly("compile_id", &CacheEntry::compile_id) .def_readonly("trace_annotation", &CacheEntry::trace_annotation) diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index f3539b8782e9c..024603270f787 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -69,13 +69,17 @@ static PyObject* convert_hook_list(std::vector& inputs) { return pyinput; } +// see https://github.com/pytorch/pytorch/pull/34845 +static void throw_python_error() { + python_error err; + err.persist(); + // NOLINTNEXTLINE(misc-throw-by-value-catch-by-reference) + throw err; +} + static PyObject* check(PyObject* pyresult) { if (C10_UNLIKELY(pyresult == nullptr)) { - // see https://github.com/pytorch/pytorch/pull/34845 - python_error err; - err.persist(); - // NOLINTNEXTLINE(misc-throw-by-value-catch-by-reference) - throw err; + throw_python_error(); } return pyresult; } @@ -84,20 +88,56 @@ static void check(bool result) { if (C10_UNLIKELY(!result)) check(nullptr); } - -// snapshot of python verbose logging toggle -static PyObject* python_verbose_logger = nullptr; -struct VerboseLogger { - static std::optional maybe_create() { - if (python_verbose_logger == nullptr) { - return std::nullopt; +struct PythonLogger { + PythonLogger() = delete; + explicit PythonLogger(PyObject* logger) : logger_(logger) { + TORCH_INTERNAL_ASSERT(logger_ != nullptr); + } + + enum Level : unsigned int { + DEBUG = 0, + INFO = 1, + WARNING = 2, + ERROR = 3, + CRITICAL = 4, + COUNT // Keep this as the last enum + }; + + // must be called while GIL is held + void log(Level level, std::string_view msg) const { + THPObjectPtr pymethod(PyUnicode_FromString(levelNames_[level].data())); + TORCH_INTERNAL_ASSERT(pymethod != nullptr); + THPObjectPtr pyfunc(PyObject_GetAttr(logger_, pymethod.get())); + if (pyfunc == nullptr) { + throw_python_error(); + } + PyObject* result = PyObject_CallFunction(pyfunc.get(), "s", msg.data()); + if (result == nullptr) { + throw_python_error(); } - return VerboseLogger(); } - void verbose_log_fn(std::string_view msg) const { - TORCH_CHECK(python_verbose_logger != nullptr); - check(PyObject_CallFunction(python_verbose_logger, "s", msg.data())); + private: + static constexpr std::array levelNames_ = { + "debug", // Level::DEBUG + "info", // Level::INFO + "warning", // Level::WARNING + "error", // Level::ERROR + "critical" // Level::CRITICAL + }; + + // Note: logger_ must stay valid for the lifetime of this object + PyObject* logger_; +}; + +struct VerboseLogger : public PythonLogger { + VerboseLogger(PyObject* vlogger) : PythonLogger(vlogger) {} + + static std::optional maybe_create(PyObject* vlogger) { + if (vlogger == Py_None) { + return std::nullopt; + } + return VerboseLogger(vlogger); } void log_node_check( @@ -137,7 +177,7 @@ struct VerboseLogger { } } oss << "]"; - verbose_log_fn(oss.str()); + log(PythonLogger::DEBUG, oss.str()); } void log_dynamic_shapes_check(size_t size_idx) const { @@ -149,10 +189,10 @@ struct VerboseLogger { TORCH_CHECK(it != cumulative_sizes_per_node.end()); size_t start_idx = it == cumulative_sizes_per_node.begin() ? 0 : std::prev(it)->first; - verbose_log_fn( + log(PythonLogger::DEBUG, "Cache miss due to changed shapes: marking size idx " + - std::to_string(size_idx - start_idx) + " of " + it->second + - " as dynamic"); + std::to_string(size_idx - start_idx) + " of " + it->second + + " as dynamic"); } // track which size index belongs to which node @@ -324,8 +364,22 @@ struct InputBuffers : public std::unordered_map { } }; -static PyObject* the_autograd_compiler = nullptr; -static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args); +/* static */ PyTLSWrapper PyTLSWrapper::create() { + TORCH_INTERNAL_ASSERT( + at::impl::ThreadLocalPythonObjects::contains("compiled_autograd_state")); + PyObject* compiled_autograd_state = + check(at::impl::ThreadLocalPythonObjects::get("compiled_autograd_state") + ->ptr(getPyInterpreter())); + return PyTLSWrapper(compiled_autograd_state); +} + +// Refer to fields in python class CompiledAutogradTLS +// May return Py_None +PyObject* PyTLSWrapper::get(std::string_view key) const { + return check(PyObject_GetAttrString(state, key.data())); +} + +static PyObject* notify_autograd_engine(PyObject* dummy, PyObject* args); static PyObject* clear_cache(PyObject* dummy, PyObject* args) { HANDLE_TH_ERRORS; @@ -343,28 +397,11 @@ static PyObject* is_cache_empty(PyObject* dummy, PyObject* args) { END_HANDLE_TH_ERRORS; } -static PyObject* set_verbose_logger(PyObject* dummy, PyObject* args) { - HANDLE_TH_ERRORS; - PyObject* logger = nullptr; - if (!PyArg_ParseTuple(args, "O", &logger)) { - Py_RETURN_FALSE; - } - - if (logger == Py_None) { - python_verbose_logger = nullptr; - } else { - python_verbose_logger = logger; - } - Py_RETURN_TRUE; - END_HANDLE_TH_ERRORS; -} - // NOLINTNEXTLINE(*array*) static PyMethodDef _methods[] = { - {"set_autograd_compiler", set_autograd_compiler, METH_VARARGS, nullptr}, + {"notify_autograd_engine", notify_autograd_engine, METH_NOARGS, nullptr}, {"clear_cache", clear_cache, METH_NOARGS, nullptr}, {"is_cache_empty", is_cache_empty, METH_NOARGS, nullptr}, - {"set_verbose_logger", set_verbose_logger, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; static struct PyModuleDef _module = { @@ -493,7 +530,8 @@ static TraceState call_begin_capture( static PyObject* call_end_capture(PyObject* self, const variable_list& inputs) { static PyObject* method_name = PyUnicode_InternFromString("end_capture"); THPObjectPtr pyinput(THPVariable_WrapList(inputs)); - return check(PyObject_CallMethodOneArg(self, method_name, pyinput.get())); + return check( + PyObject_CallMethodObjArgs(self, method_name, pyinput.get(), nullptr)); } struct ClosingTHPObjectPtr : public THPObjectPtr { @@ -504,7 +542,7 @@ struct ClosingTHPObjectPtr : public THPObjectPtr { return; } static PyObject* method_name = PyUnicode_InternFromString("close"); - if (PyObject_CallMethodNoArgs(get(), method_name) == nullptr) { + if (PyObject_CallMethodObjArgs(get(), method_name, nullptr) == nullptr) { PyErr_WriteUnraisable(get()); PyErr_Clear(); } @@ -523,7 +561,7 @@ CacheNode* _compiled_autograd_impl( THPObjectPtr* graph_arg_hooks) { std::unordered_map& dependencies = graph_task.dependencies_; std::vector> worklist{graph_root}; - AutogradCompilerCall compiler_call; + AutogradCompilerCall compiler_call(PyTLSWrapper::create()); for (const auto i : c10::irange(output_edges.size())) { compiler_call.node_calls @@ -538,7 +576,8 @@ CacheNode* _compiled_autograd_impl( check_exec_info ? graph_task.exec_info_.size() : dependencies.size() + 1); int i = 0; - std::optional vlogger = VerboseLogger::maybe_create(); + std::optional vlogger = + VerboseLogger::maybe_create(compiler_call.state.get("vlogger")); while (!worklist.empty()) { std::shared_ptr fn = std::move(worklist.back()); worklist.pop_back(); @@ -597,6 +636,8 @@ CacheNode* _compiled_autograd_impl( // TODO(jansel): some dynamic sizes seem to be ints not symints if (!cache->check_dynamic_sizes(compiler_call, vlogger)) { // cache miss, need to capture FX graph + PyObject* the_autograd_compiler = compiler_call.state.get("compiler"); + TORCH_INTERNAL_ASSERT(the_autograd_compiler != Py_None); ClosingTHPObjectPtr py_compiler( check(PyObject_CallNoArgs((the_autograd_compiler)))); @@ -736,6 +777,7 @@ CacheNode* _compiled_autograd_impl( return cache; } +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) struct LockGuardWithErrorLogs { LockGuardWithErrorLogs(std::mutex& mtx) : mtx_(mtx) { // Note: the standard allows try_lock to fail spuriously during races for @@ -794,28 +836,16 @@ variable_list compiled_autograd( return outputs; } -static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args) { +static PyObject* notify_autograd_engine(PyObject* dummy, PyObject* args) { HANDLE_TH_ERRORS; - PyObject* obj = nullptr; - if (!PyArg_ParseTuple(args, "O", &obj)) { - return nullptr; - } - - PyObject* prior = the_autograd_compiler; - if (obj == Py_None) { // disable - the_autograd_compiler = nullptr; // decref not needed due to `prior` + PyTLSWrapper state = PyTLSWrapper::create(); + PyObject* compiler = state.get("compiler"); + if (compiler == Py_None) { // disable Engine::set_compiled_autograd(nullptr); } else { // enable - Py_INCREF(obj); - the_autograd_compiler = obj; Engine::set_compiled_autograd(&compiled_autograd); } - - if (prior == nullptr) { - Py_RETURN_NONE; - } else { - return prior; - } + Py_RETURN_NONE; END_HANDLE_TH_ERRORS; } diff --git a/torch/csrc/inductor/aoti_package/model_package_loader.cpp b/torch/csrc/inductor/aoti_package/model_package_loader.cpp index 5f758787e658f..5443e729d1520 100644 --- a/torch/csrc/inductor/aoti_package/model_package_loader.cpp +++ b/torch/csrc/inductor/aoti_package/model_package_loader.cpp @@ -10,17 +10,11 @@ #include #include -// TODO: Investigate why this is necessary, but fixes build problems in FRL -#if __has_include("filesystem") -#include -namespace fs = std::filesystem; -#else -#include -namespace fs = std::experimental::filesystem; -#endif - #ifndef _WIN32 #include +#else +#include +namespace fs = std::filesystem; #endif // TODO: C++17 has the filesystem header, which may replace these @@ -42,7 +36,7 @@ bool file_exists(std::string& path) { #ifdef _WIN32 return fs::exists(path); #else - struct stat rc; + struct stat rc {}; return lstat(path.c_str(), &rc) == 0; #endif } diff --git a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp index 37c69ccfa813d..99dc7499022fe 100644 --- a/torch/csrc/inductor/aoti_runner/model_container_runner.cpp +++ b/torch/csrc/inductor/aoti_runner/model_container_runner.cpp @@ -5,17 +5,11 @@ #include #include -// TODO: Investigate why this is necessary, but fixes build problems in FRL -#if __has_include("filesystem") -#include -namespace fs = std::filesystem; -#else -#include -namespace fs = std::experimental::filesystem; -#endif - #ifndef _WIN32 #include +#else +#include +namespace fs = std::filesystem; #endif namespace { @@ -23,7 +17,7 @@ bool file_exists(std::string& path) { #ifdef _WIN32 return fs::exists(path); #else - struct stat rc; + struct stat rc {}; return lstat(path.c_str(), &rc) == 0; #endif } diff --git a/torch/csrc/inductor/aoti_runtime/device_utils.h b/torch/csrc/inductor/aoti_runtime/device_utils.h index 76731999968dd..5b1fc36c97ea4 100644 --- a/torch/csrc/inductor/aoti_runtime/device_utils.h +++ b/torch/csrc/inductor/aoti_runtime/device_utils.h @@ -38,12 +38,10 @@ using DeviceStreamType = cudaStream_t; throw std::runtime_error("CPU runtime error"); \ } -namespace torch { -namespace aot_inductor { +namespace torch::aot_inductor { using DeviceStreamType = void*; -} // namespace aot_inductor -} // namespace torch +} // namespace torch::aot_inductor #endif // USE_CUDA diff --git a/torch/csrc/inductor/aoti_runtime/model_container.h b/torch/csrc/inductor/aoti_runtime/model_container.h index 279e480530871..a712a9e3d0173 100644 --- a/torch/csrc/inductor/aoti_runtime/model_container.h +++ b/torch/csrc/inductor/aoti_runtime/model_container.h @@ -19,8 +19,7 @@ class AOTInductorModelContainer { AOTInductorModelContainer( size_t num_models, const std::string& device_str, - const std::optional& cubin_dir = std::nullopt) - : use_secondary_(false), constant_folded_(false) { + const std::optional& cubin_dir = std::nullopt) { constants_map_ = std::make_shared(); constants_array_ = std::make_shared>(); @@ -413,10 +412,10 @@ class AOTInductorModelContainer { // If true, // constants_map_secondary/constant_blob_secondary/constants_array_secondary // is being used. - bool use_secondary_; + bool use_secondary_{false}; // Determine whether we have ran constant folding - bool constant_folded_; + bool constant_folded_{false}; // Holds the mapping of constants to at::Tensor. // The underlying data of at::Tensor is in either constant_blob_ (for CUDA). diff --git a/torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h b/torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h index 998ef1fd1483c..e379d372ffaa0 100644 --- a/torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h +++ b/torch/csrc/inductor/aoti_torch/c/shim_mkldnn.h @@ -129,6 +129,58 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__linear_pointwise_binary( const char* attr, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__qlinear_pointwise_tensor( + AtenTensorHandle X, + AtenTensorHandle act_scale, + AtenTensorHandle act_zero_point, + AtenTensorHandle onednn_weight, + AtenTensorHandle weight_scales, + AtenTensorHandle weight_zero_points, + AtenTensorHandle* B, + double output_scale, + int64_t output_zero_point, + const int32_t* output_dtype, + const char* post_op_name, + const double** post_op_args, + int64_t post_op_args_len_, + const char* post_op_algorithm, + AtenTensorHandle* ret0); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu__qlinear_pointwise_binary_tensor( + AtenTensorHandle X, + AtenTensorHandle act_scale, + AtenTensorHandle act_zero_point, + AtenTensorHandle onednn_weight, + AtenTensorHandle weight_scales, + AtenTensorHandle weight_zero_points, + AtenTensorHandle* other, + AtenTensorHandle* B, + double output_scale, + int64_t output_zero_point, + const int32_t* output_dtype, + double other_scale, + int64_t other_zero_point, + const char* binary_post_op, + double binary_alpha, + const char* unary_post_op, + const double** unary_post_op_args, + int64_t unary_post_op_args_len_, + const char* unary_post_op_algorithm, + AtenTensorHandle* ret0); + +#if AT_MKL_ENABLED() + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__mkl_linear( + AtenTensorHandle X, + AtenTensorHandle W, + AtenTensorHandle origin_W, + AtenTensorHandle* B, + int64_t prepack_batch_size, + AtenTensorHandle* ret0); + +#endif // AT_MKL_ENABLED + #ifdef __cplusplus } // extern "C" #endif diff --git a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp index 26821c2445ba5..62365e676d63a 100644 --- a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp +++ b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp @@ -14,7 +14,7 @@ namespace torch::aot_inductor { void OSSProxyExecutor::prefill_stack_with_static_arguments( int index, - at::TypePtr schema_arg_type, + const at::TypePtr& schema_arg_type, const nlohmann::json& serialized_arg, OSSOpKernel& op_kernel) { auto& stack = op_kernel.stack_; @@ -33,7 +33,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments( } case c10::TypeKind::IntType: { TORCH_CHECK(serialized_arg_type == "as_int"); - stack.emplace_back(c10::IValue()); + stack.emplace_back(); dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); break; } @@ -41,7 +41,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments( TORCH_CHECK( serialized_arg_type == "as_int" || serialized_arg_type == "as_sym_int"); - stack.emplace_back(c10::IValue()); + stack.emplace_back(); dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); break; } @@ -107,14 +107,14 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments( TORCH_CHECK(serialized_arg_type == "as_ints"); dynamic_args.emplace_back( index, DynamicArgType::ListIntType, serialized_arg_val.size()); - stack.emplace_back(c10::IValue()); + stack.emplace_back(); } else if (schema_arg_type->isSubtypeOf(at::ListType::ofSymInts())) { TORCH_CHECK( serialized_arg_type == "as_ints" || serialized_arg_type == "as_sym_ints"); dynamic_args.emplace_back( index, DynamicArgType::ListIntType, serialized_arg_val.size()); - stack.emplace_back(c10::IValue()); + stack.emplace_back(); } else if (schema_arg_type->isSubtypeOf(at::ListType::ofFloats())) { TORCH_CHECK(serialized_arg_type == "as_floats"); std::vector ret; @@ -133,7 +133,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments( if (serialized_arg_type == "as_ints") { dynamic_args.emplace_back( index, DynamicArgType::ListIntType, serialized_arg_val.size()); - stack.emplace_back(c10::IValue()); + stack.emplace_back(); } else if (serialized_arg_type == "as_floats") { std::vector ret; for (const auto& arg : serialized_arg_val) { @@ -192,7 +192,7 @@ void OSSProxyExecutor::prefill_stack_with_static_arguments( schema_arg_type->castRaw()->getElementType(); if (serialized_arg_type == "as_none") { - stack.emplace_back(c10::nullopt); + stack.emplace_back(std::nullopt); if (inner_type->kind() == c10::TypeKind::TensorType) { // Tensor is None dynamic_args.emplace_back(index, DynamicArgType::TensorType, 0); @@ -259,7 +259,7 @@ void OSSProxyExecutor::get_output_info_from_serialized( auto& serialized_output_val = serialized_output.begin().value(); auto& schema_return = schema_returns[output_index]; - at::TypePtr schema_return_type = schema_return.real_type(); + const at::TypePtr& schema_return_type = schema_return.real_type(); switch (schema_return_type->kind()) { case c10::TypeKind::TensorType: { @@ -408,13 +408,13 @@ void OSSProxyExecutor::call_function( list_item_types.has_value(), "Could not find list of item types for optional tensor list input"); - for (std::string item_type : list_item_types.value()) { + for (const std::string& item_type : list_item_types.value()) { if (item_type == "as_tensor") { at::Tensor* tensor = tensor_handle_to_tensor_pointer( flatten_tensor_args[tensor_id++]); optional_tensor_list.emplace_back(*tensor); } else if (item_type == "as_none") { - optional_tensor_list.emplace_back(c10::nullopt); + optional_tensor_list.emplace_back(std::nullopt); } } stack[arg_index] = optional_tensor_list; @@ -422,6 +422,7 @@ void OSSProxyExecutor::call_function( } case DynamicArgType::ListIntType: { std::vector vals; + vals.reserve(length); for (int j = 0; j < length; j++) { vals.push_back(flatten_int_args[int_id++]); } @@ -468,10 +469,10 @@ void OSSProxyExecutor::call_function( schema_return.type()->kind() == c10::TypeKind::ListType && schema_return.type()->isSubtypeOf(at::ListType::ofTensors())) { auto tensors = stack[index++].toTensorList(); - for (size_t i = 0; i < tensors.size(); ++i) { + for (auto&& t : tensors) { at::Tensor* tensor = tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); - *tensor = tensors[i]; + *tensor = t; } } else { TORCH_CHECK( diff --git a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h index c1a0f9260edd5..d881866b5abaa 100644 --- a/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h +++ b/torch/csrc/inductor/aoti_torch/oss_proxy_executor.h @@ -82,7 +82,7 @@ class OSSProxyExecutor : public ProxyExecutor { private: void prefill_stack_with_static_arguments( int index, - at::TypePtr schema_arg_type, + const at::TypePtr& schema_arg_type, const nlohmann::json& serialized_arg, OSSOpKernel& op_kernel); diff --git a/torch/csrc/inductor/aoti_torch/shim_common.cpp b/torch/csrc/inductor/aoti_torch/shim_common.cpp index 1dd766b3a11a9..3b0292d5c1a3c 100644 --- a/torch/csrc/inductor/aoti_torch/shim_common.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_common.cpp @@ -47,29 +47,26 @@ #endif -#if __has_include("filesystem") -#include -namespace fs = std::filesystem; -#else -#include -namespace fs = std::experimental::filesystem; -#endif - #ifndef _WIN32 -#include #include #include #include +#include + +#else +#include +namespace fs = std::filesystem; #endif // HACK for failed builds in ARVR, where it cannot find these symbols within // std::experimental::filesystem namespace { std::string get_current_path() { -#if __has_include("filesystem") && !defined(__linux__) +#ifdef _WIN32 return fs::current_path().string(); #else - char currentPath[PATH_MAX]; + // NOLINTNEXTLINE(*array*) + char currentPath[PATH_MAX]{}; if (getcwd(currentPath, sizeof(currentPath)) != nullptr) { return std::string(currentPath); } else { @@ -79,16 +76,16 @@ std::string get_current_path() { } bool file_exists(std::string& path) { -#if __has_include("filesystem") && !defined(__linux__) +#ifdef _WIN32 return fs::exists(path); #else - struct stat rc; + struct stat rc {}; return lstat(path.c_str(), &rc) == 0; #endif } bool create_directories(const std::string& path) { -#if __has_include("filesystem") && !defined(__linux__) +#ifdef _WIN32 return fs::create_directories(path); #else if (mkdir(path.c_str(), 0777) == -1) { @@ -1055,11 +1052,11 @@ AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle( if (!file_exists(tmp_folder)) { std::cout << "aoti_torch_save_tensor_handle: Path does not exist, creating it..." - << tmp_folder << std::endl; + << tmp_folder << '\n'; if (!create_directories(tmp_folder)) { std::cout << "aoti_torch_save_tensor_handle: Error creating directory: " - << tmp_folder << std::endl; + << tmp_folder << '\n'; return; } } @@ -1068,11 +1065,11 @@ AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle( auto bytes = torch::jit::pickle_save(c10::IValue(*t)); std::ofstream fout(tensor_filepath_to_save, std::ios::out | std::ios::binary); - fout.write(bytes.data(), bytes.size()); + fout.write(bytes.data(), static_cast(bytes.size())); fout.close(); std::cout << "aoti_torch_save_tensor_handle: Saved tensor to " - << tensor_filepath_to_save << std::endl; + << tensor_filepath_to_save << '\n'; #endif // !defined(C10_MOBILE) } @@ -1087,7 +1084,7 @@ AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( std::cout << " " << msg; } std::cout << " " - << "]:" << std::endl; + << "]:" << '\n'; // Print exact tensor values for small size tensors const int64_t numel = t->numel(); @@ -1096,8 +1093,8 @@ AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( } // Print summary stats of the tensor - std::cout << "Number of elements: " << numel << std::endl; - std::cout << "Dtype: " << t->dtype() << std::endl; + std::cout << "Number of elements: " << numel << '\n'; + std::cout << "Dtype: " << t->dtype() << '\n'; if (numel > 0) { // torch/aten `mean()` function only supports float and complex dtypes // See: @@ -1109,24 +1106,24 @@ AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle( at::isComplexType(at::typeMetaToScalarType(t->dtype())); at::ScalarType float_dtype = is_complex_type ? at::kComplexFloat : at::kFloat; - std::cout << "Mean value: " << mean_value(float_dtype) << std::endl; + std::cout << "Mean value: " << mean_value(float_dtype) << '\n'; if (!is_complex_type) { // "min_all_cuda" function is not implemented for 'ComplexFloat' type. // (similar for max) Skip printing min/max value for complex type tensors // here If encountered complex dtypes (rare occasions), suggest to print // out the whole value of the tensor. - std::cout << "Min value: " << t->min().item() << std::endl; - std::cout << "Max value: " << t->max().item() << std::endl; + std::cout << "Min value: " << t->min().item() << '\n'; + std::cout << "Max value: " << t->max().item() << '\n'; } } - std::cout << "Device: " << t->device() << std::endl; - std::cout << "Size: " << t->sizes() << std::endl; - std::cout << "Stride: " << t->strides() << std::endl; - std::cout << "Layout: " << t->layout() << std::endl; - std::cout << "Is contiguous: " << t->is_contiguous() << std::endl; - std::cout << "Requires grad: " << t->requires_grad() << std::endl; - - std::cout << std::endl; + std::cout << "Device: " << t->device() << '\n'; + std::cout << "Size: " << t->sizes() << '\n'; + std::cout << "Stride: " << t->strides() << '\n'; + std::cout << "Layout: " << t->layout() << '\n'; + std::cout << "Is contiguous: " << t->is_contiguous() << '\n'; + std::cout << "Requires grad: " << t->requires_grad() << '\n'; + + std::cout << '\n'; } // ProxyExecutor diff --git a/torch/csrc/inductor/aoti_torch/shim_mkldnn.cpp b/torch/csrc/inductor/aoti_torch/shim_mkldnn.cpp index b43e2ff773bf7..d8912f95127af 100644 --- a/torch/csrc/inductor/aoti_torch/shim_mkldnn.cpp +++ b/torch/csrc/inductor/aoti_torch/shim_mkldnn.cpp @@ -9,6 +9,7 @@ #endif #include #include +#include using namespace torch::aot_inductor; @@ -269,4 +270,119 @@ AOTITorchError aoti_torch_cpu__linear_pointwise_binary( }); } +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__qlinear_pointwise_tensor( + AtenTensorHandle X, + AtenTensorHandle act_scale, + AtenTensorHandle act_zero_point, + AtenTensorHandle onednn_weight, + AtenTensorHandle weight_scales, + AtenTensorHandle weight_zero_points, + AtenTensorHandle* B, + double output_scale, + int64_t output_zero_point, + const int32_t* output_dtype, + const char* post_op_name, + const double** post_op_args, + int64_t post_op_args_len_, + const char* post_op_algorithm, + AtenTensorHandle* ret0) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::List> scalars_list; + scalars_list.reserve(post_op_args_len_); + for (int64_t i = 0; i < post_op_args_len_; i++) { + scalars_list.emplace_back(pointer_to_optional(post_op_args[i])); + } + + auto tmp_result = at::native::QLinearOnednn::run_pointwise_tensor( + *tensor_handle_to_tensor_pointer(X), + *tensor_handle_to_tensor_pointer(act_scale), + *tensor_handle_to_tensor_pointer(act_zero_point), + *tensor_handle_to_tensor_pointer(onednn_weight), + *tensor_handle_to_tensor_pointer(weight_scales), + *tensor_handle_to_tensor_pointer(weight_zero_points), + pointer_to_optional(B), + output_scale, + output_zero_point, + pointer_to_optional(output_dtype), + post_op_name, + scalars_list, + post_op_algorithm); + *ret0 = new_tensor_handle(std::move(tmp_result)); + }); +} + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu__qlinear_pointwise_binary_tensor( + AtenTensorHandle X, + AtenTensorHandle act_scale, + AtenTensorHandle act_zero_point, + AtenTensorHandle onednn_weight, + AtenTensorHandle weight_scales, + AtenTensorHandle weight_zero_points, + AtenTensorHandle* other, + AtenTensorHandle* B, + double output_scale, + int64_t output_zero_point, + const int32_t* output_dtype, + double other_scale, + int64_t other_zero_point, + const char* binary_post_op, + double binary_alpha, + const char* unary_post_op, + const double** unary_post_op_args, + int64_t unary_post_op_args_len_, + const char* unary_post_op_algorithm, + AtenTensorHandle* ret0) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::List> scalars_list; + scalars_list.reserve(unary_post_op_args_len_); + for (int64_t i = 0; i < unary_post_op_args_len_; i++) { + scalars_list.emplace_back(pointer_to_optional(unary_post_op_args[i])); + } + + auto tmp_result = at::native::QLinearOnednn::run_pointwise_binary_tensor( + *tensor_handle_to_tensor_pointer(X), + *tensor_handle_to_tensor_pointer(act_scale), + *tensor_handle_to_tensor_pointer(act_zero_point), + *tensor_handle_to_tensor_pointer(onednn_weight), + *tensor_handle_to_tensor_pointer(weight_scales), + *tensor_handle_to_tensor_pointer(weight_zero_points), + pointer_to_optional(other), + pointer_to_optional(B), + output_scale, + output_zero_point, + pointer_to_optional(output_dtype), + other_scale, + other_zero_point, + binary_post_op, + binary_alpha, + unary_post_op, + scalars_list, + unary_post_op_algorithm); + *ret0 = new_tensor_handle(std::move(tmp_result)); + }); +} + +#if AT_MKL_ENABLED() + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__mkl_linear( + AtenTensorHandle X, + AtenTensorHandle W, + AtenTensorHandle origin_W, + AtenTensorHandle* B, + int64_t prepack_batch_size, + AtenTensorHandle* ret0) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto tmp_result = at::native::mkl_linear( + *tensor_handle_to_tensor_pointer(X), + *tensor_handle_to_tensor_pointer(W), + *tensor_handle_to_tensor_pointer(origin_W), + pointer_to_optional(B), + prepack_batch_size); + *ret0 = new_tensor_handle(std::move(tmp_result)); + }); +} + +#endif // AT_MKL_ENABLED + #endif // AT_MKLDNN_ENABLED() diff --git a/torch/csrc/jit/OVERVIEW.md b/torch/csrc/jit/OVERVIEW.md index c2651750ebcc1..b15fe34d4397f 100644 --- a/torch/csrc/jit/OVERVIEW.md +++ b/torch/csrc/jit/OVERVIEW.md @@ -198,7 +198,7 @@ Note that the chosen overload is not shown in any way in the textual output. If Each node also has a set of attributes which are named integers, strings, floats, `Tensors`, subgraphs, or lists of these types. These are used by special primitive operators to encode additional data in the `Node`. For instance `prim::Constant` defines a compile-time constant value. For `Tensor` constants, it will have a single `Tensor` attribute with the name `attr::value` which contains the value of the constant. -Attributes are _rarely used_. Operators like convolution or matrix-multiply have no attributes and take their arguments through the input list. This includes things that might be typically thought of as constants, like the stride of the convolution. In PyTorch, any of this information is potentially a dynamic property of the program so `Nodes` are always encoded in a way that allows these values to be dynamically determined. However, we recognize that many inputs are almost always constants, so we make it easy to quickly check if an input is constant and get its value with `c10::optional Node::get(Symbol name)`, which returns an `IValue` (a concrete value for the input) in the case the node is constant and `nullopt` otherwise. +Attributes are _rarely used_. Operators like convolution or matrix-multiply have no attributes and take their arguments through the input list. This includes things that might be typically thought of as constants, like the stride of the convolution. In PyTorch, any of this information is potentially a dynamic property of the program so `Nodes` are always encoded in a way that allows these values to be dynamically determined. However, we recognize that many inputs are almost always constants, so we make it easy to quickly check if an input is constant and get its value with `std::optional Node::get(Symbol name)`, which returns an `IValue` (a concrete value for the input) in the case the node is constant and `nullopt` otherwise. ## Block ## diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp index a44eccb601ba3..9cd655ad930ef 100644 --- a/torch/csrc/jit/api/module.cpp +++ b/torch/csrc/jit/api/module.cpp @@ -451,7 +451,8 @@ IValue Module::create_class(const c10::QualifiedName& name, Stack stack) const { const auto classType = _ivalue()->compilation_unit()->get_class(c10::QualifiedName(name)); if (!classType) { - AT_ERROR( + TORCH_CHECK( + false, "Could not find class with name: '", name.qualifiedName(), "' in module."); diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h index 558dcdee57af2..8e9be1de48a5f 100644 --- a/torch/csrc/jit/api/module.h +++ b/torch/csrc/jit/api/module.h @@ -593,7 +593,7 @@ struct TORCH_API ModulePolicy { } // are we going to return everything? If so, we can optimize the calculate // of the size of the list. - static CONSTEXPR_EXCEPT_WIN_CUDA bool all_slots = false; + static constexpr bool all_slots = false; }; struct TORCH_API ParameterPolicy { @@ -606,7 +606,7 @@ struct TORCH_API ParameterPolicy { static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) { return typ->is_parameter(i) && v.isTensor(); } - static CONSTEXPR_EXCEPT_WIN_CUDA bool all_slots = false; + static constexpr bool all_slots = false; }; struct TORCH_API BufferPolicy { @@ -620,7 +620,7 @@ struct TORCH_API BufferPolicy { return typ->getAttribute(i)->isSubtypeOf(*TensorType::get()) && typ->is_buffer(i); } - static CONSTEXPR_EXCEPT_WIN_CUDA bool all_slots = false; + static constexpr bool all_slots = false; }; struct TORCH_API AttributePolicy { @@ -633,7 +633,7 @@ struct TORCH_API AttributePolicy { static bool valid(const ClassTypePtr& typ, size_t i, const IValue& v) { return true; } - static CONSTEXPR_EXCEPT_WIN_CUDA bool all_slots = true; + static constexpr bool all_slots = true; }; // take a Policy object, and make a version of it that returns the slot. diff --git a/torch/csrc/jit/api/object.h b/torch/csrc/jit/api/object.h index 2c0f7e3b164f0..8f0d11d718747 100644 --- a/torch/csrc/jit/api/object.h +++ b/torch/csrc/jit/api/object.h @@ -108,7 +108,7 @@ struct TORCH_API Object { if (auto method = find_method(name)) { return *method; } - AT_ERROR("Method '", name, "' is not defined."); + TORCH_CHECK(false, "Method '", name, "' is not defined."); } const std::vector get_methods() const { @@ -137,7 +137,7 @@ struct TORCH_API Object { prop.name, Method(_ivalue(), prop.getter), std::move(setter)}; } } - AT_ERROR("Property '", name, "' is not defined."); + TORCH_CHECK(false, "Property '", name, "' is not defined."); } const std::vector get_properties() const { diff --git a/torch/csrc/jit/frontend/error_report.cpp b/torch/csrc/jit/frontend/error_report.cpp index 67f461f953a28..1f87e5e0cd7ed 100644 --- a/torch/csrc/jit/frontend/error_report.cpp +++ b/torch/csrc/jit/frontend/error_report.cpp @@ -63,7 +63,7 @@ std::string ErrorReport::current_call_stack() { #ifndef C10_MOBILE return get_stacked_errors(calls); #else - AT_ERROR("Call stack not supported on mobile"); + TORCH_CHECK(false, "Call stack not supported on mobile"); #endif // C10_MOBILE } diff --git a/torch/csrc/jit/frontend/tracer.cpp b/torch/csrc/jit/frontend/tracer.cpp index 7a662f0a0d3ae..7e5100026bd2f 100644 --- a/torch/csrc/jit/frontend/tracer.cpp +++ b/torch/csrc/jit/frontend/tracer.cpp @@ -56,7 +56,8 @@ void genericAddOptionalInput( template void badArgType(const T& v) { - AT_ERROR( + TORCH_CHECK( + false, "Found an unsupported argument type in the JIT tracer: ", c10::demangle_type(), ". File a bug report."); @@ -323,7 +324,8 @@ Value* TracingState::getOutput(const IValue& iv, size_t i) { graph->insertNode(dict_node); return dict_node->output(); } else { - AT_ERROR( + TORCH_CHECK( + false, "Only tensors, lists, tuples of tensors, or dictionary of tensors can be output from traced functions"); } } @@ -416,7 +418,8 @@ static IValue addInput( return elems; } } else { - AT_ERROR( + TORCH_CHECK( + false, "Only tensors or (possibly nested) dict or tuples of tensors can be " "inputs to traced functions. Got ", type->repr_str()); @@ -472,7 +475,7 @@ std::pair, Stack> trace( // varied on subsequent invocations of the trace. Any other variables // will be treated as constants. if (isTracing()) { - AT_ERROR("Tracing can't be nested"); + TORCH_CHECK(false, "Tracing can't be nested"); } auto state = std::make_shared(); setTracingState(state); diff --git a/torch/csrc/jit/frontend/tracer.h b/torch/csrc/jit/frontend/tracer.h index 106a82e3a9ec3..885bb790fdf24 100644 --- a/torch/csrc/jit/frontend/tracer.h +++ b/torch/csrc/jit/frontend/tracer.h @@ -344,19 +344,21 @@ inline void addInputs( Node* n, const char* name, const std::vector& value) { - AT_ERROR("Tracing a list of bool type is currently not supported!"); + TORCH_CHECK(false, "Tracing a list of bool type is currently not supported!"); } template void addInputs(Node* n, const char* name, ArrayRef value) { - AT_ERROR("Tracing a list of arbitrary type is currently not supported!"); + TORCH_CHECK( + false, "Tracing a list of arbitrary type is currently not supported!"); } template void addInputs( Node* n, const char* name, const std::unordered_map& value) { - AT_ERROR("Tracing a dict of arbitrary types is currently not supported!"); + TORCH_CHECK( + false, "Tracing a dict of arbitrary types is currently not supported!"); } template @@ -387,7 +389,8 @@ template < std::decay_t, c10::intrusive_ptr>)>> void addOutput(Node* node, T&&) { - AT_ERROR( + TORCH_CHECK( + false, "Found an unsupported argument type ", c10::demangle_type(), " in the JIT tracer. File a bug report."); diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp index 52796de8e24d0..996521300ef65 100644 --- a/torch/csrc/jit/ir/ir.cpp +++ b/torch/csrc/jit/ir/ir.cpp @@ -292,8 +292,7 @@ SourceRange Node::sourceRange() const { } static std::ostream& indent(std::ostream& out, size_t level) { - for (const auto i : c10::irange(level)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(level)) { out << " "; } return out; @@ -747,9 +746,10 @@ void Block::destroy() { void Graph::cloneFrom(Graph& src) { auto env = [](Value* v) -> Value* { - AT_ERROR( + TORCH_CHECK( + false, "Graph::copy() encountered a use of a value " + v->debugName() + - " not in scope. Run lint!"); + " not in scope. Run lint!"); }; block()->cloneFrom(src.block(), env); } @@ -1768,8 +1768,7 @@ Node* Graph::createTupleSlice( new_vals.reserve(num_values); int64_t i = beg; - for (const auto j : c10::irange(num_values)) { - (void)j; // Suppress unused variable warning + for ([[maybe_unused]] const auto j : c10::irange(num_values)) { auto idx = insertConstant(IValue(static_cast(i))); auto tupleIndex = insertNode(createTupleIndex(tup, idx, tt->elements()[i])); @@ -1817,8 +1816,7 @@ Node* Graph::createListUnpack(Value* v, size_t size) { ListTypePtr list_type = v->type()->expect(); TypePtr elem_type = list_type->getElementType(); auto n = create(prim::ListUnpack, {v}, 0); - for (const auto i : c10::irange(size)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(size)) { n->addOutput()->setType(elem_type); } return n; diff --git a/torch/csrc/jit/mobile/compatibility/backport.cpp b/torch/csrc/jit/mobile/compatibility/backport.cpp index 5714842791d4a..d945d023a1a34 100644 --- a/torch/csrc/jit/mobile/compatibility/backport.cpp +++ b/torch/csrc/jit/mobile/compatibility/backport.cpp @@ -49,7 +49,7 @@ bool _backport_for_mobile( std::unique_ptr istream_adapter; file_stream.open(input_filename, std::ifstream::in | std::ifstream::binary); if (!file_stream) { - AT_ERROR("open file failed, file path: ", input_filename); + TORCH_CHECK(false, "open file failed, file path: ", input_filename); } auto writer_func = [&](const void* buf, size_t nbytes) -> size_t { out.write(static_cast(buf), nbytes); @@ -67,7 +67,7 @@ bool _backport_for_mobile( std::ifstream file_stream; file_stream.open(input_filename, std::ifstream::in | std::ifstream::binary); if (!file_stream) { - AT_ERROR("open file failed, file path: ", input_filename); + TORCH_CHECK(false, "open file failed, file path: ", input_filename); } PyTorchStreamWriter writer(output_filename); diff --git a/torch/csrc/jit/mobile/interpreter.cpp b/torch/csrc/jit/mobile/interpreter.cpp index 359ee2ac557ba..9eb2e7db2c59d 100644 --- a/torch/csrc/jit/mobile/interpreter.cpp +++ b/torch/csrc/jit/mobile/interpreter.cpp @@ -362,7 +362,7 @@ bool InterpreterState::run(Stack& stack) { frame.step(); } break; default: - AT_ERROR(toString(inst.op), " is invalid."); + TORCH_CHECK(false, toString(inst.op), " is invalid."); } if (!prev_value) { diff --git a/torch/csrc/jit/mobile/model_tracer/TensorUtils.cpp b/torch/csrc/jit/mobile/model_tracer/TensorUtils.cpp index 5184ba563ce23..76a5ee2b6eb93 100644 --- a/torch/csrc/jit/mobile/model_tracer/TensorUtils.cpp +++ b/torch/csrc/jit/mobile/model_tracer/TensorUtils.cpp @@ -32,7 +32,7 @@ void for_each_tensor_in_ivalue( for_each_tensor_in_ivalue(it.value(), func); } } else { - AT_ERROR("Unhandled type of IValue. Got ", iv.tagKind()); + TORCH_CHECK(false, "Unhandled type of IValue. Got ", iv.tagKind()); } } } // namespace torch::jit::mobile diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp index e4836bd55fd68..2f7470f980487 100644 --- a/torch/csrc/jit/mobile/module.cpp +++ b/torch/csrc/jit/mobile/module.cpp @@ -38,7 +38,7 @@ Method Module::get_method(const std::string& name) const { if (auto method = find_method(name)) { return *method; } - AT_ERROR("Method '", name, "' is not defined."); + TORCH_CHECK(false, "Method '", name, "' is not defined."); } bool Module::compareMethodSchemas( diff --git a/torch/csrc/jit/mobile/nnc/context.h b/torch/csrc/jit/mobile/nnc/context.h index 3976d28ec8944..b9633ea5bfafc 100644 --- a/torch/csrc/jit/mobile/nnc/context.h +++ b/torch/csrc/jit/mobile/nnc/context.h @@ -22,10 +22,10 @@ struct TORCH_API InputSpec { explicit InputSpec(const c10::IValue& value); // Serialize the spec into an IValue. - C10_NODISCARD c10::IValue serialize() const; + [[nodiscard]] c10::IValue serialize() const; // Check whether the input tensor adheres to the spec. - C10_NODISCARD bool validate(const at::Tensor& input) const; + [[nodiscard]] bool validate(const at::Tensor& input) const; std::vector sizes_; c10::ScalarType dtype_{c10::ScalarType::Undefined}; @@ -40,10 +40,10 @@ struct TORCH_API OutputSpec { explicit OutputSpec(const c10::IValue& value); // Serialize the spec into an IValue. - C10_NODISCARD c10::IValue serialize() const; + [[nodiscard]] c10::IValue serialize() const; // Allocate an output tensor in accordance with the spec. - C10_NODISCARD at::Tensor allocate() const; + [[nodiscard]] at::Tensor allocate() const; std::vector sizes_; c10::ScalarType dtype_{c10::ScalarType::Undefined}; @@ -84,7 +84,7 @@ struct TORCH_API MemoryPlan { explicit MemoryPlan(const c10::IValue& value); - C10_NODISCARD c10::IValue serialize() const; + [[nodiscard]] c10::IValue serialize() const; void allocate(ExecutionState* state) const; @@ -207,10 +207,10 @@ class TORCH_API CompilationUnit { // Serialize all registered functions into an IValue. The IValue will be save // into the compiled TorchScript model file ahead-of-time on the host, and // will be deserialized at runtime on the target device. - C10_NODISCARD c10::IValue serialize() const; + [[nodiscard]] c10::IValue serialize() const; // Execute a registered function. - C10_NODISCARD c10::impl::GenericList run( + [[nodiscard]] c10::impl::GenericList run( const c10::QualifiedName& function_name, const c10::impl::GenericList& inputs) const; @@ -218,7 +218,7 @@ class TORCH_API CompilationUnit { void register_function(std::unique_ptr fn); private: - C10_NODISCARD Function* find_function(const c10::QualifiedName& qn) const; + [[nodiscard]] Function* find_function(const c10::QualifiedName& qn) const; std::unordered_map> functions_; }; diff --git a/torch/csrc/jit/mobile/promoted_prim_ops.cpp b/torch/csrc/jit/mobile/promoted_prim_ops.cpp index 857cb30429102..b3d961a5a85de 100644 --- a/torch/csrc/jit/mobile/promoted_prim_ops.cpp +++ b/torch/csrc/jit/mobile/promoted_prim_ops.cpp @@ -190,8 +190,7 @@ void toList(Stack& stack) { "Output annotation list dimension and runtime tensor dimension must match for tolist()"); // Wrap out_ty in a ListType dim times. - for (const auto i : c10::irange(dim_val)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(dim_val)) { out_ty = at::ListType::create(out_ty); } @@ -228,33 +227,36 @@ void dictIndex(Stack& stack) { auto dict = pop(stack).toGenericDict(); auto value = dict.find(key); if (value == dict.end()) { - AT_ERROR("KeyError: ", key); + TORCH_CHECK(false, "KeyError: ", key); } push(stack, value->value()); } -static const C10_UNUSED std::array op_reg = { - mobile::prim_op_fn_register("prim::TupleIndex", tupleIndex), - mobile::prim_op_fn_register("aten::Bool.Tensor", boolTensor), - mobile::prim_op_fn_register("aten::format", aten_format), - mobile::prim_op_fn_register("prim::NumToTensor.Scalar", numToTensorScalar), - mobile::prim_op_fn_register( - "prim::RaiseException", - raiseExceptionWithMessage), - mobile::prim_op_fn_register("prim::device", device), - mobile::prim_op_fn_register("prim::dtype", dtype), - mobile::prim_op_fn_register("prim::layout", layout), - mobile::prim_op_fn_register("aten::__not__", _not), - mobile::prim_op_fn_register("aten::__is__", is), - mobile::prim_op_fn_register("aten::__isnot__", isNot), - mobile::prim_op_fn_register("aten::dim", dim), - mobile::prim_op_fn_register("prim::Uninitialized", unInitialized), - mobile::prim_op_fn_register("prim::is_cuda", isCuda), - mobile::prim_op_fn_register("aten::__getitem__.Dict_str", dictIndex), - mobile::prim_op_fn_register("prim::unchecked_cast", noop), - // TODO: (@pavithran) size is overloaded with int[] and Tensor - // so this throws error expecting int not Tensor - // mobile::prim_op_fn_register("aten::size", size) +[[maybe_unused]] static const std::array + op_reg = { + mobile::prim_op_fn_register("prim::TupleIndex", tupleIndex), + mobile::prim_op_fn_register("aten::Bool.Tensor", boolTensor), + mobile::prim_op_fn_register("aten::format", aten_format), + mobile::prim_op_fn_register( + "prim::NumToTensor.Scalar", + numToTensorScalar), + mobile::prim_op_fn_register( + "prim::RaiseException", + raiseExceptionWithMessage), + mobile::prim_op_fn_register("prim::device", device), + mobile::prim_op_fn_register("prim::dtype", dtype), + mobile::prim_op_fn_register("prim::layout", layout), + mobile::prim_op_fn_register("aten::__not__", _not), + mobile::prim_op_fn_register("aten::__is__", is), + mobile::prim_op_fn_register("aten::__isnot__", isNot), + mobile::prim_op_fn_register("aten::dim", dim), + mobile::prim_op_fn_register("prim::Uninitialized", unInitialized), + mobile::prim_op_fn_register("prim::is_cuda", isCuda), + mobile::prim_op_fn_register("aten::__getitem__.Dict_str", dictIndex), + mobile::prim_op_fn_register("prim::unchecked_cast", noop), + // TODO: (@pavithran) size is overloaded with int[] and Tensor + // so this throws error expecting int not Tensor + // mobile::prim_op_fn_register("aten::size", size) }; } // namespace torch::jit diff --git a/torch/csrc/jit/mobile/register_ops_common_utils.h b/torch/csrc/jit/mobile/register_ops_common_utils.h index 344b4dd25b858..4bc04054c5075 100644 --- a/torch/csrc/jit/mobile/register_ops_common_utils.h +++ b/torch/csrc/jit/mobile/register_ops_common_utils.h @@ -14,14 +14,14 @@ inline void noop(Stack& n) {} int64_t normalizeIndex(int64_t idx, int64_t list_size); // reference function THPVariable_to in python_variable_methods.cpp -static C10_UNUSED at::Tensor to_dispatch( +[[maybe_unused]] static at::Tensor to_dispatch( at::Tensor self, std::optional device, std::optional scalarType, bool non_blocking, bool copy) { if (device && device->is_cuda()) { - at::globalContext().lazyInitCUDA(); + at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); } if (!device && !scalarType && !copy) { return self; diff --git a/torch/csrc/jit/mobile/type_parser.cpp b/torch/csrc/jit/mobile/type_parser.cpp index 8b92c91643e46..5f15110d9face 100644 --- a/torch/csrc/jit/mobile/type_parser.cpp +++ b/torch/csrc/jit/mobile/type_parser.cpp @@ -336,7 +336,7 @@ void TypeParser::advance() { lex(); } -C10_NODISCARD c10::string_view TypeParser::cur() const { +[[nodiscard]] c10::string_view TypeParser::cur() const { return next_token_; } diff --git a/torch/csrc/jit/mobile/type_parser.h b/torch/csrc/jit/mobile/type_parser.h index 420e43a5c406e..e2cb66e3fe879 100644 --- a/torch/csrc/jit/mobile/type_parser.h +++ b/torch/csrc/jit/mobile/type_parser.h @@ -33,7 +33,7 @@ class TORCH_API TypeParser { std::string next(); c10::string_view nextView(); void advance(); - C10_NODISCARD c10::string_view cur() const; + [[nodiscard]] c10::string_view cur() const; std::string pythonStr_; size_t start_; diff --git a/torch/csrc/jit/passes/autocast.cpp b/torch/csrc/jit/passes/autocast.cpp index 766c084302645..1d5cb636e4541 100644 --- a/torch/csrc/jit/passes/autocast.cpp +++ b/torch/csrc/jit/passes/autocast.cpp @@ -131,7 +131,7 @@ std::optional parseAutocast( // // TODO: better error message // - AT_ERROR("Unsupported autocast syntax"); + TORCH_CHECK(false, "Unsupported autocast syntax"); } return std::nullopt; @@ -330,7 +330,7 @@ void handleBlock(Block* block, AutocastContext initial_state) { parseAutocast(node->input(), current_state())) { if (node->hasUses()) { // TODO: better error message - AT_ERROR("`with autocast() as ...` is not supported"); + TORCH_CHECK(false, "`with autocast() as ...` is not supported"); } TORCH_INTERNAL_ASSERT( !incompatible_amp.has_value() || !incompatible_amp.value(), @@ -492,7 +492,7 @@ void handleBlock(Block* block, AutocastContext initial_state) { // Banned in autocast, see binary_cross_entropy_banned() case aten::binary_cross_entropy: if (current_state()) { - AT_ERROR("Unsafe to autocast"); + TORCH_CHECK(false, "Unsafe to autocast"); } } diff --git a/torch/csrc/jit/passes/bailout_graph.cpp b/torch/csrc/jit/passes/bailout_graph.cpp index 490fc366ad419..7f8d7eedbe6bf 100644 --- a/torch/csrc/jit/passes/bailout_graph.cpp +++ b/torch/csrc/jit/passes/bailout_graph.cpp @@ -102,7 +102,7 @@ struct BailOutGraphBuilderForNode { } else if (outer_node->kind() == prim::If) { buildBailOutIf(b->outputs(), outer_node); } else { - AT_ERROR("Unexpected outer node"); + TORCH_CHECK(false, "Unexpected outer node"); } } } diff --git a/torch/csrc/jit/passes/frozen_conv_folding.cpp b/torch/csrc/jit/passes/frozen_conv_folding.cpp index 2e43bd2354d84..6bc75bfcc8cf6 100644 --- a/torch/csrc/jit/passes/frozen_conv_folding.cpp +++ b/torch/csrc/jit/passes/frozen_conv_folding.cpp @@ -327,8 +327,8 @@ bool FoldFrozenConvMulOrDiv(Block* b) { // channels-out resize it to the shape that will broadcast to // weight_tensor when the op is run so we dont change weight size std::vector weight_compatible_size = {out_channels}; - for (const auto i : c10::irange(1, weight_tensor.ndimension())) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : + c10::irange(1, weight_tensor.ndimension())) { weight_compatible_size.push_back(1); } diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index a804abe8013a5..8dfa836f87bd8 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -829,8 +829,7 @@ struct GraphFuser { } bchunk->removeInput(producer_index); - for (const auto i : c10::irange(nchunks)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(nchunks)) { bchunk->eraseOutput(nchunks * producer_index); } diff --git a/torch/csrc/jit/passes/loop_unrolling.cpp b/torch/csrc/jit/passes/loop_unrolling.cpp index ebc4894a2ecbe..05a4ffb424e01 100644 --- a/torch/csrc/jit/passes/loop_unrolling.cpp +++ b/torch/csrc/jit/passes/loop_unrolling.cpp @@ -128,8 +128,7 @@ void repeatBody(Block* body, size_t times, Block* dest) { std::vector io = dest->inputs().vec(); TORCH_INTERNAL_ASSERT( !body->inputs().at(0)->hasUses(), "loop counter should be unused"); - for (const auto i : c10::irange(times)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(times)) { io[0] = body->inputs().at(0); io = insertBlockCopy(*graph, body, io); } diff --git a/torch/csrc/jit/passes/lower_tuples.cpp b/torch/csrc/jit/passes/lower_tuples.cpp index 94610679e98e9..ff8c1642f6281 100644 --- a/torch/csrc/jit/passes/lower_tuples.cpp +++ b/torch/csrc/jit/passes/lower_tuples.cpp @@ -107,7 +107,8 @@ void removeTupleNodes(Node* n, bool must_remove_tuples) { auto construct_node = n->inputs().at(0)->node(); if (construct_node->kind() != prim::TupleConstruct) { if (must_remove_tuples) { - AT_ERROR(n->kind().toQualString(), " not matched to tuple construct"); + TORCH_CHECK( + false, n->kind().toQualString(), " not matched to tuple construct"); } return; } @@ -120,7 +121,8 @@ void removeTupleNodes(Node* n, bool must_remove_tuples) { auto maybe_int = constant_as(idx); if (!maybe_int) { if (must_remove_tuples) { - AT_ERROR(n->sourceRange(), "tuple index with non-constant index"); + TORCH_CHECK( + false, n->sourceRange(), "tuple index with non-constant index"); } return; } diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index 7a4f95ec69763..1f67cb4f970f6 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -324,23 +324,19 @@ void unpackQuantizedWeightsHelper( const int64_t kSpatialDim = config_vals.at(0); // skip kSpatialDim unsigned idx = 1; - for (const auto i : c10::irange(kSpatialDim)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { stride_int.emplace_back(config_vals.at(idx)); idx++; } - for (const auto i : c10::irange(kSpatialDim)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { padding_int.emplace_back(config_vals.at(idx)); idx++; } - for (const auto i : c10::irange(kSpatialDim)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { dilation_int.emplace_back(config_vals.at(idx)); idx++; } - for (const auto i : c10::irange(kSpatialDim)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(kSpatialDim)) { output_padding_int.emplace_back(config_vals.at(idx)); idx++; } diff --git a/torch/csrc/jit/passes/quantization/quantization_patterns.h b/torch/csrc/jit/passes/quantization/quantization_patterns.h index aeba208ea98e9..80cf46d7e021e 100644 --- a/torch/csrc/jit/passes/quantization/quantization_patterns.h +++ b/torch/csrc/jit/passes/quantization/quantization_patterns.h @@ -75,8 +75,7 @@ std::string getQuantizeForScalar(const std::string& value) { )" + value + "_tensor : Tensor = aten::scalar_tensor(" + value + ", " + value + "_float_scalar_type"; - for (const auto i : c10::irange(3)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(3)) { quantize_pattern += ", " + value + "_none"; } quantize_pattern += ")"; diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index 847541c30d531..7c2f2a33a79e0 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -532,7 +532,7 @@ IValue toIValue(py::handle obj, const TypePtr& type, std::optional N) { #ifdef USE_RPC return obj.cast().toIValue(); #else - AT_ERROR("RRef is only supported with the distributed package"); + TORCH_CHECK(false, "RRef is only supported with the distributed package"); #endif } break; case TypeKind::PyObjectType: { @@ -697,7 +697,7 @@ py::object toPyObject(IValue ivalue) { std::move(ivalue).toRRef()); return py::cast(torch::distributed::rpc::PyRRef(RRefPtr)); #else - AT_ERROR("RRef is only supported with the distributed package"); + TORCH_CHECK(false, "RRef is only supported with the distributed package"); #endif } else if (ivalue.isObject()) { const auto obj = std::move(ivalue).toObject(); @@ -751,7 +751,8 @@ py::object toPyObject(IValue ivalue) { } else if (ivalue.isSymBool()) { return py::cast(std::move(ivalue).toSymBool()); } else { - AT_ERROR( + TORCH_CHECK( + false, "Missing cases in 'toPyObject'! Can't convert ", ivalue.tagKind(), " to a Python object"); diff --git a/torch/csrc/jit/python/pybind_utils.h b/torch/csrc/jit/python/pybind_utils.h index eee1cf05b1201..28e9621393750 100644 --- a/torch/csrc/jit/python/pybind_utils.h +++ b/torch/csrc/jit/python/pybind_utils.h @@ -342,7 +342,8 @@ inline TypedIValue toDictKeyIValue(py::handle key) { } else if (py::isinstance(key)) { return TypedIValue(py::cast(key), FloatType::get()); } else { - AT_ERROR("Dictionary inputs may only have string, int, or float keys"); + TORCH_CHECK( + false, "Dictionary inputs may only have string, int, or float keys"); } } @@ -687,8 +688,12 @@ inline IValue toTypeInferredIValue(py::handle input) { return c10::intrusive_ptr::reclaim_copy( ptr.release()); } - AT_ERROR( - "Tracer cannot infer type of ", py::str(input), "\n:", match.reason()); + TORCH_CHECK( + false, + "Tracer cannot infer type of ", + py::str(input), + "\n:", + match.reason()); } return toIValue(input, match.type()); } @@ -1086,9 +1091,10 @@ inline Stack evilDeprecatedBadCreateStackDoNotUse( at::ArrayRef inputs, size_t reserve_extra_space = 0) { if (tuple.size() != inputs.size()) { - AT_ERROR( + TORCH_CHECK( + false, "expected " + std::to_string(inputs.size()) + " inputs, but got " + - std::to_string(tuple.size())); + std::to_string(tuple.size())); } Stack result; result.reserve(tuple.size() + reserve_extra_space); diff --git a/torch/csrc/jit/python/python_list.cpp b/torch/csrc/jit/python/python_list.cpp index 2193f806bf3c6..e3e16c7d65cdb 100644 --- a/torch/csrc/jit/python/python_list.cpp +++ b/torch/csrc/jit/python/python_list.cpp @@ -134,7 +134,8 @@ void initScriptListBindings(PyObject* module) { auto seq = std::make_shared(self->type()); - for (const auto i [[maybe_unused]] : c10::irange(slicelength)) { + for ([[maybe_unused]] const auto i [[maybe_unused]] : + c10::irange(slicelength)) { seq->append(self->getItem(static_cast(start))); start += step; } diff --git a/torch/csrc/jit/python/python_tracer.cpp b/torch/csrc/jit/python/python_tracer.cpp index 65c8fd9079eb1..8761867434f17 100644 --- a/torch/csrc/jit/python/python_tracer.cpp +++ b/torch/csrc/jit/python/python_tracer.cpp @@ -115,7 +115,8 @@ std::pair, Stack> createGraphByTracingWithDict( // method. auto out = func(**inputs_dict); if (out.ptr() == Py_None) { - AT_ERROR( + TORCH_CHECK( + false, "The traced function didn't return any values! Side-effects are not " "captured in traces, so it would be a no-op."); } @@ -155,7 +156,8 @@ std::pair, Stack> createGraphByTracing( } auto out = func(*py_inputs); if (out.ptr() == Py_None) { - AT_ERROR( + TORCH_CHECK( + false, "The traced function didn't return any values! Side-effects are not " "captured in traces, so it would be a no-op."); } diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp index 2eb9a6f021770..690859f0a0a2a 100644 --- a/torch/csrc/jit/python/script_init.cpp +++ b/torch/csrc/jit/python/script_init.cpp @@ -866,6 +866,53 @@ void initJitScriptBindings(PyObject* module) { // Similar to Tensor's `__hash__`, which is `id()`. return std::hash{}(self._ivalue().get()); }) + .def( + "__deepcopy__", + [](const Object& self, const py::dict& memo) { + if (auto getstate_method = self.find_method("__getstate__")) { + auto object_state = toPyObject((*getstate_method)(Stack{})); + + if (auto qualname = self.type()->name()) { + auto class_type = getCustomClass(qualname->qualifiedName()); + auto self = Object(c10::ivalue::Object::create( + c10::StrongTypePtr( + std::shared_ptr(), + class_type), + 1)); + + if (auto setstate_method = + self.find_method("__setstate__")) { + auto setstate_schema = + setstate_method->function().getSchema(); + TORCH_INTERNAL_ASSERT( + setstate_schema.arguments().size() == 2, + "__setstate__ method for class ", + class_type->repr_str(), + " must have exactly 2 arguments!"); + auto state_type = + setstate_schema.arguments().at(1).type(); + (*setstate_method)( + Stack{toIValue(object_state, state_type)}); + return self; + } + std::stringstream err; + err << "Tried to deepcopy object "; + if (auto qualname = class_type->name()) { + err << qualname->qualifiedName() << " "; + } + err << "which does not have a __setstate__ method defined!"; + throw std::runtime_error(err.str()); + } + } + + std::stringstream err; + err << "Tried to deepcopy object "; + if (auto qualname = self.type()->name()) { + err << qualname->qualifiedName() << " "; + } + err << "which does not have a __getstate__ method defined!"; + throw std::runtime_error(err.str()); + }) .def(py::pickle( [](const Object& self) -> std::tuple { // __getstate__ diff --git a/torch/csrc/jit/runtime/decomposition_registry.cpp b/torch/csrc/jit/runtime/decomposition_registry.cpp index 98a654ccaab0c..62e20022f60d1 100644 --- a/torch/csrc/jit/runtime/decomposition_registry.cpp +++ b/torch/csrc/jit/runtime/decomposition_registry.cpp @@ -101,7 +101,7 @@ static void RunDecompositions(Block* block) { void RunDecompositions(std::shared_ptr g) { RunDecompositions(g->block()); - for (C10_UNUSED const auto _ : c10::irange(2)) { + for ([[maybe_unused]] const auto _ : c10::irange(2)) { PeepholeOptimize(g, /*disable_shape_peephole*/ true); ConstantPropagation(g); } diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp index 4ef3b404aab96..d7aa2b6bf63fb 100644 --- a/torch/csrc/jit/runtime/interpreter.cpp +++ b/torch/csrc/jit/runtime/interpreter.cpp @@ -115,6 +115,10 @@ struct TLSCurrentInterpreterGuard { InterpreterStateImpl* prev_state_; }; +bool in_torchscript_runtime() { + return tls_int_state_ptr_ != nullptr; +} + // InterpreterState state that and used to compute a Code struct InterpreterStateImpl : c10::intrusive_ptr_target { InterpreterStateImpl(const Code& code, TaskLauncher taskLauncher) @@ -240,7 +244,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { std::size_t initialSize_{stack_.size()}; }; - struct C10_UNUSED DoNothing {}; + struct [[maybe_unused]] DoNothing {}; #if defined(__GNUC__) || defined(__clang__) #define JIT_USE_COMPUTED_GOTO @@ -321,14 +325,14 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { switch (inst.op) { case INST(ENTER): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); const auto& obj = peek(stack, 0, 1); TORCH_INTERNAL_ASSERT(obj.isObject()); entered_objects.push_back(obj); } INST_NEXT; case INST(EXIT): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); auto obj = entered_objects.back().toObject(); auto& f = obj->type()->getMethod("__exit__"); push(stack, std::move(obj)); @@ -340,14 +344,14 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { continue; } case INST(OP): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); auto stackSizeGuard = stackSizeAssertGuard(); frame.function->operator_table_[inst.X](stack); stackSizeGuard.callAssert(); } INST_NEXT; case INST(OPN): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); stack.emplace_back(inst.N); auto stackSizeGuard = stackSizeAssertGuard(); frame.function->operator_table_[inst.X](stack); @@ -355,22 +359,22 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } INST_NEXT; case INST(LOAD): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); stack.emplace_back(reg(inst.X)); } INST_NEXT; case INST(MOVE): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); stack.emplace_back(std::move(reg(inst.X))); } INST_NEXT; case INST(STORE): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); reg(inst.X) = pop(stack); } INST_NEXT; case INST(STOREN): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); TORCH_INTERNAL_ASSERT(stack.size() >= inst.N); for (size_t i = inst.N; i > 0; --i) { reg(inst.X + i - 1) = pop(stack); @@ -378,28 +382,28 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } INST_NEXT; case INST(DROP): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); stack.pop_back(); } INST_NEXT; case INST(DROPR): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); reg(inst.X) = IValue(); } INST_NEXT; case INST(LOADC): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); stack.emplace_back(frame.function->constant_table_[inst.X]); } INST_NEXT; case INST(GET_ATTR): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); const auto& userObj = stack.back().toObjectRef(); stack.back() = userObj.getSlot(inst.X); } INST_NEXT; case INST(SET_ATTR): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); auto v = pop(stack); auto& userObj = stack.back().toObjectRef(); userObj.setSlot(inst.X, std::move(v)); @@ -407,7 +411,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } INST_NEXT; case INST(JF): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); if (pop(stack).toBool()) { inst = instFetch(1); } else { @@ -416,12 +420,12 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } INST_DISPATCH; case INST(JMP): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); inst = instFetch(inst.X); } INST_DISPATCH; case INST(LOOP): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); // stack: iteration_count, max_iter, cond, loop_carried_deps... auto fr = stack.end() - (inst.N + 1); int64_t trip_count = fr[0].toInt(); @@ -442,13 +446,13 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } INST_DISPATCH; case INST(CALL): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); Function* fn = frame.function->function_table_[inst.X]; callFunction(*fn, stack); continue; } case INST(INTERFACE_CALL): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); // note the hash table lookup to find the function // this can be more optimized if necessary, caching parts // of the hashing computation or storing the offset when @@ -489,7 +493,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { return false; } case INST(WAIT): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); auto future = stack.back().toFuture(); if (!future->completed()) { getOrCreateFuture(); @@ -547,7 +551,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } INST_NEXT; case INST(PROFILE_OP): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); auto& frame_id_ref = frame.id; if (!frame_id_ref.has_value()) { frame_id_ref = Frame::genId(); @@ -559,7 +563,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } INST_NEXT; case INST(FAIL_GUARD): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); // patch FAIL_GUARD back to GUARD GRAPH_DEBUG( "Bailout ", inst.X, " triggered via bailout_requests_!"); @@ -568,7 +572,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } INST_NEXT; case INST(TYPECHECK): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); unsigned num_inputs = inst.N, i = 0; TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0); // Check every input's shape against profiled (expected) shape. @@ -588,7 +592,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } INST_NEXT; case INST(GUARD): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); if (!stack.back().isTensor()) { // stack.back() is an Uninitialized IValue and this is a guard // on a block output. Uninitialized IValues are never used @@ -609,7 +613,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } INST_NEXT; case INST(TAIL_CALL): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); GRAPH_DEBUG("running TAIL_CALL for ", inst.X); frame.function->function_table_[inst.X]->ensure_defined(); size_t remaining_bailout_depth = @@ -632,22 +636,22 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { continue; } case INST(LIST_UNPACK): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); listUnpack(stack, inst.X); } INST_NEXT; case INST(TUPLE_CONSTRUCT): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); tupleConstruct(stack, inst.X); } INST_NEXT; case INST(TUPLE_SLICE): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); tupleSlice(stack, inst.X, inst.X + inst.N); } INST_NEXT; case INST(NAMED_TUPLE_CONSTRUCT): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); namedTupleConstruct( stack, frame.function->type_table_[inst.X]->expect(), @@ -655,28 +659,28 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } INST_NEXT; case INST(LIST_CONSTRUCT): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); const auto& type = frame.function->type_table_[inst.X]->expectRef(); listConstruct(stack, type, inst.N); } INST_NEXT; case INST(DICT_CONSTRUCT): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); const auto& type = frame.function->type_table_[inst.X]->expectRef(); dictConstruct(stack, type, inst.N); } INST_NEXT; case INST(CREATE_OBJECT): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); auto type = frame.function->type_table_[inst.X]->expect(); createObject(stack, type); } INST_NEXT; case INST(ISINSTANCE): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); at::ArrayRef types( &frame.function->type_table_[inst.X], &frame.function->type_table_[inst.X] + inst.N); @@ -684,84 +688,84 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } INST_NEXT; case INST(TUPLE_INDEX): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); tupleIndex(stack); } INST_NEXT; case INST(RAISE_EXCEPTION): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); raiseExceptionWithMessage(stack); } INST_NEXT; case INST(UNCHECKED_CAST): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); noop(stack); } INST_NEXT; case INST(__IS__): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); is(stack); } INST_NEXT; case INST(UN_INITIALIZED): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); unInitialized(stack); } INST_NEXT; case INST(__ISNOT__): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); isNot(stack); } INST_NEXT; case INST(FORMAT): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); format(stack, inst.X); } INST_NEXT; case INST(DEVICE): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); device(stack); } INST_NEXT; case INST(DTYPE): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); TORCH_INTERNAL_ASSERT(!stack.empty()); dtype(stack); } INST_NEXT; case INST(DIM): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); TORCH_INTERNAL_ASSERT(!stack.empty()); dim(stack); } INST_NEXT; case INST(__NOT__): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); _not(stack); } INST_NEXT; case INST(DICT_INDEX): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); dictIndex(stack); } INST_NEXT; case INST(TO_LIST): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); toList(stack); } INST_NEXT; case INST(NUM_TO_TENSOR): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); numToTensorScalar(stack); } INST_NEXT; case INST(IS_CUDA): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); isCuda(stack); } INST_NEXT; case INST(FORK): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); // Move inputs to a separate stack auto& forked_fn = toGraphFunction(*frame.function->function_table_[inst.X]); @@ -777,7 +781,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } INST_NEXT; case INST(AWAITABLE): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); auto fn_ptr = frame.function->function_table_[inst.X]; auto& fn = toGraphFunction(*fn_ptr); auto num_outputs = fn.graph()->outputs().size(); @@ -817,7 +821,7 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target { } INST_NEXT; case INST(WARN): { - auto _ = instGuard(); + [[maybe_unused]] auto _ = instGuard(); // Keeps track of which WARN instruction has been executed before, // we only want to execute each WARN once to match default Python // warning behavior. diff --git a/torch/csrc/jit/runtime/interpreter.h b/torch/csrc/jit/runtime/interpreter.h index ffafd3ab096a9..54e3fc8b86e8e 100644 --- a/torch/csrc/jit/runtime/interpreter.h +++ b/torch/csrc/jit/runtime/interpreter.h @@ -40,6 +40,8 @@ using Stack = std::vector; using c10::ivalue::Future; using TaskLauncher = std::function)>; +bool TORCH_API in_torchscript_runtime(); + struct TORCH_API Code { Code() = default; explicit Code(interpreter::CodeImpl* pImpl); diff --git a/torch/csrc/jit/runtime/operator.cpp b/torch/csrc/jit/runtime/operator.cpp index 1935d8ccf7402..35dead2a395c9 100644 --- a/torch/csrc/jit/runtime/operator.cpp +++ b/torch/csrc/jit/runtime/operator.cpp @@ -355,7 +355,8 @@ void registerOperator(Operator&& op) { if (op.schema().is_varret()) { Symbol s = Symbol::fromQualString(op.schema().name()); if (!printerHasSpecialCaseFor(s)) { - AT_ERROR( + TORCH_CHECK( + false, "Missing special case in python printer for non-schematized" " operator ", op.schema().name(), @@ -363,7 +364,8 @@ void registerOperator(Operator&& op) { } if (aliasAnalysisHasSpecialCaseFor(s) && op.aliasAnalysisKind() == AliasAnalysisKind::CONSERVATIVE) { - AT_ERROR( + TORCH_CHECK( + false, "Conflict in special casing in alias analysis for non-schematized" " operator ", op.schema().name(), @@ -371,7 +373,8 @@ void registerOperator(Operator&& op) { } if (aliasAnalysisHasSpecialCaseFor(s) && op.aliasAnalysisKind() == AliasAnalysisKind::FROM_SCHEMA) { - AT_ERROR( + TORCH_CHECK( + false, "The operator ", op.schema().name(), " is special cased and cannot use explicit alias analysis."); diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp index 54ec8e8441fa7..4b5518a525b62 100644 --- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp +++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp @@ -198,7 +198,7 @@ static bool needsGradientInProfilingMode(Block* b) { // differentiable graph. Autodiff will inspect these properties and prune // off gradients that aren't required // `requires_grad` properties from `dnode->outputs()` will also be transferred -static C10_UNUSED void setRequiresGradOnDiffGraph(Node* dnode) { +[[maybe_unused]] static void setRequiresGradOnDiffGraph(Node* dnode) { auto gi = dnode->g(attr::Subgraph)->inputs(); for (size_t i = 0; i < dnode->inputs().size(); i++) { if (auto ty = dnode->input(i)->type()->cast()) { diff --git a/torch/csrc/jit/runtime/register_c10_ops.cpp b/torch/csrc/jit/runtime/register_c10_ops.cpp index ff6162d46e0c8..85e8c0a2b037c 100644 --- a/torch/csrc/jit/runtime/register_c10_ops.cpp +++ b/torch/csrc/jit/runtime/register_c10_ops.cpp @@ -52,7 +52,7 @@ Registerer& registerer() { } // global instance to run its constructor on startup -C10_UNUSED Registerer& dummy = registerer(); +[[maybe_unused]] Registerer& dummy = registerer(); } // namespace diff --git a/torch/csrc/jit/runtime/register_ops_utils.cpp b/torch/csrc/jit/runtime/register_ops_utils.cpp index abbdf44ec6051..32bbe97104996 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.cpp +++ b/torch/csrc/jit/runtime/register_ops_utils.cpp @@ -34,7 +34,7 @@ void listIndex(Stack& stack) { if (pos != list.end()) { push(stack, static_cast(std::distance(list.begin(), pos))); } else { - AT_ERROR("'", elem, "' is not in list"); + TORCH_CHECK(false, "'", elem, "' is not in list"); } } @@ -107,7 +107,7 @@ void listRemove(Stack& stack) { if (pos != list.end()) { list.erase(pos); } else { - AT_ERROR("list.remove(x): x not in list"); + TORCH_CHECK(false, "list.remove(x): x not in list"); } } @@ -205,7 +205,7 @@ void listPopImpl(Stack& stack, const char* empty_message) { const int64_t normalized_idx = normalizeIndex(idx, list_size); if (list_size == 0) { - AT_ERROR(empty_message); + TORCH_CHECK(false, empty_message); } push(stack, getItem(list, idx)); @@ -311,8 +311,7 @@ void listMulIntLeftInPlace(Stack& stack) { list.clear(); } else if (n > 1) { size_t list_size = list.size(); - for (const auto i : c10::irange(1, n)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(1, n)) { for (const auto j : c10::irange(list_size)) { list.push_back(list.get(j)); } @@ -330,8 +329,7 @@ void listMulIntLeft(Stack& stack) { const auto size = list.size() * n; ret.reserve(size); - for (const auto i : c10::irange(n)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(n)) { for (IValue e : list) { ret.push_back(std::move(e)); } @@ -348,8 +346,7 @@ void listMulIntRight(Stack& stack) { const auto size = list.size() * n; ret.reserve(size); - for (const auto i : c10::irange(n)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(n)) { for (IValue e : list) { ret.push_back(std::move(e)); } @@ -382,8 +379,7 @@ void listSlice(Stack& stack) { sliced_list.reserve(num_values); int i = start; - for (const auto j : c10::irange(num_values)) { - (void)j; // Suppress unused variable warning + for ([[maybe_unused]] const auto j : c10::irange(num_values)) { sliced_list.push_back(list.get(i)); i += step; } @@ -429,7 +425,8 @@ at::Generator make_generator_for_device( } #endif } else { - AT_ERROR( + TORCH_CHECK( + false, "Unsupported device for at::make_generator_for_device found: ", device.str()); } diff --git a/torch/csrc/jit/runtime/register_ops_utils.h b/torch/csrc/jit/runtime/register_ops_utils.h index ebdc5ba205cd5..340b597280a6e 100644 --- a/torch/csrc/jit/runtime/register_ops_utils.h +++ b/torch/csrc/jit/runtime/register_ops_utils.h @@ -71,7 +71,7 @@ inline double round_to_even(double a) { // and if the dest is an int the source must be integral type void checkImplicitTensorToNum(const at::Tensor& t, bool toInt); -static C10_UNUSED int64_t floordiv(int64_t a, int64_t b) { +[[maybe_unused]] static int64_t floordiv(int64_t a, int64_t b) { if (b == 0) { throw std::runtime_error("division by 0"); } @@ -85,16 +85,16 @@ static C10_UNUSED int64_t floordiv(int64_t a, int64_t b) { } } TORCH_API void checkDoubleInRange(double a); -static C10_UNUSED int64_t floor(double a) { +[[maybe_unused]] static int64_t floor(double a) { checkDoubleInRange(a); return std::floor(a); } -static C10_UNUSED int64_t ceil(double a) { +[[maybe_unused]] static int64_t ceil(double a) { checkDoubleInRange(a); return std::ceil(a); } -static C10_UNUSED int64_t gcd(int64_t a, int64_t b) { +[[maybe_unused]] static int64_t gcd(int64_t a, int64_t b) { while (b != 0) { int64_t r = a % b; a = b; @@ -200,7 +200,7 @@ void listRemove(Stack& stack) { if (pos != list.end()) { list.erase(pos); } else { - AT_ERROR("list.remove(x): x not in list"); + TORCH_CHECK(false, "list.remove(x): x not in list"); } } @@ -251,7 +251,7 @@ void listIndex(Stack& stack) { if (pos != list.end()) { push(stack, static_cast(std::distance(list.begin(), pos))); } else { - AT_ERROR("'", elem, "' is not in list"); + TORCH_CHECK(false, "'", elem, "' is not in list"); } } diff --git a/torch/csrc/jit/runtime/register_prim_ops.cpp b/torch/csrc/jit/runtime/register_prim_ops.cpp index d20c5d6a0fec5..085ad913a6c6f 100644 --- a/torch/csrc/jit/runtime/register_prim_ops.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops.cpp @@ -38,8 +38,7 @@ std::string stringSlice( int64_t i = start_val; std::string result = ""; - for (const auto j : c10::irange(num_vals)) { - (void)j; // Suppress unused variable warning + for ([[maybe_unused]] const auto j : c10::irange(num_vals)) { result += string[i]; i += step; } @@ -1042,7 +1041,7 @@ static const std::vector opGenArgs{ [](Stack& stack) { at::Tensor t = pop(stack).toTensor(); if (t.dim() == 0) { - AT_ERROR("len() of a 0-d tensor"); + TORCH_CHECK(false, "len() of a 0-d tensor"); } push(stack, t.sizes()[0]); }, @@ -1489,7 +1488,7 @@ void dictPop(Stack& stack) { if (has_default) { push(stack, default_value); } else { - AT_ERROR("KeyError: ", key); + TORCH_CHECK(false, "KeyError: ", key); } } else { // note: before erase @@ -1509,7 +1508,7 @@ void dictDelete(Stack& stack) { void dictPopItem(Stack& stack) { auto dict = pop(stack).toGenericDict(); if (dict.empty()) { - AT_ERROR("popitem(): dictionary is empty"); + TORCH_CHECK(false, "popitem(): dictionary is empty"); } auto head_item = dict.begin(); diff --git a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp index 035a5d35c4630..b09cc45ce33f7 100644 --- a/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp +++ b/torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp @@ -35,7 +35,8 @@ RegisterOperators reg({ prim::profile, [](const Node* node) -> Operation { return [](Stack& stack) { - AT_ERROR( + TORCH_CHECK( + false, "Must be lowered to Interpreter's PROFILE instruction"); // NOLINT }; }, @@ -44,7 +45,8 @@ RegisterOperators reg({ prim::profile_ivalue, [](const Node* node) -> Operation { return [](Stack& stack) { - AT_ERROR( + TORCH_CHECK( + false, "Must be lowered to Interpreter's PROFILE instruction"); // NOLINT }; }, @@ -188,7 +190,7 @@ RegisterOperators reg({ prim::TypeCheck /* (...) -> (..., bool) */, [](const Node* /* node */) -> Operation { return [](Stack& /* stack */) { - AT_ERROR("prim::TypeCheck not yet implemented"); // NOLINT + TORCH_CHECK(false, "prim::TypeCheck not yet implemented"); // NOLINT }; }, aliasAnalysisSpecialCase()), @@ -196,19 +198,22 @@ RegisterOperators reg({ prim::FallbackGraph, [](const Node* node) -> Operation { return [](Stack& stack) { - AT_ERROR( + TORCH_CHECK( + false, "Must be converted to prim::FunctionCall by replaceFallbackGraphWithFallbackFunction"); // NOLINT }; }, aliasAnalysisSpecialCase()), Operator( "prim::Guard(Tensor(a) t) -> Tensor(a)", - [](Stack& stack) { AT_ERROR("Should be replaced by prim::BailOut"); }, + [](Stack& stack) { + TORCH_CHECK(false, "Should be replaced by prim::BailOut"); + }, aliasAnalysisFromSchema()), Operator( "prim::BailOut(...) -> Tensor(a)", [](Stack& /* stack */) { - AT_ERROR("prim::BailOut not yet implemented"); // NOLINT + TORCH_CHECK(false, "prim::BailOut not yet implemented"); // NOLINT }, aliasAnalysisFromSchema()), Operator( @@ -379,7 +384,7 @@ RegisterOperators logging_operators( }, aliasAnalysisFromSchema())}); -C10_UNUSED void hashValue(Stack& stack) { +[[maybe_unused]] void hashValue(Stack& stack) { auto value = pop(stack); push(stack, value.hash()); } @@ -578,7 +583,8 @@ at::Tensor interpolate( scale_factors_2, scale_factors_3); - AT_ERROR( + TORCH_CHECK( + false, "Input Error: Only 3D, 4D and 5D input Tensors supported", " (got ", input_dim, diff --git a/torch/csrc/jit/runtime/register_special_ops.cpp b/torch/csrc/jit/runtime/register_special_ops.cpp index 5fa06b0927451..783aaf87ef7d7 100644 --- a/torch/csrc/jit/runtime/register_special_ops.cpp +++ b/torch/csrc/jit/runtime/register_special_ops.cpp @@ -77,7 +77,8 @@ std::vector compute_sizes(const IValue& seq) { void checkSequenceSize(int64_t n, int64_t dim, int64_t seq_size) { if (seq_size != n) { - AT_ERROR( + TORCH_CHECK( + false, "Expected sequence of length ", n, " at dim ", diff --git a/torch/csrc/jit/runtime/script_profile.cpp b/torch/csrc/jit/runtime/script_profile.cpp index 3ad4716d32b59..a1e1ad6972e4a 100644 --- a/torch/csrc/jit/runtime/script_profile.cpp +++ b/torch/csrc/jit/runtime/script_profile.cpp @@ -102,7 +102,7 @@ auto initBindings() { return nullptr; } -const auto C10_UNUSED torchBindInitializer = initBindings(); +[[maybe_unused]] const auto torchBindInitializer = initBindings(); } // namespace diff --git a/torch/csrc/jit/runtime/static/ProcessedNodeInputs.h b/torch/csrc/jit/runtime/static/ProcessedNodeInputs.h index 1f81c67368c4d..81d4b06d15624 100644 --- a/torch/csrc/jit/runtime/static/ProcessedNodeInputs.h +++ b/torch/csrc/jit/runtime/static/ProcessedNodeInputs.h @@ -43,7 +43,7 @@ class ProcessedNodeInputs { } } - C10_NODISCARD uint16_t size() const { + [[nodiscard]] uint16_t size() const { if (C10_LIKELY(repr_.is_inline())) { return repr_.inline_repr_.size; } else { @@ -51,7 +51,7 @@ class ProcessedNodeInputs { } } - C10_NODISCARD bool empty() const { + [[nodiscard]] bool empty() const { return size() == 0; } @@ -93,11 +93,11 @@ class ProcessedNodeInputs { HeapArrayPtr(HeapArrayPtr&&) noexcept = default; HeapArrayPtr& operator=(HeapArrayPtr&&) noexcept = default; - C10_NODISCARD bool empty() const { + [[nodiscard]] bool empty() const { return size() != 0; } - C10_NODISCARD uint16_t size() const { + [[nodiscard]] uint16_t size() const { return array_ ? array_[0] : 0; } @@ -137,7 +137,7 @@ class ProcessedNodeInputs { // awkward. #pragma pack(push, 2) union Repr { - C10_NODISCARD bool is_inline() const { + [[nodiscard]] bool is_inline() const { uint8_t tag = 0; // Use of reinterpret_cast to pointer to char or unsigned char // is defined behavior; see diff --git a/torch/csrc/jit/runtime/static/impl.cpp b/torch/csrc/jit/runtime/static/impl.cpp index 15f22bee7dfc0..908cbf17ce665 100644 --- a/torch/csrc/jit/runtime/static/impl.cpp +++ b/torch/csrc/jit/runtime/static/impl.cpp @@ -1581,8 +1581,7 @@ float BlockRunner::benchmark_model( const bool is_kwargs_empty = kwargs_list.empty(); const KeywordArgs empty_kwargs; - for (const auto _n_run : c10::irange(warmup_runs)) { - (void)_n_run; // Suppress unused variable warning + for ([[maybe_unused]] const auto _n_run : c10::irange(warmup_runs)) { const auto num_args = static_cast(args_list.size()); for (const auto j : c10::irange(num_args)) { operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]); @@ -1592,8 +1591,7 @@ float BlockRunner::benchmark_model( } } caffe2::Timer timer; - for (const auto _n_run : c10::irange(main_runs)) { - (void)_n_run; // Suppress unused variable warning + for ([[maybe_unused]] const auto _n_run : c10::irange(main_runs)) { const auto num_args = static_cast(args_list.size()); for (const auto j : c10::irange(num_args)) { operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]); @@ -1745,8 +1743,7 @@ BlockRunner::IndividualMetrics BlockRunner::benchmark_individual_ops( results.first_iter_time = timer.MilliSeconds(); // warmup runs - for (const auto _n_run : c10::irange(warmup_runs)) { - (void)_n_run; // Suppress unused variable warning + for ([[maybe_unused]] const auto _n_run : c10::irange(warmup_runs)) { const auto num_args = static_cast(args_list.size()); for (const auto j : c10::irange(num_args)) { operator()(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]); @@ -1757,8 +1754,7 @@ BlockRunner::IndividualMetrics BlockRunner::benchmark_individual_ops( } // main runs - for (const auto i : c10::irange(main_runs)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(main_runs)) { const auto num_args = static_cast(args_list.size()); for (const auto j : c10::irange(num_args)) { set_inputs(args_list[j], is_kwargs_empty ? empty_kwargs : kwargs_list[j]); diff --git a/torch/csrc/jit/runtime/static/impl.h b/torch/csrc/jit/runtime/static/impl.h index eb8eceb41dc35..7087d39f2e16b 100644 --- a/torch/csrc/jit/runtime/static/impl.h +++ b/torch/csrc/jit/runtime/static/impl.h @@ -456,7 +456,7 @@ class TORCH_API StaticModule { return num_inputs() + num_constants() + num_intermediate_values(); } - C10_NODISCARD const std::vector& output_indices() const { + [[nodiscard]] const std::vector& output_indices() const { return output_indices_; } @@ -488,7 +488,7 @@ class TORCH_API StaticModule { }); } - C10_NODISCARD Node* findNodeWithKindForTesting(const std::string& kind) const; + [[nodiscard]] Node* findNodeWithKindForTesting(const std::string& kind) const; const std::optional& schema() const { return schema_; @@ -644,7 +644,7 @@ class TORCH_API BlockRunner { } // Output is readonly. The writing process happens inside ProcessedNodes - C10_NODISCARD const IValue& Output(uint32_t i) const { + [[nodiscard]] const IValue& Output(uint32_t i) const { DCHECK(i < outputs_.size()); return *outputs_[i]; } @@ -923,7 +923,7 @@ class TORCH_API ProcessedNode { } // Input is readonly - C10_NODISCARD const IValue& Input(uint32_t i) const { + [[nodiscard]] const IValue& Input(uint32_t i) const { return values_[inputs_[i]]; } @@ -933,7 +933,7 @@ class TORCH_API ProcessedNode { return values_[outputs_offset_ + i]; } - C10_NODISCARD const IValue& Output(uint32_t i) const { + [[nodiscard]] const IValue& Output(uint32_t i) const { DCHECK(i < num_outputs()); return values_[outputs_offset_ + i]; } @@ -943,12 +943,12 @@ class TORCH_API ProcessedNode { return static_cast(fn_->num_outputs()); } - C10_NODISCARD c10::ArrayRef outputs() const { + [[nodiscard]] c10::ArrayRef outputs() const { return c10::ArrayRef( values_ + outputs_offset_, num_outputs()); } - C10_NODISCARD uint16_t num_inputs() const { + [[nodiscard]] uint16_t num_inputs() const { return inputs_.size(); } @@ -990,7 +990,7 @@ class TORCH_API ProcessedNode { values_ = values; } - C10_NODISCARD uint16_t output_ivalue_index(uint16_t i) const { + [[nodiscard]] uint16_t output_ivalue_index(uint16_t i) const { DCHECK(i < num_outputs()); return outputs_offset_ + i; } @@ -1019,9 +1019,9 @@ class TORCH_API ProcessedNode { } private: - C10_NODISCARD bool verify_outputs_dont_overlap_each_other() const; + [[nodiscard]] bool verify_outputs_dont_overlap_each_other() const; - C10_NODISCARD bool verify_inputs_dont_overlap_outputs(bool force_check) const; + [[nodiscard]] bool verify_inputs_dont_overlap_outputs(bool force_check) const; Node* node_; const ProcessedFunction* fn_; diff --git a/torch/csrc/jit/runtime/static/memory_planner.h b/torch/csrc/jit/runtime/static/memory_planner.h index 8110a83dba968..018b8947a07cf 100644 --- a/torch/csrc/jit/runtime/static/memory_planner.h +++ b/torch/csrc/jit/runtime/static/memory_planner.h @@ -172,15 +172,15 @@ class MemoryPlanner { return managed_output_tensors_.size(); } - C10_NODISCARD size_t total_num_unmanaged() const { + [[nodiscard]] size_t total_num_unmanaged() const { return num_unmanaged_non_scalars() + num_unmanaged_scalars(); } - C10_NODISCARD size_t num_unmanaged_non_scalars() const { + [[nodiscard]] size_t num_unmanaged_non_scalars() const { return unmanaged_ivalues_.size() + unmanaged_borrowed_ivalues_.size(); } - C10_NODISCARD size_t num_unmanaged_scalars() const { + [[nodiscard]] size_t num_unmanaged_scalars() const { return num_unmanaged_scalar_ivalues_; } diff --git a/torch/csrc/jit/runtime/static/passes.cpp b/torch/csrc/jit/runtime/static/passes.cpp index 68fd8a270c026..7fa31fe933679 100644 --- a/torch/csrc/jit/runtime/static/passes.cpp +++ b/torch/csrc/jit/runtime/static/passes.cpp @@ -37,8 +37,8 @@ bool forwardHasOp( } namespace { -C10_UNUSED -void ConcatAddMulReplaceNaNClip(std::shared_ptr& graph) { +[[maybe_unused]] void ConcatAddMulReplaceNaNClip( + std::shared_ptr& graph) { // TODO:: check restrictions for inputs; outputs not used elsewhere std::string pattern = R"IR( graph(%a, %b, %c, %d, %e, %f, %g, %h, %i, %j): @@ -91,8 +91,8 @@ void ConcatAddMulReplaceNaNClip(std::shared_ptr& graph) { fuse.runOnGraph(graph); } -C10_UNUSED -void CastedBatchOneHotLengths(std::shared_ptr& graph) { +[[maybe_unused]] void CastedBatchOneHotLengths( + std::shared_ptr& graph) { // TODO:: check restrictions for inputs; outputs not used elsewhere std::string pattern = R"IR( graph(%a, %b, %c, %d, %e, %f, %g): @@ -122,8 +122,8 @@ void CastedBatchOneHotLengths(std::shared_ptr& graph) { fuse.runOnGraph(graph); } -C10_UNUSED -void ConcatBatchMatMulBatchGather(std::shared_ptr& graph) { +[[maybe_unused]] void ConcatBatchMatMulBatchGather( + std::shared_ptr& graph) { std::string pattern = R"IR( graph(%a, %b, %c, %d, %e, %f): %y0 : Tensor = aten::stack(%a, %b) @@ -171,7 +171,7 @@ void ConcatBatchMatMulBatchGather(std::shared_ptr& graph) { fuse.runOnGraph(graph); } -C10_UNUSED void ClipRangesGatherRangesLengthsToOffsets( +[[maybe_unused]] void ClipRangesGatherRangesLengthsToOffsets( std::shared_ptr& graph) { // TODO:: check restrictions for inputs; outputs not used elsewhere std::string pattern = R"IR( @@ -189,7 +189,8 @@ C10_UNUSED void ClipRangesGatherRangesLengthsToOffsets( fuse.runOnGraph(graph); } -C10_UNUSED void ClipRangesGather(std::shared_ptr& graph) { +[[maybe_unused]] void ClipRangesGather( + std::shared_ptr& graph) { // TODO:: check restrictions for inputs; outputs not used elsewhere // fuse without lengths-to-offsets std::string pattern = R"IR( @@ -206,7 +207,7 @@ C10_UNUSED void ClipRangesGather(std::shared_ptr& graph) { fuse.runOnGraph(graph); } -C10_UNUSED void PrecomputeMultiplierShiftForSigridHash( +[[maybe_unused]] void PrecomputeMultiplierShiftForSigridHash( std::shared_ptr& graph) { std::string pattern = R"IR( graph(%a, %b, %c, %d, %e): @@ -224,7 +225,7 @@ C10_UNUSED void PrecomputeMultiplierShiftForSigridHash( fuse.runOnGraph(graph); } -C10_UNUSED void ClipRangesToGatherToOffsets( +[[maybe_unused]] void ClipRangesToGatherToOffsets( std::shared_ptr& graph) { std::string pattern = R"IR( graph(%a, %b, %c, %d, %to0_in0, %to0_in1, %to0_in2): @@ -254,7 +255,8 @@ C10_UNUSED void ClipRangesToGatherToOffsets( fuse.runOnGraph(graph); } -C10_UNUSED void ToLengthsToOffsets(std::shared_ptr& graph) { +[[maybe_unused]] void ToLengthsToOffsets( + std::shared_ptr& graph) { std::string pattern = R"IR( graph(%a, %includelastoffset, %dtype, %nonblocking, %copy, %memoryformat): %y0 : Tensor = aten::to(%a, %dtype, %nonblocking, %copy, %memoryformat) @@ -281,8 +283,8 @@ C10_UNUSED void ToLengthsToOffsets(std::shared_ptr& graph) { fuse.runOnGraph(graph); } -C10_UNUSED -void ClipRangesGatherSigridHash(std::shared_ptr& graph) { +[[maybe_unused]] void ClipRangesGatherSigridHash( + std::shared_ptr& graph) { // TODO:: check restrictions for inputs; outputs not used elsewhere std::string pattern = R"IR( graph(%a, %b, %c, %d, %e, %f, %g, %h): @@ -298,7 +300,7 @@ void ClipRangesGatherSigridHash(std::shared_ptr& graph) { fuse.runOnGraph(graph); } -C10_UNUSED void ClipRangesGatherRangesSigridHash( +[[maybe_unused]] void ClipRangesGatherRangesSigridHash( std::shared_ptr& graph) { std::string pattern = R"IR( graph(%a, %b, %c, %d, %e, %f, %g): @@ -316,7 +318,7 @@ C10_UNUSED void ClipRangesGatherRangesSigridHash( fuse.runOnGraph(graph); } -C10_UNUSED void ClipRangesGatherRangesX2SigridHashPrecompute( +[[maybe_unused]] void ClipRangesGatherRangesX2SigridHashPrecompute( std::shared_ptr& graph) { // Placeholder is a dummy op used to capture the first subgraph std::string pattern = R"IR( @@ -357,7 +359,7 @@ C10_UNUSED void ClipRangesGatherRangesX2SigridHashPrecompute( fuse.runOnGraph(graph); } -C10_UNUSED void SplitOutPrecomputeOpsForSparseNN( +[[maybe_unused]] void SplitOutPrecomputeOpsForSparseNN( std::shared_ptr& graph) { #ifdef FBCODE_CAFFE2 PrecomputeMultiplierShiftForSigridHash(graph); @@ -1295,12 +1297,12 @@ void UseSplitAndSqueeze(std::shared_ptr& graph) { } } -C10_UNUSED void RemoveUnnecessaryOutputs( +[[maybe_unused]] void RemoveUnnecessaryOutputs( std::shared_ptr& graph) { RemoveUnnecessaryEmbeddingBagOutputs(graph); } -C10_UNUSED void RemoveUnnecessaryEmbeddingBagOutputs( +[[maybe_unused]] void RemoveUnnecessaryEmbeddingBagOutputs( std::shared_ptr& graph) { std::string pattern = R"IR( graph(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset): diff --git a/torch/csrc/jit/runtime/vararg_functions.cpp b/torch/csrc/jit/runtime/vararg_functions.cpp index c102e0c61ad84..9428b892bd7a1 100644 --- a/torch/csrc/jit/runtime/vararg_functions.cpp +++ b/torch/csrc/jit/runtime/vararg_functions.cpp @@ -130,7 +130,7 @@ void format(Stack& stack, size_t num_inputs) { } ss << format.substr(begin, loc - begin); if (used_args >= args.size()) { - AT_ERROR("Too few arguments for format string: ", format); + TORCH_CHECK(false, "Too few arguments for format string: ", format); } ss << args[used_args]; begin = loc + 2; diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp index ee83ff78444f0..fd6dfa6f8cd47 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer.cpp +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.cpp @@ -784,7 +784,8 @@ flatbuffers::Offset FlatbufferSerializer:: ival_pos) .Union(); } else { - AT_ERROR("Invalid IValue type for serialization: ", ivalue.tagKind()); + TORCH_CHECK( + false, "Invalid IValue type for serialization: ", ivalue.tagKind()); } return CreateIValue(fbb, ivalue_type, offset); } diff --git a/torch/csrc/jit/serialization/flatbuffer_serializer.h b/torch/csrc/jit/serialization/flatbuffer_serializer.h index 41fb52415a129..5474e48ccf1fc 100644 --- a/torch/csrc/jit/serialization/flatbuffer_serializer.h +++ b/torch/csrc/jit/serialization/flatbuffer_serializer.h @@ -32,15 +32,15 @@ class TORCH_API DetachedBuffer final { : data_(data), size_(size), data_owner_(internal_data_owner) {} /// Returns a pointer to the data. - C10_NODISCARD void* data() { + [[nodiscard]] void* data() { return data_; } /// Returns a pointer to the data. - C10_NODISCARD const void* data() const { + [[nodiscard]] const void* data() const { return data_; } /// Returns the size of the data, in bytes. - C10_NODISCARD size_t size() const { + [[nodiscard]] size_t size() const { return size_; } diff --git a/torch/csrc/jit/serialization/import.cpp b/torch/csrc/jit/serialization/import.cpp index 8770484554e9c..ad2b58695a7ce 100644 --- a/torch/csrc/jit/serialization/import.cpp +++ b/torch/csrc/jit/serialization/import.cpp @@ -264,7 +264,7 @@ Module ScriptModuleDeserializer::deserialize( } } if (reader_->hasRecord("model.json") && code_prefix_ == "code/") { - AT_ERROR("Legacy model format is not supported on mobile."); + TORCH_CHECK(false, "Legacy model format is not supported on mobile."); } auto tuple = readArchive("constants").toTuple(); for (auto constant : tuple->elements()) { diff --git a/torch/csrc/jit/serialization/pickle.cpp b/torch/csrc/jit/serialization/pickle.cpp index 0fdaf0bcf3672..4bf6189a5bf59 100644 --- a/torch/csrc/jit/serialization/pickle.cpp +++ b/torch/csrc/jit/serialization/pickle.cpp @@ -96,7 +96,8 @@ std::vector pickle_save(const at::IValue& ivalue) { writer); return container_data; #else - AT_ERROR( + TORCH_CHECK( + false, "pickle_save not supported on mobile " "(see https://github.com/pytorch/pytorch/pull/30108)"); #endif @@ -136,7 +137,8 @@ IValue pickle_load(const std::vector& data) { /*device=*/std::nullopt, reader); #else - AT_ERROR( + TORCH_CHECK( + false, "pickle_load not supported on mobile " "(see https://github.com/pytorch/pytorch/pull/30108)"); #endif @@ -153,10 +155,11 @@ c10::IValue pickle_load_obj(std::string_view data) { /*tensor_prefix=*/"", /*type_resolver=*/customClassResolver, /*obj_loader=*/torch::jit::ObjLoaderFunc, - /*device=*/c10::nullopt, + /*device=*/std::nullopt, reader); #else - AT_ERROR( + TORCH_CHECK( + false, "pickle_load not supported on mobile " "(see https://github.com/pytorch/pytorch/pull/30108)"); #endif diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 9fdd4d4ea777c..98f6f14398211 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -135,7 +135,7 @@ void Pickler::pushIValueImpl(const IValue& ivalue) { } err << ". Please define serialization methods via def_pickle() for " "this class."; - AT_ERROR(err.str()); + TORCH_CHECK(false, err.str()); } else if (ivalue.isRRef()) { #ifdef USE_RPC TORCH_CHECK( @@ -154,7 +154,7 @@ void Pickler::pushIValueImpl(const IValue& ivalue) { pushIValue(enum_holder->value()); push(PickleOpCode::REDUCE); } else { - AT_ERROR("Unknown IValue type for pickling: ", ivalue.tagKind()); + TORCH_CHECK(false, "Unknown IValue type for pickling: ", ivalue.tagKind()); } } diff --git a/torch/csrc/jit/serialization/pickler.h b/torch/csrc/jit/serialization/pickler.h index 9be9b0fb2d8c1..cf45166d464d7 100644 --- a/torch/csrc/jit/serialization/pickler.h +++ b/torch/csrc/jit/serialization/pickler.h @@ -216,7 +216,7 @@ class TORCH_API Pickler { // the left of a '::', its type cannot be deduced by the compiler so one must // explicitly instantiate the template, i.e. push(int) works, push(int) // does not) - static CONSTEXPR_EXCEPT_WIN_CUDA size_t kBufferSize = 256; + static constexpr size_t kBufferSize = 256; template void push(std::common_type_t value) { const char* begin = reinterpret_cast(&value); diff --git a/torch/csrc/jit/serialization/python_print.cpp b/torch/csrc/jit/serialization/python_print.cpp index 39195e3752ff1..4077404d4bd08 100644 --- a/torch/csrc/jit/serialization/python_print.cpp +++ b/torch/csrc/jit/serialization/python_print.cpp @@ -436,8 +436,7 @@ struct PythonPrintImpl { size_t level = 0; // indent to the current indent level TaggedStringStream& indent() { - for (const auto i : c10::irange(level)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(level)) { body_ << " "; } return body_; @@ -455,7 +454,7 @@ struct PythonPrintImpl { auto it_b = list_b.begin(); if (list_a.size() != list_b.size()) { - AT_ERROR("Python printer expected 2 lists of same size"); + TORCH_CHECK(false, "Python printer expected 2 lists of same size"); } for (; it_a != list_a.end(); ++it_a, ++it_b) { @@ -1299,8 +1298,7 @@ struct PythonPrintImpl { IValue createBroadList(dtype value, const int64_t& N) { c10::List repeated; repeated.reserve(N); - for (const auto i : c10::irange(N)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(N)) { repeated.push_back(value); } return repeated; diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 296a57cf0169b..fc95f7fe9a4a6 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -199,7 +199,8 @@ static void restoreContainerTypeTags( } else if (is(*type)) { ivalue.toList().unsafeSetElementType(type->containedType(0)); } else { - AT_ERROR("Unknown type for tag restoration: " + type->annotation_str()); + TORCH_CHECK( + false, "Unknown type for tag restoration: " + type->annotation_str()); } } @@ -625,7 +626,8 @@ PickleOpCode Unpickler::readInstruction() { device.is_hpu() || device.is_mps() || device.is_privateuseone()) { tensor = tensor.to(device, tensor.scalar_type()); } else if (device.type() != DeviceType::CPU) { - AT_ERROR( + TORCH_CHECK( + false, "supported devices include CPU, CUDA, HPU and ", c10::get_privateuse1_backend(), " however got ", @@ -660,7 +662,8 @@ PickleOpCode Unpickler::readInstruction() { stack_.begin() + static_cast(key_pos), stack_.end()); } break; default: { - AT_ERROR( + TORCH_CHECK( + false, "Unknown opcode for unpickling at ", // NOLINTNEXTLINE(performance-no-int-to-ptr) reinterpret_cast(opcode), @@ -708,7 +711,7 @@ void Unpickler::readGlobal( stack_.back().toList().unsafeSetElementType(IntType::get()); }); } else { - AT_ERROR("Unknown pickler class id", class_name); + TORCH_CHECK(false, "Unknown pickler class id", class_name); } } else if (module_name == "torch.jit._pickle") { if (class_name == "build_tensor_from_id") { @@ -758,7 +761,7 @@ void Unpickler::readGlobal( } else if (class_name == "build_boollist") { elem_type = BoolType::get(); } else { - AT_ERROR("Unknown pickler class id ", class_name); + TORCH_CHECK(false, "Unknown pickler class id ", class_name); } // Unpickle a list specialization (e.g. List[Tensor], List[int], ...) globals_.emplace_back([this, elem_type] { @@ -1090,7 +1093,7 @@ void Unpickler::readSlowWithBuffer(char* dest, size_t sz) { AT_ASSERT(sz <= buffer_.size()); buffer_remaining_ = reader_(buffer_.data(), buffer_.size()); if (buffer_remaining_ < needed) { - AT_ERROR("Unexpected end of pickler archive."); + TORCH_CHECK(false, "Unexpected end of pickler archive."); } memcpy(dest + from_old_buf, buffer_.data(), needed); buffer_pos_ = needed; // assignment (0'ed from read) @@ -1128,7 +1131,7 @@ std::string Unpickler::readBytes(size_t length) { const size_t needed = length - from_old_buf; size_t nread = reader_(&data[from_old_buf], needed); if (nread != needed) { - AT_ERROR("Unexpected end of pickler archive."); + TORCH_CHECK(false, "Unexpected end of pickler archive."); } buffer_remaining_ = 0; // buffer_pos_ has no meaning with buffer_remaining_ == 0. @@ -1170,7 +1173,7 @@ void Unpickler::readListElements(IValue list_ivalue, size_t start) { list.emplace_back(elem); } } else { - AT_ERROR("Unknown IValue list kind: ", list_ivalue.tagKind()); + TORCH_CHECK(false, "Unknown IValue list kind: ", list_ivalue.tagKind()); } stack_.erase( stack_.begin() + static_cast(start), stack_.end()); diff --git a/torch/csrc/jit/tensorexpr/cpp_intrinsics.h b/torch/csrc/jit/tensorexpr/cpp_intrinsics.h index caeeed693ff38..3149335ea30f9 100644 --- a/torch/csrc/jit/tensorexpr/cpp_intrinsics.h +++ b/torch/csrc/jit/tensorexpr/cpp_intrinsics.h @@ -8,13 +8,13 @@ constexpr auto cpp_intrinsics_definition = R"( namespace std { template ::value, int>::type = 0> + std::enable_if_t, int> = 0> T rsqrt(T v) { return 1.0f / std::sqrt(v); } template ::value, int>::type = 0> + std::enable_if_t, int> = 0> T frac(T v) { T intpart; return std::modf(v, &intpart); diff --git a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp index a9f24139e029f..5fe52a0ff9e0e 100644 --- a/torch/csrc/jit/tensorexpr/llvm_codegen.cpp +++ b/torch/csrc/jit/tensorexpr/llvm_codegen.cpp @@ -1075,14 +1075,16 @@ void LLVMCodeGenImpl::visit(const CompareSelectPtr& v) { } template -typename std::enable_if::value, llvm::Value*>::type -getFromType(llvm::Type* type, T value) { +std::enable_if_t, llvm::Value*> getFromType( + llvm::Type* type, + T value) { return llvm::ConstantInt::get(type, value, std::is_signed::value); } template -typename std::enable_if::value, llvm::Value*>::type -getFromType(llvm::Type* type, T value) { +std::enable_if_t, llvm::Value*> getFromType( + llvm::Type* type, + T value) { return llvm::ConstantFP::get(type, value); } diff --git a/torch/csrc/jit/tensorexpr/lowerings.cpp b/torch/csrc/jit/tensorexpr/lowerings.cpp index dfe11d859b34c..ca56a7f95b7ea 100644 --- a/torch/csrc/jit/tensorexpr/lowerings.cpp +++ b/torch/csrc/jit/tensorexpr/lowerings.cpp @@ -1990,7 +1990,7 @@ int nnc_lowerings_lazy_registration() { } // namespace NNCLoweringFunction getStandardLoweringFor(const std::string& schema_str) { - C10_UNUSED static const int once = nnc_lowerings_lazy_registration(); + [[maybe_unused]] static const int once = nnc_lowerings_lazy_registration(); const auto& lowerings = getNNCLoweringRegistry(); if (auto l = lowerings.find(parseSchema(schema_str))) { return *l; diff --git a/torch/csrc/jit/tensorexpr/operators/misc.cpp b/torch/csrc/jit/tensorexpr/operators/misc.cpp index fce41388561a2..fab35357c83b1 100644 --- a/torch/csrc/jit/tensorexpr/operators/misc.cpp +++ b/torch/csrc/jit/tensorexpr/operators/misc.cpp @@ -12,7 +12,7 @@ int64_t normalizeAndCheckIndex(int64_t idx, int64_t list_size) { } if (idx < 0 || idx >= list_size) { - AT_ERROR("Invalid index ", idx, " for list_size", list_size); + TORCH_CHECK(false, "Invalid index ", idx, " for list_size", list_size); } return idx; } diff --git a/torch/csrc/jit/testing/file_check.cpp b/torch/csrc/jit/testing/file_check.cpp index e9bf764c31575..97273ef4a110c 100644 --- a/torch/csrc/jit/testing/file_check.cpp +++ b/torch/csrc/jit/testing/file_check.cpp @@ -506,13 +506,10 @@ struct FileCheckImpl { end_range = start_range + check.search_str_.size(); break; } - case CHECK_DAG: { - AT_ERROR(); - } break; - case CHECK_NOT: { - AT_ERROR(); - } break; + default: + TORCH_CHECK(false); } + return SourceRange(source, start_range, end_range); } diff --git a/torch/csrc/lazy/backend/backend_device.h b/torch/csrc/lazy/backend/backend_device.h index 8c274b8fc8b1f..3a4a722323f0c 100644 --- a/torch/csrc/lazy/backend/backend_device.h +++ b/torch/csrc/lazy/backend/backend_device.h @@ -84,6 +84,7 @@ TORCH_API std::optional GetBackendDevice( // For variadic template. TORCH_API std::optional GetBackendDevice(); +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Winfinite-recursion") template std::optional GetBackendDevice( const T& tensor, @@ -94,5 +95,6 @@ std::optional GetBackendDevice( } return GetBackendDevice(forward_tensors...); } +C10_DIAGNOSTIC_POP() } // namespace torch::lazy diff --git a/torch/csrc/lazy/core/ir.cpp b/torch/csrc/lazy/core/ir.cpp index 599f491c008ea..fe9cfba2556c0 100644 --- a/torch/csrc/lazy/core/ir.cpp +++ b/torch/csrc/lazy/core/ir.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -57,7 +58,7 @@ hash_t OpKind::hash() const { } bool Node::enableDynamicShape() { - static bool enabled = std::getenv("LTC_ENABLE_DYNAMIC_SHAPES") != nullptr; + static bool enabled = c10::utils::has_env("LTC_ENABLE_DYNAMIC_SHAPES"); return enabled || FLAGS_ltc_enable_dynamic_shapes; } diff --git a/torch/csrc/lazy/core/ir_util.cpp b/torch/csrc/lazy/core/ir_util.cpp index e61acd79540c0..814bb5a54b0cd 100644 --- a/torch/csrc/lazy/core/ir_util.cpp +++ b/torch/csrc/lazy/core/ir_util.cpp @@ -1,5 +1,7 @@ #include +#include + #include namespace torch::lazy { @@ -8,17 +10,17 @@ std::vector Util::ComputePostOrder( const Node* node, EmissionMap* emap) { std::vector post_order; - std::vector queue; - queue.push_back(node); - while (!queue.empty()) { - node = queue.back(); + std::stack node_stack; + node_stack.push(node); + while (!node_stack.empty()) { + node = node_stack.top(); auto it = emap->find(node); if (it == emap->end()) { (*emap)[node] = kEmitting; for (auto& output : node->operands()) { auto oit = emap->find(output.node); if (oit == emap->end()) { - queue.push_back(output.node); + node_stack.push(output.node); } else { TORCH_CHECK( oit->second != kEmitting, @@ -36,10 +38,10 @@ std::vector Util::ComputePostOrder( } (*emap)[node] = kEmitted; post_order.push_back(node); - queue.pop_back(); + node_stack.pop(); } else { TORCH_CHECK(it->second == kEmitted); - queue.pop_back(); + node_stack.pop(); } } return post_order; diff --git a/torch/csrc/lazy/core/lazy_graph_executor.cpp b/torch/csrc/lazy/core/lazy_graph_executor.cpp index bb6808796216b..96af97eef0e3e 100644 --- a/torch/csrc/lazy/core/lazy_graph_executor.cpp +++ b/torch/csrc/lazy/core/lazy_graph_executor.cpp @@ -19,8 +19,6 @@ #include -#include - namespace torch::lazy { namespace { diff --git a/torch/csrc/lazy/core/shape_inference.cpp b/torch/csrc/lazy/core/shape_inference.cpp index f192b8431a1b2..f0ebaee9ddf04 100644 --- a/torch/csrc/lazy/core/shape_inference.cpp +++ b/torch/csrc/lazy/core/shape_inference.cpp @@ -85,7 +85,7 @@ static std::vector expand_param_if_needed( ss << "expected " << param_name << " to be a single integer value or a " << "list of " << expected_dim << " values to match the convolution " << "dimensions, but got " << param_name << "=" << list_param; - AT_ERROR(ss.str()); + TORCH_CHECK(false, ss.str()); } else { return list_param.vec(); } diff --git a/torch/csrc/lazy/core/shape_inference.h b/torch/csrc/lazy/core/shape_inference.h index 76ddea597a784..7a44454da654a 100644 --- a/torch/csrc/lazy/core/shape_inference.h +++ b/torch/csrc/lazy/core/shape_inference.h @@ -13,8 +13,7 @@ #include #include -namespace torch { -namespace lazy { +namespace torch::lazy { // Turn clang-format off, as we rely on the whole signature being on one line // for codegen. // clang-format off @@ -120,5 +119,4 @@ TORCH_API std::vector compute_shape_diagonal_scatter(const a TORCH_API std::vector compute_shape_slice_scatter_symint(const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step); TORCH_API std::vector compute_shape_as_strided_scatter_symint(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); // clang-format on -} // namespace lazy -} // namespace torch +} // namespace torch::lazy diff --git a/torch/csrc/lazy/core/thread_pool.cpp b/torch/csrc/lazy/core/thread_pool.cpp index 9481fc52f4946..3f87aaa96b519 100644 --- a/torch/csrc/lazy/core/thread_pool.cpp +++ b/torch/csrc/lazy/core/thread_pool.cpp @@ -19,8 +19,7 @@ class ThreadPool { public: explicit ThreadPool(size_t num_threads) { threads_.reserve(num_threads); - for (const auto i : c10::irange(num_threads)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(num_threads)) { threads_.emplace_back([this]() { c10::setThreadName("pt_thread_pool"); Worker(); diff --git a/torch/csrc/lazy/python/init.cpp b/torch/csrc/lazy/python/init.cpp index 616ce56b697e9..f30615355e0e2 100644 --- a/torch/csrc/lazy/python/init.cpp +++ b/torch/csrc/lazy/python/init.cpp @@ -18,10 +18,10 @@ #include #endif // FBCODE_CAFFE2 || OVRSOURCE #include +#include #include -namespace torch { -namespace lazy { +namespace torch::lazy { // TODO(whc) backend 'device' related APIs are not very clear, this code could // be simplified but it should probably be done together with @@ -190,10 +190,10 @@ void initLazyBindings(PyObject* module) { return torch::lazy::getLTCForceFallback(); }); lazy.def("_set_force_fallback", [](std::string newval) { - torch::lazy::getLTCForceFallback() = newval; + torch::lazy::getLTCForceFallback() = std::move(newval); }); lazy.def("_clear_ir_cache", []() { TrieCache::Get()->Clear(); }); - lazy.def("_dump_ir_cache", [](std::string filename) { + lazy.def("_dump_ir_cache", [](const std::string& filename) { TrieCache::Get()->DumpToDotFile(filename); }); lazy.def("_set_reuse_ir", [](bool val) { FLAGS_torch_lazy_reuse_ir = val; }); @@ -337,5 +337,4 @@ void initLazyBindings(PyObject* module) { #endif // USE_DEPLOY } -} // namespace lazy -} // namespace torch +} // namespace torch::lazy diff --git a/torch/csrc/lazy/python/init.h b/torch/csrc/lazy/python/init.h index 5bdc5a9722908..12ab6e3ee7d50 100644 --- a/torch/csrc/lazy/python/init.h +++ b/torch/csrc/lazy/python/init.h @@ -3,10 +3,8 @@ #include #include -namespace torch { -namespace lazy { +namespace torch::lazy { TORCH_PYTHON_API void initLazyBindings(PyObject* module); -} // namespace lazy -} // namespace torch +} // namespace torch::lazy diff --git a/torch/csrc/lazy/python/python_util.cpp b/torch/csrc/lazy/python/python_util.cpp index 1ae663c519f56..5568d5f79a7c3 100644 --- a/torch/csrc/lazy/python/python_util.cpp +++ b/torch/csrc/lazy/python/python_util.cpp @@ -8,8 +8,7 @@ #include #include -namespace torch { -namespace lazy { +namespace torch::lazy { std::optional GetPythonFrameTop() { if (!Py_IsInitialized()) { @@ -51,5 +50,4 @@ std::vector GetPythonFrames() { return frames; } -} // namespace lazy -} // namespace torch +} // namespace torch::lazy diff --git a/torch/csrc/lazy/python/python_util.h b/torch/csrc/lazy/python/python_util.h index 271c694ee35dd..6399b224dbffb 100644 --- a/torch/csrc/lazy/python/python_util.h +++ b/torch/csrc/lazy/python/python_util.h @@ -4,12 +4,10 @@ #include #include -namespace torch { -namespace lazy { +namespace torch::lazy { std::optional TORCH_PYTHON_API GetPythonFrameTop(); std::vector TORCH_PYTHON_API GetPythonFrames(); -} // namespace lazy -} // namespace torch +} // namespace torch::lazy diff --git a/torch/csrc/lazy/ts_backend/ops/device_data.cpp b/torch/csrc/lazy/ts_backend/ops/device_data.cpp index bd80fcd7fe613..8567f1d2ed8ce 100644 --- a/torch/csrc/lazy/ts_backend/ops/device_data.cpp +++ b/torch/csrc/lazy/ts_backend/ops/device_data.cpp @@ -5,8 +5,7 @@ #include -namespace torch { -namespace lazy { +namespace torch::lazy { DeviceData::DeviceData(std::shared_ptr data) : TsNode( @@ -26,7 +25,7 @@ const DeviceData* DeviceData::Cast(const Node* node) { return NodeCast(node); } -NodePtr DeviceData::Create(std::shared_ptr data) { +NodePtr DeviceData::Create(const std::shared_ptr& data) { NodePtr node = ReuseOrMakeNode(data); // ReuseOrMakeNode may return a reused node which has the same shape, // however, we need to replace the old data_ with the new one. @@ -38,5 +37,4 @@ NodePtr DeviceData::Create(std::shared_ptr data) { return node; } -} // namespace lazy -} // namespace torch +} // namespace torch::lazy diff --git a/torch/csrc/lazy/ts_backend/ops/device_data.h b/torch/csrc/lazy/ts_backend/ops/device_data.h index 53e7814fc39a4..1cbfa5c3b63ae 100644 --- a/torch/csrc/lazy/ts_backend/ops/device_data.h +++ b/torch/csrc/lazy/ts_backend/ops/device_data.h @@ -4,8 +4,9 @@ #include #include -namespace torch { -namespace lazy { +#include + +namespace torch::lazy { class TORCH_API DeviceData : public TsNode { public: @@ -18,7 +19,7 @@ class TORCH_API DeviceData : public TsNode { // A DeviceData node can be reused if the shape matches, // but we will substitute the actual data_ pointer under // the hood. - bool CanBeReused(std::shared_ptr data) const { + bool CanBeReused(const std::shared_ptr& data) const { return data_->shape() == data->shape(); } @@ -29,14 +30,14 @@ class TORCH_API DeviceData : public TsNode { } void SetData(std::shared_ptr data) { - data_ = data; + data_ = std::move(data); } static const DeviceData* Cast(const Node* node); // To reuse IR nodes, use this method to create DeviceData nodes - // instead of calling the constructor directly. - static NodePtr Create(std::shared_ptr data); + // instead of calling the constructor directconst ly. + static NodePtr Create(const std::shared_ptr& data); TSOpVector Lower( std::shared_ptr function, @@ -46,5 +47,4 @@ class TORCH_API DeviceData : public TsNode { std::shared_ptr data_; }; -} // namespace lazy -} // namespace torch +} // namespace torch::lazy diff --git a/torch/csrc/lazy/ts_backend/ops/generic.cpp b/torch/csrc/lazy/ts_backend/ops/generic.cpp index 774bccd0df022..6c14a44b96e46 100644 --- a/torch/csrc/lazy/ts_backend/ops/generic.cpp +++ b/torch/csrc/lazy/ts_backend/ops/generic.cpp @@ -1,7 +1,6 @@ #include -namespace torch { -namespace lazy { +namespace torch::lazy { Generic::Generic( OpKind op, @@ -32,5 +31,4 @@ Generic::Generic(OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed) : TsNode(op, std::move(shape), num_outputs, hash_seed), hash_seed_(hash_seed) {} -} // namespace lazy -} // namespace torch +} // namespace torch::lazy diff --git a/torch/csrc/lazy/ts_backend/ops/generic.h b/torch/csrc/lazy/ts_backend/ops/generic.h index c605aaa437cc9..507ac0e0cf81b 100644 --- a/torch/csrc/lazy/ts_backend/ops/generic.h +++ b/torch/csrc/lazy/ts_backend/ops/generic.h @@ -4,8 +4,7 @@ #include -namespace torch { -namespace lazy { +namespace torch::lazy { // Generic IR Node implementation for nodes which can simply be described by a // specific OpKind and a lowering function. IR nodes carrying @@ -50,5 +49,4 @@ inline NodePtr GenericOp( op, operands, std::move(shape), num_outputs, hash_seed); } -} // namespace lazy -} // namespace torch +} // namespace torch::lazy diff --git a/torch/csrc/lazy/ts_backend/ops/to_copy.h b/torch/csrc/lazy/ts_backend/ops/to_copy.h index 3a5f47411dfdd..53e0d76689c76 100644 --- a/torch/csrc/lazy/ts_backend/ops/to_copy.h +++ b/torch/csrc/lazy/ts_backend/ops/to_copy.h @@ -2,8 +2,7 @@ #include -namespace torch { -namespace lazy { +namespace torch::lazy { // This IR was copied from code-generated output, but the entire _to_copy // operator cannot be trivially code genereated since it is only desirable to @@ -123,5 +122,4 @@ class ToCopy : public torch::lazy::TsNode { std::optional memory_format; }; -} // namespace lazy -} // namespace torch +} // namespace torch::lazy diff --git a/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp b/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp index 7bd808c1333f1..4a1653cc176bd 100644 --- a/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp +++ b/torch/csrc/lazy/ts_backend/ts_backend_impl.cpp @@ -42,7 +42,7 @@ class TSBackendImpl : public torch::lazy::BackendImplInterface { public: TSBackendImpl() { // TODO(whc) unify how all our flags are set and parsed as envs - static bool env_use_cuda = std::getenv("LTC_TS_CUDA") != nullptr; + static bool env_use_cuda = c10::utils::has_env("LTC_TS_CUDA"); auto type = (env_use_cuda || FLAGS_torch_lazy_ts_cuda) ? at::kCUDA : at::kCPU; default_device_type_ = std::make_shared(type); diff --git a/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp b/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp index dd132c5fac051..ca7f8e97ae343 100644 --- a/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp +++ b/torch/csrc/lazy/ts_backend/ts_eager_fallback.cpp @@ -106,7 +106,7 @@ c10::DispatchKey dispatch_key(c10::DeviceType device_type) { return c10::DispatchKey::CUDA; } default: { - AT_ERROR("Unsupported device type: ", device_type); + TORCH_CHECK(false, "Unsupported device type: ", device_type); } } } diff --git a/torch/csrc/lazy/ts_backend/ts_node.cpp b/torch/csrc/lazy/ts_backend/ts_node.cpp index 172e07f94306e..46cbc31ca058e 100644 --- a/torch/csrc/lazy/ts_backend/ts_node.cpp +++ b/torch/csrc/lazy/ts_backend/ts_node.cpp @@ -1,11 +1,12 @@ +#include #include #include namespace { std::string GetFirstUserFrameInPythonIfEnabled() { static const auto LTC_ENABLE_SOURCE_INFO = - std::getenv("LTC_ENABLE_SOURCE_INFO"); - if (!LTC_ENABLE_SOURCE_INFO) { + c10::utils::has_env("LTC_ENABLE_SOURCE_INFO"); + if (LTC_ENABLE_SOURCE_INFO) { return {}; } diff --git a/torch/csrc/mtia/Module.cpp b/torch/csrc/mtia/Module.cpp index 506ad0a0ee466..37624b3737d67 100644 --- a/torch/csrc/mtia/Module.cpp +++ b/torch/csrc/mtia/Module.cpp @@ -39,7 +39,7 @@ void initModule(PyObject* module) { m.def("_mtia_init", []() { TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level poison_fork(); - at::globalContext().lazyInitMTIA(); + at::globalContext().lazyInitDevice(c10::DeviceType::MTIA); }); m.def("_mtia_isBuilt", []() { diff --git a/torch/csrc/multiprocessing/init.h b/torch/csrc/multiprocessing/init.h index 0adf0b8ddbc36..1773060e07c48 100644 --- a/torch/csrc/multiprocessing/init.h +++ b/torch/csrc/multiprocessing/init.h @@ -2,10 +2,8 @@ #include -namespace torch { -namespace multiprocessing { +namespace torch::multiprocessing { PyMethodDef* python_functions(); -} // namespace multiprocessing -} // namespace torch +} // namespace torch::multiprocessing diff --git a/torch/csrc/profiler/collection.cpp b/torch/csrc/profiler/collection.cpp index df3bb265bf8e3..eefe5621a293e 100644 --- a/torch/csrc/profiler/collection.cpp +++ b/torch/csrc/profiler/collection.cpp @@ -195,7 +195,8 @@ auto InputOutputEncoder::getIValueGenerator(const IOType& io_type) { return {RawTensorMetadata(), sizes, strides}; } const auto& raw_metadata = *tensor_metadata_it++; - for (C10_UNUSED const auto _ : c10::irange(raw_metadata.size_dim_)) { + for ([[maybe_unused]] const auto _ : + c10::irange(raw_metadata.size_dim_)) { if (tensor_size_strides_it.exhausted()) { LOG(WARNING) << "Expected Tensor Size mismatch with raw Tensor metadata. Reported shapes may be inaccurate!"; @@ -204,7 +205,8 @@ auto InputOutputEncoder::getIValueGenerator(const IOType& io_type) { sizes.push_back(*tensor_size_strides_it++); } if (raw_metadata.layout_ == at::kStrided) { - for (C10_UNUSED const auto _ : c10::irange(raw_metadata.size_dim_)) { + for ([[maybe_unused]] const auto _ : + c10::irange(raw_metadata.size_dim_)) { if (tensor_size_strides_it.exhausted()) { LOG(WARNING) << "Expected Tensor Strides mismatch with raw Tensor metadata. Reported shapes may be inaccurate!"; @@ -397,7 +399,7 @@ std::unique_ptr ThreadLocalSubqueue::begin_op( namespace { template struct StealOrDefault { - StealOrDefault(T& container) + explicit StealOrDefault(T& container) : container_{container}, it_{container.begin()} {} ~StealOrDefault() { @@ -429,7 +431,7 @@ void ThreadLocalSubqueue::TorchOpStorage::materialize( const kineto::DeviceAndResource& kineto_info) { // Plumb Autograd info to the top level annotation. auto it = op_events_.begin(); - for (C10_UNUSED const auto _ : + for ([[maybe_unused]] const auto _ : c10::irange(static_cast(op_events_.size()) - 1)) { auto& first = it->basic_fields_; auto& second = (++it)->basic_fields_; @@ -1060,7 +1062,7 @@ class TransferEvents { std::shared_ptr& r, std::shared_ptr parent) { r->visit(c10::overloaded( - [&](ExtraFields& i) { + [&]([[maybe_unused]] ExtraFields& i) { TORCH_INTERNAL_ASSERT(r->start_tid_ == noTID); r->start_tid_ = parent ? parent->start_tid_ : at::RecordFunction::currentThreadId(); @@ -1297,7 +1299,7 @@ int64_t adjust_durations_dfs(std::shared_ptr& r) { [&children_total_duration](ExtraFields& i) { i.duration_ns_ = children_total_duration; }, - [](ExtraFields& _) { + []([[maybe_unused]] ExtraFields& _) { // Pass- Allocation events can't have children }, [&](auto&) { @@ -1333,10 +1335,10 @@ int64_t adjust_timestamps_dfs( i.end_time_ns_ = new_start_time + (i.end_time_ns_ - r->start_time_ns_); }, - [](ExtraFields& i) { + []([[maybe_unused]] ExtraFields& i) { // Pass- We don't need to manually adjust end time for Vulkan events }, - [](ExtraFields& _) { + []([[maybe_unused]] ExtraFields& _) { // Pass- No duration or end time to adjust }, [&](auto&) { @@ -1475,20 +1477,26 @@ RecordQueue::getRecords( ProfilerStepInfo step = step_idx < step_info.size() ? step_info[step_idx] : defaultStep; for (const auto& i : ev) { - // If event has start time after step end time we can continue to the next - // step - while (i->start_time_ns_ > step.end_time_ns) { - step_idx++; - step = step_idx < step_info.size() ? step_info[step_idx] : defaultStep; - } - // If Step annotation starts before event and ends before event ends with - // intersection then we move the lefthand side of the step annotation to - // the event start time - if (right_intersection_only(step, i->start_time_ns_, i->endTimeNS())) { - auto currStepRes = out[step.out_idx]; - currStepRes->start_time_ns_ = i->start_time_ns_ + 1; - step_idx++; - step = step_idx < step_info.size() ? step_info[step_idx] : defaultStep; + // Only adjust timestamps if experimental config is enabled + if (config_.experimental_config.adjust_profiler_step) { + // If event has start time after step end time we can continue to the + // next step + while (i->start_time_ns_ > step.end_time_ns) { + step_idx++; + step = + step_idx < step_info.size() ? step_info[step_idx] : defaultStep; + } + // If Step annotation starts before event and ends before event ends + // with intersection then we move the lefthand side of the step + // annotation to the event start time + if (right_intersection_only(step, i->start_time_ns_, i->endTimeNS())) { + // NOLINTNEXTLINE(facebook-hte-LocalUncheckedArrayBounds) + auto currStepRes = out[step.out_idx]; + currStepRes->start_time_ns_ = i->start_time_ns_ + 1; + step_idx++; + step = + step_idx < step_info.size() ? step_info[step_idx] : defaultStep; + } } out.push_back(i); } diff --git a/torch/csrc/profiler/containers.h b/torch/csrc/profiler/containers.h index 6ff73917d9147..060c6e3b5341d 100644 --- a/torch/csrc/profiler/containers.h +++ b/torch/csrc/profiler/containers.h @@ -5,7 +5,6 @@ #include #include #include -#include #include #include @@ -52,7 +51,10 @@ class AppendOnlyList { AppendOnlyList() : buffer_last_{buffer_.before_begin()} {} AppendOnlyList(const AppendOnlyList&) = delete; + AppendOnlyList(AppendOnlyList&&) = delete; AppendOnlyList& operator=(const AppendOnlyList&) = delete; + AppendOnlyList& operator=(AppendOnlyList&&) = delete; + ~AppendOnlyList() = default; size_t size() const { return n_blocks_ * ChunkSize - (size_t)(end_ - next_); diff --git a/torch/csrc/profiler/kineto_shim.cpp b/torch/csrc/profiler/kineto_shim.cpp index ed37f83bf63ff..ef70242eafb35 100644 --- a/torch/csrc/profiler/kineto_shim.cpp +++ b/torch/csrc/profiler/kineto_shim.cpp @@ -96,8 +96,6 @@ TraceWrapper::TraceWrapper(const int64_t start_time, const std::string& name) } #endif // USE_KINETO -TraceWrapper::~TraceWrapper() = default; - activity_t* TraceWrapper::addCPUActivity( const std::string& name, const libkineto::ActivityType type, @@ -227,6 +225,7 @@ void prepareTrace( const ActivitySet& activities, const torch::profiler::impl::ExperimentalConfig& config) { #ifdef USE_KINETO + libkineto::api().resetKinetoTLS(); if (!libkineto::api().isProfilerRegistered()) { libkineto_init(/*cpuOnly=*/cpuOnly, /*logOnError=*/true); libkineto::api().suppressLogMessages(); diff --git a/torch/csrc/profiler/kineto_shim.h b/torch/csrc/profiler/kineto_shim.h index 44509e4a5e64e..085e9dd2fcb2d 100644 --- a/torch/csrc/profiler/kineto_shim.h +++ b/torch/csrc/profiler/kineto_shim.h @@ -67,9 +67,6 @@ void addMetadata( // Wraps: libkineto::CpuTraceBuffer struct TraceWrapper { TraceWrapper(const int64_t start_time, const std::string& name); - TraceWrapper(TraceWrapper&&) = default; - TraceWrapper(const TraceWrapper&) = delete; - ~TraceWrapper(); // The caller is expected to hold a mutex when calling `addCPUActivity`. activity_t* addCPUActivity( @@ -96,8 +93,6 @@ struct TraceWrapper { struct ActivityTraceWrapper { explicit ActivityTraceWrapper(std::unique_ptr&& trace); ActivityTraceWrapper() = default; - ActivityTraceWrapper(ActivityTraceWrapper&&) = default; - ActivityTraceWrapper(const ActivityTraceWrapper&) = delete; explicit operator bool() const; void save(const std::string& path); diff --git a/torch/csrc/profiler/orchestration/observer.cpp b/torch/csrc/profiler/orchestration/observer.cpp index cd9798a339ab5..39a8845cb8483 100644 --- a/torch/csrc/profiler/orchestration/observer.cpp +++ b/torch/csrc/profiler/orchestration/observer.cpp @@ -17,12 +17,14 @@ ExperimentalConfig::ExperimentalConfig( bool verbose, std::vector performance_events, bool enable_cuda_sync_events, + bool adjust_profiler_step, bool adjust_timestamps) : profiler_metrics{std::move(profiler_metrics)}, profiler_measure_per_kernel{profiler_measure_per_kernel}, verbose{verbose}, performance_events(std::move(performance_events)), enable_cuda_sync_events{enable_cuda_sync_events}, + adjust_profiler_step{adjust_profiler_step}, adjust_timestamps{adjust_timestamps} {} /*explicit*/ ExperimentalConfig::operator bool() const { diff --git a/torch/csrc/profiler/orchestration/observer.h b/torch/csrc/profiler/orchestration/observer.h index 35d9ce0d186d0..4475101efacc8 100644 --- a/torch/csrc/profiler/orchestration/observer.h +++ b/torch/csrc/profiler/orchestration/observer.h @@ -20,8 +20,10 @@ enum class C10_API_ENUM ActivityType { }; inline std::string actToString(ActivityType t) { - const std::string ActivityTypeNames[] = { - "CPU", "XPU", "CUDA", "MTIA", "PrivateUse1"}; + const std::array< + std::string, + static_cast(ActivityType::NUM_KINETO_ACTIVITIES)> + ActivityTypeNames = {"CPU", "XPU", "CUDA", "MTIA", "PrivateUse1"}; return ActivityTypeNames[static_cast(t)]; } @@ -55,6 +57,7 @@ struct TORCH_API ExperimentalConfig { bool verbose = false, std::vector performance_events = {}, bool enable_cuda_sync_events = false, + bool adjust_profiler_step = false, bool adjust_timestamps = false); explicit operator bool() const; @@ -72,6 +75,13 @@ struct TORCH_API ExperimentalConfig { * This feature is new and currently disabled by default. */ bool enable_cuda_sync_events; + /* + * Controls whether or not timestamp adjustment for ProfilerStep and parent + * Python events occurs after profiling. This occurs at an O(n) cost and + * affects only the start of profiler step events. + */ + bool adjust_profiler_step; + /* * Controls whether or not timestamp adjustment occurs after profiling. * The purpose of this is to adjust Vulkan event timelines to align with those @@ -86,7 +96,7 @@ struct TORCH_API ExperimentalConfig { }; struct TORCH_API ProfilerConfig { - ProfilerConfig( + explicit ProfilerConfig( ProfilerState state, bool report_input_shapes = false, bool profile_memory = false, diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index 46e70d90e7adb..1a859c58980b2 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -335,7 +335,8 @@ void initPythonBindings(PyObject* module) { bool /* profiler_measure_per_kernel */, bool /* verbose */, std::vector /* performance_events */, - bool /* enable_cuda_sync_events */ + bool /* enable_cuda_sync_events */, + bool /* adjust_profiler_step */ >(), "An experimental config for Kineto features. Please note that" "backward compatibility is not guaranteed.\n" @@ -348,12 +349,15 @@ void initPythonBindings(PyObject* module) { " performance_events : a list of profiler events to be used for measurement.\n" " enable_cuda_sync_events : for CUDA profiling mode, enable adding CUDA synchronization events\n" " that expose CUDA device, stream and event synchronization activities. This feature is new\n" - " and currently disabled by default.\n", + " and currently disabled by default.\n" + " adjust_profiler_step (bool) : whether to adjust the profiler step to\n" + " match the parent python event duration. This feature is new and currently disabled by default.\n", py::arg("profiler_metrics") = std::vector(), py::arg("profiler_measure_per_kernel") = false, py::arg("verbose") = false, py::arg("performance_events") = std::vector(), - py::arg("enable_cuda_sync_events") = false) + py::arg("enable_cuda_sync_events") = false, + py::arg("adjust_profiler_step") = false) .def(py::pickle( [](const ExperimentalConfig& p) { // __getstate__ py::list py_metrics; @@ -372,11 +376,12 @@ void initPythonBindings(PyObject* module) { p.profiler_measure_per_kernel, p.verbose, p.enable_cuda_sync_events, + p.adjust_profiler_step, p.performance_events); }, [](const py::tuple& t) { // __setstate__ - if (t.size() >= 4) { - throw std::runtime_error("Expected atleast 4 values in state"); + if (t.size() >= 5) { + throw std::runtime_error("Expected atleast 5 values in state"); } py::list py_metrics = t[0].cast(); @@ -400,7 +405,8 @@ void initPythonBindings(PyObject* module) { t[1].cast(), t[2].cast(), std::move(performance_events), - t[3].cast()); + t[3].cast(), + t[4].cast()); })); py::class_(m, "ProfilerConfig") diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp index cb110253c3346..9df13f071429e 100644 --- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp +++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp @@ -339,6 +339,8 @@ static void finalizeExecutionTraceOutput(ExecutionTraceObserver& ob) { inline ExecutionTraceObserver::ID getObjectID( ExecutionTraceObserver& ob, const void* t) { + const std::lock_guard lock(ob.gMutex); + auto iter = ob.objectId.find(t); if (iter == ob.objectId.end()) { ExecutionTraceObserver::ID objectId = ob.getNewID(); @@ -569,26 +571,29 @@ static void recordOperatorStart( auto tid = fn.threadId(); try { - const std::lock_guard lock(ob.gMutex); - - // if current thread stack is empty, push the root node to the stack first - if (ob.opStack[tid].empty()) { - auto thread_node_id = ob.getNewID(); - ob.opStack[tid].push(thread_node_id); - writeJsonNode( - ob.out, - "[pytorch|profiler|execution_trace|thread]", - thread_node_id, - 0, // rf_id - kRootId, - 0, // fw_parent - -1, // seq_id - static_cast>( - RecordScope::USER_SCOPE), - tid, - 0); // fw_tid - ob.out << ","; + { + const std::lock_guard lock(ob.gMutex); + + // if current thread stack is empty, push the root node to the stack first + if (ob.opStack[tid].empty()) { + auto thread_node_id = ob.getNewID(); + ob.opStack[tid].push(thread_node_id); + writeJsonNode( + ob.out, + "[pytorch|profiler|execution_trace|thread]", + thread_node_id, + 0, // rf_id + kRootId, + 0, // fw_parent + -1, // seq_id + static_cast>( + RecordScope::USER_SCOPE), + tid, + 0); // fw_tid + ob.out << ","; + } } + fc.name = fn.name(); auto num_inputs = fn.num_inputs(); const auto inputs = fn.inputs(); @@ -619,17 +624,21 @@ static void recordOperatorStart( handleKernelBackendInfo(fc, fn); - fc.parentId = ob.opStack[tid].top(); - // get parent id from the forward stack, this can be different for - // autograd ops, which may execute on a different thread than the original - // thread (which should have the parent op on the stack). - auto fw_tid = fn.forwardThreadId(); - if (fw_tid != 0) { - fc.fwParentId = ob.opStack[fw_tid].top(); + { + const std::lock_guard lock(ob.gMutex); + + fc.parentId = ob.opStack[tid].top(); + // get parent id from the forward stack, this can be different for + // autograd ops, which may execute on a different thread than the original + // thread (which should have the parent op on the stack). + auto fw_tid = fn.forwardThreadId(); + if (fw_tid != 0) { + fc.fwParentId = ob.opStack[fw_tid].top(); + } + // all input nodes should have id > opId + fc.opId = ob.getNewID(); + ob.opStack[tid].push(fc.opId); } - // all input nodes should have id > opId - fc.opId = ob.getNewID(); - ob.opStack[tid].push(fc.opId); } catch (const std::exception& e) { LOG(WARNING) << "Exception in execution trace observer: " << e.what(); @@ -712,10 +721,6 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) { std::vector output_shapes; std::vector output_values; try { - const std::lock_guard lock(ob->gMutex); - // remove current op id from stack - - ob->opStack[fn.threadId()].pop(); for (const auto i : c10::irange(output_start, outputs.size())) { appendValueInfo( *ob, @@ -734,31 +739,37 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) { const std::string additiona_attrs = fn.isNcclMeta() ? getCommsNodeAttrs(fn) : ""; - - writeJsonNode( - ob->out, - fc.name, - fc.opId, - fn.handle(), - fc.parentId, - fc.fwParentId, - fn.seqNr(), - static_cast>(fn.scope()), - fn.threadId(), - fn.forwardThreadId(), - vectorToString(fc.inputValues), - vectorToString(fc.inputShapes), - vectorToString(fc.inputStrides), - vectorToString(fc.inputTypes), - vectorToString(output_values), - vectorToString(output_shapes), - vectorToString(output_strides), - vectorToString(output_types), - op_schema_str, - fc.kernelBackend, - fc.kernelFile, - additiona_attrs); - ob->out << ","; + { + const std::lock_guard lock(ob->gMutex); + + // remove current op id from stack + ob->opStack[fn.threadId()].pop(); + + writeJsonNode( + ob->out, + fc.name, + fc.opId, + fn.handle(), + fc.parentId, + fc.fwParentId, + fn.seqNr(), + static_cast>(fn.scope()), + fn.threadId(), + fn.forwardThreadId(), + vectorToString(fc.inputValues), + vectorToString(fc.inputShapes), + vectorToString(fc.inputStrides), + vectorToString(fc.inputTypes), + vectorToString(output_values), + vectorToString(output_shapes), + vectorToString(output_strides), + vectorToString(output_types), + op_schema_str, + fc.kernelBackend, + fc.kernelFile, + additiona_attrs); + ob->out << ","; + } } catch (const std::exception& e) { LOG(WARNING) << "Exception in execution trace observer: [" << fc.name << " (" << fc.opId << ")] " << e.what(); diff --git a/torch/csrc/profiler/stubs/base.cpp b/torch/csrc/profiler/stubs/base.cpp index bc11cab837968..6ee455ca7e97f 100644 --- a/torch/csrc/profiler/stubs/base.cpp +++ b/torch/csrc/profiler/stubs/base.cpp @@ -1,28 +1,31 @@ -#include - +#include #include +#include +#include +#include namespace torch::profiler::impl { -ProfilerStubs::~ProfilerStubs() = default; - namespace { struct DefaultStubs : public ProfilerStubs { - DefaultStubs(const char* name) : name_{name} {} + explicit DefaultStubs(const char* name) : name_{name} {} - void record(c10::DeviceIndex*, ProfilerVoidEventStub*, int64_t*) - const override { + void record( + c10::DeviceIndex* /*device*/, + ProfilerVoidEventStub* /*event*/, + int64_t* /*cpu_ns*/) const override { fail(); } - float elapsed(const ProfilerVoidEventStub*, const ProfilerVoidEventStub*) - const override { + float elapsed( + const ProfilerVoidEventStub* /*event*/, + const ProfilerVoidEventStub* /*event2*/) const override { fail(); - return 0.f; + return 0.F; } - void mark(const char*) const override { + void mark(const char* /*name*/) const override { fail(); } - void rangePush(const char*) const override { + void rangePush(const char* /*name*/) const override { fail(); } void rangePop() const override { @@ -31,7 +34,7 @@ struct DefaultStubs : public ProfilerStubs { bool enabled() const override { return false; } - void onEachDevice(std::function) const override { + void onEachDevice(std::function /*op*/) const override { fail(); } void synchronize() const override { @@ -41,7 +44,7 @@ struct DefaultStubs : public ProfilerStubs { private: void fail() const { - AT_ERROR(name_, " used in profiler but not enabled."); + TORCH_CHECK(false, name_, " used in profiler but not enabled."); } const char* const name_; diff --git a/torch/csrc/profiler/stubs/base.h b/torch/csrc/profiler/stubs/base.h index c8a0e6cd2ebbe..c64f4e5a6c9e9 100644 --- a/torch/csrc/profiler/stubs/base.h +++ b/torch/csrc/profiler/stubs/base.h @@ -33,7 +33,7 @@ struct TORCH_API ProfilerStubs { } virtual void onEachDevice(std::function op) const = 0; virtual void synchronize() const = 0; - virtual ~ProfilerStubs(); + virtual ~ProfilerStubs() = default; }; TORCH_API void registerCUDAMethods(ProfilerStubs* stubs); diff --git a/torch/csrc/profiler/unwind/communicate.h b/torch/csrc/profiler/unwind/communicate.h index 063fe542a3419..bdaca33b6db2f 100644 --- a/torch/csrc/profiler/unwind/communicate.h +++ b/torch/csrc/profiler/unwind/communicate.h @@ -1,15 +1,16 @@ #pragma once #include -#include #include #include +#include #include namespace torch::unwind { // helper to open a process with stdin/stdout/stderr streams. struct Communicate { Communicate(const char* command, const char** args) { - if (pipe(inpipe_) < 0 || pipe(outpipe_) < 0 || pipe(errpipe_) < 0) { + if (pipe(inpipe_.data()) < 0 || pipe(outpipe_.data()) < 0 || + pipe(errpipe_.data()) < 0) { throw UnwindError("pipe() failed"); } pid_t pid = fork(); @@ -29,17 +30,21 @@ struct Communicate { close(inpipe_[0]); close(outpipe_[1]); close(errpipe_[1]); - outbuf_.reset( - new __gnu_cxx::stdio_filebuf(inpipe_[1], std::ios::out)); - inbuf_.reset( - new __gnu_cxx::stdio_filebuf(outpipe_[0], std::ios::in)); - errbuf_.reset( - new __gnu_cxx::stdio_filebuf(errpipe_[0], std::ios::in)); - in_.reset(new std::istream(inbuf_.get())); - out_.reset(new std::ostream(outbuf_.get())); - err_.reset(new std::ostream(errbuf_.get())); + outbuf_ = std::make_unique<__gnu_cxx::stdio_filebuf>( + inpipe_[1], std::ios::out); + inbuf_ = std::make_unique<__gnu_cxx::stdio_filebuf>( + outpipe_[0], std::ios::in); + errbuf_ = std::make_unique<__gnu_cxx::stdio_filebuf>( + errpipe_[0], std::ios::in); + in_ = std::make_unique(inbuf_.get()); + out_ = std::make_unique(outbuf_.get()); + err_ = std::make_unique(errbuf_.get()); } } + Communicate(const Communicate&) = delete; + Communicate(Communicate&&) = delete; + Communicate& operator=(const Communicate&) = delete; + Communicate& operator=(Communicate&&) = delete; ~Communicate() { close(inpipe_[1]); close(outpipe_[0]); @@ -56,9 +61,9 @@ struct Communicate { } private: - int inpipe_[2]; - int outpipe_[2]; - int errpipe_[2]; + std::array inpipe_{-1, -1}; + std::array outpipe_{-1, -1}; + std::array errpipe_{-1, -1}; std::unique_ptr<__gnu_cxx::stdio_filebuf> outbuf_, inbuf_, errbuf_; std::unique_ptr in_; std::unique_ptr out_; diff --git a/torch/csrc/profiler/unwind/debug_info.h b/torch/csrc/profiler/unwind/debug_info.h index 067d7dc2e83e6..ac440ed198caf 100644 --- a/torch/csrc/profiler/unwind/debug_info.h +++ b/torch/csrc/profiler/unwind/debug_info.h @@ -259,6 +259,7 @@ struct DebugInfo { } } + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) Sections& s_; std::optional line_number_program_offset_; uint64_t offset_ = 0; diff --git a/torch/csrc/profiler/unwind/eh_frame_hdr.h b/torch/csrc/profiler/unwind/eh_frame_hdr.h index 09c23279d19ce..740f4beb2c85c 100644 --- a/torch/csrc/profiler/unwind/eh_frame_hdr.h +++ b/torch/csrc/profiler/unwind/eh_frame_hdr.h @@ -40,6 +40,7 @@ struct EHFrameHdr { throw UnwindError("unknown table encoding"); } } + // NOLINTNEXTLINE(performance-no-int-to-ptr) eh_frame_ = (void*)L.readEncodedOr(eh_frame_ptr_enc_, 0); fde_count_ = L.readEncodedOr(fde_count_enc_, 0); table_start_ = L.loc(); @@ -54,6 +55,7 @@ struct EHFrameHdr { .readEncoded(table_enc_); } void* fde(size_t i) const { + // NOLINTNEXTLINE(performance-no-int-to-ptr) return (void*)Lexer(table_start_, base_) .skip((2 * i + 1) * table_size_) .readEncoded(table_enc_); diff --git a/torch/csrc/profiler/unwind/fast_symbolizer.h b/torch/csrc/profiler/unwind/fast_symbolizer.h index d4201f10c013d..6a8e75c05bf63 100644 --- a/torch/csrc/profiler/unwind/fast_symbolizer.h +++ b/torch/csrc/profiler/unwind/fast_symbolizer.h @@ -7,8 +7,8 @@ #include #include #include -#include #include +#include namespace torch::unwind { diff --git a/torch/csrc/profiler/unwind/fde.h b/torch/csrc/profiler/unwind/fde.h index ea8b4ca94eaea..cb3de64486b89 100644 --- a/torch/csrc/profiler/unwind/fde.h +++ b/torch/csrc/profiler/unwind/fde.h @@ -1,4 +1,5 @@ #pragma once +#include #include #include #include @@ -41,6 +42,7 @@ struct FDE { Lexer L(data); auto length = L.read4or8Length(); void* fde_start = L.loc(); + // NOLINTNEXTLINE(performance-no-int-to-ptr) void* cie_data = (void*)((int64_t)fde_start - L.read()); Lexer LC(cie_data); auto cie_length = LC.read4or8Length(); @@ -54,17 +56,17 @@ struct FDE { if (hasAugmentation("eh")) { throw UnwindError("unsupported 'eh' augmentation string"); } - code_alignment_factor_ = LC.readULEB128(); - data_alignment_factor_ = LC.readSLEB128(); + code_alignment_factor_ = static_cast(LC.readULEB128()); + data_alignment_factor_ = static_cast(LC.readSLEB128()); if (version == 1) { ra_register_ = LC.read(); } else { - ra_register_ = LC.readULEB128(); + ra_register_ = static_cast(LC.readULEB128()); } // we assume this in the state TORCH_INTERNAL_ASSERT(ra_register_ == 16, "unexpected number of registers"); if (augmentation_string_ && *augmentation_string_ == 'z') { - augmentation_length_ = LC.readULEB128(); + augmentation_length_ = static_cast(LC.readULEB128()); Lexer A(LC.loc()); for (auto ap = augmentation_string_ + 1; *ap; ap++) { switch (*ap) { @@ -92,7 +94,7 @@ struct FDE { high_pc_ = low_pc_ + L.readEncodedValue(fde_enc); if (hasAugmentation("z")) { - augmentation_length_fde_ = L.readULEB128(); + augmentation_length_fde_ = static_cast(L.readULEB128()); } L.readEncodedOr(lsda_enc, 0); @@ -153,7 +155,7 @@ struct FDE { } last_reg_ = reg; last_offset_ = off; - state().cfa = Action::regPlusData(reg, off); + state().cfa = Action::regPlusData(static_cast(reg), off); } void def_cfa_register(int64_t reg) { def_cfa(reg, last_offset_); @@ -185,7 +187,8 @@ struct FDE { if (LOG) { (*out_) << "register " << reg << " " << rhs_reg << "\n"; } - state().registers.at(reg) = Action::regPlusData(reg, 0); + state().registers.at(reg) = + Action::regPlusData(static_cast(reg), 0); } TableState& state() { @@ -209,6 +212,7 @@ struct FDE { throw UnwindError("Address not in range"); } if (LOG) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) (*out_) << "readUpTo " << (void*)addr << " for " << library_name_ << " at " << (void*)load_bias_ << "\n"; } @@ -312,6 +316,7 @@ struct FDE { case DW_CFA_expression: { auto reg = L.readULEB128(); auto len = L.readULEB128(); + // NOLINTNEXTLINE(performance-no-int-to-ptr) auto end = (void*)((uint64_t)L.loc() + len); auto op = L.read(); if ((op & 0xF0) == 0x70) { // DW_bregX @@ -327,6 +332,7 @@ struct FDE { } case DW_CFA_def_cfa_expression: { auto len = L.readULEB128(); + // NOLINTNEXTLINE(performance-no-int-to-ptr) auto end = (void*)((uint64_t)L.loc() + len); auto op = L.read(); if ((op & 0xF0) == 0x70) { // DW_bregX @@ -344,6 +350,7 @@ struct FDE { } default: { std::stringstream ss; + // NOLINTNEXTLINE(performance-no-int-to-ptr) ss << "unknown op code " << (void*)(uint64_t)lowbits; throw UnwindError(ss.str()); } @@ -372,7 +379,7 @@ struct FDE { int64_t code_alignment_factor_; int64_t data_alignment_factor_; - void* cie_data_; + void* cie_data_{nullptr}; int64_t ra_register_; uint8_t lsda_enc = DW_EH_PE_omit; @@ -388,7 +395,7 @@ struct FDE { // state accumulated while parsing instructions int64_t last_reg_ = 0; int64_t last_offset_ = 0; - uint64_t current_pc_; + uint64_t current_pc_ = 0; TableState initial_state_; // state after the initial instructions, used by restore diff --git a/torch/csrc/profiler/unwind/lexer.h b/torch/csrc/profiler/unwind/lexer.h index 117df6b9b0286..9224cd6e47e39 100644 --- a/torch/csrc/profiler/unwind/lexer.h +++ b/torch/csrc/profiler/unwind/lexer.h @@ -118,7 +118,7 @@ struct LexerImpl { void* loc() const { return (void*)next_; } - LexerImpl& skip(int64_t bytes) { + LexerImpl& skip(size_t bytes) { next_ += bytes; return *this; } diff --git a/torch/csrc/profiler/unwind/mem_file.h b/torch/csrc/profiler/unwind/mem_file.h index b5b6807a7bbce..2580e6f6da55a 100644 --- a/torch/csrc/profiler/unwind/mem_file.h +++ b/torch/csrc/profiler/unwind/mem_file.h @@ -81,7 +81,9 @@ struct MemFile { } MemFile(const MemFile&) = delete; + MemFile(MemFile&&) = delete; MemFile& operator=(const MemFile&) = delete; + MemFile& operator=(MemFile&&) = delete; [[nodiscard]] const char* data() const { return (const char*)mem_; } diff --git a/torch/csrc/profiler/unwind/unwind.cpp b/torch/csrc/profiler/unwind/unwind.cpp index 22ddf02d8452e..bed307245822f 100644 --- a/torch/csrc/profiler/unwind/unwind.cpp +++ b/torch/csrc/profiler/unwind/unwind.cpp @@ -1,7 +1,7 @@ +#include #include #include #include -#include #if !defined(__linux__) || !defined(__x86_64__) || !defined(__has_include) || \ !__has_include("ext/stdio_filebuf.h") @@ -65,6 +65,10 @@ struct UpgradeExclusive { rdlock_.unlock(); rdlock_.mutex()->lock(); } + UpgradeExclusive(const UpgradeExclusive&) = delete; + UpgradeExclusive(UpgradeExclusive&&) = delete; + UpgradeExclusive& operator=(const UpgradeExclusive&) = delete; + UpgradeExclusive& operator=(UpgradeExclusive&&) = delete; ~UpgradeExclusive() { rdlock_.mutex()->unlock(); rdlock_.lock(); @@ -121,8 +125,8 @@ static const char* process_name() { } struct Version { - uint64_t adds_ = LONG_LONG_MAX; - uint64_t subs_ = LONG_LONG_MAX; + uint64_t adds_ = LLONG_MAX; + uint64_t subs_ = LLONG_MAX; }; struct UnwindCache { @@ -498,7 +502,10 @@ Stats stats() { } // namespace torch::unwind -extern "C" void unwind_c(std::vector* result, int64_t rsp, int64_t rbp) { +extern "C" C10_USED void unwind_c( + std::vector* result, + int64_t rsp, + int64_t rbp) { std::shared_lock lock(torch::unwind::cache_mutex_); torch::unwind::UnwindState state{}; // NOLINTNEXTLINE(performance-no-int-to-ptr) diff --git a/torch/csrc/serialization.cpp b/torch/csrc/serialization.cpp index c922da900613d..7a6f61128b93e 100644 --- a/torch/csrc/serialization.cpp +++ b/torch/csrc/serialization.cpp @@ -168,7 +168,8 @@ void doRead(io fildes, void* raw_buf, size_t nbytes) { if (err == EINTR) { continue; } else { - AT_ERROR("read(): fd ", fildes, " failed with ", strerror(err)); + TORCH_CHECK( + false, "read(): fd ", fildes, " failed with ", strerror(err)); } } else if (r == 0) { break; @@ -180,7 +181,8 @@ void doRead(io fildes, void* raw_buf, size_t nbytes) { nbytes -= r; } if (nbytes != 0) { - AT_ERROR( + TORCH_CHECK( + false, "unexpected EOF, expected ", nbytes, " more bytes. The file might be corrupted."); @@ -208,7 +210,8 @@ void doWrite(io fildes, void* raw_buf, size_t nbytes) { if (err == EINTR) { continue; } else { - AT_ERROR("write(): fd ", fildes, " failed with ", strerror(err)); + TORCH_CHECK( + false, "write(): fd ", fildes, " failed with ", strerror(err)); } } buf += r; diff --git a/torch/csrc/utils/byte_order.cpp b/torch/csrc/utils/byte_order.cpp index 4432e74bd06c1..fb10cd665aa13 100644 --- a/torch/csrc/utils/byte_order.cpp +++ b/torch/csrc/utils/byte_order.cpp @@ -12,8 +12,7 @@ namespace { static inline void swapBytes16(void* ptr) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint16_t output; + uint16_t output = 0; memcpy(&output, ptr, sizeof(uint16_t)); #if defined(_MSC_VER) && !defined(_DEBUG) output = _byteswap_ushort(output); @@ -28,8 +27,7 @@ static inline void swapBytes16(void* ptr) { } static inline void swapBytes32(void* ptr) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t output; + uint32_t output = 0; memcpy(&output, ptr, sizeof(uint32_t)); #if defined(_MSC_VER) && !defined(_DEBUG) output = _byteswap_ulong(output); @@ -46,8 +44,7 @@ static inline void swapBytes32(void* ptr) { } static inline void swapBytes64(void* ptr) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint64_t output; + uint64_t output = 0; memcpy(&output, ptr, sizeof(uint64_t)); #if defined(_MSC_VER) output = _byteswap_uint64(output); @@ -70,8 +67,7 @@ static inline void swapBytes64(void* ptr) { } static inline uint16_t decodeUInt16(const uint8_t* data) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint16_t output; + uint16_t output = 0; memcpy(&output, data, sizeof(uint16_t)); return output; } @@ -83,8 +79,7 @@ static inline uint16_t decodeUInt16ByteSwapped(const uint8_t* data) { } static inline uint32_t decodeUInt32(const uint8_t* data) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint32_t output; + uint32_t output = 0; memcpy(&output, data, sizeof(uint32_t)); return output; } @@ -96,8 +91,7 @@ static inline uint32_t decodeUInt32ByteSwapped(const uint8_t* data) { } static inline uint64_t decodeUInt64(const uint8_t* data) { - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - uint64_t output; + uint64_t output = 0; memcpy(&output, data, sizeof(uint64_t)); return output; } @@ -149,6 +143,7 @@ TORCH_API void THP_decodeBuffer( bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint16_t x; c10::Half f; @@ -191,6 +186,7 @@ TORCH_API void THP_decodeBuffer( bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint32_t x; float f; @@ -208,6 +204,7 @@ TORCH_API void THP_decodeBuffer( bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint64_t x; double d; @@ -225,10 +222,12 @@ TORCH_API void THP_decodeBuffer, bool>( bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint32_t x; float re; }; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint32_t y; float im; @@ -250,10 +249,12 @@ TORCH_API void THP_decodeBuffer, bool>( bool do_byte_swap, size_t len) { for (const auto i : c10::irange(len)) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint64_t x; double re; }; + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) union { uint64_t y; double im; diff --git a/torch/csrc/utils/cpp_stacktraces.cpp b/torch/csrc/utils/cpp_stacktraces.cpp index 715271d76c826..06286f042ed3d 100644 --- a/torch/csrc/utils/cpp_stacktraces.cpp +++ b/torch/csrc/utils/cpp_stacktraces.cpp @@ -4,41 +4,18 @@ #include #include +#include namespace torch { namespace { bool compute_cpp_stack_traces_enabled() { - auto envar = std::getenv("TORCH_SHOW_CPP_STACKTRACES"); - if (envar) { - if (strcmp(envar, "0") == 0) { - return false; - } - if (strcmp(envar, "1") == 0) { - return true; - } - TORCH_WARN( - "ignoring invalid value for TORCH_SHOW_CPP_STACKTRACES: ", - envar, - " valid values are 0 or 1."); - } - return false; + auto envvar = c10::utils::check_env("TORCH_SHOW_CPP_STACKTRACES"); + return envvar.has_value() && envvar.value(); } bool compute_disable_addr2line() { - auto envar = std::getenv("TORCH_DISABLE_ADDR2LINE"); - if (envar) { - if (strcmp(envar, "0") == 0) { - return false; - } - if (strcmp(envar, "1") == 0) { - return true; - } - TORCH_WARN( - "ignoring invalid value for TORCH_DISABLE_ADDR2LINE: ", - envar, - " valid values are 0 or 1."); - } - return false; + auto envvar = c10::utils::check_env("TORCH_DISABLE_ADDR2LINE"); + return envvar.has_value() && envvar.value(); } } // namespace @@ -48,20 +25,19 @@ bool get_cpp_stacktraces_enabled() { } static torch::unwind::Mode compute_symbolize_mode() { - auto envar_c = std::getenv("TORCH_SYMBOLIZE_MODE"); - if (envar_c) { - std::string envar = envar_c; - if (envar == "dladdr") { + auto envar_c = c10::utils::get_env("TORCH_SYMBOLIZE_MODE"); + if (envar_c.has_value()) { + if (envar_c == "dladdr") { return unwind::Mode::dladdr; - } else if (envar == "addr2line") { + } else if (envar_c == "addr2line") { return unwind::Mode::addr2line; - } else if (envar == "fast") { + } else if (envar_c == "fast") { return unwind::Mode::fast; } else { TORCH_CHECK( false, "expected {dladdr, addr2line, fast} for TORCH_SYMBOLIZE_MODE, got ", - envar); + envar_c.value()); } } else { return compute_disable_addr2line() ? unwind::Mode::dladdr diff --git a/torch/csrc/utils/disable_torch_function.cpp b/torch/csrc/utils/disable_torch_function.cpp index e25badfe3e645..74558637b72c0 100644 --- a/torch/csrc/utils/disable_torch_function.cpp +++ b/torch/csrc/utils/disable_torch_function.cpp @@ -343,7 +343,7 @@ inline bool array_has_torch_function(PyObject* const* args, Py_ssize_t nargs) { } PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg) { - bool result; // NOLINT(cppcoreguidelines-init-variables) + bool result = false; if (PyTuple_CheckExact(arg) || PyList_CheckExact(arg)) { // Fast path: // If we know that we have a tuple or list, we can skip an INCREF and diff --git a/torch/csrc/utils/init.cpp b/torch/csrc/utils/init.cpp index 391b331c4f10c..e80fbc281c8b2 100644 --- a/torch/csrc/utils/init.cpp +++ b/torch/csrc/utils/init.cpp @@ -39,7 +39,7 @@ void initThroughputBenchmarkBindings(PyObject* module) { const py::kwargs& kwargs) { // Depending on this being ScriptModule of nn.Module we will release // the GIL or not further down in the stack - return self.runOnce(std::move(args), kwargs); + return self.runOnce(args, kwargs); }) .def( "benchmark", diff --git a/torch/csrc/utils/invalid_arguments.cpp b/torch/csrc/utils/invalid_arguments.cpp index 4d69870145939..d26f8c2ee1da4 100644 --- a/torch/csrc/utils/invalid_arguments.cpp +++ b/torch/csrc/utils/invalid_arguments.cpp @@ -116,6 +116,7 @@ struct Option { Option(Option&& other) noexcept = default; Option& operator=(const Option&) = delete; Option& operator=(Option&&) = delete; + ~Option() = default; std::vector arguments; bool is_variadic; diff --git a/torch/csrc/utils/out_types.cpp b/torch/csrc/utils/out_types.cpp index 6dad9c91c18c9..4799f0ed47e35 100644 --- a/torch/csrc/utils/out_types.cpp +++ b/torch/csrc/utils/out_types.cpp @@ -16,7 +16,8 @@ void check_out_type_matches( } // NOLINTNEXTLINE(bugprone-unchecked-optional-access) if (!scalarType_is_none && result.scalar_type() != scalarType.value()) { - AT_ERROR( + TORCH_CHECK( + false, "dtype ", // NOLINTNEXTLINE(bugprone-unchecked-optional-access) *scalarType, @@ -25,7 +26,8 @@ void check_out_type_matches( ")"); } if (layout && result.layout() != *layout) { - AT_ERROR( + TORCH_CHECK( + false, "layout ", *layout, " does not match layout of out parameter (", @@ -34,7 +36,8 @@ void check_out_type_matches( } // NOLINTNEXTLINE(bugprone-unchecked-optional-access) if (!device_is_none && result.device().type() != device.value().type()) { - AT_ERROR( + TORCH_CHECK( + false, "device type ", // NOLINTNEXTLINE(bugprone-unchecked-optional-access) device->type(), diff --git a/torch/csrc/utils/python_dispatch.cpp b/torch/csrc/utils/python_dispatch.cpp index aa87568078867..c5a659f371da0 100644 --- a/torch/csrc/utils/python_dispatch.cpp +++ b/torch/csrc/utils/python_dispatch.cpp @@ -97,6 +97,10 @@ struct EnableHermeticPyObject { c10::impl::tls_set_dispatch_key_included( at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_); } + EnableHermeticPyObject(const EnableHermeticPyObject&) = delete; + EnableHermeticPyObject(EnableHermeticPyObject&&) = delete; + EnableHermeticPyObject& operator=(const EnableHermeticPyObject&) = delete; + EnableHermeticPyObject& operator=(EnableHermeticPyObject&&) = delete; bool old_; bool old_excluded_python_; bool old_python_; @@ -638,7 +642,7 @@ void initDispatchBindings(PyObject* module) { if (!op.overload_name.empty()) { ss << "." << op.overload_name; } - names.emplace_back(ss.str()); + names.emplace_back(std::move(ss).str()); } return names; diff --git a/torch/csrc/utils/python_numbers.h b/torch/csrc/utils/python_numbers.h index d5b772b768e22..c22f752d78349 100644 --- a/torch/csrc/utils/python_numbers.h +++ b/torch/csrc/utils/python_numbers.h @@ -177,8 +177,7 @@ inline bool THPUtils_unpackNumberAsBool(PyObject* obj) { return !(real_val == 0 && imag_val == 0); } - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int overflow; + int overflow = 0; long long value = PyLong_AsLongLongAndOverflow(obj, &overflow); if (value == -1 && PyErr_Occurred()) { throw python_error(); diff --git a/torch/csrc/utils/python_torch_function_mode.h b/torch/csrc/utils/python_torch_function_mode.h index f0e6bb9acbe97..56d6329378734 100644 --- a/torch/csrc/utils/python_torch_function_mode.h +++ b/torch/csrc/utils/python_torch_function_mode.h @@ -11,6 +11,12 @@ struct StashTorchFunctionModeGuard { ~StashTorchFunctionModeGuard() { at::impl::PythonTorchFunctionTLS::push_onto_stack(cur_mode_); } + StashTorchFunctionModeGuard(const StashTorchFunctionModeGuard&) = delete; + StashTorchFunctionModeGuard(StashTorchFunctionModeGuard&&) = delete; + StashTorchFunctionModeGuard& operator=(const StashTorchFunctionModeGuard&) = + delete; + StashTorchFunctionModeGuard& operator=(StashTorchFunctionModeGuard&&) = + delete; const std::shared_ptr& get_cur_mode() { return cur_mode_; diff --git a/torch/csrc/utils/tensor_apply.cpp b/torch/csrc/utils/tensor_apply.cpp index 906b5422b3734..c8a731d8d5fe7 100644 --- a/torch/csrc/utils/tensor_apply.cpp +++ b/torch/csrc/utils/tensor_apply.cpp @@ -53,8 +53,7 @@ static void recursive_apply( } auto n = sizes[dim]; - for (const auto i : c10::irange(n)) { - (void)i; // Suppress unused variable warning + for ([[maybe_unused]] const auto i : c10::irange(n)) { recursive_apply(sizes, scalarType, dim + 1, fn, strided_data); for (auto& td : strided_data) { td.step(dim); diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index de58b1965492d..e6371498314ba 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -198,7 +198,7 @@ ScalarType infer_scalar_type(PyObject* obj) { // NOLINTNEXTLINE(bugprone-unchecked-optional-access) return *scalarType; } - AT_ERROR("Could not infer dtype of ", Py_TYPE(obj)->tp_name); + TORCH_CHECK(false, "Could not infer dtype of ", Py_TYPE(obj)->tp_name); } void recursive_store( @@ -345,6 +345,23 @@ Tensor internal_new_from_data( } #endif + if (PyObject_HasAttrString(data, "__dlpack__")) { + py::object tensor_o = + py::module::import("torch").attr("utils").attr("dlpack").attr( + "from_dlpack")(py::handle(data)); + Tensor tensor = py::cast(tensor_o); + const auto& inferred_scalar_type = + type_inference ? tensor.scalar_type() : scalar_type; + auto device = device_opt.has_value() ? *device_opt : tensor.device(); + pybind11::gil_scoped_release no_gil; + maybe_initialize_device(device); + return tensor.to( + device, + inferred_scalar_type, + /*non_blocking=*/false, + /*copy=*/copy_variables); + } + auto device = device_opt.has_value() ? *device_opt : options.device(); auto sizes = compute_sizes(data, scalar_type); @@ -853,6 +870,14 @@ class CheckSparseTensorInvariantsContext { ~CheckSparseTensorInvariantsContext() { at::globalContext().setCheckSparseTensorInvariants(state); } + CheckSparseTensorInvariantsContext( + const CheckSparseTensorInvariantsContext&) = delete; + CheckSparseTensorInvariantsContext(CheckSparseTensorInvariantsContext&&) = + delete; + CheckSparseTensorInvariantsContext& operator=( + const CheckSparseTensorInvariantsContext&) = delete; + CheckSparseTensorInvariantsContext& operator=( + CheckSparseTensorInvariantsContext&&) = delete; private: bool state; diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp index 255c74af79544..b8e9120b6c61e 100644 --- a/torch/csrc/utils/tensor_numpy.cpp +++ b/torch/csrc/utils/tensor_numpy.cpp @@ -52,14 +52,12 @@ bool is_numpy_dlpack_deleter_bugged() { #include #include #include -#include #include using namespace at; using namespace torch::autograd; -namespace torch { -namespace utils { +namespace torch::utils { bool is_numpy_available() { static bool available = []() { @@ -68,8 +66,7 @@ bool is_numpy_available() { } // Try to get exception message, print warning and return false std::string message = "Failed to initialize NumPy"; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - PyObject *type, *value, *traceback; + PyObject *type = nullptr, *value = nullptr, *traceback = nullptr; PyErr_Fetch(&type, &value, &traceback); if (auto str = value ? PyObject_Str(value) : nullptr) { if (auto enc_str = PyUnicode_AsEncodedString(str, "utf-8", "strict")) { @@ -403,10 +400,8 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { } // Extract the `obj.__cuda_array_interface__['typestr']` attribute - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - ScalarType dtype; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - int dtype_size_in_bytes; + ScalarType dtype{}; + int dtype_size_in_bytes = 0; { PyObject* py_typestr = nullptr; if (PyDict_GetItemStringRef(cuda_dict, "typestr", &py_typestr) < 0) { @@ -415,8 +410,7 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { if (py_typestr == nullptr) { throw TypeError("attribute `typestr` must exist"); } - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - PyArray_Descr* descr; + PyArray_Descr* descr = nullptr; TORCH_CHECK_VALUE( PyArray_DescrConverter(py_typestr, &descr), "cannot parse `typestr`"); dtype = numpy_dtype_to_aten(descr->type_num); @@ -429,8 +423,7 @@ at::Tensor tensor_from_cuda_array_interface(PyObject* obj) { } // Extract the `obj.__cuda_array_interface__['data']` attribute - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - void* data_ptr; + void* data_ptr = nullptr; { PyObject* py_data = nullptr; if (PyDict_GetItemStringRef(cuda_dict, "data", &py_data) < 0) { @@ -573,7 +566,6 @@ void validate_numpy_for_dlpack_deleter_bug() { bool is_numpy_dlpack_deleter_bugged() { return numpy_with_dlpack_deleter_bug_installed; } -} // namespace utils -} // namespace torch +} // namespace torch::utils #endif // USE_NUMPY diff --git a/torch/csrc/utils/tensor_types.cpp b/torch/csrc/utils/tensor_types.cpp index 7dacce7bce238..00f60a8a1f7fb 100644 --- a/torch/csrc/utils/tensor_types.cpp +++ b/torch/csrc/utils/tensor_types.cpp @@ -60,7 +60,7 @@ const char* backend_to_string(const at::Backend& backend) { case at::Backend::Meta: return "torch.meta"; default: - AT_ERROR("Unimplemented backend ", backend); + TORCH_CHECK(false, "Unimplemented backend ", backend); } } diff --git a/torch/csrc/utils/torch_dispatch_mode.h b/torch/csrc/utils/torch_dispatch_mode.h index 2eb8ba7a1cbbb..8fe5404b44a28 100644 --- a/torch/csrc/utils/torch_dispatch_mode.h +++ b/torch/csrc/utils/torch_dispatch_mode.h @@ -27,6 +27,12 @@ struct StashTorchDispatchModeGuard { std::move(saved_mode_)); } } + StashTorchDispatchModeGuard(const StashTorchDispatchModeGuard&) = delete; + StashTorchDispatchModeGuard(StashTorchDispatchModeGuard&&) = delete; + StashTorchDispatchModeGuard& operator=(const StashTorchDispatchModeGuard&) = + delete; + StashTorchDispatchModeGuard& operator=(StashTorchDispatchModeGuard&&) = + delete; const std::shared_ptr& get_cur_mode() { return saved_mode_; @@ -44,6 +50,12 @@ struct StashTorchDispatchStackGuard { c10::impl::TorchDispatchModeTLS::set_state(std::move(saved_state_)); saved_state_ = std::move(old); } + StashTorchDispatchStackGuard(const StashTorchDispatchStackGuard&) = delete; + StashTorchDispatchStackGuard(StashTorchDispatchStackGuard&&) = delete; + StashTorchDispatchStackGuard& operator=(const StashTorchDispatchStackGuard&) = + delete; + StashTorchDispatchStackGuard& operator=(StashTorchDispatchStackGuard&&) = + delete; ~StashTorchDispatchStackGuard() { c10::impl::TorchDispatchModeTLS::set_state(std::move(saved_state_)); diff --git a/torch/csrc/xpu/Module.cpp b/torch/csrc/xpu/Module.cpp index 6e6c9a4564b65..bd417a8de5a17 100644 --- a/torch/csrc/xpu/Module.cpp +++ b/torch/csrc/xpu/Module.cpp @@ -39,6 +39,17 @@ static void poison_fork() { // XPU management methods +PyObject* THXPModule_getArchFlags(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS +#ifdef XPU_ARCH_FLAGS + static const char* flags = C10_STRINGIZE(XPU_ARCH_FLAGS); + return THPUtils_packString(flags); +#else + Py_RETURN_NONE; +#endif + END_HANDLE_TH_ERRORS +} + static PyObject* THXPModule_isInBadFork_wrap(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS return PyBool_FromLong(in_bad_fork); @@ -363,7 +374,7 @@ static PyObject* THXPModule_initExtension(PyObject* self, PyObject* noargs) { HANDLE_TH_ERRORS TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level poison_fork(); - at::globalContext().lazyInitXPU(); + at::globalContext().lazyInitDevice(c10::DeviceType::XPU); auto m = THPObjectPtr(PyImport_ImportModule("torch.xpu")); if (!m) @@ -404,6 +415,7 @@ static struct PyMethodDef _THXPModule_methods[] = { THXPModule_getDeviceCount_wrap, METH_NOARGS, nullptr}, + {"_xpu_getArchFlags", THXPModule_getArchFlags, METH_NOARGS, nullptr}, {"_xpu_isInBadFork", THXPModule_isInBadFork_wrap, METH_NOARGS, nullptr}, {"_xpu_getCurrentStream", THXPModule_getCurrentStream_wrap, diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 356de1a573097..7e17f9ccb6da0 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -186,7 +186,7 @@ def _check_capability(): work properly, but your PyTorch was compiled with CUDA_VERSION %d. Please install the correct PyTorch binary using instructions from https://pytorch.org - """ + """ # noqa: F841 old_gpu_warn = """ Found GPU%d %s which is of cuda capability %d.%d. @@ -195,7 +195,7 @@ def _check_capability(): """ if torch.version.cuda is not None: # on ROCm we don't want this check - CUDA_VERSION = torch._C._cuda_getCompiledVersion() + CUDA_VERSION = torch._C._cuda_getCompiledVersion() # noqa: F841 for d in range(device_count()): capability = get_device_capability(d) major = capability[0] @@ -750,11 +750,15 @@ def _raw_device_uuid_amdsmi() -> Optional[List[str]]: warnings.warn("Cannot get amd device handler") return None try: - uuid = amdsmi.amdsmi_get_gpu_device_uuid(handler) + uuid = amdsmi.amdsmi_get_gpu_asic_info(handler)["asic_serial"][ + 2: + ] # Removes 0x prefix from serial except amdsmi.AmdSmiException: warnings.warn("Cannot get uuid for amd device") return None - uuids.append(str(uuid)) + uuids.append( + str(uuid).lower() + ) # Lower-case to match expected HIP_VISIBLE_DEVICES uuid input return uuids @@ -793,7 +797,7 @@ def _raw_device_uuid_nvml() -> Optional[List[str]]: def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]: r"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs.""" - def uuid_to_orinal(candidate: str, uuids: List[str]) -> int: + def uuid_to_ordinal(candidate: str, uuids: List[str]) -> int: best_match = -1 for idx, uuid in enumerate(uuids): if not uuid.startswith(candidate): @@ -806,7 +810,11 @@ def uuid_to_orinal(candidate: str, uuids: List[str]) -> int: rc: List[int] = [] for candidate in candidates: - idx = uuid_to_orinal(candidate, uuids) + if torch.version.hip: + candidate = candidate.replace( + "GPU-", "", 1 + ) # Remove GPU-prefix to match amdsmi asic serial + idx = uuid_to_ordinal(candidate, uuids) # First invalid ordinal stops parsing if idx < 0: break @@ -823,7 +831,12 @@ def _device_count_amdsmi() -> int: return 0 try: if type(visible_devices[0]) is str: - return -1 + uuids = _raw_device_uuid_amdsmi() + if uuids is None: + return -1 + # Create string version of visible devices to avoid mypy warnings + visible_device_str = cast(List[str], visible_devices) + visible_devices = _transform_uuid_to_ordinals(visible_device_str, uuids) else: raw_cnt = _raw_device_count_amdsmi() if raw_cnt <= 0: @@ -1082,7 +1095,13 @@ def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int: idx = _get_device_index(device, optional=True) visible_devices = _parse_visible_devices() if type(visible_devices[0]) is str: - raise RuntimeError("HIP_VISIBLE_DEVICES should be indices and not strings") + uuids = _raw_device_uuid_amdsmi() + if uuids is None: + raise RuntimeError("Can't get device UUIDs") + visible_devices_str = cast( + List[str], visible_devices + ) # Create str variable for mypy + visible_devices = _transform_uuid_to_ordinals(visible_devices_str, uuids) idx_map = dict(enumerate(cast(List[int], visible_devices))) if idx not in idx_map: raise RuntimeError( diff --git a/torch/cuda/_memory_viz.py b/torch/cuda/_memory_viz.py index 2047ec4efb28f..b03a5236184e0 100644 --- a/torch/cuda/_memory_viz.py +++ b/torch/cuda/_memory_viz.py @@ -213,7 +213,6 @@ def segsum(data): Args: data: snapshot dictionary created from _snapshot() """ - segments = [] out = io.StringIO() out.write(f"Summary of segments >= {Bytes(PAGE_SIZE)} in size\n") total_reserved = 0 @@ -272,7 +271,6 @@ def segsum(data): out.write(f'segments: {len(data["segments"])}\n') out.write(f'total_reserved: {Bytes(total_reserved)}\n') out.write(f'total_allocated: {Bytes(total_allocated)}\n') - internal_external = f' ({Bytes(free_internal)} internal + {Bytes(free_external)} external)' if free_internal else '' out.write(f'total_free: {_report_free(free_external, free_internal)}\n') out.write(legend) assert free_internal + free_external + total_allocated == total_reserved @@ -478,10 +476,8 @@ def free(alloc, device): kv_to_elem = {} - - # create the device trace - for time, action, (tensor_key, version), size in memory_profile.timeline: + for _time, action, (tensor_key, version), size in memory_profile.timeline: if not isinstance(tensor_key, TensorKey): continue if action == Action.CREATE: diff --git a/torch/cuda/_sanitizer.py b/torch/cuda/_sanitizer.py index ab03485085878..01f40421425a1 100644 --- a/torch/cuda/_sanitizer.py +++ b/torch/cuda/_sanitizer.py @@ -16,6 +16,7 @@ import inspect import io import logging +import re import sys import textwrap import traceback @@ -41,6 +42,10 @@ logger = logging.getLogger(__name__) +# Note that this is only factories that take Tensor as input as they are +# the ones we care about. +FACTORY_FUNCTION_REGEX = re.compile("(new_.*|.*_like)") + class AccessType(enum.Enum): READ = enum.auto() @@ -486,6 +491,7 @@ def _handle_argument( self, value: Any, is_write: bool, + metadata_only: bool, name: Optional[str] = None, is_output: bool = False, ) -> None: @@ -493,7 +499,7 @@ def _handle_argument( data_ptr = value.data_ptr() if is_write: self.dataptrs_written.add(data_ptr) - else: + elif not metadata_only: self.dataptrs_read.add(data_ptr) self.tensor_aliases.setdefault(data_ptr, []) @@ -507,21 +513,42 @@ def parse_inputs( schema: torch.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any], + *, + is_factory: bool, ) -> None: for argument, value in zip_arguments(schema, args, kwargs): is_write = argument.alias_info is not None and argument.alias_info.is_write + # A change is metadata only if it is a view or a factory function that + # reads only metadata + metadata_only = is_factory or ( + argument.alias_info is not None and not argument.alias_info.is_write + ) pytree.tree_map_( functools.partial( - self._handle_argument, is_write=is_write, name=argument.name + self._handle_argument, + is_write=is_write, + name=argument.name, + metadata_only=metadata_only, ), value, ) - def parse_outputs(self, outputs: Any) -> None: - pytree.tree_map_( - functools.partial(self._handle_argument, is_write=True, is_output=True), - outputs, - ) + def parse_outputs( + self, schema: torch.FunctionSchema, outputs: Any, *, is_factory: bool + ) -> None: + for res, value in zip(schema.returns, (outputs,)): + metadata_only = is_factory or ( + res.alias_info is not None and not res.alias_info.is_write + ) + pytree.tree_map_( + functools.partial( + self._handle_argument, + is_write=not metadata_only, + is_output=True, + metadata_only=metadata_only, + ), + value, + ) class CUDASanitizerDispatchMode(TorchDispatchMode): @@ -563,12 +590,14 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} + is_factory = bool(FACTORY_FUNCTION_REGEX.match(func._schema.name)) + argument_handler = ArgumentHandler() - argument_handler.parse_inputs(func._schema, args, kwargs) + argument_handler.parse_inputs(func._schema, args, kwargs, is_factory=is_factory) outputs = func(*args, **kwargs) - argument_handler.parse_outputs(outputs) + argument_handler.parse_outputs(func._schema, outputs, is_factory=is_factory) errors = self.event_handler._handle_kernel_launch( torch.cuda.current_stream().cuda_stream, argument_handler.dataptrs_read - argument_handler.dataptrs_written, @@ -602,9 +631,20 @@ def enable(self): self.dispatch.__enter__() self.enabled = True + def disable(self): + self.dispatch.__exit__(None, None, None) + self.enabled = False + def __del__(self): - if self.enabled: - self.dispatch.__exit__(None, None, None) + # Since this object lifetime is linked to the `torch.cuda._sanitizer` python + # module, it often gets deleted as part of the overall `torch` module cleanup + # At that time, depending on CPython version, the torch.* module might be in + # different states of being already cleaned up. + # Similarly other imports might already have been cleaned up so `sys` might + # be already gone as well. + # Skip exiting the mode if it outlived the runtime. + if (sys is not None) and (not sys.is_finalizing()) and self.enabled: + self.disable() def enable_cuda_sanitizer(): diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py index b5de9f73df726..226278aabc1f8 100644 --- a/torch/cuda/graphs.py +++ b/torch/cuda/graphs.py @@ -357,11 +357,10 @@ def make_graphed_callables( # Capture backward graphs in reverse order per_callable_static_grad_outputs = [] per_callable_static_grad_inputs = [] - for static_input_surface, static_outputs, bwd_graph, module_params in zip( + for static_input_surface, static_outputs, bwd_graph in zip( reversed(per_callable_static_input_surfaces), reversed(per_callable_static_outputs), reversed(bwd_graphs), - reversed(per_callable_module_params), ): # For now, assumes all static_outputs require grad # assert all(o.requires_grad for o in static_outputs), "Outputs of graphed callables must require grad." diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 8c34761771337..145458de3040f 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -72,11 +72,13 @@ torch._C.__dict__["_cuda_endAllocateCurrentStreamToPool"] = _dummy_type( "_cuda_endAllocateCurrentStreamToPool" ) + torch._C.__dict__["_cuda_releasePool"] = _dummy_type("_cuda_releasePool") from torch._C import ( # noqa: F401 _cuda_beginAllocateToPool, _cuda_CUDAAllocator, _cuda_endAllocateCurrentStreamToPool, + _cuda_releasePool, _MemPool, _MemPoolContext, ) @@ -978,6 +980,25 @@ def _get_current_allocator() -> _CUDAAllocator: return _CUDAAllocator(torch._C._cuda_getAllocator()) +class MemPoolContext(_MemPoolContext): + r"""MemPoolContext holds the currently active pool and stashes the previous + pool. On deletion it makes the previous pool active. + + Args: + pool(torch.cuda.MemPool): a MemPool object to be made active so that + allocations route to this pool. + + """ + + def __init__(self, pool: _MemPool): + super().__init__(pool) + + @staticmethod + def active_pool() -> Optional[_MemPool]: + r"""Returns the active MemPool""" + return _MemPoolContext.active_pool() + + class MemPool(_MemPool): r"""MemPool represents a pool of memory in a caching allocator. Currently, it's just the ID of the pool object maintained in the CUDACachingAllocator. @@ -1001,27 +1022,30 @@ def id(self) -> Tuple[int, int]: @property def allocator(self) -> Optional[_cuda_CUDAAllocator]: - r"""Returns the allocator this MemPool routes allocations to""" + r"""Returns the allocator this MemPool routes allocations to.""" return super().allocator + def use_count(self) -> int: + r"""Returns the reference count of this pool.""" + return super().use_count() -class MemPoolContext(_MemPoolContext): - r"""MemPoolContext holds the currently active pool and stashes the previous - pool. On deletion it makes the previous pool active. + def snapshot(self): + r"""Return a snapshot of the CUDA memory allocator pool state across all + devices. - Args: - pool(torch.cuda.MemPool): a MemPool object to be made active so that - allocations route to this pool. - - """ + Interpreting the output of this function requires familiarity with the + memory allocator internals. - def __init__(self, pool: MemPool): - super().__init__(pool) - - @staticmethod - def active_pool() -> Optional[_MemPool]: - r"""Returns the active MemPool""" - return _MemPoolContext.active_pool() + .. note:: + See :ref:`cuda-memory-management` for more details about GPU memory + management. + """ + try: + ctx = MemPoolContext(self) + snapshot = torch.cuda.memory_snapshot() + finally: + del ctx + return snapshot @contextlib.contextmanager @@ -1045,4 +1069,5 @@ def use_mem_pool(pool: MemPool, device: Union[Device, int] = None): yield finally: _cuda_endAllocateCurrentStreamToPool(device_index, pool.id) + _cuda_releasePool(device_index, pool.id) del ctx diff --git a/torch/cuda/tunable.py b/torch/cuda/tunable.py index 8b387102b43dc..6ac9f14ab55a2 100644 --- a/torch/cuda/tunable.py +++ b/torch/cuda/tunable.py @@ -112,6 +112,7 @@ C++ or Python APIs. """ +import warnings from typing import Optional, Tuple import torch @@ -122,6 +123,8 @@ "is_enabled", "tuning_enable", "tuning_is_enabled", + "record_untuned_enable", + "record_untuned_is_enabled", "set_max_tuning_duration", "get_max_tuning_duration", "set_max_tuning_iterations", @@ -133,6 +136,7 @@ "write_file_on_exit", "write_file", "read_file", + "tune_gemm_in_file", ] @@ -160,6 +164,19 @@ def tuning_is_enabled() -> bool: return torch._C._cuda_tunableop_tuning_is_enabled() # type: ignore[attr-defined] +def record_untuned_enable(val: bool = True) -> None: + r"""Enable recording untuned of TunableOp perations for offline tuning. + + When enabled, if a tuned entry isn't found, write it to the untuned file. + """ + torch._C._cuda_record_untuned_enable(val) # type: ignore[attr-defined] + + +def record_untuned_is_enabled() -> bool: + r"""Returns whether TunableOp operations are recorded for offline tuning.""" + return torch._C._cuda_record_untuned_is_enabled() # type: ignore[attr-defined] + + def set_max_tuning_duration(duration: int) -> None: r"""Set max time in milliseconds to spend tuning a given solution. @@ -240,3 +257,64 @@ def read_file(filename: Optional[str] = None) -> bool: if filename is None: filename = get_filename() return torch._C._cuda_tunableop_read_file(filename) # type: ignore[attr-defined] + + +def tune_gemm_in_file(filename: str) -> None: + r"""tune GEMM in file.""" + + assert is_enabled() + assert tuning_is_enabled() + + with open(filename) as file: + for line in file: + if line.startswith("Gemm"): + untuned_gemm = line.strip().split(",")[:] + [op_sig, data_type, layout] = untuned_gemm[0].split("_") + + transA = True if layout[0] == "T" else False + transB = True if layout[1] == "T" else False + + dtype = { + "float": torch.float32, + "double": torch.float64, + "BFloat16": torch.bfloat16, + "Half": torch.half, + "c10::complex": torch.complex128, + "c10::complex": torch.complex64, + "Float8_e4m3fn": torch.float8_e4m3fn, + "Float8_e5m2": torch.float8_e5m2, + "Float8_e4m3fnuz": torch.float8_e4m3fnuz, + "Float8_e5m2fnuz": torch.float8_e5m2fnuz, + }.get(data_type, torch.half) + + if op_sig == "GemmTunableOp": + [n, m, k] = [int(g) for g in untuned_gemm[1].split("_")[1:]] + matA = ( + torch.rand(k, m, dtype=dtype, device="cuda").t() + if transB + else torch.rand(m, k, dtype=dtype, device="cuda") + ) + matB = ( + torch.rand(n, k, dtype=dtype, device="cuda").t() + if transA + else torch.rand(k, n, dtype=dtype, device="cuda") + ) + torch.mm(matA, matB) + elif op_sig == "GemmStridedBatchedTunableOp": + [n, m, k] = [int(g) for g in untuned_gemm[1].split("_")[1:4]] + [b] = [int(g) for g in untuned_gemm[1].split("_")[5:6]] + matA = ( + torch.rand(b, k, m, dtype=dtype, device="cuda") + if transB + else torch.rand(b, m, k, dtype=dtype, device="cuda") + ) + matB = ( + torch.rand(b, n, k, dtype=dtype, device="cuda") + if transA + else torch.rand(b, k, n, dtype=dtype, device="cuda") + ) + matA = matA.transpose(1, 2) if transB else matA + matB = matB.transpose(1, 2) if transA else matB + torch.bmm(matA, matB) + else: + warnings.warn(f"error: unkown op {op_sig}") diff --git a/torch/custom_class.h b/torch/custom_class.h index f97e21f09e7fc..6893eeca93106 100644 --- a/torch/custom_class.h +++ b/torch/custom_class.h @@ -12,6 +12,8 @@ #include #include #include + +#include #include namespace torch { @@ -117,7 +119,7 @@ class class_ : public ::torch::detail::class_base { c10::tagged_capsule self, ParameterTypes... arg) { c10::intrusive_ptr classObj = - at::guts::invoke(func, std::forward(arg)...); + std::invoke(func, std::forward(arg)...); auto object = self.ivalue.toObject(); object->setSlot(0, c10::IValue::make_capsule(classObj)); }; @@ -325,7 +327,7 @@ class class_ : public ::torch::detail::class_base { c10::tagged_capsule self, SetStateArg arg) { c10::intrusive_ptr classObj = - at::guts::invoke(set_state, std::move(arg)); + std::invoke(set_state, std::move(arg)); auto object = self.ivalue.toObject(); object->setSlot(0, c10::IValue::make_capsule(classObj)); }; diff --git a/torch/custom_class_detail.h b/torch/custom_class_detail.h index 138cae75ef67b..81538d26a2258 100644 --- a/torch/custom_class_detail.h +++ b/torch/custom_class_detail.h @@ -6,6 +6,8 @@ #include #include +#include + namespace torch { namespace detail { @@ -80,7 +82,7 @@ struct WrapMethod { WrapMethod(R (CurrClass::*m)(Args...)) : m(std::move(m)) {} R operator()(c10::intrusive_ptr cur, Args... args) { - return c10::guts::invoke(m, *cur, args...); + return std::invoke(m, *cur, args...); } R (CurrClass::*m)(Args...); @@ -91,7 +93,7 @@ struct WrapMethod { WrapMethod(R (CurrClass::*m)(Args...) const) : m(std::move(m)) {} R operator()(c10::intrusive_ptr cur, Args... args) { - return c10::guts::invoke(m, *cur, args...); + return std::invoke(m, *cur, args...); } R (CurrClass::*m)(Args...) const; diff --git a/torch/distributed/_composable/fsdp/_fsdp_collectives.py b/torch/distributed/_composable/fsdp/_fsdp_collectives.py index 2a77aefaf1855..0032f6f635006 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_collectives.py +++ b/torch/distributed/_composable/fsdp/_fsdp_collectives.py @@ -2,7 +2,6 @@ from typing import cast, List, NamedTuple, Optional, Tuple, Union import torch -import torch._dynamo.compiled_autograd as ca import torch.distributed as dist from torch.distributed.device_mesh import _get_device_handle from torch.distributed.distributed_c10d import ReduceOp @@ -12,6 +11,7 @@ _get_dim0_padded_size, _raise_assert_with_print, _to_dtype_if_needed, + compiled_autograd_enabled, ) from ._fsdp_param import FSDPParam, ShardedState @@ -183,7 +183,7 @@ def foreach_all_gather( def _get_param_all_gather_inputs( fsdp_params: List[FSDPParam], ) -> List[List[torch.Tensor]]: - if ca.compiled_autograd_enabled: + if compiled_autograd_enabled(): return [fsdp_param.all_gather_inputs for fsdp_param in fsdp_params] # Intentionally try to run a fast-path that bypasses abstractions for the @@ -245,41 +245,76 @@ def foreach_all_gather_copy_out( param_all_gather_input_numels, all_gather_input_split_sizes, ) = all_gather_result - dtype, device = all_gather_output.dtype, all_gather_output.device + _dtype, device = all_gather_output.dtype, all_gather_output.device device_handle = _get_device_handle(device.type) if all_gather_event is not None: # sync op device_handle.current_stream().wait_event(all_gather_event) if isinstance(all_gather_work, dist.distributed_c10d.Work): # async op all_gather_work.wait() world_size, device = group.size(), all_gather_output.device + + split_with_sizes_out: List[torch.Tensor] = [] + shard_i_copy_infos: List[Tuple[FSDPParam, List[torch.Tensor]]] = [] for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip( param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params ): - if ca.compiled_autograd_enabled: - fsdp_param.init_all_gather_outputs( - all_gather_input_numels, - all_gather_input_dtypes, - world_size, - device, - # NOTE: Under compile, make sure we always recreate all_gather_outputs - # per AllGather. See [Note: Invariants for torch.compile Traceable FSDP2]. - force_recreate=True, - ) - else: - fsdp_param.init_all_gather_outputs( - all_gather_input_numels, all_gather_input_dtypes, world_size, device - ) # no-op after 1st call + # NOTE: Under compile, make sure we always recreate all_gather_outputs + # per AllGather. See [Note: Invariants for torch.compile Traceable FSDP2]. + force_recreate = compiled_autograd_enabled() + fsdp_param.init_all_gather_outputs( + all_gather_input_numels, + all_gather_input_dtypes, + world_size, + device, + force_recreate=force_recreate, + ) + if not force_recreate: fsdp_param.alloc_all_gather_outputs() + param_all_gather_outputs = fsdp_param.all_gather_outputs + if fsdp_param.fsdp_placement.dim != 0: + # Copy to a temporary and then chunk-cat into the final all-gather + # output tensors + param_all_gather_outputs = [ + torch.empty_like(t) for t in param_all_gather_outputs + ] + shard_i_copy_infos.append((fsdp_param, param_all_gather_outputs)) + split_with_sizes_out.extend(param_all_gather_outputs) + all_gather_output = all_gather_output.view(world_size, -1) - gen = (t for fsdp_param in fsdp_params for t in fsdp_param.all_gather_outputs) if all_gather_output.dtype == torch.uint8: - out = [t.view(world_size, -1).view(torch.uint8) for t in gen] + out = [t.view(world_size, -1).view(torch.uint8) for t in split_with_sizes_out] else: - out = [t.view(world_size, -1) for t in gen] + out = [t.view(world_size, -1) for t in split_with_sizes_out] torch.ops.fsdp.split_with_sizes_copy( all_gather_output, all_gather_input_split_sizes, dim=1, out=out ) + for fsdp_param, param_all_gather_outputs in shard_i_copy_infos: + # Chunk-cat from the temporary to the final all-gather output tensors + shard_dim = fsdp_param.fsdp_placement.dim + for param_all_gather_output, target_all_gather_output in zip( + param_all_gather_outputs, fsdp_param.all_gather_outputs + ): + padded_sharded_size = ( + fsdp_param.padded_sharded_param_size + if fsdp_param.sharded_state == ShardedState.SHARDED + else cast( + torch.Tensor, fsdp_param._sharded_post_forward_param_data + ).size() + ) + pre_param_size = list(padded_sharded_size) + pre_param_size[0] *= world_size + chunks = torch.chunk( + param_all_gather_output.view(pre_param_size), world_size, dim=0 + ) + post_param_size = list(padded_sharded_size) + post_param_size[shard_dim] *= world_size + cat_out = target_all_gather_output.view(post_param_size) + torch.cat(chunks, dim=shard_dim, out=cat_out) + torch._C._autograd._unsafe_set_version_counter( + target_all_gather_output, target_all_gather_output._version - 1 + ) + @torch.no_grad() def foreach_reduce( @@ -313,6 +348,14 @@ def foreach_reduce( reduce_scatter_group, all_reduce_group, reduce_dtype ) world_size = reduce_scatter_group.size() + for i, (fsdp_param, unsharded_grad) in enumerate(zip(fsdp_params, unsharded_grads)): + if (shard_dim := fsdp_param.fsdp_placement.dim) == 0: + continue + assert ( + unsharded_grad.size(shard_dim) % world_size == 0 + ), f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}" + chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim) + unsharded_grads[i] = torch.cat(chunks, dim=0) padded_unsharded_sizes = tuple( _get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads ) @@ -374,6 +417,8 @@ def foreach_reduce( for padded_unsharded_size, fsdp_param in zip( padded_unsharded_sizes, fsdp_params ): + # Assume even sharding for Shard(i), i > 0; otherwise would require + # copy-out for contiguous strides new_sharded_grad = torch.as_strided( reduce_output, size=fsdp_param.sharded_size, @@ -404,7 +449,7 @@ def foreach_reduce( new_sharded_grad ) fsdp_param.sharded_param.grad = new_sharded_dtensor_grad - if not ca.compiled_autograd_enabled: + if not compiled_autograd_enabled(): for hook in ( getattr(fsdp_param.sharded_param, "_post_accumulate_grad_hooks", {}) or {} diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index 31b74079aaa8b..d967a55d25454 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -6,7 +6,6 @@ from typing import Any, cast, List, Optional import torch -import torch._dynamo.compiled_autograd as ca import torch.distributed as dist import torch.nn as nn from torch.distributed._composable.contract import _get_registry @@ -14,6 +13,36 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec +_compiled_autograd_enabled: bool = False + +if torch._running_with_deploy(): + + def detect_compiled_autograd(): + pass + + def compiled_autograd_enabled(): + return False + +else: + + def detect_compiled_autograd(): + assert ( + not torch.compiler.is_compiling() + ), "`detect_compiled_autograd()` is designed to be called in eager mode" + global _compiled_autograd_enabled + import torch._dynamo.compiled_autograd as ca + + _compiled_autograd_enabled = ( + ca.enabled() + or ca.compiled_autograd_enabled_force_eager + or ca.in_compiled_autograd_region() + ) + + def compiled_autograd_enabled(): + global _compiled_autograd_enabled + return _compiled_autograd_enabled + + @dataclass class DataParallelMeshInfo: mesh: DeviceMesh @@ -98,13 +127,15 @@ def _chunk_with_empty( return chunks -def _get_dim0_chunked_size( - chunk: torch.Tensor, unchunked_size: torch.Size +def _get_dim_chunked_size( + chunk: torch.Tensor, unchunked_size: torch.Size, dim: int ) -> torch.Size: if chunk.numel() > 0: return chunk.size() - # For 0 numel, we need to preserve trailing dims for DTensor APIs - return cast(torch.Size, torch.Size([0]) + unchunked_size[1:]) + # For 0 numel, we need to preserve nonzero-sized dims for DTensor APIs + return cast( + torch.Size, unchunked_size[:dim] + torch.Size([0]) + unchunked_size[dim + 1 :] + ) def _from_local_no_grad( @@ -116,7 +147,7 @@ def _from_local_no_grad( it avoids some CPU overhead by avoiding default args and not being differentiable. """ - if not ca.compiled_autograd_enabled: + if not compiled_autograd_enabled(): return DTensor( # Use the local tensor directly instead of constructing a new tensor # variable, e.g. with `view_as()`, since this is not differentiable diff --git a/torch/distributed/_composable/fsdp/_fsdp_param.py b/torch/distributed/_composable/fsdp/_fsdp_param.py index e80fc3649a390..ac66b6f3300d5 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param.py @@ -3,10 +3,9 @@ import itertools from dataclasses import dataclass, field from enum import auto, Enum -from typing import Any, cast, List, Optional, Sequence, Tuple +from typing import Any, Callable, cast, List, Optional, Sequence, Tuple import torch -import torch._dynamo.compiled_autograd as ca import torch.nn as nn from torch._prims_common import make_contiguous_strides_for from torch.distributed._functional_collectives import AsyncCollectiveTensor @@ -19,9 +18,10 @@ from ._fsdp_common import ( _chunk_with_empty, _from_local_no_grad, - _get_dim0_chunked_size, + _get_dim_chunked_size, _raise_assert_with_print, _to_dtype_if_needed, + compiled_autograd_enabled, FSDPMeshInfo, HSDPMeshInfo, ) @@ -32,8 +32,8 @@ FSDP considers the following tensors: - Original parameter: parameter passed to :class:`FSDPParam`, i.e. the one on the module when applying FSDP -- Sharded parameter: sharding the original parameter on dim-0 as a DTensor - over the main mesh +- Sharded parameter: sharding the original parameter on dim-0 (or a + user-specified dim) as a DTensor over the main mesh - All-gather inputs: the ``torch.Tensor`` or ``Tensor`` s passed to all-gather, derived from the sharded parameter - All-gather output: the ``torch.Tensor`` or ``Tensor`` s resulting from @@ -221,6 +221,7 @@ def __init__( mesh_info: FSDPMeshInfo, post_forward_mesh_info: Optional[FSDPMeshInfo], device: torch.device, + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]], mp_policy: MixedPrecisionPolicy, offload_policy: OffloadPolicy, ): @@ -234,7 +235,7 @@ def __init__( self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory ) self.grad_offload_event: Optional[torch.Event] = None - self._init_sharded_param(param, device) + self._init_sharded_param(param, device, shard_placement_fn) if self.post_forward_mesh_info: self._init_sharded_post_forward_param_metadata(param) self._init_extensions() @@ -250,7 +251,12 @@ def __init__( ) @torch.no_grad() - def _init_sharded_param(self, param: nn.Parameter, device: torch.device): + def _init_sharded_param( + self, + param: nn.Parameter, + device: torch.device, + shard_placement_fn: Optional[Callable], + ): if param.device != device and param.device.type != "meta": raise AssertionError( f"Expects the parameter to already be moved to device {device} but got {param.device}" @@ -259,6 +265,14 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device): raise NotImplementedError( f"FSDP does not support non-contiguous parameters yet: {param.shape=} {param.stride()=}" ) + fsdp_placement = shard_placement_fn(param) if shard_placement_fn else None + if fsdp_placement is None: + fsdp_placement = Shard(0) + elif fsdp_placement.dim < 0: + fsdp_placement = Shard(fsdp_placement.dim + param.ndim) + assert isinstance(fsdp_placement, Shard), f"{fsdp_placement}" + self.fsdp_placement = fsdp_placement + shard_dim = fsdp_placement.dim # TODO: Replace the sharded DTensor parameter construction logic with # `distribute_tensor` after https://github.com/pytorch/pytorch/issues/116101 # TODO: Simplify the following sharded parameter padding logic after @@ -276,7 +290,6 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device): "FSDP requires the DP and TP mesh to have the same parent mesh but got: \n" f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}" ) - name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism" assert dp_mesh.mesh_dim_names is not None, name_dims_error assert tp_mesh.mesh_dim_names is not None, name_dims_error @@ -286,16 +299,16 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device): raise NotImplementedError( f"FSDP only supports 1D TP, not {self._tp_spec.placements}" ) - split_factor = self._tp_spec.num_shards_map[0] + split_factor = self._tp_spec.num_shards_map[shard_dim] assert ( 2 <= self._spmd_mesh.ndim <= 3 ), f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}." self._spmd_placements: Tuple[Placement, ...] dp_shard_tp_placement = ( ( - _StridedShard(0, split_factor=split_factor) + _StridedShard(shard_dim, split_factor=split_factor) if split_factor > 1 - else Shard(0) + else fsdp_placement ), self._tp_spec.placements[0], ) @@ -309,8 +322,7 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device): self._spmd_placements, tensor_meta=self._tp_spec.tensor_meta, ) - # NOTE: FSDP+TP does not support uneven sharding for now - # TODO: enable uneven sharding for FSDP+TP + # TODO: Enable uneven sharding for FSDP+TP. if split_factor > 1: # FSDP has strided sharding on tensor dim 0 num_shards = self._sharding_spec.num_shards_map[0] tensor_size_dim_0 = self._sharding_spec.shape[0] @@ -320,38 +332,52 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device): f"tensor dim 0 has size {tensor_size_dim_0} which cannot be " f"evenly sharded into {num_shards} shards." ) - param_data = cast(DTensor, param)._local_tensor else: self._spmd_mesh = self.mesh_info.mesh if isinstance(self.mesh_info, HSDPMeshInfo): - self._spmd_placements = (Replicate(), Shard(0)) + self._spmd_placements = (Replicate(), fsdp_placement) else: - self._spmd_placements = (Shard(0),) + self._spmd_placements = (fsdp_placement,) self._sharding_spec = DTensorSpec( self._spmd_mesh, self._spmd_placements, - tensor_meta=TensorMeta( - param.size(), - param.stride(), - param.dtype, - ), + tensor_meta=TensorMeta(param.size(), param.stride(), param.dtype), ) param_data = param assert param_data.is_contiguous(), f"{param_data.shape=} {param_data.stride()=}" + shard_dim = fsdp_placement.dim + if shard_dim >= param_data.ndim: + raise AssertionError( + f"Shard dim {shard_dim} is invalid for {param_data.ndim}D tensor: {param.shape}" + ) self._orig_size = param_data.size() self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size) shard_rank = self.mesh_info.shard_mesh_rank shard_world_size = self.mesh_info.shard_mesh_size - chunks = _chunk_with_empty(param_data, shard_world_size, dim=0) + if shard_dim > 0 and param_data.size(shard_dim) % shard_world_size != 0: + # If sharding on nonzero dim, require even sharding for now because + # the uneven sharding (1) requires extra copies before/after FSDP + # collectives and (2) introduces extra complexity to handle padding + # and unpadding + raise NotImplementedError( + f"FSDP does not support uneven sharding on dim {shard_dim}: " + f"{param_data.size()} (world size: {shard_world_size})" + ) + chunks = _chunk_with_empty(param_data, shard_world_size, dim=shard_dim) sharded_param = chunks[shard_rank] - self.sharded_size = _get_dim0_chunked_size(sharded_param, param_data.size()) + self.sharded_size = _get_dim_chunked_size( + sharded_param, param_data.size(), dim=shard_dim + ) self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size) padded_sharded_size = chunks[0].size() # 0th always padded + self.padded_sharded_param_size = padded_sharded_size + # Pre-pad the sharded parameter to avoid padding before all-gather padded_sharded_param = param_data.new_zeros(padded_sharded_size) - self.padded_sharded_param_size = padded_sharded_param.size() if sharded_param.numel() > 0: - padded_sharded_param[: sharded_param.size(0)].copy_(sharded_param) + padded_sharded_param.narrow( + dim=shard_dim, start=0, length=sharded_param.size(shard_dim) + ).copy_(sharded_param) if self.offload_to_cpu and not padded_sharded_param.is_meta: padded_sharded_param = padded_sharded_param.cpu() if self.pin_memory: @@ -359,9 +385,12 @@ def _init_sharded_param(self, param: nn.Parameter, device: torch.device): device=self.device ) self._sharded_param_data = padded_sharded_param.view(-1) - self.sharded_param = nn.Parameter( - self.to_sharded_dtensor(padded_sharded_param[: sharded_param.size(0)]) + length = sharded_param.size(shard_dim) if sharded_param.numel() > 0 else 0 + sharded_param = padded_sharded_param.narrow( + dim=shard_dim, start=0, length=length ) + assert sharded_param.is_contiguous(), f"{self.fsdp_placement=}" + self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param)) self.sharded_param.requires_grad_(param.requires_grad) # Let `param_data` be freed normally when its ref count reaches 0 when # the `fully_shard` call returns to allow provided parameters to alias @@ -373,8 +402,10 @@ def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None assert mesh_info is not None # mypy param_data = param._local_tensor if isinstance(param, DTensor) else param chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0) - self.sharded_post_forward_size = _get_dim0_chunked_size( - chunks[mesh_info.shard_mesh_rank], param_data.size() + self.sharded_post_forward_size = _get_dim_chunked_size( + chunks[mesh_info.shard_mesh_rank], + param_data.size(), + dim=self.fsdp_placement.dim, ) self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for( self.sharded_post_forward_size @@ -434,7 +465,7 @@ def init_unsharded_param(self): - Sharded parameters - Placeholders for the `self._unsharded_param` nn.Parameter """ - if not ca.compiled_autograd_enabled and hasattr( + if not compiled_autograd_enabled() and hasattr( self, "_unsharded_param" ): # after the 1st all-gather inner_tensor = self._sharded_local_tensor @@ -452,7 +483,7 @@ def init_unsharded_param(self): self._extensions_data.clear() return inner_tensor = self._sharded_local_tensor - if not ca.compiled_autograd_enabled and hasattr( + if not compiled_autograd_enabled() and hasattr( inner_tensor, "fsdp_post_all_gather" ): all_gather_outputs = self._unflatten_all_gather_outputs() @@ -479,7 +510,7 @@ def init_unsharded_param(self): if self.is_dtensor: unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec) if hasattr(self, "_unsharded_param"): - assert ca.compiled_autograd_enabled + assert compiled_autograd_enabled() with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter( self._unsharded_param ): @@ -618,7 +649,7 @@ def alloc_all_gather_outputs(self) -> None: alloc_storage(tensor) def free_unsharded_param(self) -> None: - if ca.compiled_autograd_enabled: + if compiled_autograd_enabled(): """ Assumptions under compile: - `self._unsharded_param` is NOT an alias of `self.all_gather_outputs`. @@ -643,7 +674,7 @@ def free_unsharded_param(self) -> None: def all_gather_inputs(self) -> List[torch.Tensor]: # 1D self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD) if self.sharded_state == ShardedState.SHARDED: - if not ca.compiled_autograd_enabled and hasattr( + if not compiled_autograd_enabled() and hasattr( self._sharded_local_tensor, "fsdp_pre_all_gather" ): sharded_local_tensor = self._sharded_local_tensor @@ -707,7 +738,7 @@ def all_gather_inputs(self) -> List[torch.Tensor]: # 1D ) return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)] elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD: - if not ca.compiled_autograd_enabled and hasattr( + if not compiled_autograd_enabled() and hasattr( self._sharded_local_tensor, "fsdp_pre_all_gather" ): raise NotImplementedError @@ -720,7 +751,6 @@ def all_gather_inputs(self) -> List[torch.Tensor]: # 1D @property def unsharded_param(self) -> nn.Parameter: # ND - self._assert_in_states(ShardedState.UNSHARDED) return self._unsharded_param @property @@ -787,9 +817,16 @@ def reset_sharded_param(self): return updated_local_tensor = False padded_sharded_size = self.padded_sharded_param_size + shard_dim = self.fsdp_placement.dim + length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0 if local_tensor.size() != padded_sharded_size: + assert ( + shard_dim == 0 + ), f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}" padded_local_tensor = local_tensor.new_zeros(padded_sharded_size) - padded_local_tensor[: local_tensor.size(0)].copy_(local_tensor) + padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_( + local_tensor + ) local_tensor = padded_local_tensor updated_local_tensor = True if self.pin_memory and not local_tensor.is_pinned(): @@ -799,7 +836,11 @@ def reset_sharded_param(self): assert isinstance(self.sharded_param, DTensor) # mypy if updated_local_tensor: # Only change the local tensor object if needed - self.sharded_param._local_tensor = local_tensor[: self.sharded_size[0]] + self.sharded_param._local_tensor = local_tensor.narrow( + dim=shard_dim, start=0, length=length + ) + assert self.sharded_param._local_tensor.is_contiguous() + self._sharding_spec = self.sharded_param._spec def __repr__(self): return f"FSDPParam(fqn={self._param_fqn}, orig_size={self._orig_size})" diff --git a/torch/distributed/_composable/fsdp/_fsdp_param_group.py b/torch/distributed/_composable/fsdp/_fsdp_param_group.py index de01249148aae..e19ac1e814dc4 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_param_group.py +++ b/torch/distributed/_composable/fsdp/_fsdp_param_group.py @@ -1,14 +1,14 @@ # mypy: allow-untyped-defs import contextlib import logging -from typing import Any, cast, Dict, List, NamedTuple, Optional, Set, Tuple +from typing import Any, Callable, cast, Dict, List, NamedTuple, Optional, Set, Tuple import torch -import torch._dynamo.compiled_autograd as ca import torch.distributed as dist import torch.nn as nn from torch.distributed.device_mesh import _get_device_handle from torch.distributed.fsdp._common_utils import _named_parameters_with_duplicates +from torch.distributed.tensor import Shard from torch.profiler import record_function from torch.utils._pytree import tree_flatten, tree_unflatten from torch.utils.hooks import RemovableHandle @@ -20,7 +20,12 @@ foreach_all_gather_copy_out, foreach_reduce, ) -from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo, TrainingState +from ._fsdp_common import ( + compiled_autograd_enabled, + FSDPMeshInfo, + HSDPMeshInfo, + TrainingState, +) from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState @@ -113,6 +118,7 @@ def __init__( mesh_info: FSDPMeshInfo, post_forward_mesh_info: Optional[FSDPMeshInfo], device: torch.device, + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]], mp_policy: MixedPrecisionPolicy, offload_policy: OffloadPolicy, ): @@ -126,6 +132,7 @@ def __init__( mesh_info, post_forward_mesh_info, device, + shard_placement_fn, mp_policy, offload_policy, ) @@ -169,6 +176,9 @@ def __init__( # overridden to only do explicit prefetching and avoid inter-stream # fragmentation from using separate unshard streams self.unshard_async_op: bool = False + # Whether to unshard in backward: can be overridden by the user if the + # parameters in this group are not needed for backward (e.g. embedding) + self.unshard_in_backward: bool = True # - CUDA events for stream synchronization # Holds the all-gather output buffer, sync objects, and metadata @@ -231,6 +241,11 @@ def unshard(self, async_op: bool = False): return if self.is_unsharded: return # no-op + if ( + not self.unshard_in_backward + and self._training_state == TrainingState.PRE_BACKWARD + ): + return if self._reshard_after_forward_event is not None: # Resharded parameter data is allocated in the default stream and # used in the all-gather streams @@ -304,7 +319,7 @@ def reshard(self): def pre_forward( self, module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: - if not ca.compiled_autograd_enabled: + if not compiled_autograd_enabled(): logger.debug("%s", self._with_fqn("FSDP::pre_forward")) with record_function(self._with_fqn("FSDP::pre_forward")): self._training_state = TrainingState.FORWARD @@ -314,7 +329,7 @@ def pre_forward( return args, kwargs def post_forward(self, module: nn.Module, input: Any, output: Any): - if not ca.compiled_autograd_enabled: + if not compiled_autograd_enabled(): logger.debug("%s", self._with_fqn("FSDP::post_forward")) with record_function(self._with_fqn("FSDP::post_forward")): self.reshard() @@ -332,17 +347,17 @@ def _record_post_forward(self) -> None: def pre_backward(self, default_prefetch: bool, *unused: Any): if self._training_state == TrainingState.PRE_BACKWARD: return - if not ca.compiled_autograd_enabled: + if not compiled_autograd_enabled(): logger.debug("%s", self._with_fqn("FSDP::pre_backward")) with record_function(self._with_fqn("FSDP::pre_backward")): self._training_state = TrainingState.PRE_BACKWARD self.unshard(self.unshard_async_op) # no-op if prefetched self.wait_for_unshard() - if default_prefetch and not ca.compiled_autograd_enabled: + if default_prefetch and not compiled_autograd_enabled(): self._backward_prefetch() def post_backward(self, *unused: Any): - if not ca.compiled_autograd_enabled: + if not compiled_autograd_enabled(): logger.debug("%s", self._with_fqn("FSDP::post_backward")) self._training_state = TrainingState.POST_BACKWARD with record_function(self._with_fqn("FSDP::post_backward_accumulate")): @@ -503,7 +518,7 @@ def _register_post_backward_hook( ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: # Traceable FSDP2 relies on `root_post_backward_callback` to call each # `FSDPParamGroup.post_backward` - if (not torch._dynamo.config.skip_fsdp_hooks) or ca.compiled_autograd_enabled: + if (not torch._dynamo.config.skip_fsdp_hooks) or compiled_autograd_enabled(): return args, kwargs if not torch.is_grad_enabled(): return args, kwargs @@ -659,7 +674,7 @@ def _get_param_module_infos( class RegisterPostBackwardFunction(torch.autograd.Function): @staticmethod def _assert_not_tracing_fsdp(): - if ca.compiled_autograd_enabled: + if compiled_autograd_enabled(): # TODO: Find a way to print the offending FSDP2 module. msg = """\ When Traceable FSDP2 is enabled, we rely on `root_post_backward_callback` to call diff --git a/torch/distributed/_composable/fsdp/_fsdp_state.py b/torch/distributed/_composable/fsdp/_fsdp_state.py index 4088f0e69dd8c..1659bf1133c9e 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_state.py +++ b/torch/distributed/_composable/fsdp/_fsdp_state.py @@ -15,7 +15,6 @@ ) import torch -import torch._dynamo.compiled_autograd as ca import torch.nn as nn from torch._logging import warning_once from torch.autograd import Variable @@ -30,7 +29,12 @@ from torch.utils._pytree import tree_flatten, tree_map from ._fsdp_api import MixedPrecisionPolicy -from ._fsdp_common import _cast_fp_tensor, TrainingState +from ._fsdp_common import ( + _cast_fp_tensor, + compiled_autograd_enabled, + detect_compiled_autograd, + TrainingState, +) from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup @@ -119,7 +123,7 @@ def _root_pre_forward( self._lazy_init() if self._state_ctx.iter_forward_root is not None: return args, kwargs - if not ca.compiled_autograd_enabled: + if not compiled_autograd_enabled(): logger.debug("FSDP::root_pre_forward") self._state_ctx.iter_forward_root = self with torch.profiler.record_function("FSDP::root_pre_forward"): @@ -154,6 +158,7 @@ def _lazy_init(self) -> None: raise RuntimeError( f"FSDP requires a single root module but got {self._modules}" ) + detect_compiled_autograd() root_module = self._modules[0] visited_states: Set[FSDPState] = set() for module_name, module in root_module.named_modules(): @@ -277,17 +282,21 @@ def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor: return grad def _root_post_backward_final_callback(self) -> None: - if not ca.compiled_autograd_enabled: + if not compiled_autograd_enabled(): logger.debug("FSDP::root_post_backward") with torch.profiler.record_function("FSDP::root_post_backward_callback"): for state in self._state_ctx.all_states: - if state._fsdp_param_group and state._fsdp_param_group.is_unsharded: + fsdp_param_group = state._fsdp_param_group + if fsdp_param_group and ( + fsdp_param_group.is_unsharded + or not fsdp_param_group.unshard_in_backward + ): # Run post-backward in case forward inputs did not require # gradient so the autograd backward did not run - state._fsdp_param_group.post_backward() + fsdp_param_group.post_backward() state._training_state = TrainingState.IDLE - if state._fsdp_param_group: - state._fsdp_param_group._training_state = TrainingState.IDLE + if fsdp_param_group: + fsdp_param_group._training_state = TrainingState.IDLE if self._state_ctx.is_last_backward: state._finalize_backward() if self._state_ctx.is_last_backward: diff --git a/torch/distributed/_composable/fsdp/fully_shard.py b/torch/distributed/_composable/fsdp/fully_shard.py index 7e40ac8a730f1..fbcc8e11c34f2 100644 --- a/torch/distributed/_composable/fsdp/fully_shard.py +++ b/torch/distributed/_composable/fsdp/fully_shard.py @@ -1,12 +1,23 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import functools -from typing import Any, cast, Dict, Iterable, List, NoReturn, Optional, Type, Union +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + NoReturn, + Optional, + Type, + Union, +) import torch import torch.nn as nn from torch.distributed._composable import contract -from torch.distributed.tensor import DeviceMesh +from torch.distributed.tensor import DeviceMesh, Shard from torch.distributed.utils import _get_root_modules from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy @@ -34,6 +45,7 @@ def fully_shard( *, mesh: Optional[DeviceMesh] = None, reshard_after_forward: Union[bool, int] = True, + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None, mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), offload_policy: OffloadPolicy = OffloadPolicy(), ): @@ -95,6 +107,14 @@ def fully_shard( between forward and backward, the registered parameters must be the sharded parameters. For ``False`` or an ``int``, this can be done by manually resharding via :meth:`reshard`. + shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]): + This callable can be used to override the sharding placement for a + parameter to shard a parameter on a dimension other than dim-0. If + this callable returns a ``Shard`` placement (not ``None``), then + FSDP will shard according to that placement (e.g. ``Shard(1)``). + If sharding on a nonzero dim, we currently require even sharding, + i.e. the tensor dim size on that dim must be divisible by the FSDP + shard mesh size. mp_policy (MixedPrecisionPolicy): This controls the mixed precision policy, which offers parameter/reduction mixed precision for this module. See :class:`MixedPrecisionPolicy` for details. @@ -139,6 +159,7 @@ def fully_shard( mesh_info, post_forward_mesh_info, device, + shard_placement_fn, mp_policy, offload_policy, ) @@ -362,6 +383,17 @@ def set_reduce_scatter_divide_factor(self, factor: float) -> None: reduce_op = torch.distributed._make_nccl_premul_sum(mul_factor) fsdp_param_group.reduce_scatter_reduce_op = reduce_op + def set_unshard_in_backward(self, unshard_in_backward: bool) -> None: + """ + Sets whether the FSDP module's parameters need to be unsharded in + backward. This can be used in expert cases when the user knows that all + parameters in this FSDP module's parameter group are not needed for + backward computation (e.g. embedding). + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group.unshard_in_backward = unshard_in_backward + def _set_unshard_async_op(self, async_op: bool): """ Sets whether to use ``async_op=True`` or ``False`` for the pre-forward diff --git a/torch/distributed/_composable/replicate.py b/torch/distributed/_composable/replicate.py index ac1956ce0962c..d86f3c5db33f0 100644 --- a/torch/distributed/_composable/replicate.py +++ b/torch/distributed/_composable/replicate.py @@ -82,8 +82,6 @@ def init( return self.has_initialized = True - - device_mesh = kwargs.get("device_mesh", None) self.module = module ignored_params = {p for m in ignored_modules for p in m.parameters()} for submodule in module.modules(): diff --git a/torch/distributed/_shard/api.py b/torch/distributed/_shard/api.py index 3a5f0c552cbed..975f499023d13 100644 --- a/torch/distributed/_shard/api.py +++ b/torch/distributed/_shard/api.py @@ -274,7 +274,7 @@ def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_grou mod, param_name, spec, src_rank=src_rank, process_group=process_group ) elif isinstance(spec, Sharder): - parent_mod_path, _, mod_name = name.rpartition(".") + parent_mod_path, _, _mod_name = name.rpartition(".") if name == "": raise KeyError("Module path must not be empty for custom sharder!") mod = module.get_submodule(name) diff --git a/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py index f8db8b6ebe96f..0548b81fb90af 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py @@ -25,7 +25,6 @@ def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None): if len(args) != 2: raise ValueError(f"Expected two arguments for torch.{cmp_fun.__name__}") - result = True st1 = args[0] st2 = args[1] if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)): diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index d50160ca8ecc3..23a0d2d21f953 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -857,7 +857,7 @@ def _init_from_local_tensor( local_shards: List[Shard] = [] for shard_metadata in sharded_tensor_metadata.shards_metadata: - rank, device = _parse_and_validate_remote_device( + rank, _device = _parse_and_validate_remote_device( process_group, shard_metadata.placement ) if rank == current_rank: diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index 042bd14f59848..538846c817477 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -3,6 +3,7 @@ import uuid from contextlib import contextmanager from datetime import timedelta +from enum import Enum from functools import partial from typing import Any, Callable, Dict, Generator, List, Optional, Tuple @@ -133,25 +134,32 @@ def _get_backend_stream() -> torch.cuda.Stream: return _backend_stream -def _pipelined_all_gather_and_consume( - shard: torch.Tensor, - shard_consumer: Callable[[torch.Tensor, int], None], - ag_out: torch.Tensor, +def _pipelined_multi_all_gather_and_consume( + shard: List[torch.Tensor], + shard_consumer: Callable[[List[torch.Tensor], int], None], + ag_out: List[torch.Tensor], group_name: str, ) -> None: """ Perform the following logic with micro-pipelined computation and communication: - tensor = all_gather_tensor(shard, gather_dim=1, group=group) - chunks = tensor.chunk(group.size()) - for src_rank, chunk in enumerate(chunks): - shard_consumer(chunk, src_rank) + gathered = [ + all_gather_tensor(x, gather_dim=0, group=group) + for x in shard + ] + + shards = [[] for _ in range(group_size)] + for x in ag_out: + for i, y in enumerate(x.chunk(group_size)): + shards[i].append(y) - NOTE: - - The shard passed to shard consumer will always be contiguous. + for src_rank, shard in enumerate(shards): + shard_consumer(shard, src_rank) """ - p2p_workspace_size_req = shard.numel() * shard.element_size() + p2p_workspace_size_req = 0 + for x in shard: + p2p_workspace_size_req += x.numel() * x.element_size() symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) group_size = symm_mem.world_size rank = symm_mem.rank @@ -159,38 +167,144 @@ def _pipelined_all_gather_and_consume( symm_mem.barrier(channel=0) backend_stream = _get_backend_stream() backend_stream.wait_stream(torch.cuda.current_stream()) - local_p2p_buf = symm_mem.get_buffer(rank, shard.shape, shard.dtype) - chunks = ag_out.chunk(group_size) - - # While consuming local shard, copy it to the local p2p buffer - # in another stream. - shard_consumer(shard, rank) - chunks[rank].copy_(shard) - - with torch.cuda.stream(backend_stream): - local_p2p_buf.copy_(shard) - symm_mem.barrier(channel=1) - torch.cuda.current_stream().wait_stream(backend_stream) + for x, y in zip(shard, ag_out): + assert x.is_contiguous(), ( + "_pipelined_all_gather_and_consume: all tensors " + "in `shard` must be contiguous" + ) + assert y.is_contiguous(), ( + "_pipelined_all_gather_and_consume: all tensors " + "in `ag_out` must be contiguous" + ) + assert x.shape[0] * group_size == y.shape[0] + assert x.shape[1:] == y.shape[1:] + + def copy_shard(dst: List[torch.Tensor], src: List[torch.Tensor]) -> None: + for d, s in zip(dst, src): + d.copy_(s) + + def get_p2p_bufs(remote_rank: int) -> List[torch.Tensor]: + offset_bytes = 0 + bufs = [] + for x in shard: + buf = symm_mem.get_buffer( + remote_rank, + x.shape, + x.dtype, + storage_offset=offset_bytes // x.element_size(), + ) + bufs.append(buf) + offset_bytes += buf.numel() * buf.element_size() + return bufs + + local_p2p_bufs = get_p2p_bufs(rank) + + # shards[i] => shard from rank i + shards: List[List[torch.Tensor]] = [[] for _ in range(group_size)] + for x in ag_out: + for i, y in enumerate(x.chunk(group_size)): + shards[i].append(y) + + # Parallelization strategy: after each rank copies its shard into its local + # p2p buffer, every rank issues independent p2p copy -> shard_consumer + # sequences to two streams. In addition to computation/communication + # overlapping, the strategy allows for computation/computation overlapping, + # greatly reducing quantization inefficiency. + # + # Notation: + # - "mv" for the copy to local buffer + # - "cp" for p2p copies + # - "b" for barriers + # + # Constraints: + # - The GPU scheduler may or may not overlap "mv" with the first shard_consumer. + # - "cp" from different streams cannot overlap. + # + # Ideal scenario 0 - "mv" overlaps with the first shard_consumer: + # + # stream 0: [ shard_consumer ][ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Ideal scenario 1 - "mv" is scheduled before the first shard_consumer: + # + # stream 0: [ shard_consumer ][ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Suboptimal scenario 0 - "mv" is scheduled after the first shard_consumer: + # + # stream 0: [ shard_consumer ] [ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Suboptimal scenario 0 - "b" is scheduled after the first shard_consumer: + # + # stream 0: [ shard_consumer ] [ cp ][ shard_consumer ] + # stream 1: [ mv ] [b][ cp ][ shard_consumer ] + # + # We haven't yet figured out a way to ensure "mv" and "b" are either + # overlapped with or scheduled before the first shard_consumer. Thus, to + # prevent suboptimal scenarios, we are giving up the chance to overlap "mv" + # and "b" with the first shard_consumer for now. + copy_shard(dst=local_p2p_bufs, src=shard) + symm_mem.barrier(channel=1) + backend_stream.wait_stream(torch.cuda.current_stream()) # At this point, all ranks have copied their local shard to # their local p2p buffer. Each rank can now copy and consume # remote shards. + shard_consumer(shard, rank) + for step in range(1, group_size): if step % 2 == 0: stream = torch.cuda.current_stream() else: stream = backend_stream remote_rank = (step + rank) % group_size - remote_p2p_buf = symm_mem.get_buffer(remote_rank, shard.shape, shard.dtype) + remote_p2p_bufs = get_p2p_bufs(remote_rank) with torch.cuda.stream(stream): - chunks[remote_rank].copy_(remote_p2p_buf) - shard_consumer(chunks[remote_rank], remote_rank) + copy_shard(dst=shards[remote_rank], src=remote_p2p_bufs) + shard_consumer(shards[remote_rank], remote_rank) + + # Copy from input to the all-gather output. Opportunistically overlap it + # with the last shard_consumer. + if group_size % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + with torch.cuda.stream(stream): + copy_shard(dst=shards[rank], src=shard) torch.cuda.current_stream().wait_stream(backend_stream) symm_mem.barrier(channel=0) +def _pipelined_all_gather_and_consume( + shard: torch.Tensor, + shard_consumer: Callable[[torch.Tensor, int], None], + ag_out: torch.Tensor, + group_name: str, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + ag_out = all_gather_tensor(shard, gather_dim=0, group=group) + shards = ag_out.chunk(group.size()) + for src_rank, shard in enumerate(shards): + shard_consumer(shard, src_rank) + """ + + def adapter(shard: List[torch.Tensor], rank: int) -> None: + shard_consumer(shard[0], rank) + + _pipelined_multi_all_gather_and_consume( + [shard], + adapter, + [ag_out], + group_name, + ) + + def _pipelined_produce_and_all2all( chunk_producer: Callable[[int, torch.Tensor], None], output: torch.Tensor, @@ -233,22 +347,68 @@ def get_p2p_buf(rank: int, idx: int) -> torch.Tensor: remote_rank = (rank - step) % group_size if step % 2 == 0: stream = torch.cuda.current_stream() - other_stream = backend_stream p2p_buf = local_p2p_buf_1 remote_p2p_buf = get_p2p_buf(remote_rank, 1) else: stream = backend_stream - other_stream = torch.cuda.current_stream() p2p_buf = local_p2p_buf_0 remote_p2p_buf = get_p2p_buf(remote_rank, 0) with torch.cuda.stream(stream): + # Parallelization strategy: every rank issues independent compute + # -> barrier -> p2p copy sequences on two streams. In addition to + # computation/communication overlapping, the strategy allows for + # computation/computation overlapping, greatly reducing + # quantization inefficiency. + # + # Ideally, stream activities would look like this ("b" for + # barriers, "cp" for p2p copies): + # + # [rank 0] + # stream 0: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # stream 1: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # + # [rank 1] + # stream 0: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # stream 1: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # + # Note that the barriers synchronize streams with the same ID + # across ranks. They don't synchronize streams on the same rank. + # + # Since the work on both streams is independent, there's no + # guarantee that the chunk_producer from stream 0 or stream 1 will + # be scheduled first. If there is a scheduling mismatch across + # ranks, the barrier forces all ranks to wait for the slowest. + # + # When scheduling mismatches occur among ranks, the stream + # activities might look like this (note that p2p copies from + # different streams cannot overlap with each other): + # + # [rank 0] + # stream 0: [ chunk_producer ][b ][ cp ][ chunk_producer ][b ][ cp ] + # stream 1: [ chunk_producer ][b] [ cp ][ chunk_producer ][b] [ cp ] + # + # [rank 1] + # stream 0: [ chunk_producer ][b] [ cp ][ chunk_producer ][b] [ cp ] + # stream 1: [ chunk_producer ][b ][ cp ][ chunk_producer ][b ][ cp ] + # + # To prevent this, we need to ensure that the chunk_producer on + # stream 1 gets scheduled first on every rank. Without access to + # the underlying kernels, CUDA offers no API to control the + # scheduling order of two independent, overlapping kernels. Our + # solution is to issue a small sleep kernel in stream 0. The sleep + # duration is insignificant, but having an extra task in stream 0 + # will almost guarantee that the chunk_producer on stream 1 gets + # scheduled first. Once the first chunk_producer is scheduled in + # the correct order, there's very little room for the scheduling + # order of subsequent kernels to be inconsistent across ranks. + if step == 2: + torch.cuda._sleep(100) chunk_producer((rank + step) % group_size, p2p_buf) symm_mem.barrier(channel=step % 2) - # Make the other stream to wait for the barrier on the current - # stream to finish before chunk_producer to avoid the compute - # delaying the barrier. - other_stream.wait_stream(stream) out_chunks[remote_rank].copy_(remote_p2p_buf) + # The local P2P buffer can only be overwritten by the next + # chunk_producer after all peers have finished reading from it. + symm_mem.barrier(channel=step % 2) chunk_producer(rank, out_chunks[rank]) torch.cuda.current_stream().wait_stream(backend_stream) @@ -286,10 +446,44 @@ def get_p2p_buf(rank: int, idx: int) -> torch.Tensor: ) +class _ScaleMode(Enum): + UNSCALED = "unscaled" + TENSOR_WISE = "tensor-wise" + ROW_WISE_SHARDED = "row-wise-sharded" + ROW_WISE_REPLICATED = "row-wise-replicated" + + +def _check_and_verify_fp8_all_gather_scale_mode( + shard: torch.Tensor, scale: Optional[torch.Tensor], gather_dim: int, group_size: int +) -> _ScaleMode: + full_shape = list(shard.shape) + full_shape[gather_dim] *= group_size + + if scale is None: + return _ScaleMode.UNSCALED + elif scale.shape[:-1] == shard.shape[:-1] and scale.shape[-1] == 1: + # Row-wise scaling + # + # NOTE: when the last dim of both A_shard and A_scale is one, we can't + # tell if A_scale is replicated tensor-wise scale or sharded row-wise + # scale. Treating it as row-wise scaling for safety. + return _ScaleMode.ROW_WISE_SHARDED + elif scale.numel() == 1: + return _ScaleMode.TENSOR_WISE + elif list(scale.shape[:-1]) == full_shape[:-1]: + return _ScaleMode.ROW_WISE_REPLICATED + else: + raise ValueError( + "Invalid scale shape for fp8 all-gather " + f"(shard shape: {shard.shape}, scale shape: {scale.shape})" + ) + + def _fused_all_gather_matmul_impl( mm_out_op: torch._ops.OpOverload, A_shard: torch.Tensor, Bs: List[torch.Tensor], + A_scale: Optional[torch.Tensor], kwargs_list: List[Dict[str, Any]], out_dtypes: List[Optional[torch.dtype]], gather_dim: int, @@ -313,36 +507,96 @@ def _fused_all_gather_matmul_impl( # The flattened tensor doesn't need to be contiguous (for computation # efficiency), as _pipelined_all_gather_and_consume guarantees that shards # passed to shard_consumer are contiguous. - x = A_shard.movedim(gather_dim, 0) - leading_dims = [group.size()] + list(x.shape[:-1]) - x = x.flatten(0, -2) + A_shard_flat = A_shard.movedim(gather_dim, 0) + leading_dims = [group.size()] + list(A_shard_flat.shape[:-1]) + A_shard_flat = A_shard_flat.flatten(0, -2) # Helper function for reverting the above transformation def unflatten(t: torch.Tensor) -> torch.Tensor: return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim) - ag_out = x.new_empty( - x.shape[0] * group.size(), - x.shape[1], + A_flat = A_shard_flat.new_empty( + A_shard_flat.shape[0] * group.size(), + A_shard_flat.shape[1], ) + outputs = [ - x.new_empty(x.shape[0] * group.size(), B.shape[1], dtype=out_dtype or B.dtype) + A_flat.new_empty(A_flat.shape[0], B.shape[1], dtype=out_dtype or B.dtype) for B, out_dtype in zip(Bs, out_dtypes) ] output_shards = [output.chunk(group.size()) for output in outputs] - # Computing block-wise matmul along the first dim of A - def shard_consumer(shard: torch.Tensor, rank: int) -> None: - for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): - mm_out_op(shard, B, **kwargs, out=output_shards[idx][rank]) - - _pipelined_all_gather_and_consume( - x, - shard_consumer, - ag_out, - group_name, + scale_mode = _check_and_verify_fp8_all_gather_scale_mode( + shard=A_shard, scale=A_scale, gather_dim=gather_dim, group_size=group.size() ) - return unflatten(ag_out), [unflatten(output) for output in outputs] + + # Computing block-wise matmul along the first dim of A + if scale_mode == _ScaleMode.ROW_WISE_SHARDED: + assert A_scale is not None + A_scale_shard = A_scale.movedim(gather_dim, 0).flatten(0, -2) + A_scale_flat = A_scale_shard.new_empty( + A_scale_shard.shape[0] * group.size(), + A_scale_shard.shape[1], + ) + + def row_wise_sharded_consumer(shard: List[torch.Tensor], rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op( + shard[0], + B, + scale_a=shard[1], + **kwargs, + out=output_shards[idx][rank], + ) + + _pipelined_multi_all_gather_and_consume( + [A_shard_flat, A_scale_shard], + row_wise_sharded_consumer, + [A_flat, A_scale_flat], + group_name, + ) + elif scale_mode == _ScaleMode.ROW_WISE_REPLICATED: + assert A_scale is not None + A_scale_shards = ( + A_scale.movedim(gather_dim, 0).flatten(0, -2).chunk(group.size()) + ) + + def row_wise_replicated_consumer(shard: torch.Tensor, rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op( + shard, + B, + scale_a=A_scale_shards[rank], + **kwargs, + out=output_shards[idx][rank], + ) + + _pipelined_all_gather_and_consume( + A_shard_flat, + row_wise_replicated_consumer, + A_flat, + group_name, + ) + else: + if scale_mode == _ScaleMode.TENSOR_WISE: + assert A_scale is not None + for kwargs in kwargs_list: + kwargs["scale_a"] = A_scale + else: + assert scale_mode == _ScaleMode.UNSCALED + + def default_consumer(shard: torch.Tensor, rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op(shard, B, **kwargs, out=output_shards[idx][rank]) + + _pipelined_all_gather_and_consume( + A_shard_flat, + default_consumer, + A_flat, + group_name, + ) + + return unflatten(A_flat), [unflatten(output) for output in outputs] @torch.library.impl(lib, "fused_all_gather_matmul", "Meta") @@ -388,6 +642,7 @@ def _fused_all_gather_matmul( torch.ops.aten.mm.out, A_shard, Bs, + None, [{} for B in Bs], [B.dtype for B in Bs], gather_dim, @@ -417,6 +672,25 @@ def _fused_all_gather_scaled_matmul_fallback( A = torch.ops._c10d_functional.wait_tensor(A) A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) + scale_mode = _check_and_verify_fp8_all_gather_scale_mode( + shard=A_shard, scale=A_scale, gather_dim=gather_dim, group_size=group_size + ) + if scale_mode == _ScaleMode.ROW_WISE_SHARDED: + A_scale_shard = A_scale + A_scale = torch.ops._c10d_functional.all_gather_into_tensor( + A_scale.contiguous(), group_size, group_name + ) + A_scale = torch.ops._c10d_functional.wait_tensor(A_scale) + A_scale = ( + A_scale.view(group_size, *A_scale_shard.shape) + .movedim(gather_dim + 1, 1) + .flatten(0, -2) + ) + elif scale_mode == _ScaleMode.ROW_WISE_REPLICATED: + A_scale = A_scale.movedim(gather_dim, 0).flatten(0, -2) + else: + assert scale_mode == _ScaleMode.TENSOR_WISE + def scaled_matmul( A: torch.Tensor, B: torch.Tensor, @@ -429,7 +703,14 @@ def scaled_matmul( ) -> torch.Tensor: leading_dims = A.shape[:-1] res = torch.ops.aten._scaled_mm( - A.flatten(0, -2), B, A_scale, B_scale, out_dtype=out_dtype + A.flatten(0, -2), + B, + A_scale, + B_scale, + bias, + result_scale, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, ) return res.unflatten(0, leading_dims) @@ -465,7 +746,10 @@ def _fused_all_gather_scaled_matmul( res = torch.ops.aten._scaled_mm(A.flatten(0, -2), B, A_scale, B_scale) res = res.unflatten(0, leading_dims) - Optimal stride order for A_shard - if A_shard.movedim(gather_dim, 0) is + The input `A_scale` can be tensor-wise, row-wise-sharded or + row-wise-replicated. + + Optimal stride order for `A_shard` - if `A_shard.movedim(gather_dim, 0)` is contiguous, no extra copy is required for input layout transformation. Otherwise A_shard needs to be copied once. """ @@ -499,9 +783,9 @@ def _fused_all_gather_scaled_matmul( torch.ops.aten._scaled_mm.out, A_shard, Bs, + A_scale, [ { - "scale_a": A_scale, "scale_b": B_scale, "bias": bias, "scale_result": result_scale, @@ -548,6 +832,7 @@ def _fused_matmul_reduce_scatter_impl( mm_out_op: torch._ops.OpOverload, A: torch.Tensor, B: torch.Tensor, + A_scale: Optional[torch.Tensor], kwargs: Dict[str, Any], out_dtype: Optional[torch.dtype], reduce_op: str, @@ -571,16 +856,36 @@ def _fused_matmul_reduce_scatter_impl( out_shape = [*A.shape[:-1], B.shape[1]] out_shape[scatter_dim] //= group.size() - # Move the gather_dim to the front and flatten the tensor into a 2D matrix + # Move the scatter_dim to the front and flatten the tensor into a 2D matrix x = A.movedim(scatter_dim, 0) leading_dims = [group.size()] + list(x.shape[:-1]) leading_dims[1] //= group.size() x = x.flatten(0, -2) - shards = x.chunk(group.size()) + A_shards = x.chunk(group.size()) + + A_scale_shards = None + if A_scale is None: + pass + elif A_scale.numel() == 1: + A_scale_shards = [A_scale] * group.size() + else: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = A_scale.movedim(scatter_dim, 0).contiguous().flatten(0, -2) + A_scale_shards = list(A_scale.chunk(group.size())) # Computing block-wise matmul along the first dim of A def chunk_producer(rank: int, out: torch.Tensor) -> None: - mm_out_op(shards[rank], B, **kwargs, out=out) + if A_scale_shards is not None: + mm_out_op( + A_shards[rank], B, scale_a=A_scale_shards[rank], **kwargs, out=out + ) + else: + mm_out_op(A_shards[rank], B, **kwargs, out=out) stacked_partials = x.new_empty(x.shape[0], B.shape[1], dtype=out_dtype or A.dtype) @@ -640,6 +945,7 @@ def _fused_matmul_reduce_scatter( mm_out_op=torch.ops.aten.mm.out, A=A, B=B, + A_scale=None, kwargs={}, out_dtype=A.dtype, reduce_op=reduce_op, @@ -662,6 +968,20 @@ def _fused_scaled_matmul_reduce_scatter_fallback( out_dtype: Optional[torch.dtype] = None, use_fast_accum: bool = False, ) -> torch.Tensor: + if A_scale.numel() > 1: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = A_scale.flatten(0, -2).contiguous() + elif A_scale.numel() != 1: + raise ValueError( + "Invalid A_scale shape " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + C = torch._scaled_mm( A.flatten(0, -2).contiguous(), B, @@ -716,8 +1036,8 @@ def _fused_scaled_matmul_reduce_scatter( mm_out_op=torch.ops.aten._scaled_mm.out, A=A, B=B, + A_scale=A_scale, kwargs={ - "scale_a": A_scale, "scale_b": B_scale, "bias": bias, "scale_result": result_scale, @@ -733,14 +1053,14 @@ def _fused_scaled_matmul_reduce_scatter( def restride_A_for_fused_matmul_reduce_scatter( t: torch.Tensor, - gather_dim: int, + scatter_dim: int, ) -> torch.Tensor: """ Restride the `A_shard` arg of `fused_matmul_reduce_scatter` for optimal perf. See the doc for `fused_matmul_reduce_scatter` for detail. """ perm = list(range(len(t.shape))) - perm.insert(0, perm.pop(gather_dim)) + perm.insert(0, perm.pop(scatter_dim)) return make_contiguous_for_perm(t, perm) diff --git a/torch/distributed/_tools/__init__.py b/torch/distributed/_tools/__init__.py index cd57eedba3751..22e974cdd64f1 100644 --- a/torch/distributed/_tools/__init__.py +++ b/torch/distributed/_tools/__init__.py @@ -3,3 +3,10 @@ from .memory_tracker import MemoryTracker from .mod_tracker import ModTracker from .runtime_estimator import RuntimeEstimator +from .sac_estimator import ( + MSPS, + SACEstimator, + SACGreedyOrderMeta, + SACStats, + SACTradeOffStats, +) diff --git a/torch/distributed/_tools/fake_collectives.py b/torch/distributed/_tools/fake_collectives.py new file mode 100644 index 0000000000000..80e5d9891bdae --- /dev/null +++ b/torch/distributed/_tools/fake_collectives.py @@ -0,0 +1,205 @@ +import torch +import torch.distributed as dist +from torch._C._distributed_c10d import ProcessGroup, Work +from torch.futures import Future +from torch.testing._internal.distributed.fake_pg import FakeStore +from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor +from torch.distributed._functional_collectives import * +from torch.utils._python_dispatch import TorchDispatchMode +from functools import wraps +from contextlib import contextmanager, nullcontext +import logging +from datetime import timedelta +from typing import cast, Optional, overload + + +class FakeWork(Work): + def __init__(self): + super().__init__() + + def get_future(self) -> Future: + future = Future() + future.set_result(None) + return future + + def wait(self, timeout: Optional[timedelta] = None) -> bool: + return True + + +def _broadcast_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return (args[0], fakework_script_obj) + + +def _all_reduce_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return (args[0], fakework_script_obj) + + +def _all_gather_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return (args[0], fakework_script_obj) + + +def _all_gather_into_tensor_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return (args[0], fakework_script_obj) + + +def _reduce_scatter_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return (args[0], fakework_script_obj) + + +def _reduce_scatter_tensor_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return (args[0], fakework_script_obj) + + +def _reduce_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return fakework_script_obj + + +def _reduce_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return fakework_script_obj + + +def _gather_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return fakework_script_obj + +def _scatter_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return (args[0], fakework_script_obj) + +def _alltoall_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return (args[0], fakework_script_obj) + +def _send_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return fakework_script_obj + + +def _recv_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return fakework_script_obj + + +def _barrier_meta(*args): + fakework = FakeWork() + fakework.__setattr__("getFuture", fakework.get_future) + fakework_script_obj = fakework.boxed() + return fakework_script_obj + + +if not torch._running_with_deploy(): + # Library MUST be defined at module scope or it doesn't work + # Creating a "DEF" Library always crashes torch::deploy so we create our + # Library instances here guarded against running inside it + lib_impl = torch.library.Library("c10d", "IMPL") + lib_impl.impl("broadcast_", _broadcast_meta, "Meta") + lib_impl.impl("allreduce_", _all_reduce_meta, "Meta") + lib_impl.impl("allgather_", _all_gather_meta, "Meta") + lib_impl.impl("_allgather_base_", _all_gather_into_tensor_meta, "Meta") + lib_impl.impl("reduce_scatter_", _reduce_scatter_meta, "Meta") + lib_impl.impl("_reduce_scatter_base_", _reduce_scatter_tensor_meta, "Meta") + lib_impl.impl("reduce_", _reduce_meta, "Meta") + lib_impl.impl("gather_", _gather_meta, "Meta") + lib_impl.impl("scatter_", _scatter_meta, "Meta") + lib_impl.impl("alltoall_", _alltoall_meta, "Meta") + lib_impl.impl("barrier", _barrier_meta, "Meta") + lib_impl.impl("send", _send_meta, "Meta") + lib_impl.impl("recv_", _recv_meta, "Meta") + + +class IgnoreDistMode(TorchDispatchMode): + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + logging.info(f"Function name: {str(func.__name__)}") + logging.info(f"Function type: {type(func)}") + logging.info(f"Func: {func}") + + res = func(*args, **kwargs or {}) + return res + + +def run_test(): + try: + rank = dist.get_rank() + except: + rank = 0 + logging.getLogger().setLevel(logging.DEBUG if rank == 0 else logging.CRITICAL) + + # with nullcontext(): + with FakeTensorMode(): + with IgnoreDistMode(): + test_tensor_list = [torch.randn(1000, device="cuda") for _ in range(3)] + test_tensor = torch.randn(10000, device="cuda") + + # testing for collective operations + dist.broadcast(test_tensor, src=0) + dist.all_reduce(test_tensor) + dist.all_gather(test_tensor_list, test_tensor) + dist.all_gather_into_tensor(test_tensor, test_tensor) + dist.reduce_scatter(test_tensor, test_tensor_list) + dist.reduce_scatter_tensor(test_tensor, test_tensor) + dist.reduce(test_tensor, dst=0) + dist.gather(test_tensor, gather_list=test_tensor_list, dst=0) + dist.scatter(test_tensor, scatter_list=test_tensor_list, src=0) + dist.all_to_all(test_tensor_list, test_tensor_list) + dist.barrier() + dist.send(test_tensor, dst=1) + dist.recv(test_tensor, src=1) + + # testing for functional collectives + output = wait_tensor(test_tensor) + output = broadcast(test_tensor, src=0, group=dist.group.WORLD) + output = all_reduce(test_tensor, reduceOp="avg", group=dist.group.WORLD) + output = all_gather_tensor(test_tensor, gather_dim=0, group=dist.group.WORLD) + output = reduce_scatter_tensor(test_tensor, scatter_dim=0, reduceOp="sum", group=dist.group.WORLD) + output = all_to_all_single(test_tensor, output_split_sizes=[0], input_split_sizes=[1], group=dist.group.WORLD) + + dist.barrier() + + +if __name__ == "__main__": + gpu_id = 0 + world_size = 4 + dims = (world_size,) + names = ("dp",) + store = FakeStore() + dist.init_process_group("fake", rank=gpu_id, world_size=world_size, store=store) + device = f"cuda:{gpu_id}" + torch.cuda.set_device(device) + try: + run_test() + finally: + dist.destroy_process_group() diff --git a/torch/distributed/_tools/ilp_utils.py b/torch/distributed/_tools/ilp_utils.py new file mode 100644 index 0000000000000..43872339d5f32 --- /dev/null +++ b/torch/distributed/_tools/ilp_utils.py @@ -0,0 +1,291 @@ +import copy +from typing import cast, Dict, List, OrderedDict, Tuple, TypedDict + +import numpy as np + +import torch +from torch.distributed._tools.mem_tracker import ( + _MemRefType, + _ModMemStats, + _ModState, + MemTracker, +) +from torch.distributed._tools.runtime_estimator import RuntimeEstimator +from torch.distributed._tools.sac_estimator import SACEstimator, SACTradeOffStats + + +class ModOrder(TypedDict): + fw_pre_order: List[str] + bw_pre_order: List[str] + fw_post_order: List[str] + bw_post_order: List[str] + + +class ModRuntime(TypedDict): + fw: float + bw: float + + +class ModStats(TypedDict): + fqn: str + # per-module params + param_per_module: int + # per-module grads + grad_per_module: int + # total accumulated gradients up to and including this module + grad_total: int + # per module fw activation size (excluding input and output) + act_fw_per_module: int + # per module bw activation size during peak_bw + act_bw_per_module: int + # per module activation grad size during peak_bw + act_grad_per_module: int + # total activation size up to but excluding the current module + # includes input of the current module (i.e., output of previous module) + act_total: int + # Inputs to the module + input_per_module: int + # Outputs of the module + output_per_module: int + # Total fw run-time of the module + fw_runtime_per_module: float + # Total bw run-time of the module + bw_runtime_per_module: float + # Is this module a leaf module + is_leaf: bool + # Total ac run-time of the module + sac_runtime: float + # Total ac_memory for the module + sac_memory: int + # Number of piecewise-linear functions used for approximating ac tradeoff curve + n_segments: int + # Slopes of the of piecewise-linear functions + slopes: List[float] + # Intercepts of the of piecewise-linear functions + intercepts: List[float] + # X breakpoints of the of piecewise-linear functions + breakpoints: List[float] + # Original trade-off curves + tradeoff_curve: OrderedDict[float, float] + + +class ModuleInfo(TypedDict): + mod_order: ModOrder + mod_stats: List[ModStats] + + +def aggregate_stats( + model: torch.nn.Module, + mem_tracker: MemTracker, + runtime_estimator: RuntimeEstimator, + sac_estimator: SACEstimator, + dev: torch.device, +) -> ModuleInfo: + """ + Collect modulewise stats for a given model, including memory, runtime, and AC tradeoff stats. + + Args: + model: nn.Module object + runtime_estimator: RuntimeEstimator object with runtime stats + mem_tracker: MemTracker object with memory stats + sac_estimator: SACEstimator object with AC tradeoff stats + dev: device the model was run on (used to extract memory stats from MemTracker) + + Returns: + ModuleInfo: A dictionary with module order and module stats. + """ + + # Memory stats + mod_mem_stats: Dict[torch.nn.Module, _ModMemStats] = dict( + copy.deepcopy(mem_tracker.memory_tracking) + ) + + # Runtime stats + mod_runtime_stats: Dict[str, ModRuntime] = { + fqn: {"fw": v["fw"], "bw": v["bw"]} + for fqn, v in runtime_estimator.mod_runtimes.items() + } + + # Module order + mod_order: ModOrder = { + "fw_pre_order": list(runtime_estimator.mod_fw_pre_order), + "bw_pre_order": list(runtime_estimator.mod_bw_pre_order), + "fw_post_order": list(runtime_estimator.mod_fw_post_order), + "bw_post_order": list(runtime_estimator.mod_bw_post_order), + } + + # Selective Activation Checkpointing stats + sac_estimator.pwlf_sac_tradeoff_curve() + mod_sac_tradeoff_stats: Dict[str, SACTradeOffStats] = copy.deepcopy( + sac_estimator.sac_mod_tradeoff_stats + ) + + module_info: ModuleInfo = { + "mod_order": mod_order, + "mod_stats": [], + } + + for mod in model.modules(): + if mod_mem_stat := mod_mem_stats.get(mod, None): + if tradeoff_stats := mod_sac_tradeoff_stats.get(mod_mem_stat.mod_fqn, None): + sac_runtime = tradeoff_stats.sac_runtime + sac_memory = tradeoff_stats.sac_memory + n_segments = tradeoff_stats.n_segments + slopes = tradeoff_stats.slopes + intercepts = tradeoff_stats.intercepts + breakpoints = tradeoff_stats.fit_breaks + tradeoff_curve = tradeoff_stats.tradeoff_curve + is_leaf = False + else: + sac_runtime = sac_memory = n_segments = 0 + slopes = intercepts = breakpoints = [] + tradeoff_curve: OrderedDict[float, float] = OrderedDict() # type: ignore[no-redef] + is_leaf = True + mod_stat: ModStats = { + "fqn": mod_mem_stat.mod_fqn, + "param_per_module": mod_mem_stat.parameter_mem, + "grad_per_module": mod_mem_stat.parameter_mem, + "grad_total": mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][ + _MemRefType.GRAD + ], + "act_fw_per_module": max( + 0, + mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][_MemRefType.ACT] + - mod_mem_stat.snapshots[_ModState.PRE_FW][-1][dev][_MemRefType.ACT] + - mod_mem_stat.output_mem, + ), + "act_bw_per_module": max( + 0, + mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.ACT], + ), + "act_grad_per_module": ( + mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.TEMP] + - mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][ + _MemRefType.TEMP + ] + ), + "act_total": mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][ + _MemRefType.ACT + ], + "input_per_module": mod_mem_stat.input_mem, + "output_per_module": mod_mem_stat.output_mem, + "fw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["fw"], + "bw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["bw"], + "is_leaf": is_leaf, + "sac_runtime": sac_runtime, + "sac_memory": sac_memory, + "n_segments": n_segments, + "slopes": slopes, + "intercepts": intercepts, + "breakpoints": breakpoints, + "tradeoff_curve": tradeoff_curve, + } + module_info["mod_stats"].append(mod_stat) + + return module_info + + +class Node(ModStats): + index: int # index according to forward pre-order + pos_fw_post_order: int # index according to forward post-order + + +class Graph: + def __init__(self, n: int) -> None: + self.nodes: List[Node] = [] + self.name2node: Dict[str, Node] = {} + self.ad_matrix = np.zeros((n, n)) + self.fw_post_order: List[str] = [] + + def add_node(self, node: Node) -> None: + self.nodes.append(node) + self.name2node[node["fqn"]] = node + + +def parse_module_info(module_info: ModuleInfo) -> Graph: + """ + Parse module info and create a graph (tree) of modules. The graph will be + used by MILP solver to find optimal SAC and/or FSDP configurations. + """ + mod_stats = module_info["mod_stats"] + fw_pre_order = module_info["mod_order"]["fw_pre_order"] + # assertion and number of nodes + assert len(mod_stats) == len(fw_pre_order) + n_nodes = len(mod_stats) + + # create graph + g = Graph(n_nodes) + g.fw_post_order = module_info["mod_order"]["fw_post_order"] + + # sort the modules by pre-order and add them to the graph + module_info["mod_stats"] = sorted( + mod_stats, key=lambda x: fw_pre_order.index(x["fqn"]) + ) + for i, one_mod_stats in enumerate(mod_stats): + node: Node = cast(Node, one_mod_stats) + node["index"] = i + node["pos_fw_post_order"] = g.fw_post_order.index(node["fqn"]) + g.add_node(node) + + # set up ancestor-descendant matrix + for i in range(n_nodes): + for j in range(i, n_nodes): + if is_self_or_submodule(g.nodes[j]["fqn"], g.nodes[i]["fqn"]): + g.ad_matrix[i][j] = 1 + else: + break + + return g + + +def is_self_or_submodule(name_descendant: str, name_ancestor: str) -> bool: + """ + check if name_descendant is a submodule of name_ancestor, or if they are the same + """ + return name_descendant == name_ancestor or name_ancestor + "." in name_descendant + + +def is_submodule(name_descendant: str, name_ancestor: str) -> bool: + """ + if name_descendant is a submodule of name_ancestor, but not the same + """ + return name_ancestor + "." in name_descendant + + +def display_bytes(b: int, unit: str = "MiB") -> str: + """ + return a string that represent the number of bytes in a desired unit + """ + if unit == "KiB": + return f"{b/2**10:.2f} KiB" + if unit == "MiB": + return f"{b/2**20:.2f} MiB" + if unit == "GiB": + return f"{b/2**30:.2f} GiB" + return f"{b:.2f} bytes" + + +def get_peak_memory_runtime_baseline(graph: Graph) -> Tuple[int, float]: + """ + Get the baseline peak memory and runtime. + Baseline here means there is no FSDP or AC. + Memory includes the parameters, gradients, activations, and activation gradients. + Memory does not include e.g., optimizer states, embedding tables, etc. + + Returns: + int: peak memory in bytes + float: compute time in ms + """ + P_1 = graph.nodes[0]["param_per_module"] + num_nodes = len(graph.nodes) + peak_mem = 0 + for i in range(num_nodes): + TG_i = graph.nodes[i]["grad_total"] + AG_i = graph.nodes[i]["act_grad_per_module"] + TA_i = graph.nodes[i]["act_total"] + peak_mem = max(peak_mem, P_1 + TG_i + AG_i + TA_i) + compute_time = ( + graph.nodes[0]["fw_runtime_per_module"] + + graph.nodes[0]["bw_runtime_per_module"] + ) + return (peak_mem, compute_time) diff --git a/torch/distributed/_tools/sac_estimator.py b/torch/distributed/_tools/sac_estimator.py new file mode 100644 index 0000000000000..f5942307ec628 --- /dev/null +++ b/torch/distributed/_tools/sac_estimator.py @@ -0,0 +1,997 @@ +import math +import os +import sys +import warnings +from collections import OrderedDict +from dataclasses import astuple, dataclass +from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple +from typing_extensions import Self + +import torch +from torch import nan, nn, UntypedStorage +from torch._guards import active_fake_mode +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tools.mod_tracker import ModTracker +from torch.distributed._tools.runtime_estimator import RuntimeEstimator +from torch.testing._internal.composite_compliance import ( + is_inplace, + is_inplace_view_fn, + is_view_fn, +) +from torch.utils._python_dispatch import ( + is_traceable_wrapper_subclass, + TorchDispatchMode, +) +from torch.utils._pytree import tree_flatten +from torch.utils.checkpoint import SAC_IGNORED_OPS + + +__all__ = ["SACEstimator", "SACStats", "MSPS", "SACTradeOffStats", "SACGreedyOrderMeta"] +aten = torch.ops.aten + +_ADDITIONAL_IGNORED_OPS = { + aten.lift_fresh.default, # type: ignore[attr-defined] + torch.ops.profiler._record_function_exit._RecordFunction, # type: ignore[attr-defined] + aten.clone.default, # type: ignore[attr-defined] # seems needed for torch.compile +} +OPS_TO_ALWAYS_SKIP = SAC_IGNORED_OPS | _ADDITIONAL_IGNORED_OPS +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) + + +def _get_untyped_storages(t: torch.Tensor) -> Set[torch.UntypedStorage]: + """ + Retrieves untyped storages from a `torch.Tensor` or one of its traceable wrapper-subclass. + + Args: + t (torch.Tensor): Input `torch.Tensor` or traceable wrapper-subclass of `torch.Tensor`. + + Returns: + Set[torch.UntypedStorage]: Set of untyped storages. + + Warns: + UserWarning: If the flattened input is not a tensor or traceable wrapper-subclass. + """ + unflattened_tensors = [t] + flattened_tensor_storages = set() + while len(unflattened_tensors) > 0: + obj = unflattened_tensors.pop() + if is_traceable_wrapper_subclass(obj): + attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined] + unflattened_tensors.extend([getattr(obj, attr) for attr in attrs]) + else: + if not hasattr(obj, "untyped_storage"): + warnings.warn( + f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}", + category=UserWarning, + stacklevel=2, + ) + else: + flattened_tensor_storages.add(obj.untyped_storage()) + return flattened_tensor_storages + + +def _display_stats_tabular(headers: List[str], table_data: List[List[Any]]) -> None: + try: + from tabulate import tabulate + except ImportError as err: + raise ImportError("Please install tabulate.") from err + + # Use tabulate to print the table + print(tabulate(table_data, headers=headers, tablefmt="rst")) + + +# Based on: +# https://github.com/fairinternal/xformers/blob/0ded5697a2ea15711ce45131002d04e72053cc6d/xformers/checkpoint.py#L62 +@dataclass +class _SACMetadata: + """ + Stores metadata for a single operator for SAC. + + Attributes: + func (Any): The operator function. + time_taken (float): The time taken by the operator. + memory_used (float): The memory used by the operator. + curr_idx (int): The current operator index. + output_ids (Tuple[int, ...]): The storage IDs of the operator's outputs. + inplace_info (Tuple[int, ...]): Tuple of self and parent operator for in-place operator. + is_view_like (bool): Whether the operator is view-like. + is_rand_op (bool): Whether the operator is a random operator. + """ + + func: Any + time_taken: float + memory_used: float + curr_idx: int + output_ids: Tuple[int, ...] + inplace_info: Tuple[int, ...] + is_view_like: bool + is_rand_op: bool + + +@dataclass +class _SACModMetadata: + """ + Stores metadata for a module for SAC. + + Attributes: + start_idx (int): The starting index of the module's operators. + force_store_random (bool): Whether to force store random operators in the module. + sac_metadata (List[_SACMetadata]): List of metadata for each operator in the module. + """ + + start_idx: int + force_store_random: bool + sac_metadata: List[_SACMetadata] + + +@dataclass +class SACStats: + """ + A class for storing Activation Checkpointing statistics corresponding to a module. + + Attributes: + func_names (List[str]): List of operator names. + runtimes (List[float]): List of operator runtimes in millliseconds. + memory (List[int]): List of operator memory usage in bytes. + view_like_ops (List[int]): Indices of view-like operators. + rand_ops (List[int]): Indices of random operators. + saved_autograd_ops (List[int]): Indices of operator results saved by autograd engine. + inplace_ops (List[Tuple[int, int]]): Tuple of indices of op and its first parent for Inplace operators. + force_store_random (bool): Whether to force store random operator results. + """ + + func_names: List[str] + runtimes: List[float] + memory: List[int] + view_like_ops: List[int] + rand_ops: List[int] + saved_autograd_ops: List[int] + inplace_ops: List[Tuple[int, int]] + force_store_random: bool + + +class MSPS(NamedTuple): + """ + Represents Memory and Runtime Statistics for an operator/operator group. + + Attributes: + func_names (Set[str]): Set of operator/operator group names. + op_idx (int): Operator index (group head index incase of operator groups). + memory (int): Memory usage in bytes. + runtime (float): Runtime in milliseconds. + msps (float): Memory per second calculated as memory/runtime. + """ + + func_names: Set[str] + op_idx: int + memory: int + runtime: float + msps: float + + +@dataclass +class SACTradeOffStats: + """ + Stores statistics for activation-checkpointing trade-off. + + Attributes: + n_segments (int): Number of piecewise linear segments fitted to the trade-off curve. + slopes (List[float]): Slopes of the pieces of linear segments fitted to the trade-off curve. + intercepts (List[float]): Intercepts of the of the pieces of linear segments fitted to the trade-off curve. + fit_breaks (List[float]): Breakpoints of the of the pieces of linear segments fitted to the trade-off curve. + tradeoff_curve (OrderedDict[float, float]): Trade-off curve data of memory discarded vs recomputation time. + sac_memory (int): Total memory of operations available for activation checkpointing in bytes. + sac_runtime (float): Total runtime of operations available for activation checkpointing in milliseconds. + """ + + n_segments: int + slopes: List[float] + intercepts: List[float] + fit_breaks: List[float] + tradeoff_curve: OrderedDict[float, float] + sac_memory: int + sac_runtime: float + + +@dataclass +class SACGreedyOrderMeta: + """ + Stores metadata for Greedy-order SAC. + + Attributes: + recomputed_ops (Set[int]): Set of operator indices to be recomputed. + stored_ops (Set[int]): Set of operator indices to be stored. + inplace_op_groups (Dict[int, Set[int]]): Dictionary of inplace operator groups from group-head to operators. + random_ops_group (Dict[int, Set[int]]): Dictionary of random op group head to random ops. + msps_meta (List[MSPS]): List of Memory and Runtime Statistics for operators. + """ + + recomputed_ops: Set[int] + stored_ops: Set[int] + inplace_op_groups: Dict[int, Set[int]] + random_ops_group: Dict[int, Set[int]] + msps_meta: List[MSPS] + + +class SACEstimator(TorchDispatchMode): + """ + Estimates the memory and recomputation time trade-offs for applying Selective Activation Checkpointing (SAC). + + This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the memory and + runtime trade-offs of functions or ``torch.nn.Module``s for Selective Activation Checkpointing (SAC). It provides + detailed statistics and metadata information for operators of each module and provides a greedy order for selecting + the operators to be recomputed/checkpointed. It also constructs the per-module trade-off graph of discarded memory + vs recomputation time for the obtained greedy order. Using ``RuntimeEstimator`` under the hood, it supports two + estimation modes, `operator-level-benchmark` and (`operator-level-cost-model` (roofline model). + + Attributes: + sac_mod_stats (Dict[str, SACStats]): Dictionary from module FQN (fuly qualified name) to ``SACStats``. + sac_mod_tradeoff_stats (Dict[str, SACTradeOffStats]): Dictionary from module FQN to ``SACTradeOffStats``. + sac_mod_greedy_order_meta (Dict[str, SACGreedyOrderMeta]): Dictionary from module FQN to ``SACGreedyOrderMeta``. + + Note: + 1) This class is designed to be used under ``FakeTensorMode``. + 2) Currently, it only supports estimation of compute time and memory usage, and does not consider communication. + + Example usage: + + .. code-block:: python + + sac_estimator = SACEstimator() + with FakeTensorMode(): + module = ... + inp = ... + with sac_estimator('operator-level-cost-model'): + output = module(inp) + sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True) + """ + + def __init__(self) -> None: + self.sac_mod_stats: Dict[str, SACStats] = {} + self.sac_mod_tradeoff_stats: Dict[str, SACTradeOffStats] = {} + self.sac_mod_greedy_order_meta: Dict[str, SACGreedyOrderMeta] = {} + self._mod_tracker = ModTracker() + self._sac_metadata: List[_SACMetadata] = [] + self._sac_mod_metadata: Dict[str, _SACModMetadata] = {} + self._leaf_modules: Set[str] = set() + self._saved_tensor_hook_ctx = torch.autograd.graph.saved_tensors_hooks( + self._pack_hook, lambda x: x + ) + self._saved_tensor_ids: Set[int] = set() + self._estimate_runtime = RuntimeEstimator._roofline_estimate + + def _pack_hook(self, x: torch.Tensor) -> torch.Tensor: + # Hook function to track underlying storage IDs of tensors + # Updates the _saved_tensor_ids set with the IDs of the tensor's storages + # Used in conjunction with torch.autograd.graph.saved_tensors_hooks + untyped_storages = _get_untyped_storages(x) + storage_ids = (hash(st) for st in untyped_storages) + self._saved_tensor_ids.update(storage_ids) + return x + + def _pre_fw_hook(self, mod: nn.Module, inputs: Any) -> None: + # Pre-forward hook function to prepare module metadata + # Tracks module FQN, force store random flag, and ``SACModMetadata`` + # Initializes metadata for non-leaf modules, marks leaf modules + mod_fqn = self._mod_tracker.get_known_fqn(mod) + assert mod_fqn is not None + num_children = sum(1 for _ in mod.children()) + if num_children > 0: + force_store_random = self._get_force_store_random(inputs) + self._sac_mod_metadata[mod_fqn] = _SACModMetadata( + start_idx=len(self._sac_metadata), + force_store_random=force_store_random, + sac_metadata=[], + ) + else: + self._leaf_modules.add(mod_fqn) + + def _post_fw_hook(self, mod: nn.Module, inputs: Any, outputs: Any) -> None: + # 1. Retrieves the module's FQN and checks if it's a leaf module + # 2. If not a leaf module, computes: + # - ``SACStats`` using the module's metadata and force store random flag + # - ``SACGreedyOrderMeta`` using the computed SAC statistics + mod_fqn = self._mod_tracker.get_known_fqn(mod) + assert mod_fqn is not None + if mod_fqn in self._leaf_modules: + return + else: + self.sac_mod_stats[mod_fqn] = self._get_sac_stats( + data=self._sac_mod_metadata[mod_fqn].sac_metadata, + force_store_random=self._sac_mod_metadata[mod_fqn].force_store_random, + ) + self.sac_mod_greedy_order_meta[mod_fqn] = self._get_greedy_order_meta( + self.sac_mod_stats[mod_fqn] + ) + + def _get_force_store_random(self, inputs: Any) -> bool: + flat_inputs, _ = tree_flatten(inputs) + return all(not isinstance(x, torch.Tensor) for x in flat_inputs) + + def _get_sac_stats( + self, data: List[_SACMetadata], force_store_random: bool + ) -> SACStats: + # 1. Ignore the operations that should be skipped by SAC such as aten.detach.default because autograd + # inserts those during backward and it breaks the fwd-bwd alignment + filtered_data = [x for x in data if x.func not in OPS_TO_ALWAYS_SKIP] + + ( + ops, + runtimes_, + memory_, + new_ids, + output_ids, + inplace_ops_, + view_like_ops_, + rand_ops_, + ) = zip(*[astuple(x) for x in filtered_data], strict=True) + + # 2. Extract the metadata information + runtimes = list(runtimes_) + memory = list(memory_) + func_names = [op._overloadpacket.__name__ for op in ops] + view_like_ops = [i for i, x in enumerate(view_like_ops_) if x] + rand_ops = [i for i, x in enumerate(rand_ops_) if x] + saved_autograd_ops = [ + i + for i, out_ids in enumerate(output_ids) + if set(out_ids).issubset(self._saved_tensor_ids) + ] + + # 3. Remap the inplace indices as we have removed OPS_TO_ALWAYS_SKIP + # FIXME @sanketpurandare: Fix this by changing the parent of the inplace-op + # to itself if the original parent is in OPS_TO_ALWAYS_SKIP. + try: + inplace_ops = [tuple(map(new_ids.index, x)) for x in inplace_ops_ if x] + except ValueError as err: + raise ValueError( + f"The remapping of inplace ops failed since one of the inplace op parents" + f" must have been present in {OPS_TO_ALWAYS_SKIP}" + ) from err + + # 4. The last operation is always stored as the output of the checkpoint + # block, so we can avoid recomputing it. We set the memory to zero + # instead of adding a new constraint because we want both the 0 and 1 + # endpoints for memory_budget to be valid + # FIXME @sanketpurandare: this heuristic for finding the last non-view non-inplace op + # might not always be correct, which would yield suboptimal policies + last_op = len(ops) - 1 + skip_ops_ = set(view_like_ops) | set({x[0] for x in inplace_ops}) + reversed_skip_ops = sorted(skip_ops_, reverse=True) + for op in reversed_skip_ops: + if op == last_op: + last_op -= 1 + + memory[last_op] = 0 + + # 5. Create a single ``SACStats`` object for the entire block of ``_SACMetadata``. + return SACStats( + func_names=func_names, + runtimes=runtimes, + memory=memory, + view_like_ops=view_like_ops, + rand_ops=rand_ops, + saved_autograd_ops=saved_autograd_ops, + inplace_ops=inplace_ops, # type: ignore[arg-type] + force_store_random=force_store_random, + ) + + def _get_inplace_metadata( + self, func: Any, out_storages: Set[UntypedStorage] + ) -> Tuple[int, Tuple[int, ...], Dict[str, Tuple[int, ...]]]: + # 1. Get the current index of the metadata obtained so far + curr_idx = len(self._sac_metadata) + # 2. Get the set of active modules that are not leaf + active_mod_fqns: Set[str] = { + par for par in self._mod_tracker.parents if par not in self._leaf_modules + } + # 3. Output ids are the identifies of the storage objects corresponding to the tensors + output_ids = tuple(hash(st) for st in out_storages) + # 4. If the function is not inplace, return + if not is_inplace(func): + return curr_idx, output_ids, {mod_fqn: () for mod_fqn in active_mod_fqns} + + op_idx = curr_idx + # 5. Initialize the parent op ids of the inplace op for each of the active modules + mod_op_parent_idxs: Dict[str, int] = { + mod_fqn: -1 for mod_fqn in active_mod_fqns + } + for i, d in enumerate(self._sac_metadata): + # 6. Find the first occurence of a tensor corresponding to each module that + # shares the same storage as the current tensor + past_output_ids = d.output_ids + if set(output_ids).issubset(set(past_output_ids)): + for mod_fqn, op_parent_idx in mod_op_parent_idxs.items(): + if op_parent_idx == -1: + if acm_stats := self._sac_mod_metadata.get(mod_fqn, None): + if i >= acm_stats.start_idx: + mod_op_parent_idxs[mod_fqn] = i + else: + assert mod_fqn == "Global" + mod_op_parent_idxs[mod_fqn] = i + # 7. If no parent tensor is found, then it's probably an inplace op on the arguments + # so one can just store the current-op idx as parent idx + for mod_fqn, op_parent_idx in mod_op_parent_idxs.items(): + if op_parent_idx < 0: + mod_op_parent_idxs[mod_fqn] = op_idx + mod_inplace_info = { + mod_fqn: (op_idx, mod_op_parent_idxs[mod_fqn]) + for mod_fqn in active_mod_fqns + } + return curr_idx, output_ids, mod_inplace_info # type: ignore[return-value] + + def __torch_dispatch__( # type: ignore[no-untyped-def] + self, func, types, args=..., kwargs=None + ): + # 1. Get the runtime estimate + out, op_time = self._estimate_runtime(func, args, kwargs) + flat_outs, _ = tree_flatten(out) + out_storages_cuda: Set[UntypedStorage] = set() + out_storages_cpu: Set[UntypedStorage] = set() + cuda_devices: Set[torch.device] = set() + for o in flat_outs: + if isinstance(o, torch.Tensor): + if o.device.type == "cuda": + out_storages_cuda.update(_get_untyped_storages(o)) + cuda_devices.add(o.device) + else: + out_storages_cpu.update(_get_untyped_storages(o)) + + # Check if there's more than 1 CUDA device + assert ( + len(cuda_devices) <= 1 + ), f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}" + + # 2. Get the memory consumed by output + nbytes_cuda = sum( + math.ceil(st.nbytes() / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + for st in out_storages_cuda + ) + nbytes_cpu = sum(st.nbytes() for st in out_storages_cpu) + nbytes = nbytes_cuda + nbytes_cpu + # 3. Get the current operator index, output storage identifiers and inplace metadata + out_storages = out_storages_cuda | out_storages_cpu + curr_idx, output_ids, mod_inplace_info = self._get_inplace_metadata( + func, out_storages + ) + # 4. Determine if the function is in-place, random-op or a view-like + is_view_like = is_view_fn(func) or is_inplace_view_fn(func) + is_rand_op = torch.Tag.nondeterministic_seeded in func.tags + if is_view_like: + nbytes = 0 + # sdpa has non-deterministic seed, but might be deterministic + # if no dropout is applied + if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention": + is_rand_op = kwargs.get("dropout_p", 0) != 0 + # 5. Create metadata information per active non-leaf module + for mod_fqn in self._mod_tracker.parents: + if mod_fqn in self._leaf_modules: + continue + acm = _SACMetadata( + func=func, + time_taken=op_time, + memory_used=nbytes, + curr_idx=curr_idx, + output_ids=output_ids, + inplace_info=mod_inplace_info[mod_fqn], + is_view_like=is_view_like, + is_rand_op=is_rand_op, + ) + if acm_stats := self._sac_mod_metadata.get(mod_fqn, None): + acm_stats.sac_metadata.append(acm) + else: + assert ( + mod_fqn == "Global" + ), f"Module {mod_fqn} not found in AC Mod Stats" + self._sac_metadata.append(acm) + + return out + + def _get_greedy_order_meta(self, sac_stats: SACStats) -> SACGreedyOrderMeta: + # An inplace-op group is a set of inplace-ops that operate on the same underlying tensor storage. + # 1. inplace_op_groups: A dictionary from the top-most parent of inplace-ops to the inplace-ops in the group + # The top-most op can itself be an inplace-op or can be a non-inplace op. + # 2. inplace_op_to_group_head: A dictionary that maps all the inplace-ops to their respective group heads. + inplace_op_groups: Dict[int, Set[int]] = {} + inplace_op_to_group_head: Dict[int, int] = dict(sac_stats.inplace_ops) + + # Initialize inplace_op_groups using inplace_op_to_group_head + for op_idx, group_head_idx in inplace_op_to_group_head.items(): + op_group = inplace_op_groups.setdefault(group_head_idx, {group_head_idx}) + op_group.add(op_idx) + + # Like inplace ops, all of the random ops in the function/module should all be either recomputed or saved + # as a group. This is because, they affect the ranom seed generator. If force_store_random is set True, + # all of the random ops will be stored by default. For easy of manageability, we store the top-most random op + # as the leader of the random_ops_group. + random_ops_group: Dict[int, Set[int]] = {} + random_group_head_idx = min(sac_stats.rand_ops, default=-1) + has_rand_ops = bool(sac_stats.rand_ops) + if has_rand_ops: + random_ops_group[random_group_head_idx] = set(sac_stats.rand_ops) + + # 1. Random ops are stored if force_store_random is set + # 2. View-like ops are recomputed by default + # 3. For inplace_op_groups: + # a) If the head of this group is an inplace op, then we have to store the entire group. + # b) If any op in the group is random and force_store_random is set, then entire group will be stored. + # c) If none of ops in the group are random and the head of the group is not an in-place op, then + # this group can be considered for recomputation in its entireity + stored_ops: Set[int] = set() + recomputed_ops: Set[int] = set() + # Case 1: + if has_rand_ops and sac_stats.force_store_random: + stored_ops.add(random_group_head_idx) + # Case 2: + recomputed_ops.update(set(sac_stats.view_like_ops)) + + for group_head_idx, op_group in inplace_op_groups.items(): + # Case 3a: + if group_head_idx in inplace_op_to_group_head: + stored_ops.add(group_head_idx) + # Case 3b: + if ( + sac_stats.force_store_random & len(op_group & set(sac_stats.rand_ops)) + > 0 + ): + stored_ops.add(group_head_idx) + + # The potential recompute candidates are populated as: + recompute_candidates: Set[int] = set() + # 1) The random group head if it is not stored + if has_rand_ops and random_group_head_idx not in stored_ops: + recompute_candidates.add(random_group_head_idx) + # 2) The in-place op group heads that are not stored + recompute_candidates.update(set(inplace_op_groups.keys()) - stored_ops) + # 3) The non-inplace and non-random ops that are neither stored nor recomputed by default + recompute_candidates.update( + set(range(len(sac_stats.memory))) + - recomputed_ops + - stored_ops + - set(inplace_op_to_group_head.keys()) + - set(sac_stats.rand_ops) + ) + + # We define msps for a recomp candidate as the ratio of memory/runtime aka memory savings per second + msps_meta: List[MSPS] = [] + for cand_idx in recompute_candidates: + op_indices = {cand_idx} + if cand_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[cand_idx]) + if has_rand_ops and cand_idx == random_group_head_idx: + op_indices.update(sac_stats.rand_ops) + + mem = sum(sac_stats.memory[op_idx] for op_idx in op_indices) + runtime = sum(sac_stats.runtimes[op_idx] for op_idx in op_indices) + func_names = {sac_stats.func_names[op_idx] for op_idx in op_indices} + msps = (mem / runtime) if runtime > 0 else sys.float_info.max + msps_meta.append(MSPS(func_names, cand_idx, mem, runtime, msps)) + # We choose canidates to be recomputed based on increasing msps + msps_meta.sort(key=lambda x: x.msps, reverse=True) + return SACGreedyOrderMeta( + recomputed_ops, stored_ops, inplace_op_groups, random_ops_group, msps_meta + ) + + def _get_sac_tradeoff_pwlf_stats( + self, + sac_stats: SACStats, + greedy_order_meta: SACGreedyOrderMeta, + n_segments: int = 2, + save_tradeoff_graph: bool = False, + filename: str = "ac_tradeoff", + ) -> SACTradeOffStats: + try: + import numpy as np # type: ignore[import-not-found] + import pwlf # type: ignore[import-untyped, import-not-found] + except ImportError as err: + raise ImportError("Please install pwlf and numpy package.") from err + + stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = ( + greedy_order_meta.stored_ops, + greedy_order_meta.recomputed_ops, + greedy_order_meta.inplace_op_groups, + greedy_order_meta.random_ops_group, + greedy_order_meta.msps_meta, + ) + # 1. Intitialize the discarded memory and recomputation runtime to sum of already chosen recomputed_ops + recomp_indices: Set[int] = set() + for r_idx in recomputed_ops: + recomp_indices.add(r_idx) + if r_idx in inplace_op_groups: + recomp_indices.update(inplace_op_groups[r_idx]) + if r_idx in random_ops_group: + recomp_indices.update(random_ops_group[r_idx]) + + discarded_mem = sum(sac_stats.memory[op_idx] for op_idx in recomp_indices) + recomp_runtime = sum(sac_stats.runtimes[op_idx] for op_idx in recomp_indices) + # 2. Initialize the max recomputation time and total recomputation memory + sac_runtime = sum(sac_stats.runtimes) + sac_memory = sum(sac_stats.memory) + # 3. Tradeoff curve stores the KV pair of the dicarded memory to total memory and, + # recomputation time to total runtime incurred. + delta = 1e-2 + tradeoff_curve = OrderedDict() + # 4. Initialize the trade-off curve with the stats of of already chosen recomputed_ops + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + # 5. Update the trade-off curve with memory and runtime stats of SAC candidates in the + # greedy order of their ``MSPS``. + for cand in msps_meta: + discarded_mem += cand.memory + recomp_runtime += cand.runtime + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + # 6. Finally, we add the memory and recomputation time of the always stored ops. + stored_indices: Set[int] = set() + for s_idx in stored_ops: + stored_indices.add(s_idx) + if s_idx in inplace_op_groups: + stored_indices.update(inplace_op_groups[s_idx]) + if s_idx in random_ops_group: + stored_indices.update(random_ops_group[s_idx]) + discarded_mem += sum(sac_stats.memory[op_idx] for op_idx in stored_indices) + recomp_runtime += sum(sac_stats.runtimes[op_idx] for op_idx in stored_indices) + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + x_ = list(tradeoff_curve.keys()) + y_ = list(tradeoff_curve.values()) + # 7. We shift the y values to left and x values to right to upperbound the trade-off function + # TODO: Write a better explanation why this needs to be done + x = x_[: len(x_) - 1] + y = y_[1:] + tradeoff_pwlf = pwlf.PiecewiseLinFit(x, y) + # 8. Fit a piecewise linear function with the specified number of segments to the trade-off curve. + n_segments = max(min(len(x) - 2, n_segments), 1) + tradeoff_pwlf.fit(n_segments=n_segments) + + # save prediction graph + def save_prediction_graph( + pwlf_: pwlf.PiecewiseLinFit, x: List[float], y: List[float], filename: str + ) -> None: + try: + import matplotlib.pyplot as plt # type: ignore[import-not-found] + import numpy as np # type: ignore[import-not-found] + except ImportError as err: + raise ImportError( + "Install matplotlib and numpy using pip: pip install matplotlib numpy" + ) from err + # predict for the determined points + xHat = np.linspace(min(x), max(x), num=10000) + yHat = pwlf_.predict(xHat) + + # plot the results + plt.figure() + plt.plot(x, y, "o", label="Shifted") + plt.plot(xHat, yHat, "-", label="Predicted") + plt.plot(x_, y_, "x", label="Original") + plt.ylabel("Recomp time / Total recomp time") + plt.xlabel("Memory discarded / Total memory") + plt.legend() + plt.title(f"{filename}") + plt.suptitle( + f"Total Memory = {sac_memory} B Total Runtime = {sac_runtime:.4f} ms", + fontsize=10, + ) + folder_name = "tradeoff_graphs" + if not os.path.exists(folder_name): + os.makedirs(folder_name) + # Save the plots in the folder + plt.savefig(os.path.join(folder_name, f"{filename}.png")) + + if save_tradeoff_graph: + save_prediction_graph(tradeoff_pwlf, x, y, filename) + # 9. Obtain the slopes, intercepts and breakpoints of the fitted piecewise linear functions + slopes = tradeoff_pwlf.calc_slopes().tolist() + assert isinstance(tradeoff_pwlf.intercepts, np.ndarray) and isinstance( + tradeoff_pwlf.fit_breaks, np.ndarray + ) + intercepts = tradeoff_pwlf.intercepts.tolist() + fit_breaks = tradeoff_pwlf.fit_breaks.tolist() + return SACTradeOffStats( + n_segments=n_segments, + slopes=slopes, + intercepts=intercepts, + fit_breaks=fit_breaks, + tradeoff_curve=tradeoff_curve, + sac_memory=sac_memory, + sac_runtime=sac_runtime, + ) + + def display_sac_stats( + self, sac_stats: SACStats, print_tabular: bool = False + ) -> None: + """ + Displays the SAC statistics. + + Args: + sac_stats (SACStats): The SAC statistics to display. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + 1. Total Memory: The total memory usage in bytes. + 2. Total Runtime: The total runtime in milliseconds. + 3. Store Random: A flag indicating whether to force store random operator results. + + Followed by a table with the following columns: + 1. Op Idx: The operator index. + 2. Op Name: The operator name. + 3. Runtimes (ms): The operator runtime in milliseconds. + 4. Memory (B): The operator memory usage in bytes. + 5. View-like: A flag indicating whether the operator is view-like. + 6. Random: A flag indicating whether the operator is random. + 7. Saved Autograd: A flag indicating whether the operator's result is saved by autograd engine. + 8. In-place: The index of the operator's first parent, or None if not in-place. + + If print_tabular is True, the table is printed in a tabular format. + Otherwise, the table is printed in a plain text format. + """ + print( + f"Total Memory: {sum(sac_stats.memory)} B Total Runtime: {sum(sac_stats.runtimes)} ms" + f" Store Random: {sac_stats.force_store_random}" + ) + table_data = [] + op_parent = dict(sac_stats.inplace_ops) + for i, fn_name in enumerate(sac_stats.func_names): + row = [ + str(i), + fn_name, + f"{sac_stats.runtimes[i]:.4f}", + str(sac_stats.memory[i]), + str(i in sac_stats.view_like_ops), + str(i in sac_stats.rand_ops), + str(i in sac_stats.saved_autograd_ops), + str(op_parent.get(i, None)), + ] + table_data.append(row) + # Define headers + headers = [ + "Op Idx", + "Op Name", + "Runtimes(ms)", + "Memory (B)", + "View-like", + "Random", + "Saved Autograd", + "In-place", + ] + if print_tabular: + _display_stats_tabular(headers, table_data) + else: + max_widths = [0 for _ in range(len(headers))] + table_data.insert(0, headers) + for row in table_data: + for i, elem in enumerate(row): + max_widths[i] = max(max_widths[i], len(elem)) + for row in table_data: + print( + "\t".join( + [f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)] + ) + ) + + def display_sac_tradeoff_stats( + self, + greedy_order_meta: SACGreedyOrderMeta, + sac_stats: SACStats, + print_tabular: bool = False, + ) -> None: + """ + Displays the SAC trade-off statistics. + + Args: + greedy_order_meta (SACGreedyOrderMeta): The SAC greedy order metadata. + sac_stats (SACStats): The SAC statistics. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + A table with the following columns: + 1. Op Id(s): The operator index(es). + 2. Op Name(s): The operator name(s). + 3. Discarded Mem (%): The percentage of discarded memory. + 4. Discarded Mem (B): The discarded memory in bytes. + 5. Recomp time (%): The percentage of recomputed time. + 6. Recomp time (ms): The recomputed time in milliseconds. + 7. MSPS: The memory per second. + 8. Always Stored: A flag indicating whether the operator is always stored. + 9. Always Recomputed: A flag indicating whether the operator is always recomputed. + + If print_tabular is True, the table is printed in a tabular format. + Otherwise, the table is printed in a plain text format. + """ + table_data = [] + total_memory, total_runtime = sum(sac_stats.memory), sum(sac_stats.runtimes) + discarded_mem: int = 0 + recomp_runtime: float = 0.0 + + def append_row( + op_indices: Set[int], + func_names: Set[str], + msps: Optional[float] = None, + stored: Optional[bool] = False, + recomputed: Optional[bool] = False, + ) -> None: + row = [ + str(op_indices), + str(func_names), + f"{discarded_mem / total_memory:.4f}", + str(discarded_mem), + f"{recomp_runtime / total_runtime:.4f}", + str(recomp_runtime), + f"{msps:.2e}" if msps is not None else str(nan), + str(stored), + str(recomputed), + ] + table_data.append(row) + + stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = ( + greedy_order_meta.stored_ops, + greedy_order_meta.recomputed_ops, + greedy_order_meta.inplace_op_groups, + greedy_order_meta.random_ops_group, + greedy_order_meta.msps_meta, + ) + + for op_idx in recomputed_ops: + op_indices: Set[int] = {op_idx} + if op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[op_idx]) + if op_idx in random_ops_group: + op_indices.update(random_ops_group[op_idx]) + discarded_mem += sum(sac_stats.memory[i] for i in op_indices) + recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices) + func_names = {sac_stats.func_names[i] for i in op_indices} + append_row(op_indices, func_names, recomputed=True) + + for cand in msps_meta: + discarded_mem += cand.memory + recomp_runtime += cand.runtime + op_indices = {cand.op_idx} + if cand.op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[cand.op_idx]) + if cand.op_idx in random_ops_group: + op_indices.update(random_ops_group[cand.op_idx]) + append_row(op_indices, cand.func_names, msps=cand.msps) + + for op_idx in stored_ops: + op_indices = {op_idx} + if op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[op_idx]) + if op_idx in random_ops_group: + op_indices.update(random_ops_group[op_idx]) + discarded_mem += sum(sac_stats.memory[i] for i in op_indices) + recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices) + func_names = {sac_stats.func_names[i] for i in op_indices} + append_row(op_indices, func_names, stored=True) + + headers = [ + "Op Id(s)", + "Op Name(s)", + "Discarded Mem (%)", + "Discarded Mem (B)", + "Recomp time (%)", + "Recomp time (ms)", + "MSPS", + "Always Stored", + "Always Recomputed", + ] + if print_tabular: + _display_stats_tabular(headers, table_data) + else: + max_widths = [0 for _ in range(len(headers))] + table_data.insert(0, headers) + for row in table_data: + for i, elem in enumerate(row): + max_widths[i] = max(max_widths[i], len(elem)) + for row in table_data: + print( + "\t".join( + [f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)] + ) + ) + + def pwlf_sac_tradeoff_curve( + self, + n_segments: int = 2, + save_tradeoff_graphs: bool = False, + ) -> None: + """ + Fits a piecewise linear function with the specified sumber of segments to the SAC trade-off curve of + discarded memory vs recomputation time. + + Args: + n_segments (int, optional): The number of segments to be used for fitting the piecewise linear function to + the trade-off curve. Defaults to 2. + save_tradeoff_graphs (bool, optional): Whether to save the trade-off graphs to file. Defaults to False. + + If save_tradeoff_graphs is True, the trade-off graphs are saved to file using the module FQN as the filename. + """ + for mod_fqn, sac_stats in self.sac_mod_stats.items(): + self.sac_mod_tradeoff_stats[mod_fqn] = self._get_sac_tradeoff_pwlf_stats( + sac_stats=sac_stats, + greedy_order_meta=self.sac_mod_greedy_order_meta[mod_fqn], + n_segments=n_segments, + save_tradeoff_graph=save_tradeoff_graphs, + filename=mod_fqn, + ) + + def display_modulewise_sac_stats( + self, depth: int = 2, print_tabular: bool = False + ) -> None: + """ + Displays the SAC and trade-off statistics for each module. + + Args: + depth (int, optional): The maximum depth of modules to display. Defaults to 2. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + For each module with depth less than or equal to the specified depth: + 1. The SAC statistics for the module (using display_sac_stats). + 2. The SAC trade-off statistics for the module (using display_sac_tradeoff_stats). + + If print_tabular is True, the statistics are printed in a tabular format. + Otherwise, the statistics are printed in a plain text format. + """ + for mod_fqn, sac_stats in self.sac_mod_stats.items(): + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(f"Module: {mod_fqn}") + self.display_sac_stats(sac_stats, print_tabular) + print(f"AC Trade-off for Module: {mod_fqn} MSPS = Memory/Runtime") + self.display_sac_tradeoff_stats( + self.sac_mod_greedy_order_meta[mod_fqn], sac_stats, print_tabular + ) + + def __call__(self, estimate_mode_type: str) -> Self: + """ + Sets the estimate mode type. + + Currently supported modes: + - "operator-level-benchmark": Estimates runtime using operator benchmarking. + - "operator-level-cost-model": Estimates runtime using roofline cost model. + + Args: + estimate_mode_type (str): The type of estimate mode to use. + + Returns: + SACEstimator: The SAC estimator instance. + + Raises: + NotImplementedError: If the estimate mode type is not supported. + """ + if estimate_mode_type == "operator-level-benchmark": + self._estimate_runtime = RuntimeEstimator._benchmark_estimate + elif estimate_mode_type == "operator-level-cost-model": + self._estimate_runtime = RuntimeEstimator._roofline_estimate + else: + raise NotImplementedError( + f"estimate_mode_type {estimate_mode_type} not supported" + ) + return self + + def __enter__(self) -> Self: # type: ignore[no-untyped-def] + fake_mode = active_fake_mode() + assert isinstance( + fake_mode, FakeTensorMode + ), "SAC Estimator should be called in FakeTensorMode" + RuntimeEstimator.fake_mode = fake_mode + self._mod_tracker.register_user_hooks( + pre_fw_hook=self._pre_fw_hook, + post_fw_hook=self._post_fw_hook, + ) + self._mod_tracker.__enter__() + self._saved_tensor_hook_ctx.__enter__() + return super().__enter__() + + def __exit__(self, *args: Any) -> None: # type: ignore[no-untyped-def] + self._saved_tensor_hook_ctx.__exit__() + self._mod_tracker.__exit__(*args) + super().__exit__(*args) diff --git a/torch/distributed/_tools/sac_ilp.py b/torch/distributed/_tools/sac_ilp.py new file mode 100644 index 0000000000000..490ac59f1a084 --- /dev/null +++ b/torch/distributed/_tools/sac_ilp.py @@ -0,0 +1,295 @@ +import logging +import math +from enum import IntEnum +from typing import Dict, List, Optional, Tuple + +from torch.distributed._tools.ilp_utils import Graph, is_submodule +from torch.distributed._tools.sac_estimator import SACStats + + +try: + from pulp import ( # type: ignore[import-untyped,import-not-found] + lpDot, + LpInteger, + LpMaximize, + LpMinimize, + LpProblem, + LpStatus, + lpSum, + LpVariable, + PULP_CBC_CMD, + value, + ) +except ImportError as err: + raise ImportError( + "Please install pulp package. See: https://github.com/coin-or/pulp." + ) from err + +# Create a logger object +logger = logging.getLogger(__name__) + +# Set the logging level to INFO +logger.setLevel(logging.INFO) + + +def sac_milp( + graph: Graph, + memory_budget: float, + world_size: int = 1, + ac_units: Optional[List[str]] = None, + fsdp_units: Optional[List[str]] = None, +) -> Tuple[Dict[str, float], float, int]: + """ + MILP to decide which modules to AC and how much memory to discard. + The objective is to minimize recomputation time. + The constraint is to ensure peak memory is under budget. + + Args: + graph: graph representation of the model as a module submodule tree + where each node is a submodule with memory & runtime stats + memory_budget: memory budget in GiB + world_size: number of GPUs. In the case of FSDP, world_size will be + used to compute the amount of parameter and gradient memory on each rank + ac_units: a list of user-specified AC units. + fsdp_units: a list of FSDP units. AC units cannot be supermodules of FSDP units. + + Returns: + Dict[str, float]: the optimal SAC solution, mapping from module fqn to + the percentage of activation memory to **discard** + float: the recomputation time of the optimal SAC solution + int: upper bound on the peak memory of the optimal SAC solution. + note that value of -1 means that the ILP solver failed to find a solution. + + """ + num_nodes = len(graph.nodes) + M = 10**2 # note: numerical issue may occur if M is too big + MEM_MULTIPLIER = 2**30 + + # Create a MILP problem + prob = LpProblem("SAC", LpMinimize) + + # Create decision variables + # y_i: indicator for if module i is AC'ed + y = LpVariable.matrix("y", list(range(num_nodes)), 0, 1, LpInteger) + # r_i: percentage of discarded activation memory + r = LpVariable.matrix("r", list(range(num_nodes)), 0, 1) + # d_i: discarded activation memory for module i + d = LpVariable.matrix("d", list(range(num_nodes)), 0) + # a_i: total activation memory at module i + a = LpVariable.matrix("a", list(range(num_nodes)), 0) + # m_i: memory at module i, combining parameters, gradients, and activations + m = LpVariable.matrix("m", list(range(num_nodes)), 0) + # rcp_i: percentage of recomputation time + rcp = LpVariable.matrix("rcp", list(range(num_nodes)), 0) + # rct_i: recomputation time for module i (in ms) + rct = LpVariable.matrix("rct", list(range(num_nodes)), 0) + # max_m: peak memory + max_m = LpVariable("max_m", 0) + + # Add constraints + # [Constraint] User specified AC units + if ac_units: + ac_units_set = set(ac_units) + for i in range(num_nodes): + if graph.nodes[i]["fqn"] not in ac_units_set: + prob += y[i] == 0 + + # [Constraint] AC units cannot be supmodules of user specified FSDP units + if fsdp_units: + for i in range(num_nodes): + if any( + is_submodule(fsdp_unit, graph.nodes[i]["fqn"]) + for fsdp_unit in fsdp_units + ): + prob += y[i] == 0 + + # [Constraint] No nested AC units + for i in range(num_nodes): + for j in range(i + 1, num_nodes): + if graph.ad_matrix[i][j] == 1: + prob += y[i] + y[j] <= 1 + + # [Constraint] Do not AC leaf modules + for i in range(num_nodes): + if graph.nodes[i]["is_leaf"]: + prob += y[i] == 0 + + # [Constraint] Express amount of discarded activation memory + for i in range(num_nodes): + # There are two measures for activation memory: ACM and IA + # 1. IA is the activation memory saved when not using AC + # 2. ACM is the total activation memory, including those + # that are not typically saved when not using AC + # Note: ACM >= IA + if (not graph.nodes[i]["is_leaf"]) and graph.nodes[i][ + "sac_memory" + ] < graph.nodes[i]["act_fw_per_module"]: + logger.warning("For module {%s}: ", graph.nodes[i]["fqn"]) + logger.warning( + "activation memory from memory tracker is {%d},", + graph.nodes[i]["act_fw_per_module"], + ) + logger.warning( + "activation memory from SAC estimator is {%d}.", + graph.nodes[i]["sac_memory"], + ) + logger.warning("Something is wrong. Please check!") + logger.warning("Overriding the latter with the former.") + graph.nodes[i]["sac_memory"] = graph.nodes[i]["act_fw_per_module"] + ACM_i = graph.nodes[i]["sac_memory"] / MEM_MULTIPLIER + IA_i = graph.nodes[i]["act_fw_per_module"] / MEM_MULTIPLIER + prob += d[i] == ACM_i * r[i] - (ACM_i - IA_i) * y[i] + + # [Constraint] Ensure correctness of r_i + # There are two parts to its correctness + # 1. r_i > 0 only if y_i == 1 (discard only if it is an AC unit) + # 2. r_i needs to be large enough to cover the difference between + # ACM and IA. Otherwise, we are not saving any memory + for i in range(num_nodes): + prob += y[i] >= r[i] + if graph.nodes[i]["is_leaf"]: + continue + ACM_i = graph.nodes[i]["sac_memory"] / MEM_MULTIPLIER + IA_i = graph.nodes[i]["act_fw_per_module"] / MEM_MULTIPLIER + prob += r[i] >= (ACM_i - IA_i) / ACM_i * y[i] + + # [Constraint] Express total activation memory in the backward pass + for i in range(num_nodes): + AG_i = graph.nodes[i]["act_grad_per_module"] / MEM_MULTIPLIER + TA_i = graph.nodes[i]["act_total"] / MEM_MULTIPLIER + # related to discarded amount of memory + pos = graph.nodes[i]["pos_fw_post_order"] + coeff = [0] * num_nodes + for p in range(pos): + j = graph.name2node[graph.fw_post_order[p]]["index"] + coeff[j] = 1 + prob += a[i] == TA_i + AG_i - lpDot(coeff, d) + + # [Constraint] Express the total amount of memory at each module + # Note that unsharded parameters and gradients are not included here + P_1 = graph.nodes[0]["param_per_module"] / MEM_MULTIPLIER + for i in range(num_nodes): + TG_i = graph.nodes[i]["grad_total"] / MEM_MULTIPLIER + prob += m[i] == a[i] + (P_1 + TG_i) / world_size + + # [Constraint] Express peak memory + for i in range(num_nodes): + prob += max_m >= m[i] + + # [Constraint] Express percentage of recomputation time + for i in range(num_nodes): + for s in range(graph.nodes[i]["n_segments"]): + slope = graph.nodes[i]["slopes"][s] + intercept = graph.nodes[i]["intercepts"][s] + prob += rcp[i] >= slope * r[i] + intercept + + # [Constraint] Express recomputation time + # rct_i = (rcp_i * ACT_i) if y_i == 1 else 0 + for i in range(num_nodes): + ACT_i = graph.nodes[i]["sac_runtime"] + prob += rct[i] <= M * y[i] + prob += rct[i] <= ACT_i * rcp[i] + prob += rct[i] >= ACT_i * rcp[i] - M * (1 - y[i]) + + # [Constraint] Peak memory should be below budget + prob += max_m <= memory_budget + + # Set Objeictive + prob += lpSum(rct) + + # Solve + solver = PULP_CBC_CMD(gapRel=0.05, timeLimit=180, msg=0) + status = prob.solve(solver) + + # If solver fails, print status and return empty solution + if status != 1: + logger.error("Solver failed to find a solution: %s", LpStatus[status]) + return {}, 0, -1 + + # Gather and return solution if optimal solution is found + ac_decisions = {} + for i in range(num_nodes): + if round(y[i].varValue) == 1: + ac_decisions[graph.nodes[i]["fqn"]] = round(r[i].varValue, 4) + recomputation_time = round(value(prob.objective), 2) + peak_mem = round(max_m.varValue * MEM_MULTIPLIER) + + return ac_decisions, recomputation_time, peak_mem + + +class SACDecision(IntEnum): + RECOMPUTE = 0 + SAVE = 1 + + +def get_optimal_checkpointing_policy_per_module( + sac_stats: SACStats, memory_budget: float +) -> List[int]: + """ + This is adapted from -- + https://github.com/facebookresearch/xformers/blob/c6c0ac31f1b08542a0bc27278c6ed10f825f6963/xformers/checkpoint.py#L375 + + Given the SACStats of a module, including list of operators, their memory, runtimes, and metadata, + decide via MILP an optimal set of operators to checkpoint under a given ``memory_budget``. + + Args: + sac_stats: the SACStats object of the module + memory_budget: a float between zero and one + + Returns: + List[int]: the decision whether each operator should be saved (1) or recomptued (0). + """ + if not (0 <= memory_budget <= 1): + raise ValueError( + f"`memory_budget` must be a float between 0 and 1. Got {memory_budget}." + ) + num_ops = len(sac_stats.func_names) + + # Create a MILP problem + prob = LpProblem("SAC-per-module", LpMaximize) + + # Create decision variables + # x[i] = 1 means the i-th operator should be saved, otherwise it should be recomputed + x = LpVariable.matrix("x", list(range(num_ops)), 0, 1, LpInteger) + + # Add constraints + # [Constraint] random ops should be saved if ``force_store_random`` is True + # otherwise, random ops should either be all recomputed or all saved + if sac_stats.force_store_random: + for i in sac_stats.rand_ops: + prob += x[i] == SACDecision.SAVE.value + else: + for i1, i2 in zip(sac_stats.rand_ops[:-1], sac_stats.rand_ops[1:]): + prob += x[i1] == x[i2] + + # [Constraint] view-like ops should always be recomputed + for i in sac_stats.view_like_ops: + prob += x[i] == SACDecision.RECOMPUTE.value + + # [Constraint] inplace ops should always be done in conjunction with its parent op + for op, op_parent in sac_stats.inplace_ops: + if op != op_parent: + prob += x[op] == x[op_parent] + else: + prob += x[op] == SACDecision.SAVE.value + + # [Constraint] saved memory should be under the ``memory_budget`` + max_memory = math.ceil(memory_budget * sum(sac_stats.memory)) + prob += lpDot(x, sac_stats.memory) <= max_memory + + # [Objective] minimize recomputation time, note the ILP is a maximization problem + # because x[i] == 1 means the op is saved (not recomputed), and thus recomputation + # time is sum(sac_stats.runtimes) - lpDot(x, sac_stats.runtimes) + prob += lpDot(x, sac_stats.runtimes) + + # Solve + solver = PULP_CBC_CMD(gapRel=0.05, timeLimit=10, msg=0) + status = prob.solve(solver) + + # If solver fails, print status and return empty solution + if status != 1: + logger.error("Solver failed to find a solution: %s", LpStatus[status]) + return [] + + # Gather and return solution if optimal solution is found + return [round(x[i].varValue) for i in range(num_ops)] diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index b1296ae712f0c..b012c94ffcaa8 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -52,23 +52,11 @@ def allreduce_hook( return _allreduce_fut(process_group, bucket.buffer()) -def fp16_compress_hook( +def _compress_hook( + dtype: torch.dtype, process_group: dist.ProcessGroup, bucket: dist.GradBucket, ) -> torch.futures.Future[torch.Tensor]: - """ - Compress by casting ``GradBucket`` to ``torch.float16`` divided by process group size. - - This DDP communication hook implements a simple gradient compression - approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``) - and then divides it by the process group size. - It allreduces those ``float16`` gradient tensors. Once compressed gradient - tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). - - Example:: - >>> # xdoctest: +SKIP - >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) - """ group_to_use = process_group if process_group is not None else dist.group.WORLD world_size = group_to_use.size() @@ -77,7 +65,7 @@ def fp16_compress_hook( if isinstance(bucket, tuple) else bucket.buffer() ) - compressed_tensor = buffer.to(torch.float16).div_(world_size) + compressed_tensor = buffer.to(dtype).div_(world_size) def decompress(fut): decompressed_tensor = buffer @@ -99,7 +87,26 @@ def decompress(fut): return fut.then(decompress) -# TODO: create an internal helper function and extract the duplicate code in FP16_compress and BF16_compress. +def fp16_compress_hook( + process_group: dist.ProcessGroup, + bucket: dist.GradBucket, +) -> torch.futures.Future[torch.Tensor]: + """ + Compress by casting ``GradBucket`` to ``torch.float16`` divided by process group size. + + This DDP communication hook implements a simple gradient compression + approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``) + and then divides it by the process group size. + It allreduces those ``float16`` gradient tensors. Once compressed gradient + tensors are allreduced, the chained callback ``decompress`` casts it back to the input data type (such as ``float32``). + + Example:: + >>> # xdoctest: +SKIP + >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) + """ + return _compress_hook(torch.float16, process_group, bucket) + + def bf16_compress_hook( process_group: dist.ProcessGroup, bucket: dist.GradBucket, @@ -118,34 +125,7 @@ def bf16_compress_hook( >>> # xdoctest: +SKIP >>> ddp_model.register_comm_hook(process_group, bf16_compress_hook) """ - group_to_use = process_group if process_group is not None else dist.group.WORLD - world_size = group_to_use.size() - - buffer = ( - cast(Tuple[torch.Tensor, ...], bucket)[0] - if isinstance(bucket, tuple) - else bucket.buffer() - ) - compressed_tensor = buffer.to(torch.bfloat16).div_(world_size) - - def decompress(fut): - decompressed_tensor = buffer - # Decompress in place to reduce the peak memory. - # See: https://github.com/pytorch/pytorch/issues/45968 - value = fut if isinstance(fut, torch.Tensor) else fut.value()[0] - decompressed_tensor.copy_(value) - return decompressed_tensor - - if torch._utils.is_compiling(): - grad = dist._functional_collectives.all_reduce( - compressed_tensor, "sum", group_to_use - ) - return decompress(grad) - else: - fut = dist.all_reduce( - compressed_tensor, group=group_to_use, async_op=True - ).get_future() - return fut.then(decompress) + return _compress_hook(torch.bfloat16, process_group, bucket) def fp16_compress_wrapper( diff --git a/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py index 4727bbf9d45e6..d5b256c97df7d 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py @@ -34,7 +34,6 @@ def _reducer_allreduce_and_upcast_hook( """ ddp_weakref = hook_state.ddp_weakref reducer, process_group = ddp_weakref().reducer, ddp_weakref().process_group - gradient_is_bucket_view = ddp_weakref().gradient_as_bucket_view # Cast bucket if different than param_dtype. if ( ddp_weakref().mixed_precision.param_dtype @@ -53,8 +52,7 @@ def _reducer_allreduce_and_upcast_hook( ret_fut.set_result(bucket.buffer()) # Upcast parameters and gradients so optimizer step can run in fp32. - params, grads = bucket.parameters(), bucket.gradients() - for p, g in zip(params, grads): + for p in bucket.parameters(): p.data = p._fp_param # free storage for mp param as it will be allocated again in next # forward pass. @@ -70,7 +68,7 @@ def wait_for_stream_cb(): # they may participate in computation. However, they would not be recast # by hook above as they don't have a grad hook installed, so cast them # back here. - for n, p in ddp_weakref().module.named_parameters(): + for _, p in ddp_weakref().module.named_parameters(): if hasattr(p, "_ddp_mp_hook_state"): p._ddp_mp_hook_state[1].remove() delattr(p, "_ddp_mp_hook_state") diff --git a/torch/distributed/benchmarks/benchmark_ddp_rpc.py b/torch/distributed/benchmarks/benchmark_ddp_rpc.py index 9846cbf265f0f..5943051419ae6 100644 --- a/torch/distributed/benchmarks/benchmark_ddp_rpc.py +++ b/torch/distributed/benchmarks/benchmark_ddp_rpc.py @@ -87,7 +87,7 @@ def _retrieve_embedding_parameters(emb_rref): def _print_header(): _print_cont("\n") _print_cont("%10s" % "") - for p in [50, 75, 90, 95]: + for _ in [50, 75, 90, 95]: _print_cont("%14s%10s" % ("sec/epoch", "epoch/sec")) _print_cont("\n") @@ -112,7 +112,6 @@ def _run_printable(cmd): buffer = io.BytesIO() torch.save(proc.stdout.decode("utf-8"), buffer) input_tensor = torch.ByteTensor(list(buffer.getvalue())) - input_length = torch.IntTensor([input_tensor.size(0)]) output = [] buffer = io.BytesIO(np.asarray(input_tensor).tobytes()) @@ -173,7 +172,7 @@ def get_next_batch(rank): measurements = [] # Include warm-up cycles during training - for epoch in range(100 + WARMUP_CYCLES): + for _ in range(100 + WARMUP_CYCLES): start = time.time() batch_size = 0 diff --git a/torch/distributed/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py index d715c651e024b..e4b47ef659cd6 100644 --- a/torch/distributed/checkpoint/planner_helpers.py +++ b/torch/distributed/checkpoint/planner_helpers.py @@ -178,7 +178,7 @@ def create_read_items_for_chunk_list( dest_offsets = [] lengths = [] for ( - dim, + _dim, offset_for_saved_tensor, offset_for_current_tensor, length, diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index e6c2b0343f6be..62c1126b04c9d 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -49,6 +49,7 @@ def _init_device_mesh_stub(): is_initialized, new_group, ProcessGroup, + split_group, ) logger = logging.getLogger(__name__) @@ -499,11 +500,11 @@ def _init_process_groups(self): # functional collectives. See details in: # https://github.com/pytorch/pytorch/issues/93173#issuecomment-1907095208 dim_group_infos: List[Tuple[str, List[int], str]] = [] + default_group = _get_default_group() if self.mesh.ndim == 1 and self.mesh.numel() == get_world_size(): # Append the default pg to the first dim groups only if the default pg is compatible with `self.device_type`. # Otherwise, create new pg. - default_group = _get_default_group() ranks = list(range(get_world_size())) dim_group = ( new_group( @@ -530,36 +531,67 @@ def _init_process_groups(self): pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape( -1, self.mesh.size(dim) ) - # multi-dim mesh, create subgroups by looping over the pg_ranks - # for each dim and append the groups - for dim_mesh in pg_ranks_by_dim: - subgroup_ranks = dim_mesh.tolist() - # Respect dim group options specified via _MeshEnv.set_dim_group_options(). - # Inherit from the parent group if no options are specified for the group. - if dim in _mesh_resources.mesh_dim_group_options: - ( - backend, - pg_options, - ) = _mesh_resources.mesh_dim_group_options[dim] - else: - backend, pg_options = None, None + # Respect dim group options specified via _MeshEnv.set_dim_group_options(). + # Inherit from the parent group if no options are specified for the group. + if dim in _mesh_resources.mesh_dim_group_options: + ( + backend, + pg_options, + ) = _mesh_resources.mesh_dim_group_options[dim] + else: + backend, pg_options = None, None + + # If we have a 2D mesh with mesh_dim_names ("dp", "tp"), the group description + # of the subgroups would be `mesh_dim_dp` and `mesh_name_tp`. + # If the mesh doesn't not have a mesh_dim_names, then the group description of the + # subgroup would be `mesh_dim_0` and `mesh_dim_1`. + group_desc = ( + f"mesh_{self.mesh_dim_names[dim]}" + if self.mesh_dim_names + else f"mesh_dim_{dim}" + ) - # We temporarily revert the re-use subgroup, since it breaks two internal tests. - # Temporarily reverting to resolve test timeout while root-causing. - # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists. - group_desc = ( - f"mesh_{self.mesh_dim_names[dim]}" - if self.mesh_dim_names - else f"mesh_dim_{dim}" + # If bound_device_id exists, it means the nccl communicator has been eagerly initialized + # so that we can use `split_group` to create subgroups through `ncclCommSplit`. + # In this case, we only need to make one API call (`split_group``) for the subgroup creation + # for each mesh dimension. In a 2 * 4 mesh, we only need to make 2 API calls per ranks to create + # all the subgroups. + # Otherwise, we need to make more than one API call (`new_group`) for subgroup creations. The + # numbers of API calls are equal to the number of subgroups for each mesh dimension. In a 2 * 4 + # mesh, we need to make 2 + 4 = 6 API calls per ranks to create all the subgroups. + dim_group = None + if ( + bound_device_id := getattr( + default_group, "bound_device_id", None ) - dim_group = new_group( - ranks=subgroup_ranks, - backend=backend, + ) is not None: + dim_group = split_group( + parent_pg=default_group, pg_options=pg_options, + split_ranks=pg_ranks_by_dim.tolist(), group_desc=group_desc, ) + # If the subgroup has been already created through `split_group`, we simply loop over `pg_ranks_by_dim` + # and append the `(group_tag, subgroup_ranks, and group_name)` tuple to the `dim_group_infos` list when + # the current rank is in the subgroup. + # Otherwise, we use `new_group` instead of `split_group` to create subgroups by looping over `pg_ranks_by_dim` + # along with appending information to the `dim_group_infos` list whenever necessary. + for dim_mesh in pg_ranks_by_dim: + subgroup_ranks = dim_mesh.tolist() + + # We temporarily revert the re-use subgroup, since it breaks two internal tests. + # Temporarily reverting to resolve test timeout while root-causing. + # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists. + if bound_device_id is None: + dim_group = new_group( + ranks=subgroup_ranks, + backend=backend, + pg_options=pg_options, + group_desc=group_desc, + ) + # only add to dim_groups if the current rank in the subgroup if self.get_rank() in subgroup_ranks: if len(dim_group_infos) > dim: @@ -761,7 +793,11 @@ def from_group( group_ranks = get_process_group_ranks(group) if ( isinstance(mesh, torch.Tensor) and mesh.tolist() != group_ranks - ) or (mesh is not None and mesh != group_ranks): + ) or ( + mesh is not None + and not isinstance(mesh, torch.Tensor) + and mesh != group_ranks + ): raise ValueError( f"Invalid mesh {str(mesh)} for ProcessGroup with ranks {group_ranks}" ) diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index d6f792d2f9e2c..44ddef2375276 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -3,6 +3,7 @@ import collections.abc import contextlib +import ctypes import hashlib import io import itertools @@ -676,9 +677,9 @@ def pg_config_info(self) -> List[Dict[str, Any]]: "pg_name": self.pg_names[pg], "pg_desc": pg.group_desc, "backend_config": self.pg_backend_config[pg], - "ranks": list(ranks.keys()) - if len(ranks) != default_pg_size - else [], # 'ranks' is an empty list when all ranks are involved in a pg + "ranks": ( + list(ranks.keys()) if len(ranks) != default_pg_size else [] + ), # 'ranks' is an empty list when all ranks are involved in a pg "group_size": len(ranks), "group_count": self.group_count, } @@ -1706,6 +1707,20 @@ def _shutdown_backend(pg): backend._shutdown() +def _abort_backend(pg: ProcessGroup): + """ + Abort the backend of a process group. + Currently, only ProcessGroupNCCL backend is supported. + No op for other backends. + """ + try: + backend = pg._get_backend(torch.device("cuda")) + except RuntimeError: + backend = None + if isinstance(backend, ProcessGroupNCCL): + backend.abort() + + def _new_process_group_helper( group_size, group_rank, @@ -1761,14 +1776,9 @@ def _new_process_group_helper( # communicators based on pre-existing ones, which can save # initialization time. Due to lazy initialization of # communicators in some backends, we have to be careful and only - # split when we *know* the backends already are connected _on all - # ranks_. We can only know this if the group we are making is the - # entire world or if we have bound a device id to the world (which - # causes early connection initialization). - if is_initialized() and ( - len(global_ranks_in_group) == _get_default_group().size() - or _get_default_group().bound_device_id - ): + # split when we *know* the default PG has already started communicator initialization. + # We know this if we have bound a device id to the default pg (eager initialized). + if is_initialized() and _get_default_group().bound_device_id: split_from = _get_split_source(_get_default_group()) else: split_from = None @@ -2064,6 +2074,101 @@ def destroy_process_group(group: Optional[ProcessGroup] = None): _unregister_process_group(pg.group_name) +def _abort_process_group(group: Optional[ProcessGroup] = None): + """ + Abort a given process group. If group.WORLD (i.e. `None`) is given, all + process groups including the default one will be aborted. + + Args: + group (ProcessGroup, optional): The process group to be aborted. + + .. note:: this API is experimental and currently only works with the NCCL + backend. + + .. note:: this API should be used with `TORCH_NCCL_ASYNC_ERROR_HANDLING` + turned off (i.e. set to 0). Otherwise, ProcessGroupNCCL's watchdog may + automatically handle errors or timeouts for you including aborting the + ProcessGroup. + """ + global _world + + if group == GroupMember.NON_GROUP_MEMBER: + return + + pg = group or GroupMember.WORLD + + assert pg is not None + if _world.pg_map.get(pg, None) is None: + raise ValueError("Invalid process group specified or has been destroyed.") + + try: + backend = pg._get_backend(torch.device("cuda")) + except RuntimeError: + backend = None + + if not isinstance(backend, ProcessGroupNCCL): + logger.warning( + "`abort_process_group` currently only has implementation for ProcessGroupNCCL; " + "however, no NCCL backend is found. This call will be a no-op." + ) + return + + if group == GroupMember.WORLD: + # Abort all backends within a ncclGroupStart|End semantic. + # This ensures that different NCCL communicators' abort calls won't + # deadlock each other. + # For details, please see: https://github.com/pytorch/pytorch/issues/119797 + backend._group_start() + for pg_to_abort in sorted( + _world.pg_names, key=lambda x: _world.pg_names[x], reverse=True + ): + _abort_backend(pg_to_abort) + backend._group_end() + + _update_default_pg(None) + _world.pg_map.clear() + _world.pg_names.clear() + _world.pg_group_ranks.clear() + _world.pg_backend_config.clear() + _world.pg_to_tag.clear() + _world.tags_to_pg.clear() + _world.pg_coalesce_state.clear() + _unregister_all_process_groups() + + # when process group doesn't have an explicit name (only WORLD (default) + # process group can have an explicit name), we use global _world.group_count + # to generate the name. We need to reset the counter on destruction to + # allow consistent value to be generated when we re-create process + # groups after some trainers recover from failure + # + # We only reset this when WORLD is being destroyed because if this + # process group is in good state, we aren't dealing with failures. + _world.group_count = 0 + else: + _abort_backend(pg) + del _world.pg_map[pg] + del _world.pg_names[pg] + del _world.pg_group_ranks[pg] + del _world.pg_backend_config[pg] + if pg in _world.pg_coalesce_state.keys(): + warnings.warn( + "Some coalesced collectives haven't been launched when " + "ProcessGroup is aborted. They will be cleaned." + ) + del _world.pg_coalesce_state[pg] + + tag = _world.pg_to_tag.get(pg) + del _world.pg_to_tag[pg] + if tag is not None: + try: + _world.tags_to_pg[tag].remove(pg) + if tag.startswith("ptd:"): + _world.tags_to_pg[""].remove(pg) + except Exception: + pass + _unregister_process_group(pg.group_name) + + def get_rank(group: Optional[ProcessGroup] = None) -> int: """ Return the rank of the current process in the provided ``group``, default otherwise. @@ -4247,9 +4352,7 @@ def barrier(group=GroupMember.WORLD, async_op=False, device_ids=None): Async work handle, if async_op is set to True. None, if not async_op or if not part of the group - .. note:: `ProcessGroupNCCL` now relies on stream synchronization instead of - device synchronization to block the CPU. Thus, please do not assume that - `barrier()` would perform a device synchronization. + .. note:: `ProcessGroupNCCL` now blocks the cpu thread till the completion of the barrier collective. """ if _rank_not_in_group(group): _warn_not_in_group("barrier") @@ -4368,26 +4471,38 @@ def _create_process_group_wrapper( return wrapped_pg -# helper function for deterministically hashing a list of ranks -def _hash_ranks(ranks: List[int]): - return hashlib.sha1(bytes("_".join(map(str, ranks)), "utf-8")).hexdigest() +# helper function for deterministically hashing a list of ranks to a unique +# string +def _hash_ranks_to_str(ranks: List[int]) -> str: + rank_join: str = "_".join(map(str, ranks)) + # In case there is already a PG with the same rank composition + unique_str = "_".join([rank_join, str(len(_world.pg_names))]) + return hashlib.sha1(bytes(unique_str, "utf-8")).hexdigest() # Takes a list of ranks and computes an integer color def _process_group_color(ranks: List[int]) -> int: - # Convert our hash to an int, but avoid negative numbers by shifting a bit. - return int(_hash_ranks(ranks), 16) % (sys.maxsize >> 1) + # Convert list to tuple to make it hashable + ranks = tuple(ranks) + hash_value = hash(ranks) + # Split color must be: + # - a non-negative integer; + # - a type compatible with C's int because we are pybinding to the latter. + # Thus, we limit the hash value within c_int's max value. + max_c_int = 2 ** (ctypes.sizeof(ctypes.c_int) * 8 - 1) + color = abs(hash_value) % max_c_int + return color def _process_group_name(ranks, use_hashed_name): + # Create name for a process group. global _world if use_hashed_name: - pg_name = _hash_ranks(ranks) - while pg_name in _world.pg_names.values(): - pg_name = hashlib.sha1(bytes(pg_name + "_", "utf-8")).hexdigest() + pg_name = _hash_ranks_to_str(ranks) else: pg_name = str(_world.group_count) _world.group_count += 1 + # TODO: why is group count incremented only in the else path? return pg_name @@ -4461,7 +4576,7 @@ def split_group( raise RuntimeError( "No device associated with the default pg, not safe to split any process groups" ) - default_backend, default_store = _world.pg_map[default_pg] + _default_backend, default_store = _world.pg_map[default_pg] global_rank = default_pg.rank() global_world_size = default_pg.size() @@ -4607,6 +4722,7 @@ def new_group( pg_options=None, use_local_synchronization=False, group_desc=None, + device_id: Optional[torch.device] = None, ): """ Create a new distributed group. @@ -4659,6 +4775,9 @@ def new_group( in that non-member ranks don't need to call into API and don't join the barrier. group_desc (str, optional): a string to describe the process group. + device_id (torch.device, optional): a single, specific device + to "bind" this process to, The `new_group` call will try to initialize + a communication backend immediately for the device if this field is given. Returns: A handle of distributed group that can be given to collective calls or @@ -4682,6 +4801,7 @@ def new_group( None, use_local_synchronization=use_local_synchronization, group_desc=group_desc, + device_id=device_id, ) @@ -4693,6 +4813,7 @@ def _new_group_with_tag( pg_tag=None, use_local_synchronization=False, group_desc=None, + device_id: Optional[torch.device] = None, ): """ Variant of ``new_group`` that exposes tag creation. @@ -4703,7 +4824,12 @@ def _new_group_with_tag( global _world default_pg = _get_default_group() - device_id = default_pg.bound_device_id + if device_id is None: + device_id = default_pg.bound_device_id + elif default_pg.bound_device_id is not None: + assert ( + device_id == default_pg.bound_device_id + ), "Mismatched bound device between new pg and the default pg." default_backend, default_store = _world.pg_map[default_pg] global_rank = default_pg.rank() global_world_size = default_pg.size() diff --git a/torch/distributed/elastic/multiprocessing/errors/__init__.py b/torch/distributed/elastic/multiprocessing/errors/__init__.py index 2f5ed2d1ab0b8..e7ecd6fd63fb2 100644 --- a/torch/distributed/elastic/multiprocessing/errors/__init__.py +++ b/torch/distributed/elastic/multiprocessing/errors/__init__.py @@ -245,7 +245,7 @@ def get_first_failure(self) -> Tuple[GlobalRank, ProcessFailure]: def format_msg(self, boarder_delim="=", section_delim="-"): title = f"{self.name} FAILED" - root_rank, root_failure = self.get_first_failure() + root_rank, _root_failure = self.get_first_failure() root_failure_fmt: str = "" other_failures_fmt: List[str] = [] diff --git a/torch/distributed/elastic/timer/file_based_local_timer.py b/torch/distributed/elastic/timer/file_based_local_timer.py index f762614a8e6c5..ff5f0eed431cb 100644 --- a/torch/distributed/elastic/timer/file_based_local_timer.py +++ b/torch/distributed/elastic/timer/file_based_local_timer.py @@ -179,6 +179,8 @@ def __init__( self._timers: Dict[Tuple[int, str], FileTimerRequest] = {} self._stop_signaled = False self._watchdog_thread: Optional[threading.Thread] = None + + self._is_client_started = False if os.path.exists(self._file_path): os.remove(self._file_path) os.mkfifo(self._file_path) @@ -249,6 +251,7 @@ def _watchdog_loop(self) -> None: # 2. We are running the watchdog loop in a separate daemon # thread, which will not block the process to stop. with open(self._file_path) as fd: + self._is_client_started = True while not self._stop_signaled: try: run_once = self._run_once @@ -390,4 +393,4 @@ def _reap_worker(self, worker_pid: int, signal: int) -> bool: return False def get_last_progress_time(self) -> int: - return self._last_progress_time + return self._last_progress_time if self._is_client_started else int(time.time()) diff --git a/torch/distributed/elastic/utils/api.py b/torch/distributed/elastic/utils/api.py index bdb8f02e0176f..da3c53c936c54 100644 --- a/torch/distributed/elastic/utils/api.py +++ b/torch/distributed/elastic/utils/api.py @@ -38,7 +38,7 @@ def get_socket_with_port() -> socket.socket: s.bind(("localhost", 0)) s.listen(0) return s - except OSError as e: + except OSError: s.close() raise RuntimeError("Failed to create a socket") diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index 396f058c45a1a..cc3194bb463ab 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -274,7 +274,7 @@ def _named_parameters_with_duplicates( kwargs["remove_duplicate"] = False try: ret = list(module.named_parameters(**kwargs)) - except AssertionError as e: + except AssertionError: kwargs.pop("remove_duplicate") ret = list(module.named_parameters(**kwargs)) return ret diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index ede7d06ec9a1d..3070d9e503738 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -168,8 +168,9 @@ class _ShardParamInfo(NamedTuple): offset_in_shard: Optional[int] numel_in_shard: Optional[int] # Use to get part of the parameter in the local shard from a flattened - # version of the unsharded parameter, e.g. - # `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]` + # version of the unsharded parameter, e.g. either + # `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]` or + # `param.as_strided((param.numel(),), (1,))[intra_param_start_idx : intra_param_end_idx + 1]` intra_param_start_idx: Optional[int] intra_param_end_idx: Optional[int] # inclusive @@ -183,6 +184,10 @@ class FlatParamShardMetadata(NamedTuple): shard of the parameters; see :class:`FlatParameter`. param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's shard of the parameters; see :class:`FlatParameter`. + param_strides (Tuple[torch.Size, ...]): Parameter strides of this rank's + shard of the parameters; see :class:`FlatParameter`. + param_contiguities (Tuple[bool, ...]): Parameter `.contiguous` call results + of this rank's shard of the parameters; see :class:`FlatParameter`. param_numels (Tuple[int, ...]): Parameter numels of this rank's shard of the parameters; see :class:`FlatParameter`. param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in @@ -192,6 +197,8 @@ class FlatParamShardMetadata(NamedTuple): param_names: Tuple[str, ...] param_shapes: Tuple[torch.Size, ...] + param_strides: Tuple[Tuple[int, ...], ...] + param_contiguities: Tuple[bool, ...] param_numels: Tuple[int, ...] param_offsets: Tuple[Tuple[int, int], ...] @@ -259,6 +266,9 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta): _param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info entry; see :class:`ParamInfo` for details. _shapes (Tuple[torch.Size, ...]): Each parameter's original shape. + _strides (Tuple[torch.Size, ...]): Each parameter's original stride. + _contiguities (Tuple[bool, ...]): Each parameter's ``contiguous()`` + call result. _fqns (Tuple[str, ...]): Each parameter's fully-qualified name (FQN) prefixed from the ``_fully_sharded_module``. The names are guaranteed to be unique in the subtree rooted at that module. @@ -336,6 +346,8 @@ class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta): _num_params: int _param_infos: Tuple[ParamInfo, ...] _shapes: Tuple[torch.Size, ...] + _strides: Tuple[Tuple[int, ...], ...] + _contiguities: Tuple[bool, ...] _fqns: Tuple[str, ...] _param_extensions: Tuple[Optional[Any], ...] _numels_with_padding: Tuple[int, ...] @@ -377,6 +389,8 @@ def _init_metadata( param_infos: List[ParamInfo], numels: List[int], shapes: List[torch.Size], + strides: List[Tuple[int, ...]], + contiguities: List[bool], fqns: List[str], shared_param_infos: List[SharedParamInfo], param_extensions: List[Optional[Any]], @@ -399,11 +413,15 @@ def _init_metadata( See the Attributes in the class docstring. """ assert len(param_infos) == len(shapes) + assert len(param_infos) == len(strides) + assert len(param_infos) == len(contiguities) assert len(param_infos) == len(fqns) assert len(param_infos) == len(param_extensions) self._num_params = len(param_infos) self._param_infos = param_infos self._shapes = shapes + self._strides = strides + self._contiguities = contiguities self._fqns = fqns self._param_extensions = param_extensions self._is_padding_mask = is_padding_mask @@ -638,6 +656,8 @@ def _init_flat_param_and_metadata( param_infos: List[ParamInfo] = [] numels: List[int] = [] shapes: List[torch.Size] = [] + strides: List[Tuple[int, ...]] = [] + contiguities: List[bool] = [] fqns: List[str] = [] shared_param_infos: List[SharedParamInfo] = [] shared_param_memo: Dict[ @@ -692,6 +712,8 @@ def _init_flat_param_and_metadata( param_infos.append(ParamInfo(param_name, submodule, submodule_name)) numels.append(param.numel()) shapes.append(param.shape) + strides.append(param.stride()) + contiguities.append(_is_truly_contiguous(param)) fqn = ( submodule_name + "." + param_name if submodule_name @@ -746,6 +768,8 @@ def _init_flat_param_and_metadata( param_infos, numels, shapes, + strides, + contiguities, fqns, shared_param_infos, param_extensions, @@ -828,7 +852,11 @@ def flatten_tensors( ) flat_tensors.append(padding_tensor) total_numel += numel_to_pad - flat_tensors.append(torch.flatten(_detach_if_needed(tensor))) + flat_tensors.append( + torch.flatten(_detach_if_needed(tensor)) + if _is_truly_contiguous(tensor) + else _detach_if_needed(tensor).as_strided((tensor.numel(),), (1,)) + ) total_numel += tensor.numel() numel_to_pad = self.world_size - (total_numel % self.world_size) if numel_to_pad > 0 and numel_to_pad < self.world_size: @@ -839,7 +867,10 @@ def flatten_tensors( total_numel += numel_to_pad else: flat_tensors = [ - torch.flatten(_detach_if_needed(tensor)) for tensor in tensors + torch.flatten(_detach_if_needed(tensor)) + if _is_truly_contiguous(tensor) + else _detach_if_needed(tensor).as_strided((tensor.numel(),), (1,)) + for tensor in tensors ] return torch.cat(flat_tensors, dim=0) @@ -986,10 +1017,10 @@ def _get_shard_metadata( sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1 # `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices # into the unsharded flat parameter (inclusive) of the given parameter - for i, ( + for ( (unsharded_param_start_idx, unsharded_param_end_idx), is_padding, - ) in enumerate(zip(flat_param_offsets, self.flat_param._is_padding_mask)): + ) in zip(flat_param_offsets, self.flat_param._is_padding_mask): if is_padding: continue in_sharded_flat_param = ( @@ -1046,7 +1077,11 @@ def _get_unpadded_shard( shape (which is true in the expected usage), then this method does not allocate any new tensor memory. """ - chunks = torch.flatten(tensor).chunk(world_size) + chunks = ( + torch.flatten(tensor).chunk(world_size) + if _is_truly_contiguous(tensor) + else tensor.as_strided((tensor.numel(),), (1,)).chunk(world_size) + ) if len(chunks) < (rank + 1): # This rank gets an empty chunk fully padded with zeros since there # are not enough chunks across ranks @@ -1119,11 +1154,15 @@ def shard_metadata( """ fqns_list = [] shapes_list = [] + strides_list = [] + contiguities_list = [] numels_list = [] shard_param_offsets = [] - for fqn, shape, numel, shard_param_info in zip( + for fqn, shape, stride, contiguous, numel, shard_param_info in zip( self.flat_param._fqns, self.flat_param._shapes, + self.flat_param._strides, + self.flat_param._contiguities, self.flat_param._numels, self.flat_param._shard_param_infos, ): @@ -1131,6 +1170,8 @@ def shard_metadata( continue fqns_list.append(fqn) shapes_list.append(shape) + strides_list.append(stride) + contiguities_list.append(contiguous) numels_list.append(numel) shard_param_offsets.append( ( @@ -1141,6 +1182,8 @@ def shard_metadata( return FlatParamShardMetadata( tuple(fqns_list), tuple(shapes_list), + tuple(strides_list), + tuple(contiguities_list), tuple(numels_list), tuple(shard_param_offsets), ) @@ -1820,13 +1863,17 @@ def _get_unflat_views_unaligned( tensor = flat_param views = ( _ext_post_unflatten_transform( - subtensor.view(shape), + subtensor.view(shape) + if contiguous + else subtensor.as_strided(shape, stride), param_extension, self._fsdp_extension, ) - for (subtensor, shape, param_extension) in zip( + for (subtensor, shape, stride, contiguous, param_extension) in zip( torch.split(tensor, flat_param._numels, dim=0), flat_param._shapes, + flat_param._strides, + flat_param._contiguities, flat_param._param_extensions, ) ) @@ -1857,7 +1904,11 @@ def _get_unflat_views_aligned( continue views.append( _ext_post_unflatten_transform( - split.view(flat_param._shapes[idx]), + split.view(flat_param._shapes[idx]) + if flat_param._contiguities[idx] + else split.as_strided( + flat_param._shapes[idx], flat_param._strides[idx] + ), flat_param._param_extensions[idx], self._fsdp_extension, ) @@ -2150,8 +2201,8 @@ def _use_sharded_grad_views(self) -> None: else: param.grad = None assert flat_param._shared_params is not None - for i, (param, (_, _, _, prim_param_name, prim_module, _)) in enumerate( - zip(flat_param._shared_params, flat_param._shared_param_infos) + for param, (_, _, _, prim_param_name, prim_module, _) in zip( + flat_param._shared_params, flat_param._shared_param_infos ): in_sharded_flat_param = hasattr(prim_module, prim_param_name) if in_sharded_flat_param and param.requires_grad: @@ -2661,6 +2712,14 @@ def _convert_to_params( return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors] +def _is_truly_contiguous(x: Tensor) -> bool: + # Special case: Pytorch thinks that 1x1 channels_last convolution weights are + # both contiguous and channels_last contiguous at the same time. + # CuDNN does not agree though and refuses to select faster kernels. + # It is the reason of having the extra check here. + return x.stride(-1) == 1 and x.is_contiguous() + + def _detach_if_needed(param_or_tensor: Union[nn.Parameter, Tensor]) -> Tensor: return ( param_or_tensor.detach() diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 35beee36ef583..df78c15105011 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -536,9 +536,7 @@ def _flatten_optim_state_dict( else: # Move the tensor in the original osd back to CPU to make the # original osd unaffected. - unflat_osd_state[fqn][state_name] = unflat_osd_state[fqn][ - state_name - ].cpu() + unflat_osd_state[fqn][state_name] = param_state.cpu() # Handle user-defined state, states that are not associated with parameters. for key in all_state_keys: @@ -1457,7 +1455,7 @@ def _unflatten_orig_param_states( # gather the tensor on its TP dimension before chunking them into DTensor again. if placement != Replicate(): placement_dim = placement.dim # type: ignore[attr-defined] - value_local = value.redistribute(placements=(Replicate(),)) + value.redistribute(placements=(Replicate(),)) reshape_size = list(flat_param._shapes[param_idx]) reshape_size[placement_dim] *= value.device_mesh.size(0) reshape_size = torch.Size(reshape_size) diff --git a/torch/distributed/fsdp/_state_dict_utils.py b/torch/distributed/fsdp/_state_dict_utils.py index f96872bfa6e7c..390d2774d958e 100644 --- a/torch/distributed/fsdp/_state_dict_utils.py +++ b/torch/distributed/fsdp/_state_dict_utils.py @@ -297,7 +297,7 @@ def _full_pre_state_dict_hook( ``nn.Module``. """ if getattr(fsdp_state, "_device_mesh", False): - root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh) + _mesh_resources.get_root_mesh(fsdp_state._device_mesh) _common_pre_state_dict_hook(module, fsdp_state) _common_unshard_pre_state_dict_hook( diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py index 476bf6a18a087..d6ca4084556f4 100644 --- a/torch/distributed/pipelining/__init__.py +++ b/torch/distributed/pipelining/__init__.py @@ -3,7 +3,6 @@ from .schedules import ( _ScheduleForwardOnly, Schedule1F1B, - ScheduleFlexibleInterleaved1F1B, ScheduleGPipe, ScheduleInterleaved1F1B, ScheduleInterleavedZeroBubble, @@ -20,7 +19,6 @@ "PipelineStage", "build_stage", "Schedule1F1B", - "ScheduleFlexibleInterleaved1F1B", "ScheduleGPipe", "ScheduleInterleaved1F1B", "ScheduleLoopedBFS", diff --git a/torch/distributed/pipelining/_backward.py b/torch/distributed/pipelining/_backward.py index 1ae9007c53be4..046c1fd1fe130 100644 --- a/torch/distributed/pipelining/_backward.py +++ b/torch/distributed/pipelining/_backward.py @@ -104,7 +104,7 @@ def get_param_groups( # but omits weights and any subgraphs connecting weights to this closure inputs_closure, _ = reverse_closure(inputs, set(), reverse_edges_dict) param_groups: Dict[Node, Dict[str, Set]] = dict() # keyed on intermediates - for i, param in enumerate(params): + for param in params: closure, intersected = reverse_closure( [param], inputs_closure, reverse_edges_dict ) @@ -140,16 +140,23 @@ def get_param_groups( def stage_backward_input( - stage_outputs: List[torch.Tensor], + stage_outputs_or_loss: List[torch.Tensor], output_grads: Optional[List[torch.Tensor]], input_values: List[torch.Tensor], weights: Iterator[Parameter], ): """ - compute the gradients for only the stage inputs with respect to the stage outputs + Compute the gradients for only the stage inputs with + respect to the stage outputs (if non-last stage) or loss (if last stage) + + After computing input gradients, we save the intermediate nodes in `param_groups` + for later use in stage_backward_weight. We don't need to save any other intermediate nodes + that aren't needed for dW because when we do dW calculation, we start from saved intermediates. + Detaching the stage_outputs_or_loss at the end of this function is important as + it frees up the memory that the autograd graph is anticipating to be used later (but doesn't actually need). """ stage_output_grad_fns: List[Node] = list( - filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs)) + filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs_or_loss)) ) stage_input_grad_fns: List[Node] = list( filter(None, map(_get_grad_fn_or_grad_acc, input_values)) @@ -163,6 +170,7 @@ def stage_backward_input( stage_input_grad_fns, weight_grad_fns, reverse_edges_dict ) + handles = [] for param_group in param_groups: for i, intermediate in enumerate(param_group["intermediates"]): @@ -178,18 +186,19 @@ def hook(grad_inputs): # These are always "split" nodes that we need to recompute, so # save their inputs. - intermediate.register_prehook(get_hook(param_group, i)) + handle = intermediate.register_prehook(get_hook(param_group, i)) + handles.append(handle) # Stage 0 inputs do not require grads? Should we skip in that case? if all(tensor.requires_grad for tensor in input_values): if output_grads is None: # In case this is the loss and there are no output_grads, then we just use 1s output_grads = [ - torch.ones_like(stage_output) for stage_output in stage_outputs + torch.ones_like(stage_output) for stage_output in stage_outputs_or_loss ] dinputs = torch.autograd.grad( - stage_outputs, + stage_outputs_or_loss, inputs=input_values, grad_outputs=output_grads, retain_graph=True, @@ -201,8 +210,19 @@ def hook(grad_inputs): inp.grad = dinputs[i] else: inp.grad += dinputs[i] + + # stage_outputs_or_loss are not used in backwards after this point, so we can safely remove it from the autograd graph + # this allows autograd to clear up the graph dedicated for this tensor and free up significant memory + for t in stage_outputs_or_loss: + t.detach_() + else: dinputs = None + + # hooks are no longer necessary, clean up for consistency + for handle in handles: + handle.remove() + return dinputs, param_groups @@ -241,6 +261,9 @@ def stage_backward_weight( grad_outputs=sum(param_group["grads"], tuple()), retain_graph=retain_graph, ) + # release grad memory early after use + del param_group["grads"] + for grad_acc, dw in zip(param_group["params"], dweights): weight, index = grad_acc_to_weight[grad_acc] if weight.grad is None: diff --git a/torch/distributed/pipelining/_unflatten.py b/torch/distributed/pipelining/_unflatten.py index d5aaba95ea38c..7c68eecb3bb62 100644 --- a/torch/distributed/pipelining/_unflatten.py +++ b/torch/distributed/pipelining/_unflatten.py @@ -18,7 +18,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph): seen_nodes, seen_modules, None, - [""], + [("", 0)], "", {}, module=new_module, diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index d58bd5f3de638..05219b057b159 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs # Copyright (c) Meta Platforms, Inc. and affiliates +import copy import csv import itertools import logging @@ -38,7 +39,6 @@ "PipelineScheduleSingle", "PipelineScheduleMulti", "Schedule1F1B", - "ScheduleFlexibleInterleaved1F1B", "ScheduleGPipe", "ScheduleInterleaved1F1B", "ScheduleLoopedBFS", @@ -115,7 +115,7 @@ def from_str(action): # Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index) _action_regex = re.compile( - r"(\d+)([F,B,W]|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B{0,1})(\d*)" + r"(\d+)(F|B|W|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)" ) @@ -158,6 +158,17 @@ def _format_pipeline_order(pipeline_order: Dict[int, List[Optional[_Action]]]) - Formats the pipeline order in a timestep (row) x rank (column) grid of actions and returns the formatted string """ + + # don't mutate the original + pipeline_order = copy.deepcopy(pipeline_order) + + # Replace None with "" + for rank in pipeline_order: + for i in range(len(pipeline_order[rank])): + if pipeline_order[rank][i] is None: + # TODO make a real 'None action' that prints as empty string and make mypy happy + pipeline_order[rank][i] = "" # type: ignore[call-overload] + # Calculate the maximum number of steps across all ranks num_steps = max(len(actions) for actions in pipeline_order.values()) step_labels = [ @@ -582,6 +593,12 @@ def __init__( self._stage.has_backward = self._has_backward self._stage_initialized = False + def _initialize_stage(self, args, kwargs): + self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs) + if self._has_backward: + self._stage._prepare_backward_infra(self._n_microbatches) + self._stage_initialized = True + def step(self, *args, target=None, losses: Optional[List] = None, **kwargs): """ Run one iteration of the pipeline schedule with *whole-batch* input. @@ -639,8 +656,7 @@ def _step_microbatches( arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) if not self._stage_initialized: - self._stage._prepare_forward_infra(self._n_microbatches) - self._stage_initialized = True + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) # Delay send waits fwd_sends_to_wait: List[dist.Work] = [] @@ -691,10 +707,7 @@ def _step_microbatches( arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) if not self._stage_initialized: - self._stage._prepare_forward_infra(self._n_microbatches) - if self._has_backward: - self._stage._prepare_backward_infra(self._n_microbatches) - self._stage_initialized = True + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) # Delay send waits fwd_sends_to_wait: List[dist.Work] = [] @@ -777,10 +790,7 @@ def _step_microbatches( arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) if not self._stage_initialized: - self._stage._prepare_forward_infra(self._n_microbatches) - if self._has_backward: - self._stage._prepare_backward_infra(self._n_microbatches) - self._stage_initialized = True + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) # Last stage has 1 warmup, second-to-last 2 warmups, ... # first stage `num_stages` warmups @@ -792,7 +802,6 @@ def _step_microbatches( # Chunk counters fwd_mb_index = 0 bwd_mb_index = 0 - weight_stage_mb_index = 0 # Warmup phase send_work = None @@ -1096,6 +1105,24 @@ def __init__( self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} self.use_full_backward = use_full_backward + def _initialize_stages(self, args: Tuple[Any, ...], kwargs): + # may be 'none' value (if this stage sends its output shapes to the next stage via P2P) + # or real value (if this stage and next stage are on the same device) + next_stage_args: Tuple[Any, ...] = tuple() + for stage in self._stages: + if stage.is_first: + next_stage_args = stage._prepare_forward_infra( + self._n_microbatches, args, kwargs + ) + else: + next_stage_args = stage._prepare_forward_infra( + self._n_microbatches, next_stage_args, kwargs + ) + + if self._has_backward: + stage._prepare_backward_infra(self._n_microbatches) + self._stages_initialized = True + def _dump_csv(self, filename): """Dump a CSV representation of the schedule into a file with the provided filename.""" with open(filename, "w", newline="") as csvfile: @@ -1229,12 +1256,7 @@ def _step_microbatches( arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) if not self._stages_initialized: - for stage in self._stages: - # TODO: why do i pass args/kwargs here? its not used? - stage._prepare_forward_infra(self._n_microbatches) - if self._has_backward: - stage._prepare_backward_infra(self._n_microbatches) - self._stages_initialized = True + self._initialize_stages(arg_mbs[0], kwarg_mbs[0]) # Based on the plan in Step 1 created in __init__: # 2. Perform communication based on the pipeline_order @@ -1463,12 +1485,7 @@ def _step_microbatches( """ arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) if not self._stages_initialized: - for stage in self._stages: - # TODO: why do i pass args/kwargs here? its not used? - stage._prepare_forward_infra(self._n_microbatches) - if self._has_backward: - stage._prepare_backward_infra(self._n_microbatches) - self._stages_initialized = True + self._initialize_stages(arg_mbs[0], kwarg_mbs[0]) # Based on the plan in Step 1 created in __init__: # 2. Perform communication based on the pipeline_order @@ -1828,6 +1845,14 @@ class ScheduleInterleaved1F1B(PipelineScheduleMulti): state and supports multiple stages per rank. When microbatches are ready for multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch (also called "depth first"). + + This schedule is mostly similar to the original paper. + It differs by being relaxing the requirement of num_microbatch % pp_size == 0. + Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and + it works as long as n_microbatches % num_rounds is 0. As a few examples, support + + 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0. + 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0. """ def __init__( @@ -1840,13 +1865,6 @@ def __init__( output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, ): self.pp_group_size = stages[0].group_size - # TODO: is this limitation a must? - if n_microbatches % self.pp_group_size != 0: - raise ValueError( - f"Interleaved 1F1B schedule requires the number of microbatches ({n_microbatches}) \ - to be a multiple of the number of pipeline ranks ({self.pp_group_size})." - ) - super().__init__( stages=stages, n_microbatches=n_microbatches, @@ -1855,16 +1873,20 @@ def __init__( kwargs_chunk_spec=kwargs_chunk_spec, output_merge_spec=output_merge_spec, ) - self.n_local_stages = len(stages) self.rank = stages[0].group_rank - self.group = stages[0].group - + self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) + self.microbatches_per_round = n_microbatches // self.number_of_rounds + if n_microbatches % self.number_of_rounds != 0: + raise ValueError( + "Interleaved 1F1B requires the number of microbatches to be a " + f"multiple of the number of rounds ({self.number_of_rounds}), " + f"but got {n_microbatches}." + ) # 1. Create the pipeline_order (all ranks do this calculation) # This will be used to keep track of the current state of the entire pipeline # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] self.pipeline_order: Dict[int, List[Optional[_Action]]] = {} - for rank in range(self.pp_group_size): rank_ops = self._calculate_single_rank_operations(rank) self.pipeline_order[rank] = rank_ops @@ -1872,9 +1894,15 @@ def __init__( def _calculate_single_rank_operations(self, rank) -> List[Optional[_Action]]: def get_rank_warmup_ops(rank): # Warms up operations for last stage - warmups_ops_last_stage = (self.n_local_stages - 1) * self.pp_group_size + warmups_ops_last_stage = ( + self.n_local_stages - 1 + ) * self.microbatches_per_round # Increment warmup operations by 2 for each hop away from the last stage - warmup_ops = warmups_ops_last_stage + 2 * ((self.pp_group_size - 1) - rank) + multiply_factor = 2 + warmup_ops = warmups_ops_last_stage + multiply_factor * ( + (self.pp_group_size - 1) - rank + ) + # We cannot have more warmup operations than there are number of microbatches, so cap it there return min(warmup_ops, self._n_microbatches * self.n_local_stages) @@ -1887,7 +1915,6 @@ def get_rank_warmup_ops(rank): # total ops encompass both forward and backward ops total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 - logger.debug( "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", rank, @@ -1900,14 +1927,15 @@ def get_rank_warmup_ops(rank): # Calculates the stage index based on step and pp_group_size def forward_stage_index(step): # Get the local index from 0 to n_local_stages-1 - local_index = (step // self.pp_group_size) % self.n_local_stages + local_index = (step // self.microbatches_per_round) % self.n_local_stages return (local_index * self.pp_group_size) + rank def backward_stage_index(step): local_index = ( self.n_local_stages - 1 - - ((step - warmup_ops) // self.pp_group_size) % self.n_local_stages + - ((step - warmup_ops) // self.microbatches_per_round) + % self.n_local_stages ) return (local_index * self.pp_group_size) + rank @@ -1923,19 +1951,15 @@ def backward_stage_index(step): ) -class ScheduleFlexibleInterleaved1F1B(PipelineScheduleMulti): +class ScheduleInterleavedZeroBubble(PipelineScheduleMulti): """ - The Flexible Interleaved 1F1B schedule. - - This schedule is mostly similar to the interleaved 1F1B schedule. - It differs by being relaxing the requirement of num_microbatch % pp_size == 0. - Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and - it works as long as n_microbatches % num_rounds is 0. As a few examples, support - - 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0. - 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0. + The Interleaved Zero Bubble schedule. + See https://arxiv.org/pdf/2401.10241 for details. + Will perform one forward and one backward on inputs for the microbatches in steady + state and supports multiple stages per rank. Uses the backward for weights to fill in + the pipeline bubble. - When enable_zero_bubble is True, we will use the ZB1P schedule in https://openreview.net/pdf?id=tuzTN0eIO5 + In particular this is implementing the ZB1P schedule in the paper. """ def __init__( @@ -1946,7 +1970,6 @@ def __init__( args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, - enable_zero_bubble: bool = False, ): self.pp_group_size = stages[0].group_size super().__init__( @@ -1956,16 +1979,15 @@ def __init__( args_chunk_spec=args_chunk_spec, kwargs_chunk_spec=kwargs_chunk_spec, output_merge_spec=output_merge_spec, - use_full_backward=not enable_zero_bubble, + use_full_backward=False, ) self.n_local_stages = len(stages) self.rank = stages[0].group_rank self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) self.microbatches_per_round = n_microbatches // self.number_of_rounds - self.enable_zero_bubble = enable_zero_bubble if n_microbatches % self.number_of_rounds != 0: raise ValueError( - "Flexible Interleaved 1F1B requires the number of microbatches to be a " + "Zero bubble requires the number of microbatches to be a " f"multiple of the number of rounds ({self.number_of_rounds}), " f"but got {n_microbatches}." ) @@ -1991,7 +2013,7 @@ def get_rank_warmup_ops(rank): self.n_local_stages - 1 ) * self.microbatches_per_round # Increment warmup operations by 2 for each hop away from the last stage - multiply_factor = 1 if self.enable_zero_bubble else 2 + multiply_factor = 1 warmup_ops = warmups_ops_last_stage + multiply_factor * ( (self.pp_group_size - 1) - rank ) @@ -2033,21 +2055,7 @@ def backward_stage_index(step): ) return (local_index * self.pp_group_size) + rank - if self.enable_zero_bubble: - num_1f1b_microbatches = rank - - return _get_1f1b_rank_ops( - self.n_local_stages, - self.pp_group_size, - warmup_ops, - fwd_bwd_ops, - cooldown_ops, - rank, - forward_stage_index, - backward_stage_index, - num_1f1b_microbatches, - enable_zero_bubble=True, - ) + num_1f1b_microbatches = rank return _get_1f1b_rank_ops( self.n_local_stages, @@ -2058,12 +2066,12 @@ def backward_stage_index(step): rank, forward_stage_index, backward_stage_index, + num_1f1b_microbatches, + enable_zero_bubble=True, ) def _add_bubbles_to_actions(self, num_stages_global): actions = self.pipeline_order - if not self.enable_zero_bubble: - return actions def need_bubble(stage, op, microbatch, num_stages_global, seen_ops): if op == _ComputationType.FORWARD: @@ -2129,35 +2137,6 @@ def need_bubble(stage, op, microbatch, num_stages_global, seen_ops): return result -class ScheduleInterleavedZeroBubble(ScheduleFlexibleInterleaved1F1B): - """ - The Interleaved Zero Bubble schedule. - See https://arxiv.org/pdf/2401.10241 for details. - Will perform one forward and one backward on inputs for the microbatches in steady - state and supports multiple stages per rank. Uses the backward for weights to fill in - the pipeline bubble. - """ - - def __init__( - self, - stages: List[_PipelineStageBase], - n_microbatches: int, - loss_fn: Optional[Callable] = None, - args_chunk_spec: Optional[Tuple[TensorChunkSpec, ...]] = None, - kwargs_chunk_spec: Optional[Dict[str, TensorChunkSpec]] = None, - output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None, - ): - super().__init__( - stages=stages, - n_microbatches=n_microbatches, - loss_fn=loss_fn, - args_chunk_spec=args_chunk_spec, - kwargs_chunk_spec=kwargs_chunk_spec, - output_merge_spec=output_merge_spec, - enable_zero_bubble=True, - ) - - def get_schedule_class(schedule_name: str): """ Maps a schedule name (case insensitive) to its corresponding class object. @@ -2169,7 +2148,6 @@ def get_schedule_class(schedule_name: str): "1F1B": Schedule1F1B, "Interleaved1F1B": ScheduleInterleaved1F1B, "GPipe": ScheduleGPipe, - "FlexibleInterleaved1F1B": ScheduleFlexibleInterleaved1F1B, "LoopedBFS": ScheduleLoopedBFS, "InterleavedZeroBubble": ScheduleInterleavedZeroBubble, "PipelineScheduleSingle": PipelineScheduleSingle, diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index f87eabb39565f..7ea111c92e969 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -254,7 +254,9 @@ def map_recv_to_send(a): def _prepare_forward_infra( self, num_microbatches: int, - ): + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[Any, ...]: raise NotImplementedError def _prepare_backward_infra(self, num_microbatches: int): @@ -693,6 +695,15 @@ def backward_one_chunk( self.grads_input = grads_input # Save a placeholder for the dw_runner self.dw_runner[bwd_chunk_id] = lambda: None + + if self.is_last: + # Autograd dependencies: + # rest_of_autograd_graph -> stage_output -> loss + # stage_output is no longer used in the last stage for backward and only needed + # to return to the user in merge_output_chunks, therefore + # this should be detached to release autograd graph context and free memory earlier + for t in stage_output: + t.detach_() logger.debug("%s Backwarded chunk %s", self.log_prefix, bwd_chunk_id) def backward_weight_one_chunk(self, bwd_chunk_id: int): @@ -717,7 +728,7 @@ def backward_weight_one_chunk(self, bwd_chunk_id: int): "param_groups": param_groups, "full_backward": False, } - weight_grads, _ = self.backward_maybe_with_nosync("weight", bwd_kwargs) + self.backward_maybe_with_nosync("weight", bwd_kwargs) else: # TODO: figure out a better way to do this: # if inputs does not require gradient, @@ -848,15 +859,21 @@ def _move_submod_to_device(self): def _prepare_forward_infra( self, num_microbatches: int, - ): + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[Any, ...]: """ Create send/recv infrastructures for activations (during forward) """ + # TODO(whc) + # this method should be deleted once lazy buffer allocation is implemented + # for now, it ignores args/kwargs becuase it should not need to do shape inference for chunk in range(num_microbatches): self.args_recv_info[chunk] = self._create_act_recv_info() # Send info during forward for each activation self.act_send_info = self._create_act_send_info() + return tuple() def get_stage_index_of_submod( self, @@ -1234,10 +1251,13 @@ def _get_stage_shapes( class PipelineStage(_PipelineStageBase): """ A class representing a pipeline stage in a pipeline parallelism setup. - This class is created manually by providing a example input (and optionally output) - as opposed to the PipelineStage class that is outputed from pipeline(). - This class extends the `_PipelineStageBase` class and can similarly be used - in `PipelineScheule`. + + PipelineStage assumes sequential partitioning of the model, i.e. the model is split into chunks where outputs from + one chunk feed into inputs of the next chunk, with no skip connections. + + PipelineStage performs runtime shape/dtype inference automatically by propagating the outputs from stage0 to + stage1 and so forth, in linear order. To bypass shape inference, pass the `input_args` and `output_args` to each + PipelineStage instance. Args: submodule (nn.Module): The PyTorch module wrapped by this stage. @@ -1256,33 +1276,49 @@ def __init__( stage_index: int, num_stages: int, device: torch.device, - input_args: Union[torch.Tensor, Tuple[torch.Tensor, ...]], + input_args: Optional[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = None, output_args: Optional[Union[torch.Tensor, Tuple[torch.Tensor, ...]]] = None, group: Optional[dist.ProcessGroup] = None, dw_builder: Optional[Callable[[], Callable[..., None]]] = None, ): super().__init__(submodule, stage_index, num_stages, device, group, dw_builder) self.inputs: Optional[List[torch.Tensor]] = None - + self.inputs_meta: Optional[Tuple[torch.Tensor, ...]] = None # Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) becuase it # might be breaking for existing users. - self.inputs_meta = ( - (input_args,) if isinstance(input_args, torch.Tensor) else input_args - ) - if output_args is None: - try: - output_args = submodule(*self.inputs_meta) - output_args = tree_map_only( - torch.Tensor, lambda x: x.to("meta"), output_args + if input_args is None: + assert output_args is None, ( + "If specifying output_args, input_args must also be specified. " + "Otherwise, shape inference will be performed at runtime" + ) + else: + self.inputs_meta = ( + (input_args,) if isinstance(input_args, torch.Tensor) else input_args + ) + if output_args is None: + logger.warning( + "Deprecation warning: passing input_args and performing init-time shape inference is deprecated. " + "PipelineStage now supports runtime shape inference using the real inputs provided to schedule step(). " + "Either delete `input_args` arg to `PipelineStage` to opt-into runtime shape inference, " + "or additionally pass `output_args` to `PipelineStage` to fully override shape inference. " ) - except Exception as e: - raise RuntimeError( - "Failed to perform pipeline shape inference- are your inputs on the same device as your module?" - ) from e - assert output_args is not None # for mypy - self._configure_outputs_meta( - (output_args,) if isinstance(output_args, torch.Tensor) else output_args - ) + try: + with torch.no_grad(): + output_args = submodule(*self.inputs_meta) + output_args = tree_map_only( + torch.Tensor, lambda x: x.to("meta"), output_args + ) + except Exception as e: + raise RuntimeError( + "Failed to perform pipeline shape inference- are your inputs on the same device as your module?" + ) from e + assert ( + output_args is not None + ), "If passing input_args, also pass output_args to override shape inference" + self._configure_outputs_meta( + (output_args,) if isinstance(output_args, torch.Tensor) else output_args + ) + # these are the buffers used in backwards send/recv, they are allocated later self.outputs_grad: List[torch.Tensor] = [] @@ -1293,22 +1329,132 @@ def stage_global_rank(peer_rank): else dist.get_global_rank(self.group, peer_rank) ) - self.prev_stage = stage_global_rank((self.group_rank - 1) % self.group_size) - self.next_stage = stage_global_rank((self.group_rank + 1) % self.group_size) + self.prev_rank = stage_global_rank((self.group_rank - 1) % self.group_size) + self.next_rank = stage_global_rank((self.group_rank + 1) % self.group_size) - logger.debug( - f"finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004 + dbg_str = ( + f"Finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004 f"{self.is_last=}, {self.num_stages=}, " - f"inputs: {[inp.shape for inp in self.inputs_meta]}, " - f"output: {[output.shape for output in self.get_outputs_meta()]}" + ) + if self.inputs_meta is not None: + dbg_str += ( + f"inputs: {[inp.shape for inp in self.inputs_meta]}, " + f"output: {[output.shape for output in self.get_outputs_meta()]}" + ) + else: + dbg_str += " running shape-inference at runtime" + + logger.debug(dbg_str) + + def _shape_inference( + self, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + ): + if kwargs is None: + kwargs = {} + assert args is not None, "Args may be an empty tuple but not None" + + # We skip recv communication if we're the first stage, but also if the previous stage is on the same rank + # and can pass its output shapes in as args instead of using send/recv. + if ( + self.is_first + # if not first stage, then check if prev stage is on the same rank + or self.stage_index_to_group_rank[self.stage_index - 1] == self.group_rank + ): + logger.debug( + "Shape inference: stage %s skipping recv, because shape info passed in via `args`", + self.stage_index, + ) + args = tree_map_only(torch.Tensor, lambda x: x.to("meta"), args) + else: + assert ( + len(args) == 0 + ), "Can't supply input args for shape inference on non-first stage" + objects = [None] + logger.debug( + "Shape inference: stage %s receiving from stage %s", + self.stage_index, + self.stage_index - 1, + ) + dist.recv_object_list( + objects, src=self.prev_rank, group=self.group, device=self.device + ) + recv_args = objects[0] + assert isinstance(recv_args, tuple), type(recv_args) + args = recv_args + + # cache input shapes for use during recv buffer allocation + self.inputs_meta = args + args = tree_map_only( + torch.Tensor, lambda x: torch.zeros_like(x, device=self.device), args ) + # set attributes needed for forward + with torch.no_grad(): + logger.debug("Shape inference: stage %s running forward", self.stage_index) + outputs = self.submod(*args, **kwargs) + + # if single tensor, convert so it is always a list + if isinstance(outputs, torch.Tensor): + outputs = [outputs] + + # communicate meta outputs not real outputs for two reasons + # 1 - its faster (esp. since obj coll pickles tensor data!) + # 2 - avoid activating a cuda context for the src rank when unpickling on the recv end! + outputs_meta = tuple( + tree_map_only(torch.Tensor, lambda x: x.to("meta"), outputs) + ) + self._configure_outputs_meta(outputs_meta) + + # Passing outputs to the next stage: + # two cases- + # 1. Usually: use send/recv communication to pass the output + # 2. Special case: for V-schedules, 2 'adjacent' stages (e.g. stage 3, 4 in an 8-stage 4-rank V) + # pass their shape info via return value and function args rather than send/recv. + if ( + self.is_last + # if not last stage, then check if next stage is on the same rank + or self.stage_index_to_group_rank[self.stage_index + 1] == self.group_rank + ): + # Case (2) above: pass shape info via return value and caller passes it as args to next stage's + # _shape_inference call + logger.debug( + "Shape inference: stage %s skipping send to next stage", + self.stage_index, + ) + + else: + # Case (1): send shapes via send operation, and ensure not to return it to the caller + logger.debug( + "Shape inference: stage %s sending to stage %s", + self.stage_index, + self.stage_index + 1, + ) + dist.send_object_list( + [outputs_meta], + dst=self.next_rank, + group=self.group, + device=self.device, + ) + outputs_meta = tuple() + + return outputs_meta + def _prepare_forward_infra( self, num_microbatches: int, - ) -> None: + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[Any, ...]: # TODO move self.device to an argument from step API (from its input tensors)? + assert num_microbatches is not None, "TODO fix num_microbatches" + + outputs: Tuple[Any, ...] = tuple() + if self.inputs_meta is None: + outputs = self._shape_inference(args, kwargs) + assert self.inputs_meta is not None # Receive info during forward # TODO: create args_recv_info lazily? (same needed for PipelineStage) for chunk_id in range(num_microbatches): @@ -1339,8 +1485,6 @@ def _prepare_forward_infra( # only need the rank that is being sent to self.act_send_info: Dict[int, List] = {} - # TODO: we didn't require output args at __init__ before, but now we do. enforce it. until we enable lazy-init - # get_outputs_meta will assert for us for idx in range(len(self.get_outputs_meta())): # We assume we always send to stage + 1 if not self.is_last: @@ -1348,6 +1492,8 @@ def _prepare_forward_infra( else: self.act_send_info[idx] = [] + return outputs + def _create_grad_recv_info( self, act_send_info: Dict, @@ -1382,15 +1528,15 @@ def _init_p2p_neighbors(self): send_tensor = torch.ones(1, device="cuda") # forward if not self.is_first: - ops.append(dist.P2POp(dist.irecv, recv_tensor, self.prev_stage, self.group)) + ops.append(dist.P2POp(dist.irecv, recv_tensor, self.prev_rank, self.group)) if not self.is_last: - ops.append(dist.P2POp(dist.isend, send_tensor, self.next_stage, self.group)) + ops.append(dist.P2POp(dist.isend, send_tensor, self.next_rank, self.group)) # backward if not self.is_first: - ops.append(dist.P2POp(dist.isend, send_tensor, self.prev_stage, self.group)) + ops.append(dist.P2POp(dist.isend, send_tensor, self.prev_rank, self.group)) if not self.is_last: - ops.append(dist.P2POp(dist.irecv, recv_tensor, self.next_stage, self.group)) + ops.append(dist.P2POp(dist.irecv, recv_tensor, self.next_rank, self.group)) return True @@ -1466,13 +1612,13 @@ def _validate_stage_shapes(pipeline_stages: List[PipelineStage]): ] logger.debug( - f"Rank: {pg_rank}" # noqa: G004 - f"Stage id: {stage_id}" - f"Stage num stages: {stage.num_stages}" - f"Stage rank: {rank}" - f"Stage world size: {world_size}" - f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}" # noqa: G003 - f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}" # noqa: G003 + f"Rank: {pg_rank}", # noqa: G004 + f"Stage id: {stage_id}", + f"Stage num stages: {stage.num_stages}", + f"Stage rank: {rank}", + f"Stage world size: {world_size}", + f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} input shapes: {stage_input_shapes}", # noqa: G003 + f"Stage {virtual_id * world_size}-{(virtual_id + 1) * world_size - 1} output shapes: {stage_output_shapes}", # noqa: G003 ) all_inputs.extend(stage_input_shapes) diff --git a/torch/distributed/tensor/README.md b/torch/distributed/tensor/README.md index 2fedb7cc3b426..3cfe16910853f 100644 --- a/torch/distributed/tensor/README.md +++ b/torch/distributed/tensor/README.md @@ -10,7 +10,7 @@ We propose distributed tensor primitives to allow easier distributed computation # torchrun --standalone --nnodes=1 --nproc-per-node=4 dtensor_example.py import os import torch -from torch.distributed._tensor import init_device_mesh, Shard, distribute_tensor +from torch.distributed.tensor import init_device_mesh, Shard, distribute_tensor # Create a mesh topology with the available devices: # 1. We can directly create the mesh using elastic launcher, (recommended) @@ -54,7 +54,7 @@ Here are some basic DTensor API examples that showcase: ```python # torchrun --standalone --nnodes=1 --nproc-per-node=4 dtensor_example.py import torch -from torch.distributed._tensor import DTensor, Shard, Replicate, distribute_tensor, distribute_module, init_device_mesh +from torch.distributed.tensor import DTensor, Shard, Replicate, distribute_tensor, distribute_module, init_device_mesh # construct a device mesh with available devices (multi-host or single host) device_mesh = init_device_mesh("cuda", (4,)) @@ -114,7 +114,7 @@ def distribute_module( ```python import torch.nn as nn -from torch.distributed._tensor import Shard, distribute_tensor, distribute_module, init_device_mesh +from torch.distributed.tensor import Shard, distribute_tensor, distribute_module, init_device_mesh class MyModule(nn.Module): def __init__(self) -> None: diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 4579a16826d0f..4383918ca35c2 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -77,7 +77,7 @@ def found_inf_reduce_handler( cast(List[object], op_info.local_args), op_info.args_tree_spec ) local_tensor_args = cast(Tuple[object, ...], local_tensor_args) - local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + op_call(*local_tensor_args, **op_info.local_kwargs) grad_dtensor = cast(list[dtensor.DTensor], args[0])[0] grad_placements = grad_dtensor.placements diff --git a/torch/distributed/tensor/_ops/_common_rules.py b/torch/distributed/tensor/_ops/_common_rules.py index 059dd04bd2f4d..dbc0864f9c974 100644 --- a/torch/distributed/tensor/_ops/_common_rules.py +++ b/torch/distributed/tensor/_ops/_common_rules.py @@ -1,4 +1,5 @@ # Copyright (c) Meta Platforms, Inc. and affiliates +import string from typing import cast, Dict, List, Optional, Tuple import torch @@ -234,7 +235,7 @@ def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputShardi ij,ij->ij - addition/mul ij,j->ij - broadcasted addition """ - alphabet = "abcdefghijklmnopqrstuvwxyz" + alphabet = string.ascii_lowercase # find the max_dim first in case we need to broadcasting input_specs = op_schema.args_spec max_dim = max(input.ndim for input in input_specs) diff --git a/torch/distributed/tensor/_ops/_conv_ops.py b/torch/distributed/tensor/_ops/_conv_ops.py index db2a8136e14da..f6e98fcf7a774 100644 --- a/torch/distributed/tensor/_ops/_conv_ops.py +++ b/torch/distributed/tensor/_ops/_conv_ops.py @@ -21,9 +21,9 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding: stride, padding, dilation, - transposed, - output_padding, - groups, + _transposed, + _output_padding, + _groups, ) = op_schema.args_schema assert isinstance(input_spec, DTensorSpec) @@ -37,7 +37,7 @@ def convolution_rules(op_schema: OpSchema) -> OutputSharding: assert isinstance(padding, List) assert isinstance(dilation, List) assert isinstance(weight_shape, torch.Size) - N, C_in, H_in, W_in = in_shape[0], in_shape[1], in_shape[2], in_shape[3] + N, H_in, W_in = in_shape[0], in_shape[2], in_shape[3] C_out = weight_shape[0] H_out = (H_in + 2 * padding[0] - dilation[0] * (weight_shape[2] - 1) - 1) // stride[ 0 @@ -73,13 +73,13 @@ def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding: input_spec, weight_spec, bias_shape_opt, - stride, - padding, - dilation, - transposed, - output_padding, - groups, - output_mask, + _stride, + _padding, + _dilation, + _transposed, + _output_padding, + _groups, + _output_mask, ) = op_schema.args_schema assert isinstance(grad_output_spec, DTensorSpec) diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index eb3b651ac0d03..e6d6cc4909567 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -479,7 +479,7 @@ def softmax_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: softmax_dim = normalize_dim(softmax_dim, input_strategy.ndim) output_strategy = OpStrategy([]) - for idx, input_placement_strategy in enumerate(input_strategy.strategies): + for input_placement_strategy in input_strategy.strategies: redistribute_costs = [] input_src_spec = input_placement_strategy.output_spec @@ -1038,8 +1038,6 @@ def _add_target_input_spec(strategy) -> DTensorSpec: ) def topk_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> OpStrategy: input_strategy = cast(OpStrategy, op_schema.args_schema[0]) - k = cast(int, op_schema.args_schema[1]) - input_shape = input_strategy.shape topk_dim = ( cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1 ) diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index fd9a7a430a70e..845664e82f19e 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -171,7 +171,6 @@ def scaled_dot_product_flash_attention_strategy( q_input_strategy = op_schema.args_schema[0] assert isinstance(q_input_strategy, OpStrategy) # assuming q/k/v have the same shape - qkv_shape = q_input_strategy.shape single_mesh_dim_strategies = [] @@ -250,7 +249,6 @@ def scaled_dot_product_flash_attention_backward_strategy( q_input_strategy = op_schema.args_schema[1] assert isinstance(q_input_strategy, OpStrategy) # assuming q/k/v have the same shape - qkv_shape = q_input_strategy.shape tensor_input_indices = [ i @@ -344,7 +342,7 @@ def scaled_dot_product_efficient_attention_strategy( q_input_strategy = op_schema.args_schema[0] assert isinstance(q_input_strategy, OpStrategy) # assuming q/k/v have the same shape - qkv_shape = q_input_strategy.shape + has_attn_bias = op_schema.args_schema[3] is not None compute_log_sumexp = op_schema.args_schema[4] @@ -418,15 +416,8 @@ def scaled_dot_product_efficient_attention_backward_strategy( q_input_strategy = op_schema.args_schema[1] assert isinstance(q_input_strategy, OpStrategy) # assuming q/k/v have the same shape - qkv_shape = q_input_strategy.shape has_attn_bias = op_schema.args_schema[4] is not None - tensor_input_indices = [ - i - for i, arg_spec in enumerate(op_schema.args_schema) - if isinstance(arg_spec, OpStrategy) - ] - single_mesh_dim_strategies = [] # placement list stores placements of [outputs, inputs] diff --git a/torch/distributed/tensor/_ops/_tensor_ops.py b/torch/distributed/tensor/_ops/_tensor_ops.py index e9bcb3b0d1224..76f7d730c37e7 100644 --- a/torch/distributed/tensor/_ops/_tensor_ops.py +++ b/torch/distributed/tensor/_ops/_tensor_ops.py @@ -367,7 +367,6 @@ def replica_only_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType schema_info=RuntimeSchemaInfo(1), ) def scatter_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType: - input_strategy = cast(OpStrategy, op_schema.args_schema[0]) single_mesh_dim_strategies = [] # placement list stores placements of [output, input, index, src] diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index 88414081a1785..3a3051c817fab 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -67,7 +67,7 @@ def _gen_transform_infos_non_cached( # Handle multi-dim device mesh placement redistribution # First, we need to build the logical shape for each mesh dim # for correct allgathering uneven shards on each mesh dim (with dynamic padding) - for i, (src, dst) in enumerate(zip(src_spec.placements, dst_spec.placements)): + for i, src in enumerate(src_spec.placements): current_logical_shape = mesh_dims_to_logical_shape[i] if isinstance(src, Shard): if i < device_mesh.ndim - 1: @@ -192,7 +192,7 @@ def redistribute_local_tensor( for transform_info in transform_infos: i = transform_info.mesh_dim current, target = transform_info.src_dst_placements - num_chunks = device_mesh.size(mesh_dim=i) + device_mesh.size(mesh_dim=i) if current == target: # short cut, just use the original local tensor @@ -220,7 +220,6 @@ def redistribute_local_tensor( elif target.is_shard(): # Case 2: target is Shard target_placement = cast(Shard, target) - target_dim = target_placement.dim if current.is_partial(): partial_spec = cast(Partial, current) new_local_tensor = partial_spec._reduce_shard_value( diff --git a/torch/distributed/tensor/_tp_conv.py b/torch/distributed/tensor/_tp_conv.py index ac11ef2162cbb..5ebb66b740f92 100644 --- a/torch/distributed/tensor/_tp_conv.py +++ b/torch/distributed/tensor/_tp_conv.py @@ -192,7 +192,6 @@ def tp_convolution_backward( ) # step2 reconstruct local gradient output tensor - N, C_out, H_out, _ = grad_out_tensor.shape padding_w = padding[1] if rank == 0: grad_out_tensor = torch.nn.functional.pad( diff --git a/torch/distributed/tensor/examples/comm_mode_features_example.py b/torch/distributed/tensor/examples/comm_mode_features_example.py index 9814397314533..c6c8cc7944761 100644 --- a/torch/distributed/tensor/examples/comm_mode_features_example.py +++ b/torch/distributed/tensor/examples/comm_mode_features_example.py @@ -269,7 +269,7 @@ def example_transformer_module_tracing(self) -> None: comm_mode = CommDebugMode() with comm_mode: - output = model(inp) + model(inp) # print the module level collective tracing information print(comm_mode.generate_comm_debug_tracing_table(noise_level=0)) @@ -592,7 +592,7 @@ def example_transformer_operation_tracing( comm_mode = CommDebugMode() with comm_mode: - output = model(inp) + model(inp) # print the operation level collective tracing information print(comm_mode.generate_comm_debug_tracing_table(noise_level=2)) @@ -628,7 +628,7 @@ def example_transformer_json_dump(self, is_seq_parallel: bool = False) -> None: comm_mode = CommDebugMode() with comm_mode: - output = model(inp) + model(inp) comm_mode.generate_json_dump(file_name="transformer_log.json", noise_level=1) comm_mode.generate_json_dump(file_name="transformer_log_2.json", noise_level=2) diff --git a/torch/distributed/tensor/examples/convnext_example.py b/torch/distributed/tensor/examples/convnext_example.py index 57d7bca8cc08b..ec035644f0d54 100644 --- a/torch/distributed/tensor/examples/convnext_example.py +++ b/torch/distributed/tensor/examples/convnext_example.py @@ -220,7 +220,7 @@ def train_convnext_example(): forward_time = 0.0 backward_time = 0.0 start = time.time() - for i in range(ITER_TIME): + for _ in range(ITER_TIME): t1 = time.time() y = model(x) torch.cuda.synchronize() diff --git a/torch/distributed/tensor/examples/torchrec_sharding_example.py b/torch/distributed/tensor/examples/torchrec_sharding_example.py index 9e6f4054e292b..fc7335b53f4e4 100644 --- a/torch/distributed/tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/tensor/examples/torchrec_sharding_example.py @@ -130,7 +130,6 @@ def run_torchrec_row_wise_even_sharding_example(rank, world_size): # manually create the embedding table's local shards num_embeddings = 8 embedding_dim = 16 - emb_table_shape = torch.Size([num_embeddings, embedding_dim]) # tensor shape local_shard_shape = torch.Size( [num_embeddings // world_size, embedding_dim] # (local_rows, local_cols) @@ -270,7 +269,7 @@ def run_torchrec_table_wise_sharding_example(rank, world_size): device = torch.device(device_type) # note: without initializing this mesh, the following local_tensor will be put on # device cuda:0. - device_mesh = init_device_mesh(device_type=device_type, mesh_shape=(world_size,)) + init_device_mesh(device_type=device_type, mesh_shape=(world_size,)) # manually create the embedding table's local shards num_embeddings = 8 @@ -293,8 +292,6 @@ def run_torchrec_table_wise_sharding_example(rank, world_size): else torch.empty(0, device=device) ) table_to_local_tensor[i] = local_tensor - # tensor shape - local_shard_shape = local_tensor.shape # tensor offset local_shard_offset = torch.Size((0, 0)) # wrap local shards into a wrapper diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index 101874796d963..8b967f877c3da 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -7,7 +7,7 @@ import weakref from abc import ABC, abstractmethod from dataclasses import dataclass -from enum import Enum +from enum import auto, Enum from typing import ( Any, Callable, @@ -33,6 +33,18 @@ __all__ = ["context_parallel"] + +class _CausalBehavior(Enum): + SKIP = None + NOT_IS_CAUSAL = False + IS_CAUSAL = True + + +class _RotateMethod(Enum): + ALL_TO_ALL = auto() + ALL_GATHER = auto() + + aten = torch.ops.aten logger = logging.getLogger(__name__) @@ -44,17 +56,12 @@ class _ContextParallelOptions: # for the experimental purpose. convert_to_f32: bool = True enable_load_balance = True + rotate_method: _RotateMethod = _RotateMethod.ALL_GATHER _cp_options = _ContextParallelOptions() -class _CausalBehavior(Enum): - SKIP = None - NOT_IS_CAUSAL = False - IS_CAUSAL = True - - def _is_causal_behavior( rank: int, world_size: int, i: int, is_causal: bool ) -> _CausalBehavior: @@ -258,6 +265,83 @@ def __call__( ... +class _RingRotater(ABC): + @abstractmethod + def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: + ... + + @abstractmethod + def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: + ... + + @abstractmethod + def next_buffer(self) -> torch.Tensor: + ... + + +class _AllToAllRotater(_RingRotater): + """Use all_to_all to send the kv to the next rank""" + + def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: + self._pg = pg + self._seq_dim = seq_dim + self._buffer: Optional[torch.Tensor] = None + + def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: + curr_buffer = curr_buffer.contiguous() + size = dist.get_world_size(self._pg) + dsts = list(range(1, size)) + [0] + self._buffer = ft_c.permute_tensor(curr_buffer, dsts, self._pg) + + def next_buffer(self) -> torch.Tensor: + assert self._buffer is not None + return _maybe_wait(self._buffer) + + +class _AllGatherRotater(_RingRotater): + """ + Allgather the kv and return the only the requried kv. + Only one communication will be done. + """ + + def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: + self._pg = pg + self._seq_dim = seq_dim + self._aggregated_buffer: Optional[torch.Tensor] = None + self._idx = 0 + + def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: + # We only need to perform the allgather once. + self._idx += 1 + if self._aggregated_buffer is None: + self._aggregated_buffer = ft_c.all_gather_tensor( + curr_buffer.contiguous(), gather_dim=0, group=self._pg + ) + + def next_buffer(self) -> torch.Tensor: + size = dist.get_world_size(self._pg) + rank = dist.get_rank(self._pg) + idx = rank - self._idx + + assert self._aggregated_buffer is not None + self._aggregated_buffer = _maybe_wait(self._aggregated_buffer) + return self._aggregated_buffer.chunk(dist.get_world_size(self._pg))[idx] + + +def _create_rotater( + pg: dist.ProcessGroup, seq_dim: int, method: Optional[_RotateMethod] = None +) -> _RingRotater: + if method is None: + method = _cp_options.rotate_method + + if method == _RotateMethod.ALL_TO_ALL: + return _AllToAllRotater(pg, seq_dim) + elif method == _RotateMethod.ALL_GATHER: + return _AllGatherRotater(pg, seq_dim) + else: + raise NotImplementedError(f"Unkonwn method {method}") + + def _ring_rotate( block: torch.Tensor, pg: dist.ProcessGroup, send_to_next: bool ) -> torch.Tensor: @@ -382,17 +466,19 @@ def _templated_ring_attention( out: torch.Tensor logsumexp: torch.Tensor + rotater = _create_rotater(pg, 2) + for i in range(size): - if next_kv is not None: + if i > 0: # Wait for the kv from the (cp_rank - 1) rank. - next_kv = _maybe_wait(next_kv) + next_kv = rotater.next_buffer() key = next_kv[: key.numel()].reshape(key.shape) value = next_kv[key.numel() :].reshape(value.shape) if i < (size - 1): # Send the k, v to the next rank next_kv = torch.cat([key.flatten(), value.flatten()]) - next_kv = _ring_rotate(next_kv, pg, send_to_next=True) + next_kv = rotater.exchange_buffers(next_kv) is_causal_behavior = _is_causal_behavior( rank=rank, world_size=size, i=i, is_causal=is_causal @@ -546,10 +632,12 @@ def _templated_ring_attention_backward( key = key.contiguous() value = value.contiguous() + kv_rotater = _create_rotater(pg, 2) + dkv_rotater = _create_rotater(pg, 2, method=_RotateMethod.ALL_TO_ALL) for i in range(size): - if next_kv is not None: + if i > 0: # Wait for the kv from the (cp_rank - 1) rank. - buffer = _maybe_wait(next_kv) + buffer = kv_rotater.next_buffer() pointer = 0 key = buffer[pointer : pointer + key.numel()].reshape(key.shape) pointer += key.numel() @@ -559,7 +647,7 @@ def _templated_ring_attention_backward( if i != size - 1: # Send the kv to the next rank. next_kv = torch.cat([key.flatten(), value.flatten()]) - next_kv = _ring_rotate(next_kv, pg, send_to_next=True) + kv_rotater.exchange_buffers(next_kv) is_causal_behavior = _is_causal_behavior( rank=rank, world_size=size, i=i, is_causal=is_causal @@ -619,9 +707,8 @@ def _templated_ring_attention_backward( grad_value += grad_value_ else: pointer = 0 - assert next_grad_kv is not None # Wait for the kv gradient from (cp_rank - 1) rank. - next_grad_kv = _maybe_wait(next_grad_kv) + next_grad_kv = dkv_rotater.next_buffer() grad_key = next_grad_kv[pointer : pointer + grad_key.numel()].reshape( grad_key.shape ) @@ -653,7 +740,7 @@ def _templated_ring_attention_backward( next_grad_kv = torch.cat([grad_key.flatten(), grad_value.flatten()]) # Send the grad key, and grad value to the next rank. - next_grad_kv = _ring_rotate(next_grad_kv, pg, send_to_next=True) + dkv_rotater.exchange_buffers(next_grad_kv) if i <= rank or not _cp_options.enable_load_balance: grad_query += grad_query_ @@ -667,11 +754,10 @@ def _templated_ring_attention_backward( add=True, ) - assert next_grad_kv is not None assert grad_key_ is not None assert grad_value_ is not None grad_query = grad_query.to(query.dtype) - next_grad_kv = _maybe_wait(next_grad_kv).to(key.dtype) + next_grad_kv = dkv_rotater.next_buffer().to(key.dtype) grad_key = next_grad_kv[: grad_key.numel()].reshape(grad_key.shape) grad_value = next_grad_kv[grad_value.numel() :].reshape(grad_value.shape) return ( @@ -1087,7 +1173,6 @@ def unshard( ), "The current implementation only works if ROUND_ROBIN_CYCLE is 2." buffer = buffer.contiguous() cp_world_size = mesh.size() - cp_rank = mesh.get_local_rank() all_buffers = [torch.empty_like(buffer) for _ in range(cp_world_size)] ft_c.all_gather_inplace(all_buffers, buffer, mesh) diff --git a/torch/distributed/tensor/parallel/api.py b/torch/distributed/tensor/parallel/api.py index db4db018cce2d..19aa9b60a2c17 100644 --- a/torch/distributed/tensor/parallel/api.py +++ b/torch/distributed/tensor/parallel/api.py @@ -1,11 +1,12 @@ # Copyright (c) Meta Platforms, Inc. and affiliates +import warnings from fnmatch import fnmatch -from typing import Dict, Union +from typing import Dict, Optional, Union import torch import torch.distributed.tensor._random as random import torch.nn as nn -from torch.distributed.tensor import DeviceMesh +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh from torch.distributed.tensor._random import ( is_rng_supported_mesh, TensorParallelRNGTracker, @@ -21,8 +22,8 @@ def parallelize_module( # type: ignore[return] module: nn.Module, - device_mesh: DeviceMesh, - parallelize_plan: Union[ParallelStyle, Dict[str, ParallelStyle]], + device_mesh: Optional[DeviceMesh] = None, + parallelize_plan: Optional[Union[ParallelStyle, Dict[str, ParallelStyle]]] = None, ) -> nn.Module: """ Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan. @@ -39,14 +40,15 @@ def parallelize_module( # type: ignore[return] Args: module (:class:`nn.Module`): Module to be parallelized. - device_mesh (:class:`DeviceMesh`): - Object which describes the mesh topology - of devices for the DTensor. - parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]]): + device_mesh (:class:`DeviceMesh`, optional): + Object which describes the mesh topology of devices for the DTensor. + If not specified, the call must be under a DeviceMesh context. + parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]], optional): The plan used to parallelize the module. It can be either a - :class:`ParallelStyle` object which contains how - we prepare input/output for Tensor Parallelism or it can be a - dict of module FQN and its corresponding :class:`ParallelStyle` object. + :class:`ParallelStyle` object which contains how we prepare + input/output for Tensor Parallelism or it can be a dict of module + FQN and its corresponding :class:`ParallelStyle` object. If not + specified, the call will do nothing at the moment. Return: A :class:`nn.Module` object parallelized. @@ -67,8 +69,16 @@ def parallelize_module( # type: ignore[return] """ torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module") + device_mesh = device_mesh or _mesh_resources.get_current_mesh() _validate_tp_mesh_dim(device_mesh) + if parallelize_plan is None: + warnings.warn( + "No parallelize_plan is provided and auto-parallel is not supported " + "at the moment, so this parallelize_module call will do nothing." + ) + return module + # instantiate a TP RNG state tracker if it's not there if is_rng_supported_mesh(device_mesh) and not isinstance( random._rng_tracker, TensorParallelRNGTracker diff --git a/torch/distributed/tensor/parallel/loss.py b/torch/distributed/tensor/parallel/loss.py index 99f1e3ad6ef9a..693e80ed7adbd 100644 --- a/torch/distributed/tensor/parallel/loss.py +++ b/torch/distributed/tensor/parallel/loss.py @@ -279,7 +279,6 @@ def _nll_loss_forward_handler( ignore_index = cast(int, args[4]) channel_dim = 1 if x.dim() >= 2 else 0 - channel_dim_size = x.shape[channel_dim] spec = x._spec mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim) diff --git a/torch/distributions/kumaraswamy.py b/torch/distributions/kumaraswamy.py index 367e5d52e44a2..b55e1c67e4ddb 100644 --- a/torch/distributions/kumaraswamy.py +++ b/torch/distributions/kumaraswamy.py @@ -48,7 +48,6 @@ def __init__(self, concentration1, concentration0, validate_args=None): self.concentration1, self.concentration0 = broadcast_all( concentration1, concentration0 ) - finfo = torch.finfo(self.concentration0.dtype) base_dist = Uniform( torch.full_like(self.concentration0, 0), torch.full_like(self.concentration0, 1), diff --git a/torch/distributions/wishart.py b/torch/distributions/wishart.py index e0a00f0cc6db0..99eb9251e09b4 100644 --- a/torch/distributions/wishart.py +++ b/torch/distributions/wishart.py @@ -312,7 +312,6 @@ def log_prob(self, value): def entropy(self): nu = self.df # has shape (batch_shape) p = self._event_shape[-1] # has singleton shape - V = self.covariance_matrix # has shape (batch_shape x event_shape) return ( (p + 1) * ( diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 336b36424f31c..dbe4f2b72ed2b 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -48,9 +48,10 @@ "ExportBackwardSignature", "ExportGraphSignature", "ExportedProgram", + "CustomDecompTable", "ModuleCallEntry", "ModuleCallSignature", - "core_aten_decompositions", + "default_decompositions", "dims", "export", "export_for_training", @@ -64,9 +65,10 @@ ] +from .decomp_utils import CustomDecompTable from .dynamic_shapes import Constraint, Dim, dims, ShapesCollection from .exported_program import ( - core_aten_decompositions, + default_decompositions, ExportedProgram, ModuleCallEntry, ModuleCallSignature, diff --git a/torch/export/_trace.py b/torch/export/_trace.py index 42a32f2fa6e29..c175dd0091c8d 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -28,10 +28,6 @@ make_fake_inputs, produce_guards_and_solve_constraints, ) -from torch._export.passes._node_metadata_hook import ( - _node_metadata_hook, - _set_node_metadata_hook, -) from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass from torch._export.passes.lift_constants_pass import ( ConstantAttrMap, @@ -40,8 +36,9 @@ ) from torch._export.utils import ( _collect_param_buffer_metadata, - _get_shape_env_from_gm, _populate_param_buffer_metadata_to_new_gm, + _update_gm_meta_if_possible, + apply_runtime_assertion_pass, placeholder_naming_pass, placeholder_prefixes, ) @@ -66,9 +63,9 @@ from torch._library.fake_class_registry import FakeScriptObject from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch._utils_internal import log_export_usage +from torch.export._unlift import _check_input_constraints_pre_hook from torch.export.dynamic_shapes import _check_dynamic_shapes, _combine_args from torch.export.exported_program import OutputKind -from torch.fx._utils import first_call_function_nn_module_stack from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.symbolic_shapes import ( ConstraintViolationError, @@ -77,7 +74,7 @@ ShapeEnv, ) from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo -from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts +from torch.fx.graph_module import _get_attr from torch.utils._pytree import TreeSpec from torch.utils._sympy.value_ranges import ValueRangeError @@ -411,6 +408,100 @@ def _remap_constants( constants[target] = constant +def _produce_aten_artifact( + *, + gm, + mod, + constant_attrs, + graph_signature, + pre_dispatch, + fake_args, + fake_kwargs, + fake_params_buffers, +) -> ATenExportArtifact: + """ + This is a helper function that is shared between export_to_aten_ir and export_to_aten_ir_make_fx + to produce the aten artifact. (export compatible graph module + signature) + + It does: + 1. Applies runtime assertion pass + 2. Populate meta val when missing + 3. Lift constants as placeholders + 4. Replace raw autograd and autocast ops with HOPs + 5. Prettify names for placeholders + 6. Preserve requires_grad value on node meta val + """ + # Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature. + # Overwrite output specs afterwards. + flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs)) + gm, graph_signature = apply_runtime_assertion_pass(gm, graph_signature) + + total_non_user_inputs = ( + len(graph_signature.parameters) + + len(graph_signature.buffers) + + len(graph_signature.input_tokens) + ) + set_missing_meta_vals(gm, flat_fake_args, total_non_user_inputs) + + export_graph_signature = _convert_to_export_graph_signature( + graph_signature, gm, _get_non_persistent_buffers(mod) + ) + + constants = rewrite_script_object_meta(gm) + constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs)) + + if pre_dispatch: + from torch._export.passes.replace_autocast_with_hop_pass import ( + replace_autocast_with_hop_pass, + ) + from torch._export.passes.replace_set_grad_with_hop_pass import ( + replace_set_grad_with_hop_pass, + ) + + # Note: replace_set_grad_with_hop_pass need to be after lift_constant_pass because + # a getattr of a constant tensor doesn't have meta["val"] until after lift_constant_pass. + # If replace_set_grad_with_hop_pass is before lift_constant_pass, + # and the constant_tensor is passed as input of the set grad hop, the placeholder's + # meta["val"] will be None and fails our verifier for placeholder. + gm, export_graph_signature = replace_set_grad_with_hop_pass( + gm, export_graph_signature + ) + + gm, export_graph_signature = replace_autocast_with_hop_pass( + gm, export_graph_signature + ) + + # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes. + for _mod in gm.modules(): + if not isinstance(_mod, torch.fx.GraphModule): + continue + for node in _mod.graph.nodes: + if node.op in ["placeholder", "output"]: + node.meta.pop("nn_module_stack", None) + node.meta.pop("stack_trace", None) + + # Prettify names for placeholder nodes. + placeholder_naming_pass( + gm, + export_graph_signature, + mod, + fake_args, + fake_kwargs, + fake_params_buffers, + constants, + ) + + _preserve_requires_grad_pass( + gm, export_graph_signature, fake_params_buffers, constants, flat_fake_args + ) + + return ATenExportArtifact( + gm, + export_graph_signature, + constants, + ) + + def _rename_constants_nodes( gm: torch.fx.GraphModule, graph_signature: ExportGraphSignature, @@ -503,25 +594,29 @@ def _get_module_hierarchy(mod: torch.nn.Module) -> Dict[str, str]: def _make_module_call_graph( - module_hierarchy: Dict[str, str], in_spec: TreeSpec, out_spec: TreeSpec, module_call_signatures: Dict[str, ModuleCallSignature], forward_arg_names: Optional[List[str]] = None, ) -> List[ModuleCallEntry]: - ret = [ + original = [ ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn)) - for fqn in module_hierarchy + for fqn in _EXPORT_MODULE_HIERARCHY # type: ignore[union-attr] ] - assert ret[0].fqn == "" - ret[0].signature = ModuleCallSignature( + assert original[0].fqn == "" + original[0].signature = ModuleCallSignature( inputs=[], outputs=[], in_spec=in_spec, out_spec=out_spec, forward_arg_names=forward_arg_names, ) - return ret + additional = [ + ModuleCallEntry(fqn=fqn, signature=signature) + for fqn, signature in module_call_signatures.items() + if fqn not in _EXPORT_MODULE_HIERARCHY # type: ignore[operator] + ] + return [*original, *additional] def _export_to_torch_ir( @@ -675,102 +770,15 @@ def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm): except (ConstraintViolationError, ValueRangeError) as e: raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 - # Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature. - # Overwrite output specs afterwards. - flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs)) - if not torch._dynamo.config.do_not_emit_runtime_asserts: - stack_trace = ( - 'File "torch/fx/passes/runtime_assert.py", line 24, ' - "in insert_deferred_runtime_asserts" - ) - with _set_node_metadata_hook( - gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) - ): - shape_env = _get_shape_env_from_gm(gm) - if shape_env: - insert_deferred_runtime_asserts( - gm, - shape_env, - f"exported program: {first_call_function_nn_module_stack(gm.graph)}", - export=True, - ) - - # update output specs - gm.recompile() - graph_signature.user_outputs = _graph_output_names(gm) - - # NOTE: aot_export adds symint metadata for placeholders with int values; - # since these become specialized, we replace such metadata with the original values - index = 0 - total_non_user_inputs = ( - len(graph_signature.parameters) - + len(graph_signature.buffers) - + len(graph_signature.input_tokens) - ) - for node in gm.graph.nodes: - if node.op == "placeholder": - if index >= total_non_user_inputs: - user_arg = flat_fake_args[index - total_non_user_inputs] - if not isinstance(user_arg, torch.Tensor): - node.meta["val"] = user_arg - index += 1 - - export_graph_signature = _convert_to_export_graph_signature( - graph_signature, gm, _get_non_persistent_buffers(mod) - ) - - constants = rewrite_script_object_meta(gm) - constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs)) - - if pre_dispatch: - from torch._export.passes.replace_autocast_with_hop_pass import ( - replace_autocast_with_hop_pass, - ) - from torch._export.passes.replace_set_grad_with_hop_pass import ( - replace_set_grad_with_hop_pass, - ) - - # Note: replace_set_grad_with_hop_pass need to be after lift_constant_pass because - # a getattr of a constant tensor doesn't have meta["val"] until after lift_constant_pass. - # If replace_set_grad_with_hop_pass is before lift_constant_pass, - # and the constant_tensor is passed as input of the set grad hop, the placeholder's - # meta["val"] will be None and fails our verifier for placeholder. - gm, export_graph_signature = replace_set_grad_with_hop_pass( - gm, export_graph_signature - ) - - gm, export_graph_signature = replace_autocast_with_hop_pass( - gm, export_graph_signature - ) - - # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes. - for _mod in gm.modules(): - if not isinstance(_mod, torch.fx.GraphModule): - continue - for node in _mod.graph.nodes: - if node.op in ["placeholder", "output"]: - node.meta.pop("nn_module_stack", None) - node.meta.pop("stack_trace", None) - - # Prettify names for placeholder nodes. - placeholder_naming_pass( - gm, - export_graph_signature, - mod, - fake_args, - fake_kwargs, - fake_params_buffers, - constants, - ) - - _preserve_requires_grad_pass( - gm, export_graph_signature, fake_params_buffers, constants, flat_fake_args - ) - - return ATenExportArtifact( - gm, - export_graph_signature, - constants, + return _produce_aten_artifact( + gm=gm, + mod=mod, + constant_attrs=constant_attrs, + graph_signature=graph_signature, + pre_dispatch=pre_dispatch, + fake_args=fake_args, + fake_kwargs=fake_kwargs, + fake_params_buffers=fake_params_buffers, ) @@ -924,7 +932,7 @@ def _verify_stack_trace(graph_module: torch.fx.GraphModule) -> None: - None or non-empty str for 'call_function', 'get_attr' - None for 'placeholder', 'output' """ - for i, mod in enumerate([graph_module] + list(graph_module.modules())): + for mod in [graph_module, *graph_module.modules()]: if not isinstance(mod, torch.fx.GraphModule): continue for node in graph_module.graph.nodes: @@ -1047,7 +1055,15 @@ def _process_jit_trace_inputs_for_export(example_inputs, example_kwarg_inputs): def _process_export_inputs(mod, args, kwargs, dynamic_shapes): - original_state_dict = mod.state_dict(keep_vars=True) + # Explicitly not calling mode.state_dict() as we do not want the module state for serialization + # but the running module state so we can always match by id() the entries here with the graph inputs + named_parameters = dict(mod.named_parameters(remove_duplicate=False)) + named_buffers = dict(mod.named_buffers(remove_duplicate=False)) + original_state_dict = named_parameters | named_buffers + + non_persistent_buffers = _get_non_persistent_buffers(mod) + for k in non_persistent_buffers: + original_state_dict.pop(k, None) if not isinstance(args, tuple): raise UserError( @@ -1102,7 +1118,6 @@ def _get_module_call_graph( assert _EXPORT_MODULE_HIERARCHY is not None module_call_graph = _make_module_call_graph( - _EXPORT_MODULE_HIERARCHY, original_in_spec, out_spec, module_call_signatures, @@ -1529,20 +1544,6 @@ def wrapped_fn(*args): if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"): gm.meta.update(mod.meta) - flat_args = pytree.tree_leaves((fake_args, fake_kwargs)) - index = 0 - for node in gm.graph.nodes: - if node.op == "placeholder": - if index >= params_len: - user_arg = flat_args[index - params_len] - if not isinstance(user_arg, torch.Tensor): - node.meta["val"] = user_arg - index += 1 - - export_graph_signature = _convert_to_export_graph_signature( - graph_signature, gm, _get_non_persistent_buffers(mod) - ) - # See comment in _export_to_aten_ir() if produce_guards_callback: try: @@ -1550,55 +1551,44 @@ def wrapped_fn(*args): except (ConstraintViolationError, ValueRangeError) as e: raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 - fake_mode = detect_fake_mode(flat_args) - - if not torch._dynamo.config.do_not_emit_runtime_asserts: - stack_trace = ( - 'File "torch/fx/passes/runtime_assert.py", line 24, ' - "in insert_deferred_runtime_asserts" - ) - with _set_node_metadata_hook( - gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) - ): - insert_deferred_runtime_asserts( - gm, - fake_mode.shape_env, - f"exported program: {first_call_function_nn_module_stack(gm.graph)}", - export=True, - ) - - # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes. - for _mod in gm.modules(): - if not isinstance(_mod, torch.fx.GraphModule): - continue - for node in _mod.graph.nodes: - if node.op in ["placeholder", "output"]: - node.meta.pop("nn_module_stack", None) - node.meta.pop("stack_trace", None) - - constants = rewrite_script_object_meta(gm) - constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs)) - - _preserve_requires_grad_pass( - gm, export_graph_signature, fake_params_buffers, constants, flat_args + return _produce_aten_artifact( + gm=gm, + mod=mod, + constant_attrs=constant_attrs, + graph_signature=graph_signature, + pre_dispatch=True, + fake_args=fake_args, + fake_kwargs=fake_kwargs, + fake_params_buffers=fake_params_buffers, ) - # Prettify names for placeholder nodes. - placeholder_naming_pass( - gm, - export_graph_signature, - mod, - fake_args, - fake_kwargs, - fake_params_buffers, - constants, - ) - return ATenExportArtifact( - gm, - export_graph_signature, - constants, - ) +def set_missing_meta_vals(gm, flat_args, num_params_buffers): + # Sets missing metadata to address two problems: + # 1. aot_export adds symint metadata for placeholders with int values; since + # these become specialized, we replace such metadata with the original values. + # 2. any tensor attributes that are not params / buffers, i.e., are constants + # need to have their metadata set before lifting them because it is needed + # for computing the exported program's signature. + index = 0 + fake_mode = detect_fake_mode(flat_args) + for node in gm.graph.nodes: + if node.op == "placeholder": + if index >= num_params_buffers: + user_arg = flat_args[index - num_params_buffers] + if not isinstance(user_arg, torch.Tensor): + node.meta["val"] = user_arg + index += 1 + if node.op == "get_attr": + val = _get_attr(gm, node.target) + if isinstance(val, torch.Tensor): + assert "val" not in node.meta, ( + f"Found attribute {node.target} that has already been fakified " + "but not yet lifted as an input. This should be impossible because " + "(1) we should have already fakified AND lifted params/buffers " + "(2) we should have NOT yet fakified OR lifted tensor constants. " + ) + node.meta["val"] = fake_mode.from_tensor(val, static_shapes=True) def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node: @@ -1638,13 +1628,22 @@ def __init__(self, mod): def forward(self, *args, **kwargs): nonlocal out_spec - if isinstance(self._export_root, torch.fx.GraphModule): + mod = self._export_root + if isinstance(mod, torch.fx.GraphModule): + # NOTE: We're going to run this graph module with an fx interpreter, + # which will not run any forward hooks. Thus, ideally, we should run + # all forward hooks here. But the general logic for running them is + # complicated (see nn/module.py), and probably not worth duplicating. + # Instead we only look for, and run, an export-specific forward hook. + if ( + _check_input_constraints_pre_hook + in mod._forward_pre_hooks.values() + ): + _check_input_constraints_pre_hook(mod, args, kwargs) with torch.fx.traceback.preserve_node_meta(): - tree_out = torch.fx.Interpreter(self._export_root).run( - *args, **kwargs - ) + tree_out = torch.fx.Interpreter(mod).run(*args, **kwargs) else: - tree_out = self._export_root(*args, **kwargs) + tree_out = mod(*args, **kwargs) flat_outs, out_spec = pytree.tree_flatten(tree_out) return tuple(flat_outs) @@ -1829,6 +1828,8 @@ def _export_for_training( _verify_stack_trace(gm) _verify_placeholder_names(gm, export_graph_signature) + _update_gm_meta_if_possible(gm, mod) + from torch._export.verifier import TrainingIRVerifier exported_program = ExportedProgram( @@ -1982,12 +1983,7 @@ def _export( from torch._export.verifier import Verifier - if ( - isinstance(mod, torch.fx.GraphModule) - and hasattr(mod, "meta") - and "custom" in mod.meta - ): - gm.meta.update({"custom": mod.meta["custom"]}) + _update_gm_meta_if_possible(gm, mod) exported_program = ExportedProgram( root=gm, diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 4f6a45585ca05..a422950fa4788 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -2,7 +2,7 @@ import copy import warnings from itertools import chain -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple import torch import torch.utils._pytree as pytree @@ -43,7 +43,7 @@ def _check_input_constraints_pre_hook(self, *args, **kwargs): def _unlift_inputs_as_getattr( gm: torch.fx.GraphModule, - lifted_inputs: List[Optional[str]], + lifted_inputs: Sequence[Optional[str]], ) -> Tuple[Dict[str, torch.fx.Node], Dict[str, torch.fx.Node]]: """ Unlift inputs referring to params/buffers/constants as getattr nodes in the @@ -72,7 +72,7 @@ def _unlift_inputs_as_getattr( def _insert_copy_for_mutations( gm: torch.fx.GraphModule, - mutated_outputs: List[Optional[str]], + mutated_outputs: Sequence[Optional[str]], unlifted_name_to_node: Dict[str, torch.fx.Node], input_name_to_node: Dict[str, torch.fx.Node], ) -> None: @@ -158,8 +158,8 @@ def _get_codegen( def _unlift( gm: torch.fx.GraphModule, - lifted_inputs: List[Optional[str]], - mutated_outputs: List[Optional[str]], + lifted_inputs: Sequence[Optional[str]], + mutated_outputs: Sequence[Optional[str]], in_spec: pytree.TreeSpec, out_spec: Optional[pytree.TreeSpec], state_dict: Dict[str, Any], @@ -280,6 +280,13 @@ def _create_stateful_graph_module( if ep is None: return stateful_gm + # When we have a constant that has requires_grad=True, we need to detach it + # when we unlift as the tensors that require gradients should be registered + # via parameters. But this is problematic when we have aliasing two constants + # because when we call detach, they will become different tensors. This dict + # keeps track of this logic. + original_tensor_to_detached_tensor = {} + # Fix up lifted tensor constants. # fx.GraphModule() constructor silently turns a constant attribute of plain_graph_module # into a buffer in stateful_gm and creates an inconsistency with graph_signature. @@ -299,7 +306,9 @@ def _create_stateful_graph_module( f"torch.export will detach it and treat it as a constant tensor " f"but please register it as parameter instead." ) - buffer = buffer.detach() + detached_buffer = buffer.detach() + original_tensor_to_detached_tensor[buffer] = detached_buffer + buffer = detached_buffer *prefix, field = constant_fqn.rsplit(".") submod = torch.fx.graph_module._get_attr_via_attr_list(stateful_gm, prefix) delattr(submod, field) @@ -309,6 +318,19 @@ def _create_stateful_graph_module( for const_name, value in ep.constants.items(): if not torch.fx.graph_module._has_attr(stateful_gm, const_name): if isinstance(value, torch.Tensor): + if value.requires_grad: + warnings.warn( + f"A model attribute `{const_name}` requires gradient " + f"but it's not properly registered as a parameter. " + f"torch.export will detach it and treat it as a constant tensor " + f"but please register it as parameter instead." + ) + if value in original_tensor_to_detached_tensor: + value = original_tensor_to_detached_tensor[value] + else: + detached_value = value.detach() + original_tensor_to_detached_tensor[value] = detached_value + value = detached_value _assign_attr( value, stateful_gm, diff --git a/torch/export/decomp_utils.py b/torch/export/decomp_utils.py new file mode 100644 index 0000000000000..1f4a8f1a25ab9 --- /dev/null +++ b/torch/export/decomp_utils.py @@ -0,0 +1,144 @@ +# mypy: allow-untyped-defs +from typing import Callable, Dict + +import torch +from torch._export.utils import ( + _collect_all_valid_cia_ops, + _collect_all_valid_cia_ops_for_aten_namespace, + _get_decomp_for_cia, + _is_aten_op, +) + + +__all__ = ["CustomDecompTable"] + + +class CustomDecompTable(Dict[torch._ops.OperatorBase, Callable]): + """ + This is a custom dictionary that is specifically used for handling decomp_table in export. + The reason we need this is because in the new world, you can only *delete* an op from decomp + table to preserve it. This is problematic for custom ops because we don't know when the custom + op will actually be loaded to the dispatcher. As a result, we need to record the custom ops operations + until we really need to materialize it (which is when we run decomposition pass.) + + Invariants we hold are: + 1. All aten decomp is loaded at the init time + 2. We materialize ALL ops when user ever reads from the table to make it more likely + that dispatcher picks up the custom op. + 3. If it is write operation, we don't necessarily materialize + 4. We load the final time during export, right before calling run_decompositions() + + """ + + def __init__(self): + super().__init__() + from torch._decomp import _core_aten_decompositions_post_autograd + + # For aten ops, we load them up in the beginning + self.decomp_table = _core_aten_decompositions_post_autograd() + + for op in _collect_all_valid_cia_ops_for_aten_namespace(): + self.decomp_table[op] = _get_decomp_for_cia(op) + + # This is to track the *pending* deleted custom ops that haven't been materialized yet + self.deleted_custom_ops = set() + # When this is true, there shouldn't be any pending operations in the table. + self.has_materialized = False + + def __getitem__(self, key): + self._materialize_if_needed() + return self.decomp_table.__getitem__(key) + + def __setitem__(self, key, value) -> None: + self.decomp_table.__setitem__(key, value) + + if key in self.deleted_custom_ops: + self.deleted_custom_ops.remove(key) + + def keys(self): + self._materialize_if_needed() + return self.decomp_table.keys() + + def __delitem__(self, key) -> None: + self.pop(key) + + def update(self, other_dict): # type: ignore[override] + for k, v in other_dict.items(): + self.decomp_table.__setitem__(k, v) + + def __missing__(self, key) -> bool: + return not self.__contains__(key) + + def __contains__(self, key) -> bool: + self._materialize_if_needed() + return self.decomp_table.__contains__(key) + + def __len__(self) -> int: + self._materialize_if_needed() + return self.decomp_table.__len__() + + def __iter__(self): + self._materialize_if_needed() + return self.decomp_table.__iter__() + + def __reversed__(self): + self._materialize_if_needed() + return self.decomp_table.__reversed__() + + def copy(self) -> "CustomDecompTable": + new_dict = CustomDecompTable() + new_dict.decomp_table = self.decomp_table.copy() + new_dict.deleted_custom_ops = self.deleted_custom_ops.copy() + new_dict.has_materialized = self.has_materialized + return new_dict + + def pop(self, *args): + def _pop_if_can(key): + if _is_aten_op(key): + return self.decomp_table.pop(key) + + if key in self.decomp_table: + # Even if we materialized it, we should add it to the deleted + # custom ops list so that when we materialize next time, + # we should respect user's intention. + self.deleted_custom_ops.add(key) + return self.decomp_table.pop(key) + + if key in self.deleted_custom_ops: + raise KeyError(f"{key} doesn't exist in the table") + + self.deleted_custom_ops.add(key) + # We would come here when user pops off something that is + # not in the table. In this case, we just pretend that it + # was in the table. + return _get_decomp_for_cia(key) + + if len(args) == 1: + return _pop_if_can(args[0]) + + if len(args) == 2: + try: + return _pop_if_can(args[0]) + except KeyError: + return args[1] + + def items(self): + self._materialize_if_needed() + return self.decomp_table.items() + + def materialize(self) -> Dict[torch._ops.OperatorBase, Callable]: + for op in _collect_all_valid_cia_ops(): + if _is_aten_op(op): + continue + elif op in self.decomp_table: + continue + elif op not in self.deleted_custom_ops: + self.decomp_table[op] = _get_decomp_for_cia(op) + + self.has_materialized = True + self.deleted_custom_ops = set() + return {**self.decomp_table} + + def _materialize_if_needed(self) -> None: + if not self.has_materialized: + self.materialize() diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 902e7a9108f61..c91fe46b71a02 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -374,7 +374,24 @@ def serializable_spec(self): } -Constraint = Union[_Constraint, _DerivedConstraint] +@dataclasses.dataclass +class _RelaxedConstraint(_ConstraintTarget): + """ + This represents a dim marked with Dim.AUTO/DYNAMIC (i.e. mark_dynamic() or maybe_mark_dynamic()), + which leaves relations & min/max ranges for inference, instead of requiring explicit specification. + The intention is for constraint violations to not be raised if produce_guards() finds equalities or + relations between a _RelaxedConstraint and another type of _Constraint. + """ + + @property + def serializable_spec(self): + return { + "t_id": self.t_id, + "dim": self.dim, + } + + +Constraint = Union[_Constraint, _DerivedConstraint, _RelaxedConstraint] def _process_equalities( @@ -385,6 +402,7 @@ def _process_equalities( source_pairs: List[Tuple["Source", "Source"]], derived_equalities: List[Tuple["Source", Union["Source", "Symbol"], Callable]], phantom_symbols: Dict[str, "Symbol"], + relaxed_sources: Set["Source"], ): """ Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become @@ -399,7 +417,7 @@ def _process_equalities( # When t.size()[dim] maps to src0, src1, ..., srcN, we add # constraints that make src0 "equal" to src1, ..., srcN. source_pairs.extend((source, other_source) for other_source in other_sources) - if not isinstance(constraint, _DerivedConstraint): + if isinstance(constraint, _Constraint): if constraint.name in names: shared_t_id, shared_dim = names[constraint.name] other_sources = get_sources(shared_t_id, shared_dim) @@ -408,7 +426,7 @@ def _process_equalities( ) else: names[constraint.name] = (constraint.t_id, constraint.dim) - else: + elif isinstance(constraint, _DerivedConstraint): # branch based on the root of the _DerivedConstraint if not isinstance(constraint.root, _PhantomRoot): # either root points to an input source @@ -431,6 +449,8 @@ def _process_equalities( # A derived equality (source, root, fn) informally corresponds to source = fn(root). # Here source describes an input and root might describe another input or a phantom symbol. derived_equalities.append((source, root, fn)) + elif isinstance(constraint, _RelaxedConstraint): + relaxed_sources.add(source) def _tree_map_with_path( @@ -662,7 +682,6 @@ def _check_dynamic_shapes( using combined args + kwargs as reference for inputs structure. """ from torch._dynamo.exc import UserError, UserErrorType - from torch._export.non_strict_utils import _flatten_dynamic_shapes if dynamic_shapes is None or len(dynamic_shapes) == 0: return @@ -768,24 +787,6 @@ def check_shape(path, t, dynamic_shape): _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs") - # raise user warning if both Dim.AUTO & Dims are specified in dynamic_shapes - flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes) - flatter_dynamic_shapes, _ = tree_flatten(flat_dynamic_shapes) - if any(isinstance(s, _Dim) for s in flatter_dynamic_shapes) and any( - s == _DimHint.AUTO for s in flatter_dynamic_shapes - ): - raise UserError( - UserErrorType.INVALID_INPUT, - "Specifying both `Dim.AUTO/Dim.DYNAMIC` and `Dim/DerivedDim` in `dynamic_shapes` is not " - "well supported at the moment, and can easily lead to constraint violation errors or obscure errors " - "in torch.export. Dim/DerivedDims expect all equal or related dimensions to be specified, " - "and do not yet compose well with `Dim.AUTO`. We suggest using `Dim.AUTO/Dim.DYNAMIC` mixed with " - "`Dim.STATIC` for auto-dynamic + static shapes, plus torch._check(dim >= min), torch._check(dim <= max) " - "calls in your program to specify min/max ranges, or `Dim`/`DerivedDim` mixed with `Dim.STATIC` " - "if you want to assert on the exact specification of your program's dynamic shapes behavior.", - case_name="dynamic_shapes_validation", - ) - def _process_dynamic_shapes( combined_args: Dict[str, Any], @@ -919,6 +920,7 @@ def _create_static_dim(tensor, i, value): torch._dynamo.mark_static(tensor, i) elif dim == _DimHint.DYNAMIC: torch._dynamo.mark_dynamic(tensor, i) + constraints.append(_RelaxedConstraint(id(tensor), i)) elif dim is None: torch._dynamo.mark_static(tensor, i) elif isinstance(shape, (tuple, list)): @@ -935,6 +937,7 @@ def _create_static_dim(tensor, i, value): torch._dynamo.mark_static(tensor, i) elif dim == _DimHint.DYNAMIC: torch._dynamo.mark_dynamic(tensor, i) + constraints.append(_RelaxedConstraint(id(tensor), i)) elif dim is None: torch._dynamo.mark_static(tensor, i) elif shape is None: diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index e357f8f067eef..c9214494ab50d 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -44,18 +44,23 @@ import torch import torch.utils._pytree as pytree from torch._export.utils import ( + _collect_all_valid_cia_ops, _collect_and_set_constant_attrs, _collect_param_buffer_metadata, _detect_fake_mode_from_gm, + _get_decomp_for_cia, + _is_preservable_cia_op, _name_hoo_subgraph_placeholders, _overwrite_signature_for_non_persistent_buffers, _populate_param_buffer_metadata_to_new_gm, _rename_without_collisions, + _special_op_to_preserve_cia, ) from torch._export.verifier import Verifier from torch._guards import detect_fake_mode from torch._subclasses.fake_tensor import unset_fake_temporarily from torch.export._tree_utils import is_equivalent, reorder_kwargs +from torch.export.decomp_utils import CustomDecompTable from torch.fx._compatibility import compatibility from torch.fx.passes.infra.pass_base import PassResult from torch.fx.passes.infra.pass_manager import PassManager @@ -79,7 +84,7 @@ "ExportedProgram", "ModuleCallEntry", "ModuleCallSignature", - "core_aten_decompositions", + "default_decompositions", ] @@ -196,8 +201,6 @@ def _override_composite_implicit_decomp(cia_ops_to_callable, safe=True): # replace their CompositeImplicitAutograd kernels with NotImplemented. # The only current users of this mode are variants of aten::to that we will # replace with aten::_to_copy in FunctionalTensorMode.__torch_dispatch__. - from torch._decomp import _get_decomp_for_cia - saved_tables = {} patched_ops = set() for op_overload, decomp_callable in cia_ops_to_callable.items(): @@ -248,13 +251,6 @@ def _override_composite_implicit_decomp(cia_ops_to_callable, safe=True): op._dispatch_cache.clear() -def _special_op_to_preserve_cia(*args, **kwargs): - """ - This is an special marker that tells our infra that we shouldn't decompose this op. - """ - return NotImplemented - - @contextmanager def _override_decomp_aten_to_variants(): # Preserve variants of aten::to understanding that they are mutating/aliasing @@ -273,8 +269,6 @@ def _override_decomp_aten_to_variants(): def _split_decomp_table_to_cia_and_python_decomp( decomp_table: Dict[torch._ops.OperatorBase, Callable] ) -> Tuple[Dict[torch._ops.OperatorBase, Callable], ...]: - from torch._decomp import _collect_all_valid_cia_ops, _is_preservable_cia_op - all_preservable_cia_ops = set(_collect_all_valid_cia_ops()) cia_ops_to_callable = {} @@ -316,15 +310,13 @@ def _split_decomp_table_to_cia_and_python_decomp( return cia_ops_to_callable, decomp_table -def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]: +def default_decompositions() -> "CustomDecompTable": """ This is the default decomposition table which contains decomposition of all ATEN operators to core aten opset. Use this API together with :func:`run_decompositions()` """ - from torch._decomp import core_aten_decompositions - - return core_aten_decompositions() + return CustomDecompTable() def _decompose_and_get_gm_with_new_signature_constants( @@ -345,9 +337,13 @@ def _decompose_and_get_gm_with_new_signature_constants( ) from torch.fx.experimental.symbolic_shapes import ShapeEnv - # TODO Merge this path with inference IR decomp, but it will require some additional work - # so I will leave it for now. T200307782 - if ep.verifier.dialect == "TRAINING": + def _is_joint_ir_decomp(ep, joint_loss_index): + return ( + joint_loss_index is not None + or ep.graph_signature.backward_signature is not None + ) + + if not _is_joint_ir_decomp(ep, joint_loss_index): mod = ep.module() fake_args = [] @@ -356,7 +352,8 @@ def _decompose_and_get_gm_with_new_signature_constants( fake_args.append(node.meta["val"]) fake_args_unwrapped = pytree.tree_unflatten(fake_args, mod._in_spec) - fake_mode = _detect_fake_mode_from_gm(mod) + # TODO T204030333 + fake_mode = _detect_fake_mode_from_gm(ep.graph_module) if fake_mode is None: fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True) @@ -531,9 +528,10 @@ def update_arg(old_arg, new_ph): ) for i, spec in enumerate(ep.graph_signature.input_specs) ] + output_specs = [ OutputSpec( - spec.kind, + OutputKind.LOSS_OUTPUT if joint_loss_index is not None else spec.kind, update_arg(spec.arg, new_outputs[i]), old_new_placeholder_map.get(spec.target, spec.target), ) @@ -644,6 +642,30 @@ def _common_getitem_elimination_pass( node_id[node] = node.name +def _get_updated_module_call_graph( + gm: torch.fx.GraphModule, + old_module_call_graph: List[ModuleCallEntry], +): + new_module_call_graph = copy.deepcopy(old_module_call_graph) + + # use node-level provenance metadata to create a map + # from old node names to new node names + provenance: Dict[str, str] = {} + for node in gm.graph.nodes: + if history := node.meta.get("from_node", []): + provenance[history[-1][0]] = node.name + + # map old names to new names in module call signatures + for entry in new_module_call_graph: + signature = entry.signature + if signature is None: + continue + for x in [*signature.inputs, *signature.outputs]: + x.name = provenance.get(x.name, x.name) + + return new_module_call_graph + + def _decompose_exported_program( ep, *, @@ -658,6 +680,15 @@ def _decompose_exported_program( joint_loss_index=joint_loss_index, ) + # The signatures of ep.module_call_graph refer to input / output nodes of + # the original graph module. However, the new graph module may have + # new nodes due to decompositions. So we need to update these signatures + # in the decomposed exported program's module_call_graph. + new_module_call_graph = _get_updated_module_call_graph( + gm, + ep.module_call_graph, + ) + # TODO unfortunately preserving graph-level metadata is not # working well with aot_export. So we manually copy it. # (The node-level meta is addressed above.) @@ -674,7 +705,7 @@ def _decompose_exported_program( graph_signature=new_graph_signature, state_dict=ep.state_dict, range_constraints=new_range_constraints, - module_call_graph=copy.deepcopy(ep.module_call_graph), + module_call_graph=new_module_call_graph, example_inputs=ep.example_inputs, constants=ep.constants, ) @@ -1022,7 +1053,6 @@ def _num_lifted_params_buffers(self): def run_decompositions( self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None, - _preserve_ops: Tuple[torch._ops.OpOverload, ...] = (), ) -> "ExportedProgram": """ Run a set of decompositions on the exported program and returns a new @@ -1053,43 +1083,16 @@ def run_decompositions( .. code-block:: python ep = torch.export.export(model, ...) - decomp_table = torch.export.core_aten_decompositions() + decomp_table = torch.export.default_decompositions() decomp_table[your_op] = your_custom_decomp ep = ep.run_decompositions(decomp_table=decomp_table) """ - from torch._decomp import ( - _decomp_table_to_post_autograd_aten, - _is_preservable_cia_op, - core_aten_decompositions, - ) - from torch._inductor import config - - # FIXME delete this option after PTC, Executorch syncing is - # bit annoying so can't get rid of it easily - if _preserve_ops != (): - warnings.warn( - "This API is deprecated and soon will be removed. " - "Please look at the docstring to see how to preserve " - "an operator." - ) - _decomp_table = ( - core_aten_decompositions() if decomp_table is None else dict(decomp_table) + default_decompositions() if decomp_table is None else dict(decomp_table) ) - if config.is_fbcode(): - # This means the decomp_table would only be containing post-autograd ops - # We should manually add CIA decomps - for k, v in _decomp_table_to_post_autograd_aten().items(): - _decomp_table[k] = v - - for op in _preserve_ops: - if op in _decomp_table: - del _decomp_table[op] - # This is needed when the op they want to preserve is a - # CIA op. - elif _is_preservable_cia_op(op): - _decomp_table[op] = _special_op_to_preserve_cia + if isinstance(_decomp_table, CustomDecompTable): + _decomp_table = _decomp_table.materialize() # Note [Seperating decomp_table into CIA decomps and non-CIA decomps] # At this point, we have a decomp_table that contains decomp behaviour for diff --git a/torch/export/passes/__init__.py b/torch/export/passes/__init__.py index c523b954e88e7..57466bee49d0a 100644 --- a/torch/export/passes/__init__.py +++ b/torch/export/passes/__init__.py @@ -41,7 +41,8 @@ def _get_new_device( for k, v in ep.state_dict.items(): if isinstance(v, torch.nn.Parameter): ep._state_dict[k] = torch.nn.Parameter( - v.to(_get_new_device(v.device, location)) + v.to(_get_new_device(v.device, location)), + v.requires_grad, ) else: ep._state_dict[k] = v.to(_get_new_device(v.device, location)) diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 2c5046da504ff..76e804716b48a 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -19,6 +19,7 @@ from torch.export.exported_program import ( ConstantArgument, ExportedProgram, + ExportGraphSignature, InputKind, ModuleCallSignature, SymIntArgument, @@ -232,7 +233,15 @@ def __init__( self._run_with_interpeter = RUN_WITH_INTERPRETER _inplace_buffer_mutations(export_graph, self.graph_signature) + + self.ivals = _IVals() + # record any intermediate value x that is used, with the modules that used it, + # and generate instructions to read the corresponding attribute seen_modules = _outline_submodules(export_graph, self) + # for each read intermediate value x, find the module that created it, + # and generate instructions to update the corresponding attribute; + # finally, initialize all these attributes + self.ivals.create(seen_modules.values()) self.range_constraints = export_module.range_constraints self.equality_constraints: List = [] @@ -567,13 +576,14 @@ def unflatten( An instance of :class:`UnflattenedModule`, which has the same module hierarchy as the original eager module pre-export. """ - if module.verifier.dialect == "TRAINING": - raise RuntimeError("Unflattener doesn't support non-functional training IR yet") module = _remove_effect_tokens(module) return UnflattenedModule(module, flat_args_adapter) -def _inplace_buffer_mutations(graph: torch.fx.Graph, graph_signature) -> None: +def _inplace_buffer_mutations( + graph: torch.fx.Graph, + graph_signature: ExportGraphSignature, +) -> None: """Transform buffer mutations from their functionalized form into a copy_ node in the graph. @@ -751,7 +761,7 @@ def __init__( seen_nodes, seen_modules, parent, - module_stack: List[str], + module_stack: List[Tuple[str, int]], module_id, module_call_graph: Dict[str, ModuleCallSignature], module: Optional[torch.nn.Module] = None, @@ -767,11 +777,16 @@ def __init__( self.module_call_graph = module_call_graph self.verbose = False - self.fqn = self.module_stack[-1] + self.fqn, num_calls = self.module_stack[-1] + # generate call name for self.fqn + self.child_fqn = _call_name(self.fqn, num_calls + 1) + if module is not None: self.module = module + self.ivals = module.ivals if hasattr(module, "ivals") else {} else: self.module = InterpreterModule(torch.fx.Graph()) + self.ivals = parent.ivals self.graph = self.module.graph @@ -781,19 +796,7 @@ def __init__( self.parent_call_module: Optional[torch.fx.Node] = None if parent is not None: - num_calls = len( - [x for x in self.seen_modules[self.module_id] if x.fqn == self.fqn] - ) - if self.fqn in module_call_graph and num_calls == 1: - raise ValueError( - f"Cannot unflatten multiple calls to module {self.fqn} while preserving its signature " - "because each of these calls might have generated a different specialized graph. " - f"If the reason you want to preserve the signature is to swap {self.fqn} with another module, " - "consider using _swap_modules() directly on the exported program instead of unflattening it." - ) - # generate call name for self.fqn - child_fqn = _call_name(self.fqn, num_calls + 1) - accessor = _compute_accessor(parent.fqn, child_fqn) + accessor = _compute_accessor(parent.fqn, self.child_fqn) _add_submodule(parent.module, accessor, self.module) self.parent_call_module = parent.graph.call_module(accessor) self.seen_modules[self.module_id].append( @@ -807,7 +810,7 @@ def __init__( ) ) - signature = module_call_graph.get(self.fqn) + signature = module_call_graph.get(self.child_fqn) if signature is not None and self.parent is not None: assert signature.in_spec.num_children == 2 args_spec = signature.in_spec.children_specs[0] @@ -946,6 +949,10 @@ def remap_input(self, x): # if module call signature needs to be preserved self.copy_sym_call_function(x) return self.node_map[x] + elif self.module_call_graph.get(self.fqn) is not None: + # x is an ival that is not in placeholders, so create a + # get_attr node corresponding to attribute __ival__x + return self.ivals.read(self.fqn, self.graph, x) else: raise RuntimeError( f"Could not run remap_input() on op type: {x.op} for node {x}" @@ -954,7 +961,7 @@ def remap_input(self, x): def finalize_outputs(self): orig_outputs = [] - signature = self.module_call_graph.get(self.fqn) + signature = self.module_call_graph.get(self.child_fqn) if signature is not None and self.parent is not None: for output in signature.outputs: if isinstance(output, (TensorArgument, SymIntArgument)): @@ -1070,8 +1077,9 @@ def run_from(self, node_idx): self.print() self.print("STEP", node_idx, node.format_node()) self.print(self.module_stack) + depth = len(self.module_stack) if node.op == "output": - if len(self.module_stack) == 1: + if depth == 1: # We want the output node of the original graph to be handled # specially by the outermost stack frame (in run_outer). So # skip finalization here. @@ -1097,10 +1105,11 @@ def run_from(self, node_idx): node_module_stack = self.module_stack else: node_module_stack = [ - path for path, ty in node.meta["nn_module_stack"].values() + (path, int(k.split("@")[-1]) if "@" in k else 0) + for k, (path, ty) in node.meta["nn_module_stack"].items() ] - if node_module_stack[: len(self.module_stack)] != self.module_stack: + if node_module_stack[:depth] != self.module_stack: # This means that the current module is done executing and the # current node is the beginning of a new module. # @@ -1116,10 +1125,11 @@ def run_from(self, node_idx): if _is_prefix(self.module_stack, node_module_stack): # This means that the current node represents the execution of a new # module. - next_module = node_module_stack[len(self.module_stack)] + next_module = node_module_stack[depth] self.print("Creating new stack frame for", next_module) # Run a nested version of module outliner from the current node # counter. Once it is complete, continue from that point. + next_module_key = list(node.meta["nn_module_stack"].keys())[depth] node_idx = _ModuleFrame( self.flat_graph, self.nodes, @@ -1127,7 +1137,7 @@ def run_from(self, node_idx): self.seen_modules, self, self.module_stack + [next_module], - list(node.meta["nn_module_stack"].keys())[len(self.module_stack)], + next_module_key.split("@")[0], self.module_call_graph, ).run_from(node_idx) module_idx += 1 @@ -1159,7 +1169,7 @@ def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModu seen_nodes, seen_modules, None, - [""], + [("", 0)], "", { entry.fqn: entry.signature @@ -1193,6 +1203,106 @@ def _reorder_submodules( parent.register_module(name, child) +class _IVals: + """ + Collect the intermediate values of buffer mutations in a graph, + along with the module call fqns that create and use them. Later, + in each fqn associated with an intermediate value we will install + a corresponding attribute, so that it can be updated and read. + + Example: in the following graph, suppose that buf_in and buf_out + are the input and output values of a buffer. + + buf_in = placeholder() + ... + ival1 = f0(buf_in, ...) # inside self.n0(...) + ... + ival2 = f1(ival1, ...) # inside self.n1(...) + ... + buf_out = f2(ival2, ...) # inside self.n2(...) + return buf_out, ... + + Here ival1 and ival2 are intermediate values created inside + calls to n0 and n1 respectively, and used inside calls to + n1 and n2 respectively. + + Thus our analysis will produce {ival1: {n0, n1}, ival2: {n1, n2}}. + """ + + def __init__(self): + # ival node name -> set of fqns that create and use it + self.fqns = defaultdict(set) + # ival node name -> tensor storage for corresponding attribute + self.storage = {} + + def read(self, fqn, graph, node): + """ + Read attribute corresponding to a given intermediate value. + """ + # to read ival x, get attribute __ival__x + with graph.inserting_before(None): + ival_node = graph.get_attr("__ival__" + node.name, type_expr=node.type) + ival_node.meta = copy.copy(node.meta) + + if node.name not in self.storage: + # create empty tensor matching fake, using a cache + # to ensure the same tensor is returned per ival_name + fake = node.meta["val"] + self.storage[node.name] = torch.empty(fake.shape, dtype=fake.dtype) + self.fqns[node.name].add(fqn) + + return ival_node + + def update(self, fqn, graph, node): + """ + Update attribute corresponding to a given intermediate value. + """ + self.fqns[node.name].add(fqn) + + # to update ival x, get attribute __ival__x and copy x to __ival__x + with graph.inserting_after(node): + ival_node = graph.get_attr("__ival__" + node.name, type_expr=node.type) + ival_node.meta = copy.copy(node.meta) + with graph.inserting_after(ival_node): + new_ival_node = graph.create_node( + "call_function", torch.ops.aten.copy_, (ival_node, node) + ) + new_ival_node.meta = copy.copy(node.meta) + + def create(self, partitions): + """ + Update attributes corresponding to intermediate values that were read. + Finally, initialize attributes in all modules that read or update + corresponding intermediate values. + """ + + entries = [] + for shared_submodules in partitions: + for entry in shared_submodules: + entries.append(entry) + graph = entry.module.graph + for node in graph.nodes: + if node.name in self.storage: + self.update(entry.fqn, graph, node) + + # fqn -> list of ival node names read or updated through it + ivals = defaultdict(list) + for name, fqns in self.fqns.items(): + for fqn in fqns: + ivals[fqn].append(name) + + for entry in entries: + for name in ivals[entry.fqn]: + ival_name = f"__ival__{name}" + # for a ival named x created in module call m, + # create attribute m.__ival__x, initially empty + setattr( + entry.module, + ival_name, + self.storage[name], + ) + + def _deduplicate_modules(partitions): for shared_submodules in partitions: for i, entry in enumerate(shared_submodules): @@ -1296,7 +1406,7 @@ def _sink_params( state_name = None for sn in inputs_to_state[node.name]: sn_split = sn.split(".") - if sn_split[: len(scope)] == scope: + if sn_split[: len(scope)] == [x.split("@")[0] for x in scope]: state_name = sn_split break diff --git a/torch/functional.py b/torch/functional.py index 9180262708eac..7327c29514ce3 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -257,17 +257,22 @@ def einsum(*args: Any) -> Tensor: .. note:: - This function uses opt_einsum (https://optimized-einsum.readthedocs.io/en/stable/) to speed up computation or to - consume less memory by optimizing contraction order. This optimization occurs when there are at least three - inputs, since the order does not matter otherwise. Note that finding _the_ optimal path is an NP-hard problem, - thus, opt_einsum relies on different heuristics to achieve near-optimal results. If opt_einsum is not available, - the default order is to contract from left to right. + Please install opt-einsum (https://optimized-einsum.readthedocs.io/en/stable/) in order to enroll into a more + performant einsum. You can install when installing torch like so: `pip install torch[opt-einsum]` or by itself + with `pip install opt-einsum`. - To bypass this default behavior, add the following line to disable the usage of opt_einsum and skip path - calculation: `torch.backends.opt_einsum.enabled = False` + If opt-einsum is available, this function will automatically speed up computation and/or consume less memory + by optimizing contraction order through our opt_einsum backend :mod:`torch.backends.opt_einsum` (The _ vs - is + confusing, I know). This optimization occurs when there are at least three inputs, since the order does not matter + otherwise. Note that finding `the` optimal path is an NP-hard problem, thus, opt-einsum relies on different + heuristics to achieve near-optimal results. If opt-einsum is not available, the default order is to contract + from left to right. + + To bypass this default behavior, add the following to disable opt_einsum and skip path calculation: + ``torch.backends.opt_einsum.enabled = False`` To specify which strategy you'd like for opt_einsum to compute the contraction path, add the following line: - `torch.backends.opt_einsum.strategy = 'auto'`. The default strategy is 'auto', and we also support 'greedy' and + ``torch.backends.opt_einsum.strategy = 'auto'``. The default strategy is 'auto', and we also support 'greedy' and 'optimal'. Disclaimer that the runtime of 'optimal' is factorial in the number of inputs! See more details in the opt_einsum documentation (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html). diff --git a/torch/fx/__init__.py b/torch/fx/__init__.py index dd04cdd09d7fa..74691bbe72ac6 100644 --- a/torch/fx/__init__.py +++ b/torch/fx/__init__.py @@ -7,6 +7,8 @@ :: import torch + + # Simple module for demonstration class MyModule(torch.nn.Module): def __init__(self) -> None: @@ -17,11 +19,13 @@ def __init__(self) -> None: def forward(self, x): return self.linear(x + self.param).clamp(min=0.0, max=1.0) + module = MyModule() from torch.fx import symbolic_trace + # Symbolic tracing frontend - captures the semantics of the module - symbolic_traced : torch.fx.GraphModule = symbolic_trace(module) + symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) # High-level intermediate representation (IR) - Graph representation print(symbolic_traced.graph) @@ -80,10 +84,32 @@ def forward(self, x): repository. ''' -from .graph_module import GraphModule -from ._symbolic_trace import symbolic_trace, Tracer, wrap, PH, ProxyableClassMeta -from .graph import Graph, CodeGen -from .node import Node, map_arg, has_side_effect -from .proxy import Proxy -from .interpreter import Interpreter as Interpreter, Transformer as Transformer -from .subgraph_rewriter import replace_pattern +from torch.fx._symbolic_trace import ( # noqa: F401 + PH, + ProxyableClassMeta, + symbolic_trace, + Tracer, + wrap, +) +from torch.fx.graph import CodeGen, Graph # noqa: F401 +from torch.fx.graph_module import GraphModule +from torch.fx.interpreter import Interpreter, Transformer +from torch.fx.node import has_side_effect, map_arg, Node +from torch.fx.proxy import Proxy +from torch.fx.subgraph_rewriter import replace_pattern + + +__all__ = [ + "symbolic_trace", + "Tracer", + "wrap", + "Graph", + "GraphModule", + "Interpreter", + "Transformer", + "Node", + "Proxy", + "replace_pattern", + "has_side_effect", + "map_arg", +] diff --git a/torch/fx/__init__.pyi b/torch/fx/__init__.pyi deleted file mode 100644 index 0a263dfc5071d..0000000000000 --- a/torch/fx/__init__.pyi +++ /dev/null @@ -1,15 +0,0 @@ -from torch.fx._symbolic_trace import ( - symbolic_trace as symbolic_trace, - Tracer as Tracer, - wrap as wrap, -) -from torch.fx.graph import Graph as Graph -from torch.fx.graph_module import GraphModule as GraphModule -from torch.fx.interpreter import Interpreter as Interpreter, Transformer as Transformer -from torch.fx.node import ( - has_side_effect as has_side_effect, - map_arg as map_arg, - Node as Node, -) -from torch.fx.proxy import Proxy as Proxy -from torch.fx.subgraph_rewriter import replace_pattern as replace_pattern diff --git a/torch/fx/_compatibility.py b/torch/fx/_compatibility.py index 27c1e600036df..8a2eeb0d2d695 100644 --- a/torch/fx/_compatibility.py +++ b/torch/fx/_compatibility.py @@ -1,16 +1,19 @@ -from typing import Any, Dict, Callable, TypeVar import textwrap +from typing import Any, Callable, Dict, TypeVar + + +_BACK_COMPAT_OBJECTS: Dict[Any, None] = {} +_MARKED_WITH_COMPATIBILITY: Dict[Any, None] = {} -_BACK_COMPAT_OBJECTS : Dict[Any, None] = {} -_MARKED_WITH_COMPATIBILITY : Dict[Any, None] = {} _T = TypeVar("_T") + def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]: if is_backward_compatible: def mark_back_compat(fn: _T) -> _T: - docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') + docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "") docstring += """ .. note:: Backwards-compatibility for this API is guaranteed. @@ -24,7 +27,7 @@ def mark_back_compat(fn: _T) -> _T: else: def mark_not_back_compat(fn: _T) -> _T: - docstring = textwrap.dedent(getattr(fn, '__doc__', None) or '') + docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "") docstring += """ .. warning:: This API is experimental and is *NOT* backward-compatible. diff --git a/torch/fx/_lazy_graph_module.py b/torch/fx/_lazy_graph_module.py index 2a14fce3782e9..cc2f686ebba10 100644 --- a/torch/fx/_lazy_graph_module.py +++ b/torch/fx/_lazy_graph_module.py @@ -1,9 +1,9 @@ # mypy: allow-untyped-defs from contextlib import contextmanager -from torch.fx import GraphModule from torch.fx.graph_module import ( _format_import_block, + GraphModule, reduce_graph_module, reduce_package_graph_module, ) diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 6693863386513..38835c6ca374f 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -1,13 +1,13 @@ # mypy: allow-untyped-defs import builtins -import copy +import collections import contextlib +import copy import functools import inspect import math import os import warnings -import collections from itertools import chain from types import CodeType, FunctionType, ModuleType from typing import ( @@ -29,11 +29,12 @@ from torch._library.fake_class_registry import FakeScriptObject from ._compatibility import compatibility +from ._lazy_graph_module import _make_graph_module from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph from .graph_module import GraphModule -from ._lazy_graph_module import _make_graph_module from .node import Argument, base_types, map_aggregate -from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager +from .proxy import ParameterProxy, Proxy, Scope, ScopeContextManager, TracerBase + HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS @@ -49,6 +50,7 @@ def is_fx_tracing(): return _is_fx_tracing_flag + @compatibility(is_backward_compatible=True) class ProxyableClassMeta(type): """ @@ -58,6 +60,7 @@ class ProxyableClassMeta(type): import torch import torch.fx + class TensorPair(metaclass=torch.fx.ProxyableClassMeta): def __init__(self, left, right): self.left, self.right = left, right @@ -72,10 +75,12 @@ def mul(self, other): r = self.right * other.right return TensorPair(l, r) - def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): + + def use_tensor_pair_ctor(x: TensorPair, y: torch.Tensor): s = x.add(TensorPair(y, y)) return s.mul(x) + x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) y = torch.randn(5, 3) ref_out = use_tensor_pair_ctor(x, y) @@ -214,6 +219,7 @@ class PHWithMeta(PHBase): """ Object representing an input placeholder to `concrete_args` """ + def __init__(self, ph_key: Optional[str] = None): super().__init__() @@ -308,6 +314,7 @@ def __init__( self.scope = Scope("", None) # Records the module call stack self.module_stack = collections.OrderedDict() + self.num_calls: Dict[str, int] = {} # Mapping of node name to module scope self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} @@ -403,7 +410,11 @@ def create_arg(self, a: Any) -> "Argument": # Tensor was not found in the Module hierarchy, stow it away in a # special attribute and set the qualname to refer to that if not qualname: - base_name = "_tensor_constant" if isinstance(a, torch.Tensor) else "_torchbind_obj" + base_name = ( + "_tensor_constant" + if isinstance(a, torch.Tensor) + else "_torchbind_obj" + ) qualname = self.get_fresh_qualname(base_name) assert isinstance(qualname, str) self.tensor_attrs[a] = qualname @@ -445,9 +456,9 @@ def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool appear with the qualified name ``foo.bar.baz`` here. """ return ( - (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) - and not isinstance(m, torch.nn.Sequential) - ) + m.__module__.startswith("torch.nn") + or m.__module__.startswith("torch.ao.nn") + ) and not isinstance(m, torch.nn.Sequential) @compatibility(is_backward_compatible=True) def path_of_module(self, mod: torch.nn.Module) -> str: @@ -511,16 +522,27 @@ def call_module( value was returned from the ``Module`` invocation. """ module_qualified_name = self.path_of_module(m) - with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope: + with ScopeContextManager( + self.scope, Scope(module_qualified_name, type(m)) + ) as _scope: # module_stack is an ordered dict so writing then deleting the # entry is equivalent to push/pop on a list - self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type) + num_calls = self.num_calls.get(module_qualified_name, 0) + module_key = ( + f"{_scope.module_path}@{num_calls}" + if num_calls > 0 + else _scope.module_path + ) + self.module_stack[module_key] = (module_qualified_name, _scope.module_type) + self.num_calls[module_qualified_name] = num_calls + 1 if not self.is_leaf_module(m, module_qualified_name): ret_val = forward(*args, **kwargs) else: - ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs) + ret_val = self.create_proxy( + "call_module", module_qualified_name, args, kwargs + ) key, _ = self.module_stack.popitem(last=True) - assert key == _scope.module_path, f" Unexpected key {key}" + assert key == module_key, f" Unexpected key {key}" return ret_val @@ -547,6 +569,7 @@ def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any The return value from the getattr call. """ + def maybe_get_proxy_for_attr( attr_val, collection_to_search, parameter_proxy_cache ): @@ -616,15 +639,16 @@ def create_args_for_root(self, root_fn, is_module, concrete_args=None): sig = inspect.signature(fn_for_analysis) - # This covers the very specific case where we are passing in flat # concrete_args as a tuple, but our traced fn takes (*args, **kwargs). # In this case, just take the concrete_args and pass them through. name_idx = 0 - if isinstance(concrete_args, tuple) and \ - len(concrete_args) > 0 and \ - (co.co_flags & HAS_VARSTUFF) and \ - total_args == 1: + if ( + isinstance(concrete_args, tuple) + and len(concrete_args) > 0 + and (co.co_flags & HAS_VARSTUFF) + and total_args == 1 + ): for concrete_arg in concrete_args: out = self.create_proxy("placeholder", f"input_{name_idx}", (), {}) if isinstance(concrete_arg, PHBase): @@ -718,12 +742,12 @@ def trace( _is_fx_tracing_flag = True try: if isinstance(root, torch.nn.Module): - # do real recompilation for _LazyGraphModule before retracing since the trace # method can not trace the _lazy_forward method. Got error: # https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 # without this. from torch.fx._lazy_graph_module import _LazyGraphModule + _LazyGraphModule.force_recompile(root) self.root = root @@ -741,12 +765,12 @@ def trace( tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None) self.graph = Graph(tracer_cls=tracer_cls) - if hasattr(fn, '__code__'): + if hasattr(fn, "__code__"): code = fn.__code__ self.graph._co_fields = { - 'co_name': code.co_name, - 'co_filename': code.co_filename, - 'co_firstlineno': code.co_firstlineno, + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, } # When we encounter a Tensor value that's not a parameter, we look if it @@ -754,11 +778,7 @@ def trace( # values to the qualified name here for efficiency. This is used downstream # in create_arg self.tensor_attrs: Dict[ - Union[ - torch.Tensor, - ScriptObject, - FakeScriptObject - ], str + Union[torch.Tensor, ScriptObject, FakeScriptObject], str ] = {} def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): @@ -835,7 +855,7 @@ def __deepcopy__(self, memo): new_tracer = Tracer.__new__(Tracer) for k, v in self.__dict__.items(): - if k in {'_autowrap_search'}: + if k in {"_autowrap_search"}: new_obj = copy.copy(v) else: new_obj = copy.deepcopy(v, memo) @@ -853,9 +873,7 @@ def replace_ph(x): cnt += 1 param = sig.parameters[name] default = ( - () - if param.default is inspect.Parameter.empty - else (param.default,) + () if param.default is inspect.Parameter.empty else (param.default,) ) out = self.create_proxy( "placeholder", f"{name}_{str(cnt)}", default, {} @@ -873,11 +891,7 @@ def replace_ph(x): return out # Union[int, bool] == bool in Python <= 3.6 - if ( - type(x) == bool - or type(x) in base_types - and type(x) != torch.Tensor - ): + if type(x) == bool or type(x) in base_types and type(x) != torch.Tensor: torch._assert( out == x, f"{name} has been specialized to have value {x} but got another value", @@ -902,13 +916,15 @@ def replace_ph(x): default = () else: param = sig.parameters[name] - default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment] + default = ( # type: ignore[assignment] + () if param.default is inspect.Parameter.empty else (param.default,) + ) return self.create_proxy( "placeholder", name, default, {}, - type_expr=fn_for_analysis.__annotations__.get(name, None) + type_expr=fn_for_analysis.__annotations__.get(name, None), ) @@ -1007,6 +1023,7 @@ def revert(self): def patch(self): self.frame_dict[self.fn_name] = self.new_fn + class _PatchedFnDel(_PatchedFn): def revert(self): del self.frame_dict[self.fn_name] @@ -1022,6 +1039,7 @@ def revert(self): def patch(self): setattr(self.frame_dict, self.fn_name, self.new_fn) + class _Patcher: def __init__(self) -> None: super().__init__() @@ -1102,6 +1120,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): CURRENT_PATCHER: Optional[_Patcher] = None + @contextlib.contextmanager def _new_patcher(): global CURRENT_PATCHER @@ -1128,7 +1147,10 @@ def _maybe_revert_all_patches(): finally: if current_patcher is not None: patches_made = current_patcher.reapply_all_patches() - assert patches_made == patches_removed, "CURRENT_PATCHER was changed during a revert_all_patches" + assert ( + patches_made == patches_removed + ), "CURRENT_PATCHER was changed during a revert_all_patches" + def _patch_wrapped_functions(patcher: _Patcher): """ @@ -1174,7 +1196,9 @@ def wrap(fn_or_name: Union[str, Callable]): def my_custom_function(x, y): return x * x + y * y - torch.fx.wrap('my_custom_function') + + torch.fx.wrap("my_custom_function") + def fn_to_be_traced(x, y): # When symbolic tracing, the below call to my_custom_function will be inserted into @@ -1244,14 +1268,14 @@ def f(a, b): if b == True: return a else: - return a*2 + return a * 2 FX can typically not trace through this due to the presence of control flow. However, we can use `concrete_args` to specialize on the value of `b` to trace through this:: - f = fx.symbolic_trace(f, concrete_args={'b': False}) - assert f(3, False) == 6 + f = fx.symbolic_trace(f, concrete_args={"b": False}) + assert f(3, False) == 6 Note that although you can still pass in different values of `b`, they will be ignored. @@ -1265,8 +1289,10 @@ def f(x): for v in x.values(): out += v return out - f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) - assert f({'a': 1, 'b': 2, 'c': 4}) == 7 + + + f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}) + assert f({"a": 1, "b": 2, "c": 4}) == 7 Args: diff --git a/torch/fx/_utils.py b/torch/fx/_utils.py index 3dd3780fe0bb3..1f2cb0afdcd88 100644 --- a/torch/fx/_utils.py +++ b/torch/fx/_utils.py @@ -55,7 +55,7 @@ def get_node_context(node, num_nodes=2) -> str: """ node_contexts = [] cur = node - for i in range(num_nodes): + for _ in range(num_nodes): node_contexts.append(cur.format_node()) if cur.op == "root": break diff --git a/torch/fx/annotate.py b/torch/fx/annotate.py index d1b5b5f2d3761..b3c5056066251 100644 --- a/torch/fx/annotate.py +++ b/torch/fx/annotate.py @@ -1,7 +1,9 @@ # mypy: allow-untyped-defs from torch.fx.proxy import Proxy + from ._compatibility import compatibility + @compatibility(is_backward_compatible=False) def annotate(val, type): """ @@ -18,13 +20,15 @@ def annotate(val, type): """ if isinstance(val, Proxy): if val.node.type: - raise RuntimeError(f"Tried to annotate a value that already had a type on it!" - f" Existing type is {val.node.type} " - f"and new type is {type}. " - f"This could happen if you tried to annotate a function parameter " - f"value (in which case you should use the type slot " - f"on the function signature) or you called " - f"annotate on the same value twice") + raise RuntimeError( + f"Tried to annotate a value that already had a type on it!" + f" Existing type is {val.node.type} " + f"and new type is {type}. " + f"This could happen if you tried to annotate a function parameter " + f"value (in which case you should use the type slot " + f"on the function signature) or you called " + f"annotate on the same value twice" + ) else: val.node.type = type return val diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index 9b347762dedba..4f9fe0f9a1407 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -1,22 +1,22 @@ # mypy: allow-untyped-defs import operator from collections import deque -from typing import Dict, List, Set, NamedTuple, Tuple, Deque +from typing import Deque, Dict, List, NamedTuple, Set, Tuple import torch -from torch.fx.passes.graph_manipulation import get_size_of_all_nodes from torch.fx.experimental.partitioner_utils import ( - Partition, Device, - PartitionerConfig, - get_partition_to_latency_mapping, + get_extra_size_of, get_latency_of_partitioned_graph, + get_partition_to_latency_mapping, NodeLatency, - get_extra_size_of, + Partition, + PartitionerConfig, PartitionMode, ) from torch.fx.graph_module import GraphModule -from torch.fx.node import Node, map_arg +from torch.fx.node import map_arg, Node +from torch.fx.passes.graph_manipulation import get_size_of_all_nodes from torch.fx.passes.split_module import split_module @@ -260,7 +260,9 @@ def find_device_for(partition: Partition): # Find devices for all the partitions without a device found_device = True for partition in no_device_partitions: - device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1))) + device_to_left_mem_bytes = dict( + sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1)) + ) found_device = find_device_for(partition) if not found_device: break @@ -463,8 +465,6 @@ def find_device_based_on_size(node) -> Device: # Check if no device is left if len(self.partitions) == len(self.devices): # No device is left - # Put the previous partitions into a list (non_single_node_partitions) - non_single_node_partitions = self.partitions[:] # Create the first single node partition for the current node self.create_single_node_partition(node) continue diff --git a/torch/fx/experimental/const_fold.py b/torch/fx/experimental/const_fold.py index 9b12a027f0563..d1ca4acde2b80 100644 --- a/torch/fx/experimental/const_fold.py +++ b/torch/fx/experimental/const_fold.py @@ -7,7 +7,12 @@ from torch.fx.passes.split_module import split_module -__all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs'] +__all__ = [ + "FoldedGraphModule", + "get_unique_attr_name_in_module", + "split_const_subgraphs", +] + class FoldedGraphModule(torch.fx.GraphModule): """ diff --git a/torch/fx/experimental/debug.py b/torch/fx/experimental/debug.py index d3c482319f2ef..e59dcbb3296f9 100644 --- a/torch/fx/experimental/debug.py +++ b/torch/fx/experimental/debug.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs import torch.fx as fx + def set_trace(gm: fx.GraphModule) -> fx.GraphModule: """ Sets a breakpoint in `gm`'s generated python code. It drops into pdb when @@ -13,18 +14,14 @@ def set_trace(gm: fx.GraphModule) -> fx.GraphModule: Returns: the `gm` with breakpoint inserted. """ + def insert_pdb(body): return ["import pdb; pdb.set_trace()\n", *body] with gm.graph.on_generate_code( make_transformer=lambda cur_transform: ( # new code transformer to register - lambda body: ( - insert_pdb( - cur_transform(body) if cur_transform - else body - ) - ) + lambda body: (insert_pdb(cur_transform(body) if cur_transform else body)) ) ): gm.recompile() diff --git a/torch/fx/experimental/graph_gradual_typechecker.py b/torch/fx/experimental/graph_gradual_typechecker.py index fb49795a06fac..0be22bc0d795a 100644 --- a/torch/fx/experimental/graph_gradual_typechecker.py +++ b/torch/fx/experimental/graph_gradual_typechecker.py @@ -1,19 +1,20 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs -from functools import reduce -import torch +import itertools import operator -from torch.fx.tensor_type import Dyn, is_consistent, TensorType, is_more_precise +from functools import reduce from typing import Callable, Dict -from torch.fx.node import Target, Node -from torch.nn.modules.batchnorm import BatchNorm2d -from torch.nn.modules.conv import Conv2d -from torch.fx.experimental.refinement_types import Equality -import itertools +import sympy + +import torch +from torch.fx.experimental.refinement_types import Equality from torch.fx.experimental.unification import Var # type: ignore[attr-defined] +from torch.fx.node import Node, Target +from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.conv import Conv2d -import sympy _INFERENCE_RULES: Dict[Target, Callable] = {} _REFINEMENT_RULES: Dict[Target, Callable] = {} @@ -32,10 +33,12 @@ def expand_to_tensor_dim(t, n): return TensorType(tuple(dims)) elif isinstance(t, TensorType): if len(t.__args__) != n: - raise TypeError(f'Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}') + raise TypeError( + f"Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}" + ) return t else: - raise TypeError(f'Cannot match the type {t}') + raise TypeError(f"Cannot match the type {t}") def broadcast_types(t1, t2): @@ -80,32 +83,39 @@ def broadcast_types(t1, t2): (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) return (t1, t2) else: - raise TypeError(f'Cannot broadcast types {t1} and {t2}') + raise TypeError(f"Cannot broadcast types {t1} and {t2}") + def register_inference_rule(call_target): def register(fn): if call_target in _INFERENCE_RULES: - raise RuntimeError(f'Inference rule already registered for {call_target}!') + raise RuntimeError(f"Inference rule already registered for {call_target}!") _INFERENCE_RULES[call_target] = fn return fn + return register + def register_refinement_rule(call_target): def register(fn): if call_target in _REFINEMENT_RULES: - raise RuntimeError(f'Refinement rule already registered for {call_target}!') + raise RuntimeError(f"Refinement rule already registered for {call_target}!") _REFINEMENT_RULES[call_target] = fn return fn + return register + def register_algebraic_expressions_inference_rule(call_target): def register(fn): if call_target in _RULES: - raise RuntimeError(f'Rule already registered for {call_target}!') + raise RuntimeError(f"Rule already registered for {call_target}!") _RULES[call_target] = fn return fn + return register + @register_inference_rule(torch.add) @register_inference_rule(operator.add) def add_inference_rule(n: Node): @@ -142,15 +152,15 @@ def add_inference_rule(n: Node): (new_t1, new_t2) = broadcast_types(t1, t2) if new_t1 != t1 or new_t2 != t2: - n.meta['broadcast'] = True + n.meta["broadcast"] = True n.meta[str(n.args[0])] = new_t1 n.meta[str(n.args[1])] = new_t2 else: - n.meta['broadcast'] = False + n.meta["broadcast"] = False - new_t1 = t1 if not n.meta['broadcast'] else new_t1 - new_t2 = t2 if not n.meta['broadcast'] else new_t2 + new_t1 = t1 if not n.meta["broadcast"] else new_t1 + new_t2 = t2 if not n.meta["broadcast"] else new_t2 # we check for consistency between the new types if is_consistent(new_t1, new_t2): @@ -164,8 +174,11 @@ def add_inference_rule(n: Node): n.type = new_t1 return n.type else: - raise TypeError(f'Cannot add arguments {n.args[0]} ({ n.args[0].type}) and {n.args[1]} ({ n.args[1].type}) in node {n}.' - f' Types should match ') + raise TypeError( + f"Cannot add arguments {n.args[0]} ({n.args[0].type}) and {n.args[1]} ({n.args[1].type}) in node {n}." + f" Types should match " + ) + @register_inference_rule(getattr) def get_attr_inference_rule(n: Node, traced): @@ -175,7 +188,6 @@ def get_attr_inference_rule(n: Node, traced): The most representitive type we have is "Dyn" but the system can be extended with more types, such as a type to represent shapes """ - attr_node = n.args[0] attr_name = n.args[1] if attr_name == "shape": @@ -186,6 +198,7 @@ def get_attr_inference_rule(n: Node, traced): # TODO. We leave it like this till we add a type to represent tensor sizes return n.type + @register_inference_rule(torch.transpose) def transpose_inference_rule(n: Node): """ @@ -212,9 +225,13 @@ def transpose_inference_rule(n: Node): n.type = get_greatest_upper_bound(n.type, final) return n.type else: - raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') + raise TypeError( + f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}" + ) else: - raise TypeError(f'Cannot transpose {dim1} and {dim2} in type {t} for node {n}') + raise TypeError( + f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}" + ) @register_inference_rule(torch.reshape) @@ -252,9 +269,10 @@ def reshape_inference_rule(n: Node): n.type = t2_type return t2_type else: - raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') + raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}") else: - raise TypeError(f'Cannot reshape in node {n} from {t1} to {t2_type}') + raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}") + @register_inference_rule(BatchNorm2d) def bn2d_inference_rule(n: Node, module_instance): @@ -275,10 +293,11 @@ def bn2d_inference_rule(n: Node, module_instance): # we check the conditions on the incoming argument # and any existing annotation # we also check for consistency between both annotations - if is_consistent(arg_type.__args__[1], module_instance.num_features) and \ - is_consistent(n.type.__args__[1], module_instance.num_features) and \ - is_consistent(arg_type, n.type): - + if ( + is_consistent(arg_type.__args__[1], module_instance.num_features) + and is_consistent(n.type.__args__[1], module_instance.num_features) + and is_consistent(arg_type, n.type) + ): # we choose the more precise type # to be the node type # so if an incoming argument has more type information @@ -286,21 +305,35 @@ def bn2d_inference_rule(n: Node, module_instance): n.type = get_greatest_upper_bound(arg_type, n.type) return n.type else: - raise TypeError(f'Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}') + raise TypeError( + f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}" + ) def calculate_out_dimension(d_in, module_instance, index): """ For calculating h_in and w_out according to the conv2D documentation """ - padding = (module_instance.padding, module_instance.padding) \ - if isinstance(module_instance.padding, int) else module_instance.padding - kernel_size = (module_instance.kernel_size, module_instance.kernel_size) \ - if isinstance(module_instance.kernel_size, int) else module_instance.kernel_size - stride = (module_instance.stride, module_instance.stride) \ - if isinstance(module_instance.stride, int) else module_instance.stride - dilation = (module_instance.dilation, module_instance.dilation) \ - if isinstance(module_instance.dilation, int) else module_instance.dilation + padding = ( + (module_instance.padding, module_instance.padding) + if isinstance(module_instance.padding, int) + else module_instance.padding + ) + kernel_size = ( + (module_instance.kernel_size, module_instance.kernel_size) + if isinstance(module_instance.kernel_size, int) + else module_instance.kernel_size + ) + stride = ( + (module_instance.stride, module_instance.stride) + if isinstance(module_instance.stride, int) + else module_instance.stride + ) + dilation = ( + (module_instance.dilation, module_instance.dilation) + if isinstance(module_instance.dilation, int) + else module_instance.dilation + ) DIMENSION_TYPES = (int, sympy.Symbol) @@ -308,14 +341,14 @@ def calculate_out_dimension(d_in, module_instance, index): return Dyn elif isinstance(d_in, DIMENSION_TYPES): - n = d_in + 2 * padding[index] - \ - dilation[index] * \ - (kernel_size[index] - 1) - 1 + n = d_in + 2 * padding[index] - dilation[index] * (kernel_size[index] - 1) - 1 return (n // stride[0]) + 1 else: - raise TypeError(f'{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}') + raise TypeError( + f"{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}" + ) def get_greatest_upper_bound(type1, type2): @@ -328,8 +361,11 @@ def get_greatest_upper_bound(type1, type2): return type1 elif isinstance(type1, TensorType) and isinstance(type2, TensorType): if not is_consistent(type1, type2): - raise TypeError(f'Inconsistent types {type1}, {type2}') - gub = [t1 if is_more_precise(t1, t2) else t2 for (t1, t2) in zip(type1.__args__, type2.__args__)] + raise TypeError(f"Inconsistent types {type1}, {type2}") + gub = [ + t1 if is_more_precise(t1, t2) else t2 + for (t1, t2) in zip(type1.__args__, type2.__args__) + ] return TensorType(tuple(gub)) @@ -353,12 +389,16 @@ def conv2d_inference_rule(n: Node, module_instance): h_in = arg_type.__args__[2] h_out = calculate_out_dimension(h_in, module_instance, 0) w_out = calculate_out_dimension(w_in, module_instance, 1) - new_type = TensorType((arg_type.__args__[0], module_instance.out_channels, h_out, w_out)) + new_type = TensorType( + (arg_type.__args__[0], module_instance.out_channels, h_out, w_out) + ) gub = get_greatest_upper_bound(new_type, curr_node_type) n.type = gub return n.type else: - raise TypeError(f'Cannot apply {module_instance} with input type { arg_type} and existing type {n.type} on {n}') + raise TypeError( + f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}" + ) @register_inference_rule(torch.nn.ReLU) @@ -394,7 +434,7 @@ def maxpool2d_check(typ, module_instance): return TensorType(tuple(new_type_list)) else: - raise TypeError(f'Wrong size {typ} for {module_instance}') + raise TypeError(f"Wrong size {typ} for {module_instance}") @register_inference_rule(torch.nn.MaxPool2d) @@ -418,7 +458,6 @@ def maxpool2d_inference_rule(n: Node, module_instance): return n.type - def linear_check(tensor_type, module_instance): """ Checks that an input tensor type satisfies the conditions for linear operation @@ -430,9 +469,11 @@ def linear_check(tensor_type, module_instance): new_type_args[-1] = module_instance.out_features return TensorType(tuple(new_type_args)) else: - raise TypeError(f'Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}') + raise TypeError( + f"Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}" + ) else: - raise TypeError(f'Type {tensor_type} must have rank 2 or more.') + raise TypeError(f"Type {tensor_type} must have rank 2 or more.") @register_inference_rule(torch.nn.Linear) @@ -470,7 +511,8 @@ def adaptiveavgpool2d_check(tensor_type, module_instance): return TensorType(tuple(new_type_list)) else: - raise TypeError(f'Tensor ranks must be 3 or 4. Got {tensor_type}') + raise TypeError(f"Tensor ranks must be 3 or 4. Got {tensor_type}") + @register_inference_rule(torch.nn.AdaptiveAvgPool2d) def adaptiveavgpool2d_inference_rule(n: Node, module_instance): @@ -486,6 +528,7 @@ def adaptiveavgpool2d_inference_rule(n: Node, module_instance): n.type = get_greatest_upper_bound(n.type, output_type) return n.type + def flatten_check(tensor_type, start_dim, end_dim): l = len(tensor_type.__args__) @@ -504,7 +547,10 @@ def flatten_check(tensor_type, start_dim, end_dim): new_type_list = lhs + mid + rhs return TensorType(tuple(new_type_list)) else: - raise TypeError(f'Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}') + raise TypeError( + f"Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}" + ) + @register_inference_rule(torch.flatten) def flatten_inference_rule(n: Node): @@ -531,10 +577,11 @@ def flatten_inference_rule(n: Node): if isinstance(n.args[0].type, TensorType): output_type = flatten_check(n.args[0].type, start_dim, end_dim) - n.type = get_greatest_upper_bound(output_type , n.type) + n.type = get_greatest_upper_bound(output_type, n.type) return n.type + class GraphTypeChecker: def __init__(self, env, traced): self.env = env @@ -572,16 +619,16 @@ def type_check_node(self, n: Node): if n.type is None: n.type = Dyn - if n.op == 'placeholder': + if n.op == "placeholder": return n.type - elif n.op == 'get_attr': + elif n.op == "get_attr": t = get_parameter(self.traced, n.target) # type: ignore[arg-type] if isinstance(t.data, torch.Tensor): n.type = TensorType(t.data.shape) return n.type - elif n.op == 'call_function': + elif n.op == "call_function": if n.target == getattr: assert getattr in _INFERENCE_RULES return _INFERENCE_RULES[n.target](n, self.traced) @@ -589,18 +636,24 @@ def type_check_node(self, n: Node): elif n.target in _INFERENCE_RULES: return _INFERENCE_RULES[n.target](n) else: - raise RuntimeError(f'No inference rule registered for target {n.target}!') + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) - elif n.op == 'call_module': + elif n.op == "call_module": module_instance = self.traced.get_submodule(n.target) if type(module_instance) in _INFERENCE_RULES: return _INFERENCE_RULES[type(module_instance)](n, module_instance) else: - raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!') + raise RuntimeError( + f"No inference rule registered for class {type(module_instance)}!" + ) + + elif n.op == "output": - elif n.op == 'output': def get_node_type(a): return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) return n.type @@ -635,6 +688,7 @@ def linear_refinement_rule(n: Node): res = [Equality(arg_type.__args__[0], n.type.__args__[0])] return res + @register_refinement_rule(BatchNorm2d) @register_refinement_rule(torch.nn.ReLU) def all_eq(n: Node): @@ -689,7 +743,11 @@ def element_wise_eq(n: Node): if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): arg_type1 = n.args[0].type arg_type2 = n.args[1].type - if isinstance(arg_type1, TensorType) and isinstance(arg_type2, TensorType) and isinstance(n.type, TensorType): + if ( + isinstance(arg_type1, TensorType) + and isinstance(arg_type2, TensorType) + and isinstance(n.type, TensorType) + ): args1, args2 = broadcast_types(arg_type1, arg_type2) # by this point, we know that args1 and args2 are the same size. a1 = args1.__args__ @@ -758,12 +816,14 @@ def conv_rule(n: Node, module_instance): n.type = new_type return new_type + class Refine: """ Symbolic shape inference. Generates constraints over type variables. Currently all constraints are equality constraints. """ + def __init__(self, traced): self.constraints = [] self.traced = traced @@ -806,7 +866,6 @@ def replace_dyn_with_fresh_var(self, typ): else: return typ - def convert_to_sympy_symbols(self, typ): """ Replace all unknown types with fresh type variables. @@ -836,22 +895,24 @@ def refine_node(self, n: Node): n.type = self.replace_dyn_with_fresh_var(n.type) - if n.op == 'call_function': + if n.op == "call_function": if n.target in _REFINEMENT_RULES: self.constraints += _REFINEMENT_RULES[n.target](n) else: pass - if n.op == 'call_module': + if n.op == "call_module": module_instance = self.traced.get_submodule(n.target) if type(module_instance) in _REFINEMENT_RULES: self.constraints += _REFINEMENT_RULES[type(module_instance)](n) else: pass - if n.op == 'output': + if n.op == "output": + def get_node_type(a): return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) return n.type @@ -860,28 +921,31 @@ def get_node_type(a): def infer_symbolic_relations(self, n: Node): n.type = self.convert_to_sympy_symbols(n.type) - if n.op == 'call_function': + if n.op == "call_function": if n.target in _RULES: return _RULES[n.target](n) else: pass - if n.op == 'call_module': + if n.op == "call_module": module_instance = self.traced.get_submodule(n.target) if type(module_instance) in _RULES: return _RULES[type(module_instance)](n, module_instance) else: pass - if n.op == 'output': + if n.op == "output": + def get_node_type(a): return a.type + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) return n.type else: pass + def get_parameter(traced, target: str): """ Returns the parameter given by ``target`` if it exists, diff --git a/torch/fx/experimental/merge_matmul.py b/torch/fx/experimental/merge_matmul.py index c1a634b2602a0..b3e1efcbd19e4 100644 --- a/torch/fx/experimental/merge_matmul.py +++ b/torch/fx/experimental/merge_matmul.py @@ -1,14 +1,13 @@ # mypy: allow-untyped-defs -import torch - -from torch.fx.node import Node -from torch.fx._symbolic_trace import symbolic_trace -from torch.fx.passes.tools_common import legalize_graph import itertools import operator - from typing import Dict, List, Tuple +import torch +from torch.fx._symbolic_trace import symbolic_trace +from torch.fx.node import Node +from torch.fx.passes.tools_common import legalize_graph + def split_result_tensors( result: torch.Tensor, inputs: List[torch.Tensor] @@ -146,7 +145,14 @@ def merge_matmul(in_mod: torch.nn.Module): # Multiply the concatenated LHS operands with the one RHS. This will produce # the same results as all the individual matmuls involving rhs in the original graph, # but they will all be concatenated together. - merge_mm = gm.graph.call_function(torch.matmul, (merge_mm_cat, rhs,), {}) + merge_mm = gm.graph.call_function( + torch.matmul, + ( + merge_mm_cat, + rhs, + ), + {}, + ) # Split the result of the merged matmul using the shapes of the LHS operands # to ascertain how large each chunk should be. diff --git a/torch/fx/experimental/meta_tracer.py b/torch/fx/experimental/meta_tracer.py index dd2c9b11ab76b..1b74f33f40b54 100644 --- a/torch/fx/experimental/meta_tracer.py +++ b/torch/fx/experimental/meta_tracer.py @@ -1,14 +1,15 @@ # mypy: allow-untyped-defs +import builtins +import functools +import warnings +from typing import Any, Callable, Dict, Optional, Union + import torch import torch.fx -import warnings -import functools -import builtins -from typing import Any, Callable, Dict, Optional, Union def embedding_override(self, input): - return torch.empty(*input.shape, self.weight.shape[-1], device='meta') + return torch.empty(*input.shape, self.weight.shape[-1], device="meta") def nn_layernorm_override(self, input): @@ -24,21 +25,22 @@ def torch_nn_relu_override(self, x): def functional_relu_override(x, inplace=False): - assert not inplace, 'dont support inplace functional.relu for metatensor analysis' + assert not inplace, "dont support inplace functional.relu for metatensor analysis" return x def torch_where_override(condition, x, y): # torch.where returns the broadcasted tensor of condition, x, and y, # so hack it by using addition - return condition.to(device='meta') + x.to(device='meta') + y.to(device='meta') + return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") def torch_abs_override(input, *, out=None): - assert out is None, 'Dont support in-place abs for MetaTensor analysis' + assert out is None, "Dont support in-place abs for MetaTensor analysis" return input -manual_meta_overrides : Dict[Callable, Callable] = { + +manual_meta_overrides: Dict[Callable, Callable] = { torch.nn.Embedding: embedding_override, torch.nn.LayerNorm: nn_layernorm_override, torch.relu: torch_relu_override, @@ -48,6 +50,7 @@ def torch_abs_override(input, *, out=None): torch.abs: torch_abs_override, } + def gen_constructor_wrapper(target): @functools.wraps(target) def wrapper(*args, **kwargs): @@ -57,57 +60,66 @@ def check_has_proxy(v): if isinstance(v, torch.fx.Proxy): nonlocal proxy proxy = v + torch.fx.node.map_aggregate(args, check_has_proxy) torch.fx.node.map_aggregate(kwargs, check_has_proxy) if proxy is not None: - return proxy.tracer.create_proxy('call_function', target, args, kwargs) + return proxy.tracer.create_proxy("call_function", target, args, kwargs) else: return target(*args, **kwargs) + return wrapper, target + class MetaProxy(torch.fx.Proxy): def install_tensor_meta(self, tensor_meta): self._tensor_meta = tensor_meta def size(self, dim=None): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: return self._tensor_meta.size(*[dim] if dim else []) - return self.tracer.create_proxy('call_method', 'size', (self, dim) if dim else (self,), {}) + return self.tracer.create_proxy( + "call_method", "size", (self, dim) if dim else (self,), {} + ) def dim(self): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: return self._tensor_meta.dim() - return self.tracer.create_proxy('call_method', 'dim', (self,), {}) + return self.tracer.create_proxy("call_method", "dim", (self,), {}) @property def shape(self): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: return self._tensor_meta.shape - return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'shape'), {}) + return self.tracer.create_proxy( + "call_function", builtins.getattr, (self, "shape"), {} + ) @property def dtype(self): - if hasattr(self, '_tensor_meta') and self._tensor_meta is not None: + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: return self._tensor_meta.dtype - return self.tracer.create_proxy('call_function', builtins.getattr, (self, 'dtype'), {}) + return self.tracer.create_proxy( + "call_function", builtins.getattr, (self, "dtype"), {} + ) @property def device(self): # Hack so we can track when devices are used. During meta-tensor propagation, # replace these values with a constant 'meta' - return MetaDeviceAttribute(self, 'device') + return MetaDeviceAttribute(self, "device") def __getattr__(self, k): - if k == '_tensor_meta': + if k == "_tensor_meta": return self.__getattribute__(k) # note: not added to the graph yet, if this is a method call # we peephole optimize to the method invocation return MetaAttribute(self, k) + class MetaAttribute(MetaProxy): def __init__(self, root, attr: str): - self.root = root self.attr = attr self.tracer = root.tracer @@ -118,33 +130,51 @@ def node(self): # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + self._node = self.tracer.create_proxy( + "call_function", getattr, (self.root, self.attr), {} + ).node return self._node def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + return self.tracer.create_proxy( + "call_method", self.attr, (self.root,) + args, kwargs + ) + class MetaDeviceAttribute(MetaAttribute): pass + def proxys_to_metas(v): if isinstance(v, MetaDeviceAttribute): - return 'meta' + return "meta" if isinstance(v, torch.fx.Proxy): - assert isinstance(v, MetaProxy), f'Expected MetaProxy but got {type(v)}' - assert hasattr(v, '_tensor_meta'), 'MetaProxy does not have an associated meta' + assert isinstance(v, MetaProxy), f"Expected MetaProxy but got {type(v)}" + assert hasattr(v, "_tensor_meta"), "MetaProxy does not have an associated meta" return v._tensor_meta return v -class MetaTracer(torch.fx.Tracer): - allow_insert_stateless_mods : bool = True - - _TORCH_METHODS_TO_PATCH = ['arange', 'zeros', 'ones', 'full_like', 'eye'] - def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, proxy_factory_fn=None): - rv = super().create_proxy(kind, target, args, kwargs, name, type_expr, proxy_factory_fn) - - if kind == 'placeholder' and target in self.meta_args: +class MetaTracer(torch.fx.Tracer): + allow_insert_stateless_mods: bool = True + + _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"] + + def create_proxy( + self, + kind, + target, + args, + kwargs, + name=None, + type_expr=None, + proxy_factory_fn=None, + ): + rv = super().create_proxy( + kind, target, args, kwargs, name, type_expr, proxy_factory_fn + ) + + if kind == "placeholder" and target in self.meta_args: rv.install_tensor_meta(self.meta_args[target]) return rv @@ -154,54 +184,57 @@ def create_proxy(self, kind, target, args, kwargs, name=None, type_expr=None, pr # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, # this will break and you will likely see issues where we cannot infer # the size of the output. - if 'device' in kwargs: - kwargs['device'] = 'meta' + if "device" in kwargs: + kwargs["device"] = "meta" try: args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas) kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas) - if kind == 'call_function': + if kind == "call_function": meta_target = manual_meta_overrides.get(target, target) meta_out = meta_target(*args_metas, **kwargs_metas) - elif kind == 'call_method': - meta_out = getattr(args_metas[0], target)(*args_metas[1:], **kwargs_metas) # type: ignore[index] - elif kind == 'call_module': - assert hasattr(self, 'orig_forward') + elif kind == "call_method": + meta_target = getattr(args_metas[0], target) # type: ignore[index] + meta_out = meta_target(*args_metas[1:], **kwargs_metas) # type: ignore[index] + elif kind == "call_module": + assert hasattr(self, "orig_forward") self._disable_module_getattr = True try: mod = self.root.get_submodule(target) mod_type = type(mod) if mod_type in manual_meta_overrides: - meta_out = manual_meta_overrides[mod_type](mod, *args_metas, **kwargs_metas) # type: ignore[misc, arg-type] + meta_out = manual_meta_overrides[mod_type]( + mod, *args_metas, **kwargs_metas + ) # type: ignore[misc, arg-type] else: meta_out = self.orig_forward(*args_metas, **kwargs_metas) finally: self._disable_module_getattr = False - elif kind == 'get_attr': + elif kind == "get_attr": self._disable_module_getattr = True try: attr_itr = self.root - atoms = target.split('.') + atoms = target.split(".") for atom in atoms: attr_itr = getattr(attr_itr, atom) assert isinstance(attr_itr, torch.Tensor) - meta_out = attr_itr.to(device='meta') + meta_out = attr_itr.to(device="meta") finally: self._disable_module_getattr = False else: return rv # TODO - assert isinstance(rv, torch.fx.Proxy), 'Dont support composite output yet' + assert isinstance(rv, torch.fx.Proxy), "Dont support composite output yet" rv.install_tensor_meta(meta_out) except Exception as e: - warnings.warn(f'Could not compute metadata for {kind} target {target}: {e}') + warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}") return rv def getattr(self, attr, attr_val, parameter_proxy_cache): - if getattr(self, '_disable_module_getattr', False): + if getattr(self, "_disable_module_getattr", False): return attr_val else: return super().getattr(attr, attr_val, parameter_proxy_cache) @@ -227,8 +260,12 @@ def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str: def path_of_module(self, mod: torch.nn.Module) -> str: try: return super().path_of_module(mod) - except NameError as e: - if self.allow_insert_stateless_mods and len(list(mod.parameters())) == 0 and len(list(mod.buffers())) == 0: + except NameError: + if ( + self.allow_insert_stateless_mods + and len(list(mod.parameters())) == 0 + and len(list(mod.buffers())) == 0 + ): path = self._insert_module_as_submodule(mod) self.prev_module = path return path @@ -237,12 +274,13 @@ def path_of_module(self, mod: torch.nn.Module) -> str: def proxy(self, node): return MetaProxy(node, self) - def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override] + def trace(self, root, meta_args: Dict[str, torch.Tensor], concrete_args=None): # type: ignore[override] assert isinstance(meta_args, dict) self.meta_args = meta_args self.patched_torch_methods = { - target: gen_constructor_wrapper(getattr(torch, target)) for target in self._TORCH_METHODS_TO_PATCH + target: gen_constructor_wrapper(getattr(torch, target)) + for target in self._TORCH_METHODS_TO_PATCH } self.orig_fns = set() @@ -252,18 +290,22 @@ def trace(self, root, meta_args : Dict[str, torch.Tensor], concrete_args=None): try: graph = super().trace(root, concrete_args) - graph._tracer_extras = {'meta_args': meta_args} + graph._tracer_extras = {"meta_args": meta_args} return graph finally: for name, (_, orig) in self.patched_torch_methods.items(): setattr(torch, name, orig) -def symbolic_trace(root : Union[torch.nn.Module, Callable[..., Any]], - meta_args : Optional[Dict[str, torch.Tensor]] = None, - concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.GraphModule: +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + meta_args: Optional[Dict[str, torch.Tensor]] = None, + concrete_args: Optional[Dict[str, Any]] = None, +) -> torch.fx.GraphModule: tracer = MetaTracer() graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type] - name = root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + name = ( + root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + ) gm = torch.fx.GraphModule(tracer.root, graph, name) return gm diff --git a/torch/fx/experimental/migrate_gradual_types/constraint.py b/torch/fx/experimental/migrate_gradual_types/constraint.py index 4693a62de2402..8aca3e482c95f 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint.py @@ -1,7 +1,16 @@ # mypy: allow-untyped-defs -from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_sub, op_mul, op_div, \ - op_mod, op_gt, op_lt, op_neq, op_eq -from torch.fx.tensor_type import TensorType, Dyn +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_div, + op_eq, + op_gt, + op_lt, + op_mod, + op_mul, + op_neq, + op_sub, +) +from torch.fx.tensor_type import Dyn, TensorType class Constraint: @@ -22,7 +31,7 @@ def __eq__(self, other): return False def __repr__(self): - return f'And({self.conjucts})' + return f"And({self.conjucts})" class Disj(Constraint): @@ -34,12 +43,14 @@ def __init__(self, disjuncts): def __eq__(self, other): if isinstance(other, Disj): - return self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts + return ( + self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts + ) else: return False def __repr__(self): - return f'Or({self.disjuncts})' + return f"Or({self.disjuncts})" class Prod(Constraint): @@ -56,13 +67,14 @@ def __eq__(self, other): return False def __repr__(self): - return f'Product({self.products})' + return f"Product({self.products})" class T(Constraint): """ True """ + def __init__(self) -> None: pass @@ -70,12 +82,14 @@ def __eq__(self, other): return isinstance(other, T) def __repr__(self): - return 'True' + return "True" + class F(Constraint): """ False """ + def __init__(self) -> None: pass @@ -83,13 +97,14 @@ def __eq__(self, other): return isinstance(other, F) def __repr__(self): - return 'False' + return "False" class BinaryConstraint(Constraint): """ Represents all binary operations """ + def __init__(self, lhs, rhs, op): """ :param lhs: lhs of the constraint @@ -102,21 +117,25 @@ def __init__(self, lhs, rhs, op): def __eq__(self, other): if isinstance(other, BinaryConstraint): - return self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op + return ( + self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op + ) else: return False def __repr__(self): - return f'({self.lhs} {self.op} {self.rhs})' + return f"({self.lhs} {self.op} {self.rhs})" class BinConstraintT(BinaryConstraint): """ Binary constraints about tensors """ + def __init__(self, lhs, rhs, op): - assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and \ - (isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn) + assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and ( + isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn + ) super().__init__(lhs, rhs, op) def __eq__(self, other): @@ -127,6 +146,7 @@ class BinConstraintD(BinaryConstraint): """ Binary constraints about dimensions """ + def __init__(self, lhs, rhs, op): assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs) assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs) @@ -137,11 +157,11 @@ def __eq__(self, other): return super().__eq__(other) - class TGreatestUpperBound(Constraint): """ Greatest Upper bound for tensors with dynamic type """ + def __init__(self, res, rhs1, rhs2): """ :param res: tensor variable that stores the result of the outout @@ -153,11 +173,15 @@ def __init__(self, res, rhs1, rhs2): self.rhs2 = rhs2 def __repr__(self): - return f'{self.res} = {self.rhs1}\u2294*{self.rhs2}' + return f"{self.res} = {self.rhs1}\u2294*{self.rhs2}" def __eq__(self, other): if isinstance(other, TGreatestUpperBound): - return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 + return ( + self.res == other.res + and self.rhs1 == other.rhs1 + and self.rhs2 == other.rhs2 + ) else: return False @@ -166,6 +190,7 @@ class DGreatestUpperBound(Constraint): """ Greatest Upper bound for dimensions """ + def __init__(self, res, rhs1, rhs2): """ :param res: Dimension variable to store the result @@ -181,11 +206,15 @@ def __init__(self, res, rhs1, rhs2): self.rhs2 = rhs2 def __repr__(self): - return f'{self.res} = {self.rhs1}\u2294{self.rhs2}' + return f"{self.res} = {self.rhs1}\u2294{self.rhs2}" def __eq__(self, other): if isinstance(other, DGreatestUpperBound): - return self.res == other.res and self.rhs1 == other.rhs1 and self.rhs2 == other.rhs2 + return ( + self.res == other.res + and self.rhs1 == other.rhs1 + and self.rhs2 == other.rhs2 + ) else: return False @@ -194,6 +223,7 @@ class CanReshape(Constraint): """ can_reshape constraint """ + def __init__(self, src, target): """ :param src: tensor variable @@ -203,7 +233,7 @@ def __init__(self, src, target): self.target = target def __repr__(self): - return f'can-reshape({self.src}, {self.target})' + return f"can-reshape({self.src}, {self.target})" def __eq__(self, other): if isinstance(other, CanReshape): @@ -213,7 +243,6 @@ def __eq__(self, other): class IndexSelect(Constraint): - def __init__(self, tensor_size, input_var, dim_replace, index, output): """ Args: @@ -235,26 +264,28 @@ def __init__(self, tensor_size, input_var, dim_replace, index, output): self.output = output def __repr__(self): - - return f' {self.output} = ' \ - f'IndexSelect({self.input_var}, ' \ - f'tensor_size: {self.tensor_size}, ' \ - f'{self.dim_replace}, ' \ - f'{self.index})' + return ( + f" {self.output} = " + f"IndexSelect({self.input_var}, " + f"tensor_size: {self.tensor_size}, " + f"{self.dim_replace}, " + f"{self.index})" + ) def __eq__(self, other): if isinstance(other, IndexSelect): - return self.tensor_size == other.tensor_size and \ - self.dim_replace == other.dim_replace and \ - self.index == other.index and \ - self.output == other.output and \ - self.input_var == other.input_var + return ( + self.tensor_size == other.tensor_size + and self.dim_replace == other.dim_replace + and self.index == other.index + and self.output == other.output + and self.input_var == other.input_var + ) else: return False class Transpose(Constraint): - def __init__(self, tensor_size, input_var, index1, index2, output): """ Args: @@ -276,26 +307,28 @@ def __init__(self, tensor_size, input_var, index1, index2, output): self.output = output def __repr__(self): - - return f' {self.output} = ' \ - f'Transpose({self.input_var}, ' \ - f'tensor_size: {self.tensor_size}, ' \ - f'{self.index1}, ' \ - f'{self.index2})' + return ( + f" {self.output} = " + f"Transpose({self.input_var}, " + f"tensor_size: {self.tensor_size}, " + f"{self.index1}, " + f"{self.index2})" + ) def __eq__(self, other): if isinstance(other, Transpose): - return self.tensor_size == other.tensor_size and \ - self.index1 == other.index1 and \ - self.index2 == other.index2 and \ - self.output == other.output and \ - self.input_var == other.input_var + return ( + self.tensor_size == other.tensor_size + and self.index1 == other.index1 + and self.index2 == other.index2 + and self.output == other.output + and self.input_var == other.input_var + ) else: return False class GetItem(Constraint): - def __init__(self, tensor_size, index, res, input_var): """ Constraint for getting item given a tensor size @@ -312,19 +345,21 @@ def __init__(self, tensor_size, index, res, input_var): self.input_var = input_var def __repr__(self): - return f' {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})' + return f" {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})" def __eq__(self, other): if isinstance(other, GetItem): - return self.res == other.res and \ - self.tensor_size == other.tensor_size and \ - self.index == other.index and \ - self.input_var == other.input_var + return ( + self.res == other.res + and self.tensor_size == other.tensor_size + and self.index == other.index + and self.input_var == other.input_var + ) else: return False -class GetItemTensor(Constraint): +class GetItemTensor(Constraint): def __init__(self, tensor_size, index_tuple, res, input_var): """ Constraint for getting item given a tensor size @@ -343,20 +378,32 @@ def __init__(self, tensor_size, index_tuple, res, input_var): self.input_var = input_var def __repr__(self): - return f' {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})' + return f" {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})" def __eq__(self, other): if isinstance(other, GetItemTensor): - return self.res == other.res and \ - self.tensor_size == other.tensor_size and \ - self.index_tuple == other.index_tuple and \ - self.input_var == other.input_var + return ( + self.res == other.res + and self.tensor_size == other.tensor_size + and self.index_tuple == other.index_tuple + and self.input_var == other.input_var + ) else: return False -class CalcConv(Constraint): - def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilation, matching_constraint_vars): +class CalcConv(Constraint): + def __init__( + self, + conv_result, + input_var, + c_out, + kernel, + padding, + stride, + dilation, + matching_constraint_vars, + ): """ :param conv_result: the convolution result :param input_var: input to convolution @@ -373,25 +420,41 @@ def __init__(self, conv_result, input_var, c_out, kernel, padding, stride, dilat self.matching_constraint = matching_constraint_vars def __repr__(self): - return f'{self.conv_result} =' \ - f' calc-conv({self.input_var},' \ - f' {self.c_out}, {self.kernel}, ' \ - f'{self.padding}, {self.stride},' \ - f' {self.dilation})' + return ( + f"{self.conv_result} =" + f" calc-conv({self.input_var}," + f" {self.c_out}, {self.kernel}, " + f"{self.padding}, {self.stride}," + f" {self.dilation})" + ) def __eq__(self, other): if isinstance(other, CalcConv): - return self.conv_result == other.conv_result and self.input_var == other.input_var and \ - self.c_out == other.c_out and self.kernel == other.kernel and self.padding == other.padding \ - and self.stride == other.stride and self.dilation == other.dilation \ + return ( + self.conv_result == other.conv_result + and self.input_var == other.input_var + and self.c_out == other.c_out + and self.kernel == other.kernel + and self.padding == other.padding + and self.stride == other.stride + and self.dilation == other.dilation and self.matching_constraint == other.matching_constraint + ) else: return False class CalcMaxPool(Constraint): - - def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, matching_constraint_vars): + def __init__( + self, + maxpool_result, + input_var, + kernel, + padding, + stride, + dilation, + matching_constraint_vars, + ): """ :param maxpool_result: the result of maxpool :param input_var: input to convolution @@ -406,18 +469,25 @@ def __init__(self, maxpool_result, input_var, kernel, padding, stride, dilation, self.matching_constraint = matching_constraint_vars def __repr__(self): - return f'{self.maxpool_result} =' \ - f' calc-maxpool({self.input_var},' \ - f' {self.kernel}, ' \ - f'{self.padding}, {self.stride},' \ - f' {self.dilation})' + return ( + f"{self.maxpool_result} =" + f" calc-maxpool({self.input_var}," + f" {self.kernel}, " + f"{self.padding}, {self.stride}," + f" {self.dilation})" + ) def __eq__(self, other): if isinstance(other, CalcMaxPool): - return self.maxpool_result == other.maxpool_result and self.input_var == other.input_var \ - and self.kernel == other.kernel and self.padding == other.padding \ - and self.stride == other.stride and self.dilation == other.dilation \ + return ( + self.maxpool_result == other.maxpool_result + and self.input_var == other.input_var + and self.kernel == other.kernel + and self.padding == other.padding + and self.stride == other.stride + and self.dilation == other.dilation and self.matching_constraint == other.matching_constraint + ) else: return False @@ -437,21 +507,28 @@ def __init__(self, res1, res2, input1, input2): def __eq__(self, other): if isinstance(other, ApplyBroadcasting): - return self.res1 == other.res1 \ - and self.res2 == other.res2 \ - and self.input1 == other.input1 \ + return ( + self.res1 == other.res1 + and self.res2 == other.res2 + and self.input1 == other.input1 and self.input2 == other.input2 + ) else: return False def __repr__(self): - return f'{self.res1}, {self.res2} ='f' apply-broadcasting({self.input1},' f' {self.input2})' + return ( + f"{self.res1}, {self.res2} =" + f" apply-broadcasting({self.input1}," + f" {self.input2})" + ) class CalcProduct(Constraint): """ Given correct dimensions, calculate the product for flatten accounting for Dyn """ + def __init__(self, start, end, flattened, dims_to_flatten): """ :param start: start index @@ -471,20 +548,25 @@ def __init__(self, start, end, flattened, dims_to_flatten): def __eq__(self, other): if isinstance(other, CalcProduct): - return self.start == other.start and self.end == other.end and \ - self.dims_to_flatten == other.dims_to_flatten and self.flattened == other.flattened + return ( + self.start == other.start + and self.end == other.end + and self.dims_to_flatten == other.dims_to_flatten + and self.flattened == other.flattened + ) else: return False def __repr__(self): - return f'{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})' + return f"{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})" class TVar: """ Tensor variable with no tensor constructor """ + def __init__(self, tvar): """ :param tvar: tensor variable @@ -492,7 +574,7 @@ def __init__(self, tvar): self.tvar = tvar def __repr__(self): - return f'TV({self.tvar})' + return f"TV({self.tvar})" def __eq__(self, other): if isinstance(other, TVar): @@ -505,6 +587,7 @@ class DVar: """ Dimension variable """ + def __init__(self, c): """ :param c: character or number @@ -512,7 +595,7 @@ def __init__(self, c): self.c = c def __repr__(self): - return f'DV({self.c})' + return f"DV({self.c})" def __eq__(self, other): if isinstance(other, DVar): @@ -525,6 +608,7 @@ class BVar: """ Boolean variable """ + def __init__(self, c): """ :param c: character or number @@ -532,7 +616,7 @@ def __init__(self, c): self.c = c def __repr__(self): - return f'BV({self.c})' + return f"BV({self.c})" def __eq__(self, other): if isinstance(other, BVar): @@ -554,5 +638,6 @@ def is_bool_expr(constraint): else: return isinstance(constraint, (BVar, Conj, Disj)) + def is_dim(d): return isinstance(d, (DVar, int)) or d == Dyn diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py index 952dde662f2ab..de7fd66894518 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_generator.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -1,34 +1,71 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs -import torch import operator import warnings from typing import Callable, Dict, Iterable +import torch from torch.fx._symbolic_trace import _assert_is_none -from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, CalcProduct, \ - Disj, TGreatestUpperBound, CalcMaxPool, CalcConv, Conj, BinConstraintT, CanReshape, BinConstraintD, GetItem, T, F, \ - TVar, DVar, GetItemTensor, IndexSelect, Transpose, DGreatestUpperBound -from torch.fx.experimental.migrate_gradual_types.operation import \ - op_eq, op_matching, op_consistency, op_leq, op_precision, op_gt, op_div, op_sub, op_neq, op_lt, op_add, op_mul -from torch.fx.node import Target, Node -from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar, gen_tvar, \ - gen_bvar - +from torch.fx.experimental.migrate_gradual_types.constraint import ( + ApplyBroadcasting, + BinConstraintD, + BinConstraintT, + CalcConv, + CalcMaxPool, + CalcProduct, + CanReshape, + Conj, + DGreatestUpperBound, + Disj, + DVar, + F, + GetItem, + GetItemTensor, + IndexSelect, + T, + TGreatestUpperBound, + Transpose, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_consistency, + op_div, + op_eq, + op_gt, + op_leq, + op_lt, + op_matching, + op_mul, + op_neq, + op_precision, + op_sub, +) +from torch.fx.experimental.migrate_gradual_types.util import ( + gen_bvar, + gen_dvar, + gen_nat_constraints, + gen_tensor_dims, + gen_tvar, +) +from torch.fx.node import Node, Target from torch.fx.tensor_type import Dyn, TensorType -from torch.nn.modules.conv import Conv2d from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.conv import Conv2d + _INFERENCE_RULES: Dict[Target, Callable] = {} MAX_TENSOR_RANK = 4 + def register_inference_rule(call_target): def register(fn): if call_target in _INFERENCE_RULES: - raise RuntimeError(f'Inference rule already registered for {call_target}!') + raise RuntimeError(f"Inference rule already registered for {call_target}!") _INFERENCE_RULES[call_target] = fn return fn + return register @@ -55,10 +92,11 @@ def get_attr_inference_rule(n: Node, symbols, constraints, counter): input = symbols[n.args[0]] attr = n.args[1] - if attr == 'device': + if attr == "device": return [BinConstraintT(input, output, op_eq)], counter else: - raise NotImplementedError('Not yet implemented') + raise NotImplementedError("Not yet implemented") + @register_inference_rule(torch.bmm) def bmm_inference_rule(n: Node, symbols, constraints, counter): @@ -79,26 +117,53 @@ def bmm_inference_rule(n: Node, symbols, constraints, counter): dims_input1, counter = gen_tensor_dims(3, counter) dims_input2, counter = gen_tensor_dims(3, counter) - inputs_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq), - BinConstraintT(bmm_input2, Dyn, op_eq), - BinConstraintT(bmm_output, Dyn, op_eq)]) - - input1_dyn = Conj([BinConstraintT(bmm_input1, Dyn, op_eq), - BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), - BinConstraintT(bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq)]) - - input2_dyn = Conj([BinConstraintT(bmm_input2, Dyn, op_eq), - BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), - BinConstraintT(bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq)]) - - consistency_constraints = [BinConstraintD(dims_input1[0], dims_input2[0], op_consistency)] + inputs_dyn = Conj( + [ + BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_output, Dyn, op_eq), + ] + ) + + input1_dyn = Conj( + [ + BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT( + bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq + ), + ] + ) + + input2_dyn = Conj( + [ + BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT( + bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq + ), + ] + ) + + consistency_constraints = [ + BinConstraintD(dims_input1[0], dims_input2[0], op_consistency) + ] batch_size, counter = gen_dvar(counter) - inputs_are_tensors = Conj([BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), - BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), - BinConstraintT(bmm_output, TensorType([batch_size, dims_input1[1], dims_input2[2]]), op_eq), - *consistency_constraints, DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0])]) + inputs_are_tensors = Conj( + [ + BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT( + bmm_output, + TensorType([batch_size, dims_input1[1], dims_input2[2]]), + op_eq, + ), + *consistency_constraints, + DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0]), + ] + ) return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter @@ -115,8 +180,6 @@ def index_select_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[1], int) assert isinstance(n.args[2], Node) - - index_select, counter = gen_tvar(counter) symbols[n] = index_select @@ -126,10 +189,30 @@ def index_select_inference_rule(n: Node, symbols, constraints, counter): is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq) is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq) - c2 = Conj([is_size_1, Disj([IndexSelect(i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select) - for i in range(MAX_TENSOR_RANK)])]) - c3 = Conj([is_dyn, Disj([IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) - for i in range(MAX_TENSOR_RANK)])]) + c2 = Conj( + [ + is_size_1, + Disj( + [ + IndexSelect( + i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select + ) + for i in range(MAX_TENSOR_RANK) + ] + ), + ] + ) + c3 = Conj( + [ + is_dyn, + Disj( + [ + IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) + for i in range(MAX_TENSOR_RANK) + ] + ), + ] + ) return [Disj([c2, c3])], counter @@ -158,14 +241,27 @@ def expand_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(symbols[arg], DVar) e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq)) - e2_constraint = BinConstraintT(e2, TensorType([arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]]), op_eq) + e2_constraint = BinConstraintT( + e2, + TensorType( + [arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]] + ), + op_eq, + ) - constraints, counter = gen_broadcasting_constraints(e1, e2, symbols, counter, expand) + constraints, counter = gen_broadcasting_constraints( + e1, e2, symbols, counter, expand + ) # constraint the output size dims, counter = gen_tensor_dims(len(n.args[1:]), counter) nat_constraints = gen_nat_constraints(dims) - c = [BinConstraintT(expand, TensorType(dims), op_eq), *nat_constraints, e2_constraint, *e2_nat_constraints] + c = [ + BinConstraintT(expand, TensorType(dims), op_eq), + *nat_constraints, + e2_constraint, + *e2_nat_constraints, + ] constraints += c return constraints, counter @@ -206,7 +302,7 @@ def equality_inference_rule(n: Node, symbols, constraints, counter): my_size = [symbols[arg] for arg in n.args[0]] return [BinConstraintT(output, TensorType(my_size), op_eq)], counter else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") @register_inference_rule("transpose") @@ -225,10 +321,17 @@ def transpose_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(from_arg, TVar) # input and output are dyn - is_dyn = Conj([BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)]) + is_dyn = Conj( + [BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)] + ) # or input is a tensor and we actually do the replacement - c3 = Disj([Transpose(i + 1, from_arg, n.args[1], n.args[2], output) for i in range(MAX_TENSOR_RANK)]) + c3 = Disj( + [ + Transpose(i + 1, from_arg, n.args[1], n.args[2], output) + for i in range(MAX_TENSOR_RANK) + ] + ) return [Disj([is_dyn, c3])], counter @@ -250,8 +353,11 @@ def type_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(from_arg, TVar) assert isinstance(to_arg, TVar) - return [BinConstraintT(from_arg, to_arg, op_consistency), - BinConstraintT(output, to_arg, op_eq)], counter + return [ + BinConstraintT(from_arg, to_arg, op_consistency), + BinConstraintT(output, to_arg, op_eq), + ], counter + @register_inference_rule("masked_fill_") def masked_fill_inference_rule(n: Node, symbols, constraints, counter): @@ -273,9 +379,11 @@ def masked_fill_inference_rule(n: Node, symbols, constraints, counter): if isinstance(e1, TVar) and isinstance(e2, TVar): masked_fill_tensor, counter = gen_tvar(counter) symbols[n] = masked_fill_tensor - return gen_broadcasting_constraints(e1, e2, symbols, counter, masked_fill_tensor) + return gen_broadcasting_constraints( + e1, e2, symbols, counter, masked_fill_tensor + ) else: - raise NotImplementedError('Not yet implemented') + raise NotImplementedError("Not yet implemented") @register_inference_rule(torch.nn.functional.embedding) @@ -286,7 +394,9 @@ def embedding_inference_rule_functional(n: Node, symbols, constraints, counter): # will treat this as a static shape. So we will not use matching. weight_dims, counter = gen_tensor_dims(2, counter) - equality_constraint = BinConstraintT(embedding_dim_weights, TensorType(weight_dims), op_eq) + equality_constraint = BinConstraintT( + embedding_dim_weights, TensorType(weight_dims), op_eq + ) embedding_dim = weight_dims[1] constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter) return [equality_constraint] + constraints, counter @@ -302,7 +412,6 @@ def embedding_inference_rule(n: Node, module_instance, symbols, constraints, cou def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): - embedding_output, counter = gen_tvar(counter) symbols[n] = embedding_output embedding_input = symbols[n.args[0]] @@ -318,9 +427,15 @@ def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): nat_constraints = gen_nat_constraints(new_dims) # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases - c_tensor_i = Conj([BinConstraintT(embedding_input, TensorType(new_dims), op_eq), - BinConstraintT(embedding_output, TensorType(new_dims + [embedding_dim]), op_eq)] + - nat_constraints) + c_tensor_i = Conj( + [ + BinConstraintT(embedding_input, TensorType(new_dims), op_eq), + BinConstraintT( + embedding_output, TensorType(new_dims + [embedding_dim]), op_eq + ), + ] + + nat_constraints + ) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter @@ -348,9 +463,10 @@ def view_inference_rule(n: Node, symbols, constraints, counter): my_view, counter = gen_tvar(counter) symbols[n] = my_view - src_var = symbols[n.args[0]] - t2 = [symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:]] # target shape + t2 = [ + symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:] + ] # target shape t2_type = [] num_constraints = [] @@ -382,7 +498,6 @@ def size_inference_rule(n: Node, symbols, constraints, counter): Ex: size = input_ids.size() """ - if len(n.args) == 1: # generate the new variable size, counter = gen_tvar(counter) @@ -398,7 +513,10 @@ def size_inference_rule(n: Node, symbols, constraints, counter): size_index, counter = gen_dvar(counter) symbols[n] = size_index input = symbols[n.args[0]] - c2 = [GetItem(i + 1, n.args[1], size_index, input) for i in range(MAX_TENSOR_RANK)] + c2 = [ + GetItem(i + 1, n.args[1], size_index, input) + for i in range(MAX_TENSOR_RANK) + ] c3 = BinConstraintD(0, size_index, op_leq) input_dyn = BinConstraintT(input, Dyn, op_eq) @@ -452,9 +570,14 @@ def cumsum_inference_rule(n: Node, symbols, constraints, counter): nat_constraints = gen_nat_constraints(new_dims) - c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims), op_eq), - BinConstraintT(output, TensorType(new_dims), op_eq)] + - [range_check(arg_1, i)] + nat_constraints) + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims), op_eq), + BinConstraintT(output, TensorType(new_dims), op_eq), + ] + + [range_check(arg_1, i)] + + nat_constraints + ) c2.append(c_tensor_i) dyn_or_tensor = Disj([c1, Disj(c2)]) @@ -481,7 +604,6 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter): get_item_arg = symbols[n.args[0]] assert isinstance(get_item_arg, TVar) - # if the input is dynamic, we accept any index and return # a dynamic dimension as output input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) @@ -492,8 +614,10 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter): # generate a getItem constraint which will be expanded based on the # tensor dimension. - c2 = [GetItem(i + 1, n.args[1], get_item_output, get_item_arg) for i in range(MAX_TENSOR_RANK)] - + c2 = [ + GetItem(i + 1, n.args[1], get_item_output, get_item_arg) + for i in range(MAX_TENSOR_RANK) + ] # since the output is a dimension, we make sure it's a natural number # added as a conjunction to the disjunction of c2 @@ -515,8 +639,10 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter): output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment] c1 = Conj([input_dyn, output_dyn]) - c2 = [GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc] - for i in range(MAX_TENSOR_RANK)] + c2 = [ + GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc] + for i in range(MAX_TENSOR_RANK) + ] else: # TODO: we should figure out why there is a key-error here. return [], counter @@ -524,7 +650,7 @@ def getitem_inference_rule(n: Node, symbols, constraints, counter): return [Disj([c1, *c2])], counter else: - raise RuntimeError('Method not yet implemented') + raise RuntimeError("Method not yet implemented") @register_inference_rule(operator.gt) @@ -553,7 +679,7 @@ def gt_inference_rule(n: Node, symbols, constraints, counter): return [equality_constraint], counter else: - raise RuntimeError('Sort Mismatch') + raise RuntimeError("Sort Mismatch") elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): if isinstance(e1, DVar): @@ -567,7 +693,9 @@ def gt_inference_rule(n: Node, symbols, constraints, counter): elif isinstance(e1, TVar) and isinstance(e2, int): # then we made the wrong assumption about the argument being a tensor # so we should fix the assumption - warnings.warn(f'Made the wrong assumption for node {n}. Correctness not guaranteed.') + warnings.warn( + f"Made the wrong assumption for node {n}. Correctness not guaranteed." + ) new_e1, counter = gen_dvar(counter) symbols[n.args[0]] = new_e1 @@ -580,10 +708,10 @@ def gt_inference_rule(n: Node, symbols, constraints, counter): return [equality_constraint], counter else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") @register_inference_rule(operator.eq) @@ -609,7 +737,7 @@ def eq_inference_rule(n: Node, symbols, constraints, counter): return [equality_constraint], counter else: - raise RuntimeError('Sort Mismatch') + raise RuntimeError("Sort Mismatch") elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): if isinstance(e1, DVar): @@ -620,9 +748,10 @@ def eq_inference_rule(n: Node, symbols, constraints, counter): equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) return [equality_constraint], counter else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") + @register_inference_rule(operator.ne) def neq_inference_rule(n: Node, symbols, constraints, counter): @@ -641,7 +770,6 @@ def neq_inference_rule(n: Node, symbols, constraints, counter): # implementing for size 3 and 4 if len(n.args[1]) == 3: - assert isinstance(n.args[1][0], (Node, int)) assert isinstance(n.args[1][1], (Node, int)) assert isinstance(n.args[1][2], (Node, int)) @@ -662,11 +790,19 @@ def neq_inference_rule(n: Node, symbols, constraints, counter): neq_3 = BinConstraintD(d3, b[2], op_neq) # dimensions inconsistent - dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1]) - dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2]) - dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3]) - - dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3]) + dims_inconsistent1 = Conj( + [BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1] + ) + dims_inconsistent2 = Conj( + [BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2] + ) + dims_inconsistent3 = Conj( + [BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3] + ) + + dims_inconsistent = Disj( + [dims_inconsistent1, dims_inconsistent2, dims_inconsistent3] + ) # we are covering size 3 and 4 only for now ne_constraint = Conj([input_is_size3, dims_inconsistent]) @@ -675,7 +811,6 @@ def neq_inference_rule(n: Node, symbols, constraints, counter): equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) elif len(n.args[1]) == 4: - assert isinstance(n.args[1][0], (Node, int)) assert isinstance(n.args[1][1], (Node, int)) assert isinstance(n.args[1][2], (Node, int)) @@ -703,12 +838,27 @@ def neq_inference_rule(n: Node, symbols, constraints, counter): neq_4 = BinConstraintD(d4, b4, op_neq) # dimensions to inconsistent - dims_inconsistent1 = Conj([BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1]) - dims_inconsistent2 = Conj([BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2]) - dims_inconsistent3 = Conj([BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3]) - dims_inconsistent4 = Conj([BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4]) - - dims_inconsistent = Disj([dims_inconsistent1, dims_inconsistent2, dims_inconsistent3, dims_inconsistent4]) + dims_inconsistent1 = Conj( + [BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1] + ) + dims_inconsistent2 = Conj( + [BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2] + ) + dims_inconsistent3 = Conj( + [BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3] + ) + dims_inconsistent4 = Conj( + [BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4] + ) + + dims_inconsistent = Disj( + [ + dims_inconsistent1, + dims_inconsistent2, + dims_inconsistent3, + dims_inconsistent4, + ] + ) ne_constraint = Conj([input_is_size4, dims_inconsistent]) @@ -717,7 +867,7 @@ def neq_inference_rule(n: Node, symbols, constraints, counter): equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") return [equality_constraint], counter @@ -748,7 +898,7 @@ def lt_inference_rule(n: Node, symbols, constraints, counter): return [equality_constraint], counter else: - raise RuntimeError('Sort Mismatch') + raise RuntimeError("Sort Mismatch") elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): if isinstance(e1, DVar): @@ -759,10 +909,10 @@ def lt_inference_rule(n: Node, symbols, constraints, counter): equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) return [equality_constraint], counter else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") @register_inference_rule(torch.full) @@ -788,28 +938,42 @@ def arange_inference_rule(n: Node, symbols, constraints, counter): if len(n.args) == 1: end = symbols[n.args[0]] else: - raise NotImplementedError('Not yet implemented') + raise NotImplementedError("Not yet implemented") # int((end - start) / step) d1, counter = gen_dvar(counter) - size_constraint = BinConstraintD(d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq) + size_constraint = BinConstraintD( + d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq + ) arange, counter = gen_tvar(counter) symbols[n] = arange # either the a parameter is a number or it is Dyn - c1 = Disj([BinConstraintD(end, Dyn, op_eq), - BinConstraintD(start, Dyn, op_eq), - BinConstraintD(step, Dyn, op_eq)]) + c1 = Disj( + [ + BinConstraintD(end, Dyn, op_eq), + BinConstraintD(start, Dyn, op_eq), + BinConstraintD(step, Dyn, op_eq), + ] + ) c2 = BinConstraintD(d1, Dyn, op_eq) both_dyn = Conj([c1, c2]) - c11 = Conj([BinConstraintD(end, Dyn, op_neq), - BinConstraintD(start, Dyn, op_neq), - BinConstraintD(step, Dyn, op_neq)]) + c11 = Conj( + [ + BinConstraintD(end, Dyn, op_neq), + BinConstraintD(start, Dyn, op_neq), + BinConstraintD(step, Dyn, op_neq), + ] + ) c22 = BinConstraintD(d1, Dyn, op_neq) both_numbers = Conj([c11, c22, size_constraint]) - return [BinConstraintT(arange, TensorType([d1]), op_eq), Disj([both_dyn, both_numbers])], counter + return [ + BinConstraintT(arange, TensorType([d1]), op_eq), + Disj([both_dyn, both_numbers]), + ], counter + def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): # additional vars that don't correspond to expressions @@ -829,7 +993,6 @@ def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): @register_inference_rule(torch.add) @register_inference_rule(operator.add) def broadcasting_inference_rule(n: Node, symbols, constraints, counter): - op_code = None if n.target == operator.add or n.target == torch.add: op_code = op_add @@ -837,7 +1000,9 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter): op_code = op_mul if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): - if isinstance(symbols[n.args[0]], TVar) and isinstance(symbols[n.args[1]], TVar): + if isinstance(symbols[n.args[0]], TVar) and isinstance( + symbols[n.args[1]], TVar + ): my_output, counter = gen_tvar(counter) symbols[n] = my_output e1 = symbols[n.args[0]] @@ -845,7 +1010,7 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter): return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output) else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)): if isinstance(symbols[n.args[0]], TVar): @@ -859,8 +1024,14 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter): e1 = symbols[n.args[0]] # we will propagate the runtime value here since this is regular addition - c = Conj([BinConstraintD(my_output, BinConstraintD(e1, n.args[1], op_code), op_eq), - BinConstraintD(0, my_output, op_leq)]) + c = Conj( + [ + BinConstraintD( + my_output, BinConstraintD(e1, n.args[1], op_code), op_eq + ), + BinConstraintD(0, my_output, op_leq), + ] + ) return [c], counter elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)): @@ -875,16 +1046,22 @@ def broadcasting_inference_rule(n: Node, symbols, constraints, counter): e2 = symbols[n.args[1]] # we will propagate the runtime value here since this is regular addition - c = Conj([BinConstraintD(my_output, BinConstraintD(e2, n.args[0], op_code), op_eq), - BinConstraintD(0, my_output, op_leq)]) + c = Conj( + [ + BinConstraintD( + my_output, BinConstraintD(e2, n.args[0], op_code), op_eq + ), + BinConstraintD(0, my_output, op_leq), + ] + ) return [c], counter else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") else: # TODO generate add constraints for scalar addition - raise NotImplementedError('Addition not yet implemented') + raise NotImplementedError("Addition not yet implemented") @register_inference_rule(torch.flatten) @@ -915,7 +1092,9 @@ def flatten_inference_rule(n: Node, symbols, constraints, counter): const = [] for i in range(1, MAX_TENSOR_RANK + 1): - c, counter = generate_flatten_constraints(start_dim, end_dim, input, flattened, i, counter) + c, counter = generate_flatten_constraints( + start_dim, end_dim, input, flattened, i, counter + ) const.append(c) return [Disj([both_dyn, *const])], counter @@ -937,7 +1116,9 @@ def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, co Input should be consistent with the normalized_shape """ assert isinstance(n.args[0], Node) - return gen_layer_norm_constraints(n, module_instance.normalized_shape, symbols, counter) + return gen_layer_norm_constraints( + n, module_instance.normalized_shape, symbols, counter + ) def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): @@ -955,13 +1136,18 @@ def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): new_dims_rhs, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(new_dims_rhs) - c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs), op_eq), - BinConstraintT(output, TensorType(new_dims_rhs), op_eq)] + - add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) + - nat_constraints) + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims_rhs), op_eq), + BinConstraintT(output, TensorType(new_dims_rhs), op_eq), + ] + + add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) + + nat_constraints + ) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter + @register_inference_rule(torch.nn.Dropout) @register_inference_rule(torch.nn.ReLU) def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter): @@ -983,7 +1169,9 @@ def linear_inference_rule(n: Node, module_instance, symbols, constraints, counte If the input is Dyn, then so should the output """ assert isinstance(n.args[0], Node) - return linear_constraints(n, module_instance.in_features, module_instance.out_features, symbols, counter) + return linear_constraints( + n, module_instance.in_features, module_instance.out_features, symbols, counter + ) @register_inference_rule("dim") # type: ignore[attr-defined] @@ -1001,8 +1189,12 @@ def torch_dim_inference_rule(n: Node, symbols, constraints, counter): for i in range(1, MAX_TENSOR_RANK + 1): new_dims_rhs_1, counter = gen_tensor_dims(i, counter) - c_tensor_i = Conj([BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq), - BinConstraintD(my_dim, i, op_eq)]) + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintD(my_dim, i, op_eq), + ] + ) c1.append(c_tensor_i) return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter @@ -1012,8 +1204,12 @@ def torch_dim_inference_rule(n: Node, symbols, constraints, counter): def torch_linear_inference_rule(n: Node, symbols, constraints, counter): assert isinstance(n.args[0], Node) weight_dims, counter = gen_tensor_dims(2, counter) - equality_constraint = BinConstraintT(symbols[n.args[1]], TensorType(weight_dims), op_eq) - constraints, counter = linear_constraints(n, weight_dims[1], weight_dims[0], symbols, counter) + equality_constraint = BinConstraintT( + symbols[n.args[1]], TensorType(weight_dims), op_eq + ) + constraints, counter = linear_constraints( + n, weight_dims[1], weight_dims[0], symbols, counter + ) return [equality_constraint] + constraints, counter @@ -1034,13 +1230,20 @@ def linear_constraints(n: Node, in_features, out_features, symbols, counter): nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) - c_tensor_i = Conj([BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), - BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq)] + - add_linear_constraints(new_dims_rhs_1, new_dims_rhs_2, in_features, out_features) + - nat_constraints) + c_tensor_i = Conj( + [ + BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq), + ] + + add_linear_constraints( + new_dims_rhs_1, new_dims_rhs_2, in_features, out_features + ) + + nat_constraints + ) c2.append(c_tensor_i) return [Disj([c1, Disj(c2)])], counter + def add_layer_norm_constraints(input_dim, normalized_dim): """ The constraints say that the type has te form: [*, 1024, 1024] @@ -1130,7 +1333,13 @@ def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, coun d4, counter = gen_dvar(counter) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) - c2 = BinConstraintT(avg_pool, TensorType([d1, d2, module_instance.output_size[0], module_instance.output_size[1]]), op_eq) + c2 = BinConstraintT( + avg_pool, + TensorType( + [d1, d2, module_instance.output_size[0], module_instance.output_size[1]] + ), + op_eq, + ) return [c1, c2, *nat_constraints], counter @@ -1152,12 +1361,16 @@ def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counte # c2 = DConsistency(module_instance.in_channels, d2) c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency) - c3 = CalcConv(my_conv, input_var, - module_instance.out_channels, - module_instance.kernel_size, - module_instance.padding, - module_instance.stride, - module_instance.dilation, [d1, d2, d3, d4]) + c3 = CalcConv( + my_conv, + input_var, + module_instance.out_channels, + module_instance.kernel_size, + module_instance.padding, + module_instance.stride, + module_instance.dilation, + [d1, d2, d3, d4], + ) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) @@ -1176,8 +1389,15 @@ def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, count c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) - c2 = CalcMaxPool(maxpool, input_var, module_instance.kernel_size, module_instance.padding, - module_instance.stride, module_instance.dilation, [d1, d2, d3, d4]) + c2 = CalcMaxPool( + maxpool, + input_var, + module_instance.kernel_size, + module_instance.padding, + module_instance.stride, + module_instance.dilation, + [d1, d2, d3, d4], + ) nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) @@ -1190,8 +1410,7 @@ def __init__(self, traced, graph=None): self.traced_params = dict(self.traced.named_parameters()) self.constraints = [] self.symbol_dict = {} - self.graph = traced.graph if hasattr(traced, 'graph') else graph - + self.graph = traced.graph if hasattr(traced, "graph") else graph def generate_constraints(self, counter=0): """ @@ -1217,7 +1436,7 @@ def generate_constraints_node(self, n: Node, counter): - conv2d """ - if n.op == 'placeholder': + if n.op == "placeholder": x, counter = gen_tvar(counter) self.symbol_dict[n] = x @@ -1226,8 +1445,8 @@ def generate_constraints_node(self, n: Node, counter): if n.type != Dyn and (not isinstance(n.type, TensorType)): if n.type == torch.nn.parameter.Parameter: # since we have a parameter, the shape must be static - assert 'example_value' in n.meta - my_type = TensorType(n.meta['example_value'].size()) + assert "example_value" in n.meta + my_type = TensorType(n.meta["example_value"].size()) else: my_type = Dyn @@ -1235,30 +1454,38 @@ def generate_constraints_node(self, n: Node, counter): c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq) return [c1, c2], counter - elif n.op == 'call_function': + elif n.op == "call_function": if n.target in _INFERENCE_RULES: - return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) + return _INFERENCE_RULES[n.target]( + n, self.symbol_dict, self.constraints, counter + ) else: - raise RuntimeError(f'No inference rule registered for target {n.target}!') - - elif n.op == 'call_module': + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) + elif n.op == "call_module": module_instance = self.traced.get_submodule(n.target) if type(module_instance) in _INFERENCE_RULES: - return _INFERENCE_RULES[type(module_instance)](n, - module_instance, - self.symbol_dict, - self.constraints, counter) + return _INFERENCE_RULES[type(module_instance)]( + n, module_instance, self.symbol_dict, self.constraints, counter + ) else: - raise RuntimeError(f'No inference rule registered for class {type(module_instance)}!') + raise RuntimeError( + f"No inference rule registered for class {type(module_instance)}!" + ) - elif n.op == 'call_method': + elif n.op == "call_method": if n.target in _INFERENCE_RULES: - return _INFERENCE_RULES[n.target](n, self.symbol_dict, self.constraints, counter) + return _INFERENCE_RULES[n.target]( + n, self.symbol_dict, self.constraints, counter + ) else: - raise RuntimeError(f'No inference rule registered for target {n.target}!') + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) - elif n.op == 'get_attr': + elif n.op == "get_attr": t = self.traced_params.get(n.target, None) if isinstance(t, torch.Tensor): @@ -1274,7 +1501,7 @@ def generate_constraints_node(self, n: Node, counter): else: return [], counter - elif n.op == 'output': + elif n.op == "output": return [], counter else: diff --git a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py index 439e3d6195e65..7a854b1dabe86 100644 --- a/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py +++ b/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py @@ -1,30 +1,67 @@ # mypy: ignore-errors import copy import itertools -from torch.fx.experimental.migrate_gradual_types.constraint_generator import BinConstraintT, MAX_TENSOR_RANK -from torch.fx.experimental.migrate_gradual_types.constraint import T, BinConstraintD, Conj, Constraint, DVar, TVar, \ - Transpose -from torch.fx.experimental.migrate_gradual_types.constraint import Disj, TGreatestUpperBound -from torch.fx.experimental.migrate_gradual_types.constraint import DGreatestUpperBound -from torch.fx.experimental.migrate_gradual_types.constraint import CalcConv, CalcMaxPool -from torch.fx.experimental.migrate_gradual_types.constraint import CalcProduct, CanReshape -from torch.fx.experimental.migrate_gradual_types.constraint import ApplyBroadcasting, Prod, F, GetItem, GetItemTensor, IndexSelect -from torch.fx.experimental.migrate_gradual_types.operation import op_eq, op_precision, op_leq, op_matching -from torch.fx.experimental.migrate_gradual_types.operation import op_consistency, op_neq -from torch.fx.experimental.migrate_gradual_types.operation import op_mul, op_add, op_sub, op_div, op_mod -from torch.fx.experimental.migrate_gradual_types.util import gen_tensor_dims, gen_nat_constraints, gen_dvar -from torch.fx.tensor_type import TensorType, Dyn from typing import Callable, Dict, List +from torch.fx.experimental.migrate_gradual_types.constraint import ( + ApplyBroadcasting, + BinConstraintD, + CalcConv, + CalcMaxPool, + CalcProduct, + CanReshape, + Conj, + Constraint, + DGreatestUpperBound, + Disj, + DVar, + F, + GetItem, + GetItemTensor, + IndexSelect, + Prod, + T, + TGreatestUpperBound, + Transpose, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.constraint_generator import ( + BinConstraintT, + MAX_TENSOR_RANK, +) +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_consistency, + op_div, + op_eq, + op_leq, + op_matching, + op_mod, + op_mul, + op_neq, + op_precision, + op_sub, +) +from torch.fx.experimental.migrate_gradual_types.util import ( + gen_dvar, + gen_nat_constraints, + gen_tensor_dims, +) +from torch.fx.tensor_type import Dyn, TensorType + + _TRANSFORMATION_RULES: Dict[Constraint, Callable] = {} def register_transformation_rule(call_target): def register(fn): if call_target in _TRANSFORMATION_RULES: - raise RuntimeError(f'Transformation rule already registered for {call_target}!') + raise RuntimeError( + f"Transformation rule already registered for {call_target}!" + ) _TRANSFORMATION_RULES[call_target] = fn return fn + return register @@ -54,10 +91,15 @@ def transform_transpose(constraint, counter): new_dims[constraint.index1] = dims[constraint.index2] new_dims[constraint.index2] = dims[constraint.index1] - transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - *nat_constraints, - is_valid_index1, is_valid_index2, - BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) + transformed_constraint = Conj( + [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index1, + is_valid_index2, + BinConstraintT(constraint.output, TensorType(new_dims), op_eq), + ] + ) return transformed_constraint, counter @@ -78,10 +120,14 @@ def transform_index_select(constraint, counter): new_dims = copy.deepcopy(dims) new_dims[constraint.index] = constraint.dim_replace - transformed_constraint = Conj([BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - *nat_constraints, - is_valid_index, - BinConstraintT(constraint.output, TensorType(new_dims), op_eq)]) + transformed_constraint = Conj( + [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index, + BinConstraintT(constraint.output, TensorType(new_dims), op_eq), + ] + ) # print(constraints) return transformed_constraint, counter @@ -106,20 +152,24 @@ def transform_get_item(constraint, counter): dims, counter = gen_tensor_dims(constraint.tensor_size, counter) nat_constraints = gen_nat_constraints(dims) - is_valid_index = valid_index(constraint.index, dims) - all_constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - *nat_constraints, - is_valid_index] + all_constraints = [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index, + ] # if the index is valid, we generate a constraint for getting an item # otherwise this clause will have been UNSAT due to the wrong index if is_valid_index == T(): - all_constraints.append(BinConstraintD(constraint.res, dims[constraint.index], op_eq)) + all_constraints.append( + BinConstraintD(constraint.res, dims[constraint.index], op_eq) + ) return Conj(all_constraints), counter + def valid_index_tensor(index, dims): """ if the slice instances exceed the length of the dimensions @@ -134,6 +184,7 @@ def valid_index_tensor(index, dims): else: return T() + @register_transformation_rule(GetItemTensor) def transform_get_item_tensor(constraint, counter): """ @@ -151,7 +202,6 @@ def transform_get_item_tensor(constraint, counter): """ assert isinstance(constraint.index_tuple, tuple) - # generate a result tensor of the expected size dims, counter = gen_tensor_dims(constraint.tensor_size, counter) nat_constraints = gen_nat_constraints(dims) @@ -163,7 +213,6 @@ def transform_get_item_tensor(constraint, counter): dim_index = 0 for i in range(len(constraint.index_tuple)): - # append 1 to the right location of the resulting tensor if constraint.index_tuple[i] is None: resulting_tensor_dims[i] = 1 @@ -172,7 +221,7 @@ def transform_get_item_tensor(constraint, counter): pass else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") # append the remaining dimensions to the right location dim_index = 0 @@ -189,10 +238,12 @@ def transform_get_item_tensor(constraint, counter): return F(), counter else: - constraints = [BinConstraintT(constraint.input_var, TensorType(dims), op_eq), - BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), - *nat_constraints, - is_valid_index] + constraints = [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), + *nat_constraints, + is_valid_index, + ] return Conj(constraints), counter @@ -217,11 +268,14 @@ def generate_binconstraint_t(constraint, counter): dim, counter = gen_dvar(counter) new_dims.append(dim) - new_dim_constraints = [BinConstraintD(old_dim, new_dim, op_precision) for - new_dim, old_dim in zip(new_dims, constraint.lhs.__args__)] + \ - [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + \ - [BinConstraintD(1, new_dim, op_leq) for - new_dim in new_dims] + new_dim_constraints = ( + [ + BinConstraintD(old_dim, new_dim, op_precision) + for new_dim, old_dim in zip(new_dims, constraint.lhs.__args__) + ] + + [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + + [BinConstraintD(1, new_dim, op_leq) for new_dim in new_dims] + ) return Conj(new_dim_constraints), counter # matching @@ -232,17 +286,39 @@ def generate_binconstraint_t(constraint, counter): d3 = constraint.rhs.__args__[2] d4 = constraint.rhs.__args__[3] - conj = [BinConstraintT(constraint.lhs, Dyn, op_eq), - BinConstraintD(d1, Dyn, op_eq), - BinConstraintD(d2, Dyn, op_eq), - BinConstraintD(d3, Dyn, op_eq), - BinConstraintD(d4, Dyn, op_eq)] - return Disj([Conj(conj), - BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq)]), counter + conj = [ + BinConstraintT(constraint.lhs, Dyn, op_eq), + BinConstraintD(d1, Dyn, op_eq), + BinConstraintD(d2, Dyn, op_eq), + BinConstraintD(d3, Dyn, op_eq), + BinConstraintD(d4, Dyn, op_eq), + ] + return ( + Disj( + [ + Conj(conj), + BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq), + ] + ), + counter, + ) elif constraint.op == op_consistency: - c_dyn = Disj([BinConstraintT(constraint.lhs, Dyn, op_eq), BinConstraintT(constraint.rhs, Dyn, op_eq)]) - [c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4], counter = gen_consistency_constraints(constraint, counter) + c_dyn = Disj( + [ + BinConstraintT(constraint.lhs, Dyn, op_eq), + BinConstraintT(constraint.rhs, Dyn, op_eq), + ] + ) + ( + ( + c_tensor_1, + c_tensor_2, + c_tensor_3, + c_tensor_4, + ), + counter, + ) = gen_consistency_constraints(constraint, counter) return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter @@ -251,7 +327,7 @@ def generate_binconstraint_t(constraint, counter): disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)] for i in range(1, constraint.rhs + 1): dims = [] - for j in range(1, i + 1): + for _ in range(1, i + 1): dim_var, counter = gen_dvar(counter) dims.append(dim_var) disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq)) @@ -272,8 +348,16 @@ def generate_binconstraint_d(constraint, counter): return T(), counter elif constraint.op == op_consistency: - return Disj([BinConstraintD(constraint.lhs, constraint.rhs, op_eq), - BinConstraintD(constraint.rhs, Dyn, op_eq), BinConstraintD(constraint.lhs, Dyn, op_eq)]), counter + return ( + Disj( + [ + BinConstraintD(constraint.lhs, constraint.rhs, op_eq), + BinConstraintD(constraint.rhs, Dyn, op_eq), + BinConstraintD(constraint.lhs, Dyn, op_eq), + ] + ), + counter, + ) else: return constraint, counter @@ -309,8 +393,17 @@ def generate_gub(constraint, counter): Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound on dimensions """ - c1 = Conj([Disj([BinConstraintT(constraint.rhs1, Dyn, op_eq), - BinConstraintT(constraint.rhs2, Dyn, op_eq)]), BinConstraintT(constraint.res, Dyn, op_eq)]) + c1 = Conj( + [ + Disj( + [ + BinConstraintT(constraint.rhs1, Dyn, op_eq), + BinConstraintT(constraint.rhs2, Dyn, op_eq), + ] + ), + BinConstraintT(constraint.res, Dyn, op_eq), + ] + ) [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter) @@ -322,9 +415,24 @@ def generate_d_gub(constraint, counter): """ Transform greatest upper bound for dimensions into equality constraints """ - c1 = Conj([BinConstraintD(constraint.rhs1, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs2, op_eq)]) - c2 = Conj([BinConstraintD(constraint.rhs2, Dyn, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) - c3 = Conj([BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), BinConstraintD(constraint.res, constraint.rhs1, op_eq)]) + c1 = Conj( + [ + BinConstraintD(constraint.rhs1, Dyn, op_eq), + BinConstraintD(constraint.res, constraint.rhs2, op_eq), + ] + ) + c2 = Conj( + [ + BinConstraintD(constraint.rhs2, Dyn, op_eq), + BinConstraintD(constraint.res, constraint.rhs1, op_eq), + ] + ) + c3 = Conj( + [ + BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), + BinConstraintD(constraint.res, constraint.rhs1, op_eq), + ] + ) return Disj([c1, c2, c3]), counter @@ -337,17 +445,26 @@ def generate_calc_conv(constraint, counter): c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq) # the second dimension of the output is equal to the output channels - c2 = Conj([BinConstraintD(d[1], constraint.c_out, op_eq), BinConstraintD(d[1], Dyn, op_neq)]) + c2 = Conj( + [ + BinConstraintD(d[1], constraint.c_out, op_eq), + BinConstraintD(d[1], Dyn, op_neq), + ] + ) # the input corresponds to the output in the first dimension of the convolution c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) c4, c5 = calc_last_two_dims(constraint, d) - leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), - BinConstraintD(0, d[1], op_leq), - BinConstraintD(0, d[2], op_leq), - BinConstraintD(0, d[3], op_leq)]) + leq_constraints = Conj( + [ + BinConstraintD(0, d[0], op_leq), + BinConstraintD(0, d[1], op_leq), + BinConstraintD(0, d[2], op_leq), + BinConstraintD(0, d[3], op_leq), + ] + ) return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter @@ -368,10 +485,14 @@ def generate_calc_maxpool(constraint, counter): c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) c4, c5 = calc_last_two_dims(constraint, d) - leq_constraints = Conj([BinConstraintD(0, d[0], op_leq), - BinConstraintD(0, d[1], op_leq), - BinConstraintD(0, d[2], op_leq), - BinConstraintD(0, d[3], op_leq)]) + leq_constraints = Conj( + [ + BinConstraintD(0, d[0], op_leq), + BinConstraintD(0, d[1], op_leq), + BinConstraintD(0, d[2], op_leq), + BinConstraintD(0, d[3], op_leq), + ] + ) return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter @@ -388,7 +509,7 @@ def generate_calc_product(constraint, counter): n = len(constraint.dims_to_flatten) # this will be evaluated right here - boundary_check = (0 <= start and start < end and end <= n) + boundary_check = 0 <= start and start < end and end <= n c_boundary = T() if boundary_check else F() @@ -410,16 +531,40 @@ def generate_calc_product(constraint, counter): if len(total_constraints) > 4: all_constraints.append(F()) else: - all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq)] + p)) + all_constraints.append( + Conj( + [ + BinConstraintT( + flattened, TensorType(lhs + mid_var + rhs), op_eq + ) + ] + + p + ) + ) else: new_var, counter = gen_dvar(counter) - mid_eq_prod = Conj([BinConstraintD(new_var, Prod(mid), op_eq), BinConstraintD(new_var, Dyn, op_neq)]) + mid_eq_prod = Conj( + [ + BinConstraintD(new_var, Prod(mid), op_eq), + BinConstraintD(new_var, Dyn, op_neq), + ] + ) mid_var = [new_var] total_constraints = lhs + mid_var + rhs if len(total_constraints) > 4: all_constraints.append(F()) else: - all_constraints.append(Conj([BinConstraintT(flattened, TensorType(lhs + mid_var + rhs), op_eq), mid_eq_prod] + p)) + all_constraints.append( + Conj( + [ + BinConstraintT( + flattened, TensorType(lhs + mid_var + rhs), op_eq + ), + mid_eq_prod, + ] + + p + ) + ) return Conj([Disj(all_constraints), c_boundary]), counter @@ -466,22 +611,40 @@ def generate_reshape(constraint, counter): if is_fully_static: # size 1 tensor - c3_tensor1 = Disj([d1_eq_dyn, - (Conj([d1_neq_dyn, - BinConstraintD(d1, Prod(target), op_eq)]))]) + c3_tensor1 = Disj( + [d1_eq_dyn, (Conj([d1_neq_dyn, BinConstraintD(d1, Prod(target), op_eq)]))] + ) all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) # size 2 tensor - all_tensor_2 = Conj([c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)]) + all_tensor_2 = Conj( + [c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)] + ) # size 3 tensor - all_tensor_3 = Conj([c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)]) + all_tensor_3 = Conj( + [c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)] + ) # size 4 tensor - all_tensor_4 = Conj([c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)]) - - return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), - nat_d1, nat_d2, nat_d3, nat_d4]), counter + all_tensor_4 = Conj( + [c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)] + ) + + return ( + Conj( + [ + Disj( + [c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4] + ), + nat_d1, + nat_d2, + nat_d3, + nat_d4, + ] + ), + counter, + ) # then there must be exactly one occurrence of dyn else: @@ -492,28 +655,57 @@ def generate_reshape(constraint, counter): new_target.append(n) # tensor 1 - c3_tensor1 = Disj([d1_eq_dyn, - (Conj([d1_neq_dyn, - is_dim_div_by_target(new_target, d1)]))]) + c3_tensor1 = Disj( + [d1_eq_dyn, (Conj([d1_neq_dyn, is_dim_div_by_target(new_target, d1)]))] + ) all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) # tensor 2 c21 = Disj([d1_eq_dyn, d2_eq_dyn]) - c22 = Conj([d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))]) + c22 = Conj( + [d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))] + ) all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])]) # tensor 3 c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn]) - c32 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3]))]) + c32 = Conj( + [ + d1_neq_dyn, + d2_neq_dyn, + d3_neq_dyn, + is_dim_div_by_target(new_target, Prod([d1, d2, d3])), + ] + ) all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])]) # tensor 4 c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn]) - c42 = Conj([d1_neq_dyn, d2_neq_dyn, d3_neq_dyn, d4_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4]))]) + c42 = Conj( + [ + d1_neq_dyn, + d2_neq_dyn, + d3_neq_dyn, + d4_neq_dyn, + is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4])), + ] + ) all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])]) - return Conj([Disj([c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4]), - nat_d1, nat_d2, nat_d3, nat_d4]), counter + return ( + Conj( + [ + Disj( + [c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4] + ), + nat_d1, + nat_d2, + nat_d3, + nat_d4, + ] + ), + counter, + ) @register_transformation_rule(ApplyBroadcasting) @@ -537,40 +729,58 @@ def generate_broadcasting(constraint, counter): # tensor possibility # generate dimensions to create tensors of size 1 - final_tensor_1_constraint, _, _, nat_dims_1, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 1, counter) + final_tensor_1_constraint, _, _, nat_dims_1, counter = gen_broadcasting_constraints( + e1, e2, e11, e12, 1, counter + ) # generate dimensions to create tensors of size 2 - final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, \ - final_tensor_2_constraint_padding_arg2, nat_dims_2, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter) - - # generate dimensions to create tensors of size 3 - final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, \ - final_tensor_3_constraint_padding_arg2, nat_dims_3, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter) - - # generate dimensions to create tensors of size 4 - final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, \ - final_tensor_4_constraint_padding_arg2, nat_dims_4, counter = \ - gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter) - - final_result = Disj([ - e1_dyn_constraint, - e2_dyn_constraint, - final_tensor_1_constraint, + ( final_tensor_2_constraint_no_padding, final_tensor_2_constraint_padding_arg1, final_tensor_2_constraint_padding_arg2, + nat_dims_2, + counter, + ) = gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter) + + # generate dimensions to create tensors of size 3 + ( final_tensor_3_constraint_no_padding, final_tensor_3_constraint_padding_arg1, final_tensor_3_constraint_padding_arg2, + nat_dims_3, + counter, + ) = gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter) + + # generate dimensions to create tensors of size 4 + ( final_tensor_4_constraint_no_padding, final_tensor_4_constraint_padding_arg1, - final_tensor_4_constraint_padding_arg2 - ]) - - return Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), counter + final_tensor_4_constraint_padding_arg2, + nat_dims_4, + counter, + ) = gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter) + + final_result = Disj( + [ + e1_dyn_constraint, + e2_dyn_constraint, + final_tensor_1_constraint, + final_tensor_2_constraint_no_padding, + final_tensor_2_constraint_padding_arg1, + final_tensor_2_constraint_padding_arg2, + final_tensor_3_constraint_no_padding, + final_tensor_3_constraint_padding_arg1, + final_tensor_3_constraint_padding_arg2, + final_tensor_4_constraint_no_padding, + final_tensor_4_constraint_padding_arg1, + final_tensor_4_constraint_padding_arg2, + ] + ) + + return ( + Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), + counter, + ) def transform_constraint(constraint: Constraint, counter: int): @@ -591,8 +801,6 @@ def transform_constraint(constraint: Constraint, counter: int): return constraint, counter - - def calc_last_two_dims(constraint, d: List[DVar]): """ Generates constraints for the last two dimensions of a convolution or a maxpool output @@ -612,29 +820,49 @@ def calc_last_two_dims(constraint, d: List[DVar]): b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)]) b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)]) - d3_not_dyn = Conj([BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)]) - d4_not_dyn = Conj([BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)]) + d3_not_dyn = Conj( + [BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)] + ) + d4_not_dyn = Conj( + [BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)] + ) # transform parameters into tuples incase they are not already - padding = (constraint.padding, constraint.padding) \ - if isinstance(constraint.padding, int) else constraint.padding - kernel = (constraint.kernel, constraint.kernel) \ - if isinstance(constraint.kernel, int) else constraint.kernel - stride = (constraint.stride, constraint.stride) \ - if isinstance(constraint.stride, int) else constraint.stride - dilation = (constraint.dilation, constraint.dilation) \ - if isinstance(constraint.dilation, int) else constraint.dilation + padding = ( + (constraint.padding, constraint.padding) + if isinstance(constraint.padding, int) + else constraint.padding + ) + kernel = ( + (constraint.kernel, constraint.kernel) + if isinstance(constraint.kernel, int) + else constraint.kernel + ) + stride = ( + (constraint.stride, constraint.stride) + if isinstance(constraint.stride, int) + else constraint.stride + ) + dilation = ( + (constraint.dilation, constraint.dilation) + if isinstance(constraint.dilation, int) + else constraint.dilation + ) f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add) f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul) - f3 = BinConstraintD(BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div) + f3 = BinConstraintD( + BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div + ) f4 = BinConstraintD(f3, 1, op_add) c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])]) f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add) f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul) - f33 = BinConstraintD(BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div) + f33 = BinConstraintD( + BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div + ) f44 = BinConstraintD(f33, 1, op_add) c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])]) @@ -652,8 +880,12 @@ def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]): one possibility about the values of the dimension variables """ # generate all possibilities of being equal or not equal to dyn for my_list - eq_possibilities = [BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list))] - neq_possibilities = [BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))] + eq_possibilities = [ + BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list)) + ] + neq_possibilities = [ + BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list)) + ] d_possibilities = [] for i in zip(eq_possibilities, neq_possibilities): @@ -721,10 +953,13 @@ def gen_all_reshape_possibilities(list_of_dims, target): all_constraints.append(Conj(p)) elif len(to_multiply) < len(list_of_dims): - all_constraints.append(Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))])) + all_constraints.append( + Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))]) + ) else: - all_constraints.append(Conj(p + [BinConstraintD(Prod(list_of_dims), - Prod(target), op_eq)])) + all_constraints.append( + Conj(p + [BinConstraintD(Prod(list_of_dims), Prod(target), op_eq)]) + ) return Disj(all_constraints) @@ -746,27 +981,36 @@ def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False if tensor_input1[index] is None: assert padding - if not padding: # then the inputs are the same length so they all have dimensions at "index" - return Conj([BinConstraintD(tensor_input1[index], 1, op_eq), - BinConstraintD(res1[index], res2[index], op_eq), - BinConstraintD(res2[index], tensor_input2[index], op_eq)]) + return Conj( + [ + BinConstraintD(tensor_input1[index], 1, op_eq), + BinConstraintD(res1[index], res2[index], op_eq), + BinConstraintD(res2[index], tensor_input2[index], op_eq), + ] + ) else: # we don't set the input dimension to 1, since it doesn't exist. - return Conj([BinConstraintD(res1[index], res2[index], op_eq), - BinConstraintD(res2[index], tensor_input2[index], op_eq)]) - - -def apply_padding(e1_var: TVar, - e11: BinConstraintT, - e2: BinConstraintT, - e12: BinConstraintT, - d2: List[DVar], - d11: List[DVar], - d12: List[DVar], - counter: int): + return Conj( + [ + BinConstraintD(res1[index], res2[index], op_eq), + BinConstraintD(res2[index], tensor_input2[index], op_eq), + ] + ) + + +def apply_padding( + e1_var: TVar, + e11: BinConstraintT, + e2: BinConstraintT, + e12: BinConstraintT, + d2: List[DVar], + d11: List[DVar], + d12: List[DVar], + counter: int, +): """ We are considering the possibility where one input has less dimensions than another input, so we apply padding to the broadcasted results @@ -789,7 +1033,6 @@ def apply_padding(e1_var: TVar, # pad the shorter input with None so we can pass it to the broadcasting helper function for i in range(1, len(d2)): - d1, counter = gen_tensor_dims(i, counter) nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12) @@ -804,30 +1047,37 @@ def apply_padding(e1_var: TVar, # for every padding size, we also consider broadcasting for j in range(len(d2) - i): - broadcast_padding.append(broadcast_dim(simulate_padding, d2, d11, d12, j, True)) + broadcast_padding.append( + broadcast_dim(simulate_padding, d2, d11, d12, j, True) + ) # we consider the possibilities for broadcasting for every dimension. Since we already # padded d1, we do not consider it while broadcasting - all_broadcasting_possibilities = generate_all_broadcasting_possibilities_no_padding(d1, - d2[(len(d2) - i):], - d11[(len(d2) - i):], - d12[(len(d2) - i):]) + all_broadcasting_possibilities = ( + generate_all_broadcasting_possibilities_no_padding( + d1, d2[(len(d2) - i) :], d11[(len(d2) - i) :], d12[(len(d2) - i) :] + ) + ) # combine all constraints into a conjunction - c = Conj([e1, e11, e2, e12, - *broadcast_padding, - all_broadcasting_possibilities, - *nat_constraints - ]) + c = Conj( + [ + e1, + e11, + e2, + e12, + *broadcast_padding, + all_broadcasting_possibilities, + *nat_constraints, + ] + ) res.append(c) return Disj(res), counter -def no_broadcast_dim_with_index(d1: List[DVar], - d2: List[DVar], - d3: List[DVar], - d4: List[DVar], - i: int): +def no_broadcast_dim_with_index( + d1: List[DVar], d2: List[DVar], d3: List[DVar], d4: List[DVar], i: int +): """ Args: d1: input 1 @@ -838,17 +1088,28 @@ def no_broadcast_dim_with_index(d1: List[DVar], Returns: Constraints for when no broadcasting occurs """ - return Conj([ - Disj([ - Conj([BinConstraintD(d1[i], 1, op_eq), - BinConstraintD(d2[i], 1, op_eq)]), - - Conj([BinConstraintD(d1[i], 1, op_neq), - BinConstraintD(d2[i], 1, op_neq)])]), - - BinConstraintD(d1[i], d3[i], op_eq), - BinConstraintD(d2[i], d4[i], op_eq)]) - + return Conj( + [ + Disj( + [ + Conj( + [ + BinConstraintD(d1[i], 1, op_eq), + BinConstraintD(d2[i], 1, op_eq), + ] + ), + Conj( + [ + BinConstraintD(d1[i], 1, op_neq), + BinConstraintD(d2[i], 1, op_neq), + ] + ), + ] + ), + BinConstraintD(d1[i], d3[i], op_eq), + BinConstraintD(d2[i], d4[i], op_eq), + ] + ) def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int): @@ -871,14 +1132,16 @@ def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int): return res, counter -def create_equality_constraints_for_broadcasting(e1: TVar, - e2: TVar, - e11: TVar, - e12: TVar, - d1: List[DVar], - d2: List[DVar], - d11: List[DVar], - d12: List[DVar]): +def create_equality_constraints_for_broadcasting( + e1: TVar, + e2: TVar, + e11: TVar, + e12: TVar, + d1: List[DVar], + d2: List[DVar], + d11: List[DVar], + d12: List[DVar], +): """ Create equality constraints for when no broadcasting occurs Args: @@ -920,10 +1183,17 @@ def gen_consistency_constraints(constraint: Constraint, counter: int): nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) - c_tensor_i = Conj([BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), - BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq)] + - [BinConstraintD(d1, d2, op_consistency) for - d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2)] + nat_constraints) + c_tensor_i = Conj( + [ + BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq), + ] + + [ + BinConstraintD(d1, d2, op_consistency) + for d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2) + ] + + nat_constraints + ) all_constraints.append(c_tensor_i) @@ -953,22 +1223,29 @@ def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): dims3, counter = gen_tensor_dims(i, counter) c3tensor = TensorType(dims3) - c += [BinConstraintT(constraint.rhs1, c1tensor, op_eq), - BinConstraintT(constraint.rhs2, c2tensor, op_eq), - BinConstraintT(constraint.res, c3tensor, op_eq)] + \ - gen_nat_constraints(dims1 + dims2 + dims3) + c += [ + BinConstraintT(constraint.rhs1, c1tensor, op_eq), + BinConstraintT(constraint.rhs2, c2tensor, op_eq), + BinConstraintT(constraint.res, c3tensor, op_eq), + ] + gen_nat_constraints(dims1 + dims2 + dims3) - assert len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__) + assert ( + len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__) + ) for i in range(len(c3tensor.__args__)): - c.append(DGreatestUpperBound(c3tensor.__args__[i], - c1tensor.__args__[i], - c2tensor.__args__[i])) + c.append( + DGreatestUpperBound( + c3tensor.__args__[i], c1tensor.__args__[i], c2tensor.__args__[i] + ) + ) all_constraints.append(Conj(c)) return all_constraints, counter -def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar]): +def generate_all_broadcasting_possibilities_no_padding( + d1: List[DVar], d2: List[DVar], d11: List[DVar], d12: List[DVar] +): """ Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension. We look at all combinations for all dimensions in d1 and d2 @@ -996,7 +1273,9 @@ def generate_all_broadcasting_possibilities_no_padding(d1: List[DVar], d2: List[ return Conj(res2) -def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int): +def gen_broadcasting_constraints( + e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int +): """ Simulates broadcasting on e1 and e2 and returns the results respectively in e11 and e12. Because of gradual types, @@ -1019,22 +1298,33 @@ def gen_broadcasting_constraints(e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: in [d1, d2, d3, d4] = dims nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims))) - initialize_tensors_constraints = create_equality_constraints_for_broadcasting(e1, e2, e11, e12, - d1, d2, d3, d4) + initialize_tensors_constraints = create_equality_constraints_for_broadcasting( + e1, e2, e11, e12, d1, d2, d3, d4 + ) [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints # without padding, broadcast all possibilities for tensors of size i - final_tensor_constraint_no_padding = Conj([*initialize_tensors_constraints, - generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4)]) + final_tensor_constraint_no_padding = Conj( + [ + *initialize_tensors_constraints, + generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4), + ] + ) # with padding, broadcast all possibilities for tensors of size i - final_tensor_constraint_padding_arg1, counter = \ - apply_padding(e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter) - - final_tensor_constraint_padding_arg2, counter = \ - apply_padding(e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter) - - return final_tensor_constraint_no_padding, \ - final_tensor_constraint_padding_arg1, \ - final_tensor_constraint_padding_arg2, nat_dims_i, counter + final_tensor_constraint_padding_arg1, counter = apply_padding( + e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter + ) + + final_tensor_constraint_padding_arg2, counter = apply_padding( + e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter + ) + + return ( + final_tensor_constraint_no_padding, + final_tensor_constraint_padding_arg1, + final_tensor_constraint_padding_arg2, + nat_dims_i, + counter, + ) diff --git a/torch/fx/experimental/migrate_gradual_types/operation.py b/torch/fx/experimental/migrate_gradual_types/operation.py index 432cd570bebbf..267100c8545c8 100644 --- a/torch/fx/experimental/migrate_gradual_types/operation.py +++ b/torch/fx/experimental/migrate_gradual_types/operation.py @@ -1,14 +1,14 @@ -op_add = '+' -op_sub = '-' -op_mul = '*' -op_div = '/' -op_eq = '=' -op_neq = '!=' -op_imp = '=>' -op_matching = '\u22b3' # (contains) -op_consistency = '~' -op_precision = '\u2291' # (square image of or equal to) -op_leq = '\u2264' # less-than or equal to -op_lt = '<' -op_gt = '>' -op_mod = '%' +op_add = "+" +op_sub = "-" +op_mul = "*" +op_div = "/" +op_eq = "=" +op_neq = "!=" +op_imp = "=>" +op_matching = "\u22b3" # (contains) +op_consistency = "~" +op_precision = "\u2291" # (square image of or equal to) +op_leq = "\u2264" # less-than or equal to +op_lt = "<" +op_gt = ">" +op_mod = "%" diff --git a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py index c8cf70006cd84..d1f9f33965e07 100644 --- a/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py +++ b/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py @@ -1,16 +1,49 @@ # mypy: allow-untyped-defs -from torch.fx.experimental.migrate_gradual_types.constraint import Conj, Disj, T, F, BinConstraintT, BVar, is_bool_expr -from torch.fx.experimental.migrate_gradual_types.constraint import BinConstraintD, TVar, DVar -from torch.fx.experimental.migrate_gradual_types.constraint import Prod, is_algebraic_expression, is_dim -from torch.fx.experimental.migrate_gradual_types.constraint_generator import ConstraintGenerator -from torch.fx.experimental.migrate_gradual_types.constraint_transformation import transform_constraint -from torch.fx.experimental.migrate_gradual_types.operation import op_add, op_eq, op_neq, op_gt, op_lt -from torch.fx.experimental.migrate_gradual_types.operation import op_leq, op_sub, op_div, op_mul, op_mod -from torch.fx.tensor_type import TensorType, Dyn +from torch.fx.experimental.migrate_gradual_types.constraint import ( + BinConstraintD, + BinConstraintT, + BVar, + Conj, + Disj, + DVar, + F, + is_algebraic_expression, + is_bool_expr, + is_dim, + Prod, + T, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.constraint_generator import ( + ConstraintGenerator, +) +from torch.fx.experimental.migrate_gradual_types.constraint_transformation import ( + transform_constraint, +) +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_div, + op_eq, + op_gt, + op_leq, + op_lt, + op_mod, + op_mul, + op_neq, + op_sub, +) +from torch.fx.tensor_type import Dyn, TensorType + try: import z3 # type: ignore[import] - from torch.fx.experimental.migrate_gradual_types.z3_types import tensor_type, z3_dyn, D + + from torch.fx.experimental.migrate_gradual_types.z3_types import ( + D, + tensor_type, + z3_dyn, + ) + HAS_Z3 = True def transform_to_z3(constraint, counter, dimension_dict): @@ -41,35 +74,48 @@ def transform_to_z3(constraint, counter, dimension_dict): return (lhs == rhs), counter else: - raise NotImplementedError('Method not yet implemented') + raise NotImplementedError("Method not yet implemented") elif isinstance(constraint, BinConstraintD): if constraint.op == op_eq: - if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs): - transformed_rhs, counter = transform_to_z3(constraint.rhs, counter, dimension_dict) + transformed_rhs, counter = transform_to_z3( + constraint.rhs, counter, dimension_dict + ) transformed_lhs = z3.Bool(constraint.lhs.c) return transformed_lhs == transformed_rhs, counter elif is_dim(constraint.lhs) and is_dim(constraint.rhs): # with dimension transformations we consider the encoding - lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict) + lhs, counter = transform_dimension( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_dimension( + constraint.rhs, counter, dimension_dict + ) return lhs == rhs, counter else: # then we have an algebraic expression which means that we disregard the # first element of the encoding - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) return lhs == rhs, counter # The assumption here is that the LHS and RHS must be dimensions elif constraint.op == op_neq: assert is_dim(constraint.lhs) assert is_dim(constraint.rhs) - lhs, counter = transform_dimension(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_dimension(constraint.rhs, counter, dimension_dict) + lhs, counter = transform_dimension( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_dimension( + constraint.rhs, counter, dimension_dict + ) if constraint.rhs == Dyn or constraint.lhs == Dyn: if constraint.rhs == Dyn: return lhs.arg(0) == 1, counter @@ -79,44 +125,83 @@ def transform_to_z3(constraint, counter, dimension_dict): # if one of the instances is a number elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int): if isinstance(constraint.lhs, int): - return z3.Or([rhs.arg(0) == 0, z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter + return ( + z3.Or( + [ + rhs.arg(0) == 0, + z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]), + ] + ), + counter, + ) elif isinstance(constraint.rhs, int): - return z3.Or([lhs.arg(0) == 0, z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)])]), counter + return ( + z3.Or( + [ + lhs.arg(0) == 0, + z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]), + ] + ), + counter, + ) else: - return z3.Or([z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]), - z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]), - z3.And([lhs.arg(0) != 0, rhs.arg(0) != 0, lhs.arg(1) != rhs.arg(1)])]), counter - + return ( + z3.Or( + [ + z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]), + z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]), + z3.And( + [ + lhs.arg(0) != 0, + rhs.arg(0) != 0, + lhs.arg(1) != rhs.arg(1), + ] + ), + ] + ), + counter, + ) elif constraint.op == op_leq: # if the dimensions are not dyn, this will come into effect # there would have been another constraint specifying if a given dimension # is dyn or not assert is_dim(constraint.lhs) and is_dim(constraint.rhs) - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) return lhs <= rhs, counter elif constraint.op == op_gt: assert is_dim(constraint.lhs) and is_dim(constraint.rhs) - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) return lhs > rhs, counter elif constraint.op == op_lt: assert is_dim(constraint.lhs) and is_dim(constraint.rhs) - lhs, counter = transform_algebraic_expression(constraint.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(constraint.rhs, counter, dimension_dict) + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) return lhs < rhs, counter else: - raise NotImplementedError('operation not yet implemented') + raise NotImplementedError("operation not yet implemented") else: - raise NotImplementedError('Operation not yet implemented') - + raise NotImplementedError("Operation not yet implemented") def transform_var(tensor, counter, dimension_dict): """ @@ -166,13 +251,15 @@ def transform_dimension(dimension, counter, dimension_dict): return D(1, dimension), counter elif isinstance(dimension, DVar): if dimension.c in dimension_dict: - return D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), counter + return ( + D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), + counter, + ) else: counter += 1 dimension_dict[dimension.c] = counter return D(z3.Int(counter), z3.Int(dimension.c)), counter - def transform_algebraic_expression(expr, counter, dimension_dict): """ Transforms an algebraic expression to z3 format @@ -190,7 +277,6 @@ def transform_algebraic_expression(expr, counter, dimension_dict): return transformed.arg(1), counter elif isinstance(expr, Prod): - dims = [] for dim in expr.products: assert is_dim(dim) @@ -199,9 +285,12 @@ def transform_algebraic_expression(expr, counter, dimension_dict): return z3.Product(dims), counter elif is_algebraic_expression(expr): - - lhs, counter = transform_algebraic_expression(expr.lhs, counter, dimension_dict) - rhs, counter = transform_algebraic_expression(expr.rhs, counter, dimension_dict) + lhs, counter = transform_algebraic_expression( + expr.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + expr.rhs, counter, dimension_dict + ) if expr.op == op_sub: c = lhs - rhs @@ -219,14 +308,13 @@ def transform_algebraic_expression(expr, counter, dimension_dict): c = lhs % rhs else: - raise NotImplementedError('operation not yet implemented') + raise NotImplementedError("operation not yet implemented") return c, counter else: raise RuntimeError - def transform_all_constraints(traced, counter=0): """ Given a trace, generates constraints and transforms them to z3 format @@ -291,7 +379,6 @@ def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0): # transform precision, matching, consistency till obtaining a fixed point new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) - # since the function returns a list of one element, we get the first element # we are only interested in the RHS in this case because the LHS just stores # the result @@ -304,19 +391,27 @@ def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0): condition_constraint_rhs = condition_constraint.rhs # transform the condition constraint - condition_constraint_rhs, counter = iterate_till_fixed_point(condition_constraint_rhs, counter) + condition_constraint_rhs, counter = iterate_till_fixed_point( + condition_constraint_rhs, counter + ) transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) - transformed_condition_constraint, counter = transform_to_z3(condition_constraint_rhs, counter, dimension_dict) - - negation_transformed_condition_constraint = z3.Not(transformed_condition_constraint) + transformed_condition_constraint, counter = transform_to_z3( + condition_constraint_rhs, counter, dimension_dict + ) - return z3.And([transformed, transformed_condition_constraint]), \ - z3.And([transformed, negation_transformed_condition_constraint]) + negation_transformed_condition_constraint = z3.Not( + transformed_condition_constraint + ) + return z3.And([transformed, transformed_condition_constraint]), z3.And( + [transformed, negation_transformed_condition_constraint] + ) - def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, user_constraints=None): + def evaluate_conditional_with_constraints( + tracer_root, graph, node, counter=0, user_constraints=None + ): """ Given an IR and a node representing a conditional, evaluate the conditional and its negation @@ -329,8 +424,10 @@ def evaluate_conditional_with_constraints(tracer_root, graph, node, counter=0, u """ - transformed_positive, transformed_negative = \ - transform_all_constraints_trace_time(tracer_root, graph, node, counter) + ( + transformed_positive, + transformed_negative, + ) = transform_all_constraints_trace_time(tracer_root, graph, node, counter) s = z3.Solver() s.add(transformed_positive) diff --git a/torch/fx/experimental/migrate_gradual_types/util.py b/torch/fx/experimental/migrate_gradual_types/util.py index 99f94609f2650..bd40d2a463f5e 100644 --- a/torch/fx/experimental/migrate_gradual_types/util.py +++ b/torch/fx/experimental/migrate_gradual_types/util.py @@ -1,6 +1,10 @@ # mypy: allow-untyped-defs -from torch.fx.experimental.migrate_gradual_types.constraint import TVar, DVar, BinConstraintD, \ - BVar +from torch.fx.experimental.migrate_gradual_types.constraint import ( + BinConstraintD, + BVar, + DVar, + TVar, +) from torch.fx.experimental.migrate_gradual_types.operation import op_leq @@ -23,6 +27,7 @@ def gen_dvar(curr): curr += 1 return DVar(curr), curr + def gen_bvar(curr): """ Generate a boolean variable @@ -32,6 +37,7 @@ def gen_bvar(curr): curr += 1 return BVar(curr), curr + def gen_tensor_dims(n, curr): """ Generate a list of tensor dimensions diff --git a/torch/fx/experimental/migrate_gradual_types/z3_types.py b/torch/fx/experimental/migrate_gradual_types/z3_types.py index 897a79d569757..939f4865ab7d9 100644 --- a/torch/fx/experimental/migrate_gradual_types/z3_types.py +++ b/torch/fx/experimental/migrate_gradual_types/z3_types.py @@ -1,22 +1,23 @@ try: import z3 # type: ignore[import] + HAS_Z3 = True # dynamic type - dyn = z3.DeclareSort('Dyn') - dyn_type = z3.Const('dyn', dyn) + dyn = z3.DeclareSort("Dyn") + dyn_type = z3.Const("dyn", dyn) # dimension - dim = z3.Datatype('dim') - dim.declare('dim', ('0', z3.IntSort()), ('1', z3.IntSort())) + dim = z3.Datatype("dim") + dim.declare("dim", ("0", z3.IntSort()), ("1", z3.IntSort())) dim = dim.create() # tensors - tensor_type = z3.Datatype('TensorType') - tensor_type.declare('Dyn', ('dyn', dyn)) - tensor_type.declare('tensor1', ('0', dim)) - tensor_type.declare('tensor2', ('0', dim), ('1', dim)) - tensor_type.declare('tensor3', ('0', dim), ('1', dim), ('2', dim)) - tensor_type.declare('tensor4', ('0', dim), ('1', dim), ('2', dim), ('3', dim)) + tensor_type = z3.Datatype("TensorType") + tensor_type.declare("Dyn", ("dyn", dyn)) + tensor_type.declare("tensor1", ("0", dim)) + tensor_type.declare("tensor2", ("0", dim), ("1", dim)) + tensor_type.declare("tensor3", ("0", dim), ("1", dim), ("2", dim)) + tensor_type.declare("tensor4", ("0", dim), ("1", dim), ("2", dim), ("3", dim)) tensor_type = tensor_type.create() # create dimension diff --git a/torch/fx/experimental/normalize.py b/torch/fx/experimental/normalize.py index 30b076a72bee2..cc6944d5a5afe 100644 --- a/torch/fx/experimental/normalize.py +++ b/torch/fx/experimental/normalize.py @@ -1,16 +1,16 @@ # mypy: allow-untyped-defs import operator -from typing import Any, Callable, Dict, Tuple, Optional +from typing import Any, Callable, Dict, Optional, Tuple import torch import torch.fx import torch.fx as fx -from torch.fx import Transformer, Proxy -from torch.fx.node import Argument, Target, Node, map_aggregate +from torch.fx import Proxy, Transformer +from torch.fx.node import Argument, map_aggregate, Node, Target from torch.fx.operator_schemas import ( - normalize_module, - normalize_function, create_type_hint, + normalize_function, + normalize_module, ) from .schema_type_annotation import AnnotateTypesWithSchema diff --git a/torch/fx/experimental/optimization.py b/torch/fx/experimental/optimization.py index 8362c0cb88ac1..2fe600c247b84 100644 --- a/torch/fx/experimental/optimization.py +++ b/torch/fx/experimental/optimization.py @@ -1,37 +1,42 @@ # mypy: allow-untyped-defs -import torch.fx as fx -from torch.fx.node import Argument, Target -from torch.nn.utils.fusion import fuse_conv_bn_eval -from typing import Type, Dict, Any, Tuple, Iterable, Optional, List, cast -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.fx.passes.shape_prop import ShapeProp import copy -from collections import defaultdict -import torch.utils.mkldnn as th_mkldnn +import logging import operator import time -import logging +from collections import defaultdict from enum import Enum +from typing import Any, cast, Dict, Iterable, List, Optional, Tuple, Type + +import torch +import torch.fx as fx +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.mkldnn as th_mkldnn +from torch.fx.node import Argument, Target +from torch.fx.passes.shape_prop import ShapeProp +from torch.nn.utils.fusion import fuse_conv_bn_eval -def _parent_name(target : str) -> Tuple[str, str]: + +def _parent_name(target: str) -> Tuple[str, str]: """ Splits a qualname into parent path and last atom. For example, `foo.bar.baz` -> (`foo.bar`, `baz`) """ - *parent, name = target.rsplit('.', 1) - return parent[0] if parent else '', name + *parent, name = target.rsplit(".", 1) + return parent[0] if parent else "", name + # Works for length 2 patterns with 2 modules -def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any]): +def matches_module_pattern( + pattern: Iterable[Type], node: fx.Node, modules: Dict[str, Any] +): if len(node.args) == 0: return False nodes: Tuple[Any, fx.Node] = (node.args[0], node) for expected_type, current_node in zip(pattern, nodes): if not isinstance(current_node, fx.Node): return False - if current_node.op != 'call_module': + if current_node.op != "call_module": return False if not isinstance(current_node.target, str): return False @@ -42,20 +47,25 @@ def matches_module_pattern(pattern: Iterable[Type], node: fx.Node, modules: Dict return True -def replace_node_module(node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module): +def replace_node_module( + node: fx.Node, modules: Dict[str, Any], new_module: torch.nn.Module +): assert isinstance(node.target, str) parent_name, name = _parent_name(node.target) modules[node.target] = new_module setattr(modules[parent_name], name, new_module) + def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module: """ Fuses convolution/BN layers for inference purposes. Will deepcopy your model by default, but can modify the model inplace as well. """ - patterns = [(nn.Conv1d, nn.BatchNorm1d), - (nn.Conv2d, nn.BatchNorm2d), - (nn.Conv3d, nn.BatchNorm3d)] + patterns = [ + (nn.Conv1d, nn.BatchNorm1d), + (nn.Conv2d, nn.BatchNorm2d), + (nn.Conv3d, nn.BatchNorm3d), + ] if not inplace: model = copy.deepcopy(model) if not no_trace or not isinstance(model, torch.fx.GraphModule): @@ -80,6 +90,7 @@ def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Modu new_graph.erase_node(node) return fx.GraphModule(fx_model, new_graph) + def remove_dropout(model: nn.Module) -> nn.Module: """ Removes all dropout layers from the module. @@ -87,15 +98,24 @@ def remove_dropout(model: nn.Module) -> nn.Module: fx_model = fx.symbolic_trace(model) class DropoutRemover(torch.fx.Transformer): - def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def call_module( + self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: if isinstance(self.submodules[target], nn.Dropout): assert len(args) == 1 return args[0] else: return super().call_module(target, args, kwargs) + return DropoutRemover(fx_model).transform() -def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[fx.Node], outputs: List[fx.Node]): + +def extract_subgraph( + orig_module: nn.Module, + nodes: List[fx.Node], + inputs: List[fx.Node], + outputs: List[fx.Node], +): """ Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. """ @@ -111,10 +131,21 @@ def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[ new_graph.lint() return fx.GraphModule(orig_module, new_graph) + mkldnn_supported = [ - nn.Conv2d, nn.Linear, nn.BatchNorm2d, nn.ReLU, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d, - torch.relu, torch.transpose, torch.sigmoid, - F.relu, F.avg_pool2d, F.adaptive_avg_pool2d + nn.Conv2d, + nn.Linear, + nn.BatchNorm2d, + nn.ReLU, + nn.MaxPool2d, + nn.AvgPool2d, + nn.AdaptiveAvgPool2d, + torch.relu, + torch.transpose, + torch.sigmoid, + F.relu, + F.avg_pool2d, + F.adaptive_avg_pool2d, ] # These are operators that may not be convertible into MKLDNN ops (e.g. the # args are scalar values). Thus, we only include them in the subgraph if their @@ -124,7 +155,7 @@ def extract_subgraph(orig_module: nn.Module, nodes: List[fx.Node], inputs: List[ mkldnn_map = { nn.Conv2d: th_mkldnn.MkldnnConv2d, nn.Linear: th_mkldnn.MkldnnLinear, - nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a) + nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a), } @@ -136,7 +167,7 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]): """ old_modules: Dict[nn.Module, nn.Module] = {} for node in nodes: - if node.op == 'call_module': + if node.op == "call_module": assert isinstance(node.target, str) cur_module = modules[node.target] if type(cur_module) in mkldnn_map: @@ -146,18 +177,24 @@ def modules_to_mkldnn(nodes: List[fx.Node], modules: Dict[str, nn.Module]): replace_node_module(node, modules, new_module) return old_modules -def reset_modules(nodes: List[fx.Node], modules: Dict[str, nn.Module], old_modules: Dict[nn.Module, nn.Module]): + +def reset_modules( + nodes: List[fx.Node], + modules: Dict[str, nn.Module], + old_modules: Dict[nn.Module, nn.Module], +): """ Maps each module that's been changed with `modules_to_mkldnn` back to its original. """ for node in nodes: - if node.op == 'call_module': - assert (isinstance(node.target, str)) + if node.op == "call_module": + assert isinstance(node.target, str) cur_module = modules[node.target] if cur_module in old_modules: replace_node_module(node, modules, old_modules[cur_module]) + class MklSubgraph: def __init__(self, fx_graph: fx.Graph): self.fx_graph = fx_graph @@ -165,6 +202,7 @@ def __init__(self, fx_graph: fx.Graph): self.start_nodes: List[fx.Node] = [] self.end_nodes: List[fx.Node] = [] + def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): """ This generates a heuristic that can be passed into `optimize_for_inference` that @@ -193,16 +231,24 @@ def benchmark(f): f() begin = time.time() for _ in range(iters): - out = f() + f() return time.time() - begin - mkl_time = benchmark(lambda: [i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs])]) + mkl_time = benchmark( + lambda: [ + i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs]) + ] + ) - reset_modules(submodule.graph.nodes, dict(submodule.named_modules()), old_modules) + reset_modules( + submodule.graph.nodes, dict(submodule.named_modules()), old_modules + ) no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) return mkl_time < no_mkl_time + return use_mkl_heuristic + def use_mkl_length(graph: MklSubgraph) -> bool: """ This is a heuristic that can be passed into `optimize_for_inference` that @@ -211,6 +257,7 @@ def use_mkl_length(graph: MklSubgraph) -> bool: """ return len(graph.nodes) > 2 + class UnionFind: def __init__(self, n): self.parent: List[Optional[int]] = [None] * n @@ -237,10 +284,11 @@ def join(self, a: int, b: int): self.parent[b] = a self.size[a] += self.size[b] + def optimize_for_inference( model: torch.nn.Module, pass_config: Optional[Dict[str, Any]] = None, - tracer: Type[fx.Tracer] = fx.Tracer + tracer: Type[fx.Tracer] = fx.Tracer, ) -> torch.nn.Module: """ Performs a set of optimization passes to optimize a model for the @@ -258,7 +306,7 @@ def optimize_for_inference( default_pass_config = { "conv_bn_fuse": True, "remove_dropout": True, - "mkldnn_layout_optimize": {'heuristic': use_mkl_length}, + "mkldnn_layout_optimize": {"heuristic": use_mkl_length}, } if pass_config is None: pass_config = {} @@ -278,7 +326,7 @@ def optimize_for_inference( cur_tracer = tracer() fx_graph = cur_tracer.trace(copy.deepcopy(model)) - fx_model = fx.GraphModule(cur_tracer.root, fx_graph) + fx.GraphModule(cur_tracer.root, fx_graph) modules: Dict[str, nn.Module] = dict(model.named_modules()) class MklSupport(Enum): @@ -292,15 +340,19 @@ class MklSupport(Enum): # a MKLDNN node if its inputs are MKLDNN nodes. for node in list(fx_graph.nodes): supports_mkldnn = MklSupport.NO - if node.op == 'call_module': + if node.op == "call_module": cur_module = modules[node.target] if type(cur_module) in mkldnn_supported: supports_mkldnn = MklSupport.YES sample_parameter = next(cur_module.parameters(), None) if sample_parameter is not None: - assert sample_parameter.dtype == torch.float, "this pass is only for torch.float modules" - assert sample_parameter.device == torch.device('cpu'), "this pass is only for CPU modules" - elif node.op == 'call_function': + assert ( + sample_parameter.dtype == torch.float + ), "this pass is only for torch.float modules" + assert sample_parameter.device == torch.device( + "cpu" + ), "this pass is only for CPU modules" + elif node.op == "call_function": if node.target in mkldnn_supported: supports_mkldnn = MklSupport.YES elif node.target in mkldnn_supported_unknown: @@ -308,15 +360,17 @@ class MklSupport(Enum): if supports_mkldnn != MklSupport.NO: if supports_mkldnn == MklSupport.UNKNOWN: - if not any(arg.target == 'to_dense' for arg in node.args): + if not any(arg.target == "to_dense" for arg in node.args): continue with fx_graph.inserting_before(node): - mkldnn_args = fx.map_arg(node.args, lambda n: fx_graph.call_method('to_mkldnn', (n, ))) + mkldnn_args = fx.map_arg( + node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,)) + ) node.args = cast(Tuple[fx.node.Argument], mkldnn_args) with fx_graph.inserting_after(node): - dense_x = fx_graph.create_node('call_method', 'to_dense', (node,)) + dense_x = fx_graph.create_node("call_method", "to_dense", (node,)) node.replace_all_uses_with(dense_x) dense_x.args = (node,) @@ -326,28 +380,26 @@ class MklSupport(Enum): # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b for node in fx_graph.nodes: - if node.op == 'call_method' and node.target == 'to_dense': + if node.op == "call_method" and node.target == "to_dense": prv_node = node.args[0] users = list(node.users) for user in users: - if user.op == 'call_method' and user.target == 'to_mkldnn': + if user.op == "call_method" and user.target == "to_mkldnn": user.replace_all_uses_with(prv_node) fx_graph.erase_node(user) if len(node.users) == 0: fx_graph.erase_node(node) - num_nodes = len(fx_graph.nodes) uf = UnionFind(num_nodes) def get_color(n): - if hasattr(n, 'color'): # Current node is part of a MKL subgraph + if hasattr(n, "color"): # Current node is part of a MKL subgraph return uf.find(n.color) - if hasattr(n, 'start_color'): # Current node is input to MKL subgraph + if hasattr(n, "start_color"): # Current node is input to MKL subgraph return uf.find(n.start_color) return None - # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists # of input nodes (which are only `to_mkldnn` calls), output nodes # (`to_dense` calls), and intermediate nodes, which are run entirely on @@ -360,14 +412,19 @@ def get_color(n): # nodes (i.e. colors), we need to join these 2 colors into 1. That's done # using a Disjoint Set Union. for cur_idx, node in enumerate(fx_graph.nodes): - if node.op == 'call_method' and node.target == 'to_mkldnn': + if node.op == "call_method" and node.target == "to_mkldnn": node.start_color = cur_idx uf.make_set(cur_idx) - elif node.op == 'call_method' and node.target == 'to_dense': + elif node.op == "call_method" and node.target == "to_dense": assert get_color(node.args[0]) is not None node.end_color = get_color(node.args[0]) else: - cur_colors = [get_color(i) for i in node.all_input_nodes if isinstance(i, fx.Node) if get_color(i) is not None] + cur_colors = [ + get_color(i) + for i in node.all_input_nodes + if isinstance(i, fx.Node) + if get_color(i) is not None + ] if len(cur_colors) == 0: continue @@ -377,17 +434,15 @@ def get_color(n): for other_color in cur_colors[1:]: uf.join(cur_colors[0], other_color) - mkldnn_graphs: Dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) for node in fx_graph.nodes: - if hasattr(node, 'color'): + if hasattr(node, "color"): mkldnn_graphs[uf.find(node.color)].nodes.append(node) - if hasattr(node, 'start_color'): + if hasattr(node, "start_color"): mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node) - if hasattr(node, 'end_color'): + if hasattr(node, "end_color"): mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node) - # Now that we have all the subgraphs, we need to decide which MKLDNN # subgraphs we actually want to keep in MKLDNN. for graph in mkldnn_graphs.values(): @@ -400,7 +455,7 @@ def get_color(n): mkldnn_conversions = 0 for node in fx_graph.nodes: - if node.target == 'to_mkldnn' or node.target == 'to_dense': + if node.target == "to_mkldnn" or node.target == "to_dense": mkldnn_conversions += 1 logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions) diff --git a/torch/fx/experimental/partitioner_utils.py b/torch/fx/experimental/partitioner_utils.py index 796c65a430228..e59921c58fa18 100644 --- a/torch/fx/experimental/partitioner_utils.py +++ b/torch/fx/experimental/partitioner_utils.py @@ -1,8 +1,8 @@ # mypy: allow-untyped-defs from enum import Enum -from typing import NamedTuple, Dict, List, Set +from typing import Dict, List, NamedTuple, Set -from torch.fx.node import Node, map_arg +from torch.fx.node import map_arg, Node class Partition: @@ -146,7 +146,7 @@ def get_top_nodes(partition: Partition) -> List[Node]: # this node is on the top bfs level in this partition if not any( n in partition.nodes and n.op not in {"placeholder", "get_attr"} - for n in input_nodes + for n in input_nodes ): top_nodes.append(node) return top_nodes @@ -282,7 +282,7 @@ def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float: latency_so_far_sec += partition_to_latency_mapping[ partition ].overall_latency_sec - children = partition.children + if partition.children: max_latency_sec = 0.0 for child in partition.children: diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 88d86cb7838ba..02c0505bf0804 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -801,6 +801,10 @@ def can_handle_tensor(x: Tensor) -> bool: if r is not NotImplemented: return r + if func is torch.ops.aten.is_nonzero.default: + with proxy_mode: + return (args[0] != 0).item() # type: ignore[attr-defined] + tracer = proxy_mode.tracer f_flat_args_kwargs = [ ( @@ -1172,11 +1176,11 @@ def impure_pred(n: fx.Node) -> bool: def wrap_key( f: Callable[_P, R], tensors: _P.args, tracer: _ProxyTracer, pre_dispatch: bool ) -> Callable[_P, R]: - flat_tensors, tensors_spec = pytree.tree_flatten(tensors) + flat_tensors, _tensors_spec = pytree.tree_flatten(tensors) @functools.wraps(f) def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R: - flat_proxies, proxies_spec = pytree.tree_flatten(proxies) + flat_proxies, _proxies_spec = pytree.tree_flatten(proxies) assert len(flat_proxies) == len(flat_tensors) with disable_proxy_modes_tracing() as m: assert isinstance(m, ProxyTorchDispatchMode) @@ -1247,10 +1251,14 @@ def __torch_function__( class PreDispatchTorchFunctionMode(TorchFunctionMode): def __init__(self, tracer: _ProxyTracer) -> None: self.tracer = tracer + # The input to torch.amp.autocast_mode._exit_autocast graph node should be the + # enter_autocast node. So we have to save the enter autocast node here, and assign it + # to the exit_autocast call_function node. + self.enter_autocast_nodes: List[torch.fx.Node] = [] def __torch_function__( self, - func: OpOverload, + func: Union[OpOverload, Callable], types: Tuple[torch._C._TensorMeta, ...], args: Tuple[object, ...] = (), kwargs: Optional[Dict[str, object]] = None, @@ -1261,7 +1269,12 @@ def __torch_function__( # TODO(tmanlaibaatar): we should systematically couple it with expoert verifier, # instead of hardcoding it here. # T203648563 + if func == torch.amp.autocast_mode._exit_autocast: + enter_node = self.enter_autocast_nodes.pop() + args = (enter_node,) node = self.tracer.create_node("call_function", func, args, {}) # type: ignore[arg-type] + if func == torch.amp.autocast_mode._enter_autocast: + self.enter_autocast_nodes.append(node) if func in [ torch._C._set_grad_enabled, torch.amp.autocast_mode._enter_autocast, @@ -1732,7 +1745,7 @@ def call_module( try: return Tracer.call_module(self, m, forward, args, kwargs) - except _ModuleNotInstalledAsSubmoduleError as e: + except _ModuleNotInstalledAsSubmoduleError: warnings.warn( f"Unable to find the path of the module {m}. " "This might be because the module was not properly registered " @@ -2214,7 +2227,14 @@ def maybe_handle_decomp( args: Tuple[object, ...], kwargs: Dict[str, object], ) -> object: + from torch._inductor.bisect_helper import BisectionManager + if op in CURRENT_DECOMPOSITION_TABLE: + if BisectionManager.disable_subsystem( + "aot_eager_decomp_partition", "decomposition", lambda: repr(op) + ): + return NotImplemented + with proxy_mode: proxy_mode.decomp_layers += 1 out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs) diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index 0b6410be41c40..84a04acc2b996 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -331,7 +331,7 @@ def replay_shape_env_events(events): # We need to call create_mapping_fn every time, since the node list might # change after each event is replayed. event.run(shape_env) - except Exception as e: + except Exception: log.error("failed when running event: %s", event) raise diff --git a/torch/fx/experimental/refinement_types.py b/torch/fx/experimental/refinement_types.py index a33ddf3710a4a..4a262af8fad9f 100644 --- a/torch/fx/experimental/refinement_types.py +++ b/torch/fx/experimental/refinement_types.py @@ -5,10 +5,10 @@ def __init__(self, lhs, rhs): self.rhs = rhs def __str__(self): - return f'{self.lhs} = {self.rhs}' + return f"{self.lhs} = {self.rhs}" def __repr__(self): - return f'{self.lhs} = {self.rhs}' + return f"{self.lhs} = {self.rhs}" def __eq__(self, other): if isinstance(other, Equality): diff --git a/torch/fx/experimental/rewriter.py b/torch/fx/experimental/rewriter.py index 3647ca59153b4..76ec03f862898 100644 --- a/torch/fx/experimental/rewriter.py +++ b/torch/fx/experimental/rewriter.py @@ -1,16 +1,18 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs import ast -import inspect -import textwrap import copy import functools +import inspect +import textwrap from types import FunctionType -from typing import cast, Union, Callable, Dict, Optional, Any +from typing import Any, Callable, cast, Dict, Optional, Union + +import torch +from torch._sources import normalize_source_lines from torch.fx._symbolic_trace import Tracer from torch.fx.graph import Graph -from torch._sources import normalize_source_lines -import torch + class AST_Rewriter(ast.NodeTransformer): """ @@ -29,11 +31,10 @@ class AST_Rewriter(ast.NodeTransformer): # suitable for dynamo tracing anyways. @torch._dynamo.disable def rewrite(self, fn: FunctionType): - # Normalize the source lines sourcelines, _ = inspect.getsourcelines(fn) sourcelines = normalize_source_lines(sourcelines) - source = ''.join(sourcelines) + source = "".join(sourcelines) normalized_str = textwrap.dedent(source) # Rewrite the original AST @@ -64,6 +65,7 @@ def change_func_globals(f, globals): g = functools.update_wrapper(g, f) g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined] return g + # Return the correct FunctionType object return change_func_globals(fn_compiled, globals=fn.__globals__) @@ -73,7 +75,7 @@ def visit_Assert(self, node): symbolically-traceable torch._assert function """ # Create the Call node - n = ast.parse('torch._assert()', mode='eval') + n = ast.parse("torch._assert()", mode="eval") assert isinstance(n, ast.Expression) call_node = n.body assert isinstance(call_node, ast.Call) @@ -96,13 +98,22 @@ def visit_AnnAssign(self, node): Output: y = annotate(f2(x),Tensor_Type((1,2,3,Dyn))) """ - return ast.Assign(targets=[node.target], value=ast.Call( - func=ast.Name(id='annotate', ctx=ast.Load()), - args=[node.value, node.annotation], keywords=[])) + return ast.Assign( + targets=[node.target], + value=ast.Call( + func=ast.Name(id="annotate", ctx=ast.Load()), + args=[node.value, node.annotation], + keywords=[], + ), + ) class RewritingTracer(Tracer): - def trace(self, root: Union[torch.nn.Module, Callable], concrete_args: Optional[Dict[str, Any]] = None) -> Graph: + def trace( + self, + root: Union[torch.nn.Module, Callable], + concrete_args: Optional[Dict[str, Any]] = None, + ) -> Graph: return super().trace(_rewrite(root), concrete_args) @@ -111,7 +122,7 @@ def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Cal # Rewrite this module's `forward` as well as the `forward`s of # all of this module's recursive descendents. Return the new, # rewritten module hierarchy. - def rewrite_module(m : torch.nn.Module): + def rewrite_module(m: torch.nn.Module): class RewrittenModule(torch.nn.Module): def __init__(self, orig): super().__init__() @@ -120,8 +131,12 @@ def __init__(self, orig): self.__dict__[k] = copy.copy(rewrite_module(v)) else: self.__dict__[k] = copy.copy(v) - RewrittenModule.forward = AST_Rewriter().rewrite(cast(FunctionType, m.forward)) + + RewrittenModule.forward = AST_Rewriter().rewrite( + cast(FunctionType, m.forward) + ) return RewrittenModule(m) + return rewrite_module(fn) else: # Rewrite this single free function diff --git a/torch/fx/experimental/schema_type_annotation.py b/torch/fx/experimental/schema_type_annotation.py index 5c7ab78706cb9..519fec16cfc84 100644 --- a/torch/fx/experimental/schema_type_annotation.py +++ b/torch/fx/experimental/schema_type_annotation.py @@ -1,13 +1,14 @@ # mypy: allow-untyped-defs -import torch -import torch.fx import inspect from typing import Any, Dict, Optional, Tuple -from torch.fx.node import Argument, Target + +import torch +import torch.fx from torch._jit_internal import boolean_dispatched +from torch.fx import Transformer +from torch.fx.node import Argument, Target from torch.fx.operator_schemas import _torchscript_type_to_python_type -from torch.fx import Transformer class AnnotateTypesWithSchema(Transformer): """ @@ -27,16 +28,24 @@ class AnnotateTypesWithSchema(Transformer): traced = AnnotateTypesWithSchema(traced).transform() """ - def __init__(self, module : torch.nn.Module, annotate_functionals : bool = True, - annotate_modules : bool = True, annotate_get_attrs : bool = True): + + def __init__( + self, + module: torch.nn.Module, + annotate_functionals: bool = True, + annotate_modules: bool = True, + annotate_get_attrs: bool = True, + ): super().__init__(module) self.annotate_functionals = annotate_functionals self.annotate_modules = annotate_modules self.annotate_get_attrs = annotate_get_attrs - def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): + def call_function( + self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ): python_ret_type = None - if self.annotate_functionals and target.__module__ == 'torch.nn.functional': + if self.annotate_functionals and target.__module__ == "torch.nn.functional": target_for_analysis = target if target in boolean_dispatched: # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have @@ -45,51 +54,71 @@ def call_function(self, target : Target, args : Tuple[Argument, ...], kwargs : D # branch signature for analysis. Otherwise, leave this un-normalized assert not isinstance(target, str) dispatched = boolean_dispatched[target] - if_true, if_false = dispatched['if_true'], dispatched['if_false'] + if_true, if_false = dispatched["if_true"], dispatched["if_false"] # TODO: can we emit the union of these? What are the implications on TorchScript # compilation? - if inspect.signature(if_true).return_annotation != inspect.signature(if_false).return_annotation: + if ( + inspect.signature(if_true).return_annotation + != inspect.signature(if_false).return_annotation + ): return super().call_function(target, args, kwargs) target_for_analysis = if_true python_ret_type = self._extract_python_return_type(target_for_analysis) return_proxy = super().call_function(target, args, kwargs) - return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type + return_proxy.node.type = ( + return_proxy.node.type if return_proxy.node.type else python_ret_type + ) return return_proxy - def call_module(self, target : Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): + def call_module( + self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ): python_ret_type = None assert isinstance(target, str) submod = self.fetch_attr(target) - if self.annotate_modules and hasattr(submod.__class__, '__name__'): + if self.annotate_modules and hasattr(submod.__class__, "__name__"): classname = submod.__class__.__name__ if getattr(torch.nn, classname, None) == submod.__class__: python_ret_type = self._extract_python_return_type(submod.forward) return_proxy = super().call_module(target, args, kwargs) - return_proxy.node.type = return_proxy.node.type if return_proxy.node.type else python_ret_type + return_proxy.node.type = ( + return_proxy.node.type if return_proxy.node.type else python_ret_type + ) return return_proxy - def get_attr(self, target : torch.fx.node.Target, args : Tuple[Argument, ...], kwargs : Dict[str, Any]): + def get_attr( + self, + target: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Any], + ): attr_proxy = super().get_attr(target, args, kwargs) if self.annotate_get_attrs: module_itr = self.module assert isinstance(target, str) - atoms = target.split('.') + atoms = target.split(".") for i, atom in enumerate(atoms): if not hasattr(module_itr, atom): - raise RuntimeError(f'Node referenced nonextent target {".".join(atoms[:i])}!') + raise RuntimeError( + f'Node referenced nonextent target {".".join(atoms[:i])}!' + ) module_itr = getattr(module_itr, atom) maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr) if maybe_inferred_ts_type.success(): - python_type = _torchscript_type_to_python_type(maybe_inferred_ts_type.type()) - attr_proxy.node.type = python_type if not attr_proxy.node.type else attr_proxy.node.type + python_type = _torchscript_type_to_python_type( + maybe_inferred_ts_type.type() + ) + attr_proxy.node.type = ( + python_type if not attr_proxy.node.type else attr_proxy.node.type + ) return attr_proxy - def _extract_python_return_type(self, target : Target) -> Optional[Any]: + def _extract_python_return_type(self, target: Target) -> Optional[Any]: """ Given a Python call target, try to extract the Python return annotation if it is available, otherwise return None @@ -109,4 +138,8 @@ def _extract_python_return_type(self, target : Target) -> Optional[Any]: except (ValueError, TypeError): return None - return sig.return_annotation if sig.return_annotation is not inspect.Signature.empty else None + return ( + sig.return_annotation + if sig.return_annotation is not inspect.Signature.empty + else None + ) diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 8c8af90ee5bff..c30cab7431c48 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -1096,7 +1096,6 @@ def binary_magic_impl(self, other): get_proxy_mode, handle_sym_dispatch, ) - from torch.fx.experimental.symbolic_shapes import safe_expand op = method_to_operator(method) @@ -1136,7 +1135,6 @@ def binary_magic_impl(self, other): except Exception: log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) raise - out = safe_expand(out) sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out) pytype: Type # This is not strictly correct. In Python, a**b may return complex when @@ -1174,7 +1172,6 @@ def unary_magic_impl(self): get_proxy_mode, handle_sym_dispatch, ) - from torch.fx.experimental.symbolic_shapes import safe_expand op = method_to_operator(method) if get_proxy_mode(): @@ -1193,7 +1190,6 @@ def unary_magic_impl(self): out_hint = None if self.hint is not None: out_hint = op(self.hint) - out = safe_expand(out) pytype: Type if method in always_int_magic_methods: pytype = int @@ -1216,7 +1212,6 @@ def sym_ite_impl(pred_node, then_node, else_node): get_proxy_mode, handle_sym_dispatch, ) - from torch.fx.experimental.symbolic_shapes import safe_expand out_hint = then_node.hint if pred_node.hint else else_node.hint if get_proxy_mode(): @@ -1245,7 +1240,6 @@ def sym_ite_impl(pred_node, then_node, else_node): ) raise - out = safe_expand(out) fx_node, _ = pred_node.shape_env._create_fx_call_function( sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node) ) @@ -1261,7 +1255,6 @@ def round_impl(self, ndigits=None): get_proxy_mode, handle_sym_dispatch, ) - from torch.fx.experimental.symbolic_shapes import safe_expand op = builtins.round if get_proxy_mode(): @@ -1276,8 +1269,6 @@ def round_impl(self, ndigits=None): log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) raise - out = safe_expand(out) - if ndigits is None: pytype = int else: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 3fc8d2ad8f547..83c651e29c585 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -159,6 +159,7 @@ class PendingUnbackedSymbolNotFound(RuntimeError): "resolve_unbacked_bindings", "is_accessor_node", "ValueRangesSLoc", + "SymIntEqByExpr", ] # FX node metadata keys for symbolic shape FX graph. @@ -193,6 +194,55 @@ def log_lru_cache_stats(wrapped_f: functools._lru_cache_wrapper[object]) -> None _SympyT = TypeVar("_SympyT", sympy.Expr, SympyBoolean, sympy.Basic) +class SymIntEqByExpr: + """ + This is a wrapper around SymInt which has alternative semantics for + equality. Specifically, instead of erroring or guarding, we + instead will hash/compare equality based on the underlying sympy + expression; e.g., s0 and s1 will always compare as False. + + NB: This does NOT do fancy analysis that maybe_evaluate_static does; + we can only reason through equalities that occur because to expressions + canonicalize to the same expression via regular simplification. + """ + + val: Union[torch.SymInt, int] + + def __init__(self, val: Union[torch.SymInt, int]) -> None: + self.val = val + + def __repr__(self) -> str: + return repr(self.val) + + def _extract(self) -> sympy.Expr: + if isinstance(self.val, torch.SymInt): + return self.val.node.expr + else: + return sympy.Integer(self.val) + + def __eq__(self, other: object) -> bool: + assert isinstance(other, SymIntEqByExpr) + + # int equality fastpath + if type(self.val) is int and type(other.val) is int: + return self.val == other.val + + return self._extract() == other._extract() + + def __hash__(self) -> int: + return hash(self._extract()) + + +def _nested_int_aware_sort(tup: Tuple[Union[SymInt, int], int]) -> Tuple[int, int, int]: + return ( + # Order nested ints by their coefficients. + # 1 here to order nested ints after non-nested-ints. + (1, tup[0].node.nested_int_coeff(), tup[1]) + if is_nested_int(tup[0]) + else (0, *tup) + ) + + # Wrapper on lru_cache that reports statistics at process end def lru_cache( maxsize: Optional[int], @@ -894,10 +944,10 @@ def free_unbacked_symbols_with_path( and rhs in pending ): # TODO: DivideByKey needs to test divisibility at runtime! - r[s] = path + (DivideByKey(int(lhs)),) + r[rhs] = path + (DivideByKey(int(lhs)),) if real is not None: assert isinstance(real, int) - shape_env.set_unbacked_var_to_val(s, real // int(lhs)) + shape_env.set_unbacked_var_to_val(rhs, real // int(lhs)) pending.remove(rhs) # The annoyance here arises from the fact that SymBool is # allocated by allocating a SymInt and then testing if it's equal @@ -974,11 +1024,10 @@ def definitely_true(a: BoolLikeType) -> bool: that would cause the expression to return True. When is it appropriate to use definitely_true? First, if you can use - a higher level combinator like parallel_or/parallel_and, prefer using - those instead, they are definitely safe (modulo short-circuiting). + a higher level combinator prefer using those instead, they are definitely + safe (modulo short-circuiting). Second, it can be used if the program would behave equivalently if - definitely_true always returned False (parallel_or/parallel_and are - examples of this pattern, modulo short-circuiting). Finally, it even + definitely_true always returned False. Finally, it even be OK if the program wouldn't behave equivalently, so long as the change is semantics preserving. It can be semantics preserving if the program errors in more cases than it did previously (but otherwise @@ -1034,30 +1083,6 @@ def statically_known_true(x: Union[bool, SymBool]) -> bool: return x -def parallel_or(*args: BoolLikeType) -> BoolLikeType: - """ - Evaluate the logical OR of several arguments, avoiding guarding on - unbacked SymInts if another argument is definitely True. - """ - if any(statically_known_true(a) for a in args): - return True - if any(definitely_true(a) for a in args): - return True - return any(args) - - -def parallel_and(*args: BoolLikeType) -> BoolLikeType: - """ - Evaluate the logical FALSE of several arguments, avoiding guarding on - unbacked SymInts if another argument is definitely False. - """ - if any(statically_known_true(torch.sym_not(a)) for a in args): - return False - if any(definitely_false(a) for a in args): - return False - return all(args) - - def sym_eq(x: _T, y: _T) -> Union[bool, SymBool]: """ Like ==, but when run on list/tuple, it will recursively test equality @@ -1147,7 +1172,7 @@ def _constrain_range_for_size( raise ValueError("Constraining SymFloat/SymBool is nyi") assert isinstance(a, SymInt), "can only constrain range for SymInt" - assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + assert isinstance(a.node.expr, sympy.Symbol), f"constraining non-Symbols NYI: {a}" a.node.shape_env._constrain_range_for_size(a.node.expr, min, max) @@ -1450,6 +1475,7 @@ class EqualityConstraint(Constraint): Tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]] ] phantom_symbols: List[sympy.Symbol] + relaxed_sources: Set[Source] _parents: Dict[Source, Source] = field(init=False) _defs: Dict[Source, sympy.Expr] = field(init=False) @@ -1515,10 +1541,12 @@ def _rewrite(self, src: Source) -> sympy.Expr: def is_equal(self, source1: Source, source2: Source) -> bool: return ( # check whether source1 and source2 have the same root - self._find(source1) == self._find(source2) - or + # or are relaxed + (src1 := self._find(source1)) in self.relaxed_sources + or (src2 := self._find(source2)) in self.relaxed_sources + or src1 == src2 # check whether source1 is derived equal to source2 - self.is_derived(source1, source2, lambda x: x) + or self.is_derived(source1, source2, lambda x: x) ) def is_derived( @@ -1754,6 +1782,15 @@ def _fast_expand(expr: _SympyT) -> _SympyT: @lru_cache(256) def safe_expand(r: _SympyT) -> _SympyT: + """ + Expand the given symbolic expression by recursively rewriting product of + sums into sum of products (with the product being either a multiplication or + exponentiation). + + NOTE: using this on an intermediate expression may prevent simplification + down the line, e.g., if we eagerly expand `(a + b)^2` into `a^2 + 2ab + b^2`, + we won't be able to simplify `(a^2 + 2ab + b^2) / (a + b)` as easily. + """ if hasattr(r, "expand"): try: return _fast_expand(r) @@ -2732,7 +2769,10 @@ def relation_with_digit(expr: str, op: str, digit: int) -> None: relation_with_digit(right, flip(op), int(left)) else: assert op == "==", t - results[left]["eq"] = sympy.sympify(right) + try: + results[left]["eq"] = sympy.sympify(right) + except TypeError as e: # rhs source is not linked to Dim name + pass # order forced specializations based on name forced_specializations = { @@ -3035,7 +3075,7 @@ def _init( # deferred_runtime_asserts to compute its length self.num_deferred_runtime_asserts = 0 self.log = log - self.log.debug("create_env") + self.log.info("create_env") self.frozen = False self.runtime_asserts_frozen = False self.dim_constraints: Optional[DimConstraints] = None @@ -3360,12 +3400,10 @@ def _ignore_fresh_unbacked_symbols_tls(self) -> bool: return getattr(TLS, "ignore_fresh_unbacked_symbols", False) @record_shapeenv_event() - def _ignore_fresh_unbacked_symbols_enter(self) -> None: - TLS.ignore_fresh_unbacked_symbols = True - - @record_shapeenv_event() - def _ignore_fresh_unbacked_symbols_exit(self) -> None: - TLS.ignore_fresh_unbacked_symbols = False + def _ignore_fresh_unbacked_symbols_set(self, b: bool) -> bool: + prev = self._ignore_fresh_unbacked_symbols_tls() + TLS.ignore_fresh_unbacked_symbols = b + return prev @contextmanager def ignore_fresh_unbacked_symbols(self) -> Iterator[None]: @@ -3373,11 +3411,11 @@ def ignore_fresh_unbacked_symbols(self) -> Iterator[None]: Indicates that the newly allocated unbacked SymInts are being discarded """ - self._ignore_fresh_unbacked_symbols_enter() + prev = self._ignore_fresh_unbacked_symbols_set(True) try: yield finally: - self._ignore_fresh_unbacked_symbols_exit() + self._ignore_fresh_unbacked_symbols_set(prev) @record_shapeenv_event() def freeze(self) -> None: @@ -3739,21 +3777,10 @@ def _create_symbolic_sizes_strides_storage_offset( candidates = { ex_size[i] * ex_stride[i]: size[i] * stride[i] for i in range(len(size)) - if stride[i] is not None and ex_stride[i] >= 0 + if stride[i] is not None } # iterate over unbound strides in sorted order - def _nested_int_aware_sort( - tup: Tuple[Union[SymInt, int], int] - ) -> Tuple[int, int, int]: - return ( - # Order nested ints by their coefficients. - # 1 here to order nested ints after non-nested-ints. - (1, tup[0].node.nested_int_coeff(), tup[1]) - if is_nested_int(tup[0]) - else (0, *tup) - ) - val_list = sorted( [(ex_stride[i], i) for i in range(len(stride)) if stride[i] is None], key=_nested_int_aware_sort, @@ -3841,8 +3868,6 @@ def create_symintnode( guess """ - source_name = source.name() if source else None - if self._translation_validation_enabled and source is not None: # Create a new symbol for this source. symbol = self._create_symbol_for_source(source) @@ -3881,8 +3906,6 @@ def create_symfloatnode( source: Optional[Source] = None, ) -> Union[float, SymFloat]: """Create a SymFloat value from a symbolic expression""" - source_name = source.name() if source else None - if self._translation_validation_enabled and source is not None: # Create a new symbol for this source. symbol = self._create_symbol_for_source(source) @@ -4770,7 +4793,7 @@ def track_symfloat(source: Source, val: Union[float, SymFloat]) -> None: res = f"{source_ref(source)} == {sexpr}" exprs.append(res) if (s0 := self.source_to_var.get(srcname)) is not None: - if source != (canonical_source := self.var_to_sources[s0][0]): + if source != self.var_to_sources[s0][0]: verbose_exprs.append( f"{res} # duck sizing added this equality because these " f"variables had the same size {self.var_to_val[s0]} " @@ -5361,6 +5384,7 @@ def _update_divisible(self) -> None: @_lru_cache def simplify(self, expr: _SympyT) -> _SympyT: """Use known constraints and replacements to simplify the given expr""" + expr = safe_expand(expr) expr = self.replace(expr) # TODO it would seem that this pass is not necessary given the # below replacement of // with /, but for nested FloorDivs @@ -5525,20 +5549,19 @@ def _update_var_to_range( # because we would now give inconsistent results for all size # oblivous tests! if upper < 2 and symbol in self.size_like: - upper = 2 + vr = ValueRanges(lower, 2) # Updates the range and the guards corresponding to each bound of the symbol. if symbol not in self.var_to_range: - r = ValueRanges(lower, upper) - self.log.debug("_update_var_to_range %s = %s (new)", symbol, r) - self.var_to_range[symbol] = r + self.log.debug("_update_var_to_range %s = %s (new)", symbol, vr) + self.var_to_range[symbol] = vr if vr_sloc is None: sloc = self._get_sloc() vr_sloc = ValueRangesSLoc(sloc, sloc) self.var_to_range_sloc[symbol] = vr_sloc else: old = self.var_to_range[symbol] - new = old & ValueRanges(lower, upper) + new = old & vr if new != old: if vr_sloc is None: sloc = self._get_sloc() @@ -5894,7 +5917,7 @@ def _default_value_range(self) -> ValueRanges: return ValueRanges(lower, int_oo) def _default_unspecified_value_range(self) -> ValueRanges: - return ValueRanges(-int_oo, int_oo) + return ValueRanges.unknown_int() @_lru_cache def _simplify_floor_div(self, expr: sympy.Expr) -> sympy.Expr: @@ -6103,7 +6126,6 @@ def compute_concrete_val() -> sympy.Basic: # If an error is raised before the end of this function, we remove the FX node # inserted, and re-raise the error. guard = None - tb = None try: if orig_expr.is_number: @@ -6267,6 +6289,7 @@ def cleanup(self) -> None: for ra in ras: ra.stack.cleanup() + @lru_cache(256) @record_shapeenv_event(save_tracked_fakes=True) def defer_runtime_assert( self, orig_expr: SympyBoolean, msg: str, fx_node: Optional[torch.fx.Node] = None @@ -6304,7 +6327,6 @@ def defer_runtime_assert( # NB: Don't use new_expr as expr; it could contain gunk like shape0 # which we don't want to guard on - # OK, we're definitely doing a runtime assert now if ( self._translation_validation_enabled and fx_node is not None @@ -6318,10 +6340,9 @@ def defer_runtime_assert( if not self._suppress_guards_tls(): # If you're here because of this assert, read Note [Backwards runtime asserts] # in torch/_inductor/graph.py - assert not self.runtime_asserts_frozen, expr - + if self.runtime_asserts_frozen: + log.warning("runtime_asserts_frozen but then got %s", expr) self._check_frozen(expr, sympy.true) - # eliminate symbols on equality tests / refine ranges if isinstance(expr, sympy.Rel): self._maybe_guard_rel(expr) diff --git a/torch/fx/experimental/unification/__init__.py b/torch/fx/experimental/unification/__init__.py index 31446d0e61253..7db0e29d1d4f7 100644 --- a/torch/fx/experimental/unification/__init__.py +++ b/torch/fx/experimental/unification/__init__.py @@ -1,4 +1,4 @@ # mypy: disable-error-code=attr-defined -from .core import unify, reify # noqa: F403 +from .core import reify, unify # noqa: F403 from .more import unifiable # noqa: F403 -from .variable import var, isvar, vars, variables, Var # noqa: F403 +from .variable import isvar, Var, var, variables, vars # noqa: F403 diff --git a/torch/fx/experimental/unification/core.py b/torch/fx/experimental/unification/core.py index 0893c385bbc9a..e32f42c8968e8 100644 --- a/torch/fx/experimental/unification/core.py +++ b/torch/fx/experimental/unification/core.py @@ -2,10 +2,11 @@ from collections.abc import Iterator # type: ignore[import] from functools import partial +from .dispatch import dispatch from .unification_tools import assoc # type: ignore[import] from .utils import transitive_get as walk from .variable import isvar -from .dispatch import dispatch + __all__ = ["reify", "unify"] @@ -13,33 +14,47 @@ # Reification # ############### + @dispatch(Iterator, dict) def _reify(t, s): return map(partial(reify, s=s), t) # return (reify(arg, s) for arg in t) + + _reify + @dispatch(tuple, dict) # type: ignore[no-redef] def _reify(t, s): return tuple(reify(iter(t), s)) + + _reify + @dispatch(list, dict) # type: ignore[no-redef] def _reify(t, s): return list(reify(iter(t), s)) + + _reify + @dispatch(dict, dict) # type: ignore[no-redef] def _reify(d, s): return {k: reify(v, s) for k, v in d.items()} + + _reify + @dispatch(object, dict) # type: ignore[no-redef] def _reify(o, s): return o # catch all, just return the object + def reify(e, s): - """ Replace variables of expression with substitution + """Replace variables of expression with substitution >>> # xdoctest: +SKIP >>> x, y = var(), var() >>> e = (1, x, (3, y)) @@ -54,12 +69,14 @@ def reify(e, s): return reify(s[e], s) if e in s else e return _reify(e, s) + ############### # Unification # ############### seq = tuple, list, Iterator + @dispatch(seq, seq, dict) def _unify(u, v, s): if len(u) != len(v): @@ -69,6 +86,8 @@ def _unify(u, v, s): if s is False: return False return s + + # # @dispatch((set, frozenset), (set, frozenset), dict) # def _unify(u, v, s): @@ -98,8 +117,8 @@ def _unify(u, v, s): @dispatch(object, object, dict) def unify(u, v, s): # no check at the moment - """ Find substitution so that u == v while satisfying s - >>> x = var('x') + """Find substitution so that u == v while satisfying s + >>> x = var("x") >>> unify((1, x), (1, 2), {}) {~x: 2} """ @@ -112,8 +131,11 @@ def unify(u, v, s): # no check at the moment if isvar(v): return assoc(s, v, u) return _unify(u, v, s) + + unify + @dispatch(object, object) # type: ignore[no-redef] def unify(u, v): return unify(u, v, {}) diff --git a/torch/fx/experimental/unification/dispatch.py b/torch/fx/experimental/unification/dispatch.py index 93039ce75070f..82d62e1f16197 100644 --- a/torch/fx/experimental/unification/dispatch.py +++ b/torch/fx/experimental/unification/dispatch.py @@ -1,6 +1,8 @@ from functools import partial + from .multipledispatch import dispatch # type: ignore[import] + namespace = {} # type: ignore[var-annotated] dispatch = partial(dispatch, namespace=namespace) diff --git a/torch/fx/experimental/unification/match.py b/torch/fx/experimental/unification/match.py index 96583ef324ded..01861a086f64b 100644 --- a/torch/fx/experimental/unification/match.py +++ b/torch/fx/experimental/unification/match.py @@ -1,8 +1,8 @@ # mypy: allow-untyped-defs -from .core import unify, reify # type: ignore[attr-defined] -from .variable import isvar +from .core import reify, unify # type: ignore[attr-defined] +from .unification_tools import first, groupby # type: ignore[import] from .utils import _toposort, freeze -from .unification_tools import groupby, first # type: ignore[import] +from .variable import isvar class Dispatcher: @@ -16,7 +16,7 @@ def add(self, signature, func): self.ordering = ordering(self.funcs) def __call__(self, *args, **kwargs): - func, s = self.resolve(args) + func, _ = self.resolve(args) return func(*args, **kwargs) def resolve(self, args): @@ -28,32 +28,38 @@ def resolve(self, args): if s is not False: result = self.funcs[signature] return result, s - raise NotImplementedError("No match found. \nKnown matches: " - + str(self.ordering) + "\nInput: " + str(args)) + raise NotImplementedError( + "No match found. \nKnown matches: " + + str(self.ordering) + + "\nInput: " + + str(args) + ) def register(self, *signature): def _(func): self.add(signature, func) return self + return _ class VarDispatcher(Dispatcher): - """ A dispatcher that calls functions with variable names + """A dispatcher that calls functions with variable names >>> # xdoctest: +SKIP - >>> d = VarDispatcher('d') - >>> x = var('x') - >>> @d.register('inc', x) + >>> d = VarDispatcher("d") + >>> x = var("x") + >>> @d.register("inc", x) ... def f(x): ... return x + 1 - >>> @d.register('double', x) + >>> @d.register("double", x) ... def f(x): ... return x * 2 - >>> d('inc', 10) + >>> d("inc", 10) 11 - >>> d('double', 10) + >>> d("double", 10) 20 """ + def __call__(self, *args, **kwargs): func, s = self.resolve(args) d = {k.token: v for k, v in s.items()} @@ -64,8 +70,8 @@ def __call__(self, *args, **kwargs): def match(*signature, **kwargs): - namespace = kwargs.get('namespace', global_namespace) - dispatcher = kwargs.get('Dispatcher', Dispatcher) + namespace = kwargs.get("namespace", global_namespace) + dispatcher = kwargs.get("Dispatcher", Dispatcher) def _(func): name = func.__name__ @@ -77,11 +83,12 @@ def _(func): d.add(signature, func) return d + return _ def supercedes(a, b): - """ ``a`` is a more specific match than ``b`` """ + """``a`` is a more specific match than ``b``""" if isvar(b) and not isvar(a): return True s = unify(a, b) @@ -96,7 +103,7 @@ def supercedes(a, b): # Taken from multipledispatch def edge(a, b, tie_breaker=hash): - """ A should be checked before B + """A should be checked before B Tie broken by tie_breaker, defaults to ``hash`` """ if supercedes(a, b): @@ -109,7 +116,7 @@ def edge(a, b, tie_breaker=hash): # Taken from multipledispatch def ordering(signatures): - """ A sane ordering of signatures to check, first to last + """A sane ordering of signatures to check, first to last Topological sort of edges as given by ``edge`` and ``supercedes`` """ signatures = list(map(tuple, signatures)) diff --git a/torch/fx/experimental/unification/more.py b/torch/fx/experimental/unification/more.py index 2228448a71a1f..da2b1773f95ba 100644 --- a/torch/fx/experimental/unification/more.py +++ b/torch/fx/experimental/unification/more.py @@ -1,10 +1,10 @@ # mypy: allow-untyped-defs -from .core import unify, reify # type: ignore[attr-defined] +from .core import reify, unify # type: ignore[attr-defined] from .dispatch import dispatch def unifiable(cls): - """ Register standard unify and reify operations on class + """Register standard unify and reify operations on class This uses the type and __dict__ or __slots__ attributes to define the nature of the term See Also: @@ -15,7 +15,7 @@ def unifiable(cls): ... self.b = b >>> unifiable(A) - >>> x = var('x') + >>> x = var("x") >>> a = A(1, 2) >>> b = A(1, x) >>> unify(a, b, {}) @@ -33,22 +33,23 @@ def unifiable(cls): def reify_object(o, s): - """ Reify a Python object with a substitution + """Reify a Python object with a substitution >>> # xdoctest: +SKIP >>> class Foo(object): ... def __init__(self, a, b): ... self.a = a ... self.b = b + ... ... def __str__(self): - ... return "Foo(%s, %s)"%(str(self.a), str(self.b)) - >>> x = var('x') + ... return "Foo(%s, %s)" % (str(self.a), str(self.b)) + >>> x = var("x") >>> f = Foo(1, x) >>> print(f) Foo(1, ~x) >>> print(reify_object(f, {x: 2})) Foo(1, 2) """ - if hasattr(o, '__slots__'): + if hasattr(o, "__slots__"): return _reify_object_slots(o, s) else: return _reify_object_dict(o, s) @@ -77,7 +78,7 @@ def _reify_object_slots(o, s): @dispatch(slice, dict) def _reify(o, s): - """ Reify a Python ``slice`` object """ + """Reify a Python ``slice`` object""" return slice(*reify((o.start, o.stop, o.step), s)) @@ -87,16 +88,17 @@ def _reify(o, s): def unify_object(u, v, s): - """ Unify two Python objects + """Unify two Python objects Unifies their type and ``__dict__`` attributes >>> # xdoctest: +SKIP >>> class Foo(object): ... def __init__(self, a, b): ... self.a = a ... self.b = b + ... ... def __str__(self): - ... return "Foo(%s, %s)"%(str(self.a), str(self.b)) - >>> x = var('x') + ... return "Foo(%s, %s)" % (str(self.a), str(self.b)) + >>> x = var("x") >>> f = Foo(1, x) >>> g = Foo(1, 2) >>> unify_object(f, g, {}) @@ -104,15 +106,17 @@ def unify_object(u, v, s): """ if type(u) != type(v): return False - if hasattr(u, '__slots__'): - return unify([getattr(u, slot) for slot in u.__slots__], - [getattr(v, slot) for slot in v.__slots__], - s) + if hasattr(u, "__slots__"): + return unify( + [getattr(u, slot) for slot in u.__slots__], + [getattr(v, slot) for slot in v.__slots__], + s, + ) else: return unify(u.__dict__, v.__dict__, s) @dispatch(slice, slice, dict) def _unify(u, v, s): - """ Unify a Python ``slice`` object """ + """Unify a Python ``slice`` object""" return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s) diff --git a/torch/fx/experimental/unification/multipledispatch/__init__.py b/torch/fx/experimental/unification/multipledispatch/__init__.py index a0295af0ea6b6..bb7304069243f 100644 --- a/torch/fx/experimental/unification/multipledispatch/__init__.py +++ b/torch/fx/experimental/unification/multipledispatch/__init__.py @@ -1,3 +1,7 @@ from .core import dispatch -from .dispatcher import (Dispatcher, halt_ordering, restart_ordering, - MDNotImplementedError) +from .dispatcher import ( + Dispatcher, + halt_ordering, + MDNotImplementedError, + restart_ordering, +) diff --git a/torch/fx/experimental/unification/multipledispatch/conflict.py b/torch/fx/experimental/unification/multipledispatch/conflict.py index 7187330ead257..44a893ad56a40 100644 --- a/torch/fx/experimental/unification/multipledispatch/conflict.py +++ b/torch/fx/experimental/unification/multipledispatch/conflict.py @@ -1,17 +1,28 @@ # mypy: allow-untyped-defs +import operator + from .utils import _toposort, groupby from .variadic import isvariadic -import operator -__all__ = ["AmbiguityWarning", "supercedes", "consistent", "ambiguous", "ambiguities", "super_signature", - "edge", "ordering"] + +__all__ = [ + "AmbiguityWarning", + "supercedes", + "consistent", + "ambiguous", + "ambiguities", + "super_signature", + "edge", + "ordering", +] + class AmbiguityWarning(Warning): pass def supercedes(a, b): - """ A is consistent and strictly more specific than B """ + """A is consistent and strictly more specific than B""" if len(a) < len(b): # only case is if a is empty and b is variadic return not a and len(b) == 1 and isvariadic(b[-1]) @@ -41,7 +52,7 @@ def supercedes(a, b): def consistent(a, b): - """ It is possible for an argument list to satisfy both A and B """ + """It is possible for an argument list to satisfy both A and B""" # Need to check for empty args if not a: @@ -51,8 +62,7 @@ def consistent(a, b): # Non-empty args check for mutual subclasses if len(a) == len(b): - return all(issubclass(aa, bb) or issubclass(bb, aa) - for aa, bb in zip(a, b)) + return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b)) else: p1 = 0 p2 = 0 @@ -70,45 +80,53 @@ def consistent(a, b): p1 += 1 # We only need to check for variadic ends # Variadic types are guaranteed to be the last element - return (isvariadic(cur_a) and p2 == len(b) or # type: ignore[possibly-undefined] - isvariadic(cur_b) and p1 == len(a)) # type: ignore[possibly-undefined] + return ( + isvariadic(cur_a) # type: ignore[possibly-undefined] + and p2 == len(b) + or isvariadic(cur_b) # type: ignore[possibly-undefined] + and p1 == len(a) + ) def ambiguous(a, b): - """ A is consistent with B but neither is strictly more specific """ + """A is consistent with B but neither is strictly more specific""" return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) def ambiguities(signatures): - """ All signature pairs such that A is ambiguous with B """ + """All signature pairs such that A is ambiguous with B""" signatures = list(map(tuple, signatures)) - return {(a, b) for a in signatures for b in signatures - if hash(a) < hash(b) - and ambiguous(a, b) - and not any(supercedes(c, a) and supercedes(c, b) - for c in signatures)} + return { + (a, b) + for a in signatures + for b in signatures + if hash(a) < hash(b) + and ambiguous(a, b) + and not any(supercedes(c, a) and supercedes(c, b) for c in signatures) + } def super_signature(signatures): - """ A signature that would break ambiguities """ + """A signature that would break ambiguities""" n = len(signatures[0]) assert all(len(s) == n for s in signatures) - return [max((type.mro(sig[i]) for sig in signatures), key=len)[0] - for i in range(n)] + return [max((type.mro(sig[i]) for sig in signatures), key=len)[0] for i in range(n)] def edge(a, b, tie_breaker=hash): - """ A should be checked before B + """A should be checked before B Tie broken by tie_breaker, defaults to ``hash`` """ # A either supercedes B and B does not supercede A or if B does then call # tie_breaker - return supercedes(a, b) and (not supercedes(b, a) or tie_breaker(a) > tie_breaker(b)) + return supercedes(a, b) and ( + not supercedes(b, a) or tie_breaker(a) > tie_breaker(b) + ) def ordering(signatures): - """ A sane ordering of signatures to check, first to last + """A sane ordering of signatures to check, first to last Topological sort of edges as given by ``edge`` and ``supercedes`` """ signatures = list(map(tuple, signatures)) diff --git a/torch/fx/experimental/unification/multipledispatch/core.py b/torch/fx/experimental/unification/multipledispatch/core.py index 5b5bdbc963014..57a0eadaae157 100644 --- a/torch/fx/experimental/unification/multipledispatch/core.py +++ b/torch/fx/experimental/unification/multipledispatch/core.py @@ -4,12 +4,14 @@ from .dispatcher import Dispatcher, MethodDispatcher + global_namespace = {} # type: ignore[var-annotated] __all__ = ["dispatch", "ismethod"] + def dispatch(*types, **kwargs): - """ Dispatch function on the types of the inputs + """Dispatch function on the types of the inputs Supports dispatch on all non-keyword arguments. Collects implementations based on the function name. Ignores namespaces. If ambiguous type signatures occur a warning is raised when the function is @@ -38,6 +40,7 @@ def dispatch(*types, **kwargs): ... @dispatch(list) ... def __init__(self, data): ... self.data = data + ... ... @dispatch(int) ... def __init__(self, datum): ... self.data = [datum] @@ -46,7 +49,7 @@ def dispatch(*types, **kwargs): >>> MyClass(3).data [3] """ - namespace = kwargs.get('namespace', global_namespace) + namespace = kwargs.get("namespace", global_namespace) types = tuple(types) @@ -65,20 +68,21 @@ def _df(func): dispatcher.add(types, func) return dispatcher + return _df def ismethod(func): - """ Is func a method? + """Is func a method? Note that this has to work as the method is defined but before the class is defined. At this stage methods look like functions. """ if hasattr(inspect, "signature"): signature = inspect.signature(func) - return signature.parameters.get('self', None) is not None + return signature.parameters.get("self", None) is not None else: if sys.version_info.major < 3: spec = inspect.getargspec(func) # type: ignore[attr-defined] else: spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment] - return spec and spec.args and spec.args[0] == 'self' + return spec and spec.args and spec.args[0] == "self" diff --git a/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/torch/fx/experimental/unification/multipledispatch/dispatcher.py index a1d28201d0419..4f160995cce0a 100644 --- a/torch/fx/experimental/unification/multipledispatch/dispatcher.py +++ b/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -1,21 +1,35 @@ # mypy: allow-untyped-defs -from warnings import warn import inspect +import itertools as itl from typing_extensions import deprecated -from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning +from warnings import warn + +from .conflict import ambiguities, AmbiguityWarning, ordering, super_signature from .utils import expand_tuples -from .variadic import Variadic, isvariadic -import itertools as itl +from .variadic import isvariadic, Variadic + + +__all__ = [ + "MDNotImplementedError", + "ambiguity_warn", + "halt_ordering", + "restart_ordering", + "variadic_signature_matches_iter", + "variadic_signature_matches", + "Dispatcher", + "source", + "MethodDispatcher", + "str_signature", + "warning_text", +] -__all__ = ["MDNotImplementedError", "ambiguity_warn", "halt_ordering", "restart_ordering", "variadic_signature_matches_iter", - "variadic_signature_matches", "Dispatcher", "source", "MethodDispatcher", "str_signature", "warning_text"] class MDNotImplementedError(NotImplementedError): - """ A NotImplementedError for multiple dispatch """ + """A NotImplementedError for multiple dispatch""" def ambiguity_warn(dispatcher, ambiguities): - """ Raise warning when ambiguity is detected + """Raise warning when ambiguity is detected Parameters ---------- dispatcher : Dispatcher @@ -92,7 +106,7 @@ def variadic_signature_matches(types, full_signature): class Dispatcher: - """ Dispatch methods based on type signature + """Dispatch methods based on type signature Use ``dispatch`` to add implementations Examples -------- @@ -109,7 +123,8 @@ class Dispatcher: >>> f(3.0) 2.0 """ - __slots__ = '__name__', 'name', 'funcs', '_ordering', '_cache', 'doc' + + __slots__ = "__name__", "name", "funcs", "_ordering", "_cache", "doc" def __init__(self, name, doc=None): self.name = self.__name__ = name @@ -119,9 +134,9 @@ def __init__(self, name, doc=None): self._cache = {} def register(self, *types, **kwargs): - """ register dispatcher with new implementation + """register dispatcher with new implementation >>> # xdoctest: +SKIP - >>> f = Dispatcher('f') + >>> f = Dispatcher("f") >>> @f.register(int) ... def inc(x): ... return x + 1 @@ -139,9 +154,11 @@ def register(self, *types, **kwargs): >>> f([1, 2, 3]) [3, 2, 1] """ + def _df(func): - self.add(types, func, **kwargs) # type: ignore[call-arg] + self.add(types, func, **kwargs) # type: ignore[call-arg] return func + return _df @classmethod @@ -152,28 +169,27 @@ def get_func_params(cls, func): @classmethod def get_func_annotations(cls, func): - """ get annotations of function positional parameters - """ + """get annotations of function positional parameters""" params = cls.get_func_params(func) if params: Parameter = inspect.Parameter - params = (param for param in params - if param.kind in - (Parameter.POSITIONAL_ONLY, - Parameter.POSITIONAL_OR_KEYWORD)) + params = ( + param + for param in params + if param.kind + in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) + ) - annotations = tuple( - param.annotation - for param in params) + annotations = tuple(param.annotation for param in params) if all(ann is not Parameter.empty for ann in annotations): return annotations def add(self, signature, func): - """ Add new types/method pair to dispatcher + """Add new types/method pair to dispatcher >>> # xdoctest: +SKIP - >>> D = Dispatcher('add') + >>> D = Dispatcher("add") >>> D.add((int, int), lambda x, y: x + y) >>> D.add((float, float), lambda x, y: x + y) >>> D(1, 2) @@ -202,24 +218,25 @@ def add(self, signature, func): for index, typ in enumerate(signature, start=1): if not isinstance(typ, (type, list)): - str_sig = ', '.join(c.__name__ if isinstance(c, type) - else str(c) for c in signature) - raise TypeError(f"Tried to dispatch on non-type: {typ}\n" - f"In signature: <{str_sig}>\n" - f"In function: {self.name}") + str_sig = ", ".join( + c.__name__ if isinstance(c, type) else str(c) for c in signature + ) + raise TypeError( + f"Tried to dispatch on non-type: {typ}\n" + f"In signature: <{str_sig}>\n" + f"In function: {self.name}" + ) # handle variadic signatures if isinstance(typ, list): if index != len(signature): - raise TypeError( - 'Variadic signature must be the last element' - ) + raise TypeError("Variadic signature must be the last element") if len(typ) != 1: raise TypeError( - 'Variadic signature must contain exactly one element. ' - 'To use a variadic union type place the desired types ' - 'inside of a tuple, e.g., [(int, str)]' + "Variadic signature must contain exactly one element. " + "To use a variadic union type place the desired types " + "inside of a tuple, e.g., [(int, str)]" ) new_signature.append(Variadic[typ[0]]) else: @@ -255,7 +272,8 @@ def __call__(self, *args, **kwargs): func = self.dispatch(*types) if not func: raise NotImplementedError( - f'Could not find signature for {self.name}: <{str_signature(types)}>') from e + f"Could not find signature for {self.name}: <{str_signature(types)}>" + ) from e self._cache[types] = func try: return func(*args, **kwargs) @@ -271,10 +289,12 @@ def __call__(self, *args, **kwargs): raise NotImplementedError( "Matching functions for " - f"{self.name}: <{str_signature(types)}> found, but none completed successfully",) from e + f"{self.name}: <{str_signature(types)}> found, but none completed successfully", + ) from e def __str__(self): return f"" + __repr__ = __str__ def dispatch(self, *types): @@ -304,7 +324,6 @@ def dispatch(self, *types): return None def dispatch_iter(self, *types): - n = len(types) for signature in self.ordering: if len(signature) == n and all(map(issubclass, types, signature)): @@ -315,21 +334,22 @@ def dispatch_iter(self, *types): result = self.funcs[signature] yield result - @deprecated("`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning) + @deprecated( + "`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning + ) def resolve(self, types): - """ Determine appropriate implementation for this type signature + """Determine appropriate implementation for this type signature .. deprecated:: 0.4.4 Use ``dispatch(*types)`` instead """ return self.dispatch(*types) def __getstate__(self): - return {'name': self.name, - 'funcs': self.funcs} + return {"name": self.name, "funcs": self.funcs} def __setstate__(self, d): - self.name = d['name'] - self.funcs = d['funcs'] + self.name = d["name"] + self.funcs = d["funcs"] self._ordering = ordering(self.funcs) self._cache = {} @@ -344,23 +364,23 @@ def __doc__(self): for sig in self.ordering[::-1]: func = self.funcs[sig] if func.__doc__: - s = f'Inputs: <{str_signature(sig)}>\n' - s += '-' * len(s) + '\n' + s = f"Inputs: <{str_signature(sig)}>\n" + s += "-" * len(s) + "\n" s += func.__doc__.strip() docs.append(s) else: other.append(str_signature(sig)) if other: - docs.append('Other signatures:\n ' + '\n '.join(other)) + docs.append("Other signatures:\n " + "\n ".join(other)) - return '\n\n'.join(docs) + return "\n\n".join(docs) def _help(self, *args): return self.dispatch(*map(type, args)).__doc__ def help(self, *args, **kwargs): - """ Print docstring for the function corresponding to inputs """ + """Print docstring for the function corresponding to inputs""" print(self._help(*args)) def _source(self, *args): @@ -370,22 +390,23 @@ def _source(self, *args): return source(func) def source(self, *args, **kwargs): - """ Print source code for the function corresponding to inputs """ + """Print source code for the function corresponding to inputs""" print(self._source(*args)) def source(func): - s = f'File: {inspect.getsourcefile(func)}\n\n' + s = f"File: {inspect.getsourcefile(func)}\n\n" s = s + inspect.getsource(func) return s class MethodDispatcher(Dispatcher): - """ Dispatch methods based on type signature + """Dispatch methods based on type signature See Also: Dispatcher """ - __slots__ = ('obj', 'cls') + + __slots__ = ("obj", "cls") @classmethod def get_func_params(cls, func): @@ -402,26 +423,31 @@ def __call__(self, *args, **kwargs): types = tuple([type(arg) for arg in args]) func = self.dispatch(*types) if not func: - raise NotImplementedError(f'Could not find signature for {self.name}: <{str_signature(types)}>') + raise NotImplementedError( + f"Could not find signature for {self.name}: <{str_signature(types)}>" + ) return func(self.obj, *args, **kwargs) def str_signature(sig): - """ String representation of type signature + """String representation of type signature >>> str_signature((int, float)) 'int, float' """ - return ', '.join(cls.__name__ for cls in sig) + return ", ".join(cls.__name__ for cls in sig) def warning_text(name, amb): - """ The text for ambiguity warnings """ + """The text for ambiguity warnings""" text = f"\nAmbiguities exist in dispatched function {name}\n\n" text += "The following signatures may result in ambiguous behavior:\n" for pair in amb: - text += "\t" + \ - ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n" + text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n" text += "\n\nConsider making the following additions:\n\n" - text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s)) - + f')\ndef {name}(...)' for s in amb]) + text += "\n\n".join( + [ + "@dispatch(" + str_signature(super_signature(s)) + f")\ndef {name}(...)" + for s in amb + ] + ) return text diff --git a/torch/fx/experimental/unification/multipledispatch/utils.py b/torch/fx/experimental/unification/multipledispatch/utils.py index 77702e8ccb7f4..9c91cca2067af 100644 --- a/torch/fx/experimental/unification/multipledispatch/utils.py +++ b/torch/fx/experimental/unification/multipledispatch/utils.py @@ -1,8 +1,10 @@ # mypy: allow-untyped-defs from collections import OrderedDict + __all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"] + def raises(err, lamda): try: lamda() @@ -31,12 +33,12 @@ def expand_tuples(L): # Taken from theano/theano/gof/sched.py # Avoids licensing issues because this was written by Matthew Rocklin def _toposort(edges): - """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) + """Topological sort algorithm by Kahn [1] - O(nodes + vertices) inputs: edges - a dict of the form {a: {b, c}} where b and c depend on a outputs: L - an ordered list of nodes that satisfy the dependencies of edges - >>> _toposort({1: (2, 3), 2: (3, )}) + >>> _toposort({1: (2, 3), 2: (3,)}) [1, 2, 3] >>> # Closely follows the wikipedia page [2] >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", @@ -44,8 +46,7 @@ def _toposort(edges): >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms """ incoming_edges = reverse_dict(edges) - incoming_edges = OrderedDict((k, set(val)) - for k, val in incoming_edges.items()) + incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items()) S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) L = [] @@ -64,7 +65,7 @@ def _toposort(edges): def reverse_dict(d): """Reverses direction of dependence dict - >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} + >>> d = {"a": (1, 2), "b": (2, 3), "c": ()} >>> reverse_dict(d) # doctest: +SKIP {1: ('a',), 2: ('a', 'b'), 3: ('b',)} :note: dict order are not deterministic. As we iterate on the @@ -82,8 +83,8 @@ def reverse_dict(d): # Taken from toolz # Avoids licensing issues because this version was authored by Matthew Rocklin def groupby(func, seq): - """ Group a collection by a key function - >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] + """Group a collection by a key function + >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"] >>> groupby(len, names) # doctest: +SKIP {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} >>> iseven = lambda x: x % 2 == 0 diff --git a/torch/fx/experimental/unification/multipledispatch/variadic.py b/torch/fx/experimental/unification/multipledispatch/variadic.py index 49e546e1ea267..1b5604a152480 100644 --- a/torch/fx/experimental/unification/multipledispatch/variadic.py +++ b/torch/fx/experimental/unification/multipledispatch/variadic.py @@ -1,15 +1,17 @@ # mypy: allow-untyped-defs from .utils import typename + __all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"] + class VariadicSignatureType(type): # checking if subclass is a subclass of self def __subclasscheck__(cls, subclass): - other_type = (subclass.variadic_type if isvariadic(subclass) - else (subclass,)) + other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,) return subclass is cls or all( - issubclass(other, cls.variadic_type) for other in other_type # type: ignore[attr-defined] + issubclass(other, cls.variadic_type) # type: ignore[attr-defined] + for other in other_type ) def __eq__(cls, other): @@ -24,8 +26,7 @@ def __eq__(cls, other): bool Whether or not `other` is equal to `self` """ - return (isvariadic(other) and - set(cls.variadic_type) == set(other.variadic_type)) # type: ignore[attr-defined] + return isvariadic(other) and set(cls.variadic_type) == set(other.variadic_type) # type: ignore[attr-defined] def __hash__(cls): return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined] @@ -57,17 +58,20 @@ class VariadicSignatureMeta(type): generate a new type for Variadic signatures. See the Variadic class for examples of how this behaves. """ + def __getitem__(cls, variadic_type): if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)): - raise ValueError("Variadic types must be type or tuple of types" - " (Variadic[int] or Variadic[(int, float)]") + raise ValueError( + "Variadic types must be type or tuple of types" + " (Variadic[int] or Variadic[(int, float)]" + ) if not isinstance(variadic_type, tuple): - variadic_type = variadic_type, + variadic_type = (variadic_type,) return VariadicSignatureType( - f'Variadic[{typename(variadic_type)}]', + f"Variadic[{typename(variadic_type)}]", (), - dict(variadic_type=variadic_type, __slots__=()) + dict(variadic_type=variadic_type, __slots__=()), ) diff --git a/torch/fx/experimental/unification/unification_tools.py b/torch/fx/experimental/unification/unification_tools.py index d06d9bef771c4..a47d900273f5e 100644 --- a/torch/fx/experimental/unification/unification_tools.py +++ b/torch/fx/experimental/unification/unification_tools.py @@ -1,25 +1,40 @@ # mypy: allow-untyped-defs import collections import operator -from functools import reduce from collections.abc import Mapping +from functools import reduce -__all__ = ['merge', 'merge_with', 'valmap', 'keymap', 'itemmap', - 'valfilter', 'keyfilter', 'itemfilter', - 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in'] + +__all__ = [ + "merge", + "merge_with", + "valmap", + "keymap", + "itemmap", + "valfilter", + "keyfilter", + "itemfilter", + "assoc", + "dissoc", + "assoc_in", + "update_in", + "get_in", +] def _get_factory(f, kwargs): - factory = kwargs.pop('factory', dict) + factory = kwargs.pop("factory", dict) if kwargs: - raise TypeError(f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'") + raise TypeError( + f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'" + ) return factory def merge(*dicts, **kwargs): - """ Merge a collection of dictionaries + """Merge a collection of dictionaries - >>> merge({1: 'one'}, {2: 'two'}) + >>> merge({1: "one"}, {2: "two"}) {1: 'one', 2: 'two'} Later dictionaries have precedence @@ -41,7 +56,7 @@ def merge(*dicts, **kwargs): def merge_with(func, *dicts, **kwargs): - """ Merge dictionaries and apply function to combined values + """Merge dictionaries and apply function to combined values A key may occur in more than one dict, and all values mapped from the key will be passed to the function as a list, such as func([val1, val2, ...]). @@ -70,7 +85,7 @@ def merge_with(func, *dicts, **kwargs): def valmap(func, d, factory=dict): - """ Apply function to values of dictionary + """Apply function to values of dictionary >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} >>> valmap(sum, bills) # doctest: +SKIP @@ -86,7 +101,7 @@ def valmap(func, d, factory=dict): def keymap(func, d, factory=dict): - """ Apply function to keys of dictionary + """Apply function to keys of dictionary >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} >>> keymap(str.lower, bills) # doctest: +SKIP @@ -102,7 +117,7 @@ def keymap(func, d, factory=dict): def itemmap(func, d, factory=dict): - """ Apply function to items of dictionary + """Apply function to items of dictionary >>> accountids = {"Alice": 10, "Bob": 20} >>> itemmap(reversed, accountids) # doctest: +SKIP @@ -118,7 +133,7 @@ def itemmap(func, d, factory=dict): def valfilter(predicate, d, factory=dict): - """ Filter items in dictionary by value + """Filter items in dictionary by value >>> iseven = lambda x: x % 2 == 0 >>> d = {1: 2, 2: 3, 3: 4, 4: 5} @@ -138,7 +153,7 @@ def valfilter(predicate, d, factory=dict): def keyfilter(predicate, d, factory=dict): - """ Filter items in dictionary by key + """Filter items in dictionary by key >>> iseven = lambda x: x % 2 == 0 >>> d = {1: 2, 2: 3, 3: 4, 4: 5} @@ -158,7 +173,7 @@ def keyfilter(predicate, d, factory=dict): def itemfilter(predicate, d, factory=dict): - """ Filter items in dictionary by item + """Filter items in dictionary by item >>> def isvalid(item): ... k, v = item @@ -182,13 +197,13 @@ def itemfilter(predicate, d, factory=dict): def assoc(d, key, value, factory=dict): - """ Return a new dict with new key value pair + """Return a new dict with new key value pair New dict has d[key] set to value. Does not modify the initial dictionary. - >>> assoc({'x': 1}, 'x', 2) + >>> assoc({"x": 1}, "x", 2) {'x': 2} - >>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP + >>> assoc({"x": 1}, "y", 3) # doctest: +SKIP {'x': 1, 'y': 3} """ d2 = factory() @@ -198,22 +213,22 @@ def assoc(d, key, value, factory=dict): def dissoc(d, *keys, **kwargs): - """ Return a new dict with the given key(s) removed. + """Return a new dict with the given key(s) removed. New dict has d[key] deleted for each supplied key. Does not modify the initial dictionary. - >>> dissoc({'x': 1, 'y': 2}, 'y') + >>> dissoc({"x": 1, "y": 2}, "y") {'x': 1} - >>> dissoc({'x': 1, 'y': 2}, 'y', 'x') + >>> dissoc({"x": 1, "y": 2}, "y", "x") {} - >>> dissoc({'x': 1}, 'y') # Ignores missing keys + >>> dissoc({"x": 1}, "y") # Ignores missing keys {'x': 1} """ factory = _get_factory(dissoc, kwargs) d2 = factory() - if len(keys) < len(d) * .6: + if len(keys) < len(d) * 0.6: d2.update(d) for key in keys: if key in d2: @@ -227,13 +242,14 @@ def dissoc(d, *keys, **kwargs): def assoc_in(d, keys, value, factory=dict): - """ Return a new dict with new, potentially nested, key value pair - - >>> purchase = {'name': 'Alice', - ... 'order': {'items': ['Apple', 'Orange'], - ... 'costs': [0.50, 1.25]}, - ... 'credit card': '5555-1234-1234-1234'} - >>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP + """Return a new dict with new, potentially nested, key value pair + + >>> purchase = { + ... "name": "Alice", + ... "order": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> assoc_in(purchase, ["order", "costs"], [0.25, 1.00]) # doctest: +SKIP {'credit card': '5555-1234-1234-1234', 'name': 'Alice', 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}} @@ -242,7 +258,7 @@ def assoc_in(d, keys, value, factory=dict): def update_in(d, keys, func, default=None, factory=dict): - """ Update value in a (potentially) nested dictionary + """Update value in a (potentially) nested dictionary inputs: d - dictionary on which to operate @@ -257,14 +273,15 @@ def update_in(d, keys, func, default=None, factory=dict): specified by the keys, with the innermost value set to func(default). >>> inc = lambda x: x + 1 - >>> update_in({'a': 0}, ['a'], inc) + >>> update_in({"a": 0}, ["a"], inc) {'a': 1} - >>> transaction = {'name': 'Alice', - ... 'purchase': {'items': ['Apple', 'Orange'], - ... 'costs': [0.50, 1.25]}, - ... 'credit card': '5555-1234-1234-1234'} - >>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP + >>> transaction = { + ... "name": "Alice", + ... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> update_in(transaction, ["purchase", "costs"], sum) # doctest: +SKIP {'credit card': '5555-1234-1234-1234', 'name': 'Alice', 'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}} @@ -272,7 +289,7 @@ def update_in(d, keys, func, default=None, factory=dict): >>> # updating a value when k0 is not in d >>> update_in({}, [1, 2, 3], str, default="bar") {1: {2: {3: 'bar'}}} - >>> update_in({1: 'foo'}, [2, 3, 4], inc, 0) + >>> update_in({1: "foo"}, [2, 3, 4], inc, 0) {1: 'foo', 2: {3: {4: 1}}} """ ks = iter(keys) @@ -300,7 +317,7 @@ def update_in(d, keys, func, default=None, factory=dict): def get_in(keys, coll, default=None, no_default=False): - """ Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. + """Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless ``no_default`` is specified, then it raises KeyError or IndexError. @@ -308,20 +325,21 @@ def get_in(keys, coll, default=None, no_default=False): ``get_in`` is a generalization of ``operator.getitem`` for nested data structures such as dictionaries and lists. - >>> transaction = {'name': 'Alice', - ... 'purchase': {'items': ['Apple', 'Orange'], - ... 'costs': [0.50, 1.25]}, - ... 'credit card': '5555-1234-1234-1234'} - >>> get_in(['purchase', 'items', 0], transaction) + >>> transaction = { + ... "name": "Alice", + ... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> get_in(["purchase", "items", 0], transaction) 'Apple' - >>> get_in(['name'], transaction) + >>> get_in(["name"], transaction) 'Alice' - >>> get_in(['purchase', 'total'], transaction) - >>> get_in(['purchase', 'items', 'apple'], transaction) - >>> get_in(['purchase', 'items', 10], transaction) - >>> get_in(['purchase', 'total'], transaction, 0) + >>> get_in(["purchase", "total"], transaction) + >>> get_in(["purchase", "items", "apple"], transaction) + >>> get_in(["purchase", "items", 10], transaction) + >>> get_in(["purchase", "total"], transaction, 0) 0 - >>> get_in(['y'], {}, no_default=True) + >>> get_in(["y"], {}, no_default=True) Traceback (most recent call last): ... KeyError: 'y' @@ -352,9 +370,9 @@ def getter(index): def groupby(key, seq): - """ Group a collection by a key function + """Group a collection by a key function - >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] + >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"] >>> groupby(len, names) # doctest: +SKIP {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} @@ -364,9 +382,14 @@ def groupby(key, seq): Non-callable keys imply grouping on a member. - >>> groupby('gender', [{'name': 'Alice', 'gender': 'F'}, - ... {'name': 'Bob', 'gender': 'M'}, - ... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP + >>> groupby( + ... "gender", + ... [ + ... {"name": "Alice", "gender": "F"}, + ... {"name": "Bob", "gender": "M"}, + ... {"name": "Charlie", "gender": "M"}, + ... ], + ... ) # doctest:+SKIP {'F': [{'gender': 'F', 'name': 'Alice'}], 'M': [{'gender': 'M', 'name': 'Bob'}, {'gender': 'M', 'name': 'Charlie'}]} @@ -388,9 +411,9 @@ def groupby(key, seq): def first(seq): - """ The first element in a sequence + """The first element in a sequence - >>> first('ABC') + >>> first("ABC") 'A' """ return next(iter(seq)) diff --git a/torch/fx/experimental/unification/utils.py b/torch/fx/experimental/unification/utils.py index 609fe59d43f45..7634c9b2ec90b 100644 --- a/torch/fx/experimental/unification/utils.py +++ b/torch/fx/experimental/unification/utils.py @@ -1,5 +1,7 @@ # mypy: allow-untyped-defs __all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"] + + def hashable(x): try: hash(x) @@ -9,7 +11,7 @@ def hashable(x): def transitive_get(key, d): - """ Transitive dict.get + """Transitive dict.get >>> d = {1: 2, 2: 3, 3: 4} >>> d.get(1) 2 @@ -32,13 +34,13 @@ def raises(err, lamda): # Taken from theano/theano/gof/sched.py # Avoids licensing issues because this was written by Matthew Rocklin def _toposort(edges): - """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) + """Topological sort algorithm by Kahn [1] - O(nodes + vertices) inputs: edges - a dict of the form {a: {b, c}} where b and c depend on a outputs: L - an ordered list of nodes that satisfy the dependencies of edges >>> # xdoctest: +SKIP - >>> _toposort({1: (2, 3), 2: (3, )}) + >>> _toposort({1: (2, 3), 2: (3,)}) [1, 2, 3] Closely follows the wikipedia page [2] [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", @@ -47,7 +49,7 @@ def _toposort(edges): """ incoming_edges = reverse_dict(edges) incoming_edges = {k: set(val) for k, val in incoming_edges.items()} - S = ({v for v in edges if v not in incoming_edges}) + S = {v for v in edges if v not in incoming_edges} L = [] while S: @@ -65,7 +67,7 @@ def _toposort(edges): def reverse_dict(d): """Reverses direction of dependence dict - >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} + >>> d = {"a": (1, 2), "b": (2, 3), "c": ()} >>> reverse_dict(d) # doctest: +SKIP {1: ('a',), 2: ('a', 'b'), 3: ('b',)} :note: dict order are not deterministic. As we iterate on the @@ -89,12 +91,12 @@ def xfail(func): def freeze(d): - """ Freeze container to hashable form + """Freeze container to hashable form >>> freeze(1) 1 >>> freeze([1, 2]) (1, 2) - >>> freeze({1: 2}) # doctest: +SKIP + >>> freeze({1: 2}) # doctest: +SKIP frozenset([(1, 2)]) """ if isinstance(d, dict): diff --git a/torch/fx/experimental/unification/variable.py b/torch/fx/experimental/unification/variable.py index 66e97a3a76636..46e59851fdfa8 100644 --- a/torch/fx/experimental/unification/variable.py +++ b/torch/fx/experimental/unification/variable.py @@ -1,14 +1,16 @@ # mypy: allow-untyped-defs from contextlib import contextmanager -from .utils import hashable + from .dispatch import dispatch +from .utils import hashable + _global_logic_variables = set() # type: ignore[var-annotated] _glv = _global_logic_variables class Var: - """ Logic Variable """ + """Logic Variable""" _id = 1 @@ -25,6 +27,7 @@ def __new__(cls, *token): def __str__(self): return "~" + str(self.token) # type: ignore[attr-defined] + __repr__ = __str__ def __eq__(self, other): @@ -46,6 +49,7 @@ def vars(): def isvar(v): return True + isvar @@ -69,12 +73,12 @@ def variables(*variables): False >>> # Normal approach >>> from unification import unify - >>> x = var('x') + >>> x = var("x") >>> unify(x, 1) {~x: 1} >>> # Context Manager approach - >>> with variables('x'): - ... print(unify('x', 1)) + >>> with variables("x"): + ... print(unify("x", 1)) {'x': 1} """ old_global_logic_variables = _global_logic_variables.copy() diff --git a/torch/fx/experimental/unify_refinements.py b/torch/fx/experimental/unify_refinements.py index cad0a33425bf8..bab662e0655a2 100644 --- a/torch/fx/experimental/unify_refinements.py +++ b/torch/fx/experimental/unify_refinements.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs from torch.fx.experimental.graph_gradual_typechecker import Refine +from torch.fx.experimental.unification import unify, Var # type: ignore[attr-defined] from torch.fx.tensor_type import TensorType -from torch.fx.experimental.unification import Var, unify # type: ignore[attr-defined] def infer_symbolic_types_single_pass(traced): @@ -13,6 +13,7 @@ def infer_symbolic_types_single_pass(traced): mgu = unify_eq(r.constraints) substitute_all_types(traced.graph, mgu) + def infer_symbolic_types(traced): """ Calls our symbolic inferencer twice. @@ -32,6 +33,7 @@ def infer_symbolic_types(traced): r.symbolic_relations() + def convert_eq(list_of_eq): """ Convert equality constraints in the right format @@ -109,6 +111,7 @@ def substitute_all_types(graph, mapping): for n in graph.nodes: n.type = substitute_solution_one_type(mapping, n.type) + def check_for_type_equality(g1, g2): """ A check equality to be used in fixed points. diff --git a/torch/fx/graph.py b/torch/fx/graph.py index b0df9f02fcb8e..2aed4a2c80d1f 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1,32 +1,47 @@ # mypy: allow-untyped-defs -from collections import defaultdict -from .node import Node, Argument, Target, map_arg, _type_repr, _get_qualified_name -import torch.utils._pytree as pytree -from . import _pytree as fx_pytree -from ._compatibility import compatibility -from torch._C import _NodeIter - -import os +import builtins import contextlib -from typing import TYPE_CHECKING, Callable, Any, List, Dict, NamedTuple, Optional, Tuple, Set, FrozenSet, Type, Iterable -from dataclasses import dataclass -from contextlib import contextmanager import copy import enum -import torch +import functools +import inspect import keyword -import re -import builtins import math +import os +import re import warnings -import inspect -import functools +from collections import defaultdict +from contextlib import contextmanager +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Dict, + FrozenSet, + Iterable, + List, + NamedTuple, + Optional, + Set, + Tuple, + Type, + TYPE_CHECKING, +) + +import torch +import torch.utils._pytree as pytree +from torch._C import _NodeIter + +from . import _pytree as fx_pytree +from ._compatibility import compatibility +from .node import _get_qualified_name, _type_repr, Argument, map_arg, Node, Target + __all__ = ["PythonCode", "CodeGen", "Graph"] if TYPE_CHECKING: + from ._symbolic_trace import Tracer # noqa: F401 from .graph_module import GraphModule # noqa: F401 - from ._symbolic_trace import Tracer # noqa: F401 # Mapping of builtins to their `typing` equivalent. @@ -38,7 +53,9 @@ tuple: Tuple, } -_legal_ops = dict.fromkeys(['call_function', 'call_method', 'get_attr', 'call_module', 'placeholder', 'output']) +_legal_ops = dict.fromkeys( + ["call_function", "call_method", "get_attr", "call_module", "placeholder", "output"] +) # Signature for functions thattransforms the body (`list[str]`) of the @@ -53,11 +70,13 @@ class _CustomBuiltin(NamedTuple): an import. For common objects of this sort, we bundle them in the globals of every FX graph. """ + # How to import this object from the standard library. import_str: str # The actual object, produced from that import string. obj: Any + _custom_builtins: Dict[str, _CustomBuiltin] = {} @@ -65,17 +84,17 @@ def _register_custom_builtin(name: str, import_str: str, obj: Any): _custom_builtins[name] = _CustomBuiltin(import_str, obj) -_register_custom_builtin('inf', 'from math import inf', math.inf) -_register_custom_builtin('nan', 'from math import nan', math.nan) -_register_custom_builtin('NoneType', 'NoneType = type(None)', type(None)) -_register_custom_builtin('torch', 'import torch', torch) -_register_custom_builtin('device', 'from torch import device', torch.device) -_register_custom_builtin('fx_pytree', 'import torch.fx._pytree as fx_pytree', fx_pytree) -_register_custom_builtin('pytree', 'import torch.utils._pytree as pytree', pytree) +_register_custom_builtin("inf", "from math import inf", math.inf) +_register_custom_builtin("nan", "from math import nan", math.nan) +_register_custom_builtin("NoneType", "NoneType = type(None)", type(None)) +_register_custom_builtin("torch", "import torch", torch) +_register_custom_builtin("device", "from torch import device", torch.device) +_register_custom_builtin("fx_pytree", "import torch.fx._pytree as fx_pytree", fx_pytree) +_register_custom_builtin("pytree", "import torch.utils._pytree as pytree", pytree) def _is_magic(x: str) -> bool: - return x.startswith('__') and x.endswith('__') + return x.startswith("__") and x.endswith("__") def _snake_case(s: str) -> str: @@ -91,22 +110,22 @@ def _snake_case(s: str) -> str: # Replace occurrences where a lowercase letter is followed by an uppercase letter -_snake_case_sub = functools.partial(re.compile(r'(?<=[a-z])([A-Z])').sub, r'_\1') +_snake_case_sub = functools.partial(re.compile(r"(?<=[a-z])([A-Z])").sub, r"_\1") def _is_from_torch(obj: Any) -> bool: - module_name = getattr(obj, '__module__', None) + module_name = getattr(obj, "__module__", None) if module_name is not None: - base_module = module_name.partition('.')[0] + base_module = module_name.partition(".")[0] return ( - base_module == 'torch' and - not module_name.startswith("torch._dynamo.") and - not module_name.startswith("torch._inductor.") + base_module == "torch" + and not module_name.startswith("torch._dynamo.") + and not module_name.startswith("torch._inductor.") ) - name = getattr(obj, '__name__', None) + name = getattr(obj, "__name__", None) # exclude torch because torch.torch.torch.torch works. idk mang - if name is not None and name != 'torch': + if name is not None and name != "torch": for guess in [torch, torch.nn.functional]: if getattr(guess, name, None) is obj: return True @@ -122,13 +141,14 @@ class _Namespace: - Each name is unique within a given namespace. - Names generated do not shadow builtins, unless the object is indeed that builtin. """ + def __init__(self): self._obj_to_name: Dict[Any, str] = {} self._unassociated_names = set() self._used_names: Set[str] = set() self._base_count: Dict[str, int] = defaultdict(int) - self._illegal_char_regex = re.compile('[^0-9a-zA-Z_]+') + self._illegal_char_regex = re.compile("[^0-9a-zA-Z_]+") self._name_suffix_regex = re.compile(r"(.*)_(\d+)$") def create_name(self, candidate: str, obj: Optional[Any]) -> str: @@ -142,13 +162,13 @@ def create_name(self, candidate: str, obj: Optional[Any]) -> str: return self._obj_to_name[obj] # delete all characters that are illegal in a Python identifier - candidate = self._illegal_char_regex.sub('_', candidate) + candidate = self._illegal_char_regex.sub("_", candidate) if not candidate: - candidate = '_unnamed' + candidate = "_unnamed" if candidate[0].isdigit(): - candidate = f'_{candidate}' + candidate = f"_{candidate}" match = self._name_suffix_regex.match(candidate) if match is None: @@ -158,13 +178,13 @@ def create_name(self, candidate: str, obj: Optional[Any]) -> str: base, num_str = match.group(1, 2) num = int(num_str) - candidate = base if num is None else f'{base}_{num}' + candidate = base if num is None else f"{base}_{num}" if not num: num = self._base_count[base] while candidate in self._used_names or self._is_illegal_name(candidate, obj): num += 1 - candidate = f'{base}_{num}' + candidate = f"{base}_{num}" self._used_names.add(candidate) self._base_count[base] = num @@ -204,36 +224,39 @@ def _rename_object(self, obj: Any, name: str): self._obj_to_name[obj] = name self._used_names.add(name) + dtype_abbrs = { - torch.bfloat16: 'bf16', - torch.float64: 'f64', - torch.float32: 'f32', - torch.float16: 'f16', - torch.float8_e4m3fn: 'f8e4m3fn', - torch.float8_e5m2: 'f8e5m2', - torch.float8_e4m3fnuz: 'f8e4m3fnuz', - torch.float8_e5m2fnuz: 'f8e5m2fnuz', - torch.complex32: 'c32', - torch.complex64: 'c64', - torch.complex128: 'c128', - torch.int8: 'i8', - torch.int16: 'i16', - torch.int32: 'i32', - torch.int64: 'i64', - torch.bool: 'b8', - torch.uint8: 'u8', - torch.uint16: 'u16', - torch.uint32: 'u32', - torch.uint64: 'u64', - torch.bits16: 'b16', + torch.bfloat16: "bf16", + torch.float64: "f64", + torch.float32: "f32", + torch.float16: "f16", + torch.float8_e4m3fn: "f8e4m3fn", + torch.float8_e5m2: "f8e5m2", + torch.float8_e4m3fnuz: "f8e4m3fnuz", + torch.float8_e5m2fnuz: "f8e5m2fnuz", + torch.complex32: "c32", + torch.complex64: "c64", + torch.complex128: "c128", + torch.int8: "i8", + torch.int16: "i16", + torch.int32: "i32", + torch.int64: "i64", + torch.bool: "b8", + torch.uint8: "u8", + torch.uint16: "u16", + torch.uint32: "u32", + torch.uint64: "u64", + torch.bits16: "b16", } + @compatibility(is_backward_compatible=True) @dataclass class PythonCode: """ Represents all the information necessary to exec or save a graph as Python code. """ + # Python source code for the forward function definition. src: str # Values in global scope during execution of `src_def`. @@ -244,15 +267,16 @@ class PythonCode: def _format_target(base: str, target: str) -> str: - elems = target.split('.') + elems = target.split(".") r = base for e in elems: if not e.isidentifier(): r = f'getattr({r}, "{e}")' else: - r = f'{r}.{e}' + r = f"{r}.{e}" return r + class _InsertPoint: def __init__(self, graph, new_insert): self.graph = graph @@ -264,9 +288,10 @@ def __enter__(self): def __exit__(self, type, value, tb): self.graph._insert = self.orig_insert + class _node_list: - def __init__(self, graph: 'Graph', direction: str = '_next'): - assert direction in ['_next', '_prev'] + def __init__(self, graph: "Graph", direction: str = "_next"): + assert direction in ["_next", "_prev"] self.graph = graph self.direction = direction @@ -278,39 +303,43 @@ def __iter__(self): yield from _NodeIter(self.graph._root, self.direction == "_prev") def __reversed__(self): - return _node_list(self.graph, '_next' if self.direction == '_prev' else '_prev') + return _node_list(self.graph, "_next" if self.direction == "_prev" else "_prev") + class _PyTreeInfo(NamedTuple): """ Contains extra info stored when we're using Pytrees """ + orig_args: List[str] in_spec: pytree.TreeSpec out_spec: Optional[pytree.TreeSpec] + @dataclass(frozen=True) class _ParsedStackTrace: """ Represents the top-most frame of a parsed stack trace """ + file: str lineno: str name: str code: str def get_summary_str(self): - return f'File: {self.file}:{self.lineno} in {self.name}, code: {self.code}' + return f"File: {self.file}:{self.lineno} in {self.name}, code: {self.code}" + # get File:lineno code from stack_trace def _parse_stack_trace(stack_trace: str): if stack_trace is None: return None pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$") - lines = stack_trace.strip().split('\n') + lines = stack_trace.strip().split("\n") # stacktrace should have innermost frame last, so we # iterate backwards to find the first line that starts # with 'File ' - summary_str = "" for idx in range(len(lines) - 2, -1, -1): line = lines[idx].strip() matches = pattern.match(line) @@ -323,6 +352,7 @@ def _parse_stack_trace(stack_trace: str): return _ParsedStackTrace(file, lineno, name, code) return None + @compatibility(is_backward_compatible=False) class CodeGen: def __init__(self): @@ -336,16 +366,18 @@ def gen_fn_def(self, free_vars: List[str], maybe_return_annotation: str) -> str: """ # If the original function didn't have self as its first argument, we # would have added it. - if len(free_vars) == 0 or free_vars[0] != 'self': - free_vars.insert(0, 'self') - return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:" + if len(free_vars) == 0 or free_vars[0] != "self": + free_vars.insert(0, "self") + return ( + f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:" + ) def generate_output(self, output_args: Argument) -> str: """ Given the output arguments, generates the return statement of the FX function. Note: The returned statement should not be indented. """ - return f'return {repr(output_args)}' + return f"return {repr(output_args)}" def process_inputs(self, *args: Any) -> Any: """ @@ -374,8 +406,15 @@ def additional_globals(self) -> List[Tuple[str, Any]]: return [] def _gen_python_code( - self, nodes, root_module: str, namespace: _Namespace, *, - verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False + self, + nodes, + root_module: str, + namespace: _Namespace, + *, + verbose: bool = False, + include_stride: bool = False, + include_device: bool = False, + colored: bool = False, ) -> PythonCode: free_vars: List[str] = [] body: List[str] = [] @@ -383,9 +422,13 @@ def _gen_python_code( wrapped_fns: Dict[str, None] = {} # Wrap string in list to pass by reference - maybe_return_annotation : List[str] = [''] - include_stride = include_stride or (os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1") - include_device = include_device or (os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1") + maybe_return_annotation: List[str] = [""] + include_stride = include_stride or ( + os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1" + ) + include_device = include_device or ( + os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1" + ) def add_global(name_hint: str, obj: Any): """Add an obj to be tracked as a global. @@ -395,7 +438,9 @@ def add_global(name_hint: str, obj: Any): Returns: the global name that should be used to reference 'obj' in generated source. """ - if _is_from_torch(obj) and obj != torch.device: # to support registering torch.device + if ( + _is_from_torch(obj) and obj != torch.device + ): # to support registering torch.device # HACK: workaround for how torch custom ops are registered. We # can't import them like normal modules so they must retain their # fully qualified name. @@ -414,19 +459,19 @@ def add_global(name_hint: str, obj: Any): for name, (_, obj) in _custom_builtins.items(): add_global(name, obj) - def type_repr(o : Any): + def type_repr(o: Any): if o == (): # Empty tuple is used for empty tuple type annotation Tuple[()] - return '()' + return "()" typename = _type_repr(o) - if hasattr(o, '__origin__'): + if hasattr(o, "__origin__"): # This is a generic type, e.g. typing.List[torch.Tensor] origin_type = _origin_type_map.get(o.__origin__, o.__origin__) origin_typename = add_global(_type_repr(origin_type), origin_type) - if hasattr(o, '__args__'): + if hasattr(o, "__args__"): # Assign global names for each of the inner type variables. args = [type_repr(arg) for arg in o.__args__] @@ -461,12 +506,13 @@ def f(s): if colored: return f"{codes[name]}{s}{codes['reset']}" return s + return f - yellow = make_wrapper_func("yellow") - cyan = make_wrapper_func("cyan") + yellow = make_wrapper_func("yellow") # noqa: F841 + cyan = make_wrapper_func("cyan") # noqa: F841 red = make_wrapper_func("red") - green = make_wrapper_func("green") + green = make_wrapper_func("green") # noqa: F841 dim_green = make_wrapper_func("dim_green") dim = make_wrapper_func("dim") dim_blue = make_wrapper_func("dim_blue") @@ -474,11 +520,13 @@ def f(s): def _get_repr(arg: Any) -> str: # Handle NamedTuples (if it has `_fields`) via add_global. - if isinstance(arg, tuple) and hasattr(arg, '_fields'): + if isinstance(arg, tuple) and hasattr(arg, "_fields"): qualified_name = _get_qualified_name(type(arg)) global_name = add_global(qualified_name, type(arg)) return f"{global_name}{repr(tuple(arg))}" - elif isinstance(arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + elif isinstance( + arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ): qualified_name = _get_qualified_name(arg) global_name = add_global(qualified_name, arg) return f"{global_name}" @@ -492,25 +540,35 @@ def _get_repr(arg: Any) -> str: size = list(arg.size()) dtype = str(arg.dtype).split(".")[-1] return f"torch.Tensor(size={size}, dtype={dtype})" + elif isinstance(arg, tuple): + if len(arg) == 1: + return f"({_get_repr(arg[0])},)" + else: + return "(" + ", ".join(_get_repr(a) for a in arg) + ")" + elif isinstance(arg, list): + return "[" + ", ".join(_get_repr(a) for a in arg) + "]" + elif isinstance(arg, slice): + return f"slice({_get_repr(arg.start)}, {_get_repr(arg.stop)}, {_get_repr(arg.step)})" else: return blue(repr(arg)) - - def _format_args(args: Tuple[Argument, ...], kwargs: Dict[str, Argument]) -> str: - args_s = ', '.join(_get_repr(a) for a in args) - kwargs_s = ', '.join(f'{k} = {_get_repr(v)}' for k, v in kwargs.items()) + def _format_args( + args: Tuple[Argument, ...], kwargs: Dict[str, Argument] + ) -> str: + args_s = ", ".join(_get_repr(a) for a in args) + kwargs_s = ", ".join(f"{k} = {_get_repr(v)}" for k, v in kwargs.items()) if args_s and kwargs_s: - return f'{args_s}, {kwargs_s}' + return f"{args_s}, {kwargs_s}" return args_s or kwargs_s # Run through reverse nodes and record the first instance of a use # of a given node. This represents the *last* use of the node in the # execution order of the program, which we will use to free unused # values - node_to_last_use : Dict[Node, Node] = {} - user_to_last_uses : Dict[Node, List[Node]] = {} + node_to_last_use: Dict[Node, Node] = {} + user_to_last_uses: Dict[Node, List[Node]] = {} - def register_last_uses(n : Node, user : Node): + def register_last_uses(n: Node, user: Node): if n not in node_to_last_use: node_to_last_use[n] = user user_to_last_uses.setdefault(user, []).append(n) @@ -519,16 +577,16 @@ def register_last_uses(n : Node, user : Node): map_arg(node.args, lambda n: register_last_uses(n, node)) map_arg(node.kwargs, lambda n: register_last_uses(n, node)) - def delete_unused_values(user : Node): + def delete_unused_values(user: Node): """ Delete values after their last use. This ensures that values that are not used in the remainder of the code are freed and the memory usage of the code is optimal. """ - if user.op == 'placeholder': + if user.op == "placeholder": return - if user.op == 'output': - body.append('\n') + if user.op == "output": + body.append("\n") return nodes_to_delete = user_to_last_uses.get(user, []) @@ -539,21 +597,23 @@ def delete_unused_values(user : Node): nodes_to_delete.append(user) if len(nodes_to_delete): - to_delete_str = ' = '.join([repr(n) for n in nodes_to_delete] + ['None']) - body.append(f'; {dim(to_delete_str)}\n') + to_delete_str = " = ".join( + [repr(n) for n in nodes_to_delete] + ["None"] + ) + body.append(f"; {dim(to_delete_str)}\n") else: - body.append('\n') + body.append("\n") prev_stacktrace = None - def append_stacktrace_summary(node : Node): + def append_stacktrace_summary(node: Node): """ Append a summary of the stacktrace to the generated code. This is useful for debugging. """ nonlocal prev_stacktrace - if node.op not in {'placeholder', 'output'}: + if node.op not in {"placeholder", "output"}: if node.stack_trace: if node.stack_trace != prev_stacktrace: prev_stacktrace = node.stack_trace @@ -566,93 +626,128 @@ def append_stacktrace_summary(node : Node): elif prev_stacktrace != "": prev_stacktrace = "" no_stacktrace_msg = "# No stacktrace found for following nodes" - body.append(f'\n{dim(no_stacktrace_msg)}\n') + body.append(f"\n{dim(no_stacktrace_msg)}\n") - def stringify_shape(shape : Iterable) -> str: + def stringify_shape(shape: Iterable) -> str: return f"[{', '.join(str(x) for x in shape)}]" - def emit_node(node : Node): - maybe_type_annotation = '' if node.type is None else f' : {type_repr(node.type)}' + def emit_node(node: Node): + maybe_type_annotation = ( + "" if node.type is None else f" : {type_repr(node.type)}" + ) if verbose: # override annotation with more detailed information from torch.fx.experimental.proxy_tensor import py_sym_types from torch.fx.passes.shape_prop import TensorMetadata - meta_val = node.meta.get('val', node.meta.get('tensor_meta', node.meta.get('example_value', None))) + meta_val = node.meta.get( + "val", + node.meta.get("tensor_meta", node.meta.get("example_value", None)), + ) # use string as annotation, to make it valid python code if isinstance(meta_val, torch.Tensor): - stride_annotation = f"{stringify_shape(meta_val.stride())}" if include_stride else "" + stride_annotation = ( + f"{stringify_shape(meta_val.stride())}" + if include_stride + else "" + ) device_annotation = f"{meta_val.device}" if include_device else "" - maybe_type_annotation = \ - f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}' \ + maybe_type_annotation = ( + f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}' f'{dim_blue(stride_annotation)}{dim_green(device_annotation)}"' + ) elif isinstance(meta_val, py_sym_types): maybe_type_annotation = f': "Sym({meta_val})"' elif isinstance(meta_val, TensorMetadata): maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"' - if node.op == 'placeholder': + if node.op == "placeholder": assert isinstance(node.target, str) - maybe_default_arg = '' if not node.args else f' = {_get_repr(node.args[0])}' - free_vars.append(f'{node.target}{maybe_type_annotation}{maybe_default_arg}') - raw_name = node.target.replace('*', '') + maybe_default_arg = ( + "" if not node.args else f" = {_get_repr(node.args[0])}" + ) + free_vars.append( + f"{node.target}{maybe_type_annotation}{maybe_default_arg}" + ) + raw_name = node.target.replace("*", "") if raw_name != repr(node): - body.append(f'{repr(node)} = {raw_name}\n') + body.append(f"{repr(node)} = {raw_name}\n") return - elif node.op == 'call_method': + elif node.op == "call_method": assert isinstance(node.target, str) body.append( - f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}' - f'({_format_args(node.args[1:], node.kwargs)})') + f"{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) return - elif node.op == 'call_function': + elif node.op == "call_function": assert callable(node.target) # pretty print operators - if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in magic_methods: + if ( + getattr(node.target, "__module__", "") == "_operator" + and node.target.__name__ in magic_methods + ): assert isinstance(node.args, tuple) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}" + ) return # pretty print inplace operators; required for jit.script to work properly # not currently supported in normal FX graphs, but generated by torchdynamo - if getattr(node.target, "__module__", "") == '_operator' and node.target.__name__ in inplace_methods: - body.append(f'{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; ' - f'{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}') + if ( + getattr(node.target, "__module__", "") == "_operator" + and node.target.__name__ in inplace_methods + ): + body.append( + f"{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}" + ) return qualified_name = _get_qualified_name(node.target) global_name = add_global(qualified_name, node.target) # special case for getattr: node.args could be 2-argument or 3-argument # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value - if global_name == 'getattr' and \ - isinstance(node.args, tuple) and \ - isinstance(node.args[1], str) and \ - node.args[1].isidentifier() and \ - len(node.args) == 2: - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}') + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}" + ) return - body.append(f'{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})') - if node.meta.get('is_wrapped', False): + body.append( + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): wrapped_fns.setdefault(global_name) return - elif node.op == 'call_module': + elif node.op == "call_module": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = ' - f'{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})') + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) return - elif node.op == 'get_attr': + elif node.op == "get_attr": assert isinstance(node.target, str) - body.append(f'{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}') + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" + ) return - elif node.op == 'output': + elif node.op == "output": if node.type is not None: maybe_return_annotation[0] = f" -> {type_repr(node.type)}" body.append(self.generate_output(node.args[0])) return - raise NotImplementedError(f'node: {node.op} {node.target}') + raise NotImplementedError(f"node: {node.op} {node.target}") for i, node in enumerate(nodes): # NOTE: emit_node does not emit a string with newline. It depends @@ -670,15 +765,13 @@ def emit_node(node : Node): # If the Graph has no non-placeholder nodes, no lines for the body # have been emitted. To continue to have valid Python code, emit a # single pass statement - body.append('pass\n') - - + body.append("pass\n") if len(wrapped_fns) > 0: - wrap_name = add_global('wrap', torch.fx.wrap) - wrap_stmts = '\n'.join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) else: - wrap_stmts = '' + wrap_stmts = "" if self._body_transformer: body = self._body_transformer(body) @@ -690,10 +783,10 @@ def emit_node(node : Node): # remove counter and generate lineno to node index mapping lineno_map: Dict[int, Optional[int]] = {} - prologue_len = prologue.count('\n') + 1 + prologue_len = prologue.count("\n") + 1 new_lines: List[str] = [] cur_idx = None - for line in ''.join(body).split('\n'): + for line in "".join(body).split("\n"): counter = re.search(r"# COUNTER: (\d+)", line) if counter and counter.group(1) is not None: cur_idx = int(counter.group(1)) @@ -701,8 +794,8 @@ def emit_node(node : Node): lineno_map[len(new_lines) + prologue_len] = cur_idx new_lines.append(line) - code = "\n".join(new_lines).lstrip('\n') - code = '\n'.join(' ' + line for line in code.split('\n')) + code = "\n".join(new_lines).lstrip("\n") + code = "\n".join(" " + line for line in code.split("\n")) fn_code = f""" {wrap_stmts} @@ -755,25 +848,35 @@ def gen_fn_def(self, free_vars, maybe_return_annotation): return super().gen_fn_def(free_vars, maybe_return_annotation) fn_args = self.pytree_info.orig_args - has_orig_self = (fn_args[0] == 'self') if len(fn_args) > 0 else False + has_orig_self = (fn_args[0] == "self") if len(fn_args) > 0 else False if has_orig_self: - free_vars.insert(0, 'self') + free_vars.insert(0, "self") fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation) if len(free_vars) > 0: # pytree has placeholders in it # when kwargs is present, in_spec is tuple(args, kwargs) - has_args_kwargs_tuple = self.pytree_info.in_spec.type == tuple and \ - self.pytree_info.in_spec.num_children == 2 and \ - self.pytree_info.in_spec.children_specs[0].type == tuple and \ - self.pytree_info.in_spec.children_specs[1].type == dict - fn_kwargs = '{}' + has_args_kwargs_tuple = ( + self.pytree_info.in_spec.type == tuple + and self.pytree_info.in_spec.num_children == 2 + and self.pytree_info.in_spec.children_specs[0].type == tuple + and self.pytree_info.in_spec.children_specs[1].type == dict + ) + fn_kwargs = "{}" fn_signature = f"[{', '.join(fn_args)}], self._in_spec" if has_args_kwargs_tuple: count_args = self.pytree_info.in_spec.children_specs[0].num_children fn_args = self.pytree_info.orig_args[:count_args] - fn_kwargs = '{' + ', '.join(f"'{k}':{v}" for k, v in zip( - self.pytree_info.in_spec.children_specs[1].context, - self.pytree_info.orig_args[count_args:])) + '}' + fn_kwargs = ( + "{" + + ", ".join( + f"'{k}':{v}" + for k, v in zip( + self.pytree_info.in_spec.children_specs[1].context, + self.pytree_info.orig_args[count_args:], + ) + ) + + "}" + ) fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec" # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid. @@ -790,16 +893,20 @@ def gen_fn_def(self, free_vars, maybe_return_annotation): def generate_output(self, output_args): if self.pytree_info and self.pytree_info.out_spec: - return f'return pytree.tree_unflatten({repr(output_args)}, self._out_spec)' + return f"return pytree.tree_unflatten({repr(output_args)}, self._out_spec)" else: return super().generate_output(output_args) + class _FindNodesLookupTable: """ Side table for the graph for the purpose of doing fast queries """ + def __init__(self): - self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict(dict) + self.table: Dict[Tuple[str, Optional[Target]], Dict[Node, None]] = defaultdict( + dict + ) def _key(self, node) -> Tuple[str, Optional[Target]]: return (node.op, node.target if node.op == "call_function" else None) @@ -813,7 +920,7 @@ def insert(self, node: Node) -> None: def remove(self, node: Node) -> None: self.table[self._key(node)].pop(node) - def find_nodes(self, *, op: str, target: Optional['Target'] = None): + def find_nodes(self, *, op: str, target: Optional["Target"] = None): if op == "call_function": assert target is not None return [*self.table[(op, target)].keys()] @@ -824,6 +931,7 @@ def find_nodes(self, *, op: str, target: Optional['Target'] = None): # op is call_method, get_attr, call_module return [node for node in self.table[(op, None)].keys() if node.target == target] + @compatibility(is_backward_compatible=True) class Graph: """ @@ -839,6 +947,7 @@ class Graph: import torch import torch.fx + class MyModule(torch.nn.Module): def __init__(self): super().__init__() @@ -846,7 +955,10 @@ def __init__(self): self.linear = torch.nn.Linear(4, 5) def forward(self, x): - return torch.topk(torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3) + return torch.topk( + torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3 + ) + m = MyModule() gm = torch.fx.symbolic_trace(m) @@ -870,13 +982,17 @@ def forward(self, x): """ @compatibility(is_backward_compatible=True) - def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Optional[Type["Tracer"]] = None, - tracer_extras: Optional[Dict[str, Any]] = None): + def __init__( + self, + owning_module: Optional["GraphModule"] = None, + tracer_cls: Optional[Type["Tracer"]] = None, + tracer_extras: Optional[Dict[str, Any]] = None, + ): """ Construct an empty Graph. """ - self._root : Node = Node(self, '', 'root', '', (), {}) - self._used_names : Dict[str, int] = {} # base name -> number + self._root: Node = Node(self, "", "root", "", (), {}) + self._used_names: Dict[str, int] = {} # base name -> number self._insert = self._root.prepend self._len = 0 self._graph_namespace = _Namespace() @@ -884,7 +1000,7 @@ def __init__(self, owning_module: Optional["GraphModule"] = None, tracer_cls: Op self._tracer_cls = tracer_cls self._tracer_extras = tracer_extras self._codegen = CodeGen() - self._co_fields : Dict[str, Any] = {} + self._co_fields: Dict[str, Any] = {} self._find_nodes_lookup_table = _FindNodesLookupTable() @property @@ -911,7 +1027,9 @@ def nodes(self) -> _node_list: return _node_list(self) @compatibility(is_backward_compatible=False) - def find_nodes(self, *, op: str, target: Optional['Target'] = None, sort: bool = True): + def find_nodes( + self, *, op: str, target: Optional["Target"] = None, sort: bool = True + ): """ Allows for fast query of nodes @@ -935,7 +1053,9 @@ def find_nodes(self, *, op: str, target: Optional['Target'] = None, sort: bool = return node_list @compatibility(is_backward_compatible=True) - def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node=False) -> 'Optional[Argument]': + def graph_copy( + self, g: "Graph", val_map: Dict[Node, Node], return_output_node=False + ) -> "Optional[Argument]": """ Copy all nodes from a given graph into ``self``. @@ -955,13 +1075,13 @@ def graph_copy(self, g : 'Graph', val_map : Dict[Node, Node], return_output_node for node in g.nodes: if node in val_map: continue - if node.op == 'output': + if node.op == "output": rv = map_arg(node.args[0], lambda n: val_map[n]) return rv if not return_output_node else (rv, node) - val_map[node] = self.node_copy(node, lambda n : val_map[n]) + val_map[node] = self.node_copy(node, lambda n: val_map[n]) return None - def __deepcopy__(self, memo=None) -> 'Graph': + def __deepcopy__(self, memo=None) -> "Graph": """ Explicitly implement __deepcopy__ to prevent excessive recursion depth from the default implementation. This uses graph_copy to copy the nodes @@ -975,16 +1095,22 @@ def __deepcopy__(self, memo=None) -> 'Graph': g._codegen = copy.deepcopy(self._codegen) assert isinstance(output_vals, tuple) output_val, old_output_node = output_vals - new_output_node = g.output(output_val, type_expr=getattr(old_output_node, 'type', None)) + new_output_node = g.output( + output_val, type_expr=getattr(old_output_node, "type", None) + ) new_output_node.meta = copy.copy(old_output_node.meta) return g @compatibility(is_backward_compatible=True) - def create_node(self, op: str, target: 'Target', - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - name: Optional[str] = None, - type_expr: Optional[Any] = None) -> Node: + def create_node( + self, + op: str, + target: "Target", + args: Optional[Tuple["Argument", ...]] = None, + kwargs: Optional[Dict[str, "Argument"]] = None, + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: """ Create a ``Node`` and add it to the ``Graph`` at the current insert-point. Note that the current insert-point can be set via :meth:`Graph.inserting_before` @@ -1020,7 +1146,10 @@ def create_node(self, op: str, target: 'Target', name = self._graph_namespace.create_name(candidate, None) n = Node(self, name, op, target, args, kwargs, type_expr) - if self.owning_module is not None and getattr(self.owning_module, "_create_node_hooks", None) is not None: + if ( + self.owning_module is not None + and getattr(self.owning_module, "_create_node_hooks", None) is not None + ): for f in self.owning_module._create_node_hooks: f(n) @@ -1042,9 +1171,8 @@ def process_inputs(self, *args): def process_outputs(self, out): return self._codegen.process_outputs(out) - @compatibility(is_backward_compatible=True) - def erase_node(self, to_erase : Node) -> None: + def erase_node(self, to_erase: Node) -> None: """ Erases a ``Node`` from the ``Graph``. Throws an exception if there are still users of that node in the ``Graph``. @@ -1054,15 +1182,20 @@ def erase_node(self, to_erase : Node) -> None: to_erase (Node): The ``Node`` to erase from the ``Graph``. """ if len(to_erase.users) > 0: - raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} ' - f'users in the graph: {to_erase.users}!') + raise RuntimeError( + f"Tried to erase Node {to_erase} but it still had {len(to_erase.users)} " + f"users in the graph: {to_erase.users}!" + ) if to_erase.graph != self: raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!") if to_erase._erased: warnings.warn(f"erase_node({to_erase}) on an already erased node") return - if self.owning_module is not None and getattr(self.owning_module, "_erase_node_hooks", None) is not None: + if ( + self.owning_module is not None + and getattr(self.owning_module, "_erase_node_hooks", None) is not None + ): for f in self.owning_module._erase_node_hooks: f(to_erase) @@ -1087,9 +1220,9 @@ def inserting_before(self, n: Optional[Node] = None): then restore it when the with statement exits:: with g.inserting_before(n): - ... # inserting before node n - ... # insert point restored to what it was previously - g.inserting_before(n) # set the insert point permanently + ... # inserting before node n + ... # insert point restored to what it was previously + g.inserting_before(n) # set the insert point permanently Args: @@ -1111,9 +1244,9 @@ def inserting_after(self, n: Optional[Node] = None): then restore it when the with statement exits:: with g.inserting_after(n): - ... # inserting after node n - ... # insert point restored to what it was previously - g.inserting_after(n) # set the insert point permanently + ... # inserting after node n + ... # insert point restored to what it was previously + g.inserting_after(n) # set the insert point permanently Args: @@ -1129,8 +1262,12 @@ def inserting_after(self, n: Optional[Node] = None): return _InsertPoint(self, n.append) @compatibility(is_backward_compatible=True) - def placeholder(self, name: str, type_expr: Optional[Any] = None, - default_value : Any = inspect.Signature.empty) -> Node: + def placeholder( + self, + name: str, + type_expr: Optional[Any] = None, + default_value: Any = inspect.Signature.empty, + ) -> Node: """ Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents a function input. @@ -1155,7 +1292,7 @@ def placeholder(self, name: str, type_expr: Optional[Any] = None, as ``Graph.create_node``. """ args = () if default_value is inspect.Signature.empty else (default_value,) - return self.create_node('placeholder', name, args=args, type_expr=type_expr) + return self.create_node("placeholder", name, args=args, type_expr=type_expr) @compatibility(is_backward_compatible=True) def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: @@ -1182,7 +1319,10 @@ def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node The same insertion point and type expression rules apply for this method as ``Graph.create_node``. """ - def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> bool: + + def _get_attr_reference_exists( + mod: torch.nn.Module, qualified_name: str + ) -> bool: module_path, _, name = qualified_name.rpartition(".") try: @@ -1196,32 +1336,40 @@ def _get_attr_reference_exists(mod: torch.nn.Module, qualified_name: str) -> boo res = getattr(submod, name) - if (not isinstance(res, torch.nn.Module) - and not isinstance(res, torch.nn.Parameter) - and name not in submod._buffers): + if ( + not isinstance(res, torch.nn.Module) + and not isinstance(res, torch.nn.Parameter) + and name not in submod._buffers + ): return False return True - if (self.owning_module and - not _get_attr_reference_exists(self.owning_module, qualified_name)): - warnings.warn("Attempted to insert a get_attr Node with no " - "underlying reference in the owning " - "GraphModule! Call " - "GraphModule.add_submodule to add the " - "necessary submodule, " - "GraphModule.add_parameter to add the " - "necessary Parameter, or " - "nn.Module.register_buffer to add the " - "necessary buffer", stacklevel=2) - return self.create_node('get_attr', qualified_name, type_expr=type_expr) + if self.owning_module and not _get_attr_reference_exists( + self.owning_module, qualified_name + ): + warnings.warn( + "Attempted to insert a get_attr Node with no " + "underlying reference in the owning " + "GraphModule! Call " + "GraphModule.add_submodule to add the " + "necessary submodule, " + "GraphModule.add_parameter to add the " + "necessary Parameter, or " + "nn.Module.register_buffer to add the " + "necessary buffer", + stacklevel=2, + ) + return self.create_node("get_attr", qualified_name, type_expr=type_expr) @compatibility(is_backward_compatible=True) - def call_module(self, - module_name: str, - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - type_expr: Optional[Any] = None) -> Node: + def call_module( + self, + module_name: str, + args: Optional[Tuple["Argument", ...]] = None, + kwargs: Optional[Dict[str, "Argument"]] = None, + type_expr: Optional[Any] = None, + ) -> Node: """ Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node represents a call to the forward() function of a ``Module`` in the ``Module`` @@ -1252,21 +1400,26 @@ def call_module(self, The same insertion point and type expression rules apply for this method as :meth:`Graph.create_node`. """ - if (self.owning_module and - self.owning_module.get_submodule(module_name) is None): - warnings.warn("Attempted to insert a call_module Node with " - "no underlying reference in the owning " - "GraphModule! Call " - "GraphModule.add_submodule to add the " - "necessary submodule") - return self.create_node('call_module', module_name, args, kwargs, type_expr=type_expr) + if self.owning_module and self.owning_module.get_submodule(module_name) is None: + warnings.warn( + "Attempted to insert a call_module Node with " + "no underlying reference in the owning " + "GraphModule! Call " + "GraphModule.add_submodule to add the " + "necessary submodule" + ) + return self.create_node( + "call_module", module_name, args, kwargs, type_expr=type_expr + ) @compatibility(is_backward_compatible=True) - def call_method(self, - method_name: str, - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - type_expr: Optional[Any] = None) -> Node: + def call_method( + self, + method_name: str, + args: Optional[Tuple["Argument", ...]] = None, + kwargs: Optional[Dict[str, "Argument"]] = None, + type_expr: Optional[Any] = None, + ) -> Node: """ Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node represents a call to a given method on the 0th element of ``args``. @@ -1294,14 +1447,18 @@ def call_method(self, The same insertion point and type expression rules apply for this method as :meth:`Graph.create_node`. """ - return self.create_node('call_method', method_name, args, kwargs, type_expr=type_expr) + return self.create_node( + "call_method", method_name, args, kwargs, type_expr=type_expr + ) @compatibility(is_backward_compatible=True) - def call_function(self, - the_function: Callable[..., Any], - args: Optional[Tuple['Argument', ...]] = None, - kwargs: Optional[Dict[str, 'Argument']] = None, - type_expr: Optional[Any] = None) -> Node: + def call_function( + self, + the_function: Callable[..., Any], + args: Optional[Tuple["Argument", ...]] = None, + kwargs: Optional[Dict[str, "Argument"]] = None, + type_expr: Optional[Any] = None, + ) -> Node: """ Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node represents a call to a Python callable, specified by ``the_function``. @@ -1329,20 +1486,24 @@ def call_function(self, The same insertion point and type expression rules apply for this method as :meth:`Graph.create_node`. """ - return self.create_node('call_function', the_function, args, kwargs, type_expr=type_expr) + return self.create_node( + "call_function", the_function, args, kwargs, type_expr=type_expr + ) @compatibility(is_backward_compatible=True) - def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = lambda x: x) -> Node: + def node_copy( + self, node: Node, arg_transform: Callable[[Node], "Argument"] = lambda x: x + ) -> Node: """ Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from the graph of node to the graph of self. Example:: # Copying all the nodes in `g` into `new_graph` - g : torch.fx.Graph = ... + g: torch.fx.Graph = ... new_graph = torch.fx.graph() value_remap = {} for node in g.nodes: - value_remap[node] = new_graph.node_copy(node, lambda n : value_remap[n]) + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) Args: @@ -1358,12 +1519,14 @@ def node_copy(self, node: Node, arg_transform: Callable[[Node], 'Argument'] = la kwargs = map_arg(node.kwargs, arg_transform) assert isinstance(args, tuple) assert isinstance(kwargs, dict) - result_node = self.create_node(node.op, node.target, args, kwargs, node.name, node.type) + result_node = self.create_node( + node.op, node.target, args, kwargs, node.name, node.type + ) result_node.meta = copy.copy(node.meta) return result_node @compatibility(is_backward_compatible=True) - def output(self, result: 'Argument', type_expr: Optional[Any] = None): + def output(self, result: "Argument", type_expr: Optional[Any] = None): """ Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents a ``return`` statement in Python code. ``result`` is the value that should @@ -1381,9 +1544,11 @@ def output(self, result: 'Argument', type_expr: Optional[Any] = None): The same insertion point and type expression rules apply for this method as ``Graph.create_node``. """ - return self.create_node(op='output', target='output', args=(result,), type_expr=type_expr) + return self.create_node( + op="output", target="output", args=(result,), type_expr=type_expr + ) - def _target_to_str(self, target : Target) -> str: + def _target_to_str(self, target: Target) -> str: if callable(target): op = target.__name__ else: @@ -1396,8 +1561,13 @@ def _target_to_str(self, target : Target) -> str: @compatibility(is_backward_compatible=True) def python_code( - self, root_module: str, *, - verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False + self, + root_module: str, + *, + verbose: bool = False, + include_stride: bool = False, + include_device: bool = False, + colored: bool = False, ) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -1458,36 +1628,50 @@ def override_node_repr(graph: Graph): with override_node_repr(self): return self._python_code( - root_module, namespace, - verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored + root_module, + namespace, + verbose=verbose, + include_stride=include_stride, + include_device=include_device, + colored=colored, ) def _python_code( - self, root_module: str, namespace: _Namespace, *, - verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, + self, + root_module: str, + namespace: _Namespace, + *, + verbose: bool = False, + include_stride: bool = False, + include_device: bool = False, + colored: bool = False, ) -> PythonCode: return self._codegen._gen_python_code( - self.nodes, root_module, namespace, - verbose=verbose, include_stride=include_stride, include_device=include_device, colored=colored + self.nodes, + root_module, + namespace, + verbose=verbose, + include_stride=include_stride, + include_device=include_device, + colored=colored, ) - def __str__(self) -> str: """ Return a human-readable (not machine-readable) string representation of this Graph """ - placeholder_names : List[str] = [] + placeholder_names: List[str] = [] # This is a one-element array just so ``format_node`` can modify the closed # over value - maybe_return_typename : List[str] = [''] + maybe_return_typename: List[str] = [""] node_strs = [node.format_node(placeholder_names) for node in self.nodes] - param_str = ', '.join(placeholder_names) - s = f'graph({param_str}){maybe_return_typename[0]}:' + param_str = ", ".join(placeholder_names) + s = f"graph({param_str}){maybe_return_typename[0]}:" for node_str in node_strs: if node_str: - s += '\n ' + node_str + s += "\n " + node_str return s @compatibility(is_backward_compatible=True) @@ -1500,15 +1684,17 @@ def print_tabular(self): try: from tabulate import tabulate except ImportError: - print("`print_tabular` relies on the library `tabulate`, " - "which could not be found on this machine. Run `pip " - "install tabulate` to install the library.") + print( + "`print_tabular` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) raise - node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] - for n in self.nodes] - print(tabulate(node_specs, - headers=['opcode', 'name', 'target', 'args', 'kwargs'])) + node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in self.nodes] + print( + tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"]) + ) @compatibility(is_backward_compatible=True) def lint(self): @@ -1522,23 +1708,34 @@ def lint(self): """ # Check topo order - def check_arg(arg : Node, n : Optional[Node] = None) -> None: - context_str = f' of Node \'{n}\' ' if n else ' ' + def check_arg(arg: Node, n: Optional[Node] = None) -> None: + context_str = f" of Node '{n}' " if n else " " if arg.graph is not self: - raise RuntimeError(f'Argument \'{arg}\'{context_str}does not belong to this Graph, ' - f'but was used as an argument! If you are copying nodes from another graph, make ' - f'sure to use ``arg_transform`` on node_copy() to remap values\n{self}') + raise RuntimeError( + f"Argument '{arg}'{context_str}does not belong to this Graph, " + f"but was used as an argument! If you are copying nodes from another graph, make " + f"sure to use ``arg_transform`` on node_copy() to remap values\n{self}" + ) if arg not in seen_values: - raise RuntimeError(f'Argument \'{arg}\'{context_str}was used before it has been ' - f'defined! Please check that Nodes in the graph are topologically ordered\n{self}') + raise RuntimeError( + f"Argument '{arg}'{context_str}was used before it has been " + f"defined! Please check that Nodes in the graph are topologically ordered\n{self}" + ) - seen_names : Set[str] = set() - seen_values : Set[Node] = set() + seen_names: Set[str] = set() + seen_values: Set[Node] = set() for node in self.nodes: - if node.op not in ['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output']: - raise RuntimeError(f'Node {node} had unknown opcode {node.op}!') + if node.op not in [ + "placeholder", + "call_method", + "call_module", + "call_function", + "get_attr", + "output", + ]: + raise RuntimeError(f"Node {node} had unknown opcode {node.op}!") if node.graph is not self: - raise RuntimeError(f'Node \'{node}\' does not belong to this Graph!') + raise RuntimeError(f"Node '{node}' does not belong to this Graph!") if node not in self._find_nodes_lookup_table: raise RuntimeError(f"Node '{node}' is not added to the side table") map_arg(node.args, lambda arg: check_arg(arg, node)) @@ -1546,7 +1743,7 @@ def check_arg(arg : Node, n : Optional[Node] = None) -> None: seen_values.add(node) if node.name in seen_names: - raise RuntimeError(f'Node redefined name {node.name}!') + raise RuntimeError(f"Node redefined name {node.name}!") seen_names.add(node.name) # Check targets are legit @@ -1554,49 +1751,64 @@ def check_arg(arg : Node, n : Optional[Node] = None) -> None: num_warnings = 0 MAX_WARNINGS = 5 for node in self.nodes: - if node.op == 'call_function': + if node.op == "call_function": if not callable(node.target): - raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' - 'a Callable is expected') + raise ValueError( + f"Node {node} target {node.target} has type {torch.typename(node.target)} but " + "a Callable is expected" + ) else: if not isinstance(node.target, str): - raise ValueError(f'Node {node} target {node.target} has type {torch.typename(node.target)} but ' - 'a str is expected') - if node.op in ['get_attr', 'call_module']: - target_atoms = node.target.split('.') + raise ValueError( + f"Node {node} target {node.target} has type {torch.typename(node.target)} but " + "a str is expected" + ) + if node.op in ["get_attr", "call_module"]: + target_atoms = node.target.split(".") m_itr = self.owning_module for i, atom in enumerate(target_atoms): new_m_itr = getattr(m_itr, atom, None) - seen_qualname = '.'.join(target_atoms[:i]) + seen_qualname = ".".join(target_atoms[:i]) if new_m_itr is None: - raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute ' - f'{atom} of {seen_qualname}') - if (node.op == "call_module" - and not isinstance(new_m_itr, torch.nn.Module)): - raise RuntimeError(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' - 'not reference an nn.Module') - elif (node.op == "get_attr" - and not isinstance(new_m_itr, torch.nn.Module) - and not isinstance(new_m_itr, torch.nn.Parameter) - and atom not in m_itr._buffers): + raise RuntimeError( + f"Node {node} target {node.target} references nonexistent attribute " + f"{atom} of {seen_qualname}" + ) + if node.op == "call_module" and not isinstance( + new_m_itr, torch.nn.Module + ): + raise RuntimeError( + f"Node {node} target {node.target} {atom} of {seen_qualname} does " + "not reference an nn.Module" + ) + elif ( + node.op == "get_attr" + and not isinstance(new_m_itr, torch.nn.Module) + and not isinstance(new_m_itr, torch.nn.Parameter) + and atom not in m_itr._buffers + ): if num_warnings < MAX_WARNINGS: # Don't emit this warning too frequently, # for very large graphs this can become very expensive # from a performance perspective. - warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does ' - 'not reference an nn.Module, nn.Parameter, or buffer, which is ' - 'what \'get_attr\' Nodes typically target') + warnings.warn( + f"Node {node} target {node.target} {atom} of {seen_qualname} does " + "not reference an nn.Module, nn.Parameter, or buffer, which is " + "what 'get_attr' Nodes typically target" + ) num_warnings += 1 else: m_itr = new_m_itr if num_warnings > MAX_WARNINGS: warnings.warn( - f'Additional {num_warnings - MAX_WARNINGS} warnings ' - 'suppressed about get_attr references' + f"Additional {num_warnings - MAX_WARNINGS} warnings " + "suppressed about get_attr references" ) @compatibility(is_backward_compatible=True) - def eliminate_dead_code(self, is_impure_node: Optional[Callable[[Node], bool]] = None): + def eliminate_dead_code( + self, is_impure_node: Optional[Callable[[Node], bool]] = None + ): """ Remove all dead code from the graph, based on each node's number of users, and whether the nodes have any side effects. The graph must be @@ -1665,7 +1877,7 @@ def set_codegen(self, codegen: CodeGen): @compatibility(is_backward_compatible=False) def on_generate_code( self, - make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc] + make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc], ): """Register a transformer function when python code is generated @@ -1691,6 +1903,7 @@ def on_generate_code( gm: fx.GraphModule = ... + # This is a code transformer we want to register. This code # transformer prepends a pdb import and trace statement at the very # beginning of the generated torch.fx code to allow for manual @@ -1698,21 +1911,17 @@ def on_generate_code( def insert_pdb(body): return ["import pdb; pdb.set_trace()\\n", *body] + # Registers `insert_pdb`, and overwrites the current registered # code transformer (given by `_` to the lambda): - gm.graph.on_generate_code( - lambda _: insert_pdb - ) + gm.graph.on_generate_code(lambda _: insert_pdb) # Or alternatively, registers a code transformer which first # runs `body` through existing registered transformer, then # through `insert_pdb`: gm.graph.on_generate_code( lambda current_trans: ( - lambda body: insert_pdb( - current_trans(body) if current_trans - else body - ) + lambda body: insert_pdb(current_trans(body) if current_trans else body) ) ) @@ -1750,47 +1959,51 @@ def on_generate_code_context_manager(): reflectable_magic_methods = { - 'add': '{} + {}', - 'sub': '{} - {}', - 'mul': '{} * {}', - 'floordiv': '{} // {}', - 'truediv': '{} / {}', - 'div': '{} / {}', - 'mod': '{} % {}', - 'pow': '{} ** {}', - 'lshift': '{} << {}', - 'rshift': '{} >> {}', - 'and_': '{} & {}', - 'or_': '{} | {}', - 'xor': '{} ^ {}', - 'getitem': '{}[{}]', - 'matmul': '{} @ {}', + "add": "{} + {}", + "sub": "{} - {}", + "mul": "{} * {}", + "floordiv": "{} // {}", + "truediv": "{} / {}", + "div": "{} / {}", + "mod": "{} % {}", + "pow": "{} ** {}", + "lshift": "{} << {}", + "rshift": "{} >> {}", + "and_": "{} & {}", + "or_": "{} | {}", + "xor": "{} ^ {}", + "getitem": "{}[{}]", + "matmul": "{} @ {}", } -magic_methods = dict({ - 'eq': '{} == {}', - 'ne': '{} != {}', - 'lt': '{} < {}', - 'gt': '{} > {}', - 'le': '{} <= {}', - 'ge': '{} >= {}', - 'pos': '+{}', - 'neg': '-{}', - 'invert': '~{}'}, **reflectable_magic_methods) +magic_methods = dict( + { + "eq": "{} == {}", + "ne": "{} != {}", + "lt": "{} < {}", + "gt": "{} > {}", + "le": "{} <= {}", + "ge": "{} >= {}", + "pos": "+{}", + "neg": "-{}", + "invert": "~{}", + }, + **reflectable_magic_methods, +) inplace_methods = { - 'iadd': '{} += {}', - 'iand': '{} &= {}', - 'ifloordiv': '{} //= {}', - 'ilshift': '{} <<= {}', - 'imod': '{} %= {}', - 'imul': '{} *= {}', - 'imatmul': '{} @= {}', - 'ior': '{} |= {}', - 'ipow': '{} **= {}', - 'irshift': '{} >>= {}', - 'isub': '{} -= {}', - 'itruediv': '{} /= {}', - 'ixor': '{} ^= {}', - 'setitem': '{}[{}] = {}', + "iadd": "{} += {}", + "iand": "{} &= {}", + "ifloordiv": "{} //= {}", + "ilshift": "{} <<= {}", + "imod": "{} %= {}", + "imul": "{} *= {}", + "imatmul": "{} @= {}", + "ior": "{} |= {}", + "ipow": "{} **= {}", + "irshift": "{} >>= {}", + "isub": "{} -= {}", + "itruediv": "{} /= {}", + "ixor": "{} ^= {}", + "setitem": "{}[{}] = {}", } diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 2328541511fd6..e2da576961774 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -19,6 +19,7 @@ from ._compatibility import compatibility from .graph import _custom_builtins, _is_from_torch, _PyTreeCodeGen, Graph, PythonCode + __all__ = [ "reduce_graph_module", "reduce_package_graph_module", @@ -386,11 +387,9 @@ def __call__(self, obj, *args, **kwargs): return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] except Exception as e: assert e.__traceback__ - topmost_framesummary: ( - traceback.FrameSummary - ) = traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[ - -1 - ] # type: ignore[arg-type] + topmost_framesummary: traceback.FrameSummary = ( + traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] + ) if "eval_with_key" in topmost_framesummary.filename: print( _WrappedCall._generate_error_message(topmost_framesummary), @@ -612,20 +611,20 @@ def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: module_str = ( f"torch.load(r'{module_file}', weights_only=False) # {module_repr}" ) - model_str += f"{tab*2}self.{module_name} = {module_str}\n" + model_str += f"{tab * 2}self.{module_name} = {module_str}\n" for buffer_name, buffer in self._buffers.items(): if buffer is None: continue - model_str += f"{tab*2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" + model_str += f"{tab * 2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" # noqa: B950 for param_name, param in self._parameters.items(): if param is None: continue - model_str += f"{tab*2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" + model_str += f"{tab * 2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" # noqa: B950 model_str += ( - f"{tab*2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" + f"{tab * 2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" ) model_str += f"{_addindent(self.code, 4)}\n" @@ -667,7 +666,6 @@ def add_submodule(self, target: str, m: torch.nn.Module) -> bool: mod: torch.nn.Module = self for item in prefix: - submod = getattr(mod, item, None) if submod is None: @@ -707,7 +705,6 @@ def delete_submodule(self, target: str) -> bool: # Get the parent module for item in path: - if not hasattr(mod, item): return False @@ -743,9 +740,7 @@ def delete_all_unused_submodules(self) -> None: used: List[str] = [] for node in self.graph.nodes: - if node.op == "call_module" or node.op == "get_attr": - # A list of strings representing the different parts # of the path. For example, `foo.bar.baz` gives us # ["foo", "bar", "baz"] diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index c75407583137d..12a2070b586f9 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -1,20 +1,24 @@ # mypy: allow-untyped-defs -from .graph_module import GraphModule -from ._lazy_graph_module import _make_graph_module -from .graph import Graph -from .node import Argument, Node, Target, map_arg, map_aggregate -from .proxy import Proxy -from ._symbolic_trace import Tracer -from ._compatibility import compatibility -from . import config -import torch.fx.traceback as fx_traceback -import torch -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import inspect from contextlib import contextmanager +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +import torch +import torch.fx.traceback as fx_traceback from torch.hub import tqdm -__all__ = ['Interpreter', 'Transformer'] +from . import config +from ._compatibility import compatibility +from ._lazy_graph_module import _make_graph_module +from ._symbolic_trace import Tracer +from .graph import Graph +from .graph_module import GraphModule +from .node import Argument, map_aggregate, map_arg, Node, Target +from .proxy import Proxy + + +__all__ = ["Interpreter", "Transformer"] + @compatibility(is_backward_compatible=True) class Interpreter: @@ -43,22 +47,22 @@ class Interpreter: method equivalents). We could subclass Interpreter like so:: class NegSigmSwapInterpreter(Interpreter): - def call_function(self, target : Target, - args : Tuple, kwargs : Dict) -> Any: + def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) - def call_method(self, target : Target, - args : Tuple, kwargs : Dict) -> Any: - if target == 'neg': + def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any: + if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) + def fn(x): return torch.sigmoid(x).neg() + gm = torch.fx.symbolic_trace(fn) input = torch.randn(3, 4) result = NegSigmSwapInterpreter(gm).run(input) @@ -74,15 +78,21 @@ def fn(x): graph instead of `module.graph`, using the provided `module` argument to satisfy any requests for state. """ + @compatibility(is_backward_compatible=True) - def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, graph: Optional[Graph] = None): + def __init__( + self, + module: torch.nn.Module, + garbage_collect_values: bool = True, + graph: Optional[Graph] = None, + ): self.module = module self.submodules = dict(self.module.named_modules()) if graph is not None: self.graph = graph else: self.graph = self.module.graph - self.env : Dict[Node, Any] = {} + self.env: Dict[Node, Any] = {} self.name = "Interpreter" self.garbage_collect_values = garbage_collect_values self.extra_traceback = True @@ -92,10 +102,10 @@ def __init__(self, module: torch.nn.Module, garbage_collect_values: bool = True, # of a given node. This represents the *last* use of the node in the # execution order of the program, which we will use to free unused # values - node_to_last_use : Dict[Node, Node] = {} - self.user_to_last_uses : Dict[Node, List[Node]] = {} + node_to_last_use: Dict[Node, Node] = {} + self.user_to_last_uses: Dict[Node, List[Node]] = {} - def register_last_uses(n : Node, user : Node): + def register_last_uses(n: Node, user: Node): if n not in node_to_last_use: node_to_last_use[n] = user self.user_to_last_uses.setdefault(user, []).append(n) @@ -105,7 +115,12 @@ def register_last_uses(n : Node, user : Node): map_arg(node.kwargs, lambda n: register_last_uses(n, node)) @compatibility(is_backward_compatible=True) - def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any: + def run( + self, + *args, + initial_env: Optional[Dict[Node, Any]] = None, + enable_io_processing: bool = True, + ) -> Any: """ Run `module` via interpretation and return the result. @@ -128,10 +143,16 @@ def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_p # position and extract those values. if enable_io_processing: args = self.graph.process_inputs(*args) - self.args_iter : Iterator[Any] = iter(args) - pbar = tqdm(total=len(self.graph.nodes), - desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}", - initial=0, position=0, leave=True, disable=config.disable_progress, delay=0) + self.args_iter: Iterator[Any] = iter(args) + pbar = tqdm( + total=len(self.graph.nodes), + desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}", + initial=0, + position=0, + leave=True, + disable=config.disable_progress, + delay=0, + ) for node in self.graph.nodes: pbar.update(1) @@ -147,7 +168,7 @@ def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_p except Exception as e: if self.extra_traceback: msg = f"While executing {node.format_node()}" - msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg) + msg = f"{e.args[0]}\n\n{msg}" if e.args else str(msg) msg += f"\nOriginal traceback:\n{node.stack_trace}" e.args = (msg,) + e.args[1:] if isinstance(e, KeyError): @@ -158,9 +179,13 @@ def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_p for to_delete in self.user_to_last_uses.get(node, []): del self.env[to_delete] - if node.op == 'output': + if node.op == "output": output_val = self.env[node] - return self.graph.process_outputs(output_val) if enable_io_processing else output_val + return ( + self.graph.process_outputs(output_val) + if enable_io_processing + else output_val + ) @compatibility(is_backward_compatible=True) def boxed_run(self, args_list): @@ -183,7 +208,7 @@ def _set_current_node(self, node): yield @compatibility(is_backward_compatible=True) - def run_node(self, n : Node) -> Any: + def run_node(self, n: Node) -> Any: """ Run a specific node ``n`` and return the result. Calls into placeholder, get_attr, call_function, @@ -204,7 +229,9 @@ def run_node(self, n : Node) -> Any: # Main Node running APIs @compatibility(is_backward_compatible=True) - def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def placeholder( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: """ Execute a ``placeholder`` node. Note that this is stateful: ``Interpreter`` maintains an internal iterator over @@ -222,7 +249,7 @@ def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : D Any: The argument value that was retrieved. """ assert isinstance(target, str) - if target.startswith('*'): + if target.startswith("*"): # For a starred parameter e.g. `*args`, retrieve all # remaining values from the args list. return list(self.args_iter) @@ -233,10 +260,14 @@ def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : D if len(args) > 0: return args[0] else: - raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si + raise RuntimeError( + f"Expected positional argument for parameter {target}, but one was not passed in!" + ) from si @compatibility(is_backward_compatible=True) - def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def get_attr( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: """ Execute a ``get_attr`` node. Will retrieve an attribute value from the ``Module`` hierarchy of ``self.module``. @@ -255,7 +286,9 @@ def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict return self.fetch_attr(target) @compatibility(is_backward_compatible=True) - def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def call_function( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: """ Execute a ``call_function`` node and return the result. @@ -275,7 +308,9 @@ def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : return target(*args, **kwargs) @compatibility(is_backward_compatible=True) - def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def call_method( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: """ Execute a ``call_method`` node and return the result. @@ -297,7 +332,9 @@ def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : D return getattr(self_obj, target)(*args_tail, **kwargs) @compatibility(is_backward_compatible=True) - def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def call_module( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: """ Execute a ``call_module`` node and return the result. @@ -320,7 +357,9 @@ def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : D return submod(*args, **kwargs) @compatibility(is_backward_compatible=True) - def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def output( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: """ Execute an ``output`` node. This really just retrieves the value referenced by the ``output`` node and returns it. @@ -339,7 +378,7 @@ def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[s # Helper methods @compatibility(is_backward_compatible=True) - def fetch_attr(self, target : str): + def fetch_attr(self, target: str): """ Fetch an attribute from the ``Module`` hierarchy of ``self.module``. @@ -349,16 +388,18 @@ def fetch_attr(self, target : str): Return: Any: The value of the attribute. """ - target_atoms = target.split('.') + target_atoms = target.split(".") attr_itr = self.module for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): - raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i+1])}") + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[:i + 1])}" + ) attr_itr = getattr(attr_itr, atom) return attr_itr @compatibility(is_backward_compatible=True) - def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: + def fetch_args_kwargs_from_env(self, n: Node) -> Tuple[Tuple, Dict]: """ Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` from the current execution environment. @@ -376,7 +417,7 @@ def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: return args, kwargs @compatibility(is_backward_compatible=True) - def map_nodes_to_values(self, args : Argument, n : Node) -> Argument: + def map_nodes_to_values(self, args: Argument, n: Node) -> Argument: """ Recursively descend through ``args`` and look up the concrete value for each ``Node`` in the current execution environment. @@ -386,13 +427,18 @@ def map_nodes_to_values(self, args : Argument, n : Node) -> Argument: n (Node): Node to which ``args`` belongs. This is only used for error reporting. """ - def load_arg(n_arg : Node) -> Any: + + def load_arg(n_arg: Node) -> Any: if n_arg not in self.env: - raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() ' - f'to diagnose such issues') + raise RuntimeError( + f"Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() " + f"to diagnose such issues" + ) return self.env[n_arg] + return map_arg(args, load_arg) + @compatibility(is_backward_compatible=True) class Transformer(Interpreter): """ @@ -409,23 +455,29 @@ class Transformer(Interpreter): method equivalents). We could subclass ``Transformer`` like so:: class NegSigmSwapXformer(Transformer): - def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def call_function( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(n) - def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: - if target == 'neg': + def call_method( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: + if target == "neg": call_self, *args_tail = args return call_self.sigmoid(*args_tail, **kwargs) return super().call_method(n) + def fn(x): return torch.sigmoid(x).neg() + gm = torch.fx.symbolic_trace(fn) - transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() + transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform() input = torch.randn(3, 4) torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid()) @@ -452,7 +504,9 @@ def is_leaf_module(self, _, __) -> bool: self.tracer.root = module @compatibility(is_backward_compatible=True) - def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: + def placeholder( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Proxy: """ Execute a ``placeholder`` node. In ``Transformer``, this is overridden to insert a new ``placeholder`` into the output @@ -467,10 +521,14 @@ def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : D """ assert isinstance(target, str) default_value = next(iter(args)) if args else inspect.Signature.empty - return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer) + return Proxy( + self.new_graph.placeholder(target, default_value=default_value), self.tracer + ) @compatibility(is_backward_compatible=True) - def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: + def get_attr( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Proxy: """ Execute a ``get_attr`` node. In ``Transformer``, this is overridden to insert a new ``get_attr`` node into the output @@ -487,16 +545,20 @@ def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict return self.tracer.create_proxy("get_attr", target, args, kwargs) @compatibility(is_backward_compatible=True) - def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def call_module( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: # Override so that the leaf module policy from `self.tracer` is respected. assert isinstance(target, str) submod = self.fetch_attr(target) return self.tracer.call_module(submod, submod.forward, args, kwargs) @compatibility(is_backward_compatible=True) - def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: + def call_function( + self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + ) -> Any: # Override so that functions that were wrapped are still wrapped. - return self.tracer.create_proxy('call_function', target, args, kwargs) + return self.tracer.create_proxy("call_function", target, args, kwargs) @compatibility(is_backward_compatible=True) def transform(self) -> GraphModule: @@ -507,8 +569,10 @@ def transform(self) -> GraphModule: with fx_traceback.preserve_node_meta(): result = super().run(enable_io_processing=False) if result is not None: - def strip_proxy(a : Union[Argument, Proxy]) -> Any: + + def strip_proxy(a: Union[Argument, Proxy]) -> Any: return a.node if isinstance(a, Proxy) else a + new_output_node = self.new_graph.output(map_aggregate(result, strip_proxy)) # also preserve the metadata from the old output node, if it exists old_output_node = list(self.graph.nodes)[-1] @@ -516,5 +580,4 @@ def strip_proxy(a : Union[Argument, Proxy]) -> Any: for k, v in old_output_node.meta.items(): new_output_node.meta[k] = v - return _make_graph_module(self.module, self.new_graph) diff --git a/torch/fx/node.py b/torch/fx/node.py index 8c3461cbe23c7..469b63403848b 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -1,39 +1,71 @@ # Nodes represent a definition of a value in our graph of operators. -from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set -from ._compatibility import compatibility -from .immutable_collections import immutable_dict, immutable_list -import torch import builtins -import types import inspect +import types import warnings -from torch.fx.operator_schemas import normalize_function, normalize_module, ArgsKwargsPair -from .._ops import ops as _ops +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union + +import torch from torch._C import _NodeBase +from torch.fx.operator_schemas import ( + ArgsKwargsPair, + normalize_function, + normalize_module, +) + +from .._ops import ops as _ops +from ._compatibility import compatibility +from .immutable_collections import immutable_dict, immutable_list + if TYPE_CHECKING: from .graph import Graph -__all__ = ['Node', 'map_arg', 'map_aggregate', "has_side_effect"] - -BaseArgumentTypes = Union[str, int, float, bool, complex, torch.dtype, - torch.Tensor, torch.device, torch.memory_format, torch.layout, torch._ops.OpOverload, - torch.SymInt, torch.SymBool, torch.SymFloat] +__all__ = ["Node", "map_arg", "map_aggregate", "has_side_effect"] + +BaseArgumentTypes = Union[ + str, + int, + float, + bool, + complex, + torch.dtype, + torch.Tensor, + torch.device, + torch.memory_format, + torch.layout, + torch._ops.OpOverload, + torch.SymInt, + torch.SymBool, + torch.SymFloat, +] base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] Target = Union[Callable[..., Any], str] -Argument = Optional[Union[ - Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types - List[Any], # actually Argument - Dict[str, Any], # actually Argument - slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing - range, - 'Node', - BaseArgumentTypes -]] - -_legal_ops = dict.fromkeys(['placeholder', 'call_method', 'call_module', 'call_function', 'get_attr', 'output', 'root']) +Argument = Optional[ + Union[ + Tuple[Any, ...], # actually Argument, but mypy can't represent recursive types + List[Any], # actually Argument + Dict[str, Any], # actually Argument + slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing + range, + "Node", + BaseArgumentTypes, + ] +] + +_legal_ops = dict.fromkeys( + [ + "placeholder", + "call_method", + "call_module", + "call_function", + "get_attr", + "output", + "root", + ] +) _side_effectful_need_to_be_preserved_pre_dispatch: Set[Callable] = { torch._C._set_grad_enabled, @@ -74,7 +106,8 @@ def _find_module_of_method(orig_method: Callable[..., Any]) -> str: for guess in [torch, torch.nn.functional]: if getattr(guess, name, None) is orig_method: return guess.__name__ - raise RuntimeError(f'cannot find module for {orig_method}') + raise RuntimeError(f"cannot find module for {orig_method}") + # Borrowed from CPython typing module # https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156 @@ -86,22 +119,24 @@ def _type_repr(obj: object) -> str: else, we fall back on repr(obj). """ if isinstance(obj, type): - if obj.__module__ == 'builtins': + if obj.__module__ == "builtins": return obj.__qualname__ - return f'{obj.__module__}.{obj.__qualname__}' + return f"{obj.__module__}.{obj.__qualname__}" if obj is ...: - return '...' + return "..." if isinstance(obj, types.FunctionType): return obj.__name__ return repr(obj) + def _get_qualified_name(func: Callable[..., Any]) -> str: # things like getattr just appear in builtins if getattr(builtins, func.__name__, None) is func: return func.__name__ # torch.Tensor.{fn} - if (isinstance(func, (types.MethodDescriptorType, types.WrapperDescriptorType)) - and func is getattr(torch.Tensor, func.__name__, None)): + if isinstance( + func, (types.MethodDescriptorType, types.WrapperDescriptorType) + ) and func is getattr(torch.Tensor, func.__name__, None): return f"torch.Tensor.{func.__name__}" name = func.__name__ if name == "": @@ -111,33 +146,45 @@ def _get_qualified_name(func: Callable[..., Any]) -> str: except Exception as e: raise RuntimeError("Unable to represent lambda") from e module = _find_module_of_method(func) - module = module.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module + module = module.replace( + "torch._ops", "torch.ops" + ) # WAR for bug in how torch.ops assigns module # Fixup segment_reduce mismatch if module == "torch" and name == "segment_reduce": name = "_" + name - return f'{module}.{name}' + return f"{module}.{name}" + -def _format_arg(arg: object, max_list_len: float = float('inf')) -> str: - if hasattr(arg, '_custom_fx_repr_fn'): +def _format_arg(arg: object, max_list_len: float = float("inf")) -> str: + if hasattr(arg, "_custom_fx_repr_fn"): return arg._custom_fx_repr_fn() elif isinstance(arg, list): - items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) - maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' - return f'[{items}{maybe_len}]' + items = ", ".join( + _format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len + ) + maybe_len = ( + "" if len(arg) < max_list_len + 1 else f", ...[total_len={len(arg)}]" + ) + return f"[{items}{maybe_len}]" elif isinstance(arg, tuple): - items = ', '.join(_format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len) - maybe_len = '' if len(arg) < max_list_len + 1 else f', ...[total_len={len(arg)}]' - maybe_comma = ',' if len(arg) == 1 else '' - return f'({items}{maybe_comma}{maybe_len})' + items = ", ".join( + _format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len + ) + maybe_len = ( + "" if len(arg) < max_list_len + 1 else f", ...[total_len={len(arg)}]" + ) + maybe_comma = "," if len(arg) == 1 else "" + return f"({items}{maybe_comma}{maybe_len})" elif isinstance(arg, dict): - items_str = ', '.join(f'{k}: {_format_arg(v)}' for k, v in arg.items()) - return f'{{{items_str}}}' + items_str = ", ".join(f"{k}: {_format_arg(v)}" for k, v in arg.items()) + return f"{{{items_str}}}" if isinstance(arg, Node): - return '%' + str(arg) + return "%" + str(arg) else: return str(arg) + @compatibility(is_backward_compatible=True) class Node(_NodeBase): """ @@ -166,23 +213,31 @@ class Node(_NodeBase): - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement in the Graph printout. """ - _args: Tuple['Argument', ...] - _kwargs: Dict[str, 'Argument'] - graph: 'Graph' + + _args: Tuple["Argument", ...] + _kwargs: Dict[str, "Argument"] + graph: "Graph" name: str op: str - target: 'Target' - _input_nodes: Dict['Node', None] - users: Dict['Node', None] + target: "Target" + _input_nodes: Dict["Node", None] + users: Dict["Node", None] type: Optional[Any] _sort_key: Any - _repr_fn: Optional[Callable[['Node'], str]] + _repr_fn: Optional[Callable[["Node"], str]] meta: Dict[str, Any] @compatibility(is_backward_compatible=True) - def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', - args: Tuple['Argument', ...], kwargs: Dict[str, 'Argument'], - return_type : Optional[Any] = None) -> None: + def __init__( + self, + graph: "Graph", + name: str, + op: str, + target: "Target", + args: Tuple["Argument", ...], + kwargs: Dict[str, "Argument"], + return_type: Optional[Any] = None, + ) -> None: """ Instantiate an instance of ``Node``. Note: most often, you want to use the Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather @@ -210,14 +265,18 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', of analyses. """ assert op in _legal_ops - if op == 'call_function': + if op == "call_function": if not callable(target): - raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' - 'but a Callable is expected') + raise ValueError( + f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} " + "but a Callable is expected" + ) else: if not isinstance(target, str): - raise ValueError(f'Node [graph = {graph}, name = \'{name}\'] target {target} has type {torch.typename(target)} ' - 'but a str is expected') + raise ValueError( + f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} " + "but a str is expected" + ) super().__init__() # bypass Node.__setattr__ for perf and so that it doesn't need to handle half-built objects @@ -225,9 +284,13 @@ def __init__(self, graph: 'Graph', name: str, op: str, target: 'Target', assign(self, "graph", graph) assign(self, "name", name) # unique name of value being created - assign(self, "op", op) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr + assign( + self, "op", op + ) # the kind of operation = placeholder|call_method|call_module|call_function|get_attr - assign(self, "target", target) # for method/module/function, the name of the method/module/function/attr + assign( + self, "target", target + ) # for method/module/function, the name of the method/module/function/attr # being invoked, e.g add, layer1, or torch.add # All `Node`-valued inputs. Key is the Node, value is don't-care. @@ -280,7 +343,7 @@ def __setstate__(self, state: Dict[str, Any]) -> None: self._next = _next @property - def next(self) -> 'Node': + def next(self) -> "Node": """ Returns the next ``Node`` in the linked list of Nodes. @@ -291,7 +354,7 @@ def next(self) -> 'Node': return self._next @property - def prev(self) -> 'Node': + def prev(self) -> "Node": """ Returns the previous ``Node`` in the linked list of Nodes. @@ -302,7 +365,7 @@ def prev(self) -> 'Node': return self._prev @compatibility(is_backward_compatible=True) - def prepend(self, x: 'Node') -> None: + def prepend(self, x: "Node") -> None: """ Insert x before this node in the list of nodes in the graph. Example:: @@ -316,7 +379,9 @@ def prepend(self, x: 'Node') -> None: """ assert self.graph == x.graph, "Attempting to move a Node into a different Graph" if self == x: - warnings.warn("Trying to prepend a node to itself. This behavior has no effect on the graph.") + warnings.warn( + "Trying to prepend a node to itself. This behavior has no effect on the graph." + ) return x._remove_from_list() p = self._prev @@ -328,28 +393,28 @@ def prepend(self, x: 'Node') -> None: nsk = x._next._sort_key if len(psk) > len(nsk): idx: int - *prefix, idx = psk[:len(nsk) + 1] + *prefix, idx = psk[: len(nsk) + 1] x._sort_key = (*prefix, idx + 1) elif len(psk) < len(nsk): - *prefix, idx = nsk[:len(psk) + 1] + *prefix, idx = nsk[: len(psk) + 1] x._sort_key = (*prefix, idx - 1) else: # same length, increase length by 1 x._sort_key = (*psk, 0) - def __gt__(self, other: 'Node') -> bool: + def __gt__(self, other: "Node") -> bool: return self._sort_key > other._sort_key - def __lt__(self, other: 'Node') -> bool: + def __lt__(self, other: "Node") -> bool: return self._sort_key < other._sort_key - def __ge__(self, other: 'Node') -> bool: + def __ge__(self, other: "Node") -> bool: return self > other or self == other - def __le__(self, other: 'Node') -> bool: + def __le__(self, other: "Node") -> bool: return self < other or self == other @compatibility(is_backward_compatible=True) - def append(self, x: 'Node') -> None: + def append(self, x: "Node") -> None: """ Insert ``x`` after this node in the list of nodes in the graph. Equivalent to ``self.next.prepend(x)`` @@ -376,7 +441,7 @@ def args(self) -> Tuple[Argument, ...]: return self._args @args.setter - def args(self, a : Tuple[Argument, ...]) -> None: + def args(self, a: Tuple[Argument, ...]) -> None: """ Set the tuple of arguments to this Node. The interpretation of arguments depends on the node's opcode. See the ``fx.Graph`` docstring for more @@ -399,7 +464,7 @@ def kwargs(self) -> Dict[str, Argument]: return self._kwargs @kwargs.setter - def kwargs(self, k : Dict[str, Argument]) -> None: + def kwargs(self, k: Dict[str, Argument]) -> None: """ Set the dict of kwargs to this Node. The interpretation of arguments depends on the node's opcode. See the ``fx.Graph`` docstring for more @@ -410,7 +475,7 @@ def kwargs(self, k : Dict[str, Argument]) -> None: self.__update_args_kwargs(self._args, k) @property - def all_input_nodes(self) -> List['Node']: + def all_input_nodes(self) -> List["Node"]: """ Return all Nodes that are inputs to this Node. This is equivalent to iterating over ``args`` and ``kwargs`` and only collecting the values that @@ -424,7 +489,7 @@ def all_input_nodes(self) -> List['Node']: return list(self._input_nodes.keys()) @compatibility(is_backward_compatible=True) - def update_arg(self, idx : int, arg : Argument) -> None: + def update_arg(self, idx: int, arg: Argument) -> None: """ Update an existing positional argument to contain the new value ``arg``. After calling, ``self.args[idx] == arg``. @@ -439,7 +504,7 @@ def update_arg(self, idx : int, arg : Argument) -> None: self.args = tuple(args) @compatibility(is_backward_compatible=True) - def insert_arg(self, idx : int, arg : Argument) -> None: + def insert_arg(self, idx: int, arg: Argument) -> None: """ Insert an positional argument to the argument list with given index. @@ -448,7 +513,9 @@ def insert_arg(self, idx : int, arg : Argument) -> None: idx (int): The index of the element in ``self.args`` to be inserted before. arg (Argument): The new argument value to insert into ``args`` """ - assert 0 <= idx <= len(self.args), "insert_args index must be between 0 and len(self.args)" + assert ( + 0 <= idx <= len(self.args) + ), "insert_args index must be between 0 and len(self.args)" args_left = self.args[:idx] args_right = self.args[idx:] @@ -463,7 +530,7 @@ def insert_arg(self, idx : int, arg : Argument) -> None: new_use.users.setdefault(self) @compatibility(is_backward_compatible=True) - def update_kwarg(self, key : str, arg : Argument) -> None: + def update_kwarg(self, key: str, arg: Argument) -> None: """ Update an existing keyword argument to contain the new value ``arg``. After calling, ``self.kwargs[key] == arg``. @@ -490,13 +557,16 @@ def stack_trace(self) -> Optional[str]: return self.meta.get("stack_trace", None) @stack_trace.setter - def stack_trace(self, trace : Optional[str]) -> None: + def stack_trace(self, trace: Optional[str]) -> None: self.meta["stack_trace"] = trace - def __update_args_kwargs(self, new_args : Tuple['Argument', ...], new_kwargs : Dict[str, 'Argument']) -> None: + def __update_args_kwargs( + self, new_args: Tuple["Argument", ...], new_kwargs: Dict[str, "Argument"] + ) -> None: """ This API is internal. Do *not* call it directly. """ + def update_users_and_input_nodes(n: Any) -> Any: if isinstance(n, Node): self._input_nodes.setdefault(n) @@ -512,8 +582,12 @@ def update_users_and_input_nodes(n: Any) -> Any: # - Normalize list->immutable_list, dict->immutable_dict, etc # - Populate self._input_nodes # - Populate arg.users[self] for each arg - object.__setattr__(self, "_args", map_aggregate(new_args, update_users_and_input_nodes)) - object.__setattr__(self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes)) + object.__setattr__( + self, "_args", map_aggregate(new_args, update_users_and_input_nodes) + ) + object.__setattr__( + self, "_kwargs", map_aggregate(new_kwargs, update_users_and_input_nodes) + ) def __repr__(self) -> str: if self._repr_fn: @@ -529,8 +603,8 @@ def _pretty_print_target(self, target: object) -> str: """ if isinstance(target, str): return target - if hasattr(target, '__module__'): - name = getattr(target, '__name__', None) + if hasattr(target, "__module__"): + name = getattr(target, "__name__", None) if name is None: # Just to be defensive, if we don't have `__name__`, get the # qualname. Not sure if this happens for any members of `operator` @@ -538,16 +612,18 @@ def _pretty_print_target(self, target: object) -> str: # things in `operator` have `_operator` as their __module__. # TODO: THIS IS BROKEN: _get_qualified_name calls `__name__` return _get_qualified_name(target) # type: ignore[arg-type] - if target.__module__ == 'builtins': - return f'builtins.{name}' - elif target.__module__ == '_operator': - return f'operator.{name}' + if target.__module__ == "builtins": + return f"builtins.{name}" + elif target.__module__ == "_operator": + return f"operator.{name}" return _get_qualified_name(target) # type: ignore[arg-type] @compatibility(is_backward_compatible=True) - def format_node(self, - placeholder_names: Optional[List[str]] = None, - maybe_return_typename: Optional[List[str]] = None) -> Optional[str]: + def format_node( + self, + placeholder_names: Optional[List[str]] = None, + maybe_return_typename: Optional[List[str]] = None, + ) -> Optional[str]: """ Return a descriptive string representation of ``self``. @@ -576,37 +652,46 @@ def format_node(self, return a descriptive string representation of the current Node. """ - if self.op == 'placeholder': + if self.op == "placeholder": assert isinstance(self.target, str) arg_str = self.target - arg_str += arg_str + f': {_type_repr(self.type)}' if self.type else '' + arg_str += arg_str + f": {_type_repr(self.type)}" if self.type else "" if placeholder_names: placeholder_names.append(arg_str) return None - maybe_typename = f'{_type_repr(self.type)} ' if self.type else '' - default_val = '(default=' + str(self.args[0]) + ')' if self.args else '' - return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}' - elif self.op == 'get_attr': - maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' - return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \ - f'{self.op}[target={self._pretty_print_target(self.target)}]' - elif self.op == 'output': + maybe_typename = f"{_type_repr(self.type)} " if self.type else "" + default_val = "(default=" + str(self.args[0]) + ")" if self.args else "" + return f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}" + elif self.op == "get_attr": + maybe_typename = ( + f"{_type_repr(self.type)} " if self.type is not None else "" + ) + return ( + f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = " + f"{self.op}[target={self._pretty_print_target(self.target)}]" + ) + elif self.op == "output": if self.type and maybe_return_typename: - maybe_return_typename[0] = f' -> {_type_repr(self.type)}' - return f'return {self.args[0]}' + maybe_return_typename[0] = f" -> {_type_repr(self.type)}" + return f"return {self.args[0]}" else: - maybe_typename = f'{_type_repr(self.type)} ' if self.type is not None else '' - return f'%{self.name} : {maybe_typename}[num_users={len(self.users)}] = ' \ - f'{self.op}[target={self._pretty_print_target(self.target)}](' \ - f'args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})' + maybe_typename = ( + f"{_type_repr(self.type)} " if self.type is not None else "" + ) + return ( + f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = " + f"{self.op}[target={self._pretty_print_target(self.target)}](" + f"args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})" + ) @compatibility(is_backward_compatible=True) - def replace_all_uses_with(self, - replace_with: 'Node', - delete_user_cb: Callable[['Node'], bool] = lambda user: True, - *, - propagate_meta: bool = False - ) -> List['Node']: + def replace_all_uses_with( + self, + replace_with: "Node", + delete_user_cb: Callable[["Node"], bool] = lambda user: True, + *, + propagate_meta: bool = False, + ) -> List["Node"]: """ Replace all uses of ``self`` in the Graph with the Node ``replace_with``. @@ -625,9 +710,10 @@ def replace_all_uses_with(self, The list of Nodes on which this change was made. """ if propagate_meta: - assert len(replace_with.meta) == 0, \ - 'Called node.replace_all_uses_with(replace_with, propagate_meta=True), ' \ - 'but replace_with already has .meta keys' + assert len(replace_with.meta) == 0, ( + "Called node.replace_all_uses_with(replace_with, propagate_meta=True), " + "but replace_with already has .meta keys" + ) for k, v in self.meta.items(): replace_with.meta[k] = v to_process = list(self.users) @@ -638,7 +724,7 @@ def replace_all_uses_with(self, skipped.append(use_node) continue - def maybe_replace_node(n : Node) -> Node: + def maybe_replace_node(n: Node) -> Node: if n == self: return replace_with else: @@ -690,9 +776,12 @@ def is_impure(self) -> bool: @compatibility(is_backward_compatible=False) def normalized_arguments( - self, root : torch.nn.Module, arg_types : Optional[Tuple[Any]] = None, - kwarg_types : Optional[Dict[str, Any]] = None, - normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + self, + root: torch.nn.Module, + arg_types: Optional[Tuple[Any]] = None, + kwarg_types: Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs: bool = False, + ) -> Optional[ArgsKwargsPair]: """ Returns normalized arguments to Python targets. This means that `args/kwargs` will be matched up to the module/functional's @@ -715,17 +804,23 @@ def normalized_arguments( Returns NamedTuple ArgsKwargsPair, or `None` if not successful. """ - if self.op == 'call_function': + if self.op == "call_function": assert callable(self.target) - return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types) # type: ignore[arg-type] - elif self.op == 'call_module': + return normalize_function( + self.target, + self.args, # type: ignore[arg-type] + self.kwargs, + arg_types, + kwarg_types, + ) + elif self.op == "call_module": assert isinstance(self.target, str) return normalize_module(root, self.target, self.args, self.kwargs) # type: ignore[arg-type] return None @compatibility(is_backward_compatible=True) - def replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None: + def replace_input_with(self, old_input: "Node", new_input: "Node") -> None: """ Loop through input nodes of ``self``, and replace all instances of ``old_input`` with ``new_input``. @@ -735,7 +830,8 @@ def replace_input_with(self, old_input: 'Node', new_input: 'Node') -> None: old_input (Node): The old input node to be replaced. new_input (Node): The new input node to replace ``old_input``. """ - def maybe_replace_node(n : Node) -> Node: + + def maybe_replace_node(n: Node) -> Node: return new_input if n == old_input else n m = self.graph.owning_module @@ -756,7 +852,7 @@ def _rename(self, candidate: str) -> None: self.graph._graph_namespace._rename_object(self, name) def __setattr__(self, name: str, value: Any) -> None: - if name == 'name' and hasattr(self, "name"): + if name == "name" and hasattr(self, "name"): m = self.graph.owning_module if getattr(m, "_replace_hook", None): assert isinstance(value, str) @@ -764,9 +860,9 @@ def __setattr__(self, name: str, value: Any) -> None: m._replace_hook(old=self, new=value, user=user) update = False if ( - hasattr(self, name) and - hasattr(self.graph, "_find_nodes_lookup_table") and - self in self.graph._find_nodes_lookup_table + hasattr(self, name) + and hasattr(self.graph, "_find_nodes_lookup_table") + and self in self.graph._find_nodes_lookup_table ): update = True self.graph._find_nodes_lookup_table.remove(self) @@ -774,6 +870,7 @@ def __setattr__(self, name: str, value: Any) -> None: if update: self.graph._find_nodes_lookup_table.insert(self) + @compatibility(is_backward_compatible=True) def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: """ @@ -782,6 +879,7 @@ def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument: assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable" return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) + @compatibility(is_backward_compatible=True) def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: """ @@ -790,7 +888,7 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: if isinstance(a, tuple): t = tuple([map_aggregate(elem, fn) for elem in a]) # Support NamedTuple (if it has `_fields`) by repacking into original type. - return t if not hasattr(a, '_fields') else type(a)(*t) # type: ignore[arg-type] + return t if not hasattr(a, "_fields") else type(a)(*t) # type: ignore[arg-type] elif isinstance(a, list): return immutable_list([map_aggregate(elem, fn) for elem in a]) elif isinstance(a, dict): @@ -799,6 +897,10 @@ def map_aggregate(a: Argument, fn: Callable[[Argument], Argument]) -> Argument: dict.__setitem__(rv, k, map_aggregate(v, fn)) return rv elif isinstance(a, slice): - return slice(map_aggregate(a.start, fn), map_aggregate(a.stop, fn), map_aggregate(a.step, fn)) + return slice( + map_aggregate(a.start, fn), + map_aggregate(a.stop, fn), + map_aggregate(a.step, fn), + ) else: return fn(a) diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index 8a5beed5285d9..f654b6c060e81 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -1,63 +1,100 @@ # mypy: allow-untyped-defs -import torch +import enum import inspect import numbers import types import typing -import enum import warnings -from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING +from typing import ( + Any, + Callable, + cast, + Dict, + List, + NamedTuple, + Optional, + Tuple, + TYPE_CHECKING, +) + +import torch from torch._jit_internal import boolean_dispatched +from torch._ops import OpOverload, OpOverloadPacket + from ._compatibility import compatibility -from torch._ops import OpOverloadPacket, OpOverload + if TYPE_CHECKING: from .node import Argument -__all__ = ["ArgsKwargsPair", "check_for_mutable_operation", "get_signature_for_torch_op", "create_type_hint", - "type_matches", "normalize_function", "normalize_module"] +__all__ = [ + "ArgsKwargsPair", + "check_for_mutable_operation", + "get_signature_for_torch_op", + "create_type_hint", + "type_matches", + "normalize_function", + "normalize_module", +] + @compatibility(is_backward_compatible=False) class ArgsKwargsPair(NamedTuple): """ Simple named tuple for wrapping args/kwargs pairs. """ + args: Tuple[Any, ...] kwargs: Dict[str, Any] -_manual_overrides : Dict[Callable, List[inspect.Signature]] = {} + +_manual_overrides: Dict[Callable, List[inspect.Signature]] = {} + def _nonzero_schemas(): signatures = [] def nonzero(self): pass + signatures.append(inspect.signature(nonzero)) - def nonzero(self, *, as_tuple : bool): # type: ignore[no-redef] + def nonzero(self, *, as_tuple: bool): # type: ignore[no-redef] pass + signatures.append(inspect.signature(nonzero)) return signatures + _manual_overrides[torch.nonzero] = _nonzero_schemas() + class _FakeGlobalNamespace: def __getattr__(self, name): - if name == 'torch': + if name == "torch": return torch - raise RuntimeError('Expected a torch namespace lookup') - -_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout, - 'number' : numbers.Number, 'Future' : torch.jit.Future, - 'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme, - '__torch__': _FakeGlobalNamespace(), 'NoneType': type(None), - 'Storage': torch.UntypedStorage, - 't': typing.TypeVar('t')} + raise RuntimeError("Expected a torch namespace lookup") + + +_type_eval_globals = { + "Tensor": torch.Tensor, + "Device": torch.device, + "Layout": torch.layout, + "number": numbers.Number, + "Future": torch.jit.Future, + "AnyEnumType": enum.Enum, + "QScheme": torch.qscheme, + "__torch__": _FakeGlobalNamespace(), + "NoneType": type(None), + "Storage": torch.UntypedStorage, + "t": typing.TypeVar("t"), +} for k in dir(typing): _type_eval_globals[k] = getattr(typing, k) -def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any: + +def _torchscript_type_to_python_type(ts_type: "torch._C.JitType") -> Any: """ Convert a TorchScript type to a Python type (including subtypes) via eval'ing the annotation_str. _type_eval_globals sets up expressions @@ -65,9 +102,13 @@ def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any: """ return eval(ts_type.annotation_str, _type_eval_globals) -def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: + +def _torchscript_schema_to_signature_impl( + ts_schema: torch._C.FunctionSchema, +) -> inspect.Signature: from inspect import Parameter - parameters : List[Parameter] = [] + + parameters: List[Parameter] = [] for arg in ts_schema.arguments: arg_type = _torchscript_type_to_python_type(arg.type) default = arg.default_value if arg.has_default_value() else Parameter.empty @@ -76,8 +117,12 @@ def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) - # argument name. Downstream, if someone converts that positional argument to a keyword # argument, the name mismatch will break things, so here we're going to normalize the # name to "input" - name = arg.name if arg.name != 'self' else 'input' - kind = Parameter.KEYWORD_ONLY if arg.kwarg_only else Parameter.POSITIONAL_OR_KEYWORD + name = arg.name if arg.name != "self" else "input" + kind = ( + Parameter.KEYWORD_ONLY + if arg.kwarg_only + else Parameter.POSITIONAL_OR_KEYWORD + ) # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument if name == "from": assert kind == Parameter.POSITIONAL_OR_KEYWORD @@ -87,9 +132,18 @@ def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) - # This renders all previous arguments to positional only for idx, p in enumerate(parameters): assert p.kind == Parameter.POSITIONAL_OR_KEYWORD - parameters[idx] = Parameter(name=p.name, kind=Parameter.POSITIONAL_ONLY, default=p.default, annotation=p.annotation) - parameters.append(Parameter(name=name, kind=kind, default=default, annotation=arg_type)) - return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns] + parameters[idx] = Parameter( + name=p.name, + kind=Parameter.POSITIONAL_ONLY, + default=p.default, + annotation=p.annotation, + ) + parameters.append( + Parameter(name=name, kind=kind, default=default, annotation=arg_type) + ) + return_types = [ + _torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns + ] if len(return_types) == 0: return_type = None elif len(return_types) == 1: @@ -99,9 +153,13 @@ def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) - return inspect.Signature(parameters, return_annotation=return_type) -_SCHEMA_TO_SIGNATURE_CACHE : Dict[Tuple[str, str], inspect.Signature] = {} -def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: +_SCHEMA_TO_SIGNATURE_CACHE: Dict[Tuple[str, str], inspect.Signature] = {} + + +def _torchscript_schema_to_signature( + ts_schema: torch._C.FunctionSchema, +) -> inspect.Signature: # Cached as it's called in the hot path of FakeTensor dispatch cache_key = ts_schema.name, ts_schema.overload_name cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key) @@ -112,8 +170,11 @@ def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> ins _SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res return res + @compatibility(is_backward_compatible=False) -def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']): +def check_for_mutable_operation( + target: Callable, args: Tuple["Argument", ...], kwargs: Dict[str, "Argument"] +): signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) if signatures and schemas: @@ -126,14 +187,16 @@ def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...] try: candidate_signature.bind(*args, **kwargs) matched_schemas.append((candidate_signature, schema)) - except TypeError as e: + except TypeError: continue def throw_if_mutable(schema): if schema.is_mutable: - raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional ' - f'code, so operations that mutate operands in-place (e.g. via `out` arguments) ' - f'are not supported') + raise RuntimeError( + f"Tried to trace mutable operation {schema}. FX only supports functional " + f"code, so operations that mutate operands in-place (e.g. via `out` arguments) " + f"are not supported" + ) if len(matched_schemas) == 0: # Did not match any schema. Cannot check for mutation @@ -147,8 +210,9 @@ def throw_if_mutable(schema): # do nothing. pass + @compatibility(is_backward_compatible=False) -def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): +def get_signature_for_torch_op(op: Callable, return_schemas: bool = False): """ Given an operator on the `torch` namespace, return a list of `inspect.Signature` objects corresponding to the overloads of that op.. May return `None` if a signature @@ -181,6 +245,7 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] return (signatures, schemas) if return_schemas else signatures + @compatibility(is_backward_compatible=False) def create_type_hint(x): """ @@ -198,11 +263,15 @@ def create_type_hint(x): if isinstance(x, (list, tuple)): # todo(chilli): Figure out the right way for mypy to handle this if isinstance(x, list): + def ret_type(x): return List[x] # type: ignore[valid-type] + else: + def ret_type(x): return Tuple[x, ...] + if len(x) == 0: return ret_type(Any) base_type = x[0] @@ -214,14 +283,17 @@ def ret_type(x): else: return ret_type(Any) return ret_type(base_type) - except Exception as e: + except Exception: # We tried to create a type hint for list but failed. - warnings.warn(f"We were not able to successfully create type hint from the type {x}") + warnings.warn( + f"We were not able to successfully create type hint from the type {x}" + ) return x + @compatibility(is_backward_compatible=False) -def type_matches(signature_type : Any, argument_type : Any): - sig_origin_type = getattr(signature_type, '__origin__', signature_type) +def type_matches(signature_type: Any, argument_type: Any): + sig_origin_type = getattr(signature_type, "__origin__", signature_type) if signature_type is argument_type: return True @@ -236,13 +308,14 @@ def type_matches(signature_type : Any, argument_type : Any): # int can be promoted to List[int] return True - if getattr(signature_type, '__origin__', None) in {list, List}: + if getattr(signature_type, "__origin__", None) in {list, List}: sig_el_type = signature_type.__args__[0] if not inspect.isclass(sig_el_type): warnings.warn( - f"Does not support nested parametric types, got {signature_type}. Please file a bug.") + f"Does not support nested parametric types, got {signature_type}. Please file a bug." + ) return False - if getattr(argument_type, '__origin__', None) in {list, List}: + if getattr(argument_type, "__origin__", None) in {list, List}: return issubclass(argument_type.__args__[0], sig_el_type) def is_homogeneous_tuple(t): @@ -267,11 +340,16 @@ def is_homogeneous_tuple(t): return False + @compatibility(is_backward_compatible=False) def normalize_function( - target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None, - kwarg_types : Optional[Dict[str, Any]] = None, - normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + target: Callable, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + arg_types: Optional[Tuple[Any]] = None, + kwarg_types: Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs: bool = False, +) -> Optional[ArgsKwargsPair]: """ Returns normalized arguments to PyTorch functions. This means that `args/kwargs` will be matched up to the functional's @@ -308,14 +386,19 @@ def normalize_function( # branch signature for analysis. Otherwise, leave this un-normalized assert not isinstance(target, str) dispatched = boolean_dispatched[target] - if_true, if_false = dispatched['if_true'], dispatched['if_false'] - if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters: + if_true, if_false = dispatched["if_true"], dispatched["if_false"] + if ( + inspect.signature(if_true).parameters + != inspect.signature(if_false).parameters + ): return None target_for_analysis = if_true assert callable(target_for_analysis) sig = inspect.signature(inspect.unwrap(target_for_analysis)) - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs) + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs( + sig, args, kwargs, normalize_to_only_use_kwargs + ) else: assert callable(target) torch_op_schemas = get_signature_for_torch_op(target) @@ -328,7 +411,7 @@ def normalize_function( try: candidate_signature.bind(*args, **kwargs) matched_schemas.append(candidate_signature) - except TypeError as e: + except TypeError: continue if len(matched_schemas) == 0: @@ -336,8 +419,9 @@ def normalize_function( pass elif len(matched_schemas) == 1: # Matched exactly one schema, unambiguous - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs, - normalize_to_only_use_kwargs) + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs( + matched_schemas[0], args, kwargs, normalize_to_only_use_kwargs + ) else: if arg_types is not None or kwarg_types is not None: arg_types = arg_types if arg_types else cast(Tuple[Any], ()) @@ -345,30 +429,49 @@ def normalize_function( for candidate_signature in torch_op_schemas: sig_matches = True try: - bound_types = candidate_signature.bind(*arg_types, **kwarg_types) + bound_types = candidate_signature.bind( + *arg_types, **kwarg_types + ) for arg_name, arg_type in bound_types.arguments.items(): param = candidate_signature.parameters[arg_name] - sig_matches = sig_matches and type_matches(param.annotation, arg_type) - except TypeError as e: + sig_matches = sig_matches and type_matches( + param.annotation, arg_type + ) + except TypeError: sig_matches = False if sig_matches: - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs, - normalize_to_only_use_kwargs) + new_args_and_kwargs = ( + _args_kwargs_to_normalized_args_kwargs( + candidate_signature, + args, + kwargs, + normalize_to_only_use_kwargs, + ) + ) break else: # Matched more than one schema. In this situation, the caller must provide the types of # the arguments of the overload they expect. - schema_printouts = '\n'.join(str(schema) for schema in matched_schemas) - raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but ' - f'the schema match was ambiguous! Please provide argument types to ' - f'the normalize_arguments() call. Available schemas:\n{schema_printouts}') + schema_printouts = "\n".join( + str(schema) for schema in matched_schemas + ) + raise RuntimeError( + f"Tried to normalize arguments to {torch.typename(target)} but " + f"the schema match was ambiguous! Please provide argument types to " + f"the normalize_arguments() call. Available schemas:\n{schema_printouts}" + ) return new_args_and_kwargs + @compatibility(is_backward_compatible=False) def normalize_module( - root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, - normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: + root: torch.nn.Module, + target: str, + args: Tuple[Any], + kwargs: Optional[Dict[str, Any]] = None, + normalize_to_only_use_kwargs: bool = False, +) -> Optional[ArgsKwargsPair]: """ Returns normalized arguments to PyTorch modules. This means that `args/kwargs` will be matched up to the functional's @@ -391,22 +494,29 @@ def normalize_module( try: submod = root.get_submodule(target) except AttributeError as e: - raise RuntimeError(f"Tried to normalize node with target {target} but root did not " - f"have that target!") from e - if hasattr(submod.__class__, '__name__'): + raise RuntimeError( + f"Tried to normalize node with target {target} but root did not " + f"have that target!" + ) from e + if hasattr(submod.__class__, "__name__"): classname = submod.__class__.__name__ if getattr(torch.nn, classname, None) == submod.__class__: sig = inspect.signature(inspect.unwrap(submod.forward)) if kwargs is None: kwargs = {} - new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, - normalize_to_only_use_kwargs) + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs( + sig, args, kwargs, normalize_to_only_use_kwargs + ) return new_args_and_kwargs return None -def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...], - kwargs : Dict[str, Any], - normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]: + +def _args_kwargs_to_normalized_args_kwargs( + sig: inspect.Signature, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + normalize_to_only_use_kwargs: bool, +) -> Optional[ArgsKwargsPair]: """ Given a call target, args, and kwargs, return the arguments normalized into an ArgsKwargsPair, or None if the type signature is not supported by @@ -428,20 +538,22 @@ def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple # Don't currently support positional-only # or varargs (*args, **kwargs) signatures supported_parameter_types = { - inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + } if any(p.kind not in supported_parameter_types for p in sig.parameters.values()): # Add an exception for one signature, which is common for random/uniform, i.e.: # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None # `from` is Python keyword and as such functions with that signature should have # positional-only args, but at the same time they could be dispatched as kwargs - if list(sig.parameters.keys()) != ['input', 'from', 'to', 'generator']: + if list(sig.parameters.keys()) != ["input", "from", "to", "generator"]: return None bound_args = sig.bind(*args, **kwargs) bound_args.apply_defaults() - new_kwargs : Dict[str, Any] = {} - new_args : List[Any] = [] + new_kwargs: Dict[str, Any] = {} + new_args: List[Any] = [] for i, param in enumerate(sig.parameters): if not normalize_to_only_use_kwargs and i < len(args): new_args.append(bound_args.arguments[param]) diff --git a/torch/fx/passes/__init__.py b/torch/fx/passes/__init__.py index f83a2f248fcde..433d8818e259a 100644 --- a/torch/fx/passes/__init__.py +++ b/torch/fx/passes/__init__.py @@ -1,12 +1,14 @@ -from . import graph_drawer -from . import graph_manipulation -from . import net_min_base -from . import operator_support -from . import param_fetch -from . import reinplace -from . import runtime_assert -from . import shape_prop -from . import split_module -from . import split_utils -from . import splitter_base -from . import tools_common +from . import ( + graph_drawer, + graph_manipulation, + net_min_base, + operator_support, + param_fetch, + reinplace, + runtime_assert, + shape_prop, + split_module, + split_utils, + splitter_base, + tools_common, +) diff --git a/torch/fx/passes/_tensorify_python_scalars.py b/torch/fx/passes/_tensorify_python_scalars.py new file mode 100644 index 0000000000000..1da82c9bd4155 --- /dev/null +++ b/torch/fx/passes/_tensorify_python_scalars.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +import logging +from typing import List, Union + +import torch +import torch.fx as fx +from torch._prims_common import get_computation_dtype +from torch._subclasses import fake_tensor # noqa: TCH001 +from torch._utils_internal import JustKnobsConfig +from torch.fx._utils import lazy_format_graph_code +from torch.fx.experimental.symbolic_shapes import ShapeEnv # noqa: TCH001 +from torch.fx.graph_module import GraphModule # noqa: TCH001 + +# TODO: refactor +from torch.fx.passes.runtime_assert import _get_sym_val +from torch.fx.proxy import MetaProxy +from torch.utils._sympy.reference import TensorReferenceAnalysis + + +__all__: List[str] = [] + +log = logging.getLogger(__name__) +graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") + +# The general shape of this transformation is to look for Tensor operations +# that take a backed SymFloat as an argument, and then redo them as tensor +# compute (with ints and tensors as inputs). For example, add(Tensor, Scalar) +# can be translated into add(Tensor, Tensor). Because Dynamo has already +# arranged for floats to be Tensor inputs to the graph, for typical float +# compute you can entirely translate the Python float operations into Tensor +# operations with only Tensor inputs. +# +# This pass is also responsible for doing CSE on the fly as we do this, since +# you don't want to keep recomputing the same quantity over and over again if +# it's used multiple times. +# +# This pass runs on the JOINT graph produced by AOT Autograd, prior to partitioning. +# The primary goal of this pass is to eliminate floats by replacing TensorScalar +# operations with TensorTensor operations and then Dead Code Elimination (DCE) of +# the item calls, which effectively removes the floats. +# +# This needs to happen before partitioning because it influences partitioning decisions, +# specifically by ensuring that we don't need to save floats across partitions. +# Additionally, there is a separate pass that changes which device computations +# occur on. That pass must be run after this one, but still before partitioning. +# +# HISTORY NOTE: Originally, I wanted to formulate this pass as pushing item() +# calls down, transforming float compute into int compute as we went. If you +# manage to eliminate all float compute, this ends up being equivalent, but +# there is a critical difference when some floats cannot be eliminated: when +# we call item() on them, what should it's SymFloat be? Ideally, it would +# be the same backed SymFloat we had before. But without symbolic expresssion +# propogation on tensor quantities, repropagating would instead give you an +# unbacked SymFloat. Maybe it is a good idea to implement symbolic propagation +# on 0d scalar tensors, but I decided to go for something simpler to start. +# +# The boring stuff: +# +# * What operators can I Tensor-ify? (Anything with a Scalar argument) +# * How do I Tensor-ify a SymFloat sympy expression (Sympy -> Op Handler -> Tensor) +# +# TODO: make sure this runs before CPU->CUDA pass for cudagraph friendliness + + +SUPPORTED_OPS = { + torch.ops.aten.mul.Tensor, + torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Tensor, + torch.ops.aten.div.Tensor, +} + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def tensorify_python_scalars( + gm: GraphModule, shape_env: ShapeEnv, fake_mode: fake_tensor.FakeTensorMode +) -> None: + """ + Converts Python scalar operations into Tensor operations within the graph. This pass looks for + Tensor operations that involve SymFloat arguments and transforms them into equivalent operations + that use only Tensor inputs. + + Args: + gm: The FX graph module representing the computation graph. + shape_env: The shape environment responsible for symbolic shape tracking and propagation + during graph transformations. + + Returns: + None + """ + import sympy + + knob = JustKnobsConfig( + name="pytorch/compiler:tensorify_python_scalars", + env_name="TENSORIFY_PYTHON_SCALARS", + default=True, + ).get() + if not knob: + return None + + graph = gm.graph + tracer = fx.proxy.GraphAppendingTracer(graph) + expr_to_sym_proxy: dict[sympy.Expr, MetaProxy] = {} + expr_to_tensor_proxy: dict[sympy.Expr, MetaProxy] = {} + + first_non_placeholder = None + placeholders = set() + for node in graph.nodes: + if node.op != "placeholder": + break + else: + placeholders.add(node) + + Analysis = TensorReferenceAnalysis + + def _sympy_interp(expr: sympy.Expr) -> MetaProxy: + # sympy_interp() with hash consing, and special handling for + # generating constants correctly + from sympy import Integer, Number, Symbol + from sympy.logic.boolalg import BooleanAtom + + from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp + + # hash cons + if isinstance(expr, Symbol) and expr not in expr_to_tensor_proxy: + # This is guaranteed to be populated by invariant established by + # insert_deferred_runtime_asserts + expr_to_tensor_proxy[expr] = torch.ops.aten.scalar_tensor.default( + expr_to_sym_proxy[expr] + ) + + # cache constants, why not + if isinstance(expr, (Integer, Number, BooleanAtom)): + dtype = None + c: Union[bool, int, float] + if isinstance(expr, BooleanAtom): + dtype = torch.bool + c = bool(expr) + elif isinstance(expr, sympy.Integer): + dtype = torch.int64 + c = int(expr) + elif isinstance(expr, sympy.Number): + dtype = torch.float64 + c = float(expr) + + node = graph.call_function( + torch.ops.aten.scalar_tensor.default, (c,), {"dtype": dtype} + ) + with fake_mode: + node.meta["val"] = torch.ops.aten.scalar_tensor.default(c, dtype=dtype) + expr_to_tensor_proxy[expr] = MetaProxy( + node, + tracer=tracer, + fake_mode=fake_mode, + ) + + if expr in expr_to_tensor_proxy: + return expr_to_tensor_proxy[expr] + + # don't cache + if isinstance(expr, Symbol): + return sympy_interp(Analysis, expr_to_tensor_proxy, expr) # type: ignore[arg-type] + + # hash cons on arguments, run expr handler + expr_to_tensor_proxy[expr] = _run_sympy_handler( + Analysis, + [_sympy_interp(arg) for arg in expr.args], # type: ignore[arg-type] + expr, + ) + + return expr_to_tensor_proxy[expr] + + nodes = list(graph.nodes) + for i, node in enumerate(nodes[:-1]): + with graph.inserting_before( + nodes[i + 1] if node not in placeholders else first_non_placeholder + ): + # Look for tensor.item() calls on placeholders + if unbacked_bindings := node.meta.get("unbacked_bindings"): + for s in unbacked_bindings.keys(): + if ( + node is not None + and node.op == "call_function" + and node.target is torch.ops.aten._local_scalar_dense.default + ): + dtype = node.args[0].meta["val"].dtype + if dtype != torch.float64: + continue + + assert isinstance(node.args[0], fx.Node), node.args[0] + + expr_to_tensor_proxy[s] = MetaProxy( + node.args[0], tracer=tracer, fake_mode=fake_mode + ) + expr_to_sym_proxy[s] = MetaProxy( + node, tracer=tracer, fake_mode=fake_mode + ) + + elif (sym_expr := _get_sym_val(node)) is not None: + if sym_expr not in expr_to_sym_proxy and not isinstance( + sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom) + ): + expr_to_sym_proxy[sym_expr] = MetaProxy( + node, tracer=tracer, fake_mode=fake_mode + ) + + # Look for functions to convert + if node.op == "call_function" and node.target in SUPPORTED_OPS: + args = [] + transform = False + compute_dtype = get_computation_dtype(node.meta["val"].dtype) + + for a in node.args: + if ( + isinstance(a, fx.Node) + and "val" in a.meta + and isinstance(zf := a.meta["val"], torch.SymFloat) + ): + transform = True + try: + proxy = _sympy_interp(zf.node.expr) + except NotImplementedError: + transform = False + break + + if proxy.node.meta["val"].dtype != compute_dtype: + proxy = torch.ops.prims.convert_element_type.default( + proxy, compute_dtype + ) + + args.append(proxy) + else: + args.append(MetaProxy(a, tracer=tracer, fake_mode=fake_mode)) + + if transform: + replacement_proxy = node.target(*args) + + if compute_dtype != node.meta["val"].dtype: + replacement_proxy = ( + torch.ops.prims.convert_element_type.default( + replacement_proxy, + node.meta["val"].dtype, + ) + ) + + node.replace_all_uses_with(replacement_proxy.node) + + graph.erase_node(node) + + # DCE symbols (which are guaranteed to be pure) only + for proxy in reversed(expr_to_sym_proxy.values()): + if len(proxy.node.users) == 0 and proxy.node.op != "placeholder": + graph.erase_node(proxy.node) + + graph_code_log.debug( + "%s", lazy_format_graph_code("tensorify_python_scalars", gm, colored=True) + ) diff --git a/torch/fx/passes/backends/cudagraphs.py b/torch/fx/passes/backends/cudagraphs.py index 0f48165b7dab4..b98178f0d5339 100644 --- a/torch/fx/passes/backends/cudagraphs.py +++ b/torch/fx/passes/backends/cudagraphs.py @@ -1,12 +1,13 @@ # mypy: allow-untyped-defs +import operator + import torch +from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupport from torch.fx.passes.tools_common import CALLABLE_NODE_OPS -from torch.fx.passes.fake_tensor_prop import FakeTensorProp from torch.utils import _pytree as pytree -import operator class CudaGraphsSupport(OperatorSupport): # TODO: why is submodules passed here @@ -27,7 +28,7 @@ def meta_fk(meta): def find_not_cuda(t): nonlocal found_not_cuda - if isinstance(t, torch.Tensor) and t.device.type != 'cuda': + if isinstance(t, torch.Tensor) and t.device.type != "cuda": found_not_cuda = True for n in node.all_input_nodes: @@ -40,6 +41,7 @@ def find_not_cuda(t): return not found_not_cuda + def partition_cudagraphs(gm, inputs): """ Partition an FX graph into sub-GraphModules that can be validly run under @@ -51,7 +53,9 @@ def partition_cudagraphs(gm, inputs): supported_ops = CudaGraphsSupport() # TODO: single node partition may be wrong due to the pessimization # from copying in and out the data. Check in benchmarks, perhaps - partitioner = CapabilityBasedPartitioner(gm, supported_ops, allows_single_node_partition=True) + partitioner = CapabilityBasedPartitioner( + gm, supported_ops, allows_single_node_partition=True + ) partitions = partitioner.propose_partitions() fused_graph = partitioner.fuse_partitions(partitions) return fused_graph diff --git a/torch/fx/passes/dialect/common/cse_pass.py b/torch/fx/passes/dialect/common/cse_pass.py index 577f445e7b316..6a501f041d193 100644 --- a/torch/fx/passes/dialect/common/cse_pass.py +++ b/torch/fx/passes/dialect/common/cse_pass.py @@ -1,20 +1,45 @@ # mypy: allow-untyped-defs -from typing import Dict, Tuple, Any +from typing import Any, Dict, Tuple import torch +from torch.fx import Graph, GraphModule, Node from torch.fx.passes.infra.pass_base import PassBase, PassResult from torch.utils._pytree import tree_flatten -from torch.fx import GraphModule, Graph -from torch.fx import Node aten = torch.ops.aten # stateful ops are banned from CSE -rand_ops = {aten.dropout, aten._fused_dropout, aten._standard_gamma, aten.bernoulli, aten.multinomial, aten.native_dropout, aten.normal, aten.poisson, aten.binomial, aten.rrelu, aten.rand_like, aten.rand, aten.randint, aten.randn, aten.randperm} # noqa: E501,B950 - -inplace_ops = {aten.add_, aten.sub_, aten.mul_, aten.div_, aten.pow_, aten.lerp_, aten.relu_, aten.sigmoid_, aten.tanh_} # noqa: E501 +rand_ops = { + aten.dropout, + aten._fused_dropout, + aten._standard_gamma, + aten.bernoulli, + aten.multinomial, + aten.native_dropout, + aten.normal, + aten.poisson, + aten.binomial, + aten.rrelu, + aten.rand_like, + aten.rand, + aten.randint, + aten.randn, + aten.randperm, +} # noqa: E501,B950 + +inplace_ops = { + aten.add_, + aten.sub_, + aten.mul_, + aten.div_, + aten.pow_, + aten.lerp_, + aten.relu_, + aten.sigmoid_, + aten.tanh_, +} # noqa: E501 @torch.fx._compatibility.compatibility(is_backward_compatible=False) @@ -24,7 +49,6 @@ def get_CSE_banned_ops(): @torch.fx._compatibility.compatibility(is_backward_compatible=False) class CSEPass(PassBase): - def __init__(self, banned_ops=None): """ This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node. @@ -58,20 +82,32 @@ def f(a): result = p(traced_graph) print(result.graph_module) """ + def get_aten_target(node): - if hasattr(node.target, 'overloadpacket'): + if hasattr(node.target, "overloadpacket"): return node.target.overloadpacket return node.target modified = False new_graph = Graph() - env: Dict[Node, Node] = {} # map from node in the old graph to node in the new graph - hash_env: Dict[Tuple[torch._ops.OpOverload, int], Node] = {} # map from hash to a node in the new graph - token_map: Dict[Tuple[torch._ops.OpOverload, int], Dict[str, Any]] = {} # map from hash to token + env: Dict[ + Node, Node + ] = {} # map from node in the old graph to node in the new graph + hash_env: Dict[ + Tuple[torch._ops.OpOverload, int], Node + ] = {} # map from hash to a node in the new graph + token_map: Dict[ + Tuple[torch._ops.OpOverload, int], Dict[str, Any] + ] = {} # map from hash to token for n in graph_module.graph.nodes: # The placeholder, output, and get_attr nodes are copied to the new graph without change # do not CSE away random operations - if n.op == 'placeholder' or n.op == 'output' or n.op == 'get_attr' or get_aten_target(n) in self.banned_ops: + if ( + n.op == "placeholder" + or n.op == "output" + or n.op == "get_attr" + or get_aten_target(n) in self.banned_ops + ): new_node = new_graph.node_copy(n, lambda x: env[x]) env[n] = new_node else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' @@ -84,13 +120,19 @@ def substitute(arg_list): if isinstance(v, Node) and v in env: arg_list[i] = env[v] return tuple(arg_list), spec + args, args_spec = substitute(n.args) kwargs, kwargs_spec = substitute(n.kwargs) # each token corresponds to a unique node # nodes with the same token can be substituted - token = {"target": n.target, "args": args, "args_spec": args_spec, - "kwargs": kwargs, "kwargs_spec": kwargs_spec} + token = { + "target": n.target, + "args": args, + "args_spec": args_spec, + "kwargs": kwargs, + "kwargs_spec": kwargs_spec, + } # hash substituted args to a number, do not hash specs because specs are not hashable hash_arg = hash((args, kwargs)) diff --git a/torch/fx/passes/fake_tensor_prop.py b/torch/fx/passes/fake_tensor_prop.py index 2b40207e0f804..8036f5d0fd556 100644 --- a/torch/fx/passes/fake_tensor_prop.py +++ b/torch/fx/passes/fake_tensor_prop.py @@ -2,13 +2,15 @@ from typing import Optional import torch.fx +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch.fx import Node -from torch.fx.node import map_aggregate from torch.fx._compatibility import compatibility -from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor -from torch.fx.experimental.proxy_tensor import snapshot_fake, py_sym_types +from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake +from torch.fx.node import map_aggregate + + +__all__ = ["FakeTensorProp"] -__all__ = ['FakeTensorProp'] @compatibility(is_backward_compatible=False) class FakeTensorProp(torch.fx.Interpreter): @@ -24,7 +26,10 @@ class FakeTensorProp(torch.fx.Interpreter): module (GraphModule): The module to be executed mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node. """ - def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None): + + def __init__( + self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None + ): super().__init__(module) if mode is None: mode = FakeTensorMode() @@ -33,7 +38,10 @@ def __init__(self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] mode.reset_nt_tensor_id_counter() def run_node(self, n: Node): - from torch.fx.experimental.symbolic_shapes import rebind_unbacked, compute_unbacked_bindings + from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + rebind_unbacked, + ) result = super().run_node(n) rebind_unbacked(self._mode.shape_env, n, result) @@ -52,8 +60,10 @@ def extract_val(obj): meta = map_aggregate(result, extract_val) if meta is not None: - n.meta['val'] = meta - if (shape_env := self._mode.shape_env) and (symbol_to_path := compute_unbacked_bindings(shape_env, result)): + n.meta["val"] = meta + if (shape_env := self._mode.shape_env) and ( + symbol_to_path := compute_unbacked_bindings(shape_env, result) + ): n.meta["unbacked_bindings"] = symbol_to_path return result diff --git a/torch/fx/passes/graph_drawer.py b/torch/fx/passes/graph_drawer.py index 975b2b6171780..9a1710c9721ae 100644 --- a/torch/fx/passes/graph_drawer.py +++ b/torch/fx/passes/graph_drawer.py @@ -58,6 +58,7 @@ } if HAS_PYDOT: + @compatibility(is_backward_compatible=False) class FxGraphDrawer: """ @@ -87,7 +88,12 @@ def __init__( self._dot_graphs = { name: self._to_dot( - graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace + graph_module, + name, + ignore_getattr, + ignore_parameters_and_buffers, + skip_node_names_in_args, + parse_stack_trace, ) } @@ -127,8 +133,8 @@ def get_dot_graph(self, submod_name=None) -> pydot.Dot: >>> symbolic_traced = torch.fx.symbolic_trace(module) >>> # setup output file >>> import ubelt as ub - >>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir() - >>> fpath = dpath / 'linear.svg' + >>> dpath = ub.Path.appdir("torch/tests/FxGraphDrawer").ensuredir() + >>> fpath = dpath / "linear.svg" >>> # draw the graph >>> g = FxGraphDrawer(symbolic_traced, "linear") >>> g.get_dot_graph().write_svg(fpath) @@ -148,7 +154,6 @@ def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]: return self._dot_graphs def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]: - template = { "shape": self.dot_graph_shape, "fillcolor": "#CAFFE3", @@ -161,7 +166,9 @@ def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]: # Use a random color for each node; based on its name so it's stable. target_name = node._pretty_print_target(node.target) target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16) - template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)] + template["fillcolor"] = _HASH_COLOR_MAP[ + target_hash % len(_HASH_COLOR_MAP) + ] return template def _get_leaf_node( @@ -199,12 +206,11 @@ def _shorten_file_name( full_file_name: str, truncate_to_last_n: int = 2, ): - splits = full_file_name.split('/') + splits = full_file_name.split("/") if len(splits) >= truncate_to_last_n: - return '/'.join(splits[-truncate_to_last_n:]) + return "/".join(splits[-truncate_to_last_n:]) return full_file_name - def _get_node_label( self, module: torch.fx.GraphModule, @@ -219,8 +225,7 @@ def _get_str_for_args_kwargs(arg): elif isinstance(arg, dict): prefix, suffix = r"|kwargs={\l", r",\n}\l" arg_strs_list = [ - f"{k}: {_format_arg(v, max_list_len=8)}" - for k, v in arg.items() + f"{k}: {_format_arg(v, max_list_len=8)}" for k, v in arg.items() ] else: # Fall back to nothing in unexpected case. return "" @@ -235,7 +240,6 @@ def _get_str_for_args_kwargs(arg): arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "") return arg_strs.replace("{", r"\{").replace("}", r"\}") - label = "{" + f"name=%{node.name}|op_code={node.op}\n" if node.op == "call_module": @@ -244,7 +248,10 @@ def _get_str_for_args_kwargs(arg): extra = "" if hasattr(leaf_module, "__constants__"): extra = r"\n".join( - [f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] # type: ignore[union-attr] + [ + f"{c}: {getattr(leaf_module, c)}" + for c in leaf_module.__constants__ + ] # type: ignore[union-attr] ) label += extra + r"\n" else: @@ -252,7 +259,10 @@ def _get_str_for_args_kwargs(arg): if self.normalize_args: try: args, kwargs = normalize_function( # type: ignore[misc] - node.target, node.args, node.kwargs, normalize_to_only_use_kwargs=True # type: ignore[arg-type] + node.target, # type: ignore[arg-type] + node.args, # type: ignore[arg-type] + node.kwargs, + normalize_to_only_use_kwargs=True, ) except Exception: # Fallback to not normalizing if there's an exception. @@ -266,12 +276,12 @@ def _get_str_for_args_kwargs(arg): label += _get_str_for_args_kwargs(kwargs) label += f"|num_users={len(node.users)}" + r"\n" - tensor_meta = node.meta.get('tensor_meta') + tensor_meta = node.meta.get("tensor_meta") label += self._tensor_meta_to_label(tensor_meta) # for original fx graph # print buf=buf0, n_origin=6 - buf_meta = node.meta.get('buf_meta', None) + buf_meta = node.meta.get("buf_meta", None) if buf_meta is not None: label += f"|buf={buf_meta.name}" + r"\n" label += f"|n_origin={buf_meta.n_origin}" + r"\n" @@ -281,8 +291,10 @@ def _get_str_for_args_kwargs(arg): if parse_stack_trace and node.stack_trace is not None: parsed_stack_trace = _parse_stack_trace(node.stack_trace) fname = self._shorten_file_name(parsed_stack_trace.file) - label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n" - + label += ( + f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + + r"\n" + ) return label + "}" @@ -322,19 +334,43 @@ def _stringify_tensor_meta(self, tm: TensorMetadata) -> str: assert "qscheme" in tm.qparams qscheme = tm.qparams["qscheme"] if qscheme in { - torch.per_tensor_affine, - torch.per_tensor_symmetric, + torch.per_tensor_affine, + torch.per_tensor_symmetric, }: result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n" - result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" + result += ( + "|" + + "q_zero_point" + + "=" + + str(tm.qparams["zero_point"]) + + r"\n" + ) elif qscheme in { - torch.per_channel_affine, - torch.per_channel_symmetric, - torch.per_channel_affine_float_qparams, + torch.per_channel_affine, + torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams, }: - result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n" - result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" - result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n" + result += ( + "|" + + "q_per_channel_scale" + + "=" + + str(tm.qparams["scale"]) + + r"\n" + ) + result += ( + "|" + + "q_per_channel_zero_point" + + "=" + + str(tm.qparams["zero_point"]) + + r"\n" + ) + result += ( + "|" + + "q_per_channel_axis" + + "=" + + str(tm.qparams["axis"]) + + r"\n" + ) else: raise RuntimeError(f"Unsupported qscheme: {qscheme}") result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n" @@ -363,7 +399,6 @@ def _to_dot( # "TB" means top-to-bottom rank direction in layout dot_graph = pydot.Dot(name, rankdir="TB") - buf_name_to_subgraph = {} for node in graph_module.graph.nodes: @@ -372,16 +407,22 @@ def _to_dot( style = self._get_node_style(node) dot_node = pydot.Node( - node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style + node.name, + label=self._get_node_label( + graph_module, node, skip_node_names_in_args, parse_stack_trace + ), + **style, ) current_graph = dot_graph - buf_meta = node.meta.get('buf_meta', None) + buf_meta = node.meta.get("buf_meta", None) if buf_meta is not None and buf_meta.n_origin > 1: buf_name = buf_meta.name if buf_name not in buf_name_to_subgraph: - buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name) + buf_name_to_subgraph[buf_name] = pydot.Cluster( + buf_name, label=buf_name + ) current_graph = buf_name_to_subgraph.get(buf_name) current_graph.add_node(dot_node) @@ -407,12 +448,14 @@ def get_module_params_or_buffers(): if node.op == "call_module": leaf_module = self._get_leaf_node(graph_module, node) - if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule): + if not ignore_parameters_and_buffers and not isinstance( + leaf_module, torch.fx.GraphModule + ): get_module_params_or_buffers() for subgraph in buf_name_to_subgraph.values(): - subgraph.set('color', 'royalblue') - subgraph.set('penwidth', '2') + subgraph.set("color", "royalblue") + subgraph.set("penwidth", "2") dot_graph.add_subgraph(subgraph) for node in graph_module.graph.nodes: @@ -426,6 +469,7 @@ def get_module_params_or_buffers(): else: if not TYPE_CHECKING: + @compatibility(is_backward_compatible=False) class FxGraphDrawer: def __init__( @@ -439,5 +483,7 @@ def __init__( dot_graph_shape: Optional[str] = None, normalize_args: bool = False, ): - raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install ' - 'pydot through your favorite Python package manager.') + raise RuntimeError( + "FXGraphDrawer requires the pydot package to be installed. Please install " + "pydot through your favorite Python package manager." + ) diff --git a/torch/fx/passes/graph_manipulation.py b/torch/fx/passes/graph_manipulation.py index 36c59cb31af05..ce9904fc500e8 100644 --- a/torch/fx/passes/graph_manipulation.py +++ b/torch/fx/passes/graph_manipulation.py @@ -5,15 +5,18 @@ from torch.fx._compatibility import compatibility from torch.fx.graph import Graph from torch.fx.graph_module import GraphModule -from torch.fx.node import ( - map_arg, - Node, - Target, -) +from torch.fx.node import map_arg, Node, Target from torch.fx.passes.shape_prop import ShapeProp -__all__ = ['replace_target_nodes_with', 'size_bytes', 'get_size_of_all_nodes', 'get_tensor_meta', - 'get_size_of_node'] + +__all__ = [ + "replace_target_nodes_with", + "size_bytes", + "get_size_of_all_nodes", + "get_tensor_meta", + "get_size_of_node", +] + @compatibility(is_backward_compatible=False) def replace_target_nodes_with( @@ -58,7 +61,6 @@ def get_size_of_all_nodes( # Mark shape and dtype for each node (node.shape and node.dtype) ShapeProp(fx_module).propagate(*args) # Calculate the total size of the whole fx graph - total_size_of_graph = 0.0 for node in fx_module.graph.nodes: if node.op == "output": break @@ -92,7 +94,7 @@ def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes: submodule = submodule_dict[node.target] parameters = submodule.named_parameters() # Parameters are named tuples - for name, p in parameters: + for _name, p in parameters: total_num_of_elems += p.numel() # Don't forget the output size # node.shape is the shape of this node's output diff --git a/torch/fx/passes/infra/__init__.py b/torch/fx/passes/infra/__init__.py index 657b6a93014f4..939157f1302e7 100644 --- a/torch/fx/passes/infra/__init__.py +++ b/torch/fx/passes/infra/__init__.py @@ -1,2 +1 @@ - from . import pass_manager diff --git a/torch/fx/passes/infra/partitioner.py b/torch/fx/passes/infra/partitioner.py index 4ffb5e3c36412..122545b8dccfe 100644 --- a/torch/fx/passes/infra/partitioner.py +++ b/torch/fx/passes/infra/partitioner.py @@ -1,22 +1,24 @@ # mypy: allow-untyped-defs -from torch.fx.passes.utils.fuser_utils import fuse_by_partitions import collections import itertools import logging - from copy import copy from typing import Dict, Iterable, List, Optional, Sequence, Set from torch.fx.graph_module import GraphModule -from torch.fx.node import Node, _get_qualified_name +from torch.fx.node import _get_qualified_name, Node from torch.fx.passes.operator_support import OperatorSupportBase +from torch.fx.passes.utils.fuser_utils import fuse_by_partitions logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) + class Partition: - def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None): + def __init__( + self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None + ): self.id = id self.nodes = dict.fromkeys(nodes) if nodes is not None else {} @@ -32,6 +34,7 @@ def remove_node(self, node: Node): def size(self): return len(self.nodes) + class _DependencyViewer: def __init__(self, graph_module: GraphModule): self.upstreams = collections.defaultdict(set) @@ -55,15 +58,16 @@ def downstreams_of(self, node: Node) -> Set[Node]: def upstreams_of(self, node: Node) -> Set[Node]: return self.upstreams[node] -class CapabilityBasedPartitioner: - def __init__(self, - graph_module: GraphModule, - operator_support: OperatorSupportBase, - allows_single_node_partition: bool = False, - non_compute_ops: Optional[Sequence[str]] = None, - allowed_single_node_partition_ops: Optional[Sequence[str]] = None, - ) -> None: +class CapabilityBasedPartitioner: + def __init__( + self, + graph_module: GraphModule, + operator_support: OperatorSupportBase, + allows_single_node_partition: bool = False, + non_compute_ops: Optional[Sequence[str]] = None, + allowed_single_node_partition_ops: Optional[Sequence[str]] = None, + ) -> None: self.graph_module = graph_module self.operator_support = operator_support self.allows_single_node_partition = allows_single_node_partition @@ -76,19 +80,21 @@ def __init__(self, self.dependency_viewer = _DependencyViewer(graph_module) def __is_node_supported(self, node: Node) -> bool: - return ( - self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node) + return self.operator_support.is_node_supported( + dict(self.graph_module.named_modules()), node ) def propose_partitions(self) -> List[Partition]: # partition_map is a mapping from partition id to a set of partition id's. # The value set contains all the partition ids that can be reached by doing a # DFS starting from the partition id in the key. - partition_map : Dict[int, Set] = collections.defaultdict(set) + partition_map: Dict[int, Set] = collections.defaultdict(set) # assumptions: nodes in candidate list is sorted in topological order - assignment: Dict[Node, int] = {} # mapping from node to partition_id - partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition + assignment: Dict[Node, int] = {} # mapping from node to partition_id + partitions_by_id: Dict[ + int, Partition + ] = {} # mapping from partition_id to partition new_partition_id = itertools.count() # try to merge partition other_id into partition self_id @@ -149,7 +155,9 @@ def dfs_iter_find_cycle(all_user_nodes: Set[Node]): # delete other partition del partitions_by_id[other_id] - partition_map[self_id] = partition_map[self_id].union(partition_map[other_id]) + partition_map[self_id] = partition_map[self_id].union( + partition_map[other_id] + ) del partition_map[other_id] return True @@ -223,16 +231,18 @@ def _update_partition_map(node: Node, id: int): for node in self.graph_module.graph.nodes: is_tuple_output = True for user in node.users: - if user.op != "call_function" or \ - _get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type] + if ( + user.op != "call_function" + or _get_qualified_name(user.target) != "_operator.getitem" + ): # type: ignore[arg-type] is_tuple_output = False break # node has tuple outputs, re-assign all following getitem node into node's partition if is_tuple_output: - id = assignment.get(node, None) # type: ignore[arg-type] + id = assignment.get(node, None) # type: ignore[arg-type] for user in node.users: - if assignment.get(user, None) != id: # type: ignore[arg-type] + if assignment.get(user, None) != id: # type: ignore[arg-type] nodes_reassignment[user] = id # type: ignore[assignment] for node, id in nodes_reassignment.items(): merge_single_node(node, id) @@ -250,7 +260,10 @@ def _update_partition_map(node: Node, id: int): assert callable(node.target) if _get_qualified_name(node.target) not in non_compute_ops: compute_node_count += 1 - if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops: + if ( + _get_qualified_name(node.target) + in self.allowed_single_node_partition_ops + ): compute_node_count += 1 if compute_node_count <= 1: partitions_to_remove.append(id) @@ -259,11 +272,17 @@ def _update_partition_map(node: Node, id: int): logger.debug("Partitions proposed:") for id, partition in partitions_by_id.items(): - logger.debug("partition #%s: %s", id, [node.name for node in partition.nodes]) + logger.debug( + "partition #%s: %s", id, [node.name for node in partition.nodes] + ) - return [partition for partition in partitions_by_id.values() if partition.size() > 0] + return [ + partition for partition in partitions_by_id.values() if partition.size() > 0 + ] - def fuse_partitions(self, partitions: List[Partition], prefix: str = "fused_") -> GraphModule: + def fuse_partitions( + self, partitions: List[Partition], prefix: str = "fused_" + ) -> GraphModule: logger.debug("Fusing partitions...") # fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ] return fuse_by_partitions( @@ -277,15 +296,23 @@ def remove_bookend_non_compute_ops(self, partitions: List[Partition]): non_compute_ops = set(self.non_compute_ops) def is_non_compute_node(node: Node): - return node.op == "call_function" and \ - _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type] + return ( + node.op == "call_function" + and _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type] + ) # cache transparent nodes transparent_input_nodes: Dict[Node, bool] = {} transparent_output_nodes: Dict[Node, bool] = {} - def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]): - if node.op == "placeholder" or (node not in partition) or (node in removed_nodes): + def is_transparent_input_node( + node: Node, partition: Set[Node], removed_nodes: Set[Node] + ): + if ( + node.op == "placeholder" + or (node not in partition) + or (node in removed_nodes) + ): return True if node in transparent_input_nodes: return transparent_input_nodes[node] @@ -299,14 +326,22 @@ def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: S transparent_input_nodes[node] = False return False - def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]): - if node.op == "placeholder" or (node not in partition) or (node in removed_nodes): + def is_transparent_output_node( + node: Node, partition: Set[Node], removed_nodes: Set[Node] + ): + if ( + node.op == "placeholder" + or (node not in partition) + or (node in removed_nodes) + ): return True if node in transparent_output_nodes: return transparent_output_nodes[node] if is_non_compute_node(node): for output_n in node.users: - if not is_transparent_output_node(output_n, partition, removed_nodes): + if not is_transparent_output_node( + output_n, partition, removed_nodes + ): transparent_output_nodes[node] = False return False transparent_output_nodes[node] = True @@ -320,9 +355,12 @@ def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: # the set. remove_node: Set[Node] = set() for node in partition.nodes: - if is_non_compute_node(node) and \ - (is_transparent_input_node(node, set(partition.nodes), remove_node) or - is_transparent_output_node(node, set(partition.nodes), remove_node)): + if is_non_compute_node(node) and ( + is_transparent_input_node(node, set(partition.nodes), remove_node) + or is_transparent_output_node( + node, set(partition.nodes), remove_node + ) + ): remove_node.add(node) if len(remove_node) != 0: diff --git a/torch/fx/passes/infra/pass_base.py b/torch/fx/passes/infra/pass_base.py index 3f5b64eafbb60..acf78d2581b5a 100644 --- a/torch/fx/passes/infra/pass_base.py +++ b/torch/fx/passes/infra/pass_base.py @@ -3,11 +3,12 @@ from collections import namedtuple from typing import Optional -from torch.fx.graph_module import GraphModule from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule + +__all__ = ["PassResult", "PassBase"] -__all__ = ['PassResult', 'PassBase'] @compatibility(is_backward_compatible=False) class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): @@ -16,9 +17,11 @@ class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): graph_module: The modified graph module modified: A flag for if the pass has modified the graph module """ + def __new__(cls, graph_module, modified): return super().__new__(cls, graph_module, modified) + @compatibility(is_backward_compatible=False) class PassBase(abc.ABC): """ diff --git a/torch/fx/passes/infra/pass_manager.py b/torch/fx/passes/infra/pass_manager.py index 29540fa447eb1..cea5f4f25c77b 100644 --- a/torch/fx/passes/infra/pass_manager.py +++ b/torch/fx/passes/infra/pass_manager.py @@ -1,19 +1,21 @@ # mypy: allow-untyped-defs import inspect import logging -from queue import Queue from functools import wraps +from queue import Queue from typing import Callable, Dict, List import torch.nn as nn -from torch.fx.graph_module import GraphModule from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule from torch.fx.passes.infra.pass_base import PassResult + logger = logging.getLogger(__name__) logger.setLevel(logging.WARNING) -__all__ = ['pass_result_wrapper', 'this_before_that_pass_constraint', 'PassManager'] +__all__ = ["pass_result_wrapper", "this_before_that_pass_constraint", "PassManager"] + @compatibility(is_backward_compatible=False) def pass_result_wrapper(fn: Callable) -> Callable: @@ -46,6 +48,7 @@ def wrapped_fn(gm): return wrapped_fn + def _validate_pass_schedule_constraint( constraint: Callable[[Callable, Callable], bool], passes: List[Callable] ) -> None: @@ -59,6 +62,7 @@ def _validate_pass_schedule_constraint( f" list." ) + def _topological_sort_passes( passes: List[Callable], constraints: List[Callable] ) -> List[Callable]: @@ -75,7 +79,7 @@ def _topological_sort_passes( return passes # Contruct a graph mapping nodes to a list of their users - graph: Dict[Callable, List[Callable]] = {p : [] for p in passes} + graph: Dict[Callable, List[Callable]] = {p: [] for p in passes} indegree_map: Dict[Callable, int] = dict.fromkeys(passes, 0) candidates: Queue = Queue() for a in passes: @@ -108,11 +112,14 @@ def _topological_sort_passes( # Check if there are unvisited nodes (aka cycles in the graph) cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys())) if len(cycle_passes) != 0: - error = f"Circular dependency detected within the following passes: {cycle_passes}" + error = ( + f"Circular dependency detected within the following passes: {cycle_passes}" + ) raise RuntimeError(error) return sorted_passes + @compatibility(is_backward_compatible=False) def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable: """ @@ -123,9 +130,7 @@ def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable ``` passes = [pass_b, pass_a] - constraints = [ - this_before_that_pass_constraint(pass_a, pass_b) - ] + constraints = [this_before_that_pass_constraint(pass_a, pass_b)] ``` Args: @@ -231,7 +236,9 @@ def add_checks(self, check: Callable) -> None: sig = inspect.signature(check) if len(list(sig.parameters.values())) != 1: - raise TypeError("PassManager check function should only take in one variable, a module") + raise TypeError( + "PassManager check function should only take in one variable, a module" + ) setattr(self, "check", check) # noqa: B010 diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index 6182972e670ea..81f8a845e83f7 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -5,7 +5,6 @@ import torch import torch.fx - from torch.fx._compatibility import compatibility from torch.fx.node import map_arg @@ -21,6 +20,7 @@ Tensors, ) + __all__ = [ "FxNetMinimizerBadModuleError", "FxNetMinimizerRunFuncError", @@ -37,7 +37,6 @@ class FxNetMinimizerBadModuleError(Exception): """ - @compatibility(is_backward_compatible=False) class FxNetMinimizerRunFuncError(Exception): """ @@ -45,7 +44,6 @@ class FxNetMinimizerRunFuncError(Exception): """ - @compatibility(is_backward_compatible=False) class FxNetMinimizerResultMismatchError(Exception): """ @@ -53,7 +51,6 @@ class FxNetMinimizerResultMismatchError(Exception): """ - @dataclass class _MinimizerSettingBase: """ @@ -109,14 +106,9 @@ def __init__( ], settings: _MinimizerSettingBase, module_exporter: Optional[ - Callable[ - [Tensors, torch.fx.GraphModule, str], - None - ] - ] = None, - exclusion_fn: Optional[ - Callable[[NodeList, int, int], None] + Callable[[Tensors, torch.fx.GraphModule, str], None] ] = None, + exclusion_fn: Optional[Callable[[NodeList, int, int], None]] = None, ): assert isinstance(module, torch.fx.GraphModule) @@ -159,14 +151,18 @@ def __init__( self.a_outputs[name] = sample_input[i] self.b_outputs[name] = sample_input[i] - def run_a(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors: + def run_a( + self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1 + ) -> TensorOrTensors: """ Run `mod` with `inputs` and generate output. The output will be compared with output of run_b(). """ raise RuntimeError("run_a() is not implemented.") - def run_b(self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1) -> TensorOrTensors: + def run_b( + self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1 + ) -> TensorOrTensors: """ Run `mod` with `inputs` and generate output. The output will be compared with output of run_a(). @@ -323,7 +319,7 @@ def _run_and_compare( split_module: torch.fx.GraphModule, submod_name: str, output_names: Names, - report_idx: int = -1 + report_idx: int = -1, ): """ Run the submodule in `split_module` that has name `submod_name` @@ -388,10 +384,14 @@ def _run_and_compare( report.append(f"Result mismatch for {result_key}") if self.module_exporter: self.module_exporter( - a_input, submodule, str(result_key[0]) + "_cpu", # type: ignore[index] + a_input, + submodule, + str(result_key[0]) + "_cpu", # type: ignore[index] ) self.module_exporter( - b_input, submodule, str(result_key[0]) + "_acc", # type: ignore[index] + b_input, + submodule, + str(result_key[0]) + "_acc", # type: ignore[index] ) raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") @@ -418,7 +418,7 @@ def _binary_search_impl( self.reports.append(report) report.append(f"Binary search iteration {self.iteration}") report.append( - f"From node index {start_idx}:{first_node_name} to {end_idx-1}:{output_node_name}. " + f"From node index {start_idx}:{first_node_name} to {end_idx - 1}:{output_node_name}. " f"Size of the interested node list is {len(nodes)}" ) cur_nodes: NodeSet = set(nodes) @@ -428,7 +428,6 @@ def _binary_search_impl( self._run_and_compare(split_module, submod_name, [output_node_name]) except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError): - if len(nodes) == 1: report.append( f"This is the last node in the sub-module. " @@ -504,13 +503,13 @@ def _sequential_traverse(self, nodes: NodeList) -> NodeSet: split_module, submod_name = self._build_submodule(cur_nodes) self._run_and_compare(split_module, submod_name, [node.name]) self.print_report(report) - except (FxNetMinimizerResultMismatchError): + except FxNetMinimizerResultMismatchError: culprits.add(node) report.append(f"Found culprit from numeric error: {node}") self.print_report(report) if not self.settings.find_all: return culprits - except (FxNetMinimizerRunFuncError): + except FxNetMinimizerRunFuncError: culprits.update(cur_nodes) report.append(f"Found culprit from run error: {node}") self.print_report(report) @@ -519,8 +518,9 @@ def _sequential_traverse(self, nodes: NodeList) -> NodeSet: return culprits - - def _block_traverse_impl(self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool) -> int: + def _block_traverse_impl( + self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool + ) -> int: """ Recursive block search implementation. find_last_node: If True, search for the last node which result in numerics difference @@ -529,7 +529,7 @@ def _block_traverse_impl(self, nodes: NodeList, start_idx: int, end_idx: int, fi report: List[str] = [] mid = (start_idx + end_idx) // 2 - cur_nodes_list: NodeList = nodes[:mid + 1] if find_last_node else nodes[mid:] + cur_nodes_list: NodeList = nodes[: mid + 1] if find_last_node else nodes[mid:] if self.exclusion_fn: self.exclusion_fn(cur_nodes_list, -1, -1) @@ -561,16 +561,20 @@ def _block_traverse_impl(self, nodes: NodeList, start_idx: int, end_idx: int, fi try: split_module, submod_name = self._build_submodule(cur_nodes) - self._run_and_compare(split_module, submod_name, [last_node_name], report_idx) + self._run_and_compare( + split_module, submod_name, [last_node_name], report_idx + ) except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): - report.append(f"Culprits found from node {first_node_name} to {last_node_name}.") + report.append( + f"Culprits found from node {first_node_name} to {last_node_name}." + ) if start_idx == mid: report.extend( [ "This is the last node in the sub-module. ", "Search in the current branch is successful with node :", - f"{start_idx}, node name: {nodes[start_idx].name}." + f"{start_idx}, node name: {nodes[start_idx].name}.", ] ) self.print_report(report) @@ -585,9 +589,13 @@ def _block_traverse_impl(self, nodes: NodeList, start_idx: int, end_idx: int, fi if find_last_node: return self._block_traverse_impl(nodes, start_idx, mid, find_last_node) else: - return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node) + return self._block_traverse_impl( + nodes, mid + 1, end_idx, find_last_node + ) else: - report.append(f"Culprits not found from node start to {mid}:{nodes[mid].name}.") + report.append( + f"Culprits not found from node start to {mid}:{nodes[mid].name}." + ) if start_idx == mid: report.extend( @@ -607,12 +615,15 @@ def _block_traverse_impl(self, nodes: NodeList, start_idx: int, end_idx: int, fi self.print_report(report) if find_last_node: - return self._block_traverse_impl(nodes, mid + 1, end_idx, find_last_node) + return self._block_traverse_impl( + nodes, mid + 1, end_idx, find_last_node + ) else: return self._block_traverse_impl(nodes, start_idx, mid, find_last_node) - - def _block_traverse(self, nodes: NodeList, find_last_node: Optional[bool]) -> NodeSet: + def _block_traverse( + self, nodes: NodeList, find_last_node: Optional[bool] + ) -> NodeSet: """ Traverse topologically sorted node list Find minimium block (start_idx, end_idx) which contains the culprit @@ -639,10 +650,7 @@ def _block_traverse(self, nodes: NodeList, find_last_node: Optional[bool]) -> No self.print_report(last_node_report) end_idx = self._block_traverse_impl(nodes, start_idx, end_idx, True) last_node_report.extend( - [ - "Finish Pass 1", - f"Find end_idx = {end_idx}:{nodes[end_idx].name}" - ] + ["Finish Pass 1", f"Find end_idx = {end_idx}:{nodes[end_idx].name}"] ) self.print_report(last_node_report) @@ -650,25 +658,28 @@ def _block_traverse(self, nodes: NodeList, find_last_node: Optional[bool]) -> No if run_both or not find_last_node: first_node_report = ["Start searching for first node in culprit"] self.print_report(first_node_report) - start_idx = self._block_traverse_impl(nodes[0:end_idx + 1], start_idx, end_idx, False) + start_idx = self._block_traverse_impl( + nodes[0 : end_idx + 1], start_idx, end_idx, False + ) first_node_report.append("*" * 50) self.reports.append(first_node_report) first_node_report.extend( [ "Finish Pass 2", - f"Find start_idx = {start_idx}:{nodes[start_idx].name}" + f"Find start_idx = {start_idx}:{nodes[start_idx].name}", ] ) self.print_report(first_node_report) # step 3: form module with minimum culprits - culprits.update(nodes[start_idx:end_idx + 1]) - result_report = [f"Finish searching, found minimum block ({nodes[start_idx]},{nodes[end_idx]})"] + culprits.update(nodes[start_idx : end_idx + 1]) + result_report = [ + f"Finish searching, found minimum block ({nodes[start_idx]},{nodes[end_idx]})" + ] self.reports.append(result_report) self.print_report(result_report) return culprits - def _defined_traverse(self, nodes: NodeList) -> NodeSet: """ run user defined `nodes` and determine if it is a culprit. @@ -735,7 +746,9 @@ def _accumulate_traverse(self, nodes: NodeList) -> NodeSet: return culprits - def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) -> NodeSet: + def _skip_traverse_impl( + self, all_nodes: NodeList, start_idx: int, end_idx: int + ) -> NodeSet: """ Skip certain nodes in graph based on settings """ @@ -754,19 +767,19 @@ def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) self.iteration += 1 report.append(f" Nodes block {self.iteration}.") report.append( - f"From node index {start_idx} to {end_idx-1}. " + f"From node index {start_idx} to {end_idx - 1}. " f"Size of the interested node list is {len(nodes)}" ) try: split_module, submod_name = self._build_submodule(cur_nodes) self._run_and_compare(split_module, submod_name, []) - except (FxNetMinimizerResultMismatchError): + except FxNetMinimizerResultMismatchError: culprits.update(cur_nodes) report.append(f"Found culprit from numeric error: {cur_nodes}") self.print_report(report) return culprits - except (FxNetMinimizerRunFuncError): + except FxNetMinimizerRunFuncError: culprits.update(cur_nodes) report.append(f"Found culprit from run error: {cur_nodes}") self.print_report(report) @@ -776,7 +789,6 @@ def _skip_traverse_impl(self, all_nodes: NodeList, start_idx: int, end_idx: int) self.print_report(report) return set() - def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet: """ Skip certain nodes in graph based on settings @@ -787,7 +799,7 @@ def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet: culprits = set() while idx < num_nodes: node = all_nodes[idx] - if (node.name in skip_nodes): # skip the node + if node.name in skip_nodes: # skip the node if idx > start_idx: culprits = self._skip_traverse_impl(all_nodes, start_idx, idx) start_idx = idx + 1 @@ -797,8 +809,6 @@ def _skip_traverse(self, all_nodes: NodeList, skip_nodes: List) -> NodeSet: return culprits - - def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList: """ Collect nodes in the model that between nodes with name of `start` and `end`. @@ -911,8 +921,10 @@ def minimize( return self._accumulate_traverse(nodes) if self.settings.traverse_method == "skip": - if (skip_nodes is None): - raise RuntimeError("'skip_nodes' can't be None when 'traverse_method' is 'skip'.") + if skip_nodes is None: + raise RuntimeError( + "'skip_nodes' can't be None when 'traverse_method' is 'skip'." + ) return self._skip_traverse(nodes, skip_nodes) if self.settings.traverse_method == "defined": diff --git a/torch/fx/passes/operator_support.py b/torch/fx/passes/operator_support.py index 57edabc0a55ae..53e8be37cecf5 100644 --- a/torch/fx/passes/operator_support.py +++ b/torch/fx/passes/operator_support.py @@ -5,11 +5,19 @@ import torch import torch.fx from torch.fx._compatibility import compatibility + from .shape_prop import TensorMetadata -from .tools_common import get_node_target, CALLABLE_NODE_OPS +from .tools_common import CALLABLE_NODE_OPS, get_node_target -__all__ = ['OperatorSupportBase', 'OperatorSupport', 'create_op_support', 'chain', 'OpSupports', 'any_chain'] +__all__ = [ + "OperatorSupportBase", + "OperatorSupport", + "create_op_support", + "chain", + "OpSupports", + "any_chain", +] # fx.Node.target typename, as returned by `get_node_target()` TargetTypeName = str @@ -28,6 +36,7 @@ @compatibility(is_backward_compatible=False) class OperatorSupportBase(abc.ABC): """Interface for determining if a fx.Node is supported by a backend""" + @abc.abstractmethod def is_node_supported( self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node @@ -57,10 +66,7 @@ class OperatorSupport(OperatorSupportBase): _support_dict: SupportDict - def __init__( - self, - support_dict: t.Optional[SupportDict] = None - ): + def __init__(self, support_dict: t.Optional[SupportDict] = None): self._support_dict = support_dict or {} def is_node_supported( @@ -139,11 +145,13 @@ def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase `IsNodeSupported` has the same call signature as `OperatorSupportBase.is_node_supported` """ + class FunctionalOperatorSupport(OperatorSupportBase): def is_node_supported( - self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node ) -> bool: return is_node_supported(submodules, node) + return FunctionalOperatorSupport() @@ -153,11 +161,10 @@ def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: instance by evaluating each input `OperatorSupportBase` instance, and returns False if any of it reports False. """ + def _chain(submods, node) -> bool: - return all( - x.is_node_supported(submods, node) - for x in op_support - ) + return all(x.is_node_supported(submods, node) for x in op_support) + return create_op_support(_chain) @@ -167,11 +174,10 @@ def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: instance by evaluating each input `OperatorSupportBase` instance, and returns True if any of it reports True. """ + def _any_chain(submods, node) -> bool: - return any( - x.is_node_supported(submods, node) - for x in op_support - ) + return any(x.is_node_supported(submods, node) for x in op_support) + return create_op_support(_any_chain) @@ -180,6 +186,7 @@ class OpSupports: """A set of atomic `OperatorSupportBase` instances that can be combined together to form more complex operator support logic. """ + @classmethod def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase: """Report a node as non-supported, if any of its arguments is of dtype""" @@ -193,6 +200,7 @@ def _decline_if_input_dtype( if arg_dtype == dtype: return False return True + return create_op_support(_decline_if_input_dtype) @classmethod @@ -200,16 +208,22 @@ def decline_if_node_in_names(cls, disallow_set: t.Set[str]) -> OperatorSupportBa """ If a node has a name that is in the disallow set, reported it as non-supported. """ + def _decline_if_node_in_names( submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node, ) -> bool: return node.name not in disallow_set + return create_op_support(_decline_if_node_in_names) def _get_arg_dtype(arg: torch.fx.Node) -> t.Any: assert isinstance(arg, torch.fx.Node) tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr] - dtype = tensor_meta.dtype if isinstance(tensor_meta, TensorMetadata) else arg.meta["type"] + dtype = ( + tensor_meta.dtype + if isinstance(tensor_meta, TensorMetadata) + else arg.meta["type"] + ) return dtype diff --git a/torch/fx/passes/param_fetch.py b/torch/fx/passes/param_fetch.py index 5979e29fcc6b2..3eba16b06b035 100644 --- a/torch/fx/passes/param_fetch.py +++ b/torch/fx/passes/param_fetch.py @@ -1,35 +1,59 @@ -from torch.fx.graph_module import GraphModule from typing import Any, Callable, Dict, List, Tuple, Type + import torch import torch.nn as nn - from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule + + +__all__ = [ + "default_matching", + "extract_attrs_for_lowering", + "lift_lowering_attrs_to_nodes", +] -__all__ = ['default_matching', 'extract_attrs_for_lowering', 'lift_lowering_attrs_to_nodes'] # Matching method matches the attribute name of current version to the attribute name of `target_version` @compatibility(is_backward_compatible=False) def default_matching(name: str, target_version: int) -> str: - """Default matching method - """ + """Default matching method""" return name + # This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering. # The first integer in the tuple is the version number of the nn.Module class when we create the parameter list. # If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module. module_fetch_book: Dict[Type, Tuple[int, List[str], Callable[[str, int], str]]] = { torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching), torch.nn.modules.conv.Conv2d: ( - 1, ["weight", "bias", "kernel_size", "stride", "padding", "dilation", "groups", "padding_mode"], default_matching + 1, + [ + "weight", + "bias", + "kernel_size", + "stride", + "padding", + "dilation", + "groups", + "padding_mode", + ], + default_matching, + ), + torch.nn.modules.batchnorm.BatchNorm2d: ( + 2, + ["weight", "bias", "running_mean", "running_var", "eps"], + default_matching, ), - torch.nn.modules.batchnorm.BatchNorm2d: (2, ["weight", "bias", "running_mean", "running_var", "eps"], default_matching), torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching), torch.nn.modules.pooling.MaxPool2d: ( - 1, ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], default_matching + 1, + ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], + default_matching, ), torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching), } + @compatibility(is_backward_compatible=False) def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]: """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book` @@ -41,21 +65,25 @@ def extract_attrs_for_lowering(mod: nn.Module) -> Dict[str, Any]: if type(mod) in module_fetch_book: version, param_to_fetch, matching_method = module_fetch_book[type(mod)] if version < mod._version: - raise RuntimeError(f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, " - "please upgrade the module_fetch_book, open an issue and @842974287 " - "or report a bug to AIACC team directly.") + raise RuntimeError( + f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, " + "please upgrade the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly." + ) for attr in param_to_fetch: attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version)) else: - raise RuntimeError(f"{torch.typename(mod)} is not in the module_fetch_book yet, " - "please add it to the module_fetch_book, open an issue and @842974287 " - "or report a bug to AIACC team directly.") + raise RuntimeError( + f"{torch.typename(mod)} is not in the module_fetch_book yet, " + "please add it to the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly." + ) return attrs_for_lowering + @compatibility(is_backward_compatible=False) def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None: - """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module. - """ + """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module.""" submodules = dict(fx_module.named_modules()) for node in fx_module.graph.nodes: @@ -63,4 +91,6 @@ def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None: if isinstance(submodules[node.target], GraphModule): lift_lowering_attrs_to_nodes(submodules[node.target]) else: - node.attrs_for_lowering = extract_attrs_for_lowering(submodules[node.target]) + node.attrs_for_lowering = extract_attrs_for_lowering( + submodules[node.target] + ) diff --git a/torch/fx/passes/pass_manager.py b/torch/fx/passes/pass_manager.py index 3cc4ff5e07090..eb793aa6f11e9 100644 --- a/torch/fx/passes/pass_manager.py +++ b/torch/fx/passes/pass_manager.py @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs +import logging from functools import wraps from inspect import unwrap from typing import Callable, List, Optional -import logging + logger = logging.getLogger(__name__) @@ -15,6 +16,7 @@ "these_before_those_pass_constraint", ] + # for callables which modify object inplace and return something other than # the object on which they act def inplace_wrapper(fn: Callable) -> Callable: @@ -31,11 +33,12 @@ def inplace_wrapper(fn: Callable) -> Callable: @wraps(fn) def wrapped_fn(gm): - val = fn(gm) + fn(gm) return gm return wrapped_fn + def log_hook(fn: Callable, level=logging.INFO) -> Callable: """ Logs callable output. @@ -48,16 +51,13 @@ def log_hook(fn: Callable, level=logging.INFO) -> Callable: ``` def my_pass(d: Dict) -> bool: changed = False - if 'foo' in d: - d['foo'] = 'bar' + if "foo" in d: + d["foo"] = "bar" changed = True return changed - pm = PassManager( - passes=[ - inplace_wrapper(log_hook(my_pass)) - ] - ) + + pm = PassManager(passes=[inplace_wrapper(log_hook(my_pass))]) ``` Args: @@ -67,6 +67,7 @@ def my_pass(d: Dict) -> bool: Returns: wrapped_fn (Callable[Type1, Type2]) """ + @wraps(fn) def wrapped_fn(gm): val = fn(gm) @@ -76,8 +77,11 @@ def wrapped_fn(gm): return wrapped_fn - -def loop_pass(base_pass: Callable, n_iter: Optional[int] = None, predicate: Optional[Callable] = None): +def loop_pass( + base_pass: Callable, + n_iter: Optional[int] = None, + predicate: Optional[Callable] = None, +): """ Convenience wrapper for passes which need to be applied multiple times. @@ -154,9 +158,7 @@ def these_before_those_pass_constraint(these: Callable, those: Callable): loop_pass(pass_a, 5), ] - constraints = [ - these_before_those_pass_constraint(pass_a, pass_b) - ] + constraints = [these_before_those_pass_constraint(pass_a, pass_b)] ``` Args: diff --git a/torch/fx/passes/reinplace.py b/torch/fx/passes/reinplace.py index 76435b9d318af..3b61446a92f7e 100644 --- a/torch/fx/passes/reinplace.py +++ b/torch/fx/passes/reinplace.py @@ -1,32 +1,38 @@ # mypy: allow-untyped-defs +import _operator +import itertools +from collections import defaultdict +from enum import Enum +from typing import Dict, Set + import torch +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode from torch.fx import Node from torch.fx._compatibility import compatibility -from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensor -from torch.utils._pytree import tree_map_only -from torch.utils import _pytree as pytree from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_map_only -import _operator -from enum import Enum -import itertools -from typing import Set, Dict -from collections import defaultdict -__all__ = ['reinplace'] +__all__ = ["reinplace"] + class _ViewType(Enum): NonView = 0 SingleOutputView = 1 MultiOutputView = 2 + def _is_view_op(tgt): if tgt is not None and isinstance(tgt, torch._ops.OpOverload): schema = tgt._schema if len(schema.arguments) > 0: first_arg = schema.arguments[0] # check if op is a view - return first_arg.alias_info is not None and not first_arg.alias_info.is_write + return ( + first_arg.alias_info is not None and not first_arg.alias_info.is_write + ) + def _get_view_type(tgt) -> _ViewType: if tgt is not None and isinstance(tgt, torch._ops.OpOverload): @@ -36,7 +42,7 @@ def _get_view_type(tgt) -> _ViewType: # check if op is a view if first_arg.alias_info is not None and not first_arg.alias_info.is_write: # check if op is a multi-output view - if '*' in first_arg.alias_info.after_set: + if "*" in first_arg.alias_info.after_set: return _ViewType.MultiOutputView else: return _ViewType.SingleOutputView @@ -54,12 +60,11 @@ def _get_view_type(tgt) -> _ViewType: # to sanity check that our aliasing information is correct. @compatibility(is_backward_compatible=False) class _FunctionalizationMetadataProp(torch.fx.Interpreter): - def run_node(self, node: Node): self.node_counter += 1 result = super().run_node(node) - node.meta['fake_result'] = result - node.meta['node_idx'] = self.node_counter + node.meta["fake_result"] = result + node.meta["node_idx"] = self.node_counter # (1) Update metadata with the list of nodes that are used by this node # copy_() doesn't read from its first argument; it writes to it, overwriting previous data. @@ -69,11 +74,11 @@ def run_node(self, node: Node): node_args = node_args[1:] # (2) Update metadata to track aliasing information about view tensor nodes. - if node.op == 'call_function': + if node.op == "call_function": view_type = _get_view_type(node.target) if view_type == _ViewType.SingleOutputView: assert isinstance(node.args[0], Node) - node.meta['view_of'] = node.args[0] + node.meta["view_of"] = node.args[0] elif view_type == _ViewType.MultiOutputView: self.multi_output_view_nodes[node] = node.args[0] @@ -95,38 +100,52 @@ def run_node(self, node: Node): # Note: we could also track indexing info here for multi-output views. # I don't think this metadata is strictly needed for de-functionalization. assert isinstance(maybe_base_of_view, Node) - node.meta['view_of'] = maybe_base_of_view + node.meta["view_of"] = maybe_base_of_view - if 'view_of' in node.meta: + if "view_of" in node.meta: # We're linking the current node with its first argument as views. # Assert here that this is actually the case, and their storages are the same. - assert isinstance(node.meta['fake_result'], FakeTensor) - assert isinstance(node.meta['view_of'].meta['fake_result'], FakeTensor) - view_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) - base_storage = StorageWeakRef(node.meta['view_of'].meta['fake_result']._typed_storage()) + assert isinstance(node.meta["fake_result"], FakeTensor) + assert isinstance(node.meta["view_of"].meta["fake_result"], FakeTensor) + view_storage = StorageWeakRef(node.meta["fake_result"]._typed_storage()) + base_storage = StorageWeakRef( + node.meta["view_of"].meta["fake_result"]._typed_storage() + ) assert view_storage == base_storage return result - - def propagate(self, *args): self.multi_output_view_nodes = {} self.node_counter = -1 with FakeTensorMode() as mode: - fake_args = [mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args] + fake_args = [ + mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args + ] return super().run(*fake_args) + def _schemas_match(functional_schema, inplace_schema): - names_match = inplace_schema.name.endswith("_") and inplace_schema.name[:-1] == functional_schema.name - arg_types_match = len(functional_schema.arguments) == len(inplace_schema.arguments) and all( - a1.type == a2.type for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments)) + names_match = ( + inplace_schema.name.endswith("_") + and inplace_schema.name[:-1] == functional_schema.name + ) + arg_types_match = len(functional_schema.arguments) == len( + inplace_schema.arguments + ) and all( + a1.type == a2.type + for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments) + ) # for the inplace op, its first argument should be mutable - assert inplace_schema.arguments[0].alias_info is not None and inplace_schema.arguments[0].alias_info.is_write + assert ( + inplace_schema.arguments[0].alias_info is not None + and inplace_schema.arguments[0].alias_info.is_write + ) # and its remaining arguments shouldn't be. assert all(a.alias_info is None for a in inplace_schema.arguments[1:]) return names_match and arg_types_match + # TODO: this should be beefed up to be able to properly re-inplace with: # - mutating ops (e.g. _fused_moving_avg_obs_fq_helper) # - out= ops (e.g. angle -> angle.out) @@ -143,17 +162,20 @@ def _maybe_get_inplace_op(op): op_namespace = op.__module__.split(".")[-1] op_base_name = op.overloadpacket.__name__ maybe_namespace_module = getattr(torch.ops, op_namespace) - maybe_inplace_op = None if maybe_namespace_module is None else getattr(maybe_namespace_module, f'{op_base_name}_', None) + maybe_inplace_op = ( + None + if maybe_namespace_module is None + else getattr(maybe_namespace_module, f"{op_base_name}_", None) + ) if maybe_inplace_op is None: return None inplace_overloads = [ - getattr(maybe_inplace_op, overload_name) for overload_name in maybe_inplace_op.overloads() + getattr(maybe_inplace_op, overload_name) + for overload_name in maybe_inplace_op.overloads() ] inplace_overloads_with_matching_schemas = [ - f - for f in inplace_overloads - if _schemas_match(op._schema, f._schema) + f for f in inplace_overloads if _schemas_match(op._schema, f._schema) ] # Just because foo() and foo_() are both existing operators, # They aren't guaranteed to have compatible schemas. @@ -165,6 +187,7 @@ def _maybe_get_inplace_op(op): inplace_op = inplace_overloads_with_matching_schemas[0] return inplace_op + _VIEW_INVERSE_MAP = { torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, @@ -172,6 +195,7 @@ def _maybe_get_inplace_op(op): torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, } + # This function, given a set of set of (aliased) tensor nodes, # Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index # in the node ordering. @@ -186,17 +210,21 @@ def _add_if_tensor(x, set_): usage_nodes = t.users for n in usage_nodes: # We only care about usages after the current node - if 'node_idx' not in n.meta or n.meta['node_idx'] <= op_index: + if "node_idx" not in n.meta or n.meta["node_idx"] <= op_index: continue # We also don't care about intermediate view ops. # They only matter if their output is then used elsewhere # (either in an out-of-place op, or as an output to the function). if n in tensor_aliases: - if isinstance(n.target, torch._ops.OpOverload) or n.target == _operator.getitem: + if ( + isinstance(n.target, torch._ops.OpOverload) + or n.target == _operator.getitem + ): continue nodes_used_after.add(n) return nodes_used_after + # Given an op that we're trying to re-inplace, "b = foo(a)", # And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)" # Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF: @@ -204,23 +232,27 @@ def _add_if_tensor(x, set_): # (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base" # (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata # as "alias" -def _get_view_inverse_node_usages(later_node_usages: Set[Node], self_aliases: Set[Node]) -> Set[Node]: +def _get_view_inverse_node_usages( + later_node_usages: Set[Node], self_aliases: Set[Node] +) -> Set[Node]: def matching_view_metadata(a, b): - return a.size() == b.size() and \ - a.stride() == b.stride() and \ - a.storage_offset() == b.storage_offset() + return ( + a.size() == b.size() + and a.stride() == b.stride() + and a.storage_offset() == b.storage_offset() + ) view_inverse_nodes = set() # Go through them in node order, so we can see chains of view_scatter ops. - for n in sorted(later_node_usages, key=lambda x: x.meta['node_idx']): + for n in sorted(later_node_usages, key=lambda x: x.meta["node_idx"]): if n.target not in _VIEW_INVERSE_MAP: continue base = n.args[0] mutated_view = n.args[1] assert isinstance(base, Node) - assert isinstance(base.meta['fake_result'], FakeTensor) + assert isinstance(base.meta["fake_result"], FakeTensor) assert isinstance(mutated_view, Node) - assert isinstance(mutated_view.meta['fake_result'], FakeTensor) + assert isinstance(mutated_view.meta["fake_result"], FakeTensor) # Check that this view_inverse op actually corresponds to taking doing the inverse # of one of our existing self_alias nodes. original_view = _VIEW_INVERSE_MAP[n.target] @@ -229,18 +261,21 @@ def matching_view_metadata(a, b): # that was created from some op `alias = foo(base, args...)` # such that the current _scatter op "inverts" that foo call. # We can check that by running the original op again, and checking that the strides match. - if 'view_of' not in self_alias.meta: + if "view_of" not in self_alias.meta: continue - self_alias_base = self_alias.meta['view_of'] + self_alias_base = self_alias.meta["view_of"] try: # The we're trying to re-use the args from the view_scatter call inside of the corresponding # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse # of the current alias we're looking at. - view_replay_metadata = original_view(self_alias_base.meta['fake_result'], *n.args[2:], **n.kwargs) - expected_metadata = self_alias.meta['fake_result'] + view_replay_metadata = original_view( + self_alias_base.meta["fake_result"], *n.args[2:], **n.kwargs + ) + expected_metadata = self_alias.meta["fake_result"] # If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace. - if matching_view_metadata(self_alias_base.meta['fake_result'], base.meta['fake_result']) and \ - matching_view_metadata(view_replay_metadata, expected_metadata): + if matching_view_metadata( + self_alias_base.meta["fake_result"], base.meta["fake_result"] + ) and matching_view_metadata(view_replay_metadata, expected_metadata): view_inverse_nodes.add(n) except Exception: continue @@ -471,25 +506,29 @@ def f(x): # NOTE: later, we'll need to add an optimization for fully recovering performance # on programs that mutate inputs. input_storages = { - StorageWeakRef( - node.meta['fake_result']._typed_storage() - ) for node in gm.graph.nodes if (node.op == 'placeholder' and isinstance(node.meta['fake_result'], torch.Tensor))} + StorageWeakRef(node.meta["fake_result"]._typed_storage()) + for node in gm.graph.nodes + if ( + node.op == "placeholder" + and isinstance(node.meta["fake_result"], torch.Tensor) + ) + } # We also need to know for a given node, what are all of its aliasing nodes. storage_to_nodes: Dict[StorageWeakRef, Set[Node]] = defaultdict(set) for n in gm.graph.nodes: - if 'fake_result' in n.meta: + if "fake_result" in n.meta: # Tree-mapping because some ops can return lists of tensors. def _add_to_map(x): if isinstance(x, FakeTensor): storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n) - pytree.tree_map_(_add_to_map, n.meta['fake_result']) + + pytree.tree_map_(_add_to_map, n.meta["fake_result"]) # inplace-ify functional ops, subject to the constraints written below. all_later_view_inverse_nodes_to_delete = set() - for idx, node in enumerate(gm.graph.nodes): - if node.op == 'call_function': - + for node in gm.graph.nodes: + if node.op == "call_function": # Today, the re-inplace pass on directly acts on: # - functional ops with an inplace variant # - {view}_scatter ops that can be potentially removed from the graph. @@ -512,8 +551,8 @@ def _add_to_map(x): # (We could potentially swizzle this into larger_tensor.add_(scalar_tensor), # this is probably an optimization to revisit later). self_arg = node.args[0] - self_flattened = pytree.tree_leaves(self_arg.meta['fake_result']) - node_flattened = pytree.tree_leaves(node.meta['fake_result']) + self_flattened = pytree.tree_leaves(self_arg.meta["fake_result"]) + node_flattened = pytree.tree_leaves(node.meta["fake_result"]) self_has_wrong_metadata = False if len(self_flattened) == len(node_flattened): for self_meta, node_meta in zip(self_flattened, node_flattened): @@ -532,8 +571,9 @@ def _add_to_map(x): continue # Step 1b: ensure that the op we're trying to re-inplace isn't a program input - self_arg_name = self_arg.name - self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) + self_arg_storage = StorageWeakRef( + self_arg.meta["fake_result"]._typed_storage() + ) if self_arg_storage in input_storages: # TODO: later, add the optimization for handling `copy_()` calls in the graph. continue @@ -543,14 +583,20 @@ def _add_to_map(x): # so we prevent re-inplacing in this case. continue - self_arg_storage = StorageWeakRef(self_arg.meta['fake_result']._typed_storage()) + self_arg_storage = StorageWeakRef( + self_arg.meta["fake_result"]._typed_storage() + ) self_aliases = storage_to_nodes[self_arg_storage] # First, we find all later usages of any of the aliases of self_arg. - later_node_usages = _get_all_later_node_usages(self_aliases, node.meta['node_idx']) + later_node_usages = _get_all_later_node_usages( + self_aliases, node.meta["node_idx"] + ) # Then, we check if any of those later usages are actually view_scatter ops # that are safe to fully remove. - later_view_inverse_node_usages = _get_view_inverse_node_usages(later_node_usages, self_aliases) + later_view_inverse_node_usages = _get_view_inverse_node_usages( + later_node_usages, self_aliases + ) # Step 2: Check to see if the input to the op is re-used later in the graph. # If not (same goes for its aliases), then this op is safe to re-in place. @@ -566,7 +612,10 @@ def _add_to_map(x): # we would prefer to remove it from the graph entirely, # and instead copy_() the slice directly into the larger tensor. # See the description of the algorithm for a full example. - if node.target in _VIEW_INVERSE_MAP and node not in all_later_view_inverse_nodes_to_delete: + if ( + node.target in _VIEW_INVERSE_MAP + and node not in all_later_view_inverse_nodes_to_delete + ): view_op = _VIEW_INVERSE_MAP[node.target] # Before: # base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...) @@ -577,13 +626,23 @@ def _add_to_map(x): mutated_slice_node = node.args[1] remaining_slice_args = node.args[2:] slice_node = gm.graph.create_node( - 'call_function', view_op, (self_arg,) + tuple(remaining_slice_args), node.kwargs) - copy_node = gm.graph.create_node( - 'call_function', torch.ops.aten.copy_.default, (slice_node, mutated_slice_node,), {}) + "call_function", + view_op, + (self_arg,) + tuple(remaining_slice_args), + node.kwargs, + ) + gm.graph.create_node( + "call_function", + torch.ops.aten.copy_.default, + ( + slice_node, + mutated_slice_node, + ), + {}, + ) # Add the slice_scatter node to our "nodes to delete" list. all_later_view_inverse_nodes_to_delete.add(node) - else: # Step 3b: Check to see if this operator has an inplace variant. maybe_inplace_op = _maybe_get_inplace_op(node.target) @@ -598,22 +657,30 @@ def _add_to_map(x): # Hmm... morally I think we also want to keep the `fake_result` metadata # up to date here, but I'm not sure how easy it is to do. # Maybe it's fine to wait until the end of the pass to update it. - curr_node_storage = StorageWeakRef(node.meta['fake_result']._typed_storage()) - storage_to_nodes[self_arg_storage].update(storage_to_nodes[curr_node_storage]) - storage_to_nodes[curr_node_storage].update(storage_to_nodes[self_arg_storage]) + curr_node_storage = StorageWeakRef( + node.meta["fake_result"]._typed_storage() + ) + storage_to_nodes[self_arg_storage].update( + storage_to_nodes[curr_node_storage] + ) + storage_to_nodes[curr_node_storage].update( + storage_to_nodes[self_arg_storage] + ) # Need to remember the view_scatter view nodes we found so we can remove them alter. - all_later_view_inverse_nodes_to_delete.update(later_view_inverse_node_usages) + all_later_view_inverse_nodes_to_delete.update( + later_view_inverse_node_usages + ) # Step 4: # Now that we've replaced b = a.foo() with a.foo_(), # We need to replace any later usages of "b" with "a" for old in itertools.chain([node], later_view_inverse_node_usages): new = old.args[0] - nodes_to_update = [n for n in old.users if n.meta['node_idx'] > node.meta['node_idx']] + nodes_to_update = [ + n for n in old.users if n.meta["node_idx"] > node.meta["node_idx"] + ] for node_to_update in nodes_to_update: - new_args = [] - args = node_to_update.args def replace_arg(a): if a == old: @@ -621,21 +688,29 @@ def replace_arg(a): return a # First, replace usages of "b" with "a" - node_to_update.args = tree_map_only(Node, replace_arg, node_to_update.args) - node_to_update.kwargs = tree_map_only(Node, replace_arg, node_to_update.kwargs) + node_to_update.args = tree_map_only( + Node, replace_arg, node_to_update.args + ) + node_to_update.kwargs = tree_map_only( + Node, replace_arg, node_to_update.kwargs + ) # Second, update our storage_to_nodes data structure. - old_flattened_res = pytree.tree_leaves(old.meta['fake_result']) - node_flattened_res = pytree.tree_leaves(node_to_update.meta['fake_result']) + old_flattened_res = pytree.tree_leaves(old.meta["fake_result"]) + node_flattened_res = pytree.tree_leaves( + node_to_update.meta["fake_result"] + ) old_res_storage = { - StorageWeakRef( - x._typed_storage() - ) for x in old_flattened_res if isinstance(x, FakeTensor)} + StorageWeakRef(x._typed_storage()) + for x in old_flattened_res + if isinstance(x, FakeTensor) + } node_res_storage = { - StorageWeakRef( - x._typed_storage() - ) for x in node_flattened_res if isinstance(x, FakeTensor)} + StorageWeakRef(x._typed_storage()) + for x in node_flattened_res + if isinstance(x, FakeTensor) + } # This will happen if we're updating a view op, e.g. # e.g. replacing @@ -647,14 +722,18 @@ def replace_arg(a): # We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor, # or multiple tensors that all share the same storage. # We can't just check equality because we might encounter FX nodes that return zero tensor outputs. - if len(old_res_storage) == 1 and len(node_res_storage) == 1 and old_res_storage == node_res_storage: - new_flattened_res = pytree.tree_leaves(new.meta['fake_result']) + if ( + len(old_res_storage) == 1 + and len(node_res_storage) == 1 + and old_res_storage == node_res_storage + ): + new_flattened_res = pytree.tree_leaves(new.meta["fake_result"]) new_res_storage = { - StorageWeakRef( - x._typed_storage() - ) for x in new_flattened_res if isinstance(x, FakeTensor)} + StorageWeakRef(x._typed_storage()) + for x in new_flattened_res + if isinstance(x, FakeTensor) + } assert len(new_res_storage) == 1 - (old_ref,) = old_res_storage (new_ref,) = new_res_storage (node_ref,) = node_res_storage # Technically, "old_ref" and all its aliases will remain @@ -670,6 +749,5 @@ def replace_arg(a): for to_delete in all_later_view_inverse_nodes_to_delete: gm.graph.erase_node(to_delete) - gm.recompile() return gm diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index ffd7c6b908e8a..1e660827a538f 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -166,9 +166,9 @@ def _node_metadata_hook( nn_module_stack: Optional[Dict[str, Any]] = None, ) -> None: fake_args = pytree.tree_map( - lambda arg: _get_example_value(arg) - if isinstance(arg, torch.fx.Node) - else arg, + lambda arg: ( + _get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg + ), node.args, ) try: @@ -530,10 +530,10 @@ def go(node, keypath): # effort basis should do. # # The second issue is a preexisting one. It can be mitigated - # with a normalisation algorithm. In general, it may also + # with a normalization algorithm. In general, it may also # be on a best effort basis, but since our grammar is not # terribly difficult, chances are we could even fully - # normalise SymPy expressions... who knows. + # normalize SymPy expressions... who knows. if i0 in constrained_unbacked_symbols: continue # constrain symbol just once diff --git a/torch/fx/passes/shape_prop.py b/torch/fx/passes/shape_prop.py index dcaee3f821139..4931e840707ee 100644 --- a/torch/fx/passes/shape_prop.py +++ b/torch/fx/passes/shape_prop.py @@ -1,17 +1,19 @@ # mypy: ignore-errors -import torch -import torch.fx import traceback +from typing import Any, Dict, NamedTuple, Optional, Tuple +import torch +import torch.fx from torch._dispatch.python import enable_python_dispatcher -from torch.fx.node import Node, map_aggregate -from typing import Any, Tuple, NamedTuple, Optional, Dict -from torch.fx._compatibility import compatibility from torch._guards import detect_fake_mode from torch._subclasses.meta_utils import is_sparse_any +from torch.fx._compatibility import compatibility +from torch.fx.node import map_aggregate, Node + + +__all__ = ["TensorMetadata", "ShapeProp"] -__all__ = ['TensorMetadata', 'ShapeProp'] @compatibility(is_backward_compatible=True) class TensorMetadata(NamedTuple): @@ -19,17 +21,20 @@ class TensorMetadata(NamedTuple): # about a tensor within a PyTorch program. # General Tensor metadata - shape : torch.Size - dtype : torch.dtype - requires_grad : bool - stride : Tuple[int, ...] - memory_format : Optional[torch.memory_format] + shape: torch.Size + dtype: torch.dtype + requires_grad: bool + stride: Tuple[int, ...] + memory_format: Optional[torch.memory_format] # Quantization metadata - is_quantized : bool + is_quantized: bool qparams: Dict[str, Any] -def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> TensorMetadata: + +def _extract_tensor_metadata( + result: torch.Tensor, include_contiguity=True +) -> TensorMetadata: """ Extract a TensorMetadata NamedTuple describing `result`. """ @@ -59,7 +64,11 @@ def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: qparams["scale"] = result.q_scale() # type: ignore[assignment] qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] - elif qscheme in {torch.per_channel_affine, torch.per_channel_affine_float_qparams, torch.per_channel_symmetric}: + elif qscheme in { + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + torch.per_channel_symmetric, + }: # In this branch, scale and zero_point are expected to be tensors, # we store the values as immutable_list in TensorMetadata for # easier serialization downstream @@ -68,7 +77,9 @@ def _extract_tensor_metadata(result : torch.Tensor, include_contiguity=True) -> qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] return TensorMetadata( - shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams) + shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams + ) + @compatibility(is_backward_compatible=True) class ShapeProp(torch.fx.Interpreter): @@ -117,12 +128,14 @@ def forward(self, x): fake_mode (FakeTensorMode): A fake mode for copying the gm """ + def __init__(self, gm, fake_mode=None): super().__init__(gm) if fake_mode is None: fake_mode = detect_fake_mode() if fake_mode is not None: from torch._dynamo.utils import deepcopy_to_fake_tensor + # Note: # We need fake execution cause the inputs are fake, however, we cannot fakify the module # - because we need to write to the tensor_meta of the real module. So we fakify to @@ -140,7 +153,7 @@ def __init__(self, gm, fake_mode=None): self.real_module = self.module - def run_node(self, n : Node) -> Any: + def run_node(self, n: Node) -> Any: try: if self.fake_module is not None: # Hacky swap. Alternatively, we could do this with overriding @@ -157,8 +170,7 @@ def run_node(self, n : Node) -> Any: except Exception as e: traceback.print_exc() raise RuntimeError( - f"ShapeProp error for: node={n.format_node()} with " - f"meta={n.meta}" + f"ShapeProp error for: node={n.format_node()} with " f"meta={n.meta}" ) from e found_tensor = False @@ -173,9 +185,9 @@ def extract_tensor_meta(obj): meta = map_aggregate(result, extract_tensor_meta) if found_tensor: - n.meta['tensor_meta'] = meta + n.meta["tensor_meta"] = meta - n.meta['type'] = type(result) + n.meta["type"] = type(result) return result def propagate(self, *args): @@ -190,7 +202,10 @@ def propagate(self, *args): Any: The value returned from executing the Module """ if self.fake_mode is not None: - fake_args = [self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t for t in args] + fake_args = [ + self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in args + ] else: fake_args = args return super().run(*fake_args) diff --git a/torch/fx/passes/split_module.py b/torch/fx/passes/split_module.py index 1881beaf2ece1..7df05aac83fa1 100644 --- a/torch/fx/passes/split_module.py +++ b/torch/fx/passes/split_module.py @@ -1,19 +1,20 @@ # mypy: allow-untyped-defs import inspect -from typing import Any, Callable, Dict, List, Optional, Set -from collections import OrderedDict import logging +from collections import OrderedDict +from typing import Any, Callable, Dict, List, Optional, Set import torch from torch.fx._compatibility import compatibility +from torch.fx._utils import lazy_format_graph_code from torch.fx.graph_module import GraphModule from torch.fx.node import Node -from torch.fx._utils import lazy_format_graph_code __all__ = ["Partition", "split_module"] log = _LOGGER = logging.getLogger(__name__) + @compatibility(is_backward_compatible=True) class Partition: def __init__(self, name: str): @@ -39,6 +40,15 @@ def __repr__(self) -> str: ) +def _get_attr_from_qualname(mod: torch.nn.Module, qualname: str) -> Any: + attr_val = mod + for atom in qualname.split("."): # type: ignore[union-attr] + if not hasattr(attr_val, atom): + raise AttributeError(f"Node target {qualname} not found!") + attr_val = getattr(attr_val, atom) + return attr_val + + # Creates subgraphs out of main graph @compatibility(is_backward_compatible=True) def split_module( @@ -146,9 +156,7 @@ def forward(self, x, y): log.debug( "%s", - lazy_format_graph_code( - "pre split_module", m, colored=True - ), + lazy_format_graph_code("pre split_module", m, colored=True), ) def construct_graph( @@ -161,21 +169,27 @@ def construct_graph( node.args[0] if len(node.args) > 0 else inspect.Signature.empty ) if keep_original_node_name: - args = () if default_value is inspect.Signature.empty else (default_value,) - base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type) # type: ignore[arg-type] + args = ( + () if default_value is inspect.Signature.empty else (default_value,) + ) + base_mod_env[node.name] = base_mod_graph.create_node( + "placeholder", + node.name, + args=args, # type: ignore[arg-type] + type_expr=node.type, + ) else: base_mod_env[node.name] = base_mod_graph.placeholder( - node.target, type_expr=node.type, default_value=default_value # type: ignore[arg-type] + node.target, # type: ignore[arg-type] + type_expr=node.type, + default_value=default_value, ) base_mod_env[node.name].meta = node.meta.copy() elif node.op == "get_attr": base_mod_env[node.name] = base_mod_graph.get_attr(node.target) # type: ignore[arg-type] base_mod_env[node.name].meta = node.meta.copy() - attr_val = m - for atom in node.target.split("."): # type: ignore[union-attr] - if not hasattr(attr_val, atom): - raise AttributeError(f"Node target {node.target} not found!") - attr_val = getattr(attr_val, atom) + assert isinstance(node.target, str) + attr_val = _get_attr_from_qualname(m, node.target) base_mod_attrs[node.target] = attr_val # type: ignore[index] return base_mod_env, base_mod_attrs @@ -185,9 +199,7 @@ def construct_graph( orig_nodes: Dict[str, Node] = {} symbol_to_node: Dict[sympy.Symbol, Node] = {} - def record_cross_partition_use( - def_node: Node, use_node: Optional[Node] - ): # noqa: B950 + def record_cross_partition_use(def_node: Node, use_node: Optional[Node]): from torch.fx.experimental.symbolic_shapes import free_symbols defined = getattr(def_node, "_fx_partition", None) @@ -195,7 +207,10 @@ def record_cross_partition_use( log.debug( "record_cross_partition_use %s (%s) %s (%s)", - def_node.name, defined, use_node.name if use_node is not None else "-", used + def_node.name, + defined, + use_node.name if use_node is not None else "-", + used, ) if defined != used: @@ -234,7 +249,9 @@ def record_cross_partition_use( def instantiate_node_partition_mapping(node): partition_name = str(split_callback(node)) - log.debug("instantiate_node_partition_mapping %s (%s)", node.name, partition_name) + log.debug( + "instantiate_node_partition_mapping %s (%s)", node.name, partition_name + ) # add node to partitions partition = partitions.get(partition_name) @@ -249,7 +266,7 @@ def instantiate_node_partition_mapping(node): GLOBAL_STATE_NODES = [ torch.amp._enter_autocast, torch.amp._exit_autocast, - torch._C._set_grad_enabled + torch._C._set_grad_enabled, ] # For grad regions: @@ -280,10 +297,10 @@ def instantiate_node_partition_mapping(node): # rely on later, but this needs some extra work. Quick fix first. # See https://github.com/pytorch/pytorch/issues/130534 if ( - (val := node.meta.get("example_value")) is not None and - isinstance(val, torch.SymInt) and - isinstance(s0 := val.node.expr, sympy.Symbol) and - s0 not in symbol_to_node + (val := node.meta.get("example_value")) is not None + and isinstance(val, (torch.SymInt, torch.SymFloat)) + and isinstance(s0 := val.node.expr, sympy.Symbol) + and s0 not in symbol_to_node ): symbol_to_node[val.node.expr] = node @@ -344,9 +361,10 @@ def instantiate_node_partition_mapping(node): if assert_monotonically_increasing: pid = split_callback(node) - assert highest_partition <= pid, \ - ("autocast or set_grad_enabled require monotonically increasing partitions:" - f"highest: {highest_partition}, this node's: {pid}") + assert highest_partition <= pid, ( + "autocast or set_grad_enabled require monotonically increasing partitions:" + f"highest: {highest_partition}, this node's: {pid}" + ) highest_partition = pid # do not capture cross-partition dependencies for global state nodes as they will be @@ -392,19 +410,42 @@ def instantiate_node_partition_mapping(node): kwargs={}, type_expr=node.type, ) - new_node.meta = node.meta.copy() # is it really a good idea to copy this? + new_node.meta = ( + node.meta.copy() + ) # is it really a good idea to copy this? partition.environment[node] = new_node # add placeholders to partition inputs for partition_name in sorted_partitions: partition = partitions[partition_name] + new_inputs: Dict[str, None] = {} for inp in partition.inputs: - placeholder = partition.graph.placeholder( - inp, - type_expr=orig_nodes[inp].type, - ) + orig_node = orig_nodes[inp] + # We don't pass in get_attr nodes as inputs to the partition, but + # instead set them as targets and use getattr within the module + + if orig_node.op == "get_attr": + assert isinstance(orig_node.target, str) + + orig_attr = _get_attr_from_qualname(m, orig_node.target) + if isinstance(orig_attr, torch.nn.Module): + placeholder = partition.graph.get_attr(orig_node.target) + partition.targets[orig_node.target] = orig_attr + else: + placeholder = partition.graph.placeholder( + inp, + type_expr=orig_nodes[inp].type, + ) + new_inputs[inp] = None + else: + placeholder = partition.graph.placeholder( + inp, + type_expr=orig_nodes[inp].type, + ) + new_inputs[inp] = None placeholder.meta = orig_nodes[inp].meta.copy() partition.environment[orig_nodes[inp]] = placeholder + partition.inputs = new_inputs # Transform nodes and collect targets for partition's submodule for node in m.graph.nodes: @@ -421,14 +462,8 @@ def instantiate_node_partition_mapping(node): if node.op not in ["call_module", "get_attr"]: target = node.target else: - target_atoms = node.target.split(".") - target_attr = m - for atom in target_atoms: - if not hasattr(target_attr, atom): - raise AttributeError(f"Operator target {node.target} not found!") - target_attr = getattr(target_attr, atom) - # target = target_atoms[-1] - target = "_".join(target_atoms) + target_attr = _get_attr_from_qualname(m, node.target) + target = node.target.replace(".", "_") partition.targets[target] = target_attr # Fill in the passed-in mapping from new qualname to old qualname if qualname_map is not None: @@ -467,7 +502,9 @@ def instantiate_node_partition_mapping(node): kwargs={}, type_expr=exit_node.type, ) - new_node.meta = exit_node.meta.copy() # is it really a good idea to copy this? + new_node.meta = ( + exit_node.meta.copy() + ) # is it really a good idea to copy this? # original module environment dict mapping node names to nodes orig_mod_env: Dict[str, Node] = {} @@ -520,13 +557,15 @@ def instantiate_node_partition_mapping(node): if keep_original_order: # first get the attr nodes required by this partition orig_mod_attr_nodes: List[Node] = [ - orig_mod_env[key] for key in partition.inputs if key not in original_order + orig_mod_env[key] + for key in partition.inputs + if key not in original_order ] for node in original_order: if node in already_constructed_attr_nodes: continue # already added this attr to the base graph - base_mod_env, based_mod_attrs = construct_graph( + base_mod_env, _based_mod_attrs = construct_graph( node, base_mod_env, base_mod_attrs ) already_constructed_attr_nodes.add(node) @@ -568,8 +607,6 @@ def instantiate_node_partition_mapping(node): ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) log.debug( "%s", - lazy_format_graph_code( - "post split_module", ret, colored=True - ), + lazy_format_graph_code("post split_module", ret, colored=True), ) return ret diff --git a/torch/fx/passes/split_utils.py b/torch/fx/passes/split_utils.py index 1c003966983f3..e2bece6f72f27 100644 --- a/torch/fx/passes/split_utils.py +++ b/torch/fx/passes/split_utils.py @@ -10,6 +10,7 @@ from .tools_common import NodeList + __all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"] diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index 70b117c8ca374..31cb357df353d 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -1,40 +1,44 @@ # mypy: allow-untyped-defs import argparse import copy +import logging from collections import defaultdict from dataclasses import dataclass -from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple -import logging +from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Sequence, Tuple import torch -from torch.fx.passes.graph_manipulation import get_size_of_node -from torch.fx.node import map_arg from torch.fx._compatibility import compatibility +from torch.fx.node import map_arg +from torch.fx.passes.graph_manipulation import get_size_of_node -from .operator_support import ( - get_node_target, - OperatorSupportBase, -) from .graph_drawer import FxGraphDrawer +from .operator_support import get_node_target, OperatorSupportBase from .shape_prop import ShapeProp from .split_utils import split_by_tags from .tools_common import ( - FxNetAccFusionsFinder, CALLABLE_NODE_OPS, - Tensors, + FxNetAccFusionsFinder, + is_node_output_tensor, NodeList, NodeSet, - is_node_output_tensor, + Tensors, ) -__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules'] +__all__ = [ + "FxNetAccNodesFinder", + "FxNetSplitterInternalError", + "Subgraph", + "SplitResult", + "generate_inputs_for_submodules", +] _LOGGER = logging.getLogger(__name__) DEFAULT_MIN_ACC_MODULE_SIZE = 1 DEFAULT_SKIP_FUSION = False DEFAULT_ALLOW_NON_TENSOR = False + class _SplitterSettingBase: def __init__( self, @@ -80,11 +84,17 @@ def __init__( "we might not care about non-tensor data flow and we can set this option " "to true to disable the functionality that prevent non-tensor data flow.", ) - args, unknown = parser.parse_known_args() + args, _unknown = parser.parse_known_args() - self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size + self.min_acc_module_size: int = ( + args.min_acc_module_size + if args.min_acc_module_size + else min_acc_module_size + ) self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion - self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor + self.allow_non_tensor: bool = ( + args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor + ) self.max_acc_splits: int = max_acc_splits @@ -114,9 +124,7 @@ def __init__( self.allow_non_tensor = allow_non_tensor self.acc_nodes: NodeSet = set() - def reduce_acc_nodes_non_tensor_input_helper( - self, cpu_worklist: NodeList - ): + def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList): """ Transitively excludes nodes from ACC supported set. For every node in the worklist: @@ -190,10 +198,12 @@ def __call__(self) -> NodeSet: return self.acc_nodes + @compatibility(is_backward_compatible=False) class FxNetSplitterInternalError(Exception): pass + @compatibility(is_backward_compatible=False) @dataclass class Subgraph: @@ -201,6 +211,7 @@ class Subgraph: nodes: NodeList device_ordinal: Optional[int] = None + @compatibility(is_backward_compatible=False) class SplitResult(NamedTuple): """ @@ -243,7 +254,9 @@ def generate_inputs_for_submodules( submodule_to_names = {mod: name for name, mod in model.named_modules()} def pre_forward(module, module_inputs): - results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs + results[submodule_to_names[module]] = ( + copy.deepcopy(module_inputs) if deepcopy else module_inputs + ) for name, mod in model.named_modules(): if name in target_submodules: @@ -308,7 +321,7 @@ def forward(self, sin_1, cos_1): """ # PCIe bandwidth for the backend, default to 100 GB/s - PCIe_BW = 100 * 2 ** 30 + PCIe_BW = 100 * 2**30 def __init__( self, @@ -335,7 +348,9 @@ def __init__( self.settings = settings self.operator_support = operator_support self.sample_input = sample_input - self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)() + self.acc_nodes = FxNetAccNodesFinder( + self.module, self.operator_support, self.settings.allow_non_tensor + )() if self.settings.skip_fusion: self.fusions = {} @@ -357,11 +372,11 @@ def __init__( # =============================================================== def get_node_submodule_map(self) -> Dict[str, str]: - """ Returns a map from node name to submodule name, e.g. - node: main_module_impl_impl_over_arch_unary_multiple_embedding - _pooling_embedding_pooling_sparse_entity_equivalence_key - _proxy_embedding_bag - maps to submodule name of: _run_on_acc_1 + """Returns a map from node name to submodule name, e.g. + node: main_module_impl_impl_over_arch_unary_multiple_embedding + _pooling_embedding_pooling_sparse_entity_equivalence_key + _proxy_embedding_bag + maps to submodule name of: _run_on_acc_1 """ return self._node_submodule_map @@ -411,9 +426,7 @@ def _lower_model_to_backend( return mod - def _find_culprit( - self, mod: torch.fx.GraphModule, inputs: Tensors - ) -> str: + def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors) -> str: """ When an error occurs during lowering or running the lowered mod, we use this function to find culprits in the `mod` that causes the error. @@ -492,7 +505,9 @@ def get_dtype(arg): supported_nodes.append(node) supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) else: - unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) + unsupported_node_types[target].add( + (arg_dtypes_tuple, kwarg_dtypes_tuple) + ) if dump_graph: self._draw_graph_based_on_node_support(self.module, supported_nodes) @@ -527,7 +542,11 @@ def split_preview(self, dump_graph: bool = False): reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" for i, subgraph in enumerate(subgraphs): - reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: " + reports += ( + f"_run_on_acc_{i}: " + if subgraph.is_acc + else f"{self.non_acc_submodule_name}{i}: " + ) reports += f"{len(subgraph.nodes)} node(s)\n" self.tag(subgraphs) @@ -535,9 +554,7 @@ def split_preview(self, dump_graph: bool = False): split_mod.eval() if dump_graph: - drawer = FxGraphDrawer( - split_mod, "preview", ignore_getattr=True - ) + drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True) dot_graphs = drawer.get_all_dot_graphs() for name, dot_graph in dot_graphs.items(): # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. @@ -564,9 +581,7 @@ def get_inputs(self, inputs): handle.remove() return sub_inputs - submod_inputs = get_submod_inputs( - split_mod, submod, self.sample_input - ) + submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input) ShapeProp(submod).propagate(*submod_inputs) total_input_bytes = 0 @@ -649,9 +664,7 @@ def find_reverse_deps( return result - def update_reverse_deps_for_fusions( - self, deps: Dict[torch.fx.Node, NodeSet] - ): + def update_reverse_deps_for_fusions(self, deps: Dict[torch.fx.Node, NodeSet]): processed_node = set() for node, fusion in self.fusions.items(): @@ -853,7 +866,11 @@ def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph def tag(self, subgraphs: List[Subgraph]): self.tags = [] for subgraph in subgraphs: - tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}" + tag = ( + f"_run_on_acc_{len(self.tags)}" + if subgraph.is_acc + else f"{self.non_acc_submodule_name}{len(self.tags)}" + ) self.tags.append(tag) for node in subgraph.nodes: if hasattr(node, "tag"): @@ -863,7 +880,9 @@ def tag(self, subgraphs: List[Subgraph]): self._node_submodule_map[node.name] = tag def split(self, remove_tag: bool = False) -> torch.fx.GraphModule: - split_module = split_by_tags(self.module, self.tags, return_tuple=self._return_tuple) + split_module = split_by_tags( + self.module, self.tags, return_tuple=self._return_tuple + ) if remove_tag: for node in self.module.graph.nodes: if hasattr(node, "tag"): @@ -875,14 +894,16 @@ def __call__(self) -> torch.fx.GraphModule: subgraphs = self.remove_small_acc_subgraphs(subgraphs) acc_subgraphs_count = len([s for s in subgraphs if s.is_acc]) non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count - print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs") + print( + f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs" + ) self.tag(subgraphs) return self.split() def generate_split_results(self) -> SplitResult: split_module = self() submodule_names = [] - for name, mod in split_module.named_children(): + for name, _mod in split_module.named_children(): submodule_names.append(name) if ( self.settings.max_acc_splits > 0 @@ -894,5 +915,7 @@ def generate_split_results(self) -> SplitResult: "result in performance issues." ) - submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names) + submodule_inputs = generate_inputs_for_submodules( + split_module, self.sample_input, submodule_names + ) return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name) diff --git a/torch/fx/passes/tests/test_pass_manager.py b/torch/fx/passes/tests/test_pass_manager.py index 60ed6671179b2..157dc4017eda5 100644 --- a/torch/fx/passes/tests/test_pass_manager.py +++ b/torch/fx/passes/tests/test_pass_manager.py @@ -26,9 +26,7 @@ def test_this_before_that_pass_constraint(self) -> None: def test_these_before_those_pass_constraint(self) -> None: passes = [lambda x: 2 * x for _ in range(10)] constraint = these_before_those_pass_constraint(passes[-1], passes[0]) - pm = PassManager( - [inplace_wrapper(p) for p in passes] - ) + pm = PassManager([inplace_wrapper(p) for p in passes]) # add unfulfillable constraint pm.add_constraint(constraint) @@ -46,7 +44,7 @@ def test_two_pass_managers(self) -> None: pm1.add_pass(p) pm1.add_constraint(constraint) output1 = pm1(1) - self.assertEqual(output1, 2 ** 3) + self.assertEqual(output1, 2**3) passes = [lambda x: 3 * x for _ in range(3)] constraint = these_before_those_pass_constraint(passes[0], passes[1]) @@ -55,4 +53,4 @@ def test_two_pass_managers(self) -> None: pm2.add_pass(p) pm2.add_constraint(constraint) output2 = pm2(1) - self.assertEqual(output2, 3 ** 3) + self.assertEqual(output2, 3**3) diff --git a/torch/fx/passes/tools_common.py b/torch/fx/passes/tools_common.py index aac071ace8c2d..4ed56be63b092 100644 --- a/torch/fx/passes/tools_common.py +++ b/torch/fx/passes/tools_common.py @@ -1,15 +1,22 @@ # mypy: allow-untyped-defs -from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional import collections -from dataclasses import dataclass import operator +from dataclasses import dataclass +from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union import torch import torch.fx -from torch.fx.node import _get_qualified_name from torch.fx._compatibility import compatibility +from torch.fx.node import _get_qualified_name -__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph'] + +__all__ = [ + "get_acc_ops_name", + "get_node_target", + "is_node_output_tensor", + "FxNetAccFusionsFinder", + "legalize_graph", +] Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]] TensorOrTensors = Union[torch.Tensor, Tensors] @@ -26,12 +33,16 @@ def get_acc_ops_name(k): elif k.__module__ and "acc_ops" in k.__module__: return f"acc_ops.{k.__name__}" else: - module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module + module = k.__module__.replace( + "torch._ops", "torch.ops" + ) # WAR for bug in how torch.ops assigns module return f"{module if module else ''}.{k.__name__}" @compatibility(is_backward_compatible=False) -def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str: +def get_node_target( + submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node +) -> str: """ Given a `node` returns its target typename. @@ -66,6 +77,7 @@ def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.No assert isinstance(node.target, str) return node.target + @compatibility(is_backward_compatible=False) def is_node_output_tensor(node: torch.fx.Node) -> bool: """Checks if the node output produces a Tensor or not. @@ -77,6 +89,7 @@ def is_node_output_tensor(node: torch.fx.Node) -> bool: type_ = node.meta.get("type", None) return type_ is not None and issubclass(type_, torch.Tensor) + @compatibility(is_backward_compatible=False) class FxNetAccFusionsFinder: """ @@ -297,7 +310,9 @@ def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: # If the new graph's size is not as large as the old one, then there must be # a cycle (i.e. some node's dependencies were not satisfied.) if len(new_graph.nodes) < len(gm.graph.nodes): - raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}") + raise RuntimeError( + f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}" + ) new_graph._codegen = gm.graph._codegen gm.graph = new_graph return gm diff --git a/torch/fx/passes/utils/__init__.py b/torch/fx/passes/utils/__init__.py index 2a7970ba4c283..ee5e7e66868a0 100644 --- a/torch/fx/passes/utils/__init__.py +++ b/torch/fx/passes/utils/__init__.py @@ -1 +1 @@ -from .common import lift_subgraph_as_module, HolderModule, compare_graphs +from .common import compare_graphs, HolderModule, lift_subgraph_as_module diff --git a/torch/fx/passes/utils/common.py b/torch/fx/passes/utils/common.py index ba2ae45aabf5d..bb628372337b4 100644 --- a/torch/fx/passes/utils/common.py +++ b/torch/fx/passes/utils/common.py @@ -3,7 +3,6 @@ from torch.fx._compatibility import compatibility from torch.fx.graph import Graph - from torch.fx.graph_module import GraphModule from torch.fx.passes.utils.matcher_utils import SubgraphMatcher from torch.nn import Module diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 3268cc4a493c7..8bcb9dee71c2f 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -1,15 +1,16 @@ # mypy: allow-untyped-defs import copy from queue import SimpleQueue -from typing import List, Dict, Optional as _Optional, Tuple +from typing import Dict, List, Optional as _Optional, Tuple import torch.fx -from torch.fx.graph_module import GraphModule +from torch.fx._compatibility import compatibility from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule from torch.fx.node import Node -from torch.fx.passes.tools_common import NodeList, NodeSet, legalize_graph +from torch.fx.passes.tools_common import legalize_graph, NodeList, NodeSet from torch.fx.passes.utils import lift_subgraph_as_module -from torch.fx._compatibility import compatibility + @compatibility(is_backward_compatible=False) def topo_sort(nodes: NodeList) -> NodeList: @@ -35,7 +36,9 @@ def topo_sort(nodes: NodeList) -> NodeList: if indegree_map[n] == 0: candidates.put(n) - assert len(nodes) == len(sorted_nodes), "topological sorted nodes doesn't have same length as input nodes" + assert len(nodes) == len( + sorted_nodes + ), "topological sorted nodes doesn't have same length as input nodes" return sorted_nodes @@ -96,7 +99,6 @@ def fuse_as_graphmodule( module_name: str, partition_lookup_table: _Optional[Dict[Node, None]] = None, ) -> Tuple[GraphModule, Tuple[Node, ...], Tuple[Node, ...]]: - """ Fuse nodes in graph_module into a GraphModule. @@ -121,9 +123,13 @@ def fuse_as_graphmodule( # assumption: nodes are already sorted in topo order for node in nodes: - assert node.graph.owning_module is gm, f"{node} doesn't belong to passed in graph module {gm._get_name()}" + assert ( + node.graph.owning_module is gm + ), f"{node} doesn't belong to passed in graph module {gm._get_name()}" assert not node._erased, f"{node} has been removed from owning graph" - assert node in gm.graph._find_nodes_lookup_table, f"{node} is not found in graph module {gm._get_name()}" + assert ( + node in gm.graph._find_nodes_lookup_table + ), f"{node} is not found in graph module {gm._get_name()}" # validates partition doesn't introduce dependency circles in the graph assert validate_partition(nodes), "Invalid partition, found dependency cycles" @@ -134,8 +140,10 @@ def fuse_as_graphmodule( subgraph = Graph() - node_to_placeholder: Dict[Node, Node] = {} # mapping of nodes from old graph to placeholder in new graph - node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph + node_to_placeholder: Dict[ + Node, Node + ] = {} # mapping of nodes from old graph to placeholder in new graph + node_map: Dict[Node, Node] = {} # mapping of nodes from old graph to new graph # handles inputs through graph.node_copy's arg_transform functions def remap_inputs(x): @@ -184,7 +192,9 @@ def remap_inputs(x): # lint to ensure correctness subgraph.lint() fused_gm: GraphModule - fused_gm, _ = lift_subgraph_as_module(gm, subgraph, comp_name="", class_name=module_name) + fused_gm, _ = lift_subgraph_as_module( + gm, subgraph, comp_name="", class_name=module_name + ) # sub_gm's input nodes in the original module original_inputs: Tuple[Node, ...] = tuple(node_to_placeholder.keys()) @@ -196,16 +206,18 @@ def remap_inputs(x): @compatibility(is_backward_compatible=False) -def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, ...], orig_outputs: Tuple[Node, ...]): +def insert_subgm( + gm: GraphModule, + sub_gm: GraphModule, + orig_inputs: Tuple[Node, ...], + orig_outputs: Tuple[Node, ...], +): # add sub_gm into gm submodule_name = sub_gm.__class__.__name__ gm.add_submodule(submodule_name, sub_gm) # Create a call_module node in main graph. - module_node = gm.graph.call_module( - submodule_name, - args=orig_inputs, - kwargs=None) + module_node = gm.graph.call_module(submodule_name, args=orig_inputs, kwargs=None) if len(orig_outputs) == 1: # main_remapping[comp.orig_outputs[0]] = module_node @@ -216,24 +228,30 @@ def insert_subgm(gm: GraphModule, sub_gm: GraphModule, orig_inputs: Tuple[Node, proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index] orig_output.replace_all_uses_with(proxy_out, propagate_meta=True) - module_node.meta["val"] = tuple(orig_output.meta.get("val", None) for orig_output in orig_outputs) + module_node.meta["val"] = tuple( + orig_output.meta.get("val", None) for orig_output in orig_outputs + ) return gm + @compatibility(is_backward_compatible=False) def erase_nodes(gm: GraphModule, nodes: NodeList): - # erase original nodes in inversed topological order for node in reversed(nodes): gm.graph.erase_node(node) @compatibility(is_backward_compatible=False) -def fuse_by_partitions(gm: GraphModule, partitions: List[Dict[Node, None]], prefix: str = "fused_") -> GraphModule: +def fuse_by_partitions( + gm: GraphModule, partitions: List[Dict[Node, None]], prefix: str = "fused_" +) -> GraphModule: for partition_id, partition in enumerate(partitions): sorted_nodes = topo_sort(list(partition)) submodule_name = prefix + str(partition_id) - sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name, partition) + sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule( + gm, sorted_nodes, submodule_name, partition + ) insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) diff --git a/torch/fx/passes/utils/matcher_utils.py b/torch/fx/passes/utils/matcher_utils.py index ba09ad177d29e..cc05b8f512b15 100644 --- a/torch/fx/passes/utils/matcher_utils.py +++ b/torch/fx/passes/utils/matcher_utils.py @@ -10,6 +10,7 @@ from torch.fx import Graph, Node from torch.fx._compatibility import compatibility + __all__ = ["SubgraphMatcher", "InternalMatch"] diff --git a/torch/fx/passes/utils/source_matcher_utils.py b/torch/fx/passes/utils/source_matcher_utils.py index 0a4f072644cdb..f77db98880b76 100644 --- a/torch/fx/passes/utils/source_matcher_utils.py +++ b/torch/fx/passes/utils/source_matcher_utils.py @@ -1,19 +1,21 @@ +import logging +import os from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Type + +from torch.fx._compatibility import compatibility from torch.fx.graph import Graph from torch.fx.node import Node -from torch.fx._compatibility import compatibility -from typing import Dict, List, Any, Type, Optional, Callable -import logging -import os -__all__ = ['get_source_partitions', 'check_subgraphs_connected', 'SourcePartition'] +__all__ = ["get_source_partitions", "check_subgraphs_connected", "SourcePartition"] + # Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs def _init_logger() -> logging.Logger: logger = logging.getLogger(__name__) - level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper() + level = os.environ.get("PYTORCH_MATCHER_LOGLEVEL", "WARNING").upper() logger.setLevel(level) console = logging.StreamHandler() formatter = logging.Formatter("%(filename)s > %(message)s") @@ -24,6 +26,7 @@ def _init_logger() -> logging.Logger: logger.propagate = False return logger + logger = _init_logger() @@ -77,8 +80,9 @@ def get_source_partitions( # be different from "source_fn_stack", for example for the add_ node # decomposed from batch norm. We should remove the check on "source_fn_stack" # after we fix "torch_fn". T199561090 - if ((source_fn_st := node.meta.get("source_fn_stack", None)) is None and - (torch_fn := node.meta.get("torch_fn", None)) is not None): + if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and ( + torch_fn := node.meta.get("torch_fn", None) + ) is not None: node_fqn, source_fn = torch_fn source_fn_name = source_fn.split(".")[1] if source_fn_name in wanted_sources: @@ -86,7 +90,6 @@ def get_source_partitions( partition = diff_modules.setdefault(node_fqn, []) partition.append(node) - if (source_fn_st := node.meta.get("source_fn_stack", None)) is not None: source_fn = source_fn_st[-1] if source_fn[1] in wanted_sources: @@ -140,7 +143,9 @@ def make_partition(nodes: List[Node], module_type: Type) -> SourcePartition: @compatibility(is_backward_compatible=False) # type: ignore[misc] -def check_subgraphs_connected(subgraph1: SourcePartition, subgraph2: SourcePartition) -> bool: +def check_subgraphs_connected( + subgraph1: SourcePartition, subgraph2: SourcePartition +) -> bool: """ Given two subgraphs A and B (in the form of a list of nodes), checks if A has nodes connecting to at least one node in B -- aka there exists a node diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 86927595eac91..ccbe065754740 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -1,29 +1,37 @@ # mypy: ignore-errors -import enum -import dis +import collections import copy -import sys -import torch +import dis +import enum import inspect -import operator -import collections import logging +import operator +import sys +from dataclasses import fields, is_dataclass +from typing import Any, Callable, Dict, Iterator, Optional, OrderedDict, Tuple -from dataclasses import is_dataclass, fields - - -from .graph import magic_methods, reflectable_magic_methods, Graph +import torch +import torch.fx.traceback as fx_traceback from torch.utils._traceback import CapturedTraceback -from typing import Tuple, Dict, OrderedDict, Optional, Any, Iterator, Callable -from .node import Target, Node, Argument, base_types, map_aggregate + from ._compatibility import compatibility +from .graph import Graph, magic_methods, reflectable_magic_methods +from .node import Argument, base_types, map_aggregate, Node, Target from .operator_schemas import check_for_mutable_operation -import torch.fx.traceback as fx_traceback -__all__ = ['TracerBase', 'GraphAppendingTracer', 'TraceError', - 'Proxy', 'Attribute', 'ParameterProxy', 'Scope', - 'ScopeContextManager'] + +__all__ = [ + "TracerBase", + "GraphAppendingTracer", + "TraceError", + "Proxy", + "MetaProxy", + "Attribute", + "ParameterProxy", + "Scope", + "ScopeContextManager", +] log = logging.getLogger(__name__) @@ -31,7 +39,7 @@ @compatibility(is_backward_compatible=False) class Scope: - """ Scope object that records the module path and the module type + """Scope object that records the module path and the module type of a module. Scope is used to track the information of the module that contains a Node in a Graph of GraphModule. For example:: @@ -41,6 +49,7 @@ def forward(self, x): # scope for this would be (module_path="sub", module_type=Sub) return x.transpose(1, 2) + class M(torch.nn.Module): def __init__(self) -> None: self.sub = Sub() @@ -62,7 +71,7 @@ def __init__(self, module_path: str, module_type: Any): @compatibility(is_backward_compatible=False) class ScopeContextManager: - """ A context manager to track the Scope of Node during symbolic tracing. + """A context manager to track the Scope of Node during symbolic tracing. When entering a forward function of a Module, we'll update the scope information of the current module, and when we exit, we'll restore the previous scope information. """ @@ -102,28 +111,28 @@ def __exit__(self, *args): "quantization_tag", # TODO deprecated "_numeric_debug_handle", # TODO deprecated "custom", - "partitioner_tag" + "partitioner_tag", ] @compatibility(is_backward_compatible=True) class TracerBase: graph: Graph - record_stack_traces : bool = False + record_stack_traces: bool = False # Feature flag for mutable schema checking # Enableby default in 1.12 - check_mutable_operations : bool = False + check_mutable_operations: bool = False # Feature flag for assert tracing - trace_asserts : bool = False + trace_asserts: bool = False # Feature flag for proxying accesses to buffer values - proxy_buffer_attributes : bool = False + proxy_buffer_attributes: bool = False # Name of the function to be traced. It will only be used when # ``root`` is an instance of ``nn.Module`` traced_func_name: str = "forward" # Maps the containing module's name to the operator name - scope : Scope + scope: Scope # Records the module call stack module_stack: OrderedDict[str, Tuple[str, Any]] @@ -132,9 +141,15 @@ class TracerBase: node_name_to_scope: Dict[str, Tuple[str, type]] @compatibility(is_backward_compatible=True) - def create_node(self, kind : str, target : Target, - args : Tuple[Argument, ...], kwargs : Dict[str, Argument], name : Optional[str] = None, - type_expr : Optional[Any] = None) -> Node: + def create_node( + self, + kind: str, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: """ Inserts a graph node given target, args, kwargs, and name. @@ -143,7 +158,7 @@ def create_node(self, kind : str, target : Target, want to disallow in-place operations from being recorded. """ - if kind == 'call_function' and self.check_mutable_operations: + if kind == "call_function" and self.check_mutable_operations: check_for_mutable_operation(target, args, kwargs) node = self.graph.create_node(kind, target, args, kwargs, name, type_expr) @@ -182,20 +197,27 @@ def create_node(self, kind : str, target : Target, node.meta["seq_nr"] = new_seq_nr elif self.module_stack: - node.meta['nn_module_stack'] = copy.copy(self.module_stack) + node.meta["nn_module_stack"] = copy.copy(self.module_stack) log.debug("create_node %s", node) return node @compatibility(is_backward_compatible=True) - def proxy(self, node: Node) -> 'Proxy': + def proxy(self, node: Node) -> "Proxy": return Proxy(node, self) @compatibility(is_backward_compatible=True) - def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: Dict[str, Any], - name: Optional[str] = None, type_expr : Optional[Any] = None, - proxy_factory_fn: Callable[[Node], 'Proxy'] = None): - ''' + def create_proxy( + self, + kind: str, + target: Target, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Callable[[Node], "Proxy"] = None, + ): + """ Create a Node from the given arguments, then return the Node wrapped in a Proxy object. @@ -203,7 +225,7 @@ def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: represents the parameter of a function. If we need to encode a default parameter, we use the ``args`` tuple. ``args`` is otherwise empty for ``placeholder`` Nodes. - ''' + """ args_ = self.create_arg(args) kwargs_ = self.create_arg(kwargs) @@ -218,8 +240,7 @@ def create_proxy(self, kind: str, target: Target, args: Tuple[Any, ...], kwargs: proxy = proxy_factory_fn(node) if self.record_stack_traces and not proxy.node.stack_trace: - proxy.node.stack_trace = ''.join(CapturedTraceback.extract().format()) - + proxy.node.stack_trace = "".join(CapturedTraceback.extract().format()) return proxy @@ -233,20 +254,23 @@ def _find_user_frame(self): # the user code during tracing. frame = inspect.currentframe() - pt_files = ['torch/fx/proxy.py', - 'torch/fx/_symbolic_trace.py', - 'torch/fx/experimental/proxy_tensor.py', - 'torch/_ops.py', - 'torch/_tensor.py', - 'torch/utils/_python_dispatch.py', - 'torch/_prims_common/wrappers.py', - 'torch/_refs/__init__.py', - 'torch/_refs/nn/functional/__init__.py', - 'torch/utils/_stats.py', - ] + pt_files = [ + "torch/fx/proxy.py", + "torch/fx/_symbolic_trace.py", + "torch/fx/experimental/proxy_tensor.py", + "torch/_ops.py", + "torch/_tensor.py", + "torch/utils/_python_dispatch.py", + "torch/_prims_common/wrappers.py", + "torch/_refs/__init__.py", + "torch/_refs/nn/functional/__init__.py", + "torch/utils/_stats.py", + ] while frame: frame = frame.f_back - if frame and all(not frame.f_code.co_filename.endswith(file) for file in pt_files): + if frame and all( + not frame.f_code.co_filename.endswith(file) for file in pt_files + ): break if not frame: @@ -264,11 +288,11 @@ def create_arg(self, a: Any) -> Argument: """ if isinstance(a, Proxy): return a.node # most common arg type goes first - elif hasattr(a, '__fx_create_arg__'): + elif hasattr(a, "__fx_create_arg__"): return a.__fx_create_arg__(self) # aggregates elif isinstance(a, tuple): - if hasattr(a, '_fields'): + if hasattr(a, "_fields"): # NamedTuple constructors don't seem to like getting a generator # expression as an argument to their constructor, so build this # intermediate tuple and unpack it into the NamedTuple constructor @@ -278,10 +302,13 @@ def create_arg(self, a: Any) -> Argument: elif isinstance(a, list): return [self.create_arg(elem) for elem in a] elif isinstance(a, dict): + def no_node(arg): if isinstance(arg, Node): - raise RuntimeError("Keys for dictionaries used as an argument cannot contain a " - f"Node. Got key: {k}") + raise RuntimeError( + "Keys for dictionaries used as an argument cannot contain a " + f"Node. Got key: {k}" + ) r = {} for k, v in a.items(): @@ -294,16 +321,27 @@ def no_node(arg): r[k] = self.create_arg(v) return r elif isinstance(a, slice): - return slice(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) + return slice( + self.create_arg(a.start), + self.create_arg(a.stop), + self.create_arg(a.step), + ) elif isinstance(a, range): - return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step)) + return range( + self.create_arg(a.start), + self.create_arg(a.stop), + self.create_arg(a.step), + ) elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): return a elif is_dataclass(a): - kwargs = {field.name: self.create_arg(getattr(a, field.name)) for field in fields(a)} + kwargs = { + field.name: self.create_arg(getattr(a, field.name)) + for field in fields(a) + } return self.create_node("call_function", a.__class__, (), kwargs) elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...: @@ -312,37 +350,41 @@ def no_node(arg): raise NotImplementedError(f"argument of type: {type(a)}") @compatibility(is_backward_compatible=True) - def to_bool(self, obj: 'Proxy') -> bool: + def to_bool(self, obj: "Proxy") -> bool: """Called when a proxy object is being converted to a boolean, such as when used in control flow. Normally we don't know what to do because we don't know the value of the proxy, but a custom tracer can attach more information to the graph node using create_node and can choose to return a value. """ - raise TraceError('symbolically traced variables cannot be used as inputs to control flow') + raise TraceError( + "symbolically traced variables cannot be used as inputs to control flow" + ) @compatibility(is_backward_compatible=True) - def iter(self, obj: 'Proxy') -> Iterator: + def iter(self, obj: "Proxy") -> Iterator: """Called when a proxy object is being iterated over, such as when used in control flow. Normally we don't know what to do because we don't know the value of the proxy, but a custom tracer can attach more information to the graph node using create_node and can choose to return an iterator. """ - raise TraceError('Proxy object cannot be iterated. This can be ' - 'attempted when the Proxy is used in a loop or' - ' as a *args or **kwargs function argument. ' - 'See the torch.fx docs on pytorch.org for a ' - 'more detailed explanation of what types of ' - 'control flow can be traced, and check out the' - ' Proxy docstring for help troubleshooting ' - 'Proxy iteration errors') + raise TraceError( + "Proxy object cannot be iterated. This can be " + "attempted when the Proxy is used in a loop or" + " as a *args or **kwargs function argument. " + "See the torch.fx docs on pytorch.org for a " + "more detailed explanation of what types of " + "control flow can be traced, and check out the" + " Proxy docstring for help troubleshooting " + "Proxy iteration errors" + ) @compatibility(is_backward_compatible=True) - def keys(self, obj: 'Proxy') -> Any: + def keys(self, obj: "Proxy") -> Any: """Called when a proxy object is has the keys() method called. This is what happens when ** is called on a proxy. This should return an iterator it ** is suppose to work in your custom tracer. """ - return Attribute(obj, 'keys')() + return Attribute(obj, "keys")() # used in Proxy object when just appending to the graph while not tracing. @@ -355,14 +397,17 @@ def __init__(self, graph: Graph): self.module_stack = collections.OrderedDict() self.node_name_to_scope = {} + @compatibility(is_backward_compatible=False) def assert_fn(x): assert x + @compatibility(is_backward_compatible=True) class TraceError(ValueError): pass + @compatibility(is_backward_compatible=True) class Proxy: """ @@ -394,7 +439,7 @@ class Proxy: """ @compatibility(is_backward_compatible=True) - def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): + def __init__(self, node: Node, tracer: "Optional[TracerBase]" = None): if tracer is None: # This allows you to create a Proxy object around a raw Node tracer = GraphAppendingTracer(node.graph) @@ -402,9 +447,9 @@ def __init__(self, node: Node, tracer: 'Optional[TracerBase]' = None): self.node = node def __repr__(self) -> str: - return f'Proxy({self.node.name})' + return f"Proxy({self.node.name})" - def __getattr__(self, k) -> 'Attribute': + def __getattr__(self, k) -> "Attribute": # note: not added to the graph yet, if this is a method call # we peephole optimize to the method invocation return Attribute(self, k) @@ -417,6 +462,7 @@ def __deepcopy__(self, memo) -> Dict: # will go to __getattr__(self, "__deepcopy__") and return a # Attribute(__deepcopy__), and may go into an infinite loop in some cases. import copy + new_dict = {} for k, v in self.__dict__.items(): try: @@ -424,7 +470,10 @@ def __deepcopy__(self, memo) -> Dict: except Exception: log.warning( "Shallow copy %s of Proxy because it cannot be deepcopied. " - "Proxy is created for node %s", k, self.node.name) + "Proxy is created for node %s", + k, + self.node.name, + ) new_obj = copy.copy(v) new_dict[k] = new_obj assert "node" in new_dict @@ -438,10 +487,12 @@ def __setstate__(self, d): # This is called when being unpickled/loaded. self.__dict__ = d - def __call__(self, *args, **kwargs) -> 'Proxy': - return self.tracer.create_proxy('call_method', '__call__', (self,) + args, kwargs) + def __call__(self, *args, **kwargs) -> "Proxy": + return self.tracer.create_proxy( + "call_method", "__call__", (self,) + args, kwargs + ) - def __iter__(self) -> Iterator['Proxy']: + def __iter__(self) -> Iterator["Proxy"]: frame = inspect.currentframe() assert frame is not None calling_frame = frame.f_back @@ -449,17 +500,20 @@ def __iter__(self) -> Iterator['Proxy']: inst_list = list(dis.get_instructions(calling_frame.f_code)) if sys.version_info >= (3, 11): from bisect import bisect_left - inst_idx = bisect_left(inst_list, calling_frame.f_lasti, key=lambda x: x.offset) + + inst_idx = bisect_left( + inst_list, calling_frame.f_lasti, key=lambda x: x.offset + ) else: inst_idx = calling_frame.f_lasti // 2 inst = inst_list[inst_idx] - if inst.opname == 'UNPACK_SEQUENCE': + if inst.opname == "UNPACK_SEQUENCE": return (self[i] for i in range(inst.argval)) # type: ignore[index] return self.tracer.iter(self) def __abs__(self): - return self.tracer.create_proxy('call_function', operator.abs, (self,), {}) + return self.tracer.create_proxy("call_function", operator.abs, (self,), {}) def __bool__(self) -> bool: if self.tracer.trace_asserts: @@ -472,19 +526,23 @@ def __bool__(self) -> bool: insts = list(dis.get_instructions(calling_frame.f_code)) if sys.version_info >= (3, 11): from bisect import bisect_left + cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset) else: cur = calling_frame.f_lasti // 2 inst = insts[cur] - if inst.opname == 'POP_JUMP_IF_TRUE': + if inst.opname == "POP_JUMP_IF_TRUE": first = insts[cur + 1] assert inst.arg is not None last = insts[inst.arg // 2 - 1] - starts_with_assert = (first.opname == 'LOAD_GLOBAL' and first.argval == 'AssertionError' - or first.opname == 'LOAD_ASSERTION_ERROR') - if starts_with_assert and last.opname == 'RAISE_VARARGS': - self.tracer.create_proxy('call_function', assert_fn, (self,), {}) + starts_with_assert = ( + first.opname == "LOAD_GLOBAL" + and first.argval == "AssertionError" + or first.opname == "LOAD_ASSERTION_ERROR" + ) + if starts_with_assert and last.opname == "RAISE_VARARGS": + self.tracer.create_proxy("call_function", assert_fn, (self,), {}) return True return self.tracer.to_bool(self) @@ -494,39 +552,90 @@ def keys(self): return self.tracer.keys(self) def __len__(self): - raise RuntimeError("'len' is not supported in symbolic tracing by default. If you want " - "this call to be recorded, please call torch.fx.wrap('len') at " - "module scope") + raise RuntimeError( + "'len' is not supported in symbolic tracing by default. If you want " + "this call to be recorded, please call torch.fx.wrap('len') at " + "module scope" + ) @classmethod def __torch_function__(cls, orig_method, types, args=None, kwargs=None): args = args if args else () kwargs = kwargs if kwargs else {} - tracers : Dict[Any, None] = {} + tracers: Dict[Any, None] = {} def find_tracer(a): if isinstance(a, cls): tracers[a.tracer] = None + torch.fx.node.map_aggregate(args, find_tracer) torch.fx.node.map_aggregate(kwargs, find_tracer) if len(tracers) > 1: - raise RuntimeError(f'Found multiple different tracers {list(tracers.keys())} while ' - f'trying to trace operations {orig_method}') + raise RuntimeError( + f"Found multiple different tracers {list(tracers.keys())} while " + f"trying to trace operations {orig_method}" + ) tracer = next(iter(tracers.keys())) if isinstance(orig_method, torch._C.ScriptMethod): args = (orig_method.owner,) + args - return tracer.create_proxy('call_method', orig_method.name, args, kwargs) + return tracer.create_proxy("call_method", orig_method.name, args, kwargs) if torch.overrides.is_tensor_method_or_property(orig_method): - return tracer.create_proxy('call_method', orig_method.__name__, args, kwargs) + return tracer.create_proxy( + "call_method", orig_method.__name__, args, kwargs + ) else: if isinstance(orig_method, torch._ops.HigherOrderOperator): # TODO: Define how to symbolically trace HigherOrderOperators raise RuntimeError("Unable to symbolically trace HigherOrderOperators") - return tracer.create_proxy('call_function', orig_method, args, kwargs, - name=tracer.graph._target_to_str(orig_method.__name__)) + return tracer.create_proxy( + "call_function", + orig_method, + args, + kwargs, + name=tracer.graph._target_to_str(orig_method.__name__), + ) + + +@compatibility(is_backward_compatible=False) +class MetaProxy(Proxy): + """ + A Proxy subclass that propagates metadata (meta['val']) during graph tracing. + """ + + def __init__( + self, node: Node, tracer: "Optional[TracerBase]" = None, fake_mode=None + ): + super().__init__(node, tracer) + self.fake_mode = fake_mode + + def __repr__(self) -> str: + return f"MetaProxy({self.node.name})" + + @classmethod + def __torch_function__(cls, orig_method, types, args=None, kwargs=None): + args = args if args else () + kwargs = kwargs if kwargs else {} + + meta_proxy = None + for arg in args: + if isinstance(arg, MetaProxy): + meta_proxy = arg + break + + assert ( + meta_proxy is not None + ), "No MetaProxy found in arguments, but one is expected." + + proxy = super().__torch_function__(orig_method, types, args, kwargs) + with meta_proxy.fake_mode: + proxy.node.meta["val"] = orig_method( + *[a.node.meta["val"] if isinstance(a, Proxy) else a for a in args], + **kwargs, + ) + return MetaProxy(proxy.node, proxy.tracer, meta_proxy.fake_mode) @compatibility(is_backward_compatible=True) @@ -543,11 +652,15 @@ def node(self): # the node for attributes is added lazily, since most will just be method calls # which do not rely on the getitem call if self._node is None: - self._node = self.tracer.create_proxy('call_function', getattr, (self.root, self.attr), {}).node + self._node = self.tracer.create_proxy( + "call_function", getattr, (self.root, self.attr), {} + ).node return self._node def __call__(self, *args, **kwargs): - return self.tracer.create_proxy('call_method', self.attr, (self.root,) + args, kwargs) + return self.tracer.create_proxy( + "call_method", self.attr, (self.root,) + args, kwargs + ) @compatibility(is_backward_compatible=False) @@ -557,6 +670,7 @@ class ParameterProxy(Proxy): attribute accesses pass through to the underlying module parameter object, so that conditional tests on these attributes will not throw exception during tracing """ + def __init__(self, tracer: TracerBase, node: Node, name, param): super().__init__(node, tracer) assert isinstance(param, torch.nn.Parameter) @@ -564,7 +678,7 @@ def __init__(self, tracer: TracerBase, node: Node, name, param): self.name = name def __repr__(self) -> str: - return f'ParameterProxy({self.name})' + return f"ParameterProxy({self.name})" @property def shape(self): @@ -588,25 +702,31 @@ def nelement(self): for method in magic_methods: + def _scope(method): def impl(*args, **kwargs): tracer = args[0].tracer target = getattr(operator, method) - return tracer.create_proxy('call_function', target, args, kwargs) + return tracer.create_proxy("call_function", target, args, kwargs) + impl.__name__ = method as_magic = f'__{method.strip("_")}__' setattr(Proxy, as_magic, impl) + _scope(method) + def _define_reflectable(orig_method_name): method_name = f'__r{orig_method_name.strip("_")}__' def impl(self, rhs): target = getattr(operator, orig_method_name) - return self.tracer.create_proxy('call_function', target, (rhs, self), {}) + return self.tracer.create_proxy("call_function", target, (rhs, self), {}) + impl.__name__ = method_name impl.__qualname__ = method_name setattr(Proxy, method_name, impl) + for orig_method_name in reflectable_magic_methods: _define_reflectable(orig_method_name) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index c0d88821d7faf..b823fda3123fa 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -1,18 +1,36 @@ -from .graph_module import GraphModule -from .graph import Graph -from .node import Node -from ._symbolic_trace import symbolic_trace -from ._compatibility import compatibility - import copy from dataclasses import dataclass -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING +from typing import ( + Any, + Callable, + Dict, + List, + NamedTuple, + Optional, + Set, + TYPE_CHECKING, + Union, +) + import torch +from ._compatibility import compatibility +from ._symbolic_trace import symbolic_trace +from .graph import Graph +from .graph_module import GraphModule +from .node import Node + + if TYPE_CHECKING: from .passes.utils.matcher_with_name_node_map_utils import InternalMatch -__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"] +__all__ = [ + "Match", + "replace_pattern", + "replace_pattern_with_filters", + "ReplacedPatterns", +] + @compatibility(is_backward_compatible=True) class Match(NamedTuple): @@ -21,6 +39,7 @@ class Match(NamedTuple): # Maps nodes in the pattern subgraph to nodes in the larger graph nodes_map: Dict[Node, Node] + @compatibility(is_backward_compatible=False) @dataclass class ReplacedPatterns: @@ -31,6 +50,7 @@ class ReplacedPatterns: # List of nodes that were added into the graph replacements: List[Node] + def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: gm.delete_all_unused_submodules() @@ -48,7 +68,6 @@ def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]: for node in gm.graph.nodes: if node.op == "call_module" or node.op == "get_attr": - gm_attr = try_get_attr(gm, node.target) replacement_attr = try_get_attr(replacement, node.target) @@ -70,11 +89,14 @@ def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]: # CASE 3: The target doesn't exist as an attribute in `gm` # or `replacement` else: - raise RuntimeError('Attempted to create a "', node.op, - '" node during subgraph rewriting ' - f"with target {node.target}, but " - "the referenced attribute does not " - "exist in the replacement GraphModule") + raise RuntimeError( + 'Attempted to create a "', + node.op, + '" node during subgraph rewriting ' + f"with target {node.target}, but " + "the referenced attribute does not " + "exist in the replacement GraphModule", + ) gm.graph.lint() @@ -83,7 +105,7 @@ def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]: def replace_pattern( gm: GraphModule, pattern: Union[Callable, GraphModule], - replacement: Union[Callable, GraphModule] + replacement: Union[Callable, GraphModule], ) -> List[Match]: """ Matches all possible non-overlapping sets of operators and their @@ -116,6 +138,7 @@ class Match(NamedTuple): import torch from torch.fx import symbolic_trace, subgraph_rewriter + class M(torch.nn.Module): def __init__(self) -> None: super().__init__() @@ -125,12 +148,15 @@ def forward(self, x, w1, w2): m2 = torch.cat([w1, w2]).sum() return x + torch.max(m1) + torch.max(m2) + def pattern(w1, w2): return torch.cat([w1, w2]).sum() + def replacement(w1, w2): return torch.stack([w1, w2]) + traced_module = symbolic_trace(M()) subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) @@ -199,7 +225,9 @@ def forward(self, x, w1, w2): return add_2 """ match_and_replacements = _replace_pattern(gm, pattern, replacement) - return [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements] + return [ + Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements + ] # Experimental API, not backward compatible @@ -208,10 +236,14 @@ def replace_pattern_with_filters( gm: GraphModule, pattern: Union[Callable, Graph, GraphModule], replacement: Union[Callable, Graph, GraphModule, None] = None, - match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, + match_filters: Optional[ + List[Callable[["InternalMatch", Graph, Graph], bool]] + ] = None, ignore_literals: bool = False, # Placed at the end to avoid breaking backward compatibility - replacement_callback: Optional[Callable[["InternalMatch", Graph, Graph], Graph]] = None, + replacement_callback: Optional[ + Callable[["InternalMatch", Graph, Graph], Graph] + ] = None, ) -> List[ReplacedPatterns]: """ See replace_pattern for documentation. This function is an overload with an additional match_filter argument. @@ -226,20 +258,25 @@ def replace_pattern_with_filters( replacement graph based on the match. """ - return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals, replacement_callback) + return _replace_pattern( + gm, pattern, replacement, match_filters, ignore_literals, replacement_callback + ) def _replace_pattern( gm: GraphModule, pattern: Union[Callable, Graph, GraphModule], replacement: Union[Callable, Graph, GraphModule, None] = None, - match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, + match_filters: Optional[ + List[Callable[["InternalMatch", Graph, Graph], bool]] + ] = None, ignore_literals: bool = False, # Placed at the end to avoid breaking backward compatibility - replacement_callback: Optional[Callable[["InternalMatch", Graph, Graph], Graph]] = None, + replacement_callback: Optional[ + Callable[["InternalMatch", Graph, Graph], Graph] + ] = None, ) -> List[ReplacedPatterns]: - - from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch + from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher if match_filters is None: match_filters = [] @@ -254,15 +291,23 @@ def _replace_pattern( else: pattern_graph = symbolic_trace(pattern).graph - matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False, - remove_overlapping_matches=True, ignore_literals=ignore_literals) + matcher = SubgraphMatcher( + pattern_graph, + match_output=False, + match_placeholder=False, + remove_overlapping_matches=True, + ignore_literals=ignore_literals, + ) _matches: List[InternalMatch] = matcher.match(original_graph) # Filter out matches that don't match the filter _matches = [ - m for m in _matches - if all(match_filter(m, original_graph, pattern_graph) - for match_filter in match_filters) + m + for m in _matches + if all( + match_filter(m, original_graph, pattern_graph) + for match_filter in match_filters + ) ] if isinstance(replacement, GraphModule): @@ -272,20 +317,28 @@ def _replace_pattern( elif callable(replacement): common_replacement_graph = symbolic_trace(replacement).graph else: - assert replacement_callback is not None, "Must provide either a replacement GraphModule or a replacement callback" + assert ( + replacement_callback is not None + ), "Must provide either a replacement GraphModule or a replacement callback" common_replacement_graph = None # As we progressively replace nodes, we'll need to keep track of how the match results should change match_changed_node: Dict[Node, Node] = {} match_and_replacements = [] - for i, match in enumerate(_matches): + for match in _matches: if replacement_callback is not None: - replacement_graph = replacement_callback(match, original_graph, pattern_graph) + replacement_graph = replacement_callback( + match, original_graph, pattern_graph + ) else: - assert common_replacement_graph is not None, "Must provide either a replacement GraphModule or a replacement callback" + assert ( + common_replacement_graph is not None + ), "Must provide either a replacement GraphModule or a replacement callback" replacement_graph = common_replacement_graph - replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] + replacement_placeholders = [ + n for n in replacement_graph.nodes if n.op == "placeholder" + ] # Build connecting between replacement graph's input and original graph input producer node @@ -300,7 +353,9 @@ def _replace_pattern( # Update match.placeholder_nodes and match.nodes_map with the node that replaced gn gn_ind = match.placeholder_nodes.index(gn) match.placeholder_nodes[gn_ind] = match_changed_node[gn] - map_key = list(match.nodes_map.keys())[list(match.nodes_map.values()).index(gn)] + map_key = list(match.nodes_map.keys())[ + list(match.nodes_map.values()).index(gn) + ] match.nodes_map[map_key] = match_changed_node[gn] else: val_map[rn] = gn @@ -322,13 +377,17 @@ def _replace_pattern( break with original_graph.inserting_before(first_user_node): # type: ignore[possibly-undefined] - copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map) + copied_returning_nodes = original_graph.graph_copy( + replacement_graph, val_map + ) if isinstance(copied_returning_nodes, Node): - copied_returning_nodes = (copied_returning_nodes, ) + copied_returning_nodes = (copied_returning_nodes,) # Get a list of nodes that have been replaced into the graph - replacement_nodes: List[Node] = [v for v in val_map.values() if v not in match.placeholder_nodes] + replacement_nodes: List[Node] = [ + v for v in val_map.values() if v not in match.placeholder_nodes + ] # Hook the output Node of the replacement subgraph in to the # original Graph at the correct location @@ -346,7 +405,7 @@ def _replace_pattern( ReplacedPatterns( anchor=match.anchors[0], nodes_map=match.nodes_map, - replacements=replacement_nodes + replacements=replacement_nodes, ) ) diff --git a/torch/fx/tensor_type.py b/torch/fx/tensor_type.py index 83b5a9f8faf65..4f375e461ef28 100644 --- a/torch/fx/tensor_type.py +++ b/torch/fx/tensor_type.py @@ -19,7 +19,7 @@ def __init__(self, dim): self.__args__ = dim def __repr__(self): - return f'TensorType[{self.__args__}]' + return f"TensorType[{self.__args__}]" def __eq__(self, other): if isinstance(other, self.__class__): @@ -38,8 +38,9 @@ class _DynType: """ _DynType defines a type which stands for the absence of type information. """ + def __init__(self) -> None: - self.__name__ = '_DynType' + self.__name__ = "_DynType" def __eq__(self, other): return isinstance(other, self.__class__) @@ -53,6 +54,7 @@ def __repr__(self): Dyn = _DynType() + @compatibility(is_backward_compatible=False) def is_consistent(t1, t2): """ @@ -73,8 +75,10 @@ def is_consistent(t1, t2): return True if isinstance(t1, TensorType) and isinstance(t2, TensorType): - return len(t1.__args__) == len(t2.__args__) and \ - all(is_consistent(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)) + return len(t1.__args__) == len(t2.__args__) and all( + is_consistent(elem1, elem2) + for elem1, elem2 in zip(t1.__args__, t2.__args__) + ) else: return False @@ -98,8 +102,10 @@ def is_more_precise(t1, t2): return True if isinstance(t1, TensorType) and isinstance(t2, TensorType): - return len(t1.__args__) == len(t2.__args__) and \ - all(is_more_precise(elem1, elem2) for elem1, elem2 in zip(t1.__args__, t2.__args__)) + return len(t1.__args__) == len(t2.__args__) and all( + is_more_precise(elem1, elem2) + for elem1, elem2 in zip(t1.__args__, t2.__args__) + ) else: return False diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index 4e72a8011f63a..84c94c75cf66f 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -1,12 +1,21 @@ # mypy: allow-untyped-defs import traceback from contextlib import contextmanager -from typing import List, Any, Dict +from typing import Any, Dict, List + from ._compatibility import compatibility -__all__ = ['preserve_node_meta', 'has_preserved_node_meta', - 'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr', - 'format_stack', 'set_current_meta', 'get_current_meta'] + +__all__ = [ + "preserve_node_meta", + "has_preserved_node_meta", + "set_stack_trace", + "set_grad_fn_seq_nr", + "reset_grad_fn_seq_nr", + "format_stack", + "set_current_meta", + "get_current_meta", +] current_meta: Dict[str, Any] = {} should_preserve_node_meta = False @@ -30,7 +39,7 @@ def preserve_node_meta(): @compatibility(is_backward_compatible=False) -def set_stack_trace(stack : List[str]): +def set_stack_trace(stack: List[str]): global current_meta if should_preserve_node_meta and stack: @@ -43,7 +52,9 @@ def set_grad_fn_seq_nr(seq_nr): if should_preserve_node_meta: # The seq_nr is captured by eager mode in the grad_fn during forward - current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr] + current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [ + seq_nr + ] current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1 @@ -90,7 +101,9 @@ def set_current_meta(node): if "from_node" not in current_meta: current_meta["from_node"] = [(node.name, node.target)] elif current_meta["from_node"][-1][0] != node.name: - current_meta["from_node"] = current_meta["from_node"] + [(node.name, node.target)] + current_meta["from_node"] = current_meta["from_node"] + [ + (node.name, node.target) + ] yield finally: diff --git a/torch/hub.py b/torch/hub.py index c037c6e9dc139..70d867a149ffa 100644 --- a/torch/hub.py +++ b/torch/hub.py @@ -720,7 +720,7 @@ def download_url_to_file( # We deliberately do not use NamedTemporaryFile to avoid restrictive # file permissions being applied to the downloaded file. dst = os.path.expanduser(dst) - for seq in range(tempfile.TMP_MAX): + for _ in range(tempfile.TMP_MAX): tmp_dst = dst + "." + uuid.uuid4().hex + ".partial" try: f = open(tmp_dst, "w+b") diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index d489b51d3cd5d..1216cee929d93 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -570,10 +570,6 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): method_stubs = stubs_fn(nn_module) property_stubs = get_property_stubs(nn_module) hook_stubs, pre_hook_stubs = get_hook_stubs(nn_module) - - user_annotated_ignored_attributes = getattr( - nn_module, "__jit_ignored_attributes__", [] - ) ignored_properties = jit_ignored_properties(nn_module) def init_fn(script_module): @@ -838,9 +834,6 @@ def infer_methods_to_compile(nn_module): (TODO add a link when the rules are published). """ check_module_initialized(nn_module) - user_annotated_ignored_attributes = getattr( - nn_module, "__jit_ignored_attributes__", [] - ) ignored_properties = jit_ignored_properties(nn_module) methods: List[str] = [] diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 1f90e5a6d84d2..1d8dccb049dd2 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -1600,7 +1600,7 @@ def _recursive_compile_class(obj, loc): _qual_name = _qualified_name(obj) # We're starting a new compilation, so update the error call stack in # case it fails - error_stack = torch._C.CallStack(_qual_name, loc) + error_stack = torch._C.CallStack(_qual_name, loc) # noqa: F841 rcb = _jit_internal.createResolutionCallbackForClassMethods(obj) return _compile_and_register_class(obj, rcb, _qual_name) diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py index 56c4a8cb36e3a..c40e27d73e5dc 100644 --- a/torch/jit/_shape_functions.py +++ b/torch/jit/_shape_functions.py @@ -255,7 +255,6 @@ def pool2d_shape_check( outputWidth: int, ): ndim = len(input) - nOutputPlane = nInputPlane assert kW > 0 and kH > 0 assert dW > 0 and dH > 0 @@ -608,12 +607,10 @@ def matmul(tensor1: List[int], tensor2: List[int]): # We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list); # we track m1 vs m2 separately even though they must match for nicer error messages n = tensor1[-2] if dim_tensor1 > 1 else 1 - m1 = tensor1[-1] batch_tensor1: List[int] = [] # TODO: handling of slice for i in range(dim_tensor1 - 2): batch_tensor1.append(tensor1[i]) - m2 = tensor2[-1] if dim_tensor2 > 1 else 1 p = tensor2[-1] batch_tensor2: List[int] = [] # TODO: handling of slice diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index 1dbcdb6a3ca2a..ef5292fe93ecb 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -55,7 +55,6 @@ def _get_interpreter_name_for_var(var): i += 1 f_locals = frame.f_locals - f_globals = frame.f_globals for k, v in f_locals.items(): if isinstance(v, torch.Tensor) and var is v: @@ -136,7 +135,7 @@ def wrapper(*args): else: return tuple(out_vars) - graph, out = torch._C._create_graph_by_tracing( + graph, _out = torch._C._create_graph_by_tracing( wrapper, in_vars + module_state, _create_interpreter_name_lookup_fn(), @@ -241,7 +240,6 @@ def verify(model, args, loss_fn=torch.sum, devices=None): if not isinstance(args, tuple): args = (args,) - saved_args = _clone_inputs(args) if is_module: saved_state = copy.deepcopy(model.state_dict()) diff --git a/torch/jit/unsupported_tensor_ops.py b/torch/jit/unsupported_tensor_ops.py index 46b0a000bd618..903c8aafba26b 100644 --- a/torch/jit/unsupported_tensor_ops.py +++ b/torch/jit/unsupported_tensor_ops.py @@ -40,7 +40,7 @@ def func(x): scope: Dict[str, Any] = {} execWrapper(funcs_str, globals(), scope) try: - cu = torch.jit.CompilationUnit(funcs_str) + torch.jit.CompilationUnit(funcs_str) except Exception as e: if "nonexistent attribute" not in repr(e): continue diff --git a/torch/library.h b/torch/library.h index bfda4955eadde..2761573e2cccf 100644 --- a/torch/library.h +++ b/torch/library.h @@ -206,6 +206,9 @@ class TORCH_API CppFunction final { ~CppFunction(); + CppFunction(const CppFunction&) = delete; + CppFunction& operator=(const CppFunction&) = delete; + CppFunction(CppFunction&&) noexcept = default; CppFunction& operator=(CppFunction&&) = default; @@ -563,6 +566,7 @@ class TORCH_API Library final { Library& operator=(const Library&) = delete; Library(Library&&) = default; Library& operator=(Library&&) = default; + ~Library() = default; // Some notes about the API design here. We had the following constraints: // diff --git a/torch/library.py b/torch/library.py index 5457f6e078148..d4224e62e456a 100644 --- a/torch/library.py +++ b/torch/library.py @@ -424,6 +424,7 @@ def _destroy(self): if not hasattr(namespace, name): continue delattr(namespace, name) + namespace._dir.remove(name) def _del_library( diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index 4962d0430992e..db808c0131330 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -1784,7 +1784,6 @@ def normalize( ) -> Tensor: if dtype is None: dtype = input.dtype - dim_ = _canonical_dim(dim, input.ndim)[0] # TODO: eliminate mask_input as unnecessary when using masked divide. mask_input = _combine_input_and_mask(sum, input, mask) if mask_input.layout == torch.strided: diff --git a/torch/masked/maskedtensor/_ops_refs.py b/torch/masked/maskedtensor/_ops_refs.py index 7344a6e801aae..d0cb64fa3c7bd 100644 --- a/torch/masked/maskedtensor/_ops_refs.py +++ b/torch/masked/maskedtensor/_ops_refs.py @@ -375,7 +375,7 @@ def ones_like(func, *args, **kwargs): @register_dispatch_func([torch.ops.aten._softmax_backward_data]) def _softmax_backward_data(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4) - grad, output, dim, input_dtype = args + grad, output, dim, _input_dtype = args if is_masked_tensor(grad) and is_masked_tensor(output): if not _masks_match(grad, output): raise ValueError( diff --git a/torch/masked/maskedtensor/binary.py b/torch/masked/maskedtensor/binary.py index 7b64cfa0fbd98..a0c024408ba4d 100644 --- a/torch/masked/maskedtensor/binary.py +++ b/torch/masked/maskedtensor/binary.py @@ -96,8 +96,8 @@ def _binary_helper(fn, args, kwargs, inplace): "Input masks must match. If you need support for this, please open an issue on Github." ) - data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data()) - mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask()) + data_args, _data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data()) + mask_args, _mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask()) args0_layout = data_args[0].layout same_layout = ( diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 366cf45eb2d50..d1cc620325933 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -3,7 +3,7 @@ import warnings from typing import Any -from typing_extensions import TypeGuard +from typing_extensions import TypeIs import torch from torch.overrides import get_default_nowrap_functions @@ -15,7 +15,7 @@ ] -def is_masked_tensor(obj: Any, /) -> TypeGuard["MaskedTensor"]: +def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]: r"""Returns True if the input is a MaskedTensor, else False Args: @@ -334,7 +334,7 @@ def get_data(self): class GetData(torch.autograd.Function): @staticmethod def forward(ctx, self): - return self._masked_data + return self._masked_data.detach() @staticmethod def backward(ctx, grad_output): diff --git a/torch/masked/maskedtensor/unary.py b/torch/masked/maskedtensor/unary.py index 790d86ef92e4c..e04ee6e810a74 100644 --- a/torch/masked/maskedtensor/unary.py +++ b/torch/masked/maskedtensor/unary.py @@ -120,8 +120,12 @@ def _unary_helper(fn, args, kwargs, inplace): "MaskedTensor unary ops do not support additional Tensor arguments" ) - mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_mask) - data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x._masked_data) + mask_args, _mask_kwargs = _map_mt_args_kwargs( + args, kwargs, lambda x: x._masked_mask + ) + data_args, _data_kwargs = _map_mt_args_kwargs( + args, kwargs, lambda x: x._masked_data + ) if args[0].layout == torch.sparse_coo: data_args[0] = data_args[0].coalesce() diff --git a/torch/multiprocessing/pool.py b/torch/multiprocessing/pool.py index 6915203566469..32a47efac0d6e 100644 --- a/torch/multiprocessing/pool.py +++ b/torch/multiprocessing/pool.py @@ -33,7 +33,7 @@ def _repopulate_pool(self): Bring the number of pool processes up to the specified number, for use after reaping workers which have exited. """ - for i in range(self._processes - len(self._pool)): + for _ in range(self._processes - len(self._pool)): # changed worker -> clean_worker args = ( self._inqueue, diff --git a/torch/nested/_internal/nested_tensor.py b/torch/nested/_internal/nested_tensor.py index d766cbd7d8bda..42b96f7bcd275 100644 --- a/torch/nested/_internal/nested_tensor.py +++ b/torch/nested/_internal/nested_tensor.py @@ -214,6 +214,18 @@ def _max_seqlen(self): def _min_seqlen(self): return self._get_min_seqlen() + # Convenience accessors that return a min / max seqlen if one is present and do NOT + # compute / cache them if they're not. + @property + def _maybe_max_seqlen(self) -> Optional[int]: + mt = self._max_seqlen_tensor + return None if mt is None else _load_val_from_tensor(mt) + + @property + def _maybe_min_seqlen(self) -> Optional[int]: + mt = self._min_seqlen_tensor + return None if mt is None else _load_val_from_tensor(mt) + def __repr__(self): # type: ignore[override] # We should implement this in torch/_tensor_str.py instead grad_fn_str = ( diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index d88498800a99b..ceeac7eaa6ef2 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -207,6 +207,17 @@ def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]: # Handle pointwise fallbacks if torch.Tag.pointwise in func.tags: + from torch.fx.experimental.symbolic_shapes import is_nested_int + + # No pointwise ops legitimately accept nested int inputs. Without this check, + # they will be incorrectly interpreted as tensors. + # See https://github.com/pytorch/pytorch/issues/138496 + for arg in args: + if is_nested_int(arg): + raise RuntimeError( + f"NestedTensor {func.__name__}: invalid argument {arg}" + ) + # Assume there aren't additional tensors that aren't the "unary/binary" args num_tensor_args = sum(isinstance(x, torch.Tensor) for x in args) if num_tensor_args == 1: @@ -293,17 +304,11 @@ def jagged_binary_pointwise(func, *args, **kwargs): mismatch_error_msg.format(func.__name__, a.shape, b.shape) ) - from .nested_tensor import _load_val_from_tensor, nested_from_padded + from .nested_tensor import nested_from_padded # handle broadcasting via padded dense -> jagged conversion - min_seqlen = None - if nt._min_seqlen_tensor is not None: - min_seqlen = _load_val_from_tensor(nt._min_seqlen_tensor) - - max_seqlen = None - if nt._max_seqlen_tensor is not None: - max_seqlen = _load_val_from_tensor(nt._max_seqlen_tensor) - + min_seqlen = nt._maybe_min_seqlen + max_seqlen = nt._maybe_max_seqlen padded_max_S = max_seqlen total_L = nt._values.shape[nt._ragged_idx - 1] if padded_max_S is None: @@ -670,7 +675,7 @@ def _softmax_default(func, *args, **kwargs): new_kwargs["dim"], reduce_on_batch, reduce_on_ragged, - reduce_on_non_batch, + _reduce_on_non_batch, ) = _wrap_jagged_dims( inp.dim(), (new_kwargs["dim"],), @@ -975,7 +980,7 @@ def cat_default(func, *args, **kwargs): ) -@register_jagged_func(torch.ops.aten.matmul.default, "self: jt, other: any") +@register_jagged_func(torch.ops.aten.matmul.default, "self: jt_all, other: any") def matmul_default(func, *args, **kwargs): _, new_kwargs = normalize_function( # type: ignore[misc] func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True @@ -984,20 +989,95 @@ def matmul_default(func, *args, **kwargs): inp = new_kwargs.pop("input") other = new_kwargs.pop("other") - if inp.is_nested and not other.is_nested: - return NestedTensor( - func(inp._values, other, **new_kwargs), **extract_kwargs(inp) + def _unbind_impl(a, b): + return [ + func(a_comp, b_comp) for (a_comp, b_comp) in zip(a.unbind(), b.unbind()) + ] + + def _padded_impl(a, b): + assert a.is_nested and not b.is_nested + nt = a + + from .nested_tensor import nested_from_padded + + min_seqlen = nt._maybe_min_seqlen + max_seqlen = nt._maybe_max_seqlen + padded_max_S = max_seqlen + total_L = nt._values.shape[nt._ragged_idx - 1] + if padded_max_S is None: + # use upper bound on max seqlen if it's not present + padded_max_S = total_L + + padded_shape = ( + *nt.shape[: nt._ragged_idx], + padded_max_S, + *nt.shape[nt._ragged_idx + 1 :], ) + padded_nt = nt.to_padded_tensor(0.0, output_size=padded_shape) + return nested_from_padded( + func(padded_nt, b), + offsets=nt._offsets, + ragged_idx=nt._ragged_idx, + sum_S=total_L, + min_seqlen=min_seqlen, + max_seqlen=max_seqlen, + ) + + # TODO: Back these with proper kernels (e.g. grouped GEMM) + # NJT x dense + if inp.is_nested and not other.is_nested: + # (B, j1, D) x (B, D, E) => (B, j1, E) + if inp.dim() >= 3 and inp.dim() == other.dim(): + # convert to padded for this + return _padded_impl(inp, other) + # Support broadcasting the dense: + # (B, j1, D) x (D, E) => (B, j1, E) + # (B, j1, D, E) x (E, F) => (B, j1, D, F) + # etc. + elif other.dim() == 2 and inp.dim() > other.dim(): + return NestedTensor( + func(inp._values, other, **new_kwargs), **extract_kwargs(inp) + ) + # NJT x NJT elif inp.is_nested and other.is_nested: - # BMM with equivalent ragged dims between the two inputs + # Support ragged batch dim: + # (B, j1, D, E) x (B, j1, E, F) => (B, j1, D, F), etc. if inp.dim() > 3 and other.dim() > 3 and raggedness_matches(inp, other._size): return NestedTensor(func(inp._values, other._values), **extract_kwargs(inp)) + # Support reducing over ragged with dense output: + # (B, D, j1) x (B, j1, E) => (B, D, E) + elif ( + inp.dim() == 3 + and other.dim() == 3 + and inp._ragged_idx == 2 + and other._ragged_idx == 1 + and inp.size(inp._ragged_idx) == other.size(other._ragged_idx) + ): + # do unbind for this; can't use padded conversion due to j1 in last dim + return torch.stack(_unbind_impl(inp, other)) raise RuntimeError( f"matmul(): not supported between inputs of shapes {inp._size} and {other.shape}" ) +@register_jagged_func(torch.ops.aten.bmm.default, "self: jt_all, mat2: any") +def bmm_default(func, *args, **kwargs): + _, new_kwargs = normalize_function( # type: ignore[misc] + func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True + ) + + inp = new_kwargs.pop("input") + other = new_kwargs.pop("mat2") + + if inp.dim() != 3: + raise ValueError("bmm(): input must be 3D") + if other.dim() != 3: + raise ValueError("bmm(): mat2 must be 3D") + + return matmul_default(torch.ops.aten.matmul.default, inp, other) + + @register_jagged_func( torch.ops.aten.expand.default, "self: jt, size: any, implicit: any?" ) @@ -1519,7 +1599,7 @@ def mean_dim(func, *args, **kwargs): new_kwargs["dim"], reduce_on_batch, reduce_on_ragged, - reduce_on_non_batch, + _reduce_on_non_batch, ) = _wrap_jagged_dims( inp.dim(), new_kwargs["dim"], @@ -1871,6 +1951,17 @@ def _nested_select_backward_default(func, *args, **kwargs): return grad_input +@register_jagged_func(torch.ops.aten.record_stream.default, "self: jt_all, s: any") +def record_stream_default(func, *args, **kwargs): + inp = args[0] + stream = args[1] + # ensure all components live until stream computation completes + func(inp._values, stream) + func(inp._offsets, stream) + if inp._lengths is not None: + func(inp._lengths, stream) + + # Make the dummy available on the C++ side. @register_jagged_func(torch.ops.aten._nested_get_jagged_dummy.default, "self: any") def _nested_get_jagged_dummy(func, *args, **kwargs): diff --git a/torch/nested/_internal/sdpa.py b/torch/nested/_internal/sdpa.py index 578904af94697..f65535033890b 100644 --- a/torch/nested/_internal/sdpa.py +++ b/torch/nested/_internal/sdpa.py @@ -323,7 +323,6 @@ def _cumulative_and_max_seq_len_nnz(qkv: torch.Tensor) -> Tuple[torch.Tensor, in cumulative_seqlen = ( qkv.lengths().cumsum(0).to(dtype=torch.int32, device=qkv.device) ) - batch_size = qkv.size(0) max_seqlen = qkv._get_max_seqlen() # TODO: Explore performance impact when compiling n_elem = int(cumulative_seqlen[-1].item()) @@ -568,8 +567,8 @@ def _sdpa_nested_preprocessing(query, key, value): output_nt_info = { "offsets": q_t.offsets(), - "_max_seqlen": q_t._get_max_seqlen(), - "_min_seqlen": q_t._get_min_seqlen(), + "max_seqlen": q_t._get_max_seqlen(), + "min_seqlen": q_t._get_min_seqlen(), } return ( @@ -710,7 +709,12 @@ def jagged_scaled_dot_product_attention( is_causal=is_causal, scale=scale, ) - return nested_view_from_values_offsets(output, query.offsets()) + return nested_view_from_values_offsets( + output, + query.offsets(), + min_seqlen=query._maybe_min_seqlen, # type: ignore[attr-defined] + max_seqlen=query._maybe_max_seqlen, # type: ignore[attr-defined] + ) compute_logsumexp = query.requires_grad or key.requires_grad or value.requires_grad @@ -745,10 +749,10 @@ def jagged_scaled_dot_product_attention( ( attention, - logsumexp, - philox_seed, - philox_offset, - debug_attn_mask, + _logsumexp, + _philox_seed, + _philox_offset, + _debug_attn_mask, ) = torch.ops.aten._flash_attention_forward( query_buffer_reshaped, key_buffer_reshaped, @@ -766,9 +770,7 @@ def jagged_scaled_dot_product_attention( # Reshape output to convert nnz to batch_size and seq_len attention = nested_view_from_values_offsets( attention, # output from flash_attn is [total_q, num_heads, head_size_og] - output_nt_info["offsets"], - min_seqlen=output_nt_info["_min_seqlen"], - max_seqlen=output_nt_info["_max_seqlen"], + **output_nt_info, ).transpose(1, 2) return _post_process_flash_output(attention, og_size) elif backend_choice == SDPBackend.EFFICIENT_ATTENTION: @@ -807,25 +809,18 @@ def jagged_scaled_dot_product_attention( # Reshape output to convert nnz to batch_size and seq_len return nested_view_from_values_offsets( attention.squeeze(0), - output_nt_info["offsets"], - min_seqlen=output_nt_info["_min_seqlen"], - max_seqlen=output_nt_info["_max_seqlen"], + **output_nt_info, ).transpose(1, 2) elif backend_choice == SDPBackend.MATH: # save the offsets and shape of the inputs, so we can reshape the final output # query @ key = attn: [B, D1, j0, D'] @ [B, D1, D' j1] = [B, D1, j0, j1] # attn @ value = out: [B, D1, j0, j1] @ [B, D1, j1, D2] = [B, D1, j0, D2] offsets = query.offsets() + min_seqlen = query._maybe_min_seqlen + max_seqlen = query._maybe_max_seqlen d1 = query._size[1] d2 = value._size[-1] - min_seqlen_tensor = query._metadata_cache.get( - "min_seqlen", None - ) # type: ignore[attr-defined] - max_seqlen_tensor = query._metadata_cache.get( - "max_seqlen", None - ) # type: ignore[attr-defined] - # convert jagged layout Nested Tensor to strided layout Nested Tensor # which support the math implementation of SDPA def get_strided_layout_nested_tensor(jagged_layout_nt): @@ -844,24 +839,14 @@ def get_strided_layout_nested_tensor(jagged_layout_nt): query, key, value, attn_mask, dropout_p, is_causal, scale=scale )[0] - from torch.nested._internal.nested_tensor import _load_val_from_tensor - # convert strided layout Nested Tensor back to jagged layout Nested Tensor attn_out = attn_out.transpose(1, 2).contiguous().values() attn_out = attn_out.view(-1, d1, d2) attn_out = nested_view_from_values_offsets( attn_out, offsets, - min_seqlen=( - None - if min_seqlen_tensor is None - else _load_val_from_tensor(min_seqlen_tensor) - ), - max_seqlen=( - None - if max_seqlen_tensor is None - else _load_val_from_tensor(max_seqlen_tensor) - ), + min_seqlen=min_seqlen, + max_seqlen=max_seqlen, ).transpose(1, 2) return attn_out diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index d25511a65e7ba..618a4aa00e5ea 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -7,16 +7,14 @@ import itertools import math import operator -from contextlib import nullcontext +import warnings from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch from torch import Tensor -from torch._higher_order_ops.flex_attention import ( - flex_attention as flex_attention_hop, - TransformGetItemToIndex, -) +from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex +from torch._higher_order_ops.flex_attention import flex_attention as flex_attention_hop from torch._higher_order_ops.utils import _set_compilation_env from torch.fx.experimental.proxy_tensor import ( _temp_remove_metadata_torch_function_mode, @@ -195,8 +193,6 @@ def _adjust_num_blocks_and_indices( new_num_rows: int, new_num_cols: int, ): - num_rows = indices.shape[-2] - num_columns = indices.shape[-1] indices = indices[:, :, :new_num_rows, :new_num_cols] num_blocks = num_blocks[:, :, :new_num_rows] num_blocks = torch.where(num_blocks < new_num_cols, num_blocks, new_num_cols) @@ -658,6 +654,19 @@ def _convert_mask_to_block_mask( ) -> Tuple[Tensor, Optional[Tensor]]: assert mask.dtype == torch.bool mask = _broadcast_to_dim(mask, 4) + + def padding_needed_for_multiple(x, multiple): + return _round_up_to_multiple(x, multiple) - x + + mask = torch.nn.functional.pad( + mask, + ( + 0, + padding_needed_for_multiple(mask.shape[-1], KV_BLOCK_SIZE), + 0, + padding_needed_for_multiple(mask.shape[-2], Q_BLOCK_SIZE), + ), + ) B, H, Q, KV = mask.shape assert Q % Q_BLOCK_SIZE == 0 assert KV % KV_BLOCK_SIZE == 0 @@ -756,7 +765,6 @@ def create_mask( Q_LEN: int, KV_LEN: int, device: str = "cuda", - _compile: bool = False, ) -> Tensor: r"""This function creates a mask tensor from a mod_fn function. @@ -779,15 +787,9 @@ def create_mask( h = torch.arange(0, H, device=device) m = torch.arange(0, Q_LEN, device=device) n = torch.arange(0, KV_LEN, device=device) - # TODO: fix this - # Lack instantiation support for __torch_function__ mode support under compile - if _compile: - ctx = nullcontext() - else: - ctx = TransformGetItemToIndex() # type: ignore[assignment] mod_type = _get_mod_type(mod_fn) - with ctx: + with TransformGetItemToIndex(): if mod_type == _ModificationType.SCORE_MOD: score_mod = mod_fn score_mod = _vmap_for_bhqkv(score_mod, prefix=(0,)) # first input is score @@ -803,30 +805,6 @@ def create_mask( raise AssertionError -def _create_block_mask_inner( - mask_mod: Callable, - B: int, - H: int, - Q_LEN: int, - KV_LEN: int, - device: str, - Q_BLOCK_SIZE: int, - KV_BLOCK_SIZE: int, -): - r"""Work around for being unable to instantiate __torch_function__ mode under compile. - `create_block_mask` will compile this inner function and wrap the call to this - with the __torch_function__ mode. - """ - mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device, _compile=True) - partial_block_mask, full_block_mask = _convert_mask_to_block_mask( - mask_tensor, - Q_BLOCK_SIZE=Q_BLOCK_SIZE, - KV_BLOCK_SIZE=KV_BLOCK_SIZE, - separate_full_blocks=True, - ) - return partial_block_mask, full_block_mask - - def create_block_mask( mask_mod: _mask_mod_signature, B: Optional[int], @@ -851,7 +829,6 @@ def create_block_mask( KV_LEN (int): Sequence length of key/value. device (str): Device to run the mask creation on. BLOCK_SIZE (int or Tuple[int, int]): Block size for the block mask. If a single int is provided it is used for both query and key/value. - _compile (bool): Whether to compile the mask_mod function. Default is False. Returns: BlockMask: A BlockMask object that contains the block mask information. @@ -872,7 +849,6 @@ def causal_mask(b, h, q_idx, kv_idx): assert ( mod_type == _ModificationType.MASK_MOD ), f"create-block_mask requires a mask_mod function! Got {mask_mod}" - inner_func = _create_block_mask_inner if B is None: B = 1 if H is None: @@ -883,20 +859,25 @@ def causal_mask(b, h, q_idx, kv_idx): else: Q_BLOCK_SIZE, KV_BLOCK_SIZE = BLOCK_SIZE - if Q_LEN < 128: - Q_BLOCK_SIZE = Q_LEN - else: - Q_LEN = _round_up_to_multiple(Q_LEN, Q_BLOCK_SIZE) - KV_LEN = _round_up_to_multiple(KV_LEN, KV_BLOCK_SIZE) if _compile: - inner_func = torch.compile(inner_func, fullgraph=True) - with TransformGetItemToIndex(): - partial_block_mask, full_block_mask = inner_func( - mask_mod, B, H, Q_LEN, KV_LEN, device, Q_BLOCK_SIZE, KV_BLOCK_SIZE + warnings.warn( + "_compile flag on create_block_mask was originally added to work around a torch.compile limitation. That limitation has since been addressed. So, to compile create_block_mask, we suggest doing torch.compile(create_block_mask). This still works for now, but will be removed in the future.", + DeprecationWarning, ) - block_mask = _create_sparse_block_from_block_mask( - (partial_block_mask, full_block_mask), mask_mod, Q_BLOCK_SIZE, KV_BLOCK_SIZE + return torch.compile(create_block_mask)( + mask_mod, B, H, Q_LEN, KV_LEN, device, BLOCK_SIZE ) + + mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device) + partial_block_mask, full_block_mask = _convert_mask_to_block_mask( + mask_tensor, + Q_BLOCK_SIZE=Q_BLOCK_SIZE, + KV_BLOCK_SIZE=KV_BLOCK_SIZE, + separate_full_blocks=True, + ) + block_mask = _create_sparse_block_from_block_mask( + (partial_block_mask, full_block_mask), mask_mod, Q_BLOCK_SIZE, KV_BLOCK_SIZE + ) return block_mask diff --git a/torch/nn/functional.py b/torch/nn/functional.py index aeaf82682d62d..22b5a46f62e1b 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -6236,7 +6236,7 @@ def multi_head_attention_forward( # if need_weights: - B, Nt, E = q.shape + _B, _Nt, E = q.shape q_scaled = q * math.sqrt(1.0 / float(E)) assert not ( diff --git a/torch/nn/modules/_functions.py b/torch/nn/modules/_functions.py index 847afcef4da2e..dd66c2b323c81 100644 --- a/torch/nn/modules/_functions.py +++ b/torch/nn/modules/_functions.py @@ -224,11 +224,7 @@ def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): ctx.scale = ctx.scale or input.new() output = input.new() - - batch_size = input.size(0) channels = input.size(1) - input_height = input.size(2) - input_width = input.size(3) output.resize_as_(input) ctx.scale.resize_as_(input) diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index f4796e50e415a..dc3a85a03d3ca 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -2630,7 +2630,7 @@ def parameters(self, recurse: bool = True) -> Iterator[Parameter]: (20L, 1L, 5L, 5L) """ - for name, param in self.named_parameters(recurse=recurse): + for _name, param in self.named_parameters(recurse=recurse): yield param def named_parameters( @@ -2725,7 +2725,7 @@ def children(self) -> Iterator["Module"]: Yields: Module: a child module """ - for name, module in self.named_children(): + for _name, module in self.named_children(): yield module def named_children(self) -> Iterator[Tuple[str, "Module"]]: diff --git a/torch/nn/modules/rnn.py b/torch/nn/modules/rnn.py index caadd5bc8e427..b2636fc8966af 100644 --- a/torch/nn/modules/rnn.py +++ b/torch/nn/modules/rnn.py @@ -1043,7 +1043,6 @@ def forward(self, input, hx=None): # noqa: F811 orig_input = input # xxx: isinstance check needs to be in conditional for TorchScript to compile batch_sizes = None - do_permute = False num_directions = 2 if self.bidirectional else 1 real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size if isinstance(orig_input, PackedSequence): diff --git a/torch/nn/modules/transformer.py b/torch/nn/modules/transformer.py index 984329ebd2e55..0f7274c540001 100644 --- a/torch/nn/modules/transformer.py +++ b/torch/nn/modules/transformer.py @@ -34,10 +34,6 @@ def _generate_square_subsequent_mask( The masked positions are filled with float('-inf'). Unmasked positions are filled with float(0.0). """ - if device is None: - device = torch.device("cpu") - if dtype is None: - dtype = torch.float32 return torch.triu( torch.full((sz, sz), float("-inf"), dtype=dtype, device=device), diagonal=1, diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py index 21119de4459c0..aad7e6c5402cf 100644 --- a/torch/nn/parallel/distributed.py +++ b/torch/nn/parallel/distributed.py @@ -1496,7 +1496,7 @@ def _pre_forward(self, *inputs, **kwargs): # Disable the python reducer if compiled_autograd is not enabled. if self._accum_grad_hooks: - for index, h in enumerate(self._accum_grad_hooks): + for h in self._accum_grad_hooks: h.remove() self._accum_grad_hooks.clear() diff --git a/torch/nn/parameter.pyi b/torch/nn/parameter.pyi index 9c998fb07f2c1..6b5afa860b863 100644 --- a/torch/nn/parameter.pyi +++ b/torch/nn/parameter.pyi @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing_extensions import TypeGuard +from typing_extensions import TypeIs from torch import device, dtype, Tensor @@ -8,7 +8,7 @@ class Parameter(Tensor): def is_lazy( param: Tensor, -) -> TypeGuard[UninitializedParameter | UninitializedBuffer]: ... +) -> TypeIs[UninitializedParameter | UninitializedBuffer]: ... class UninitializedParameter(Tensor): def __init__(self, data: Tensor = ..., requires_grad: bool = ...) -> None: ... diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 8c4e0a459b5ed..b15a45a4d17bf 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -24,14 +24,12 @@ "symbolic_opset19", "symbolic_opset20", # Enums - "ExportTypes", "OperatorExportTypes", "TrainingMode", "TensorProtoDataType", "JitScalarType", # Public functions "export", - "export_to_pretty_string", "is_in_onnx_export", "select_model_mode_for_export", "register_custom_op_symbolic", @@ -57,7 +55,6 @@ from torch._C import _onnx as _C_onnx from torch._C._onnx import OperatorExportTypes, TensorProtoDataType, TrainingMode -from ._exporter_states import ExportTypes from ._internal.exporter._onnx_program import ONNXProgram from ._internal.onnxruntime import ( is_onnxrt_backend_supported, @@ -70,7 +67,6 @@ from .utils import ( _run_symbolic_function, _run_symbolic_method, - export_to_pretty_string, is_in_onnx_export, register_custom_op_symbolic, select_model_mode_for_export, @@ -115,7 +111,6 @@ # Set namespace for exposed private names DiagnosticOptions.__module__ = "torch.onnx" ExportOptions.__module__ = "torch.onnx" -ExportTypes.__module__ = "torch.onnx" JitScalarType.__module__ = "torch.onnx" ONNXProgram.__module__ = "torch.onnx" ONNXRuntimeOptions.__module__ = "torch.onnx" @@ -154,6 +149,7 @@ def export( external_data: bool = True, dynamic_shapes: dict[str, Any] | tuple[Any, ...] | list[Any] | None = None, report: bool = False, + optimize: bool = False, verify: bool = False, profile: bool = False, dump_exported_program: bool = False, @@ -285,6 +281,7 @@ def forward(self, x): Only one parameter `dynamic_axes` or `dynamic_shapes` should be set at the same time. report: Whether to generate a markdown report for the export process. + optimize: Whether to optimize the exported model. verify: Whether to verify the exported model using ONNX Runtime. profile: Whether to profile the export process. dump_exported_program: Whether to dump the :class:`torch.export.ExportedProgram` to a file. @@ -295,7 +292,7 @@ def forward(self, x): training: Deprecated option. Instead, set the training mode of the model before exporting. operator_export_type: Deprecated option. Only ONNX is supported. - do_constant_folding: Deprecated option. The exported graph is always optimized. + do_constant_folding: Deprecated option. custom_opsets: Deprecated. A dictionary: @@ -354,6 +351,7 @@ def forward(self, x): external_data=external_data, dynamic_shapes=dynamic_shapes, report=report, + optimize=optimize, verify=verify, profile=profile, dump_exported_program=dump_exported_program, diff --git a/torch/onnx/_internal/_exporter_legacy.py b/torch/onnx/_internal/_exporter_legacy.py index 0a6c80616bc9c..0eeed0f6dbc76 100644 --- a/torch/onnx/_internal/_exporter_legacy.py +++ b/torch/onnx/_internal/_exporter_legacy.py @@ -126,7 +126,7 @@ def _initiate_registry_from_torchlib(self) -> None: Args: torchlib_registry: The torchlib registry to use for populating the registry. """ - import onnxscript._framework_apis.torch_2_5 as onnxscript_apis + import onnxscript._framework_apis.torch_2_6 as onnxscript_apis for meta in onnxscript_apis.get_torchlib_ops(): internal_name_instance = registration.OpName.from_qualified_name( diff --git a/torch/onnx/_internal/_lazy_import.py b/torch/onnx/_internal/_lazy_import.py index b12e53ef29262..b0c23abd31bcd 100644 --- a/torch/onnx/_internal/_lazy_import.py +++ b/torch/onnx/_internal/_lazy_import.py @@ -30,7 +30,7 @@ def __getattr__(self, attr): if TYPE_CHECKING: import onnx import onnxscript - import onnxscript._framework_apis.torch_2_5 as onnxscript_apis + import onnxscript._framework_apis.torch_2_6 as onnxscript_apis onnxscript_ir = onnxscript.ir @@ -38,4 +38,4 @@ def __getattr__(self, attr): onnx = _LazyModule("onnx") onnxscript = _LazyModule("onnxscript") onnxscript_ir = _LazyModule("onnxscript.ir") - onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_5") + onnxscript_apis = _LazyModule("onnxscript._framework_apis.torch_2_6") diff --git a/torch/onnx/_internal/exporter/_building.py b/torch/onnx/_internal/exporter/_building.py index 7729aa000b281..234ef5486a5ae 100644 --- a/torch/onnx/_internal/exporter/_building.py +++ b/torch/onnx/_internal/exporter/_building.py @@ -207,6 +207,8 @@ def _determine_input_dtype( return ir.DataType.STRING if isinstance(arg, (ir.Tensor, ir.TensorProtocol)): return arg.dtype + if isinstance(arg, complex): + return ir.DataType.FLOAT if arg is None: return ir.DataType.UNDEFINED @@ -261,9 +263,15 @@ def _get_or_create_constant( dtype: ir.DataType, opset: onnxscript.values.Opset, ) -> ir.Value: + # float representation of complex numbers + if isinstance(arg, complex): + # Convert the complex number to a float + arg = (arg.real, arg.imag) + if isinstance(arg, list): # Make the arg hashable arg = tuple(arg) # type: ignore[assignment] + constant_value = constant_farm.get((arg, dtype)) # type: ignore[arg-type] if constant_value is None: constant_tensor = ir.tensor(value=arg, dtype=dtype) # type: ignore[arg-type] @@ -412,7 +420,7 @@ def _process_python_sequences( # when the expected input type is INT64 # We assume this only happens for 1D cases if all(isinstance(val, ir.Value) for val in arg): - named_inputs[name] = opset.Concat(*arg) + named_inputs[name] = opset.Concat(*arg, axis=0) continue dtype = _determine_input_dtype(param, arg, type_binding) @@ -423,7 +431,7 @@ def _process_python_sequences( elif val is None: # Skip None values continue - elif isinstance(arg, (ir.Tensor, ir.TensorProtocol)): + elif isinstance(val, (ir.Tensor, ir.TensorProtocol)): new_args.append(opset.Constant(value=val)) else: # Turn the Python constant into 1D tensor for the constant @@ -431,9 +439,9 @@ def _process_python_sequences( val, (bool, int, float) ), f"Expected int or float, got {type(val)}" new_args.append( - _get_or_create_constant(constant_farm, [arg], dtype, opset) # type: ignore[arg-type] + _get_or_create_constant(constant_farm, [val], dtype, opset) # type: ignore[arg-type] ) - named_inputs[name] = opset.Concat(*new_args) + named_inputs[name] = opset.Concat(*new_args, axis=0) continue return named_inputs diff --git a/torch/onnx/_internal/exporter/_capture_strategies.py b/torch/onnx/_internal/exporter/_capture_strategies.py index 0446e03adc7f4..b29f403ec941a 100644 --- a/torch/onnx/_internal/exporter/_capture_strategies.py +++ b/torch/onnx/_internal/exporter/_capture_strategies.py @@ -4,8 +4,10 @@ from __future__ import annotations import abc +import contextlib import dataclasses import datetime +import logging import pathlib from typing import Any, Callable, TYPE_CHECKING @@ -17,6 +19,9 @@ import os +logger = logging.getLogger(__name__) + + def _verbose_printer(verbose: bool | None) -> Callable[..., None]: """Prints messages based on `verbose`.""" if verbose is False: @@ -33,6 +38,22 @@ def _take_first_line(text: str) -> str: return first_line +@contextlib.contextmanager +def _patch_dynamo_unsupported_functions(): + """Patch PyTorch to bypass some functions torch.export.export does not support.""" + # TODO: Remove the patches once dynamo supports these functions. + import torch.jit + + # Replace torch.jit.isinstance with isinstance + jit_isinstance = torch.jit.isinstance + torch.jit.isinstance = isinstance + logger.info("Replaced torch.jit.isinstance with isinstance to allow dynamo tracing") + try: + yield + finally: + torch.jit.isinstance = jit_isinstance + + @dataclasses.dataclass class Result: exported_program: torch.export.ExportedProgram | None @@ -119,22 +140,23 @@ class TorchExportStrategy(CaptureStrategy): def _capture( self, model, args, kwargs, dynamic_shapes ) -> torch.export.ExportedProgram: - try: - return torch.export.export( - model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes - ) - except torch._dynamo.exc.UserError as exc: - # Refine the dynamic shapes based on the suggested fixes. + with _patch_dynamo_unsupported_functions(): try: - new_shapes = torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( - exc.msg, dynamic_shapes + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes + ) + except torch._dynamo.exc.UserError as exc: + # Refine the dynamic shapes based on the suggested fixes. + try: + new_shapes = torch.export.dynamic_shapes.refine_dynamic_shapes_from_suggested_fixes( + exc.msg, dynamic_shapes + ) + except Exception: + # If the dynamic shapes cannot be refined, re-raise the exception. + raise exc from None + return torch.export.export( + model, args, kwargs=kwargs, dynamic_shapes=new_shapes ) - except Exception: - # If the dynamic shapes cannot be refined, re-raise the exception. - raise exc from None - return torch.export.export( - model, args, kwargs=kwargs, dynamic_shapes=new_shapes - ) def _enter(self, model) -> None: model_repr = _take_first_line(repr(model)) diff --git a/torch/onnx/_internal/exporter/_compat.py b/torch/onnx/_internal/exporter/_compat.py index 3fddef36b8b42..411384f1d360d 100644 --- a/torch/onnx/_internal/exporter/_compat.py +++ b/torch/onnx/_internal/exporter/_compat.py @@ -132,6 +132,7 @@ def export_compat( keep_initializers_as_inputs: bool = False, external_data: bool = True, report: bool = False, + optimize: bool = False, verify: bool = False, profile: bool = False, dump_exported_program: bool = False, @@ -196,6 +197,11 @@ def export_compat( keep_initializers_as_inputs=keep_initializers_as_inputs, ) onnx_program = _onnx_program.ONNXProgram(ir.load(f), None) + + # NOTE: It it's falling back to the legacy exporter, we don't need to + # optimize the model, so we return it here. Users can still optimize + # the model using the optimize() if they want. + return onnx_program else: raise @@ -203,7 +209,8 @@ def export_compat( onnx_program.model = onnxscript_apis.convert_version( onnx_program.model, opset_version ) - onnx_program.model = onnxscript_apis.optimize(onnx_program.model) + if optimize: + onnx_program.optimize() if f is not None: onnx_program.save( diff --git a/torch/onnx/_internal/exporter/_decomp.py b/torch/onnx/_internal/exporter/_decomp.py index de0dd0c0bcb73..66ea959722689 100644 --- a/torch/onnx/_internal/exporter/_decomp.py +++ b/torch/onnx/_internal/exporter/_decomp.py @@ -58,7 +58,7 @@ def create_onnx_friendly_decomposition_table( decomposition_table: dict[torch._ops.OperatorBase, Callable] = {} for op_overload, decomp_fn in itertools.chain( - torch._decomp._decomp_table_to_post_autograd_aten().items(), # type: ignore[attr-defined] + torch._export.utils._decomp_table_to_post_autograd_aten().items(), # type: ignore[attr-defined] torch._decomp.decomposition_table.items(), # type: ignore[attr-defined] ): # Skip decomposition for op_overload as long as that op_overload has a corresponding ONNX diff --git a/torch/onnx/_internal/exporter/_onnx_program.py b/torch/onnx/_internal/exporter/_onnx_program.py index 3f423f787a723..4a0fad5506aaf 100644 --- a/torch/onnx/_internal/exporter/_onnx_program.py +++ b/torch/onnx/_internal/exporter/_onnx_program.py @@ -114,6 +114,14 @@ def model_proto(self) -> onnx.ModelProto: """Return the ONNX ``ModelProto`` object.""" return ir.serde.serialize_model(self.model) + def optimize(self) -> None: + """Optimize the ONNX model. + + This method optimizes the ONNX model by performing constant folding and + eliminating redundancies in the graph. The optimization is done in-place. + """ + self.model = onnxscript_apis.optimize(self.model) + def save( self, destination: str | os.PathLike, diff --git a/torch/onnx/_internal/exporter/_testing.py b/torch/onnx/_internal/exporter/_testing.py index 19f0c73734839..5860256599db3 100644 --- a/torch/onnx/_internal/exporter/_testing.py +++ b/torch/onnx/_internal/exporter/_testing.py @@ -54,6 +54,11 @@ def assert_onnx_program( kwargs = {} torch_module = exported_program.module() torch_outputs, _ = _pytree.tree_flatten(torch_module(*args, **kwargs)) + # ONNX outputs are always real, so we need to convert torch complex outputs to real representations + torch_outputs = [ + torch.view_as_real(output) if torch.is_complex(output) else output + for output in torch_outputs + ] onnx_outputs = program(*args, **kwargs) # TODO(justinchuby): Include output names in the error message torch.testing.assert_close( diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py index 54af3142cc230..0a98cb32ceda5 100644 --- a/torch/onnx/_internal/fx/dynamo_graph_extractor.py +++ b/torch/onnx/_internal/fx/dynamo_graph_extractor.py @@ -68,7 +68,7 @@ def register_pytree_node( def _register_huggingface_model_output_extension(self): try: from transformers import modeling_outputs # type: ignore[import] - except ImportError as e: + except ImportError: return def model_output_flatten( diff --git a/torch/onnx/_internal/fx/passes/_utils.py b/torch/onnx/_internal/fx/passes/_utils.py index 853557362e049..a7b05786ab171 100644 --- a/torch/onnx/_internal/fx/passes/_utils.py +++ b/torch/onnx/_internal/fx/passes/_utils.py @@ -61,7 +61,6 @@ def set_node_name( new_name: The new name to use. name_to_node_cache: A cache of node names to nodes. """ - module = node.graph.owning_module node_name_to_set = collections.deque([(node, new_name)]) while node_name_to_set: diff --git a/torch/onnx/_internal/fx/passes/functionalization.py b/torch/onnx/_internal/fx/passes/functionalization.py index 3b68de48080c6..14455546411f8 100644 --- a/torch/onnx/_internal/fx/passes/functionalization.py +++ b/torch/onnx/_internal/fx/passes/functionalization.py @@ -84,12 +84,11 @@ def wrapped(*inputs): out = function(*inputs_functional) finally: torch._disable_functionalization() - flat_inputs = pytree.tree_leaves(inputs) + flat_inputs_functional = pytree.tree_leaves(inputs_functional) - for inpt, input_functional in zip(flat_inputs, flat_inputs_functional): + for input_functional in flat_inputs_functional: if isinstance(input_functional, torch.Tensor): torch._sync(input_functional) - inpt_new = torch._from_functional_tensor(input_functional) pytree.tree_map(torch._sync, out) out_unwrapped = pytree.tree_map(torch._from_functional_tensor, out) return out_unwrapped diff --git a/torch/onnx/_internal/fx/passes/modularization.py b/torch/onnx/_internal/fx/passes/modularization.py index e1ec411aea19e..f729a7b60d35e 100644 --- a/torch/onnx/_internal/fx/passes/modularization.py +++ b/torch/onnx/_internal/fx/passes/modularization.py @@ -139,8 +139,8 @@ def from_dynamo_produced_raw_meta( cls, raw_meta: _DYNAMO_NN_MODULE_META_TYPE ) -> _ModuleMeta: """Create a module meta from raw meta produced by FX dynamo tracer.""" - module_name, (qualified_name, module_class) = raw_meta - return _ModuleMeta(module_name, module_class, raw_meta) + module_name, (_qualified_name, module_class) = raw_meta + return _ModuleMeta(module_name.split("@")[0], module_class, raw_meta) @classmethod def from_raw_meta( diff --git a/torch/onnx/_internal/fx/passes/type_promotion.py b/torch/onnx/_internal/fx/passes/type_promotion.py index 81cb6ccb7439d..1113395b51782 100644 --- a/torch/onnx/_internal/fx/passes/type_promotion.py +++ b/torch/onnx/_internal/fx/passes/type_promotion.py @@ -43,7 +43,7 @@ def _try_getclosurevars(func): try: return inspect.getclosurevars(func) - except TypeError as e: + except TypeError: return None diff --git a/torch/onnx/_internal/io_adapter.py b/torch/onnx/_internal/io_adapter.py index 16c1313a2d5a2..7334c79620de4 100644 --- a/torch/onnx/_internal/io_adapter.py +++ b/torch/onnx/_internal/io_adapter.py @@ -136,15 +136,35 @@ def apply( # TODO: make_fx lose stack info https://github.com/pytorch/pytorch/issues/90276 -def _replace_tuple_with_list(spec: pytree.TreeSpec) -> pytree.TreeSpec: - _type = list if spec.type == tuple else spec.type - return pytree.TreeSpec( - _type, spec.context, list(map(_replace_tuple_with_list, spec.children_specs)) +# TODO(XuehaiPan): Dynamo does not support `dummy_leaf = object()` as a sentinel value in the frame. +class _DummyLeaf: # use a class instead. + pass + + +def _replace_list_with_tuple(spec: pytree.TreeSpec) -> pytree.TreeSpec: + def replace_list_with_tuple(x: Any) -> Any: + if type(x) is list: + return pytree.tree_map( + replace_list_with_tuple, + tuple(x), + is_leaf=lambda x: type(x) is list, + ) + return x + + dummy_leaf = _DummyLeaf() + dummy_tree = pytree.tree_unflatten([dummy_leaf] * spec.num_leaves, spec) + dummy_tree = pytree.tree_map( + replace_list_with_tuple, + dummy_tree, + is_leaf=lambda x: type(x) is list, ) + return pytree.tree_structure(dummy_tree) -def _open_top_level_list_if_single_element(spec: pytree.TreeSpec) -> pytree.TreeSpec: - if spec.type == list and spec.num_children == 1: +def _open_top_level_sequence_if_single_element( + spec: pytree.TreeSpec, +) -> pytree.TreeSpec: + if spec.type in (tuple, list) and spec.num_children == 1: return spec.children_specs[0] return spec @@ -167,10 +187,10 @@ def _assert_identical_pytree_spec( pass_if_any_checks: Sequence[Callable[[], bool]] = [ lambda: spec1 == spec2, # FIXME: Bug in `dynamo.export`. Sometimes outputs returned in 'list' instead of 'tuple'. - lambda: _replace_tuple_with_list(spec1) == _replace_tuple_with_list(spec2), + lambda: _replace_list_with_tuple(spec1) == _replace_list_with_tuple(spec2), # FIXME: Bug in `dynamo.export`. Sometimes single function return is wrapped in list. - lambda: _open_top_level_list_if_single_element(spec1) == spec2, - lambda: spec1 == _open_top_level_list_if_single_element(spec2), + lambda: _open_top_level_sequence_if_single_element(spec1) == spec2, + lambda: spec1 == _open_top_level_sequence_if_single_element(spec2), ] if not any(check() for check in pass_if_any_checks): diff --git a/torch/onnx/_internal/onnx_proto_utils.py b/torch/onnx/_internal/onnx_proto_utils.py index 5fc181b180824..19c31ab16f380 100644 --- a/torch/onnx/_internal/onnx_proto_utils.py +++ b/torch/onnx/_internal/onnx_proto_utils.py @@ -4,19 +4,21 @@ from __future__ import annotations import glob -import io import os import shutil -import zipfile -from typing import Any, Mapping +from typing import Any, Mapping, TYPE_CHECKING import torch import torch.jit._trace import torch.serialization -from torch.onnx import _constants, _exporter_states, errors +from torch.onnx import errors from torch.onnx._internal import jit_utils, registration +if TYPE_CHECKING: + import io + + def export_as_test_case( model_bytes: bytes, inputs_data, outputs_data, name: str, dir: str ) -> str: @@ -54,7 +56,6 @@ def export_as_test_case( _export_file( model_bytes, os.path.join(test_case_dir, "model.onnx"), - _exporter_states.ExportTypes.PROTOBUF_FILE, {}, ) data_set_dir = os.path.join(test_case_dir, "test_data_set_0") @@ -163,47 +164,12 @@ def export_data(data, value_info_proto, f: str) -> None: def _export_file( model_bytes: bytes, f: io.BytesIO | str, - export_type: str, export_map: Mapping[str, bytes], ) -> None: """export/write model bytes into directory/protobuf/zip""" - if export_type == _exporter_states.ExportTypes.PROTOBUF_FILE: - assert len(export_map) == 0 - with torch.serialization._open_file_like(f, "wb") as opened_file: - opened_file.write(model_bytes) - elif export_type in { - _exporter_states.ExportTypes.ZIP_ARCHIVE, - _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE, - }: - compression = ( - zipfile.ZIP_DEFLATED - if export_type == _exporter_states.ExportTypes.COMPRESSED_ZIP_ARCHIVE - else zipfile.ZIP_STORED - ) - with zipfile.ZipFile(f, "w", compression=compression) as z: - z.writestr(_constants.ONNX_ARCHIVE_MODEL_PROTO_NAME, model_bytes) - for k, v in export_map.items(): - z.writestr(k, v) - elif export_type == _exporter_states.ExportTypes.DIRECTORY: - if isinstance(f, io.BytesIO) or not os.path.isdir(f): # type: ignore[arg-type] - raise ValueError( - f"f should be directory when export_type is set to DIRECTORY, instead get type(f): {type(f)}" - ) - if not os.path.exists(f): # type: ignore[arg-type] - os.makedirs(f) # type: ignore[arg-type] - - model_proto_file = os.path.join(f, _constants.ONNX_ARCHIVE_MODEL_PROTO_NAME) # type: ignore[arg-type] - with torch.serialization._open_file_like(model_proto_file, "wb") as opened_file: - opened_file.write(model_bytes) - - for k, v in export_map.items(): - weight_proto_file = os.path.join(f, k) # type: ignore[arg-type] - with torch.serialization._open_file_like( - weight_proto_file, "wb" - ) as opened_file: - opened_file.write(v) - else: - raise ValueError("Unknown export type") + assert len(export_map) == 0 + with torch.serialization._open_file_like(f, "wb") as opened_file: + opened_file.write(model_bytes) def _add_onnxscript_fn( diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 799f2d6f81a56..3d02159155120 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -1988,7 +1988,7 @@ def _embedding_bag_helper( # FIXME(justinchuby): We need to handle what happens when we call b.op on a node return block_input_iter = utils._add_input_to_block(loop_block) - cond = utils._add_input_to_block(loop_block) + utils._add_input_to_block(loop_block) indices_start = loop_context.op( "Gather", offsets_starts, block_input_iter, axis_i=0 diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py index 7bf27b273832f..809a98c5f9dee 100644 --- a/torch/onnx/symbolic_opset11.py +++ b/torch/onnx/symbolic_opset11.py @@ -534,7 +534,7 @@ def stack(g: jit_utils.GraphContext, tensor_list, dim): @_onnx_symbolic("aten::_unique2") @symbolic_helper.parse_args("v", "i", "i", "i") def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_counts): - u, indices, inverse_indices, counts = g.op( + u, _indices, inverse_indices, counts = g.op( "Unique", self, sorted_i=sorted, outputs=4 ) return u, inverse_indices, counts @@ -545,7 +545,7 @@ def _unique2(g: jit_utils.GraphContext, self, sorted, return_inverse, return_cou def unique_dim( g: jit_utils.GraphContext, self, dim, sorted, return_inverse, return_counts ): - u, indices, inverse_indices, counts = g.op( + u, _indices, inverse_indices, counts = g.op( "Unique", self, axis_i=dim, sorted_i=sorted, outputs=4 ) return u, inverse_indices, counts @@ -945,7 +945,6 @@ def index(g: jit_utils.GraphContext, self, index): @_onnx_symbolic("aten::index_fill") def index_fill(g: jit_utils.GraphContext, self, dim, index, value): - dim_value = symbolic_helper._parse_arg(dim, "i") expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) @@ -957,8 +956,7 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value): @_onnx_symbolic("aten::index_copy") def index_copy(g: jit_utils.GraphContext, self, dim, index, source): - dim_value = symbolic_helper._parse_arg(dim, "i") - expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + _expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) return scatter(g, self, dim, expanded_index, source) diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index 7aaefd37201dd..21489fbb79725 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -346,8 +346,7 @@ def unfold(g: jit_utils.GraphContext, input, dimension, size, step): loop_block = loop_context.block block_input_iter = utils._add_input_to_block(loop_block) - # FIXME(justinchuby): cond is unused? - cond = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) # noqa: F841 starts = loop_context.op("Gather", low_indices, block_input_iter) ends = loop_context.op("Gather", hi_indices, block_input_iter) diff --git a/torch/onnx/symbolic_opset13.py b/torch/onnx/symbolic_opset13.py index e31416ae2bc90..aa40c55780420 100644 --- a/torch/onnx/symbolic_opset13.py +++ b/torch/onnx/symbolic_opset13.py @@ -211,7 +211,7 @@ def tensor_split( loop_block = loop_context.block block_input_iter = utils._add_input_to_block(loop_block) - cond = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) # noqa: F841 final_splits = utils._add_input_to_block(loop_block) start = loop_context.op( @@ -689,7 +689,7 @@ def repeat_interleave( loop_block = loop_context.block block_input_iter = utils._add_input_to_block(loop_block) - cond = utils._add_input_to_block(loop_block) + cond = utils._add_input_to_block(loop_block) # noqa: F841 final_splits = utils._add_input_to_block(loop_block) r_split = loop_context.op("SequenceAt", r_splits, block_input_iter) diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index 2bcba2f93d04a..997e0cfb4a153 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -2955,7 +2955,6 @@ def index_put(g: jit_utils.GraphContext, self, indices_list_value, values, accum @_onnx_symbolic("aten::index_fill") def index_fill(g: jit_utils.GraphContext, self, dim, index, value): - dim_value = symbolic_helper._parse_arg(dim, "i") expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) @@ -2968,8 +2967,7 @@ def index_fill(g: jit_utils.GraphContext, self, dim, index, value): @_onnx_symbolic("aten::index_copy") def index_copy(g: jit_utils.GraphContext, self, dim, index, source): - dim_value = symbolic_helper._parse_arg(dim, "i") - expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( + _expanded_index_shape, expanded_index = symbolic_helper._index_fill_reshape_helper( g, self, dim, index ) return scatter(g, self, dim, expanded_index, source) @@ -3674,14 +3672,14 @@ def new_full( def eye(g: jit_utils.GraphContext, *args): if len(args) == 5: # aten::eye(n, dtype, layout, device, pin_memory) - n, dtype, layout, device, pin_memory = args + n, dtype, layout, device, _pin_memory = args dim_size = symbolic_helper._unsqueeze_helper(g, n, [0]) shape = g.op("Concat", dim_size, dim_size, axis_i=0) tensor = zeros(g, shape, dtype, layout, device) return g.op("EyeLike", tensor) if len(args) == 6: # aten::eye(n, m, dtype, layout, device, pin_memory) - n, m, dtype, layout, device, pin_memory = args + n, m, dtype, layout, device, _pin_memory = args shape = g.op( "Concat", symbolic_helper._unsqueeze_helper(g, n, [0]), @@ -5567,14 +5565,14 @@ def linalg_matrix_norm( g, g.op("Abs", self), axes_i=[dim[0]], keepdims_i=keepdim ) if ord_value > 0: - result, indices = max( + result, _indices = max( g, sum, dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), keepdim=keepdim, ) else: - result, indices = min( + result, _indices = min( g, sum, dim_or_y=g.op("Constant", value_t=torch.LongTensor([dim[1]])), @@ -6391,7 +6389,7 @@ def prim_loop(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: opset_version = GLOBALS.export_onnx_opset_version old_blocks = tuple(node.blocks()) - new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( + _new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( g, "Loop", *inputs, outputs=node.outputsSize(), n_blocks=len(old_blocks) ) @@ -6500,7 +6498,7 @@ def prim_if(g: jit_utils.GraphContext, *inputs, **attrs) -> list[_C.Value]: return final_b_list else: old_blocks = tuple(n.blocks()) - new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( + _new_op_outputs, new_block_contexts, new_node = jit_utils.add_op_with_blocks( g, "If", *inputs, outputs=n.outputsSize(), n_blocks=len(old_blocks) ) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index 4e3510472cd96..7561438924591 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -20,13 +20,7 @@ import torch.jit._trace import torch.serialization from torch import _C -from torch.onnx import ( # noqa: F401 - _constants, - _deprecation, - _exporter_states, - errors, - symbolic_helper, -) +from torch.onnx import _constants, _deprecation, errors, symbolic_helper # noqa: F401 from torch.onnx._globals import GLOBALS from torch.onnx._internal import diagnostics, jit_utils, onnx_proto_utils, registration @@ -41,7 +35,6 @@ "model_signature", "warn_on_static_input_change", "unpack_quantized_tensor", - "export_to_pretty_string", "unconvertible_ops", "register_custom_op_symbolic", "unregister_custom_op_symbolic", @@ -1066,7 +1059,7 @@ def _model_to_graph( input_names=input_names, module=module, ) - except Exception as e: + except Exception: _C._jit_onnx_log("Torch IR graph at exception: ", graph) raise @@ -1146,84 +1139,6 @@ def _model_to_graph( return graph, params_dict, torch_out -@torch._disable_dynamo -@_deprecation.deprecated("2.5", "the future", "use onnx.printer.to_text() instead") -def export_to_pretty_string( - model, - args, - export_params=True, - verbose=False, - training=_C_onnx.TrainingMode.EVAL, - input_names=None, - output_names=None, - operator_export_type=_C_onnx.OperatorExportTypes.ONNX, - export_type=None, - google_printer=False, - opset_version=None, - keep_initializers_as_inputs=None, - custom_opsets=None, - add_node_names=True, - do_constant_folding=True, - dynamic_axes=None, -): - """Similar to :func:`export`, but returns a text representation of the ONNX model. - - Only differences in args listed below. All other args are the same - as :func:`export`. - - Args: - add_node_names (bool, default True): Whether or not to set - NodeProto.name. This makes no difference unless - ``google_printer=True``. - google_printer (bool, default False): If False, will return a custom, - compact representation of the model. If True will return the - protobuf's `Message::DebugString()`, which is more verbose. - - Returns: - A UTF-8 str containing a human-readable representation of the ONNX model. - """ - if opset_version is None: - opset_version = _constants.ONNX_DEFAULT_OPSET - if custom_opsets is None: - custom_opsets = {} - GLOBALS.export_onnx_opset_version = opset_version - GLOBALS.operator_export_type = operator_export_type - - with exporter_context(model, training, verbose): - val_keep_init_as_ip = _decide_keep_init_as_input( - keep_initializers_as_inputs, operator_export_type, opset_version - ) - val_add_node_names = _decide_add_node_names( - add_node_names, operator_export_type - ) - val_do_constant_folding = _decide_constant_folding( - do_constant_folding, operator_export_type, training - ) - args = _decide_input_format(model, args) - graph, params_dict, torch_out = _model_to_graph( - model, - args, - verbose, - input_names, - output_names, - operator_export_type, - val_do_constant_folding, - training=training, - dynamic_axes=dynamic_axes, - ) - - return graph._pretty_print_onnx( # type: ignore[attr-defined] - params_dict, - opset_version, - False, - operator_export_type, - google_printer, - val_keep_init_as_ip, - custom_opsets, - val_add_node_names, - ) - - @_deprecation.deprecated("2.5", "the future", "avoid using this function") def unconvertible_ops( model, @@ -1423,9 +1338,6 @@ def _export( ): assert GLOBALS.in_onnx_export is False - if export_type is None: - export_type = _exporter_states.ExportTypes.PROTOBUF_FILE - if isinstance(model, torch.nn.DataParallel): raise ValueError( "torch.nn.DataParallel is not supported by ONNX " @@ -1516,10 +1428,6 @@ def _export( dynamic_axes=dynamic_axes, ) - # TODO: Don't allocate a in-memory string for the protobuf - defer_weight_export = ( - export_type is not _exporter_states.ExportTypes.PROTOBUF_FILE - ) if custom_opsets is None: custom_opsets = {} @@ -1540,12 +1448,13 @@ def _export( getattr(model, "training", False), # type: ignore[arg-type] ) _C._jit_pass_onnx_assign_scoped_names_for_node_and_value(graph) + defer_weight_export = False if export_params: ( proto, export_map, - val_use_external_data_format, - node_names, + _val_use_external_data_format, + _node_names, ) = graph._export_onnx( # type: ignore[attr-defined] params_dict, opset_version, @@ -1563,13 +1472,13 @@ def _export( ( proto, export_map, - val_use_external_data_format, - node_names, + _, + _, ) = graph._export_onnx( # type: ignore[attr-defined] {}, opset_version, dynamic_axes, - False, + defer_weight_export, operator_export_type, not verbose, val_keep_init_as_ip, @@ -1585,7 +1494,7 @@ def _export( ) if verbose: _C._jit_onnx_log("Exported graph: ", graph) - onnx_proto_utils._export_file(proto, f, export_type, export_map) + onnx_proto_utils._export_file(proto, f, export_map) finally: assert GLOBALS.in_onnx_export GLOBALS.in_onnx_export = False diff --git a/torch/onnx/verification.py b/torch/onnx/verification.py index f489252f5a7b2..26810b116ffc0 100644 --- a/torch/onnx/verification.py +++ b/torch/onnx/verification.py @@ -26,7 +26,7 @@ import torch import torch._C._onnx as _C_onnx from torch import _C -from torch.onnx import _constants, _experimental, _exporter_states, utils +from torch.onnx import _constants, _experimental, utils from torch.onnx._globals import GLOBALS from torch.onnx._internal import onnx_proto_utils from torch.types import Number @@ -893,8 +893,7 @@ def verify_aten_graph( graph, export_options, onnx_params_dict ) model_f: str | io.BytesIO = io.BytesIO() - export_type = _exporter_states.ExportTypes.PROTOBUF_FILE - onnx_proto_utils._export_file(proto, model_f, export_type, export_map) + onnx_proto_utils._export_file(proto, model_f, export_map) # NOTE: Verification is unstable. Try catch to emit information for debugging. try: @@ -1783,7 +1782,7 @@ def find_mismatch( args = utils._decide_input_format(model, inputs_for_export) model = utils._pre_trace_quant_model(model, args) - graph, params, torch_out, module = utils._create_jit_graph(model, args) + graph, params, _torch_out, _module = utils._create_jit_graph(model, args) params_dict = utils._get_named_param_dict(graph, params) utils._apply_friendly_debug_names(graph, params_dict) diff --git a/torch/optim/_adafactor.py b/torch/optim/_adafactor.py index dc3941008ab8a..65f41d6ab182e 100644 --- a/torch/optim/_adafactor.py +++ b/torch/optim/_adafactor.py @@ -9,6 +9,7 @@ _disable_dynamo_if_unsupported, _get_scalar_dtype, _maximize_doc, + _params_doc, Optimizer, ParamsT, TensorListList, @@ -223,8 +224,7 @@ def step(self, closure=None): """ + rf""" Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups + {_params_doc} lr (float, Tensor, optional): unlike other optimizers, Adafactor does not require a learning rate, and Shazeer, Noam, and Mitchell Stern do not use lr at all. Deviating from the paper, this implementation uses lr for applying weight diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index ef45706176a34..60c37680aeb57 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -14,6 +14,7 @@ _get_capturable_supported_devices, _get_scalar_dtype, _maximize_doc, + _params_doc, _use_grad_for_differentiable, _view_as_real, Optimizer, @@ -219,16 +220,15 @@ def step(self, closure=None): """ + rf""" Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups + {_params_doc} + lr (float, Tensor, optional): coefficient that scale delta before it is applied + to the parameters (default: 1.0) rho (float, optional): coefficient used for computing a running average of squared gradients (default: 0.9). A higher value of `rho` will result in a slower average, which can be helpful for preventing oscillations in the learning process. eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-6). - lr (float, Tensor, optional): coefficient that scale delta before it is applied - to the parameters (default: 1.0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) {_foreach_doc} {_capturable_doc} diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 7427471c1bfd4..c45df14727c69 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -12,6 +12,7 @@ _get_scalar_dtype, _get_value, _maximize_doc, + _params_doc, _use_grad_for_differentiable, _view_as_real, Optimizer, @@ -216,8 +217,7 @@ def step(self, closure=None): """ + rf""" Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups + {_params_doc} lr (float, Tensor, optional): learning rate (default: 1e-2) lr_decay (float, optional): learning rate decay (default: 0) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 708ea28a57166..23337e6352568 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -17,6 +17,7 @@ _get_scalar_dtype, _get_value, _maximize_doc, + _params_doc, _stack_if_compiling, _use_grad_for_differentiable, _view_as_real, @@ -288,8 +289,7 @@ def step(self, closure=None): """ + rf""" Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups + {_params_doc} lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR is not yet supported for all our implementations. Please use a float LR if you are not also specifying fused=True or capturable=True. diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index b1c80a2ae3dca..4459d033c1e36 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -15,6 +15,7 @@ _get_scalar_dtype, _get_value, _maximize_doc, + _params_doc, _use_grad_for_differentiable, _view_as_real, Optimizer, @@ -203,8 +204,7 @@ def step(self, closure=None): """ + rf""" Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups + {_params_doc} lr (float, Tensor, optional): learning rate (default: 2e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 0c49f528e8f13..fc6aec32b2e30 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -17,6 +17,7 @@ _get_scalar_dtype, _get_value, _maximize_doc, + _params_doc, _stack_if_compiling, _use_grad_for_differentiable, _view_as_real, @@ -285,8 +286,7 @@ def step(self, closure=None): """ + rf""" Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups + {_params_doc} lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR is not yet supported for all our implementations. Please use a float LR if you are not also specifying fused=True or capturable=True. diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index 79de96aa86cd2..32a52cf9ac4ee 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -15,6 +15,7 @@ _get_scalar_dtype, _get_value, _maximize_doc, + _params_doc, _use_grad_for_differentiable, _view_as_real, Optimizer, @@ -174,8 +175,7 @@ def step(self, closure=None): averaging`_. Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups + {_params_doc} lr (float, Tensor, optional): learning rate (default: 1e-2) lambd (float, optional): decay term (default: 1e-4) alpha (float, optional): power for eta update (default: 0.75) diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index f0a10efefd12f..abbeb51edfb00 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -247,8 +247,7 @@ def step(self, epoch: Optional[int] = None): else: values = self.get_lr() - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, lr = data + for param_group, lr in zip(self.optimizer.param_groups, values): if isinstance(param_group["lr"], Tensor): param_group["lr"].fill_(lr) else: @@ -909,14 +908,26 @@ def __init__( group["lr"] = group["initial_lr"] # "Undo" the step performed by other schedulers - for scheduler in self._schedulers: - scheduler.last_epoch -= 1 + self.recursive_undo() # Perform the initial step for only the first scheduler self._schedulers[0]._initial_step() self._last_lr = schedulers[0].get_last_lr() + def recursive_undo(self, sched=None): + """ + Recursively undo any step performed by the initialisation of + schedulers. + """ + scheds = self if sched is None else sched + + if hasattr(scheds, "_schedulers"): + for s in scheds._schedulers: + self.recursive_undo(s) + elif hasattr(scheds, "last_epoch"): + scheds.last_epoch -= 1 + def step(self): # type: ignore[override] """Perform a step.""" self.last_epoch += 1 @@ -1318,8 +1329,10 @@ def __init__( raise ValueError( f"expected {len(optimizer.param_groups)} min_lrs, got {len(min_lr)}" ) + self.default_min_lr = None self.min_lrs = list(min_lr) else: + self.default_min_lr = min_lr self.min_lrs = [min_lr] * len(optimizer.param_groups) self.patience = patience @@ -1375,6 +1388,20 @@ def step(self, metrics: SupportsFloat, epoch=None): # type: ignore[override] self._last_lr = [group["lr"] for group in self.optimizer.param_groups] def _reduce_lr(self, epoch): + if len(self.optimizer.param_groups) != len(self.min_lrs): + if self.default_min_lr is None: + raise RuntimeError( + "The number of param groups in the `optimizer` " + f"({len(self.optimizer.param_groups)}) differs " + f"from when `ReduceLROnPlateau` was initialized " + f"({len(self.min_lrs)}), usually due to a new " + "param group being added to the optimizer. Please " + "modify the `min_lrs` field to match the length " + "of the `optimizer` param groups." + ) + else: + self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups) + for i, param_group in enumerate(self.optimizer.param_groups): old_lr = float(param_group["lr"]) new_lr = max(old_lr * self.factor, self.min_lrs[i]) @@ -1837,8 +1864,7 @@ def step(self, epoch=None): self.last_epoch = math.floor(epoch) with _enable_get_lr_call(self): - for i, data in enumerate(zip(self.optimizer.param_groups, self.get_lr())): - param_group, lr = data + for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): param_group["lr"] = lr self._last_lr = [group["lr"] for group in self.optimizer.param_groups] diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index e26b3bf302587..2dd7e130c0d6c 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -16,6 +16,7 @@ _get_scalar_dtype, _get_value, _maximize_doc, + _params_doc, _stack_if_compiling, _use_grad_for_differentiable, _view_as_real, @@ -251,8 +252,7 @@ def step(self, closure=None): """ + rf""" Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups + {_params_doc} lr (float, Tensor, optional): learning rate (default: 2e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 8f7993842c100..f3b7e7dac0af8 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -232,6 +232,10 @@ def _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]: # Common doc strings among optimizers +_params_doc = r"""params (iterable): iterable of parameters or named_parameters to optimize + or iterable of dicts defining parameter groups. When using named_parameters, + all parameters in all groups should be named""" + _foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer is used. If unspecified by the user (so foreach is None), we will try to use foreach over the for-loop implementation on CUDA, since it is usually @@ -308,7 +312,9 @@ def register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> Removabl return handle -ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] +ParamsT: TypeAlias = Union[ + Iterable[torch.Tensor], Iterable[Dict[str, Any]], Iterable[Tuple[str, torch.Tensor]] +] _P = ParamSpec("_P") R = TypeVar("R") @@ -649,6 +655,8 @@ def state_dict(self) -> StateDict: parameter group is a Dict. Each parameter group contains metadata specific to the optimizer, such as learning rate and weight decay, as well as a List of parameter IDs of the parameters in the group. + If a param group was initialized with ``named_parameters()`` the names + content will also be saved in the state dict. NOTE: The parameter IDs may look like indices but they are just IDs associating state with param_group. When loading from a state_dict, @@ -673,12 +681,14 @@ def state_dict(self) -> StateDict: 'weight_decay': 0, ... 'params': [0] + 'param_names' ['param0'] (optional) }, { 'lr': 0.001, 'weight_decay': 0.5, ... 'params': [1, 2, 3] + 'param_names': ['param1', 'layer.weight', 'layer.bias'] (optional) } ] } @@ -834,6 +844,17 @@ def load_state_dict(self, state_dict: StateDict) -> None: Args: state_dict (dict): optimizer state. Should be an object returned from a call to :meth:`state_dict`. + + .. note:: + The names of the parameters (if they exist under the "param_names" key of each param group + in :meth:`state_dict`) will not affect the loading process. + To use the parameters' names for custom cases (such as when the parameters in the loaded state dict + differ from those initialized in the optimizer), + a custom ``register_load_state_dict_pre_hook`` should be implemented to adapt the loaded dict + accordingly. + If ``param_names`` exist in loaded state dict ``param_groups`` they will be saved and override + the current names, if present, in the optimizer state. If they do not exist in loaded state dict, + the optimizer ``param_names`` will remain unchanged. """ # shallow copy, to be consistent with module API state_dict = state_dict.copy() @@ -905,6 +926,8 @@ def update_group( group: Dict[str, Any], new_group: Dict[str, Any] ) -> Dict[str, Any]: new_group["params"] = group["params"] + if "param_names" in group and "param_names" not in new_group: + new_group["param_names"] = group["param_names"] return new_group param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] @@ -982,10 +1005,6 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] Args: closure (Callable): A closure that reevaluates the model and returns the loss. Optional for most optimizers. - - .. note:: - Unless otherwise specified, this function should not modify the - ``.grad`` field of the parameters. """ raise NotImplementedError @@ -1014,6 +1033,25 @@ def add_param_group(self, param_group: Dict[str, Any]) -> None: else: param_group["params"] = list(params) + extracted_param_tensors = [] + extracted_param_names = [] + for param in param_group["params"]: + if isinstance(param, tuple): + param_name = param[0] + extracted_param_names.append(param_name) + extracted_param_tensors.append(param[1]) + else: + extracted_param_tensors.append(param) + + param_group["params"] = extracted_param_tensors + if len(extracted_param_names) != 0: + if len(extracted_param_names) == len(extracted_param_tensors): + param_group["param_names"] = extracted_param_names + else: + raise ValueError( + "all optimizer params should be with/without names. Some param names are missing" + ) + for param in param_group["params"]: if not isinstance(param, torch.Tensor): raise TypeError( @@ -1045,6 +1083,14 @@ def add_param_group(self, param_group: Dict[str, Any]) -> None: param_set: Set[torch.Tensor] = set() for group in self.param_groups: param_set.update(set(group["params"])) + if ("param_names" in param_group) != ("param_names" in group): + current_group_txt = ( + "with names" if "param_names" in param_group else "without names" + ) + raise ValueError( + "all optimizer param groups should be with/without names. " + f"cannot add param group {current_group_txt} to the optimizer" + ) if not param_set.isdisjoint(set(param_group["params"])): raise ValueError("some parameters appear in more than one parameter group") diff --git a/torch/optim/radam.py b/torch/optim/radam.py index a2d0c31a91736..9a36a2be1841d 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -16,6 +16,7 @@ _get_scalar_dtype, _get_value, _maximize_doc, + _params_doc, _use_grad_for_differentiable, _view_as_real, Optimizer, @@ -225,8 +226,7 @@ def step(self, closure=None): """ + rf""" Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups + {_params_doc} lr (float, Tensor, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) @@ -237,8 +237,8 @@ def step(self, closure=None): decay as in AdamW to obtain RAdamW (default: False) {_foreach_doc} {_maximize_doc} - {_differentiable_doc} {_capturable_doc} + {_differentiable_doc} .. _On the variance of the adaptive learning rate and beyond: https://arxiv.org/abs/1908.03265 diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 876f4e1d697bf..f839ba0f021c6 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -15,6 +15,7 @@ _get_capturable_supported_devices, _get_scalar_dtype, _maximize_doc, + _params_doc, _use_grad_for_differentiable, _view_as_real, Optimizer, @@ -201,9 +202,10 @@ def step(self, closure=None): .. math:: \begin{aligned} &\rule{110mm}{0.4pt} \\ - &\textbf{input} : \alpha \text{ (alpha)},\: \gamma \text{ (lr)}, + &\textbf{input} : \alpha \text{ (alpha)}, \: \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\ - &\hspace{13mm} \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},\: centered\\ + &\hspace{13mm} \lambda \text{ (weight decay)},\: \mu \text{ (momentum)}, + \: centered, \: \epsilon \text{ (epsilon)} \\ &\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \: \textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0 \\[-1.ex] &\rule{110mm}{0.4pt} \\ @@ -241,19 +243,18 @@ def step(self, closure=None): """ + rf""" Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups + {_params_doc} lr (float, Tensor, optional): learning rate (default: 1e-2) - momentum (float, optional): momentum factor (default: 0) alpha (float, optional): smoothing constant (default: 0.99) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + momentum (float, optional): momentum factor (default: 0) centered (bool, optional) : if ``True``, compute the centered RMSProp, the gradient is normalized by an estimation of its variance - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + {_capturable_doc} {_foreach_doc} {_maximize_doc} - {_capturable_doc} {_differentiable_doc} """ diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index e28f3535a0b99..538c8ac0a861d 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -15,6 +15,7 @@ _get_capturable_supported_devices, _get_scalar_dtype, _maximize_doc, + _params_doc, _use_grad_for_differentiable, _view_as_real, Optimizer, @@ -202,16 +203,15 @@ def step(self, closure=None): """ + rf""" Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups + {_params_doc} lr (float, optional): learning rate (default: 1e-2) etas (Tuple[float, float], optional): pair of (etaminus, etaplus), that are multiplicative increase and decrease factors (default: (0.5, 1.2)) step_sizes (Tuple[float, float], optional): a pair of minimal and maximal allowed step sizes (default: (1e-6, 50)) - {_foreach_doc} {_capturable_doc} + {_foreach_doc} {_maximize_doc} {_differentiable_doc} diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 46af5ae77537e..ab70f08b44113 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -12,6 +12,7 @@ _foreach_doc, _fused_doc, _maximize_doc, + _params_doc, _use_grad_for_differentiable, DeviceDict, Optimizer, @@ -185,13 +186,13 @@ def step(self, closure=None): """ + rf""" Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups + {_params_doc} lr (float, Tensor, optional): learning rate (default: 1e-3) momentum (float, optional): momentum factor (default: 0) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) dampening (float, optional): dampening for momentum (default: 0) - nesterov (bool, optional): enables Nesterov momentum (default: False) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + nesterov (bool, optional): enables Nesterov momentum. Only applicable + when momentum is non-zero. (default: False) {_maximize_doc} {_foreach_doc} {_differentiable_doc} diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index 22ef7841270f6..23ac70678e2ec 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -5,7 +5,7 @@ from torch import Tensor from . import _functional as F -from .optimizer import _maximize_doc, Optimizer, ParamsT +from .optimizer import _maximize_doc, _params_doc, Optimizer, ParamsT __all__ = ["SparseAdam"] @@ -170,8 +170,7 @@ def step(self, closure=None): Args: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups + {_params_doc} lr (float, Tensor, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) diff --git a/torch/overrides.py b/torch/overrides.py index d75ac0d9ce157..7040ecc522539 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -150,7 +150,6 @@ def get_ignored_functions() -> Set[Callable]: torch.wait, torch.as_tensor, torch.from_numpy, - torch.get_device, torch.tensor, torch.default_generator, torch.has_cuda, @@ -653,6 +652,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.gather: lambda input, dim, index, out=None, sparse_grad=False: -1, torch.gcd: lambda input, other, out=None: -1, torch.ge: lambda input, other, out=None: -1, + torch.get_device: lambda input: -1, torch.greater_equal: lambda input, other, out=None: -1, torch.geqrf: lambda input, out=None: -1, torch.i0: lambda input, out=None: -1, @@ -2085,6 +2085,16 @@ def __torch_function__(self, func, types, args=(), kwargs=None): return func(*args, **kwargs) +@contextlib.contextmanager +def _enable_torch_function(): + old_state = torch._C._get_torch_function_state() + try: + torch._C._set_torch_function_state(torch._C._TorchFunctionState.ENABLED) + yield + finally: + torch._C._set_torch_function_state(old_state) + + @contextlib.contextmanager def enable_reentrant_dispatch(): # NB: this can't simply be diff --git a/torch/package/_mangling.py b/torch/package/_mangling.py index 0cf3791d16044..09d7901c2d6cc 100644 --- a/torch/package/_mangling.py +++ b/torch/package/_mangling.py @@ -53,7 +53,7 @@ def demangle(name: str) -> str: mangled name, irrespective of which PackageMangler created it. """ if is_mangled(name): - first, sep, last = name.partition(".") + _first, sep, last = name.partition(".") # If there is only a base mangle prefix, e.g. '', # then return an empty string. return last if len(sep) != 0 else "" diff --git a/torch/package/_package_pickler.py b/torch/package/_package_pickler.py index 8856ad6c37ccf..b80d92c12eb21 100644 --- a/torch/package/_package_pickler.py +++ b/torch/package/_package_pickler.py @@ -49,6 +49,7 @@ def __init__(self, importer: Importer, *args, **kwargs): self.dispatch[FunctionType] = PackagePickler.save_global # type: ignore[assignment] def save_global(self, obj, name=None): + # ruff: noqa: F841 # unfortunately the pickler code is factored in a way that # forces us to copy/paste this function. The only change is marked # CHANGED below. diff --git a/torch/package/find_file_dependencies.py b/torch/package/find_file_dependencies.py index 80cfccbec50a6..dd5c5bb9ea99f 100644 --- a/torch/package/find_file_dependencies.py +++ b/torch/package/find_file_dependencies.py @@ -89,7 +89,7 @@ def visit_Call(self, node): self.references[(name, alias)] = True else: self.references[(name, None)] = True - except Exception as e: + except Exception: return diff --git a/torch/package/package_exporter.py b/torch/package/package_exporter.py index 7b377b95454da..2ece831fab005 100644 --- a/torch/package/package_exporter.py +++ b/torch/package/package_exporter.py @@ -427,7 +427,7 @@ def _write_source_string( def _import_module(self, module_name: str): try: return self.importer.import_module(module_name) - except ModuleNotFoundError as e: + except ModuleNotFoundError: if not is_mangled(module_name): raise msg = ( @@ -662,7 +662,7 @@ def _check_mocked_error(module: Optional[str], field: Optional[str]): memo: DefaultDict[int, str] = defaultdict(None) memo_count = 0 # pickletools.dis(data_value) - for opcode, arg, pos in pickletools.genops(data_value): + for opcode, arg, _pos in pickletools.genops(data_value): if pickle_protocol == 4: if ( opcode.name == "SHORT_BINUNICODE" diff --git a/torch/package/package_importer.py b/torch/package/package_importer.py index cf557d72bd4f7..f779ee1f08660 100644 --- a/torch/package/package_importer.py +++ b/torch/package/package_importer.py @@ -463,7 +463,6 @@ def _install_on_parent(self, parent: str, name: str, module: types.ModuleType): # note: copied from cpython's import code, with call to create module replaced with _make_module def _do_find_and_load(self, name): - path = None parent = name.rpartition(".")[0] module_name_no_parent = name.rpartition(".")[-1] if parent: @@ -475,7 +474,7 @@ def _do_find_and_load(self, name): parent_module = self.modules[parent] try: - path = parent_module.__path__ # type: ignore[attr-defined] + parent_module.__path__ # type: ignore[attr-defined] except AttributeError: # when we attempt to import a package only containing pybinded files, diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index 2095b882f5de9..864b7ab095ad0 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -192,7 +192,7 @@ def _extract_parameters_and_gradients( def extract_parameters(node: _ProfilerEvent) -> Iterator[TensorKey]: - for p, p_grad in _extract_parameters_and_gradients(node): + for p, _p_grad in _extract_parameters_and_gradients(node): if p is not None: yield p diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index 4b0708c4a78f5..39abcbd37c212 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -884,7 +884,7 @@ def _save_triton_kernels(): for kernel_file in kernel_files: if kernel_file is None: continue - path, name = os.path.split(kernel_file) + name = os.path.basename(kernel_file) dst = os.path.join(resource_dir, name) shutil.copyfile(kernel_file, dst) diff --git a/torch/quantization/__init__.py b/torch/quantization/__init__.py index 8789fea17a17f..11114de431386 100644 --- a/torch/quantization/__init__.py +++ b/torch/quantization/__init__.py @@ -16,7 +16,7 @@ def default_eval_fn(model, calib_data): Default evaluation function takes a torch.utils.data.Dataset or a list of input Tensors and run the model on the dataset """ - for data, target in calib_data: + for data, _target in calib_data: model(data) diff --git a/torch/random.py b/torch/random.py index 783331145633f..38d37e03dfeae 100644 --- a/torch/random.py +++ b/torch/random.py @@ -147,6 +147,10 @@ def fork_rng( see details in [Note: support the custom device with privateuse1] """ + if device_type == "meta": + yield + return + device_type = torch.device(device_type).type device_mod = getattr(torch, device_type, None) if device_mod is None: diff --git a/torch/serialization.py b/torch/serialization.py index d937680c031c7..17517db6e7fd1 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -28,7 +28,7 @@ Type, Union, ) -from typing_extensions import TypeAlias, TypeGuard # Python 3.10+ +from typing_extensions import TypeAlias, TypeIs import torch import torch._weights_only_unpickler as _weights_only_unpickler @@ -620,7 +620,7 @@ def storage_to_tensor_type(storage): return getattr(module, storage_type.__name__.replace("Storage", "Tensor")) -def _is_path(name_or_buffer) -> TypeGuard[Union[str, os.PathLike]]: +def _is_path(name_or_buffer) -> TypeIs[Union[str, os.PathLike]]: return isinstance(name_or_buffer, (str, os.PathLike)) @@ -806,7 +806,7 @@ def save( # documentation. We need it so that Sphinx doesn't leak `pickle`s path from # the build environment (e.g. ` str: "is not supported yet. Please call torch.load outside the skip_data context manager." ) + true_values = ["1", "y", "yes", "true"] + # Add ability to force safe only or non-safe weight loads via environment variables + force_weights_only_load = ( + os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0") in true_values + ) + force_no_weights_only_load = ( + os.getenv("TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD", "0") in true_values + ) + + if force_weights_only_load and force_no_weights_only_load: + raise RuntimeError( + "Only one of `TORCH_FORCE_WEIGHTS_ONLY_LOAD` or `TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD` " + "should be set, but both were set." + ) + elif force_weights_only_load: + weights_only = True + elif force_no_weights_only_load: + if weights_only is None: + warnings.warn( + "Environment variable TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD detected, since the" + "`weights_only` argument was not explicitly passed to `torch.load`, forcing weights_only=False.", + UserWarning, + stacklevel=2, + ) + weights_only = False + if weights_only is None: weights_only, warn_weights_only = False, True else: warn_weights_only = False - # Add ability to force safe only weight loads via environment variable - if os.getenv("TORCH_FORCE_WEIGHTS_ONLY_LOAD", "0").lower() in [ - "1", - "y", - "yes", - "true", - ]: - weights_only = True - if weights_only: if pickle_module is not None: raise RuntimeError( @@ -1493,7 +1510,7 @@ def persistent_load(saved_id): tar.extract("storages", path=tmpdir) with open(os.path.join(tmpdir, "storages"), "rb", 0) as f: num_storages = pickle_module.load(f, **pickle_load_args) - for i in range(num_storages): + for _ in range(num_storages): args = pickle_module.load(f, **pickle_load_args) key, location, storage_type = args dtype = storage_type._dtype @@ -1527,7 +1544,7 @@ def persistent_load(saved_id): num_tensors = pickle_module.load(f, **pickle_load_args) for _ in range(num_tensors): args = pickle_module.load(f, **pickle_load_args) - key, storage_id, original_tensor_type = args + key, storage_id, _original_tensor_type = args storage = deserialized_objects[storage_id] (ndim,) = struct.unpack(" torch.Tensor: ) B_t = B.t() assert isinstance(B_t, torch.sparse.SparseSemiStructuredTensor) - row, col = A.shape + row, _col = A.shape A_padded = B_t._pad_dense_input(A) result = B_t._mm(A_padded.t(), bias=bias).t() return result[:row, :] diff --git a/torch/sparse/_triton_ops.py b/torch/sparse/_triton_ops.py index 346130c700892..ebc59b18d5a72 100644 --- a/torch/sparse/_triton_ops.py +++ b/torch/sparse/_triton_ops.py @@ -44,8 +44,8 @@ def check_mm_compatible_shapes(f_name, lhs, rhs): f"but got lhs.dim() == {lhs.dim()} and rhs.dim() == {rhs.dim()}.", ) - m, kl = lhs.shape[-2:] - kr, n = rhs.shape[-2:] + _m, kl = lhs.shape[-2:] + kr, _n = rhs.shape[-2:] check( kl == kr, @@ -360,13 +360,13 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None): indices_format = indices_data[0] assert blocks.ndim == 3 - P, Ms, Ks = blocks.shape + _P, Ms, Ks = blocks.shape if indices_format == "scatter_mm": c_offsets, pq = indices_data[1:] assert others.ndim == 3 - Q, Ks_, Ns = others.shape + _Q, Ks_, Ns = others.shape assert Ks == Ks_ if accumulators is None: @@ -749,6 +749,7 @@ def bsr_dense_addmm_meta( num_stages=None, sparsity=None, dtype=None, + out_dtype=None, _version=0, **extra, ): @@ -757,15 +758,31 @@ def bsr_dense_addmm_meta( # bsr_dense_addmm_meta functionality. if dtype is None: dtype = torch.float16 + if out_dtype is None: + out_dtype = dtype if sparsity is None: sparsity = 0.5 if {SPLIT_N, num_warps, num_stages, GROUP_SIZE_ROW} == {None}: device_name = torch.cuda.get_device_name() key = (M, K, N, Ms, Ks, beta == 0, beta == 1, alpha == 1) + if dtype is out_dtype: + version_dtype = dtype + else: + version_dtype = dtype, out_dtype meta = get_meta( - "bsr_dense_addmm", key, device_name, version=(_version, dtype, sparsity) + "bsr_dense_addmm", + key, + device_name, + version=(_version, version_dtype, sparsity), ) if meta is None and sparsity != 0.5: + meta = get_meta( + "bsr_dense_addmm", + key, + device_name, + version=(_version, version_dtype, 0.5), + ) + if meta is None and dtype is not out_dtype: meta = get_meta( "bsr_dense_addmm", key, device_name, version=(_version, dtype, 0.5) ) @@ -775,8 +792,15 @@ def bsr_dense_addmm_meta( "bsr_dense_addmm", (*key[:2], "*", *key[3:]), device_name, - version=(_version, dtype, 0.5), + version=(_version, version_dtype, 0.5), ) + if matching_meta is None and dtype is not out_dtype: + matching_meta = get_meta( + "bsr_dense_addmm", + (*key[:2], "*", *key[3:]), + device_name, + version=(_version, dtype, 0.5), + ) for mkey in sorted(matching_meta or {}): meta_ = matching_meta[mkey] n = mkey[2] @@ -794,7 +818,7 @@ def bsr_dense_addmm_meta( # message warn_once( "bsr_dense_addmm uses non-optimal triton kernel parameters" - f" for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=} {dtype=}" + f" for {M=} {K=} {N=} {Ms=}, {Ks=} {beta=} {alpha=} {dtype=} {out_dtype=}" ) SPLIT_N = SPLIT_N or max(N // Ms, 1) @@ -993,8 +1017,6 @@ def bsr_scatter_mm_indices_data( """ assert bsr.dense_dim() == 0 assert bsr.ndim == 2 # no batch dims - crow_indices = bsr.crow_indices() - col_indices = bsr.col_indices() blocksize = bsr.values().shape[-2:] M, K = bsr.shape Ms, Ks = blocksize @@ -1213,7 +1235,8 @@ def bsr_dense_addmm( beta, alpha, sparsity=sparsity, - dtype=out.dtype, + dtype=dense.dtype, + out_dtype=out.dtype, ) out_backup = out @@ -1665,8 +1688,6 @@ def sampled_addmm( return out blocksize = out.values().shape[-2:] - m = mat1.size(-2) - n = mat2.size(-1) k = mat1.size(-1) # NOTE: (m, 0) @ (0, n) == zeros(m, n) @@ -1714,7 +1735,7 @@ def bsr_dense_mm( meta: Optional[dict] = None, ): f_name = "bsr_dense_mm" - m, kl = bsr.shape[-2:] + m, _kl = bsr.shape[-2:] if not skip_checks: check_bsr_layout(f_name, bsr) check_device(f_name, bsr, dense.device) @@ -1729,7 +1750,7 @@ def bsr_dense_mm( f"{f_name}(): dense.size(-1) == {n} should be divisible by 16", ) else: - kr, n = dense.shape[-2:] + _kr, n = dense.shape[-2:] original_batch_dims_broadcasted = broadcast_batch_dims(f_name, bsr, dense) @@ -1993,7 +2014,6 @@ def _scatter_mm2_kernel( allow_tf32: tl.constexpr, ): Ms = M // TILE_M - Ns = N // TILE_N pid_t = tl.program_id(axis=0) @@ -2044,9 +2064,8 @@ def _scatter_mm2( pq_indices: torch.Tensor, accumulators: torch.Tensor, ): - P, M, K = blocks.shape - Q, _, N = others.shape - R, _, _ = accumulators.shape + _P, M, K = blocks.shape + _Q, _, N = others.shape meta = dict( TILE_M=max(16, M // 4), TILE_N=max(16, N // 4), num_stages=1, num_warps=2 @@ -2218,9 +2237,9 @@ def _scatter_mm6( force_contiguous: bool = True, ): SPLIT_N = meta["SPLIT_N"] - P, Ms, Ks = blocks.shape - B, K_, N = others.shape - B_, M, N_ = accumulators.shape + _P, Ms, Ks = blocks.shape + B, _K, N = others.shape + B_, _M, N_ = accumulators.shape assert N_ == N Ns = N // SPLIT_N assert B_ == B diff --git a/torch/sparse/_triton_ops_meta.py b/torch/sparse/_triton_ops_meta.py index 77a8699e62332..5bbd61b373cc3 100644 --- a/torch/sparse/_triton_ops_meta.py +++ b/torch/sparse/_triton_ops_meta.py @@ -379,7 +379,6 @@ def from_key(key, parameters): minimizer_key = ( initial_key if initial_key in minimizer_keys else min(minimizer_keys) ) - minimizer_target = all_values[minimizer_key] parameters = from_key(minimizer_key, parameters) speedup_incr = (1 - minimal_target / reference_target) * 100 if speedup_incr < 0: @@ -555,7 +554,7 @@ def step_meta_parameter(name, value, direction, meta, m=m, n=n, k=k, bm=bm, bk=b return value return next_value - meta, speedup, timing, sensitivity_message = minimize( + meta, speedup, timing, _sensitivity_message = minimize( bench, initial_meta, reference_meta, step_meta_parameter ) if initial_meta is not reference_meta and initial_meta == meta and not force: @@ -644,7 +643,15 @@ def tune_bsr_dense_addmm( # Compute the key of parameters: sparsity = round(1 - bsr._nnz() * BM * BK / (M * K), 2) dtype = bsr.dtype - version = (0, dtype, sparsity) + if out is None: + out_dtype = dtype + else: + out_dtype = out.dtype + if out_dtype is dtype: + version_dtype = dtype + else: + version_dtype = (dtype, out_dtype) + version = (0, version_dtype, sparsity) key = (M, K, N, BM, BK, beta == 0, beta == 1, alpha == 1) # For tuning, for an initial state, use parameters from the @@ -740,6 +747,7 @@ def optimize_bsr_dense_addmm( use_left_alpha=False, use_right_alpha=False, dtype=torch.float16, + out_dtype=None, device="cuda", sparsity=0.5, force=False, @@ -756,6 +764,10 @@ def optimize_bsr_dense_addmm( right_alpha = ( make_tensor(n, dtype=dtype, device=device) if use_right_alpha else None ) + if out_dtype is not None: + out = dense.new_empty((m, n), dtype=out_dtype) + else: + out = None tune_bsr_dense_addmm( input, bsr, @@ -764,6 +776,7 @@ def optimize_bsr_dense_addmm( alpha=alpha, left_alpha=left_alpha, right_alpha=right_alpha, + out=out, store=True, force=force, verbose=verbose, @@ -833,7 +846,7 @@ def main(op="scatter_mm", force=False, dtype=torch.float16, verbose=True): raise NotImplementedError(op) except KeyboardInterrupt: break - except Exception as msg: + except Exception: dump() raise dump() diff --git a/torch/storage.py b/torch/storage.py index 0de35499baf5f..c6efb4a7c5095 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -8,7 +8,16 @@ import io import threading import warnings -from typing import Any, cast, Dict as _Dict, Optional as _Optional, Type, TypeVar, Union +from typing import ( + Any, + cast, + Dict as _Dict, + Optional as _Optional, + Type, + TYPE_CHECKING, + TypeVar, + Union, +) from typing_extensions import Self import torch @@ -16,6 +25,10 @@ from torch.types import _bool, _int, Storage +if TYPE_CHECKING: + from torch._prims_common import DeviceLikeType + + __all__ = ["TypedStorage", "UntypedStorage"] @@ -273,9 +286,9 @@ def _to(self, dtype): storage = storage.clone() return storage - def to( - self, *, device: torch.device, non_blocking: _bool = False - ) -> Union[_StorageBase, TypedStorage]: + def to(self, *, device: DeviceLikeType, non_blocking: _bool = False): + if not isinstance(device, torch.device): + device = torch.device(device) return _to(self, device, non_blocking) def double(self): @@ -1061,8 +1074,10 @@ def hpu(self, device=None, non_blocking=False) -> Self: hpu_storage = self._untyped_storage.hpu(device, non_blocking) return self._new_wrapped_storage(hpu_storage) - def to(self, *, device: torch.device, non_blocking: bool = False) -> Self: + def to(self, *, device: DeviceLikeType, non_blocking: bool = False) -> Self: _warn_typed_storage_removal() + if not isinstance(device, torch.device): + device = torch.device(device) if self.dtype in [ torch.quint8, torch.quint4x2, diff --git a/torch/testing/_internal/autocast_test_lists.py b/torch/testing/_internal/autocast_test_lists.py index 2d18da71ec2bd..a75e9c834b70a 100644 --- a/torch/testing/_internal/autocast_test_lists.py +++ b/torch/testing/_internal/autocast_test_lists.py @@ -246,7 +246,6 @@ def __init__(self, dev): # Utility arguments, created as one-element tuples pointwise0_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) pointwise1_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) - pointwise2_bf16 = (torch.randn(n, dtype=torch.bfloat16, device=dev),) mat0_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) mat1_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) mat2_bf16 = (torch.randn((n, n), dtype=torch.bfloat16, device=dev),) @@ -260,14 +259,10 @@ def __init__(self, dev): for dimset in dummy_dimsets] dimsets = ((n, n, n), (n, n, n, n), (n, n, n, n, n)) - conv_args_bf16 = [(torch.randn(dimset, dtype=torch.bfloat16, device=dev), - torch.randn(dimset, dtype=torch.bfloat16, device=dev)) - for dimset in dimsets] conv_args_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev), torch.randn(dimset, dtype=torch.float32, device=dev)) for dimset in dimsets] - bias_fp32 = (torch.randn((n,), dtype=torch.float32, device=dev),) element0_fp32 = (torch.randn(1, dtype=torch.float32, device=dev),) pointwise0_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) pointwise1_fp32 = (torch.randn(n, dtype=torch.float32, device=dev),) @@ -276,8 +271,10 @@ def __init__(self, dev): mat2_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) mat3_fp32 = (torch.randn((n, n), dtype=torch.float32, device=dev),) - dummy_fp32 = [(torch.randn(dimset, dtype=torch.float32, device=dev),) - for dimset in dummy_dimsets] + dummy_fp32 = [ # noqa: F841 + (torch.randn(dimset, dtype=torch.float32, device=dev),) + for dimset in dummy_dimsets + ] # The lists below organize ops that autocast needs to test. # self.list_name corresponds to test_autocast_list_name in test/test_cpu.py. # Each op is associated with a tuple of valid arguments. diff --git a/torch/testing/_internal/autograd_function_db.py b/torch/testing/_internal/autograd_function_db.py index e092c4d9339b7..46abb4bb758dd 100644 --- a/torch/testing/_internal/autograd_function_db.py +++ b/torch/testing/_internal/autograd_function_db.py @@ -68,7 +68,7 @@ def setup_context(ctx, inputs, outputs): @staticmethod def backward(ctx, grad_output, grad_saved): - input, dinput = ctx.saved_tensors + _input, dinput = ctx.saved_tensors result = grad_output * dinput + 6 * dinput return result @@ -213,7 +213,6 @@ def forward(x, dim): x = to_numpy(x) ind = np.argsort(x, axis=dim) ind_inv = np.argsort(ind, axis=dim) - result = np.take_along_axis(x, ind, axis=dim) return ( torch.tensor(x, device=device), torch.tensor(ind, device=device), @@ -222,7 +221,7 @@ def forward(x, dim): @staticmethod def setup_context(ctx, inputs, output): - x, dim = inputs + _x, dim = inputs _, ind, ind_inv = output ctx.mark_non_differentiable(ind, ind_inv) ctx.save_for_backward(ind, ind_inv) @@ -252,7 +251,6 @@ class SortGenVmap(torch.autograd.Function): @staticmethod def forward(x, dim): - device = x.device ind = torch.argsort(x, dim=dim) ind_inv = torch.argsort(ind, axis=dim) result = torch.take_along_dim(x, ind, dim=dim) @@ -301,7 +299,7 @@ def forward(x, ind, ind_inv, dim): @staticmethod def setup_context(ctx, inputs, output): - x, ind, ind_inv, dim = inputs + _x, ind, ind_inv, dim = inputs ctx.save_for_backward(ind, ind_inv) ctx.save_for_forward(ind, ind_inv) ctx.dim = dim @@ -347,7 +345,7 @@ def forward(x, ind, ind_inv, dim): @staticmethod def setup_context(ctx, inputs, outputs): - x, ind, ind_inv, dim = inputs + _x, ind, ind_inv, dim = inputs ctx.save_for_backward(ind, ind_inv) ctx.save_for_forward(ind, ind_inv) ctx.dim = dim diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index ad7f1d1621cb1..ae47eae9c1fff 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -94,7 +94,7 @@ def evaluate_platform_supports_fp8(): try: import numba.cuda TEST_NUMBA_CUDA = numba.cuda.is_available() - except Exception as e: + except Exception: TEST_NUMBA_CUDA = False TEST_NUMBA = False else: @@ -253,6 +253,8 @@ def _check_cusparse_generic_available(): def _check_hipsparse_generic_available(): if not TEST_WITH_ROCM: return False + if not torch.version.hip: + return False rocm_version = str(torch.version.hip) rocm_version = rocm_version.split("-")[0] # ignore git sha diff --git a/torch/testing/_internal/common_device_type.py b/torch/testing/_internal/common_device_type.py index 6df572bd0e116..4b3e3a0d7ec58 100644 --- a/torch/testing/_internal/common_device_type.py +++ b/torch/testing/_internal/common_device_type.py @@ -11,9 +11,21 @@ from collections import namedtuple from enum import Enum from functools import partial, wraps -from typing import Any, ClassVar, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import ( + Any, + ClassVar, + Dict, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Union, +) import torch +from torch._inductor.utils import GPU_TYPES from torch.testing._internal.common_cuda import ( _get_torch_cuda_version, _get_torch_rocm_version, @@ -1201,7 +1213,14 @@ def __init__(self, dep, reason, device_type=None): def __call__(self, fn): @wraps(fn) def dep_fn(slf, *args, **kwargs): - if self.device_type is None or self.device_type == slf.device_type: + if ( + self.device_type is None + or self.device_type == slf.device_type + or ( + isinstance(self.device_type, Iterable) + and slf.device_type in self.device_type + ) + ): if (isinstance(self.dep, str) and getattr(slf, self.dep, True)) or ( isinstance(self.dep, bool) and self.dep ): @@ -1230,6 +1249,12 @@ def __init__(self, dep, reason): super().__init__(dep, reason, device_type="xpu") +# Skips a test on XPU or CUDA if the condition is true. +class skipGPUIf(skipIf): + def __init__(self, dep, reason): + super().__init__(dep, reason, device_type=GPU_TYPES) + + # Skips a test on Lazy if the condition is true. class skipLazyIf(skipIf): def __init__(self, dep, reason): diff --git a/torch/testing/_internal/common_dist_composable.py b/torch/testing/_internal/common_dist_composable.py index e7bce5c37f3d9..8b1778a918dc4 100644 --- a/torch/testing/_internal/common_dist_composable.py +++ b/torch/testing/_internal/common_dist_composable.py @@ -107,5 +107,7 @@ def __init__(self, device: torch.device) -> None: ), ) + # FIXME(rec): forward() is not a method, it's a local function inside __init__ + # that is never used. It should probabkly be outdented by four spaces, or removed. def forward(self, x: torch.Tensor) -> torch.Tensor: return self.seq2(self.lin(self.seq1(x))) diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index fb2a5c034b3e7..1f2ffd8c8be48 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -117,10 +117,16 @@ def wrapper(*args, **kwargs): return wrapper +# TODO (kwen2501): what is the purpose of this decorator? Tests with this +# decorator were always skipped. So they may be outdated already. +# Oct 2024: bumping the small-world criteria to < 8, as we are increasing the +# number of GPUs in CI from 2 to 4, and we need to continue skipping those tests +# to keep CI green. But this is just a temporary solution. We should clean up +# those tests somehow. def skip_if_small_worldsize(func): @wraps(func) def wrapper(*args, **kwargs): - if (os.environ["BACKEND"] != "mpi") and int(os.environ["WORLD_SIZE"]) <= 2: + if (os.environ["BACKEND"] != "mpi") and int(os.environ["WORLD_SIZE"]) < 8: sys.exit(TEST_SKIPS["small_worldsize"].exit_code) return func(*args, **kwargs) @@ -353,6 +359,22 @@ def skip_if_win32(): ) +def sm_is_or_higher_than(device: torch.device, major: int, minor: int) -> bool: + """ + Returns True if the device's compute capability is (major, minor) or higher. + Error out if the device is not a CUDA device. + Returns False if device is a RoCM device. + """ + if device.type != "cuda": + raise ValueError("sm_is_or_later() is only supported for CUDA devices") + + if torch.version.hip is not None: + # ROCm devices may have different compute capability codes + return False + + return torch.cuda.get_device_capability(device) >= (major, minor) + + @retry_on_connect_failures def create_tcp_store( addr="localhost", @@ -673,7 +695,7 @@ def run_test(self, test_name: str, parent_pipe) -> None: "Process %s skipping test %s for following reason: %s", self.rank, test_name, str(se) ) sys.exit(TEST_SKIPS["generic"].exit_code) - except Exception as e: + except Exception: logger.error( "Caught exception: \n%s exiting " "process %s with exit code: %s", diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 24e70566f8f9b..529835f6d9a1e 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -35,7 +35,7 @@ ) from torch.testing._internal.common_utils import ( make_fullrank_matrices_with_distinct_singular_values, - TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, TEST_SCIPY, + TEST_WITH_ROCM, IS_FBCODE, IS_WINDOWS, IS_MACOS, TEST_SCIPY, torch_to_numpy_dtype_dict, numpy_to_torch_dtype, TEST_WITH_ASAN, GRADCHECK_NONDET_TOL, slowTest, TEST_WITH_SLOW, TEST_WITH_TORCHINDUCTOR @@ -732,10 +732,7 @@ def sample_inputs_jiterator(op, device, dtype, requires_grad, **kwargs): for shape_lhs, shape_rhs in shapes: lhs = make_arg(shape_lhs) - - args = [] - for i in range(num_inputs - 1): - args.append(make_arg(shape_rhs)) + args = [make_arg(shape_rhs) for _ in range(num_inputs - 1)] broadcasts_input = (shape_lhs != torch.broadcast_shapes(shape_lhs, shape_rhs)) yield SampleInput(lhs, args=tuple(args), kwargs=sample_kwargs, broadcasts_input=broadcasts_input) @@ -843,8 +840,6 @@ def to_float(start, end, step): yield SampleInput(1, args=(3, 1)) def sample_inputs_randn(op, device, dtype, requires_grad, **kwargs): - make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False) - shapes = ( (M,), (S, S) @@ -3098,7 +3093,6 @@ def sample_inputs_diff(op_info, device, dtype, requires_grad, **kwargs): ((XS, XS, XS), 2, (XS, XS, 1), (XS, XS, 1)), ((XS, XS, XS), 2, (XS, XS, XS), (XS, XS, XS)),) - sample_inputs = [] for size, dim, size_prepend, size_append in test_cases: prepend_size = 0 if (size_prepend is None) else size_prepend[dim] append_size = 0 if (size_append is None) else size_append[dim] @@ -3126,7 +3120,7 @@ def sample_inputs_histogram(op_info, device, dtype, requires_grad, **kwargs): weight=weight_tensor, density=density) bins_tensor = make_arg((bin_ct + 1,)) - sorted_bins, bins_indices = torch.sort(bins_tensor) + sorted_bins, _bins_indices = torch.sort(bins_tensor) yield SampleInput(input_tensor, sorted_bins, weight=weight_tensor, density=density) @@ -4906,7 +4900,6 @@ def error_inputs_gelu(op, device, **kwargs): def sample_inputs_max_min_reduction_with_dim(op_info, device, dtype, requires_grad, **kwargs): - inputs = [] args_for_reduction_with_dim = ( ((S, S, S), (1,),), ((S, S, S), (1, True, ),), @@ -5236,8 +5229,6 @@ def sample_inputs_dist(op_info, device, dtype, requires_grad, **kwargs): # Missing to test the nondeterminism of the operation # https://github.com/pytorch/pytorch/issues/53352 def sample_inputs_index(op_info, device, dtype, requires_grad, reference=False, **kwargs): - # target.index_select(dim, idx) - select = "index_select" in op_info.name # target.index_add(dim, idx, source, *, alpha=1) add = "index_add" in op_info.name # target.index_copy(dim, idx, source) @@ -5379,8 +5370,6 @@ def make_idx(n, m, dim, d): ((S, S, S), S, (M, M - 1, M + 1)), ] - fill_value = make_tensor([], dtype=dtype, device="cpu").item() - for c in cases: self_shape, high, idx_sizes = c dim = len(self_shape) @@ -7206,7 +7195,6 @@ def sample_inputs_segment_reduce(op_info, device, dtype, requires_grad, *, mode= def _tensor(shape, dtype=dtype, low=None, high=None): return make_tensor(shape, dtype=dtype, device=device, low=low, high=high, requires_grad=requires_grad) - zero = torch.tensor(0, dtype=torch.long, device=device) test_cases = ( # inp_shape, dim, lengths, unsafe ((S,), 0, [0, 1, 2, 2], False), @@ -7249,8 +7237,6 @@ def sample_inputs_ravel(op_info, device, dtype, requires_grad, **kwargs): yield SampleInput(make_arg((S, S, S), noncontiguous=True)) def sample_inputs_unravel_index(op_info, device, dtype, requires_grad, **kwargs): - make_arg = partial(make_tensor, dtype=dtype, device=device, - low=None, high=None, requires_grad=requires_grad) yield SampleInput( torch.tensor( [[3, 8, 13], [0, 5, 10]], @@ -7579,7 +7565,6 @@ def error_inputs_view_reshape(op, device, **kwargs): def sample_inputs_atleast1d2d3d(op_info, device, dtype, requires_grad, **kwargs): - input_list = [] shapes = ((S, S, S, S), (S, S, S), (S, S), (S, ), (),) make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) for shape in shapes: @@ -7803,7 +7788,8 @@ def reference_inputs_where(op, device, dtype, requires_grad, **kwargs): yield SampleInput(a, args=(c, b)) # type promoting - other_dtype = torch.double if dtype is not torch.double else torch.long + # FIXME(rec): shouldn't other_dtype be used two lines below? + other_dtype = torch.double if dtype is not torch.double else torch.long # noqa: F841 c = make_cond((10, 3), noncontiguous=True) a = make_arg((10, 1), dtype=torch.long) b = make_arg((10, 1)) @@ -8783,7 +8769,7 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_ causal_options = [False] # FIXME: Large errors with causal+fp32 else: causal_options = [True, False] - for qkv_shape, is_causal, dropout_p, enable_gqa in product( + for qkv_shape, is_causal, dropout_p, _enable_gqa in product( qkv_shapes, causal_options, [0.0, 0.5], gqa_options): shape_q, shape_kv = qkv_shape samples.append(SampleInput( @@ -8795,7 +8781,8 @@ def sample_inputs_scaled_dot_product_attention(op_info, device, dtype, requires_ )) # Add non standard shapes - diff_v_head_dim = SampleInput( + # FIXME(rec): should diff_v_head_dim be appended to samples? + diff_v_head_dim = SampleInput( # noqa: F841 make((batch, num_heads, seq_q, head_dim)), make((batch, num_heads, seq_kv, head_dim)), make((batch, num_heads, seq_kv, head_dim + 8)), @@ -8841,7 +8828,7 @@ def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_g mask_types = [1, 2] # UpperLeft, LowerRight scales = [None, 1.0] - for qkv_shape, is_causal, dropout_p, mask_type, scale in product( + for qkv_shape, _is_causal, dropout_p, mask_type, scale in product( qkv_shapes, [True, False], [0.0, 0.5], mask_types, scales): shape_q, shape_kv = qkv_shape samples.append(SampleInput( @@ -8861,7 +8848,8 @@ def sample_inputs_efficient_attention_forward(op_info, device, dtype, requires_g )) # Add non standard shapes - diff_v_head_dim = SampleInput( + # FIXME(rec): should diff_v_head_dim be appended to samples? + diff_v_head_dim = SampleInput( # noqa: F841 make((batch, seq_q, num_heads, head_dim)), make((batch, seq_kv, num_heads, head_dim)), make((batch, seq_kv, num_heads, head_dim + 8)), @@ -9046,7 +9034,6 @@ def sample_inputs_allclose(op_info, device, dtype, requires_grad, **kwargs): sample_shapes = [(), (S), (S, S, S)] atols = [1e-2, 1e-16] rtols = [1e-1, 0.5] - eps = 1e-8 for s, rtol, atol in product(sample_shapes, rtols, atols): # close sample t = make_tensor(s, device=device, dtype=dtype, requires_grad=requires_grad) @@ -9474,7 +9461,7 @@ def sample_zero_size_tensor_inputs(self, opinfo, device, dtype, requires_grad, * _foreach_inputs_kwargs = {k: kwargs.pop(k, v) for k, v in _foreach_inputs_default_kwargs.items()} _foreach_inputs_kwargs["requires_grad"] = requires_grad allow_higher_dtype_scalars = kwargs.pop("allow_higher_dtype_scalars", False) - for rightmost_arg_type in self._rightmost_arg_types: + for _rightmost_arg_type in self._rightmost_arg_types: zero_size_foreach_inputs_kwargs = copy.deepcopy(_foreach_inputs_kwargs) zero_size_foreach_inputs_kwargs["zero_size"] = True input = sample_inputs_foreach(None, device, dtype, NUM_SIZE0_TENSORS, **zero_size_foreach_inputs_kwargs) @@ -10866,6 +10853,9 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides"), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace_all_strides"), + DecorateInfo(unittest.skip("consistently fails internally and causes other tests to appear flaky"), + "TestForeach", "test_parity", dtypes=(torch.complex128,), + active_if=lambda kwargs: IS_FBCODE and not kwargs["noncontiguous"]), ), ), ForeachFuncInfo( @@ -10884,6 +10874,9 @@ def __call__(self, opinfo, device, dtype, requires_grad, **kwargs): DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_inplace", dtypes=(torch.bool,)), DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_inplace_all_strides", dtypes=(torch.bool,)), + DecorateInfo(unittest.skip("consistently fails internally and causes other tests to appear flaky"), + "TestForeach", "test_parity", dtypes=(torch.complex128,), + active_if=lambda kwargs: IS_FBCODE and not kwargs["noncontiguous"]), ), ), ForeachFuncInfo( @@ -13559,7 +13552,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 5e-1}),), dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), - backward_dtypes=floating_types(), supports_forward_ad=True, supports_fwgrad_bwgrad=True, promotes_int_to_float=True, @@ -16291,14 +16283,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): check_batched_forward_grad=False, decorators=[skipCUDAIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "This platform doesn't support Flash Attention")], skips=( - # for element 1, was torch.Size([4, 4, 11]) but real shape was torch.Size([16, 11]) - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", device_type="cuda", - dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", device_type="cuda", - dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", device_type="cuda", - dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), - # Checking the scalar value of the philox seed and offset # Checking the scalar value of the philox seed and offset DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), @@ -16326,14 +16310,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): skipCUDAIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "This platform doesn't support efficient attention"), skipCUDAIf(TEST_WITH_ROCM, "Efficient attention on ROCM doesn't support custom_mask_type==2")], skips=( - # for element 1, was torch.Size([4, 4, 11]) but real shape was torch.Size([16, 11]) - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_meta_outplace", device_type="cuda", - dtypes=[torch.float16, torch.bfloat16, torch.float32], - active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_symbolic_meta_outplace", device_type="cuda", - dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), - DecorateInfo(unittest.expectedFailure, "TestMeta", "test_dispatch_meta_outplace", device_type="cuda", - dtypes=[torch.float16, torch.bfloat16], active_if=TEST_WITH_ROCM and PLATFORM_SUPPORTS_FLASH_ATTENTION), # Checking the scaler value of the philox seed and offset DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator', device_type='cuda'), DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples', device_type='cuda'), @@ -17029,6 +17005,19 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_varargs=True, sample_inputs_func=sample_inputs_permute, reference_inputs_func=reference_inputs_permute), + OpInfo('permute_copy', + dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf), + supports_out=True, + assert_autodiffed=True, + assert_jit_shape_analysis=True, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_varargs=False, # torch.permute is also not varargs + sample_inputs_func=sample_inputs_permute, + reference_inputs_func=reference_inputs_permute, + skips=( + DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit', dtypes=(torch.float32,)), + )), BinaryUfuncInfo('pow', dtypes=all_types_and_complex_and(torch.half, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf), @@ -19466,6 +19455,25 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_gradgrad=True, supports_out=False, ), + OpInfo('unbind_copy', + dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), + ref=reference_unbind, + sample_inputs_func=sample_inputs_unbind, + error_inputs_func=error_inputs_unbind, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, + supports_gradgrad=True, + supports_out=True, + check_batched_grad=False, + skips=( + # Expected __torch_dispatch__ for aten::unbind_copy.int_out to return None + # but it returned something else instead. + DecorateInfo( + unittest.expectedFailure, + 'TestProxyTensorOpInfo', + 'test_make_fx_symbolic_exhaustive_out' + ), + )), OpInfo('vstack', aliases=('row_stack',), dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16), @@ -23957,6 +23965,11 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): "_refs.permute", torch_opinfo_name="permute", ), + PythonRefInfo( + "_refs.permute_copy", + torch_opinfo_name="permute_copy", + supports_out=True, + ), ElementwiseUnaryPythonRefInfo( "_refs.rad2deg", torch_opinfo_name="rad2deg", @@ -24062,10 +24075,6 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): PythonRefInfo( "_refs.transpose_copy", torch_opinfo_name="transpose_copy", - skips=( - # RuntimeError: no _refs support for torch.Tensor.is_conj - DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref'), - ), supports_out=True, ), PythonRefInfo( @@ -24082,6 +24091,10 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): torch_opinfo_name="T", error_inputs_func=partial(error_inputs_T, has_ndims_error=True), ), + PythonRefInfo( + "_refs.unbind_copy", + torch_opinfo_name="unbind_copy", + ), PythonRefInfo( "_refs.unfold", torch_opinfo_name="unfold", diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index 63963bab1b050..69dabb07c1c29 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -340,7 +340,10 @@ def reference_fn(m, p, i, t, constructor_kwargs=constructor_kwargs): ) scalar_input = make_input(()).log() - scalar_target = make_input(()) if kwargs.get('log_target', False) else make_input(()).log() + # FIXME(rec): scalar_target is unused, perhaps should be argument to FunctionInput? + scalar_target = ( # noqa: F841 + make_input(()) if kwargs.get('log_target', False) else make_input(()).log() + ) module_inputs.append( ModuleInput(constructor_input=FunctionInput(**constructor_kwargs), forward_input=FunctionInput(scalar_input, scalar_input), diff --git a/torch/testing/_internal/common_optimizers.py b/torch/testing/_internal/common_optimizers.py index bd8f1f2963f52..052fddda8285c 100644 --- a/torch/testing/_internal/common_optimizers.py +++ b/torch/testing/_internal/common_optimizers.py @@ -55,7 +55,9 @@ class OptimizerInput: def __init__( self, - params: Union[List[Parameter], List[Tensor], Dict[Any, Any]], + params: Union[ + List[Parameter], List[Tensor], Dict[Any, Any], List[Dict[str, Any]] + ], kwargs: Dict[str, Any], desc: str = "", ): @@ -244,6 +246,7 @@ def test_wrapper(*args, **kwargs): def get_error_inputs_for_all_optims(device, dtype): if _get_device_type(device) == "cpu": sample_param = Parameter(torch.randn(1, device=device, dtype=dtype)) + sample_param2 = Parameter(torch.randn(1, device=device, dtype=dtype)) return [ ErrorOptimizerInput( OptimizerInput( @@ -281,6 +284,28 @@ def get_error_inputs_for_all_optims(device, dtype): error_type=ValueError, error_regex="Tensor lr must be 1-element", ), + ErrorOptimizerInput( + OptimizerInput( + params=[("weight", sample_param), sample_param2], + kwargs={}, + desc="all optimizer params should be with/without names", + ), + error_type=ValueError, + error_regex="all optimizer params should be with/without names. Some param names are missing", + ), + ErrorOptimizerInput( + OptimizerInput( + params=[ + {"params": [sample_param], "lr": 1e-2}, + {"params": [("weight", sample_param2)]}, + ], + kwargs={}, + desc="all optimizer param groups should be with/without names.", + ), + error_type=ValueError, + error_regex="all optimizer param groups should be with/without names. " + "cannot add param group with names to the optimizer", + ), ] else: return [] diff --git a/torch/testing/_internal/common_pruning.py b/torch/testing/_internal/common_pruning.py index e8a64dfcc3c37..affb0616c9231 100644 --- a/torch/testing/_internal/common_pruning.py +++ b/torch/testing/_internal/common_pruning.py @@ -362,7 +362,7 @@ def __init__( self.linear = nn.Linear(hidden_dim, output_dim) def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - output, hidden = self.lstm(input) + output, _hidden = self.lstm(input) decoded = self.linear(output) return decoded, output diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index 3435899e5849e..5deaf4e8efcd2 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -67,7 +67,6 @@ import copy import io import functools -import time import os import unittest @@ -125,7 +124,7 @@ def test_only_eval_fn(model, calib_data): input Tensors and run the model on the dataset """ for inp in calib_data: - output = model(*inp) + model(*inp) _default_loss_fn = torch.nn.CrossEntropyLoss() def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn): @@ -135,7 +134,7 @@ def test_only_train_fn(model, train_data, loss_fn=_default_loss_fn): """ optimizer = torch.optim.Adam(model.parameters(), lr=0.001) train_loss, correct, total = 0, 0, 0 - for i in range(10): + for _ in range(10): model.train() for data, target in train_data: @@ -194,7 +193,6 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_bat model.train() cnt = 0 for image, target in data_loader: - start_time = time.time() print('.', end='') cnt += 1 image, target = image.to(device), target.to(device) @@ -203,7 +201,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, ntrain_bat optimizer.zero_grad() loss.backward() optimizer.step() - acc1, acc5 = accuracy(output, target, topk=(1, 5)) + accuracy(output, target, topk=(1, 5)) if cnt >= ntrain_batches: return return @@ -1183,7 +1181,8 @@ def _create_quantized_model(self, model_class: Type[torch.nn.Module], **kwargs): # Creates quantized model for testing mobile script modules qengine = "qnnpack" with override_quantized_engine(qengine): - qconfig = torch.ao.quantization.get_default_qconfig(qengine) + # FIXME(rec): shouldn't qconfig be passed to quantize? + qconfig = torch.ao.quantization.get_default_qconfig(qengine) # noqa: F841 model = model_class(**kwargs) model = quantize(model, test_only_eval_fn, [self.calib_data]) @@ -2374,7 +2373,7 @@ def __init__(self) -> None: self.conv1 = nn.Conv2d(3, 3, 1) self.relu1 = nn.ReLU(inplace=False) layers = [] - for i in range(3): + for _ in range(3): layers.append(ConvBNReLU()) self.features = nn.Sequential(*layers) head = [nn.Linear(300, 10), nn.ReLU(inplace=False)] diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 71805be228714..764a2fc6f3c01 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -74,6 +74,7 @@ from torch._C import ScriptDict, ScriptList # type: ignore[attr-defined] from torch._dynamo.trace_rules import _as_posix_path from torch._utils_internal import get_writable_path +from torch._logging.scribe import open_source_signpost from torch.nn import ( ModuleDict, ModuleList, @@ -97,6 +98,7 @@ from torch.testing._internal.common_dtype import get_all_dtypes from torch.utils._import_utils import _check_module_exists import torch.utils._pytree as pytree +from torch.utils import cpp_extension try: import pytest has_pytest = True @@ -2220,7 +2222,7 @@ def is_iterable_of_tensors(iterable, include_empty=False): if not isinstance(t, torch.Tensor): return False - except TypeError as te: + except TypeError: return False return True @@ -2329,7 +2331,7 @@ def __exit__(self, exec_type, exec_value, traceback): discrepancy_detected = True # Query memory multiple items to ensure leak was not transient - for n in range(3): + for _ in range(3): caching_allocator_mem_allocated = torch.cuda.memory_allocated(i) bytes_free, bytes_total = torch.cuda.mem_get_info(i) driver_mem_allocated = bytes_total - bytes_free @@ -2396,6 +2398,17 @@ def print_repro_on_failure(repro_parts): sample_isolation_prefix = f"PYTORCH_OPINFO_SAMPLE_INPUT_INDEX={tracked_input.index}" repro_str = " ".join(filter(None, (sample_isolation_prefix, *repro_parts))) + + open_source_signpost( + subsystem="test_repros", + name="test_failure", + parameters=json.dumps( + { + "repro": " ".join(filter(None, (sample_isolation_prefix, *repro_parts))), + } + ), + ) + repro_msg = f""" To execute this test, run the following from the base repo dir: {repro_str} @@ -3033,6 +3046,8 @@ def _run_custom(self, result=None): if strict_mode or should_reset_dynamo: torch._dynamo.reset() + torch.compiler.set_stance("default") + # TODO: Remove this; this is grandfathered in because we suppressed errors # on test suite previously # When strict mode is False, suppress_errors is True @@ -4220,7 +4235,7 @@ def runWithPytorchAPIUsageStderr(code): # CI flag should be set in the parent process only. env.pop("CI", None) env.pop("TEST_SHOWLOCALS", None) - (stdout, stderr) = TestCase.run_process_no_exception(code, env=env) + _stdout, stderr = TestCase.run_process_no_exception(code, env=env) return stderr.decode('ascii') @@ -5365,7 +5380,7 @@ def remove_cpp_extensions_build_root(): """ Removes the default root folder under which extensions are built. """ - default_build_root = torch.utils.cpp_extension.get_default_build_root() + default_build_root = cpp_extension.get_default_build_root() if os.path.exists(default_build_root): if IS_WINDOWS: # rmtree returns permission error: [WinError 5] Access is denied @@ -5373,3 +5388,24 @@ def remove_cpp_extensions_build_root(): subprocess.run(["rm", "-rf", default_build_root], stdout=subprocess.PIPE) else: shutil.rmtree(default_build_root, ignore_errors=True) + +# Decorator to provide a helper to load inline extensions to a temp directory +def scoped_load_inline(func): + + @wraps(func) + def wrapper(*args, **kwargs): + def load_inline(*args, **kwargs): + if IS_WINDOWS: + # TODO(xmfan): even using TemporaryDirectoryName will result in permission error + return cpp_extension.load_inline(*args, **kwargs) + + assert "build_directory" not in kwargs + with TemporaryDirectoryName() as temp_dir_name: + if kwargs.get("verbose", False): + print(f'Using temporary extension directory {temp_dir_name}...', file=sys.stderr) + kwargs["build_directory"] = temp_dir_name + return cpp_extension.load_inline(*args, **kwargs) + + return func(*args, load_inline=load_inline, **kwargs) + + return wrapper diff --git a/torch/testing/_internal/composite_compliance.py b/torch/testing/_internal/composite_compliance.py index ab1a05d4fef40..c0ce944c641d0 100644 --- a/torch/testing/_internal/composite_compliance.py +++ b/torch/testing/_internal/composite_compliance.py @@ -413,8 +413,8 @@ def unwrap(e): def gather_leaf_tensors(args, kwargs): leaf_tensors = [] - args, args_spec = tree_flatten(args) - kwargs, kwargs_spec = tree_flatten(kwargs) + args, _args_spec = tree_flatten(args) + kwargs, _kwargs_spec = tree_flatten(kwargs) args = args + kwargs for arg in args: if not isinstance(arg, torch.Tensor): diff --git a/torch/testing/_internal/custom_op_db.py b/torch/testing/_internal/custom_op_db.py index f15e8312aa5a4..c457a423e0e65 100644 --- a/torch/testing/_internal/custom_op_db.py +++ b/torch/testing/_internal/custom_op_db.py @@ -41,7 +41,7 @@ def _(x): def numpy_cube_setup_context(ctx, inputs, output): x, = inputs - cube, dx = output + _cube, dx = output ctx.save_for_backward(x, dx) def numpy_cube_backward(ctx, grad_out, grad_dx): @@ -131,7 +131,7 @@ def _(x, dim): return torch.empty_like(x), torch.empty_like(x, dtype=torch.long), torch.empty_like(x, dtype=torch.long) def numpy_sort_setup_context(ctx, inputs, output): - out, ind, ind_inv = output + _out, ind, ind_inv = output ctx.dim = inputs[1] ctx.save_for_backward(ind, ind_inv) ctx.mark_non_differentiable(ind, ind_inv) @@ -167,7 +167,7 @@ def _(x, ind, ind_inv, dim): return torch.empty_like(x) def numpy_take_setup_context(ctx, inputs, output): - x, ind, ind_inv, dim = inputs + _x, ind, ind_inv, dim = inputs ctx.dim = dim ctx.save_for_backward(ind, ind_inv) diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 8514be7979190..1a69cdcd2f5e4 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -307,23 +307,28 @@ def backend(self) -> str: def build_device_mesh(self) -> DeviceMesh: return DeviceMesh(self.device_type, list(range(self.world_size))) - def init_pg(self) -> None: + def init_pg(self, eager_init) -> None: if "nccl" in self.backend and torch.cuda.device_count() < self.world_size: sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code) if self.backend not in ["nccl", "gloo", "mpi", "cpu:gloo,cuda:nccl"]: raise RuntimeError(f"Backend {self.backend} not supported!") + if "nccl" in self.backend: + # set device for nccl pg for collectives + torch.cuda.set_device(self.rank) + + # For nccl backend, bind the device to the process if device_id is not None + # so the nccl communicator is immediately formed and we can use `ncclCommSplit` + # for form subgroup to avoid unnecesssary overhead. dist.init_process_group( backend=self.backend, world_size=self.world_size, rank=self.rank, # pyre-ignore[16] init_method=f"file://{self.file_name}", # pyre-ignore[16] + device_id=(torch.device(f"{self.device_type}:{self.rank}") if eager_init else None), ) - # set device for nccl pg for collectives - if "nccl" in self.backend: - torch.cuda.set_device(self.rank) def destroy_pg(self) -> None: # Wait for all ranks to reach here before starting shutdown. @@ -356,30 +361,34 @@ def run_subtests(self, *args, **kwargs): # wrapper to initialize comms (processgroup) -def with_comms(func: TestFunc) -> TestFunc: - assert func is not None +def with_comms(eager_init: bool = False) -> TestFunc: - @wraps(func) # pyre-ignore[6] - def wrapper( - self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc] - ) -> None: - # if enough GPU we can use GPU, otherwise we fallback to CPU - if not torch.cuda.is_available() or torch.cuda.device_count() < self.world_size: - self.device_type = "cpu" - else: - self.device_type = DEVICE_TYPE + def decorator(func): - self.init_pg() + @wraps(func) # pyre-ignore[6] + def wrapper( + self, *args: Tuple[object], **kwargs: Dict[str, Any] # type: ignore[misc] + ) -> None: + # if enough GPU we can use GPU, otherwise we fallback to CPU + if not torch.cuda.is_available() or torch.cuda.device_count() < self.world_size: + self.device_type = "cpu" + else: + self.device_type = DEVICE_TYPE - try: - func(self, *args, **kwargs) # type: ignore[misc] - except Exception as e: - dist.destroy_process_group() - raise e + self.init_pg(eager_init) + + try: + func(self, *args, **kwargs) # type: ignore[misc] + except Exception as e: + dist.destroy_process_group() + raise e + + self.destroy_pg() + + return wrapper - self.destroy_pg() + return decorator - return wrapper class DTensorOpTestBase(MultiThreadedTestCase): diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index bbbec92153df5..53be4c081c175 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -894,50 +894,6 @@ def test_barrier_timeout_full_group(self): if group_id is not None: self._test_barrier_timeout(group_id, timeout) - # This test helper can only be used when using the Gloo or NCCL backend - # **and** both the Gloo and NCCL backends are available. - # See the @skip annotations below. - def _test_group_override_backend(self, initializer): - if BACKEND == "gloo": - new_backend = "nccl" - elif BACKEND == "nccl": - new_backend = "gloo" - elif BACKEND in DistTestCases.backend_feature["plugin"]: - new_backend = "gloo" - - group, group_id, rank = initializer(backend=new_backend) - if group_id is None: - return - - if new_backend == "gloo": - self.assertTrue(group_id._get_backend_name(), "gloo") - if new_backend == "nccl": - self.assertTrue(group_id._get_backend_name(), "nccl") - - self.assertEqual(rank, group[dist.get_rank(group_id)]) - self.assertEqual(len(group), dist.get_world_size(group_id)) - - # Pin device (so we avoid NCCL race conditions/deadlocks). - group_rank = dist.get_rank(group_id) - torch.cuda.set_device(group_rank) - - # Run broadcast of CUDA tensor (so it works for both Gloo and NCCL). - tensor = _build_tensor(2, value=group_rank).cuda() - dist.broadcast(tensor, src=group[0], group=group_id) - self.assertEqual(_build_tensor(2, value=0), tensor.to("cpu")) - - @require_backend_is_available(DistTestCases.backend_feature["gpu"]) - @require_world_size(3) - @skip_if_lt_x_gpu(2) - def test_backend_group(self): - self._test_group_override_backend(self._init_group_test) - - @require_backend_is_available(DistTestCases.backend_feature["gpu"]) - @skip_if_lt_x_gpu(2) - @unittest.skipIf(BACKEND == "ucc", "broken, see https://github.com/pytorch/pytorch/pull/113620") - def test_backend_full_group(self): - self._test_group_override_backend(self._init_full_group_test) - @skip_but_pass_in_sandcastle_if( BACKEND not in DistTestCases.backend_feature["subgroup"], f"The {BACKEND} backend does not support creating subgroups on CUDA devices", @@ -984,7 +940,7 @@ def test_new_subgroups_world_size_not_divisible_by_group_size(self): @require_world_size(4) @skip_if_lt_x_gpu(4) def test_new_subgroups_by_enumeration(self): - group, group_id, rank = self._init_global_test() + _group, _group_id, rank = self._init_global_test() rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) device_id = rank_to_GPU[rank][0] cur_subgroup, subgroups = dist.new_subgroups_by_enumeration( @@ -1010,9 +966,8 @@ def test_new_subgroups_by_enumeration(self): @require_world_size(4) @skip_if_lt_x_gpu(4) def test_new_subgroups_by_enumeration_input_rank_exceeds_world_size(self): - group, group_id, rank = self._init_global_test() - rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) - device_id = rank_to_GPU[rank][0] + _group, group_id, _rank = self._init_global_test() + init_multigpu_helper(dist.get_world_size(), BACKEND) world_size = get_world_size(group_id) with self.assertRaisesRegex( @@ -1029,7 +984,7 @@ def test_new_subgroups_by_enumeration_input_rank_exceeds_world_size(self): ) @skip_if_no_gpu def test_new_subgroups_by_enumeration_negative_input_rank(self): - group, group_id, rank = self._init_global_test() + self._init_global_test() with self.assertRaisesRegex( ValueError, @@ -1426,7 +1381,6 @@ def test_batch_isend_irecv_ring_exchange_nccl(self): rank_to_GPU = init_multigpu_helper(world_size, BACKEND) device_id = rank_to_GPU[rank][0] torch.cuda.set_device(device_id) - p2p_op_list = [] send_tensor = _build_tensor(world_size, device_id=device_id) recv_tensor = _build_tensor(world_size, value=-1, device_id=device_id) @@ -1577,8 +1531,7 @@ def test_batch_isend_irecv_op_list_err(self): def test_batch_isend_irecv_mixed_backend_err(self): self._barrier() rank = dist.get_rank() - rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) - device_id = rank_to_GPU[rank][0] + init_multigpu_helper(dist.get_world_size(), BACKEND) group_gloo = dist.new_group(ranks=[0, 1], backend="gloo") group_nccl = dist.new_group(ranks=[0, 1], backend="nccl") if rank == 0: @@ -2597,7 +2550,7 @@ def call_dist_op( # TODO: move this test to use torch.profiler once kineto issues are # fixed internally. - with autograd_profiler_ctx as prof: + with autograd_profiler_ctx: works = [op_call() for op_call in op_calls] if is_async: for work in works: @@ -2788,7 +2741,7 @@ def test_all_reduce_complex_unsupported_ops(self): dist.ReduceOp.BOR, dist.ReduceOp.BXOR, ] - group, group_id, rank = self._init_global_test() + _group, group_id, _rank = self._init_global_test() for unsupported_op in unsupported_ops: with self.assertRaisesRegex( ValueError, "all_reduce does not support" @@ -2954,12 +2907,12 @@ def test_all_reduce_full_group_max(self): # SPARSE ALL REDUCE def _test_sparse_all_reduce_sum(self, fn): - group, group_id, rank = self._init_global_test() + _group, group_id, rank = self._init_global_test() tests = simple_sparse_reduce_tests( rank, dist.get_world_size(), num_inputs=1 ) - for (inputs, outputs) in tests: + for inputs, outputs in tests: tensors = [fn(input) for input in inputs] dist.all_reduce(tensors[0], dist.ReduceOp.SUM, group_id) self.assertEqual(tensors[0], outputs[0]) @@ -3022,7 +2975,7 @@ def _all_reduce_coalesced_max_test_cases(group_size): BACKEND == "nccl", "Nccl does not support CPU tensors" ) def test_all_reduce_coalesced_max_complex_unsupported(self): - group, group_id, rank = self._init_global_test() + _group, group_id, _rank = self._init_global_test() with self.assertRaisesRegex(ValueError, "all_reduce does not support"): dist.all_reduce_coalesced( [_build_tensor(1, dtype=torch.cfloat)], dist.ReduceOp.MAX, group_id @@ -3238,7 +3191,7 @@ def _test_scatter_helper( BACKEND == "ucc", "CPU tensor ops not supported by UCP TL" ) def test_scatter_checks(self): - group, group_id, rank = self._init_global_test() + group, _group_id, rank = self._init_global_test() one = torch.ones([1]) # Specify scatter_list argument only on source rank. @@ -3357,7 +3310,7 @@ def _test_gather_helper( BACKEND == "ucc", "CPU tensor ops not supported by UCP TL" ) def test_gather_checks(self): - group, group_id, rank = self._init_global_test() + group, _group_id, rank = self._init_global_test() one = torch.ones([1]) # Specify gather_list argument only on destination rank. @@ -4351,7 +4304,7 @@ def _test_DistributedDataParallel( def _test_DistributedDataParallelCPU(self, gradient_as_bucket_view=False): # Run a simple end to end DDP-CPU model, use result of single node # model as baseline - group, group_id, rank = self._init_global_test() + _group, _group_id, rank = self._init_global_test() # cpu training setup model_base = DDP_NET @@ -4420,7 +4373,7 @@ def __init__(self) -> None: self.net2 = nn.Linear(10, 0) model = ToyModel().to(self.rank) - ddp_model = nn.parallel.DistributedDataParallel( + nn.parallel.DistributedDataParallel( model, device_ids=[self.rank] ) @@ -4537,7 +4490,7 @@ def test_ddp_comm_hook_logging(self): # Hook not registered yet, so should be empty self.assertEqual(ddp_logging_data.get("comm_hook"), None) # After second forward pass, hook should still be empty string - for i in range(2): + for _ in range(2): inp = torch.ones(1, 1, device=self.rank) loss = ddp_model(inp).sum() loss.backward() @@ -4638,7 +4591,7 @@ def _test_ddp_hook_with_optimizer_parity( ) # Run optimizer with hook model. - for i in range(6): + for _ in range(6): ddp_model_with_optimizer_hook.zero_grad() out = ddp_model_with_optimizer_hook(inp) loss = out.sum() @@ -4647,7 +4600,7 @@ def _test_ddp_hook_with_optimizer_parity( dist.barrier() # Run regular model. - for i in range(6): + for _ in range(6): ddp_model_with_no_hook.zero_grad() out = ddp_model_with_no_hook(inp) loss = out.sum() @@ -4768,7 +4721,7 @@ def test_get_data_parallel_params(self): torch.nn.parallel.DistributedDataParallel._set_params_and_buffers_to_ignore_for_model( model, params_to_ignore ) - ddp_model = torch.nn.parallel.DistributedDataParallel( + torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.rank] ) dp_params = torch.nn.parallel.DistributedDataParallel._get_data_parallel_params( @@ -5018,7 +4971,7 @@ def forward(self_, x): # noqa: B902 self.assertEqual(mp_config.param_dtype, p._mp_param.dtype) self.assertEqual(torch.float32, p._fp_param.dtype) - for i in range(6): + for _ in range(6): loss = net(inp).sum() loss.backward() # Verify gradient synchronization and params and grads are fp32. @@ -5269,7 +5222,7 @@ def _test_accumulate_gradients_no_sync( to the ``ddp_model``. The hook fed into this function should not change the resulting gradients. """ - group, group_id, rank = self._init_global_test() + _group, group_id, rank = self._init_global_test() world_size = get_world_size() # FIXME: Add testing for gloo/CUDA @@ -5455,7 +5408,7 @@ def add(fut): ) @skip_if_no_gpu def test_DistributedDataParallel(self): - group, group_id, rank = self._init_global_test() + _group, _group_id, rank = self._init_global_test() rank_to_GPU = init_multigpu_helper(dist.get_world_size(), BACKEND) gpus = list(rank_to_GPU[rank]) @@ -5845,7 +5798,7 @@ def test_DistributedDataParallel_SyncBatchNorm_Channels_Last(self): def _test_DistributedDataParallel_SyncBatchNorm_with_memory_format( self, memory_format ): - group, group_id, rank = self._init_global_test() + _group, _group_id, rank = self._init_global_test() num_processes = dist.get_world_size() local_bs = 2 bs_offset = int(rank * 2) @@ -5896,7 +5849,7 @@ def _test_DistributedDataParallel_SyncBatchNorm_with_memory_format( ) @skip_if_no_gpu def test_DistributedDataParallel_SyncBatchNorm(self): - group, group_id, rank = self._init_global_test() + _group, _group_id, rank = self._init_global_test() world_size = dist.get_world_size() # DDP does not support replicating BN layers within a process, hence # testing with one module replica per process @@ -5941,7 +5894,7 @@ def test_DistributedDataParallel_SyncBatchNorm(self): ) @skip_if_no_gpu def test_DistributedDataParallel_SyncBatchNorm_No_Affine(self): - group, group_id, rank = self._init_global_test() + _group, _group_id, rank = self._init_global_test() world_size = dist.get_world_size() # DDP does not support replicating BN layers within a process, hence # testing with one module replica per process @@ -5966,7 +5919,7 @@ def test_DistributedDataParallel_SyncBatchNorm_No_Affine(self): ) @skip_if_no_gpu def test_DistributedDataParallel_SyncBatchNorm_2D_Input(self): - group, group_id, rank = self._init_global_test() + _group, _group_id, rank = self._init_global_test() # DDP does not support replicating BN layers within a process, hence # testing with one module replica per process gpus = [rank] @@ -6013,7 +5966,7 @@ def test_DistributedDataParallel_SyncBatchNorm_2D_Input(self): @skip_if_no_gpu @require_world_size(2) def test_DistributedDataParallel_SyncBatchNorm_Single_Input_Per_Process(self): - group, group_id, rank = self._init_global_test() + _group, _group_id, rank = self._init_global_test() # DDP does not support replicating BN layers within a process, hence # testing with one module replica per process gpus = [rank] @@ -6061,7 +6014,7 @@ def test_DistributedDataParallel_SyncBatchNorm_Single_Input_Per_Process(self): def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value( self, ): - group, group_id, rank = self._init_global_test() + _group, _group_id, rank = self._init_global_test() model = nn.parallel.DistributedDataParallel( ONLY_SBN_NET.cuda(rank), device_ids=[rank] ) @@ -6102,13 +6055,11 @@ def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_Running_Value( ) @skip_if_no_gpu def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self): - group, group_id, rank = self._init_global_test() + _group, _group_id, rank = self._init_global_test() # only do single GPU per process gpus = [rank] # cpu training setup - model = BN_NET - num_processes = dist.get_world_size() local_bs = rank + 2 bs_offset = int((rank + 3) * rank / 2) @@ -6128,7 +6079,7 @@ def test_DistributedDataParallel_SyncBatchNorm_Diff_Input_Sizes_gradient(self): ) @skip_if_no_gpu def test_DistributedDataParallel_SyncBatchNorm_half(self): - group, group_id, rank = self._init_global_test() + _group, _group_id, rank = self._init_global_test() model = copy.deepcopy(BN_NET) model = model.half() @@ -6219,7 +6170,7 @@ def parse_env(var): return os.environ[var] if var in os.environ else "N/A" dist.set_debug_level(dist.DebugLevel.INFO) - group, group_id, rank = self._init_global_test() + _, group_id, _ = self._init_global_test() model_DDP = self._test_ddp_logging_data(is_gpu=False) ddp_logging_data = model_DDP._get_ddp_logging_data() @@ -6366,7 +6317,7 @@ def parse_env(var): ) @skip_if_no_gpu def test_ddp_logging_data_gpu(self): - group, group_id, rank = self._init_global_test() + _group, _group_id, rank = self._init_global_test() model_DDP = self._test_ddp_logging_data(is_gpu=True) ddp_logging_data = model_DDP._get_ddp_logging_data() self.assertEqual(ddp_logging_data.get("device_ids"), str(rank)) @@ -6424,7 +6375,7 @@ def test_static_graph_api_cpu(self): expected_err = "should be called before training loop starts" with self.assertRaisesRegex(RuntimeError, expected_err): local_bs = 2 - batch_size, input, target, loss = self._prepare_dummy_data(local_bs) + _batch_size, input, target, loss = self._prepare_dummy_data(local_bs) offset = dist.get_rank() * local_bs # DDP training, DDP scatters subsets of input to nodes/GPUs @@ -6906,7 +6857,7 @@ def _test_ddp_profiling(self, profiler_ctx, profiler_ctx2=None): profiler_ctx2 = copy.deepcopy(profiler_ctx) with profiler_ctx as prof: - for i in range(num_iters): + for _ in range(num_iters): loss = net(inp).sum() loss.backward() @@ -6934,7 +6885,7 @@ def _test_ddp_profiling(self, profiler_ctx, profiler_ctx2=None): device_ids=[self.rank], find_unused_parameters=True, ) - for i in range(3): + for _ in range(3): loss = net(inp).sum() loss.backward() # Now enable the profiler. @@ -7071,7 +7022,7 @@ def test_ddp_profiling_execution_trace(self): activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], execution_trace_observer=et ) - prof = self._test_ddp_profiling( + self._test_ddp_profiling( profiler_ctx=torch_profiler_ctx1, profiler_ctx2=torch_profiler_ctx2, ) @@ -7117,7 +7068,7 @@ def test_ddp_join_model_equivalence(self): model.parameters(), lr=learning_rate * dist.get_world_size() ) with net.join(): - for i in range(num_iters): + for _ in range(num_iters): ddp_optim.zero_grad() out = net(inp) loss = out.sum() @@ -7287,7 +7238,7 @@ def forward(self, x): n = 0 with exception_ctx: with model.join(throw_on_early_termination=True): - for i in range(num_iters): + for _ in range(num_iters): loss = model(model_input).sum() loss.backward() self._model_step(model) @@ -7668,7 +7619,6 @@ def forward(self, x): "ignore_buffer", torch.zeros(5 + self.rank, device=self.rank) ) proxy_params = list(model.fc2.parameters()) - proxy_buffers = list(model.fc2.buffers()) model_fc2_name = next( module_name for module_name, module in model.named_modules() @@ -7702,7 +7652,7 @@ def forward(self, x): local_model = copy.deepcopy(ddp.module).cuda(self.rank) inp = torch.ones(1, dtype=torch.float).to(device_id) * (self.rank + 1) - for i in range(6): + for _ in range(6): ddp(inp).sum().backward() local_model(inp).sum().backward() @@ -7816,7 +7766,7 @@ def forward(self, x): static_graph=static, ) inp = torch.randn(20, 10, device=self.rank) - for i in range(6): + for _ in range(6): loss = ddp_model(inp) # To test https://github.com/pytorch/pytorch/issues/61982 loss /= 10 @@ -7825,7 +7775,6 @@ def forward(self, x): @require_backend_is_available(DistTestCases.backend_feature["gpu"]) @skip_if_lt_x_gpu(2) def test_ddp_device(self): - m = nn.Linear(10, 10).to(self.rank) expected_len = 2 class TensorWrapper: @@ -7963,7 +7912,7 @@ def forward(self_, input, expected_type): # noqa: B902 @require_backend_is_available({"gloo"}) def test_grads_same_across_ranks_with_no_sync(self): - group, group_id, rank = self._init_global_test() + _group, _group_id, rank = self._init_global_test() world_size = dist.get_world_size() if world_size < 2: self.skipTest("This test requires at least two ranks.") @@ -8122,7 +8071,6 @@ def test_ddp_control_flow_same_across_ranks(self): @require_backend_is_available(DistTestCases.backend_feature["gpu"]) @skip_if_lt_x_gpu(2) def test_invalid_static_graph(self): - world_size = dist.get_world_size() torch.cuda.set_device(self.rank) model = torch.nn.parallel.DistributedDataParallel( ControlFlowToyModel().cuda(self.rank), @@ -8342,11 +8290,11 @@ def _test_compute_bucket_assignment_by_size(self, use_logger): self._generate_sparse_tensors_for_bucket_assignment_test() ) if use_logger: - result = dist._compute_bucket_assignment_by_size( + dist._compute_bucket_assignment_by_size( tensors_sparse, [400], logger=net.logger ) else: - result = dist._compute_bucket_assignment_by_size( + dist._compute_bucket_assignment_by_size( tensors_sparse, [400] ) if use_logger: @@ -8496,7 +8444,7 @@ def test_ddp_model_diff_shape_across_ranks(self): backend=dist.get_backend(), timeout=timedelta(seconds=10) ) torch.cuda.set_device(self.rank) - ctx, expected_err = self._determine_expected_error_verify_model_across_rank( + ctx, _expected_err = self._determine_expected_error_verify_model_across_rank( group_to_use ) # Creates network with different sized embedding table on different @@ -8522,7 +8470,7 @@ def test_ddp_model_diff_num_params_across_ranks(self): backend=dist.get_backend(), timeout=timedelta(seconds=10) ) torch.cuda.set_device(self.rank) - ctx, expected_err = self._determine_expected_error_verify_model_across_rank( + ctx, _expected_err = self._determine_expected_error_verify_model_across_rank( group_to_use, diff_num_params=True ) @@ -8706,7 +8654,6 @@ def forward(self, x): return F.relu(self.lin1(x)) torch.manual_seed(31415) - world_size = dist.get_world_size() torch.cuda.set_device(self.rank) model = ToyModel(self.rank).cuda(self.rank) ddp_model = torch.nn.parallel.DistributedDataParallel( @@ -8717,7 +8664,7 @@ def forward(self, x): static_graph=static_graph, ) random_input = torch.randn(20, 10, device=self.rank) - for i in range(10): + for _ in range(10): out = ddp_model(random_input) loss = out.sum() loss.backward() @@ -9046,9 +8993,7 @@ def forward(self, x): if ignore_sparse: for module_name, module in model.named_modules(): if module == model.sub_module.embedding_net.embedding: - for parameter_name, param in module.named_parameters( - recurse=False - ): + for parameter_name, _param in module.named_parameters(recurse=False): fqn = f"{module_name}.{parameter_name}" sparse_embedding_fqns.append(fqn) @@ -9069,7 +9014,7 @@ def forward(self, x): fqn_to_param_index = {} index = 0 for module_name, module in model.named_modules(): - for parameter_name, param in module.named_parameters(recurse=False): + for parameter_name, _param in module.named_parameters(recurse=False): fqn = f"{module_name}.{parameter_name}" fqn_to_param_index[fqn] = index if fqn not in sparse_embedding_fqns: @@ -9204,7 +9149,7 @@ def test_ddp_sync_bn_training_vs_eval(self): model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank]) # Test sync occurs in training mode. with torch.autograd.profiler.profile() as prof: - for i in range(6): + for _ in range(6): inp = torch.randn(10, 2, 4, 4).cuda(rank) out = model(inp) loss = out.sum() @@ -9224,7 +9169,7 @@ def test_ddp_sync_bn_training_vs_eval(self): if self.rank == 0: model_inference.eval() with torch.autograd.profiler.profile() as prof: - for i in range(6): + for _ in range(6): inp = torch.randn(10, 2, 4, 4).cuda(rank) out = model_inference(inp) loss = out.sum() @@ -9331,7 +9276,7 @@ def get_loss(model_output): "dict": dict, } for output_type in type_mapping.keys(): - for i in range(6): + for _ in range(6): out = model(inp, output_type=output_type) loss = get_loss(out) loss.backward() @@ -9380,7 +9325,7 @@ def forward(self, x): find_unused_parameters=find_unused, static_graph=static_graph, ) - for i in range(6): + for _ in range(6): out = ddp(inp) self.assertFalse(out[0].requires_grad) o = (out[0] + out[1]).sum() @@ -9546,7 +9491,7 @@ def buffer_comm_hook(ddp, named_buffers): broadcast_buffers=False, ) inp = torch.randn(2, 10, device=rank) - for i in range(2): + for _ in range(2): loss_hook = model_ddp(inp).sum() # Since buffer reduction is done pre-forward, simulate it for # no hook case here. @@ -9626,7 +9571,7 @@ def buffer_comm_hook(ddp, named_buffers): device_ids=[self.rank], ) inp = torch.randn(2, 10, device=rank) - for i in range(2): + for _ in range(2): loss_hook = model_ddp(inp).sum() loss_no_hook = model_ddp_no_hook(inp).sum() self._verify_buffers_equal(model_ddp, model_ddp_no_hook) @@ -9737,46 +9682,11 @@ def forward(self, inp): ddp._check_reducer_finalized() ddp(input) - @skip_if_lt_x_gpu(2) - @skip_but_pass_in_sandcastle_if( - BACKEND != "nccl", - "TORCH_NCCL_USE_COMM_NONBLOCKING only applies to NCCL" - ) - def test_nccl_init_abort(self): - """ - Tests that we can abort a NCCL communicator during initialization and - recover appropriately. - """ - # Reinitialize global process group with TORCH_NCCL_USE_COMM_NONBLOCKING=1 - os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1" - dist.destroy_process_group() - timeout = timedelta(seconds=1) - dist.init_process_group( - init_method=INIT_METHOD, - backend=BACKEND, - world_size=int(os.environ["WORLD_SIZE"]), - rank=self.rank, - timeout=timeout, - ) - - # Abort pg in background thread. - running = True - - def abort(device): - pg = _get_default_group() - while running: - pg._get_backend(torch.device(device))._shutdown() - time.sleep(1) - - if self.rank != 1: - import threading - t = threading.Thread(target=abort, args=(self.rank,)) - t.start() - with self.assertRaises(RuntimeError): - # First collective triggers initialization via ncclCommInitRank. - torch.distributed.barrier() - running = False - t.join() + """ + # The set of "test_ddp_update_process_group..." below failed after + # upgrading CI from 2 GPUs to 4 GPUs. + # Commented out for now. + # Test purpose needs better documentation. def _run_ddp_update_process_group(self, new_pg): def get_num_torch_recompiles(): @@ -9960,7 +9870,7 @@ def test_ddp_update_process_group_no_find_unused(self): find_unused_parameters=False, ) ddp._update_process_group(_get_default_group()) - + """ @skip_if_lt_x_gpu(2) @skip_but_pass_in_sandcastle_if( @@ -9989,7 +9899,7 @@ def forward(self, x): device_ids=[self.rank], ) inp = torch.randn(2, 10, device=rank) - for i in range(2): + for _ in range(2): if rank == 0: model_ddp.module.buffer = model_ddp.module.buffer + 1 loss = model_ddp(inp).sum() @@ -10034,18 +9944,19 @@ def forward(self, x): b = model(inp) loss = a.sum() + b.sum() loss.backward() - # Grads should be equal to a local model that ran through inp twice and averaged grads + # Grads should be equal to a local model that ran through inp + # `world_size` times and averaged grads if self.rank == 0: inp_clone = inp.clone() - for _ in range(2): + iters = dist.get_world_size() + for _ in range(iters): a = local_model(inp_clone) b = local_model(inp_clone) loss = a.sum() + b.sum() loss.backward() - ws = dist.get_world_size() for p in local_model.parameters(): - p.grad.data = p.grad / dist.get_world_size() + p.grad.data = p.grad / iters for p_ddp, p_local in zip( model.parameters(), @@ -10415,7 +10326,7 @@ def forward(self, input): ddp._set_ddp_sink_clone(False) input = torch.rand(10, 10).cuda(self.rank) - with OpPatcher() as patcher: + with OpPatcher(): ddp(input).sum().backward() diff --git a/torch/testing/_internal/distributed/nn/api/remote_module_test.py b/torch/testing/_internal/distributed/nn/api/remote_module_test.py index ee3a374c745df..7a889b0db8473 100644 --- a/torch/testing/_internal/distributed/nn/api/remote_module_test.py +++ b/torch/testing/_internal/distributed/nn/api/remote_module_test.py @@ -12,7 +12,7 @@ from torch.distributed.nn.api.remote_module import _REMOTE_MODULE_PICKLED_ATTRIBUTES from torch.distributed.nn.api.remote_module import _RemoteModule from torch.testing._internal.common_distributed import skip_if_lt_x_gpu -from torch.testing._internal.common_utils import TemporaryFileName +from torch.testing._internal.common_utils import TemporaryFileName, TEST_WITH_ROCM from torch.testing._internal.distributed.rpc.rpc_agent_test_fixture import ( RpcAgentTestFixture, ) @@ -535,7 +535,7 @@ def test_send_remote_module_over_the_wire_script_not_supported(self): dst_worker1_name, modes=[ModuleCreationMode.MODULE_CTOR_WITH_INTERFACE] ): # Test querying some simple attributes from worker2. - attrs = rpc.rpc_sync( + rpc.rpc_sync( dst_worker2_name, remote_module_attributes, (remote_module,) ) @@ -563,7 +563,7 @@ def test_create_remote_module_from_module_rref(self): ret2 = rpc.rpc_sync( dst_worker2_name, remote_forward, (remote_module2, args) ) - self.assertEqual(ret2, ret2) + self.assertEqual(ret1, ret2) class CudaRemoteModuleTest(CommonRemoteModuleTest): @@ -613,8 +613,15 @@ def test_invalid_devices(self): ) ] + if TEST_WITH_ROCM: + errorString = (r"HIP error: invalid device ordinal\n" + r"HIP kernel errors might be asynchronously reported at some other API call, " + r"so the stacktrace below might be incorrect.\n" + r"For debugging consider passing AMD_SERIALIZE_KERNEL=3") + else: + errorString = r"CUDA error: invalid device ordinal" with self.assertRaisesRegex( - RuntimeError, r"CUDA error: invalid device ordinal" + RuntimeError, errorString ): [ m.forward() diff --git a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py index 0a6b9a843b629..a0e934fae280f 100644 --- a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py @@ -343,7 +343,6 @@ def _test_graph_for_py_nested_call(self, exec_mode, sparse): else: t1 = torch.ones(3, 3, requires_grad=True) t2 = torch.zeros(3, 3, requires_grad=True) - nest_dst_rank = (dst_rank + 1) % self.world_size if ExecMode.RPC_SYNC == exec_mode: ret = rpc.rpc_sync( worker_name(dst_rank), @@ -499,11 +498,11 @@ def _test_no_graph_with_tensors_not_require_grad(self, exec_mode, sparse): t1 = torch.ones(3, 3, requires_grad=False) t2 = torch.zeros(3, 3, requires_grad=False) if ExecMode.RPC_SYNC == exec_mode: - ret = rpc.rpc_sync( + rpc.rpc_sync( worker_name(dst_rank), torch.add, args=(t1, t2) ) elif ExecMode.REMOTE == exec_mode: - ret = rpc.remote( + rpc.remote( worker_name(dst_rank), torch.add, args=(t1, t2) ).to_here() else: @@ -531,7 +530,7 @@ def _test_no_graph_with_tensors_not_require_grad(self, exec_mode, sparse): dist.barrier() def _test_rpc_complex_args(self, exec_mode, sparse): - with dist_autograd.context() as context_id: + with dist_autograd.context(): num_tensors = 10 tensors = [] for i in range(num_tensors): @@ -556,7 +555,6 @@ def _test_rpc_complex_args(self, exec_mode, sparse): # Verify appropriate tensors have been attached the autograd graph. next_funcs = next(iter(dist_autograd._current_context()._send_functions().values())).next_functions - idx = 0 for i in range(len(next_funcs)): self.assertEqual( "torch::autograd::AccumulateGrad", next_funcs[i][0].name() @@ -731,7 +729,6 @@ def _test_trainer_ps(self, create_ref_fn, trainer_fn, sparse): self._check_rpc_done(rank_diff) # trainers are done and holding the context for verification - accumulate_grad_func = None for rank_diff in rank_diffs: # make sure grads are accumulated for the same tensors and values # are all correct @@ -890,7 +887,7 @@ def _multiple_backward(self, t1, t2, sparse): else: loss = loss.sum() # Run backward in a loop multiple times. - for i in range(1000): + for _ in range(1000): dist_autograd.backward(context_id, [loss], retain_graph=True) # For current context, this rank sends t1 and t2 tensors to dst_rank, @@ -1279,7 +1276,7 @@ def test_autograd_context(self): ) context_ids = [] - for i in range(200): + for _ in range(200): with dist_autograd.context() as context_id: self.assertEqual( context_id, @@ -1298,12 +1295,12 @@ def test_autograd_context(self): @dist_init def test_nested_context(self): - with dist_autograd.context() as context_id: + with dist_autograd.context(): # Nested contexts not supported. with self.assertRaisesRegex( RuntimeError, "Already have an autograd context id for this thread" ): - with dist_autograd.context() as context_id: + with dist_autograd.context(): pass @dist_init @@ -1438,7 +1435,7 @@ def test_worker_ids_recorded(self): t1.requires_grad = True t2.requires_grad = True for dst_rank in dst_ranks: - ret = rpc.rpc_sync( + rpc.rpc_sync( worker_name(dst_rank), torch.add, args=(t1, t2) ) rpc.rpc_sync( @@ -1475,7 +1472,7 @@ def get_event(partial_key): @dist_init def test_error_in_context(self): - with dist_autograd.context() as context_id: + with dist_autograd.context(): t1 = torch.rand(3, 3, requires_grad=True) t2 = torch.rand(6, 6, requires_grad=True) @@ -1651,7 +1648,7 @@ def _run_test_backward_unused_send_function_in_thread(self): # We don't use the result of an RPC function, as a result the # backward pass would hang in the "FAST" mode. - res = rpc.rpc_sync( + rpc.rpc_sync( worker_name(self._next_rank()), torch.add, args=(t1, t2) ) @@ -1757,7 +1754,6 @@ def test_backward_without_context(self): @dist_init def test_backward_without_rpc(self): - dst_rank = self.rank with dist_autograd.context() as context_id: t1 = torch.rand((3, 3), requires_grad=True) t2 = torch.rand((3, 3), requires_grad=True) @@ -2172,7 +2168,7 @@ def test_async_dist_autograd(self): if self.rank != 0: # All other ranks schedule work on rank 0. threads = [] - for i in range(20): + for _ in range(20): t = threading.Thread(target=DistAutogradTest._workload_thread) t.start() threads.append(t) @@ -2399,7 +2395,7 @@ def backward(ctx, grad): self.assertTrue(p_a == p_g) # Run backwards multiple times. - for i in range(10): + for _ in range(10): dist_autograd.backward(context_id, [loss], retain_graph=True) # non-contiguous indices and value, we should trigger a copy. @@ -2418,7 +2414,7 @@ def backward(ctx, grad): self.assertFalse(p_b == p_g) # Run backwards multiple times to verify accumulation. - for i in range(10): + for _ in range(10): dist_autograd.backward(context_id, [loss], retain_graph=True) @dist_init @@ -2550,7 +2546,7 @@ def test_gpu_to_cpu_continuation(self): t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0") t2 = torch.rand(3, 3, requires_grad=True) # Run a few iterations. - for i in range(3): + for _ in range(3): t1.grad = None t2.grad = None # Root is CPU @@ -2574,7 +2570,7 @@ def test_gpu_to_cpu_continuation_gpu_root(self): t1 = torch.rand(3, 3, requires_grad=True, device="cuda:0") t2 = torch.rand(3, 3, requires_grad=True) # Run a few iterations. - for i in range(3): + for _ in range(3): t1.grad = None t2.grad = None # Root is CPU diff --git a/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py b/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py index 310dc740db680..39c25260887c2 100644 --- a/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_optimizer_test.py @@ -135,7 +135,7 @@ def test_dist_optim_exception_on_constructor(self): remote_param2 = remote_method(MyModule.get_w, remote_module2) with self.assertRaisesRegex(Exception, "Error creating optimizer."): - dist_optim = DistributedOptimizer( + DistributedOptimizer( OptimizerFailingOnConstructor, [remote_param1, remote_param2] ) @@ -169,8 +169,6 @@ def _test_dist_optim_base(self, optim_cls, *args, **kwargs): remote_param1 = remote_method(MyModule.get_w, remote_module1) remote_param2 = remote_method(MyModule.get_w, remote_module2) - old_w1_remote = remote_param1.to_here() - # sanity check: local and remote initial weights should match self.assertEqual(old_w1, remote_param1.to_here()) self.assertEqual(old_w2, remote_param2.to_here()) diff --git a/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py b/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py index 5d7e7b1244bce..eab07be49e56b 100644 --- a/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/examples/reinforcement_learning_rpc_test.py @@ -5,7 +5,6 @@ # and https://pytorch.org/tutorials/intermediate/rpc_tutorial.html import numpy as np -from itertools import count import torch import torch.distributed.rpc as rpc @@ -109,8 +108,8 @@ def run_episode(self, agent_rref, n_steps): agent_rref (RRef): an RRef referencing the agent object. n_steps (int): number of steps in this episode """ - state, ep_reward = self.env.reset(), 0 - for step in range(n_steps): + state, _ep_reward = self.env.reset(), 0 + for _ in range(n_steps): # send the state to the agent to get an action action = _remote_method(Agent.select_action, agent_rref, self.id, state) @@ -222,9 +221,9 @@ def finish_episode(self): def run_agent(agent, n_steps): - for i_episode in count(1): + while True: agent.run_episode(n_steps=n_steps) - last_reward = agent.finish_episode() + agent.finish_episode() if agent.running_reward > agent.reward_threshold: print(f"Solved! Running reward is now {agent.running_reward}!") diff --git a/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py index a1163adb97cc8..0b69d9ff75448 100644 --- a/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/jit/dist_autograd_test.py @@ -33,7 +33,6 @@ def fork_add(t1, t2, dst: str): class JitDistAutogradTest(RpcAgentTestFixture): @dist_init def test_get_gradients(self): - dst_rank = self.rank @torch.jit.script def dist_get_gradients(context_id: int) -> (Dict[Tensor, Tensor]): diff --git a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py index 2f83eb3311c65..4270f4bcd006f 100644 --- a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py @@ -153,7 +153,7 @@ def script_add_ones_with_record_function(x, block: str): @torch.jit.script def record_function_on_caller_rpc_async(dst_worker_name: str, block: str) -> Tensor: t: Tensor = torch.ones(1) - with record_function(block) as rf: + with record_function(block): fut1 = rpc.rpc_async(dst_worker_name, script_add_ones, (t, )) # Extra operator call to avoid de-duplication of the next async call # see https://github.com/pytorch/pytorch/pull/62710#discussion_r694680279 @@ -669,8 +669,6 @@ def test_less_than_needed_args_are_specified(self): if self.rank != 0: return - dst_worker_name = worker_name((self.rank + 1) % self.world_size) - # Notice, args matching happens during scripting. with self.assertRaisesRegex(RuntimeError, "Argument second_arg not provided"): @@ -689,8 +687,6 @@ def test_more_than_needed_args_are_specified(self): if self.rank != 0: return - dst_worker_name = worker_name((self.rank + 1) % self.world_size) - # Notice, args matching happens during scripting. with self.assertRaisesRegex( RuntimeError, @@ -893,10 +889,10 @@ def test_torchscript_function(self): def test_torchscript_function_exception(self): dst_worker_name = worker_name((self.rank + 1) % self.world_size) with self.assertRaisesRegex(RuntimeError, r"one_arg\(\) expected at most"): - ret = rpc.rpc_sync(dst_worker_name, one_arg, args=(10, 20)) + rpc.rpc_sync(dst_worker_name, one_arg, args=(10, 20)) with self.assertRaisesRegex(RuntimeError, r"one_arg\(\) expected at most"): - rref = rpc.remote(dst_worker_name, one_arg, args=(10, 20)) + rpc.remote(dst_worker_name, one_arg, args=(10, 20)) @dist_init def test_torchscript_functions_not_supported(self): @@ -913,13 +909,13 @@ def test_torchscript_functions_not_supported(self): # rpc_sync still accepts script class and run it in # the same code path as python call. - ret = rpc.rpc_sync(dst_worker_name, MyScriptClass, args=(self.rank,)) + rpc.rpc_sync(dst_worker_name, MyScriptClass, args=(self.rank,)) # rpc_sync does not accept script module method. # Python 3.5 and Python 3.6 throw different error message, the only # common word can be greped is "pickle". with self.assertRaisesRegex(TypeError, "pickle"): - ret = rpc.rpc_async( + rpc.rpc_async( dst_worker_name, my_local_script_module.forward, args=() ) @@ -1070,7 +1066,6 @@ def callback(fut): @dist_init def test_callback_chain(self): n = self.rank + 1 - dst = worker_name(n % self.world_size) def callback(fut): return fut.wait() + 1 @@ -1148,7 +1143,7 @@ def test_call_rpc_with_profiling(self): "worker1", ) with torch.autograd.profiler.record_function(prof_key) as rf: - ret = call_rpc_with_profiling(rf.record, "worker1") + call_rpc_with_profiling(rf.record, "worker1") # TODO: Can't get a reliable time for this profiling event since # it's hard to estimate the execution time on the remote end for non-UDFs. # This can be resolved by https://github.com/pytorch/pytorch/issues/36272. @@ -1297,7 +1292,7 @@ def test_call_fork_in_jit_with_profiling(self): # future from within a script function with torch.jit.fork with _profile() as prof: with torch.autograd.profiler.record_function("foo") as rf: - ret = call_fork_with_profiling(rf.record) + call_fork_with_profiling(rf.record) events = prof.function_events function_event = get_function_event(events, "foo") diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index 413f97d94eb28..752370617241d 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -1130,7 +1130,7 @@ def test_worker_id(self): self.assertEqual(peer_worker_info.name, worker_name(peer_rank)) with self.assertRaisesRegex(RuntimeError, "could not find destination"): - unknown_worker_id = rpc.get_worker_info("WorkerUnknown") + rpc.get_worker_info("WorkerUnknown") @dist_init def test_get_worker_infos(self): @@ -1149,7 +1149,6 @@ def test_get_worker_infos(self): @dist_init def test_self_add(self): self_worker_info = rpc.get_worker_info() - self_worker_name = worker_name(self.rank) fut = rpc.rpc_async(self_worker_info, torch.add, args=(torch.ones(2, 2), 1)) ret = rpc.rpc_sync(self_worker_info, torch.add, args=(torch.ones(2, 2), 1)) self.assertEqual(fut.wait(), torch.ones(2, 2) + 1) @@ -1473,18 +1472,18 @@ def test_invalid_names(self): worker_id = 0 with self.assertRaisesRegex(RuntimeError, "Worker name must match"): - info = WorkerInfo("abc*", worker_id) + WorkerInfo("abc*", worker_id) with self.assertRaisesRegex(RuntimeError, "Worker name must match"): - info = WorkerInfo(" ", worker_id) + WorkerInfo(" ", worker_id) with self.assertRaisesRegex(RuntimeError, "must be non-empty"): - info = WorkerInfo("", worker_id) + WorkerInfo("", worker_id) # If the number in the message does not match, it is likely that the # value of MAX_NAME_LEN in RPC WorkerInfo has changed. with self.assertRaisesRegex(RuntimeError, "shorter than 128"): - info = WorkerInfo("".join(["a" for i in range(500)]), worker_id) + WorkerInfo("".join(["a" for i in range(500)]), worker_id) # Test that WorkerInfo can be pickled and sent in RPC call @dist_init @@ -1562,9 +1561,7 @@ def test_multi_rpc(self): @dist_init def test_future_wait_twice(self): dst = worker_name((self.rank + 1) % self.world_size) - futs = [] - for i in range(20): - futs.append(rpc.rpc_async(dst, raise_func)) + futs = [rpc.rpc_async(dst, raise_func) for _ in range(20)] with self.assertRaisesRegex(ValueError, "Expected error"): torch.futures.wait_all(futs) @@ -1724,7 +1721,7 @@ def test_shutdown_followed_by_rpc(self): def test_expected_src(self): dst_rank = (self.rank + 1) % self.world_size expected_src_rank = (self.rank - 1) % self.world_size - ret = rpc.rpc_sync(worker_name(dst_rank), set_value, args=(self.rank,)) + rpc.rpc_sync(worker_name(dst_rank), set_value, args=(self.rank,)) value = VALUE_FUTURE.result() self.assertEqual(value, expected_src_rank) @@ -1803,7 +1800,7 @@ def test_profiler_rpc_memory(self): dst_worker = worker_name(dst) with _profile(profile_memory=True) as p: fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) - res = fut.wait() + fut.wait() function_events = p.function_events event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events} @@ -1813,7 +1810,7 @@ def test_profiler_rpc_memory(self): # No memory profiled if profile_memory=False with _profile(profile_memory=False) as p: fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) - res = fut.wait() + fut.wait() function_events = p.function_events event_cpu_mem_usages = {event.cpu_memory_usage for event in function_events} @@ -1827,9 +1824,8 @@ def test_profiler_export_trace(self): dst_worker = worker_name(dst) with _profile() as p: fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) - res = fut.wait() + fut.wait() - events = p.function_events with TemporaryFileName() as fname: path = fname p.export_chrome_trace(path) @@ -1920,7 +1916,7 @@ def _run_test_profiler_remote_events_profiled(self): dst_worker = worker_name(dst) with _profile() as prof: fut = rpc.rpc_async(dst_worker, udf_with_torch_ops, args=()) - ret = fut.wait() + fut.wait() events = prof.function_events @@ -1999,7 +1995,7 @@ def _run_rpc_profiling_async_function(self, device="cpu"): ret = rpc.rpc_async( dst1, slow_async_add, args=(dst2, x, y, device), timeout=20 ) - out = ret.wait() + ret.wait() function_events = prof.function_events # slow_async_add resulted in an RPC from dst1 -> dst2, so this should be @@ -2130,7 +2126,7 @@ def _run_test_profiler_with_autograd_context(self): dst = (self.rank + 1) % self.world_size if self.rank == 1: # Cases where we can double wrap messages with profiling information and autograd info. - with dist_autograd.context() as context_id: + with dist_autograd.context(): with _profile() as prof: self.run_profiling_workload(dst) @@ -2139,7 +2135,7 @@ def _run_test_profiler_with_autograd_context(self): # Ensure that flipped order of ctx managers results in events being # recorded as expected. with _profile() as prof: - with dist_autograd.context() as context_id: + with dist_autograd.context(): self.run_profiling_workload(dst) self.validate_profiling_workload(dst, prof) @@ -2168,7 +2164,7 @@ def _profiler_test_with_rpc( "foo" ) ) - with record_function_ctx_mgr as rf: + with record_function_ctx_mgr: if rpc_exec_mode == RPCExecMode.SYNC: rpc.rpc_sync(worker_name(dst), func, args=args) elif rpc_exec_mode == RPCExecMode.ASYNC: @@ -2452,7 +2448,7 @@ def test_async_record_function_double_end_callbacks(self): num_sleep_seconds = 1 if self.rank == 1: # Validate that calling the function twice results in an error. - with _profile() as pf: + with _profile(): with torch.autograd.profiler.record_function("foo") as rf: fut = rpc.rpc_async( worker_name(0), my_sleep_func, args=(num_sleep_seconds,) @@ -2470,7 +2466,7 @@ def test_async_record_function_legacy(self): # Note: These exist for backward compatibility with TorchScript num_sleep_seconds = 1 if self.rank == 1: - with _profile() as pf: + with _profile(): try: handle = torch.ops.profiler._record_function_enter("foo", None) fut = rpc.rpc_async( @@ -2623,7 +2619,7 @@ def test_py_function_exception(self): n = self.rank + 1 dst_rank = n % self.world_size with self.assertRaises(TypeError): - ret = rpc.rpc_sync(worker_name(dst_rank), no_result, args=(10,)) + rpc.rpc_sync(worker_name(dst_rank), no_result, args=(10,)) @dist_init def test_py_raise_in_user_func(self): @@ -2840,7 +2836,7 @@ def test_rref_forward_chain(self): ret_rref = rref_forward_chain(dst_rank, self.world_size, rref, ttl) - for i in range(ttl): + for _ in range(ttl): self.assertEqual(len(ret_rref), 1) ret_rref = ret_rref[0].to_here() @@ -3125,7 +3121,7 @@ def _test_rref_leak(self, _mock_delete_all_user_and_unforked_owner_rrefs, ignore # Wait for all init to complete. dist.barrier() - rref = rpc.remote( + rref = rpc.remote( # noqa: F841 worker_name((self.rank + 1) % self.world_size), torch.add, args=(torch.ones(2, 2), 1), @@ -3556,7 +3552,7 @@ def test_wait_all_timeout(self): self.assertTrue(_thread_local_var.future_list == []) dst = worker_name((self.rank + 1) % self.world_size) timeout = 0.1 # 100 ms - fut = rpc.rpc_async(dst, my_sleep_func, args=(1,), timeout=timeout) + rpc.rpc_async(dst, my_sleep_func, args=(1,), timeout=timeout) self.assertFalse(hasattr(_thread_local_var, "future_list")) @dist_init @@ -3565,7 +3561,7 @@ def test_wait_all_raise_in_user_func(self): with _wait_all(): self.assertTrue(_thread_local_var.future_list == []) dst = worker_name((self.rank + 1) % self.world_size) - fut = rpc.rpc_async(dst, raise_func) + rpc.rpc_async(dst, raise_func) self.assertFalse(hasattr(_thread_local_var, "future_list")) @dist_init @@ -3846,7 +3842,6 @@ def callback(fut): @dist_init def test_callback_wrong_arg_num(self): - set_by_cb = concurrent.futures.Future() n = self.rank + 1 fut = rpc.rpc_async( @@ -3911,7 +3906,6 @@ def callback(idx, fut): @dist_init def test_callback_chain(self): n = self.rank + 1 - dst = worker_name(n % self.world_size) def callback(fut): return fut.wait() + 1 @@ -4030,15 +4024,15 @@ def test_pickle_future(self): errMsg = "Can not pickle torch.futures.Future" dst = worker_name((self.rank + 1) % self.world_size) - with TemporaryFileName() as fname: + with TemporaryFileName(): with self.assertRaisesRegex(RuntimeError, errMsg): rpc.rpc_sync(dst, fail_on_fut, args=(fut,)) - with TemporaryFileName() as fname: + with TemporaryFileName(): with self.assertRaisesRegex(RuntimeError, errMsg): rpc.rpc_async(dst, fail_on_fut, args=(fut,)) - with TemporaryFileName() as fname: + with TemporaryFileName(): with self.assertRaisesRegex(RuntimeError, errMsg): rpc.remote(dst, fail_on_fut, args=(fut,)) @@ -4380,7 +4374,7 @@ def test_wait_all_with_exception(self): futs.append(rpc.rpc_async(dst, raise_func)) with self.assertRaisesRegex(ValueError, "Expected error"): - ret = torch.futures.wait_all(futs) + torch.futures.wait_all(futs) @dist_init def test_wait_all_with_partial_exception(self): @@ -4392,7 +4386,7 @@ def test_wait_all_with_partial_exception(self): futs.append(rpc.rpc_async(dst, raise_func)) with self.assertRaisesRegex(ValueError, "Expected error"): - ret = torch.futures.wait_all(futs) + torch.futures.wait_all(futs) @dist_init(setup_rpc=False) @skip_but_pass_in_sandcastle_if( @@ -4717,7 +4711,7 @@ def test_tensorpipe_options_throw_on_timedelta_timeout(self): timeout = timedelta() # Ensure that constructing TensorPipeRpcBackendOptions with timedelta fails with self.assertRaisesRegex(TypeError, "incompatible constructor arguments"): - rpc_backend_options = rpc.TensorPipeRpcBackendOptions( + rpc.TensorPipeRpcBackendOptions( init_method=self.rpc_backend_options.init_method, num_worker_threads=self.rpc_backend_options.num_worker_threads, rpc_timeout=timeout, @@ -5747,9 +5741,6 @@ def test_device_maps_missing_config(self): @skip_if_lt_x_gpu(1) def test_device_maps_missing_config_not_timeout(self): - dst = worker_name((self.rank + 1) % self.world_size) - options = self.rpc_backend_options - rpc.init_rpc( name=worker_name(self.rank), backend=self.rpc_backend, @@ -5973,7 +5964,7 @@ def test_device_mismatch(self): RuntimeError, "Expected all tensors to be on the same device, but found at least two devices" ): - rets = rpc.rpc_sync( + rpc.rpc_sync( dst, TensorPipeAgentCudaRpcTest._gpu_add_wrong_gpus, args=(x, y) @@ -6284,22 +6275,22 @@ def test_devices_option_mismatch_reverse(self): @skip_if_lt_x_gpu(1) def test_cuda_future_device_as_int(self): - fut = Future(devices=[0]) + Future(devices=[0]) @skip_if_lt_x_gpu(1) def test_cuda_future_device_as_str(self): - fut = Future(devices=["cuda:0"]) + Future(devices=["cuda:0"]) @skip_if_lt_x_gpu(1) def test_cuda_future_device_as_device(self): - fut = Future(devices=[torch.device("cuda", 0)]) + Future(devices=[torch.device("cuda", 0)]) @skip_if_lt_x_gpu(1) def test_cuda_future_device_not_cuda(self): with self.assertRaisesRegex( ValueError, "Expected devices to have indices, got cpu" ): - fut = Future(devices=["cpu"]) + Future(devices=["cpu"]) @skip_if_lt_x_gpu(1) def test_cuda_future_can_extract_cuda_tensor(self): diff --git a/torch/testing/_internal/fake_config_module.py b/torch/testing/_internal/fake_config_module.py new file mode 100644 index 0000000000000..52b2c5f5e10cc --- /dev/null +++ b/torch/testing/_internal/fake_config_module.py @@ -0,0 +1,30 @@ +import sys +from typing import Optional + +from torch.utils._config_module import install_config_module + + +e_bool = True +e_int = 1 +e_float = 1.0 +e_string = "string" +e_list = [1] +e_set = {1} +e_tuple = (1,) +e_dict = {1: 2} +e_none: Optional[bool] = None +e_ignored = True +_e_ignored = True +magic_cache_config_ignored = True +# [@compile_ignored: debug] +e_compile_ignored = True + + +class nested: + e_bool = True + + +_cache_config_ignore_prefix = ["magic_cache_config"] +_save_config_ignore = ["e_ignored"] + +install_config_module(sys.modules[__name__]) diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index 9c3ea935d9076..25bf3f16806a6 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -1,47 +1,58 @@ # mypy: ignore-errors -import torch import functools -from torch.testing import make_tensor import unittest + +import torch from functorch.experimental.control_flow import map -from torch.testing._internal.opinfo.core import ( - OpInfo, - SampleInput, -) -from torch.testing._internal.common_dtype import all_types_and, custom_types -from torch.testing._internal.opinfo.core import DecorateInfo +from torch.nn.attention.flex_attention import _create_empty_block_mask, flex_attention +from torch.testing import make_tensor from torch.testing._internal.common_device_type import onlyCUDA -from torch.nn.attention.flex_attention import flex_attention, _create_empty_block_mask +from torch.testing._internal.common_dtype import all_types_and, custom_types +from torch.testing._internal.opinfo.core import DecorateInfo, OpInfo, SampleInput + def sample_inputs_map(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial( - make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) - yield SampleInput([make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)], - args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2))) + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + yield SampleInput( + [make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)], + args=(make_arg(1, low=0.1, high=2), make_arg(1, low=0.1, high=2)), + ) + def inner_f(x, y0, y1): - return [x[0].cos().add_(1.) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())] + return [x[0].cos().add_(1.0) * y0, (x[1] + y1.sin()).cos_().view(x[1].size())] + def simple_map(xs, y0, y1): def f(x, y0, y1): return inner_f(x, y0, y1) + return map(f, xs, y0, y1) + def nested_map(xs, y0, y1): def f1(xx, y0, y1): def f2(x, y0, y1): return inner_f(x, y0, y1) + return map(f2, xx, y0, y1) + return map(f1, xs, y0, y1) + def triple_nested_map(xs, y0, y1): def f0(xs, y0, y1): def f1(xx, y0, y1): def f2(x, y0, y1): return inner_f(x, y0, y1) + return map(f2, xx, y0, y1) + return map(f1, xs, y0, y1) + return map(f0, xs, y0, y1) @@ -102,11 +113,27 @@ def simple_cond(x): return torch.cond(x.sum() > 2, lambda x: (x.cos(),), lambda x: (x.sin(),), [x]) +def sample_inputs_invoke_subgraph(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2)) + + +def simple_invoke_subgraph(x): + def fn(x): + return (torch.sin(x),) + + return torch._higher_order_ops.invoke_subgraph(fn, None, (x,)) + + def sample_inputs_auto_functionalize(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial( make_tensor, device=device, dtype=dtype, requires_grad=False ) - yield SampleInput(make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2)) + yield SampleInput( + make_arg(2, 2, 2, low=0.1, high=2), make_arg(2, 2, 2, low=0.1, high=2) + ) def simple_auto_functionalize(x, z): @@ -123,13 +150,8 @@ def score_mod(score, b, h, m, n): q, k, v = (make_arg(2, 2, 128, 8, low=0.1, high=2) for _ in range(3)) block_mask = _create_empty_block_mask(q, k) - yield SampleInput( - q, - k, - v, - score_mod, - block_mask - ) + yield SampleInput(q, k, v, score_mod, block_mask) + def sample_inputs_while_loop(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial( @@ -140,6 +162,7 @@ def sample_inputs_while_loop(opinfo, device, dtype, requires_grad, **kwargs): make_arg(2, 3, 4, low=0.1, high=2), ) + def simple_while_loop(iter_t, x): def cond_fn(iter_t, x): return iter_t > 0 @@ -150,7 +173,56 @@ def body_fn(iter_t, x): return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x)) +def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs): + make_arg = functools.partial( + make_tensor, device=device, dtype=dtype, requires_grad=requires_grad + ) + yield SampleInput( + make_arg(2, 2, low=0.1, high=2), + make_arg(2, 2, 2, low=0.1, high=2), + ) + + +def simple_scan(init, xs): + + def combine_fn(carry, x): + result = carry @ x + x + return result, carry.clone() + + return torch._higher_order_ops.scan(combine_fn, init, xs) + + hop_db = [ + OpInfo( + name="scan", + variant_test_name="simple", + op=simple_scan, + sample_inputs_func=sample_inputs_scan, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=False, + # "torch.compile with aot_autograd does not currently support double backward." + supports_gradgrad=False, + ), + OpInfo( + name="invoke_subgraph", + variant_test_name="simple", + op=simple_invoke_subgraph, + sample_inputs_func=sample_inputs_invoke_subgraph, + dtypes=all_types_and(torch.bool, torch.half), + supports_out=False, + check_batched_grad=False, + check_batched_gradgrad=False, + check_batched_forward_grad=False, + check_inplace_batched_forward_grad=False, + supports_autograd=True, + # "torch.compile with aot_autograd does not currently support double backward." + supports_gradgrad=False, + ), OpInfo( name="map", variant_test_name="simple", @@ -241,7 +313,9 @@ def body_fn(iter_t, x): check_inplace_batched_forward_grad=False, skips=( DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"), - DecorateInfo(unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"), + DecorateInfo( + unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export" + ), DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"), DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"), ), @@ -260,10 +334,12 @@ def body_fn(iter_t, x): check_inplace_batched_forward_grad=False, skips=( DecorateInfo(unittest.expectedFailure, "TestHOP", "test_aot_export"), - DecorateInfo(unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export"), + DecorateInfo( + unittest.expectedFailure, "TestHOP", "test_pre_dispatch_export" + ), DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"), DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"), ), decorators=[onlyCUDA], - ) + ), ] diff --git a/torch/testing/_internal/hypothesis_utils.py b/torch/testing/_internal/hypothesis_utils.py index 98aa82e1c93d2..139470ccc20e2 100644 --- a/torch/testing/_internal/hypothesis_utils.py +++ b/torch/testing/_internal/hypothesis_utils.py @@ -36,7 +36,7 @@ }) def _get_valid_min_max(qparams): - scale, zero_point, quantized_type = qparams + scale, zero_point, _quantized_type = qparams adjustment = 1 + torch.finfo(torch.float).eps _long_type_info = torch.iinfo(torch.long) long_min, long_max = _long_type_info.min / adjustment, _long_type_info.max / adjustment @@ -317,11 +317,11 @@ def tensor_conv( spatial_dim = draw(st.sampled_from(spatial_dim)) feature_map_shape = [] - for i in range(spatial_dim): + for _ in range(spatial_dim): feature_map_shape.append(draw(st.integers(*feature_map_range))) kernels = [] - for i in range(spatial_dim): + for _ in range(spatial_dim): kernels.append(draw(st.integers(*kernel_range))) tr = False diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index d8e86651fa9a5..5441ef761ce65 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -39,9 +39,11 @@ def test_cpu(): HAS_CPU = LazyVal(test_cpu) -HAS_CUDA = torch.cuda.is_available() and has_triton() +HAS_TRITON = has_triton() -HAS_XPU = torch.xpu.is_available() and has_triton() +HAS_CUDA = torch.cuda.is_available() and HAS_TRITON + +HAS_XPU = torch.xpu.is_available() and HAS_TRITON HAS_GPU = HAS_CUDA or HAS_XPU @@ -102,6 +104,7 @@ def skip_windows_ci(name: str, file: str) -> None: raise unittest.SkipTest("requires sympy/functorch/filelock") requires_gpu = functools.partial(unittest.skipIf, not HAS_GPU, "requires gpu") +requires_triton = functools.partial(unittest.skipIf, not HAS_TRITON, "requires triton") skipCUDAIf = functools.partial(skipDeviceIf, device="cuda") skipXPUIf = functools.partial(skipDeviceIf, device="xpu") diff --git a/torch/testing/_internal/jit_metaprogramming_utils.py b/torch/testing/_internal/jit_metaprogramming_utils.py index 02a9fcc5405e5..30a6b8f8e067a 100644 --- a/torch/testing/_internal/jit_metaprogramming_utils.py +++ b/torch/testing/_internal/jit_metaprogramming_utils.py @@ -508,19 +508,16 @@ def get_nn_functional_compiled_fn_and_inputs(name, self_size, args, variant_name if variant_name != '': test_name = test_name + '_' + variant_name - no_grad = variant_name == 'inplace' - self_variable = create_input((self_size,))[0][0] - kwargs = None # need to record this because methods can change the size (e.g. unsqueeze) - args_variable, kwargs_variable = create_input(args) + args_variable, _kwargs_variable = create_input(args) self_tensor = deepcopy(self_variable.data) args_tensor = deepcopy(unpack_variables(args_variable)) f_args_variable = (self_variable,) + args_variable - f_args_tensor = (self_tensor,) + args_tensor + f_args_tensor = (self_tensor,) + args_tensor # noqa: F841 with torch._jit_internal._disable_emit_hooks(): script_fn, inputs = gen_script_fn_and_args(name, "nn_functional", *f_args_variable) return script_fn, inputs @@ -589,7 +586,7 @@ def forward({}): def create_script_module(self, nn_module, constructor_args, *args, **kwargs): def script_module(*args, **kwargs): - formals, tensors, actuals = get_script_args(args) + _formals, tensors, actuals = get_script_args(args) method_args = ', '.join(['self'] + actuals) call_args_str = ', '.join(actuals) @@ -709,11 +706,14 @@ def try_get_nn_module_compiled_mod_and_inputs(*args, **kwargs): input = (input,) input = input + (kwargs['target_fn'](),) - args_variable, kwargs_variable = create_input(input, dtype=input_dtype) + args_variable, _kwargs_variable = create_input(input, dtype=input_dtype) f_args_variable = deepcopy(unpack_variables(args_variable)) out_var = deepcopy(f_args_variable) - args, mod = f_args_variable, create_script_module(None, nn_module, constructor_args, *f_args_variable)(*f_args_variable) + + _args, mod = f_args_variable, create_script_module( + None, nn_module, constructor_args, *f_args_variable + )(*f_args_variable) return mod, out_var diff --git a/torch/testing/_internal/jit_utils.py b/torch/testing/_internal/jit_utils.py index a8c7fa261f998..f359e81979769 100644 --- a/torch/testing/_internal/jit_utils.py +++ b/torch/testing/_internal/jit_utils.py @@ -230,7 +230,7 @@ def extract_files(buffer): # and it's easier to just work with a fresh copy each time. buffer_copy = buffer.getvalue() - code_files, debug_files = extract_files(buffer) + code_files, _debug_files = extract_files(buffer) except RuntimeError as e: if not self._isHookExceptionOk(e): @@ -247,7 +247,7 @@ def extract_files(buffer): torch.jit.save(imported, saved_module_buffer_2) saved_module_buffer_2.seek(0) - code_files_2, debug_files_2 = extract_files(saved_module_buffer_2) + code_files_2, _debug_files_2 = extract_files(saved_module_buffer_2) for a, b in zip(code_files, code_files_2): self.assertMultiLineEqual(a, b) @@ -503,7 +503,7 @@ def checkScript(self, if capture_output: with self.capture_stdout() as script_stdout: script_outputs = scripted_fn(*recording_inputs) - with self.capture_stdout() as opt_script_stdout: + with self.capture_stdout(): opt_script_outputs = scripted_fn(*recording_inputs) with self.capture_stdout() as _python_stdout: python_outputs = python_fn(*inputs) @@ -740,7 +740,7 @@ def attrs_with_prefix(module, prefix): def warmup_backward(f, *args): profiling_count = 3 results = [] - for i in range(profiling_count): + for _ in range(profiling_count): if len(args) > 0: r = torch.autograd.grad(f, *args) results.append(r) diff --git a/torch/testing/_internal/opinfo/definitions/_masked.py b/torch/testing/_internal/opinfo/definitions/_masked.py index eda339ebfe68a..66a5fb2c2b073 100644 --- a/torch/testing/_internal/opinfo/definitions/_masked.py +++ b/torch/testing/_internal/opinfo/definitions/_masked.py @@ -357,7 +357,6 @@ def sample_inputs_masked_softmax( def sample_inputs_masked_cumops(op_info, device, dtype, requires_grad, **kwargs): """Sample inputs for masked cumsum and cumprod.""" - inputs: List[SampleInput] = [] for sample_input in sample_inputs_softmax_variant( op_info, device, dtype, requires_grad, **kwargs ): diff --git a/torch/testing/_internal/opinfo/definitions/nested.py b/torch/testing/_internal/opinfo/definitions/nested.py index e86daab5f1a95..f13caa8c49eca 100644 --- a/torch/testing/_internal/opinfo/definitions/nested.py +++ b/torch/testing/_internal/opinfo/definitions/nested.py @@ -132,7 +132,6 @@ def _slice_input(t, i=i, inp=nt_inp): def reduction_reference(op, sample): assert sample.input.is_nested dim = sample.kwargs.get("dim", None) - keepdim = sample.kwargs.get("keepdim", False) assert dim != 0, "reductions over the batch dim are not supported" assert "dims" not in sample.kwargs assert sample.input._ragged_idx == 1 @@ -340,6 +339,73 @@ def sample_inputs_to(op_info, device, dtype, requires_grad, op_kwargs=None, **kw ) +def sample_inputs_bmm(op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs): + for njt_3d in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[3] + ): + # (B, j1, D) x (B, D, E) => (B, j1, E) + B, D = njt_3d.shape[0], njt_3d.shape[-1] + E = D + 2 + other = torch.randn(B, D, E, device=device, dtype=dtype) + # used for slicing in unbind_reference() + other._batch_dim = 0 + yield SampleInput(njt_3d.clone().detach(), kwargs={"mat2": other}) + + # TODO (need factory functions): + # (B, D, j1) x (B, j1, E) => (B, D, E) + + +def reference_bmm(op, sample): + # unbind reduces a dim and bmm requires 3D, so use matmul as the reference + matmul_op = copy(op) + matmul_op.op = torch.matmul + # change arg name from mat2 -> other + modified_sample = copy(sample) + other = modified_sample.kwargs["mat2"] + del modified_sample.kwargs["mat2"] + modified_sample.kwargs["other"] = other + return unbind_reference(matmul_op, modified_sample) + + +def sample_inputs_matmul( + op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs +): + # also run bmm samples through + for sample_input in sample_inputs_bmm(op_info, device, dtype, requires_grad): + # change arg name from mat2 -> other + other = sample_input.kwargs["mat2"] + del sample_input.kwargs["mat2"] + sample_input.kwargs["other"] = other + yield sample_input + + # 3D cases not covered by bmm + for njt_3d in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[3] + ): + # (B, j1, D) x (D, E) => (B, j1, E) + D = njt_3d.shape[-1] + E = D + 2 + yield SampleInput( + njt_3d.clone().detach(), + kwargs={"other": torch.randn(D, E, device=device, dtype=dtype)}, + ) + + # 4D cases + for njt_4d in _sample_njts( + device=device, dtype=dtype, requires_grad=requires_grad, dims=[4] + ): + # (B, j1, D, E) x (E, F) => (B, j1, D, F) + E = njt_4d.shape[-1] + F = E + 2 + yield SampleInput( + njt_4d.clone().detach(), + kwargs={"other": torch.randn(E, F, device=device, dtype=dtype)}, + ) + + # TODO (need factory functions): + # (B, j1, D, E) x (B, j1, E, F) => (B, j1, D, F) + + def sample_inputs_masked_select( op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs ): @@ -480,6 +546,7 @@ def sample_inputs_nn_functional_rms_norm( # to specify if they cannot be auto-generated for some reason. Try to keep these sorted # in alphabetical order! njt_sample_inputs = { + "bmm": sample_inputs_bmm, "clone": sample_inputs_clone, **{f"mvlgamma.mvlgamma_p_{p}": sample_inputs_mvl_gamma(p=1) for p in (1, 3, 5)}, "nn.functional.embedding_bag": sample_inputs_nn_functional_embedding_bag, @@ -489,10 +556,12 @@ def sample_inputs_nn_functional_rms_norm( **{f"polygamma.polygamma_n_{n}": sample_inputs_polygamma_n(n=n) for n in range(5)}, "special.polygamma.special_polygamma_n_0": sample_inputs_special_polygamma_n(n=0), "to": sample_inputs_to, + "matmul": sample_inputs_matmul, "masked_select": sample_inputs_masked_select, } njt_references = { + "bmm": reference_bmm, "nn.functional.embedding_bag": reference_nn_functional_embedding_bag, } diff --git a/torch/testing/_internal/opinfo/definitions/sparse.py b/torch/testing/_internal/opinfo/definitions/sparse.py index 3e1f816d9f73f..41c17471d9de2 100644 --- a/torch/testing/_internal/opinfo/definitions/sparse.py +++ b/torch/testing/_internal/opinfo/definitions/sparse.py @@ -237,7 +237,6 @@ def _validate_sample_input_sparse_reduction(op_info, sample, check_validate=Fals if op_info.name in {"masked.amax", "masked.amin", "masked.mean", "masked.prod"}: t_inp = sample.input - batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim() mask = sample.kwargs.get("mask") if ( mask is not None @@ -321,7 +320,7 @@ def _validate_sample_input_sparse_reduction(op_info, sample, check_validate=Fals def _validate_sample_input_sparse_reduction_sum(sample, check_validate=False): # NOTE: When fixing a failing sample case, remove the # corresponding if-block - t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs + t_inp, t_kwargs = sample.input, sample.kwargs dim = t_kwargs.get("dim") keepdim = t_kwargs.get("keepdim") layout = t_inp.layout @@ -569,7 +568,7 @@ def _to_sparse(tensor, **kwargs): def _validate_sample_input_elementwise_binary_sparse_mul(sample): # NOTE: When fixing a failing sample case, remove the # corresponding if-block - t_inp, t_args, t_kwargs = sample.input, sample.args, sample.kwargs + t_inp, t_args = sample.input, sample.args batch_dim = t_inp.dim() - t_inp.dense_dim() - t_inp.sparse_dim() layout = t_inp.layout dtype = t_inp.dtype diff --git a/torch/testing/_internal/opinfo/definitions/special.py b/torch/testing/_internal/opinfo/definitions/special.py index 5b137799db8e5..f153deacaa99e 100644 --- a/torch/testing/_internal/opinfo/definitions/special.py +++ b/torch/testing/_internal/opinfo/definitions/special.py @@ -130,7 +130,6 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): ref=scipy.special.i0e if TEST_SCIPY else None, decorators=(precisionOverride({torch.bfloat16: 3e-1, torch.float16: 3e-1}),), dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), - backward_dtypes=floating_types(), sample_inputs_func=sample_inputs_i0_i1, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -141,8 +140,8 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): ref=np_unary_ufunc_integer_promotion_wrapper(scipy.special.i1) if TEST_SCIPY else None, - dtypes=all_types_and(torch.bool), - dtypesIfCUDA=all_types_and(torch.bool), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + backward_dtypes=floating_types(), sample_inputs_func=sample_inputs_i0_i1, decorators=( DecorateInfo( @@ -169,8 +168,8 @@ def sample_inputs_erfcx(op_info, device, dtype, requires_grad, **kwargs): "special.i1e", aten_name="special_i1e", ref=scipy.special.i1e if TEST_SCIPY else None, - dtypes=all_types_and(torch.bool), - dtypesIfCUDA=all_types_and(torch.bool), + dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16), + backward_dtypes=floating_types(), sample_inputs_func=sample_inputs_i0_i1, supports_forward_ad=True, supports_fwgrad_bwgrad=True, diff --git a/torch/testing/_internal/opinfo/utils.py b/torch/testing/_internal/opinfo/utils.py index 05468e10da2c9..a7b1f61c7d263 100644 --- a/torch/testing/_internal/opinfo/utils.py +++ b/torch/testing/_internal/opinfo/utils.py @@ -86,7 +86,7 @@ def get_supported_dtypes(op, sample_inputs_fn, device_type): for sample in samples: try: op(sample.input, *sample.args, **sample.kwargs) - except RuntimeError as re: + except RuntimeError: # dtype is not supported supported = False break diff --git a/torch/testing/_internal/optests/aot_autograd.py b/torch/testing/_internal/optests/aot_autograd.py index 975ea555a1ec3..d82bbdbee6e37 100644 --- a/torch/testing/_internal/optests/aot_autograd.py +++ b/torch/testing/_internal/optests/aot_autograd.py @@ -48,7 +48,6 @@ def aot_autograd_check( """ flat_args, args_spec = pytree.tree_flatten((args, kwargs)) - args_is_tensor = [isinstance(arg, torch.Tensor) for arg in flat_args] args = [arg for arg in flat_args if isinstance(arg, torch.Tensor)] # We construct a new function that only accepts Tensors as inputs diff --git a/torch/testing/_internal/optests/generate_tests.py b/torch/testing/_internal/optests/generate_tests.py index 7fac1e57c6ac8..183968f697ab9 100644 --- a/torch/testing/_internal/optests/generate_tests.py +++ b/torch/testing/_internal/optests/generate_tests.py @@ -392,9 +392,7 @@ def validate_failures_dict_structure( """ failure_dict = failure_dict.data - qualnames = list(failure_dict.keys()) for test_to_option in failure_dict.values(): - test_names = list(test_to_option.keys()) for test_name, test_dict in test_to_option.items(): if set(test_dict.keys()) != set({"comment", "status"}): raise RuntimeError( diff --git a/torch/testing/_internal/torchbind_impls.py b/torch/testing/_internal/torchbind_impls.py index ad728aa909744..5566b241f5625 100644 --- a/torch/testing/_internal/torchbind_impls.py +++ b/torch/testing/_internal/torchbind_impls.py @@ -75,6 +75,7 @@ def meta_takes_foo_tuple_return(foo, x): def register_fake_classes(): + # noqa: F841 @torch._library.register_fake_class("_TorchScriptTesting::_Foo") class FakeFoo: def __init__(self, x: int, y: int): diff --git a/torch/testing/_internal/triton_utils.py b/torch/testing/_internal/triton_utils.py index d3a8065f29404..0443551ef5106 100644 --- a/torch/testing/_internal/triton_utils.py +++ b/torch/testing/_internal/triton_utils.py @@ -191,6 +191,71 @@ def add_kernel_with_scaling( output = (x + y) * scaling_factor tl.store(out_ptr + offsets, output, mask=mask) + @triton.jit + def add_kernel_with_tma_1d( + in_desc_ptr0, + in_desc_ptr1, + out_desc_ptr, + BLOCK_SIZE: "tl.constexpr", + ): + pid = tl.program_id(axis=0) + offset = pid * BLOCK_SIZE + + a = tl._experimental_descriptor_load( + in_desc_ptr0, + [offset], + [BLOCK_SIZE], + tl.float32, + ) + b = tl._experimental_descriptor_load( + in_desc_ptr1, + [offset], + [BLOCK_SIZE], + tl.float32, + ) + + output = a + b + + tl._experimental_descriptor_store( + out_desc_ptr, + output, + [offset], + ) + + @triton.jit + def add_kernel_with_tma_2d( + in_desc_ptr0, + in_desc_ptr1, + out_desc_ptr, + BLOCK_SIZE_X: "tl.constexpr", + BLOCK_SIZE_Y: "tl.constexpr", + ): + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + offset_x = pid_x * BLOCK_SIZE_X + offset_y = pid_y * BLOCK_SIZE_Y + + x = tl._experimental_descriptor_load( + in_desc_ptr0, + [offset_x, offset_y], + [BLOCK_SIZE_X, BLOCK_SIZE_Y], + tl.float32, + ) + y = tl._experimental_descriptor_load( + in_desc_ptr1, + [offset_x, offset_y], + [BLOCK_SIZE_X, BLOCK_SIZE_Y], + tl.float32, + ) + + output = x + y + + tl._experimental_descriptor_store( + out_desc_ptr, + output, + [offset_x, offset_y], + ) + @triton.jit def mul2_kernel( in_ptr0, diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 8aa9d4063f018..e55caad06e707 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -8,7 +8,7 @@ import unittest import warnings from types import FunctionType, ModuleType -from typing import Any, Callable, Dict, NoReturn, Optional, Set, Union +from typing import Any, Callable, Dict, List, NoReturn, Optional, Set, Union from typing_extensions import deprecated from unittest import mock @@ -144,6 +144,7 @@ def __setattr__(self, name: str, value: object) -> None: raise AttributeError(f"{self.__name__}.{name} does not exist") else: self._config[name] = value + self._is_dirty = True def __getattr__(self, name: str) -> Any: try: @@ -153,29 +154,61 @@ def __getattr__(self, name: str) -> Any: raise AttributeError(f"{self.__name__}.{name} does not exist") from e def __delattr__(self, name: str) -> None: + self._is_dirty = True # must support delete because unittest.mock.patch deletes # then recreate things del self._config[name] + def _get_dict( + self, + ignored_keys: Optional[List[str]] = None, + ignored_prefixes: Optional[List[str]] = None, + skip_default: bool = False, + ) -> Dict[str, Any]: + """Export a dictionary of current configuration keys and values. + + This function is design to provide a single point which handles + accessing config options and exporting them into a dictionary. + This is used by a number of different user facing export methods + which all have slightly different semantics re: how and what to + skip. + + Arguments: + ignored_keys are keys that should not be exported. + ignored_prefixes are prefixes that if a key matches should + not be exported + skip_default does two things. One if a key has not been modified + it skips it. The other is it modified the logging behaviour + to match what codegen already did for modified skipped keys + """ + config: Dict[str, Any] = {} + for key in self._config: + if ignored_keys and key in ignored_keys: + if skip_default and self._config[key] != self._default[key]: + warnings.warn( + f"Skipping serialization of {key} value {self._config[key]}" + ) + continue + if ignored_prefixes: + if any(key.startswith(prefix) for prefix in ignored_prefixes): + continue + if skip_default and self._config[key] == self._default[key]: + continue + config[key] = copy.deepcopy(self._config[key]) + return config + def save_config(self) -> bytes: """Convert config to a pickled blob""" - config = dict(self._config) - for key in config.get("_save_config_ignore", ()): - config.pop(key) - return pickle.dumps(config, protocol=2) + return pickle.dumps( + self._get_dict(ignored_keys=self._config.get("_save_config_ignore", ())), + protocol=2, + ) def save_config_portable(self) -> Dict[str, Any]: """Convert config to portable format""" - config: Dict[str, Any] = {} - for key in sorted(self._config): - if key.startswith("_"): - continue - if any( - key.startswith(e) for e in self._config["_cache_config_ignore_prefix"] - ): - continue - config[key] = self._config[key] - return config + prefixes = ["_"] + prefixes.extend(self._config["_cache_config_ignore_prefix"]) + return self._get_dict(ignored_prefixes=prefixes) def codegen_config(self) -> str: """Convert config to Python statements that replicate current config. @@ -183,39 +216,38 @@ def codegen_config(self) -> str: """ lines = [] mod = self.__name__ - for k, v in self._config.items(): - if k in self._config.get("_save_config_ignore", ()): - if v != self._default[k]: - warnings.warn(f"Skipping serialization of {k} value {v}") - continue - if v == self._default[k]: - continue + for k, v in self._get_dict( + ignored_keys=self._config.get("_save_config_ignore"), skip_default=True + ).items(): lines.append(f"{mod}.{k} = {v!r}") return "\n".join(lines) def get_hash(self) -> bytes: """Hashes the configs that are not compile_ignored""" if self._is_dirty or self._hash_digest is None: - dict_to_hash = { - k: v - for k, v in self._config.items() - if k not in self._compile_ignored_keys - } + dict_to_hash = self._get_dict(ignored_keys=list(self._compile_ignored_keys)) string_to_hash = repr(sorted(dict_to_hash.items())) self._hash_digest = hashlib.md5(string_to_hash.encode("utf-8")).digest() self._is_dirty = False return self._hash_digest @deprecated( - "`config.to_dict()` has been deprecated. It may no longer change the underlying config." - " use `config.shallow_copy_dict()` or `config.get_config_copy()` instead", + "`config.to_dict()` has been deprecated. It no longer changes the underlying config." + " use `config.get_config_copy()` instead if you just want a copy of the config, or " + "config.load_config if you need mutable access", category=FutureWarning, ) def to_dict(self) -> Dict[str, Any]: - return self.shallow_copy_dict() + return self.get_config_copy() + @deprecated( + "`config.shallow_copy_dict()` has been deprecated. It no longer changes the underlying config." + " use `config.get_config_copy()` instead if you just want a copy of the config, or " + "config.load_config if you need mutable access", + category=FutureWarning, + ) def shallow_copy_dict(self) -> Dict[str, Any]: - return {**self._config} + return self.get_config_copy() def load_config(self, maybe_pickled_config: Union[bytes, Dict[str, Any]]) -> None: """Restore from a prior call to save_config() or shallow_copy_dict()""" @@ -226,7 +258,7 @@ def load_config(self, maybe_pickled_config: Union[bytes, Dict[str, Any]]) -> Non self._config.update(config) def get_config_copy(self) -> Dict[str, Any]: - return copy.deepcopy(self._config) + return self._get_dict() def patch( self, @@ -268,23 +300,19 @@ def foo(...): assert isinstance(changes, dict), f"expected `dict` got {type(changes)}" prior: Dict[str, Any] = {} config = self - dirty = False class ConfigPatch(ContextDecorator): def __enter__(self) -> None: assert not prior - nonlocal dirty for key in changes.keys(): # KeyError on invalid entry - prior[key] = config._config[key] - dirty = key not in config._compile_ignored_keys - config._config.update(changes) - config._is_dirty = dirty + prior[key] = config.__getattr__(key) + for k, v in changes.items(): + config.__setattr__(k, v) def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore[no-untyped-def] - nonlocal dirty - config._config.update(prior) - config._is_dirty = dirty + for k, v in prior.items(): + config.__setattr__(k, v) prior.clear() return ConfigPatch() diff --git a/torch/utils/_freeze.py b/torch/utils/_freeze.py index d2d6cea7f2284..60bdbf8b056ec 100644 --- a/torch/utils/_freeze.py +++ b/torch/utils/_freeze.py @@ -113,8 +113,7 @@ def msg(self, path: Path, code: str): # S: skipped (not a package dir) # X: skipped (deny-listed) # N: skipped (not a python file) - for i in range(self.indent): - print(" ", end="") + print(" " * self.indent, end="") print(f"{code} {path}") def write_bytecode(self, install_root): diff --git a/torch/utils/_python_dispatch.py b/torch/utils/_python_dispatch.py index 5460bffb809fd..04604bc6ec59e 100644 --- a/torch/utils/_python_dispatch.py +++ b/torch/utils/_python_dispatch.py @@ -4,7 +4,7 @@ import warnings from dataclasses import dataclass from typing import Any, Dict, List, Optional, Set, Union, Protocol, Tuple, Sequence, overload, Deque, Type -from typing_extensions import TypeGuard +from typing_extensions import TypeIs from collections import deque import torch @@ -365,7 +365,7 @@ def to( -def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]: +def is_traceable_wrapper_subclass(t: object) -> TypeIs[TensorWithFlatten]: """ Returns whether or not a tensor subclass that implements __torch_dispatch__ is 'traceable' with torch.compile. @@ -402,7 +402,7 @@ def is_traceable_wrapper_subclass(t: object) -> TypeGuard[TensorWithFlatten]: and hasattr(t, "__tensor_unflatten__") ) -def is_traceable_wrapper_subclass_type(t: Type) -> TypeGuard[Type[TensorWithFlatten]]: +def is_traceable_wrapper_subclass_type(t: Type) -> TypeIs[Type[TensorWithFlatten]]: """Same as above, but takes a type argument instead of an instance.""" return (issubclass(t, torch.Tensor) and t != torch.Tensor and hasattr(t, "__tensor_flatten__") and hasattr(t, "__tensor_unflatten__")) @@ -463,7 +463,6 @@ def _correct_storage_aliasing(func, schema_info, args, outs): assert isinstance(func, torch._ops.OpOverload) assert isinstance(args, tuple) assert isinstance(outs, (list, tuple)) - flat_outs = torch.utils._pytree.tree_leaves(outs) def alias_non_inplace_storage(arg, ret): # This is hopefully a reasonable assert: diff --git a/torch/utils/_strobelight/cli_function_profiler.py b/torch/utils/_strobelight/cli_function_profiler.py index 2e6ed474efd55..c2e4ae679a941 100644 --- a/torch/utils/_strobelight/cli_function_profiler.py +++ b/torch/utils/_strobelight/cli_function_profiler.py @@ -226,7 +226,7 @@ def _stop_strobelight_no_throw( return self._get_results() - except Exception as error: + except Exception: logger.warning("error during stop_strobelight", exc_info=True) # Return true if strobelight started and is running. Never throw. @@ -240,7 +240,7 @@ def _start_strobelight(self) -> bool: logger.info("strobelight profiling running") return True - except Exception as error: + except Exception: logger.warning("error during start_strobelight:", exc_info=True) if strobelight_started: self._stop_strobelight_no_throw(collect_results=False) diff --git a/torch/utils/_strobelight/examples/cli_function_profiler_example.py b/torch/utils/_strobelight/examples/cli_function_profiler_example.py index d92fa3b8a6031..b67a8abd9f41d 100644 --- a/torch/utils/_strobelight/examples/cli_function_profiler_example.py +++ b/torch/utils/_strobelight/examples/cli_function_profiler_example.py @@ -15,7 +15,7 @@ def fn(x, y, z): @strobelight(sample_each=10000, stop_at_error=False) @torch.compile() def work(): - for i in range(10): + for _ in range(10): torch._dynamo.reset() for j in range(5): torch._dynamo.reset() @@ -29,7 +29,7 @@ def work(): @strobelight(profiler, sample_tags=["something", "another"]) def work2(): sum = 0 - for i in range(100000000): + for _ in range(100000000): sum += 1 work2() diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 6747a314cedd5..e3e248a008328 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -577,6 +577,7 @@ def __new__(cls, *args, **assumptions): args = cls._collapse_arguments(args, **assumptions) # find local zeros args = cls._find_localzeros(args, **assumptions) + args = frozenset(args) if not args: @@ -761,49 +762,44 @@ def _find_localzeros(cls, values, **options): When a value is identified as being more extreme than another member it replaces that member; if this is never true, then the value is simply appended to the localzeros. - """ - localzeros = set() # type: ignore[var-annotated] - for v in values: - is_newzero = True - localzeros_ = list(localzeros) - for z in localzeros_: - if id(v) == id(z): - is_newzero = False - else: - con = cls._is_connected(v, z) - if con: - is_newzero = False - if con is True or con == cls: - localzeros.remove(z) - localzeros.update([v]) - if is_newzero: - localzeros.update([v]) - return localzeros - @classmethod - def _is_connected(cls, x, y): + Unlike the sympy implementation, we only look for zero and one, we don't + do generic is connected test pairwise which is slow """ - Check if x and y are connected somehow. - """ - if x == y: - return True - t, f = Max, Min - for op in "><": - for j in range(2): - try: - if op == ">": - v = x >= y + + # First, collapse all numeric arguments + other_values = set() + num_value = None + for arg in values: + if arg.is_Number: + if num_value is None: + num_value = arg + else: + if cls is Max: + num_value = max(num_value, arg) + elif cls is Min: + num_value = min(num_value, arg) else: - v = x <= y - except TypeError: - return False # non-real arg - if not v.is_Relational: - return t if v else f - t, f = f, t # type: ignore[assignment] - x, y = y, x - x, y = y, x # run next pass with reversed order relative to start - - return False + raise AssertionError(f"impossible {cls}") + else: + other_values.add(arg) + + # Special cases when there is only one symbolic value + if num_value is None: + return other_values + + if len(other_values) == 0: + return {num_value} + + if len(other_values) == 1: + other_value = next(iter(other_values)) + if num_value in (0.0, 0) and other_value.is_nonnegative: + return other_values if cls is Max else {num_value} + if num_value == 1 and other_value.is_positive: + return other_values if cls is Max else {num_value} + + other_values.add(num_value) + return other_values _eval_is_algebraic = lambda s: _torf(i.is_algebraic for i in s.args) # noqa: E731 _eval_is_antihermitian = lambda s: _torf( # noqa: E731 diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index a55c46f798da0..eb03e0697cda2 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -161,6 +161,8 @@ def _run_sympy_handler(analysis, args, expr, index_dtype=torch.int64): r = handler(*args) log.debug("%s(%s) -> %s", handler_name, args, r) return r + except NotImplementedError: + raise except Exception: log.warning("failed while executing %s(%s)", handler_name, args) raise diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 7df7982ab6c04..4e2835fceaf30 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -308,7 +308,7 @@ def sym_sum(args): def _to_dtype(x: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: - return torch.ops.aten._to_copy(x, dtype=dtype) + return torch.ops.prims.convert_element_type.default(x, dtype) # Suppose we have some int/float arguments. This diagram commutes: @@ -444,7 +444,7 @@ def int_truediv(a, b): @staticmethod def floordiv(a, b): - return torch.ops.aten.floor_divide(a, b) + return torch.ops.aten.div.Tensor_mode(a, b, rounding_mode="floor") @staticmethod def truncdiv(a, b): @@ -476,6 +476,42 @@ def log(x): def sqrt(x): return torch.ops.aten.sqrt.default(x) + @staticmethod + def sin(x): + return torch.ops.aten.sin.default(x) + + @staticmethod + def cos(x): + return torch.ops.aten.cos.default(x) + + @staticmethod + def tanh(x): + return torch.ops.aten.tanh.default(x) + + @staticmethod + def sinh(x): + return torch.ops.aten.sinh.default(x) + + @staticmethod + def cosh(x): + return torch.ops.aten.cosh.default(x) + + @staticmethod + def tan(x): + return torch.ops.aten.tan.default(x) + + @staticmethod + def acos(x): + return torch.ops.aten.acos.default(x) + + @staticmethod + def atan(x): + return torch.ops.aten.atan.default(x) + + @staticmethod + def asin(x): + return torch.ops.aten.asin.default(x) + @staticmethod def pow(a, b): return torch.ops.aten.pow.Tensor_Tensor(a, b) diff --git a/torch/utils/_sympy/solve.py b/torch/utils/_sympy/solve.py index cdc722d20fa58..707350a68ac90 100644 --- a/torch/utils/_sympy/solve.py +++ b/torch/utils/_sympy/solve.py @@ -115,7 +115,10 @@ def _try_isolate_lhs( # If we can't tell whether 'other' is negative or positive, we do nothing. # That is because we don't know whether we have mirror the operation or not. - if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None): + # We also divide only when we know 'rhs' is not zero. + if not (isinstance(e, INEQUALITY_TYPES) and other.is_negative is None) and not ( + not isinstance(e, INEQUALITY_TYPES) and rhs.is_zero + ): # Divide both sides by 'other'. lhs = lhs / other rhs = rhs / other diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index b1f2ed7308bde..c9dad120dc935 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -2,6 +2,7 @@ from __future__ import annotations import dataclasses +import functools import itertools import logging import math @@ -224,6 +225,8 @@ def __contains__(self, x: AllIn) -> bool: return ValueRanges.wrap(x).issubset(self) def issubset(self, other): + if other is self.unknown_int(): + return True return sympy_generic_le(other.lower, self.lower) and sympy_generic_le( self.upper, other.upper ) @@ -248,9 +251,9 @@ def __and__( # type: ignore[misc] ... def __and__(self: AllVR, other: AllVR) -> AllVR: - if other == ValueRanges.unknown(): + if other in (ValueRanges.unknown(), ValueRanges.unknown_int()): return self - if self == ValueRanges.unknown(): + if self in (ValueRanges.unknown(), ValueRanges.unknown_int()): return other assert self.is_bool == other.is_bool, (self, other) assert self.is_int == other.is_int, (self, other) @@ -298,14 +301,17 @@ def is_singleton(self) -> bool: return self.lower == self.upper @staticmethod + @functools.lru_cache(maxsize=None) def unknown() -> ValueRanges[sympy.Expr]: return ValueRanges(-sympy.oo, sympy.oo) @staticmethod + @functools.lru_cache(maxsize=None) def unknown_int() -> ValueRanges[sympy.Expr]: return ValueRanges(-int_oo, int_oo) @staticmethod + @functools.lru_cache(maxsize=None) def unknown_bool() -> ValueRanges[SympyBoolean]: return ValueRanges(sympy.false, sympy.true) @@ -445,7 +451,7 @@ def constant(value, dtype): elif dtype.is_floating_point: return ValueRanges.unknown() else: - return ValueRanges(-int_oo, int_oo) + return ValueRanges.unknown_int() if is_python: type_ = dtype_to_type(dtype) diff --git a/torch/utils/_traceback.py b/torch/utils/_traceback.py index aa3944d417086..2a7aa1b56d43b 100644 --- a/torch/utils/_traceback.py +++ b/torch/utils/_traceback.py @@ -237,8 +237,8 @@ def format_all(tbs): rs.append(None) delayed_idxs.append(i) - stbs = torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs]) - for i, stb in zip(delayed_idxs, stbs): + torch._C._profiler.symbolize_tracebacks([tbs[i].tb for i in delayed_idxs]) + for i in delayed_idxs: rs[i] = traceback.format_list(tbs[i].summary()) return rs diff --git a/torch/utils/_triton.py b/torch/utils/_triton.py index 3cc977c848382..fa3431f748996 100644 --- a/torch/utils/_triton.py +++ b/torch/utils/_triton.py @@ -15,6 +15,29 @@ def has_triton_package() -> bool: return False +@functools.lru_cache(None) +def has_triton_tma(): + if has_triton_package(): + import torch + + if ( + torch.cuda.is_available() + and torch.cuda.get_device_capability() >= (9, 0) + and not torch.version.hip + ): + try: + from triton.tools.experimental_descriptor import ( # noqa: F401 + create_1d_tma_descriptor, + create_2d_tma_descriptor, + ) + + return True + except ImportError: + pass + + return False + + @functools.lru_cache(None) def has_triton() -> bool: if not has_triton_package(): diff --git a/torch/utils/benchmark/examples/blas_compare_setup.py b/torch/utils/benchmark/examples/blas_compare_setup.py index 323138d19ddd2..1057037d169a4 100644 --- a/torch/utils/benchmark/examples/blas_compare_setup.py +++ b/torch/utils/benchmark/examples/blas_compare_setup.py @@ -171,7 +171,7 @@ def main(): print(f"Building PyTorch for env: `{env_name}`") # We have to re-run during each build to pick up the new # build config settings. - build_run = subprocess.run( + subprocess.run( f"source activate {env_path} && " f"cd {git_root} && " "python setup.py install --cmake", diff --git a/torch/utils/benchmark/examples/sparse/fuzzer.py b/torch/utils/benchmark/examples/sparse/fuzzer.py index 8f3885839d3fa..8b10fc9fac186 100644 --- a/torch/utils/benchmark/examples/sparse/fuzzer.py +++ b/torch/utils/benchmark/examples/sparse/fuzzer.py @@ -58,7 +58,6 @@ def main(): for i, (tensors, tensor_properties, _) in enumerate(add_fuzzer.take(n=n)): x = tensors["x"] - y = tensors["y"] shape = ", ".join(tuple(f'{i:>4}' for i in x.shape)) x_tensor_properties = tensor_properties["x"] description = "".join([ diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py index 5a3e9f635891d..199a49bde20ff 100644 --- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py +++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py @@ -178,7 +178,6 @@ class CallgrindStats: stmt_callgrind_out: Optional[str] def __repr__(self) -> str: - newline = "\n" # `\` cannot appear in fstring code section. base_stats = self.baseline_exclusive_stats output = f""" {super().__repr__()} @@ -665,7 +664,7 @@ def run(args: List[str], **kwargs: Any) -> Tuple[CompletedProcessType, str]: raise OSError(f"Failed to collect callgrind profile:\n{error_report}") def parse_output(fpath: str, inclusive: bool) -> FunctionCounts: - annotate_invocation, annotate_invocation_output = run([ + _annotate_invocation, annotate_invocation_output = run([ "callgrind_annotate", f"--inclusive={'yes' if inclusive else 'no'}", "--threshold=100", diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index ed0e02c4c1b93..3cba71da0df73 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -4,6 +4,7 @@ # This script outputs relevant system environment info # Run it with `python collect_env.py` or `python -m torch.utils.collect_env` import datetime +import json import locale import re import subprocess @@ -47,26 +48,43 @@ 'cpu_info', ]) -DEFAULT_CONDA_PATTERNS = { +COMMON_PATTERNS = [ "torch", "numpy", + "triton", + "optree", +] + +NVIDIA_PATTERNS = [ + "cuda-cudart", + "cuda-cupti", + "cuda-libraries", + "cuda-opencl", + "cuda-nvrtc", + "cuda-runtime", + "cublas", + "cudnn", + "cufft", + "curand", + "cusolver", + "cusparse", + "nccl", + "nvjitlink", + "nvtx", +] + +CONDA_PATTERNS = [ "cudatoolkit", "soumith", "mkl", "magma", - "triton", - "optree", -} +] -DEFAULT_PIP_PATTERNS = { - "torch", - "numpy", +PIP_PATTERNS = [ "mypy", "flake8", - "triton", - "optree", "onnx", -} +] def run(command): @@ -113,7 +131,7 @@ def run_and_return_first_line(run_lambda, command): def get_conda_packages(run_lambda, patterns=None): if patterns is None: - patterns = DEFAULT_CONDA_PATTERNS + patterns = CONDA_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS conda = os.environ.get('CONDA_EXE', 'conda') out = run_and_read_all(run_lambda, "{} list".format(conda)) if out is None: @@ -305,8 +323,25 @@ def get_cpu_info(run_lambda): if get_platform() == 'linux': rc, out, err = run_lambda('lscpu') elif get_platform() == 'win32': - rc, out, err = run_lambda('wmic cpu get Name,Manufacturer,Family,Architecture,ProcessorType,DeviceID, \ - CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision /VALUE') + rc, out, err = run_lambda( + 'powershell.exe "gwmi -Class Win32_Processor | Select-Object -Property Name,Manufacturer,Family,\ + Architecture,ProcessorType,DeviceID,CurrentClockSpeed,MaxClockSpeed,L2CacheSize,L2CacheSpeed,Revision\ + | ConvertTo-Json"' + ) + if rc == 0: + lst = [] + try: + obj = json.loads(out) + if type(obj) is list: + for o in obj: + lst.append("----------------------") + lst.extend([f"{k}: {v}" for (k, v) in o.items()]) + else: + lst.extend([f"{k}: {v}" for (k, v) in obj.items()]) + except ValueError as e: + lst.append(out) + lst.append(str(e)) + out = "\n".join(lst) elif get_platform() == 'darwin': rc, out, err = run_lambda("sysctl -n machdep.cpu.brand_string") cpu_info = 'None' @@ -335,10 +370,17 @@ def get_mac_version(run_lambda): def get_windows_version(run_lambda): - system_root = os.environ.get('SYSTEMROOT', 'C:\\Windows') - wmic_cmd = os.path.join(system_root, 'System32', 'Wbem', 'wmic') - findstr_cmd = os.path.join(system_root, 'System32', 'findstr') - return run_and_read_all(run_lambda, '{} os get Caption | {} /v Caption'.format(wmic_cmd, findstr_cmd)) + ret = run_and_read_all( + run_lambda, + 'powershell.exe "gwmi -Class Win32_OperatingSystem | Select-Object -Property Caption,\ + OSArchitecture,Version | ConvertTo-Json"', + ) + try: + obj = json.loads(ret) + ret = f'{obj["Caption"]} ({obj["Version"]} {obj["OSArchitecture"]})' + except ValueError as e: + ret += f"\n{str(e)}" + return ret def get_lsb_version(run_lambda): @@ -395,7 +437,7 @@ def get_libc_version(): def get_pip_packages(run_lambda, patterns=None): """Return `pip list` output. Note: will also find conda-installed pytorch and numpy packages.""" if patterns is None: - patterns = DEFAULT_PIP_PATTERNS + patterns = PIP_PATTERNS + COMMON_PATTERNS + NVIDIA_PATTERNS # People generally have `pip` as `pip` or `pip3` # But here it is invoked as `python -mpip` diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index b25fedc908146..e1e260bd6448d 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -4,6 +4,7 @@ import importlib import importlib.abc import os +import platform import re import shlex import shutil @@ -994,7 +995,7 @@ def CppExtension(name, sources, *args, **kwargs): libraries.append('torch') libraries.append('torch_cpu') libraries.append('torch_python') - if IS_WINDOWS: + if IS_WINDOWS and platform.machine().lower() != "arm64": libraries.append("sleef") kwargs['libraries'] = libraries @@ -1400,10 +1401,10 @@ def check_compiler_is_gcc(compiler): env['LC_ALL'] = 'C' # Don't localize output try: version_string = subprocess.check_output([compiler, '-v'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) - except Exception as e: + except Exception: try: version_string = subprocess.check_output([compiler, '--version'], stderr=subprocess.STDOUT, env=env).decode(*SUBPROCESS_DECODE_ARGS) - except Exception as e: + except Exception: return False # Check for 'gcc' or 'g++' for sccache wrapper pattern = re.compile("^COLLECT_GCC=(.*)$", re.MULTILINE) diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 87a450461317e..8522583a20d51 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -1268,7 +1268,7 @@ def _try_get_data(self, timeout=_utils.MP_STATUS_CHECK_INTERVAL): # test. # See NOTE [ DataLoader on Linux and open files limit ] fds_limit_margin = 10 - fs = [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)] + [tempfile.NamedTemporaryFile() for i in range(fds_limit_margin)] except OSError as e: if e.errno == errno.EMFILE: raise RuntimeError( diff --git a/torch/utils/data/datapipes/_hook_iterator.py b/torch/utils/data/datapipes/_hook_iterator.py index a3f91516038ae..ae42f75885c1d 100644 --- a/torch/utils/data/datapipes/_hook_iterator.py +++ b/torch/utils/data/datapipes/_hook_iterator.py @@ -214,7 +214,7 @@ def wrap_generator(*args, **kwargs): else: # Decided against using `contextlib.nullcontext` for performance reasons _check_iterator_valid(datapipe, iterator_id) response = gen.send(request) - except StopIteration as e: + except StopIteration: return except Exception as e: # TODO: Simplify the traceback message to skip over `response = gen.send(None)` diff --git a/torch/utils/data/datapipes/gen_pyi.py b/torch/utils/data/datapipes/gen_pyi.py index fbed7b5246963..dbe448b65beb1 100644 --- a/torch/utils/data/datapipes/gen_pyi.py +++ b/torch/utils/data/datapipes/gen_pyi.py @@ -188,7 +188,7 @@ def process_signature(line: str) -> str: # Remove the datapipe after 'self' or 'cls' unless it has '*' tokens[i] = "" elif "Callable =" in token: # Remove default argument if it is a function - head, default_arg = token.rsplit("=", 2) + head, _default_arg = token.rsplit("=", 2) tokens[i] = head.strip(" ") + "= ..." tokens = [t for t in tokens if t != ""] line = ", ".join(tokens) diff --git a/torch/utils/data/sampler.py b/torch/utils/data/sampler.py index 6da35f8192b5c..4e89c24aca575 100644 --- a/torch/utils/data/sampler.py +++ b/torch/utils/data/sampler.py @@ -1,4 +1,5 @@ # mypy: allow-untyped-defs +import itertools from typing import ( Generic, Iterable, @@ -333,26 +334,17 @@ def __init__( def __iter__(self) -> Iterator[List[int]]: # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 + sampler_iter = iter(self.sampler) if self.drop_last: - sampler_iter = iter(self.sampler) - while True: - try: - batch = [next(sampler_iter) for _ in range(self.batch_size)] - yield batch - except StopIteration: - break + # Create multiple references to the same iterator + args = [sampler_iter] * self.batch_size + for batch_droplast in zip(*args): + yield [*batch_droplast] else: - batch = [0] * self.batch_size - idx_in_batch = 0 - for idx in self.sampler: - batch[idx_in_batch] = idx - idx_in_batch += 1 - if idx_in_batch == self.batch_size: - yield batch - idx_in_batch = 0 - batch = [0] * self.batch_size - if idx_in_batch > 0: - yield batch[:idx_in_batch] + batch = [*itertools.islice(sampler_iter, self.batch_size)] + while batch: + yield batch + batch = [*itertools.islice(sampler_iter, self.batch_size)] def __len__(self) -> int: # Can only be called if self.sampler has __len__ implemented diff --git a/torch/utils/flop_counter.py b/torch/utils/flop_counter.py index a6b2195fdf694..e2d92d75c5044 100644 --- a/torch/utils/flop_counter.py +++ b/torch/utils/flop_counter.py @@ -698,8 +698,8 @@ def process_mod(mod_name, depth): # if there are any FLOPs in there that aren't already fully contained by # a module. if 'Global' in self.flop_counts and not is_global_subsumed: - for idx in range(len(values)): - values[idx][0] = " " + values[idx][0] + for value in values: + value[0] = " " + value[0] values = process_mod('Global', 0) + values diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index b9817b4d5ebdc..7ba128cb217a8 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -1,4 +1,5 @@ import collections +import os from .constants import (API_BLAS, API_C10, API_CAFFE2, API_DRIVER, API_FFT, API_PYTORCH, API_RAND, API_ROCTX, API_RTC, API_RUNTIME, @@ -24,6 +25,12 @@ supported in ROCm/HIP yet. """ +_IS_FBCODE = os.environ.get("IS_FBCODE", "0") == "1" + +# FBCODE compiles against rccl sources instead of an installed rccl package. +# The header location is src/rccl.h versus rccl/rccl.h, respectively. +_RCCL_HEADER = "" if _IS_FBCODE else "" + # List of math functions that should be replaced inside device code only. MATH_TRANSPILATIONS = collections.OrderedDict( [ @@ -603,7 +610,7 @@ ("cufft.h", ("hipfft/hipfft.h", CONV_INCLUDE, API_BLAS)), ("cufftXt.h", ("hipfft/hipfftXt.h", CONV_INCLUDE, API_BLAS)), # PyTorch also has a source file named "nccl.h", so we need to "<"">" to differentiate - ("", ("", CONV_INCLUDE, API_RUNTIME)), + ("", (_RCCL_HEADER, CONV_INCLUDE, API_RUNTIME)), ("nvrtc.h", ("hip/hiprtc.h", CONV_INCLUDE, API_RTC)), ("thrust/system/cuda", ("thrust/system/hip", CONV_INCLUDE, API_BLAS)), ("cub/util_allocator.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)), @@ -7923,6 +7930,7 @@ ("cub::BlockLoad", ("hipcub::BlockLoad", CONV_SPECIAL_FUNC, API_RUNTIME)), ("cub::BlockStore", ("hipcub::BlockStore", CONV_SPECIAL_FUNC, API_RUNTIME)), ("cub::BlockRakingLayout", ("hipcub::BlockRakingLayout", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::BlockRadixSort", ("hipcub::BlockRadixSort", CONV_SPECIAL_FUNC, API_RUNTIME)), ("cub::Uninitialized", ("hipcub::Uninitialized", CONV_SPECIAL_FUNC, API_RUNTIME)), ("cub::RowMajorTid", ("hipcub::RowMajorTid", CONV_SPECIAL_FUNC, API_RUNTIME)), ("cub::CachingDeviceAllocator", ("hipcub::CachingDeviceAllocator", CONV_SPECIAL_FUNC, API_RUNTIME)), @@ -7934,6 +7942,7 @@ ("cub::DeviceSegmentedRadixSort", ("hipcub::DeviceSegmentedRadixSort", CONV_SPECIAL_FUNC, API_RUNTIME)), ("cub::DeviceSegmentedReduce", ("hipcub::DeviceSegmentedReduce", CONV_SPECIAL_FUNC, API_RUNTIME)), ("cub::DeviceSelect", ("hipcub::DeviceSelect", CONV_SPECIAL_FUNC, API_RUNTIME)), + ("cub::FpLimits", ("hipcub::FpLimits", CONV_SPECIAL_FUNC, API_RUNTIME)), ("cub::KeyValuePair", ("hipcub::KeyValuePair", CONV_SPECIAL_FUNC, API_RUNTIME)), ("cub::Max", ("hipcub::Max", CONV_SPECIAL_FUNC, API_RUNTIME)), ("cub::Min", ("hipcub::Min", CONV_SPECIAL_FUNC, API_RUNTIME)), diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 31f38dfd9527b..dc19b6fd84d1f 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -868,7 +868,7 @@ def c2_repl(m): def mk_repl(templ, include_current_dir=True): def repl(m): f = m.group(1) - dirpath, filename = os.path.split(f) + filename = os.path.basename(f) if ( f.startswith(("ATen/cuda", "ATen/native/cuda", diff --git a/torch/utils/jit/log_extract.py b/torch/utils/jit/log_extract.py index 51894f495e8e7..88ffe7bc5926d 100644 --- a/torch/utils/jit/log_extract.py +++ b/torch/utils/jit/log_extract.py @@ -10,7 +10,6 @@ def extract_ir(filename: str) -> List[str]: BEGIN = "" END = "" pfx = None - current = "" graphs = [] with open(filename) as f: split_strs = f.read().split(BEGIN) diff --git a/torch/utils/model_dump/__init__.py b/torch/utils/model_dump/__init__.py index f2cd974798f91..de662e794b0db 100644 --- a/torch/utils/model_dump/__init__.py +++ b/torch/utils/model_dump/__init__.py @@ -130,15 +130,12 @@ def hierarchical_pickle(data): } if typename == "torch._utils._rebuild_tensor_v2": assert data.state is None - if len(data.args) == 6: - storage, offset, size, stride, requires_grad, hooks = data.args - else: - storage, offset, size, stride, requires_grad, hooks, metadata = data.args + storage, offset, size, stride, requires_grad, *_ = data.args storage_info = get_storage_info(storage) return {"__tensor_v2__": [storage_info, offset, size, stride, requires_grad]} if typename == "torch._utils._rebuild_qtensor": assert data.state is None - storage, offset, size, stride, quantizer, requires_grad, hooks = data.args + storage, offset, size, stride, quantizer, requires_grad, *_ = data.args storage_info = get_storage_info(storage) assert isinstance(quantizer, tuple) assert isinstance(quantizer[0], torch.utils.show_pickle.FakeClass) diff --git a/torch/utils/tensorboard/_pytorch_graph.py b/torch/utils/tensorboard/_pytorch_graph.py index d3d2f37cad749..502675ef95661 100644 --- a/torch/utils/tensorboard/_pytorch_graph.py +++ b/torch/utils/tensorboard/_pytorch_graph.py @@ -241,9 +241,6 @@ def parse(graph, trace, args=None, omit_useless_nodes=True): args (tuple): input tensor[s] for the model. omit_useless_nodes (boolean): Whether to remove nodes from the graph. """ - n_inputs = len(args) - - scope = {} nodes_py = GraphPy() for node in graph.inputs(): if omit_useless_nodes: @@ -264,7 +261,6 @@ def parse(graph, trace, args=None, omit_useless_nodes=True): if ( parent.kind() == GETATTR_KIND ): # If the parent node is not the top-level "self" node - parent_attr_name = parent.s("name") parent_attr_key = parent.output().debugName() parent_scope = attr_to_scope[parent_attr_key] attr_scope = parent_scope.split("/")[-1] diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index 29d1db2006b26..e5346f5bdcdd6 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -665,7 +665,7 @@ def make_video(tensor, fps): return import tempfile - t, h, w, c = tensor.shape + _t, h, w, c = tensor.shape # encode sequence of images into gif string clip = mpy.ImageSequenceClip(list(tensor), fps=fps) diff --git a/torch/utils/viz/_cycles.py b/torch/utils/viz/_cycles.py index 8c1b9da7a6ad0..1fc53c503ff7b 100644 --- a/torch/utils/viz/_cycles.py +++ b/torch/utils/viz/_cycles.py @@ -279,7 +279,6 @@ def create_graph(objects, *, context=None, filter=None): tidx = id_to_node.get(rid, None) if tidx is None: continue - t = nodes[tidx] labels = references.get(rid, ["?"]) node_referrers[tidx].append(fidx) for label in labels: @@ -320,7 +319,7 @@ def cuda_allocation_context(): addr = seg['address'] for blk in seg['blocks']: if blk['state'] == 'active_allocated': - frames, real_size = _block_extra(blk) + frames, _real_size = _block_extra(blk) addr_to_frame[addr] = frames addr += blk['size'] diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index baf1e844ebc01..380c30bcc2979 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -395,6 +395,24 @@ def synchronize(device: _device_t = None) -> None: return torch._C._xpu_synchronize(device) +def get_arch_list() -> List[str]: + r"""Return list XPU architectures this library was compiled for.""" + if not is_available(): + return [] + arch_flags = torch._C._xpu_getArchFlags() + if arch_flags is None: + return [] + return arch_flags.split() + + +def get_gencode_flags() -> str: + r"""Return XPU AOT(ahead-of-time) build flags this library was compiled with.""" + arch_list = get_arch_list() + if len(arch_list) == 0: + return "" + return f'-device {",".join(arch for arch in arch_list)}' + + def _get_generator(device: torch.device) -> torch._C.Generator: r"""Return the XPU Generator object for the given device. @@ -478,9 +496,11 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: "device_of", "device_count", "empty_cache", + "get_arch_list", "get_device_capability", "get_device_name", "get_device_properties", + "get_gencode_flags", "get_rng_state", "get_rng_state_all", "get_stream", diff --git a/torchgen/_autoheuristic/benchmark_runner.py b/torchgen/_autoheuristic/benchmark_runner.py index 999ea48cbe116..3a1579c493493 100644 --- a/torchgen/_autoheuristic/benchmark_runner.py +++ b/torchgen/_autoheuristic/benchmark_runner.py @@ -68,12 +68,10 @@ def run(self) -> None: self.main(args.num_samples, args.num_reps) @abstractmethod - def run_benchmark(self, *args: Any) -> None: - ... + def run_benchmark(self, *args: Any) -> None: ... @abstractmethod - def create_input(self) -> Tuple[Any, ...]: - ... + def create_input(self) -> Tuple[Any, ...]: ... def main(self, num_samples: int, num_reps: int) -> None: for _ in tqdm(range(num_samples)): diff --git a/torchgen/_autoheuristic/train_decision.py b/torchgen/_autoheuristic/train_decision.py index 31cc7632fac69..bea7bde90ab56 100644 --- a/torchgen/_autoheuristic/train_decision.py +++ b/torchgen/_autoheuristic/train_decision.py @@ -449,8 +449,8 @@ def get_winner_and_speedup(group): for row in group.itertuples(): choice2time[row.choice] = row.median_execution_time - assert len(unique_choices) == len( - group + assert ( + len(unique_choices) == len(group) ), f"len(unique_choices) != len(group): {len(unique_choices)} != {len(group)}" return pd.Series( diff --git a/torchgen/api/cpp.py b/torchgen/api/cpp.py index c657570ee3e24..6cc40d66037d7 100644 --- a/torchgen/api/cpp.py +++ b/torchgen/api/cpp.py @@ -253,9 +253,7 @@ def returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType: elif t.name == BaseTy.Scalar: return BaseCType(scalarT) elif isinstance(t, ListType): - assert ( - not mutable - ), "Native functions should never return a mutable tensor list. They should return void." + assert not mutable, "Native functions should never return a mutable tensor list. They should return void." elem = returntype_type(t.elem, mutable=False) assert t.size is None, f"fixed size list returns not supported: {t}" return VectorCType(elem) diff --git a/torchgen/api/lazy.py b/torchgen/api/lazy.py index cfffa516b656b..b6094a2558832 100644 --- a/torchgen/api/lazy.py +++ b/torchgen/api/lazy.py @@ -378,7 +378,8 @@ def __init__( self.generator_arg is None ), "We expect there is only one generator arg" self.generator_arg = NamedCType( - arg.name, arg.type # type:ignore[arg-type] + arg.name, + arg.type, # type:ignore[arg-type] ) keyword_args.extend( LazyArgument(arg, self.properties, symint=symint) diff --git a/torchgen/api/python.py b/torchgen/api/python.py index eb0f074898872..7c27e815b5e97 100644 --- a/torchgen/api/python.py +++ b/torchgen/api/python.py @@ -551,9 +551,9 @@ def from_pairs( # Out overloads in C++ don't have TensorOptions arguments, # so take these from the functional variant - signature_kwargs[ - "tensor_options_args" - ] = functional.signature.tensor_options_args + signature_kwargs["tensor_options_args"] = ( + functional.signature.tensor_options_args + ) return PythonSignatureGroup( signature=type(out.signature)(**signature_kwargs), diff --git a/torchgen/api/translate.py b/torchgen/api/translate.py index 761fb3c7c2b98..6e62816cac693 100644 --- a/torchgen/api/translate.py +++ b/torchgen/api/translate.py @@ -164,42 +164,42 @@ def translate( and isinstance(t.elem.elem, BaseCType) and str(t.elem.elem.type) == "at::Tensor" ): - ctx[ - NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT))) - ] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())" + ctx[NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))] = ( + f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())" + ) if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))): - ctx[ - NamedCType(t.name, BaseCType(optionalTensorRefT)) - ] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())" + ctx[NamedCType(t.name, BaseCType(optionalTensorRefT))] = ( + f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())" + ) if t.type == ConstRefCType(BaseCType(scalarT)): ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to()" if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))): - ctx[ - NamedCType(t.name, BaseCType(optionalScalarRefT)) - ] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())" + ctx[NamedCType(t.name, BaseCType(optionalScalarRefT))] = ( + f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())" + ) if t.type == BaseCType(scalar_t): - ctx[ - NamedCType(t.name, BaseCType(opmath_t)) - ] = f"static_cast({b.expr})" + ctx[NamedCType(t.name, BaseCType(opmath_t))] = ( + f"static_cast({b.expr})" + ) # [Note: IOptTensorListRef] if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))): - ctx[ - NamedCType(t.name, BaseCType(iOptTensorListRefT)) - ] = f"at::IOptTensorListRef({b.expr})" + ctx[NamedCType(t.name, BaseCType(iOptTensorListRefT))] = ( + f"at::IOptTensorListRef({b.expr})" + ) # Add implicit bindings if the generated code is inside a Tensor method if method: - ctx[ - NamedCType("self", MutRefCType(BaseCType(tensorT))) - ] = "const_cast(*this)" - ctx[ - NamedCType("self", ConstRefCType(BaseCType(tensorT))) - ] = "const_cast(*this)" + ctx[NamedCType("self", MutRefCType(BaseCType(tensorT)))] = ( + "const_cast(*this)" + ) + ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = ( + "const_cast(*this)" + ) # This is better! Byte-for-byte compat # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this" diff --git a/torchgen/api/types/signatures.py b/torchgen/api/types/signatures.py index f7d85ca6e2fe8..7e0a4b91037a3 100644 --- a/torchgen/api/types/signatures.py +++ b/torchgen/api/types/signatures.py @@ -406,9 +406,7 @@ def kernel_signature( meta = backend_index.get_kernel(f) symint = meta is not None and meta.supports_symint() if symint: - assert ( - f.func.has_symint() - ), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema" + assert f.func.has_symint(), f"attempted to define symint kernel for {backend_index.dispatch_key} without SymInt in schema" if backend_index.external: return DispatcherSignature.from_schema(f.func, prefix=prefix, symint=symint) else: diff --git a/torchgen/api/unboxing.py b/torchgen/api/unboxing.py index 1e649b7517889..edb48ec5d172a 100644 --- a/torchgen/api/unboxing.py +++ b/torchgen/api/unboxing.py @@ -194,9 +194,7 @@ def _gen_code_optional_type( }} else {{ {out_name} = {ctype.cpp_type(strip_ref=True)}(); }} - """.split( - "\n" - ), + """.split("\n"), decl, ) @@ -213,9 +211,7 @@ def _gen_code_list_type( code.extend( f""" {ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name}); - """.split( - "\n" - ) + """.split("\n") ) # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional> elif isinstance(t.elem, OptionalType): @@ -226,9 +222,7 @@ def _gen_code_list_type( {connector.join(res_code)} {out_name}.push_back({res_name}); }} - """.split( - "\n" - ) + """.split("\n") ) else: # use ArrayRef as default. @@ -242,8 +236,6 @@ def _gen_code_list_type( {vec_name}.push_back({res_name}); }} {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name}); - """.split( - "\n" - ) + """.split("\n") ) return code, decl diff --git a/torchgen/context.py b/torchgen/context.py index a20310498164b..d257bf99243da 100644 --- a/torchgen/context.py +++ b/torchgen/context.py @@ -95,7 +95,7 @@ def wrapper(slf: S, f: F) -> T: def method_with_nested_native_function( - func: Callable[[S, F3], T] + func: Callable[[S, F3], T], ) -> Callable[[S, F3], T]: @functools.wraps(func) def wrapper(slf: S, f: F3) -> T: @@ -108,7 +108,7 @@ def wrapper(slf: S, f: F3) -> T: # Convenience decorator for functions that explicitly take in a BackendIndex, # instead of indirectly taking one in as a closure def with_native_function_and_index( - func: Callable[[F, BackendIndex], T] + func: Callable[[F, BackendIndex], T], ) -> Callable[[F, BackendIndex], T]: @functools.wraps(func) def wrapper(f: F, backend_index: BackendIndex) -> T: @@ -120,7 +120,7 @@ def wrapper(f: F, backend_index: BackendIndex) -> T: # Convenience decorator for functions that explicitly take in a Dict of BackendIndices def with_native_function_and_indices( - func: Callable[[F, dict[DispatchKey, BackendIndex]], T] + func: Callable[[F, dict[DispatchKey, BackendIndex]], T], ) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]: @functools.wraps(func) def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T: diff --git a/torchgen/dest/native_functions.py b/torchgen/dest/native_functions.py index a93405555bc22..e9bf2dcb0d074 100644 --- a/torchgen/dest/native_functions.py +++ b/torchgen/dest/native_functions.py @@ -8,6 +8,27 @@ from torchgen.utils import mapMaybe +def torch_api_key_word_prefix(bankend_index: BackendIndex) -> str: + if bankend_index.external: + return "" + + # Although Intel GPU ATen library is out-of-tree, it still utilizes torchgen to produce structrued + # kernels. Regarding these produced structured kernels, they should be visible for the Intel GPU ATen + # library. Therefore, we need to add "TORCH_XPU_API" prefix to these structured kernels, + # rather than "TORCH_API". Because the semantic of "TORCH_API" is "hidden" for out-of-tree backends. + # For other in-tree backends like cpu and cuda, they still use "TORCH_API" prefix with "visible" semantic. + device_torch_api_key_word_mapping = { + "XPU": "TORCH_XPU_API", + } + + return ( + device_torch_api_key_word_mapping.get( + bankend_index.dispatch_key.name, "TORCH_API" + ) + + " " + ) + + @with_native_function_and_index def gen_unstructured(f: NativeFunction, backend_index: BackendIndex) -> str | None: sig = kernel_signature(f, backend_index) @@ -28,7 +49,7 @@ def gen_structured(g: NativeFunctionsGroup, backend_index: BackendIndex) -> list metadata = backend_index.get_kernel(g) if metadata is None: return [] - prefix = "" if backend_index.external else "TORCH_API " + prefix = torch_api_key_word_prefix(backend_index) return [ f"""\ struct {prefix}structured_{metadata.kernel} : public at::meta::structured_{meta_name} {{ diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index 091bec237238e..cb7dc00a60b85 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -515,9 +515,7 @@ def generate_defn(cpp_sig: CppSignature) -> str: # CUDA requires special handling if is_cuda_dispatch_key(self.backend_index.dispatch_key): - device_guard = ( - f"globalContext().lazyInitCUDA();\n{device_guard}" - ) + device_guard = f"globalContext().lazyInitDevice(c10::DeviceType::CUDA);\n{device_guard}" else: # kernel is operating on existing tensors diff --git a/torchgen/executorch/api/et_cpp.py b/torchgen/executorch/api/et_cpp.py index 76cebcd0f0f1d..1d7672715a62c 100644 --- a/torchgen/executorch/api/et_cpp.py +++ b/torchgen/executorch/api/et_cpp.py @@ -184,9 +184,7 @@ def returntype_type(t: Type, *, mutable: bool) -> CType: elif t.name == BaseTy.Scalar: return BaseCType(scalarT) elif isinstance(t, ListType): - assert ( - not mutable - ), "Native functions should never return a mutable tensor list. They should return void." + assert not mutable, "Native functions should never return a mutable tensor list. They should return void." elem = returntype_type(t.elem, mutable=False) assert t.size is None, f"fixed size list returns not supported: {t}" return VectorCType(elem) diff --git a/torchgen/executorch/api/unboxing.py b/torchgen/executorch/api/unboxing.py index 6845e72a22a5d..999147212a1a1 100644 --- a/torchgen/executorch/api/unboxing.py +++ b/torchgen/executorch/api/unboxing.py @@ -127,9 +127,7 @@ def _gen_code_optional_type( return ( f""" auto {out_name} = {arg_name}.toOptional<{base_type.cpp_type(strip_ref=True)}>(); - """.split( - "\n" - ), + """.split("\n"), decl, ) @@ -147,9 +145,7 @@ def _gen_code_list_type( code.extend( f""" auto {out_name} = {arg_name}.toTensorList(); - """.split( - "\n" - ) + """.split("\n") ) elif isinstance(t.elem, BaseType) and ( t.elem.name == BaseTy.int or t.elem.name == BaseTy.SymInt @@ -157,17 +153,13 @@ def _gen_code_list_type( code.extend( f""" auto {out_name} = {arg_name}.toIntList(); - """.split( - "\n" - ) + """.split("\n") ) elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.float: code.extend( f""" auto {out_name} = {arg_name}.toDoubleList(); - """.split( - "\n" - ) + """.split("\n") ) elif isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool: # handle list type with size, e.g., bool[4] @@ -183,9 +175,7 @@ def _gen_code_list_type( #else auto {out_name} = {arg_name}.toBoolList(); #endif - """.split( - "\n" - ) + """.split("\n") ) # pytorch codegen: # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional> @@ -205,9 +195,7 @@ def _gen_code_list_type( #else auto {out_name} = {arg_name}.toListOptionalTensor(); #endif - """.split( - "\n" - ) + """.split("\n") ) else: # use ArrayRef as default. @@ -223,8 +211,6 @@ def _gen_code_list_type( {vec_name}.push_back({res_name}); }} {ctype.cpp_type(strip_ref=True)} {out_name}({vec_name}); - """.split( - "\n" - ) + """.split("\n") ) return code, decl diff --git a/torchgen/executorch/model.py b/torchgen/executorch/model.py index 6aadfe41daed2..fe46d04d6c449 100644 --- a/torchgen/executorch/model.py +++ b/torchgen/executorch/model.py @@ -96,7 +96,7 @@ def gen_from_yaml( ) assert ( dim_order in dim_order_alias_map - ), "Undefined dim_order alias: " + str(dim_order) + ), f"Undefined dim_order alias: {dim_order}" dtype_alias_used.add(type_alias) # Generate all permutations of dtype alias values @@ -172,11 +172,11 @@ def grow_from_backend_indices( @staticmethod def from_backend_indices( - backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]] + backend_indices: dict[DispatchKey, dict[OperatorName, BackendMetadata]], ) -> ETKernelIndex: - kernel_index: dict[ - OperatorName, dict[ETKernelKey, BackendMetadata] - ] = defaultdict(dict) + kernel_index: dict[OperatorName, dict[ETKernelKey, BackendMetadata]] = ( + defaultdict(dict) + ) ETKernelIndex.grow_from_backend_indices(kernel_index, backend_indices) return ETKernelIndex(kernel_index) diff --git a/torchgen/gen.py b/torchgen/gen.py index e5870a24fc668..ab918577c6160 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -600,19 +600,15 @@ def __call__(self, f: NativeFunction) -> str: using schema = {sig.type()}; using ptr_schema = schema*; // See Note [static constexpr char* members for windows NVCC] - STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(name, "aten::{f.func.name.name}") - STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(overload_name, "{f.func.name.overload_name}") - STATIC_CONSTEXPR_STR_INL_EXCEPT_WIN_CUDA(schema_str, {cpp_string(str(f.func))}) + static constexpr const char* name = "aten::{f.func.name.name}"; + static constexpr const char* overload_name = "{f.func.name.overload_name}"; + static constexpr const char* schema_str = {cpp_string(str(f.func))}; static {sig.defn(name="call", is_redispatching_fn=False)}; static {sig.defn(name="redispatch", is_redispatching_fn=True)}; }};""" elif self.target is Target.DEFINITION: defns = f""" -STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, name, "aten::{f.func.name.name}") -STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, overload_name, "{f.func.name.overload_name}") -STATIC_CONST_STR_OUT_OF_LINE_FOR_WIN_CUDA({name}, schema_str, {cpp_string(str(f.func))}) - // aten::{f.func} static C10_NOINLINE c10::TypedOperatorHandle<{name}::schema> create_{name}_typed_handle() {{ return c10::Dispatcher::singleton() @@ -1362,7 +1358,7 @@ def get_grouped_by_view_native_functions( native_functions: Sequence[NativeFunction], ) -> Sequence[NativeFunction | NativeFunctionsViewGroup]: def maybe_create_view_group( - d: dict[ViewSchemaKind | SchemaKind, NativeFunction] + d: dict[ViewSchemaKind | SchemaKind, NativeFunction], ) -> list[NativeFunction | NativeFunctionsViewGroup]: funcs: list[NativeFunction | NativeFunctionsViewGroup] = [] if ViewSchemaKind.aliasing in d: @@ -1409,7 +1405,7 @@ def get_grouped_native_functions( native_functions: Sequence[NativeFunction], ) -> Sequence[NativeFunction | NativeFunctionsGroup]: def flatten_pre_group( - d: dict[SchemaKind, NativeFunction] + d: dict[SchemaKind, NativeFunction], ) -> Sequence[NativeFunction | NativeFunctionsGroup]: r = NativeFunctionsGroup.from_dict(d) if r is None: @@ -1476,9 +1472,7 @@ def get_native_function_declarations_from_ns_grouped_kernels( {ns_helper.prologue} {newline.join(ordered_kernels)} {ns_helper.epilogue} - """.split( - newline - ) + """.split(newline) ) return declarations @@ -1671,9 +1665,7 @@ def get_namespaced_declaration( {ns_helper.prologue} {newline.join(ordered_kernels)} {ns_helper.epilogue} - """.split( - newline - ) + """.split(newline) ) return declarations @@ -2386,9 +2378,7 @@ def operator_headers() -> list[str]: os.path.join(aoti_fm.install_dir, header_file_name) ) as old_file: old_header = old_file.read() - assert ( - old_header == new_header - ), """ + assert old_header == new_header, """ WARNING: The generated AOTInductor C shim header files have unexpectedly changed. This indicates an AOTInductor fallback operator ABI backward compatibility breakage!!! diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index 5ba12f88bdd9d..942e4cd8290ce 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -71,7 +71,10 @@ # convert args to C types, names in declarations, and expressions in function bodies -def convert_arg_type_and_name(typ: Type, name: str) -> tuple[list[str], list[str], list[str], list[str]]: # type: ignore[return] +def convert_arg_type_and_name( # type: ignore[return] + typ: Type, + name: str, +) -> tuple[list[str], list[str], list[str], list[str]]: if isinstance(typ, BaseType): if typ.name in base_type_to_c_type: return ( diff --git a/torchgen/gen_executorch.py b/torchgen/gen_executorch.py index d29713568e665..902ffa3889e64 100644 --- a/torchgen/gen_executorch.py +++ b/torchgen/gen_executorch.py @@ -295,7 +295,7 @@ def gen_unboxing( ) -> None: # Iterable type for write_sharded is a Tuple of (native_function, (kernel_key, metadata)) def key_func( - item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]] + item: tuple[NativeFunction, tuple[ETKernelKey, BackendMetadata]], ) -> str: return item[0].root_name + ":" + item[1][0].to_native_string() @@ -739,7 +739,7 @@ def parse_yaml( # (2) Return BackendIndices if kernel index is absent def map_index( - m: dict[OperatorName, BackendMetadata] + m: dict[OperatorName, BackendMetadata], ) -> dict[OperatorName, BackendMetadata]: return {op: m[op] for op in m if op in op_names} diff --git a/torchgen/gen_functionalization_type.py b/torchgen/gen_functionalization_type.py index fbc9459eb5e64..afa4218002b55 100644 --- a/torchgen/gen_functionalization_type.py +++ b/torchgen/gen_functionalization_type.py @@ -278,13 +278,13 @@ def is_alias(a: Argument) -> bool: args = func.arguments.flat_non_out # The first argument is a tensor with an alias semantics (annotations) - assert len(args) > 0 and args[0].type == BaseType( - BaseTy.Tensor + assert ( + len(args) > 0 and args[0].type == BaseType(BaseTy.Tensor) ), f"""In the functionalization codegen, we expect the first argument of every view operator to be a tensor, but found an argument of type {str(args[0].type)} for operator: {str(func.name)}.""" # No other arguments have aliasing semantics - assert is_alias(args[0]) and not any( - is_alias(a) for a in args[1:] + assert ( + is_alias(args[0]) and not any(is_alias(a) for a in args[1:]) ), """In the functionalization codegen, we expect the first argument of every view operator to alias the output. View operators with multiple aliasing inputs aren't supported yet. Found an operator that doesn't satisfy this constraint""" diff --git a/torchgen/gen_lazy_tensor.py b/torchgen/gen_lazy_tensor.py index 884f645cc4b5b..a4223ad505707 100644 --- a/torchgen/gen_lazy_tensor.py +++ b/torchgen/gen_lazy_tensor.py @@ -176,9 +176,9 @@ class default_args: tensor_class: str = "torch::lazy::LazyTensor" tensor_class_hdr: str = "torch/csrc/lazy/core/tensor.h" lazy_ir_generator: type[GenLazyIR] = GenLazyIR - native_func_definition_generator: type[ + native_func_definition_generator: type[GenLazyNativeFuncDefinition] = ( GenLazyNativeFuncDefinition - ] = GenLazyNativeFuncDefinition + ) backend_name: str = "TorchScript" @@ -257,9 +257,9 @@ def main() -> None: lazy_ir_generator: type[GenLazyIR] = default_args.lazy_ir_generator if options.gen_ts_lowerings: lazy_ir_generator = GenTSLazyIR - native_func_definition_generator: type[ - GenLazyNativeFuncDefinition - ] = default_args.native_func_definition_generator + native_func_definition_generator: type[GenLazyNativeFuncDefinition] = ( + default_args.native_func_definition_generator + ) run_gen_lazy_tensor( aten_path, diff --git a/torchgen/model.py b/torchgen/model.py index 956949343101a..4fe4f9cab5569 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -1484,14 +1484,15 @@ def __post_init__(self) -> None: else: # mutable keyword arguments whose name has _scratch_ prefix are # scratch tensors for memory planning and should not be returned - assert len( - [ - arg - for arg in self.arguments.out - if not arg.name.startswith("_scratch_") - ] - ) == len( - self.returns + assert ( + len( + [ + arg + for arg in self.arguments.out + if not arg.name.startswith("_scratch_") + ] + ) + == len(self.returns) ), "Must return as many arguments as there are out arguments, or no return at all" if self.name.name.inplace: @@ -1590,9 +1591,7 @@ def kind(self) -> SchemaKind: ), "invariant: all scratch operators are expected to be out= operators too" return SchemaKind.scratch elif is_out: - assert ( - not is_scratch - ), "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!" + assert not is_scratch, "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!" # noqa: B950 return SchemaKind.out elif is_mutable: return SchemaKind.mutable @@ -2701,9 +2700,7 @@ def __post_init__(self) -> None: ) if self.view.has_composite_implicit_autograd_nested_tensor_kernel: if self.view_inplace is not None: - assert ( - self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel - ), ( + assert self.view_inplace.has_composite_implicit_autograd_nested_tensor_kernel, ( f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either" " both have CompositeImplicitAutogradNestedTensor kernels, or both not have composite kernels." ) diff --git a/torchgen/native_function_generation.py b/torchgen/native_function_generation.py index a44efab68426d..1ae4599407c02 100644 --- a/torchgen/native_function_generation.py +++ b/torchgen/native_function_generation.py @@ -1,5 +1,6 @@ from __future__ import annotations +import string from collections import defaultdict from typing import Sequence @@ -194,9 +195,7 @@ def generate_out_args_from_schema( lambda a: [] if a.annotation is None else a.annotation.alias_set, func.arguments.flat_all, ) - valid_annotations = [ - x for x in "abcdefghijklmnopqrstuvwxyz" if x not in used_annotations - ] + valid_annotations = [x for x in string.ascii_lowercase if x not in used_annotations] all_rets_are_tensors = all(r.type == BaseType(BaseTy.Tensor) for r in func.returns) diff --git a/torchgen/operator_versions/gen_mobile_upgraders.py b/torchgen/operator_versions/gen_mobile_upgraders.py index 362ce427d508c..5e4034bc4d61e 100644 --- a/torchgen/operator_versions/gen_mobile_upgraders.py +++ b/torchgen/operator_versions/gen_mobile_upgraders.py @@ -263,7 +263,7 @@ def construct_register_size(register_size_from_yaml: int) -> str: def construct_version_maps( - upgrader_bytecode_function_to_index_map: dict[str, Any] + upgrader_bytecode_function_to_index_map: dict[str, Any], ) -> str: version_map = torch._C._get_operator_version_map() sorted_version_map_ = sorted(version_map.items(), key=itemgetter(0)) # type: ignore[no-any-return] @@ -305,7 +305,7 @@ def construct_version_maps( def get_upgrader_bytecode_function_to_index_map( - upgrader_dict: list[dict[str, Any]] + upgrader_dict: list[dict[str, Any]], ) -> dict[str, Any]: upgrader_bytecode_function_to_index_map = {} index = 0 diff --git a/torchgen/static_runtime/config.py b/torchgen/static_runtime/config.py index 1e7b541fa2c12..9fe129f9754dd 100644 --- a/torchgen/static_runtime/config.py +++ b/torchgen/static_runtime/config.py @@ -366,9 +366,9 @@ def override_test_values(arg_map: dict[str, str], op_name: str, index: int) -> N arg_map["out_int32"] = "false" else: arg_map["crow_indices"] = "torch::tensor({0}, torch::kInt32)" - arg_map[ - "col_indices" - ] = "torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)" + arg_map["col_indices"] = ( + "torch::tensor({0, 1, 0, 2, 1, 2, 0, 1, 0, 2, 1, 2}, torch::kInt32)" + ) arg_map["out_int32"] = "false" return if op_name == "_convert_indices_from_coo_to_csr":