From 8043e67026b1cd5b5f1d17c46cd6fe579c322168 Mon Sep 17 00:00:00 2001 From: Haroun Habeeb Date: Fri, 15 Nov 2024 05:03:20 +0000 Subject: [PATCH] catch tensor.numel() == 0 in nan detector (#140741) Context: we are trying to pass an empty tensor through the system now (sometimes;... its an edge case); and it seems to cause all_reduce to seg fault, which is unexpected to me Deep Shah and Pavan identified the issue, I'm just pushing for a fix :) Test Plan: idk what i'm doing here, someone help Reviewed By: shuqiangzhang Differential Revision: D65956095 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140741 Approved by: https://github.com/shuqiangzhang --- torch/csrc/distributed/c10d/NanCheck.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torch/csrc/distributed/c10d/NanCheck.cu b/torch/csrc/distributed/c10d/NanCheck.cu index d256413d60a10..2506ccd1ad094 100644 --- a/torch/csrc/distributed/c10d/NanCheck.cu +++ b/torch/csrc/distributed/c10d/NanCheck.cu @@ -233,6 +233,10 @@ void checkForNan(const at::Tensor& tensor, at::cuda::CUDAStream& stream) { const size_t numThreadsPerBlock = std::min(maxNumThreadsPerBlock, tensor.numel()); + if (!(numThreadsPerBlock > 0)) { + return; + } + const size_t numBlocks = std::min( maxNumBlocks, (tensor.numel() + numThreadsPerBlock - 1) / numThreadsPerBlock);